Commit a60add9f authored by Alessio Quaresima's avatar Alessio Quaresima
Browse files

cochlear encoding added to the SpikeTimit

parent b87b119b
using MAT using MAT
using NPZ
""" """
Get words and time intervals from the SenID Get words and time intervals from the SenID
""" """
...@@ -109,18 +111,37 @@ end ...@@ -109,18 +111,37 @@ end
function get_spiketimes(;df::DataFrame) function get_spiketimes(;df::DataFrame, encoding::String, channels::Union{Int, Nothing}=nothing)
get_spiketimes_file(file)= file*"_spike_timing.mat" spiketimes=Vector{Spiketimes}()
get_array(value::Float64) = begin x=zeros(Float64, 1); x[1] = value; x end for file in df.path
spikes =[] if encoding =="BAE"
for f in df.path push!(spiketimes, bae2spiketimes(file))
fp = matopen(get_spiketimes_file(f)) elseif encoding =="cochlea"
@assert !isnothing(channels)
push!(spiketimes, npz2spiketimes(file, channels))
end
end
return spiketimes
end
function bae2spiketimes(file)
get_array(value::Float64) = begin x=zeros(Float64, 1); x[1] = value; x end
fp = matopen(file*"_spike_timing.mat")
spike = read(fp)|> x->get(x,"spike_output", nothing)[1,:] spike = read(fp)|> x->get(x,"spike_output", nothing)[1,:]
map(row->spike[row] = spike[row][:], findall(typeof.(spike) .==Array{Float64,2})) map(row->spike[row] = spike[row][:], findall(typeof.(spike) .==Array{Float64,2}))
map(row->spike[row] = get_array(spike[row]), findall(typeof.(spike) .==Float64)) map(row->spike[row] = get_array(spike[row]), findall(typeof.(spike) .==Float64))
@assert length(spike) == length(findall(x-> isa(x, Vector{Float64}), spike)) @assert length(spike) == length(findall(x-> isa(x, Vector{Float64}), spike))
push!(spikes,spike .* correction)
close(fp) close(fp)
end return Spiketimes(spike .* correction)
map(x->Spiketimes(x), spikes) end
function npz2spiketimes(file::String, channels::Int)
fp = file*"_cochlea$channels.npz"
matrix = npzread(fp)["arr_0"]
spiketimes=Spiketimes()
for n in 1:channels
push!(spiketimes, matrix[1,findall(x-> x==n, matrix[2,:])])
end
return spiketimes
end end
...@@ -172,15 +172,15 @@ function get_word_labels(df, word::String) ...@@ -172,15 +172,15 @@ function get_word_labels(df, word::String)
return df_phones return df_phones
end end
function get_word_spikes(df::DataFrame, words::Union{String, Vector{String}}) function get_word_spikes(df::DataFrame, words::Union{String, Vector{String}}; encoding, channels=70)
if isa(words, String) if isa(words, String)
return get_spikes_in_interval(get_spiketimes(df=df), _get_word_interval(df,words)) return get_spikes_in_interval(get_spiketimes(df=df, encoding=encoding, channels=channels), _get_word_interval(df,words))
else else
spikes = Vector{Spiketimes}() spikes = Vector{Spiketimes}()
durations = Vector{Float64}() durations = Vector{Float64}()
_durations= _durations=
for word in words for word in words
_spikes, _durations = get_spikes_in_interval(get_spiketimes(df=df), _get_word_interval(df,word)) _spikes, _durations = get_spikes_in_interval(get_spiketimes(df=df, encoding=encoding, channels=channels), _get_word_interval(df,word))
push!(spikes, _spikes...) push!(spikes, _spikes...)
push!(durations, _durations...) push!(durations, _durations...)
end end
...@@ -259,7 +259,7 @@ end ...@@ -259,7 +259,7 @@ end
## Select and prepare spike inputs for network simulation ## Select and prepare spike inputs for network simulation
function select_inputs( df::DataFrame, words::Vector{String}; function select_inputs( df::DataFrame, words::Vector{String};
samples::Union{Int, Nothing}=nothing, samples::Union{Int, Nothing}=nothing, encoding, channels=nothing,
kwargs...) kwargs...)
""" """
Select inputs in the dataset and pre-process them for the simulation Select inputs in the dataset and pre-process them for the simulation
...@@ -292,7 +292,8 @@ function select_inputs( df::DataFrame, words::Vector{String}; ...@@ -292,7 +292,8 @@ function select_inputs( df::DataFrame, words::Vector{String};
end end
end end
###Get intervals and phonemes for each dataset entry (some can have more than one!) ###Get intervals and phonemes for each dataset entry (some can have more than one!)
spiketimes, durations = get_word_spikes(df_word[inds,:], word) spiketimes, durations = get_word_spikes(df_word[inds,:], word,
encoding=encoding, channels=channels)
resample_spikes!(spiketimes; kwargs...) resample_spikes!(spiketimes; kwargs...)
transform_into_bursts!(spiketimes; kwargs...) transform_into_bursts!(spiketimes; kwargs...)
labels = vcat(get_word_labels(df_word[inds,:],word)...) labels = vcat(get_word_labels(df_word[inds,:],word)...)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment