Commit 824773c6 authored by luckypeter.okonun's avatar luckypeter.okonun
Browse files

Uploaded CDP_Candes_1D

parent 0d370598
from numpy.random import randn, random_sample
from numpy.linalg import norm
from numpy import sqrt, conj, tile, mean, exp, angle, trace
from numpy.fft import fft, ifft
import numpy as np
import time
import random
def CDP_Candes_1D(config):
# Implementation of the Wirtinger Flow (WF) algorithm presented in the paper
# "Phase Retrieval via Wirtinger Flow: Theory and Algorithms"
# by E. J. Candes, X. Li, and M. Soltanolkotabi
# The input data are coded diffraction patterns about a random complex
# valued 1D signal.
## make signal
n = 128
x = randn(n,1) + 1j*randn(n,1)
## Make masks and Linear sampling operators
L = 6 # Number of masks
# Sample phases : each symbol in alphabet [1, -1, i,-i] has equal prob.
Masks = np.random.choice(np.array[1j, -1j, 1, -1], [n,L])
#Sample magnitudes and make masks
Temp = rand(size(Masks))
Masks = Masks *( (temp <= 0.2)*sqrt(3) + (temp > 0.2)/sqrt(2) )
# Make Linear operators, A is forward map and At its scaled adjoints (At(Y)*numel(Y) is the adjoint)
A = lambda I: fft(conj(Masks)) * repmat(I,[1, L]) # Input is n x 1 signal, output is n x L array
At = lambda Y: mean(Masks * ifft(Y), 2) # I nput is n x L array, output is n 1 signal
# Data
Y = abs(A(x))**2
## Initiazation
npower_iter = 50 # Number of power iterattons
z0 = randn(n,1)
z0 = z0/(z0,'fro') # Initial guess
for tt in range(npower_iter):
z0 = At(Y*A(z0))
z0 = z0/norm(z0, 'fro')
normest = sqrt(sum(Y.flatten)/numel(Y)) # Estimate norm to scale eigenvector
Z = normest * z0 # Apply scaling
Relerrs = norm(x - exp(-1j *angle(trace(x*z))) * z, 'fro')/norm(x,'fro') #Initial rel error
## Loop
T = 2500 # Max number of iterations
Tau0 = 330 # Time constant for step size
for t in range(T):
Bz = A(z)
C = (abs(Bz)**2-Y) * Bz
grad = At(C) # Wirtinger gradient
z = z - mu(t)/normest**2 * grad # Gradient update
Relerrs = [Relerrs, norm(x - exp(-1j*angle(trace(x*z))) * z, 'fro')/norm(x,'fro')]
## Check results
print("Relative error after initialization:" + Relerrs[1])
print(" Relative error after iterations:" + Relerrs[T+1])
X = range(T)
Y = Relerrs
fig = plt.figure()
ax.plot(X, Y)
plt.title('Relative error vs. iteration count')
plt.xlabel('Iteration')
plt.ylabel('Relative error (log10)')
ax.set_yscale('log')
plt.show()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment