Commit 46200831 authored by s.gretchko's avatar s.gretchko
Browse files

Added Ptychography algorithm, prox operators & iterate monitor

parent f6c795e7
......@@ -2,6 +2,6 @@
import SetProxPythonPath
from proxtoolbox.experiments.ptychography.ptychographyExperiment import PtychographyExperiment
Pty = PtychographyExperiment(debug = False)
Pty = PtychographyExperiment(debug = False, warmup = False)
Pty.run()
Pty.show()
from proxtoolbox.algorithms.algorithm import Algorithm
from proxtoolbox.Utilities.cell import Cell, isCell
from numpy import ones
class PHeBIE(Algorithm):
"""
Proximal Heterogenious Block Implicit-Explicit (PHeBIE)
minimzation algorithm as proposed in the paper
"Proximal Heterogeneous Block Implicit-Explicit Method and
Application to Blind Ptychographic Diffraction Imaging",
R. Hesse, D. R. Luke, S. Sabach and M. K. Tam,
SIAM J. on Imaging Sciences, 8(1):426--457 (2015).
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on Aug. 31, 2017.
"""
def __init__(self, experiment, iterateMonitor, accelerator):
super(PHeBIE, self).__init__(experiment, iterateMonitor, accelerator)
# Sets the PHeBIE block stepsizes
if hasattr(experiment, 'beta'):
self.beta = experiment.beta
else:
self.beta = ones(experiment.nProx-1) # TODO check if correct
# the stepsize parameter gamma is only needed for the analysis.
#For the implmentation, this can be taken to be machine zero.
if hasattr(experiment, 'gamma'):
self.gamma = experiment.gamma
else:
self.gamma = 1e-30
# instantiate explicit prox operators
self.explicit = []
for proxClass in experiment.explicit:
if proxClass != '':
prox = proxClass(experiment)
self.explicit.append(prox)
def evaluate(self, u):
"""
Update for the PHeBIE algorithm
Parameters
----------
u : a cell whose cell elements are the blocks
The current iterate.
Returns
-------
u_new : type ???
The new iterate.
"""
u_new = u
# Loop through the block-wise implicit-explicit update steps.
K = len(self.proxOperators)
for s in range(K):
u_new[s] = self.explicit[s].eval(u, s) # explicit prox evaluation
u_new[s] = self.proxOperators[s].eval(u_new[s], s) # implicit prox evauation
# apply first order acceleration if required
if self.accelerator is not None:
u_new = self.accelerator.evaluate(u_new, self)
return u_new
from proxtoolbox.algorithms.iterateMonitor import IterateMonitor
from proxtoolbox.Utilities.cell import Cell, isCell
import numpy as np
from numpy import zeros, angle, trace, exp, sqrt, sum, matmul, array, reshape
from numpy.linalg import norm
class PHeBIE_IterateMonitor(IterateMonitor):
"""
Algorithm analyzer for monitoring iterates of
projection algorithms for the PHeBIE algorithm.
Specialization of the IterateMonitor class.
"""
def __init__(self, experiment):
super(PHeBIE_IterateMonitor, self).__init__(experiment)
self.gaps = None
self.rel_errors = None
self.product_space_dimension = experiment.product_space_dimension
self.Nz = experiment.Nz
self.positions = experiment.positions
def preprocess(self, alg):
"""
Allocate data structures needed to collect
statistics. Called before running the algorithm.
Parameters
----------
alg : instance of Algorithm class
Algorithm to be monitored.
"""
super(PHeBIE_IterateMonitor, self).preprocess(alg)
# In PHeBIE, the iterate u is a cell of blocks of variables.
# In the analysis of Hesse, Luke, Sabach&Tam, SIAM J. Imaging Sciences, 2015
# all blocks are monitored for convergence.
self.u_monitor = self.u0
# set up diagnostic arrays
if self.diagnostic:
self.gaps = self.changes.copy()
self.gaps[0] = sqrt(999)
if self.truth is not None:
self.rel_errors = self.changes.copy()
self.rel_errors[0] = sqrt(999)
def updateStatistics(self, alg):
"""
Update statistics. Called at each iteration
while the algorithm is running.
Parameters
----------
alg : instance of Algorithm class
Algorithm being monitored.
"""
u = alg.u_new
prev_u = self.u_monitor
normM = self.norm_data
nProx = len(alg.proxOperators)
tmp_change = 0
for j in range(nProx):
if isCell(u[j]):
for jj in range(len(u[j])):
tmp_change += (norm(u[j][jj] - prev_u[j][jj], 'fro')/normM)**2
else:
p, q = self.getIterateSize(u[j])
if p == 1 and q == 1:
tmp_change += (norm(u[j] - prev_u[j], 'fro')/normM)**2
elif q == 1:
for jj in range(self.product_space_dimension):
tmp_change += (norm(u[j][:,:,jj] - prev_u[j][:,:,jj], 'fro')/normM)**2
else: # cells of 4D arrays?!!!
for jj in range(self.product_space_dimension):
for k in range(self.Nz):
tmp_change += (norm(u[j][:,:,k,jj] - prev_u[j][:,:,k,jj], 'fro')/normM)**2
self.changes[alg.iter] = sqrt(tmp_change)
if self.diagnostic:
self.gaps[alg.iter] = self.calculateObjective(alg)
if self.truth is not None:
rel_error = norm(self.truth - exp(-1j*angle(trace(matmul(self.truth.T.conj(), u[0])))) * u[0],'fro') / self.norm_truth
self.rel_errors[alg.iter] = rel_error
def calculateObjective(self, alg):
objValue = 0
u = alg.u_new
normM = self.norm_data
self.rangeNx = np.arange(alg.Nx, dtype=np.int)
self.rangeNy = np.arange(alg.Ny, dtype=np.int)
if isCell(u[2]):
for jj in range(len(u[2])):
indy = (self.positions[jj,0] + self.rangeNy).astype(int)
indx = (self.positions[jj,1] + self.rangeNx).astype(int)
objValue += (norm(u[0]*u[1][indy, indx] - u[2][jj], 'fro')/normM)**2
else:
p, _q = self.getIterateSize(u[2])
for jj in range(p):
indy = (self.positions[jj,0] + self.rangeNy).astype(int)
indx = (self.positions[jj,1] + self.rangeNx).astype(int)
objValue += (norm(u[0]*u[1][indy, indx] - u[2][:,:,jj], 'fro')/normM)**2
return objValue
def postprocess(self, alg, output):
"""
Called after the algorithm stops. Store statistics in
the given 'output' dictionary
Parameters
----------
alg : instance of Algorithm class
Algorithm that was monitored.
output : dictionary
Contains the last iterate and various statistics that
were collected while the algorithm was running.
Returns
-------
output : dictionary into which the following entries are added
(when diagnostics are required)
gaps : ndarray
Squared gap distance normalized by the magnitude
constraint
"""
output = super(PHeBIE_IterateMonitor, self).postprocess(alg, output)
if self.diagnostic:
stats = output['stats']
stats['gaps'] = self.gaps[1:alg.iter+1]
if self.truth is not None:
stats['rel_errors'] = self.rel_errors[1:alg.iter+1]
return output
\ No newline at end of file
......@@ -17,9 +17,10 @@ from .AvP import *
from .QNAvP import *
from .DRl import *
from .CDRl import *
from .PHeBIE import *
from .iterateMonitor import *
from .feasibilityIterateMonitor import *
from .PHeBIE_IterateMonitor import*
'''
from .AP_expert import *
......
#SG
# New Cone_and_Sphere code
from proxtoolbox.algorithms.iterateMonitor import IterateMonitor
from proxtoolbox.Utilities.cell import Cell, isCell
......@@ -95,7 +93,7 @@ class FeasibilityIterateMonitor(IterateMonitor):
# is not converging to something reasonable. We don't know how to characterize this
# yet (as of 2018) but a rotation of the iterates to a fixed global reference
# reveals the RELEVANT convergence of the (shadow sequence of the) algorithm:
if self.isBlock:
if isCell(u):
for j in range(len(u)):
self.u_monitor[0][j] = exp(-1j*angle(trace(matmul(prev_u_mon[0][j].T.conj(), u[j]))))*u[j]
else:
......
......@@ -50,6 +50,7 @@ class IterateMonitor:
alg : instance of Algorithm class
Algorithm being monitored.
"""
self.u_monitor = alg.u_new # store the last iterate in u_monitor
if not self.isCell:
tmp_change = self.evaluateChange(alg.u, alg.u_new)
else:
......@@ -88,7 +89,6 @@ class IterateMonitor:
"""
output['u_monitor'] = self.u_monitor
# TODO: This appears to not have been updated anywhere, meaning that this is the initial random guess
if self.diagnostic:
if isCell(self.u_monitor):
u_m = self.u_monitor[0]
......
......@@ -26,11 +26,12 @@ class ExperimentMetaClass(type):
"""
Called when any instance of Experiment class is created.
Does a few things:
- set the experiment parameters based on the given keyword arguments and the experiment defaults
(self.defaultParameters)
- create the class object by calling type
- call the function "self.initialize", which will initialize the experiment by loading data, setting up
algorithms and so on
- set the experiment parameters based on the given keyword
arguments and the default parameters given by the static method
getDefaultParameters()
- create the instance of the experiment class
- call the initialize() method on the newly created instance
(to load the data, set-up the prox operators and the algorithm).
"""
actualArgs = cls.getDefaultParameters() # retrieve default arguments (specified by derived class)
actualArgs.update(kwargs) # update these arguments based the on caller's arguments
......@@ -215,28 +216,15 @@ class Experiment(metaclass=ExperimentMetaClass):
# load data
self.loadData()
self.reshapeData(self.Nx, self.Ny, self.Nz)
self.reshapeData(self.Nx, self.Ny, self.Nz) # TODO need to revisit this
# - arguments not needed in particular
# define the prox operators to be used for this experiment (list their names)
# default implementation chooses prox operators based on constraint
self.setupProxOperators()
# find prox operators classes based on their name
proxOperatorClasses = []
for prox in self.proxOperators:
if prox != '':
proxOperatorClasses.append(getattr(proxoperators, prox))
self.proxOperators = proxOperatorClasses
# same with product prox operators
proxOperatorClasses = []
for prox in self.productProxOperators:
if prox != '':
proxOperatorClasses.append(getattr(proxoperators, prox))
self.productProxOperators = proxOperatorClasses
# and propagator and inverse propagator
self.propagator = getattr(proxoperators, self.propagator)
self.inverse_propagator = getattr(proxoperators, self.inverse_propagator)
# retrieve the classes corresponding to the prox operators' names given above
self.retrieveProxOperatorClasses()
if self.TOL2 is None:
self.TOL2 = 1e-20
......@@ -252,24 +240,16 @@ class Experiment(metaclass=ExperimentMetaClass):
# self.TOL2 = self.data_ball * 1e-15
# -----------
# instantiate iterate monitor and algorithm
algorithmClass = getattr(algorithms, self.algorithm_name)
iterateMonitorClass = getattr(algorithms, self.iterate_monitor_name)
iterateMonitor = iterateMonitorClass(self)
if self.accelerator_name is not None:
acceleratorClass = getattr(algorithms, self.accelerator_name)
accelerator = acceleratorClass(self)
else:
accelerator = None
self.algorithm = algorithmClass(self, iterateMonitor, accelerator)
# instantiate iterate monitor, accelerator and algorithm
self.instanciateAlgorithm()
self.animation = 'Phase_animation' # TODO difference with anim???
self.animation = None # TODO difference with anim???
def run(self):
"""
Run the algorithm associated with this experiment.
The initial iterate is obtained from the data processor.
The initial iterate is obtained from the loadData() method.
TODO: implement multiple runs
"""
print("Running " + self.algorithm_name + " on " + self.experiment_name + "...")
......@@ -291,7 +271,7 @@ class Experiment(metaclass=ExperimentMetaClass):
if self.verbose:
self.printBenchmark()
def loadData(self, *args, **kwargs):
def loadData(self):
"""
Load or generate the dataset that will be used for
this experiment. Create the initial iterate.
......@@ -330,6 +310,35 @@ class Experiment(metaclass=ExperimentMetaClass):
self.proxOperators.append(None)
self.proxOperators.append('Approx_PM_Poisson') # Patrick: This is just to monitor the change of phases!
def retrieveProxOperatorClasses(self):
"""
Retrieve the Python classes corresponding to the prox operators
defined by the setupProxOperators().
This is a helper method called during the initialization
process. It should not be overriden, unless the derived
experiment class uses an additional category of prox operators
not covered in the Experiment class.
"""
# find prox operators classes based on their name
# and replace names by actual classes
proxOperatorClasses = []
for prox in self.proxOperators:
if prox != '':
proxOperatorClasses.append(getattr(proxoperators, prox))
self.proxOperators = proxOperatorClasses
# same with product prox operators
proxOperatorClasses = []
for prox in self.productProxOperators:
if prox != '':
proxOperatorClasses.append(getattr(proxoperators, prox))
self.productProxOperators = proxOperatorClasses
# and propagator and inverse propagator
self.propagator = getattr(proxoperators, self.propagator)
self.inverse_propagator = getattr(proxoperators, self.inverse_propagator)
def reshapeData(self, Nx, Ny, Nz):
"""
Reshape data based on the given arguments. This method is called
......@@ -367,6 +376,26 @@ class Experiment(metaclass=ExperimentMetaClass):
# the prox algorithms work with the square root of the measurement:
self.data = self.data.reshape(Nx, Ny)
def instanciateAlgorithm(self):
"""
Instanciate algorithm and some objects used by the
algorithm, namely the iterate monitor and accelerator,
based on the names provided.
This is a helper method called during the initialization
process. It should not be overriden.
"""
algorithmClass = getattr(algorithms, self.algorithm_name)
iterateMonitorClass = getattr(algorithms, self.iterate_monitor_name)
iterateMonitor = iterateMonitorClass(self)
if self.accelerator_name is not None:
acceleratorClass = getattr(algorithms, self.accelerator_name)
accelerator = acceleratorClass(self)
else:
accelerator = None
self.algorithm = algorithmClass(self, iterateMonitor, accelerator)
def saveOutput(self):
# Create directory (if needed)
......@@ -520,8 +549,6 @@ class Experiment(metaclass=ExperimentMetaClass):
@staticmethod
def getDefaultParameters():
defaultParams = {
'data_processor_name': 'JWST_DataProcessor',
'data_processor_package': 'proxtoolbox.Problems_old.Phase',
'object': 'complex',
'constraint': 'amplitude only',
'distance': 'far field',
......
......@@ -2,6 +2,7 @@
# pylint: disable=no-member # for dynamically created variables
# pylint: disable=access-member-before-definition # for dynamically created variables
from proxtoolbox.experiments.experiment import Experiment
from proxtoolbox import proxoperators
from proxtoolbox.Utilities.cell import Cell, isCell
from proxtoolbox.Utilities.loadMatFile import loadMatFile
from proxtoolbox.experiments.ptychography.ptychographyUtils import circ, \
......@@ -25,6 +26,7 @@ import matplotlib.pyplot as plt
from matplotlib.pyplot import subplots, show, figure
import os.path
import copy
class PtychographyExperiment(Experiment):
......@@ -62,7 +64,7 @@ class PtychographyExperiment(Experiment):
'lambda_switch': 4,
'data_ball': 1e-30,
'diagnostic': True,
'iterate_monitor_name': 'FeasibilityIterateMonitor',
'iterate_monitor_name': 'PHeBIE_IterateMonitor',
'verbose': 1,
'graphics': 1,
'anim': 0
......@@ -89,7 +91,7 @@ class PtychographyExperiment(Experiment):
def __init__(self,
data_dir = '../InputData/Ptychography/',
datafile = 'data_NTT_01_26210_192x192',
farfield = False,
farfield = True,
poissonfactor = 5,
plot = True,
switch_probemask = True,
......@@ -148,6 +150,13 @@ class PtychographyExperiment(Experiment):
self.RodenburgInnerIt = RodenburgInnerIt
self.error_type = error_type
# internal data members
self.proxSequence = None
# TODO override iterate monitor and optimality_monitor based
# algorithm and/or prox configuration being used
def loadData(self):
"""
Load Ptychography dataset. Create the initial iterate.
......@@ -255,7 +264,9 @@ class PtychographyExperiment(Experiment):
# Generate cell array of normalized intensities
# Normalizeation such, that for I = I_max (as matrix) the average pixel
# intensity is 1 and total is Nx*Ny, for all other I the values are lower
N_pie = int(self.N_pie)
self.N_pie = int(self.N_pie) # for now because this is read
# as a floating point value
N_pie = self.N_pie
self.data = Cell(N_pie)
self.data_sq = self.data
self.data_zeros = self.data
......@@ -366,7 +377,7 @@ class PtychographyExperiment(Experiment):
probe_guess = np.ones_like(probe)
else: # TODO implement missing cases
errMsg = "Probe guess type " + self.probe_guess_type \
+ "is not yet implemented"
+ " is not yet implemented"
raise NotImplementedError(errMsg)
# Initial guess for the object
......@@ -391,7 +402,7 @@ class PtychographyExperiment(Experiment):
np.pi)/2)
else:
errMsg = "Object guess type " + self.object_guess_type \
+ "is not yet implemented"
+ " is not yet implemented"
raise NotImplementedError(errMsg)
# Legacy variables
......@@ -427,7 +438,6 @@ class PtychographyExperiment(Experiment):
for y in range(int(self.Ny)):
object_support[indy[y],indx] += probe_mask[y,:]
#object_support[indy,indx] += probe_mask
self.object_support = (object_support > 0).astype(object_support.dtype)
object_guess *= self.object_support
......@@ -468,14 +478,168 @@ class PtychographyExperiment(Experiment):
"""
super(PtychographyExperiment, self).setupProxOperators() # call parent's method
self.nProx = 3
self.propagator = 'Propagator_FreFra'
self.inverse_propagator = 'InvPropagator_FreFra'
self.explicit = []
self.proxOperators = []
self.proxOperators.append('P_Amod')
self.proxOperators.append('Approx_Pphase_FreFra_Poisson')
self.sets = self.sets
self.nProx = self.sets
self.productProxOperators = []
if self.ptychography_prox == 'PHeBIE_Ptwise':
self.n_product_Prox = self.product_space_dimension
for _j in range(self.n_product_Prox):
self.productProxOperators.append('Approx_Pphase_FreFra_Poisson')
# prox 0
self.explicit.append('Explicit_ptychography_scan_farfield_probe_ptwise')
self.proxOperators.append('P_ptychography_scan_farfield_probe')
# prox 1
self.explicit.append('Explicit_ptychography_scan_farfield_object_ptwise')
self.proxOperators.append('P_supp_amp_band')
# prox 3
self.explicit.append('Explicit_ptychography_image')
self.proxOperators.append('Prox_product_space')
self.proxSequence = [1, 2, 3]
else:
errMsg = "Prox operator configuration " + self.ptychography_prox \
+ " is not yet implemented"
raise NotImplementedError(errMsg)
'''
switch input.ptychography_prox
case 'PHeBIE'
input.iterate_monitor = 'PHeBIE_iterate_monitor';
input.optimality_monitor = 'PHeBIE_ptychography_objective';
input.Explicit{3}='Explicit_ptychography_image';
input.Prox{3} = 'Prox_product_space';
input.n_product_Prox=input.product_space_dimension;
n_product_Prox=input.n_product_Prox;
input.product_Prox=cell(1, n_product_Prox);
for j=1:input.n_product_Prox
input.product_Prox{j} = 'Approx_Pphase_FreFra_Poisson';
end
input.Explicit{1}='Explicit_ptychography_scan_farfield_probe'; %pairs with the Implicit/prox mappings below
input.Prox{1} = 'P_ptychography_scan_farfield_probe';
input.Explicit{2}='Explicit_ptychography_scan_farfield_object'; %pairs with the Implicit/prox mappings below
input.Prox{2} = 'P_supp_amp_band';
input.ProxSequence = [1 2 3];
% below might be legacy/junk code
if strcmp(input.method,'Nesterov_PHeBIE')
input.fnames = {'probe','object','phi'};
end
case 'PHeBIE_Ptwise'
input.iterate_monitor = 'PHeBIE_iterate_monitor';
input.optimality_monitor = 'PHeBIE_ptychography_objective';
input.Explicit{3}='Explicit_ptychography_image';
input.Prox{3} = 'Prox_product_space';
input.n_product_Prox=input.product_space_dimension;
n_product_Prox=input.n_product_Prox;
input.product_Prox=cell(1, n_product_Prox);
for j=1:input.n_product_Prox
input.product_Prox{j} = 'Approx_Pphase_FreFra_Poisson';
end
input.Explicit{1}='Explicit_ptychography_scan_farfield_probe_ptwise'; %pairs with the Implicit/prox mappings below
input.Prox{1} = 'P_ptychography_scan_farfield_probe';
input.Explicit{2}='Explicit_ptychography_scan_farfield_object_ptwise'; %pairs with the Implicit/prox mappings below
input.Prox{2} = 'P_supp_amp_band';
input.ProxSequence = [1 2 3];
% below might be legacy/junk code
if strcmp(input.method,'Nesterov_PHeBIE')
input.fnames = {'probe','object','phi'};
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% the cases below have not been updated (DRL, 19.02.2019)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
case 'DRlPHeBIE_Ptwise'
input.iterate_monitor = 'PHeBIE_iterate_monitor';
input.optimality_monitor = 'PHeBIE_ptychography_objective';
input.Prox{1} = 'P_ptychography_PHeBIE_D_ptwise';
input.Prox{2} = 'P_ptychography_PHeBIE_phi';
input.ProxSequence = [1 2];
input.DRlwith = {'phi'}; %a list of file name on which reflections are to be performed with.
case 'Rodenburg'
input.iterate_monitor = 'generic_iterate_monitor';
input.optimality_monitor = '';
input.Prox{1} = 'P_ptychography_Rodenburg';
input.ProxSequence = [1];
case 'Thibault'
input.iterate_monitor = 'generic_iterate_monitor';
input.optimality_monitor = '';
input.Prox{1} = 'P_ptychography_Thibault_OP';
input.Prox{2} = 'P_ptychography_Thibault_F';
input.ProxSequence = [1 2];
input.warmup_Prox{1} = 'P_ptychography_Thibault_O';
input.warmup_Prox{2} = 'P_ptychography_Thibault_F';