Commit 3abaab93 authored by Alessio Quaresima's avatar Alessio Quaresima
Browse files

Merge branch 'master' of gitlab.socsci.ru.nl:alessio.quaresima1/spiketimit.jl

parents a3aab63d 0e8b3983
using PyCall
using DataFrames
using DataFramesMeta
using Pandas
cd(joinpath(@__DIR__,".."))
py"""
import sys
import os
sys.path.insert(0, os.getcwd())
print(sys.path)
"""
TIMIT = pyimport("TIMIT_loader")
pyimport("importlib")."reload"(TIMIT)
##
#path = "C:\\Users\\leoni\\Desktop\\3rd_year_AI\\1_Thesis\\litwin-kumar_model_thesis\\Spike TIMIT"
path = "/home/cocconat/Documents/Research/phd_project/speech/litwin-kumar_model_thesis/Spike TIMIT"
dataset = TIMIT.create_dataset(joinpath(path,"train"))
spkrinfo, spkrsent = TIMIT.create_spkrdata(path)
##
include("../src/SpikeTimit.jl")
#Create the path strings leading to folders in the data set
test_path = joinpath(path, "test");
train_path = joinpath(path, "train");
dict_path = joinpath(path, "DOC/TIMITDIC.TXT");
train = SpikeTimit.create_dataset(;dir= train_path)
test = SpikeTimit.create_dataset(;dir= test_path)
dict = SpikeTimit.create_dictionary(file=dict_path)
##
const NYC = 6
const SOUTH = 5
const male = "m"
const female = "f"
##
words = ["that", "had", "she", "me", "your", "all", "like", "don't", "year", "water", "dark", "rag", "oily", "wash", "ask", "carry", "suit"]
# nyc_male = @where(train,in_dialect.(:dialect, target=NYC), in_gender.(:gender, target=male), in_words.(:words, target=words))
# nyc_female = @where(train,in_dialect.(:dialect, target=NYC), in_gender.(:gender, target=female), in_words.(:words, target=words))
# south_male = @where(train,in_dialect.(:dialect, target=SOUTH), in_gender.(:gender, target=male), in_words.(:words, target=words))
# south_female = @where(train,in_dialect.(:dialect, target=SOUTH), in_gender.(:gender, target=female), in_words.(:words, target=words))
##
##
##
# words[1].phones[1].db
using StatsBase
using Plots
function merge_phones(phone1, phone2)
ph = Phone(phone1.ph*phone2.ph,phone1.t0, phone2.t1 )
ph.osc = vcat(phone1.osc, phone2.osc)
ph.db = vcat(phone1.db, phone2.db)
return ph
end
function get_all_phones(dataset, words)
all_phones = Dict()
isa(words, String) && (words = [words])
for word in words
matches = SpikeTimit.find_word(;df=dataset, word=word) |> x-> TIMIT.get_spectra(x |> Pandas.DataFrame, target_words=word)
for word in SpikeTimit.py2j_words(matches)
# println("\n")
for phone in word.phones
# print(phone.ph," ")
if haskey(all_phones,phone.ph)
push!(all_phones[phone.ph], phone)
else
push!(all_phones,phone.ph=>[phone])
end
end
end
end
return all_phones
end
function padding(phone; pad_length=400, pad_zero = -80)
# mat = zeros(20, pad_length)
mat = ones(20, pad_length) .* (-80)
if size(phone.db)[2]> 400
mat[:,1:400] .= phone.db[:,1:400]
else
mat[:,1:size(phone.db)[2]] .= phone.db[:,1:end]
end
return mat
end
function get_padded_features(all_phones)
phones = Vector{String}(undef, length(all_phones))
cases = Vector{Int}(undef, length(all_phones))
features = Vector{Matrix{Float64}}(undef, length(all_phones))
phones_labels = collect(keys(all_phones))
Threads.@threads for idx in 1:length(phones_labels)
ph = phones_labels[idx]
samples = length(all_phones[ph])
phones[idx]= ph
cases[idx] = samples
data_phone = zeros(400*20, samples)
for n in 1:samples
data_phone[:, n] .= padding(all_phones[ph][n])[:]
end
features[idx] = data_phone
end
return phones, cases, features
end
function compare_sounds(samples1, samples2)
all_phones = []
push!(all_phones, samples1[1]...)
push!(all_phones, samples2[1]...)
all_phones = collect(Set(all_phones))
measures = zeros(length(all_phones),2)
for (n,ph) in enumerate(all_phones)
ex1 = findfirst(x->x==all_phones[n],samples1[1])
ex2 = findfirst(x->x==all_phones[n],samples2[1])
measures[n,1] = isnothing(ex1) ? 1 : samples1[2][ex1]
measures[n,2] = isnothing(ex2) ? 1 : samples2[2][ex2]
end
return jsd(measures[:,1], measures[:,2])
end
jsd(x,y)= 0.5*(kldivergence(x, y) + kldivergence(y,x))
# crossentropy(dr1, dr2)
##
males=[]
females=[]
for dr in 1:8
push!(males,@where(train,in_dialect.(:dialect, target=dr), in_gender.(:gender, target=male), in_words.(:words, target=words)) |> x->get_padded_features(get_all_phones(x, words)))
push!(females,@where(train,in_dialect.(:dialect, target=dr), in_gender.(:gender, target=female), in_words.(:words, target=words)) |> x->get_padded_features(get_all_phones(x, words)))
end
# ##
# data = [males..., females...]
# entropy = zeros(16,16)
# for x in 1:16
# for y in 1:16
# entropy[x,y] = compare_sounds(data[x], data[y])
# end
# end
# hclust(entropy)
# hclust(entropy)
# # heatmap(entropy)
##
all_sounds=[]
for (sounds, samples, _) in data
push!(all_sounds, sounds)
end
common_sounds = collect(intersect(Set.(all_sounds)...))
mean_data = zeros(8000, 16, length(common_sounds))
std_data = zeros(8000, 16, length(common_sounds))
for (n,s) in enumerate(common_sounds)
speaker=0
for (sounds, samples, feats) in data
speaker+=1
id = findfirst(x->x==s, sounds)
mean_data[:, speaker, n] = mean(feats[id], dims=2)[:,1]
std_data[:, speaker, n] = std(feats[id], dims=2)[:,1]
end
end
# heatmap(reshape(std_data[:,16,10],20,400))
# heatmap(reshape(mean_data[:,16,10],20,400))
##
phone_div =zeros(35,35,2)
for speaker in 1:16
for x in 1:35
for y in 1:35
phone_div[x,y,1]+=jsd(abs.(mean_data[:,speaker,x]),abs.(mean_data[:,speaker,y]))
phone_div[x,y,2]+=jsd(abs.(std_data[:,speaker,x]),abs.(std_data[:,speaker,y]))
end
end
end
speaker_div =zeros(16,16,2)
for phone in 1:35
for spx in 1:16
for spy in 1:16
speaker_div[spx,spy,1]+=jsd(abs.(mean_data[:,spx,phone]),abs.(mean_data[:,spy,phone]))
speaker_div[spx,spy,2]+=jsd(abs.(std_data[:,spx,phone]),abs.(std_data[:,spy,phone]))
end
end
end
##
using Clustering
drs = [
"m_New England",
"m_Northern",
"m_North Midland",
"m_South Midland",
"m_Southern",
"m_New York City",
"m_Western",
"m_Moved around",
"f_New England",
"f_Northern",
"f_North Midland",
"f_South Midland",
"f_Southern",
"f_New York City",
"f_Western",
"f_Moved around"]
drs= repeat(drs,2)
order = hclust(speaker_div[:,:,1],linkage=:complete).order
heatmap(speaker_div[order,order,1])
heatmap(speaker_div[order,order,2])
pdialects = plot!(xticks=(1:16,drs[order]),yticks=(1:16, drs[order]), xrotation=45, colorbar=false, tickfontsize=10, title="Dialects cross-entropy")
##
heatmap(phone_div)
heatmap(reshape(all_data[:,6,10], 20,400))
all_data[:,6,10]
order = hclust(phone_div, linkage=:complete).order
heatmap(phone_div[order,order])
common_sounds
pphones =plot!(xticks=(1:35,collect(common_sounds)[order]),yticks=(1:35, common_sounds[order]), xrotation=45, colorbar=false, tickfontsize=10, title="Phones cross-entropy")
##
savefig(pphones, "phones_similarities.png")
savefig(pdialects, "dialects_crossentropy.png")
##
##
variations =zeros(8000,16,35)
for phone in 1:35
for sp in 1:16
variations[:,sp, phone]= std
end
end
#
data[1][3]
#
# mean(data[1][3][10], dims=2)[:,1]
# # , dims=2
# # , dims=2
#
# data[1][3][10]
# heatmap(map(x->reshape(x, 20,400), mean.(data[1][3][2], dims=2))[1])
# heatmap(map(x->reshape(x, 20,400), mean.(nyc_male_data[3], dims=2))[2])
# heatmap(map(x->reshape(x, 20,400), mean.(nyc_male_data[3], dims=2))[3])
# heatmap(map(x->reshape(x, 20,400), mean.(male_data[3], dims=2))[4])
# heatmap(map(x->reshape(x, 20,400), mean.(nyc_male_data[3], dims=2))[5])
# heatmap(map(x->reshape(x, 20,400), mean.(nyc_male_data[3], dims=2))[6])
# heatmap(map(x->reshape(x, 20,400), mean.(nyc_male_data[3], dims=2))[35])
##
phone_div =ones(16,16,35,35)*-1
for sp1 in 1:16
for sp2 in 1:sp1
for x in 1:35
for y in 1:x
phone_div[sp1,sp2,x,y]=jsd(abs.(all_data[:,sp1,x]),abs.(all_data[:,sp2,y]))
phone_div[sp2,sp1,y,x]= phone_div[sp1,sp2,x,y]
end
end
end
end
##
heatmap(phone_div[1,:,4,:])
hclust(phone_div)
##
word = "little"
matches = SpikeTimit.find_word(;df=train, word=word)
water_f = @where(matches, SpikeTimit.in_gender.(:gender,target="f")) |> x-> TIMIT.get_spectra(x |> Pandas.DataFrame, target_words=word)[1]
water_m = @where(matches, SpikeTimit.in_gender.(:gender,target="m")) |> x-> TIMIT.get_spectra(x |> Pandas.DataFrame, target_words=word)[2]
function stack_phones(word)
phones = []
for ph in word.phones
push!(phones,ph.db)
end
return phones
end
using Plots
p = Plots.plot(heatmap(hcat(stack_phones(water_f)...), title="Word: "*word),heatmap(hcat(stack_phones(water_m)...), ylabel="Frequency", xlabel="Duration (ms)"), layout=(2,1), colorbar=false, guidefontsize=18, tickfontsize=15, titlefontsize=20)
savefig(p,"word_little.png")
q = Plots.plot(heatmap(water_f.phones[4].db, title="Phone: EL"),heatmap(water_m.phones[4].db, ylabel="Frequency", xlabel="Duration (ms)"), layout=(2,1), colorbar=false, guidefontsize=18, tickfontsize=15, titlefontsize=20)
savefig(q,"phone_little.png")
##
heatmap(padding(all_phones["ae"][1]))
heatmap(padding(all_phones["ae"][2]))
heatmap(padding(all_phones["ae"][3]))
##
function rate_coding_word(word::SpikeTimit.Word)
times = []
encoding = Matrix{Float64}(undef, 20, length(word.phones))
for (n,ph) in enumerate(word.phones)
encoding[:,n] = mean(ph.db, dims=2)[:,1]
push!(times, ph.t0 - word.t0)
end
return times, encoding
end
using Plots
times, phs = rate_coding_word(words[1])
a = heatmap(words[1].phones[1].db)
b = heatmap(words[1].phones[2].db)
c = heatmap(words[1].phones[3].db)
words[1].word
Plots.plot(a,b,c, layout=(1,3), colorbar=false, axes=nothing, ticks=nothing)
times, phs = rate_coding_word(words[9])
heatmap(phs)
words[1].phones[1].ph
### A Pluto.jl notebook ###
# v0.12.21
using Markdown
using InteractiveUtils
# ╔═╡ 70951726-8063-11eb-2ba9-1fa3822b9e91
using .SpikeTimit
# ╔═╡ c0d57c40-8010-11eb-3004-195f0590db26
md"""
This notebook will show how to import data from SpikeTimit[1] database and run it as a stimulus for the LKD network[2]. The SpikeTimit.jl module import the standard daataset released with the publication [1].
"""
# ╔═╡ 62beb3c0-8011-11eb-1ab5-8de2f870d0b2
md""" Import the module and the relevant packages"""
# ╔═╡ 8015ec4a-8011-11eb-04d6-9bcd09fada86
PATH = joinpath(@__DIR__,"SpikeTimit.jl")
# ╔═╡ 5e4f6080-8063-11eb-39b1-ab78ccdbf423
include(PATH)
# ╔═╡ 24e8c3e2-8064-11eb-23dc-2f6848c499e5
# ╔═╡ 693fb58a-8063-11eb-3e59-c9249009d1d6
# ╔═╡ 4566f194-8012-11eb-39d7-ad467788a78b
md"""
Import the dataset. Notice, the spike-time are not imported, only the files are stored and ready to be read. You have to set your PATH fort the dataset
"""
# ╔═╡ 42f1c31e-8063-11eb-0059-75ffb8aa555a
begin
test_path = joinpath(@__DIR__,"Spike TIMIT", "test" );
train_path = joinpath(@__DIR__,"Spike TIMIT", "train" );
dict_path = joinpath(@__DIR__,"DOC","TIMITDIC.TXT");
end
# ╔═╡ 39885efc-8064-11eb-071d-c1eaa5f8892b
# ╔═╡ d3e653d0-8063-11eb-15e8-019ae2ff331a
md""" dict is a list of all words with the correesponding phones."""
# ╔═╡ 204518b2-8012-11eb-09f7-1f5da1ba4e1d
begin
train = SpikeTimit.create_dataset(;dir= train_path)
test = SpikeTimit.create_dataset(;dir= test_path)
dict = SpikeTimit.create_dictionary(file=dict_path)
end
# ╔═╡ ae1af4c8-8011-11eb-0797-8d37104dcef5
md"""Select a subset of the dataset. Here for convenience I created a query to find all the sentences that contain the word "spring" in the train set.
ou can look up at the query and do others on the base of the DataFrame style. I suggest you to use the @linq macro and read the documentation carefully."""
# ╔═╡ f24740dc-8063-11eb-309d-e578fe133f5b
d_word = SpikeTimit.find_word(word="spring", df=train)
# ╔═╡ 4504b352-8011-11eb-0b7b-e97d88db44c9
md"""
Obviously, you can also choose the sentences by selecting some specific rows. Careful, the dataset has not an ordering.
"""
# ╔═╡ 40474e38-8011-11eb-0a99-917702896ff5
d_number = train[1,:]
# ╔═╡ 2ed1509a-8011-11eb-243e-e5df7b658803
md"""
Once you have sub-selected some columns, you can import the spike times. they are already corrected as explained in the PDF"""
# ╔═╡ 570de36a-8064-11eb-3bcb-076383a55a63
spikes= SpikeTimit.get_spiketimes(df=d_number)
# ╔═╡ 5c4b7dae-8064-11eb-1e77-b55a9c0588c0
plt1 = SpikeTimit.raster_plot(spikes[1])
plt1.plt
md"""
Also, the dataframe contains these fields:
speaker : the ID of the speaker
senID : the ID of the sentence
path : the path to access it (so you can retrieve the correct file adding the .EXT )
words : the words and their respective timing in ms
phones : the phones and their respective timing in ms
You can access it in this way:
"""
speakerID_firstsent = train[1,:speaker]
words_firstsent = train[1,:words]
# ╔═╡ 02a71dd8-8011-11eb-3f32-bb85b0c102f5
md"""
References
----------
[1] _Pan, Zihan, Yansong Chua, Jibin Wu, Malu Zhang, Haizhou Li, and Eliathamby Ambikairajah. “An Efficient and Perceptually Motivated Auditory Neural Encoding and Decoding Algorithm for Spiking Neural Networks.” Frontiers in Neuroscience 13 (2020). https://doi.org/10.3389/fnins.2019.01420._
[2] _Litwin-Kumar, Ashok, and Brent Doiron. “Formation and Maintenance of Neuronal Assemblies through Synaptic Plasticity.” Nature Communications 5, no. 1 (December 2014). https://doi.org/10.1038/ncomms6319._
"""
"""
Extract all the firing times and the corresponding neurons from an array with all the neurons and their relative firing times. i.e. the reverse_dictionary
"""
function reverse_dictionary(spikes)
all_times = Dict()
for n in eachindex(spikes)
if !isempty(spikes[n])
for t in spikes[n]
tt = round(Int,t*10000)
if haskey(all_times,tt)
all_times[tt] = [all_times[tt]..., n]
else
push!(all_times, tt=>[n])
end
end
end
end
return all_times
end
"""
From the reverse_dictionary data structure obtain 2 arrays that are faster to access in the simulation loop.
1. First array contains the sorted spike times.
2. The second array contains vectors with each firing neuron
"""
function sort_spikes(dictionary)
neurons = Array{Vector{Int}}(undef, length(keys(dictionary)))
sorted = sort(collect(keys(dictionary)))
for (n,k) in enumerate(sorted)
neurons[n] = dictionary[k]
end
return sorted, neurons
end
dictionary = reverse_spiketimes(spikes[1])
sorted, neurons = sort_spikes(dictionary)
## This will count all the firing event happened
firing_index= 1
next_firing_time= sorted[firing_index]
# This is the loop in the simulation
for tt in 1:10000
if tt == next_firing_time
firing_neurons = neurons[firing_index]
firing_index +=1
next_firing_time = sorted[firing_index]
println("At time step: ", tt)
println("These neurons fire: ", firing_neurons)
end
end
#this file is part of litwin-kumar_doiron_formation_2014
#Copyright (C) 2014 Ashok Litwin-Kumar
#see README for more information
#==
This file looks into the poperties of the SpikeTimit dataset.
I will:
1. Import the train set.
2. Select a subspace of the dataset with fixed:
regional accent (dr1)
gender (f)
set of words ()
set of words (["that", "had", "she", "me", "your", "year"])
3. Show the raster plot of a single realization of these words in the dataset
4. Train a classifier to distinguish between these words on a reduced feature space.
5. Expand the dataset to 4 regional accents and both genders and
extract the time of the phones that compose the word "year"
6. Train a classifier to distinguish between these phones on a reduced feature space.
7. Extract an average map of these phones.
==#
using Plots
using Random
using MLDataUtils
using MLJLinearModels
using StatsBase
## 1
include("SpikeTimit.jl")
test_path = joinpath(@__DIR__,"Spike TIMIT", "test");
train_path = joinpath(@__DIR__,"Spike TIMIT", "train");
dict_path = joinpath(@__DIR__,"DOC", "TIMITDIC.TXT");
train = SpikeTimit.create_dataset(;dir= train_path)
test = SpikeTimit.create_dataset(;dir= test_path)
dict = SpikeTimit.create_dictionary(file=dict_path)
## 2
Dong_words = ["that", "had", "she", "me", "your", "year"]
speakers = []
for word in Dong_words
speaker = @linq train |>
where(:dialect .== "dr1", :gender.=='f', word . :words) |>
select(:speaker) |> unique
push!(speakers,Set(speaker.speaker))
end
speakers = collect(intersect(speakers...))
single_speaker_train = @where(train,:speaker.==speakers[1])
## 3
spikes = []
plots = []
for word in Dong_words
spiketimes, duration, phones = SpikeTimit.get_word_spiketimes(;df=single_speaker_train, word=word)
@show typeof(spiketimes)
spiketimes = SpikeTimit.resample_spikes(;spiketimes=spiketimes, n_feat=4)
push!(spikes,spiketimes[1])
rst = SpikeTimit.raster_plot(spiketimes[1])
plot(rst.plt, title=word)
push!(plots,rst.plt)
end
plot(plots...)
spiketimes[1]
## 4. Compute average time of each word
data = zeros(6,length(spikes[1]))
for st in eachindex(spikes)
for n in eachindex(spikes[st])
if !isempty(spikes[st][n])
data[st,n] =mean(spikes[st][n])
end
end
#
end
plots = []
for n = 1:6
plt = scatter(data[n,:], label=false)
push!(plots,plt)
end
plot(plots...)
## 4.bis Compute the distinguishability
train_dataset = zeros(size(data)[2], 1000)
labels= zeros(Int,1000)
for x in 1:1000
n = rand(1:6)
train_dataset[:,x] .= data[n,:]
labels[x] = n
end
train_dataset
labels
MultinomialLogisticRegression(train_dataset,labels)
## 4 Define classifier
function make_set_index(y::Int64, ratio::Float64)
train, test = Vector{Int64}(), Vector{Int64}()
ratio_0 = length(train)/y
for index in 1:y
# if !(index ∈ train)
if rand()< ratio - ratio_0
push!(train,index)
else
push!(test,index)
end
end
return train, test
end
function labels_to_y(labels)
_labels = collect(Set(labels))
z = zeros(Int,maximum(_labels))
z[_labels] .= 1:length(_labels)
return z[labels]