Skip to content
Snippets Groups Projects

Little changes on Padder

Merged Jens Lucht requested to merge fix-padder-doc into master
1 file
+ 56
46
Compare changes
  • Side-by-side
  • Inline
+ 56
46
@@ -31,20 +31,40 @@ def pad_width_to_torch(pad_width, ndim):
return tuple(chain(*reversed(pad_pairs)))
class IdentityPadder:
def __init__(self, imshape):
self.padded_shape = imshape
class _PadderBase:
def __init__(self, imshape, pad_width, mode="constant", value=0):
self.imshape = tuple(imshape)
self.imdim = len(imshape)
self.pad_width = pad_width_as_pairs(pad_width, self.imdim)
self.crop_slice = (...,) + tuple(slice(pl, -pr or None) for pl, pr in self.pad_width)
def __call__(self, array):
return array
pass
def crop(self, array):
return array
return array[self.crop_slice]
@property
def padded_shape(self):
"""Shape of padded image."""
return tuple(s + pl + pr for s, (pl, pr) in zip(self.imshape, self.pad_width))
inv = crop
"""Alias for crop."""
class IdentityPadder(_PadderBase):
def __init__(self, imshape, pad_width, **kwargs):
super().__init__(imshape, 0, **kwargs) # pad_width is fixed as 0
def __call__(self, array):
return array
class Padder:
def crop(self, array):
return array
class Padder(_PadderBase):
"""
Pad numpy arrays and torch tensors, but also have access to invert the padding (crop).
Parameter input aligns with numpy.pad and is translated for torch calls.
@@ -61,25 +81,25 @@ class Padder:
combines "stat_length" "constant_values", "end_values" parameters of numpy.pad.
value for constant padding for tensors.
Example:
--------
from hotopy.utils import Padder
import numpy as np
import torch
print(arr := np.arange(9).reshape(3, 3))
p = Padder(arr.shape, 1)
print(f"{p.padded_shape = }")
print(padded := p(arr))
print(unpadded := p.inv(padded))
print(arr := torch.as_tensor(arr))
p = Padder(arr.shape, ((2, 0), (3, 1)), mode="edge")
print(padded := p(arr))
print(unpadded := p.inv(padded))
Example
-------
>>> import numpy as np
>>> import torch
>>> from hotopy.utils import Padder
>>> print(arr := np.arange(9).reshape(3, 3))
>>> p = Padder(arr.shape, 1)
>>> print(f"{p.padded_shape = }")
>>> print(padded := p(arr))
>>> print(unpadded := p.inv(padded))
>>> print(arr := torch.as_tensor(arr))
>>> p = Padder(arr.shape, ((2, 0), (3, 1)), mode="edge")
>>> print(padded := p(arr))
>>> print(unpadded := p.inv(padded))
"""
torch_padmode_from_numpy = {
_torch_padmode_from_numpy = {
"edge": "replicate",
"wrap": "circular",
}
@@ -91,9 +111,7 @@ class Padder:
return IdentityPadder(imshape)
def __init__(self, imshape, pad_width, mode="constant", value=0):
self.imshape = imshape
self.imdim = len(imshape)
self.pad_width = pad_width_as_pairs(pad_width, self.imdim)
super().__init__(imshape, pad_width, mode=mode, value=value)
# build arguments for numpy pad calls
self.np_args = {"mode": mode}
@@ -107,31 +125,23 @@ class Padder:
# build arguments for torch pad calls
self.torch_args = {
"mode": self.torch_padmode_from_numpy.get(mode, mode),
"mode": self._torch_padmode_from_numpy.get(mode, mode),
"value": value,
}
self.pad_width_torch = pad_width_to_torch(self.pad_width, self.imdim)
self.crop_slice = (...,) + tuple(slice(pl, -pr or None) for pl, pr in self.pad_width)
def __call__(self, array):
match array:
case np.ndarray():
pad_width = [[0, 0]] * (array.ndim - self.imdim) + self.pad_width
return np.pad(array, pad_width, **self.np_args)
case torch.Tensor():
# batch dimension is needed for non constant padding in torch
return torch.nn.functional.pad(
array[None], self.pad_width_torch, **self.torch_args
)[0]
case _:
raise ValueError("Padder called with incompatible array type: %s", type(array))
array_class = type(array)
def crop(self, array):
return array[self.crop_slice]
if array_class is torch.Tensor:
# batch dimension is needed for non constant padding in torch
padded = torch.nn.functional.pad(array[None], self.pad_width_torch, **self.torch_args)
padded = padded[0] # remove auxiliary batch dimension
else:
if array_class is not np.ndarray:
array = np.asarray(array)
@property
def padded_shape(self):
return tuple(s + pl + pr for s, (pl, pr) in zip(self.imshape, self.pad_width))
pad_width = [[0, 0]] * (array.ndim - self.imdim) + self.pad_width
padded = np.pad(array, pad_width, **self.np_args)
inv = crop
"""Alias for inverse padding (aka cropping)."""
return padded
Loading