Commit 4b71000c authored by markus.meier01's avatar markus.meier01
Browse files

Fixed dimension error in phase.py

parent e5112f74
...@@ -127,17 +127,19 @@ class P_diag(ProxOperator): ...@@ -127,17 +127,19 @@ class P_diag(ProxOperator):
m = self.m; m = self.m;
p = self.p; p = self.p;
K = self.K; K = self.K;
print(u.shape)
if m == 1: if m == 1:
tmp = sum(u, axis=0, dtype=u.dtype) tmp = sum(u, axis=0, dtype=u.dtype)
elif n == 1: elif n == 1:
tmp = sum(u, axis=1, dtype=u.dtype) tmp = sum(u, axis=1, dtype=u.dtype)
elif p == 1: elif p == 1:
tmp = zeros((n, m), dtype=u.dtype) tmp = zeros((m,n), dtype=u.dtype)
for k in range(K): for k in range(K):
tmp += u[:, :, k] tmp += u[:, :, k]
print(tmp)
else: else:
tmp = zeros((n, m, p), dtype=u.dtype) tmp = zeros((m, n, p), dtype=u.dtype)
for k in range(K): for k in range(K):
tmp += u[:, :, :, k] tmp += u[:, :, :, k]
...@@ -148,9 +150,10 @@ class P_diag(ProxOperator): ...@@ -148,9 +150,10 @@ class P_diag(ProxOperator):
elif n == 1: elif n == 1:
return tmp.reshape(tmp.size, 1) @ ones((1, K), dtype=u.dtype) return tmp.reshape(tmp.size, 1) @ ones((1, K), dtype=u.dtype)
elif p == 1: elif p == 1:
u_diag = empty((n, m, K), dtype=u.dtype) u_diag = empty((m,n, K), dtype=u.dtype)
for k in range(K): for k in range(K):
u_diag[:, :, k] = tmp u_diag[:, :, k] = tmp
print(u_diag.shape)
return u_diag return u_diag
else: else:
u_diag = empty((n, m, p, K), dtype=u.dtype) u_diag = empty((n, m, p, K), dtype=u.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment