RAAR.py 3.95 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 14 12:48:26 2015

@author: rebecca
"""

8
9
10
11
from math import exp, sqrt
from numpy import zeros

from scipy.linalg import norm
12
13
14
15
16
17
18
19
20
21

from .algorithms import Algorithm

class RAAR(Algorithm):
    """
    Relaxed Averaged Alternating Reflection algorithm
    """
    
    def __init__(self,config):
        """
rnahme's avatar
rnahme committed
22
23
24
25
26
        Parameters
        ----------
        config : dict        
                 Dictionary containing the problem configuration.
                 It must contain the following mappings:
27
            
28
29
                proxoperators: 2 ProxOperators
                    Tuple of ProxOperators (the class, no instance)
rnahme's avatar
rnahme committed
30
                beta0: number
31
                    Starting relaxation parmater
rnahme's avatar
rnahme committed
32
                beta_max: number
33
                    Maximum relaxation parameter
rnahme's avatar
rnahme committed
34
                beta_switch: int
35
                    Iteration at which beta moves from beta0 -> beta_max
rnahme's avatar
rnahme committed
36
                normM: number
37
                    ?
rnahme's avatar
rnahme committed
38
                Nx: int
39
                    Row-dim of the product space elements
rnahme's avatar
rnahme committed
40
                Ny: int
41
                    Column-dim of the product space elements
rnahme's avatar
rnahme committed
42
                Nz: int
43
                    Depth-dim of the product space elements
rnahme's avatar
rnahme committed
44
                dim: int
45
46
                    Size of the product space
        """
47
48
        self.proj1 = config['proxoperators'][0](config)
        self.proj2 = config['proxoperators'][1](config)
49
50
51
52
        self.normM = config['normM']
        self.beta0 = config['beta0']
        self.beta_max = config['beta_max']
        self.beta_switch = config['beta_switch']
53
        self.Nx = config['Nx']; self.Ny = config['Ny']; self.Nz = config['Nz'];
54
55
        self.dim = config['dim']
        self.iters = 0
56
57

    def run(self, u, tol, maxiter):
58
59
60
        """
        Runs the algorithm for the specified input data
        """
61
62
        
        ##### PREPROCESSING
63
        normM = self.normM
64
        
65
        beta = self.beta0
66
        iters = self.iters
67
68
69
        change = zeros(maxiter+1,dtype=u.dtype)
        change[0] = 999
        gap = change.copy()
70
        
71
        tmp1 = 2*self.proj2.work(u) - u
72
        
73
74
        
        ##### LOOP
75
        while iters < maxiter and change[iters] >= tol:
76
            tmp = exp((-iters/self.beta_switch)**3);
77
            beta = (tmp*self.beta0) + ((1-tmp)*self.beta_max);
78
79
            iters += 1;
            
80
            tmp3 = self.proj1.work(tmp1);
81
            tmp_u = ((beta*(2*tmp3-tmp1)) + ((1-beta)*tmp1) + u)/2;
82
            tmp2 = self.proj2.work(tmp_u);
83
            
84
            tmp3 = self.proj1.work(tmp2);
85
86
            
            tmp_change = 0; tmp_gap = 0;
87
            if self.Ny == 1 or self.Nx == 1:
88
89
                tmp_change = (norm(u-tmp_u,'fro')/normM)**2;
                tmp_gap = (norm(tmp3-tmp2,'fro')/normM)**2;
90
            elif self.Nz == 1:
91
                for j in range(self.dim):
92
                    tmp_change += (norm(u[:,:,j]-tmp_u[:,:,j],'fro')/normM)**2;
93
                    tmp_gap += (norm(tmp3[:,:,j]-tmp2[:,:,j],'fro')/normM)**2;
94
            else:
95
96
                for j in range(self.dim):
                    for k in range(self.Nz):
97
98
                        tmp_change += (norm(u[:,:,k,j]-tmp_u[:,:,k,j],'fro')/normM)**2;
                        tmp_gap += (norm(tmp3[:,:,k,j]-tmp2[:,:,k,j],'fro')/normM)**2;
99
            
100
101
            change[iters] = sqrt(tmp_change);
            gap[iters] = sqrt(tmp_gap);
102
103
104
105
            
            u = tmp_u;
            tmp1 = (2*tmp2) - tmp_u;
            
106
107
        
        ##### POSTPROCESSING
108
        u = tmp2;
109
110
        tmp = self.proj1.work(u);
        tmp2 = self.proj2.work(u);
111
        if self.Ny == 1:
112
113
            u1 = tmp[:,1];
            u2 = tmp2[:,1];
114
        elif self.Nx == 1:
115
116
            u1 = tmp[1,:];
            u2 = tmp2[1,:];
117
        elif self.Nz == 1:
118
119
120
121
122
123
            u1 = tmp[:,:,1];
            u2 = tmp2[:,:,1];
        else:
            u1 = tmp;
            u2 = tmp2;
        change = change[1:iters+1];
124
        gap = gap[1:iters+1];
125
        
126
        return u1, u2, iters, change, gap