Commit a60add9f authored by Alessio Quaresima's avatar Alessio Quaresima
Browse files

cochlear encoding added to the SpikeTimit

parent b87b119b
using MAT
using NPZ
"""
Get words and time intervals from the SenID
"""
......@@ -109,18 +111,37 @@ end
function get_spiketimes(;df::DataFrame)
get_spiketimes_file(file)= file*"_spike_timing.mat"
get_array(value::Float64) = begin x=zeros(Float64, 1); x[1] = value; x end
spikes =[]
for f in df.path
fp = matopen(get_spiketimes_file(f))
function get_spiketimes(;df::DataFrame, encoding::String, channels::Union{Int, Nothing}=nothing)
spiketimes=Vector{Spiketimes}()
for file in df.path
if encoding =="BAE"
push!(spiketimes, bae2spiketimes(file))
elseif encoding =="cochlea"
@assert !isnothing(channels)
push!(spiketimes, npz2spiketimes(file, channels))
end
end
return spiketimes
end
function bae2spiketimes(file)
get_array(value::Float64) = begin x=zeros(Float64, 1); x[1] = value; x end
fp = matopen(file*"_spike_timing.mat")
spike = read(fp)|> x->get(x,"spike_output", nothing)[1,:]
map(row->spike[row] = spike[row][:], findall(typeof.(spike) .==Array{Float64,2}))
map(row->spike[row] = get_array(spike[row]), findall(typeof.(spike) .==Float64))
@assert length(spike) == length(findall(x-> isa(x, Vector{Float64}), spike))
push!(spikes,spike .* correction)
close(fp)
end
map(x->Spiketimes(x), spikes)
return Spiketimes(spike .* correction)
end
function npz2spiketimes(file::String, channels::Int)
fp = file*"_cochlea$channels.npz"
matrix = npzread(fp)["arr_0"]
spiketimes=Spiketimes()
for n in 1:channels
push!(spiketimes, matrix[1,findall(x-> x==n, matrix[2,:])])
end
return spiketimes
end
......@@ -172,15 +172,15 @@ function get_word_labels(df, word::String)
return df_phones
end
function get_word_spikes(df::DataFrame, words::Union{String, Vector{String}})
function get_word_spikes(df::DataFrame, words::Union{String, Vector{String}}; encoding, channels=70)
if isa(words, String)
return get_spikes_in_interval(get_spiketimes(df=df), _get_word_interval(df,words))
return get_spikes_in_interval(get_spiketimes(df=df, encoding=encoding, channels=channels), _get_word_interval(df,words))
else
spikes = Vector{Spiketimes}()
durations = Vector{Float64}()
_durations=
for word in words
_spikes, _durations = get_spikes_in_interval(get_spiketimes(df=df), _get_word_interval(df,word))
_spikes, _durations = get_spikes_in_interval(get_spiketimes(df=df, encoding=encoding, channels=channels), _get_word_interval(df,word))
push!(spikes, _spikes...)
push!(durations, _durations...)
end
......@@ -259,7 +259,7 @@ end
## Select and prepare spike inputs for network simulation
function select_inputs( df::DataFrame, words::Vector{String};
samples::Union{Int, Nothing}=nothing,
samples::Union{Int, Nothing}=nothing, encoding, channels=nothing,
kwargs...)
"""
Select inputs in the dataset and pre-process them for the simulation
......@@ -292,7 +292,8 @@ function select_inputs( df::DataFrame, words::Vector{String};
end
end
###Get intervals and phonemes for each dataset entry (some can have more than one!)
spiketimes, durations = get_word_spikes(df_word[inds,:], word)
spiketimes, durations = get_word_spikes(df_word[inds,:], word,
encoding=encoding, channels=channels)
resample_spikes!(spiketimes; kwargs...)
transform_into_bursts!(spiketimes; kwargs...)
labels = vcat(get_word_labels(df_word[inds,:],word)...)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment