import numpy as np from .array_tools import shifted_fft, shifted_ifft, roll_to_pos __all__ = ['fourier_interpolate'] def fourier_interpolate(arr: np.ndarray, factor: any = 2., **kwargs) -> np.ndarray: """ Interpolation by padding in the Fourier domain. Keyword arguments get passed to sub-functionalities. Options are for example: - center_data: whether to move the center of mass to the center of the array before interpolation, this is generally a good idea NOTE: testing by simple 2x interpolation and subsequent 2x2 binning shows that errors are introduced in the procedure :param arr: numpy array, 2 or 3-dimensional :param factor: tuple or float to indicate by which factor to interpolate, default to factor 2 :returns: interpolated array, cast to identical dtype as arr """ try: if len(factor) == 2: out = fourier_interpolate_2d_nonsquare(arr, factor=factor, **kwargs) elif len(factor) == 3: out = fourier_interpolate_3d_nonsquare(arr, factor=factor, **kwargs) else: raise Exception("This error should never raise") except TypeError: if len(arr.shape) == 2: out = fourier_interpolate_2d(arr, factor=factor, **kwargs) elif len(arr.shape) == 3: out = fourier_interpolate_3d(arr, factor=factor, **kwargs) else: raise NotImplementedError('Can only interpolate 2d or 3d arrays as of now') return out.astype(arr.dtype, casting='unsafe') def fourier_interpolate_2d_nonsquare(arr: np.ndarray, factor: tuple = (1., 2.), center_data: bool = False) -> np.ndarray: """ Interpolate 2d array (can be complex-valued, should be well-sampled) :param arr: numpy array of rectangular size :param factor: tuple of interpolation factor (2 for doubling of the number of sampling points) :param center_data: automatic centering of the data :return: interpolated array """ ny, nx = arr.shape pdy, pdx = int((ny / 2) * (factor[0] - 1)), int((nx / 2) * (factor[1] - 1)) if center_data: _arr = roll_to_pos(arr, ny // 2, nx // 2, move_maximum=True) _arr = roll_to_pos(_arr, ny // 2, nx // 2, ) fd = shifted_fft(arr) tmp = np.pad(fd, ((pdy, pdy), (pdx, pdx)), mode='constant') return shifted_ifft(tmp) def fourier_interpolate_3d_nonsquare(arr: np.ndarray, factor: tuple = (1., 2., 3.)) -> np.ndarray: """ Interpolate 3d data (can be complex-valued, should be oversampled) :param arr: dataset to be interpolated. :param factor: interpolation factor (2 for doubling of the number of sampling points) :return: interpolated data """ ny, nx, nz = arr.shape pdy, pdx, pdz = int((ny / 2) * (factor[0] - 1)), int((nx / 2) * (factor[1] - 1)), int((nz / 2) * (factor[2] - 1)) fd = shifted_fft(arr) tmp = np.pad(fd, ((pdy, pdy), (pdx, pdx), (pdz, pdz)), mode='constant') return shifted_ifft(tmp) def fourier_interpolate_2d(arr: np.ndarray, factor: float = 2., center_data=False) -> np.ndarray: """ Interpolate 2d array (can be complex-valued, should be oversampled) :param arr: numpy array of rectangular size :param factor: interpolation factor (2 for doubling of the number of sampling points) :param center_data: automatic centering of the data :return: interpolated array """ ny, nx = arr.shape assert ny == nx, "Accepts only rectangular arrays currently" pd = int((ny / 2) * (factor - 1)) if center_data: _arr = roll_to_pos(arr, ny // 2, nx // 2, move_maximum=True) _arr = roll_to_pos(_arr, ny // 2, nx // 2, ) fd = shifted_fft(arr) tmp = np.pad(fd, ((pd, pd), (pd, pd)), mode='constant') return shifted_ifft(tmp) def fourier_interpolate_3d(arr: np.ndarray, center_data: bool = False, factor: float = 2.) -> np.ndarray: """ Interpolate 3d data (can be complex-valued, should be oversampled) :param arr: dataset to be interpolated. :param center_data: not implemented: would enable automatic centering of the data :param factor: interpolation factor (2 for doubling of the number of sampling points) :return: interpolated data """ ny, nx, nz = arr.shape pd = int((ny / 2) * (factor - 1)) if center_data: raise NotImplementedError else: _arr = arr fd = shifted_fft(arr) tmp = np.pad(fd, ((pd, pd), (pd, pd), (pd, pd)), mode='constant') return shifted_ifft(tmp)