### fixed orthonormality constraint

parent 59bfba7a
 ... ... @@ -21,17 +21,24 @@ class P_orthonorm(ProxOperator): norms = [sqrt(sum(abs(u) ** 2)) for u in u_norm] # determine angle _a_ between u and u a = sum(u_norm * u_norm) / (norms * norms) if a != 0: # for non-orthogonal iterates, apply change # determine root of y^3 - 3/2 a y^2 + 1/2 a = 0 y_part = cbrt(2 * sqrt(a ** 4 + a ** 2) - 2 * a - a ** 3) y = 0.5 * (a ** 2 / y_part + y_part - a) # apply projection u_new = zeros_like(u_norm) u_new = u_norm - (y / (y ** 2 - 1)) * (u_norm - y * u_norm) u_new = (1 / (y ** 2 - 1)) * (u_norm - y * u_norm) return u_new else: if a == 0: # for already orthogonal iterates, apply no change return u_norm elif a == -1 or a == 1: # Cannot do anything for parallel vectors return u_norm # determine roots of y^2 - 2/a y + 1 = 0 elif a > 0: y = 1 / a + sqrt(1 / a ** 2 - 1) elif a < 0: y = 1 / a - sqrt(1 / a ** 2 - 1) else: raise Exception("This should never rise, check calculation of a") # apply projection u_new = zeros_like(u_norm) u_new = u_norm - (y / (y ** 2 - 1)) * (y * u_norm - u_norm) u_new = (1 / (y ** 2 - 1)) * (y * u_norm - u_norm) return u_new class P_norm(ProxOperator): ... ... @@ -63,9 +70,11 @@ if __name__ == "__main__": portho = P_orthonorm(exp) pnorm = P_norm(exp) th = 0.55 inp = np.array([[1, 0], [np.cos(th * np.pi), np.sin(th * np.pi)]]) out = portho.eval(inp) print("Input:", inp) print("Output:", out) print("Inner product of the output: ", np.sum(out*out) ) \ No newline at end of file checksum = [] thlist = np.arange(0.11, 1.51, 0.1) for th in thlist: inp = np.array([[1, 0], [np.cos(th * np.pi), np.sin(th * np.pi)]]) out = portho.eval(inp) checksum += [np.sum(out * out)] assert np.sum(checksum) < 1e-14, "Test failed" print("Test passed")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!