Commit 9abb1308 authored by s.gretchko's avatar s.gretchko
Browse files

Added Phasepack experiment

parent 044ae881
from proxtoolbox.experiments.phase.phaseExperiment import PhaseExperiment
from proxtoolbox import proxoperators
from proxtoolbox.utils.loadMatFile import loadMatFile
import numpy as np
from numpy import fromfile, exp, nonzero, zeros, pi, resize, real, angle
from numpy.random import rand
from numpy.linalg import norm, pinv
import proxtoolbox.utils as utils
from proxtoolbox.utils.cell import Cell, isCell
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import subplots, show, figure
import os.path
import time
class Phasepack_Experiment(PhaseExperiment):
'''
Phasepack experiment class
'''
@staticmethod
def getDefaultParameters():
defaultParams = {
'experiment_name' : 'Phasepack',
'object': 'nonnegative',
'constraint': 'nonnegative and support',
'dataset' : 'PhaseSLM_40x40',
'Nx': 256,
'Ny': 256,
'Nz': 1,
'MAXIT': 200,
'TOL': 1e-4,
'lambda_0': 0.85,
'lambda_max': 0.85,
'lambda_switch': 20,
'data_ball': .999826e-30,
'diagnostic': True,
'iterate_monitor_name': 'FeasibilityIterateMonitor',
'verbose': 0,
'graphics': 1
}
return defaultParams
def __init__(self, dataset = 'PhaseSLM_40x40', **kwargs):
super(Phasepack_Experiment, self).__init__(**kwargs)
self.dataset = dataset
# additional attributes
self.phasepack_A = None
self.phasepack_Apinv = None
self.abs_illumination = None
self.supp_phase = None
self.data_zeros = None
self.support_idx = None
def loadData(self):
"""
Load Phasepack dataset. Create the initial iterate.
"""
# check if data exists
# TODO: All the data corresponding to this experiment should be placed
# in a zip file on the http://vaopt.math.uni-goettingen.de/data/ website
# so that we can use the automatic data loading feature (i.e., GetData.getData("Phasepack")).
# Another possibility is to write code that downloads the data files directly from
# the original website (https://rice.app.box.com/v/TransmissionMatrices)
data_dir = "../InputData/Phase/" + self.dataset + '/'
filenames = ['A_GS.mat', 'YH_squared_test.mat', 'XH_test.mat']
for filename in filenames:
data_path = data_dir + filename
if not os.path.isfile(data_path):
errMsg = "File " + data_path + " was not found"
print(errMsg)
print('*************************************************************************')
print('* INPUT DATA MISSING. Please download the phase pack input data from ')
print('* https://rice.app.box.com/v/TransmissionMatrices ')
print('* and save it in the directory', data_dir)
print('*************************************************************************')
raise IOError(errMsg)
print("Loading data. This may take a while...")
# transmission matrix
XH_test_dict = loadMatFile(data_dir + filenames[2])
XH_test = XH_test_dict['XH_test']
self.truth = XH_test[0,:].astype(float)
del XH_test
del XH_test_dict
Nphys = len(self.truth)
S = np.ones(Nphys)
#t = time.time()
A_GS_dict = loadMatFile(data_dir + filenames[0])
#elapsed = time.time() - t
#print("Took", elapsed, " seconds to load A_GS.mat file.")
self.phasepack_A = A_GS_dict['A']
del A_GS_dict
self.phasepack_Apinv = pinv(self.phasepack_A)
# diffraction pattern
YH_squared_test_dict = loadMatFile(data_dir + filenames[1])
YH_squared_test = YH_squared_test_dict['YH_squared_test']
self.data_sq = YH_squared_test[0,:].astype(float)
del YH_squared_test
del YH_squared_test_dict
orig_res = len(self.data_sq) # actual data size
workres = np.sqrt(orig_res)
N = workres
self.data = np.sqrt(self.data_sq)
# standard for the main program is that
# the data field is the magnitude SQUARED
self.norm_rt_data = norm(self.data)
self.norm_data = self.norm_rt_data
Xtrue = self.phasepack_A @ self.truth
Xtrue = self.norm_rt_data / norm(Xtrue) * Xtrue # gets the scaling right
self.truth = self.phasepack_Apinv @ Xtrue
self.truth = np.reshape(self.truth, (len(self.truth), 1)) # we need a matrix later on, not just a vector
self.truth_dim = self.truth.shape
self.norm_truth = norm(self.truth)
del Xtrue
# define support
self.data_zeros = np.where(self.data == 0)
self.support_idx = np.where(S != 0)
self.sets = 2
# use the abs_illumination field to represent the
# support constraint.
self.abs_illumination = S
self.supp_phase = []
# initial guess
self.u0 = S*rand(Nphys)
self.u0 = self.u0 / norm(self.u0) * self.norm_rt_data
def setupProxOperators(self):
"""
Determine the prox operators to be used for this experiment
"""
super(Phasepack_Experiment, self).setupProxOperators() # call parent's method
self.proxOperators = []
self.productProxOperators = []
if self.formulation == 'cyclic':
# there are as many prox operators as there are sets
self.nProx = self.sets
self.product_space_dimension = 1
for _j in range(self.nProx-1):
self.proxOperators.append('Pphase_phasepack')
self.proxOperators.append('P_SP')
else: # product space formulation
# add prox operators
self.nProx = 2
self.product_space_dimension = self.sets
self.proxOperators.append('P_diag')
self.proxOperators.append('Prox_product_space')
# add product prox operators
self.n_product_Prox = self.product_space_dimension
for _j in range(self.n_product_Prox-1):
self.productProxOperators.append('Pphase_phasepack')
self.productProxOperators.append('P_SP')
def show(self):
"""
Generate graphical output from the solution
"""
u_m = self.output['u_monitor']
if isCell(u_m):
u = u_m[0]
if isCell(u):
u = u[0]
u2 = u_m[len(u_m)-1]
if isCell(u2):
u2 = u2[len(u2)-1]
else:
u2 = u_m
if u2.ndim > 2:
u2 = u2[:,:,0]
u = self.output['u']
if isCell(u):
u = u[0]
elif u.ndim > 2:
u = u[:,:,0]
n = int(np.sqrt(len(u)))
u = np.reshape(u, (n, n), order='F')
u2 = np.reshape(u2, (n, n), order='F')
algo_desc = self.algorithm.getDescription()
title = "Algorithm " + algo_desc
# figure 904
titles = ["Best approximation amplitude - physical constraint satisfied",
"Best approximation phase - physical constraint satisfied",
"Best approximation amplitude - Fourier constraint satisfied",
"Best approximation phase - Fourier constraint satisfied"]
f = self.createFourImageFigure(u, u2, titles)
f.suptitle(title)
plt.subplots_adjust(hspace = 0.3) # adjust vertical space (height) between subplots (default = 0.2)
# figure 900
changes = self.output['stats']['changes']
time = self.output['stats']['time']
time_str = "{:.{}f}".format(time, 5) # 5 is precision
label = "Iterations (time = " + time_str + " s)"
f, ((ax1, ax2), (ax3, ax4)) = subplots(2, 2, \
figsize = (self.figure_width, self.figure_height),
dpi = self.figure_dpi)
self.createImageSubFigure(f, ax1, abs(u), titles[0])
self.createImageSubFigure(f, ax2, real(u), titles[1])
ax3.semilogy(changes)
ax3.set_xlabel(label)
ax3.set_ylabel('Log of change in iterates')
if 'gaps' in self.output['stats']:
gaps = self.output['stats']['gaps']
ax4.semilogy(gaps)
ax4.set_xlabel(label)
ax4.set_ylabel('Log of the gap distance')
f.suptitle(title)
plt.subplots_adjust(hspace = 0.3) # adjust vertical space (height) between subplots (default = 0.2)
plt.subplots_adjust(wspace = 0.3) # adjust horizontal space (width) between subplots (default = 0.2)
show()
from proxtoolbox.proxoperators.proxoperator import ProxOperator
import numpy as np
from numpy import zeros, real, shape
class Pphase_phasepack(ProxOperator):
"""
Projection onto phase magnitude constraints as done in
the phasepack toolbox. See the main
[PhasePack page](http://cs.umd.edu/~tomg/projects/phasepack/)
for complete information, or check out the [user guide]
(https://arxiv.org/abs/1711.09777).
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on Jan 20, 2019.
"""
def __init__(self, experiment):
self.phasepack_A = experiment.phasepack_A
self.phasepack_Apinv = experiment.phasepack_Apinv
self.data = experiment.data
def eval(self, u, prox_idx=None):
"""
Projects the input data onto onto phase magnitude constraints
Parameters
----------
u : ndarray
Function in the physical domain to be projected
prox_idx : int, optional
Index of this prox operator
Returns
-------
u_new : ndarray
The projection in the physical (time) domain in the
same format as `u`
"""
u_new = self.phasepack_A @ u
u_new = self.phasepack_Apinv @ (self.data * u_new / abs(u_new))
return u_new
......@@ -26,3 +26,4 @@ from .P_CDP_ADMM import *
from .P_S import *
from .P_CDP_cyclic import *
from .sourceLocProx import *
from .Pphase_phasepack import *
......@@ -96,21 +96,12 @@ class MatFile73Reader:
else:
value = arr.T.squeeze()
elif matlab_type == 'complex':
# TODO simple implementation but not efficient memorywise
arr_real = np.array(dataset, dtype=dataset.dtype)
shape = arr_real.shape
arr_flat = arr_real.flatten()
n = len(arr_flat)
arr_flat_comp = np.ndarray(n, dtype = np.complex)
for i in range(n):
val = arr_flat[i]
c = complex(val[0], val[1])
arr_flat_comp[i] = c
arr_comp = arr_flat_comp.reshape(shape)
if arr_comp.size == 1:
value = arr_comp.item()
else:
value = arr_comp.T.squeeze()
arr = np.array(dataset)
arr = (arr['real'] + arr['imag']*1j).astype(np.complex128)
if arr.size == 1:
value = arr.item()
else:
value = arr.T.squeeze()
elif matlab_type == 'char':
st = ''
for c_item in dataset:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment