Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
nam
ProxPython
Commits
9c4bc48f
Commit
9c4bc48f
authored
Nov 04, 2020
by
jansen31
Browse files
reconstruct coupled orbitals
parent
ff4aca49
Changes
3
Hide whitespace changes
Inline
Side-by-side
proxtoolbox/experiments/orbitaltomography/degenerate_orbits.py
View file @
9c4bc48f
...
...
@@ -12,13 +12,13 @@ class DegenerateOrbital(PlanarMolecule):
defaultParams
=
{
'experiment_name'
:
'2D ARPES'
,
'data_filename'
:
None
,
'from_intensity_data'
:
Fals
e
,
'from_intensity_data'
:
Tru
e
,
'object'
:
'real'
,
'degeneracy'
:
2
,
# Number of degenerate states to reconstruct
'constraint'
:
'sparse real'
,
'sparsity_parameter'
:
40
,
'use_sparsity_with_support'
:
True
,
'threshold_for_support'
:
0.
1
,
'threshold_for_support'
:
0.
05
,
'support_filename'
:
None
,
'Nx'
:
None
,
'Ny'
:
None
,
...
...
@@ -156,10 +156,10 @@ class DegenerateOrbital(PlanarMolecule):
ax
[
ii
].
set_title
(
"Degenerate orbit %d"
%
ii
)
im
=
ax
[
-
2
].
imshow
(
np
.
sum
(
abs
(
u_show
)
**
2
,
axis
=
0
))
ax
[
-
2
].
set_title
(
"Local density of states"
)
plt
.
colorbar
(
im
,
ax
=
ax
[
-
2
])
#
plt.colorbar(im, ax=ax[-2]
, shrink=0.7
)
im
=
ax
[
-
1
].
imshow
(
fourier_intensity
)
ax
[
-
1
].
set_title
(
"Fourier domain intensity"
)
plt
.
colorbar
(
im
,
ax
=
ax
[
-
1
])
#
plt.colorbar(im, ax=ax[-1]
, shrink=0.7
)
plt
.
tight_layout
()
if
show
:
plt
.
show
()
...
...
proxtoolbox/experiments/orbitaltomography/orthogonal_orbits.py
0 → 100644
View file @
9c4bc48f
import
matplotlib.pyplot
as
plt
import
numpy
as
np
from
skimage.io
import
imread
from
scipy.ndimage
import
binary_dilation
,
shift
,
center_of_mass
from
proxtoolbox.experiments.orbitaltomography.planar_molecule
import
PlanarMolecule
from
proxtoolbox.utils.orbitaltomog
import
shifted_fft
,
fourier_interpolate
,
bin_array
,
shifted_ifft
from
proxtoolbox.utils.visualization.complex_field_visualization
import
complex_to_rgb
class
OrthogonalOrbitals
(
PlanarMolecule
):
@
staticmethod
def
getDefaultParameters
():
defaultParams
=
{
'experiment_name'
:
'2D ARPES'
,
'data_filename'
:
None
,
'from_intensity_data'
:
True
,
'object'
:
'real'
,
'degeneracy'
:
2
,
# Number of degenerate states to reconstruct
'constraint'
:
'sparse real'
,
'sparsity_parameter'
:
40
,
'use_sparsity_with_support'
:
True
,
'threshold_for_support'
:
0.01
,
'support_filename'
:
None
,
'Nx'
:
None
,
'Ny'
:
None
,
'Nz'
:
None
,
'MAXIT'
:
500
,
'TOL'
:
1e-10
,
'lambda_0'
:
0.85
,
'lambda_max'
:
0.50
,
'lambda_switch'
:
50
,
'data_ball'
:
.
999826
,
'TOL2'
:
1e-15
,
'diagnostic'
:
True
,
'algorithm'
:
'CP'
,
'iterate_monitor_name'
:
'FeasibilityIterateMonitor'
,
# 'IterateMonitor', #
'rotate'
:
False
,
'verbose'
:
1
,
'graphics'
:
1
,
'interpolate_and_zoom'
:
True
,
'debug'
:
True
,
'progressbar'
:
None
}
return
defaultParams
def
__init__
(
self
,
**
kwargs
):
super
(
OrthogonalOrbitals
,
self
).
__init__
(
**
kwargs
)
def
loadData
(
self
):
"""
Load data and set in the correct format for reconstruction
Parameters are taken from experiment class (self) properties, which must include::
- data_filename: str, path to the data file, or list of file names
- from_intensity_data: bool, if the data file gives intensities rather than field amplitude
- support_filename: str, optional path to file with object support
- use_sparsity_with_support: bool, if true, use a support before the sparsity constraint.
The support is calculated by thresholding the object autocorrelation, and dilate the result
- threshold_for_support: float, in range [0,1], fraction of the maximum at which to threshold when
determining support or support for sparsity
"""
# load data
if
self
.
data_filename
is
None
:
self
.
data_filename
=
input
(
'Please enter the path to the datafile: '
)
try
:
if
isinstance
(
self
.
data_filename
,
str
):
self
.
data
=
imread
(
self
.
data_filename
)
else
:
self
.
data
=
np
.
array
([
imread
(
fname
)
for
fname
in
self
.
data_filename
])
except
FileNotFoundError
:
print
(
"Tried path %s, found nothing. "
%
self
.
data_filename
)
self
.
data_filename
=
input
(
'Please enter a valid path to the datafile: '
)
self
.
data
=
imread
(
self
.
data_filename
)
# If data is corrected for A.K, then it should be well centered. we can check that here
for
i
in
range
(
len
(
self
.
data
)):
cm
=
center_of_mass
(
self
.
data
[
i
]
**
2
)
to_shift
=
tuple
([
s
//
2
-
cm
[
i
]
for
i
,
s
in
enumerate
(
self
.
data
[
i
].
shape
)])
self
.
data
[
i
]
=
shift
(
self
.
data
[
i
],
to_shift
,
mode
=
'nearest'
,
order
=
1
)
# Keep the same resolution?
self
.
Nz
,
ny
,
nx
=
self
.
data
.
shape
if
self
.
Ny
is
None
:
self
.
Ny
=
ny
if
self
.
Nx
is
None
:
self
.
Nx
=
nx
if
ny
!=
self
.
Ny
or
nx
!=
self
.
Nx
:
# binning must be done for the intensity-data, as that preserves the normalization
if
self
.
from_intensity_data
:
self
.
data
=
bin_array
(
self
.
data
,
(
self
.
Nz
,
self
.
Ny
,
self
.
Nx
))
else
:
self
.
data
=
np
.
sqrt
(
bin_array
(
self
.
data
**
2
,
(
self
.
Nz
,
self
.
Ny
,
self
.
Nx
)))
self
.
Nz
,
self
.
Ny
,
self
.
Nx
=
self
.
data
.
shape
# Calculate electric field and norm of the data
if
self
.
from_intensity_data
:
# avoid sqrt of negative numbers (due to background subtraction)
self
.
data
=
np
.
where
(
self
.
data
>
0
,
np
.
sqrt
(
abs
(
self
.
data
)),
np
.
zeros_like
(
self
.
data
))
self
.
norm_data
=
np
.
sqrt
(
np
.
sum
(
self
.
data
**
2
))
# Object support determination
if
self
.
support
is
not
None
:
self
.
support
=
imread
(
self
.
support_filename
)
else
:
self
.
support
=
support_from_stack
(
self
.
data
,
threshold
=
self
.
threshold_for_support
,
absolute_autocorrelation
=
True
,
binary_dilate_support
=
1
)
if
self
.
use_sparsity_with_support
:
self
.
sparsity_support
=
support_from_stack
(
self
.
data
,
threshold
=
self
.
threshold_for_support
,
binary_dilate_support
=
1
)
self
.
createRandomGuess
()
# some variables wich are necessary for the algorithm:
self
.
data_sq
=
self
.
data
**
2
self
.
data_zeros
=
np
.
where
(
self
.
data
==
0
)
def
createRandomGuess
(
self
):
"""
Taking the measured data, add a random phase and calculate the resulting iterate guess
"""
self
.
u0
=
self
.
data
*
np
.
exp
(
1j
*
2
*
np
.
pi
*
np
.
random
.
random_sample
(
self
.
data
.
shape
))
self
.
u0
=
shifted_fft
(
self
.
u0
,
axes
=
(
-
1
,
-
2
))
def
setupProxOperators
(
self
):
"""
Determine the prox operators to be used based on the given constraint.
This method is called during the initialization process.
sets the parameters:
- self.proxOperators
- self.propagator and self.inverse_propagator
"""
# Select the right real space operator sparsity-based proxoperators
self
.
proxOperators
.
append
(
'P_Sparsity_real_incoherent'
)
# Apply orthonormality constraint
self
.
proxOperators
.
append
(
'P_orthonorm'
)
# Modulus proxoperator (normally the second operator)
self
.
proxOperators
.
append
(
'P_M'
)
self
.
propagator
=
'PropagatorFFT2'
self
.
inverse_propagator
=
'InvPropagatorFFT2'
self
.
nProx
=
len
(
self
.
proxOperators
)
def
plotInputData
(
self
):
"""Quick plotting routine to show the data, initial guess and the sparsity support"""
fig
,
ax
=
plt
.
subplots
(
2
,
self
.
Nz
+
1
,
figsize
=
(
12
,
7
))
for
ii
in
range
(
self
.
Nz
):
im
=
ax
[
0
][
ii
].
imshow
(
self
.
data
[
ii
])
plt
.
colorbar
(
im
,
ax
=
ax
[
0
][
ii
])
ax
[
0
][
ii
].
set_title
(
"Photoelectron spectrum %d"
%
ii
)
if
self
.
sparsity_support
is
not
None
:
im
=
ax
[
0
][
-
1
].
imshow
(
self
.
sparsity_support
,
cmap
=
'gray'
)
# plt.colorbar(im, ax=ax[2])
ax
[
0
][
-
1
].
set_title
(
"Sparsity support"
)
for
ii
in
range
(
self
.
Nz
):
im
=
ax
[
1
][
ii
].
imshow
(
complex_to_rgb
(
self
.
u0
[
ii
]))
plt
.
colorbar
(
im
,
ax
=
ax
[
1
][
ii
])
ax
[
1
][
ii
].
set_title
(
"Degenerate orbit %d"
%
ii
)
ax
[
1
][
-
1
].
imshow
(
np
.
sum
(
abs
(
self
.
u0
)
**
2
,
axis
=
0
))
ax
[
1
][
-
1
].
set_title
(
"Integrated density of states"
)
plt
.
show
()
def
show
(
self
,
**
kwargs
):
"""
Create basic result plots of the phase retrieval procedure
"""
super
(
PlanarMolecule
,
self
).
show
()
self
.
output
[
'u1'
]
=
self
.
algorithm
.
prox1
.
eval
(
self
.
algorithm
.
u
)
self
.
output
[
'u2'
]
=
self
.
algorithm
.
prox2
.
eval
(
self
.
algorithm
.
u
)
figsize
=
kwargs
.
pop
(
"figsize"
,
(
12
,
3
))
for
i
,
operator
in
enumerate
(
self
.
algorithm
.
proxOperators
):
operator_name
=
self
.
proxOperators
[
i
].
__name__
f
=
self
.
plot_guess
(
operator
.
eval
(
self
.
algorithm
.
u
),
name
=
"%s satisfied"
%
operator_name
,
show
=
False
,
interpolate_and_zoom
=
self
.
interpolate_and_zoom
,
figsize
=
figsize
)
self
.
output
[
'plots'
].
append
(
f
)
# f1 = self.plot_guess(self.output['u1'], name='Best approximation: physical constraint satisfied', show=False)
# f2 = self.plot_guess(self.output['u2'], name='Best approximation: Fourier constraint satisfied', show=False)
# prop = self.propagator(self)
# u_hat = prop.eval(self.algorithm.prox1.eval(self.algorithm.u))
# h = self.plot_guess(u_hat, show=False, name="Fourier domain measurement projection")
# self.output['plots'].append(f1)
# self.output['plots'].append(f2)
# self.output['plots'].append(h)
plt
.
show
()
# def saveOutput(self, **kwargs):
# super(PlanarMolecule, self).saveOutput(**kwargs)
def
plot_guess
(
self
,
u
,
name
=
None
,
show
=
True
,
interpolate_and_zoom
=
False
,
figsize
=
(
12
,
6
)):
""""Given a list of fields, plot the individual fields and the combined intensity"""
prop
=
self
.
propagator
(
self
)
# This is not a string but the indicated class itself, to be instantiated
u_hat
=
prop
.
eval
(
u
)
fourier_intensity
=
np
.
sqrt
(
np
.
sum
(
abs
(
u_hat
)
**
2
,
axis
=
0
))
if
interpolate_and_zoom
:
u_show
=
self
.
interp_zoom_field
(
u
)
else
:
u_show
=
u
fig
,
ax
=
plt
.
subplots
(
2
,
len
(
u
)
+
1
,
figsize
=
figsize
,
num
=
name
)
for
ii
in
range
(
self
.
Nz
):
im
=
ax
[
0
][
ii
].
imshow
(
complex_to_rgb
(
u_show
[
ii
]))
ax
[
0
][
ii
].
set_title
(
"Degenerate orbit %d"
%
ii
)
im
=
ax
[
0
][
-
1
].
imshow
(
np
.
sum
(
abs
(
u_show
)
**
2
,
axis
=
0
))
ax
[
0
][
-
1
].
set_title
(
"Local density of states"
)
for
ii
in
range
(
self
.
Nz
):
im
=
ax
[
1
][
ii
].
imshow
(
complex_to_rgb
(
u_hat
[
ii
]))
ax
[
1
][
ii
].
set_title
(
"Fourier domain %d"
%
ii
)
# plt.colorbar(im, ax=ax[-2], shrink=0.7)
im
=
ax
[
1
][
-
1
].
imshow
(
fourier_intensity
)
ax
[
1
][
-
1
].
set_title
(
"Total Fourier domain intensity"
)
# plt.colorbar(im, ax=ax[-1], shrink=0.7)
plt
.
tight_layout
()
if
show
:
plt
.
show
()
return
fig
def
interp_zoom_field
(
self
,
u
,
interpolation
=
2
,
zoom
=
0.5
):
"""
interpolate a field and zoom in to the center
"""
nt
,
ny
,
nx
=
u
.
shape
cm
=
center_of_mass
(
np
.
sum
(
abs
(
u
)
**
2
,
axis
=
0
))
to_shift
=
(
0
,
-
1
*
int
(
np
.
round
(
cm
[
0
]
-
ny
/
2
)),
-
1
*
int
(
np
.
round
(
cm
[
1
]
-
nx
/
2
)))
centered
=
np
.
roll
(
u
,
to_shift
,
axis
=
(
0
,
1
,
2
))
zmy
=
int
(
ny
*
zoom
)
//
2
zmx
=
int
(
nx
*
zoom
)
//
2
zoomed
=
centered
[:,
zmy
:
ny
-
zmy
,
zmx
:
nx
-
zmx
]
interpolated
=
np
.
array
([
fourier_interpolate
(
u_i
,
factor
=
interpolation
)
for
u_i
in
zoomed
])
return
interpolated
def
support_from_stack
(
input_array
:
np
.
ndarray
,
threshold
:
float
=
0.1
,
relative_threshold
:
bool
=
True
,
input_in_fourier_domain
:
bool
=
True
,
absolute_autocorrelation
:
bool
=
True
,
binary_dilate_support
:
int
=
0
)
->
np
.
ndarray
:
"""
Determine an initial support from a list of autocorrelations.
Args:
input_array: either the measured diffraction patterns (arpes patterns) or guesses of the objects
threshold: support is everywhere where the autocorrelation is higher than the threshold
relative_threshold: If true, threshold at threshold*np.amax(autocorrelation)
input_in_fourier_domain: False if a guess of the object is given in input_array
absolute_autocorrelation: Take the absolute value of the autocorrelation? (Generally a
good idea for objects which are not non-negative)
binary_dilate_support: number of dilation operations to apply to the support.
Returns:
support array (dimensions: input_array.shape[1:], dtype=np.int)
"""
_axes
=
tuple
(
range
(
-
1
*
input_array
.
ndim
+
1
,
0
))
if
not
input_in_fourier_domain
:
kspace
=
shifted_fft
(
input_array
,
axes
=
_axes
)
else
:
kspace
=
input_array
# Taking absolute value of the Fourier transform yields autocorrelation by conv. theorem)
autocorrelation
=
shifted_ifft
(
abs
(
kspace
),
axes
=
_axes
)
if
absolute_autocorrelation
:
autocorrelation
=
abs
(
autocorrelation
)
# Take the sum along the first axis to get the average of the autocorrelations
autocorrelation
=
np
.
sum
(
autocorrelation
,
axis
=
0
)
# Detetmine thresholding
maxval
=
np
.
amax
(
autocorrelation
)
if
relative_threshold
:
threshold_val
=
threshold
*
maxval
else
:
threshold_val
=
threshold
support
=
(
autocorrelation
>
threshold_val
).
astype
(
np
.
uint
)
# Dilate support to make it a bit too big (also fills small gaps)
if
binary_dilate_support
>
0
:
support
=
binary_dilation
(
support
,
iterations
=
binary_dilate_support
).
astype
(
np
.
uint
)
return
support
proxtoolbox/proxoperators/P_orthonorm.py
View file @
9c4bc48f
...
...
@@ -37,8 +37,8 @@ class P_orthonorm(ProxOperator):
raise
Exception
(
"This should never rise, check calculation of a"
)
# apply projection
u_new
=
zeros_like
(
u_norm
)
u_new
[
0
]
=
u_norm
[
0
]
-
(
y
/
(
y
**
2
-
1
))
*
(
y
*
u_norm
[
0
]
-
u_norm
[
1
])
u_new
[
1
]
=
(
1
/
(
y
**
2
-
1
))
*
(
y
*
u_norm
[
0
]
-
u_norm
[
1
])
u_new
[
1
]
=
u_norm
[
0
]
-
(
y
/
(
y
**
2
-
1
))
*
(
y
*
u_norm
[
0
]
-
u_norm
[
1
])
u_new
[
0
]
=
(
1
/
(
y
**
2
-
1
))
*
(
y
*
u_norm
[
0
]
-
u_norm
[
1
])
return
u_new
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment