planar_molecule.py 11.2 KB
Newer Older
1
2
3
4
5
from skimage.io import imread
from scipy.ndimage import binary_dilation
import numpy as np
import matplotlib.pyplot as plt

6
from proxtoolbox.experiments.orbitaltomography.orbitExperiment import OrbitalTomographyExperiment
7
from proxtoolbox.utils.visualization.complex_field_visualization import complex_to_rgb
8
from proxtoolbox.utils.orbitaltomog import bin_array, shifted_fft, shifted_ifft, fourier_interpolate, roll_to_pos
9

10
11
12

class PlanarMolecule(OrbitalTomographyExperiment):
    @staticmethod
13
    def getDefaultParameters():
14
15
        defaultParams = {
            'experiment_name': '2D ARPES',
16
17
            'data_filename': None,
            'from_intensity_data': False,
18
19
            'object': 'real',
            'constraint': 'sparse real',
20
            'sparsity_parameter': 40,
21
22
23
            'use_sparsity_with_support': False,
            'threshold_for_support': 0.1,
            'support_filename': None,
24
25
            'Nx': None,
            'Ny': None,
26
27
28
29
30
            'Nz': 1,
            'MAXIT': 500,
            'TOL': 1e-10,
            'lambda_0': 0.85,
            'lambda_max': 0.50,
31
32
33
            'lambda_switch': 50,
            'data_ball': .999826,
            'TOL2': 1e-15,
34
            'diagnostic': True,
35
            'algorithm': 'DRl',
36
            'iterate_monitor_name': 'FeasibilityIterateMonitor',  # 'IterateMonitor',  #
37
38
39
            'rotate': False,
            'verbose': 1,
            'graphics': 1,
40
            'interpolate_and_zoom': True,
41
            'debug': True,
42
43
44
45
46
47
48
        }
        return defaultParams

    def __init__(self, **kwargs):
        super(PlanarMolecule, self).__init__(**kwargs)

        # do here any data member initialization
49
50
51
52
53
54
        self.data_filename = kwargs['data_filename']  # given as input to class normally
        self.from_intensity_data = kwargs['from_intensity_data']
        self.support_filename = kwargs['support_filename']  # optional class argument
        self.sparsity_parameter = kwargs['sparsity_parameter']
        self.use_sparsity_with_support = kwargs['use_sparsity_with_support']
        self.threshold_for_support = kwargs['threshold_for_support']  # to determine support from the autocorrelation.
55
        self.interpolate_and_zoom = kwargs['interpolate_and_zoom']  # For plotting
56
57

        # the following data members are set by loadData(), in addition to those specified in parent classes
58
59
        self.support = None  # support
        self.sparsity_support = None
Matthijs's avatar
Matthijs committed
60
        self.data_zeros = None
61

62
    def loadData(self):
63
64
        """
        Load data and set in the correct format for reconstruction
65
        Parameters are taken from experiment class (self) properties, which must include::
66
67
68
69
70
71
72
73
74
            - data_filename: str, path to the data file
            - from_intensity_data: bool, if the data file gives intensities rather than field amplitude
            - support_filename: str, optional path to file with object support
            - use_sparsity_with_support: bool, if true, use a support before the sparsity constraint.
                The support is calculated by thresholding the object autocorrelation, and dilate the result
            - threshold_for_support: float, in range [0,1], fraction of the maximum at which to threshold when
                determining support or support for sparsity
        """
        # load data
75
76
        if self.data_filename is None:
            self.data_filename = input('Please enter the path to the datafile: ')
77
        try:
78
            self.data = imread(self.data_filename)
79
        except FileNotFoundError:
80
81
82
            print("Tried path %s, found nothing. " % self.data_filename)
            self.data_filename = input('Please enter a valid path to the datafile: ')
            self.data = imread(self.data_filename)
83
84
85
86

        # Keep the same resolution?
        ny, nx = self.data.shape
        try:
87
            if ny % self.Ny == 0 and nx % self.Nx == 0:
88
                # binning must be done for the intensity-data, as that preserves the normalization
89
90
                if self.from_intensity_data:
                    self.data = bin_array(self.data, (self.Ny, self.Nx))
91
                else:
92
                    self.data = np.sqrt(bin_array(self.data ** 2, (self.Ny, self.Nx)))
93
94
95
            else:
                # TODO: use flexibility allowed by the binning function to prevent this error
                raise ValueError('Incompatible values for Ny and Nx given in configuration dict')
96
        except TypeError:
97
98
99
100
            pass
        self.Ny, self.Nx = self.data.shape

        # Calculate electric field and norm of the data
101
        if self.from_intensity_data:
102
103
104
105
106
            # avoid sqrt of negative numbers (due to background subtraction)
            self.data = np.where(self.data > 0, np.sqrt(abs(self.data)), np.zeros_like(self.data))
        self.norm_data = np.sqrt(np.sum(self.data ** 2))

        # Object support determination
107
108
109
        if self.support is not None:
            self.support = imread(self.support_filename)
        else:
110
            self.support = support_from_autocorrelation(self.data,
111
                                                        threshold=self.threshold_for_support,
112
113
                                                        absolute_autocorrelation=True,
                                                        binary_dilate_support=1)
114
115
116
117
        if self.use_sparsity_with_support:
            self.sparsity_support = support_from_autocorrelation(self.data,
                                                                 threshold=self.threshold_for_support,
                                                                 binary_dilate_support=1)
118
119
        self.createRandomGuess()

Matthijs's avatar
Matthijs committed
120
121
122
123
        # some variables wich are necessary for the algorithm:
        self.data_sq = self.data ** 2
        self.data_zeros = np.where(self.data == 0)

124
125
126
    def createRandomGuess(self):
        """
        Taking the measured data, add a random phase and calculate the resulting iterate guess
127
        """
128
        ph_init = 2 * np.pi * np.random.random_sample(self.data.shape)
129
130
131
132
133
        self.u0 = self.data * np.exp(1j * ph_init)
        self.u0 = np.fft.fftn(self.u0)

    def plotInputData(self):
        """Quick plotting routine to show the data, initial guess and the sparsity support"""
134
135
136
        fig, ax = plt.subplots(1, 3, figsize=(12, 3.5))
        im = ax[0].imshow(self.data)
        plt.colorbar(im, ax=ax[0])
137
        plt.title("Photoelectron spectrum")
138
        ax[1].imshow(complex_to_rgb(self.u0))
139
        plt.title("Initial guess")
140
141
142
        if self.sparsity_support is not None:
            im = ax[2].imshow(self.sparsity_support, cmap='gray')
            # plt.colorbar(im, ax=ax[2])
143
144
            plt.title("Sparsity support")
        plt.show()
145
146

    def show(self):
147
148
149
        """
        Create basic result plots of the phase retrieval procedure
        """
150
        self.output['plots'] = []
151
152
153
154
155
156
157
158
159
160
161
        self.output['u1'] = self.algorithm.prox1.eval(self.algorithm.u)
        self.output['u2'] = self.algorithm.prox2.eval(self.algorithm.u)
        if self.interpolate_and_zoom:
            for key in ["u1", "u2"]:
                center = tuple([s // 2 for s in self.output[key].shape])
                self.output[key] = roll_to_pos(self.output[key], pos=center, move_maximum=True)
                self.output[key] = roll_to_pos(self.output[key], pos=center)
                self.output[key] = fourier_interpolate(self.output[key], 2)
                zmy, zmx = tuple([s // 4 for s in self.output[key].shape])
                self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx]
        u1 = self.output['u1']
162
        u2 = self.output['u2']
163
        change = self.output['stats']['changes']
164

165
        f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))
166
        im = ax1.imshow(abs(u1), cmap='gray')
167
168
        f.colorbar(im, ax=ax1)
        ax1.set_title('best approximation amplitude: \nphysical constraint satisfied')
169
        ax2.imshow(complex_to_rgb(u1))
170
171
172
173
        ax2.set_title('best approximation phase: \nphysical constraint satisfied')
        im = ax3.imshow(abs(u2), cmap='gray')
        f.colorbar(im, ax=ax3)
        ax3.set_title('best approximation amplitude: \nFourier constraint satisfied')
174
        ax4.imshow(complex_to_rgb(u2))
175
176
177
        ax4.set_title('best approximation amplitude: \nFourier constraint satisfied')
        f.tight_layout()

178
        g, ((bx1, bx2), (bx3, bx4)) = plt.subplots(2, 2, figsize=(9, 6))
179
        im = bx1.imshow(abs(u1), cmap='gray')
180
181
        f.colorbar(im, ax=bx1)
        bx1.set_title('best approximation amplitude: \nphysical constraint satisfied')
182
        im = bx2.imshow(u1.real, cmap='gray')
183
184
185
186
187
188
189
190
191
192
193
        f.colorbar(im, ax=bx2)
        bx2.set_title('best approximation phase: \nphysical constraint satisfied')
        bx3.semilogy(change)
        bx3.set_xlabel('iteration')
        bx3.set_ylabel('Change')
        if 'gaps' in self.output['stats']:
            gaps = self.output['stats']['gaps']
            bx4.semilogy(gaps)
            bx4.set_xlabel('iteration')
            bx4.set_ylabel('Gap')
        f.tight_layout()
194
195
196
197
198
199
200
201

        h, ax = plt.subplots(1,3,figsize=(9,3))
        ax[0].imshow(self.data)
        prop = self.propagator(self)
        u_hat = prop.eval(self.algorithm.prox1.eval(self.algorithm.u))
        ax[1].imshow(abs(u_hat))
        ax[2].imshow(complex_to_rgb(u_hat))
        h.tight_layout()
202
        plt.show()
203

204
205
        self.output['plots'].append(f)
        self.output['plots'].append(g)
206
        self.output['plots'].append(h)
207

208
209
210
    # def saveOutput(self, **kwargs):
    #     super(PlanarMolecule, self).saveOutput(**kwargs)

211

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def support_from_autocorrelation(input_array: np.ndarray,
                                 threshold: float = 0.1,
                                 relative_threshold: bool = True,
                                 input_in_fourier_domain: bool = True,
                                 absolute_autocorrelation: bool = True,
                                 binary_dilate_support: int = 0) -> np.ndarray:
    """
    Determine an initial support from the autocorrelation of an object.

    Args:
        input_array: either the measured diffraction (arpes pattern) or a guess of the object
        threshold: support is everywhere where the autocorrelation is higher than the threshold
        relative_threshold: If true, threshold at threshold*np.amax(autocorrelation)
        input_in_fourier_domain: False if a guess of the object is given in input_array
        absolute_autocorrelation: Take the absolute value of the autocorrelation? (Generally a
            good idea for objects which are not non-negative)
        binary_dilate_support: number of dilation operations to apply to the support.

    Returns:
        support array (same dimensions as input, dtype=np.int)
    """
    if not input_in_fourier_domain:
        kspace = shifted_fft(input_array)
    else:
        kspace = input_array

    autocorrelation = shifted_ifft(abs(kspace))  # Taking absolute value yields autocorrelation by conv. theorem)
    if absolute_autocorrelation:
        autocorrelation = abs(autocorrelation)
    maxval = np.amax(autocorrelation)
    if relative_threshold:
        threshold_val = threshold * maxval
    else:
        threshold_val = threshold
    support = (autocorrelation > threshold_val).astype(np.uint)
    if binary_dilate_support > 0:
        support = binary_dilation(support, iterations=binary_dilate_support).astype(np.uint)

    return support