Commit e0f42d12 authored by markus.meier01's avatar markus.meier01
Browse files

Uploaded algorithm for Wirtinger Flow, removed some printouts and undid the...

Uploaded algorithm for Wirtinger Flow, removed some printouts and undid the namechange for truth/Truth
parent 47c6e1d9
......@@ -10,7 +10,7 @@ from numpy import zeros, angle
from numpy import all as a
from scipy.linalg import norm
from .algorithms import Algorithm
from trace import Trace
class AP_expert(Algorithm):
"""
......@@ -53,9 +53,9 @@ class AP_expert(Algorithm):
self.iters = 0
if 'truth' in config:
self.Truth = config['truth']
self.Truth_dim = config['truth_dim']
self.norm_Truth = config['norm_truth']
self.truth = config['truth']
self.truth_dim = config['truth_dim']
self.norm_truth = config['norm_truth']
def run(self, u, tol, maxiter):
......@@ -99,13 +99,13 @@ class AP_expert(Algorithm):
tmp_gap = (norm(tmp1-tmp_u,'fro')/norm_data)**2
if hasattr(self, 'truth'):
if self.Truth_dim[0] == 1:
if self.truth_dim[0] == 1:
z=tmp_u[0,:]
elif self.Truth_dim[1]==1:
elif self.truth_dim[1]==1:
z=tmp_u[:,0]
else:
z=tmp_u
Relerrs[iter] = norm(self.Truth - exp(-1j*angle(Trace(self.Truth.T*z))) * z, 'fro')/self.norm_Truth
Relerrs[iter] = norm(self.truth - exp(-1j*angle(trace(self.truth.T*z))) * z, 'fro')/self.norm_truth
elif q==1:
for j in range(self.product_space_dimension):
......
......@@ -53,9 +53,9 @@ class SimpleAlgorithm:
self.config = config
if 'truth' in config:
self.Truth = config['truth']
self.Truth_dim = config['truth_dim']
self.norm_Truth = config['norm_truth']
self.truth = config['truth']
self.truth_dim = config['truth_dim']
self.norm_truth = config['norm_truth']
if 'diagnostic' in config:
self.diagnostic = True
......@@ -143,15 +143,15 @@ class SimpleAlgorithm:
tmp_gap = phase_offset_compensated_norm(u1, u2, norm_factor=norm_data, norm_type='fro') ** 2
if hasattr(self, 'truth'):
if self.Truth_dim[0] == 1:
if self.truth_dim[0] == 1:
z = u1[0, :]
elif self.Truth_dim[1] == 1:
elif self.truth_dim[1] == 1:
z = u1[:, 0]
else:
z = u1
# Relerrs[iter] = norm(self.truth - exp(-1j * angle(trace(self.truth.T * z))) * z,
# 'fro') / self.norm_truth
Relerrs[iter] = phase_offset_compensated_norm(self.Truth, z, norm_factor=self.norm_Truth,
Relerrs[iter] = phase_offset_compensated_norm(self.truth, z, norm_factor=self.norm_truth,
norm_type='fro')
elif q == 1:
......@@ -164,9 +164,8 @@ class SimpleAlgorithm:
if 'diagnostic' in self.config:
if hasattr(self, 'truth'):
z = u1[:, :, 0]
print(z.shape)
print(self.Truth.T.shape)
Relerrs[iter] = norm((self.Truth - exp(-1j * angle(trace(self.Truth.T * z))) * z),'fro') / self.norm_Truth
Relerrs[iter] = norm((self.truth - exp(-1j * angle(trace(self.truth.T.transpose() * z))) * z),'fro') / self.norm_truth
else:
if 'diagnostic' in self.config:
......@@ -182,7 +181,7 @@ class SimpleAlgorithm:
norm(u2[:, :, k, j] - shadow[:, :, k, j], 'fro') / (norm_data)) ** 2
if hasattr(self, 'truth') and (j == 0):
Relerrs[iter] = Relerrs[iter] + norm(
self.Truth - exp(-1j * angle(trace(self.truth.T * u1[:, :, k, 1]))) * u1[:, :, k, 1],
self.truth - exp(-1j * angle(trace(self.truth.T * u1[:, :, k, 1]))) * u1[:, :, k, 1],
'fro') / self.norm_Truth
change[iter] = sqrt(tmp_change)
......
from math import sqrt, exp
from numpy import zeros, angle, tile, mean, conj, reshape, repeat, newaxis, trace
from numpy.fft import fft, ifft, fft2, ifft2
from scipy.linalg import norm
from .algorithms import Algorithm
import numpy as np
## Doesn't work yet
class Wirtinger(Algorithm):
def __init__(self, config):
self.Prox1 = config['proxoperators'][0](config)
self.Prox2 = config['proxoperators'][1](config)
normest = config['norm_data']
self.Ny = config['Ny']
self.Nx = config['Nx']
self.Nz = config['Nz']
L = config['product_space_dimension']
u = config['u_0']
if 'truth' in config:
x = config['truth']
x_dim = config['truth_dim']
normM = config['norm_truth']
MAXIT = config['MAXIT']
TOL = config['TOL']
Masks = np.zeros((self.Ny,self.Nx,L))
if 'JWST' in config:
Masks = np.zeros(Ny,Nx,L)
indicator_ampl = config['indicator_ampl']
illumination_phase = config['illumination_phase']
for j in range(L):
Masks[:,:,j] = indicator_ampl * exp(1j*illumination_phase[:,:,j])
Y = config['data_sq']
elif 'CDP' in config:
Masks = config['Masks']
Y = config['data_sq']
else:
Y = config['data_sq']
normM = config['norm_rt_data']
Masks[0,0,0:L] = 1/normM
## def run(self, u, TOL, MAXIT):
Prox1 = self.Prox1
Prox2 = self.Prox2
if u.ndim < 3:
p = 1
q = 1
elif u.ndim == 3:
p = u.shape[2]
q = 1
else:
p = u.shape[2]
q = u.shape[3]
iter = 0
change = zeros(MAXIT + 1, dtype=u.dtype)
change[0] = 999
gap = change.copy()
if self.Nx==1:
A = lambda I: fft(Masks*I)
At = lambda Y: mean(conj(Masks)*ifft(Y),1)
elif self.Ny==1:
A = lambda I: fft(Masks*I)
At = lambda Y: mean(conj(Masks)*ifft(Y),0)
else:
A = lambda I: fft2(Masks*I)
h = (mean(conj(Masks)*ifft2(Y),2))
At = lambda Y: np.repeat(h[:,:,np.newaxis],L,axis=2)
if hasattr(self, 'truth'):
Relerrs = change.copy()
tau0 = 330
mu = lambda t: min(1-exp(-t/tau0), 0.4)
while iter < MAXIT and change[iter]>=TOL:
iter +=1
Bz = A(u)
C = (abs(Bz)**2 -Y) * Bz
grad = At(C)
step = mu(iter)/normest**2 * grad
u = u - step
tmp_change=0
tmp_gap = 0
if p==1 and q==1:
tmp_change = (norm(step, 'fro')/normM)**2
tmp_u = Prox1.work(u)
tmp1 = Prox2.work(tmp_u)
tmp_gap = (norm(tmp1-tmp_u, 'fro')/normM)**2
if x_dim[0]==1:
z = u[0,:]
elif x_dim[1]==1:
z = u[:,0]
else:
z = u[:,:,0]
Relerrs[iter] = norm(x - exp(-1j*angle(trace(x.transpose()*z)))*z,'fro')/normM
elif q==1:
tmp_u = Prox1.work(u)
tmp1 = Prox2.work(tmp_u)
tmp_change = L*(norm(step[:,:,1], 'fro') /normM)**2
for j in range(L):
tmp_gap = tmp_gap + (norm(tmp1[:,:,j] - tmp_u[:,:,j], 'fro') / normM)**2
Relerrs[iter] = norm(x - exp(-1j*angle(trace(x*tmp_u[:,:,0]))) *tmp_u[:,:,0], 'fro')/normM
else:
Times[iter] = toc
tmp_u = Prox1(u)
tmp1 = Prox2(tmp_u)
for j in range(L):
for k in range(Nz):
tmp_change = tmp_change + (norm(step[:,:,k,j], 'fro') /normM)**2
tmp_gap = tmp_gap + (norm(step[:,:,k,j], 'fro')/normM)**2
Relerrs[iter] = norm(x-exp(-1j*angle(trace(x.transpose()*tmp_u[:,:,:,0])))*tmp_u[:,:,:,0],'fro')/normM
gap[iter] = sqrt(tmp_gap)
change[iter] = sqrt(tmp_change)
tmp = Prox1(u)
tmp2 = Prox2(u)
if self.Nx==1:
u1 = tmp[:,0]
u2 = tmp2[:,0]
elif self.Ny==1:
u1 = tmp[0,:]
u2 = tmp2[0,:]
elif self.Nz==1:
u1 = tmp[:,:,0]
u2 = tmp2[:,:,0]
else:
u1 = tmp
u2 = tmp2
change = change[1:iter]
return {'u1': u1, 'u2': u2, 'iter': iter, 'change': change, 'gap': gap}
\ No newline at end of file
......@@ -31,5 +31,6 @@ from .GRAAL_F import *
from .GRAAL_objective import *
from .KM import *
from .QNAP import *
from .Wirtinger import *
__all__ = ["AP","HPR","RAAR", "AP_expert", "GRAAL", "RAAR_expert", "DRl", "ADMM", "RRR", "CAARL", "CADRl", "CDRl", "CP", "CPrand", "DRAP", "DRl", "GRAAL_F", "GRAAL_objective", "KM", "QNAP"]
__all__ = ["AP","HPR","RAAR", "AP_expert", "GRAAL", "RAAR_expert", "DRl", "ADMM", "RRR", "CAARL", "CADRl", "CDRl", "CP", "CPrand", "DRAP", "DRl", "GRAAL_F", "GRAAL_objective", "KM", "QNAP", "Wirtinger"]
......@@ -60,7 +60,7 @@ new_config = {
# able to control ('without too much damage')
# 'Algorithm':
'method' : 'Wirtinger',#'Accelerated_AP_product_space',
'algorithm' : 'Wirtinger',#'Accelerated_AP_product_space',
'numruns' :100, # the only time this parameter will
# be different than 1 is when we are
# benchmarking...not something a normal user
......
......@@ -59,7 +59,7 @@ new_config = {
## IMPORTANT: algorithm used to be Wirtinger, not sure if AP also works
'algorithm' : 'AP', #'Accelerated_AP_product_space';
'algorithm' : 'Wirtinger', #'Accelerated_AP_product_space';
'numruns' : 100, # the only time this parameter will
# be different than 1 is when we are
# benchmarking...not something a normal user
......
......@@ -39,10 +39,6 @@ def Phase_graphics(config, output):
#beta_max = config['beta_max']
u_0 = config['u_0']
u = output['u1']
print(u.shape)
if output['u1'].ndim == 2:
u = output['u1']
u2 = output['u2']
......
......@@ -127,7 +127,6 @@ class P_diag(ProxOperator):
m = self.m;
p = self.p;
K = self.K;
if m == 1:
tmp = sum(u, axis=0, dtype=u.dtype)
elif n == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment