Commit 004f39c9 authored by alessio.quaresima's avatar alessio.quaresima
Browse files

qMerge branch 'master' of gitlab.socsci.ru.nl:alessio.quaresima1/spiketimit.jl

parents 25733b47 e84476fd
build
DOC
TIMIT
build
__pycache__
using Base
module SpikeTimit
using DataFrames
using DataFramesMeta
using OrderedCollections
using MAT
using Plots
using ColorSchemes
using Random
## This is the sample rate used in the scripts.
sr = 16000
## This is the rescaling factor for the spike-time discussed in the PDF
correction = 5.
import Plots: Series, Plot, Subplot
"""
Get words and time intervals from the SenID
"""
function get_words(root, senID)
path = joinpath(root, senID*".wrd")
times0 = []
times1 = []
words = []
for line in readlines(path)
t0,tf,w = split(line)
push!(times0,parse(Int,t0))
push!(times1,parse(Int,tf))
push!(words, String(w))
end
_data = Array{Union{String,Float64},2}(undef,length(times0),3)
_data[:,1] = words
_data[:,2] = times0 ./ sr
_data[:,3] = times1 ./ sr
return _data
end
"""
Get phonemes and time intervals from the SenID
"""
function get_phones(root, senID)
path = joinpath(root, senID*".phn")
times0 = []
times1 = []
phones = []
for line in readlines(path)
t0,tf,p = split(line)
push!(times0,parse(Int,t0))
push!(times1,parse(Int,tf))
push!(phones, String(p))
end
_data = Array{Union{String,Float64},2}(undef,length(times0),3)
_data[:,1] = phones
_data[:,2] = times0 ./ sr
_data[:,3] = times1 ./ sr
return _data
end
## Create a dictionary with all words in the dataset
function create_dictionary(;file)
dict = OrderedDict()
for line in readlines(file)
if !startswith(line, ";")
word, sounds = split(line,"/")
# @show word, sounds
push!(dict, word=>sounds)
end
end
return dict
end
# function create_speaker_dataset(;dir)
# df = DataFrame(dir)
# dir = "/home/cocconat/Documents/Research/phd_project/speech/litwin-kumar_model_thesis/DOC/SPKRINFO.TXT"
# using DataFrames
# DataFrames.readtable(dir)
# # , allowcomments=true, commentmark='%')
function get_dialect(root)
return String(split(root,"/") |> x->filter(startswith("dr"),x)[1])
end
function create_dataset(;dir)
df = DataFrame(speaker = String[] , senID = String[], dialect=String[], gender=Char[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[])
for (root, dirs, files) in walkdir(dir)
for file in files
if endswith(file,"wav")
speaker = split(root, "/")[end]
senID = split(file,".")[1]
words = get_words(root, senID)
phones = get_phones(root, senID)
dr = get_dialect(root)
gender = speaker[1]
push!(df,(speaker,senID,dr,gender,joinpath(root,senID),words,phones))
end
end
end
return df
end
function find_word(;word, df)
"""Search for word in the dataset and return the items"""
return @linq transform(filter(:words => x-> word x , df), :senID) |> unique
end
########################
## Extract .mat files ##
########################
function get_matrix(;df)
get_matrix_file(file) = file*"_binary_matrix.mat"
return get_matrix_file.(df.path)
end
function get_file(file, ext)
return file*"."*ext
end
Spiketimes = Vector{Vector{Float64}}
function get_spiketimes(;df)
get_spiketimes_file(file)= file*"_spike_timing.mat"
get_array(value::Float64) = begin x=zeros(Float64, 1); x[1] = value; x end
if length(size(df)) == 1 # only 1 row
spikes = [read(matopen(get_spiketimes_file(df.path)))["spike_output"][1,:]]
else
spikes = map(x->x[1,:],get_spiketimes_file.(df.path) |> x->matopen.(x) |> x-> read.(x) |> x->get.(x,"spike_output", nothing) )
end
map(spike->map(row->spike[row] = spike[row][:], findall(typeof.(spike) .==Array{Float64,2})), spikes)
map(spike->map(row->spike[row] = get_array(spike[row]), findall(typeof.(spike) .==Float64)), spikes)
map(spike->findall(isa.(spike,Array{Float64,1})) |> x->spike[x] = spike[x]*correction, spikes)
return map(x->Spiketimes(x), spikes)
end
## make fancy raste plot
struct TracePlot{I,T}
indices::I
plt::Plot{T}
sp::Subplot{T}
series::Vector{Series}
end
function TracePlot(n::Int = 1; maxn::Int = typemax(Int), sp = nothing, kw...)
clist= get(ColorSchemes.colorschemes[:viridis],range(0,1,length=n))
indices = if n > maxn
shuffle(1:n)[1:maxn]
else
1:n
end
if sp == nothing
plt = scatter(length(indices);kw...)
sp = plt[1]
else
plt = scatter!(sp, length(indices); kw...)
end
for n in indices
sp.series_list[n].plotattributes[:markercolor]=clist[n]
sp.series_list[n].plotattributes[:markerstrokecolor]=clist[n]
end
TracePlot(indices, plt, sp, sp.series_list)
end
function Base.push!(tp::TracePlot, x::Number, y::AbstractVector)
push!(tp.series[x], y, x .*ones(length(y)))
end
Base.push!(tp::TracePlot, x::Number, y::Number) = push!(tp, [y], x)
function raster_plot(spikes; ax=nothing, alpha=0.5)
s = TracePlot(length(spikes), legend=false, alpha)
ax = isnothing(ax) ? s : ax
for (n,z) in enumerate(spikes)
if length(z)>0
push!(ax,n,z[:])
end
end
return ax
end
"""
Extract all the firing times and the corresponding neurons from an array with all the neurons and their relative firing times. i.e. the reverse_dictionary
"""
function reverse_dictionary(spikes,dt::Float64)
all_times = Dict()
for n in eachindex(spikes)
if !isempty(spikes[n])
for t in spikes[n]
tt = round(Int,t*1000/dt) ## from seconds to timesteps
if haskey(all_times,tt)
all_times[tt] = [all_times[tt]..., n]
else
push!(all_times, tt=>[n])
end
end
end
end
return all_times
end
"""
From the reverse_dictionary data structure obtain 2 arrays that are faster to access in the simulation loop.
1. First array contains the sorted spike times.
2. The second array contains vectors with each firing neuron
"""
function sort_spikes(dictionary)
neurons = Array{Vector{Int}}(undef, length(keys(dictionary)))
sorted = sort(collect(keys(dictionary)))
for (n,k) in enumerate(sorted)
neurons[n] = dictionary[k]
end
return sorted, neurons
end
"""
Stack together different input maps: (now using memory allocation)
- spikes is an array with inputs
- durations is the duration in seconds of each encoding
"""
function stack_spiketimes(spikes, durations, silence_time; dt=0.1)
# for the memory allocation
nr_unique_fts = 0
for spike_times in spikes
dict = reverse_dictionary(spike_times, dt)
nr_unique_fts += length(collect(Set(dict)))
end
global_time_step = 0
# memory allocation!! Otherwise stack overflow.
all_neurons = Vector{Any}(undef, nr_unique_fts)
all_ft = Vector{Any}(undef, nr_unique_fts)
filled_indices = 0
for (spike_times,dd) in zip(spikes, durations)
dictionary = reverse_dictionary(spike_times, dt)
sorted, neurons = sort_spikes(dictionary)
#shift time:
sorted .+= global_time_step + silence_time
## put them together
lower_bound = filled_indices + 1
filled_indices += size(sorted,1)
all_ft[lower_bound:filled_indices] = sorted
all_neurons[lower_bound:filled_indices] = neurons
global_time_step += round(Int, dd*1000/dt)
end
@assert(size(all_ft) == size(all_neurons))
return all_ft, all_neurons
end
"""
Return the input of the spike trains corrsponding to the rows, in the (spike-time -> neuron ) format
"""
function inputs_from_df(df, rows; dt=0.1)
_df = df[rows,:]
## Get the duration of each frame: it corresponds to the (last) symbol h# in the phones array
durations = _df.phones |> phones->map(phone->phone[end], phones)
spikes = SpikeTimit.get_spiketimes(df=_df)
all_ft, all_n = stack_spiketimes(spikes, durations, dt)
@assert(size(all_ft) == size(all_n))
return all_ft, all_n, durations
end
"""
Get spiketimes corresponding to the word
"""
function get_word_spiketimes(;df, word::String)
df = find_word(;df=df,word=word)
intervals = get_interval_word(; df=df, word=word)
phones = get_phones_in_word(;df=df, word=word)
spikes = get_spiketimes(;df)
spiketimes, duration = get_spikes_in_interval(; spiketimes=spikes,intervals=intervals)
return spiketimes, duration, phones
end
"""
For each word extract the interval corresponding to it
"""
function get_interval_word(;df, word::String)
intervals = []
for row in eachrow(df)
interval = [0.,0.]
for my_word in eachrow(row.words)
if my_word[1] == word
interval = my_word[2:3]
end
end
push!(intervals, interval)
end
isempty(intervals)
return intervals
end
"""
For each word extract the interval corresponding to it
"""
function get_phones_in_word(;df, word::String)
intervals = []
for row in eachrow(df)
interval = [0.,0.]
for my_word in eachrow(row.words)
if my_word[1] == word
interval = my_word[2:3]
end
end
push!(intervals, interval)
end
isempty(intervals)
return intervals
end
"""
Return the spiketimes subset corresponding to the selected interval
"""
function get_spikes_in_interval(; spiketimes::Union{Spiketimes, Array{Spiketimes}}, intervals)
new_spiketimes = Vector{Spiketimes}()
durations = Vector()
if isa(spiketimes,Spiketimes)
for interval in intervals
new_spikes, duration = _get_spikes_in_interval(spiketimes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
end
else
@assert(length(spiketimes) == length(intervals))
for (spikes, interval) in zip(spiketimes, intervals)
new_spikes, duration = _get_spikes_in_interval(spikes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
end
end
return new_spiketimes, durations
end
"""
Return the spiketimes subset corresponding to the selected interval, for one sentence
"""
function _get_spikes_in_interval(spikes, interval)
new_spiketimes=Spiketimes()
for neuron in spikes
neuron_spiketimes = Vector()
for st in eachindex(neuron)
if (neuron[st] > interval[1]) && (neuron[st] < interval[end])
push!(neuron_spiketimes, neuron[st]-interval[1])
end
end
push!(new_spiketimes, neuron_spiketimes)
end
return new_spiketimes, interval[end] - interval[1]
end
function resample_spikes(;spiketimes::Array{Spiketimes,1}, n_feat=8)
for s in eachindex(spiketimes)
spiketimes[s] = _resample_spikes(spiketimes=spiketimes[s], n_feat=n_feat)
end
return spiketimes
end
"""
Unify frequencies bin
Input: spiketimes array with elements of size 620 elements
Return array with sorted spikes in less classes
"""
function _resample_spikes(;spiketimes::Spiketimes, n_feat=8)
FREQUENCIES = 20
old_bins = length(spiketimes)÷FREQUENCIES
new_bins = old_bins÷n_feat+1
new_spikes = map(x->Vector{Float64}(),1:new_bins*FREQUENCIES)
for freq in 1:FREQUENCIES
old_freq = (freq-1)*old_bins
new_freq = (freq-1)*new_bins
for new_bin in 1:new_bins
last_bin = new_bin*n_feat <32 ? new_bin*n_feat : 31
bins = 1+(new_bin-1)*n_feat : last_bin
for old_bin in bins
push!(new_spikes[new_bin+new_freq], spiketimes[old_bin+old_freq]...)
end
end
end
return sort!.(new_spikes)
end
"""
Extract all the phonemes contained in the word and their time intervals.
Run over all the sentences in the dataset and return a list of relevant phonemes,
each row is the corresponding sentence (df entry), each row contains all the phones.
"""
function get_phones_in_word(;df, word)
all_phones = []
for row in eachrow(df)
for my_word in eachrow(row.words)
word_phones = []
if my_word[1] == word
t0,t1 = my_word[2:3]
phones = []
for phone in eachrow(row.phones)
if (phone[2] >= t0) && (phone[3]<= t1)
push!(phones, collect(phone))
end
end
push!(word_phones, phones)
end
push!(all_phones, word_phones...)
end
end
return all_phones
end
"""
Get firing times for each phones list.
The phones list contains all the relevant phones (and their time interval) in a certain df entry
"""
function get_phones_spiketimes(spikes, phones_list)
phone_spiketimes = []
phone_labels = []
for (phones, spike) in zip(phones_list, spikes)
intervals=[]
for phone in phones
push!(intervals,phone[2:3])
push!(phone_labels,phone[1])
end
spiketime, duration = SpikeTimit.get_spikes_in_interval(; spiketimes=spike,intervals=intervals)
push!(phone_spiketimes,spiketime)
end
return phone_labels, vcat(phone_spiketimes...)
end
end
#
# function get_labels_and_intervals(df, indices, durations; dt=0.1)
# sentences = df[indices,:].words
# word_interval_ends = reshape([],0,1)
# word_labels = reshape([],0,1)
# prev_sents_duration = 0
# for (i,sentence) in enumerate(sentences)
# word_interval_ends = vcat(word_interval_ends, (copy(sentence[:,3]) .+= prev_sents_duration))
# word_labels = vcat(word_labels, sentence[:,1])
# prev_sents_duration += durations[i] #add the duration of the entire sentence to the duration of all previous sentences
# end
# for i = 1:size(word_interval_ends, 1)
# word_interval_ends[i,1] = round(Int, word_interval_ends[i,1]*1000/dt) #from seconds to timesteps
# end
# return word_labels, word_interval_ends
# end
#
# # function input_from_Dong(df, silence_time, dt; filter_speakers=true, n_speakers=10, repetitions=1, train_or_test="train", filter_dialects=true, target_dialects=["dr1"], filter_genders=true, target_gender="f", samples=10, n_feat = 7) # Dong_words = ["that", "she", "all", "your", "me", "had", "like", "don't", "year", "water", "dark", "rag", "oily", "wash", "ask", "carry", "suit"]
# # Dong_words = ["that", "had", "she", "me", "your", "year"]
# # intervals = []
# # df_words = []
# # samples_per_word = samples
# #
# # for (i, word) in enumerate(Dong_words)
# # df_word = copy(find_word(word=word, df=df))
# # if filter_dialects
# # df_word = filter_dialect(df_word, train_or_test, target_dialects)
# # end
# # if filter_genders
# # df_word = filter_gender(df_word, target_gender)
# # end
# # if filter_speakers
# # df_word = filter_speaker(df_word, n_speakers)
# # end
# #
# #
# # n_occurences = size(df_word,1)
# #
# # #randomly order the number of occurences to sample
# # if samples_per_word <= n_occurences
# # inds = randperm(n_occurences)[1:samples_per_word]
# # df_word = df_word[inds, :]
# # else
# # message = string("WARNING: for word: '", word, "', samples_per_word (", samples_per_word, ") exceeds the number of occurences (", n_occurences, ")")
# # @assert false message
# # end
# #
# # interval = get_interval_word(df=df_word, word=word) #get interval
# # interval_copy = copy(interval)
# # df_word_copy = copy(df_word)
# # for repetition in 2:repetitions
# # df_word = [df_word; df_word_copy]
# # interval = [interval; interval_copy]
# # end
# # push!(df_words, df_word)
# # push!(intervals, interval)
# # end
# #
# # #Because we repeat samples, set samples_per_word to the appropriate value
# # samples_per_word = samples_per_word*repetitions
# #
# # Dong_items = Vector()
# # for i = 1:size(Dong_words)[1]
# # spiketimes = get_spiketimes(; df=df_words[i])
# # spikes, _ = get_spikes_in_interval(spiketimes = spiketimes, intervals = intervals[i])
# # new_spikes = []
# # for spike_times in spikes
# # push!(new_spikes, resample_spikes(;spike_times=spike_times, n_feat=n_feat))
# # end
# # Dong_items = vcat(Dong_items, new_spikes)
# # end
# #
# # labels = Vector{String}(undef, size(Dong_words,1)*samples_per_word)
# # for (i,label) in enumerate(Dong_words)
# # lower_bound = (i-1)*samples_per_word+1
# # upper_bound = lower_bound+samples_per_word-1
# # labels[lower_bound: upper_bound] .= label
# # end
# #
# # durations = []
# # for word_i in intervals
# # for i in 1:size(word_i,1)
# # push!(durations, word_i[i][2] - word_i[i][1])
# # end
# # end
# #
# # # get the phones corresponding to the word by extracting phones from the word intervals
# # phone_labels, phone_durations = phones_in_intervals(samples_per_word, size(Dong_words,1), intervals, df_words)
# #
# # Npop = length(Dong_items[1])
# #
# # # Let's shuffle the data
# # our_data = DataFrame(labels = labels, phone_labels = phone_labels, durations = durations, phone_durations = phone_durations, spikes = Dong_items)
# # our_data = our_data[randperm(nrow(our_data)),:]
# #
# # all_ft, all_n = stack_spiketimes(our_data.spikes, our_data.durations, dt, silence_time*1000/dt)
# # # add silence to the data (in front of each word, also done in stack_spiketimes())
# # our_data.durations .+= silence_time
# #
# # #We don't need spikes anymore
# # select!(our_data, Not(:spikes))
# # return all_ft, all_n, our_data, Npop
# # end
# #
# # function phones_in_intervals(samples_per_word, dong_size, intervals, df_words)
# # phone_labels = Vector{Array{String}}(undef, dong_size*samples_per_word)
# # phone_durations = Vector{Array{Float64}}(undef, dong_size*samples_per_word)
# # for i in eachindex(intervals)
# # for j in eachindex(intervals[i])
# # phone_label = Vector{String}()
# # phone_duration = Vector{Float64}()
# # phones = df_words[i][j, :].phones
# # interval = intervals[i][j]
# # for r in eachindex(phones[:, 1])
# # if phones[r,2] >= interval[1] && phones[r,3] <= interval[2]
# # push!(phone_label, phones[r,1])
# # push!(phone_duration, phones[r,3] - phones[r,2])
# # end
# # end
# # phone_labels[(i-1)*samples_per_word + j] = phone_label
# # phone_durations[(i-1)*samples_per_word + j] = phone_duration
# # end
# # end
# # return phone_labels, phone_durations
# # end
# #
# # function get_labels_and_save_points_Dong(our_data, silence_time; dt=0.1, measurements_per_word=1, measurements_per_phone=1)
# # word_save_points = Array{Float64,2}(undef, size(our_data.labels,1),measurements_per_word)
# # word_labels = our_data.labels
# #
# # prev_words_duration = 0
# # for (i,occurrence) in enumerate(eachrow(our_data))
# # save_step = occurrence.durations/measurements_per_word
# # for j in 1:measurements_per_word
# # word_save_points[i,j] = save_step*j + prev_words_duration
# # end
# # prev_words_duration += occurrence.durations
# # end
# #
# # phone_labels = our_data.phone_labels
# # phone_labels = [(phone_labels...)...] # flatten the list
# # phone_durations = our_data.phone_durations
# # phone_save_points = Array{Float64,2}(undef, size(phone_labels,1),measurements_per_phone)
# #
# # prev_phone_duration = 0
# # elapsed_phones = 1
# # for (i,word) in enumerate(phone_durations)
# # prev_phone_duration += silence_time
# # for (j,duration) in enumerate(word)
# # save_step = duration/measurements_per_phone
# # for n in 1:measurements_per_phone
# # phone_save_points[elapsed_phones, n] = save_step*n + prev_phone_duration
# # end
# # prev_phone_duration += duration
# # elapsed_phones += 1
# # end
# # end
# #
# # # Convert from seconds to timesteps
# # for i = 1:size(word_save_points, 1)
# # for j = 1:size(word_save_points, 2)
# # word_save_points[i, j] = round(Int, word_save_points[i, j]*1000/dt) #from seconds to timesteps
# # end
# # end
# #
# # # Convert from seconds to timesteps
# # for i = 1:size(phone_save_points, 1)
# # for j = 1:size(phone_save_points, 2)
# # phone_save_points[i,j] = round(Int, phone_save_points[i,j]*1000/dt) #from seconds to timesteps
# # end
# # end
# #
# # return word_labels, word_save_points, phone_labels, phone_save_points
# # end