Commit b1d5fcb6 authored by jansen31's avatar jansen31
Browse files

start towards momentum microscopy full data reconstruction

parent d7e71f66
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imread
from scipy.ndimage import binary_dilation, shift, center_of_mass
from proxtoolbox.experiments.orbitaltomography.planar_molecule import PlanarMolecule
from proxtoolbox.utils.orbitaltomog import shifted_fft, fourier_interpolate, bin_array, shifted_ifft
from proxtoolbox.utils.visualization.complex_field_visualization import complex_to_rgb
class OrbitalMomentumMicroscope(PlanarMolecule):
"""
Orbit
"""
@staticmethod
def getDefaultParameters():
defaultParams = {
'experiment_name': 'Momentum microscopy',
'data_filename': '..\\Inputdata\\OrbitalTomog'
+ '\\2020_11_05_Coronene_HOMO_near_degenerate_noisefree.tif',
'from_intensity_data': False,
'object': 'real',
'constraint': 'sparse real',
'orthonormality': True,
'recover_background': False,
'n_states': 2,
'sparsity_parameter': 75,
'use_sparsity_with_support': True,
'threshold_for_support': 0.1,
'support_filename': None,
'Nx': None,
'Ny': None,
'Nz': None,
'MAXIT': 500,
'TOL': 1e-10,
'diagnostic': True,
'algorithm': 'CP', # Cyclic Projections: reduces to AP when only given 2 proxoperators
'iterate_monitor_name': 'FeasibilityIterateMonitor', # 'IterateMonitor', #
'verbose': 1,
'interpolate_and_zoom': True,
'debug': True,
'progressbar': None
}
return defaultParams
def __init__(self, **kwargs):
super(OrbitalMomentumMicroscope, self).__init__(**kwargs)
self.orthonormality = kwargs['orthonormality']
self.recover_background = kwargs['recover_background']
self.n_states = kwargs['n_states']
self.state_axis = 0 # The individual orbitals (and background) are contained in this axis
self.energy_axis = 1 # This dimension of the iterate contains the energy dependence
self.momentum_axes = (-1, -2)
def loadData(self):
"""
Load data and set in the correct format for reconstruction
Parameters are taken from experiment class (self) properties, which must include::
- data_filename: str, path to the data file, or list of file names
- from_intensity_data: bool, if the data file gives intensities rather than field amplitude
- support_filename: str, optional path to file with object support
- use_sparsity_with_support: bool, if true, use a support before the sparsity constraint.
The support is calculated by thresholding the object autocorrelation, and dilate the result
- threshold_for_support: float, in range [0,1], fraction of the maximum at which to threshold when
determining support or support for sparsity
"""
# load data
if self.data_filename is None:
self.data_filename = input('Please enter the path to the datafile: ')
try:
if isinstance(self.data_filename, str):
self.data = imread(self.data_filename)
else:
self.data = np.array([imread(fname) for fname in self.data_filename])
except FileNotFoundError:
print("Tried path %s, found nothing. " % self.data_filename)
self.data_filename = input('Please enter a valid path to the datafile: ')
self.data = imread(self.data_filename)
# If data is corrected for A.K, then it should be well centered. we can check that here
for i in range(len(self.data)):
cm = center_of_mass(self.data[i] ** 2)
to_shift = tuple([s // 2 - cm[i] for i, s in enumerate(self.data[i].shape)])
self.data[i] = shift(self.data[i], to_shift, mode='nearest', order=1)
# Keep the same resolution?
self.Nz, ny, nx = self.data.shape
if self.Ny is None:
self.Ny = ny
if self.Nx is None:
self.Nx = nx
if ny != self.Ny or nx != self.Nx:
# binning must be done for the intensity-data, as that preserves the normalization
if self.from_intensity_data:
self.data = bin_array(self.data, (self.Nz, self.Ny, self.Nx))
else:
self.data = np.sqrt(bin_array(self.data ** 2, (self.Nz, self.Ny, self.Nx)))
self.Nz, self.Ny, self.Nx = self.data.shape
# Calculate electric field and norm of the data
if self.from_intensity_data:
# avoid sqrt of negative numbers (due to background subtraction)
self.data = np.where(self.data > 0, np.sqrt(abs(self.data)), np.zeros_like(self.data))
self.norm_data = np.sqrt(np.sum(self.data ** 2))
# Object support determination
if self.support is not None:
self.support = imread(self.support_filename)
else:
self.support = support_from_stack(self.data,
threshold=self.threshold_for_support,
absolute_autocorrelation=True,
binary_dilate_support=1)
if self.use_sparsity_with_support:
self.sparsity_support = self.support
self.createRandomGuess()
# some variables which are necessary for the algorithm:
self.data_sq = self.data ** 2
self.data_zeros = np.where(self.data == 0)
def createRandomGuess(self):
"""
Taking the measured data, add a random phase and calculate the resulting iterate guess
"""
if self.recover_background:
n = self.n_states + 1
else:
n = self.n_states
self.u0 = np.array([self.data * np.exp(1j * 2 * np.pi * np.random.random_sample(self.data.shape))
for i in range(n)])
self.u0 = shifted_fft(self.u0, axes=(-1, -2))
def setupProxOperators(self):
"""
Determine the prox operators to be used based on the given constraint.
This method is called during the initialization process.
sets the parameters:
- self.proxOperators
- self.propagator and self.inverse_propagator
"""
# Select the right real space operator sparsity-based proxoperators
self.proxOperators.append('P_Sparsity_real_incoherent')
# Apply orthonormality constraint
self.proxOperators.append('P_orthonorm')
# Modulus proxoperator (normally the second operator)
self.proxOperators.append('P_M')
self.propagator = 'PropagatorFFT2'
self.inverse_propagator = 'InvPropagatorFFT2'
self.nProx = len(self.proxOperators)
def plotInputData(self):
"""Quick plotting routine to show the data, initial guess and the sparsity support"""
fig, ax = plt.subplots(2, self.Nz + 1, figsize=(12, 7))
for ii in range(self.Nz):
im = ax[0][ii].imshow(self.data[ii])
plt.colorbar(im, ax=ax[0][ii])
ax[0][ii].set_title("Photoelectron spectrum %d" % ii)
if self.sparsity_support is not None:
im = ax[0][-1].imshow(self.sparsity_support, cmap='gray')
# plt.colorbar(im, ax=ax[2])
ax[0][-1].set_title("Sparsity support")
for ii in range(self.Nz):
im = ax[1][ii].imshow(complex_to_rgb(self.u0[ii]))
plt.colorbar(im, ax=ax[1][ii])
ax[1][ii].set_title("Degenerate orbit %d" % ii)
ax[1][-1].imshow(np.sum(abs(self.u0) ** 2, axis=0))
ax[1][-1].set_title("Integrated density of states")
plt.show()
def show(self, **kwargs):
"""
Create basic result plots of the phase retrieval procedure
"""
super(PlanarMolecule, self).show()
self.output['u1'] = self.algorithm.prox1.eval(self.algorithm.u)
self.output['u2'] = self.algorithm.prox2.eval(self.algorithm.u)
figsize = kwargs.pop("figsize", (12, 6))
for i, operator in enumerate(self.algorithm.proxOperators):
operator_name = self.proxOperators[i].__name__
f = self.plot_guess(operator.eval(self.algorithm.u),
name="%s satisfied" % operator_name,
show=False,
interpolate_and_zoom=self.interpolate_and_zoom,
figsize=figsize)
self.output['plots'].append(f)
plt.show()
def plot_guess(self, u, name=None, show=True, interpolate_and_zoom=False, figsize=(12, 6)):
""""Given a list of fields, plot the individual fields and the combined intensity"""
prop = self.propagator(self) # This is not a string but the indicated class itself, to be instantiated
u_hat = prop.eval(u)
fourier_intensity = np.sqrt(np.sum(abs(u_hat) ** 2, axis=0))
if interpolate_and_zoom:
u_show = interp_zoom_field(u)
else:
u_show = u
fig, ax = plt.subplots(2, len(u) + 1, figsize=figsize, num=name)
for ii in range(self.Nz):
im = ax[0][ii].imshow(complex_to_rgb(u_show[ii]))
ax[0][ii].set_title("Degenerate orbit %d" % ii)
im = ax[0][-1].imshow(np.sum(abs(u_show) ** 2, axis=0))
ax[0][-1].set_title("Local density of states")
for ii in range(self.Nz):
im = ax[1][ii].imshow(complex_to_rgb(u_hat[ii]))
ax[1][ii].set_title("Fourier domain %d" % ii)
# plt.colorbar(im, ax=ax[-2], shrink=0.7)
im = ax[1][-1].imshow(fourier_intensity)
ax[1][-1].set_title("Total Fourier domain intensity")
# plt.colorbar(im, ax=ax[-1], shrink=0.7)
plt.tight_layout()
if show:
plt.show()
return fig
def support_from_stack(input_array: np.ndarray,
threshold: float = 0.1,
relative_threshold: bool = True,
input_in_fourier_domain: bool = True,
absolute_autocorrelation: bool = True,
binary_dilate_support: int = 0) -> np.ndarray:
"""
Determine an initial support from a list of autocorrelations.
Args:
input_array: either the measured diffraction patterns (arpes patterns) or guesses of the objects
threshold: support is everywhere where the autocorrelation is higher than the threshold
relative_threshold: If true, threshold at threshold * np.max(autocorrelation)
input_in_fourier_domain: False if a guess of the object is given in input_array
absolute_autocorrelation: Take the absolute value of the autocorrelation? (Generally a
good idea for objects which are not non-negative)
binary_dilate_support: number of dilation operations to apply to the support.
Returns:
support array (dimensions: input_array.shape[1:], dtype=np.int)
"""
_axes = tuple(range(-1 * input_array.ndim + 1, 0))
if not input_in_fourier_domain:
kspace = shifted_fft(input_array, axes=_axes)
else:
kspace = input_array
# Taking absolute value of the Fourier transform yields autocorrelation by conv. theorem)
autocorrelation = shifted_ifft(abs(kspace), axes=_axes)
if absolute_autocorrelation:
autocorrelation = abs(autocorrelation)
# Take the sum along the first axis to get the average of the autocorrelations
autocorrelation = np.sum(autocorrelation, axis=0)
# Determine thresholding
maxval = np.amax(autocorrelation)
if relative_threshold:
threshold_val = threshold * maxval
else:
threshold_val = threshold
support = (autocorrelation > threshold_val).astype(np.uint)
# Dilate support to make it a bit too big (also fills small gaps)
if binary_dilate_support > 0:
support = binary_dilation(support, iterations=binary_dilate_support).astype(np.uint)
return support
def interp_zoom_field(u, interpolation=2, zoom=0.5):
"""
interpolate a field and zoom in to the center
"""
nt, ny, nx = u.shape
cm = center_of_mass(np.sum(abs(u) ** 2, axis=0))
to_shift = (0, -1 * int(np.round(cm[0] - ny / 2)), -1 * int(np.round(cm[1] - nx / 2)))
centered = np.roll(u, to_shift, axis=(0, 1, 2))
zmy = int(ny * zoom) // 2
zmx = int(nx * zoom) // 2
zoomed = centered[:, zmy:ny - zmy, zmx:nx - zmx]
interpolated = np.array([fourier_interpolate(u_i, factor=interpolation) for u_i in zoomed])
return interpolated
"""
Written by Matthijs Jansen, November 2020.
Contains proxoperator to apply to multidimensional photoelectron spectroscopy data, in order to reconstruct a set
of molecular orbitals as well as background contributions
"""
from numpy import sum, where, moveaxis
from numpy.core._multiarray_umath import sqrt, array
from proxtoolbox.proxoperators import ProxOperator, P_M, P_Sparsity_real
__all__ = ['P_nondispersive']
class P_nondispersive(ProxOperator):
"""
Constrain molecular orbitals to be non-dispersive, i.e. u(k) should not depend on energy
"""
def __init__(self, experiment):
super(P_nondispersive, self).__init__(experiment)
self.energy_axis = experiment.energy_axis
self.state_axis = experiment.state_axis
self.momentum_axes = experiment.momentum_axes
self.ns = experiment.u0.shape[self.state_axis]
self.ne = experiment.u0.shape[self.energy_axis]
def eval(self, u, prox_idx=None):
if self.state_axis !=0 and self.energy_axis != 1:
u = moveaxis(u, [self.state_axis, self.energy_axis], [0, 1])
out = u.copy()
profiles = sum(u, axis=self.energy_axis) # Integrate over the energy axis
profiles /= sqrt(sum(abs(profiles)**2, axis=self.momentum_axes))
norms = sqrt(sum(abs(u)**2, axis=self.momentum_axes))
for s in range(self.ns):
for e in range(self.ne):
out[s, e] = profiles[e] * norms[s, e]
if self.state_axis !=0 and self.energy_axis != 1:
out = moveaxis(out, [0, 1], [self.state_axis, self.energy_axis])
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment