Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Menu
Open sidebar
nam
ProxPython
Commits
d7885374
Commit
d7885374
authored
Dec 17, 2019
by
markus.meier01
Browse files
fixed some things in Wirtinger algorithm
parent
e0f42d12
Changes
1
Hide whitespace changes
Inline
Side-by-side
proxtoolbox/Algorithms/Wirtinger.py
View file @
d7885374
...
...
@@ -4,50 +4,51 @@ from numpy.fft import fft, ifft, fft2, ifft2
from
scipy.linalg
import
norm
from
.algorithms
import
Algorithm
import
numpy
as
np
## Doesn't work yet
import
cmath
## Doesn't work yet
, relative error increases over the iterations instead of decreasing.
class
Wirtinger
(
Algorithm
):
def
__init__
(
self
,
config
):
self
.
Prox1
=
config
[
'proxoperators'
][
0
](
config
)
self
.
Prox2
=
config
[
'proxoperators'
][
1
](
config
)
normest
=
config
[
'norm_data'
]
self
.
norm_data
=
config
[
'norm_data'
]
self
.
Ny
=
config
[
'Ny'
]
self
.
Nx
=
config
[
'Nx'
]
self
.
Nz
=
config
[
'Nz'
]
L
=
config
[
'product_space_dimension'
]
u
=
config
[
'u_0'
]
self
.
product_space_dimension
=
config
[
'product_space_dimension'
]
self
.
u
=
config
[
'u_0'
]
if
'truth'
in
config
:
x
=
config
[
'truth'
]
x
_dim
=
config
[
'truth_dim'
]
normM
=
config
[
'norm_truth'
]
self
.
truth
=
config
[
'truth'
]
self
.
truth
_dim
=
config
[
'truth_dim'
]
self
.
norm_truth
=
config
[
'norm_truth'
]
MAXIT
=
config
[
'MAXIT'
]
TOL
=
config
[
'TOL'
]
Masks
=
np
.
zeros
((
self
.
Ny
,
self
.
Nx
,
L
))
if
'JWST'
in
config
:
Masks
=
np
.
zeros
(
Ny
,
Nx
,
L
)
indicator_ampl
=
config
[
'indicator_ampl'
]
illumination_phase
=
config
[
'illumination_phase'
]
for
j
in
range
(
L
):
Masks
[:,:,
j
]
=
indicator_ampl
*
exp
(
1j
*
illumination_phase
[:,:,
j
])
Y
=
config
[
'data_sq'
]
elif
'CDP'
in
config
:
Masks
=
config
[
'Masks'
]
Y
=
config
[
'data_sq'
]
else
:
Y
=
config
[
'data_sq'
]
normM
=
config
[
'norm_rt_data'
]
Masks
[
0
,
0
,
0
:
L
]
=
1
/
normM
## def run(self, u, TOL, MAXIT):
if
(
'experiment'
in
config
):
self
.
experiment
=
config
[
'experiment'
]
if
self
.
experiment
==
'CDP'
:
self
.
Masks
=
config
[
'Masks'
]
self
.
data_sq
=
config
[
'data_sq'
]
elif
self
.
experiment
==
'JWST'
:
self
.
Masks
=
np
.
zeros
(
self
.
Ny
,
self
.
Nx
,
self
.
product_space_dimension
)
self
.
indicator_ampl
=
config
[
'indicator_ampl'
]
self
.
illumination_phase
=
config
[
'illumination_phase'
]
self
.
data_sq
=
config
[
'data_sq'
]
for
j
in
range
(
product_space_dimension
):
Masks
[:,:,
j
]
=
indicator_ampl
*
exp
(
1j
*
illumination_phase
[:,:,
j
])
else
:
self
.
data_sq
=
config
[
'data_sq'
]
self
.
norm_rt_data
=
config
[
'norm_rt_data'
]
self
.
Masks
[
0
,
0
,
0
:
self
.
product_space_dimension
]
=
1
/
self
.
norm_truth
def
run
(
self
,
u
,
TOL
,
MAXIT
):
Masks
=
self
.
Masks
Prox1
=
self
.
Prox1
Prox2
=
self
.
Prox2
...
...
@@ -62,77 +63,90 @@ class Wirtinger(Algorithm):
q
=
u
.
shape
[
3
]
iter
=
0
change
=
zeros
(
MAXIT
+
1
,
dtype
=
u
.
dtype
)
change
=
zeros
(
MAXIT
,
dtype
=
u
.
dtype
)
change
[
0
]
=
999
gap
=
change
.
copy
()
if
self
.
Nx
==
1
:
A
=
lambda
I
:
fft
(
Masks
*
I
)
At
=
lambda
Y
:
mean
(
conj
(
Masks
)
*
ifft
(
Y
),
1
)
At
=
lambda
Y
:
np
.
repeat
(
mean
(
conj
(
Masks
)
*
ifft
(
Y
),
1
)
[:,
np
.
newaxis
],
self
.
product_space_dimension
,
axis
=
1
)
elif
self
.
Ny
==
1
:
A
=
lambda
I
:
fft
(
Masks
*
I
)
At
=
lambda
Y
:
mean
(
conj
(
Masks
)
*
ifft
(
Y
),
0
)
At
=
lambda
Y
:
np
.
repeat
(
mean
(
conj
(
Masks
)
*
ifft
(
Y
),
0
)
[
np
.
newaxis
,:],
self
.
product_space_dimension
,
axis
=
0
)
else
:
A
=
lambda
I
:
fft2
(
Masks
*
I
)
h
=
(
mean
(
conj
(
Masks
)
*
ifft2
(
Y
),
2
))
At
=
lambda
Y
:
np
.
repeat
(
h
[:,:,
np
.
newaxis
],
L
,
axis
=
2
)
At
=
lambda
Y
:
np
.
repeat
(
mean
(
conj
(
Masks
)
*
ifft2
(
Y
),
2
)[:,:,
np
.
newaxis
],
self
.
product_space_dimension
,
axis
=
2
)
if
hasattr
(
self
,
'truth'
):
Relerrs
=
change
.
copy
()
tau0
=
330
mu
=
lambda
t
:
min
(
1
-
exp
(
-
t
/
tau0
),
0.4
)
while
iter
<
MAXIT
and
change
[
iter
]
>=
TOL
:
while
iter
<
MAXIT
-
1
and
change
[
iter
]
>=
TOL
:
iter
+=
1
Bz
=
A
(
u
)
C
=
(
abs
(
Bz
)
**
2
-
Y
)
*
Bz
C
=
(
abs
(
Bz
)
**
2
-
self
.
data_sq
)
*
Bz
grad
=
At
(
C
)
step
=
mu
(
iter
)
/
normest
**
2
*
grad
u
=
u
-
step
tmp_change
=
0
step
=
mu
(
iter
)
/
self
.
norm_data
**
2
*
grad
u
=
u
-
step
#print(norm(u,'fro'))
#print("u")
#print(norm(Bz,'fro'))
#print("Bz")
#print(norm(C,'fro'))
#print("C")
#print(norm(grad,'fro'))
#print("grad")
#print(norm(step,'fro'))
#print("step")
#print(norm(u,'fro'))
#print("newu")
tmp_change
=
0
tmp_gap
=
0
if
p
==
1
and
q
==
1
:
tmp_change
=
(
norm
(
step
,
'fro'
)
/
normM
)
**
2
tmp_change
=
(
norm
(
step
,
'fro'
)
/
self
.
norm_truth
)
**
2
tmp_u
=
Prox1
.
work
(
u
)
tmp1
=
Prox2
.
work
(
tmp_u
)
tmp_gap
=
(
norm
(
tmp1
-
tmp_u
,
'fro'
)
/
normM
)
**
2
if
x_dim
[
0
]
==
1
:
tmp_gap
=
(
norm
(
tmp1
-
tmp_u
,
'fro'
)
/
self
.
norm_truth
)
**
2
if
self
.
truth_dim
[
0
]
==
1
:
z
=
u
[
0
,:]
elif
x
_dim
[
1
]
==
1
:
elif
self
.
truth
_dim
[
1
]
==
1
:
z
=
u
[:,
0
]
else
:
z
=
u
[:,:,
0
]
Relerrs
[
iter
]
=
norm
(
x
-
exp
(
-
1j
*
angle
(
trace
(
x
.
transpose
()
*
z
)))
*
z
,
'fro'
)
/
normM
Relerrs
[
iter
]
=
norm
(
self
.
truth
-
cmath
.
exp
(
-
1j
*
angle
(
trace
(
self
.
truth
*
z
)))
*
z
,
'fro'
)
/
self
.
norm_truth
elif
q
==
1
:
tmp_u
=
Prox1
.
work
(
u
)
tmp1
=
Prox2
.
work
(
tmp_u
)
tmp_change
=
L
*
(
norm
(
step
[:,:,
1
],
'fro'
)
/
normM
)
**
2
for
j
in
range
(
L
):
tmp_gap
=
tmp_gap
+
(
norm
(
tmp1
[:,:,
j
]
-
tmp_u
[:,:,
j
],
'fro'
)
/
normM
)
**
2
Relerrs
[
iter
]
=
norm
(
x
-
exp
(
-
1j
*
angle
(
trace
(
x
*
tmp_u
[:,:,
0
])))
*
tmp_u
[:,:,
0
],
'fro'
)
/
normM
tmp_change
=
self
.
product_space_dimension
*
(
norm
(
step
[:,:,
0
],
'fro'
)
/
self
.
norm_truth
)
**
2
for
j
in
range
(
self
.
product_space_dimension
):
tmp_gap
=
tmp_gap
+
(
norm
(
tmp1
[:,:,
j
]
-
tmp_u
[:,:,
j
],
'fro'
)
/
self
.
norm_truth
)
**
2
Relerrs
[
iter
]
=
norm
(
self
.
truth
-
cmath
.
exp
(
-
1j
*
angle
(
trace
(
self
.
truth
*
tmp_u
[:,:,
0
])))
*
tmp_u
[:,:,
0
],
'fro'
)
/
self
.
norm_truth
else
:
Times
[
iter
]
=
toc
tmp_u
=
Prox1
(
u
)
tmp1
=
Prox2
(
tmp_u
)
for
j
in
range
(
L
):
for
k
in
range
(
Nz
):
tmp_change
=
tmp_change
+
(
norm
(
step
[:,:,
k
,
j
],
'fro'
)
/
normM
)
**
2
tmp_gap
=
tmp_gap
+
(
norm
(
step
[:,:,
k
,
j
],
'fro'
)
/
normM
)
**
2
Relerrs
[
iter
]
=
norm
(
x
-
exp
(
-
1j
*
angle
(
trace
(
x
.
transpose
()
*
tmp_u
[:,:,:,
0
])))
*
tmp_u
[:,:,:,
0
],
'fro'
)
/
normM
tmp_u
=
Prox1
.
work
(
u
)
tmp1
=
Prox2
.
work
(
tmp_u
)
for
j
in
range
(
self
.
product_space_dimension
):
for
k
in
range
(
self
.
Nz
):
tmp_change
=
tmp_change
+
(
norm
(
step
[:,:,
k
,
j
],
'fro'
)
/
self
.
norm_truth
)
**
2
tmp_gap
=
tmp_gap
+
(
norm
(
step
[:,:,
k
,
j
],
'fro'
)
/
self
.
norm_truth
)
**
2
Relerrs
[
iter
]
=
norm
(
self
.
truth
-
cmath
.
exp
(
-
1j
*
angle
(
trace
(
self
.
truth
*
tmp_u
[:,:,:,
0
])))
*
tmp_u
[:,:,:,
0
],
'fro'
)
/
self
.
norm_truth
gap
[
iter
]
=
sqrt
(
tmp_gap
)
change
[
iter
]
=
sqrt
(
tmp_change
)
print
(
14
)
tmp
=
Prox1
(
u
)
tmp2
=
Prox2
(
u
)
tmp
=
Prox1
.
work
(
u
)
tmp2
=
Prox2
.
work
(
u
)
if
self
.
Nx
==
1
:
u1
=
tmp
[:,
0
]
u2
=
tmp2
[:,
0
]
...
...
@@ -146,7 +160,9 @@ class Wirtinger(Algorithm):
u1
=
tmp
u2
=
tmp2
change
=
change
[
1
:
iter
]
change
=
change
[
0
:
iter
]
return
{
'u1'
:
u1
,
'u2'
:
u2
,
'iter'
:
iter
,
'change'
:
change
,
'gap'
:
gap
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment