iterateMonitor.py 7.88 KB
Newer Older
jansen31's avatar
jansen31 committed
1
2
3
from numpy import zeros, sqrt
from numpy.linalg import norm

4
from proxtoolbox import algorithms
jansen31's avatar
jansen31 committed
5
from proxtoolbox.utils.cell import isCell
6
from proxtoolbox.utils.size import size_matlab
7

Matthijs's avatar
Matthijs committed
8

9
10
11
12
13
14
15
16
class IterateMonitor:
    """
    Base class for iterate monitors. 
    This is the default algorithm analyzer for monitoring
    iterates of fixed point algorithms.
    """

    def __init__(self, experiment):
17
        self.u0 = experiment.u0
jansen31's avatar
jansen31 committed
18
19
        assert self.u0 is not None, 'No valid initial guess given'
        self.u_monitor = self.u0  # the part of the sequence that is being monitored put u0 as default
20
        self.isCell = isCell(self.u0)
21
22
23
        self.norm_data = experiment.norm_data
        self.truth = experiment.truth
        self.truth_dim = experiment.truth_dim
24
        self.formulation = experiment.formulation
25
        self.norm_truth = experiment.norm_truth
26
27
28
        self.diagnostic = experiment.diagnostic
        self.rotate = experiment.rotate
        self.n_product_Prox = experiment.n_product_Prox
29
30
31
        self.anim = experiment.anim
        self.anim_callback = experiment.animate
        self.anim_step = experiment.anim_step
32
        self.silent = experiment.silent
33

34
        self.changes = None
35
36
37

        # instantiate optimality monitor if it exists
        self.optimality_monitor = None
38
        if hasattr(experiment, 'optimality_monitor') and experiment.optimality_monitor is not None:
39
40
41
42
            optimality_monitor_name = experiment.optimality_monitor
            optimality_monitor_class = getattr(algorithms, optimality_monitor_name)
            self.optimality_monitor = optimality_monitor_class(experiment)

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    def preprocess(self, alg):
        """
        Allocate data structures needed to collect 
        statistics. Called before running the algorithm.
 
        Parameters
        ----------
        alg : instance of Algorithm class
            Algorithm to be monitored.
        """
        self.maxIter = alg.maxIter
        self.changes = zeros(self.maxIter + 1)
        self.changes[0] = sqrt(999)

    def updateStatistics(self, alg):
        """
        Update statistics. Called at each iteration
        while the algorithm is running.
 
        Parameters
        ----------
        alg : instance of Algorithm class
            Algorithm being monitored.
        """
jansen31's avatar
jansen31 committed
67
        self.u_monitor = alg.u_new  # store the last iterate in u_monitor
68
        if not self.isCell:
69
70
71
72
73
74
75
            tmp_change = self.evaluateChange(alg.u, alg.u_new)
        else:
            tmp_change = 0
            for u_elem, u_new_elem in zip(alg.u, alg.u_new):
                tmp_change += self.evaluateChange(u_elem, u_new_elem)
        self.changes[alg.iter] = sqrt(tmp_change)

76
77
78
79
80
81
82
83
84
85
86
    def displayProgress(self, alg):
        """
        Display progress. This method is called after 
        UpdateStatistics(). Default implementation does
        nothing. May be overriden in derived classes
 
        Parameters
        ----------
        alg : instance of Algorithm class
            Algorithm being monitored.
        """
87
        if (not self.silent) and self.anim and self.anim_callback is not None \
jansen31's avatar
jansen31 committed
88
                and alg.iter % self.anim_step == 0:
89
90
            self.anim_callback(alg)

91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    def postprocess(self, alg, output):
        """
        Called after the algorithm stops. Store statistics in
        the given 'output' dictionary

        Parameters
        ----------
        alg : instance of Algorithm class
            Algorithm that was monitored.
        output : dictionary
            Contains the last iterate and various statistics that
            were collected while the algorithm was running.
                
        Returns
        -------
        output : dictionary into which the following entries are added
            u_monitor : ndarray or list of ndarrays.
                Part of the sequence that is being monitored. Default
                is to monitor the entire variable 'u'.
            u1 : ndarray 
                Part of 'u_monitor'. Used for display.
            u2 : ndarray
                Part of 'u_monitor'. Used for display.
            changes: ndarray
                normalized change in successive iterates
        """

Matthijs's avatar
Matthijs committed
118
        output['u_monitor'] = self.u_monitor
119
        if self.diagnostic:
120
            if isCell(self.u_monitor):
121
122
123
124
                u_m = self.u_monitor[0]
            else:
                u_m = self.u_monitor
            # Python code for [m,n,p] = size(u_m)
125
            if isCell(u_m):
jansen31's avatar
jansen31 committed
126
                # in matlab this corresponded to a cell
127
128
                # here we assume that such a cell is m by n where m = 1
                # so far this seems to be always the case
jansen31's avatar
jansen31 committed
129
                # the following tests attempt to match the ndarray case below
130
131
132
133
                m = 1
                n = len(u_m)
                if n == self.n_product_Prox:
                    u1 = u_m[0]
jansen31's avatar
jansen31 committed
134
135
136
137
                    u2 = u_m[n - 1]
                else:  # although the following case corresponds to Matlab's
                    # code, this is suspicious because u1 and u2 do not
                    # have the same structure as in the case above
138
139
                    u1 = u_m
                    u2 = u_m
jansen31's avatar
jansen31 committed
140
            else:  # ndarray
141
                if u_m.ndim == 1:
142
143
                    u1 = u_m
                    u2 = u_m
144
145
146
147
148
149
150
151
152
153
154
155
156
157
                else:
                    m, n, p, _q = size_matlab(u_m)
                    if n == self.n_product_Prox:
                        u1 = u_m[:, 0]
                        u2 = u_m[:, n - 1]
                    elif m == self.n_product_Prox:
                        u1 = u_m[0, :]
                        u2 = u_m[m - 1, :]
                    elif p == self.n_product_Prox:
                        u1 = u_m[:, :, 0]
                        u2 = u_m[:, :, p - 1]
                    else:
                        u1 = u_m
                        u2 = u_m
158
159
160
            output['u1'] = u1
            output['u2'] = u2

161
        if 'stats' in output:
jansen31's avatar
jansen31 committed
162
            output['stats']['changes'] = self.changes[1:alg.iter + 1]
163
        return output
jansen31's avatar
jansen31 committed
164

165
    def getIterateSize(self, u):
Matthijs's avatar
Matthijs committed
166
167
168
169
170
171
        """
        Given an iterate, determine p and q parameters
        :param u:
        :return: p,q (ints)
        """
        # TODO: explain what p, q are?
172
173
174
175
176
177
178
179
180
181
182
183
        if u.ndim < 3:
            p = 1
            q = 1
        elif u.ndim == 3:
            p = u.shape[2]
            q = 1
        else:
            p = u.shape[2]
            q = u.shape[3]
        return p, q

    def evaluateChange(self, u, u_new):
Matthijs's avatar
Matthijs committed
184
        """
185
186
        Given an old and new iterate calculate the total absolute
        squared difference, normalized to self.norm_data
Matthijs's avatar
Matthijs committed
187
188
189
190
        :param u: iterate
        :param u_new: new iterate
        :return: sum of abs squared differences = frobenius norm squared
        """
191
192
193
        p, q = self.getIterateSize(u_new)
        tmp_change = 0
        if p == 1 and q == 1:
jansen31's avatar
jansen31 committed
194
            tmp_change = (norm(u - u_new) / self.norm_data) ** 2
195
        elif q == 1:
Matthijs's avatar
Matthijs committed
196
            for j in range(p):  # TODO this loop can be replaced by np.sum(abs(...)**2)/self.norm_data**2
jansen31's avatar
jansen31 committed
197
                tmp_change += (norm(u[:, :, j] - u_new[:, :, j]) / self.norm_data) ** 2
Matthijs's avatar
Matthijs committed
198
199
        else:  # 4D arrays?!!!
            for j in range(q):  # TODO this loop can be replaced by np.sum(abs(...)**2)/self.norm_data**2
200
                for k in range(p):
jansen31's avatar
jansen31 committed
201
                    tmp_change += (norm(u[:, :, k, j] - u_new[:, :, k, j]) / self.norm_data) ** 2
202
        return tmp_change
203
204
205

    def calculateObjective(self, alg):
        """
jansen31's avatar
jansen31 committed
206
        Calculate objective value. The default implementation
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        uses the optimality monitor if it exists

        Parameters
        ----------
        alg : instance of Algorithm class
            Algorithm that was monitored.
                
        Returns
        -------
        objValue : real
            objective value
        """
        if self.optimality_monitor is not None:
            return self.optimality_monitor.calculateObjective(alg)
        else:
222
            raise AttributeError("optimality_monitor was not provided")