Commit 0905afea authored by s.gretchko's avatar s.gretchko
Browse files

Added ADMM algorithm, iterate monitor, prox operators & JWST demo

parent 58db158e
import SetProxPythonPath
from proxtoolbox.experiments.phase.JWST_Experiment import JWST_Experiment
JWST = JWST_Experiment(algorithm='ADMM', lambda_0=3.0, lambda_max=3.0)
JWST.run()
JWST.show()
from proxtoolbox.algorithms.algorithm import Algorithm
from proxtoolbox import algorithms
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
import numpy as np
from numpy import pi, zeros, conj
from numpy.fft import fft2, ifft2, fft, ifft
class ADMM(Algorithm):
"""
Alternating directions method of multipliers for solving problems of the form :
minimize f(x) + g(y),
subject to Ax=y
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on October 2, 2017.
"""
def __init__(self, experiment, iterateMonitor, accelerator):
super(ADMM, self).__init__(experiment, iterateMonitor, accelerator)
# instantiate Lagrange multiplier update object
if hasattr(experiment, 'lagrange_mult_update'):
lagrange_mult_update_name = experiment.lagrange_mult_update
lagrange_mult_update_class = getattr(algorithms, lagrange_mult_update_name)
self.lagrange_mult_update = lagrange_mult_update_class(experiment)
else:
raise('Lagrange multiplier update not defined')
def evaluate(self, u):
"""
Update of the ADMM algorithm
Parameters
----------
u : a 3-dimensional cell
The current iterate. 3 blocks of variables, primal (x),
auxilliary (y) and dual variables corresponding to Lagrange
multipliers for the constraint Ax=y
Returns
-------
u_new : a 3-dimensional cell
The new iterate (same type as input parameter `u`).
"""
lmbda = self.computeRelaxation()
u_new = u.copy() # this is only a shallow copy, but this will be enough in this case
self.prox1.lmbda = lmbda
u_new[0] = self.prox1.eval(u)
self.prox2.lmbda = lmbda
u_new[1] = self.prox2.eval(u_new)
u_new[2] = self.lagrange_mult_update.eval(u_new)
return u_new
def getDescription(self):
return self.getDescriptionHelper("\\lambda", self.lambda_0, self.lambda_max)
\ No newline at end of file
from proxtoolbox.utils.size import size_matlab
from proxtoolbox.algorithms.iterateMonitor import IterateMonitor
from proxtoolbox.utils.cell import Cell, isCell
import numpy as np
from numpy import zeros, angle, trace, exp, sqrt, sum, matmul, array, reshape
from numpy.linalg import norm
class ADMM_IterateMonitor(IterateMonitor):
"""
Algorithm analyzer for monitoring iterates of
projection algorithms for the ADMM algorithm.
Specialization of the IterateMonitor class.
"""
def __init__(self, experiment):
super(ADMM_IterateMonitor, self).__init__(experiment)
self.gaps = None
self.shadows = None
self.rel_errors = None
self.product_space_dimension = experiment.product_space_dimension
if hasattr(experiment, 'norm_data'):
self.normM = np.sqrt(experiment.Nx*experiment.Ny)
else:
# won't normalize data to make it independent of array size
self.normM = 1.0
def preprocess(self, alg):
"""
Allocate data structures needed to collect
statistics. Called before running the algorithm.
Parameters
----------
alg : instance of Algorithm class
Algorithm to be monitored.
"""
super(ADMM_IterateMonitor, self).preprocess(alg)
# In ADMM, the iterate u is a cell of blocks of variables.
# In the analysis of Aspelmeier, Charitha& Luke, SIAM J. Imaging
# Sciences, 2016, only the step difference on the "dual"
# blocks of variables is monitored.
self.u_monitor = self.u0.copy()
# set up diagnostic arrays
if self.diagnostic:
self.gaps = self.changes.copy()
self.gaps[0] = 999
self.shadows = self.changes.copy()
self.shadows[0] = sqrt(999)
if self.truth is not None:
self.rel_errors = self.changes.copy()
self.rel_errors[0] = sqrt(999)
def updateStatistics(self, alg):
"""
Update statistics. Called at each iteration
while the algorithm is running.
Parameters
----------
alg : instance of Algorithm class
Algorithm being monitored.
"""
# 'u' is a cell of three blocks of variables, primal, auxilliary and dual,
# each of these containing ARRAYS. For ADMM, convergence of the
# algorithm is determined only by the behavior of the 2nd and 3rd
# blocks of variables (see Aspelmeier, Charitha& Luke, SIAM J. Imaging
# Sciences, 2016). The first block of primal variables is monitored as a
# "shadow sequence" in analogy with the shadow sequence of the
# Douglas-Rachford iteration.
u = alg.u_new
prev_u = self.u_monitor
tmp_change = 0
normM = self.normM
for j in range(1,3):
# the first cell is the block (assumed single) of
# primal variables - assumed an array.
# All other cells are the dual blocks, stored as
# cells or arrays
if isCell(u[j]):
k = len(u[j])
_m, _n, p, q = size_matlab(u[j])
else:
m, n, p, q = size_matlab(u[j])
k = self.n_product_Prox
for kk in range(k):
if p == 1 and q == 1:
if isCell(u[j]):
# the next line was added on the hunch that a global phase
# rotation each iteration is making the algorithm look like
# it is converging more slowly than it really is
tmp_change += (norm(u[j][kk] - prev_u[j][kk])/normM)**2
else: # dealing with 1D arrays on the product space
if n == k:
tmp_change += (norm(u[j][:,kk] - prev_u[j][:,kk])/normM)**2
elif m == k:
tmp_change += (norm(u[j][kk,:] - prev_u[j][kk,:])/normM)**2
elif q == 1:
if isCell(u[j]): # this means cells of cells of 3D arrays...I hope not!
for jj in range(p):
# the next line was added on the hunch that a global phase
# rotation each iteration is making the algorithm look like
# it is converging more slowly than it really is
tmp_change += (norm(u[j][kk][:,:,jj] - prev_u[j][kk][:,:,jj])/normM)**2
else:
# we have a 3D array, the third dimension being the
# product space
tmp_change += (norm(u[j][:,:,kk] - prev_u[j][:,:,kk])/normM)**2
else: # cells of 4D arrays?!!!
for jj in range(q):
for pp in range(p):
tmp_change += (norm(u[kk][:,:,pp, jj] - prev_u[kk][:,:,pp, jj])/normM)**2
self.changes[alg.iter] = sqrt(tmp_change)
if self.diagnostic:
tmp_shadow =0
rel_error = 0
# looks for the gap. We use this to hold the function values of the
# ADMM objective. This is problem/implementation specific and will
# be found in the drivers/*/ProxOperators subdirectories
# where the problem families are defined.
self.gaps[alg.iter] = self.calculateObjective(alg)
# Next, we look at the progress of the primal sequence, or 'shadows' of the
# ADMM algorithm. These are just the steps in the primal sequence
_m, _n, p, q = size_matlab(u[0])
if p == 1 and q == 1:
# the next line was added on the hunch that a global phase
# rotation each iteration is making the algorithm look like
# it is converging more slowly than it really is
tmp_shadow += (norm(u[0] - prev_u[0])/normM)**2
if self.truth is not None:
# assumes that the data is arranged as a column vector:
rel_error += norm(self.truth - exp(-1j*angle(trace(matmul(self.truth.T.conj(), u[0])))) * u[0]) / self.norm_truth
elif q == 1:
for jj in range(p):
tmp_shadow += (norm(u[0][:,:,jj] - prev_u[0][:,:,jj])/normM)**2
if self.truth is not None:
rel_error += norm(self.truth - exp(-1j*angle(trace(matmul(self.truth.T.conj(), u[0][:,:,jj]))))) / self.norm_truth
else: # cells of 4D arrays?!!!
for jj in range(q):
for k in range(p):
tmp_shadow += (norm(u[0][:,:,k,jj] - prev_u[0][:,:,k,jj])/normM)**2
if self.truth is not None:
rel_error += norm(self.truth - exp(-1j*angle(trace(matmul(self.truth.T.conj(), u[0][:,:,k,jj]))))) / self.norm_truth
# end shadow sequence monitor
# The following is the Euclidean norm of the gap to
# the unregularized set. To monitor the Euclidean norm of the gap to the
# regularized set is expensive to calculate, so we use this surrogate.
# Since the stopping criteria is on the change in the iterates, this
# does not matter.
self.shadows[alg.iter] = sqrt(tmp_shadow)
if self.truth is not None:
self.rel_errors[alg.iter] = rel_error
self.u_monitor = u.copy()
def postprocess(self, alg, output):
"""
Called after the algorithm stops. Store statistics in
the given 'output' dictionary
Parameters
----------
alg : instance of Algorithm class
Algorithm that was monitored.
output : dictionary
Contains the last iterate and various statistics that
were collected while the algorithm was running.
Returns
-------
output : dictionary into which the following entries are added
(when diagnostics are required)
gaps : ndarray
Squared gap distance normalized by the magnitude
constraint
shadows : ndarray
??? TODO
"""
output = super(ADMM_IterateMonitor, self).postprocess(alg, output)
if self.diagnostic:
stats = output['stats']
stats['gaps'] = self.gaps[1:alg.iter+1]
stats['shadows'] = self.gaps[1:alg.iter+1]
if self.truth is not None:
stats['rel_errors'] = self.rel_errors[1:alg.iter+1]
return output
\ No newline at end of file
from proxtoolbox.proxoperators.proxoperator import ProxOperator
from proxtoolbox.proxoperators.ADMM_prox import ADMM_Context
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
import numpy as np
from numpy import pi, zeros, conj
from numpy.fft import fft2, ifft2, fft, ifft
import copy
class ADMM_PhaseLagrangeUpdate(ADMM_Context):
"""
Explicit step with respect to the Lagrange multipliers
in the ADMM algorithm.
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on Oct 1, 2017.
"""
def __init__(self, experiment):
super(ADMM_PhaseLagrangeUpdate, self).__init__(experiment)
def eval(self, u):
"""
Explicit step with respect to the Lagrange multipliers
in the ADMM algorithm
Parameters
----------
u : cell
Input data
Returns
-------
xprime : cell
The resulting step
"""
v = copy.deepcopy(u[2]) # we need a deep copy here
m1, n1, p1, _q1 = size_matlab(u[0])
m2, n2, _p2, _q2 = size_matlab(u[1])
k2 = np.amax(u[2].shape)
if m1 > 1 and n1 > 1 and p1 == 1:
FFT = lambda u: fft2(u)
IFFT = lambda u: ifft2(u)
for j in range(k2):
if self.farfield:
if self.fresnel_nr is not None and self.fresnel_nr[j] > 0:
u_hat = -1j*self.fresnel_nr[j]/(self.Nx*self.Ny*2*pi)*FFT(u[0]-self.illumination[j]) + self.FT_conv_kernel[j]
elif self.FT_conv_kernel is not None:
u_hat = FFT(self.FT_conv_kernel[j]*u[0]) / (m1*n1)
else:
u_hat = FFT(u[0]) / (m1*n1)
else: # near field
if self.beam is not None:
u_hat = IFFT(self.FT_conv_kernel[j]*FFT(u*self.beam[j]))/self.magn[j]
else:
u_hat = IFFT(self.FT_conv_kernel[j]*FFT(u))/self.magn[j]
tmp = u_hat - u[1][j]
v[j] += tmp
elif m1 == 1:
FFT = lambda u: fft(u)
IFFT = lambda u: ifft(u)
for j in range(m2):
if self.farfield:
if self.fresnel_nr is not None and self.fresnel_nr[j] > 0:
raise NotImplementedError('Error ADMM_PhaseLagrangeUpdate: complicated far field set-ups for 1D signals not implemented')
elif self.FT_conv_kernel is not None:
u_hat = FFT(self.FT_conv_kernel[j]*u[0]) / n1
else:
u_hat = FFT(u[0]) / n1
else:
raise NotImplementedError('Error ADMM_PhaseLagrangeUpdate: near field set-ups for 1D signals not implemented')
tmp = u_hat - u[1][j,:]
v[j,:] += tmp
elif n1 == 1:
FFT = lambda u: fft(u)
IFFT = lambda u: ifft(u)
for j in range(n2):
if self.farfield:
if self.fresnel_nr is not None and self.fresnel_nr[j] > 0:
raise NotImplementedError('Error ADMM_PhaseLagrangeUpdate: complicated far field set-ups for 1D signals not implemented')
elif self.FT_conv_kernel is not None:
u_hat = FFT(self.FT_conv_kernel[j]*u[0]) / m1
else:
u_hat = FFT(u[0]) / m1
else:
raise NotImplementedError('Error ADMM_PhaseLagrangeUpdate: near field set-ups for 1D signals not implemented')
tmp = u_hat - u[1][:,j]
v[:,j] += tmp
return v
from proxtoolbox.proxoperators.proxoperator import ProxOperator
from proxtoolbox.proxoperators.ADMM_prox import ADMM_Context
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
import numpy as np
from numpy import pi, zeros, conj
from numpy.fft import fft2, ifft2, fft, ifft
from numpy.linalg import norm
import copy
class ADMM_phase_indicator_objective(ADMM_Context):
"""
Augmented Lagrangian for the ADMM algorithm applied to the
phase retrieval problem with an indicator function for the primal objective:
Lagrangian(u{1}, u{2}, u{3})= iota_0(u{1}) + ...
sum_{j=1}^m iota_j(u{2}(j) + <u{3}(j), (F_ju{1}(j))-u{2}(j))>
+ 1/2||F_ju{1}(j))-u{2}(j)||^2
We assume that the point u is feasible, so the indicator functions
will be zero and all that need be computed is:
Lagrangian(u{1}, u{2}, u{3}) = sum_{j=1}^m <u{3}(j),
(F_ju{1}(j))-u{2}(j)> + 1/2||F_ju{1}(j))-u{2}(j)||^2
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on Oct 3, 2017.
"""
def __init__(self, experiment):
super(ADMM_phase_indicator_objective, self).__init__(experiment)
if hasattr(experiment, 'norm_data'):
self.normM = experiment.norm_data
else:
self.normM = 1.0
def calculateObjective(self, alg):
"""
Evaluate ADMM objective function
Parameters
----------
alg : algorithm instance
The algorithm that is running
Returns
-------
lagrangian : real
The value of the objective function
which will in this case measure the gap
"""
# The implementation of this function is in many ways similar to
# the ADMM_PhaseLagrangeUpdate code. The structure is similar
# even if the calculation is different. It would be nice to
# use one common implementation
lagrangian = 0
u = alg.u_new
m1, n1, p1, _q1 = size_matlab(u[0])
m2, n2, _p2, _q2 = size_matlab(u[1])
k2 = self.product_space_dimension
eta = self.lmbda
if m1 > 1 and n1 > 1 and p1 == 1:
FFT = lambda u: fft2(u)
IFFT = lambda u: ifft2(u)
for j in range(k2):
if self.farfield:
if self.fresnel_nr is not None and self.fresnel_nr[j] > 0:
u_hat = -1j*self.fresnel_nr[j]/(self.Nx*self.Ny*2*pi)*FFT(u[0]-self.illumination[j]) + self.FT_conv_kernel[j]
elif self.FT_conv_kernel is not None:
u_hat = FFT(self.FT_conv_kernel[j]*u[0]) / (m1*n1)
else:
u_hat = FFT(u[0]) / (m1*n1)
else: # near field
if self.beam is not None:
u_hat = IFFT(self.FT_conv_kernel[j]*FFT(u*self.beam[j]))/self.magn[j]
else:
u_hat = IFFT(self.FT_conv_kernel[j]*FFT(u))/self.magn[j]
tmp = u_hat - u[1][j]
lagrangian += np.real(np.trace(u[2][j].T.conj() @ tmp)) / self.normM**2
lagrangian += 0.5 * eta * (norm(tmp)/self.normM)**2
elif m1 == 1: # 1D signals
FFT = lambda u: fft(u)
IFFT = lambda u: ifft(u)
for j in range(m2): # in Matlab this is k2, is it correct?
if self.farfield:
if self.fresnel_nr is not None and self.fresnel_nr[j] > 0:
raise NotImplementedError('Error ADMM_phase_indicator_objective: complicated far field set-ups for 1D signals not implemented')
elif self.FT_conv_kernel is not None:
u_hat = FFT(self.FT_conv_kernel[j]*u[0]) / n1
else:
u_hat = FFT(u[0]) / n1
else:
raise NotImplementedError('Error ADMM_phase_indicator_objective: near field set-ups for 1D signals not implemented')
if isCell(u[1]):
tmp = u_hat - u[1][j]
lagrangian += (u[2][j].T.conj() @ tmp) / self.normM**2
else:
tmp = u_hat - u[1][j,:]
lagrangian += (u[2][j,:].T.conj() @ tmp) / self.normM**2
lagrangian += 0.5 * (norm(tmp)/self.normM)**2
elif n1 == 1: # 1D signals
FFT = lambda u: fft(u)
IFFT = lambda u: ifft(u)
for j in range(n2): # in Matlab this is k2, is it correct?
if self.farfield:
if self.FT_conv_kernel is not None:
u_hat = FFT(self.FT_conv_kernel[j]*u[0]) / m1
else:
u_hat = FFT(u[0]) / m1
else:
raise NotImplementedError('Error ADMM_phase_indicator_objective: near field set-ups for 1D signals not implemented')
if isCell(u[1]):
tmp = u_hat - u[1][j]
lagrangian += (u[2][j].T.conj() @ tmp) / self.normM**2
else:
tmp = u_hat - u[1][j,:]
lagrangian += (u[2][j,:].T.conj() @ tmp) / self.normM**2
lagrangian += 0.5 * (norm(tmp)/self.normM)**2
else:
print('Error ADMM_phase_indicator_objective: Not designed to handle 4D arrays - throwing a dummy value 999')
lagrangian = 999
return lagrangian
......@@ -21,22 +21,8 @@ from .iterateMonitor import *
from .feasibilityIterateMonitor import *
from .PHeBIE_IterateMonitor import*
from .CT_IterateMonitor import*
'''
from .AP_expert import *
# from .PALM import *
from .RAAR_expert import *
from .GRAAL import *
# from .HAAR import *
from .ADMM import *
from .RRR import *
from .CAARL import *
from .CDRlrand import *
from .CPrand import *
from .DRl import *
from .GRAAL_F import *
from .GRAAL_objective import *
from .KM import *
from .Wirtinger import *
'''
#__all__ = ["AP","HPR","RAAR", "AP_expert", "GRAAL", "RAAR_expert", "DRl", "ADMM", "RRR", "CAARL", "CADRl", "CDRl", "CP", "CPrand", "DRAP", "DRl", "GRAAL_F", "GRAAL_objective", "KM"]
from .ADMM_IterateMonitor import *
from .ADMM_PhaseLagrangeUpdate import *
from .ADMM_phase_indicator_objective import *
from proxtoolbox import algorithms
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
from numpy import zeros, angle, trace, exp, sqrt, sum
......@@ -24,6 +25,14 @@ class IterateMonitor:
self.rotate = experiment.rotate
self.n_product_Prox = experiment.n_product_Prox
self.changes = None
# instantiate optimality monitor if it exists
self.optimality_monitor = None
if hasattr(experiment, 'optimality_monitor'):
optimality_monitor_name = experiment.optimality_monitor
optimality_monitor_class = getattr(algorithms, optimality_monitor_name)
self.optimality_monitor = optimality_monitor_class(experiment)
def preprocess(self, alg):
"""
......@@ -185,6 +194,26 @@ class IterateMonitor:
for k in range(p):
tmp_change += (norm(u[:,:,k,j] - u_new[:,:,k,j])/self.norm_data)**2
return tmp_change
def calculateObjective(self, alg):
"""
Calculate objective value. The dafault implementation
uses the optimality monitor if it exists
Parameters
----------
alg : instance of Algorithm class
Algorithm that was monitored.
Returns
-------
objValue : real
objective value
"""
if self.optimality_monitor is not None:
return self.optimality_monitor.calculateObjective(alg)
else:
raise("optimality_monitor was not provided")
......@@ -223,6 +223,8 @@ class Experiment(metaclass=ExperimentMetaClass):
self.loadData()
self.reshapeData(self.Nx, self.Ny, self.Nz) # TODO need to revisit this
# - arguments not needed in particular
if self.TOL2 is None:
self.TOL2 = 1e-20
# define the prox operators to be used for this experiment (list their names)
# default implementation chooses prox operators based on constraint
......@@ -231,9 +233,6 @@ class Experiment(metaclass=ExperimentMetaClass):
# retrieve the classes corresponding to the prox operators' names given above
self.retrieveProxOperatorClasses()
if self.TOL2 is None:
self.TOL2 = 1e-20