CDP_processor.py 3.93 KB
Newer Older
1
2
from numpy.random import randn, random_sample
from numpy.linalg import norm
3
4
from numpy import sqrt, conj, tile, mean
from numpy.fft import fft, ifft
5
import numpy as np
6
import time
7

8
9
10
11
12
13
14
15
16
17
18
19
20
def CDP_processor(config):
    # Implementation of the Wirtinger Flow (WF) algorithm presented in the paper 
    # "Phase Retrieval via Wirtinger Flow: Theory and Algorithms" 
    # by E. J. Candes, X. Li, and M. Soltanolkotabi
    # integrated into the ProxToolbox by 
    # Russell Luke, September 2016.
    
    # The input data are coded diffraction patterns about a random complex
    # valued image. 
    
    ## Make image
    n1 = config['Ny']
    n2 = config['Nx'] # for 1D signals, this will be 1
21
    x = randn(n1,n2) + 1j*randn(n1,n2)
22
23
    config['truth']=x
    config['norm_truth']=norm(x,'fro')
24
    config['truth_dim'] = x.shape
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
    
    
    ## Make masks and linear sampling operators
    
    L = config['product_space_dimension']                  # Number of masks 
    if n2==1:
        Masks = np.random.choice(np.array([1j, -1j, 1, -1]),(n1,L))
    elif n1==1:
        Masks = np.random.choice(np.array([1j, -1j, 1, -1]),(L,n2))
    else:
        Masks = zeros((n1,n2,L));  # Storage for L masks, each of dim n1 x n2
        # Sample phases: each symbol in alphabet {1, -1, i , -i} has equal prob. 
        for ll in range(L):
            Masks[:,:,ll] = np.random.choice(np.array([1j, -1j, 1, -1]),(n1,n2))
    
    # Sample magnitudes and make masks 
41
    temp = random_sample(Masks.shape) #works like rand but accepts tuple as argument
42
43
44
45
46
47
48
    Masks = Masks * ( (temp <= 0.2)*sqrt(3) + (temp > 0.2)/sqrt(2) )
    config['Masks'] = conj(Masks)
    # Saving the conjugate of the mask saves on computing the conjugate
    # every time the mapping A (below) is applied.
    
    if n2==1:
        # Make linear operators; A is forward map and At its scaled adjoint (At(Y)*numel(Y) is the adjoint)
49
50
51
        A = lambda I: fft(conj(Masks) * tile(I,[1, L]))  # Input is n x 1 signal, output is n x L array
        At = lambda Y: mean(Masks * ifft(Y), 1)           # Input is n x L array, output is n x 1 signal
    elif n1==1 :
52
        # Make linear operators; A is forward map and At its scaled adjoint (At(Y)*numel(Y) is the adjoint)
53
54
55
        A = lambda I: fft(conj(Masks) * tile(I,[L, 1]))  # Input is 1 x n signal, output is L x n array
        At = lambda Y: mean(Masks * ifft(Y), 0)                # Input is L x n array, output is 1 x n signal
    else:
56
        A = lambda I:  fft2(config['Masks'] * reshape(tile(I,[1, L]), (I.shape[0],I.shape[1], L)))  # Input is n1 x n2 image, output is n1 x n2 x L array
57
        At = lambda Y: mean(Masks * ifft2(Y), 2)                                            # Input is n1 x n2 X L array, output is n1 x n2 image
58
59
    
    # Data 
60
    Y = abs(A(x))
61
    config['rt_data']=Y
62
63
    Y=Y**2
    config['data']=Y
64
    config['norm_data']=sum(Y)/Y.size
65
66
    normest = sqrt(config['norm_data']) # Estimate norm to scale eigenvector 
    config['norm_rt_data']=normest
67
68
69
70
     
    
    ## Initialization
    
71
    npower_iter = config['warmup_iter'];                         # Number of power iterations 
72
    z0 = randn(n1,n2); z0 = z0/norm(z0,'fro') # Initial guess 
73
    tic = time.time()                                     # Power iterations 
74
    for tt in range(npower_iter): 
75
76
        z0 = At(Y*A(z0))
        z0 = z0/norm(z0)
77

78
    toc  = time.time()
79
    
80
    z = normest * z0                  # Apply scaling 
81
    if n2==1:
82
83
        Relerrs = norm(x - exp(-1j*angle(trace(x.T*z))) * z, 'fro')/norm(x,'fro')
        config['u_0'] = tile(z,[1,L])
84
    elif n1==1:
85
86
        Relerrs = norm(x - exp(-1j*angle(trace(z.T*x))) * z, 'fro')/norm(x,'fro')
        config['u_0'] = tilet(z,[L,1])
87
    else:
88
89
        Relerrs = norm(x - exp(-1j*angle(trace(x.T*z))) * z, 'fro')/norm(x,'fro')
        config['u_0']=reshape(tile(z,[1, L]), (z.shape[0], z.shape[1], L))
90
    
91
92
93
    print('Run time of initialization: %.2f  seconds', toc-tic)
    print('Relative error after initialization: %.2f', Relerrs)
    print('\n')