Commit 02aee98a authored by Leonie1.Wagner's avatar Leonie1.Wagner
Browse files

start of energy encoding set up

parent ccdeaad9
...@@ -34,17 +34,56 @@ dict = SpikeTimit.create_dictionary(file=dict_path) ...@@ -34,17 +34,56 @@ dict = SpikeTimit.create_dictionary(file=dict_path)
## ##
words = ["that"] # Parameters to change to filter the data
target_dialects = [1] measurements_per_word = 1; # how often are the states stored per word
target_gender = "f" # "fm" "m" measurements_per_phone = 10; # how often are the states stored per phone
in_words(df_words) = !isempty(intersect(Set(df_words),Set(words))) timestep = 0.1;
in_dialect(df_dialect) = df_dialect target_dialects spikes_per_burst_increase = -1
in_gender(df_gender) = occursin(df_gender, target_gender) samples_per_word = 25
repetitions = 2 # amount of times you present the network with each unique stimulus.
silence_time = 0.15 # in seconds
n_features = 1 # number of features combined from input frequencies
random_seed = 10
nbr_of_pop = -1
words = ["that", "had", "she", "me", "your", "all", "like", "don't", "year", "water", "dark", "rag", "oily", "wash", "ask", "carry", "suit"]
target_dialects = "dr1 dr2 dr3 dr4 dr5 dr6 dr7 dr8"
target_gender = "f m"
#target_dialects = [1 2 3 4 5 6 7 8]
#target_gender = "fm" # "fm" "m"
# in_words(df_words) = !isempty(intersect(Set(df_words),Set(words)))
# in_dialect(df_dialect) = df_dialect ∈ target_dialects
# in_gender(df_gender) = occursin(df_gender, target_gender)
# this is a DataFrameMeta macro # this is a DataFrameMeta macro
speaker = @where(train,in_dialect.(:dialect), in_gender.(:gender), in_words.(:words)) #speaker = @where(train,in_dialect.(:dialect), in_gender.(:gender), in_words.(:words))
speaker.words #speaker
words = TIMIT.get_spectra(speaker |> Pandas.DataFrame, target_words=["that"]) # Filtering the dataframe
speakers = []
for word in words
speaker = @linq train |> # SET THIS TO TRAIN/TEST ACCORDINGLY
where(occursin.(:dialect,target_dialects), occursin.(:gender,target_gender), word . :words) |>
select(:speaker) |> unique
push!(speakers,Set(speaker.speaker))
end
speakers = collect(intersect(speakers...))
filtered_df = filter(:speaker=> x->x speakers, train) # This is the filtering of the dataframe where you only select the speakers you want.
include("../src/SpikeTimit.jl")
speaker = SpikeTimit.select_inputs(df=filtered_df, words=words, samples = samples_per_word, n_feat = n_features, random_seed=random_seed);
py_words = []
for i in 1:length(words)
push!(py_words, TIMIT.get_spectra(speaker[i] |> Pandas.DataFrame, target_words=words[i]));
end
words = [(py_words...)...]
labels = []
for i in 1:size(words,1)
push!(labels, words[i].word)
end
labels
## ##
## ##
words[1].phones[1].db words[1].phones[1].db
...@@ -62,25 +101,80 @@ function py2j_words(words) ...@@ -62,25 +101,80 @@ function py2j_words(words)
end end
return jwords return jwords
end end
words =py2j_words(words) jwords =py2j_words(words)
function rate_coding_word(word::SpikeTimit.Word) function rate_coding_word(word::SpikeTimit.Word)
times = [] times = []
durations = []
encoding = Matrix{Float64}(undef, 20, length(word.phones)) encoding = Matrix{Float64}(undef, 20, length(word.phones))
for (n,ph) in enumerate(word.phones) for (n,ph) in enumerate(word.phones)
encoding[:,n] = mean(ph.db, dims=2)[:,1] encoding[:,n] = mean(ph.db, dims=2)[:,1]
push!(times, ph.t0 - word.t0) push!(times, ph.t0 - word.t0)
push!(durations, ph.t1 - ph.t0)
end end
return times, encoding return times, durations, encoding
end end
using Plots times = []
times, phs = rate_coding_word(words[1]) durations = []
a = heatmap(words[1].phones[1].db) rates = []
b = heatmap(words[1].phones[2].db) for word in jwords
c = heatmap(words[1].phones[3].db) t, d, r = rate_coding_word(word)
words[1].word push!(times, t)
Plots.plot(a,b,c, layout=(1,3), colorbar=false, axes=nothing, ticks=nothing) push!(durations, d)
times, phs = rate_coding_word(words[9]) push!(rates, r)
heatmap(phs) end
words[1].phones[1].ph times[1] # starting times (in s?) of each phone in the first word
durations[1]
rates[1] # the rates to each population (rows) for each phone (columns) in the word
typeof(rates)
size(rates[2],2)
new_rates = copy(rates)
for (i, rate) in enumerate(rates)
for ph in 1:size(rate,2)
new_rates[i][:,ph] = (rate[:,ph]./sum(rate[:,ph])) .*8
end
end
new_rates[1] # the rates to each population (rows) for each phone (columns) in the word, normalized to sum to 8
# repeat
all_times = repeat(times, repetitions)
all_durations = repeat(durations, repetitions)
all_rates = repeat(new_rates, repetitions)
all_labels = repeat(labels, repetitions)
# shuffle
using Random
ind = shuffle(1:size(all_times,1))
all_times = all_times[ind,:]
all_durations = all_durations[ind,:]
all_rates = all_rates[ind,:]
all_labels = all_labels[ind,:]
# to do:
# stack
global_time = 0.
for (label, dd) in zip(labels, durations)
@assert(label.duration == dd)
push!(words_transcripts.intervals,(global_time, global_time+dd))
push!(words_transcripts.sign,label.word)
for ph in label.phones
push!(phones_transcripts.intervals, (global_time+ph.t0, global_time+ph.t1))
push!(phones_transcripts.sign,ph.ph)
end
global_time += dd + silence_time
# make the stim
# Plots
# using Plots
# times, phs = rate_coding_word(words[1])
# a = heatmap(words[1].phones[1].db)
# b = heatmap(words[1].phones[2].db)
# c = heatmap(words[1].phones[3].db)
# words[1].word
# Plots.plot(a,b,c, layout=(1,3), colorbar=false, axes=nothing, ticks=nothing)
# times, phs = rate_coding_word(words[9])
# heatmap(phs)
# words[1].phones[1].ph
...@@ -228,7 +228,7 @@ def get_spectra(dataframe, target_words=[], cqt_p=BAE): ...@@ -228,7 +228,7 @@ def get_spectra(dataframe, target_words=[], cqt_p=BAE):
paths = dataframe.path paths = dataframe.path
if isinstance(dataframe.path,str): if isinstance(dataframe.path,str):
paths = [dataframe.path] paths = [dataframe.path]
print(paths) #print(paths)
for my_path in paths: for my_path in paths:
oscillogram, sr = librosa.load(my_path + ".wav") oscillogram, sr = librosa.load(my_path + ".wav")
...@@ -245,7 +245,7 @@ def get_spectra(dataframe, target_words=[], cqt_p=BAE): ...@@ -245,7 +245,7 @@ def get_spectra(dataframe, target_words=[], cqt_p=BAE):
duration = len(oscillogram) / sr duration = len(oscillogram) / sr
osc_sr = len(oscillogram) / duration osc_sr = len(oscillogram) / duration
db_sr = cqt.shape[1] / duration db_sr = cqt.shape[1] / duration
print(final_time/duration, TRANSCRIPT_SR) #print(final_time/duration, TRANSCRIPT_SR)
# %% # %%
words, word_times = [], [] words, word_times = [], []
......
...@@ -81,32 +81,32 @@ module SpikeTimit ...@@ -81,32 +81,32 @@ module SpikeTimit
# # , allowcomments=true, commentmark='%') # # , allowcomments=true, commentmark='%')
function get_dialect(root) function get_dialect(root)
return split(root,"\\") |> x->parse(Int,filter(startswith("dr"),x)[1][end]) return String(split(root,"\\") |> x->filter(startswith("dr"),x)[1])
end end
function create_dataset(;dir) function create_dataset(;dir)
df = DataFrame(speaker = String[] , senID = String[], dialect=Int[], gender=Char[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[], sentence=Vector{String}[]) df = DataFrame(speaker = String[] , senID = String[], dialect=String[], gender=Char[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[])
for (root, dirs, files) in walkdir(dir) for (root, dirs, files) in walkdir(dir)
for file in files for file in files
if endswith(file,"wav") if endswith(file,"wav")
speaker = split(root, "\\")[end] speaker = split(root, "\\")[end]
senID = split(file,".")[1] senID = split(file,".")[1]
words = get_words(root, senID) words = get_words(root, senID)
phones = get_phones(root, senID) phones = get_phones(root, senID)
dr = get_dialect(root) dr = get_dialect(root)
gender = speaker[1] gender = speaker[1]
sentence = String.(words[:,1]) push!(df,(speaker,senID,dr,gender,joinpath(root,senID),words,phones))
push!(df,(speaker,senID,dr,gender,joinpath(root,senID),words,phones,sentence)) end
end end
end end
end return df
return df end
end
function find_word(;df::DataFrame, word::String)
"""Search for word in the dataset and return the items"""
return @linq df |> where(word . :words)
end
function find_word(;df::DataFrame, word::String)
"""Search for word in the dataset and return the items"""
return @linq df |> where(word . :words)
end
######################## ########################
## Extract .mat files ## ## Extract .mat files ##
...@@ -379,10 +379,11 @@ module SpikeTimit ...@@ -379,10 +379,11 @@ module SpikeTimit
end end
function select_inputs(; df, words, samples=10, n_feat = 7) function select_inputs(; df, words, samples=10, n_feat = 7, random_seed=NaN)
all_spikes = [] all_spikes = []
all_durations = [] all_durations = []
all_labels = [] all_labels = []
new_df = []
for (i, word) in enumerate(words) for (i, word) in enumerate(words)
df_word = find_word(word=word, df=df) df_word = find_word(word=word, df=df)
...@@ -391,6 +392,9 @@ module SpikeTimit ...@@ -391,6 +392,9 @@ module SpikeTimit
#randomly order the number of occurences to sample #randomly order the number of occurences to sample
if samples <= n_occurences if samples <= n_occurences
if !isnan(random_seed)
Random.seed!(random_seed)
end
inds = randperm(n_occurences)[1:samples] inds = randperm(n_occurences)[1:samples]
else else
message = string("WARNING: for word: '", word, "', samples per word (", samples, ") exceeds the number of occurences (", n_occurences, ")") message = string("WARNING: for word: '", word, "', samples per word (", samples, ") exceeds the number of occurences (", n_occurences, ")")
...@@ -407,9 +411,15 @@ module SpikeTimit ...@@ -407,9 +411,15 @@ module SpikeTimit
push!(all_spikes, spiketimes...) push!(all_spikes, spiketimes...)
push!(all_durations, durations...) push!(all_durations, durations...)
push!(all_labels,labels...) push!(all_labels,labels...)
push!(new_df, df_word[inds,:])
# if i == 1
# new_df = df_word[inds,:]
# else
# append!(new_df, df_word[inds,:])
# end
end end
return all_durations, all_spikes, all_labels return new_df
end end
function mix_inputs(;durations, spikes, labels, repetitions, silence_time) function mix_inputs(;durations, spikes, labels, repetitions, silence_time)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment