CDP_processor.py 3.88 KB
Newer Older
1
2
3
4
5
from numpy.random import randn, random_sample
from numpy.linalg import norm
from numpy import sqrt, conj, fft, tile
import numpy as np

6
7
8
9
10
11
12
13
14
15
16
17
18
def CDP_processor(config):
    # Implementation of the Wirtinger Flow (WF) algorithm presented in the paper 
    # "Phase Retrieval via Wirtinger Flow: Theory and Algorithms" 
    # by E. J. Candes, X. Li, and M. Soltanolkotabi
    # integrated into the ProxToolbox by 
    # Russell Luke, September 2016.
    
    # The input data are coded diffraction patterns about a random complex
    # valued image. 
    
    ## Make image
    n1 = config['Ny']
    n2 = config['Nx'] # for 1D signals, this will be 1
19
    x = randn(n1,n2) + 1j*randn(n1,n2)
20
21
    config['truth']=x
    config['norm_truth']=norm(x,'fro')
22
    config['truth_dim'] = x.shape
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    
    
    ## Make masks and linear sampling operators
    
    L = config['product_space_dimension']                  # Number of masks 
    if n2==1:
        Masks = np.random.choice(np.array([1j, -1j, 1, -1]),(n1,L))
    elif n1==1:
        Masks = np.random.choice(np.array([1j, -1j, 1, -1]),(L,n2))
    else:
        Masks = zeros((n1,n2,L));  # Storage for L masks, each of dim n1 x n2
        # Sample phases: each symbol in alphabet {1, -1, i , -i} has equal prob. 
        for ll in range(L):
            Masks[:,:,ll] = np.random.choice(np.array([1j, -1j, 1, -1]),(n1,n2))
    
    # Sample magnitudes and make masks 
39
    temp = random_sample(Masks.shape) #works like rand but accepts tuple as argument
40
41
42
43
44
45
46
    Masks = Masks * ( (temp <= 0.2)*sqrt(3) + (temp > 0.2)/sqrt(2) )
    config['Masks'] = conj(Masks)
    # Saving the conjugate of the mask saves on computing the conjugate
    # every time the mapping A (below) is applied.
    
    if n2==1:
        # Make linear operators; A is forward map and At its scaled adjoint (At(Y)*numel(Y) is the adjoint)
47
48
49
        A = lambda I: fft(conj(Masks) * tile(I,[1, L]))  # Input is n x 1 signal, output is n x L array
        At = lambda Y: mean(Masks * ifft(Y), 1)           # Input is n x L array, output is n x 1 signal
    elif n1==1 :
50
        # Make linear operators; A is forward map and At its scaled adjoint (At(Y)*numel(Y) is the adjoint)
51
52
53
        A = lambda I: fft(conj(Masks) * tile(I,[L, 1]))  # Input is 1 x n signal, output is L x n array
        At = lambda Y: mean(Masks * ifft(Y), 0)                # Input is L x n array, output is 1 x n signal
    else:
54
        A = lambda I:  fft2(config['Masks'] * reshape(tile(I,[1, L]), (I.shape[0],I.shape[1], L)))  # Input is n1 x n2 image, output is n1 x n2 x L array
55
        At = lambda Y: mean(Masks * ifft2(Y), 2)                                            # Input is n1 x n2 X L array, output is n1 x n2 image
56
57
    
    # Data 
58
    Y = abs(A(x))
59
    config['rt_data']=Y
60
61
    Y=Y**2
    config['data']=Y
62
    config['norm_data']=sum(Y)/Y.size
63
64
    normest = sqrt(config['norm_data']) # Estimate norm to scale eigenvector 
    config['norm_rt_data']=normest
65
66
67
68
     
    
    ## Initialization
    
69
70
71
    npower_iter = config['warmup_iter'];                         # Number of power iterations 
    z0 = randn((n1,n2)); z0 = z0/norm(z0,'fro') # Initial guess 
    tic = time.time()                                     # Power iterations 
72
    for tt in range(npower_iter): 
73
        z0 = At(Y*A(z0)); z0 = z0/norm(z0,'fro')
74

75
    toc  = time.time()
76
    
77
    z = normest * z0                  # Apply scaling 
78
    if n2==1:
79
80
        Relerrs = norm(x - exp(-1j*angle(trace(x.T*z))) * z, 'fro')/norm(x,'fro')
        config['u_0'] = tile(z,[1,L])
81
    elif n1==1:
82
83
        Relerrs = norm(x - exp(-1j*angle(trace(z.T*x))) * z, 'fro')/norm(x,'fro')
        config['u_0'] = tilet(z,[L,1])
84
    else:
85
86
        Relerrs = norm(x - exp(-1j*angle(trace(x.T*z))) * z, 'fro')/norm(x,'fro')
        config['u_0']=reshape(tile(z,[1, L]), (z.shape[0], z.shape[1], L))
87
    
88
89
90
    print('Run time of initialization: %.2f  seconds', toc-tic)
    print('Relative error after initialization: %.2f', Relerrs)
    print('\n')