Commit 6606b038 by jansen31

fixed orthonormality constraint

parent 59bfba7a
 ... ... @@ -21,17 +21,24 @@ class P_orthonorm(ProxOperator): norms = [sqrt(sum(abs(u) ** 2)) for u in u_norm] # determine angle _a_ between u[0] and u[1] a = sum(u_norm[0] * u_norm[1]) / (norms[0] * norms[1]) if a != 0: # for non-orthogonal iterates, apply change # determine root of y^3 - 3/2 a y^2 + 1/2 a = 0 y_part = cbrt(2 * sqrt(a ** 4 + a ** 2) - 2 * a - a ** 3) y = 0.5 * (a ** 2 / y_part + y_part - a) # apply projection u_new = zeros_like(u_norm) u_new[0] = u_norm[0] - (y / (y ** 2 - 1)) * (u_norm[1] - y * u_norm[0]) u_new[1] = (1 / (y ** 2 - 1)) * (u_norm[1] - y * u_norm[0]) return u_new else: if a == 0: # for already orthogonal iterates, apply no change return u_norm elif a == -1 or a == 1: # Cannot do anything for parallel vectors return u_norm # determine roots of y^2 - 2/a y + 1 = 0 elif a > 0: y = 1 / a + sqrt(1 / a ** 2 - 1) elif a < 0: y = 1 / a - sqrt(1 / a ** 2 - 1) else: raise Exception("This should never rise, check calculation of a") # apply projection u_new = zeros_like(u_norm) u_new[0] = u_norm[0] - (y / (y ** 2 - 1)) * (y * u_norm[0] - u_norm[1]) u_new[1] = (1 / (y ** 2 - 1)) * (y * u_norm[0] - u_norm[1]) return u_new class P_norm(ProxOperator): ... ... @@ -63,9 +70,11 @@ if __name__ == "__main__": portho = P_orthonorm(exp) pnorm = P_norm(exp) th = 0.55 inp = np.array([[1, 0], [np.cos(th * np.pi), np.sin(th * np.pi)]]) out = portho.eval(inp) print("Input:", inp) print("Output:", out) print("Inner product of the output: ", np.sum(out[0]*out[1]) ) \ No newline at end of file checksum = [] thlist = np.arange(0.11, 1.51, 0.1) for th in thlist: inp = np.array([[1, 0], [np.cos(th * np.pi), np.sin(th * np.pi)]]) out = portho.eval(inp) checksum += [np.sum(out[0] * out[1])] assert np.sum(checksum) < 1e-14, "Test failed" print("Test passed")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment