planar_molecule.py 11.9 KB
Newer Older
1
2
3
4
5
from skimage.io import imread
from scipy.ndimage import binary_dilation
import numpy as np
import matplotlib.pyplot as plt

6
from proxtoolbox.experiments.orbitaltomography.orbitExperiment import OrbitalTomographyExperiment
7
from proxtoolbox.utils.visualization.complex_field_visualization import complex_to_rgb
8
from proxtoolbox.utils.orbitaltomog import bin_array, shifted_fft, shifted_ifft, fourier_interpolate, roll_to_pos
9

10
from proxtoolbox.utils.GetData import datadir
11
12
13

class PlanarMolecule(OrbitalTomographyExperiment):
    @staticmethod
14
    def getDefaultParameters():
jansen31's avatar
jansen31 committed
15
        # TODO: optimize parameters and proxoperators to get good & consistent phase retrieval using the demo
16
        defaultParams = {
jansen31's avatar
jansen31 committed
17
            'experiment_name': 'noisy 2D ARPES', # '2D ARPES', #
18
            'data_filename': datadir/'OrbitalTomog'/'coronene_homo1.tif',
19
            'from_intensity_data': False,
20
21
            'object': 'real',
            'constraint': 'sparse real',
22
            'sparsity_parameter': 40,
23
24
25
            'use_sparsity_with_support': False,
            'threshold_for_support': 0.1,
            'support_filename': None,
26
27
            'Nx': None,
            'Ny': None,
28
29
30
31
32
            'Nz': 1,
            'MAXIT': 500,
            'TOL': 1e-10,
            'lambda_0': 0.85,
            'lambda_max': 0.50,
33
            'lambda_switch': 50,
jansen31's avatar
jansen31 committed
34
            'data_ball': .999826e-30,
35
            'TOL2': 1e-15,
36
            'diagnostic': True,
37
            'algorithm': 'DRl',
38
            'iterate_monitor_name': 'FeasibilityIterateMonitor',  # 'IterateMonitor',  #
39
40
41
            'rotate': False,
            'verbose': 1,
            'graphics': 1,
42
            'interpolate_and_zoom': True,
43
            'debug': True,
44
            'progressbar': None  # Valid options: None, 'tqdm' or 'tqdm_notebook'
45
46
47
48
49
50
51
        }
        return defaultParams

    def __init__(self, **kwargs):
        super(PlanarMolecule, self).__init__(**kwargs)

        # do here any data member initialization
52
53
54
55
56
57
        self.data_filename = kwargs['data_filename']  # given as input to class normally
        self.from_intensity_data = kwargs['from_intensity_data']
        self.support_filename = kwargs['support_filename']  # optional class argument
        self.sparsity_parameter = kwargs['sparsity_parameter']
        self.use_sparsity_with_support = kwargs['use_sparsity_with_support']
        self.threshold_for_support = kwargs['threshold_for_support']  # to determine support from the autocorrelation.
58
        self.interpolate_and_zoom = kwargs['interpolate_and_zoom']  # For plotting
59
        self.progressbar = kwargs['progressbar']
60
61

        # the following data members are set by loadData(), in addition to those specified in parent classes
62
63
        self.support = None  # support
        self.sparsity_support = None
Matthijs's avatar
Matthijs committed
64
        self.data_zeros = None
65

66
    def loadData(self):
67
68
        """
        Load data and set in the correct format for reconstruction
69
        Parameters are taken from experiment class (self) properties, which must include::
70
71
72
73
74
75
76
77
78
            - data_filename: str, path to the data file
            - from_intensity_data: bool, if the data file gives intensities rather than field amplitude
            - support_filename: str, optional path to file with object support
            - use_sparsity_with_support: bool, if true, use a support before the sparsity constraint.
                The support is calculated by thresholding the object autocorrelation, and dilate the result
            - threshold_for_support: float, in range [0,1], fraction of the maximum at which to threshold when
                determining support or support for sparsity
        """
        # load data
79
80
        if self.data_filename is None:
            self.data_filename = input('Please enter the path to the datafile: ')
81
        try:
82
            self.data = imread(self.data_filename)
83
        except FileNotFoundError:
84
85
86
            print("Tried path %s, found nothing. " % self.data_filename)
            self.data_filename = input('Please enter a valid path to the datafile: ')
            self.data = imread(self.data_filename)
87
88
89
90

        # Keep the same resolution?
        ny, nx = self.data.shape
        try:
91
            if ny % self.Ny == 0 and nx % self.Nx == 0:
92
                # binning must be done for the intensity-data, as that preserves the normalization
93
94
                if self.from_intensity_data:
                    self.data = bin_array(self.data, (self.Ny, self.Nx))
95
                else:
96
                    self.data = np.sqrt(bin_array(self.data ** 2, (self.Ny, self.Nx)))
97
98
99
            else:
                # TODO: use flexibility allowed by the binning function to prevent this error
                raise ValueError('Incompatible values for Ny and Nx given in configuration dict')
100
        except TypeError:
101
102
103
104
            pass
        self.Ny, self.Nx = self.data.shape

        # Calculate electric field and norm of the data
105
        if self.from_intensity_data:
106
107
108
109
110
            # avoid sqrt of negative numbers (due to background subtraction)
            self.data = np.where(self.data > 0, np.sqrt(abs(self.data)), np.zeros_like(self.data))
        self.norm_data = np.sqrt(np.sum(self.data ** 2))

        # Object support determination
111
112
113
        if self.support is not None:
            self.support = imread(self.support_filename)
        else:
114
            self.support = support_from_autocorrelation(self.data,
115
                                                        threshold=self.threshold_for_support,
116
117
                                                        absolute_autocorrelation=True,
                                                        binary_dilate_support=1)
118
119
120
121
        if self.use_sparsity_with_support:
            self.sparsity_support = support_from_autocorrelation(self.data,
                                                                 threshold=self.threshold_for_support,
                                                                 binary_dilate_support=1)
122
123
        self.createRandomGuess()

Matthijs's avatar
Matthijs committed
124
125
126
127
        # some variables wich are necessary for the algorithm:
        self.data_sq = self.data ** 2
        self.data_zeros = np.where(self.data == 0)

128
129
130
    def createRandomGuess(self):
        """
        Taking the measured data, add a random phase and calculate the resulting iterate guess
131
        """
132
        ph_init = 2 * np.pi * np.random.random_sample(self.data.shape)
133
134
135
136
137
        self.u0 = self.data * np.exp(1j * ph_init)
        self.u0 = np.fft.fftn(self.u0)

    def plotInputData(self):
        """Quick plotting routine to show the data, initial guess and the sparsity support"""
138
139
140
        fig, ax = plt.subplots(1, 3, figsize=(12, 3.5))
        im = ax[0].imshow(self.data)
        plt.colorbar(im, ax=ax[0])
141
        plt.title("Photoelectron spectrum")
142
        ax[1].imshow(complex_to_rgb(self.u0))
143
        plt.title("Initial guess")
144
145
146
        if self.sparsity_support is not None:
            im = ax[2].imshow(self.sparsity_support, cmap='gray')
            # plt.colorbar(im, ax=ax[2])
147
148
            plt.title("Sparsity support")
        plt.show()
149
150

    def show(self):
151
152
153
        """
        Create basic result plots of the phase retrieval procedure
        """
154
        self.output['plots'] = []
155
156
157
158
159
160
161
162
163
164
165
        self.output['u1'] = self.algorithm.prox1.eval(self.algorithm.u)
        self.output['u2'] = self.algorithm.prox2.eval(self.algorithm.u)
        if self.interpolate_and_zoom:
            for key in ["u1", "u2"]:
                center = tuple([s // 2 for s in self.output[key].shape])
                self.output[key] = roll_to_pos(self.output[key], pos=center, move_maximum=True)
                self.output[key] = roll_to_pos(self.output[key], pos=center)
                self.output[key] = fourier_interpolate(self.output[key], 2)
                zmy, zmx = tuple([s // 4 for s in self.output[key].shape])
                self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx]
        u1 = self.output['u1']
166
        u2 = self.output['u2']
167
        change = self.output['stats']['changes']
168

169
        f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))
170
        im = ax1.imshow(abs(u1), cmap='gray')
171
172
        f.colorbar(im, ax=ax1)
        ax1.set_title('best approximation amplitude: \nphysical constraint satisfied')
173
        ax2.imshow(complex_to_rgb(u1))
174
175
176
177
        ax2.set_title('best approximation phase: \nphysical constraint satisfied')
        im = ax3.imshow(abs(u2), cmap='gray')
        f.colorbar(im, ax=ax3)
        ax3.set_title('best approximation amplitude: \nFourier constraint satisfied')
178
        ax4.imshow(complex_to_rgb(u2))
179
180
181
        ax4.set_title('best approximation amplitude: \nFourier constraint satisfied')
        f.tight_layout()

182
        g, ((bx1, bx2), (bx3, bx4)) = plt.subplots(2, 2, figsize=(9, 6))
183
        im = bx1.imshow(abs(u1), cmap='gray')
184
185
        f.colorbar(im, ax=bx1)
        bx1.set_title('best approximation amplitude: \nphysical constraint satisfied')
186
        im = bx2.imshow(u1.real, cmap='gray')
187
188
189
190
191
192
193
194
195
196
197
        f.colorbar(im, ax=bx2)
        bx2.set_title('best approximation phase: \nphysical constraint satisfied')
        bx3.semilogy(change)
        bx3.set_xlabel('iteration')
        bx3.set_ylabel('Change')
        if 'gaps' in self.output['stats']:
            gaps = self.output['stats']['gaps']
            bx4.semilogy(gaps)
            bx4.set_xlabel('iteration')
            bx4.set_ylabel('Gap')
        f.tight_layout()
198

jansen31's avatar
jansen31 committed
199
        h, ax = plt.subplots(1, 3, figsize=(9, 3))
200
        ax[0].imshow(self.data)
jansen31's avatar
jansen31 committed
201
        ax[0].set_title("Measured data")
202
        prop = self.propagator(self)
jansen31's avatar
jansen31 committed
203
204
205
206
        guess = self.algorithm.prox2.eval(self.algorithm.u)
        guess = roll_to_pos(guess, pos=tuple([s//2 for s in guess.shape]), move_maximum=True)
        guess = roll_to_pos(guess, pos=tuple([s // 2 for s in guess.shape]))
        u_hat = prop.eval(self.algorithm.prox1.eval(guess))
207
        ax[1].imshow(abs(u_hat))
jansen31's avatar
jansen31 committed
208
        ax[1].set_title("Predicted measurement intensity")
209
        ax[2].imshow(complex_to_rgb(u_hat))
jansen31's avatar
jansen31 committed
210
        ax[2].set_title("Predicted phase (by color)")
211
        h.tight_layout()
212
        plt.show()
213

214
215
        self.output['plots'].append(f)
        self.output['plots'].append(g)
216
        self.output['plots'].append(h)
217

218
219
220
    # def saveOutput(self, **kwargs):
    #     super(PlanarMolecule, self).saveOutput(**kwargs)

221

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def support_from_autocorrelation(input_array: np.ndarray,
                                 threshold: float = 0.1,
                                 relative_threshold: bool = True,
                                 input_in_fourier_domain: bool = True,
                                 absolute_autocorrelation: bool = True,
                                 binary_dilate_support: int = 0) -> np.ndarray:
    """
    Determine an initial support from the autocorrelation of an object.

    Args:
        input_array: either the measured diffraction (arpes pattern) or a guess of the object
        threshold: support is everywhere where the autocorrelation is higher than the threshold
        relative_threshold: If true, threshold at threshold*np.amax(autocorrelation)
        input_in_fourier_domain: False if a guess of the object is given in input_array
        absolute_autocorrelation: Take the absolute value of the autocorrelation? (Generally a
            good idea for objects which are not non-negative)
        binary_dilate_support: number of dilation operations to apply to the support.

    Returns:
        support array (same dimensions as input, dtype=np.int)
    """
    if not input_in_fourier_domain:
        kspace = shifted_fft(input_array)
    else:
        kspace = input_array

    autocorrelation = shifted_ifft(abs(kspace))  # Taking absolute value yields autocorrelation by conv. theorem)
    if absolute_autocorrelation:
        autocorrelation = abs(autocorrelation)
    maxval = np.amax(autocorrelation)
    if relative_threshold:
        threshold_val = threshold * maxval
    else:
        threshold_val = threshold
    support = (autocorrelation > threshold_val).astype(np.uint)
    if binary_dilate_support > 0:
        support = binary_dilation(support, iterations=binary_dilate_support).astype(np.uint)

    return support