Commit aaa0521d authored by Matthijs's avatar Matthijs
Browse files

clean-up, documentation

parent bb9db09b
...@@ -17,10 +17,8 @@ class P_Sparsity(ProxOperator): ...@@ -17,10 +17,8 @@ class P_Sparsity(ProxOperator):
---------- ----------
config : dict - Dictionary containing the problem configuration. It must contain the following mappings: config : dict - Dictionary containing the problem configuration. It must contain the following mappings:
'sparsity_parameter' : int 'sparsity_parameter' : int
'Ny' : int
""" """
self.sparsity_parameter = config['sparsity_parameter'] self.sparsity_parameter = config['sparsity_parameter']
self.ny = config['Ny']
if 'sparsity_support' in config: if 'sparsity_support' in config:
self.support = config['sparsity_support'].real.astype(np.uint) self.support = config['sparsity_support'].real.astype(np.uint)
...@@ -53,13 +51,13 @@ class P_Sparsity(ProxOperator): ...@@ -53,13 +51,13 @@ class P_Sparsity(ProxOperator):
------- -------
p_Sparsity : array_like, the projection p_Sparsity : array_like, the projection
""" """
u *= self.support # apply support u *= self.support # apply support (simple 1 if no support)
p_Sparsity = 0 * u p_sparse = 0 * u
sorting = np.argsort(abs(u), axis=None) # gives indices of sorted array in ascending order sorting = np.argsort(abs(u), axis=None) # gives indices of sorted array in ascending order
indices = np.asarray([unravel_index(sorting[i], u.shape) for i in range(-1 * self.sparsity_parameter, 0)]) indices = np.asarray([unravel_index(sorting[i], u.shape) for i in range(-1 * self.sparsity_parameter, 0)])
p_Sparsity[indices[:, 0], indices[:, 1]] = u[indices[:, 0], indices[:, 1]] p_sparse[indices[:, 0], indices[:, 1]] = u[indices[:, 0], indices[:, 1]]
return p_Sparsity return p_sparse
class P_Sparsity_real(P_Sparsity): class P_Sparsity_real(P_Sparsity):
......
...@@ -345,7 +345,7 @@ class P_amp(ProxOperator): ...@@ -345,7 +345,7 @@ class P_amp(ProxOperator):
self.amplitude = config['amplitude'] self.amplitude = config['amplitude']
def work(self, u): def work(self, u):
return magproj(u, self.amplitude) # argument order changed compared to matlab implementation!!! return magproj(u, self.amplitude) # argument order changed compared to matlab implementation!!!
# P_SP.m # P_SP.m
...@@ -386,7 +386,7 @@ class P_SP(ProxOperator): ...@@ -386,7 +386,7 @@ class P_SP(ProxOperator):
# DESCRIPTION: Projection subroutine for projecting onto Fourier # DESCRIPTION: Projection subroutine for projecting onto Fourier
# magnitude constraints # magnitude constraints
# #
# INPUT: Func_params = a data structure with .data = nonegative real FOURIER DOMAIN CONSTRAINT # INPUT: Func_params = a data structure with .data = non-negative real FOURIER DOMAIN CONSTRAINT
# .data_ball is the regularization parameter described in # .data_ball is the regularization parameter described in
# D. R. Luke, Nonlinear Analysis 75 (2012) 1531–1546. # D. R. Luke, Nonlinear Analysis 75 (2012) 1531–1546.
# .TOL2 is an extra tolerance. # .TOL2 is an extra tolerance.
...@@ -448,6 +448,11 @@ class Approx_PM_Gaussian(ProxOperator): ...@@ -448,6 +448,11 @@ class Approx_PM_Gaussian(ProxOperator):
class Approx_PM_Poisson(ProxOperator): class Approx_PM_Poisson(ProxOperator):
def __init__(self, config): def __init__(self, config):
"""
Test whether guess is close to the measured data, if not, applies magnitude projection to the data
:param config:
"""
self.TOL2 = config['TOL2'] self.TOL2 = config['TOL2']
self.M = config['data'] self.M = config['data']
self.b = config['data_sq'] self.b = config['data_sq']
...@@ -461,15 +466,15 @@ class Approx_PM_Poisson(ProxOperator): ...@@ -461,15 +466,15 @@ class Approx_PM_Poisson(ProxOperator):
Ib = self.Ib Ib = self.Ib
epsilon = self.epsilon epsilon = self.epsilon
U = fft2(u) U = fft2(u) # guess in Fourier space
U_sq = U * conj(U) U_sq = U * conj(U) # modulus squared guess in Fourier space
tmp = U_sq / b tmp = U_sq / b # relative error of the guess (1=good)
tmp[Ib] = 1 tmp[Ib] = 1 # where data_zeros, set rel.error to 1
U_sq[Ib] = 0 U_sq[Ib] = 0 # where data_zeros, set guess to zero
IU = tmp == 0 IU = tmp == 0
tmp[IU] = 1 tmp[IU] = 1 # catch invalid values for the logarithm next line
tmp = log(tmp) tmp = log(tmp)
hU = sum(sum(U_sq * tmp + b - U_sq)) hU = sum(sum(U_sq * tmp + b - U_sq)) # Kullback-Leibler divergence?
if hU >= epsilon + TOL2: if hU >= epsilon + TOL2:
U0 = magproj(U, M) U0 = magproj(U, M)
return ifft2(U0) return ifft2(U0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment