Commit 75209682 authored by Matthijs's avatar Matthijs
Browse files

compatibility of sparsity with non-2d arrays

parent df972ad2
from numpy import zeros_like from numpy import zeros_like, unravel_index
import numpy as np import numpy as np
from .proxoperators import ProxOperator from .proxoperators import ProxOperator
...@@ -29,14 +29,14 @@ class P_Sparsity(ProxOperator): ...@@ -29,14 +29,14 @@ class P_Sparsity(ProxOperator):
if self.sparsity_parameter > 30: if self.sparsity_parameter > 30:
def value_selection(original, indices, sparsity_parameter): def value_selection(original, indices, sparsity_parameter):
idx_for_threshold = divmod(indices[-sparsity_parameter], self.ny) idx_for_threshold = unravel_index(indices[-sparsity_parameter], original.shape)
threshold_val = abs(original[idx_for_threshold].get()) threshold_val = abs(original[idx_for_threshold].get())
return (abs(original) >= threshold_val) * original return (abs(original) >= threshold_val) * original
else: else:
def value_selection(original, indices, sparsity_parameter): def value_selection(original, indices, sparsity_parameter):
out = zeros_like(original) out = zeros_like(original)
hits = indices[-sparsity_parameter:].get() hits = indices[-sparsity_parameter:].get()
hit_idx = [divmod(hit, self.ny) for hit in hits] hit_idx = [unravel_index(hit, original.shape) for hit in hits]
for _idx in hit_idx: for _idx in hit_idx:
out[_idx[0], _idx[1]] = original[_idx[0], _idx[1]] out[_idx[0], _idx[1]] = original[_idx[0], _idx[1]]
return out return out
...@@ -56,7 +56,7 @@ class P_Sparsity(ProxOperator): ...@@ -56,7 +56,7 @@ class P_Sparsity(ProxOperator):
u *= self.support # apply support u *= self.support # apply support
p_Sparsity = 0 * u p_Sparsity = 0 * u
sorting = np.argsort(abs(u), axis=None) # gives indices of sorted array in ascending order sorting = np.argsort(abs(u), axis=None) # gives indices of sorted array in ascending order
indices = np.asarray([divmod(sorting[i], self.ny) for i in range(-1 * self.sparsity_parameter, 0)]) indices = np.asarray([unravel_index(sorting[i], u) for i in range(-1 * self.sparsity_parameter, 0)])
p_Sparsity[indices[:, 0], indices[:, 1]] = u[indices[:, 0], indices[:, 1]] p_Sparsity[indices[:, 0], indices[:, 1]] = u[indices[:, 0], indices[:, 1]]
return p_Sparsity return p_Sparsity
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment