Skip to content
Snippets Groups Projects

Make Padder work for torch tensors

Merged Paul Meyer requested to merge extend_padder into master
Files
3
+ 25
17
@@ -6,7 +6,7 @@ from scipy.signal import get_window
from scipy.fft import fft2, ifft2, rfft2, irfft2
from scipy.ndimage import fourier_gaussian
from ..utils import rfftshape
from ..utils import rfftshape, Padder
def _broadcast_to_dim(i):
@@ -80,15 +80,14 @@ class FourierFilter:
def kernel(self, val):
self._kernel = as_tensor(val, device=self.device, dtype=self.dtype)
def __init__(self, shape, *args, real=True, norm=None, dtype=None, device=None):
self.shape = tuple(
shape
) # image shape (note: x can be stacked images of in shape ``(n, *shape)``)
def __init__(self, shape, *args, real=True, norm=None, dtype=None, device=None, pad=0):
self.real = real # use rfft?
self.norm = norm # fft normalization
self.dtype = dtype
self.device = device
self._kernel = None
self.padder = Padder(shape, pad, mode="edge")
self.shape = self.padder.padded_shape
if real:
self.kernel_shape = rfftshape(self.shape)
@@ -134,9 +133,11 @@ class FourierFilter:
Filtered image as Tensor. If NumPy ndarray is required, call ``.numpy()``.
"""
x = as_tensor(x, device=self.device)
x = self.padder(x)
X = self._fft(x)
Y = self.apply_filter(X)
return self._ifft(Y)
y = self._ifft(Y)
return self.padder.crop(y)
class GaussianBlur(FourierFilter):
@@ -158,7 +159,7 @@ class GaussianBandpass(FourierFilter):
Parameters
----------
shape: tuple
Shape of image(s) to be filtered.
Shape of image(s) to be filtered. Without padding and without (optional) batch dimension.
sigma_low: float, tuple
Standard deviation of Gaussian used for kept lower frequencies. Supports standard deviation per axis given as
tuple. See ``scipy.ndimage.fourier_gaussian`` for details.
@@ -173,21 +174,28 @@ class GaussianBandpass(FourierFilter):
Input signal/image is real-valued (Default ``True``). If signal is complex, set ``False``.
norm: None, str
Normalization of FFT. Choices are ``"backward"`` or ``None``, i.e. the default, ``"forward"``, or ``"ortho"``
dtype: torch.dtype
datatype for the filter kernel.
pad: int, tuple (optional)
padding to be applied before applying the filter.
accepts arguments like ``pad_width`` of ``numpy.pad``. Defaults to 0.
Example
-------
>>> from hotopy.image import GaussianBandpass
>>> from scipy.datasets import ascent
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> img = ascent()
>>> lowpass = GaussianBandpass(img.shape, 10, None)
>>> highpass = GaussianBandpass(img.shape, None, 40)
>>> bandpass = GaussianBandpass(img.shape, 10, 40)
>>> img_filtered_sequential = lowpass(highpass(img)).numpy()
>>> img_filtered_bandpass = bandpass(img).numpy()
>>> np.allclose(img_filtered_bandpass, img_filtered_sequential)
True
>>> image_stack = np.random.random((3, 64, 64)) + np.linspace(0, 2, 64)[None,:,None]
>>> sigma = 1
>>> pad = int(4 * sigma + 1)
>>> # pad = 0 # comparison without padding
>>> lowpass = GaussianBandpass(image_stack.shape[-2:], sigma, None, pad=pad)
>>> highpass = GaussianBandpass(image_stack.shape[-2:], None, 2 * sigma, pad=pad)
>>> bandpass = GaussianBandpass(image_stack.shape[-2:], sigma, 2 * sigma, pad=pad)
>>> plt.subplot(221); plt.imshow(image_stack[0])
>>> plt.subplot(222); plt.imshow(lowpass(image_stack)[0])
>>> plt.subplot(223); plt.imshow(highpass(image_stack)[0])
>>> plt.subplot(224); plt.imshow(bandpass(image_stack)[0])
"""
def _init_kernel(self, sigma_low, sigma_high):
Loading