Commit 9c4bc48f authored by jansen31's avatar jansen31
Browse files

reconstruct coupled orbitals

parent ff4aca49
......@@ -12,13 +12,13 @@ class DegenerateOrbital(PlanarMolecule):
defaultParams = {
'experiment_name': '2D ARPES',
'data_filename': None,
'from_intensity_data': False,
'from_intensity_data': True,
'object': 'real',
'degeneracy': 2, # Number of degenerate states to reconstruct
'constraint': 'sparse real',
'sparsity_parameter': 40,
'use_sparsity_with_support': True,
'threshold_for_support': 0.1,
'threshold_for_support': 0.05,
'support_filename': None,
'Nx': None,
'Ny': None,
......@@ -156,10 +156,10 @@ class DegenerateOrbital(PlanarMolecule):
ax[ii].set_title("Degenerate orbit %d" % ii)
im = ax[-2].imshow(np.sum(abs(u_show) ** 2, axis=0))
ax[-2].set_title("Local density of states")
plt.colorbar(im, ax=ax[-2])
# plt.colorbar(im, ax=ax[-2], shrink=0.7)
im = ax[-1].imshow(fourier_intensity)
ax[-1].set_title("Fourier domain intensity")
plt.colorbar(im, ax=ax[-1])
# plt.colorbar(im, ax=ax[-1], shrink=0.7)
plt.tight_layout()
if show:
plt.show()
......
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imread
from scipy.ndimage import binary_dilation, shift, center_of_mass
from proxtoolbox.experiments.orbitaltomography.planar_molecule import PlanarMolecule
from proxtoolbox.utils.orbitaltomog import shifted_fft, fourier_interpolate, bin_array, shifted_ifft
from proxtoolbox.utils.visualization.complex_field_visualization import complex_to_rgb
class OrthogonalOrbitals(PlanarMolecule):
@staticmethod
def getDefaultParameters():
defaultParams = {
'experiment_name': '2D ARPES',
'data_filename': None,
'from_intensity_data': True,
'object': 'real',
'degeneracy': 2, # Number of degenerate states to reconstruct
'constraint': 'sparse real',
'sparsity_parameter': 40,
'use_sparsity_with_support': True,
'threshold_for_support': 0.01,
'support_filename': None,
'Nx': None,
'Ny': None,
'Nz': None,
'MAXIT': 500,
'TOL': 1e-10,
'lambda_0': 0.85,
'lambda_max': 0.50,
'lambda_switch': 50,
'data_ball': .999826,
'TOL2': 1e-15,
'diagnostic': True,
'algorithm': 'CP',
'iterate_monitor_name': 'FeasibilityIterateMonitor', # 'IterateMonitor', #
'rotate': False,
'verbose': 1,
'graphics': 1,
'interpolate_and_zoom': True,
'debug': True,
'progressbar': None
}
return defaultParams
def __init__(self, **kwargs):
super(OrthogonalOrbitals, self).__init__(**kwargs)
def loadData(self):
"""
Load data and set in the correct format for reconstruction
Parameters are taken from experiment class (self) properties, which must include::
- data_filename: str, path to the data file, or list of file names
- from_intensity_data: bool, if the data file gives intensities rather than field amplitude
- support_filename: str, optional path to file with object support
- use_sparsity_with_support: bool, if true, use a support before the sparsity constraint.
The support is calculated by thresholding the object autocorrelation, and dilate the result
- threshold_for_support: float, in range [0,1], fraction of the maximum at which to threshold when
determining support or support for sparsity
"""
# load data
if self.data_filename is None:
self.data_filename = input('Please enter the path to the datafile: ')
try:
if isinstance(self.data_filename, str):
self.data = imread(self.data_filename)
else:
self.data = np.array([imread(fname) for fname in self.data_filename])
except FileNotFoundError:
print("Tried path %s, found nothing. " % self.data_filename)
self.data_filename = input('Please enter a valid path to the datafile: ')
self.data = imread(self.data_filename)
# If data is corrected for A.K, then it should be well centered. we can check that here
for i in range(len(self.data)):
cm = center_of_mass(self.data[i] ** 2)
to_shift = tuple([s // 2 - cm[i] for i, s in enumerate(self.data[i].shape)])
self.data[i] = shift(self.data[i], to_shift, mode='nearest', order=1)
# Keep the same resolution?
self.Nz, ny, nx = self.data.shape
if self.Ny is None:
self.Ny = ny
if self.Nx is None:
self.Nx = nx
if ny != self.Ny or nx != self.Nx:
# binning must be done for the intensity-data, as that preserves the normalization
if self.from_intensity_data:
self.data = bin_array(self.data, (self.Nz, self.Ny, self.Nx))
else:
self.data = np.sqrt(bin_array(self.data ** 2, (self.Nz, self.Ny, self.Nx)))
self.Nz, self.Ny, self.Nx = self.data.shape
# Calculate electric field and norm of the data
if self.from_intensity_data:
# avoid sqrt of negative numbers (due to background subtraction)
self.data = np.where(self.data > 0, np.sqrt(abs(self.data)), np.zeros_like(self.data))
self.norm_data = np.sqrt(np.sum(self.data ** 2))
# Object support determination
if self.support is not None:
self.support = imread(self.support_filename)
else:
self.support = support_from_stack(self.data,
threshold=self.threshold_for_support,
absolute_autocorrelation=True,
binary_dilate_support=1)
if self.use_sparsity_with_support:
self.sparsity_support = support_from_stack(self.data,
threshold=self.threshold_for_support,
binary_dilate_support=1)
self.createRandomGuess()
# some variables wich are necessary for the algorithm:
self.data_sq = self.data ** 2
self.data_zeros = np.where(self.data == 0)
def createRandomGuess(self):
"""
Taking the measured data, add a random phase and calculate the resulting iterate guess
"""
self.u0 = self.data * np.exp(1j * 2 * np.pi * np.random.random_sample(self.data.shape))
self.u0 = shifted_fft(self.u0, axes=(-1, -2))
def setupProxOperators(self):
"""
Determine the prox operators to be used based on the given constraint.
This method is called during the initialization process.
sets the parameters:
- self.proxOperators
- self.propagator and self.inverse_propagator
"""
# Select the right real space operator sparsity-based proxoperators
self.proxOperators.append('P_Sparsity_real_incoherent')
# Apply orthonormality constraint
self.proxOperators.append('P_orthonorm')
# Modulus proxoperator (normally the second operator)
self.proxOperators.append('P_M')
self.propagator = 'PropagatorFFT2'
self.inverse_propagator = 'InvPropagatorFFT2'
self.nProx = len(self.proxOperators)
def plotInputData(self):
"""Quick plotting routine to show the data, initial guess and the sparsity support"""
fig, ax = plt.subplots(2, self.Nz + 1, figsize=(12, 7))
for ii in range(self.Nz):
im = ax[0][ii].imshow(self.data[ii])
plt.colorbar(im, ax=ax[0][ii])
ax[0][ii].set_title("Photoelectron spectrum %d" % ii)
if self.sparsity_support is not None:
im = ax[0][-1].imshow(self.sparsity_support, cmap='gray')
# plt.colorbar(im, ax=ax[2])
ax[0][-1].set_title("Sparsity support")
for ii in range(self.Nz):
im = ax[1][ii].imshow(complex_to_rgb(self.u0[ii]))
plt.colorbar(im, ax=ax[1][ii])
ax[1][ii].set_title("Degenerate orbit %d" % ii)
ax[1][-1].imshow(np.sum(abs(self.u0) ** 2, axis=0))
ax[1][-1].set_title("Integrated density of states")
plt.show()
def show(self, **kwargs):
"""
Create basic result plots of the phase retrieval procedure
"""
super(PlanarMolecule, self).show()
self.output['u1'] = self.algorithm.prox1.eval(self.algorithm.u)
self.output['u2'] = self.algorithm.prox2.eval(self.algorithm.u)
figsize = kwargs.pop("figsize", (12, 3))
for i, operator in enumerate(self.algorithm.proxOperators):
operator_name = self.proxOperators[i].__name__
f = self.plot_guess(operator.eval(self.algorithm.u),
name="%s satisfied" % operator_name,
show=False,
interpolate_and_zoom=self.interpolate_and_zoom,
figsize=figsize)
self.output['plots'].append(f)
# f1 = self.plot_guess(self.output['u1'], name='Best approximation: physical constraint satisfied', show=False)
# f2 = self.plot_guess(self.output['u2'], name='Best approximation: Fourier constraint satisfied', show=False)
# prop = self.propagator(self)
# u_hat = prop.eval(self.algorithm.prox1.eval(self.algorithm.u))
# h = self.plot_guess(u_hat, show=False, name="Fourier domain measurement projection")
# self.output['plots'].append(f1)
# self.output['plots'].append(f2)
# self.output['plots'].append(h)
plt.show()
# def saveOutput(self, **kwargs):
# super(PlanarMolecule, self).saveOutput(**kwargs)
def plot_guess(self, u, name=None, show=True, interpolate_and_zoom=False, figsize=(12, 6)):
""""Given a list of fields, plot the individual fields and the combined intensity"""
prop = self.propagator(self) # This is not a string but the indicated class itself, to be instantiated
u_hat = prop.eval(u)
fourier_intensity = np.sqrt(np.sum(abs(u_hat) ** 2, axis=0))
if interpolate_and_zoom:
u_show = self.interp_zoom_field(u)
else:
u_show = u
fig, ax = plt.subplots(2, len(u) + 1, figsize=figsize, num=name)
for ii in range(self.Nz):
im = ax[0][ii].imshow(complex_to_rgb(u_show[ii]))
ax[0][ii].set_title("Degenerate orbit %d" % ii)
im = ax[0][-1].imshow(np.sum(abs(u_show) ** 2, axis=0))
ax[0][-1].set_title("Local density of states")
for ii in range(self.Nz):
im = ax[1][ii].imshow(complex_to_rgb(u_hat[ii]))
ax[1][ii].set_title("Fourier domain %d" % ii)
# plt.colorbar(im, ax=ax[-2], shrink=0.7)
im = ax[1][-1].imshow(fourier_intensity)
ax[1][-1].set_title("Total Fourier domain intensity")
# plt.colorbar(im, ax=ax[-1], shrink=0.7)
plt.tight_layout()
if show:
plt.show()
return fig
def interp_zoom_field(self, u, interpolation=2, zoom=0.5):
"""
interpolate a field and zoom in to the center
"""
nt, ny, nx = u.shape
cm = center_of_mass(np.sum(abs(u) ** 2, axis=0))
to_shift = (0, -1*int(np.round(cm[0] - ny / 2)), -1*int(np.round(cm[1] - nx / 2)))
centered = np.roll(u, to_shift, axis=(0, 1, 2))
zmy = int(ny * zoom) // 2
zmx = int(nx * zoom) // 2
zoomed = centered[:, zmy:ny - zmy, zmx:nx - zmx]
interpolated = np.array([fourier_interpolate(u_i, factor=interpolation) for u_i in zoomed])
return interpolated
def support_from_stack(input_array: np.ndarray,
threshold: float = 0.1,
relative_threshold: bool = True,
input_in_fourier_domain: bool = True,
absolute_autocorrelation: bool = True,
binary_dilate_support: int = 0) -> np.ndarray:
"""
Determine an initial support from a list of autocorrelations.
Args:
input_array: either the measured diffraction patterns (arpes patterns) or guesses of the objects
threshold: support is everywhere where the autocorrelation is higher than the threshold
relative_threshold: If true, threshold at threshold*np.amax(autocorrelation)
input_in_fourier_domain: False if a guess of the object is given in input_array
absolute_autocorrelation: Take the absolute value of the autocorrelation? (Generally a
good idea for objects which are not non-negative)
binary_dilate_support: number of dilation operations to apply to the support.
Returns:
support array (dimensions: input_array.shape[1:], dtype=np.int)
"""
_axes = tuple(range(-1 * input_array.ndim + 1, 0))
if not input_in_fourier_domain:
kspace = shifted_fft(input_array, axes=_axes)
else:
kspace = input_array
# Taking absolute value of the Fourier transform yields autocorrelation by conv. theorem)
autocorrelation = shifted_ifft(abs(kspace), axes=_axes)
if absolute_autocorrelation:
autocorrelation = abs(autocorrelation)
# Take the sum along the first axis to get the average of the autocorrelations
autocorrelation = np.sum(autocorrelation, axis=0)
# Detetmine thresholding
maxval = np.amax(autocorrelation)
if relative_threshold:
threshold_val = threshold * maxval
else:
threshold_val = threshold
support = (autocorrelation > threshold_val).astype(np.uint)
# Dilate support to make it a bit too big (also fills small gaps)
if binary_dilate_support > 0:
support = binary_dilation(support, iterations=binary_dilate_support).astype(np.uint)
return support
......@@ -37,8 +37,8 @@ class P_orthonorm(ProxOperator):
raise Exception("This should never rise, check calculation of a")
# apply projection
u_new = zeros_like(u_norm)
u_new[0] = u_norm[0] - (y / (y ** 2 - 1)) * (y * u_norm[0] - u_norm[1])
u_new[1] = (1 / (y ** 2 - 1)) * (y * u_norm[0] - u_norm[1])
u_new[1] = u_norm[0] - (y / (y ** 2 - 1)) * (y * u_norm[0] - u_norm[1])
u_new[0] = (1 / (y ** 2 - 1)) * (y * u_norm[0] - u_norm[1])
return u_new
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment