array_tools.py 4.06 KB
Newer Older
1
from numpy import roll, ndarray, floor, iscomplexobj, round, any, isnan, nan_to_num
Matthijs's avatar
Matthijs committed
2
3
from scipy.ndimage.measurements import maximum_position, center_of_mass
from scipy.fftpack import fftn, fftshift, ifftn, ifftshift
4
from warnings import warn
5
from numpy.lib.stride_tricks import as_strided
6

7
__all__ = ["shift_array", 'roll_to_pos', 'shifted_ifft', 'shifted_fft', 'tile_array']
Matthijs's avatar
Matthijs committed
8

9

Matthijs's avatar
Matthijs committed
10
11
12
13
14
15
16
17
18
def shift_array(arr: ndarray, dy: int, dx: int):
    """
    Use numpy.roll to shift an array in the first and second dimensions
    :param arr: numpy array
    :param dy: shift in first dimension
    :param dx: shift in second dimension
    :return: array like arr
    """
    temp = roll(arr, (dy, dx), (0, 1))
19
20
21
    return temp


Matthijs's avatar
Matthijs committed
22
23
def roll_to_pos(arr: ndarray, y: int = 0, x: int = 0, pos: tuple = None, move_maximum: bool = False,
                by_abs_val: bool = True) -> ndarray:
Matthijs's avatar
Matthijs committed
24
25
26
27
28
29
30
31
32
33
34
    """
    Shift the center of mass of an array to the given position by cyclic permutation

    :param arr: 2d array, works best for well-centered feature with limited support
    :param y: position parameter
    :param x: position parameter for second dimension
    :param pos: tuple with the new position, overriding y,x values. should be used for higher-dimensional arrays
    :param move_maximum: if true, look only at max-value
    :param by_abs_val: take abs value for the determination of max-val or center-of-mass
    :return: array like original
    """
35
    if move_maximum:
Matthijs's avatar
Matthijs committed
36
37
        if by_abs_val or iscomplexobj(arr):
            old = floor(maximum_position(abs(arr)))
38
        else:
Matthijs's avatar
Matthijs committed
39
            old = floor(maximum_position(arr))
40
    else:
Matthijs's avatar
Matthijs committed
41
42
        if by_abs_val or iscomplexobj(arr):
            old = floor(center_of_mass(abs(arr)))
43
        else:
Matthijs's avatar
Matthijs committed
44
            old = floor(center_of_mass(arr))
45
46
47
    if any(isnan(old)):
        old = nan_to_num(old)
        warn(Warning("Unexpected error in the calculation of the center of mass, casting NaNs to num"))
Matthijs's avatar
Matthijs committed
48
    if pos is not None:  # dimension-independent method
Matthijs's avatar
Matthijs committed
49
        shifts = tuple([int(round(pos[i] - old[i])) for i in range(len(pos))])
Matthijs's avatar
Matthijs committed
50
        dims = tuple([i for i in range(len(pos))])
Matthijs's avatar
Matthijs committed
51
        temp = roll(arr, shift=shifts, axis=dims)
Matthijs's avatar
Matthijs committed
52
53
    else:  # old method
        temp = shift_array(arr, int(y - old[0]), int(x - old[1]))
54
55
    if temp.shape != arr.shape:
        raise Exception('Non-matching input and output shapes')
56
57
58
    return temp


59
60
61
62
63
64
65
def shifted_fft(arr, axes=None):
    """
    Combined fftshift and fft routine, based on scipy.fftpack

    Args:
        arr: numpy array
        axes: identical to  argument for scipy.fftpack.fft
66

67
68
69
    Returns:
        transformed array
    """
70

Matthijs's avatar
Matthijs committed
71
    return ifftshift(fftn(fftshift(arr, axes=axes), axes=axes), axes=axes)
72
73


74
def shifted_ifft(arr, axes=None):
75
    """
76
77
78
79
80
    Combined fftshift and fft routine, based on scipy.fftpack

    Args:
        arr: numpy array
        axes: identical to  argument for scipy.fftpack.fft
81

82
83
    Returns:
        transformed array
84
    """
Matthijs's avatar
Matthijs committed
85
    return fftshift(ifftn(ifftshift(arr, axes=axes), axes=axes), axes=axes)
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117


def tile_array(a: ndarray, shape):
    """
    Upsample an array by nearest-neighbour interpolation, i.e. [1,2] -> [1,1,2,2]
    :param a: numpy array, ndim = [2,3]
    :param shape: tile size, single integer for rectangular tiles, tuple for individual axes otherwise
    :return: resampled array
    """
    if a.ndim == 2:
        try:
            b0, b1 = shape
        except TypeError:
            b0 = shape
            b1 = shape
        r, c = a.shape  # number of rows/columns
        rs, cs = a.strides  # row/column strides
        x = as_strided(a, (r, b0, c, b1), (rs, 0, cs, 0))  # view a as larger 4D array
        return x.reshape(r * b0, c * b1)  # create new 2D array
    elif a.ndim == 3:
        try:
            b0, b1, b2 = shape
        except TypeError:
            b0 = shape
            b1 = shape
            b2 = shape
        x, y, z = a.shape
        xs, ys, zs = a.strides
        temp = as_strided(a, (x, b0, y, b1, z, b2), (xs, 0, ys, 0, zs, 0))
        return temp.reshape((x * b0, y * b1, z * b2))
    else:
        raise NotImplementedError("Arrays of dimensions other than 2 and 3 are not implemented yet")