Commit 58bf802a authored by Matthijs's avatar Matthijs
Browse files

towards 3d with negative z momenta added

parent 93581134
from proxtoolbox.Problems.OrbitalTomog.planar_molecule.orbitaltomog_data_processor import support_from_autocorrelation from proxtoolbox.Problems.OrbitalTomog.planar_molecule.orbitaltomog_data_processor import support_from_autocorrelation
from proxtoolbox.Utilities.OrbitalTomog.array_tools import shifted_fft from proxtoolbox.Utilities.OrbitalTomog import shifted_fft, bin_array, pad_to_square
from proxtoolbox.Utilities.OrbitalTomog.binning import bin_array
from proxtoolbox.Problems.OrbitalTomog.Graphics.stack_viewer import XYZStackViewer from proxtoolbox.Problems.OrbitalTomog.Graphics.stack_viewer import XYZStackViewer
import numpy as np import numpy as np
from skimage.io import imread from skimage.io import imread
from itertools import combinations
def data_processor(config): def data_processor(config):
# Load data
inp = imread(config['data_filename']) inp = imread(config['data_filename'])
# Load measurement sensitivity, determine data_zeros
sensitivity = imread(config['sensitivity_filename'])
# Add the negative K_z part to determine a real-valued orbital?
if 'add_negative_kz' in config and config['add_negative_kz']:
inp, sens = mirror_kspace(inp, sensitivity)
ny, nx, nz = inp.shape ny, nx, nz = inp.shape
config['data'] = abs(inp) config['data'] = abs(inp)
config['data_zeros'] = sensitivity == 0
# Keep the same resolution? # Keep the same resolution?
if 'Ny' not in config or 'Nx' not in config or 'Nz' not in config: if 'Ny' not in config or 'Nx' not in config or 'Nz' not in config:
print('Setting problem dimensions based on data') print('Setting problem dimensions based on data')
config['Ny'], config['Nx'], config['Nz'] = ny, nx, nz config['Ny'], config['Nx'], config['Nz'] = ny, nx, nz
elif ny != config['Ny'] or nx != config['Nx'] or nz != config['Nz']: elif ny != config['Ny'] or nx != config['Nx'] or nz != config['Nz']:
nandata = np.where(config['data_zeros'], np.nan, config['data'])
if not ('from intensity data' in config and config['from intensity data']): if not ('from intensity data' in config and config['from intensity data']):
# binning must be done for the intensity-data, as that preserves the normalization # binning must be done for the intensity-data, as that preserves the normalization
config['data'] = np.sqrt(bin_array(config['data'] ** 2, (config['Ny'], config["Nx"], config['Nz']))) config['data'] = np.sqrt(np.nan_to_num(bin_array(nandata ** 2, (config['Ny'], config["Nx"], config['Nz']))))
else: else:
config['data'] = bin_array(config['data'], (config['Ny'], config["Nx"])) config['data'] = np.nan_to_num(bin_array(nandata, (config['Ny'], config["Nx"])))
# elif ny == config['Ny'] and nx == config['Nx'] or nz == config['Nz']: config["data_zeros"] = bin_array(config['data_zeros'], (config['Ny'], config["Nx"], config['Nz'])) == 0
# pass
# else:
# raise ValueError('Incompatible values for Ny, Nx, Nz given in configuration dict')
# Load measurement sensitivity, determine data_zeros
sensitivity = imread(config['sensitivity_filename'])
config['data_zeros'] = sensitivity == 0
assert config['data_zeros'].shape == config['data'].shape, 'Non-matching sensitivity and data arrays' assert config['data_zeros'].shape == config['data'].shape, 'Non-matching sensitivity and data arrays'
# Calculate electric field # Calculate electric field
...@@ -60,15 +62,20 @@ def data_processor(config): ...@@ -60,15 +62,20 @@ def data_processor(config):
binary_dilate_support=1) binary_dilate_support=1)
# Initial guess # Initial guess
ph_init = 2 * np.pi * np.random.random_sample(inp.shape) ph_init = 2 * np.pi * np.random.random_sample(config['data'].shape)
config['u_0'] = inp * np.exp(1j * ph_init) config['u_0'] = config['data'] * np.exp(1j * ph_init)
if config['dataprocessor_plotting']: if config['dataprocessor_plotting']:
input_viewer = XYZStackViewer(inp, cmap='viridis') input_viewer = XYZStackViewer(config['data'], cmap='viridis')
config['input viewer'] = input_viewer # the reference keeps the matplotlib object alive config['input viewer'] = input_viewer # the reference keeps the matplotlib object alive
support_viewer = XYZStackViewer(shifted_fft(inp).real) support_viewer = XYZStackViewer(shifted_fft(config['data']).real)
config['support viewer'] = support_viewer # the reference keeps the matplotlib object alive config['support viewer'] = support_viewer # the reference keeps the matplotlib object alive
if 'fourier_shift_arrays' in config and config['fourier_shift_arrays']:
for key in ['u_0', 'support', 'sparsity_support', 'data', 'data_zeros']:
if key in config:
config[key] = np.fft.fftshift(config[key])
# Other settings # Other settings
config['fresnel_nr'] = 0 config['fresnel_nr'] = 0
config['FT_conv_kernel'] = 1 config['FT_conv_kernel'] = 1
...@@ -77,3 +84,54 @@ def data_processor(config): ...@@ -77,3 +84,54 @@ def data_processor(config):
config['data_sq'] = abs(config['data']) ** 2 config['data_sq'] = abs(config['data']) ** 2
return config return config
def mirror_kspace(kspace: np.ndarray, sensitivity: np.ndarray = None, shift_mirror: tuple = (1, 1, 1),
square_array: bool = True):
"""
Mirror a 3d kspace array in the kz=0 plane.
:param kspace: kspace array, e.g. from arpes_wvlscan_to_kspace
:param sensitivity: if given (from arpes_wvlscan_to_kspace), test
:param shift_mirror: array rolling done to get good centering. is tested by sensitivity
:param square_array: pad array to be square
:return: 3d arrays kspace and sensitivity (if given)
"""
def mirror_array(arr: np.ndarray, roll=shift_mirror) -> np.ndarray:
arr = np.concatenate([np.roll(np.flip(arr), (roll[0], roll[1]), axis=(0, 1)), arr[:, :, 1:]],
axis=2)
arr = np.roll(arr, roll[2], axis=2)
return arr
full_kspace = mirror_array(kspace)
if square_array:
if np.any(np.isnan(full_kspace)):
cv = np.nan
else:
cv = 0
full_kspace = pad_to_square(full_kspace, constant_values=cv)
if sensitivity is None:
return full_kspace
else:
mirrored_sens = mirror_array(sensitivity)
if np.any(np.isnan(mirrored_sens)):
cv = np.nan
else:
cv = 0
mirrored_sens = pad_to_square(mirrored_sens, constant_values=cv)
testslice = [len(mirrored_sens) // i for i in [2, 3, 4]]
test_array = np.where(np.isnan(mirrored_sens), 0, 1)
for tsl in testslice:
testslices = [test_array[tsl],
test_array[:, tsl],
test_array[:, :, tsl],
test_array[-tsl],
test_array[:, -tsl],
test_array[:, :, -tsl]]
test_res = []
for a, b in combinations(testslices, 2):
test_res += [np.all(np.isclose(a, b))]
assert np.all(
test_res), 'Non-matching test slices, indicating that shift_mirror parameter should be changed'
return full_kspace, mirrored_sens
...@@ -11,6 +11,7 @@ new_config = { ...@@ -11,6 +11,7 @@ new_config = {
'data_filename': 'pentacene_3d_arpes.tif', # In the directory '../../../InputData/OrbitalTomog/' 'data_filename': 'pentacene_3d_arpes.tif', # In the directory '../../../InputData/OrbitalTomog/'
'from intensity data': False, # File gives field amplitudes 'from intensity data': False, # File gives field amplitudes
'sensitivity_filename': 'pentacene_3d_arpes_sensitivity.tif', 'sensitivity_filename': 'pentacene_3d_arpes_sensitivity.tif',
'add_negative_kz': False,
# What type of object are we working with? # What type of object are we working with?
# Options are: 'phase', 'real', 'nonnegative', 'complex' # Options are: 'phase', 'real', 'nonnegative', 'complex'
...@@ -20,7 +21,7 @@ new_config = { ...@@ -20,7 +21,7 @@ new_config = {
# Options are: 'support only', 'real and support', 'nonnegative and support', # Options are: 'support only', 'real and support', 'nonnegative and support',
# 'amplitude only', 'sparse real', 'sparse complex', and 'hybrid' # 'amplitude only', 'sparse real', 'sparse complex', and 'hybrid'
# 'symmetric sparse real', 'symmetric sparse complex' # 'symmetric sparse real', 'symmetric sparse complex'
'constraint': 'sparse real', 'constraint': 'sparse complex',
# What type of measurements are we working with? # What type of measurements are we working with?
# Options are: 'single diffraction', 'diversity diffraction', # Options are: 'single diffraction', 'diversity diffraction',
...@@ -41,16 +42,6 @@ new_config = { ...@@ -41,16 +42,6 @@ new_config = {
# 'Nx': 64, # 'Nx': 64,
# 'Nz': 64, # 'Nz': 64,
# 'fresnel_nr' : 0,
# moved this to phase
# if(strcmp('distance,'near field'))
# 'fresnel_nr' : 1*2*pi*'Nx,
# else
# 'fresnel_nr' : 0, #1*2*pi*'Nx,
# 'magn' : 1,
# What are the noise characteristics (Poisson or Gaussian)? # What are the noise characteristics (Poisson or Gaussian)?
'noise': 'Poisson', 'noise': 'Poisson',
# ========================================== # ==========================================
...@@ -156,8 +147,8 @@ new_config = { ...@@ -156,8 +147,8 @@ new_config = {
# default is 1. # default is 1.
'graphics_display': 'Phase_graphics_3d', # name of the plotting routine 'graphics_display': 'Phase_graphics_3d', # name of the plotting routine
'dataprocessor_plotting': True, 'dataprocessor_plotting': True,
'interpolate_result': False, 'interpolate_result': True, # default interpolate by a factor 2
'zoomin_on_result': False 'zoomin_on_result': True, # Default zoom in to 50% of the field of view
'fourier_shift_arrays': True # If the data is centered in the middle of the array(image
} }
...@@ -7,6 +7,7 @@ from proxtoolbox.ProxOperators.proxoperators import ProxOperator ...@@ -7,6 +7,7 @@ from proxtoolbox.ProxOperators.proxoperators import ProxOperator
from proxtoolbox.Problems.OrbitalTomog import Graphics from proxtoolbox.Problems.OrbitalTomog import Graphics
from numpy.linalg import norm from numpy.linalg import norm
from numpy import square, sqrt from numpy import square, sqrt
from numpy.fft import fftshift
from proxtoolbox.Utilities.OrbitalTomog import interpolation, array_tools, binning from proxtoolbox.Utilities.OrbitalTomog import interpolation, array_tools, binning
...@@ -148,22 +149,23 @@ class Phase(Problem): ...@@ -148,22 +149,23 @@ class Phase(Problem):
""" """
Processes the solution and generates the output Processes the solution and generates the output
""" """
# Center the solution (since position is a degree of freedom, # Center the solution (since position is a degree of freedom, and if desired, interpolate the results.
# and if desired, interpolate the results.
center = tuple([s//2 for s in self.config['u'].shape]) center = tuple([s//2 for s in self.config['u'].shape])
for key in ['u', 'u1', 'u2']: for key in ['u', 'u1', 'u2']:
if 'fourier_shift_arrays' in self.config and self.config['fourier_shift_arrays']:
self.output[key] = fftshift(self.output[key])
self.output[key] = array_tools.roll_to_pos(self.output[key], pos=center, move_maximum=True) self.output[key] = array_tools.roll_to_pos(self.output[key], pos=center, move_maximum=True)
self.output[key] = array_tools.roll_to_pos(self.output[key], pos=center) self.output[key] = array_tools.roll_to_pos(self.output[key], pos=center)
# This sequence will work for objects *with a small support* even if they lie over the edge of the array # This sequence will work for objects *with a small support* even if they lie over the edge of the array
if 'interpolate_result' in self.config and self.config['interpolate_result']: if 'interpolate_result' in self.config and self.config['interpolate_result']:
self.output[key] = interpolation.fourier_interpolate(self.output[key]) self.output[key] = interpolation.fourier_interpolate(self.output[key])
if 'zoomin_on_result' in self.config and self.config['zoomin_on_result']: if 'zoomin_on_result' in self.config and self.config['zoomin_on_result']:
# TODO: if some support given, use bounding box of the support as zoom in region
if self.output[key].ndim == 2: if self.output[key].ndim == 2:
zmy, zmx = tuple([s//4 for s in self.output[key].shape]) zmy, zmx = tuple([s//4 for s in self.output[key].shape])
self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx] self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx]
elif self.output[key].ndim == 3: elif self.output[key].ndim == 3:
zmy, zmx, zmz = tuple([s//4 for s in self.output[key].shape]) zmy, zmx, zmz = tuple([s//4 for s in self.output[key].shape])
# (self.config['Ny'] // 4, self.config["Nx"] // 4, self.config['Nz'] // 4)
self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx, zmz:-zmz] self.output[key] = self.output[key][zmy:-zmy, zmx:-zmx, zmz:-zmz]
def show(self): def show(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment