Commit fd346541 authored by Leon Merten Lohse's avatar Leon Merten Lohse
Browse files

use black for code formatting

parent 6fa01584
[flake8]
max-line-length = 100
ignore = E501
select = C,E,F,W,B,B9
...@@ -3,7 +3,7 @@ import numpy as np ...@@ -3,7 +3,7 @@ import numpy as np
from . import _solver from . import _solver
class Solver2d(): class Solver2d:
""" """
Solver for 2d PDEs of the form Solver for 2d PDEs of the form
Az(x) * u_z = -Axx(x) * u_xx + F(x,z) * u Az(x) * u_z = -Axx(x) * u_xx + F(x,z) * u
...@@ -30,10 +30,10 @@ class Solver2d(): ...@@ -30,10 +30,10 @@ class Solver2d():
return Az * self._ones return Az * self._ones
def _compute_rxx(self, Axx): def _compute_rxx(self, Axx):
return -Axx * self.dz / 2. / self.dx**2 * self._ones return -Axx * self.dz / 2.0 / self.dx ** 2 * self._ones
def _compute_f(self, F): def _compute_f(self, F):
return F * self.dz / 2. * self._ones return F * self.dz / 2.0 * self._ones
def step(self, F, boundary): def step(self, F, boundary):
...@@ -53,7 +53,7 @@ class Solver2d(): ...@@ -53,7 +53,7 @@ class Solver2d():
return self.u return self.u
class Solver2dfull(): class Solver2dfull:
""" """
Solver for 2d PDEs of the form Solver for 2d PDEs of the form
Az(x) * u_z = Axx(x) * u_xx + Ax(x) * u_x + F(x,z) * u Az(x) * u_z = Axx(x) * u_xx + Ax(x) * u_x + F(x,z) * u
...@@ -81,13 +81,13 @@ class Solver2dfull(): ...@@ -81,13 +81,13 @@ class Solver2dfull():
return Az * self._ones return Az * self._ones
def _compute_rxx(self, Axx): def _compute_rxx(self, Axx):
return -Axx * self.dz / 2. / self.dx**2 * self._ones return -Axx * self.dz / 2.0 / self.dx ** 2 * self._ones
def _compute_rx(self, Ax): def _compute_rx(self, Ax):
return -Ax * self.dz / 4. / self.dx * self._ones return -Ax * self.dz / 4.0 / self.dx * self._ones
def _compute_f(self, F): def _compute_f(self, F):
return F * self.dz / 2. * self._ones return F * self.dz / 2.0 * self._ones
def step(self, F, boundary): def step(self, F, boundary):
...@@ -102,13 +102,12 @@ class Solver2dfull(): ...@@ -102,13 +102,12 @@ class Solver2dfull():
u[0] = boundary[0] u[0] = boundary[0]
u[-1] = boundary[1] u[-1] = boundary[1]
self.u = _solver.step1d_AAF(self.rz, self.rxx, self.rx, self.u = _solver.step1d_AAF(self.rz, self.rxx, self.rx, fp, self.f, up, u)
fp, self.f, up, u)
return self.u return self.u
class Solver3d(): class Solver3d:
""" """
Solver for equations of the form Solver for equations of the form
Az * u_z = Axx * u_xx + Ayy * u_yy + F(x,y,z) * u Az * u_z = Axx * u_xx + Ayy * u_yy + F(x,y,z) * u
...@@ -122,7 +121,7 @@ class Solver3d(): ...@@ -122,7 +121,7 @@ class Solver3d():
self._nx = u0.shape[-1] self._nx = u0.shape[-1]
self._ny = u0.shape[-2] self._ny = u0.shape[-2]
self._ones = np.ones((self._ny, self._nx,), dtype=self.dtype) self._ones = np.ones(u0.shape, dtype=self.dtype)
self._boundary_slice = np.zeros_like(self._ones) self._boundary_slice = np.zeros_like(self._ones)
self.u = u0 * self._ones self.u = u0 * self._ones
...@@ -137,16 +136,16 @@ class Solver3d(): ...@@ -137,16 +136,16 @@ class Solver3d():
self.f = self._compute_f(F0) self.f = self._compute_f(F0)
def _compute_rz(self, Az): def _compute_rz(self, Az):
return Az * (1+0j) return Az * (1 + 0j)
def _compute_rxx(self, Axx): def _compute_rxx(self, Axx):
return -Axx * self.dz / 2. / self.dx**2 * (1+0j) return -Axx * self.dz / 2.0 / self.dx ** 2 * (1 + 0j)
def _compute_ryy(self, Ayy): def _compute_ryy(self, Ayy):
return -Ayy * self.dz / 2. / self.dy**2 * (1+0j) return -Ayy * self.dz / 2.0 / self.dy ** 2 * (1 + 0j)
def _compute_f(self, F): def _compute_f(self, F):
return F * self.dz / 4. * self._ones return F * self.dz / 4.0 * self._ones
def step(self, F, boundary): def step(self, F, boundary):
......
...@@ -3,7 +3,7 @@ import scipy.special ...@@ -3,7 +3,7 @@ import scipy.special
def hankelMatrix(N, n=0): def hankelMatrix(N, n=0):
''' """
returns a N x N matrix for discrete Hankel transfrom of n-th order. returns a N x N matrix for discrete Hankel transfrom of n-th order.
N: number of pixels N: number of pixels
...@@ -15,7 +15,7 @@ def hankelMatrix(N, n=0): ...@@ -15,7 +15,7 @@ def hankelMatrix(N, n=0):
As in: Theory and operational rules for the discreteHankel transform As in: Theory and operational rules for the discreteHankel transform
by Natalie Baddour* and Ugo Chouinard by Natalie Baddour* and Ugo Chouinard
https://doi.org/10.1364/JOSAA.32.000611 https://doi.org/10.1364/JOSAA.32.000611
''' """
jn = np.array(scipy.special.jn_zeros(n, N + 1)) jn = np.array(scipy.special.jn_zeros(n, N + 1))
...@@ -31,13 +31,13 @@ def hankelMatrix(N, n=0): ...@@ -31,13 +31,13 @@ def hankelMatrix(N, n=0):
def hankelFreq(N, n=0, kmax=0.5): def hankelFreq(N, n=0, kmax=0.5):
''' """
Returns the Hankel space (frequency) sampling grid for the inverse discrete Returns the Hankel space (frequency) sampling grid for the inverse discrete
Hankel transfrom (of order n) of a signal with N pixels. Hankel transfrom (of order n) of a signal with N pixels.
kmax is the maximum sampling frequency in dimensionless units, i.e. kmax is the maximum sampling frequency in dimensionless units, i.e.
minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px) minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px)
-> 0.5 dimensionless -> 0.5 dimensionless
''' """
jn = np.array(scipy.special.jn_zeros(n, N + 1)) jn = np.array(scipy.special.jn_zeros(n, N + 1))
...@@ -45,21 +45,20 @@ def hankelFreq(N, n=0, kmax=0.5): ...@@ -45,21 +45,20 @@ def hankelFreq(N, n=0, kmax=0.5):
def hankelSamples(N, n=0, kmax=0.5): def hankelSamples(N, n=0, kmax=0.5):
''' """
Returns the real space sampling grid for the forward discrete Hankel Returns the real space sampling grid for the forward discrete Hankel
transfrom (of order n) of a signal with N pixels. transfrom (of order n) of a signal with N pixels.
kmax is the maximum sampling frequency in dimensionless units, i.e. kmax is the maximum sampling frequency in dimensionless units, i.e.
minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px) minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px)
-> 0.5 dimensionless -> 0.5 dimensionless
''' """
jn = np.array(scipy.special.jn_zeros(n, N)) jn = np.array(scipy.special.jn_zeros(n, N))
return jn / (kmax*2*np.pi) return jn / (kmax * 2 * np.pi)
class DiscreteHankelTransform: class DiscreteHankelTransform:
def __init__(self, N, n=0, kmax=0.5): def __init__(self, N, n=0, kmax=0.5):
self._matrix = hankelMatrix(N, n) self._matrix = hankelMatrix(N, n)
......
...@@ -8,7 +8,7 @@ def fftfreqn(N, d=1.0): ...@@ -8,7 +8,7 @@ def fftfreqn(N, d=1.0):
d *= np.ones(ndim) d *= np.ones(ndim)
# index trick: n'th line is a list of ones with a -1 on n'th position # index trick: n'th line is a list of ones with a -1 on n'th position
shapes = np.ones((ndim, ndim), dtype='int') - 2 * np.eye(ndim, dtype='int') shapes = np.ones((ndim, ndim), dtype="int") - 2 * np.eye(ndim, dtype="int")
xi = [np.reshape(fftfreq(N[dim], d[dim]), tuple(shapes[dim, :])) for dim in range(ndim)] xi = [np.reshape(fftfreq(N[dim], d[dim]), tuple(shapes[dim, :])) for dim in range(ndim)]
...@@ -19,9 +19,9 @@ def gridn(N): ...@@ -19,9 +19,9 @@ def gridn(N):
ndim = len(N) ndim = len(N)
# index trick: n'th line is a list of ones with a -1 on n'th position # index trick: n'th line is a list of ones with a -1 on n'th position
shapes = np.ones((ndim, ndim), dtype='int') - 2 * np.eye(ndim, dtype='int') shapes = np.ones((ndim, ndim), dtype="int") - 2 * np.eye(ndim, dtype="int")
x = [np.reshape(np.arange(-N[dim]+1, N[dim]), tuple(shapes[dim, :])) for dim in range(ndim)] x = [np.reshape(np.arange(-N[dim] + 1, N[dim]), tuple(shapes[dim, :])) for dim in range(ndim)]
return x return x
......
...@@ -5,11 +5,12 @@ from pyfftw.interfaces.numpy_fft import fftn, ifftn, ifftshift, fftshift ...@@ -5,11 +5,12 @@ from pyfftw.interfaces.numpy_fft import fftn, ifftn, ifftshift, fftshift
from . import finite_differences as fd from . import finite_differences as fd
from .misc import fftfreqn from .misc import fftfreqn
_lambda0 = 1. # all lengths are in units of the wavelength _lambda0 = 1.0 # all lengths are in units of the wavelength
_k0 = 2 * np.pi / _lambda0 _k0 = 2 * np.pi / _lambda0
def _compute_potential(n_, k_=_k0): return k_*k_ * (1 - n_*n_) def _compute_potential(n_, k_=_k0):
return k_ * k_ * (1 - n_ * n_)
def squaresum(a): def squaresum(a):
...@@ -23,16 +24,16 @@ def rayleighSommerfeldTF(shape, dperp, k, dz): ...@@ -23,16 +24,16 @@ def rayleighSommerfeldTF(shape, dperp, k, dz):
f = fftfreqn(shape, dperp) # spatial frequencies f = fftfreqn(shape, dperp) # spatial frequencies
f2 = squaresum(f) f2 = squaresum(f)
mask = (f2 < 1 / wl) mask = f2 < 1 / wl
phasechirp = np.sqrt(1 - wl**2 * (mask * f2)) phasechirp = np.sqrt(1 - wl ** 2 * (mask * f2))
TF = mask * np.exp(1j * k * dz * phasechirp) TF = mask * np.exp(1j * k * dz * phasechirp)
return TF return TF
class MultislicePropagator(): class MultislicePropagator:
""" """
Multi-Slice approximation Multi-Slice approximation
Paganin, Coherent X-Ray Optics, p101 Paganin, Coherent X-Ray Optics, p101
...@@ -42,7 +43,7 @@ class MultislicePropagator(): ...@@ -42,7 +43,7 @@ class MultislicePropagator():
dtype = np.complex128 dtype = np.complex128
def __init__(self, u0, d, wl=1.): def __init__(self, u0, d, wl=1.0):
# ndim = len(d) # ndim = len(d)
...@@ -76,8 +77,7 @@ class MultislicePropagator(): ...@@ -76,8 +77,7 @@ class MultislicePropagator():
class FDPropagator2d(fd.Solver2d): class FDPropagator2d(fd.Solver2d):
def __init__(self, n0, u0, dz, dx, wl=1.0):
def __init__(self, n0, u0, dz, dx, wl=1.):
self._k = _k0 / wl self._k = _k0 / wl
...@@ -96,8 +96,7 @@ class FDPropagator2d(fd.Solver2d): ...@@ -96,8 +96,7 @@ class FDPropagator2d(fd.Solver2d):
class FDPropagatorCS(fd.Solver2dfull): class FDPropagatorCS(fd.Solver2dfull):
def __init__(self, n0, u0, dz, dx, wl=1.0):
def __init__(self, n0, u0, dz, dx, wl=1.):
nx = u0.shape[-1] nx = u0.shape[-1]
self._x = np.linspace(-nx * dx / 2, nx * dx / 2, nx) self._x = np.linspace(-nx * dx / 2, nx * dx / 2, nx)
...@@ -119,8 +118,7 @@ class FDPropagatorCS(fd.Solver2dfull): ...@@ -119,8 +118,7 @@ class FDPropagatorCS(fd.Solver2dfull):
class FDPropagator3d(fd.Solver3d): class FDPropagator3d(fd.Solver3d):
def __init__(self, n0, u0, dz, dy, dx, wl=1.0):
def __init__(self, n0, u0, dz, dy, dx, wl=1.):
self._k = _k0 / wl self._k = _k0 / wl
......
...@@ -9,13 +9,12 @@ from .misc import fftfreqn, gridn ...@@ -9,13 +9,12 @@ from .misc import fftfreqn, gridn
def fresnelKernelCS(N, fresnelNumber): def fresnelKernelCS(N, fresnelNumber):
hFreq = hankel.hankelFreq(N) hFreq = hankel.hankelFreq(N)
kern = np.exp(-1j * np.pi * hFreq**2 / fresnelNumber) kern = np.exp(-1j * np.pi * hFreq ** 2 / fresnelNumber)
return kern return kern
class FresnelPropagatorCS: class FresnelPropagatorCS:
def __init__(self, N, fresnelNumber): def __init__(self, N, fresnelNumber):
self._N = N self._N = N
...@@ -43,13 +42,12 @@ def fresnelTFKernel(shape, fresnelNumbers): ...@@ -43,13 +42,12 @@ def fresnelTFKernel(shape, fresnelNumbers):
f = fftfreqn(shape) f = fftfreqn(shape)
kernel = np.ones(shape, dtype=np.complex128) kernel = np.ones(shape, dtype=np.complex128)
for dim in range(ndim): for dim in range(ndim):
kernel *= np.exp((-1j * np.pi/(fresnelNumbers[dim])) * f[dim]**2) kernel *= np.exp((-1j * np.pi / (fresnelNumbers[dim])) * f[dim] ** 2)
return kernel return kernel
class FresnelTFPropagator: class FresnelTFPropagator:
def __init__(self, shape, fresnelNumbers): def __init__(self, shape, fresnelNumbers):
self._shape = np.array(shape) self._shape = np.array(shape)
...@@ -77,18 +75,19 @@ def fresnelIRKernel(shape, fresnelNumbers): ...@@ -77,18 +75,19 @@ def fresnelIRKernel(shape, fresnelNumbers):
kernel = np.ones(2 * np.array(shape, dtype=int) - 1, dtype=np.complex128) kernel = np.ones(2 * np.array(shape, dtype=int) - 1, dtype=np.complex128)
# phase factor # phase factor
kernel *= np.prod(np.sqrt(np.abs(fresnelNumbers)) * np.exp((-1j * np.pi / 4) * np.sign(fresnelNumbers))) kernel *= np.prod(
np.sqrt(np.abs(fresnelNumbers)) * np.exp((-1j * np.pi / 4) * np.sign(fresnelNumbers))
)
x = gridn(shape) x = gridn(shape)
for dim in range(ndim): for dim in range(ndim):
kernel *= np.exp(1j*np.pi*fresnelNumbers[dim] * x[dim]**2) kernel *= np.exp(1j * np.pi * fresnelNumbers[dim] * x[dim] ** 2)
return kernel return kernel
class FresnelIRPropagator: class FresnelIRPropagator:
def __init__(self, shape, fresnelNumbers): def __init__(self, shape, fresnelNumbers):
self._shape = np.array(shape) self._shape = np.array(shape)
...@@ -101,6 +100,6 @@ class FresnelIRPropagator: ...@@ -101,6 +100,6 @@ class FresnelIRPropagator:
def __call__(self, u): def __call__(self, u):
uprop = fftconvolve(u, self._kernel, mode='valid') uprop = fftconvolve(u, self._kernel, mode="valid")
return uprop return uprop
[build-system] [build-system]
requires = ["setuptools", "wheel", "scikit-build", "cmake", "ninja"] requires = ["setuptools", "wheel", "scikit-build", "cmake", "ninja"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[tool.black]
line-length = 100
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment