Commit 0819a3fd authored by s.gretchko's avatar s.gretchko
Browse files

Added AvP2 algorithm, iterate monitor & JWST demo

parent 293c19fb
import SetProxPythonPath
from proxtoolbox.experiments.phase.JWST_Experiment import JWST_Experiment
JWST = JWST_Experiment(algorithm='ADMM', lambda_0=3.0, lambda_max=3.0,
MAXIT=300, noise = True, rotate = True)
JWST.run()
JWST.show()
import SetProxPythonPath
from proxtoolbox.experiments.phase.JWST_Experiment import JWST_Experiment
JWST = JWST_Experiment(algorithm='AvP2', lambda_0=0.75, lambda_max=0.75)
JWST.run()
JWST.show()
import SetProxPythonPath
from proxtoolbox.experiments.phase.JWST_Experiment import JWST_Experiment
JWST = JWST_Experiment(algorithm='AvP2', lambda_0=0.75, lambda_max=0.75, noise=True)
JWST.run()
JWST.show()
from proxtoolbox.algorithms.algorithm import Algorithm
from proxtoolbox import proxoperators
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
import numpy as np
import copy
from numpy import exp
class AvP2(Algorithm):
"""
Alternating prox with modification as described
in Luke,Sabach&Teboulle "Optimization on Spheres:
Models and Proximal Algorithms
with Computational Performance Comparisons"
D. R. Luke, S. Sabach and M. Teboulle. 2019.
This is an averaged projections algorithm with a
2-step recursion (x^{k-1}, x^{k}, u^{k})
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on September 08, 2016.
"""
def __init__(self, experiment, iterateMonitor, accelerator):
super(AvP2, self).__init__(experiment, iterateMonitor, accelerator)
if not hasattr(experiment, 'shift_data'):
if isCell(experiment.u0[2]):
length = len(experiment.u0[2])
self.shift_data = Cell(length)
for j in range(length):
self.shift_data[j] = 0
else:
self.shift_data = 0
def evaluate(self, u):
"""
Update of the AvP2 algorithm
Parameters
----------
u : ndarray or a cell of ndarray objects
The current iterate.
Returns
-------
u_new : ndarray or a cell of ndarray objects
The new iterate (same type as input parameter `u`).
"""
lmbda = self.computeRelaxation()
u_new = Cell(len(u))
if isCell(u[2]):
length = len(u[2])
u2_tmp = Cell(length)
for j in range(length):
u2_tmp[j] = u[2][j] + self.shift_data[j] + (u[1]-u[0])*lmbda
u_new2 = self.prox1.eval(u2_tmp)
u_new[1] = u_new2[0]
for j in range(length):
u2_tmp[j] = u[2][j] + (2*u[1]-u[0]-self.shift_data[j])*lmbda
u_new[2] = self.prox2.eval(u2_tmp)
else:
mm, nn, pp, _qq = size_matlab(u[2])
if pp == 1:
if mm == self.product_space_dimension:
u2_tmp = u[2] + self.shift_data + np.ones(mm)*(u[1]-u[0])*lmbda
u_new2 = self.prox1.eval(u2_tmp)
u_new[1] = u_new2[0,:]
u2_tmp = u[2] + (np.ones(mm)*(2*u[1]-u[0]) - self.shift_data[j])*lmbda
u_new[2] = self.prox2.eval(u2_tmp)
else:
u2_tmp = u[2] + self.shift_data + np.ones(nn)*(u[1]-u[0])*lmbda
u_new2 = self.prox1.eval(u2_tmp)
u_new[1] = u_new2[:,0]
u2_tmp = u[2] + (np.ones(nn)*(2*u[1]-u[0]) - self.shift_data[j])*lmbda
u_new[2] = self.prox2.eval(u2_tmp)
else:
# below is a hack: the shift data is assumed a scalar, which
# only makes sense if this case is a phase problem.
for j in range(pp):
u2_tmp = u[2][:,:,j] + self.shift_data + (u[1]-u[0])*lmbda
u_new2 = self.prox1(u2_tmp)
u_new[1] = u_new2[:,:,0]
for j in range(pp):
u2_tmp[:,:,j] = u[2][:,:,j] + (2*u[1]-u[0]-self.shift_data)*lmbda
u_new[2] = self.prox2(u2_tmp)
u_new[0] = copy.deepcopy(u[1]) # deep copy is needed since u[1] may be a cell
return u_new
from proxtoolbox.algorithms.iterateMonitor import IterateMonitor
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
import numpy as np
from numpy import zeros, angle, trace, exp, sqrt, sum, matmul, array, reshape
from numpy.linalg import norm
class AvP2_IterateMonitor(IterateMonitor):
"""
Algorithm analyzer for monitoring iterates of
projection algorithms for the PHeBIE algorithm.
Specialization of the IterateMonitor class.
"""
def __init__(self, experiment):
super(AvP2_IterateMonitor, self).__init__(experiment)
self.gaps = None
self.rel_errors = None
self.product_space_dimension = experiment.product_space_dimension
self.Nz = experiment.Nz
name = experiment.experiment_name
self.useExpForRelError = (name == 'CDP' or name == 'JWST' or name == 'Phase')
def preprocess(self, alg):
"""
Allocate data structures needed to collect
statistics. Called before running the algorithm.
Parameters
----------
alg : instance of Algorithm class
Algorithm to be monitored.
"""
super(AvP2_IterateMonitor, self).preprocess(alg)
# In PHeBIE, the iterate u is a cell of blocks of variables.
# In the analysis of Hesse, Luke, Sabach&Tam, SIAM J. Imaging Sciences, 2015
# all blocks are monitored for convergence.
self.u_monitor = self.u0[1].copy()
# set up diagnostic arrays
if self.diagnostic:
self.gaps = self.changes.copy()
self.gaps[0] = 999
if self.truth is not None:
self.rel_errors = self.changes.copy()
self.rel_errors[0] = sqrt(999)
def updateStatistics(self, alg):
"""
Update statistics. Called at each iteration
while the algorithm is running.
Parameters
----------
alg : instance of Algorithm class
Algorithm being monitored.
"""
u = alg.u_new
prev_u1 = self.u_monitor
normM = self.norm_data
m, n, p, q = size_matlab(u[1])
tmp_change = 0
if p == 1 and q == 1:
tmp_change += (norm(u[1] - prev_u1)/normM)**2
elif q == 1:
for jj in range(self.product_space_dimension):
tmp_change += (norm(u[1][:,:,jj] - prev_u1)/normM)**2
else: # cells of 4D arrays?!!!
for jj in range(self.product_space_dimension):
for k in range(self.Nz):
tmp_change += (norm(u[1][:,:,k,jj] - prev_u1)/normM)**2
self.changes[alg.iter] = sqrt(tmp_change)
if self.diagnostic:
self.gaps[alg.iter] = self.calculateObjective(alg)
if self.truth is not None:
mm, nn, _pp, _qq = size_matlab(u[2])
# the main_ProxToolbox
# looks for the gap. We use this to hold the function values of the
# objective. Here we assume that the underlying objective (for
# experiment, NSLS) only depends on the current x-block, u{2}, in the APplus
# algorithm.
if not isCell(u[2]):
m = mm
n = nn
if m == self.product_space_dimension:
u_tmp = u[1][0,:]
elif n == self.product_space_dimension:
u_tmp = u[1][:,0]
else:
if p == 1:
u_tmp = u[1]
else:
u_tmp = u[1][:,:,p-1]
if self.useExpForRelError:
rel_error = norm(self.truth - exp(-1j*angle(trace(matmul(self.truth.T.conj(), u_tmp)))) * u_tmp) / self.norm_truth
else:
rel_error = norm(self.truth - u_tmp)
self.rel_errors[alg.iter] = rel_error
self.u_monitor = u[1].copy()
def calculateObjective(self, alg):
# TODO Matlab code:
# There is an objective function behind this method. This has been
# implemented for source localization, but not yet for phaes retrieval
# if strcmp(method_input.problem_family, 'Source_localization')
# method_input.gap(iter) = feval('NSLS_sourceloc_objective', u{3}, method_input);
# else
#
# method_input.gap(iter) = 999; % just a dummy value.
# for now use dummy value
objValue = 999
return objValue
def postprocess(self, alg, output):
"""
Called after the algorithm stops. Store statistics in
the given 'output' dictionary
Parameters
----------
alg : instance of Algorithm class
Algorithm that was monitored.
output : dictionary
Contains the last iterate and various statistics that
were collected while the algorithm was running.
Returns
-------
output : dictionary into which the following entries are added
(when diagnostics are required)
gaps : ndarray
Squared gap distance normalized by the magnitude
constraint
"""
output = super(AvP2_IterateMonitor, self).postprocess(alg, output)
if self.diagnostic:
stats = output['stats']
stats['gaps'] = self.gaps[1:alg.iter+1]
if self.truth is not None:
stats['rel_errors'] = self.rel_errors[1:alg.iter+1]
return output
\ No newline at end of file
......@@ -25,4 +25,6 @@ from .ADMM import *
from .ADMM_IterateMonitor import *
from .ADMM_PhaseLagrangeUpdate import *
from .ADMM_phase_indicator_objective import *
from .AvP2 import *
from .AvP2_IterateMonitor import *
......@@ -209,33 +209,34 @@ class JWST_Experiment(PhaseExperiment):
self.propagator = 'Propagator_FreFra'
self.inverse_propagator = 'InvPropagator_FreFra'
# This may depend on which algorithm is used
if self.algorithm_name == 'ADMM':
#----------------------------
# This block of code is independent of experiment: could move to PhaseEperiment::setupProxOperators
# There are two main prox mappings, one for each of the first two
# blocks. These prox mappings may be further decomposed, but that
# is not handled at this level.
self.proxOperators = []
self.productProxOperators = []
if self.formulation == 'cyclic':
# there are as many prox operators as there are sets
self.nProx = self.sets
self.product_space_dimension = 1
for _j in range(self.nProx-1):
self.proxOperators.append('Approx_Pphase_FreFra_Poisson')
self.proxOperators.append('P_amp_support')
else: # product space formulation
# add prox operators
self.nProx = 2
# Set up Prox0, prox operator used in the primal prox block
if not hasattr(self, 'Prox0'):
self.Prox0 = self.proxOperators[0] # given by previous call to parent
self.iterate_monitor_name = 'ADMM_IterateMonitor'
#if not hasattr(self, 'optimality_monitor'):
self.optimality_monitor = 'ADMM_phase_indicator_objective'
self.lagrange_mult_update = 'ADMM_PhaseLagrangeUpdate'
#-------------------------------
# JWST specific code
self.proxOperators.append('P_diag')
self.proxOperators.append('Prox_product_space')
# add product prox operators
self.n_product_Prox = self.product_space_dimension
for _j in range(self.n_product_Prox-1):
self.productProxOperators.append('Approx_Pphase_FreFra_Poisson')
self.productProxOperators.append('P_amp_support')
# Now we adjust data structures and prox operators according to the algorithm
if self.algorithm_name == 'ADMM':
# set-up variables in cells for block-wise algorithms
# ADMM has three blocks of variables, primal, auxilliary (primal
# variables in the image space of some linear mapping) and dual.
# we sort these into a 3D cell.
u0 = self.u0
self.u0 = Cell(3)
self.proxOperators = []
self.productProxOperators = []
K = self.product_space_dimension - 1
......@@ -251,29 +252,18 @@ class JWST_Experiment(PhaseExperiment):
self.u0[2][j] = self.u0[1][j] / self.lambda_0
self.proxOperators.append('Prox_primal_ADMM_indicator')
self.proxOperators.append('Approx_Pphase_FreFra_ADMM_Poisson')
else:
# general case
self.proxOperators = []
self.productProxOperators = []
if self.formulation == 'cyclic':
# there are as many prox operators as there are sets
self.nProx = self.sets
self.product_space_dimension = 1
for _j in range(self.nProx-1):
self.proxOperators.append('Approx_Pphase_FreFra_Poisson')
self.proxOperators.append('P_amp_support')
else: # product space formulation
# add prox operators
self.nProx = 2
self.proxOperators.append('P_diag')
self.proxOperators.append('Prox_product_space')
# add product prox operators
self.n_product_Prox = self.product_space_dimension
for _j in range(self.n_product_Prox-1):
self.productProxOperators.append('Approx_Pphase_FreFra_Poisson')
self.productProxOperators.append('P_amp_support')
elif self.algorithm_name == 'AvP2':
# u_0 should be a cell...we change it into a cell of cells
u0 = self.u0
self.u0 = Cell(3)
prox2_class = getattr(proxoperators, self.proxOperators[1])
prox2 = prox2_class(self)
self.u0[2] = prox2.eval(u0)
P_Diag_class = getattr(proxoperators, 'P_diag')
P_Diag_prox = P_Diag_class(self)
tmp_u = P_Diag_prox.eval(self.u0[2])
self.u0[1] = tmp_u[0]
self.u0[0] = u0[self.product_space_dimension-1]
def show(self):
"""
......
......@@ -61,6 +61,37 @@ class PhaseExperiment(Experiment):
# the prox algorithms work with the square root of the measurement:
self.data = self.data.reshape(Nx, Ny)
def setupProxOperators(self):
"""
Determine the prox operators to be used for this experiment
"""
super(PhaseExperiment, self).setupProxOperators() # call parent's method
# Algorithm specific initialization
# The following code should be independent from the actual experiment
if self.algorithm_name == 'ADMM':
# There are two main prox mappings, one for each of the first two
# blocks. These prox mappings may be further decomposed, but that
# is not handled at this level.
# Set up Prox0, prox operator used in the primal prox block
if not hasattr(self, 'Prox0'):
self.Prox0 = self.proxOperators[0] # given by previous call to parent
self.iterate_monitor_name = 'ADMM_IterateMonitor'
#if not hasattr(self, 'optimality_monitor'):
self.optimality_monitor = 'ADMM_phase_indicator_objective'
self.lagrange_mult_update = 'ADMM_PhaseLagrangeUpdate'
elif self.algorithm_name == 'AvP2':
# There are two main prox mappings, one for each of the first two
# blocks. These prox mappings may be further decomposed, but that
# is not handled at this level.
# Set up Prox0, prox operator used in the primal prox block
if not hasattr(self, 'Prox0'):
self.Prox0 = self.proxOperators[0] # given by previous call to parent
self.iterate_monitor_name = 'AvP2_IterateMonitor'
def show(self):
"""
Generate graphical output from the solution
......
......@@ -2,6 +2,7 @@
import numpy as np
from numpy import zeros
from proxtoolbox.proxoperators.proxoperator import ProxOperator
from proxtoolbox import proxoperators
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
......@@ -18,10 +19,13 @@ class Prox_product_space(ProxOperator):
def __init__(self, experiment):
# instantiate product prox operators
self.proxOps = []
for proxClass in experiment.productProxOperators:
if proxClass != '':
prox = proxClass(experiment)
self.proxOps.append(prox)
for prox_item in experiment.productProxOperators:
if isinstance(prox_item, str):
proxClass = getattr(proxoperators, prox_item)
else:
proxClass = prox_item
prox = proxClass(experiment)
self.proxOps.append(prox)
def eval(self, u, prox_index = None):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment