Commit 6a359126 authored by s.gretchko's avatar s.gretchko

Added Wirtinger alg., iter. monitor, corresp. JWST, CDP_1D, CDP_2D demos + JWST_Phebie demos

parent 9774c651
import SetProxPythonPath
from proxtoolbox.experiments.phase.CDP_Experiment import CDP_Experiment
CDP = CDP_Experiment(algorithm = 'Wirtinger')
CDP.run()
CDP.show()
import SetProxPythonPath
from proxtoolbox.experiments.phase.CDP_Experiment import CDP_Experiment
CDP = CDP_Experiment(algorithm = 'Wirtinger', Nx=256, TOL=1e-8, debug=True)
CDP.run()
CDP.show()
import SetProxPythonPath
from proxtoolbox.experiments.phase.JWST_Experiment import JWST_Experiment
JWST = JWST_Experiment(algorithm='PHeBIE', anim=True)
JWST.run()
JWST.show()
import SetProxPythonPath
from proxtoolbox.experiments.phase.JWST_Experiment import JWST_Experiment
JWST = JWST_Experiment(algorithm='PHeBIE', anim=True, noise=True)
JWST.run()
JWST.show()
import SetProxPythonPath
from proxtoolbox.experiments.phase.JWST_Experiment import JWST_Experiment
# Note: this demo has the same behavior as Matlab's with the same initial
# conditions. The algorithm does not converge.
JWST = JWST_Experiment(algorithm='Wirtinger')
JWST.run()
JWST.show()
import SetProxPythonPath
from proxtoolbox.experiments.phase.JWST_Experiment import JWST_Experiment
# Note: this demo has the same behavior as Matlab's with the same initial
# conditions. The algorithm does not converge.
JWST = JWST_Experiment(algorithm='Wirtinger', noise=True)
JWST.run()
JWST.show()
from proxtoolbox.algorithms.algorithm import Algorithm
from proxtoolbox import proxoperators
from proxtoolbox.utils.cell import Cell, isCell
from numpy import ones
......@@ -34,10 +35,13 @@ class PHeBIE(Algorithm):
# instantiate explicit prox operators
self.explicit = []
for proxClass in experiment.explicit:
if proxClass != '':
prox = proxClass(experiment)
self.explicit.append(prox)
for prox_item in experiment.explicit:
if isinstance(prox_item, str):
proxClass = getattr(proxoperators, prox_item)
else:
proxClass = prox_item
prox = proxClass(experiment)
self.explicit.append(prox)
def evaluate(self, u):
"""
......@@ -53,7 +57,9 @@ class PHeBIE(Algorithm):
u_new : type ???
The new iterate.
"""
print("Iteration: ", self.iter)
if self.debug:
print("Iteration: ", self.iter)
u_new = u.copy()
# Loop through the block-wise implicit-explicit update steps.
......
......@@ -18,7 +18,6 @@ class PHeBIE_IterateMonitor(IterateMonitor):
self.rel_errors = None
self.product_space_dimension = experiment.product_space_dimension
self.Nz = experiment.Nz
self.positions = experiment.positions
def preprocess(self, alg):
......@@ -86,25 +85,6 @@ class PHeBIE_IterateMonitor(IterateMonitor):
self.u_monitor = u.copy()
def calculateObjective(self, alg):
objValue = 0
u = alg.u_new
normM = self.norm_data
self.rangeNx = np.arange(alg.Nx, dtype=np.int)
self.rangeNy = np.arange(alg.Ny, dtype=np.int)
if isCell(u[2]):
for jj in range(len(u[2])):
indy = (self.positions[jj,0] + self.rangeNy).astype(int)
indx = (self.positions[jj,1] + self.rangeNx).astype(int)
objValue += (norm(u[0]*u[1][indy,:][:,indx] - u[2][jj], 'fro')/normM)**2
else:
p, _q = self.getIterateSize(u[2])
for jj in range(p):
indy = (self.positions[jj,0] + self.rangeNy).astype(int)
indx = (self.positions[jj,1] + self.rangeNx).astype(int)
objValue += (norm(u[0]*u[1][indy,:][:,indx] - u[2][:,:,jj], 'fro')/normM)**2
return objValue
def postprocess(self, alg, output):
"""
......
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
import numpy as np
from numpy.linalg import norm
class PHeBIE_phase_objective:
"""
objective function for the PHeBIE algorithm applied to the
phase retrieval problem
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on Feb. 12, 2019.
"""
def __init__(self, experiment):
if hasattr(experiment, 'norm_data'):
self.norm_data = experiment.norm_data
else:
self.norm_data = 1.0
def calculateObjective(self, alg):
"""
Evaluate PHeBIE objective function
Parameters
----------
alg : algorithm instance
The algorithm that is running
Returns
-------
objValue : real
The value of the objective function
"""
objValue = 0
u = alg.u_new
normM = self.norm_data
m1, n1, p1, _q1 = size_matlab(u[0])
if isCell(u[1]):
for jj in range(len(u[1])):
objValue += (norm(u[0] - u[1][jj])/normM)**2
else:
m2, n2, p2, q2 = size_matlab(u[1])
if m1 == 1:
objValue = (norm(np.ones(m2,1) @ u[0] - u[1]))**2
elif n1 == 1:
objValue = (norm(u[0] @ np.ones(1,n2)*u[0] - u[1]))**2
elif p1 == 1:
for jj in range(p2):
objValue += (norm(u[0] - u[1][:,:,jj]))**2
else: # cells of 4D arrays?!!!
for jj in range(p1):
for k in range(q2):
objValue += (norm(u[0][:,:,k] - u[1][:,:,k,jj]))**2
return objValue
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
import numpy as np
from numpy.linalg import norm
class PHeBIE_ptychography_objective:
"""
objective function for the PHeBIE algorithm applied to the
ptychography problem
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on Feb. 12, 2019.
"""
def __init__(self, experiment):
if hasattr(experiment, 'norm_data'):
self.norm_data = experiment.norm_data
else:
self.norm_data = 1.0
self.positions = experiment.positions
def calculateObjective(self, alg):
"""
Evaluate PHeBIE objective function
Parameters
----------
alg : algorithm instance
The algorithm that is running
Returns
-------
objValue : real
The value of the objective function
"""
objValue = 0
u = alg.u_new
normM = self.norm_data
self.rangeNx = np.arange(alg.Nx, dtype=np.int)
self.rangeNy = np.arange(alg.Ny, dtype=np.int)
if isCell(u[2]):
for jj in range(len(u[2])):
indy = (self.positions[jj,0] + self.rangeNy).astype(int)
indx = (self.positions[jj,1] + self.rangeNx).astype(int)
objValue += (norm(u[0]*u[1][indy,:][:,indx] - u[2][jj], 'fro')/normM)**2
else:
_m, _n, p, _q = size_matlab(u[2])
for jj in range(p):
indy = (self.positions[jj,0] + self.rangeNy).astype(int)
indx = (self.positions[jj,1] + self.rangeNx).astype(int)
objValue += (norm(u[0]*u[1][indy,:][:,indx] - u[2][:,:,jj], 'fro')/normM)**2
return objValue
from proxtoolbox.algorithms.iterateMonitor import IterateMonitor
from proxtoolbox.utils.cell import Cell, isCell
import numpy as np
from numpy import zeros, angle, trace, exp, sqrt, sum, matmul, array, reshape
from numpy.linalg import norm
import time
class Wirt_IterateMonitor(IterateMonitor):
"""
Algorithm analyzer for monitoring iterates of
projection algorithms for the Wirtinger algorithm.
Specialization of the IterateMonitor class.
"""
def __init__(self, experiment):
super(Wirt_IterateMonitor, self).__init__(experiment)
self.gaps = None
self.rel_errors = None
self.times = None
self.product_space_dimension = experiment.product_space_dimension
self.Nz = experiment.Nz
self.startTime = 0
def preprocess(self, alg):
"""
Allocate data structures needed to collect
statistics. Called before running the algorithm.
Parameters
----------
alg : instance of Algorithm class
Algorithm to be monitored.
"""
super(Wirt_IterateMonitor, self).preprocess(alg)
# In PHeBIE, the iterate u is a cell of blocks of variables.
# In the analysis of Hesse, Luke, Sabach&Tam, SIAM J. Imaging Sciences, 2015
# all blocks are monitored for convergence.
self.u_monitor = self.u0.copy()
# set up diagnostic arrays
if self.diagnostic:
self.gaps = self.changes.copy()
self.gaps[0] = 999
if self.truth is not None:
self.rel_errors = self.changes.copy()
self.rel_errors[0] = sqrt(999)
self.times = self.changes.copy()
self.times[0] = 0
self.startTime = time.time()
def updateStatistics(self, alg):
"""
Update statistics. Called at each iteration
while the algorithm is running.
Parameters
----------
alg : instance of Algorithm class
Algorithm being monitored.
"""
u = alg.u_new
normM = self.norm_data
tmp_change = 0
tmp_gap = 0
rel_error = 0
if self.diagnostic:
tmp_u = alg.prox1.eval(u)
tmp1 = alg.prox2.eval(tmp_u)
if isCell(u):
for j in range(len(u)):
tmp_change += (norm(alg.step[j])/normM)**2
if self.diagnostic:
tmp_gap += (norm(tmp1[j] - tmp_u[j])/normM)**2
if self.diagnostic and self.truth is not None:
z = u[0]
rel_error = norm(self.truth - exp(-1j*angle(trace(self.truth.T.conj() @ z))) * z) / self.norm_truth
else:
tmp_change = sum(abs(alg.step)**2)/normM**2
if self.diagnostic:
tmp_gap = sum(abs(tmp1 - tmp_u)**2)/normM**2
if self.truth is not None:
if self.truth_dim[0] == 1:
z = u[0,:]
z = reshape(z, (1,len(z))) # we want a true matrix not just a vector
elif self.truth_dim[1] == 1:
z = u[:,0]
z = reshape(z, (len(z),1)) # we want a true matrix not just a vector
else:
z = u[:,:,0]
rel_error = norm(self.truth - exp(-1j*angle(trace(self.truth.T.conj() @ z))) * z) / self.norm_truth
self.changes[alg.iter] = sqrt(tmp_change)
if self.diagnostic:
self.gaps[alg.iter] = sqrt(tmp_gap)
if self.truth is not None:
self.rel_errors[alg.iter] = rel_error
self.times[alg.iter] = time.time() - self.startTime
self.u_monitor = u.copy()
def postprocess(self, alg, output):
"""
Called after the algorithm stops. Store statistics in
the given 'output' dictionary
Parameters
----------
alg : instance of Algorithm class
Algorithm that was monitored.
output : dictionary
Contains the last iterate and various statistics that
were collected while the algorithm was running.
Returns
-------
output : dictionary into which the following entries are added
(when diagnostics are required)
gaps : ndarray
Squared gap distance normalized by the magnitude
constraint
"""
output = super(Wirt_IterateMonitor, self).postprocess(alg, output)
if self.diagnostic:
stats = output['stats']
stats['gaps'] = self.gaps[1:alg.iter+1]
if self.truth is not None:
stats['rel_errors'] = self.rel_errors[1:alg.iter+1]
stats['times'] = self.times[1:alg.iter+1]
return output
\ No newline at end of file
from proxtoolbox.algorithms.algorithm import Algorithm
#from proxtoolbox.utils import fft, ifft
from proxtoolbox.utils.cell import Cell, isCell
from proxtoolbox.utils.size import size_matlab
import numpy as np
from numpy import sqrt, conj, tile, mean, exp, angle, trace, reshape
from numpy.linalg import norm
from numpy.fft import fft2, ifft2, fft, ifft
class Wirtinger(Algorithm):
"""
Wirtinger flow algorithm as implemented
by E. J. Candes, X. Li, and M. Soltanolkotabi
"Phase Retrieval via Wirtinger Flow: Theory and Algorithms"
adapted for ProxToolbox
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on June 29, 2017.
"""
def __init__(self, experiment, iterateMonitor, accelerator):
super(Wirtinger, self).__init__(experiment, iterateMonitor, accelerator)
self.masks = experiment.masks
self.data_sq = experiment.data_sq
self.tau0 = 330
self.mu = lambda t: min(1-exp(-t/self.tau0), 0.4)
def evaluate(self, u):
"""
Update for Wirtinger algorithm.
Parameters
----------
u : ndarray
The current iterate.
Returns
-------
u_new : ndarray
The new iterate (same type as input parameter `u`).
"""
n1 = self.Ny
n2 = self.Nx
L = self.product_space_dimension
if isCell(u):
# assumes 2D
u_new = Cell(L)
self.step = Cell(L)
get = lambda u, j: u[j]
set = self.set0
data_shape = u[0].shape
else:
u_new = np.empty(u.shape, dtype=np.complex128)
self.step = np.empty(u.shape, dtype=np.complex128)
if n2 == 1:
# u is n1 x L array
get = lambda u, j: u[:,j]
set = self.set1
data_shape = (n1)
elif n1 == 1:
# u is L x n2 array
get = lambda u, j: u[j,:]
set = self.set2
data_shape = (n2)
else:
# assumes 2D case
get = lambda u, j: u[:,:,j]
set = self.set3
data_shape = (n1, n2)
grad = np.zeros(data_shape, dtype=np.complex128)
for j in range(L):
grad += self.evaluateGradHelper(get(u, j), get(self.masks, j), get(self.data_sq, j))
grad /= L
step = self.mu(self.iter+1)/self.norm_data**2 * grad
for j in range(L):
set(self.step, j, step)
set(u_new, j, get(u, j)-step)
return u_new
# the following functions are helper functions used
# by the evaluate() method. They should have been lambda
# functions, but Python does not support these kind of
# functions that perform an assignment using subscripts
def set0(self, dest, j, val):
dest[j] = val
def set1(self, dest, j, val):
dest[:,j] = val
def set2(self, dest, j, val):
dest[j,:] = val
def set3(self, dest, j, val):
dest[:,:,j] = val
def evaluateGradHelper(self, u, masks, data_sq):
m, n, _p, _q = size_matlab(u)
if m > 1 and n > 1:
FFT = lambda u: fft2(u)
IFFT = lambda u: ifft2(u)
else:
FFT = lambda u: fft(u)
IFFT = lambda u: ifft(u)
Bz = FFT(masks * u)
C = (abs(Bz)**2 - data_sq) * Bz
grad = conj(masks) * IFFT(C)
return grad
......@@ -29,3 +29,7 @@ from .AvP2 import *
from .AvP2_IterateMonitor import *
from .GenericAccelerator import *
from .DyRePr import*
from .PHeBIE_ptychography_objective import *
from .PHeBIE_phase_objective import *
from .Wirtinger import *
from .Wirt_IterateMonitor import *
......@@ -45,6 +45,10 @@ class Algorithm:
self.u = None
self.u_new = None
# for debugging
self.debug = experiment.debug
def preprocess(self):
"""
The default implementation calls the iterate monitor's
......
......@@ -32,7 +32,7 @@ class IterateMonitor:
# instantiate optimality monitor if it exists
self.optimality_monitor = None
if hasattr(experiment, 'optimality_monitor'):
if hasattr(experiment, 'optimality_monitor') and experiment.optimality_monitor is not None:
optimality_monitor_name = experiment.optimality_monitor
optimality_monitor_class = getattr(algorithms, optimality_monitor_name)
self.optimality_monitor = optimality_monitor_class(experiment)
......
......@@ -192,7 +192,9 @@ class Experiment(metaclass=ExperimentMetaClass):
self.truth = None
self.truth_dim = None
self.norm_truth = None
self.optimality_monitor = None
self.proj_iter = None # not sure what it does
self.proxOperators = []
......
from proxtoolbox.experiments.phase.phaseExperiment import PhaseExperiment
from proxtoolbox.utils import fft, ifft
from proxtoolbox.utils.loadMatFile import loadMatFile
import proxtoolbox.utils as utils
import numpy as np
from numpy import sqrt, conj, tile, mean, exp, angle, trace, reshape, matmul
from numpy.random import randn, random_sample
......@@ -79,7 +79,7 @@ class CDP_Experiment(PhaseExperiment):
# used predefined randomly generated data for the case where n1 = 128, n2 = 1
n1 = self.Ny
n2 = self.Nx
debug = self.debug and n1 == 128 and n2 == 1 \
debug = self.debug and n1 == 128 and (n2 == 1 or n2 == 256) \
and self.product_space_dimension == 10
debug_image = None
debug_masks = None
......@@ -87,7 +87,16 @@ class CDP_Experiment(PhaseExperiment):
# make image
if debug:
debug_image, debug_masks, debug_z0 = self.createTestImage()
if n2 == 1:
debug_image, debug_masks, debug_z0 = self.createTestImage()
elif n2 == 256:
x_dict = loadMatFile('../InputData/Phase/CDP_2D_x.mat')
debug_image = x_dict['x']
masks_dict = loadMatFile('../InputData/Phase/CDP_2D_Masks.mat')
debug_masks = masks_dict['Masks']
z0_dict = loadMatFile('../InputData/Phase/CDP_2D_z0.mat')
debug_z0 = z0_dict['z0']
x = debug_image
else:
x = randn(n1,n2) + 1j*randn(n1,n2)
......@@ -123,8 +132,8 @@ class CDP_Experiment(PhaseExperiment):
A = lambda I: fft(conj(masks) * tile(I,[L, 1])) # Input is 1 x n signal, output is L x n array
At = lambda Y: mean(masks * ifft(Y), 0).reshape((n1,n2)) # Input is L x n array, output is 1 x n signal
else:
A = lambda I: fft2(masks * reshape(tile(I,[1, L]), (I.shape[0],I.shape[1], L))) # Input is n1 x n2 image, output is n1 x n2 x L array
At = lambda Y: mean(masks * ifft2(Y), 2).reshape((n1,n2)) # Input is n1 x n2 X L array, output is n1 x n2 image
A = lambda I: fft2(masks * reshape(tile(I,[1, L]), (I.shape[0],I.shape[1], L), order='F')) # Input is n1 x n2 image, output is n1 x n2 x L array
At = lambda Y: mean(masks * ifft2(Y), 2).reshape((n1,n2)) # Input is n1 x n2 X L array, output is n1 x n2 image
# Support constraint: none
self.indicator_ampl = 1
......@@ -135,7 +144,16 @@ class CDP_Experiment(PhaseExperiment):
#alt_support_idx = find(Xi_A.T, 0, "!=")
# Data
Y = abs(A(x))
if n2 == 1 or n1 == 1:
Y = abs(A(x))
else:
tmp_x = reshape(tile(x,[1, L]), (x.shape[0],x.shape[1], L), order='F')
masks_x = self.masks*tmp_x
fft2_masks_x = np.empty(masks.shape, dtype=masks.dtype)
for j in range(L):
fft2_masks_x[:,:,j] = fft2(masks_x[:,:,j])
Y = abs(fft2_masks_x)
self.rt_data = Y
Y = Y**2
self.data = Y
......@@ -169,7 +187,7 @@ class CDP_Experiment(PhaseExperiment):
self.u0 = tile(z,[L,1])
else:
_Relerrs = norm(x - exp(-1j*angle(trace(matmul(x.T.conj(), z)))) * z, 'fro')/norm(x,'fro')
self.u0 = reshape(tile(