Commit b6fc835a authored by alessio.quaresima's avatar alessio.quaresima
Browse files

first committ

parents
using Base
module SpikeTimit
using DataFrames
using DataFramesMeta
using OrderedCollections
using MAT
using Plots
using ColorSchemes
## This is the sample rate used in the scripts.
sr = 16000
## This is the rescaling factor for the spike-time discussed in the PDF
correction = 5.
import Plots: Series, Plot, Subplot
## Create a dictionary with all words in the dataset
function create_dictionary(;file)
dict = OrderedDict()
for line in readlines(file)
if !startswith(line, ";")
word, sounds = split(line,"/")
# @show word, sounds
push!(dict, word=>sounds)
end
end
return dict
end
function get_words(root, senID)
path = joinpath(root, senID*".wrd")
times0 = []
times1 = []
words = []
for line in readlines(path)
t0,tf,w = split(line)
push!(times0,parse(Int,t0))
push!(times1,parse(Int,tf))
push!(words, String(w))
end
_data = Array{Union{String,Float64},2}(undef,length(times0),3)
_data[:,1] = words
_data[:,2] = times0 ./ sr
_data[:,3] = times1 ./ sr
return _data
end
function get_phones(root, senID)
path = joinpath(root, senID*".phn")
times0 = []
times1 = []
phones = []
for line in readlines(path)
t0,tf,p = split(line)
push!(times0,parse(Int,t0))
push!(times1,parse(Int,tf))
push!(phones, String(p))
end
_data = Array{Union{String,Float64},2}(undef,length(times0),3)
_data[:,1] = phones
_data[:,2] = times0 ./ sr
_data[:,3] = times1 ./ sr
return _data
end
function create_dataset(;dir)
df = DataFrame(speaker = String[] , senID = String[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[])
for (root, dirs, files) in walkdir(dir)
for file in files
if endswith(file,"wav")
speaker = split(root, "/")[end]
senID = split(file,".")[1]
words = get_words(root, senID)
phones = get_phones(root, senID)
# @show phones
push!(df,(speaker,senID,joinpath(root,senID),words,phones))
end
end
end
return df
end
function find_word(;word, df)
"""Search for word in the dataset and return the items"""
return @linq transform(filter(:words => x-> word x , df), :senID) |> unique
end
function get_spiketimes(;df)
get_spiketimes_file(file)= file*"_spike_timing.mat"
if length(size(df)) == 1 # only 1 row
spikes = [read(matopen(get_spiketimes_file(df.path)))["spike_output"][1,:]]
else
spikes = map(x->x[1,:],get_spiketimes_file.(df.path) |> x->matopen.(x) |> x-> read.(x) |> x->get.(x,"spike_output", nothing) )
end
map(spike->map(row->spike[row] = [spike[row]], findall(typeof.(spike) .==Float64)), spikes)
map(spike->findall(isa.(spike,Matrix)) |> x->spike[x] = spike[x]*correction, spikes)
return spikes
end
function get_matrix(;df)
get_matrix_file(file) = file*"_binary_matrix.mat"
return get_matrix_file.(df.path)
end
function get_file(file, ext)
return file*"."*ext
end
## make fancy raste plot
struct TracePlot{I,T}
indices::I
plt::Plot{T}
sp::Subplot{T}
series::Vector{Series}
end
function TracePlot(n::Int = 1; maxn::Int = typemax(Int), sp = nothing, kw...)
clist= get(ColorSchemes.colorschemes[:viridis],range(0,1,length=n))
indices = if n > maxn
shuffle(1:n)[1:maxn]
else
1:n
end
if sp == nothing
plt = scatter(length(indices);kw...)
sp = plt[1]
else
plt = scatter!(sp, length(indices); kw...)
end
for n in indices
sp.series_list[n].plotattributes[:markercolor]=clist[n]
sp.series_list[n].plotattributes[:markerstrokecolor]=clist[n]
end
TracePlot(indices, plt, sp, sp.series_list)
end
function Base.push!(tp::TracePlot, x::Number, y::AbstractVector)
push!(tp.series[x], y, x .*ones(length(y)))
end
Base.push!(tp::TracePlot, x::Number, y::Number) = push!(tp, [y], x)
function raster_plot(spikes; alpha=0.5)
s = TracePlot(length(spikes), legend=false)
for (n,z) in enumerate(spikes)
if length(z)>0
push!(s,n,z[1,:])
end
end
return s
end
###########################################################
"""
Extract all the firing times and the corresponding neurons from an array with all the neurons and their relative firing times. i.e. the reverse_dictionary
"""
function reverse_dictionary(spikes,dt::Float64)
all_times = Dict()
for n in eachindex(spikes)
if !isempty(spikes[n])
for t in spikes[n]
tt = round(Int,t*1000/dt) ## from seconds to timesteps
if haskey(all_times,tt)
all_times[tt] = [all_times[tt]..., n]
else
push!(all_times, tt=>[n])
end
end
end
end
return all_times
end
"""
From the reverse_dictionary data structure obtain 2 arrays that are faster to access in the simulation loop.
1. First array contains the sorted spike times.
2. The second array contains vectors with each firing neuron
"""
function sort_spikes(dictionary)
neurons = Array{Vector{Int}}(undef, length(keys(dictionary)))
sorted = sort(collect(keys(dictionary)))
for (n,k) in enumerate(sorted)
neurons[n] = dictionary[k]
end
return sorted, neurons
end
"""
Stack together different input maps:
- spikes is an array with inputs
- duration is the duration in seconds of each encoding
"""
function stack_spiketimes(spikes, duration, dt)
global_time_step = 0
all_neurons = []
all_ft = []
for (spike_times,dd) in zip(spikes, duration)
dictionary = reverse_dictionary(spikes[1], dt)
sorted, neurons = sort_spikes(dictionary)
#shift time:
sorted .+= global_time_step
## put them together
all_neurons = vcat(all_neurons, neurons)
all_ft = vcat(all_ft..., sorted)
global_time_step += round(Int, dd*1000/dt)
end
return all_ft, all_neurons
end
"""
Return the input of the spike trains corrsponding to the rows, in the (spike-time -> neuron ) format
"""
function inputs_from_df(df, rows; dt=0.1)
_df = df[rows,:]
## Get the duration of each frame: it corresponds to the (last) symbol h# in the phones array
duration = df.phones |> phones->map(phone->phone[end], phones)
spikes= SpikeTimit.get_spiketimes(df=_df)
all_ft, all_n = stack_spiketimes(spikes, duration, dt)
@assert(size(all_ft) == size(all_n))
return all_ft, all_n
end
"""
For each word extract the interval corresponding to it
"""
function get_interval_word(;df, word::String)
intervals = []
for row in eachrow(df)
interval = [0.,0.]
for my_word in eachrow(row.words)
if my_word[1] == word
interval = my_word[2:3]
end
end
push!(intervals, interval)
end
isempty(intervals)
return intervals
end
"""
Return the spiketimes subset corresponding to the selected interval, for each sentence
"""
function get_spikes_in_interval(; spiketimes, intervals)
new_spiketimes = Vector()
durations = Vector()
for (spikes, interval) in zip(spiketimes, intervals)
new_spikes, duration = _get_spikes_in_interval(spikes, interval)
push!(new_spiketimes, new_spikes)
push!(durations, duration)
end
return new_spiketimes, durations
end
"""
Return the spiketimes subset corresponding to the selected interval, for one sentence
"""
function _get_spikes_in_interval(spikes, interval)
for n in eachindex(spikes)
in_interval(x) = (x > interval[1])*(x < interval[end])
spikes[n] = filter(in_interval,spikes[n]) .- interval[1]
end
return spikes, interval[end] - interval[1]
end
end
## import the dataset. Notice, the spike-time are not imported, only the files are stored and ready to be read.
include("SpikeTimit.jl")
using .SpikeTimit
test_path = joinpath(@__DIR__,"Spike TIMIT", "test" )
train_path = joinpath(@__DIR__,"Spike TIMIT", "train" )
dict_path = joinpath(@__DIR__,"DOC","TIMITDIC.TXT")
dict = SpikeTimit.create_dictionary(file=dict_path)
train = SpikeTimit.create_dataset(;dir= train_path)
test = SpikeTimit.create_dataset(;dir= test_path)
## function get_spiketimes(;df, word)
my_sent = train[1,:]
my_sent.words
#we want the dark word
spikes = SpikeTimit.get_spiketimes(;df=my_sent)
my_sent.words[:,3]
counter = 0
for x in new_spiketimes
counter += length(x)
end
counter
##
using .SpikeTimit
project= SpikeTimit.find_word(; word="project", df=train)
intervals = SpikeTimit.get_interval_word(;df=project, word="project")
spiketimes= SpikeTimit.get_spiketimes(; df=project)
sts, durations = SpikeTimit.get_spikes_in_interval(;spiketimes=spiketimes, intervals=intervals)
SpikeTimit.stack_spiketimes(sts, durations, 0.1)
isempty(Matrix{Float64}(undef,0,0))
intervals[1]
##
prompts = joinpath(@__DIR__,"DOC","PROMPTS.TXT")
function get_sentences(;prompts)
dictionary=Dict()
for line in readlines(prompts)
if !startswith(line,";")
words = split(line, " ")
for word in words
a = lowercase(replace(word,[',','.',';']=>""))
if !haskey(dictionary,a)
push!(dictionary, a=>1)
else
dictionary[a]+=1
end
end
end
end
return dictionary
end
words = get_sentences(prompts=prompts)
sort(words, rev=true, by=values)
sort(collect(words), by=x->x[2])
# Select a subset of the dataset. Here for convenience I created a query to find all the sentences that contain the word "spring" in the train set.
# You can look up at the query and do others on the base of the DataFrame style. I suggest you to use the @linq macro and read the documentation carefully.
d_word = SpikeTimit.find_word(word="spring", df=train)
# Obviously, you can also choose the sentences by selecting some specific rows. Careful, the dataset has not an ordering.
d_number = train[1,:]
# Once you have sub-selected some columns, you can import the spike times.
# they are already corrected as explained in the PDF
spikes= SpikeTimit.get_spiketimes(df=d_number)
## Measure the spike frequency on a sample
counter = 0
for x in spikes[1]
counter += length(x)
end
duration = d_number[:phones] |> phones->map(phone->phone[end], phones)
## Data from LDK
jex =1.78
rx_n = 4.5e3
g_noise = jex*rx_n
spike_per_pop = counter/620/2.92 ## Hz per population
pmemb = 0.2
jex_s = 50
g_signal = spike_per_pop *620*pmemb*jex_s
## Now let's convert the spike-times into a convenient input.
test_path = "/run/media/cocconat/data/speech_encoding/Spike TIMIT/Spike TIMIT/test"
test = SpikeTimit.create_dataset(;dir= test_path)
## select two samples
d_number = test[1:2,:]
spikes= SpikeTimit.get_spiketimes(df=d_number)
## Show how it works for a single sample
dictionary = SpikeTimit.reverse_dictionary(spikes[1], 0.1)
sorted, neurons = SpikeTimit.sort_spikes(dictionary)
## check the length of these array is the same
all_ft, all_n = SpikeTimit.inputs_from_df(test, [1,2])
## This is the loop in the simulation, place it in the input section
firing_index= 1
next_firing_time= all_ft[firing_index]
for tt in 1:10000
if tt == next_firing_time
firing_neurons = all_n[firing_index]
firing_index +=1
next_firing_time = all_ft[firing_index]
println("At time step: ", tt)
println("These neurons fire: ", firing_neurons)
end
end
using Plots
plot(length.(all_n))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment