Commit 824773c6 by luckypeter.okonun

parent 0d370598
 from numpy.random import randn, random_sample from numpy.linalg import norm from numpy import sqrt, conj, tile, mean, exp, angle, trace from numpy.fft import fft, ifft import numpy as np import time import random def CDP_Candes_1D(config): # Implementation of the Wirtinger Flow (WF) algorithm presented in the paper # "Phase Retrieval via Wirtinger Flow: Theory and Algorithms" # by E. J. Candes, X. Li, and M. Soltanolkotabi # The input data are coded diffraction patterns about a random complex # valued 1D signal. ## make signal n = 128 x = randn(n,1) + 1j*randn(n,1) ## Make masks and Linear sampling operators L = 6 # Number of masks # Sample phases : each symbol in alphabet [1, -1, i,-i] has equal prob. Masks = np.random.choice(np.array[1j, -1j, 1, -1], [n,L]) #Sample magnitudes and make masks Temp = rand(size(Masks)) Masks = Masks *( (temp <= 0.2)*sqrt(3) + (temp > 0.2)/sqrt(2) ) # Make Linear operators, A is forward map and At its scaled adjoints (At(Y)*numel(Y) is the adjoint) A = lambda I: fft(conj(Masks)) * repmat(I,[1, L]) # Input is n x 1 signal, output is n x L array At = lambda Y: mean(Masks * ifft(Y), 2) # I nput is n x L array, output is n 1 signal # Data Y = abs(A(x))**2 ## Initiazation npower_iter = 50 # Number of power iterattons z0 = randn(n,1) z0 = z0/(z0,'fro') # Initial guess for tt in range(npower_iter): z0 = At(Y*A(z0)) z0 = z0/norm(z0, 'fro') normest = sqrt(sum(Y.flatten)/numel(Y)) # Estimate norm to scale eigenvector Z = normest * z0 # Apply scaling Relerrs = norm(x - exp(-1j *angle(trace(x*z))) * z, 'fro')/norm(x,'fro') #Initial rel error ## Loop T = 2500 # Max number of iterations Tau0 = 330 # Time constant for step size for t in range(T): Bz = A(z) C = (abs(Bz)**2-Y) * Bz grad = At(C) # Wirtinger gradient z = z - mu(t)/normest**2 * grad # Gradient update Relerrs = [Relerrs, norm(x - exp(-1j*angle(trace(x*z))) * z, 'fro')/norm(x,'fro')] ## Check results print("Relative error after initialization:" + Relerrs[1]) print(" Relative error after iterations:" + Relerrs[T+1]) X = range(T) Y = Relerrs fig = plt.figure() ax.plot(X, Y) plt.title('Relative error vs. iteration count') plt.xlabel('Iteration') plt.ylabel('Relative error (log10)') ax.set_yscale('log') plt.show()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!