Commit 5dfa6715 authored by Leon Merten Lohse's avatar Leon Merten Lohse
Browse files

major refactor of vacuum propagation code

parent ad301752
Pipeline #210873 passed with stages
in 1 minute and 34 seconds
......@@ -32,15 +32,6 @@ def fftfreqn(n, dx=1):
return np.meshgrid(*xi, sparse=True, indexing="ij")
def gridn(n):
n = np.asarray(n)
ndim = n.size
x = [np.arange(-n[dim] + 1, n[dim]) for dim in range(ndim)]
return np.meshgrid(*x, sparse=True, indexing="ij")
def crop(a, crop_width):
slices = [slice(w[0], -w[1] if w[1] > 0 else None) for w in crop_width]
......
......@@ -8,11 +8,29 @@ from scipy.fft import fftn, ifftn
from . import hankel
from .misc import fftfreqn, gridn, crop, squaresum
from .misc import fftfreqn, crop, squaresum
def _verify_fresnel_numbers(fresnel_numbers, ndim):
fresnel_numbers = np.asarray(fresnel_numbers)
if np.ndim(fresnel_numbers) == 0:
fresnel_numbers = fresnel_numbers * np.ones((1, ndim))
elif np.ndim(fresnel_numbers) == 1:
fresnel_numbers = fresnel_numbers[:, np.newaxis] * np.ones((1, ndim))
elif np.ndim(fresnel_numbers) == 2:
if fresnel_numbers.shape[1] != ndim:
raise ValueError("Parameter `fresnel_numbers` incompatible with data dimension.")
# fresnel_numbers has the correct shape
else:
raise ValueError("Parameter `fresnel_numbers` must either be scalar, 1d, or 2d array.")
return fresnel_numbers
class FresnelPropagatorCS:
def __init__(self, nsamples, fresnel_numbers, npad=2):
def __init__(self, nsamples, fresnel_numbers, npad=1):
# TODO(Leon): verify input
self._nsamples = nsamples
self._npad = npad
......@@ -20,29 +38,21 @@ class FresnelPropagatorCS:
if np.any(self._npad < 1):
raise ValueError("Padding factor `npad` cannot be less than 1.")
fresnel_numbers = np.asarray(fresnel_numbers)
if np.ndim(fresnel_numbers) == 0:
self._ndist = 1
elif np.ndim(fresnel_numbers) == 1:
self._ndist = fresnel_numbers.shape[0]
else:
raise ValueError(
"Parameter 'fresnel_numbers' must either be scalar or one-dimensional."
)
self.fresnel_numbers = fresnel_numbers * np.ones((self._ndist,))
self.fresnel_numbers = _verify_fresnel_numbers(fresnel_numbers, 1)[:, 0] # this is 1d
self._ndist = self.fresnel_numbers.shape[0]
# calculate amount of padding
self._pad_width = (self._nsamples * (self._npad - 1) / 2).astype(int)
self._pad_width = np.asarray(self._nsamples * (self._npad - 1) / 2, dtype=int)
self._nsamples_pad = self._nsamples + 2 * self._pad_width
self._kern = self._init_kernel()
self._dht = hankel.DiscreteHankelTransform(nsamples, 0)
self.kernel = self._init_kernel()
self._dht = hankel.DiscreteHankelTransform(self._nsamples_pad, 0)
def _init_kernel(self):
samples_pad = self._nsamples + 2 * self._pad_width
freq = hankel.hankel_freq(samples_pad)
kern = np.exp(-1j * np.pi / self.fresnel_numbers * np.square(freq))
freq = hankel.hankel_freq(self._nsamples_pad)
kern = np.exp(
-1j * np.pi / self.fresnel_numbers[:, np.newaxis] * np.square(freq)[np.newaxis, :]
)
return kern
......@@ -53,25 +63,37 @@ class FresnelPropagatorCS:
raise ValueError("Invalid shape.")
# pad with zeros
upad = np.pad(u, self._pad_width)
upad = np.pad(u, (0, 2 * self._pad_width))
hf_u = self._dht(upad)
hf_uprop = self._kern * hf_u
hf_uprop = self.kernel * hf_u[np.newaxis, :]
uprop = np.zeros_like(u, shape=(self._ndist, *u.shape))
for dist in self._ndist:
for dist in range(self._ndist):
# TODO(Leon): modify dht so that it can process the entire stack
tmp = self._dht(hf_uprop[dist, :])
uprop[dist, :] = crop(tmp, (self._pad_width, self._pad_width))
uprop[dist, :] = crop(
tmp,
[
(0, 2 * self._pad_width),
],
)
return uprop
if self._ndist == 1:
return uprop[0, :]
else:
return uprop
class FTConvolutionPropagator:
"""
Blub
Base class for propagation methods that are based on Fourier-transform (FT) convolutions.
This class is not intended to be used directly and does nothing.
"""
def __init__(self, shape, npad):
self._shape = shape
self._shape = tuple(shape)
self._ndim = len(self._shape)
self._npad = np.asarray(npad) * np.ones((self._ndim,))
......@@ -80,8 +102,10 @@ class FTConvolutionPropagator:
# calculate amount of padding
self._pad_width = (np.asarray(self._shape) * (self._npad - 1) / 2).astype(int)
self._shape_pad = self._shape + 2 * self._pad_width
self.kernel = 1
self._ndist = 1
def __call__(self, u, workers=-1):
u = np.asarray(u)
......@@ -92,9 +116,12 @@ class FTConvolutionPropagator:
# pad with zeros
upad = np.pad(u, self._pad_width)
ft_u = fftn(upad, axes=(-2, -1), workers=workers)
# compute fft over last ndim axes
axes = list(range(-self._ndim, 0))
ft_u = fftn(upad, axes=axes, workers=workers)
ft_uprop = ft_u * self.kernel
uproppad = ifftn(ft_uprop, axes=(-2, -1), workers=workers)
uproppad = ifftn(ft_uprop, axes=axes, workers=workers)
# crop central part
pad_width_full = [
......@@ -102,32 +129,29 @@ class FTConvolutionPropagator:
] + list(self._pad_width)
uprop = crop(uproppad, zip(pad_width_full, pad_width_full))
return uprop
if self._ndist == 1:
return uprop[0, :]
else:
return uprop
class FresnelTFPropagator(FTConvolutionPropagator):
"""
Blub
Fresnel Transfer Function (Fres-TF) method.
"""
def __init__(self, shape, fresnel_numbers, npad=2):
super(FresnelTFPropagator, self).__init__(shape, npad)
fresnel_numbers = np.asarray(fresnel_numbers)
if np.ndim(fresnel_numbers) == 0:
self._ndist = 1
else:
self._ndist = fresnel_numbers.shape[0]
self.fresnel_numbers = fresnel_numbers * np.ones((self._ndist, self._ndim))
self.fresnel_numbers = _verify_fresnel_numbers(fresnel_numbers, self._ndim)
self._ndist = self.fresnel_numbers.shape[0]
self.kernel = self._init_kernel()
def _init_kernel(self):
shape_pad = np.asarray(self._shape) + 2 * self._pad_width
f = fftfreqn(shape_pad)
f = fftfreqn(self._shape_pad)
kernel = np.ones((self._ndist, *shape_pad), dtype=np.complex128)
kernel = np.ones((self._ndist, *self._shape_pad), dtype=np.complex128)
for dist in range(self._ndist):
for dim in range(self._ndim):
kernel[dist, :] *= np.exp(
......@@ -137,79 +161,108 @@ class FresnelTFPropagator(FTConvolutionPropagator):
return kernel
def _gridn(n):
n = np.asarray(n)
ndim = n.size
# magic from fftfreq
x = [
np.concatenate([np.arange(0, (n[dim] - 1) // 2 + 1), np.arange(-(n[dim] // 2), 0)])
for dim in range(ndim)
]
return np.meshgrid(*x, sparse=True, indexing="ij")
class FresnelIRPropagator(FTConvolutionPropagator):
"""
Blub
Fresnel Impulse Response (Fres-IR) method. Goodman calls this the "Convolution approach".
"""
def __init__(self, shape, fresnel_numbers, npad=2):
super(FresnelIRPropagator, self).__init__(shape, npad)
fresnel_numbers = np.asarray(fresnel_numbers)
if np.ndim(fresnel_numbers) == 0:
self._ndist = 1
else:
self._ndist = fresnel_numbers.shape[0]
self.fresnel_numbers = fresnel_numbers * np.ones((self._ndist, self._ndim))
self.fresnel_numbers = _verify_fresnel_numbers(fresnel_numbers, self._ndim)
self._ndist = self.fresnel_numbers.shape[0]
self.kernel = self._init_kernel()
def _init_kernel(self):
shape_pad = np.asarray(self._shape) + 2 * self._pad_width
# 1. assemble convolution kernel
conv_kernel = np.ones((self._ndist, *shape_pad), dtype=np.complex128)
# 1. assemble real-space convolution kernel (impulse response)
conv_kernel = np.ones((self._ndist, *self._shape_pad), dtype=np.complex128)
x = gridn(shape_pad)
x = _gridn(self._shape_pad)
for dist in range(self._ndist):
# phase factor
conv_kernel *= np.prod(
# scaling and phase factor
conv_kernel[dist, ...] *= np.prod(
np.sqrt(np.abs(self.fresnel_numbers[dist, :]))
* np.exp((-1j * np.pi / 4) * np.sign(self.fresnel_numbers[dist, :]))
)
) * np.prod(np.exp((-1j * np.pi / 4) * np.sign(self.fresnel_numbers[dist, :])))
# chirp function
for dim in range(self._ndim):
conv_kernel *= np.exp(
conv_kernel[dist, ...] *= np.exp(
1j * np.pi * self.fresnel_numbers[dist, dim] * np.square(x[dim])
)
# 2. compute TF from convolution kernel
kernel = fftn(conv_kernel, axes=(-2, -1))
# 2. compute transfer function from convolution kernel
axes = list(range(-self._ndim, 0))
kernel = fftn(conv_kernel, axes=axes)
return kernel
class ASMPropagator(FTConvolutionPropagator):
"""
Blub
Angular Spectrum Method (ASM).
Goodman calls this the "Exact Transfer Function approach".
Voelz and Roggeman call this the "Rayleigh-Sommerfeld Transfer Function solution".
"""
_lambda0 = 1.0 # all lengths are in units of the wavelength
_k0 = 2 * np.pi / _lambda0
def __init__(self, shape, dperp, k, dz, npad=2):
def __init__(self, shape, dperp, dz, wl=1, npad=2, mask_evanescent=True):
super(ASMPropagator, self).__init__(shape, npad)
self._dperp = dperp
self._k = k
self._dz = dz
self._dperp = dperp # pixel size
self._wl = wl # relative wavelength
self.mask_evanescent = mask_evanescent
# calculate amount of padding
self._pad_width = (np.asarray(self._shape) * (self._npad - 1) / 2).astype(int)
dz = np.asarray(dz)
if np.ndim(dz) == 0:
self._dz = dz * np.ones((1,))
elif np.ndim(dz) == 1:
self._dz = dz
else:
raise ValueError("Parameter `dz` must either be scalar or 1d array.")
self._ndist = self._dz.shape[0]
self.kernel = self._init_kernel()
def _init_kernel(self):
# construct Rayleigh-Sommerfeld transfer function (Goodman: Fourier Optics, 4.2.3)
wl = self._k0 / self._k # relative wavelength
shape_pad = np.asarray(self._shape) + 2 * self._pad_width
f = fftfreqn(shape_pad, self._dperp) # spatial frequencies
f = fftfreqn(self._shape_pad, self._dperp) # spatial frequencies
f2 = squaresum(f)
mask = f2 < 1 / wl
phasechirp = np.sqrt(1 - np.square(wl) * (mask * f2))
kernel = mask * np.exp(1j * self._k * self._dz * phasechirp)
phase_chirp = np.sqrt(1 / np.square(self._wl) - f2.astype(np.complex))
mask = f2 < 1 / np.square(self._wl) if self.mask_evanescent else 1
# expand dimension of dz
dz = self._dz.reshape(
[
-1,
]
+ self._ndim
* [
1,
]
)
kernel = mask * np.exp(1j * self._k0 * dz * phase_chirp[np.newaxis, ...])
return kernel
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment