Commit 59bfba7a authored by jansen31's avatar jansen31
Browse files

orthonormality constraint

parent a2312cd7
from skimage.io import imread
from scipy.ndimage import binary_dilation
import numpy as np
import matplotlib.pyplot as plt
from proxtoolbox.experiments.orbitaltomography.planar_molecule import PlanarMolecule
from proxtoolbox.utils.visualization.complex_field_visualization import complex_to_rgb
from proxtoolbox.utils.orbitaltomog import bin_array, shifted_fft, shifted_ifft, fourier_interpolate, roll_to_pos
class DegenerateOrbital(PlanarMolecule):
@staticmethod
def getDefaultParameters():
defaultParams = {
'experiment_name': '2D ARPES',
'data_filename': None,
'from_intensity_data': False,
'object': 'real',
'constraint': 'sparse real',
'sparsity_parameter': 40,
'use_sparsity_with_support': False,
'threshold_for_support': 0.1,
'support_filename': None,
'Nx': None,
'Ny': None,
'Nz': 1,
'MAXIT': 500,
'TOL': 1e-10,
'lambda_0': 0.85,
'lambda_max': 0.50,
'lambda_switch': 50,
'data_ball': .999826,
'TOL2': 1e-15,
'diagnostic': True,
'algorithm': 'DRl',
'iterate_monitor_name': 'FeasibilityIterateMonitor', # 'IterateMonitor', #
'rotate': False,
'verbose': 1,
'graphics': 1,
'interpolate_and_zoom': True,
'debug': True,
}
return defaultParams
def __init__(self, **kwargs):
super(DegenerateOrbital, self).__init__(**kwargs)
def loadData(self):
"""
Load data and set in the correct format for reconstruction
Parameters are taken from experiment class (self) properties, which must include::
- data_filename: str, path to the data file
- from_intensity_data: bool, if the data file gives intensities rather than field amplitude
- support_filename: str, optional path to file with object support
- use_sparsity_with_support: bool, if true, use a support before the sparsity constraint.
The support is calculated by thresholding the object autocorrelation, and dilate the result
- threshold_for_support: float, in range [0,1], fraction of the maximum at which to threshold when
determining support or support for sparsity
"""
# load data
if self.data_filename is None:
self.data_filename = input('Please enter the path to the datafile: ')
try:
self.data = imread(self.data_filename)
except FileNotFoundError:
print("Tried path %s, found nothing. " % self.data_filename)
self.data_filename = input('Please enter a valid path to the datafile: ')
self.data = imread(self.data_filename)
# Keep the same resolution?
ny, nx = self.data.shape
try:
if ny % self.Ny == 0 and nx % self.Nx == 0:
# binning must be done for the intensity-data, as that preserves the normalization
if self.from_intensity_data:
self.data = bin_array(self.data, (self.Ny, self.Nx))
else:
self.data = np.sqrt(bin_array(self.data ** 2, (self.Ny, self.Nx)))
else:
# TODO: use flexibility allowed by the binning function to prevent this error
raise ValueError('Incompatible values for Ny and Nx given in configuration dict')
except TypeError:
pass
self.Ny, self.Nx = self.data.shape
# Calculate electric field and norm of the data
if self.from_intensity_data:
# avoid sqrt of negative numbers (due to background subtraction)
self.data = np.where(self.data > 0, np.sqrt(abs(self.data)), np.zeros_like(self.data))
self.norm_data = np.sqrt(np.sum(self.data ** 2))
# Object support determination
if self.support is not None:
self.support = imread(self.support_filename)
else:
self.support = support_from_autocorrelation(self.data,
threshold=self.threshold_for_support,
absolute_autocorrelation=True,
binary_dilate_support=1)
if self.use_sparsity_with_support:
self.sparsity_support = support_from_autocorrelation(self.data,
threshold=self.threshold_for_support,
binary_dilate_support=1)
self.createRandomGuess()
# some variables wich are necessary for the algorithm:
self.data_sq = self.data ** 2
self.data_zeros = np.where(self.data == 0)
def createRandomGuess(self):
"""
Taking the measured data, add a random phase and calculate the resulting iterate guess
"""
ph_init = 2 * np.pi * np.random.random_sample(self.data.shape)
self.u0 = self.data * np.exp(1j * ph_init)
self.u0 = np.fft.fftn(self.u0)
def plotInputData(self):
"""Quick plotting routine to show the data, initial guess and the sparsity support"""
fig, ax = plt.subplots(1, 3, figsize=(12, 3.5))
im = ax[0].imshow(self.data)
plt.colorbar(im, ax=ax[0])
plt.title("Photoelectron spectrum")
ax[1].imshow(complex_to_rgb(self.u0))
plt.title("Initial guess")
if self.sparsity_support is not None:
im = ax[2].imshow(self.sparsity_support, cmap='gray')
# plt.colorbar(im, ax=ax[2])
plt.title("Sparsity support")
plt.show()
def show(self):
"""
Create basic result plots of the phase retrieval procedure
"""
self.output['plots'] = []
self.output['u1'] = self.algorithm.prox1.eval(self.algorithm.u)
self.output['u2'] = self.algorithm.prox2.eval(self.algorithm.u)
if self.interpolate_and_zoom:
for key in ["u1", "u2"]:
center = tuple([s // 2 for s in self.output[key].shape])
self.output[key] = roll_to_pos(self.output[key], pos=center, move_maximum=True)
self.output[key] = roll_to_pos(self.output[key], pos=center)
self.output[key] = fourier_interpolate(self.output[key], 2)
zmy, zmx = tuple([s // 4 for s in self.output[key].shape])
self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx]
u1 = self.output['u1']
u2 = self.output['u2']
change = self.output['stats']['changes']
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))
im = ax1.imshow(abs(u1), cmap='gray')
f.colorbar(im, ax=ax1)
ax1.set_title('best approximation amplitude: \nphysical constraint satisfied')
ax2.imshow(complex_to_rgb(u1))
ax2.set_title('best approximation phase: \nphysical constraint satisfied')
im = ax3.imshow(abs(u2), cmap='gray')
f.colorbar(im, ax=ax3)
ax3.set_title('best approximation amplitude: \nFourier constraint satisfied')
ax4.imshow(complex_to_rgb(u2))
ax4.set_title('best approximation amplitude: \nFourier constraint satisfied')
f.tight_layout()
g, ((bx1, bx2), (bx3, bx4)) = plt.subplots(2, 2, figsize=(9, 6))
im = bx1.imshow(abs(u1), cmap='gray')
f.colorbar(im, ax=bx1)
bx1.set_title('best approximation amplitude: \nphysical constraint satisfied')
im = bx2.imshow(u1.real, cmap='gray')
f.colorbar(im, ax=bx2)
bx2.set_title('best approximation phase: \nphysical constraint satisfied')
bx3.semilogy(change)
bx3.set_xlabel('iteration')
bx3.set_ylabel('Change')
if 'gaps' in self.output['stats']:
gaps = self.output['stats']['gaps']
bx4.semilogy(gaps)
bx4.set_xlabel('iteration')
bx4.set_ylabel('Gap')
f.tight_layout()
h, ax = plt.subplots(1, 3, figsize=(9,3))
ax[0].imshow(self.data)
ax[0].set_title("Measured data")
prop = self.propagator(self)
u_hat = prop.eval(self.algorithm.prox1.eval(self.algorithm.u))
ax[1].imshow(abs(u_hat))
ax[1].set_title("Predicted measurement intensity")
ax[2].imshow(complex_to_rgb(u_hat))
ax[2].set_title("Predicted phase (by color)")
h.tight_layout()
plt.show()
self.output['plots'].append(f)
self.output['plots'].append(g)
self.output['plots'].append(h)
# def saveOutput(self, **kwargs):
# super(PlanarMolecule, self).saveOutput(**kwargs)
def support_from_autocorrelation(input_array: np.ndarray,
threshold: float = 0.1,
relative_threshold: bool = True,
input_in_fourier_domain: bool = True,
absolute_autocorrelation: bool = True,
binary_dilate_support: int = 0) -> np.ndarray:
"""
Determine an initial support from the autocorrelation of an object.
Args:
input_array: either the measured diffraction (arpes pattern) or a guess of the object
threshold: support is everywhere where the autocorrelation is higher than the threshold
relative_threshold: If true, threshold at threshold*np.amax(autocorrelation)
input_in_fourier_domain: False if a guess of the object is given in input_array
absolute_autocorrelation: Take the absolute value of the autocorrelation? (Generally a
good idea for objects which are not non-negative)
binary_dilate_support: number of dilation operations to apply to the support.
Returns:
support array (same dimensions as input, dtype=np.int)
"""
if not input_in_fourier_domain:
kspace = shifted_fft(input_array)
else:
kspace = input_array
autocorrelation = shifted_ifft(abs(kspace)) # Taking absolute value yields autocorrelation by conv. theorem)
if absolute_autocorrelation:
autocorrelation = abs(autocorrelation)
maxval = np.amax(autocorrelation)
if relative_threshold:
threshold_val = threshold * maxval
else:
threshold_val = threshold
support = (autocorrelation > threshold_val).astype(np.uint)
if binary_dilate_support > 0:
support = binary_dilation(support, iterations=binary_dilate_support).astype(np.uint)
return support
from numpy import sum, arccos, array, cbrt, sqrt
from numpy import sum, array, cbrt, sqrt, zeros_like
from proxtoolbox.proxoperators.proxoperator import ProxOperator
__all__ = ['P_orthonorm', "P_norm"]
class P_orthonorm(ProxOperator):
def __init__(self, experiment):
......@@ -15,23 +17,27 @@ class P_orthonorm(ProxOperator):
def eval(self, u, **kwargs):
# normalize u[0] and u[1]
u_norm = self.p_norm.eval(u, )
u_norm = self.p_norm.eval(u)
norms = [sqrt(sum(abs(u) ** 2)) for u in u_norm]
# determine angle _a_ between u[0] and u[1]
a = arccos(u_norm[0] * u_norm[1])
# determine root of y^3 - 3/2 a y^2 + 1/2 a = 0
y_part = cbrt(2 * sqrt(a ** 4 + a ** 2) - 2 * a - a ** 3)
y = 0.5 * (a ** 2 / y_part + y_part - a)
# apply projection
u_new = u_norm.copy()
u_new[0] = u_norm[0] - (y / (y ** 2 - 1)) * (u_norm[1] - y * u_norm[0])
u_new[1] = (1 / (y ** 2 - 1)) * (u_norm[1] - y * u_norm[0])
return u_new
a = sum(u_norm[0] * u_norm[1]) / (norms[0] * norms[1])
if a != 0: # for non-orthogonal iterates, apply change
# determine root of y^3 - 3/2 a y^2 + 1/2 a = 0
y_part = cbrt(2 * sqrt(a ** 4 + a ** 2) - 2 * a - a ** 3)
y = 0.5 * (a ** 2 / y_part + y_part - a)
# apply projection
u_new = zeros_like(u_norm)
u_new[0] = u_norm[0] - (y / (y ** 2 - 1)) * (u_norm[1] - y * u_norm[0])
u_new[1] = (1 / (y ** 2 - 1)) * (u_norm[1] - y * u_norm[0])
return u_new
else:
return u_norm
class P_norm(ProxOperator):
def __init__(self, experiment):
"""
Normalize iterates to
Normalize iterates, such the incoherent sum [sum(abs(u)**2, axis=0)] adds up to experiment.norm
"""
super(P_norm, self).__init__(experiment)
self.norm = experiment.norm_data
......@@ -39,4 +45,27 @@ class P_norm(ProxOperator):
def eval(self, u, **kwargs):
# Normalize components of u
return array([self.norm * u_n / sum(abs(u_n)) / sqrt(len(u)) for u_n in u])
return array([self.norm * u_n / sqrt(sum(abs(u_n) ** 2)) / sqrt(len(u)) for u_n in u])
if __name__ == "__main__":
import numpy as np
class DummyExperiment:
def __init__(self, data=1, norm=sqrt(2)):
self.data = data
self.norm_data = norm
exp = DummyExperiment()
portho = P_orthonorm(exp)
pnorm = P_norm(exp)
th = 0.55
inp = np.array([[1, 0], [np.cos(th * np.pi), np.sin(th * np.pi)]])
out = portho.eval(inp)
print("Input:", inp)
print("Output:", out)
print("Inner product of the output: ", np.sum(out[0]*out[1]) )
\ No newline at end of file
......@@ -27,3 +27,4 @@ from .P_S import *
from .P_CDP_cyclic import *
from .sourceLocProx import *
from .Pphase_phasepack import *
from .P_orthonorm import *
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment