Commit 0f544b26 authored by Matthijs's avatar Matthijs
Browse files

simple parts from orbital_tomography-3d branch

parent 98cad1a1
......@@ -13,6 +13,16 @@ from numpy import exp
class DRl(SimpleAlgorithm):
def __init__(self, config):
super(DRl, self).__init__(config)
for parameter_key_extension in ['_0', '_max', '_switch']:
key = 'lambda'+parameter_key_extension
alternative = 'beta'+parameter_key_extension
try:
if key not in config:
config[key] = config[alternative]
except KeyError:
raise KeyError('config should contain %s, or %s must be given as alternative')
def evaluate(self, u):
iter = self.config['iter'] + 1
......
from .array_tools import *
from .binning import *
from .interpolation import *
import numpy as np
from scipy.ndimage import measurements
import scipy.fftpack as fft
from numpy import roll, ndarray, floor, iscomplexobj, round
from scipy.ndimage.measurements import maximum_position, center_of_mass
from scipy.fftpack import fftn, fftshift, ifftn, ifftshift
__all__ = ["shift_array", 'roll_to_pos', 'shifted_ifft', 'shifted_fft']
def shift_array(arr, dy, dx):
temp = np.roll(arr, (dy, dx), (0, 1))
def shift_array(arr: ndarray, dy: int, dx: int):
"""
Use numpy.roll to shift an array in the first and second dimensions
:param arr: numpy array
:param dy: shift in first dimension
:param dx: shift in second dimension
:return: array like arr
"""
temp = roll(arr, (dy, dx), (0, 1))
return temp
def roll_to_pos(arr: np.ndarray, y: int = 0, x: int = 0, pos: tuple = None, move_maximum: bool = False,
by_abs_val: bool = True) -> np.ndarray:
def roll_to_pos(arr: ndarray, y: int = 0, x: int = 0, pos: tuple = None, move_maximum: bool = False,
by_abs_val: bool = True) -> ndarray:
"""
Shift the center of mass of an array to the given position by cyclic permutation
......@@ -22,21 +31,23 @@ def roll_to_pos(arr: np.ndarray, y: int = 0, x: int = 0, pos: tuple = None, move
:return: array like original
"""
if move_maximum:
if by_abs_val or arr.dtype in [np.complex64, np.complex128]:
old = np.floor(measurements.maximum_position(abs(arr)))
if by_abs_val or iscomplexobj(arr):
old = floor(maximum_position(abs(arr)))
else:
old = np.floor(measurements.maximum_position(arr))
old = floor(maximum_position(arr))
else:
if by_abs_val or arr.dtype in [np.complex64, np.complex128]:
old = np.floor(measurements.center_of_mass(abs(arr)))
if by_abs_val or iscomplexobj(arr):
old = floor(center_of_mass(abs(arr)))
else:
old = np.floor(measurements.center_of_mass(arr))
old = floor(center_of_mass(arr))
if pos is not None: # dimension-independent method
shifts = tuple([int(np.round(pos[i]-old[i])) for i in range(len(pos))])
shifts = tuple([int(round(pos[i] - old[i])) for i in range(len(pos))])
dims = tuple([i for i in range(len(pos))])
temp = np.roll(arr, shift=shifts, axis=dims)
temp = roll(arr, shift=shifts, axis=dims)
else: # old method
temp = shift_array(arr, int(y - old[0]), int(x - old[1]))
if temp.shape != arr.shape:
raise Exception('Non-matching input and output shapes')
return temp
......@@ -52,7 +63,7 @@ def shifted_fft(arr, axes=None):
transformed array
"""
return fft.ifftshift(fft.fftn(fft.fftshift(arr, axes=axes), axes=axes), axes=axes)
return ifftshift(fftn(fftshift(arr, axes=axes), axes=axes), axes=axes)
def shifted_ifft(arr, axes=None):
......@@ -66,4 +77,4 @@ def shifted_ifft(arr, axes=None):
Returns:
transformed array
"""
return fft.fftshift(fft.ifftn(fft.ifftshift(arr, axes=axes), axes=axes), axes=axes)
return fftshift(ifftn(ifftshift(arr, axes=axes), axes=axes), axes=axes)
import numpy as np
__all__ = ['bin_array']
def bin_array(arr: np.ndarray, new_shape: any, pad_zeros=True) -> np.ndarray:
"""
......@@ -20,7 +22,7 @@ def bin_array(arr: np.ndarray, new_shape: any, pad_zeros=True) -> np.ndarray:
padding = tuple([(0, (binfactor[i] - s % binfactor[i]) % binfactor[i]) for i, s in enumerate(arr.shape)])
if pad_zeros and np.any(np.array(padding) != 0):
_arr = np.pad(arr, padding, mode='constant', constant_values=0) # pad array
_shape = tuple([s//binfactor[i] for i, s in enumerate(_arr.shape)]) # update binned size due to padding
_shape = tuple([s // binfactor[i] for i, s in enumerate(_arr.shape)]) # update binned size due to padding
else:
_arr = arr # expected to fail if padding has non-zeros
# send to 2d or 3d padding functions
......
import numpy as np
from .array_tools import shifted_fft, shifted_ifft, roll_to_pos
__all__ = ['fourier_interpolate']
def fourier_interpolate(arr: np.ndarray, factor: any = 2., **kwargs) -> np.ndarray:
"""
......
......@@ -112,5 +112,5 @@ setup(
},
# Some packages which are required for the good operation of ProxPython
install_requires=['numpy', 'scikit-image', 'matplotlib']
install_requires=['numpy', 'scikit-image', 'matplotlib', 'scipy']
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment