interpolation.py 4.5 KB
Newer Older
1
2
3
import numpy as np
from .array_tools import shifted_fft, shifted_ifft, roll_to_pos

Matthijs's avatar
Matthijs committed
4
__all__ = ['fourier_interpolate']
5

Matthijs's avatar
Matthijs committed
6

7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def fourier_interpolate(arr: np.ndarray, factor: any = 2., **kwargs) -> np.ndarray:
    """
    Interpolation by padding in the Fourier domain.
    Keyword arguments get passed to sub-functionalities. Options are for example:
    - center_data: whether to move the center of mass to the center of the array before interpolation,
    this is generally a good idea

    NOTE: testing by simple 2x interpolation and subsequent 2x2 binning shows that errors are introduced in the procedure

    :param arr: numpy array, 2 or 3-dimensional
    :param factor: tuple or float to indicate by which factor to interpolate, default to factor 2
    :returns: interpolated array, cast to identical dtype as arr
    """
    try:
        if len(factor) == 2:
            out = fourier_interpolate_2d_nonsquare(arr, factor=factor, **kwargs)
        elif len(factor) == 3:
            out = fourier_interpolate_3d_nonsquare(arr, factor=factor, **kwargs)
        else:
            raise Exception("This error should never raise")
    except TypeError:
        if len(arr.shape) == 2:
            out = fourier_interpolate_2d(arr, factor=factor, **kwargs)
        elif len(arr.shape) == 3:
            out = fourier_interpolate_3d(arr, factor=factor, **kwargs)
        else:
            raise NotImplementedError('Can only interpolate 2d or 3d arrays as of now')
Matthijs's avatar
Matthijs committed
34
35
36
37
    if np.isrealobj(arr):
        return out.real.astype(arr.dtype, casting='unsafe')
    else:
        return out.astype(arr.dtype, casting='unsafe')
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112


def fourier_interpolate_2d_nonsquare(arr: np.ndarray, factor: tuple = (1., 2.),
                                     center_data: bool = False) -> np.ndarray:
    """
    Interpolate 2d array (can be complex-valued, should be well-sampled)

    :param arr: numpy array of rectangular size
    :param factor: tuple of interpolation factor (2 for doubling of the number of sampling points)
    :param center_data: automatic centering of the data
    :return: interpolated array
    """
    ny, nx = arr.shape
    pdy, pdx = int((ny / 2) * (factor[0] - 1)), int((nx / 2) * (factor[1] - 1))
    if center_data:
        _arr = roll_to_pos(arr, ny // 2, nx // 2, move_maximum=True)
        _arr = roll_to_pos(_arr, ny // 2, nx // 2, )
    fd = shifted_fft(arr)
    tmp = np.pad(fd, ((pdy, pdy), (pdx, pdx)), mode='constant')
    return shifted_ifft(tmp)


def fourier_interpolate_3d_nonsquare(arr: np.ndarray, factor: tuple = (1., 2., 3.)) -> np.ndarray:
    """
    Interpolate 3d data (can be complex-valued, should be oversampled)

    :param arr: dataset to be interpolated.
    :param factor: interpolation factor (2 for doubling of the number of sampling points)
    :return: interpolated data
    """
    ny, nx, nz = arr.shape
    pdy, pdx, pdz = int((ny / 2) * (factor[0] - 1)), int((nx / 2) * (factor[1] - 1)), int((nz / 2) * (factor[2] - 1))
    fd = shifted_fft(arr)
    tmp = np.pad(fd, ((pdy, pdy), (pdx, pdx), (pdz, pdz)), mode='constant')
    return shifted_ifft(tmp)


def fourier_interpolate_2d(arr: np.ndarray, factor: float = 2., center_data=False) -> np.ndarray:
    """
    Interpolate 2d array (can be complex-valued, should be oversampled)

    :param arr: numpy array of rectangular size
    :param factor: interpolation factor (2 for doubling of the number of sampling points)
    :param center_data: automatic centering of the data
    :return: interpolated array
    """
    ny, nx = arr.shape
    assert ny == nx, "Accepts only rectangular arrays currently"
    pd = int((ny / 2) * (factor - 1))
    if center_data:
        _arr = roll_to_pos(arr, ny // 2, nx // 2, move_maximum=True)
        _arr = roll_to_pos(_arr, ny // 2, nx // 2, )
    fd = shifted_fft(arr)
    tmp = np.pad(fd, ((pd, pd), (pd, pd)), mode='constant')
    return shifted_ifft(tmp)


def fourier_interpolate_3d(arr: np.ndarray, center_data: bool = False, factor: float = 2.) -> np.ndarray:
    """
    Interpolate 3d data (can be complex-valued, should be oversampled)

    :param arr: dataset to be interpolated.
    :param center_data: not implemented: would enable automatic centering of the data
    :param factor: interpolation factor (2 for doubling of the number of sampling points)
    :return: interpolated data
    """
    ny, nx, nz = arr.shape
    pd = int((ny / 2) * (factor - 1))
    if center_data:
        raise NotImplementedError
    else:
        _arr = arr
    fd = shifted_fft(arr)
    tmp = np.pad(fd, ((pd, pd), (pd, pd), (pd, pd)), mode='constant')
    return shifted_ifft(tmp)