Commit fdebd298 authored by Matthijs's avatar Matthijs
Browse files

added superpixel sparsity proxoperator (functioning)

parent b8fd790a
......@@ -13,8 +13,9 @@ def data_processor(config):
sensitivity = imread(config['sensitivity_filename'])
if 'crop data' in config:
cy, cx, cz = config['crop data']
inp = inp[cy:-1 - cy, cx:-1 - cx, cz:-1 - cz]
sensitivity = sensitivity[cy:-1 - cy, cx:-1 - cx, cz:-1 - cz]
ny, nx, nz = inp.shape
inp = inp[cy:ny - cy, cx:nx - cx, cz:nz - cz]
sensitivity = sensitivity[cy:ny - cy, cx:nx - cx, cz:nz - cz]
# Add the negative K_z part to determine a real-valued orbital?
if 'add_negative_kz' in config and config['add_negative_kz']:
if 'kz_dimension' not in config:
......@@ -89,7 +90,6 @@ def data_processor(config):
config['support viewer'] = support_viewer # the reference keeps the matplotlib object alive
if 'fourier_shift_arrays' in config and config['fourier_shift_arrays']:
print('Fourier shifting the arrays')
for key in ['u_0', 'support', 'sparsity_support', 'data', 'data_zeros']:
if key in config:
config[key] = np.fft.fftshift(config[key])
......
......@@ -62,11 +62,12 @@ new_config = {
'TOL': 1e-10,
# relaxaton parameters in RAAR, HPR and HAAR
'beta_0': 0.85, # starting relaxation prameter (only used with HAAR, HPR and RAAR)
'beta_max': 0.50, # maximum relaxation prameter (only used with HAAR, RAAR, and HPR)
'beta_0': 0.85, # starting relaxation parameter (only used with HAAR, HPR and RAAR)
'beta_max': 0.50, # maximum relaxation parameter (only used with HAAR, RAAR, and HPR)
'beta_switch': 30, # iteration at which beta moves from beta_0 -> beta_max
'sparsity_parameter': 100,
'superpixel size': 1,
'use_sparsity_with_support': True,
'symmetry_type': 1, # -1 for antisymmetric functions, 1 for symmetric ones.
'symmetry_axis': -1, # which axis is symmetric. (mirror plane perpendicular to this axis)
......
......@@ -6,8 +6,9 @@ from proxtoolbox import ProxOperators
from proxtoolbox.ProxOperators.proxoperators import ProxOperator
from proxtoolbox.Problems.OrbitalTomog import Graphics
from numpy.linalg import norm
from numpy import square, sqrt
from numpy.fft import fftshift
from numpy import square, sqrt, exp, pi
from numpy.fft import fftn, fftshift
from numpy.random import random_sample
from proxtoolbox.Utilities.OrbitalTomog import interpolation, array_tools, binning
......@@ -54,9 +55,15 @@ class Phase(Problem):
elif self.config['constraint'] == 'amplitude only':
used_proxoperators[0] = 'P_amp'
elif self.config['constraint'] == 'sparse real':
used_proxoperators[0] = 'P_Sparsity_real'
if 'superpixel size' in self.config and self.config['superpixel size'] != 1:
used_proxoperators[0] = 'P_Sparsity_Superpixel_real'
else:
used_proxoperators[0] = 'P_Sparsity_real'
elif self.config['constraint'] == 'sparse complex':
used_proxoperators[0] = 'P_Sparsity'
if 'superpixel size' in self.config and self.config['superpixel size'] != 1:
used_proxoperators[0] = 'P_Sparsity_Superpixel'
else:
used_proxoperators[0] = 'P_Sparsity'
elif self.config['constraint'] in ['symmetric sparse real', 'sparse symmetric real']:
used_proxoperators[0] = 'P_Sparsity_Symmetric_real'
elif self.config['constraint'] in ['symmetric sparse complex', 'symmetric sparse complex']:
......@@ -188,3 +195,6 @@ class Phase(Problem):
# TODO: some basic saving procedure. perhaps just a python pickle? (of the config and output dictionaries)
raise NotImplementedError
def random_guess(self):
ph_init = 2 * pi * random_sample(self.config['data'].shape)
self.config['u_0'] = fftn(self.config['data'] * exp(1j * ph_init))
......@@ -28,7 +28,7 @@ class P_Sparsity(ProxOperator):
else:
self.support = 1
if self.sparsity_parameter > 30 or len(config['u0'].shape) != 2:
if self.sparsity_parameter > 30 or len(config['u_0'].shape) != 2:
def value_selection(original, indices, sparsity_parameter):
idx_for_threshold = unravel_index(indices[-sparsity_parameter], original.shape)
threshold_val = abs(original[idx_for_threshold].get())
......@@ -173,9 +173,10 @@ class P_Sparsity_Superpixel(P_Sparsity):
super(P_Sparsity_Superpixel, self).__init__(config=config)
# Set the superpixel size:
self.superpixel_size = config['superpixel size']
# Bin the support:
if self.support is not 1:
self.support = bin_array(self.support, self.superpixel_size)
self.support /= max(self.support)
self.support = self.support / max(self.support)
# Assert that the binning+upsampling conserves the array size
test_shape = tile_array(bin_array(config['u_0'], self.superpixel_size, pad_zeros=False),
self.superpixel_size).shape
......@@ -183,9 +184,10 @@ class P_Sparsity_Superpixel(P_Sparsity):
# TODO: allow for padding, then cut of the remainder after tile_array
def work(self, u):
binned = bin_array(u, self.superpixel_size)
constrained = super(P_Sparsity_Superpixel, self).work(binned)
return tile_array(constrained, self.superpixel_size)
binned = bin_array(abs(u), self.superpixel_size)
sparse_array = super(P_Sparsity_Superpixel, self).work(binned)
mask = tile_array(sparse_array, self.superpixel_size, normalize=True) > 0
return np.where(mask, u, 0)
class P_Sparsity_Superpixel_real(P_Sparsity_Superpixel):
......@@ -193,5 +195,5 @@ class P_Sparsity_Superpixel_real(P_Sparsity_Superpixel):
Apply real-valued sparsity on superpixels, i.e. on the binned array
"""
def work(self, u):
super(P_Sparsity_Superpixel, self).work(u.real)
return super(P_Sparsity_Superpixel_real, self).work(u.real)
......@@ -85,7 +85,7 @@ def shifted_ifft(arr, axes=None):
return fftshift(ifftn(ifftshift(arr, axes=axes), axes=axes), axes=axes)
def tile_array(a: ndarray, shape):
def tile_array(a: ndarray, shape, normalize: bool = False):
"""
Upsample an array by nearest-neighbour interpolation, i.e. [1,2] -> [1,1,2,2]
:param a: numpy array, ndim = [2,3]
......@@ -98,10 +98,14 @@ def tile_array(a: ndarray, shape):
except TypeError:
b0 = shape
b1 = shape
if normalize:
norm = (b0 * b1)
else:
norm = 1
r, c = a.shape # number of rows/columns
rs, cs = a.strides # row/column strides
x = as_strided(a, (r, b0, c, b1), (rs, 0, cs, 0)) # view a as larger 4D array
return x.reshape(r * b0, c * b1) # create new 2D array
return x.reshape(r * b0, c * b1)/norm # create new 2D array
elif a.ndim == 3:
try:
b0, b1, b2 = shape
......@@ -109,9 +113,13 @@ def tile_array(a: ndarray, shape):
b0 = shape
b1 = shape
b2 = shape
if normalize:
norm = (b0*b1*b2)
else:
norm = 1
x, y, z = a.shape
xs, ys, zs = a.strides
temp = as_strided(a, (x, b0, y, b1, z, b2), (xs, 0, ys, 0, zs, 0))
return temp.reshape((x * b0, y * b1, z * b2))
return temp.reshape((x * b0, y * b1, z * b2))/norm
else:
raise NotImplementedError("Arrays of dimensions other than 2 and 3 are not implemented yet")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment