hrt_pipe_sub.py 27.1 KB
Newer Older
dcalc's avatar
dcalc committed
1
2
3
4
5
import numpy as np
from astropy.io import fits
from scipy.ndimage import gaussian_filter
from operator import itemgetter
from utils import *
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import os
import time
import subprocess

def demod_hrt(data,pmp_temp):
    '''
    Use constant demodulation matrices to demodulate input data
    '''
    if pmp_temp == '50':
        demod_data = np.array([[ 0.28037298,  0.18741922,  0.25307596,  0.28119895],
                     [ 0.40408596,  0.10412157, -0.7225681,   0.20825675],
                     [-0.19126636, -0.5348939,   0.08181918,  0.64422774],
                     [-0.56897295,  0.58620095, -0.2579202,   0.2414017 ]])
        
    elif pmp_temp == '40':
        demod_data = np.array([[ 0.26450154,  0.2839626,   0.12642948,  0.3216773 ],
                     [ 0.59873885,  0.11278069, -0.74991184,  0.03091451],
                     [ 0.10833212, -0.5317737,  -0.1677862,   0.5923593 ],
                     [-0.46916953,  0.47738808, -0.43824592,  0.42579797]])
    
dcalc's avatar
dcalc committed
26
    else:
27
        printc("Demodulation Matrix for PMP TEMP of {pmp_temp} deg is not available", color = bcolors.FAIL)
dcalc's avatar
dcalc committed
28

29
30
31
32
33
    printc(f'Using a constant demodulation matrix for a PMP TEMP of {pmp_temp} deg',color = bcolors.OKGREEN)
    
    demod_data = demod_data.reshape((4,4))
    shape = data.shape
    demod = np.tile(demod_data, (shape[0],shape[1],1,1))
dcalc's avatar
dcalc committed
34

35
36
37
    if data.ndim == 5:
        #if data array has more than one scan
        data = np.moveaxis(data,-1,0) #moving number of scans to first dimension
dcalc's avatar
dcalc committed
38

39
40
        data = np.matmul(demod,data)
        data = np.moveaxis(data,0,-1) #move scans back to the end
dcalc's avatar
dcalc committed
41
    
42
43
44
45
46
47
    elif data.ndim == 4:
        #for if data has just one scan
        data = np.matmul(demod,data)
    
    return data, demod

dcalc's avatar
dcalc committed
48
49

def unsharp_masking(flat,sigma,flat_pmp_temp,cpos_arr,clean_mode,pol_end=4):
50
51
52
    """
    unsharp masks the flat fields to blur our polarimetric structures due to solar rotation
    """
dcalc's avatar
dcalc committed
53
54
    flat_demod, demodM = demod_hrt(flat, flat_pmp_temp)

dcalc's avatar
dcalc committed
55
    norm_factor = np.mean(flat_demod[512:1536,512:1536,0,cpos_arr[0]])
dcalc's avatar
dcalc committed
56
57
58
59
60

    flat_demod /= norm_factor

    new_demod_flats = np.copy(flat_demod)

dcalc's avatar
dcalc committed
61
#     b_arr = np.zeros((2048,2048,3,5))
dcalc's avatar
dcalc committed
62
63
64
65
66
67
68
69

    if cpos_arr[0] == 0:
	    wv_range = range(1,6)

    elif cpos_arr[0] == 5:
	    wv_range = range(5)

    if clean_mode == "QUV":
jonas's avatar
jonas committed
70
71
72
        start_clean_pol = 1
        print("Unsharp Masking Q,U,V")
        
dcalc's avatar
dcalc committed
73
    elif clean_mode == "UV":
jonas's avatar
jonas committed
74
75
76
        start_clean_pol = 2
        print("Unsharp Masking U,V")
	    
dcalc's avatar
dcalc committed
77
    elif clean_mode == "V":
jonas's avatar
jonas committed
78
79
80
        start_clean_pol = 3
        print("Unsharp Masking V")
	    
dcalc's avatar
dcalc committed
81
82
83
84
85
86
87

    for pol in range(start_clean_pol,pol_end):

	    for wv in wv_range: #not the continuum

	        a = np.copy(np.clip(flat_demod[:,:,pol,wv], -0.02, 0.02))
	        b = a - gaussian_filter(a,sigma)
dcalc's avatar
dcalc committed
88
# 	        b_arr[:,:,pol-1,wv-1] = b
dcalc's avatar
dcalc committed
89
90
91
92
93
94
95
96
97
98
	        c = a - b

	        new_demod_flats[:,:,pol,wv] = c

    invM = np.linalg.inv(demodM)

    return np.matmul(invM, new_demod_flats*norm_factor)


def flat_correction(data,flat,flat_states,rows,cols):
99
100
101
    """
    correct science data with flat fields
    """
dcalc's avatar
dcalc committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    if flat_states == 6:
        
        printc("Dividing by 6 flats, one for each wavelength",color=bcolors.OKGREEN)
            
        tmp = np.mean(flat,axis=-2) #avg over pol states for the wavelength

        return data / tmp[rows,cols, np.newaxis, :, np.newaxis]


    elif flat_states == 24:

        printc("Dividing by 24 flats, one for each image",color=bcolors.OKGREEN)

        return data / flat[rows,cols, :, :, np.newaxis] #only one new axis for the scans
            
    elif flat_states == 4:

        printc("Dividing by 4 flats, one for each pol state",color=bcolors.OKGREEN)

        tmp = np.mean(flat,axis=-1) #avg over wavelength

        return data / tmp[rows,cols, :, np.newaxis, np.newaxis]
    else:
        print(" ")
        printc('-->>>>>>> Unable to apply flat correction. Please insert valid flat_states',color=bcolors.WARNING)


def prefilter_correction(data,voltagesData_arr,prefilter,prefilter_voltages):
130
131
132
    """
    applies prefilter correction
    """
dcalc's avatar
dcalc committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    def _get_v1_index1(x):
        index1, v1 = min(enumerate([abs(i) for i in x]), key=itemgetter(1))
        return  v1, index1
    
    data_shape = data.shape
    
    for scan in range(data_shape[-1]):

        voltage_list = voltagesData_arr[scan]
        
        for wv in range(6):

            v = voltage_list[wv]

            vdif = [v - pf for pf in prefilter_voltages]
            
            v1, index1 = _get_v1_index1(vdif)
            
            if vdif[index1] >= 0:
                v2 = vdif[index1 + 1]
                index2 = index1 + 1
                
            else:
                v2 = vdif[index1-1]
                index2 = index1 - 1
                
159
            imprefilter = (prefilter[:,:, index1]*v1 + prefilter[:,:, index2]*v2)/(v1+v2) #interpolation between nearest voltages
dcalc's avatar
dcalc committed
160
161
162
163
164
165

            data[:,:,:,wv,scan] /= imprefilter[...,np.newaxis]
            
    return data

def CT_ItoQUV(data, ctalk_params, norm_stokes, cpos_arr):
166
167
168
    """
    performs cross talk correction for I -> Q,U,V
    """
dcalc's avatar
dcalc committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    before_ctalk_data = np.copy(data)
    data_shape = data.shape
    ceny = slice(data_shape[0]//2 - data_shape[0]//4, data_shape[0]//2 + data_shape[0]//4)
    cenx = slice(data_shape[1]//2 - data_shape[1]//4, data_shape[1]//2 + data_shape[1]//4)
    cont_stokes = np.mean(data[ceny,cenx,0,cpos_arr[0],:], axis = (0,1))
    
    for i in range(6):
                
        stokes_i_wv_avg = np.mean(data[ceny,cenx,0,i,:], axis = (0,1))

        if norm_stokes:
            #if normed, applies normalised offset to normed stokes

            tmp_param = ctalk_params*np.divide(stokes_i_wv_avg,cont_stokes)

            q_slope = tmp_param[0,0,:]
            u_slope = tmp_param[0,1,:]
            v_slope = tmp_param[0,2,:]

            q_int = tmp_param[1,0,:]
            u_int = tmp_param[1,1,:]
            v_int = tmp_param[1,2,:]

            data[:,:,1,i,:] = before_ctalk_data[:,:,1,i,:] - before_ctalk_data[:,:,0,i,:]*q_slope - q_int

            data[:,:,2,i,:] = before_ctalk_data[:,:,2,i,:] - before_ctalk_data[:,:,0,i,:]*u_slope - u_int

            data[:,:,3,i,:] = before_ctalk_data[:,:,3,i,:] - before_ctalk_data[:,:,0,i,:]*v_slope - v_int

        else:
            #if not normed, applies raw offset cross talk correction to raw stokes counts

            tmp_param = ctalk_params[0,:,:]*np.divide(stokes_i_wv_avg,cont_stokes)

            q_slope = tmp_param[0,:]
            u_slope = tmp_param[1,:]
            v_slope = tmp_param[2,:]

            q_int = ctalk_params[1,0,:]
            u_int = ctalk_params[1,1,:]
            v_int = ctalk_params[1,2,:]

            data[:,:,1,i,:] = before_ctalk_data[:,:,1,i,:] - before_ctalk_data[:,:,0,i,:]*q_slope - q_int*stokes_i_wv_avg 

            data[:,:,2,i,:] = before_ctalk_data[:,:,2,i,:] - before_ctalk_data[:,:,0,i,:]*u_slope - u_int*stokes_i_wv_avg 

            data[:,:,3,i,:] = before_ctalk_data[:,:,3,i,:] - before_ctalk_data[:,:,0,i,:]*v_slope - v_int*stokes_i_wv_avg
    
    return data


220
221
222
223
224
225
226
227
228
229
def cmilos(data_f, hdr_arr, wve_axis_arr, data_shape, cpos_arr, data, rte, field_stop, start_row, start_col, out_rte_filename, out_dir):
    """
    RTE inversion using CMILOS
    """
    print(" ")
    printc('-->>>>>>> RUNNING CMILOS ',color=bcolors.OKGREEN)
    
    try:
        CMILOS_LOC = os.path.realpath(__file__)

jonas's avatar
jonas committed
230
        CMILOS_LOC = CMILOS_LOC[:-15] + 'cmilos/' #-11 as hrt_pipe.py is 11 characters
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254

        if os.path.isfile(CMILOS_LOC+'milos'):
            printc("Cmilos executable located at:", CMILOS_LOC,color=bcolors.WARNING)

        else:
            raise ValueError('Cannot find cmilos:', CMILOS_LOC)

    except ValueError as err:
        printc(err.args[0],color=bcolors.FAIL)
        printc(err.args[1],color=bcolors.FAIL)
        return        

    wavelength = 6173.3356

    for scan in range(int(data_shape[-1])):

        start_time = time.time()

        file_path = data_f[scan]
        wave_axis = wve_axis_arr[scan]
        hdr = hdr_arr[scan]

        #must invert each scan independently, as cmilos only takes in one dataset at a time

jonas's avatar
jonas committed
255
        #get wave_axis from the hdr information of the science scans
256
257
258
259
260
261
262
        if cpos_arr[0] == 0:
            shift_w =  wave_axis[3] - wavelength
        elif cpos_arr[0] == 5:
            shift_w =  wave_axis[2] - wavelength

        wave_axis = wave_axis - shift_w

jonas's avatar
jonas committed
263
        print('It is assumed the wavelength array is given by the hdr')
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
        #print(wave_axis,color = bcolors.WARNING)
        print("Wave axis is: ", (wave_axis - wavelength)*1000.)
        print('Saving data into dummy_in.txt for RTE input')

        sdata = data[:,:,:,:,scan]
        y,x,p,l = sdata.shape
        #print(y,x,p,l)

        filename = 'dummy_in.txt'
        with open(filename,"w") as f:
            for i in range(x):
                for j in range(y):
                    for k in range(l):
                        f.write('%e %e %e %e %e \n' % (wave_axis[k],sdata[j,i,0,k],sdata[j,i,1,k],sdata[j,i,2,k],sdata[j,i,3,k])) #wv, I, Q, U, V
        del sdata

        printc(f'  ---- >>>>> Inverting data scan number: {scan} .... ',color=bcolors.OKGREEN)

        cmd = CMILOS_LOC+"./milos"
        cmd = fix_path(cmd)

        if rte == 'RTE':
            rte_on = subprocess.call(cmd+" 6 15 0 0 dummy_in.txt  >  dummy_out.txt",shell=True)
        if rte == 'CE':
            rte_on = subprocess.call(cmd+" 6 15 2 0 dummy_in.txt  >  dummy_out.txt",shell=True)
        if rte == 'CE+RTE':
            rte_on = subprocess.call(cmd+" 6 15 1 0 dummy_in.txt  >  dummy_out.txt",shell=True)

        #print(rte_on)

        printc('  ---- >>>>> Reading results.... ',color=bcolors.OKGREEN)
        del_dummy = subprocess.call("rm dummy_in.txt",shell=True)
        #print(del_dummy)

        res = np.loadtxt('dummy_out.txt')
        npixels = res.shape[0]/12.
        #print(npixels)
        #print(npixels/x)
        result = np.zeros((12,y*x)).astype(float)
        rte_invs = np.zeros((12,y,x)).astype(float)
        for i in range(y*x):
            result[:,i] = res[i*12:(i+1)*12]
        result = result.reshape(12,y,x)
        result = np.einsum('ijk->ikj', result)
        rte_invs = result
        del result
        rte_invs_noth = np.copy(rte_invs)

        """
        From 0 to 11
        Counter (PX Id)
        Iterations
        Strength
        Inclination
        Azimuth
        Eta0 parameter
        Doppler width
        Damping
        Los velocity
        Constant source function
        Slope source function
        Minimum chisqr value
        """

        noise_in_V =  np.mean(data[:,:,3,cpos_arr[0],:])
        low_values_flags = np.max(np.abs(data[:,:,3,:,scan]),axis=-1) < noise_in_V  # Where values are low
        
        rte_invs[2,low_values_flags] = 0
        rte_invs[3,low_values_flags] = 0
        rte_invs[4,low_values_flags] = 0

        #np.savez_compressed(out_dir+'_RTE', rte_invs=rte_invs, rte_invs_noth=rte_invs_noth)
        
        del_dummy = subprocess.call("rm dummy_out.txt",shell=True)
        #print(del_dummy)

        """
        #vlos S/C vorrection
        v_x, v_y, v_z = hdr['HCIX_VOB']/1000, hdr['HCIY_VOB']/1000, hdr['HCIZ_VOB']/1000

        #need line of sight velocity, should be total HCI velocity in km/s, with sun at origin. 
        #need to take care for velocities moving towards the sun, (ie negative) #could use continuum position as if towards or away
    
        if cpos_arr[scan] == 5: #moving away, redshifted
            dir_factor = 1
        
        elif cpos_arr[scan] == 0: #moving towards, blueshifted
            dir_factor == -1
        
        v_tot = dir_factor*math.sqrt(v_x**2 + v_y**2+v_z**2) #in km/s

        rte_invs_noth[8,:,:] = rte_invs_noth[8,:,:] - v_tot
        """

        rte_data_products = np.zeros((6,rte_invs_noth.shape[1],rte_invs_noth.shape[2]))

        rte_data_products[0,:,:] = rte_invs_noth[9,:,:] + rte_invs_noth[10,:,:] #continuum
        rte_data_products[1,:,:] = rte_invs_noth[2,:,:] #b mag strength
        rte_data_products[2,:,:] = rte_invs_noth[3,:,:] #inclination
        rte_data_products[3,:,:] = rte_invs_noth[4,:,:] #azimuth
        rte_data_products[4,:,:] = rte_invs_noth[8,:,:] #vlos
        rte_data_products[5,:,:] = rte_invs_noth[2,:,:]*np.cos(rte_invs_noth[3,:,:]*np.pi/180.) #blos

        rte_data_products *= field_stop[np.newaxis,start_row:start_row + data.shape[0],start_col:start_col + data.shape[1]] #field stop, set outside to 0

        if out_rte_filename is None:
            filename_root = str(file_path.split('.fits')[0][-10:])
        else:
            if isinstance(out_rte_filename, list):
                filename_root = out_rte_filename[scan]

            elif isinstance(out_rte_filename, str):
                filename_root = out_rte_filename

            else:
                filename_root = str(file_path.split('.fits')[0][-10:])
                print(f"out_rte_filename neither string nor list, reverting to default: {filename_root}")

        with fits.open(file_path) as hdu_list:
jonas's avatar
jonas committed
383
            hdu_list[0].hdr = hdr
384
385
386
387
            hdu_list[0].data = rte_data_products
            hdu_list.writeto(out_dir+filename_root+'_rte_data_products.fits', overwrite=True)

        with fits.open(file_path) as hdu_list:
jonas's avatar
jonas committed
388
            hdu_list[0].hdr = hdr
389
390
391
392
            hdu_list[0].data = rte_data_products[5,:,:]
            hdu_list.writeto(out_dir+filename_root+'_blos_rte.fits', overwrite=True)

        with fits.open(file_path) as hdu_list:
jonas's avatar
jonas committed
393
            hdu_list[0].hdr = hdr
394
395
396
397
            hdu_list[0].data = rte_data_products[4,:,:]
            hdu_list.writeto(out_dir+filename_root+'_vlos_rte.fits', overwrite=True)

        with fits.open(file_path) as hdu_list:
jonas's avatar
jonas committed
398
            hdu_list[0].hdr = hdr
399
400
401
402
403
404
405
            hdu_list[0].data = rte_data_products[0,:,:]
            hdu_list.writeto(out_dir+filename_root+'_Icont_rte.fits', overwrite=True)

        printc('--------------------------------------------------------------',bcolors.OKGREEN)
        printc(f"------------- CMILOS RTE Run Time: {np.round(time.time() - start_time,3)} seconds ",bcolors.OKGREEN)
        printc('--------------------------------------------------------------',bcolors.OKGREEN)

jonas's avatar
jonas committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
def cmilos_fits(data_f, hdr_arr, wve_axis_arr, data_shape, cpos_arr, data, rte, field_stop, start_row, start_col, out_rte_filename, out_dir):
    """
    RTE inversion using CMILOS
    """
    print(" ")
    printc('-->>>>>>> RUNNING CMILOS ',color=bcolors.OKGREEN)
    
    try:
        CMILOS_LOC = os.path.realpath(__file__)

        CMILOS_LOC = CMILOS_LOC[:-15] + 'cmilos-fits/' #-11 as hrt_pipe.py is 11 characters

        if os.path.isfile(CMILOS_LOC+'milos'):
            printc("Cmilos-fits executable located at:", CMILOS_LOC,color=bcolors.WARNING)

        else:
            raise ValueError('Cannot find cmilos-fits:', CMILOS_LOC)

    except ValueError as err:
        printc(err.args[0],color=bcolors.FAIL)
        printc(err.args[1],color=bcolors.FAIL)
        return        

    wavelength = 6173.3356

    for scan in range(int(data_shape[-1])):

        start_time = time.time()

        file_path = data_f[scan]
        wave_axis = wve_axis_arr[scan]
        hdr = hdr_arr[scan]

        #must invert each scan independently, as cmilos only takes in one dataset at a time

        #get wave_axis from the hdr information of the science scans
        if cpos_arr[0] == 0:
            shift_w =  wave_axis[3] - wavelength
        elif cpos_arr[0] == 5:
            shift_w =  wave_axis[2] - wavelength

        wave_axis = wave_axis - shift_w

        print('It is assumed the wavelength array is given by the hdr')
        #print(wave_axis,color = bcolors.WARNING)
        print("Wave axis is: ", (wave_axis - wavelength)*1000.)
        print('Saving data into dummy_in.txt for RTE input')

        sdata = data[:,:,:,:,scan]
        y,x,p,l = sdata.shape

        #create hdr with wavelength positions
        hdr = fits.Header()
        print(wave_axis[0])
        hdr['LAMBDA0'] = wave_axis[0]#needs it in Angstrom 6173.1 etc
        hdr['LAMBDA1'] = wave_axis[1]
        hdr['LAMBDA2'] = wave_axis[2]
        hdr['LAMBDA3'] = wave_axis[3]
        hdr['LAMBDA4'] = wave_axis[4]
        hdr['LAMBDA5'] = wave_axis[5]

        #write out data to temp fits for cmilos-fits input
468
        input_arr = np.transpose(sdata, axes = (3,2,0,1)) #must transpose due to cfitsio
jonas's avatar
jonas committed
469
470
471
        hdu1 = fits.PrimaryHDU(data=input_arr, header = hdr)

        #mask
472
        mask = np.ones((sdata.shape[0],sdata.shape[1])) #change this for fdt
jonas's avatar
jonas committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
        hdu2 = fits.ImageHDU(data=mask)

        #write out to temp fits
        hdul_tmp = fits.HDUList([hdu1, hdu2])
        hdul_tmp.writeto(out_dir+'temp_cmilos_io.fits', overwrite=True)
        
        del sdata

        printc(f'  ---- >>>>> Inverting data scan number: {scan} .... ',color=bcolors.OKGREEN)

        cmd = CMILOS_LOC+"milos"
        #cmd = fix
        #fix_path(cmd)
        print(cmd)

        if rte == 'RTE':
            rte_on = subprocess.call(cmd+f" 6 15 0 {out_dir+'temp_cmilos_io.fits'}",shell=True)
        if rte == 'CE':
            rte_on = subprocess.call(cmd+f" 6 15 2 {out_dir+'temp_cmilos_io.fits'}",shell=True)
        if rte == 'CE+RTE':
            rte_on = subprocess.call(cmd+f" 6 15 1 {out_dir+'temp_cmilos_io.fits'}",shell=True)

        print(rte_on)

        printc('  ---- >>>>> Reading results.... ',color=bcolors.OKGREEN)
        #print(del_dummy)

        with fits.open(out_dir+'temp_cmilos_io.fits') as hdu_list:
501
502
            rte_out = hdu_list[0].data
            #hdu_list.writeto(out_dir+'rte_out.fits', overwrite=True)
jonas's avatar
jonas committed
503
504
505
506
507
508
509
510
511
512
513
        
        del input_arr

        """
        From 0 to 11
        Iterations
        Strength
        Inclination
        Azimuth
        Eta0 parameter
        Doppler width
514
        Damping/aa
jonas's avatar
jonas committed
515
        Los velocity
516
        alfa? Counter PID?
jonas's avatar
jonas committed
517
518
519
520
521
522
        Constant source function
        Slope source function
        Minimum chisqr value
        """

        """
523
        Direct from cmilos-fits/milos.c
jonas's avatar
jonas committed
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
        inv->iter = malloc(npix*sizeof(int));
        inv->B    = malloc(npix*sizeof(double));
        inv->gm   = malloc(npix*sizeof(double));
        inv->az   = malloc(npix*sizeof(double));
        inv->eta0 = malloc(npix*sizeof(double));
        inv->dopp = malloc(npix*sizeof(double));
        inv->aa   = malloc(npix*sizeof(double));
        inv->vlos = malloc(npix*sizeof(double)); //km/s
        inv->alfa = malloc(npix*sizeof(double)); //stay light factor
        inv->S0   = malloc(npix*sizeof(double));
        inv->S1   = malloc(npix*sizeof(double));
        inv->nchisqrf = malloc(npix*sizeof(double));
        
        """

539
540
541
        """
        noise_in_V =  np.mean(data[:,:,3,cpos_arr[0],:])
        low_values_flags = np.max(np.abs(data[:,:,3,:,scan]),axis=-1) < noise_in_V  # Where values are low
jonas's avatar
jonas committed
542
        
543
544
545
546
547
548
        rte_out[2,low_values_flags] = 0 #not sure about 2,3,4 indexing here
        rte_out[3,low_values_flags] = 0
        rte_out[4,low_values_flags] = 0
        """
       
        rte_data_products = np.zeros((6,rte_out.shape[1],rte_out.shape[2]))
jonas's avatar
jonas committed
549

550
551
552
553
554
555
        rte_data_products[0,:,:] = rte_out[9,:,:] + rte_out[10,:,:] #continuum
        rte_data_products[1,:,:] = rte_out[1,:,:] #b mag strength
        rte_data_products[2,:,:] = rte_out[2,:,:] #inclination
        rte_data_products[3,:,:] = rte_out[3,:,:] #azimuth
        rte_data_products[4,:,:] = rte_out[7,:,:] #vlos
        rte_data_products[5,:,:] = rte_out[1,:,:]*np.cos(rte_out[2,:,:]*np.pi/180.) #blos
jonas's avatar
jonas committed
556

557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        rte_data_products *= field_stop[np.newaxis,start_row:start_row + data.shape[0],start_col:start_col + data.shape[1]] #field stop, set outside to 0

        if out_rte_filename is None:
            filename_root = str(file_path.split('.fits')[0][-10:])
        else:
            if isinstance(out_rte_filename, list):
                filename_root = out_rte_filename[scan]

            elif isinstance(out_rte_filename, str):
                filename_root = out_rte_filename

            else:
                filename_root = str(file_path.split('.fits')[0][-10:])
                print(f"out_rte_filename neither string nor list, reverting to default: {filename_root}")

        with fits.open(file_path) as hdu_list:
            hdu_list[0].hdr = hdr
            hdu_list[0].data = rte_data_products
            hdu_list.writeto(out_dir+filename_root+'_rte_data_products.fits', overwrite=True)

        with fits.open(file_path) as hdu_list:
            hdu_list[0].hdr = hdr
            hdu_list[0].data = rte_data_products[5,:,:]
            hdu_list.writeto(out_dir+filename_root+'_blos_rte.fits', overwrite=True)
jonas's avatar
jonas committed
581

582
583
584
585
586
587
588
589
590
        with fits.open(file_path) as hdu_list:
            hdu_list[0].hdr = hdr
            hdu_list[0].data = rte_data_products[4,:,:]
            hdu_list.writeto(out_dir+filename_root+'_vlos_rte.fits', overwrite=True)

        with fits.open(file_path) as hdu_list:
            hdu_list[0].hdr = hdr
            hdu_list[0].data = rte_data_products[0,:,:]
            hdu_list.writeto(out_dir+filename_root+'_Icont_rte.fits', overwrite=True)
jonas's avatar
jonas committed
591
592
593
594
595

        printc('--------------------------------------------------------------',bcolors.OKGREEN)
        printc(f"------------- CMILOS RTE Run Time: {np.round(time.time() - start_time,3)} seconds ",bcolors.OKGREEN)
        printc('--------------------------------------------------------------',bcolors.OKGREEN)

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630

def pmilos(data_f, wve_axis_arr, data_shape, cpos_arr, data, rte, field_stop, start_row, start_col, out_rte_filename, out_dir):
    """
    RTE inversion using PMILOS
    """
    print(" ")
    printc('-->>>>>>> RUNNING PMILOS ',color=bcolors.OKGREEN)
    
    try:
        PMILOS_LOC = os.path.realpath(__file__)

        PMILOS_LOC = PMILOS_LOC[:-8] + 'P-MILOS/' #11 as hrt_pipe.py is 11 characters -8 if in utils.py

        if os.path.isfile(PMILOS_LOC+'pmilos.x'):
            printc("Pmilos executable located at:", PMILOS_LOC,color=bcolors.WARNING)

        else:
            raise ValueError('Cannot find pmilos:', PMILOS_LOC)

    except ValueError as err:
        printc(err.args[0],color=bcolors.FAIL)
        printc(err.args[1],color=bcolors.FAIL)
        return  
    
    wavelength = 6173.3356

    for scan in range(int(data_shape[-1])):

        start_time = time.time()

        file_path = data_f[scan]
        wave_axis = wve_axis_arr[scan]

        #must invert each scan independently, as cmilos only takes in one dataset at a time

jonas's avatar
jonas committed
631
        #get wave_axis from the hdr information of the science scans
632
633
634
635
636
637
638
        if cpos_arr[0] == 0:
            shift_w =  wave_axis[3] - wavelength
        elif cpos_arr[0] == 5:
            shift_w =  wave_axis[2] - wavelength

        wave_axis = wave_axis - shift_w

jonas's avatar
jonas committed
639
        print('It is assumed the wavelength array is given by the hdr')
640
641
642
643
644
645
646
647
648
649
650
651
        #print(wave_axis,color = bcolors.WARNING)
        print("Wave axis is: ", (wave_axis - wavelength)*1000.)
        print('Saving data into ./P-MILOS/run/data/input_tmp.fits for pmilos RTE input')

        #write wavelengths to wavelength.fits file for the settings

        wave_input = np.zeros((2,6)) #cfitsio reads dimensions in opposite order
        wave_input[0,:] = 1
        wave_input[1,:] = wave_axis

        print(wave_axis)

jonas's avatar
jonas committed
652
        hdr = fits.hdr()
653

jonas's avatar
jonas committed
654
        primary_hdu = fits.PrimaryHDU(wave_input, hdr = hdr)
655
656
657
658
659
660
        hdul = fits.HDUList([primary_hdu])
        hdul.writeto(f'./P-MILOS/run/wavelength_tmp.fits', overwrite=True)

        sdata = data[:,:,:,:,scan]

        #create input fits file for pmilos
jonas's avatar
jonas committed
661
        hdr = fits.hdr() 
662
663
664
665
666
667
        
        hdr['CTYPE1'] = 'HPLT-TAN'
        hdr['CTYPE2'] = 'HPLN-TAN'
        hdr['CTYPE3'] = 'STOKES' #check order of stokes
        hdr['CTYPE4'] = 'WAVE-GRI' 
    
jonas's avatar
jonas committed
668
        primary_hdu = fits.PrimaryHDU(sdata, hdr = hdr)
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
        hdul = fits.HDUList([primary_hdu])
        hdul.writeto(f'./P-MILOS/run/data/input_tmp.fits', overwrite=True)

        if rte == 'RTE':
            cmd = "mpiexec -np 10 ../pmilos.x pmilos.minit"
        
        if rte == 'CE':
            cmd = "mpiexec -np 10 ../pmilos.x pmilos_ce.minit"

        if rte == 'CE+RTE':
            print("CE+RTE not possible on PMILOS, performing RTE instead")
            cmd = "mpiexec -np 10 ../pmilos.x pmilos.minit"

       
        del sdata
        #need to change settings for CE or CE+RTE in the pmilos.minit file here
        
        printc(f'  ---- >>>>> Inverting data scan number: {scan} .... ',color=bcolors.OKGREEN)

        cwd = os.getcwd()
        os.chdir("./P-MILOS/run/")
        rte_on = subprocess.call(cmd,shell=True)
        os.chdir(cwd)

        if rte == 'CE':
            out_file = 'inv_input_tmp_mod_ce.fits'

        else:
            out_file = 'inv_input_tmp_mod.fits'

        with fits.open(f'./P-MILOS/run/results/{out_file}') as hdu_list:
            result = hdu_list[0].data

        #del_dummy = subprocess.call(f"rm ./P-MILOS/run/results/{out_file}.fits",shell=True) 
        del_dummy = subprocess.call(f"rm ./P-MILOS/run/results/{out_file}.fits",shell=True) #must delete the output file
      
        #result has dimensions [rows,cols,13]
        print(result.shape)
        """
        PMILOS Output 13 columns
        0. eta0 = line-to-continuum absorption coefficient ratio 
        1. B = magnetic field strength [Gauss] 
        2. vlos = line-of-sight velocity [km/s] 
        3. dopp = Doppler width [Angstroms] 
        4. aa = damping parameter 
        5. gm = magnetic field inclination [deg] 
        6. az = magnetic field azimuth [deg] 
        7. S0 = source function constant 
        8. S1 = source function gradient 
        9. mac = macroturbulent velocity [km/s] 
        10. filling factor of the magnetic component [0-1]  
        11. Number of iterations performed 
        12. Chisqr value
        """

        rte_data_products = np.zeros((6,result.shape[0],result.shape[1]))

        rte_data_products[0,:,:] = result[:,:,7] + result[:,:,8] #continuum
        rte_data_products[1,:,:] = result[:,:,1] #b mag strength
        rte_data_products[2,:,:] = result[:,:,5] #inclination
        rte_data_products[3,:,:] = result[:,:,6] #azimuth
        rte_data_products[4,:,:] = result[:,:,2] #vlos
        rte_data_products[5,:,:] = result[:,:,1]*np.cos(result[:,:,5]*np.pi/180.) #blos

        rte_data_products *= field_stop[np.newaxis,start_row:start_row + data.shape[0],start_col:start_col + data.shape[1]] #field stop, set outside to 0

        #flipping taken care of for the field stop in the hrt_pipe 

        if out_rte_filename is None:
            filename_root = str(file_path.split('.fits')[0][-10:])
        else:
            filename_root = out_rte_filename

        with fits.open(file_path) as hdu_list:
            hdu_list[0].data = rte_data_products
            hdu_list.writeto(out_dir+filename_root+'_rte_data_products.fits', overwrite=True)

        with fits.open(file_path) as hdu_list:
            hdu_list[0].data = rte_data_products[5,:,:]
            hdu_list.writeto(out_dir+filename_root+'_blos_rte.fits', overwrite=True)

        with fits.open(file_path) as hdu_list:
            hdu_list[0].data = rte_data_products[4,:,:]
            hdu_list.writeto(out_dir+filename_root+'_vlos_rte.fits', overwrite=True)

        with fits.open(file_path) as hdu_list:
            hdu_list[0].data = rte_data_products[0,:,:]
            hdu_list.writeto(out_dir+filename_root+'_Icont_rte.fits', overwrite=True)


    printc('--------------------------------------------------------------',bcolors.OKGREEN)
    printc(f"------------- PMILOS RTE Run Time: {np.round(time.time() - start_time,3)} seconds ",bcolors.OKGREEN)
    printc('--------------------------------------------------------------',bcolors.OKGREEN)