phase.py 7.9 KB
Newer Older
1
2
3
4
5
6
# -*- coding: utf-8 -*-

from proxtoolbox.Problems.problems import Problem
from proxtoolbox import Algorithms
from proxtoolbox import ProxOperators
from proxtoolbox.ProxOperators.proxoperators import ProxOperator
Russell Luke's avatar
Russell Luke committed
7
from proxtoolbox.Problems.OrbitalTomog import Graphics
8
9
# from .proxoperators import ProxOperator
#from phase_retrieval.back_end_utilities import norm
10
from numpy.linalg import norm
11
from numpy import square, sqrt
12
13
14
15
16
17


class Phase(Problem):
    """
    Phase Problem
    """
18
19
20
    config = {}

    def __init__(self, new_config):
21
22
23
24
25
26
        """
        The initialization of a Phase instance takes the default configuration
        and updates the parameters with the arguments in new_config.
        
        Parameters
        ----------
27
28
29
30
31
32
33
        new_config : dict with non-empty keys:
            object, constraint, experiment, distance, algorithm
            Ny, Nx
            noise
            MAXIT, TOL
            diagnostic
            data, data_norm, support,
34
        """
35
36
        self.config = new_config
        #self.back_end = back_end
37
38
39

        #call data processor, read data

40
41
42
#         module = __import__(self.config['data_filename'])
#         data_processor = getattr(module, self.config['data_filename'])
#         data_processor(self.config)
43
44
45
46

        if 'Nz' not in self.config:
            self.config['Nz'] = 1

47
        # If method_config['formulation'] does not exist, use the product space as the default.
48
49
50
51
52
        if 'formulation' in self.config:
            formulation = self.config['formulation']
        else:
            formulation = 'product space'

53
54
        # Set the projectors and inputs based on the types of constraints and experiments
        used_proxoperators = ['', '', '']
55

56
        # Projector 1 (real / object space)
57
        if self.config['constraint'] == 'support only':
58
            used_proxoperators[0] = 'P_S'
59
        elif self.config['constraint'] == 'real and support':
60
61
62
63
64
            used_proxoperators[0] = 'P_S_real'
        elif self.config['constraint'] == 'nonnegative and support':
            used_proxoperators[0] = 'P_SP'
        elif self.config['constraint'] == 'amplitude only':
            used_proxoperators[0] = 'P_amp'
65
66
67
68
        elif self.config['constraint'] == 'sparse real':
            used_proxoperators[0] = 'P_Sparsity_real'
        elif self.config['constraint'] == 'sparse complex':
            used_proxoperators[0] = 'P_Sparsity'
69
70
71
72
73
74
        elif self.config['constraint'] in ['symmetric sparse real', 'sparse symmetric real']:
            used_proxoperators[0] = 'P_Sparsity_Symmetric_real'
        elif self.config['constraint'] in ['symmetric sparse complex', 'symmetric sparse complex']:
            used_proxoperators[0] = 'P_Sparsity_Symmetric'
        else:
            raise ValueError('Constraint not recognized')
75

76
77
78
        # Projector 2 (k / Fourier space)
#         used_proxoperators[1] = 'P_M'  # 'Approx_P_FreFra_Poisson'
        used_proxoperators[1] = 'Approx_P_FreFra_Poisson'
79
80
81

        self.config['proxoperators'] = []

82
        for prox in used_proxoperators:
83
84
85
86
87
88
89
            if prox != '':
                self.config['proxoperators'].append(getattr(ProxOperators, prox))

        if 'product_space_dimension' not in self.config:
            self.config['product_space_dimension'] = 1

        # set the animation program:
90
91
92
93
94
95
        self.config['animation'] = 'Phase_animation'

        # if you are only working with two sets but want to do averaged projections (= alternating projections on the
        # product space) or RAAR on the product space (=swarming), then you will want to change
        # product_space_dimension=2 and adjust your input files and projectors accordingly. you could also do this
        # within the data processor
96
97
98

        self.config['TOL2'] = 1e-15

99
100
101
        # To estimate the gap in the sequential formulation, we build the appropriate point in the product space.
        # This allows for code reuse. Note for sequential diversity diffraction, input.Proj1 is the "RCAAR" version
        # of the function.
102
        if formulation == 'sequential':
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            raise NotImplementedError("This needs careful matlab -> python translation")
            # for j in range(self.config['product_space_dimension']):
            #     self.config['proj_iter'] = j
            #     proj1 = self.config['proxoperators'][0](self.config)
            #     u_1[:, :, j] = proj1.work(self.config['u_0'])
            #     self.config['proj_iter'] = mod(j, config['product_space_dimension']) + 1
            #     proj1 = self.config['proxoperators'][0](self.config)
            #     u_1[:, :, j] = proj1.work(self.config['u_0'])
            # end
        else:  # i.e. formulation=='product space'
            proj_1 = self.config['proxoperators'][0](self.config)#, self.back_end)
            u_1 = proj_1.work(self.config['u_0'])
            proj_2 = self.config['proxoperators'][1](self.config)#, self.back_end)
            u_2 = proj_2.work(u_1)
117
118

        # estimate the gap in the relevant metric
119
120
        if self.config['Nx'] == 1 or self.config['Ny'] == 1:  # 1D problem
            tmp_gap = square(norm(u_1 - u_2) / self.config['norm_data'])  # norm_rt_data
121
        elif self.config['product_space_dimension'] == 1:
122
            tmp_gap = (norm(u_1 - u_2) / self.config['norm_data']) ** 2  # norm_rt_data
123
        else:
124
            tmp_gap = 0
125
126
            for j in range(self.config['product_space_dimension']):
                # compute (||P_Sx-P_Mx||/norm_data)^2:
127
128
129
130
131
132
133
134
135
136
                tmp_gap = tmp_gap + (norm(u_1[:, :, j] - u_2[:, :, j]) / self.config['norm_data']) ** 2  # norm_rt_data

        gap_0 = sqrt(tmp_gap)

        # sets the set fattening to be a percentage of the initial gap to the unfattened set with respect to the
        # relevant metric (KL or L2), that percentage given by input.data_ball input by the user.
        self.config['data_ball'] = self.config['data_ball'] * gap_0
        # the second tolerance relative to the order of magnitude of the metric
        self.config['TOL2'] = self.config['data_ball'] * 1e-15
        # self.config['proxoperators']
Russell Luke's avatar
Russell Luke committed
137
        self.algorithm = getattr(Algorithms, self.config['algorithm'])(self.config)#, self.back_end)
138
139
140
141
142

    def _presolve(self):
        """
        Prepares argument for actual solving routine
        """
143
144
        pass # Actually nothing for the phase problem

145
146
    def _solve(self):
        """
147
148
149
        Runs the algorithm to solve the given problem

        Output can be accessed from self.output.
150
        """
151
152
        # Call to the algorithm, specifically run in SimpleAlgorithm
        self.output = self.algorithm.run(self.config['u_0'], self.config['TOL'], self.config['MAXIT'])
153
154
155
156
157
158
        print('Iterations:' + str(self.output['iter']))

    def _postsolve(self):
        """
        Processes the solution and generates the output
        """
159
160
161
162
163
164
165
166
167
        # Center the solution (since position is a degree of freedom,
        # and if desired, interpolate the results.
        yc, xc = int(self.config['Ny'] / 2), int(self.config["Nx"] / 2)
        for key in ['u', 'u1', 'u2']:
            self.output[key] = Graphics.roll_to_pos(self.output[key], yc, xc, move_maximum=True) # first move maximum
            self.output[key] = Graphics.roll_to_pos(self.output[key], yc, xc) # then move center of mass.
            # This sequence will work for objects *with a small support* even if they lie over the edge of the array
            if 'interpolate_result' in self.config and self.config['interpolate_result']:
                self.output[key] = Graphics.fourier_interpolate(self.output[key])
jansen31's avatar
jansen31 committed
168
169
170
            if 'zoomin_on_result' in self.config and self.config['zoomin_on_result']:
                zmy, zmx = int(self.config['Ny'] / 4), int(self.config["Nx"] / 4)
                self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx]
171

172
173
174
175
176
177
    def show(self):
        """
        Generates graphical output from the solution
        """
        print("Calculation time:")
        print(self.elapsed_time)
Russell Luke's avatar
Russell Luke committed
178
        _graphics = getattr(Graphics, self.config['graphics_display'])
179
        _graphics(self.config, self.output)