SpikeTimit.jl 19.4 KB
Newer Older
alessio.quaresima's avatar
alessio.quaresima committed
1
2
3
4
5
6
7
8
9
using Base

module SpikeTimit
    using DataFrames
    using DataFramesMeta
    using OrderedCollections
    using MAT
    using Plots
    using ColorSchemes
alessio.quaresima's avatar
alessio.quaresima committed
10
    using Random
alessio.quaresima's avatar
alessio.quaresima committed
11
12
13
    using StatsBase
    using Distributions

alessio.quaresima's avatar
alessio.quaresima committed
14
15
16
17
18
19
20
21
    ## This is the sample rate used in the scripts.
    sr = 16000
    ## This is the rescaling factor for the spike-time discussed in the PDF
    correction = 5.

    import Plots: Series, Plot, Subplot


alessio.quaresima's avatar
alessio.quaresima committed
22
23
24
	"""
	Get words and time intervals from the SenID
	"""
alessio.quaresima's avatar
alessio.quaresima committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    function get_words(root, senID)
      path = joinpath(root, senID*".wrd")
      times0 = []
      times1 = []
      words = []
      for line in readlines(path)
          t0,tf,w = split(line)
          push!(times0,parse(Int,t0))
          push!(times1,parse(Int,tf))
          push!(words, String(w))
      end
      _data = Array{Union{String,Float64},2}(undef,length(times0),3)
      _data[:,1] = words
      _data[:,2] = times0 ./ sr
      _data[:,3] = times1 ./ sr
      return _data
    end

alessio.quaresima's avatar
alessio.quaresima committed
43
44
45
	"""
	Get phonemes and time intervals from the SenID
	"""
alessio.quaresima's avatar
alessio.quaresima committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    function get_phones(root, senID)
      path = joinpath(root, senID*".phn")
      times0 = []
      times1 = []
      phones = []
      for line in readlines(path)
          t0,tf,p = split(line)
          push!(times0,parse(Int,t0))
          push!(times1,parse(Int,tf))
          push!(phones, String(p))
      end
      _data = Array{Union{String,Float64},2}(undef,length(times0),3)
      _data[:,1] = phones
      _data[:,2] = times0 ./ sr
      _data[:,3] = times1 ./ sr
      return _data
    end
alessio.quaresima's avatar
alessio.quaresima committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

    ## Create a dictionary with all words in the dataset
    function create_dictionary(;file)
        dict = OrderedDict()
        for line in readlines(file)
            if !startswith(line, ";")
                word, sounds = split(line,"/")
                push!(dict, word=>sounds)
            end
        end
        return dict
    end

# 	function create_speaker_dataset(;dir)
# df = DataFrame(dir)
# dir = "/home/cocconat/Documents/Research/phd_project/speech/litwin-kumar_model_thesis/DOC/SPKRINFO.TXT"
# using DataFrames
# DataFrames.readtable(dir)
# # , allowcomments=true, commentmark='%')

	function get_dialect(root)
Alessio Quaresima's avatar
Alessio Quaresima committed
84
		return splitpath(root) |> x->parse(Int,filter(startswith("dr"),x)[1][end])
alessio.quaresima's avatar
alessio.quaresima committed
85
86
	end

Leonie1.Wagner's avatar
Leonie1.Wagner committed
87
88
89
90
91
    function create_dataset(;dir)
        df = DataFrame(speaker = String[] , senID = String[], dialect=Int[], gender=Char[], path=String[], words=Array{Union{String, Float64},2}[], phones=Array{Union{String, Float64},2}[], sentence=Vector{String}[])
        for (root, dirs, files) in walkdir(dir)
            for file in files
                if endswith(file,"wav")
Alessio Quaresima's avatar
Alessio Quaresima committed
92
                    speaker = splitpath(root)[end]
Leonie1.Wagner's avatar
Leonie1.Wagner committed
93
94
95
96
97
98
99
100
101
102
103
104
                    senID   = split(file,".")[1]
                    words   = get_words(root, senID)
                    phones  = get_phones(root, senID)
		    		dr = get_dialect(root)
		    		gender = speaker[1]
					sentence = String.(words[:,1])
                    push!(df,(speaker,senID,dr,gender,joinpath(root,senID),words,phones,sentence))
                end
            end
        end
        return df
    end
alessio.quaresima's avatar
alessio.quaresima committed
105

Leonie1.Wagner's avatar
Leonie1.Wagner committed
106
107
108
109
    function find_word(;df::DataFrame, word::String)
        """Search for word in the dataset and return the items"""
        return @linq df |> where(word . :words)
    end
alessio.quaresima's avatar
alessio.quaresima committed
110

alessio.quaresima's avatar
alessio.quaresima committed
111
112
113
	########################
	## Extract .mat files ##
	########################
alessio.quaresima's avatar
alessio.quaresima committed
114

alessio.quaresima's avatar
alessio.quaresima committed
115
    function get_matrix(;df::DataFrame)
alessio.quaresima's avatar
alessio.quaresima committed
116
117
118
119
120
121
122
123
        get_matrix_file(file) =  file*"_binary_matrix.mat"
        return get_matrix_file.(df.path)
    end

    function get_file(file, ext)
        return file*"."*ext
    end

alessio.quaresima's avatar
alessio.quaresima committed
124
125
	Spiketimes = Vector{Vector{Float64}}

alessio.quaresima's avatar
alessio.quaresima committed
126
127

    function get_spiketimes(;df::DataFrame)
alessio.quaresima's avatar
alessio.quaresima committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        get_spiketimes_file(file)= file*"_spike_timing.mat"
		get_array(value::Float64) = begin x=zeros(Float64, 1); x[1] = value; x end
        if length(size(df)) == 1 # only 1 row
            spikes = [read(matopen(get_spiketimes_file(df.path)))["spike_output"][1,:]]
        else
            spikes = map(x->x[1,:],get_spiketimes_file.(df.path) |> x->matopen.(x) |> x-> read.(x) |> x->get.(x,"spike_output", nothing) )
        end
		map(spike->map(row->spike[row] = spike[row][:], findall(typeof.(spike) .==Array{Float64,2})), spikes)
		map(spike->map(row->spike[row] = get_array(spike[row]), findall(typeof.(spike) .==Float64)), spikes)
        map(spike->findall(isa.(spike,Array{Float64,1})) |> x->spike[x] = spike[x]*correction, spikes)
		return map(x->Spiketimes(x), spikes)
    end


alessio.quaresima's avatar
alessio.quaresima committed
142
143
144
	########################
	## Stack spikes together
	########################
alessio.quaresima's avatar
alessio.quaresima committed
145
146

	"""
alessio.quaresima's avatar
alessio.quaresima committed
147
148
	Extract all the firing times and the corresponding neurons from an array with
	all the neurons and their relative firing times. i.e. the inverse_dictionary
alessio.quaresima's avatar
alessio.quaresima committed
149
	"""
alessio.quaresima's avatar
alessio.quaresima committed
150
	function inverse_dictionary(spikes::Spiketimes)
alessio.quaresima's avatar
alessio.quaresima committed
151
152
153
		all_times = Dict()
		for n in eachindex(spikes)
			if !isempty(spikes[n])
alessio.quaresima's avatar
alessio.quaresima committed
154
155
				for tt in spikes[n]
					# tt = round(Int,t*1000/dt) ## from seconds to timesteps
alessio.quaresima's avatar
alessio.quaresima committed
156
157
158
159
160
161
162
163
164
165
166
167
168
					if haskey(all_times,tt)
						all_times[tt] = [all_times[tt]..., n]
					else
						push!(all_times, tt=>[n])
					end
				end
			end
		end
		return all_times
	end


	"""
alessio.quaresima's avatar
alessio.quaresima committed
169
	From the inverse_dictionary data structure obtain 2 arrays that are faster to access in the simulation loop.
alessio.quaresima's avatar
alessio.quaresima committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
	1. First array contains the sorted spike times.
	2. The second array contains vectors with each firing neuron
	"""
	function sort_spikes(dictionary)
		neurons = Array{Vector{Int}}(undef, length(keys(dictionary)))
		sorted = sort(collect(keys(dictionary)))
		for (n,k) in enumerate(sorted)
			neurons[n] = dictionary[k]
		end
		return sorted, neurons
	end


	"""
alessio.quaresima's avatar
alessio.quaresima committed
184
185
	Stack together spiketimes sequences:
		- spikes is an array with inputs in the form Vectr
alessio.quaresima's avatar
alessio.quaresima committed
186
		- durations is the duration in seconds of each encoding
alessio.quaresima's avatar
alessio.quaresima committed
187
	"""
alessio.quaresima's avatar
alessio.quaresima committed
188
	function stack_spiketimes(spikes::Vector{Spiketimes}, durations::Vector{Float64}, silence_time::Float64)
alessio.quaresima's avatar
alessio.quaresima committed
189
190
191
        # for the memory allocation
        nr_unique_fts = 0
        for spike_times in spikes
alessio.quaresima's avatar
alessio.quaresima committed
192
            nr_unique_fts +=length(inverse_dictionary(spike_times))
alessio.quaresima's avatar
alessio.quaresima committed
193
        end
alessio.quaresima's avatar
alessio.quaresima committed
194
195
		all_neurons = Vector{Vector{Int}}(undef, nr_unique_fts)
		all_ft = Vector{Float64}(undef, nr_unique_fts)
alessio.quaresima's avatar
alessio.quaresima committed
196

alessio.quaresima's avatar
alessio.quaresima committed
197
        global_time = 0
alessio.quaresima's avatar
alessio.quaresima committed
198
        filled_indices = 0
alessio.quaresima's avatar
alessio.quaresima committed
199
200
201
202
203
204
		for (spike_times, dd) in zip(spikes, durations)
			# get spiketimes
			sorted, neurons = sort_spikes(inverse_dictionary(spike_times))
			#shift time for each neuron:
			sorted .+= global_time

Alessio Quaresima's avatar
merged    
Alessio Quaresima committed
205

206
207
            ## put them together
            lower_bound = filled_indices + 1
alessio.quaresima's avatar
alessio.quaresima committed
208
209
210
            filled_indices += size(sorted,1)
            all_ft[lower_bound:filled_indices] = sorted
            all_neurons[lower_bound:filled_indices] = neurons
alessio.quaresima's avatar
alessio.quaresima committed
211
212
213
214

			global_time += dd
			global_time += silence_time

alessio.quaresima's avatar
alessio.quaresima committed
215
		end
alessio.quaresima's avatar
alessio.quaresima committed
216
        @assert(size(all_ft) == size(all_neurons))
alessio.quaresima's avatar
alessio.quaresima committed
217
218
219
		return all_ft, all_neurons
	end

alessio.quaresima's avatar
alessio.quaresima committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

	function stack_labels(labels, durations, silence_time)
		phones_transcripts = Transcription()
		words_transcripts = Transcription()
		global_time = 0.
		for (label, dd) in zip(labels, durations)
			@assert(label.duration == dd)
			push!(words_transcripts.intervals,(global_time, global_time+dd))
			push!(words_transcripts.sign,label.word)
			for ph in label.phones
				push!(phones_transcripts.intervals, (global_time+ph.t0, global_time+ph.t1))
				push!(phones_transcripts.sign,ph.ph)
			end
			global_time += dd + silence_time
		end
		return words_transcripts,phones_transcripts
	end

	##########################
	## Get words and phonemes
	##########################

	struct Phone
		ph::String
		t0::Float64
		t1::Float64
246
247
248
249
250
251
252
253
		osc::Array{Float64}
		db::Matrix{Float64}
		function Phone(ph::String, t0::Float64, t1::Float64)
			new(ph, t0, t1, zeros(1), zeros(1,1))
		end
		function Phone(ph::String, t0::Float64, t1::Float64, osc::Vector{Float64}, db::Matrix{Float64})
			new(ph, t0, t1, osc, db)
		end
alessio.quaresima's avatar
alessio.quaresima committed
254
255
256
257
258
259
260
	end

	struct Word
		word::String
		phones::Vector{Phone}
		duration::Float64
		t0::Float64
261
		t1::Float64
alessio.quaresima's avatar
alessio.quaresima committed
262
	end
alessio.quaresima's avatar
alessio.quaresima committed
263

alessio.quaresima's avatar
alessio.quaresima committed
264
265
266
267
268
269
270
271
	struct Transcription
		intervals::Vector{Tuple{Float64,Float64}}
		steps::Vector{Tuple{Int,Int}}
		sign::Vector{String}
		function Transcription()
			new([],[],[])
		end
	end
alessio.quaresima's avatar
alessio.quaresima committed
272

alessio.quaresima's avatar
alessio.quaresima committed
273
274
275
276
277
278
	function convert_to_dt(data::Transcription,dt::Float64)
		for n in eachindex(data.intervals)
			t0,t1 = data.intervals[n]
			push!(data.steps,(round(Int,t0*1000/dt), round(Int, t1*1000/dt)))
		end
		return data
alessio.quaresima's avatar
alessio.quaresima committed
279
280
	end

alessio.quaresima's avatar
alessio.quaresima committed
281
282
283
	function convert_to_dt(data::Vector{Float64},dt::Float64)
		return round.(Int,data ./dt .*1000)
	end
alessio.quaresima's avatar
alessio.quaresima committed
284

alessio.quaresima's avatar
alessio.quaresima committed
285
	"""
alessio.quaresima's avatar
alessio.quaresima committed
286
287
	Extract the word and the phones contained in the datataset and matching with the target word.
	Each row addresses a df entry, each row contains all the phones and word labels, with time intervals.
alessio.quaresima's avatar
alessio.quaresima committed
288
	"""
alessio.quaresima's avatar
alessio.quaresima committed
289
290
	function get_word_labels(;df, word::String)
		df_phones = Vector{Vector{Word}}()
alessio.quaresima's avatar
alessio.quaresima committed
291
		for row in eachrow(df)
alessio.quaresima's avatar
alessio.quaresima committed
292
			all_phones = Vector{Word}()
alessio.quaresima's avatar
alessio.quaresima committed
293
294
			for my_word in eachrow(row.words)
				if my_word[1] == word
alessio.quaresima's avatar
alessio.quaresima committed
295
296
297
298
299
300
301
302
					t0,t1 = my_word[2:3]
					word_phones = Vector{Phone}()
					for phone  in eachrow(row.phones)
						if (phone[2] >= t0) && (phone[3]<= t1)
							ph = Phone(phone[1], phone[2]-t0,phone[3]-t0)
							push!(word_phones, ph)
						end
					end
303
					push!(all_phones, Word(String(my_word[1]), word_phones,t1-t0,t0,t1))
alessio.quaresima's avatar
alessio.quaresima committed
304
305
				end
			end
alessio.quaresima's avatar
alessio.quaresima committed
306
			push!(df_phones, all_phones)
alessio.quaresima's avatar
alessio.quaresima committed
307
		end
alessio.quaresima's avatar
alessio.quaresima committed
308
309
310
311
312
313
		return df_phones
	end

	function get_spikes_in_word(; df, word::String)
        spikes, durations = get_spikes_in_interval(spiketimes = get_spiketimes(df=df), df_intervals = get_interval_word(df=df, word=word))
		return spikes, durations
alessio.quaresima's avatar
alessio.quaresima committed
314
315
316
	end

	"""
alessio.quaresima's avatar
alessio.quaresima committed
317
318
	For each realization of a word in a dataset entry, extract the interval corresponding to it
	Return all the intervals for each dataset entry
alessio.quaresima's avatar
alessio.quaresima committed
319
	"""
alessio.quaresima's avatar
alessio.quaresima committed
320
321
	function get_interval_word(;df, word::String)
		df_intervals = []
alessio.quaresima's avatar
alessio.quaresima committed
322
		for row in eachrow(df)
alessio.quaresima's avatar
alessio.quaresima committed
323
			intervals = []
alessio.quaresima's avatar
alessio.quaresima committed
324
			interval = [0.,0.]
alessio.quaresima's avatar
alessio.quaresima committed
325
			n = 0
alessio.quaresima's avatar
alessio.quaresima committed
326
327
328
			for my_word in eachrow(row.words)
				if my_word[1] == word
					interval = my_word[2:3]
alessio.quaresima's avatar
alessio.quaresima committed
329
					push!(intervals, interval)
alessio.quaresima's avatar
alessio.quaresima committed
330
331
				end
			end
alessio.quaresima's avatar
alessio.quaresima committed
332
			push!(df_intervals, intervals)
alessio.quaresima's avatar
alessio.quaresima committed
333
		end
alessio.quaresima's avatar
alessio.quaresima committed
334
		return df_intervals
alessio.quaresima's avatar
alessio.quaresima committed
335
336
337
	end

	"""
alessio.quaresima's avatar
alessio.quaresima committed
338
	Return the spiketimes subset corresponding to the selected interval, for vectors of Spiketimes
alessio.quaresima's avatar
alessio.quaresima committed
339
	"""
alessio.quaresima's avatar
alessio.quaresima committed
340
	function get_spikes_in_interval(; spiketimes::Union{Spiketimes, Array{Spiketimes}}, df_intervals)
alessio.quaresima's avatar
alessio.quaresima committed
341
		new_spiketimes = Vector{Spiketimes}()
alessio.quaresima's avatar
alessio.quaresima committed
342
		durations = Vector()
alessio.quaresima's avatar
alessio.quaresima committed
343
		if isa(spiketimes,Spiketimes)
alessio.quaresima's avatar
alessio.quaresima committed
344
345
346
347
348
349
			for intervals in df_intervals
				for interval in intervals
					new_spikes, duration  = _get_spikes_in_interval(spiketimes, interval)
					push!(new_spiketimes, new_spikes)
					push!(durations, duration)
				end
alessio.quaresima's avatar
alessio.quaresima committed
350
351
			end
		else
alessio.quaresima's avatar
alessio.quaresima committed
352
353
354
355
356
357
358
			@assert(length(spiketimes) == length(df_intervals))
			for (spikes, intervals) in zip(spiketimes, df_intervals)
				for interval in intervals
					new_spikes, duration  = _get_spikes_in_interval(spikes, interval)
					push!(new_spiketimes, new_spikes)
					push!(durations, duration)
				end
alessio.quaresima's avatar
alessio.quaresima committed
359
			end
alessio.quaresima's avatar
alessio.quaresima committed
360
361
362
363
364
365
		end
		return new_spiketimes, durations
	end


	"""
alessio.quaresima's avatar
alessio.quaresima committed
366
	Return the spiketimes subset corresponding to the selected interval, for one Spiketimes
alessio.quaresima's avatar
alessio.quaresima committed
367
368
	"""
	function _get_spikes_in_interval(spikes, interval)
alessio.quaresima's avatar
alessio.quaresima committed
369
370
371
372
373
374
375
376
377
		new_spiketimes=Spiketimes()
		for neuron in spikes
			neuron_spiketimes = Vector()
			for st in eachindex(neuron)
				if (neuron[st] > interval[1]) && (neuron[st] < interval[end])
					push!(neuron_spiketimes, neuron[st]-interval[1])
				end
			end
			push!(new_spiketimes, neuron_spiketimes)
alessio.quaresima's avatar
alessio.quaresima committed
378
		end
alessio.quaresima's avatar
alessio.quaresima committed
379
		return new_spiketimes, interval[end] - interval[1]
alessio.quaresima's avatar
alessio.quaresima committed
380
381
	end

alessio.quaresima's avatar
alessio.quaresima committed
382

Leonie1.Wagner's avatar
Leonie1.Wagner committed
383
    function select_inputs(; df, words, samples=10, n_feat = 7)
alessio.quaresima's avatar
alessio.quaresima committed
384
385
386
387
388
389
390
        all_spikes = Vector{Spiketimes}()
        all_durations = Vector{Float64}()
        all_labels = []

        for (i, word) in enumerate(words)
            df_word = find_word(word=word, df=df)
            n_occurences = size(df_word,1)
391
			#@show word, n_occurences
alessio.quaresima's avatar
alessio.quaresima committed
392
393
394
395
396
397
398
399
400
401
402
403

            #randomly order the number of occurences to sample
            if samples <= n_occurences
                inds = randperm(n_occurences)[1:samples]
            else
                message = string("WARNING: for word: '", word, "', samples per word (", samples, ") exceeds the number of occurences (", n_occurences, ")")
                @assert false message
            end
			###Get intervals and phonemes for each dataset entry (some can have more than one!)
			spiketimes, durations = get_spikes_in_word(; df=df_word[inds,:], word)
			spiketimes = resample_spikes(spiketimes=spiketimes, n_feat=n_feat)
			labels = vcat(get_word_labels(;df=df_word[inds,:], word=word)...)
404
			#@show length(labels), length(spiketimes)
alessio.quaresima's avatar
alessio.quaresima committed
405
406
407
408
409
410
411

			@assert(length(spiketimes) == length(labels))

			push!(all_spikes, spiketimes...)
			push!(all_durations, durations...)
			push!(all_labels,labels...)

alessio.quaresima's avatar
alessio.quaresima committed
412
		end
Leonie1.Wagner's avatar
Leonie1.Wagner committed
413
		return all_durations, all_spikes, all_labels
alessio.quaresima's avatar
alessio.quaresima committed
414
415
416
417
418
419
420
421
422
423
	end

    function mix_inputs(;durations, spikes, labels, repetitions, silence_time)
		ids = shuffle(repeat(1:length(durations), repetitions))

		all_ft, all_n = stack_spiketimes(spikes[ids], durations[ids], silence_time)
		words_t, phones_t = SpikeTimit.stack_labels(labels[ids],durations[ids],silence_time)
		return all_ft, all_n, words_t, phones_t
	end

Leonie1.Wagner's avatar
Leonie1.Wagner committed
424
425
426
427
	function get_savepoints(;trans::Transcription, n_measure::Int)
		measures = Array{Int64,2}(undef, size(trans.steps,1), n_measure)
		for (i,step) in enumerate(trans.steps)
			l = step[2] - step[1]
alessio.quaresima's avatar
alessio.quaresima committed
428
			l_single = floor(Int, l/n_measure)
Leonie1.Wagner's avatar
Leonie1.Wagner committed
429
			measures[i,1:n_measure] = (step[1] .+ collect(1:n_measure).* l_single)
430
			# push!(measures,step[1] .+ collect(1:n_measure).* l_single)
alessio.quaresima's avatar
alessio.quaresima committed
431
		end
432
		return measures
alessio.quaresima's avatar
alessio.quaresima committed
433
434
435
	end

    """
alessio.quaresima's avatar
alessio.quaresima committed
436
	Unify frequencies bin
alessio.quaresima's avatar
alessio.quaresima committed
437
	Input: spiketimes array with elements of size 620 elements
alessio.quaresima's avatar
alessio.quaresima committed
438
	Return array with sorted spikes in less classes
alessio.quaresima's avatar
alessio.quaresima committed
439
440
441
442
443
444
445
446
447
448
449
450
451

    nfeat	new_bins	Rounding	last_bin
    2	    15,5		<--		    2*15=30 <DO> add last to last bin
    3	    10,333333	<--		    3*10=30 <DO> add last to last bin
    4	    7,75		-->		    4*8=32	<DO>
    5	    6,2		    <--		    5*6=30	<DO> add last to last bin
    6	    5,16666		<--		    6*5=30	<DO> add last to last bin
    7	    4,429		<--		    7*4=28	<DO> add last 3 to last bin
    8	    3,875		-->		    8*4=32	<DO>
    9	    3,44444		<--		    9*3=27	<DO> add last 4 to last bin
    10	    3,1		    <--		    10*3=30	<DO> add last to last bin
    11	    2,818181	-->		    11*3=33	<DO>

alessio.quaresima's avatar
alessio.quaresima committed
452
	"""
alessio.quaresima's avatar
alessio.quaresima committed
453
454
455
456
457
458
459
460
461
462
	function resample_spikes(;spiketimes::Vector{Spiketimes},n_feat)
		for s in 1:length(spiketimes)
			spiketimes[s] = _resample_spikes(spiketimes=spiketimes[s], n_feat=n_feat)
		end
		return spiketimes
	end

	function _resample_spikes(;spiketimes::Spiketimes, n_feat)
        # If we don't reduce the bins
        if n_feat == 1
463
            return spiketimes
alessio.quaresima's avatar
alessio.quaresima committed
464
        elseif n_feat > 11 || n_feat < 1
465
466
            println("WARNING; you are crazy, returning original spike_times")
            return spiketimes
alessio.quaresima's avatar
alessio.quaresima committed
467
468
        end

alessio.quaresima's avatar
alessio.quaresima committed
469
470
		FREQUENCIES = 20

alessio.quaresima's avatar
alessio.quaresima committed
471
472
473
474
475
476
477
        old_bins = convert(Int64, length(spiketimes)/FREQUENCIES)
        @assert (old_bins==31) "WARNING: old_bins != 31, this function is probably broken for other values than 31 (super hardcoded)"
        new_bins = round(Int, old_bins/n_feat - 0.1)
        add_last = 0
        if n_feat*new_bins < 31
            add_last = 31-n_feat*new_bins
        end
alessio.quaresima's avatar
alessio.quaresima committed
478
479
		new_spikes = map(x->Vector{Float64}(),1:new_bins*FREQUENCIES)

alessio.quaresima's avatar
alessio.quaresima committed
480
481
482
483
484
        for freq in 1:FREQUENCIES
            old_freq = (freq-1)*old_bins
            new_freq = (freq-1)*new_bins
            for new_bin in 1:new_bins
                last_bin = new_bin*n_feat <32 ? new_bin*n_feat : 31
alessio.quaresima's avatar
alessio.quaresima committed
485
486
487
                if new_bin == new_bins
                    last_bin += add_last
                end
alessio.quaresima's avatar
alessio.quaresima committed
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
                bins = 1+(new_bin-1)*n_feat : last_bin
                for old_bin in bins
                    push!(new_spikes[new_bin+new_freq], spiketimes[old_bin+old_freq]...)
                end
            end
        end
		return sort!.(new_spikes)
	end



	"""
	Get firing times for each phones list.
	The phones list contains all the relevant phones (and their time interval) in a certain df entry
	"""
	function get_phones_spiketimes(spikes, phones_list)
		phone_spiketimes = []
		phone_labels = []
		for (phones, spike) in zip(phones_list, spikes)
			intervals=[]
			for phone in phones
				push!(intervals,phone[2:3])
				push!(phone_labels,phone[1])
			end
			spiketime, duration = SpikeTimit.get_spikes_in_interval(; spiketimes=spike,intervals=intervals)
			push!(phone_spiketimes,spiketime)
		end
		return phone_labels, vcat(phone_spiketimes...)
alessio.quaresima's avatar
alessio.quaresima committed
516
517
	end

alessio.quaresima's avatar
alessio.quaresima committed
518

519
	function transform_into_bursts(all_ft, all_neurons; spikes_per_burst_increase=0)
alessio.quaresima's avatar
alessio.quaresima committed
520
521
522
523
524
	        new_all_ft = []
	        new_all_neurons = []
	        expdist = Exponential(5)
        	for (i, time) in enumerate(all_ft)
           	 # determine X (amount of spikes in burst) -> bias dice
525
            	values = [2,3,4,5,6] .+ spikes_per_burst_increase
alessio.quaresima's avatar
alessio.quaresima committed
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
            	weights = [0.8, 0.15, 0.075, 0.035, 0.03] # based on plot 1B 0.7 nA (Oswald, Doiron & Maler (2007))
           	 weights = weights ./ sum(weights) # normalized weights
           	 number_of_spikes = sample(values, Weights(weights)) - 1 # -1 because first spike is determined from data
           	 # determine interval from time to spike (for all X new spikes)
           	 push!(new_all_ft, time)
            	push!(new_all_neurons, all_neurons[i])
            	for j in 1:number_of_spikes
            	    interval = rand(expdist)
             	   new_time = time + 4 + interval
            	    push!(new_all_ft, new_time)
            	    push!(new_all_neurons, all_neurons[i])
      	      end
      	  end
    	    # rounding
     	   for i in 1:size(new_all_ft,1)
      	      new_all_ft[i] = round(Int,new_all_ft[i])
      	      new_all_ft[i] = convert(Float64, new_all_ft[i])
    	    end

        # sorting
        zipped = DataFrame(ft = new_all_ft, neurons = new_all_neurons)
        zipped = sort!(zipped, [:ft])
        # if two rows have same time combine into 1 row with both their neurons
        for (i, row) in enumerate(eachrow(zipped))
            if i != size(zipped,1)
                if row.ft == zipped.ft[i+1]     #compare current spike time with next time
                    next_row = copy(zipped.neurons[i+1])    #spiking neurons of next row
                    new_row = vcat(copy(row.neurons), copy(next_row))   # concatenating current neurons + next neurons
                    zipped.neurons[i+1] = copy(new_row)   # assigning all neurons to next time
                    zipped.ft[i] = -1.0    # setting time of the row to -1 (so filter can take it out later)
                end
            end
        end
        zipped = filter(row -> row[:ft] != -1.0, zipped)
        return zipped.ft, zipped.neurons
    end

	########################
    ## Raster Plot 		####
	########################

    struct TracePlot{I,T}
        indices::I
        plt::Plot{T}
        sp::Subplot{T}
        series::Vector{Series}
    end
    function TracePlot(n::Int = 1; maxn::Int = typemax(Int), sp = nothing, kw...)
        clist= get(ColorSchemes.colorschemes[:viridis],range(0,1,length=n))
        indices = if n > maxn
            shuffle(1:n)[1:maxn]
        else
            1:n
        end
        if sp == nothing
            plt = scatter(length(indices);kw...)
            sp = plt[1]
        else
            plt = scatter!(sp, length(indices); kw...)
        end
        for n in indices
            sp.series_list[n].plotattributes[:markercolor]=clist[n]
            sp.series_list[n].plotattributes[:markerstrokecolor]=clist[n]
        end
        TracePlot(indices, plt, sp, sp.series_list)
    end
alessio.quaresima's avatar
alessio.quaresima committed
592

alessio.quaresima's avatar
alessio.quaresima committed
593
594
595
596
    function Base.push!(tp::TracePlot, x::Number, y::AbstractVector)
        push!(tp.series[x], y, x .*ones(length(y)))
    end
    Base.push!(tp::TracePlot, x::Number, y::Number) = push!(tp, [y], x)
alessio.quaresima's avatar
alessio.quaresima committed
597

alessio.quaresima's avatar
alessio.quaresima committed
598
599
600
601
602
603
604
605
606
607
608
609
    function raster_plot(spikes::Spiketimes; ax=nothing,  alpha=0.5)
        s = TracePlot(length(spikes), legend=false, alpha)
		ax = isnothing(ax) ? s : ax
        for (n,z) in enumerate(spikes)
            if length(z)>0
                push!(ax,n,z[:])
            end
        end
        return ax
    end

end