Commit fd346541 authored by Leon Merten Lohse's avatar Leon Merten Lohse
Browse files

use black for code formatting

parent 6fa01584
[flake8]
max-line-length = 100
ignore = E501
select = C,E,F,W,B,B9
......@@ -3,7 +3,7 @@ import numpy as np
from . import _solver
class Solver2d():
class Solver2d:
"""
Solver for 2d PDEs of the form
Az(x) * u_z = -Axx(x) * u_xx + F(x,z) * u
......@@ -30,10 +30,10 @@ class Solver2d():
return Az * self._ones
def _compute_rxx(self, Axx):
return -Axx * self.dz / 2. / self.dx**2 * self._ones
return -Axx * self.dz / 2.0 / self.dx ** 2 * self._ones
def _compute_f(self, F):
return F * self.dz / 2. * self._ones
return F * self.dz / 2.0 * self._ones
def step(self, F, boundary):
......@@ -53,7 +53,7 @@ class Solver2d():
return self.u
class Solver2dfull():
class Solver2dfull:
"""
Solver for 2d PDEs of the form
Az(x) * u_z = Axx(x) * u_xx + Ax(x) * u_x + F(x,z) * u
......@@ -81,13 +81,13 @@ class Solver2dfull():
return Az * self._ones
def _compute_rxx(self, Axx):
return -Axx * self.dz / 2. / self.dx**2 * self._ones
return -Axx * self.dz / 2.0 / self.dx ** 2 * self._ones
def _compute_rx(self, Ax):
return -Ax * self.dz / 4. / self.dx * self._ones
return -Ax * self.dz / 4.0 / self.dx * self._ones
def _compute_f(self, F):
return F * self.dz / 2. * self._ones
return F * self.dz / 2.0 * self._ones
def step(self, F, boundary):
......@@ -102,13 +102,12 @@ class Solver2dfull():
u[0] = boundary[0]
u[-1] = boundary[1]
self.u = _solver.step1d_AAF(self.rz, self.rxx, self.rx,
fp, self.f, up, u)
self.u = _solver.step1d_AAF(self.rz, self.rxx, self.rx, fp, self.f, up, u)
return self.u
class Solver3d():
class Solver3d:
"""
Solver for equations of the form
Az * u_z = Axx * u_xx + Ayy * u_yy + F(x,y,z) * u
......@@ -122,7 +121,7 @@ class Solver3d():
self._nx = u0.shape[-1]
self._ny = u0.shape[-2]
self._ones = np.ones((self._ny, self._nx,), dtype=self.dtype)
self._ones = np.ones(u0.shape, dtype=self.dtype)
self._boundary_slice = np.zeros_like(self._ones)
self.u = u0 * self._ones
......@@ -137,16 +136,16 @@ class Solver3d():
self.f = self._compute_f(F0)
def _compute_rz(self, Az):
return Az * (1+0j)
return Az * (1 + 0j)
def _compute_rxx(self, Axx):
return -Axx * self.dz / 2. / self.dx**2 * (1+0j)
return -Axx * self.dz / 2.0 / self.dx ** 2 * (1 + 0j)
def _compute_ryy(self, Ayy):
return -Ayy * self.dz / 2. / self.dy**2 * (1+0j)
return -Ayy * self.dz / 2.0 / self.dy ** 2 * (1 + 0j)
def _compute_f(self, F):
return F * self.dz / 4. * self._ones
return F * self.dz / 4.0 * self._ones
def step(self, F, boundary):
......
......@@ -3,7 +3,7 @@ import scipy.special
def hankelMatrix(N, n=0):
'''
"""
returns a N x N matrix for discrete Hankel transfrom of n-th order.
N: number of pixels
......@@ -15,7 +15,7 @@ def hankelMatrix(N, n=0):
As in: Theory and operational rules for the discreteHankel transform
by Natalie Baddour* and Ugo Chouinard
https://doi.org/10.1364/JOSAA.32.000611
'''
"""
jn = np.array(scipy.special.jn_zeros(n, N + 1))
......@@ -31,13 +31,13 @@ def hankelMatrix(N, n=0):
def hankelFreq(N, n=0, kmax=0.5):
'''
"""
Returns the Hankel space (frequency) sampling grid for the inverse discrete
Hankel transfrom (of order n) of a signal with N pixels.
kmax is the maximum sampling frequency in dimensionless units, i.e.
minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px)
-> 0.5 dimensionless
'''
"""
jn = np.array(scipy.special.jn_zeros(n, N + 1))
......@@ -45,21 +45,20 @@ def hankelFreq(N, n=0, kmax=0.5):
def hankelSamples(N, n=0, kmax=0.5):
'''
"""
Returns the real space sampling grid for the forward discrete Hankel
transfrom (of order n) of a signal with N pixels.
kmax is the maximum sampling frequency in dimensionless units, i.e.
minimal sampled realspace oscillation 2px -> max. sampled frequency 1/(2px)
-> 0.5 dimensionless
'''
"""
jn = np.array(scipy.special.jn_zeros(n, N))
return jn / (kmax*2*np.pi)
return jn / (kmax * 2 * np.pi)
class DiscreteHankelTransform:
def __init__(self, N, n=0, kmax=0.5):
self._matrix = hankelMatrix(N, n)
......
......@@ -8,7 +8,7 @@ def fftfreqn(N, d=1.0):
d *= np.ones(ndim)
# index trick: n'th line is a list of ones with a -1 on n'th position
shapes = np.ones((ndim, ndim), dtype='int') - 2 * np.eye(ndim, dtype='int')
shapes = np.ones((ndim, ndim), dtype="int") - 2 * np.eye(ndim, dtype="int")
xi = [np.reshape(fftfreq(N[dim], d[dim]), tuple(shapes[dim, :])) for dim in range(ndim)]
......@@ -19,9 +19,9 @@ def gridn(N):
ndim = len(N)
# index trick: n'th line is a list of ones with a -1 on n'th position
shapes = np.ones((ndim, ndim), dtype='int') - 2 * np.eye(ndim, dtype='int')
shapes = np.ones((ndim, ndim), dtype="int") - 2 * np.eye(ndim, dtype="int")
x = [np.reshape(np.arange(-N[dim]+1, N[dim]), tuple(shapes[dim, :])) for dim in range(ndim)]
x = [np.reshape(np.arange(-N[dim] + 1, N[dim]), tuple(shapes[dim, :])) for dim in range(ndim)]
return x
......
......@@ -5,11 +5,12 @@ from pyfftw.interfaces.numpy_fft import fftn, ifftn, ifftshift, fftshift
from . import finite_differences as fd
from .misc import fftfreqn
_lambda0 = 1. # all lengths are in units of the wavelength
_lambda0 = 1.0 # all lengths are in units of the wavelength
_k0 = 2 * np.pi / _lambda0
def _compute_potential(n_, k_=_k0): return k_*k_ * (1 - n_*n_)
def _compute_potential(n_, k_=_k0):
return k_ * k_ * (1 - n_ * n_)
def squaresum(a):
......@@ -23,16 +24,16 @@ def rayleighSommerfeldTF(shape, dperp, k, dz):
f = fftfreqn(shape, dperp) # spatial frequencies
f2 = squaresum(f)
mask = (f2 < 1 / wl)
mask = f2 < 1 / wl
phasechirp = np.sqrt(1 - wl**2 * (mask * f2))
phasechirp = np.sqrt(1 - wl ** 2 * (mask * f2))
TF = mask * np.exp(1j * k * dz * phasechirp)
return TF
class MultislicePropagator():
class MultislicePropagator:
"""
Multi-Slice approximation
Paganin, Coherent X-Ray Optics, p101
......@@ -42,7 +43,7 @@ class MultislicePropagator():
dtype = np.complex128
def __init__(self, u0, d, wl=1.):
def __init__(self, u0, d, wl=1.0):
# ndim = len(d)
......@@ -76,8 +77,7 @@ class MultislicePropagator():
class FDPropagator2d(fd.Solver2d):
def __init__(self, n0, u0, dz, dx, wl=1.):
def __init__(self, n0, u0, dz, dx, wl=1.0):
self._k = _k0 / wl
......@@ -96,8 +96,7 @@ class FDPropagator2d(fd.Solver2d):
class FDPropagatorCS(fd.Solver2dfull):
def __init__(self, n0, u0, dz, dx, wl=1.):
def __init__(self, n0, u0, dz, dx, wl=1.0):
nx = u0.shape[-1]
self._x = np.linspace(-nx * dx / 2, nx * dx / 2, nx)
......@@ -119,8 +118,7 @@ class FDPropagatorCS(fd.Solver2dfull):
class FDPropagator3d(fd.Solver3d):
def __init__(self, n0, u0, dz, dy, dx, wl=1.):
def __init__(self, n0, u0, dz, dy, dx, wl=1.0):
self._k = _k0 / wl
......
......@@ -9,13 +9,12 @@ from .misc import fftfreqn, gridn
def fresnelKernelCS(N, fresnelNumber):
hFreq = hankel.hankelFreq(N)
kern = np.exp(-1j * np.pi * hFreq**2 / fresnelNumber)
kern = np.exp(-1j * np.pi * hFreq ** 2 / fresnelNumber)
return kern
class FresnelPropagatorCS:
def __init__(self, N, fresnelNumber):
self._N = N
......@@ -43,13 +42,12 @@ def fresnelTFKernel(shape, fresnelNumbers):
f = fftfreqn(shape)
kernel = np.ones(shape, dtype=np.complex128)
for dim in range(ndim):
kernel *= np.exp((-1j * np.pi/(fresnelNumbers[dim])) * f[dim]**2)
kernel *= np.exp((-1j * np.pi / (fresnelNumbers[dim])) * f[dim] ** 2)
return kernel
class FresnelTFPropagator:
def __init__(self, shape, fresnelNumbers):
self._shape = np.array(shape)
......@@ -77,18 +75,19 @@ def fresnelIRKernel(shape, fresnelNumbers):
kernel = np.ones(2 * np.array(shape, dtype=int) - 1, dtype=np.complex128)
# phase factor
kernel *= np.prod(np.sqrt(np.abs(fresnelNumbers)) * np.exp((-1j * np.pi / 4) * np.sign(fresnelNumbers)))
kernel *= np.prod(
np.sqrt(np.abs(fresnelNumbers)) * np.exp((-1j * np.pi / 4) * np.sign(fresnelNumbers))
)
x = gridn(shape)
for dim in range(ndim):
kernel *= np.exp(1j*np.pi*fresnelNumbers[dim] * x[dim]**2)
kernel *= np.exp(1j * np.pi * fresnelNumbers[dim] * x[dim] ** 2)
return kernel
class FresnelIRPropagator:
def __init__(self, shape, fresnelNumbers):
self._shape = np.array(shape)
......@@ -101,6 +100,6 @@ class FresnelIRPropagator:
def __call__(self, u):
uprop = fftconvolve(u, self._kernel, mode='valid')
uprop = fftconvolve(u, self._kernel, mode="valid")
return uprop
[build-system]
requires = ["setuptools", "wheel", "scikit-build", "cmake", "ninja"]
build-backend = "setuptools.build_meta"
[tool.black]
line-length = 100
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment