Commit c798055c authored by alessio.quaresima's avatar alessio.quaresima
Browse files

Clean up git

parent 48705fa0
......@@ -8,6 +8,9 @@ module SpikeTimit
using Plots
using ColorSchemes
using Random
using StatsBase
using Distributions
## This is the sample rate used in the scripts.
sr = 16000
## This is the rescaling factor for the spike-time discussed in the PDF
......@@ -64,7 +67,6 @@ module SpikeTimit
for line in readlines(file)
if !startswith(line, ";")
word, sounds = split(line,"/")
# @show word, sounds
push!(dict, word=>sounds)
end
end
......@@ -79,11 +81,11 @@ module SpikeTimit
# # , allowcomments=true, commentmark='%')
function get_dialect(root)
return String(split(root,"/") |> x->filter(startswith("dr"),x)[1])
return split(root,"/") |> x->parse(Int,filter(startswith("dr"),x)[1][end])
end
function create_dataset(;dir)
df = DataFrame(speaker = String[] , senID = String[], dialect=String[], gender=Char[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[])
df = DataFrame(speaker = String[] , senID = String[], dialect=Int[], gender=Char[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[])
for (root, dirs, files) in walkdir(dir)
for file in files
if endswith(file,"wav")
......@@ -91,8 +93,8 @@ module SpikeTimit
senID = split(file,".")[1]
words = get_words(root, senID)
phones = get_phones(root, senID)
dr = get_dialect(root)
gender = speaker[1]
dr = get_dialect(root)
gender = speaker[1]
push!(df,(speaker,senID,dr,gender,joinpath(root,senID),words,phones))
end
end
......@@ -100,16 +102,16 @@ module SpikeTimit
return df
end
function find_word(;word, df)
function find_word(;df::DataFrame, word::String)
"""Search for word in the dataset and return the items"""
return @linq transform(filter(:words => x-> word x , df), :senID) |> unique
return @linq df |> where(word . :words)
end
########################
## Extract .mat files ##
########################
function get_matrix(;df)
function get_matrix(;df::DataFrame)
get_matrix_file(file) = file*"_binary_matrix.mat"
return get_matrix_file.(df.path)
end
......@@ -120,7 +122,8 @@ module SpikeTimit
Spiketimes = Vector{Vector{Float64}}
function get_spiketimes(;df)
function get_spiketimes(;df::DataFrame)
get_spiketimes_file(file)= file*"_spike_timing.mat"
get_array(value::Float64) = begin x=zeros(Float64, 1); x[1] = value; x end
if length(size(df)) == 1 # only 1 row
......@@ -135,59 +138,20 @@ module SpikeTimit
end
## make fancy raste plot
struct TracePlot{I,T}
indices::I
plt::Plot{T}
sp::Subplot{T}
series::Vector{Series}
end
function TracePlot(n::Int = 1; maxn::Int = typemax(Int), sp = nothing, kw...)
clist= get(ColorSchemes.colorschemes[:viridis],range(0,1,length=n))
indices = if n > maxn
shuffle(1:n)[1:maxn]
else
1:n
end
if sp == nothing
plt = scatter(length(indices);kw...)
sp = plt[1]
else
plt = scatter!(sp, length(indices); kw...)
end
for n in indices
sp.series_list[n].plotattributes[:markercolor]=clist[n]
sp.series_list[n].plotattributes[:markerstrokecolor]=clist[n]
end
TracePlot(indices, plt, sp, sp.series_list)
end
function Base.push!(tp::TracePlot, x::Number, y::AbstractVector)
push!(tp.series[x], y, x .*ones(length(y)))
end
Base.push!(tp::TracePlot, x::Number, y::Number) = push!(tp, [y], x)
function raster_plot(spikes; ax=nothing, alpha=0.5)
s = TracePlot(length(spikes), legend=false, alpha)
ax = isnothing(ax) ? s : ax
for (n,z) in enumerate(spikes)
if length(z)>0
push!(ax,n,z[:])
end
end
return ax
end
########################
## Stack spikes together
########################
"""
Extract all the firing times and the corresponding neurons from an array with all the neurons and their relative firing times. i.e. the reverse_dictionary
Extract all the firing times and the corresponding neurons from an array with
all the neurons and their relative firing times. i.e. the inverse_dictionary
"""
function reverse_dictionary(spikes,dt::Float64)
function inverse_dictionary(spikes::Spiketimes)
all_times = Dict()
for n in eachindex(spikes)
if !isempty(spikes[n])
for t in spikes[n]
tt = round(Int,t*1000/dt) ## from seconds to timesteps
for tt in spikes[n]
# tt = round(Int,t*1000/dt) ## from seconds to timesteps
if haskey(all_times,tt)
all_times[tt] = [all_times[tt]..., n]
else
......@@ -201,7 +165,7 @@ module SpikeTimit
"""
From the reverse_dictionary data structure obtain 2 arrays that are faster to access in the simulation loop.
From the inverse_dictionary data structure obtain 2 arrays that are faster to access in the simulation loop.
1. First array contains the sorted spike times.
2. The second array contains vectors with each firing neuron
"""
......@@ -216,124 +180,186 @@ module SpikeTimit
"""
Stack together different input maps: (now using memory allocation)
- spikes is an array with inputs
Stack together spiketimes sequences:
- spikes is an array with inputs in the form Vectr
- durations is the duration in seconds of each encoding
"""
function stack_spiketimes(spikes, durations, silence_time; dt=0.1)
function stack_spiketimes(spikes::Vector{Spiketimes}, durations::Vector{Float64}, silence_time::Float64)
# for the memory allocation
nr_unique_fts = 0
for spike_times in spikes
dict = reverse_dictionary(spike_times, dt)
nr_unique_fts += length(collect(Set(dict)))
nr_unique_fts +=length(inverse_dictionary(spike_times))
end
all_neurons = Vector{Vector{Int}}(undef, nr_unique_fts)
all_ft = Vector{Float64}(undef, nr_unique_fts)
global_time_step = 0
# memory allocation!! Otherwise stack overflow.
all_neurons = Vector{Any}(undef, nr_unique_fts)
all_ft = Vector{Any}(undef, nr_unique_fts)
global_time = 0
filled_indices = 0
for (spike_times,dd) in zip(spikes, durations)
dictionary = reverse_dictionary(spike_times, dt)
sorted, neurons = sort_spikes(dictionary)
#shift time:
sorted .+= global_time_step + silence_time
for (spike_times, dd) in zip(spikes, durations)
# get spiketimes
sorted, neurons = sort_spikes(inverse_dictionary(spike_times))
#shift time for each neuron:
sorted .+= global_time
## put them together
lower_bound = filled_indices + 1
lower_bound = filled_indices +1
filled_indices += size(sorted,1)
all_ft[lower_bound:filled_indices] = sorted
all_neurons[lower_bound:filled_indices] = neurons
global_time_step += round(Int, dd*1000/dt)
global_time += dd
global_time += silence_time
end
@assert(size(all_ft) == size(all_neurons))
return all_ft, all_neurons
end
"""
function stack_labels(labels, durations, silence_time)
phones_transcripts = Transcription()
words_transcripts = Transcription()
global_time = 0.
for (label, dd) in zip(labels, durations)
@assert(label.duration == dd)
push!(words_transcripts.intervals,(global_time, global_time+dd))
push!(words_transcripts.sign,label.word)
for ph in label.phones
push!(phones_transcripts.intervals, (global_time+ph.t0, global_time+ph.t1))
push!(phones_transcripts.sign,ph.ph)
end
global_time += dd + silence_time
end
return words_transcripts,phones_transcripts
end
#="""
Return the input of the spike trains corrsponding to the rows, in the (spike-time -> neuron ) format
"""
function inputs_from_df(df, rows; dt=0.1)
function inputs_from_df(df, rows)
_df = df[rows,:]
## Get the duration of each frame: it corresponds to the (last) symbol h# in the phones array
durations = _df.phones |> phones->map(phone->phone[end], phones)
spikes = SpikeTimit.get_spiketimes(df=_df)
all_ft, all_n = stack_spiketimes(spikes, durations, dt)
all_ft, all_n = stack_spiketimes(spikes, durations)
@assert(size(all_ft) == size(all_n))
return all_ft, all_n, durations
end
=#
"""
Get spiketimes corresponding to the word
"""
function get_word_spiketimes(;df, word::String)
df = find_word(;df=df,word=word)
intervals = get_interval_word(; df=df, word=word)
phones = get_phones_in_word(;df=df, word=word)
##########################
## Get words and phonemes
##########################
struct Phone
ph::String
t0::Float64
t1::Float64
end
struct Word
word::String
phones::Vector{Phone}
duration::Float64
t0::Float64
end
spikes = get_spiketimes(;df)
spiketimes, duration = get_spikes_in_interval(; spiketimes=spikes,intervals=intervals)
struct Transcription
intervals::Vector{Tuple{Float64,Float64}}
steps::Vector{Tuple{Int,Int}}
sign::Vector{String}
function Transcription()
new([],[],[])
end
end
return spiketimes, duration, phones
function convert_to_dt(data::Transcription,dt::Float64)
for n in eachindex(data.intervals)
t0,t1 = data.intervals[n]
push!(data.steps,(round(Int,t0*1000/dt), round(Int, t1*1000/dt)))
end
return data
end
function convert_to_dt(data::Vector{Float64},dt::Float64)
return round.(Int,data ./dt .*1000)
end
"""
For each word extract the interval corresponding to it
Extract the word and the phones contained in the datataset and matching with the target word.
Each row addresses a df entry, each row contains all the phones and word labels, with time intervals.
"""
function get_interval_word(;df, word::String)
intervals = []
function get_word_labels(;df, word::String)
df_phones = Vector{Vector{Word}}()
for row in eachrow(df)
interval = [0.,0.]
all_phones = Vector{Word}()
for my_word in eachrow(row.words)
if my_word[1] == word
interval = my_word[2:3]
t0,t1 = my_word[2:3]
word_phones = Vector{Phone}()
for phone in eachrow(row.phones)
if (phone[2] >= t0) && (phone[3]<= t1)
ph = Phone(phone[1], phone[2]-t0,phone[3]-t0)
push!(word_phones, ph)
end
end
push!(all_phones, Word(String(my_word[1]), word_phones,t1-t0,t0))
end
end
push!(intervals, interval)
push!(df_phones, all_phones)
end
isempty(intervals)
return intervals
return df_phones
end
function get_spikes_in_word(; df, word::String)
spikes, durations = get_spikes_in_interval(spiketimes = get_spiketimes(df=df), df_intervals = get_interval_word(df=df, word=word))
return spikes, durations
end
"""
For each word extract the interval corresponding to it
For each realization of a word in a dataset entry, extract the interval corresponding to it
Return all the intervals for each dataset entry
"""
function get_phones_in_word(;df, word::String)
intervals = []
function get_interval_word(;df, word::String)
df_intervals = []
for row in eachrow(df)
intervals = []
interval = [0.,0.]
n = 0
for my_word in eachrow(row.words)
if my_word[1] == word
interval = my_word[2:3]
push!(intervals, interval)
end
end
push!(intervals, interval)
push!(df_intervals, intervals)
end
isempty(intervals)
return intervals
return df_intervals
end
"""
Return the spiketimes subset corresponding to the selected interval
Return the spiketimes subset corresponding to the selected interval, for vectors of Spiketimes
"""
function get_spikes_in_interval(; spiketimes::Union{Spiketimes, Array{Spiketimes}}, intervals)
function get_spikes_in_interval(; spiketimes::Union{Spiketimes, Array{Spiketimes}}, df_intervals)
new_spiketimes = Vector{Spiketimes}()
durations = Vector()
if isa(spiketimes,Spiketimes)
for interval in intervals
new_spikes, duration = _get_spikes_in_interval(spiketimes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
for intervals in df_intervals
for interval in intervals
new_spikes, duration = _get_spikes_in_interval(spiketimes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
end
end
else
@assert(length(spiketimes) == length(intervals))
for (spikes, interval) in zip(spiketimes, intervals)
new_spikes, duration = _get_spikes_in_interval(spikes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
@assert(length(spiketimes) == length(df_intervals))
for (spikes, intervals) in zip(spiketimes, df_intervals)
for interval in intervals
new_spikes, duration = _get_spikes_in_interval(spikes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
end
end
end
return new_spiketimes, durations
......@@ -341,7 +367,7 @@ module SpikeTimit
"""
Return the spiketimes subset corresponding to the selected interval, for one sentence
Return the spiketimes subset corresponding to the selected interval, for one Spiketimes
"""
function _get_spikes_in_interval(spikes, interval)
new_spiketimes=Spiketimes()
......@@ -358,23 +384,104 @@ module SpikeTimit
end
function resample_spikes(;spiketimes::Array{Spiketimes,1}, n_feat=8)
for s in eachindex(spiketimes)
spiketimes[s] = _resample_spikes(spiketimes=spiketimes[s], n_feat=n_feat)
function select_inputs(; df, words, samples=10, n_feat = 7)
all_spikes = Vector{Spiketimes}()
all_durations = Vector{Float64}()
all_labels = []
for (i, word) in enumerate(words)
df_word = find_word(word=word, df=df)
n_occurences = size(df_word,1)
@show n_occurences
#randomly order the number of occurences to sample
if samples <= n_occurences
inds = randperm(n_occurences)[1:samples]
else
message = string("WARNING: for word: '", word, "', samples per word (", samples, ") exceeds the number of occurences (", n_occurences, ")")
@assert false message
end
###Get intervals and phonemes for each dataset entry (some can have more than one!)
spiketimes, durations = get_spikes_in_word(; df=df_word[inds,:], word)
spiketimes = resample_spikes(spiketimes=spiketimes, n_feat=n_feat)
labels = vcat(get_word_labels(;df=df_word[inds,:], word=word)...)
@show length(labels), length(spiketimes)
@assert(length(spiketimes) == length(labels))
push!(all_spikes, spiketimes...)
push!(all_durations, durations...)
push!(all_labels,labels...)
end
return spiketimes
return all_durations, all_spikes, all_labels
end
function mix_inputs(;durations, spikes, labels, repetitions, silence_time)
ids = shuffle(repeat(1:length(durations), repetitions))
all_ft, all_n = stack_spiketimes(spikes[ids], durations[ids], silence_time)
words_t, phones_t = SpikeTimit.stack_labels(labels[ids],durations[ids],silence_time)
return all_ft, all_n, words_t, phones_t
end
function get_savepoints(;trans::Transcription, n_measure::Int)
measures = []
labels = []
for s in eachindex(trans.steps)
step = trans.steps[s]
sign = trans.sign[s]
l = step[2] - step[1]
l_single = floor(Int, l/n_measure)
push!(measures,step[1] .+ collect(1:n_measure).* l_single)
push!(labels, repeat([sign], n_measure))
end
return measures, labels
end
"""
Unify frequencies bin
Input: spiketimes array with elements of size 620 elements
Return array with sorted spikes in less classes
nfeat new_bins Rounding last_bin
2 15,5 <-- 2*15=30 <DO> add last to last bin
3 10,333333 <-- 3*10=30 <DO> add last to last bin
4 7,75 --> 4*8=32 <DO>
5 6,2 <-- 5*6=30 <DO> add last to last bin
6 5,16666 <-- 6*5=30 <DO> add last to last bin
7 4,429 <-- 7*4=28 <DO> add last 3 to last bin
8 3,875 --> 8*4=32 <DO>
9 3,44444 <-- 9*3=27 <DO> add last 4 to last bin
10 3,1 <-- 10*3=30 <DO> add last to last bin
11 2,818181 --> 11*3=33 <DO>
"""
function _resample_spikes(;spiketimes::Spiketimes, n_feat=8)
function resample_spikes(;spiketimes::Vector{Spiketimes},n_feat)
for s in 1:length(spiketimes)
spiketimes[s] = _resample_spikes(spiketimes=spiketimes[s], n_feat=n_feat)
end
return spiketimes
end
function _resample_spikes(;spiketimes::Spiketimes, n_feat)
# If we don't reduce the bins
if n_feat == 1
return spike_times
elseif n_feat > 11 || n_feat < 1
prinln("WARNING; you are crazy, returning original spike_times")
return spike_times
end
FREQUENCIES = 20
old_bins = length(spiketimes)÷FREQUENCIES
new_bins = old_bins÷n_feat+1
old_bins = convert(Int64, length(spiketimes)/FREQUENCIES)
@assert (old_bins==31) "WARNING: old_bins != 31, this function is probably broken for other values than 31 (super hardcoded)"
new_bins = round(Int, old_bins/n_feat - 0.1)
add_last = 0
if n_feat*new_bins < 31
add_last = 31-n_feat*new_bins
end
new_spikes = map(x->Vector{Float64}(),1:new_bins*FREQUENCIES)
for freq in 1:FREQUENCIES
......@@ -382,6 +489,9 @@ module SpikeTimit
new_freq = (freq-1)*new_bins
for new_bin in 1:new_bins
last_bin = new_bin*n_feat <32 ? new_bin*n_feat : 31
if new_bin == new_bins
last_bin += add_last
end
bins = 1+(new_bin-1)*n_feat : last_bin
for old_bin in bins
push!(new_spikes[new_bin+new_freq], spiketimes[old_bin+old_freq]...)
......@@ -392,31 +502,6 @@ module SpikeTimit
end
"""
Extract all the phonemes contained in the word and their time intervals.
Run over all the sentences in the dataset and return a list of relevant phonemes,
each row is the corresponding sentence (df entry), each row contains all the phones.
"""
function get_phones_in_word(;df, word)
all_phones = []
for row in eachrow(df)
for my_word in eachrow(row.words)
word_phones = []
if my_word[1] == word
t0,t1 = my_word[2:3]
phones = []
for phone in eachrow(row.phones)
if (phone[2] >= t0) && (phone[3]<= t1)
push!(phones, collect(phone))
end
end
push!(word_phones, phones)
end
push!(all_phones, word_phones...)
end
end
return all_phones
end
"""
Get firing times for each phones list.
......@@ -438,180 +523,94 @@ module SpikeTimit
end
end
function transform_into_bursts(all_ft, all_neurons)
new_all_ft = []
new_all_neurons = []
expdist = Exponential(5)
for (i, time) in enumerate(all_ft)
# determine X (amount of spikes in burst) -> bias dice
values = [2,3,4,5,6]
weights = [0.8, 0.15, 0.075, 0.035, 0.03] # based on plot 1B 0.7 nA (Oswald, Doiron & Maler (2007))
weights = weights ./ sum(weights) # normalized weights
number_of_spikes = sample(values, Weights(weights)) - 1 # -1 because first spike is determined from data
# determine interval from time to spike (for all X new spikes)
push!(new_all_ft, time)
push!(new_all_neurons, all_neurons[i])
for j in 1:number_of_spikes
interval = rand(expdist)
new_time = time + 4 + interval
push!(new_all_ft, new_time)
push!(new_all_neurons, all_neurons[i])
end
end
# rounding
for i in 1:size(new_all_ft,1)
new_all_ft[i] = round(Int,new_all_ft[i])
new_all_ft[i] = convert(Float64, new_all_ft[i])
end
# sorting
zipped = DataFrame(ft = new_all_ft, neurons = new_all_neurons)
zipped = sort!(zipped, [:ft])
# if two rows have same time combine into 1 row with both their neurons
for (i, row) in enumerate(eachrow(zipped))
if i != size(zipped,1)
if row.ft == zipped.ft[i+1] #compare current spike time with next time
next_row = copy(zipped.neurons[i+1]) #spiking neurons of next row
new_row = vcat(copy(row.neurons), copy(next_row)) # concatenating current neurons + next neurons
zipped.neurons[i+1] = copy(new_row) # assigning all neurons to next time
zipped.ft[i] = -1.0 # setting time of the row to -1 (so filter can take it out later)
end
end
end
zipped = filter(row -> row[:ft] != -1.0, zipped)
return zipped.ft, zipped.neurons
end
########################
## Raster Plot ####
########################
struct TracePlot{I,T}
indices::I
plt::Plot{T}
sp::Subplot{T}
series::Vector{Series}
end
function TracePlot(n::Int = 1; maxn::Int = typemax(Int), sp = nothing, kw...)
clist= get(ColorSchemes.colorschemes[:viridis],range(0,1,length=n))