Commit 1236de9a authored by Leon Merten Lohse's avatar Leon Merten Lohse
Browse files

code style

parent 902d39a3
Pipeline #174171 passed with stage
in 41 seconds
...@@ -2,71 +2,68 @@ import numpy as np ...@@ -2,71 +2,68 @@ import numpy as np
import scipy.special import scipy.special
def hankelMatrix(N, n=0): def hankelMatrix(N, n=0):
''' '''
returns a N x N matrix for discrete Hankel transfrom of n-th order. returns a N x N matrix for discrete Hankel transfrom of n-th order.
N: number of pixels N: number of pixels
n: order of Hankel transform n: order of Hankel transform
The Hankel matrix is self-inverse! I.e. HH = Id The Hankel matrix is self-inverse! I.e. HH = Id
For forward and backward Hankel transfrom different prefactors have to be considered! For forward and backward Hankel transfrom different prefactors have to be considered!
As in: Theory and operational rules for the discreteHankel transform As in: Theory and operational rules for the discreteHankel transform
by Natalie Baddour* and Ugo Chouinard by Natalie Baddour* and Ugo Chouinard
https://doi.org/10.1364/JOSAA.32.000611 https://doi.org/10.1364/JOSAA.32.000611
''' '''
jn = np.array(scipy.special.jn_zeros(n, N + 1)) jn = np.array(scipy.special.jn_zeros(n, N + 1))
k = np.expand_dims(np.arange(N), axis=0) k = np.expand_dims(np.arange(N), axis=0)
m = np.expand_dims(np.arange(N), axis=1) m = np.expand_dims(np.arange(N), axis=1)
jN = jn[-1] jN = jn[-1]
Y = scipy.special.jn(n, jn[m] * jn[k] / jN) # matrix Y = scipy.special.jn(n, jn[m] * jn[k] / jN) # matrix
Y *= 2 / (jN * scipy.special.jn(n + 1, jn[k]) ** 2) # prefactor Y *= 2 / (jN * scipy.special.jn(n + 1, jn[k]) ** 2) # prefactor
return Y return Y
def hankelFreq(N, n=0, kmax=0.5): def hankelFreq(N, n=0, kmax=0.5):
''' '''
Returns the Hankel space (frequency) sampling grid for the inverse discrete Returns the Hankel space (frequency) sampling grid for the inverse discrete
Hankel transfrom (of order n) of a signal with N pixels. Hankel transfrom (of order n) of a signal with N pixels.
kmax is the maximum sampling frequency in dimensionless units, i.e. kmax is the maximum sampling frequency in dimensionless units, i.e.
minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px) minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px)
-> 0.5 dimensionless -> 0.5 dimensionless
''' '''
jn = np.array(scipy.special.jn_zeros(n, N + 1)) jn = np.array(scipy.special.jn_zeros(n, N + 1))
return jn[:-1] * kmax / jn[N] return jn[:-1] * kmax / jn[N]
def hankelSamples(N, n=0, kmax=0.5): def hankelSamples(N, n=0, kmax=0.5):
''' '''
Returns the real space sampling grid for the forward discrete Hankel Returns the real space sampling grid for the forward discrete Hankel
transfrom (of order n) of a signal with N pixels. transfrom (of order n) of a signal with N pixels.
kmax is the maximum sampling frequency in dimensionless units, i.e. kmax is the maximum sampling frequency in dimensionless units, i.e.
minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px) minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px)
-> 0.5 dimensionless -> 0.5 dimensionless
''' '''
jn = np.array(scipy.special.jn_zeros(n, N)) jn = np.array(scipy.special.jn_zeros(n, N))
return jn / (kmax*2*np.pi) return jn / (kmax*2*np.pi)
class DiscreteHankelTransform: class DiscreteHankelTransform:
def __init__(self, N, n=0, kmax=0.5): def __init__(self, N, n=0, kmax=0.5):
self._matrix = hankelMatrix(N, n) self._matrix = hankelMatrix(N, n)
def __call__(self, x):
return self._matrix @ x
def __call__(self, x):
return self._matrix @ x
from pyfftw.interfaces.numpy_fft import fft,fftn,fftfreq,ifftn,ifftshift,fftshift from pyfftw.interfaces.numpy_fft import fftn, ifftn, ifftshift, fftshift
import numpy as np import numpy as np
from functools import reduce from functools import reduce
...@@ -7,7 +7,10 @@ from .misc import fftfreqn ...@@ -7,7 +7,10 @@ from .misc import fftfreqn
_lambda0 = 1. # all lengths are in units of the wavelength _lambda0 = 1. # all lengths are in units of the wavelength
_k0 = 2 * np.pi / _lambda0 _k0 = 2 * np.pi / _lambda0
_compute_potential = lambda n_, k_=_k0 : k_*k_ * (1 - n_*n_)
def _compute_potential(n_, k_=_k0):
return k_*k_ * (1 - n_*n_)
def squaresum(a): def squaresum(a):
...@@ -16,17 +19,17 @@ def squaresum(a): ...@@ -16,17 +19,17 @@ def squaresum(a):
def rayleighSommerfeldTF(shape, dperp, k, dz): def rayleighSommerfeldTF(shape, dperp, k, dz):
# construct Rayleigh-Sommerfeld transfer function (Goodman: Fourier Optics, 4.2.3) # construct Rayleigh-Sommerfeld transfer function (Goodman: Fourier Optics, 4.2.3)
wl = _k0 / k # relative wavelength wl = _k0 / k # relative wavelength
f = fftfreqn(shape, dperp) # spatial frequencies f = fftfreqn(shape, dperp) # spatial frequencies
f2 = squaresum(f) f2 = squaresum(f)
mask = (f2 < 1 / wl) mask = (f2 < 1 / wl)
phasechirp = np.sqrt(1 - wl**2 * (mask * f2)) phasechirp = np.sqrt(1 - wl**2 * (mask * f2))
TF = mask * np.exp(1j * k * dz * phasechirp) TF = mask * np.exp(1j * k * dz * phasechirp)
return TF return TF
...@@ -34,17 +37,16 @@ class MultislicePropagator(): ...@@ -34,17 +37,16 @@ class MultislicePropagator():
""" """
Multi-Slice approximation Multi-Slice approximation
Paganin, Coherent X-Ray Optics, p101 Paganin, Coherent X-Ray Optics, p101
See Kenan Li, Michael Wojcik, and Chris Jacobsen 2017 See Kenan Li, Michael Wojcik, and Chris Jacobsen 2017
""" """
dtype = np.complex128 dtype = np.complex128
def __init__(self, u0, d, wl=1.): def __init__(self, u0, d, wl=1.):
ndim = len(d) # ndim = len(d)
# TODO: check input # TODO: check input
# assert u0 is array # assert u0 is array
# assert d is tuple # assert d is tuple
...@@ -53,42 +55,40 @@ class MultislicePropagator(): ...@@ -53,42 +55,40 @@ class MultislicePropagator():
self._dz = d[0] self._dz = d[0]
self._dperp = np.array(d[1:]) self._dperp = np.array(d[1:])
self._ones = np.ones(u0.shape, dtype = self.dtype) self._ones = np.ones(u0.shape, dtype=self.dtype)
self._k = _k0 / wl self._k = _k0 / wl
self._fourierKernel = rayleighSommerfeldTF(self._ones.shape, self._dperp, self._k, self._dz) self._fourierKernel = rayleighSommerfeldTF(self._ones.shape, self._dperp, self._k, self._dz)
self.u = ifftshift(u0) self.u = ifftshift(u0)
def step(self, n): def step(self, n):
nshift = ifftshift(n) nshift = ifftshift(n)
F = _compute_potential(nshift, self._k) * self._ones F = _compute_potential(nshift, self._k) * self._ones
self._realKernel = np.exp(-1j * self._dz / (2 * self._k) * F ) self._realKernel = np.exp(-1j * self._dz / (2 * self._k) * F)
up = self.u up = self.u
self.u = ifftn(self._fourierKernel * fftn(self._realKernel * up)) self.u = ifftn(self._fourierKernel * fftn(self._realKernel * up))
return fftshift(self.u) return fftshift(self.u)
class FDPropagator2d(fd.Solver2d): class FDPropagator2d(fd.Solver2d):
def __init__(self, n0, u0, dz, dx, wl=1.): def __init__(self, n0, u0, dz, dx, wl=1.):
self._k = _k0 / wl self._k = _k0 / wl
Az = 2j * self._k Az = 2j * self._k
Axx = 1 Axx = 1
F0 = _compute_potential(n0, self._k) F0 = _compute_potential(n0, self._k)
super().__init__(Az, Axx, F0, u0, dz, dx) super().__init__(Az, Axx, F0, u0, dz, dx)
def step(self, n, boundary): def step(self, n, boundary):
F = _compute_potential(n, self._k) F = _compute_potential(n, self._k)
...@@ -98,32 +98,31 @@ class FDPropagator2d(fd.Solver2d): ...@@ -98,32 +98,31 @@ class FDPropagator2d(fd.Solver2d):
class FDPropagatorCS(fd.Solver2dfull): class FDPropagatorCS(fd.Solver2dfull):
def __init__(self, n0, u0, dz, dx, wl=1.): def __init__(self, n0, u0, dz, dx, wl=1.):
nx = u0.shape[-1] nx = u0.shape[-1]
self._x = np.linspace(-nx * dx / 2, nx * dx / 2, nx) self._x = np.linspace(-nx * dx / 2, nx * dx / 2, nx)
self._k = _k0 / wl self._k = _k0 / wl
Az = 2j * self._k * self._x Az = 2j * self._k * self._x
Axx = self._x Axx = self._x
Ax = 1 Ax = 1
F0 = _compute_potential(n0, self._k) * self._x F0 = _compute_potential(n0, self._k) * self._x
super().__init__(Az, Axx, Ax, F0, u0, dz, dx) super().__init__(Az, Axx, Ax, F0, u0, dz, dx)
def step(self, n, boundary): def step(self, n, boundary):
F = _compute_potential(n, self._k) * self._x F = _compute_potential(n, self._k) * self._x
return super().step(F, boundary) return super().step(F, boundary)
class FDPropagator3d(fd.Solver3d): class FDPropagator3d(fd.Solver3d):
def __init__(self, n0, u0, dz, dy, dx, wl=1.): def __init__(self, n0, u0, dz, dy, dx, wl=1.):
self._k = _k0 / wl self._k = _k0 / wl
Az = 2j * self._k Az = 2j * self._k
...@@ -133,10 +132,8 @@ class FDPropagator3d(fd.Solver3d): ...@@ -133,10 +132,8 @@ class FDPropagator3d(fd.Solver3d):
super().__init__(Az, Axx, Ayy, F0, u0, dz, dy, dx) super().__init__(Az, Axx, Ayy, F0, u0, dz, dy, dx)
def step(self, n, boundary): def step(self, n, boundary):
F = _compute_potential(n, self._k) F = _compute_potential(n, self._k)
return super().step(F, boundary) return super().step(F, boundary)
...@@ -2,7 +2,7 @@ from . import hankel ...@@ -2,7 +2,7 @@ from . import hankel
from .misc import fftfreqn, gridn from .misc import fftfreqn, gridn
import numpy as np import numpy as np
from pyfftw.interfaces.numpy_fft import fft,fftn,fftfreq,ifftn,ifftshift,fftshift from pyfftw.interfaces.numpy_fft import fftn, ifftn
from scipy.signal import fftconvolve from scipy.signal import fftconvolve
_lambda0 = 1. # all lengths are in units of the wavelength _lambda0 = 1. # all lengths are in units of the wavelength
...@@ -12,30 +12,29 @@ _k0 = 2 * np.pi / _lambda0 ...@@ -12,30 +12,29 @@ _k0 = 2 * np.pi / _lambda0
def fresnelKernelCS(N, fresnelNumber): def fresnelKernelCS(N, fresnelNumber):
hFreq = hankel.hankelFreq(N) hFreq = hankel.hankelFreq(N)
kern = np.exp(-1j * np.pi * hFreq**2 / fresnelNumber ) kern = np.exp(-1j * np.pi * hFreq**2 / fresnelNumber)
return kern return kern
class FresnelPropagatorCS: class FresnelPropagatorCS:
def __init__(self, N, fresnelNumber): def __init__(self, N, fresnelNumber):
self._N = N self._N = N
Y = hankel.hankelMatrix(N, 0) Y = hankel.hankelMatrix(N, 0)
kern = fresnelKernelCS(N, fresnelNumber) kern = fresnelKernelCS(N, fresnelNumber)
# TODO: express this nicer (kern is a diagonal matrix) # TODO: express this nicer (kern is a diagonal matrix)
self._matrix = Y @ (kern[:,None] * Y) self._matrix = Y @ (kern[:, None] * Y)
def __call__(self, u): def __call__(self, u):
uprop = self._matrix @ u uprop = self._matrix @ u
return uprop return uprop
def fresnelTFKernel(shape, fresnelNumbers): def fresnelTFKernel(shape, fresnelNumbers):
...@@ -45,7 +44,7 @@ def fresnelTFKernel(shape, fresnelNumbers): ...@@ -45,7 +44,7 @@ def fresnelTFKernel(shape, fresnelNumbers):
fresnelNumbers = fresnelNumbers * np.ones(ndim) fresnelNumbers = fresnelNumbers * np.ones(ndim)
f = fftfreqn(shape) f = fftfreqn(shape)
kernel = np.ones(shape, dtype=np.complex128) kernel = np.ones(shape, dtype=np.complex128)
for dim in range(ndim): for dim in range(ndim):
kernel *= np.exp((-1j * np.pi/(fresnelNumbers[dim])) * f[dim]**2) kernel *= np.exp((-1j * np.pi/(fresnelNumbers[dim])) * f[dim]**2)
...@@ -55,22 +54,20 @@ def fresnelTFKernel(shape, fresnelNumbers): ...@@ -55,22 +54,20 @@ def fresnelTFKernel(shape, fresnelNumbers):
class FresnelTFPropagator: class FresnelTFPropagator:
def __init__(self, shape, fresnelNumbers): def __init__(self, shape, fresnelNumbers):
self._shape = np.array(shape) self._shape = np.array(shape)
self._ndim = len(self._shape) self._ndim = len(self._shape)
if np.isscalar(fresnelNumbers): if np.isscalar(fresnelNumbers):
fresnelNumbers = fresnelNumbers * np.ones(self._ndim) fresnelNumbers = fresnelNumbers * np.ones(self._ndim)
self._kernel = fresnelTFKernel(shape, fresnelNumbers) self._kernel = fresnelTFKernel(shape, fresnelNumbers)
def __call__(self, u): def __call__(self, u):
uprop = ifftn( self._kernel * fftn(u))
return uprop uprop = ifftn(self._kernel * fftn(u))
return uprop
def fresnelIRKernel(shape, fresnelNumbers): def fresnelIRKernel(shape, fresnelNumbers):
...@@ -81,34 +78,32 @@ def fresnelIRKernel(shape, fresnelNumbers): ...@@ -81,34 +78,32 @@ def fresnelIRKernel(shape, fresnelNumbers):
fresnelNumbers = fresnelNumbers * np.ones(ndim) fresnelNumbers = fresnelNumbers * np.ones(ndim)
kernel = np.ones(2 * np.array(shape, dtype=int) - 1, dtype=np.complex128) kernel = np.ones(2 * np.array(shape, dtype=int) - 1, dtype=np.complex128)
# phase factor # phase factor
kernel *= np.prod(np.sqrt(np.abs(fresnelNumbers)) * np.exp((-1j * np.pi / 4) * np.sign(fresnelNumbers)) ) kernel *= np.prod(np.sqrt(np.abs(fresnelNumbers)) * np.exp((-1j * np.pi / 4) * np.sign(fresnelNumbers)))
x = gridn(shape) x = gridn(shape)
for dim in range(ndim): for dim in range(ndim):
xdim = x[dim] # / shape[dim] kernel *= np.exp(1j*np.pi*fresnelNumbers[dim] * x[dim]**2)
kernel *= np.exp(1j*np.pi*fresnelNumbers[dim] * xdim**2)
return kernel return kernel
class FresnelIRPropagator: class FresnelIRPropagator:
def __init__(self, shape, fresnelNumbers): def __init__(self, shape, fresnelNumbers):
self._shape = np.array(shape) self._shape = np.array(shape)
self._ndim = len(self._shape) self._ndim = len(self._shape)
if np.isscalar(fresnelNumbers): if np.isscalar(fresnelNumbers):
fresnelNumbers = fresnelNumbers * np.ones(self._ndim) fresnelNumbers = fresnelNumbers * np.ones(self._ndim)
self._kernel = fresnelIRKernel(shape, fresnelNumbers) self._kernel = fresnelIRKernel(shape, fresnelNumbers)
def __call__(self, u): def __call__(self, u):
uprop = fftconvolve(u, self._kernel, mode='valid') uprop = fftconvolve(u, self._kernel, mode='valid')
return uprop return uprop
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment