Commit 73a6277f authored by Matthijs's avatar Matthijs
Browse files

rearrange imports, documentation

parent 387f4ab0
import numpy as np
from scipy.ndimage import measurements
import scipy.fftpack as fft
from numpy import roll, ndarray, floor, iscomplexobj, round
from scipy.ndimage.measurements import maximum_position, center_of_mass
from scipy.fftpack import fftn, fftshift, ifftn, ifftshift
def shift_array(arr, dy, dx):
temp = np.roll(arr, (dy, dx), (0, 1))
def shift_array(arr: ndarray, dy: int, dx: int):
"""
Use numpy.roll to shift an array in the first and second dimensions
:param arr: numpy array
:param dy: shift in first dimension
:param dx: shift in second dimension
:return: array like arr
"""
temp = roll(arr, (dy, dx), (0, 1))
return temp
def roll_to_pos(arr: np.ndarray, y: int = 0, x: int = 0, pos: tuple = None, move_maximum: bool = False,
by_abs_val: bool = True) -> np.ndarray:
def roll_to_pos(arr: ndarray, y: int = 0, x: int = 0, pos: tuple = None, move_maximum: bool = False,
by_abs_val: bool = True) -> ndarray:
"""
Shift the center of mass of an array to the given position by cyclic permutation
......@@ -22,19 +29,19 @@ def roll_to_pos(arr: np.ndarray, y: int = 0, x: int = 0, pos: tuple = None, move
:return: array like original
"""
if move_maximum:
if by_abs_val or arr.dtype in [np.complex64, np.complex128]:
old = np.floor(measurements.maximum_position(abs(arr)))
if by_abs_val or iscomplexobj(arr):
old = floor(maximum_position(abs(arr)))
else:
old = np.floor(measurements.maximum_position(arr))
old = floor(maximum_position(arr))
else:
if by_abs_val or arr.dtype in [np.complex64, np.complex128]:
old = np.floor(measurements.center_of_mass(abs(arr)))
if by_abs_val or iscomplexobj(arr):
old = floor(center_of_mass(abs(arr)))
else:
old = np.floor(measurements.center_of_mass(arr))
old = floor(center_of_mass(arr))
if pos is not None: # dimension-independent method
shifts = tuple([int(np.round(pos[i]-old[i])) for i in range(len(pos))])
shifts = tuple([int(round(pos[i] - old[i])) for i in range(len(pos))])
dims = tuple([i for i in range(len(pos))])
temp = np.roll(arr, shift=shifts, axis=dims)
temp = roll(arr, shift=shifts, axis=dims)
else: # old method
temp = shift_array(arr, int(y - old[0]), int(x - old[1]))
if temp.shape != arr.shape:
......@@ -54,7 +61,7 @@ def shifted_fft(arr, axes=None):
transformed array
"""
return fft.ifftshift(fft.fftn(fft.fftshift(arr, axes=axes), axes=axes), axes=axes)
return ifftshift(fftn(fftshift(arr, axes=axes), axes=axes), axes=axes)
def shifted_ifft(arr, axes=None):
......@@ -68,4 +75,4 @@ def shifted_ifft(arr, axes=None):
Returns:
transformed array
"""
return fft.fftshift(fft.ifftn(fft.ifftshift(arr, axes=axes), axes=axes), axes=axes)
return fftshift(ifftn(ifftshift(arr, axes=axes), axes=axes), axes=axes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment