Commit 876384f3 authored by alessio.quaresima's avatar alessio.quaresima
Browse files

New SpikeTimit

parent 00d0bf6f
......@@ -7,6 +7,7 @@ module SpikeTimit
using MAT
using Plots
using ColorSchemes
using Random
## This is the sample rate used in the scripts.
sr = 16000
## This is the rescaling factor for the spike-time discussed in the PDF
......@@ -14,19 +15,10 @@ module SpikeTimit
import Plots: Series, Plot, Subplot
## Create a dictionary with all words in the dataset
function create_dictionary(;file)
dict = OrderedDict()
for line in readlines(file)
if !startswith(line, ";")
word, sounds = split(line,"/")
# @show word, sounds
push!(dict, word=>sounds)
end
end
return dict
end
"""
Get words and time intervals from the SenID
"""
function get_words(root, senID)
path = joinpath(root, senID*".wrd")
times0 = []
......@@ -45,6 +37,9 @@ module SpikeTimit
return _data
end
"""
Get phonemes and time intervals from the SenID
"""
function get_phones(root, senID)
path = joinpath(root, senID*".phn")
times0 = []
......@@ -62,17 +57,43 @@ module SpikeTimit
_data[:,3] = times1 ./ sr
return _data
end
## Create a dictionary with all words in the dataset
function create_dictionary(;file)
dict = OrderedDict()
for line in readlines(file)
if !startswith(line, ";")
word, sounds = split(line,"/")
# @show word, sounds
push!(dict, word=>sounds)
end
end
return dict
end
# function create_speaker_dataset(;dir)
# df = DataFrame(dir)
# dir = "/home/cocconat/Documents/Research/phd_project/speech/litwin-kumar_model_thesis/DOC/SPKRINFO.TXT"
# using DataFrames
# DataFrames.readtable(dir)
# # , allowcomments=true, commentmark='%')
function get_dialect(root)
return String(split(root,"/") |> x->filter(startswith("dr"),x)[1])
end
function create_dataset(;dir)
df = DataFrame(speaker = String[] , senID = String[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[])
df = DataFrame(speaker = String[] , senID = String[], dialect=String[], gender=Char[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[])
for (root, dirs, files) in walkdir(dir)
for file in files
if endswith(file,"wav")
speaker = split(root, "/")[end]
senID = split(file,".")[1]
words = get_words(root, senID)
phones = get_phones(root, senID)
# @show phones
push!(df,(speaker,senID,joinpath(root,senID),words,phones))
senID = split(file,".")[1]
words = get_words(root, senID)
phones = get_phones(root, senID)
dr = get_dialect(root)
gender = speaker[1]
push!(df,(speaker,senID,dr,gender,joinpath(root,senID),words,phones))
end
end
end
......@@ -84,18 +105,9 @@ module SpikeTimit
return @linq transform(filter(:words => x-> word x , df), :senID) |> unique
end
function get_spiketimes(;df)
get_spiketimes_file(file)= file*"_spike_timing.mat"
if length(size(df)) == 1 # only 1 row
spikes = [read(matopen(get_spiketimes_file(df.path)))["spike_output"][1,:]]
else
spikes = map(x->x[1,:],get_spiketimes_file.(df.path) |> x->matopen.(x) |> x-> read.(x) |> x->get.(x,"spike_output", nothing) )
end
map(spike->map(row->spike[row] = [spike[row]], findall(typeof.(spike) .==Float64)), spikes)
map(spike->findall(isa.(spike,Matrix)) |> x->spike[x] = spike[x]*correction, spikes)
return spikes
end
########################
## Extract .mat files ##
########################
function get_matrix(;df)
get_matrix_file(file) = file*"_binary_matrix.mat"
......@@ -106,6 +118,23 @@ module SpikeTimit
return file*"."*ext
end
Spiketimes = Vector{Vector{Float64}}
function get_spiketimes(;df)
get_spiketimes_file(file)= file*"_spike_timing.mat"
get_array(value::Float64) = begin x=zeros(Float64, 1); x[1] = value; x end
if length(size(df)) == 1 # only 1 row
spikes = [read(matopen(get_spiketimes_file(df.path)))["spike_output"][1,:]]
else
spikes = map(x->x[1,:],get_spiketimes_file.(df.path) |> x->matopen.(x) |> x-> read.(x) |> x->get.(x,"spike_output", nothing) )
end
map(spike->map(row->spike[row] = spike[row][:], findall(typeof.(spike) .==Array{Float64,2})), spikes)
map(spike->map(row->spike[row] = get_array(spike[row]), findall(typeof.(spike) .==Float64)), spikes)
map(spike->findall(isa.(spike,Array{Float64,1})) |> x->spike[x] = spike[x]*correction, spikes)
return map(x->Spiketimes(x), spikes)
end
## make fancy raste plot
struct TracePlot{I,T}
indices::I
......@@ -138,16 +167,16 @@ module SpikeTimit
end
Base.push!(tp::TracePlot, x::Number, y::Number) = push!(tp, [y], x)
function raster_plot(spikes; alpha=0.5)
s = TracePlot(length(spikes), legend=false)
function raster_plot(spikes; ax=nothing, alpha=0.5)
s = TracePlot(length(spikes), legend=false, alpha)
ax = isnothing(ax) ? s : ax
for (n,z) in enumerate(spikes)
if length(z)>0
push!(s,n,z[1,:])
push!(ax,n,z[:])
end
end
return s
return ax
end
###########################################################
"""
......@@ -171,7 +200,6 @@ module SpikeTimit
end
"""
From the reverse_dictionary data structure obtain 2 arrays that are faster to access in the simulation loop.
1. First array contains the sorted spike times.
......@@ -188,25 +216,39 @@ module SpikeTimit
"""
Stack together different input maps:
Stack together different input maps: (now using memory allocation)
- spikes is an array with inputs
- duration is the duration in seconds of each encoding
- durations is the duration in seconds of each encoding
"""
function stack_spiketimes(spikes, duration, dt)
global_time_step = 0
all_neurons = []
all_ft = []
for (spike_times,dd) in zip(spikes, duration)
dictionary = reverse_dictionary(spikes[1], dt)
function stack_spiketimes(spikes, durations, silence_time; dt=0.1)
# for the memory allocation
nr_unique_fts = 0
for spike_times in spikes
dict = reverse_dictionary(spike_times, dt)
nr_unique_fts += length(collect(Set(dict)))
end
global_time_step = 0
# memory allocation!! Otherwise stack overflow.
all_neurons = Vector{Any}(undef, nr_unique_fts)
all_ft = Vector{Any}(undef, nr_unique_fts)
filled_indices = 0
for (spike_times,dd) in zip(spikes, durations)
dictionary = reverse_dictionary(spike_times, dt)
sorted, neurons = sort_spikes(dictionary)
#shift time:
sorted .+= global_time_step
sorted .+= global_time_step + silence_time
## put them together
all_neurons = vcat(all_neurons, neurons)
all_ft = vcat(all_ft..., sorted)
lower_bound = filled_indices + 1
filled_indices += size(sorted,1)
all_ft[lower_bound:filled_indices] = sorted
all_neurons[lower_bound:filled_indices] = neurons
global_time_step += round(Int, dd*1000/dt)
end
@assert(size(all_ft) == size(all_neurons))
return all_ft, all_neurons
end
......@@ -216,13 +258,28 @@ module SpikeTimit
function inputs_from_df(df, rows; dt=0.1)
_df = df[rows,:]
## Get the duration of each frame: it corresponds to the (last) symbol h# in the phones array
duration = df.phones |> phones->map(phone->phone[end], phones)
spikes= SpikeTimit.get_spiketimes(df=_df)
all_ft, all_n = stack_spiketimes(spikes, duration, dt)
durations = _df.phones |> phones->map(phone->phone[end], phones)
spikes = SpikeTimit.get_spiketimes(df=_df)
all_ft, all_n = stack_spiketimes(spikes, durations, dt)
@assert(size(all_ft) == size(all_n))
return all_ft, all_n
return all_ft, all_n, durations
end
"""
Get spiketimes corresponding to the word
"""
function get_word_spiketimes(;df, word::String)
df = find_word(;df=df,word=word)
intervals = get_interval_word(; df=df, word=word)
phones = get_phones_in_word(;df=df, word=word)
spikes = get_spiketimes(;df)
spiketimes, duration = get_spikes_in_interval(; spiketimes=spikes,intervals=intervals)
return spiketimes, duration, phones
end
"""
For each word extract the interval corresponding to it
"""
......@@ -242,15 +299,42 @@ module SpikeTimit
end
"""
Return the spiketimes subset corresponding to the selected interval, for each sentence
For each word extract the interval corresponding to it
"""
function get_spikes_in_interval(; spiketimes, intervals)
new_spiketimes = Vector()
function get_phones_in_word(;df, word::String)
intervals = []
for row in eachrow(df)
interval = [0.,0.]
for my_word in eachrow(row.words)
if my_word[1] == word
interval = my_word[2:3]
end
end
push!(intervals, interval)
end
isempty(intervals)
return intervals
end
"""
Return the spiketimes subset corresponding to the selected interval
"""
function get_spikes_in_interval(; spiketimes::Union{Spiketimes, Array{Spiketimes}}, intervals)
new_spiketimes = Vector{Spiketimes}()
durations = Vector()
for (spikes, interval) in zip(spiketimes, intervals)
new_spikes, duration = _get_spikes_in_interval(spikes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
if isa(spiketimes,Spiketimes)
for interval in intervals
new_spikes, duration = _get_spikes_in_interval(spiketimes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
end
else
@assert(length(spiketimes) == length(intervals))
for (spikes, interval) in zip(spiketimes, intervals)
new_spikes, duration = _get_spikes_in_interval(spikes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
end
end
return new_spiketimes, durations
end
......@@ -260,38 +344,274 @@ module SpikeTimit
Return the spiketimes subset corresponding to the selected interval, for one sentence
"""
function _get_spikes_in_interval(spikes, interval)
for n in eachindex(spikes)
in_interval(x) = (x > interval[1])*(x < interval[end])
spikes[n] = filter(in_interval,spikes[n]) .- interval[1]
new_spiketimes=Spiketimes()
for neuron in spikes
neuron_spiketimes = Vector()
for st in eachindex(neuron)
if (neuron[st] > interval[1]) && (neuron[st] < interval[end])
push!(neuron_spiketimes, neuron[st]-interval[1])
end
end
push!(new_spiketimes, neuron_spiketimes)
end
return spikes, interval[end] - interval[1]
return new_spiketimes, interval[end] - interval[1]
end
"""
function resample_spikes(;spiketimes::Array{Spiketimes,1}, n_feat=8)
for s in eachindex(spiketimes)
spiketimes[s] = _resample_spikes(spiketimes=spiketimes[s], n_feat=n_feat)
end
return spiketimes
end
"""
Unify frequencies bin
Input: spiketimes array with 620 elements
Input: spiketimes array with elements of size 620 elements
Return array with sorted spikes in less classes
"""
function resample_spikes(;spike_times, n_feat=7)
function _resample_spikes(;spiketimes::Spiketimes, n_feat=8)
FREQUENCIES = 20
old_bins = length(spike_times)÷FREQUENCIES
new_bins = 31÷n_feat+1
old_bins = length(spiketimes)÷FREQUENCIES
new_bins = old_bins÷n_feat+1
new_spikes = map(x->Vector{Float64}(),1:new_bins*FREQUENCIES)
for freq in 1:FREQUENCIES
old_freq = (freq-1)*old_bins
new_freq = (freq-1)*new_bins
for new_bin in 1:new_bins
last_bin = new_bin*n_feat <32 ? new_bin*n_feat : 31
bins = 1+(new_bin-1)*n_feat : last_bin
for old_bin in bins
push!(new_spikes[new_bin+new_freq], spike[old_bin+old_freq]...)
for freq in 1:FREQUENCIES
old_freq = (freq-1)*old_bins
new_freq = (freq-1)*new_bins
for new_bin in 1:new_bins
last_bin = new_bin*n_feat <32 ? new_bin*n_feat : 31
bins = 1+(new_bin-1)*n_feat : last_bin
for old_bin in bins
push!(new_spikes[new_bin+new_freq], spiketimes[old_bin+old_freq]...)
end
end
end
return sort!.(new_spikes)
end
"""
Extract all the phonemes contained in the word and their time intervals.
Run over all the sentences in the dataset and return a list of relevant phonemes,
each row is the corresponding sentence (df entry), each row contains all the phones.
"""
function get_phones_in_word(;df, word)
all_phones = []
for row in eachrow(df)
for my_word in eachrow(row.words)
word_phones = []
if my_word[1] == word
t0,t1 = my_word[2:3]
phones = []
for phone in eachrow(row.phones)
if (phone[2] >= t0) && (phone[3]<= t1)
push!(phones, collect(phone))
end
end
push!(word_phones, phones)
end
push!(all_phones, word_phones...)
end
end
return sort!.(new_spikes)
return all_phones
end
"""
Get firing times for each phones list.
The phones list contains all the relevant phones (and their time interval) in a certain df entry
"""
function get_phones_spiketimes(spikes, phones_list)
phone_spiketimes = []
phone_labels = []
for (phones, spike) in zip(phones_list, spikes)
intervals=[]
for phone in phones
push!(intervals,phone[2:3])
push!(phone_labels,phone[1])
end
spiketime, duration = SpikeTimit.get_spikes_in_interval(; spiketimes=spike,intervals=intervals)
push!(phone_spiketimes,spiketime)
end
return phone_labels, vcat(phone_spiketimes...)
end
end
#
# function get_labels_and_intervals(df, indices, durations; dt=0.1)
# sentences = df[indices,:].words
# word_interval_ends = reshape([],0,1)
# word_labels = reshape([],0,1)
# prev_sents_duration = 0
# for (i,sentence) in enumerate(sentences)
# word_interval_ends = vcat(word_interval_ends, (copy(sentence[:,3]) .+= prev_sents_duration))
# word_labels = vcat(word_labels, sentence[:,1])
# prev_sents_duration += durations[i] #add the duration of the entire sentence to the duration of all previous sentences
# end
# for i = 1:size(word_interval_ends, 1)
# word_interval_ends[i,1] = round(Int, word_interval_ends[i,1]*1000/dt) #from seconds to timesteps
# end
# return word_labels, word_interval_ends
# end
#
# # function input_from_Dong(df, silence_time, dt; filter_speakers=true, n_speakers=10, repetitions=1, train_or_test="train", filter_dialects=true, target_dialects=["dr1"], filter_genders=true, target_gender="f", samples=10, n_feat = 7) # Dong_words = ["that", "she", "all", "your", "me", "had", "like", "don't", "year", "water", "dark", "rag", "oily", "wash", "ask", "carry", "suit"]
# # Dong_words = ["that", "had", "she", "me", "your", "year"]
# # intervals = []
# # df_words = []
# # samples_per_word = samples
# #
# # for (i, word) in enumerate(Dong_words)
# # df_word = copy(find_word(word=word, df=df))
# # if filter_dialects
# # df_word = filter_dialect(df_word, train_or_test, target_dialects)
# # end
# # if filter_genders
# # df_word = filter_gender(df_word, target_gender)
# # end
# # if filter_speakers
# # df_word = filter_speaker(df_word, n_speakers)
# # end
# #
# #
# # n_occurences = size(df_word,1)
# #
# # #randomly order the number of occurences to sample
# # if samples_per_word <= n_occurences
# # inds = randperm(n_occurences)[1:samples_per_word]
# # df_word = df_word[inds, :]
# # else
# # message = string("WARNING: for word: '", word, "', samples_per_word (", samples_per_word, ") exceeds the number of occurences (", n_occurences, ")")
# # @assert false message
# # end
# #
# # interval = get_interval_word(df=df_word, word=word) #get interval
# # interval_copy = copy(interval)
# # df_word_copy = copy(df_word)
# # for repetition in 2:repetitions
# # df_word = [df_word; df_word_copy]
# # interval = [interval; interval_copy]
# # end
# # push!(df_words, df_word)
# # push!(intervals, interval)
# # end
# #
# # #Because we repeat samples, set samples_per_word to the appropriate value
# # samples_per_word = samples_per_word*repetitions
# #
# # Dong_items = Vector()
# # for i = 1:size(Dong_words)[1]
# # spiketimes = get_spiketimes(; df=df_words[i])
# # spikes, _ = get_spikes_in_interval(spiketimes = spiketimes, intervals = intervals[i])
# # new_spikes = []
# # for spike_times in spikes
# # push!(new_spikes, resample_spikes(;spike_times=spike_times, n_feat=n_feat))
# # end
# # Dong_items = vcat(Dong_items, new_spikes)
# # end
# #
# # labels = Vector{String}(undef, size(Dong_words,1)*samples_per_word)
# # for (i,label) in enumerate(Dong_words)
# # lower_bound = (i-1)*samples_per_word+1
# # upper_bound = lower_bound+samples_per_word-1
# # labels[lower_bound: upper_bound] .= label
# # end
# #
# # durations = []
# # for word_i in intervals
# # for i in 1:size(word_i,1)
# # push!(durations, word_i[i][2] - word_i[i][1])
# # end
# # end
# #
# # # get the phones corresponding to the word by extracting phones from the word intervals
# # phone_labels, phone_durations = phones_in_intervals(samples_per_word, size(Dong_words,1), intervals, df_words)
# #
# # Npop = length(Dong_items[1])
# #
# # # Let's shuffle the data
# # our_data = DataFrame(labels = labels, phone_labels = phone_labels, durations = durations, phone_durations = phone_durations, spikes = Dong_items)
# # our_data = our_data[randperm(nrow(our_data)),:]
# #
# # all_ft, all_n = stack_spiketimes(our_data.spikes, our_data.durations, dt, silence_time*1000/dt)
# # # add silence to the data (in front of each word, also done in stack_spiketimes())
# # our_data.durations .+= silence_time
# #
# # #We don't need spikes anymore
# # select!(our_data, Not(:spikes))
# # return all_ft, all_n, our_data, Npop
# # end
# #
# # function phones_in_intervals(samples_per_word, dong_size, intervals, df_words)
# # phone_labels = Vector{Array{String}}(undef, dong_size*samples_per_word)
# # phone_durations = Vector{Array{Float64}}(undef, dong_size*samples_per_word)
# # for i in eachindex(intervals)
# # for j in eachindex(intervals[i])
# # phone_label = Vector{String}()
# # phone_duration = Vector{Float64}()
# # phones = df_words[i][j, :].phones
# # interval = intervals[i][j]
# # for r in eachindex(phones[:, 1])
# # if phones[r,2] >= interval[1] && phones[r,3] <= interval[2]
# # push!(phone_label, phones[r,1])
# # push!(phone_duration, phones[r,3] - phones[r,2])
# # end
# # end
# # phone_labels[(i-1)*samples_per_word + j] = phone_label
# # phone_durations[(i-1)*samples_per_word + j] = phone_duration
# # end
# # end
# # return phone_labels, phone_durations
# # end
# #
# # function get_labels_and_save_points_Dong(our_data, silence_time; dt=0.1, measurements_per_word=1, measurements_per_phone=1)
# # word_save_points = Array{Float64,2}(undef, size(our_data.labels,1),measurements_per_word)
# # word_labels = our_data.labels
# #
# # prev_words_duration = 0
# # for (i,occurrence) in enumerate(eachrow(our_data))
# # save_step = occurrence.durations/measurements_per_word
# # for j in 1:measurements_per_word
# # word_save_points[i,j] = save_step*j + prev_words_duration
# # end
# # prev_words_duration += occurrence.durations
# # end
# #
# # phone_labels = our_data.phone_labels
# # phone_labels = [(phone_labels...)...] # flatten the list
# # phone_durations = our_data.phone_durations
# # phone_save_points = Array{Float64,2}(undef, size(phone_labels,1),measurements_per_phone)
# #
# # prev_phone_duration = 0
# # elapsed_phones = 1
# # for (i,word) in enumerate(phone_durations)
# # prev_phone_duration += silence_time
# # for (j,duration) in enumerate(word)
# # save_step = duration/measurements_per_phone
# # for n in 1:measurements_per_phone
# # phone_save_points[elapsed_phones, n] = save_step*n + prev_phone_duration
# # end
# # prev_phone_duration += duration
# # elapsed_phones += 1
# # end
# # end
# #
# # # Convert from seconds to timesteps
# # for i = 1:size(word_save_points, 1)
# # for j = 1:size(word_save_points, 2)
# # word_save_points[i, j] = round(Int, word_save_points[i, j]*1000/dt) #from seconds to timesteps
# # end
# # end
# #
# # # Convert from seconds to timesteps
# # for i = 1:size(phone_save_points, 1)
# # for j = 1:size(phone_save_points, 2)
# # phone_save_points[i,j] = round(Int, phone_save_points[i,j]*1000/dt) #from seconds to timesteps
# # end
# # end
# #
# # return word_labels, word_save_points, phone_labels, phone_save_points
# # end
### A Pluto.jl notebook ###
# v0.12.21
using Markdown
using InteractiveUtils
# ╔═╡ 70951726-8063-11eb-2ba9-1fa3822b9e91
using .SpikeTimit
# ╔═╡ c0d57c40-8010-11eb-3004-195f0590db26
md"""
This notebook will show how to import data from SpikeTimit[1] database and run it as a stimulus for the LKD network[2]. The SpikeTimit.jl module import the standard daataset released with the publication [1].
"""
# ╔═╡ 62beb3c0-8011-11eb-1ab5-8de2f870d0b2
md""" Import the module and the relevant packages"""
# ╔═╡ 8015ec4a-8011-11eb-04d6-9bcd09fada86
PATH = joinpath(@__DIR__,"SpikeTimit.jl")
# ╔═╡ 5e4f6080-8063-11eb-39b1-ab78ccdbf423
include(PATH)
# ╔═╡ 24e8c3e2-8064-11eb-23dc-2f6848c499e5
# ╔═╡ 693fb58a-8063-11eb-3e59-c9249009d1d6
# ╔═╡ 4566f194-8012-11eb-39d7-ad467788a78b
md"""
Import the dataset. Notice, the spike-time are not imported, only the files are stored and ready to be read. You have to set your PATH fort the dataset
"""
# ╔═╡ 42f1c31e-8063-11eb-0059-75ffb8aa555a
begin
test_path = joinpath(@__DIR__,"Spike TIMIT", "test" );