phase.py 9.19 KB
Newer Older
1
2
3
4
5
6
# -*- coding: utf-8 -*-

from proxtoolbox.Problems.problems import Problem
from proxtoolbox import Algorithms
from proxtoolbox import ProxOperators
from proxtoolbox.ProxOperators.proxoperators import ProxOperator
Russell Luke's avatar
Russell Luke committed
7
from proxtoolbox.Problems.OrbitalTomog import Graphics
8
from numpy.linalg import norm
9
from numpy import square, sqrt
Matthijs's avatar
Matthijs committed
10
from proxtoolbox.Utilities.OrbitalTomog import interpolation, array_tools, binning
11
12
13
14
15
16


class Phase(Problem):
    """
    Phase Problem
    """
17
    def __init__(self, new_config):
18
19
20
21
22
23
        """
        The initialization of a Phase instance takes the default configuration
        and updates the parameters with the arguments in new_config.
        
        Parameters
        ----------
24
25
26
27
28
29
30
        new_config : dict with non-empty keys:
            object, constraint, experiment, distance, algorithm
            Ny, Nx
            noise
            MAXIT, TOL
            diagnostic
            data, data_norm, support,
31
        """
32
        self.config = new_config
33
34
35
36

        if 'Nz' not in self.config:
            self.config['Nz'] = 1

37
        # If method_config['formulation'] does not exist, use the product space as the default.
38
39
40
41
42
        if 'formulation' in self.config:
            formulation = self.config['formulation']
        else:
            formulation = 'product space'

43
44
        # Set the projectors and inputs based on the types of constraints and experiments
        used_proxoperators = ['', '', '']
45

46
        # Projector 1 (real / object space)
47
        if self.config['constraint'] == 'support only':
48
            used_proxoperators[0] = 'P_S'
49
        elif self.config['constraint'] == 'real and support':
50
51
52
53
54
            used_proxoperators[0] = 'P_S_real'
        elif self.config['constraint'] == 'nonnegative and support':
            used_proxoperators[0] = 'P_SP'
        elif self.config['constraint'] == 'amplitude only':
            used_proxoperators[0] = 'P_amp'
55
56
57
58
        elif self.config['constraint'] == 'sparse real':
            used_proxoperators[0] = 'P_Sparsity_real'
        elif self.config['constraint'] == 'sparse complex':
            used_proxoperators[0] = 'P_Sparsity'
59
60
61
62
63
64
        elif self.config['constraint'] in ['symmetric sparse real', 'sparse symmetric real']:
            used_proxoperators[0] = 'P_Sparsity_Symmetric_real'
        elif self.config['constraint'] in ['symmetric sparse complex', 'symmetric sparse complex']:
            used_proxoperators[0] = 'P_Sparsity_Symmetric'
        else:
            raise ValueError('Constraint not recognized')
65

66
        # Projector 2 (k / Fourier space)
Matthijs's avatar
Matthijs committed
67
68
69
70
        if self.config['experiment'] == '3D ARPES':
            used_proxoperators[1] = 'P_M_masked'
        else:
            used_proxoperators[1] = 'Approx_P_FreFra_Poisson'
71
72
73

        self.config['proxoperators'] = []

74
        for prox in used_proxoperators:
75
76
77
78
79
80
81
            if prox != '':
                self.config['proxoperators'].append(getattr(ProxOperators, prox))

        if 'product_space_dimension' not in self.config:
            self.config['product_space_dimension'] = 1

        # set the animation program:
82
83
84
85
86
87
        self.config['animation'] = 'Phase_animation'

        # if you are only working with two sets but want to do averaged projections (= alternating projections on the
        # product space) or RAAR on the product space (=swarming), then you will want to change
        # product_space_dimension=2 and adjust your input files and projectors accordingly. you could also do this
        # within the data processor
88
89
90

        self.config['TOL2'] = 1e-15

91
92
93
        # To estimate the gap in the sequential formulation, we build the appropriate point in the product space.
        # This allows for code reuse. Note for sequential diversity diffraction, input.Proj1 is the "RCAAR" version
        # of the function.
94
        if formulation == 'sequential':
95
96
97
98
99
100
101
102
103
104
            raise NotImplementedError("This needs careful matlab -> python translation")
            # for j in range(self.config['product_space_dimension']):
            #     self.config['proj_iter'] = j
            #     proj1 = self.config['proxoperators'][0](self.config)
            #     u_1[:, :, j] = proj1.work(self.config['u_0'])
            #     self.config['proj_iter'] = mod(j, config['product_space_dimension']) + 1
            #     proj1 = self.config['proxoperators'][0](self.config)
            #     u_1[:, :, j] = proj1.work(self.config['u_0'])
            # end
        else:  # i.e. formulation=='product space'
Matthijs's avatar
Matthijs committed
105
            proj_1 = self.config['proxoperators'][0](self.config)  # , self.back_end)
106
            u_1 = proj_1.work(self.config['u_0'])
Matthijs's avatar
Matthijs committed
107
            proj_2 = self.config['proxoperators'][1](self.config)  # , self.back_end)
108
            u_2 = proj_2.work(u_1)
109
110

        # estimate the gap in the relevant metric
111
112
        if self.config['Nx'] == 1 or self.config['Ny'] == 1:  # 1D problem
            tmp_gap = square(norm(u_1 - u_2) / self.config['norm_data'])  # norm_rt_data
113
        elif self.config['product_space_dimension'] == 1:
114
            tmp_gap = (norm(u_1 - u_2) / self.config['norm_data']) ** 2  # norm_rt_data
115
        else:
116
            tmp_gap = 0
117
118
            for j in range(self.config['product_space_dimension']):
                # compute (||P_Sx-P_Mx||/norm_data)^2:
119
120
121
122
123
124
125
126
127
128
                tmp_gap = tmp_gap + (norm(u_1[:, :, j] - u_2[:, :, j]) / self.config['norm_data']) ** 2  # norm_rt_data

        gap_0 = sqrt(tmp_gap)

        # sets the set fattening to be a percentage of the initial gap to the unfattened set with respect to the
        # relevant metric (KL or L2), that percentage given by input.data_ball input by the user.
        self.config['data_ball'] = self.config['data_ball'] * gap_0
        # the second tolerance relative to the order of magnitude of the metric
        self.config['TOL2'] = self.config['data_ball'] * 1e-15
        # self.config['proxoperators']
Russell Luke's avatar
Russell Luke committed
129
        self.algorithm = getattr(Algorithms, self.config['algorithm'])(self.config)#, self.back_end)
130
131
132
133
134

    def _presolve(self):
        """
        Prepares argument for actual solving routine
        """
Matthijs's avatar
Matthijs committed
135
        pass  # Actually nothing for the phase problem
136

137
138
    def _solve(self):
        """
139
140
141
        Runs the algorithm to solve the given problem

        Output can be accessed from self.output.
142
        """
143
144
        # Call to the algorithm, specifically run in SimpleAlgorithm
        self.output = self.algorithm.run(self.config['u_0'], self.config['TOL'], self.config['MAXIT'])
145
146
147
148
149
150
        print('Iterations:' + str(self.output['iter']))

    def _postsolve(self):
        """
        Processes the solution and generates the output
        """
151
152
        # Center the solution (since position is a degree of freedom,
        # and if desired, interpolate the results.
Matthijs's avatar
Matthijs committed
153
        center = tuple([s//2 for s in self.config['u'].shape])
154
        for key in ['u', 'u1', 'u2']:
Matthijs's avatar
Matthijs committed
155
156
            self.output[key] = array_tools.roll_to_pos(self.output[key], pos=center, move_maximum=True)
            self.output[key] = array_tools.roll_to_pos(self.output[key], pos=center)
157
158
            # This sequence will work for objects *with a small support* even if they lie over the edge of the array
            if 'interpolate_result' in self.config and self.config['interpolate_result']:
Matthijs's avatar
Matthijs committed
159
                self.output[key] = interpolation.fourier_interpolate(self.output[key])
jansen31's avatar
jansen31 committed
160
            if 'zoomin_on_result' in self.config and self.config['zoomin_on_result']:
Matthijs's avatar
Matthijs committed
161
162
163
164
165
166
167
168
169
170
171
                if self.config[key].ndims == 2:
                    zmy, zmx = self.config[key].shape  # self.config['Ny'] // 4, self.config["Nx"] // 4
                    self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx]
                elif self.config[key].ndims == 3:
                    zmy, zmx, zmz = self.config[key].shape
                    # (self.config['Ny'] // 4, self.config["Nx"] // 4, self.config['Nz'] // 4)
                    self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx, zmz:-zmz]

        # # Old code, developed for 2D orbital imaging:
        # yc, xc = int(self.config['Ny'] / 2), int(self.config["Nx"] / 2)
        # for key in ['u', 'u1', 'u2']:
Matthijs's avatar
Matthijs committed
172
173
        #     self.output[key] = array_tools.roll_to_pos(self.output[key], yc, xc, move_maximum=True)
        #     self.output[key] = array_tools.roll_to_pos(self.output[key], yc, xc)
Matthijs's avatar
Matthijs committed
174
175
176
177
178
179
        #     # This sequence will work for objects *with a small support* even if they lie over the edge of the array
        #     if 'interpolate_result' in self.config and self.config['interpolate_result']:
        #         self.output[key] = interpolation.fourier_interpolate(self.output[key])
        #     if 'zoomin_on_result' in self.config and self.config['zoomin_on_result']:
        #         zmy, zmx = int(self.config['Ny'] / 4), int(self.config["Nx"] / 4)
        #         self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx]
180

181
182
183
184
185
186
    def show(self):
        """
        Generates graphical output from the solution
        """
        print("Calculation time:")
        print(self.elapsed_time)
Russell Luke's avatar
Russell Luke committed
187
        _graphics = getattr(Graphics, self.config['graphics_display'])
188
        _graphics(self.config, self.output)
Matthijs's avatar
Matthijs committed
189
190
191
192
193
194
195
196
197
198
199
200

    def save(self):
        """
        Saves inputs and outputs of the reconstruction procedure

        Necessary inputs: save directory,

        :return:
        """
        # TODO: some basic procedure. perhaps just a python pickle? (of the config and output dictionaries)
        raise NotImplementedError