array_tools.py 2.68 KB
Newer Older
1
2
3
from numpy import roll, ndarray, floor, iscomplexobj, round
from scipy.ndimage.measurements import maximum_position, center_of_mass
from scipy.fftpack import fftn, fftshift, ifftn, ifftshift
4

5
__all__ = ["shift_array", 'roll_to_pos', 'shifted_ifft', 'shifted_fft']
6

7
8
9
10
11
12
13
14
15
16

def shift_array(arr: ndarray, dy: int, dx: int):
    """
    Use numpy.roll to shift an array in the first and second dimensions
    :param arr: numpy array
    :param dy: shift in first dimension
    :param dx: shift in second dimension
    :return: array like arr
    """
    temp = roll(arr, (dy, dx), (0, 1))
17
18
19
    return temp


20
21
def roll_to_pos(arr: ndarray, y: int = 0, x: int = 0, pos: tuple = None, move_maximum: bool = False,
                by_abs_val: bool = True) -> ndarray:
Matthijs's avatar
Matthijs committed
22
23
24
25
26
27
28
29
30
31
32
    """
    Shift the center of mass of an array to the given position by cyclic permutation

    :param arr: 2d array, works best for well-centered feature with limited support
    :param y: position parameter
    :param x: position parameter for second dimension
    :param pos: tuple with the new position, overriding y,x values. should be used for higher-dimensional arrays
    :param move_maximum: if true, look only at max-value
    :param by_abs_val: take abs value for the determination of max-val or center-of-mass
    :return: array like original
    """
33
    if move_maximum:
34
35
        if by_abs_val or iscomplexobj(arr):
            old = floor(maximum_position(abs(arr)))
36
        else:
37
            old = floor(maximum_position(arr))
38
    else:
39
40
        if by_abs_val or iscomplexobj(arr):
            old = floor(center_of_mass(abs(arr)))
41
        else:
42
            old = floor(center_of_mass(arr))
Matthijs's avatar
Matthijs committed
43
    if pos is not None:  # dimension-independent method
44
        shifts = tuple([int(round(pos[i] - old[i])) for i in range(len(pos))])
Matthijs's avatar
Matthijs committed
45
        dims = tuple([i for i in range(len(pos))])
46
        temp = roll(arr, shift=shifts, axis=dims)
Matthijs's avatar
Matthijs committed
47
48
    else:  # old method
        temp = shift_array(arr, int(y - old[0]), int(x - old[1]))
49
50
    if temp.shape != arr.shape:
        raise Exception('Non-matching input and output shapes')
51
52
53
    return temp


54
55
56
57
58
59
60
def shifted_fft(arr, axes=None):
    """
    Combined fftshift and fft routine, based on scipy.fftpack

    Args:
        arr: numpy array
        axes: identical to  argument for scipy.fftpack.fft
61

62
63
64
    Returns:
        transformed array
    """
65

66
    return ifftshift(fftn(fftshift(arr, axes=axes), axes=axes), axes=axes)
67
68


69
def shifted_ifft(arr, axes=None):
70
    """
71
72
73
74
75
    Combined fftshift and fft routine, based on scipy.fftpack

    Args:
        arr: numpy array
        axes: identical to  argument for scipy.fftpack.fft
76

77
78
    Returns:
        transformed array
79
    """
80
    return fftshift(ifftn(ifftshift(arr, axes=axes), axes=axes), axes=axes)