Commit cf48ce2d authored by s.gretchko's avatar s.gretchko
Browse files

Fixed magproj calls (switched argument order) in P_M.py

parent 2679ce47
...@@ -37,7 +37,7 @@ class P_M(ProxOperator): ...@@ -37,7 +37,7 @@ class P_M(ProxOperator):
""" """
m = self.data m = self.data
a = self.prop.eval(u) a = self.prop.eval(u)
b = magproj(a, m) b = magproj(m, a)
return self.invprop.eval(b) return self.invprop.eval(b)
...@@ -70,7 +70,7 @@ class P_M_masked(P_M): ...@@ -70,7 +70,7 @@ class P_M_masked(P_M):
array_like - p_M: the projection IN THE PHYSICAL (time) DOMAIN array_like - p_M: the projection IN THE PHYSICAL (time) DOMAIN
""" """
fourier_space_iterate = self.prop.eval(u) fourier_space_iterate = self.prop.eval(u)
constrained = magproj(fourier_space_iterate.copy(), self.data) constrained = magproj(self.data, fourier_space_iterate.copy())
update = where(self.mask, fourier_space_iterate, constrained) update = where(self.mask, fourier_space_iterate, constrained)
return self.invprop(update) return self.invprop(update)
...@@ -97,7 +97,7 @@ class Approx_P_M(P_M): ...@@ -97,7 +97,7 @@ class Approx_P_M(P_M):
# Now see that the propagated field is within the ball around the data (if any). # Now see that the propagated field is within the ball around the data (if any).
# If not, the projection is calculated, otherwise we do nothing. # If not, the projection is calculated, otherwise we do nothing.
if h_u_hat >= self.data_ball + self.TOL2: if h_u_hat >= self.data_ball + self.TOL2:
b = magproj(u_hat, self.data) b = magproj(self.data, u_hat)
return self.invprop.eval(b) return self.invprop.eval(b)
else: else:
return u return u
...@@ -124,7 +124,7 @@ class Approx_P_M_masked(P_M_masked): ...@@ -124,7 +124,7 @@ class Approx_P_M_masked(P_M_masked):
# Now see that the propagated field is within the ball around the data (if any). # Now see that the propagated field is within the ball around the data (if any).
# If not, the projection is calculated, otherwise we do nothing. # If not, the projection is calculated, otherwise we do nothing.
if h_u_hat >= self.data_ball + self.TOL2: if h_u_hat >= self.data_ball + self.TOL2:
constrained = magproj(u_hat.copy(), self.data) # Apply constraint constrained = magproj(self.data, u_hat.copy()) # Apply constraint
update = where(self.mask, u_hat, constrained) # Masking operation update = where(self.mask, u_hat, constrained) # Masking operation
return self.invprop(update) # Propagate back return self.invprop(update) # Propagate back
else: else:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment