Commit aff72aa5 authored by jansen31's avatar jansen31
Browse files

new experiment class

parent 949f352d
from .planar_molecule import PlanarMolecule
# from .molecule_3d import Molecule3D
from .degenerate_orbits import DegenerateOrbital
__all__ = ['PlanarMolecule',
# 'Molecule3D',
'DegenerateOrbital']
\ No newline at end of file
from skimage.io import imread
from scipy.ndimage import binary_dilation
import numpy as np
import matplotlib.pyplot as plt
......@@ -16,6 +14,7 @@ class DegenerateOrbital(PlanarMolecule):
'data_filename': None,
'from_intensity_data': False,
'object': 'real',
'degeneracy': 2, # Number of degenerate states to reconstruct
'constraint': 'sparse real',
'sparsity_parameter': 40,
'use_sparsity_with_support': False,
......@@ -32,7 +31,7 @@ class DegenerateOrbital(PlanarMolecule):
'data_ball': .999826,
'TOL2': 1e-15,
'diagnostic': True,
'algorithm': 'DRl',
'algorithm': 'CP',
'iterate_monitor_name': 'FeasibilityIterateMonitor', # 'IterateMonitor', #
'rotate': False,
'verbose': 1,
......@@ -44,6 +43,7 @@ class DegenerateOrbital(PlanarMolecule):
def __init__(self, **kwargs):
super(DegenerateOrbital, self).__init__(**kwargs)
self.degeneracy = kwargs.pop("degeneracy", 2)
def loadData(self):
"""
......@@ -57,76 +57,71 @@ class DegenerateOrbital(PlanarMolecule):
- threshold_for_support: float, in range [0,1], fraction of the maximum at which to threshold when
determining support or support for sparsity
"""
# load data
if self.data_filename is None:
self.data_filename = input('Please enter the path to the datafile: ')
try:
self.data = imread(self.data_filename)
except FileNotFoundError:
print("Tried path %s, found nothing. " % self.data_filename)
self.data_filename = input('Please enter a valid path to the datafile: ')
self.data = imread(self.data_filename)
# Keep the same resolution?
ny, nx = self.data.shape
try:
if ny % self.Ny == 0 and nx % self.Nx == 0:
# binning must be done for the intensity-data, as that preserves the normalization
if self.from_intensity_data:
self.data = bin_array(self.data, (self.Ny, self.Nx))
else:
self.data = np.sqrt(bin_array(self.data ** 2, (self.Ny, self.Nx)))
else:
# TODO: use flexibility allowed by the binning function to prevent this error
raise ValueError('Incompatible values for Ny and Nx given in configuration dict')
except TypeError:
pass
self.Ny, self.Nx = self.data.shape
# Calculate electric field and norm of the data
if self.from_intensity_data:
# avoid sqrt of negative numbers (due to background subtraction)
self.data = np.where(self.data > 0, np.sqrt(abs(self.data)), np.zeros_like(self.data))
self.norm_data = np.sqrt(np.sum(self.data ** 2))
# Object support determination
if self.support is not None:
self.support = imread(self.support_filename)
else:
self.support = support_from_autocorrelation(self.data,
threshold=self.threshold_for_support,
absolute_autocorrelation=True,
binary_dilate_support=1)
if self.use_sparsity_with_support:
self.sparsity_support = support_from_autocorrelation(self.data,
threshold=self.threshold_for_support,
binary_dilate_support=1)
self.createRandomGuess()
# some variables wich are necessary for the algorithm:
self.data_sq = self.data ** 2
self.data_zeros = np.where(self.data == 0)
super(DegenerateOrbital, self).loadData()
def createRandomGuess(self):
"""
Taking the measured data, add a random phase and calculate the resulting iterate guess
"""
ph_init = 2 * np.pi * np.random.random_sample(self.data.shape)
self.u0 = self.data * np.exp(1j * ph_init)
self.u0 = np.fft.fftn(self.u0)
self.u0 = np.array([self.data * np.exp(1j * 2 * np.pi * np.random.random_sample(self.data.shape))
for i in range(self.degeneracy)])
self.u0 = np.fft.fftn(self.u0, axes=(-1, -2))
def setupProxOperators(self):
"""
Determine the prox operators to be used based on the given constraint.
This method is called during the initialization process.
sets the parameters:
- self.proxOperators
- self.propagator and self.inverse_propagator
"""
# Apply orthonormality constraint first
self.proxOperators.append('P_orthonorm')
# Select the right real space operator sparsity-based proxoperators
if self.constraint == 'sparse real':
self.proxOperators.append('P_Sparsity_real')
elif self.constraint == 'sparse complex':
self.proxOperators.append('P_Sparsity')
elif self.constraint in ['symmetric sparse real', 'sparse symmetric real']:
self.proxOperators.append('P_Sparsity_Symmetric_real')
elif self.constraint in ['symmetric sparse complex', 'symmetric sparse complex']:
self.proxOperators.append('P_Sparsity_Symmetric')
# Modulus proxoperator (normally the second operator)
if self.experiment_name == '3D ARPES':
self.proxOperators.append('P_M_masked')
self.propagator = 'PropagatorFFTn'
self.inverse_propagator = 'InvPropagatorFFTn'
elif self.experiment_name == 'noisy 2D ARPES':
# Apply modulus constraint when the difference is large enough (Approx_ prefix)
self.proxOperators.append('Approx_P_M')
self.propagator = 'PropagatorFFT2'
self.inverse_propagator = 'InvPropagatorFFT2'
else: # No noise case: always apply the modulus constraint
self.proxOperators.append('P_M')
self.propagator = 'PropagatorFFT2'
self.inverse_propagator = 'InvPropagatorFFT2'
self.nProx = len(self.proxOperators)
def plotInputData(self):
"""Quick plotting routine to show the data, initial guess and the sparsity support"""
fig, ax = plt.subplots(1, 3, figsize=(12, 3.5))
im = ax[0].imshow(self.data)
fig, ax = plt.subplots(2, self.degeneracy + 1, figsize=(12, 7))
im = ax[0][0].imshow(self.data)
plt.colorbar(im, ax=ax[0])
plt.title("Photoelectron spectrum")
ax[1].imshow(complex_to_rgb(self.u0))
plt.title("Initial guess")
ax[0][0].set_title("Photoelectron spectrum")
if self.sparsity_support is not None:
im = ax[2].imshow(self.sparsity_support, cmap='gray')
im = ax[0][-1].imshow(self.sparsity_support, cmap='gray')
# plt.colorbar(im, ax=ax[2])
plt.title("Sparsity support")
ax[0][-1].set_title("Sparsity support")
for ii in range(self.degeneracy):
im = ax[1][ii].imshow(complex_to_rgb(self.u0[ii]))
plt.colorbar(im, ax=ax[1][ii])
ax[1][ii].set_title("Degenerate orbit %d" % ii)
ax[1][-1].imshow(np.sum(abs(self.u0) ** 2, axis=0))
ax[1][-1].set_title("Local density of states")
plt.show()
def show(self):
......@@ -136,37 +131,40 @@ class DegenerateOrbital(PlanarMolecule):
self.output['plots'] = []
self.output['u1'] = self.algorithm.prox1.eval(self.algorithm.u)
self.output['u2'] = self.algorithm.prox2.eval(self.algorithm.u)
if self.interpolate_and_zoom:
for key in ["u1", "u2"]:
center = tuple([s // 2 for s in self.output[key].shape])
self.output[key] = roll_to_pos(self.output[key], pos=center, move_maximum=True)
self.output[key] = roll_to_pos(self.output[key], pos=center)
self.output[key] = fourier_interpolate(self.output[key], 2)
zmy, zmx = tuple([s // 4 for s in self.output[key].shape])
self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx]
# if self.interpolate_and_zoom:
# for key in ["u1", "u2"]:
# center = tuple([s // 2 for s in self.output[key].shape])
# self.output[key] = roll_to_pos(self.output[key], pos=center, move_maximum=True)
# self.output[key] = roll_to_pos(self.output[key], pos=center)
# self.output[key] = fourier_interpolate(self.output[key], 2)
# zmy, zmx = tuple([s // 4 for s in self.output[key].shape])
# self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx]
u1 = self.output['u1']
u2 = self.output['u2']
change = self.output['stats']['changes']
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))
im = ax1.imshow(abs(u1), cmap='gray')
f.colorbar(im, ax=ax1)
ax1.set_title('best approximation amplitude: \nphysical constraint satisfied')
ax2.imshow(complex_to_rgb(u1))
ax2.set_title('best approximation phase: \nphysical constraint satisfied')
im = ax3.imshow(abs(u2), cmap='gray')
f.colorbar(im, ax=ax3)
ax3.set_title('best approximation amplitude: \nFourier constraint satisfied')
ax4.imshow(complex_to_rgb(u2))
ax4.set_title('best approximation amplitude: \nFourier constraint satisfied')
f.tight_layout()
self.plot_guess(u1, name='Best approximation: physical constraint satisfied', show=False)
self.plot_guess(u2, name='Best approximation: Fourier constraint satisfied', show=False)
# f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(9, 6))
# im = ax1.imshow(abs(u1), cmap='gray')
# f.colorbar(im, ax=ax1)
# ax1.set_title('best approximation amplitude: \nphysical constraint satisfied')
# ax2.imshow(complex_to_rgb(u1))
# ax2.set_title('best approximation phase: \nphysical constraint satisfied')
# im = ax3.imshow(abs(u2), cmap='gray')
# f.colorbar(im, ax=ax3)
# ax3.set_title('best approximation amplitude: \nFourier constraint satisfied')
# ax4.imshow(complex_to_rgb(u2))
# ax4.set_title('best approximation amplitude: \nFourier constraint satisfied')
# f.tight_layout()
g, ((bx1, bx2), (bx3, bx4)) = plt.subplots(2, 2, figsize=(9, 6))
im = bx1.imshow(abs(u1), cmap='gray')
f.colorbar(im, ax=bx1)
g.colorbar(im, ax=bx1)
bx1.set_title('best approximation amplitude: \nphysical constraint satisfied')
im = bx2.imshow(u1.real, cmap='gray')
f.colorbar(im, ax=bx2)
g.colorbar(im, ax=bx2)
bx2.set_title('best approximation phase: \nphysical constraint satisfied')
bx3.semilogy(change)
bx3.set_xlabel('iteration')
......@@ -176,12 +174,12 @@ class DegenerateOrbital(PlanarMolecule):
bx4.semilogy(gaps)
bx4.set_xlabel('iteration')
bx4.set_ylabel('Gap')
f.tight_layout()
g.tight_layout()
h, ax = plt.subplots(1, 3, figsize=(9,3))
h, ax = plt.subplots(1, 3, figsize=(9, 3))
ax[0].imshow(self.data)
ax[0].set_title("Measured data")
prop = self.propagator(self)
prop = self.propagator.eval(self)
u_hat = prop.eval(self.algorithm.prox1.eval(self.algorithm.u))
ax[1].imshow(abs(u_hat))
ax[1].set_title("Predicted measurement intensity")
......@@ -197,43 +195,15 @@ class DegenerateOrbital(PlanarMolecule):
# def saveOutput(self, **kwargs):
# super(PlanarMolecule, self).saveOutput(**kwargs)
def support_from_autocorrelation(input_array: np.ndarray,
threshold: float = 0.1,
relative_threshold: bool = True,
input_in_fourier_domain: bool = True,
absolute_autocorrelation: bool = True,
binary_dilate_support: int = 0) -> np.ndarray:
"""
Determine an initial support from the autocorrelation of an object.
Args:
input_array: either the measured diffraction (arpes pattern) or a guess of the object
threshold: support is everywhere where the autocorrelation is higher than the threshold
relative_threshold: If true, threshold at threshold*np.amax(autocorrelation)
input_in_fourier_domain: False if a guess of the object is given in input_array
absolute_autocorrelation: Take the absolute value of the autocorrelation? (Generally a
good idea for objects which are not non-negative)
binary_dilate_support: number of dilation operations to apply to the support.
Returns:
support array (same dimensions as input, dtype=np.int)
"""
if not input_in_fourier_domain:
kspace = shifted_fft(input_array)
else:
kspace = input_array
autocorrelation = shifted_ifft(abs(kspace)) # Taking absolute value yields autocorrelation by conv. theorem)
if absolute_autocorrelation:
autocorrelation = abs(autocorrelation)
maxval = np.amax(autocorrelation)
if relative_threshold:
threshold_val = threshold * maxval
else:
threshold_val = threshold
support = (autocorrelation > threshold_val).astype(np.uint)
if binary_dilate_support > 0:
support = binary_dilation(support, iterations=binary_dilate_support).astype(np.uint)
return support
def plot_guess(self, u, name=None, show=True):
""""Given a list of fields, plot the individual fields and the combined intensity"""
fig, ax = plt.subplots(1, len(u) + 1, figsize=(12, 3.5), num=name)
for ii in range(self.degeneracy):
im = ax[ii].imshow(complex_to_rgb(u[ii]))
plt.colorbar(im, ax=ax[ii])
ax[ii].set_title("Degenerate orbit %d" % ii)
ax[-1].imshow(np.sum(abs(u) ** 2, axis=0))
ax[-1].set_title("Local density of states")
plt.tight_layout()
if show:
plt.show()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment