Commit b87b119b authored by Alessio Quaresima's avatar Alessio Quaresima
Browse files

SpikeTimit.jl structure reorganized in 4 files.

General improvements and check of function consistency
parent e3753bcb
This diff is collapsed.
using MAT
"""
Get words and time intervals from the SenID
"""
function _get_words(root, senID)
path = joinpath(root, senID*".wrd")
times0 = []
times1 = []
words = []
for line in readlines(path)
t0,tf,w = split(line)
push!(times0,parse(Int,t0))
push!(times1,parse(Int,tf))
push!(words, String(w))
end
_data = Array{Union{String,Float64},2}(undef,length(times0),3)
_data[:,1] = words
_data[:,2] = times0 ./ sr
_data[:,3] = times1 ./ sr
return _data
end
matopen
"""
Get phonemes and time intervals from the SenID
"""
function _get_phones(root, senID)
path = joinpath(root, senID*".phn")
times0 = []
times1 = []
phones = []
for line in readlines(path)
t0,tf,p = split(line)
push!(times0,parse(Int,t0))
push!(times1,parse(Int,tf))
push!(phones, String(p))
end
_data = Array{Union{String,Float64},2}(undef,length(times0),3)
_data[:,1] = phones
_data[:,2] = times0 ./ sr
_data[:,3] = times1 ./ sr
return _data
end
# function create_speaker_dataset(;dir)
# df = DataFrame(dir)
# dir = "/home/cocconat/Documents/Research/phd_project/speech/litwin-kumar_model_thesis/DOC/SPKRINFO.TXT"
# using DataFrames
# DataFrames.readtable(dir)
# # , allowcomments=true, commentmark='%')
function _get_dialect(root)
return splitpath(root) |> x->parse(Int,filter(startswith("dr"),x)[1][end])
end
function create_dataset(;dir)
df = DataFrame(speaker = String[] , senID = String[], dialect=Int[], gender=Char[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[], sentence=Vector{String}[])
for (root, dirs, files) in walkdir(dir)
for file in files
if endswith(file,"wav")
speaker = splitpath(root)[end]
senID = split(file,".")[1]
words = _get_words(root, senID)
phones = _get_phones(root, senID)
dr = _get_dialect(root)
gender = speaker[1]
sentence = String.(words[:,1])
push!(df,(speaker,senID,dr,gender,joinpath(root,senID),words,phones,sentence))
end
end
end
return df
end
## Create a dictionary with all words in the dataset
function create_dictionary(;file)
dict = OrderedDict()
for line in readlines(file)
if !startswith(line, ";")
word, sounds = split(line,"/")
push!(dict, strip(word)=>sounds)
end
end
return dict
end
## Obsolete. Use MetaFrame to subselect the dataset.
function find_word(;df::DataFrame, words::Union{Vector{String}, String})
words = isa(words, String) ? [words] : words
in_words(df_words) = !isempty(intersect(Set(df_words),Set(words)))
"""Search for word in the dataset and return the items"""
return @where(df, in_words.(:words))
end
########################
## Extract .mat files ##
########################
function _get_matrix(;df::DataFrame)
get_matrix_file(file) = file*"_binary_matrix.mat"
return get_matrix_file.(df.path)
end
function _get_file(file, ext)
return file*"."*ext
end
function get_spiketimes(;df::DataFrame)
get_spiketimes_file(file)= file*"_spike_timing.mat"
get_array(value::Float64) = begin x=zeros(Float64, 1); x[1] = value; x end
spikes =[]
for f in df.path
fp = matopen(get_spiketimes_file(f))
spike = read(fp)|> x->get(x,"spike_output", nothing)[1,:]
map(row->spike[row] = spike[row][:], findall(typeof.(spike) .==Array{Float64,2}))
map(row->spike[row] = get_array(spike[row]), findall(typeof.(spike) .==Float64))
@assert length(spike) == length(findall(x-> isa(x, Vector{Float64}), spike))
push!(spikes,spike .* correction)
close(fp)
end
map(x->Spiketimes(x), spikes)
end
"""
Unify frequencies bin
Input: spiketimes array with elements of size 620 elements
Return array with sorted spikes in less classes
nfeat new_bins Rounding last_bin
2 15,5 <-- 2*15=30 <DO> add last to last bin
3 10,333333 <-- 3*10=30 <DO> add last to last bin
4 7,75 --> 4*8=32 <DO>
5 6,2 <-- 5*6=30 <DO> add last to last bin
6 5,16666 <-- 6*5=30 <DO> add last to last bin
7 4,429 <-- 7*4=28 <DO> add last 3 to last bin
8 3,875 --> 8*4=32 <DO>
9 3,44444 <-- 9*3=27 <DO> add last 4 to last bin
10 3,1 <-- 10*3=30 <DO> add last to last bin
11 2,818181 --> 11*3=33 <DO>
"""
function resample_spikes!(spiketimes::Vector{Spiketimes}; n_feat = 8, kwargs... )
for s in 1:length(spiketimes)
spiketimes[s] = _resample_spikes!(spiketimes[s], n_feat=n_feat)
end
end
function _resample_spikes!(spiketimes::Spiketimes; n_feat)
# If we don't reduce the bins
if n_feat == 1
return spiketimes
elseif n_feat > 11 || n_feat < 1
@assert 1==0 "Impossible resampling values"
end
FREQUENCIES = 20
all_bins = length(spiketimes) ## should be 620
old_bins = convert(Int64, all_bins/FREQUENCIES) ## should be 31
@assert (old_bins==31) "WARNING: old_bins != 31, this function is probably broken for other values than 31 (super hardcoded)"
new_bins = ceil(Int, old_bins/n_feat )
new_spikes = Spiketimes(map(x->Vector{Float64}(),1:new_bins*FREQUENCIES))
for freq in 1:FREQUENCIES
old_freq = (freq-1)*old_bins
new_freq = (freq-1)*new_bins
for old_bin in 1:old_bins
new_bin = ceil(Int,(old_bin)/n_feat)
push!(new_spikes[new_bin+new_freq], spiketimes[old_bin+old_freq]...)
end
end
return new_spikes
end
function transform_into_bursts!(spiketimes::Vector{Spiketimes}; spikes_per_burst_increase=0, kwargs...)
for s in 1:length(spiketimes)
_transform_into_bursts!(spiketimes[s], spikes_per_burst_increase=spikes_per_burst_increase)
end
end
function _transform_into_bursts!(spiketimes::Spiketimes; spikes_per_burst_increase=0)
##Spike distributions
# based on plot 1B 0.7 nA (Oswald, Doiron & Maler (2007))
expdist = Exponential(0.005)
values = [2,3,4,5,6] .+ spikes_per_burst_increase
weights = [0.8, 0.15, 0.075, 0.035, 0.03]
weights = Weights(weights ./ sum(weights)) # normalized weights
for neuron in 1:length(spiketimes)
if !isempty(spiketimes[neuron])
temp = copy(spiketimes[neuron])
for spike_time in temp
number_of_spikes = sample(values, weights) - 1 # -1 because first spike is determined from data
for j in 1:number_of_spikes
interval = rand(expdist)
new_time = spike_time + interval + 0.004 # ms
push!(spiketimes[neuron], new_time)
end
end
end
sort!(spiketimes[neuron])
end
end
function select_frequencies!(spiketimes::Vector{Spiketimes}; frequencies::Union{Vector, UnitRange{Int64}})
old_freqs = 1:length(spiketimes[1])
del_freqs = []
for freq in old_freqs
if !(freq frequencies)
push!(del_freqs,freq)
end
end
remove_frequencies!(spiketimes, frequencies=del_freqs)
end
function remove_frequencies!(spiketimes::Vector{Spiketimes}; frequencies::Union{Vector, UnitRange{Int64}})
for s in 1:length(spiketimes)
for f in frequencies
spiketimes[s][f] = Vector{Float64}()
end
end
end
function _remove_frequencies!(spiketime::Spiketimes; frequencies::Union{Vector, UnitRange{Int64}})
for f in frequencies
spiketime[f] .= Vector{Float64}()
end
end
#=======================
Raster Plot
=======================#
import Plots: Series, Plot, Subplot
using Plots
using ColorSchemes
struct TracePlot{I,T}
indices::I
plt::Plot{T}
sp::Subplot{T}
series::Vector{Series}
end
function TracePlot(n::Int = 1; maxn::Int = typemax(Int), sp = nothing, kw...)
clist= get(ColorSchemes.colorschemes[:viridis],range(0,1,length=n))
indices = if n > maxn
shuffle(1:n)[1:maxn]
else
1:n
end
if sp == nothing
plt = scatter(length(indices);kw...)
sp = plt[1]
else
plt = scatter!(sp, length(indices); kw...)
end
for n in indices
sp.series_list[n].plotattributes[:markercolor]=clist[n]
sp.series_list[n].plotattributes[:markerstrokecolor]=clist[n]
end
TracePlot(indices, plt, sp, sp.series_list)
end
function Base.push!(tp::TracePlot, x::Number, y::AbstractVector)
push!(tp.series[x], y, x .*ones(length(y)))
end
Base.push!(tp::TracePlot, x::Number, y::Number) = push!(tp, [y], x)
function raster_plot(spikes::Spiketimes; ax=nothing, kwargs...)
s = TracePlot(length(spikes); legend=false, alpha=0.5, kwargs...)
ax = isnothing(ax) ? s : ax
for (n,z) in enumerate(spikes)
if length(z)>0
push!(ax,n,z[:])
end
end
return ax
end
##########################
## Words and Phones structs
##########################
Interval = Vector{Float64}
Intervals = Vector{Interval}
struct Phone
ph::String
t0::Float64
t1::Float64
osc::Array{Float64}
db::Matrix{Float64}
function Phone(ph::String, t0::Float64, t1::Float64)
new(ph, t0, t1, zeros(1), zeros(1,1))
end
function Phone(ph::String, t0::Float64, t1::Float64, osc::Vector{Float64}, db::Matrix{Float64})
new(ph, t0, t1, osc, db)
end
end
struct Word
word::String
phones::Vector{Phone}
duration::Float64
t0::Float64
t1::Float64
end
struct Transcription
intervals::Vector{Tuple{Float64,Float64}}
steps::Vector{Tuple{Int,Int}}
sign::Vector{String}
function Transcription()
new([],[],[])
end
end
########################
## Stack spikes together
########################
"""
Extract all the firing times and the corresponding neurons from an array with
all the neurons and their relative firing times. i.e. the inverse_dictionary
"""
function _inverse_dictionary(spikes::Spiketimes)
all_times = Dict()
for n in eachindex(spikes)
if !isempty(spikes[n])
for tt in spikes[n]
# tt = round(Int,t*1000/dt) ## from seconds to timesteps
if haskey(all_times,tt)
push!(all_times[tt], n)
else
push!(all_times, tt=>[n])
end
end
end
end
return all_times
end
"""
From the inverse_dictionary data structure obtain 2 arrays that are faster to access in the simulation loop.
1. First array contains the sorted spike times.
2. The second array contains vectors with each firing neuron
"""
function _sort_spikes(dictionary)
neurons = Array{Vector{Int}}(undef, length(keys(dictionary)))
sorted = sort(collect(keys(dictionary)))
for (n,k) in enumerate(sorted)
neurons[n] = dictionary[k]
end
return sorted, neurons
end
function _sort_spikes(spiketimes::Vector{Float64}, neurons::Vector{Vector{Int64}})
order = collect(1:length(spiketimes))
sort!(order, by=x->spiketimes[x])
return spiketimes[order], neurons[order]
end
"""
Stack together spiketimes sequences:
- spikes is an array with inputs in the form Vectr
- durations is the duration in seconds of each encoding
"""
function stack_spiketimes(spikes::Vector{Spiketimes}, durations::Vector{Float64}, silence_time::Float64)
# for the memory allocation
nr_unique_fts = 0
for spike_times in spikes
nr_unique_fts +=length(_inverse_dictionary(spike_times))
end
all_neurons = Vector{Vector{Int}}(undef, nr_unique_fts)
all_ft = Vector{Float64}(undef, nr_unique_fts)
global_time = 0.
filled_indices = 0
for (spike_times, dd) in zip(spikes, durations)
# get spiketimes
sorted, neurons = _sort_spikes(_inverse_dictionary(spike_times))
#shift time for each neuron:
sorted .+= global_time
## put them together
lower_bound = filled_indices +1
filled_indices = lower_bound + size(sorted,1) -1
all_ft[lower_bound: filled_indices] = sorted
all_neurons[lower_bound:filled_indices] = neurons
global_time += dd
global_time += silence_time
end
@assert(size(all_ft) == size(all_neurons))
return _sort_spikes(all_ft, all_neurons)
end
"""
Stack labels sequences:
- labels is an array with Strings
- durations is the duration in seconds of each encoding
"""
function stack_labels(labels::Vector{Word}, durations::Vector{Float64}, silence_time::Float64)
phones_transcripts = Transcription()
words_transcripts = Transcription()
global_time = 0.
for (label, dd) in zip(labels, durations)
@assert(label.duration == dd)
push!(words_transcripts.intervals,(global_time, global_time+dd))
push!(words_transcripts.sign,label.word)
for ph in label.phones
push!(phones_transcripts.intervals, (global_time+ph.t0, global_time+ph.t1))
push!(phones_transcripts.sign,ph.ph)
end
global_time += dd + silence_time
end
return words_transcripts,phones_transcripts
end
"""
Extract the word and the phones contained in the datataset and matching with the target word.
Each row addresses a df entry, each row contains all the phones and word labels, with time intervals.
"""
function get_word_labels(df, word::String)
df_phones = Vector{Vector{Word}}()
for row in eachrow(df)
all_phones = Vector{Word}()
for my_word in eachrow(row.words)
if my_word[1] == word
t0,t1 = my_word[2:3]
word_phones = Vector{Phone}()
for phone in eachrow(row.phones)
if (phone[2] >= t0) && (phone[3]<= t1)
ph = Phone(phone[1], phone[2]-t0,phone[3]-t0)
push!(word_phones, ph)
end
end
push!(all_phones, Word(String(my_word[1]), word_phones,t1-t0,t0,t1))
end
end
push!(df_phones, all_phones)
end
return df_phones
end
function get_word_spikes(df::DataFrame, words::Union{String, Vector{String}})
if isa(words, String)
return get_spikes_in_interval(get_spiketimes(df=df), _get_word_interval(df,words))
else
spikes = Vector{Spiketimes}()
durations = Vector{Float64}()
_durations=
for word in words
_spikes, _durations = get_spikes_in_interval(get_spiketimes(df=df), _get_word_interval(df,word))
push!(spikes, _spikes...)
push!(durations, _durations...)
end
return spikes, durations
end
end
"""
For each realization of a word in a dataset entry, extract the interval corresponding to it
Return all the intervals for each dataset entry
"""
function _get_word_interval(df, word::String)
df_intervals = Vector{Intervals}()
for row in eachrow(df)
intervals = Intervals()
for my_word in eachrow(row.words)
if my_word[1] == word
interval = [float(my_word[2]),float(my_word[3])]
push!(intervals, interval)
end
end
push!(df_intervals, intervals)
end
return df_intervals
end
"""
Return the spiketimes subset corresponding to the selected interval, for vectors of Spiketimes
"""
function get_spikes_in_interval(spiketimes::Union{Spiketimes, Array{Spiketimes}},
df_intervals::Union{Intervals, Array{Intervals}})
new_spiketimes = Vector{Spiketimes}()
durations = Vector{Float64}()
if isa(spiketimes,Spiketimes)
for intervals in df_intervals
for interval in intervals
new_spikes, duration = _get_spikes_in_interval(spiketimes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
end
end
else
@assert(length(spiketimes) == length(df_intervals))
for (spikes, intervals) in zip(spiketimes, df_intervals)
for interval in intervals
new_spikes, duration = _get_spikes_in_interval(spikes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
end
end
end
return new_spiketimes, durations
end
"""
Return the spiketimes subset corresponding to the selected interval, for one Spiketimes
"""
function _get_spikes_in_interval(spikes, interval)
new_spiketimes=Spiketimes()
for neuron in spikes
neuron_spiketimes = Vector()
for st in eachindex(neuron)
if (neuron[st] > interval[1]) && (neuron[st] < interval[end])
push!(neuron_spiketimes, neuron[st]-interval[1])
end
end
push!(new_spiketimes, neuron_spiketimes)
end
return new_spiketimes, interval[end] - interval[1]
end
## Select and prepare spike inputs for network simulation
function select_inputs( df::DataFrame, words::Vector{String};
samples::Union{Int, Nothing}=nothing,
kwargs...)
"""
Select inputs in the dataset and pre-process them for the simulation
Parameters:
==========
df::DataFrame Input dataset
words::Vector{String} Words to be selected
Optional:
--------
samples::Int64 Number of samples of the word to be selected in the dataset
if the required samples are more than the available realization
return an error.
"""
all_spikes = Vector{Spiketimes}()
all_durations = Vector{Float64}()
all_labels = Vector{Word}()
for (i, word) in enumerate(words)
df_word = find_word(words=word, df=df)
n_occurences = size(df_word,1)
inds = randperm(n_occurences)
if !isnothing(samples)
if samples <= n_occurences
inds = inds[1:samples]
else
message = string("WARNING: for word: '", word, "', The required
samples per word (", samples, ") exceeds the number of occurences (", n_occurences, ")")
@assert false message
end
end
###Get intervals and phonemes for each dataset entry (some can have more than one!)
spiketimes, durations = get_word_spikes(df_word[inds,:], word)
resample_spikes!(spiketimes; kwargs...)
transform_into_bursts!(spiketimes; kwargs...)
labels = vcat(get_word_labels(df_word[inds,:],word)...)
@assert(length(spiketimes) == length(labels))
push!(all_spikes, spiketimes...)
push!(all_durations, durations...)
push!(all_labels,labels...)
end
return all_durations, all_spikes, all_labels
end
function transcriptions_to_dt(data::Transcription,dt::Float64)
for n in eachindex(data.intervals)
t0,t1 = data.intervals[n]
push!(data.steps,(round(Int,t0*1000/dt), round(Int, t1*1000/dt)))
end
return data
end
function all_ft_to_dt(data::Vector{Float64},dt::Float64)
return round.(Int,data ./dt .*1000)
end
function mix_inputs(durations::Vector{Float64}, spikes::Vector{Spiketimes}, labels::Vector{Word};
repetitions::Int64, silence_time::Float64, ids=false)
_ids = shuffle(repeat(1:length(durations), repetitions))