phase.py 6.35 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-

from proxtoolbox.Problems.problems import Problem
from proxtoolbox import Algorithms
from proxtoolbox import ProxOperators
from proxtoolbox.ProxOperators.proxoperators import ProxOperator
from proxtoolbox.Problems.Phase import Graphics
8
9
# from .proxoperators import ProxOperator
#from phase_retrieval.back_end_utilities import norm
10
from numpy.linalg import norm
11
from numpy import square, sqrt
12
13
14
15
16
17


class Phase(Problem):
    """
    Phase Problem
    """
18
19
20
    config = {}

    def __init__(self, new_config):
21
22
23
24
25
26
        """
        The initialization of a Phase instance takes the default configuration
        and updates the parameters with the arguments in new_config.
        
        Parameters
        ----------
27
28
29
30
31
32
33
        new_config : dict with non-empty keys:
            object, constraint, experiment, distance, algorithm
            Ny, Nx
            noise
            MAXIT, TOL
            diagnostic
            data, data_norm, support,
34
        """
35
36
        self.config = new_config
        #self.back_end = back_end
37
38
39

        #call data processor, read data

40
41
42
#         module = __import__(self.config['data_filename'])
#         data_processor = getattr(module, self.config['data_filename'])
#         data_processor(self.config)
43
44
45
46

        if 'Nz' not in self.config:
            self.config['Nz'] = 1

47
        # If method_config['formulation'] does not exist, use the product space as the default.
48
49
50
51
52
        if 'formulation' in self.config:
            formulation = self.config['formulation']
        else:
            formulation = 'product space'

53
54
        # Set the projectors and inputs based on the types of constraints and experiments
        used_proxoperators = ['', '', '']
55

56
        # Projector 1 (real / object space)
57
        if self.config['constraint'] == 'support only':
58
            used_proxoperators[0] = 'P_S'
59
        elif self.config['constraint'] == 'real and support':
60
61
62
63
64
            used_proxoperators[0] = 'P_S_real'
        elif self.config['constraint'] == 'nonnegative and support':
            used_proxoperators[0] = 'P_SP'
        elif self.config['constraint'] == 'amplitude only':
            used_proxoperators[0] = 'P_amp'
65

66
67
68
        # Projector 2 (k / Fourier space)
#         used_proxoperators[1] = 'P_M'  # 'Approx_P_FreFra_Poisson'
        used_proxoperators[1] = 'Approx_P_FreFra_Poisson'
69
70
71

        self.config['proxoperators'] = []

72
        for prox in used_proxoperators:
73
74
75
76
77
78
79
            if prox != '':
                self.config['proxoperators'].append(getattr(ProxOperators, prox))

        if 'product_space_dimension' not in self.config:
            self.config['product_space_dimension'] = 1

        # set the animation program:
80
81
82
83
84
85
        self.config['animation'] = 'Phase_animation'

        # if you are only working with two sets but want to do averaged projections (= alternating projections on the
        # product space) or RAAR on the product space (=swarming), then you will want to change
        # product_space_dimension=2 and adjust your input files and projectors accordingly. you could also do this
        # within the data processor
86
87
88

        self.config['TOL2'] = 1e-15

89
90
91
        # To estimate the gap in the sequential formulation, we build the appropriate point in the product space.
        # This allows for code reuse. Note for sequential diversity diffraction, input.Proj1 is the "RCAAR" version
        # of the function.
92
        if formulation == 'sequential':
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            raise NotImplementedError("This needs careful matlab -> python translation")
            # for j in range(self.config['product_space_dimension']):
            #     self.config['proj_iter'] = j
            #     proj1 = self.config['proxoperators'][0](self.config)
            #     u_1[:, :, j] = proj1.work(self.config['u_0'])
            #     self.config['proj_iter'] = mod(j, config['product_space_dimension']) + 1
            #     proj1 = self.config['proxoperators'][0](self.config)
            #     u_1[:, :, j] = proj1.work(self.config['u_0'])
            # end
        else:  # i.e. formulation=='product space'
            proj_1 = self.config['proxoperators'][0](self.config)#, self.back_end)
            u_1 = proj_1.work(self.config['u_0'])
            proj_2 = self.config['proxoperators'][1](self.config)#, self.back_end)
            u_2 = proj_2.work(u_1)
107
108

        # estimate the gap in the relevant metric
109
110
        if self.config['Nx'] == 1 or self.config['Ny'] == 1:  # 1D problem
            tmp_gap = square(norm(u_1 - u_2) / self.config['norm_data'])  # norm_rt_data
111
        elif self.config['product_space_dimension'] == 1:
112
            tmp_gap = (norm(u_1 - u_2) / self.config['norm_data']) ** 2  # norm_rt_data
113
        else:
114
            tmp_gap = 0
115
116
            for j in range(self.config['product_space_dimension']):
                # compute (||P_Sx-P_Mx||/norm_data)^2:
117
118
119
120
121
122
123
124
125
126
127
                tmp_gap = tmp_gap + (norm(u_1[:, :, j] - u_2[:, :, j]) / self.config['norm_data']) ** 2  # norm_rt_data

        gap_0 = sqrt(tmp_gap)

        # sets the set fattening to be a percentage of the initial gap to the unfattened set with respect to the
        # relevant metric (KL or L2), that percentage given by input.data_ball input by the user.
        self.config['data_ball'] = self.config['data_ball'] * gap_0
        # the second tolerance relative to the order of magnitude of the metric
        self.config['TOL2'] = self.config['data_ball'] * 1e-15
        # self.config['proxoperators']
        self.algorithm = getattr(algorithms, self.config['algorithm'])(self.config)#, self.back_end)
128
129
130
131
132

    def _presolve(self):
        """
        Prepares argument for actual solving routine
        """
133
134
        pass # Actually nothing for the phase problem

135
136
    def _solve(self):
        """
137
138
139
        Runs the algorithm to solve the given problem

        Output can be accessed from self.output.
140
        """
141
142
        # Call to the algorithm, specifically run in SimpleAlgorithm
        self.output = self.algorithm.run(self.config['u_0'], self.config['TOL'], self.config['MAXIT'])
143
144
145
146
147
148
        print('Iterations:' + str(self.output['iter']))

    def _postsolve(self):
        """
        Processes the solution and generates the output
        """
149
150
        pass

151
152
153
154
    def show(self):
        """
        Generates graphical output from the solution
        """
155

156
157
        print("Calculation time:")
        print(self.elapsed_time)
158
159
        _graphics = getattr(graphics, self.config['graphics_display'])
        _graphics(self.config, self.output)