Commit d066a35c authored by alessio.quaresima's avatar alessio.quaresima
Browse files

Merge branch 'librosa' into 'librosa'

load_spectra works on windows and our SpikeTimit file

See merge request alessio.quaresima1/spiketimit.jl!1
parents ae2c0e91 0180f130
......@@ -3,24 +3,26 @@ using DataFrames
using DataFramesMeta
using Pandas
cd(@__DIR__)
cd(joinpath(@__DIR__,".."))
py"""
import sys
import os
sys.path.insert(0, os.getcwd())
print(sys.path)
"""
pwd()
TIMIT = pyimport("TIMIT_loader")
pyimport("importlib")."reload"(TIMIT)
path = "/home/cocconat/Documents/Research/phd_project/speech/litwin-kumar_model_thesis/Spike TIMIT"
path = "C:\\Users\\leoni\\Desktop\\3rd_year_AI\\1_Thesis\\litwin-kumar_model_thesis\\Spike TIMIT"
dataset = TIMIT.create_dataset(joinpath(path,"train"))
spkrinfo, spkrsent = TIMIT.create_spkrdata(path)
# dataset |> Pandas.DataFrame |> DataFrames.DataFrame
##
include("src/SpikeTimit.jl")
include("../src/SpikeTimit.jl")
#Create the path strings leading to folders in the data set
test_path = joinpath(path, "test");
......
......@@ -73,7 +73,7 @@ def count_dataset(path, dtype="train"):
yield 1
def get_dialect(path):
return int(path.split("/")[-3][-1])
return int(path.split("\\")[-3][-1])
# | > x->parse(Int, filter(startswith("dr"), x)[1][end])
......
......@@ -81,7 +81,7 @@ module SpikeTimit
# # , allowcomments=true, commentmark='%')
function get_dialect(root)
return split(root,"/") |> x->parse(Int,filter(startswith("dr"),x)[1][end])
return split(root,"\\") |> x->parse(Int,filter(startswith("dr"),x)[1][end])
end
function create_dataset(;dir)
......@@ -89,14 +89,14 @@ module SpikeTimit
for (root, dirs, files) in walkdir(dir)
for file in files
if endswith(file,"wav")
speaker = split(root, "/")[end]
speaker = split(root, "\\")[end]
senID = split(file,".")[1]
words = get_words(root, senID)
phones = get_phones(root, senID)
dr = get_dialect(root)
gender = speaker[1]
dr = get_dialect(root)
gender = speaker[1]
sentence = String.(words[:,1])
push!(df,(speaker,senID,dr,gender,joinpath(root,senID),words,phones, sentence))
push!(df,(speaker,senID,dr,gender,joinpath(root,senID),words,phones,sentence))
end
end
end
......@@ -185,7 +185,7 @@ module SpikeTimit
- spikes is an array with inputs in the form Vectr
- durations is the duration in seconds of each encoding
"""
function stack_spiketimes(spikes::Vector{Spiketimes}, durations::Vector{Float64}, silence_time::Float64)
function stack_spiketimes(spikes, durations, silence_time::Float64)
# for the memory allocation
nr_unique_fts = 0
for spike_times in spikes
......@@ -201,13 +201,13 @@ module SpikeTimit
sorted, neurons = sort_spikes(inverse_dictionary(spike_times))
#shift time for each neuron:
sorted .+= global_time
## put them together
lower_bound = filled_indices +1
## put them together
lower_bound = filled_indices + 1
filled_indices += size(sorted,1)
all_ft[lower_bound:filled_indices] = sorted
all_neurons[lower_bound:filled_indices] = neurons
global_time += dd
global_time += silence_time
......@@ -234,21 +234,6 @@ module SpikeTimit
return words_transcripts,phones_transcripts
end
#="""
Return the input of the spike trains corrsponding to the rows, in the (spike-time -> neuron ) format
"""
function inputs_from_df(df, rows)
_df = df[rows,:]
## Get the duration of each frame: it corresponds to the (last) symbol h# in the phones array
durations = _df.phones |> phones->map(phone->phone[end], phones)
spikes = SpikeTimit.get_spiketimes(df=_df)
all_ft, all_n = stack_spiketimes(spikes, durations)
@assert(size(all_ft) == size(all_n))
return all_ft, all_n, durations
end
=#
##########################
## Get words and phonemes
##########################
......@@ -314,7 +299,7 @@ module SpikeTimit
push!(word_phones, ph)
end
end
push!(all_phones, Word(String(my_word[1]), word_phones,t1-t0, t0, t1))
push!(all_phones, Word(String(my_word[1]), word_phones,t1-t0,t0,t1))
end
end
push!(df_phones, all_phones)
......@@ -395,17 +380,17 @@ module SpikeTimit
function select_inputs(; df, words, samples=10, n_feat = 7)
all_spikes = Vector{Spiketimes}()
all_durations = Vector{Float64}()
all_spikes = []
all_durations = []
all_labels = []
for (i, word) in enumerate(words)
df_word = find_word(word=word, df=df)
n_occurences = size(df_word,1)
@show n_occurences
#@show word, n_occurences
#randomly order the number of occurences to sample
if samples <= n_occurences
if samples <= n_occurences
inds = randperm(n_occurences)[1:samples]
else
message = string("WARNING: for word: '", word, "', samples per word (", samples, ") exceeds the number of occurences (", n_occurences, ")")
......@@ -415,7 +400,7 @@ module SpikeTimit
spiketimes, durations = get_spikes_in_word(; df=df_word[inds,:], word)
spiketimes = resample_spikes(spiketimes=spiketimes, n_feat=n_feat)
labels = vcat(get_word_labels(;df=df_word[inds,:], word=word)...)
@show length(labels), length(spiketimes)
#@show length(labels), length(spiketimes)
@assert(length(spiketimes) == length(labels))
......@@ -429,24 +414,20 @@ module SpikeTimit
function mix_inputs(;durations, spikes, labels, repetitions, silence_time)
ids = shuffle(repeat(1:length(durations), repetitions))
all_ft, all_n = stack_spiketimes(spikes[ids], durations[ids], silence_time)
words_t, phones_t = SpikeTimit.stack_labels(labels[ids],durations[ids],silence_time)
return all_ft, all_n, words_t, phones_t
end
function get_savepoints(;trans::Transcription, n_measure::Int)
measures = []
labels = []
for s in eachindex(trans.steps)
step = trans.steps[s]
sign = trans.sign[s]
measures = Array{Int64,2}(undef, size(trans.steps,1), n_measure)
for (i,step) in enumerate(trans.steps)
l = step[2] - step[1]
l_single = floor(Int, l/n_measure)
push!(measures,step[1] .+ collect(1:n_measure).* l_single)
push!(labels, repeat([sign], n_measure))
measures[i,1:n_measure] = (step[1] .+ collect(1:n_measure).* l_single)
# push!(measures,step[1] .+ collect(1:n_measure).* l_single)
end
return measures, labels
return measures
end
"""
......@@ -477,10 +458,10 @@ module SpikeTimit
function _resample_spikes(;spiketimes::Spiketimes, n_feat)
# If we don't reduce the bins
if n_feat == 1
return spike_times
return spiketimes
elseif n_feat > 11 || n_feat < 1
prinln("WARNING; you are crazy, returning original spike_times")
return spike_times
println("WARNING; you are crazy, returning original spike_times")
return spiketimes
end
FREQUENCIES = 20
......@@ -533,13 +514,13 @@ module SpikeTimit
end
function transform_into_bursts(all_ft, all_neurons)
function transform_into_bursts(all_ft, all_neurons; spikes_per_burst_increase=0)
new_all_ft = []
new_all_neurons = []
expdist = Exponential(5)
for (i, time) in enumerate(all_ft)
# determine X (amount of spikes in burst) -> bias dice
values = [2,3,4,5,6]
values = [2,3,4,5,6] .+ spikes_per_burst_increase
weights = [0.8, 0.15, 0.075, 0.035, 0.03] # based on plot 1B 0.7 nA (Oswald, Doiron & Maler (2007))
weights = weights ./ sum(weights) # normalized weights
number_of_spikes = sample(values, Weights(weights)) - 1 # -1 because first spike is determined from data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment