orbital_tomog_3d_data_processor.py 6.42 KB
Newer Older
1
from proxtoolbox.Problems.OrbitalTomog.planar_molecule.orbitaltomog_data_processor import support_from_autocorrelation
2
from proxtoolbox.Utilities.OrbitalTomog import shifted_fft, bin_array, pad_to_square
Matthijs's avatar
Matthijs committed
3
from proxtoolbox.Problems.OrbitalTomog.Graphics.stack_viewer import XYZStackViewer
4
5
import numpy as np
from skimage.io import imread
6
from itertools import combinations
7
8
9


def data_processor(config):
10
    # Load data
11
    inp = imread(config['data_filename'])
12
13
14
15
    # Load measurement sensitivity, determine data_zeros
    sensitivity = imread(config['sensitivity_filename'])
    # Add the negative K_z part to determine a real-valued orbital?
    if 'add_negative_kz' in config and config['add_negative_kz']:
Matthijs's avatar
bugfix    
Matthijs committed
16
        inp, sensitivity = mirror_kspace(inp, sensitivity)
17
18
    ny, nx, nz = inp.shape
    config['data'] = abs(inp)
19
    config['data_zeros'] = sensitivity == 0
20

Matthijs's avatar
Matthijs committed
21
    # Keep the same resolution?
Matthijs's avatar
Matthijs committed
22
    if 'Ny' not in config or 'Nx' not in config or 'Nz' not in config:
23
        config['Ny'], config['Nx'], config['Nz'] = ny, nx, nz
Matthijs's avatar
bugfix    
Matthijs committed
24
        print('Setting problem dimensions based on data, i.e. %s' % str(ny, nx, nz))
Matthijs's avatar
Matthijs committed
25
    elif ny != config['Ny'] or nx != config['Nx'] or nz != config['Nz']:
26
        nandata = np.where(config['data_zeros'], np.nan, config['data'])
27
        if not ('from intensity data' in config and config['from intensity data']):
Matthijs's avatar
Matthijs committed
28
            # binning must be done for the intensity-data, as that preserves the normalization
29
            config['data'] = np.sqrt(np.nan_to_num(bin_array(nandata ** 2, (config['Ny'], config["Nx"], config['Nz']))))
30
        else:
31
32
            config['data'] = np.nan_to_num(bin_array(nandata, (config['Ny'], config["Nx"])))
        config["data_zeros"] = bin_array(config['data_zeros'], (config['Ny'], config["Nx"], config['Nz'])) == 0
33

Matthijs's avatar
Matthijs committed
34
35
    assert config['data_zeros'].shape == config['data'].shape, 'Non-matching sensitivity and data arrays'

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    # Calculate electric field
    if 'from intensity data' in config and config['from intensity data']:
        # avoid sqrt of negative numbers (due to background subtraction)
        config['data'] = np.where(config['data'] > 0,
                                  np.sqrt(abs(config['data'])),
                                  np.zeros_like(config['data']))
    config['norm_data'] = np.sqrt(np.sum(config['data'] ** 2))

    # Object support determination
    try:
        config['support'] = imread(config['support_filename'])
    except KeyError:  # 'support filename does not exist, so define a support here'
        if 'threshold for support' not in config:
            config['threshold for support'] = 0.1
        config['support'] = support_from_autocorrelation(config['data'],
                                                         threshold=config['threshold for support'],
                                                         absolute_autocorrelation=True,
                                                         binary_dilate_support=1)

    if ('use_sparsity_with_support' in config
            and config['use_sparsity_with_support']
            and 'sparsity_support' not in config):
        if 'threshold for support' not in config:
            config['threshold for support'] = 0.1
        config['sparsity_support'] = support_from_autocorrelation(config['data'],
                                                                  threshold=config['threshold for support'],
                                                                  binary_dilate_support=1)

    # Initial guess
65
66
    ph_init = 2 * np.pi * np.random.random_sample(config['data'].shape)
    config['u_0'] = config['data'] * np.exp(1j * ph_init)
67
68

    if config['dataprocessor_plotting']:
69
        input_viewer = XYZStackViewer(config['data'], cmap='viridis')
70
        config['input viewer'] = input_viewer  # the reference keeps the matplotlib object alive
71
        support_viewer = XYZStackViewer(shifted_fft(config['data']).real)
72
        config['support viewer'] = support_viewer  # the reference keeps the matplotlib object alive
73

74
75
76
77
78
    if 'fourier_shift_arrays' in config and config['fourier_shift_arrays']:
        for key in ['u_0', 'support', 'sparsity_support', 'data', 'data_zeros']:
            if key in config:
                config[key] = np.fft.fftshift(config[key])

79
80
81
82
83
84
85
86
    # Other settings
    config['fresnel_nr'] = 0
    config['FT_conv_kernel'] = 1
    config['use_farfield_formula'] = True
    config['magn'] = 1
    config['data_sq'] = abs(config['data']) ** 2

    return config
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137


def mirror_kspace(kspace: np.ndarray, sensitivity: np.ndarray = None, shift_mirror: tuple = (1, 1, 1),
                  square_array: bool = True):
    """
    Mirror a 3d kspace array in the kz=0 plane.

    :param kspace: kspace array, e.g. from arpes_wvlscan_to_kspace
    :param sensitivity: if given (from arpes_wvlscan_to_kspace), test
    :param shift_mirror: array rolling done to get good centering. is tested by sensitivity
    :param square_array: pad array to be square
    :return: 3d arrays kspace and sensitivity (if given)
    """

    def mirror_array(arr: np.ndarray, roll=shift_mirror) -> np.ndarray:
        arr = np.concatenate([np.roll(np.flip(arr), (roll[0], roll[1]), axis=(0, 1)), arr[:, :, 1:]],
                             axis=2)
        arr = np.roll(arr, roll[2], axis=2)
        return arr

    full_kspace = mirror_array(kspace)
    if square_array:
        if np.any(np.isnan(full_kspace)):
            cv = np.nan
        else:
            cv = 0
        full_kspace = pad_to_square(full_kspace, constant_values=cv)
    if sensitivity is None:
        return full_kspace
    else:
        mirrored_sens = mirror_array(sensitivity)
        if np.any(np.isnan(mirrored_sens)):
            cv = np.nan
        else:
            cv = 0
        mirrored_sens = pad_to_square(mirrored_sens, constant_values=cv)
        testslice = [len(mirrored_sens) // i for i in [2, 3, 4]]
        test_array = np.where(np.isnan(mirrored_sens), 0, 1)
        for tsl in testslice:
            testslices = [test_array[tsl],
                          test_array[:, tsl],
                          test_array[:, :, tsl],
                          test_array[-tsl],
                          test_array[:, -tsl],
                          test_array[:, :, -tsl]]
            test_res = []
            for a, b in combinations(testslices, 2):
                test_res += [np.all(np.isclose(a, b))]
            assert np.all(
                test_res), 'Non-matching test slices, indicating that shift_mirror parameter should be changed'
        return full_kspace, mirrored_sens