Commit 876384f3 authored by alessio.quaresima's avatar alessio.quaresima
Browse files

New SpikeTimit

parent 00d0bf6f
This diff is collapsed.
### A Pluto.jl notebook ###
# v0.12.21
using Markdown
using InteractiveUtils
# ╔═╡ 70951726-8063-11eb-2ba9-1fa3822b9e91
using .SpikeTimit
# ╔═╡ c0d57c40-8010-11eb-3004-195f0590db26
md"""
This notebook will show how to import data from SpikeTimit[1] database and run it as a stimulus for the LKD network[2]. The SpikeTimit.jl module import the standard daataset released with the publication [1].
"""
# ╔═╡ 62beb3c0-8011-11eb-1ab5-8de2f870d0b2
md""" Import the module and the relevant packages"""
# ╔═╡ 8015ec4a-8011-11eb-04d6-9bcd09fada86
PATH = joinpath(@__DIR__,"SpikeTimit.jl")
# ╔═╡ 5e4f6080-8063-11eb-39b1-ab78ccdbf423
include(PATH)
# ╔═╡ 24e8c3e2-8064-11eb-23dc-2f6848c499e5
# ╔═╡ 693fb58a-8063-11eb-3e59-c9249009d1d6
# ╔═╡ 4566f194-8012-11eb-39d7-ad467788a78b
md"""
Import the dataset. Notice, the spike-time are not imported, only the files are stored and ready to be read. You have to set your PATH fort the dataset
"""
# ╔═╡ 42f1c31e-8063-11eb-0059-75ffb8aa555a
begin
test_path = joinpath(@__DIR__,"Spike TIMIT", "test" );
train_path = joinpath(@__DIR__,"Spike TIMIT", "train" );
dict_path = joinpath(@__DIR__,"DOC","TIMITDIC.TXT");
end
# ╔═╡ 39885efc-8064-11eb-071d-c1eaa5f8892b
# ╔═╡ d3e653d0-8063-11eb-15e8-019ae2ff331a
md""" dict is a list of all words with the correesponding phones."""
# ╔═╡ 204518b2-8012-11eb-09f7-1f5da1ba4e1d
begin
train = SpikeTimit.create_dataset(;dir= train_path)
test = SpikeTimit.create_dataset(;dir= test_path)
dict = SpikeTimit.create_dictionary(file=dict_path)
end
# ╔═╡ ae1af4c8-8011-11eb-0797-8d37104dcef5
md"""Select a subset of the dataset. Here for convenience I created a query to find all the sentences that contain the word "spring" in the train set.
ou can look up at the query and do others on the base of the DataFrame style. I suggest you to use the @linq macro and read the documentation carefully."""
# ╔═╡ f24740dc-8063-11eb-309d-e578fe133f5b
d_word = SpikeTimit.find_word(word="spring", df=train)
# ╔═╡ 4504b352-8011-11eb-0b7b-e97d88db44c9
md"""
Obviously, you can also choose the sentences by selecting some specific rows. Careful, the dataset has not an ordering.
"""
# ╔═╡ 40474e38-8011-11eb-0a99-917702896ff5
d_number = train[1,:]
# ╔═╡ 2ed1509a-8011-11eb-243e-e5df7b658803
md"""
Once you have sub-selected some columns, you can import the spike times. they are already corrected as explained in the PDF"""
# ╔═╡ 570de36a-8064-11eb-3bcb-076383a55a63
spikes= SpikeTimit.get_spiketimes(df=d_number)
# ╔═╡ 5c4b7dae-8064-11eb-1e77-b55a9c0588c0
plt1 = SpikeTimit.raster_plot(spikes[1])
plt1.plt
md"""
Also, the dataframe contains these fields:
speaker : the ID of the speaker
senID : the ID of the sentence
path : the path to access it (so you can retrieve the correct file adding the .EXT )
words : the words and their respective timing in ms
phones : the phones and their respective timing in ms
You can access it in this way:
"""
speakerID_firstsent = train[1,:speaker]
words_firstsent = train[1,:words]
# ╔═╡ 02a71dd8-8011-11eb-3f32-bb85b0c102f5
md"""
References
----------
[1] _Pan, Zihan, Yansong Chua, Jibin Wu, Malu Zhang, Haizhou Li, and Eliathamby Ambikairajah. “An Efficient and Perceptually Motivated Auditory Neural Encoding and Decoding Algorithm for Spiking Neural Networks.” Frontiers in Neuroscience 13 (2020). https://doi.org/10.3389/fnins.2019.01420._
[2] _Litwin-Kumar, Ashok, and Brent Doiron. “Formation and Maintenance of Neuronal Assemblies through Synaptic Plasticity.” Nature Communications 5, no. 1 (December 2014). https://doi.org/10.1038/ncomms6319._
"""
"""
Extract all the firing times and the corresponding neurons from an array with all the neurons and their relative firing times. i.e. the reverse_dictionary
"""
function reverse_dictionary(spikes)
all_times = Dict()
for n in eachindex(spikes)
if !isempty(spikes[n])
for t in spikes[n]
tt = round(Int,t*10000)
if haskey(all_times,tt)
all_times[tt] = [all_times[tt]..., n]
else
push!(all_times, tt=>[n])
end
end
end
end
return all_times
end
"""
From the reverse_dictionary data structure obtain 2 arrays that are faster to access in the simulation loop.
1. First array contains the sorted spike times.
2. The second array contains vectors with each firing neuron
"""
function sort_spikes(dictionary)
neurons = Array{Vector{Int}}(undef, length(keys(dictionary)))
sorted = sort(collect(keys(dictionary)))
for (n,k) in enumerate(sorted)
neurons[n] = dictionary[k]
end
return sorted, neurons
end
dictionary = reverse_spiketimes(spikes[1])
sorted, neurons = sort_spikes(dictionary)
## This will count all the firing event happened
firing_index= 1
next_firing_time= sorted[firing_index]
# This is the loop in the simulation
for tt in 1:10000
if tt == next_firing_time
firing_neurons = neurons[firing_index]
firing_index +=1
next_firing_time = sorted[firing_index]
println("At time step: ", tt)
println("These neurons fire: ", firing_neurons)
end
end
# SpikeTimit module
```@contents
```
```@meta
CurrentModule = SpikeTimit
DocTestSetup = quote
using SpikeTimit
end
```
```@docs
SpikeTimit.find_word
```
#https://towardsdatascience.com/how-to-automate-julia-documentation-with-documenter-jl-21a44d4a188f
#this file is part of litwin-kumar_doiron_formation_2014
#Copyright (C) 2014 Ashok Litwin-Kumar
#see README for more information
#==
This file looks into the poperties of the SpikeTimit dataset.
I will:
1. Import the train set.
2. Select a subspace of the dataset with fixed:
regional accent (dr1)
gender (f)
set of words ()
set of words (["that", "had", "she", "me", "your", "year"])
3. Show the raster plot of a single realization of these words in the dataset
4. Train a classifier to distinguish between these words on a reduced feature space.
5. Expand the dataset to 4 regional accents and both genders and
extract the time of the phones that compose the word "year"
6. Train a classifier to distinguish between these phones on a reduced feature space.
7. Extract an average map of these phones.
==#
using Plots
using Random
using MLDataUtils
using MLJLinearModels
using StatsBase
## 1
include("SpikeTimit.jl")
test_path = joinpath(@__DIR__,"Spike TIMIT", "test");
train_path = joinpath(@__DIR__,"Spike TIMIT", "train");
dict_path = joinpath(@__DIR__,"DOC", "TIMITDIC.TXT");
train = SpikeTimit.create_dataset(;dir= train_path)
test = SpikeTimit.create_dataset(;dir= test_path)
dict = SpikeTimit.create_dictionary(file=dict_path)
## 2
Dong_words = ["that", "had", "she", "me", "your", "year"]
speakers = []
for word in Dong_words
speaker = @linq train |>
where(:dialect .== "dr1", :gender.=='f', word . :words) |>
select(:speaker) |> unique
push!(speakers,Set(speaker.speaker))
end
speakers = collect(intersect(speakers...))
single_speaker_train = @where(train,:speaker.==speakers[1])
## 3
spikes = []
plots = []
for word in Dong_words
spiketimes, duration, phones = SpikeTimit.get_word_spiketimes(;df=single_speaker_train, word=word)
@show typeof(spiketimes)
spiketimes = SpikeTimit.resample_spikes(;spiketimes=spiketimes, n_feat=4)
push!(spikes,spiketimes[1])
rst = SpikeTimit.raster_plot(spiketimes[1])
plot(rst.plt, title=word)
push!(plots,rst.plt)
end
plot(plots...)
spiketimes[1]
## 4. Compute average time of each word
data = zeros(6,length(spikes[1]))
for st in eachindex(spikes)
for n in eachindex(spikes[st])
if !isempty(spikes[st][n])
data[st,n] =mean(spikes[st][n])
end
end
#
end
plots = []
for n = 1:6
plt = scatter(data[n,:], label=false)
push!(plots,plt)
end
plot(plots...)
## 4.bis Compute the distinguishability
train_dataset = zeros(size(data)[2], 1000)
labels= zeros(Int,1000)
for x in 1:1000
n = rand(1:6)
train_dataset[:,x] .= data[n,:]
labels[x] = n
end
train_dataset
labels
MultinomialLogisticRegression(train_dataset,labels)
## 4 Define classifier
function make_set_index(y::Int64, ratio::Float64)
train, test = Vector{Int64}(), Vector{Int64}()
ratio_0 = length(train)/y
for index in 1:y
# if !(index ∈ train)
if rand()< ratio - ratio_0
push!(train,index)
else
push!(test,index)
end
end
return train, test
end
function labels_to_y(labels)
_labels = collect(Set(labels))
z = zeros(Int,maximum(_labels))
z[_labels] .= 1:length(_labels)
return z[labels]
end
function MultinomialLogisticRegression(X::Matrix{Float64},labels::Array{Int64}; λ=0.5::Float64, test_ratio=0.5)
n_classes = length(Set(labels))
while length(labels) >size(X,2)
pop!(labels)
end
y = labels_to_y(labels)
n_features = size(X,1)
train, test = make_set_index(length(y),test_ratio)
@show length(train)
train_std = StatsBase.fit(ZScoreTransform, X[:, train], dims=2)
StatsBase.transform!(train_std,X)
intercept = false
train_dataset[isnan.(train_dataset)].=0
# deploy MultinomialRegression from MLJLinearModels, λ being the strenght of the reguliser
mnr = MultinomialRegression(λ; fit_intercept=intercept)
# Fit the model
θ = MLJLinearModels.fit(mnr, X[:,train]', y[train])
# # The model parameters are organized such we can apply X⋅θ, the following is only to clarify
# Get the predictions X⋅θ and map each vector to its maximal element
# return θ, X
preds = MLJLinearModels.softmax(MLJLinearModels.apply_X(X[:,test]',θ,n_classes))
targets = map(x->argmax(x),eachrow(preds))
#and evaluate the model over the labels
scores = mean(targets .== y[test])
params = reshape(θ, n_features +Int(intercept), n_classes)
return scores, params
end
## Select a subset of regional dialect and the word "year".
single_dialect = @where(train,occursin.(:dialect, "dr1"*"dr2"*"dr3"), word . :words)
df = SpikeTimit.find_word(word="year",df=single_dialect)
## Get the time intervals of all the phones contained in the word, within the dataset
phones = get_phones_in_word(;df=df, word=word)
spikes = SpikeTimit.get_spiketimes(;df)
SpikeTimit.resample_spikes(;spiketimes=spikes, n_feat=2)
phone_labels, spiketimes = get_phones_spiketimes(spikes, phones)
## Compute the mean for each phone and set the classification stage
get_mean(x) = isempty(x) ? 0. : mean(x)
data = map(neuron -> get_mean.(neuron), spiketimes)
all_labels = collect(Set(phone_labels))
train_dataset = zeros(length(data[1]), 1000)
labels= zeros(Int,1000)
for x in 1:1000
n = rand(1:length(data))
phone = phone_labels[n]
train_dataset[:,x] .= data[n]
labels[x] = findfirst(l-> phone .== l, all_labels)
end
MultinomialLogisticRegression(train_dataset,labels)
## Show the average qualities of all the phones
plots=[]
_labels=[]
for n = 1:length(all_labels)
realizations = findall(l-> all_labels[n] .== l, phone_labels)
if length(realizations) > 5
traceplot = SpikeTimit.TracePlot(length(spiketimes[1]), alpha=1/length(realizations))
plot(traceplot.plt, title=all_labels[n])
push!(plots, traceplot)
push!(_labels, all_labels[n])
end
end
plots
# plots = repeat([traceplot],length(all_labels))
for (_l, ax) in zip(_labels, plots)
realizations = findall(l-> _l .== l, phone_labels)
@show length(realizations)
for r in realizations
SpikeTimit.raster_plot(spiketimes[r], ax=ax, alpha=0.1)
end
end
plot([p.plt for p in plots]..., legend=false, alpha=0.0001)
## import the dataset. Notice, the spike-time are not imported, only the files are stored and ready to be read. ## import the dataset. Notice, the spike-time are not imported, only the files are stored and ready to be read.
include("SpikeTimit.jl") include("src/SpikeTimit.jl")
using .SpikeTimit using .SpikeTimit
using Plots
basedir = "/home/cocconat/Documents/Research/phd_project/speech/litwin-kumar_model_thesis/Spike TIMIT"
test_path = joinpath(@__DIR__,"Spike TIMIT", "test" ) test_path = joinpath(basedir, "test" )
train_path = joinpath(@__DIR__,"Spike TIMIT", "train" ) train_path = joinpath(basedir,"train" )
dict_path = joinpath(@__DIR__,"DOC","TIMITDIC.TXT") dict_path = joinpath(basedir,"DOC","TIMITDIC.TXT")
dict = SpikeTimit.create_dictionary(file=dict_path) dict = SpikeTimit.create_dictionary(file=dict_path)
train = SpikeTimit.create_dataset(;dir= train_path) train = SpikeTimit.create_dataset(;dir= train_path)
...@@ -20,7 +22,10 @@ my_sent.words ...@@ -20,7 +22,10 @@ my_sent.words
spikes = SpikeTimit.get_spiketimes(;df=my_sent) spikes = SpikeTimit.get_spiketimes(;df=my_sent)
new_spikes = resample_spikes(spike_times=spikes[1]) new_spikes = SpikeTimit.resample_spikes(spike_times=spikes[1],n_feat=8)
histogram(vcat(new_spikes...),bins=0:0.01:3 )
## Verify spike resampling is correct ## Verify spike resampling is correct
...@@ -28,17 +33,16 @@ z = 0 ...@@ -28,17 +33,16 @@ z = 0
for neurons in spikes[1] for neurons in spikes[1]
z += length(neurons) z += length(neurons)
end end
y = 0 y = 0
for neurons in new_spikes for neurons in new_spikes
y += length(neurons) y += length(neurons)
end end
@show z,y @show z,y
##
my_sent.words[:,3]
## Count number of spikes
my_sent.words[:,3]
counter = 0 counter = 0
for x in new_spiketimes for x in new_spiketimes
counter += length(x) counter += length(x)
...@@ -50,6 +54,17 @@ using .SpikeTimit ...@@ -50,6 +54,17 @@ using .SpikeTimit
project= SpikeTimit.find_word(; word="project", df=train) project= SpikeTimit.find_word(; word="project", df=train)
intervals = SpikeTimit.get_interval_word(;df=project, word="project") intervals = SpikeTimit.get_interval_word(;df=project, word="project")
spiketimes= SpikeTimit.get_spiketimes(; df=project) spiketimes= SpikeTimit.get_spiketimes(; df=project)
vcat(spiketimes[1]...)
h = histogram(vcat(spiketimes[1]...),bins=0:0.01:3 )
##
for spike in spiketimes
# spikes = SpikeTimit.resample_spikes(spike_times=spikes,n_feat=8)
h = histogram(vcat(spike...),bins=0:0.01:3 )
push!(plots,h)
end
plot(plots...)
sts, durations = SpikeTimit.get_spikes_in_interval(;spiketimes=spiketimes, intervals=intervals) sts, durations = SpikeTimit.get_spikes_in_interval(;spiketimes=spiketimes, intervals=intervals)
SpikeTimit.stack_spiketimes(sts, durations, 0.1) SpikeTimit.stack_spiketimes(sts, durations, 0.1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment