Commit d7885374 authored by markus.meier01's avatar markus.meier01
Browse files

fixed some things in Wirtinger algorithm

parent e0f42d12
......@@ -4,50 +4,51 @@ from numpy.fft import fft, ifft, fft2, ifft2
from scipy.linalg import norm
from .algorithms import Algorithm
import numpy as np
## Doesn't work yet
import cmath
## Doesn't work yet, relative error increases over the iterations instead of decreasing.
class Wirtinger(Algorithm):
def __init__(self, config):
self.Prox1 = config['proxoperators'][0](config)
self.Prox2 = config['proxoperators'][1](config)
normest = config['norm_data']
self.norm_data = config['norm_data']
self.Ny = config['Ny']
self.Nx = config['Nx']
self.Nz = config['Nz']
L = config['product_space_dimension']
u = config['u_0']
self.product_space_dimension = config['product_space_dimension']
self.u = config['u_0']
if 'truth' in config:
x = config['truth']
x_dim = config['truth_dim']
normM = config['norm_truth']
self.truth = config['truth']
self.truth_dim = config['truth_dim']
self.norm_truth = config['norm_truth']
MAXIT = config['MAXIT']
TOL = config['TOL']
Masks = np.zeros((self.Ny,self.Nx,L))
if 'JWST' in config:
Masks = np.zeros(Ny,Nx,L)
indicator_ampl = config['indicator_ampl']
illumination_phase = config['illumination_phase']
for j in range(L):
Masks[:,:,j] = indicator_ampl * exp(1j*illumination_phase[:,:,j])
Y = config['data_sq']
elif 'CDP' in config:
Masks = config['Masks']
Y = config['data_sq']
else:
Y = config['data_sq']
normM = config['norm_rt_data']
Masks[0,0,0:L] = 1/normM
## def run(self, u, TOL, MAXIT):
if ('experiment' in config):
self.experiment = config['experiment']
if self.experiment == 'CDP':
self.Masks = config['Masks']
self.data_sq = config['data_sq']
elif self.experiment == 'JWST':
self.Masks = np.zeros(self.Ny,self.Nx,self.product_space_dimension)
self.indicator_ampl = config['indicator_ampl']
self.illumination_phase = config['illumination_phase']
self.data_sq = config['data_sq']
for j in range(product_space_dimension):
Masks[:,:,j] = indicator_ampl * exp(1j*illumination_phase[:,:,j])
else:
self.data_sq = config['data_sq']
self.norm_rt_data = config['norm_rt_data']
self.Masks[0,0,0:self.product_space_dimension] = 1/self.norm_truth
def run(self, u, TOL, MAXIT):
Masks = self.Masks
Prox1 = self.Prox1
Prox2 = self.Prox2
......@@ -62,77 +63,90 @@ class Wirtinger(Algorithm):
q = u.shape[3]
iter = 0
change = zeros(MAXIT + 1, dtype=u.dtype)
change = zeros(MAXIT, dtype=u.dtype)
change[0] = 999
gap = change.copy()
if self.Nx==1:
A = lambda I: fft(Masks*I)
At = lambda Y: mean(conj(Masks)*ifft(Y),1)
At = lambda Y: np.repeat(mean(conj(Masks) * ifft(Y),1)[:,np.newaxis], self.product_space_dimension, axis = 1)
elif self.Ny==1:
A = lambda I: fft(Masks*I)
At = lambda Y: mean(conj(Masks)*ifft(Y),0)
At = lambda Y: np.repeat(mean(conj(Masks) * ifft(Y),0)[np.newaxis,:], self.product_space_dimension, axis = 0)
else:
A = lambda I: fft2(Masks*I)
h = (mean(conj(Masks)*ifft2(Y),2))
At = lambda Y: np.repeat(h[:,:,np.newaxis],L,axis=2)
At = lambda Y: np.repeat(mean(conj(Masks) * ifft2(Y),2)[:,:,np.newaxis], self.product_space_dimension, axis = 2)
if hasattr(self, 'truth'):
Relerrs = change.copy()
tau0 = 330
mu = lambda t: min(1-exp(-t/tau0), 0.4)
while iter < MAXIT and change[iter]>=TOL:
while iter < MAXIT-1 and change[iter]>=TOL:
iter +=1
Bz = A(u)
C = (abs(Bz)**2 -Y) * Bz
C = (abs(Bz)**2 -self.data_sq)* Bz
grad = At(C)
step = mu(iter)/normest**2 * grad
u = u - step
tmp_change=0
step = mu(iter)/self.norm_data**2 * grad
u = u - step
#print(norm(u,'fro'))
#print("u")
#print(norm(Bz,'fro'))
#print("Bz")
#print(norm(C,'fro'))
#print("C")
#print(norm(grad,'fro'))
#print("grad")
#print(norm(step,'fro'))
#print("step")
#print(norm(u,'fro'))
#print("newu")
tmp_change = 0
tmp_gap = 0
if p==1 and q==1:
tmp_change = (norm(step, 'fro')/normM)**2
tmp_change = (norm(step, 'fro')/self.norm_truth)**2
tmp_u = Prox1.work(u)
tmp1 = Prox2.work(tmp_u)
tmp_gap = (norm(tmp1-tmp_u, 'fro')/normM)**2
if x_dim[0]==1:
tmp_gap = (norm(tmp1-tmp_u, 'fro')/self.norm_truth)**2
if self.truth_dim[0]==1:
z = u[0,:]
elif x_dim[1]==1:
elif self.truth_dim[1]==1:
z = u[:,0]
else:
z = u[:,:,0]
Relerrs[iter] = norm(x - exp(-1j*angle(trace(x.transpose()*z)))*z,'fro')/normM
Relerrs[iter] = norm(self.truth - cmath.exp(-1j*angle(trace(self.truth*z)))*z,'fro')/self.norm_truth
elif q==1:
tmp_u = Prox1.work(u)
tmp1 = Prox2.work(tmp_u)
tmp_change = L*(norm(step[:,:,1], 'fro') /normM)**2
for j in range(L):
tmp_gap = tmp_gap + (norm(tmp1[:,:,j] - tmp_u[:,:,j], 'fro') / normM)**2
Relerrs[iter] = norm(x - exp(-1j*angle(trace(x*tmp_u[:,:,0]))) *tmp_u[:,:,0], 'fro')/normM
tmp_change = self.product_space_dimension*(norm(step[:,:,0], 'fro') /self.norm_truth)**2
for j in range(self.product_space_dimension):
tmp_gap = tmp_gap + (norm(tmp1[:,:,j] - tmp_u[:,:,j], 'fro') / self.norm_truth)**2
Relerrs[iter] = norm(self.truth - cmath.exp(-1j*angle(trace(self.truth*tmp_u[:,:,0]))) *tmp_u[:,:,0], 'fro') /self.norm_truth
else:
Times[iter] = toc
tmp_u = Prox1(u)
tmp1 = Prox2(tmp_u)
for j in range(L):
for k in range(Nz):
tmp_change = tmp_change + (norm(step[:,:,k,j], 'fro') /normM)**2
tmp_gap = tmp_gap + (norm(step[:,:,k,j], 'fro')/normM)**2
Relerrs[iter] = norm(x-exp(-1j*angle(trace(x.transpose()*tmp_u[:,:,:,0])))*tmp_u[:,:,:,0],'fro')/normM
tmp_u = Prox1.work(u)
tmp1 = Prox2.work(tmp_u)
for j in range(self.product_space_dimension):
for k in range(self.Nz):
tmp_change = tmp_change + (norm(step[:,:,k,j], 'fro') /self.norm_truth)**2
tmp_gap = tmp_gap + (norm(step[:,:,k,j], 'fro')/self.norm_truth)**2
Relerrs[iter] = norm(self.truth-cmath.exp(-1j*angle(trace(self.truth*tmp_u[:,:,:,0])))*tmp_u[:,:,:,0],'fro')/self.norm_truth
gap[iter] = sqrt(tmp_gap)
change[iter] = sqrt(tmp_change)
print(14)
tmp = Prox1(u)
tmp2 = Prox2(u)
tmp = Prox1.work(u)
tmp2 = Prox2.work(u)
if self.Nx==1:
u1 = tmp[:,0]
u2 = tmp2[:,0]
......@@ -146,7 +160,9 @@ class Wirtinger(Algorithm):
u1 = tmp
u2 = tmp2
change = change[1:iter]
change = change[0:iter]
return {'u1': u1, 'u2': u2, 'iter': iter, 'change': change, 'gap': gap}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment