Skip to content
Snippets Groups Projects
Commit cf30346c authored by Jens Lucht's avatar Jens Lucht
Browse files

Remove Tikhonov/PGD PoC code from repo.

parent 25f26b2d
Branches
Tags
1 merge request!27Nonlinear Tikhonov-regularized Fresnel reconstruction algorithm with (projected) gradient decent
Pipeline #269372 passed
toolboxPath = '/home/jlucht/workspace/holotomotoolbox';
addpath(genpath(toolboxPath));
getDataset = @(fileName, H5path) double(flipud(rot90(h5read(fileName, H5path))));
%% load sample
fname = '/home/jlucht/workspace/fresnel/Colloids_SiO2_nonabs.h5';
holos = getDataset(fname, '/holograms');
fresnel_nums = getDataset(fname, '/fresnel_numbers');
%% toolbox samples
holodata = getHologram('beads');
holos = holodata.holograms(:,:,:);
fresnel_nums = holodata.fresnelNumbers;
%%
showImage(holos(:,:,1));
size(holos)
%% Settings
s = phaserec_ctf;
% s.lim1 = 1e-3; s.lim2 = 5e-1;
s.lim1 = 5e-4; s.lim2 = 1e-4;
% s.optimization.stepsizeRule = 'constantStepsize'; % be consistent with Torch
s.optimization.stepsizeRule = 'bbStepsize';
% s.optimization.initialStepsize = 6.2e-2;
% s.optimization.linesearchMaxTries = 1; % disable non-monotone line search (not in py)
s.optimization.verbose = true;
s.optimization.tolerance = 1e-2; % be faster
s.optimization.useGPU = false;
%s.optimization.maxIterations = 42;
%% CTF
rec_ctf = phaserec_ctf(holos, fresnel_nums, s);
%showImage(rec_ctf);
%%
[rec, stats] = phaserec_nonlinTikhonov(holos, fresnel_nums, s);
%%
showImage(rec, 1);
showImage(rec_ctf, 2);
%%
ny = size(rec,2);
nx = size(rec,1);
nx2 = floor(nx/2);
xx = 1:ny;
plot(xx, rec_ctf(nx2,:), 'g-', xx, rec(nx2,:), 'r-');
% plot(rec_ctf(nx2,:));
%%
alpha = [s.lim1, s.lim2];
if true
targetFile = './reconstructionBeads.h5';
% because MATLAB ... is tricky ;)
if ~isfile(targetFile)
h5create(targetFile, '/holograms', size(holos));
h5create(targetFile, '/fresnelnumbers', size(fresnel_nums));
h5create(targetFile, '/alpha', size(alpha));
h5create(targetFile, '/tikhonov', size(rec));
h5create(targetFile, '/ctf', size(rec_ctf));
h5create(targetFile, '/tikhonov_stats/values', size(stats.functionalValues));
h5create(targetFile, '/tikhonov_stats/resiGrad', size(stats.resiGrad));
h5create(targetFile, '/tikhonov_stats/resiGradRel', size(stats.resiGradRel));
h5create(targetFile, '/tikhonov_stats/gradNorm', size(stats.gradNorm));
h5create(targetFile, '/tikhonov_stats/lr', size(stats.stepsize));
end
h5write(targetFile, '/holograms', holos);
h5write(targetFile, '/fresnelnumbers', fresnel_nums);
h5write(targetFile, '/alpha', alpha);
h5write(targetFile, '/tikhonov', rec);
h5write(targetFile, '/ctf', rec_ctf);
h5write(targetFile, '/tikhonov_stats/values', stats.functionalValues);
h5write(targetFile, '/tikhonov_stats/resiGrad', stats.resiGrad);
h5write(targetFile, '/tikhonov_stats/resiGradRel', stats.resiGradRel);
h5write(targetFile, '/tikhonov_stats/gradNorm', stats.gradNorm);
h5write(targetFile, '/tikhonov_stats/lr', stats.stepsize);
end
\ No newline at end of file
"""
Playing with Fresnel propagation and reconstruction in the Tikhonov scheme in PyTorch.
Author: Jens Lucht
"""
from typing import List, Callable
from functools import partial
import time
import torch
from torch import Tensor
from torch.nn import Module, Parameter, MSELoss, functional
from torch.fft import fftfreq, fft2, ifft2
from torch.optim import SGD
from torch.linalg import vector_norm
from hotopy.phase import cloetens_regularization
from hotopy.fourier import fftfreqn
from hotopy.phase.util import expand_fresnel_numbers, symmetric_padding
# new
from hotopy.phase.pgd import PGD, stepsize_barlizai_borwein
# lazy aliases
# vecotrial 2-norm
v2norm = partial(vector_norm, ord=2)
iprint = partial(print, end='')
## beginregion matlab import
from scipy.io import loadmat
from pathlib import Path
import numpy as np
toolbox_path = Path("../../holotomotoolbox")
sample_path = Path("functions/generators/")
def matswap(inp):
return np.moveaxis(inp, -1, 0)
def get_holograms(name, unpack=True):
fname = f"holograms{name.capitalize()}.mat"
hpath = toolbox_path / sample_path / "holograms" / fname
mat = loadmat(hpath)
if unpack:
return mat['holograms'], mat['fresnelNumbers'][0]
return mat
## endregion
def fftfreqn_torch(shape, dx=1.0, device=None):
xi = [fftfreq(dim, dx, device=device) for dim in shape]
return torch.meshgrid(*xi, indexing="ij")
def fresnel_kernel(shape: tuple, nf: List[float], device=None):
kernel = torch.zeros((len(nf), *shape), dtype=torch.float64, device=device)
freqs = fftfreqn_torch(shape, device=device)
for i in range(len(shape)):
kernel.add_(torch.square(freqs[i]) / nf[:, i, None, None])
return torch.exp(-1j * torch.pi * kernel)
def regularization_kernel(shape, nf, alpha=None):
# note: this function used CPU only and then wraps result as a tensor
freqs = fftfreqn(shape)
reg = cloetens_regularization(freqs, nf.cpu().numpy(), alpha=alpha)
return torch.from_numpy(reg)
def _torch_pad(pad):
"""
Converts per axes padding (in same order), like in numpy.pad, to torch padding (in reserved 2-value per dim order)
"""
return tuple([p for dim in reversed(pad) for p in dim])
def _n_slice(*args, **kwargs):
args = [None if arg == 0 else arg for arg in args]
return slice(*args, **kwargs)
def sympad(shape: tuple, factor=2, dims: int = 2, mode: str = 'replicate') -> (Callable, Callable):
"""
Symmetric padding for input tensor by factor factor.
"""
factor -= 1
newshape = [int(factor*shape) for shape in shape[-dims:]]
newshape = [0] * (len(shape) - dims) + newshape
padding = symmetric_padding(shape, newshape)
print(padding)
torch_padding = _torch_pad(padding[-dims:])
crops = tuple([_n_slice(pre, -post) for (pre, post) in padding])
def pad(input: Tensor):
# little hack for 2d padding
if input.ndim == 2:
input = input[None, ...]
out = 0
else:
out = slice(None)
return functional.pad(input, torch_padding, mode=mode)[out]
def crop(input: Tensor):
return input[crops[-input.ndim:]]
return pad, crop, newshape
class FresnelProp(Module):
"""
Fresnel (forward) propagation defined by given fresnel number nf.
"""
def __init__(self, shape, nf, device=None):
super(FresnelProp, self).__init__()
self.kernel = fresnel_kernel(shape, nf, device=device).to(memory_format=torch.contiguous_format)
self._dim = (-2, -1)
self.shape = shape
def forward(self, x: Tensor):
return ifft2(self.kernel * fft2(x, s=self.shape), s=self.shape)
class ImageModel(Module):
def __init__(self, x0, nf, device=None):
super(ImageModel, self).__init__()
self.x = Parameter(x0, requires_grad=True)
self.propagator = FresnelProp(x0.shape, nf, device=device)
def forward(self):
return (self.propagator(torch.exp(self.x.mul(1j)) ) ).abs().square()
class TikhonovLoss:
def __init__(self, regularization: Tensor, device=None, reduction='sum') -> None:
self.regularization = torch.as_tensor(regularization.sqrt(), device=device).to(memory_format=torch.contiguous_format)
self.data_loss = MSELoss(reduction=reduction)
def __call__(self, input: Tensor, target: Tensor, recon: Tensor) -> Tensor:
return 0.5 * self.data_loss(input, target) + \
v2norm(fft2(recon, norm="ortho").mul(self.regularization)).square() # FIXME: raises warning about complex to real
def tikhonov(holos, fresnel_nums):
holoshape = N, *imshape = holos.shape
ndim = len(imshape)
alpha = [5e-4, 1e-4]
# alpha = [1e-3, 5e-1] # does not work, sensitive
device = 'cpu'
max_iter = 22
tolerance = 1e-2
precision = torch.double
"""
If we do not have any constraint, the projected gradient-decent (an instance of forward-backward-splitting, if
one function is the indicator function on a convex set -> proximal operator is the projection onto this set).
In PyTorch gradient-decent is implemented in the SGD (stochastic gradient decent) optimizer.
To extend this to a projected gradient decent is straight forward, i.e. we need to apply the projection
after the 'inner' forward (gradient) step, thus it can be wrapped in a sub-class of SGD.
"""
lr = 1 / (4 * N + max(alpha))
# inital guess with CTF
ctf = CTF(imshape, fresnel_nums.numpy(), alpha=alpha, device=device)
initial = ctf(holos.numpy())
fresnel_nums = torch.tensor(expand_fresnel_numbers(fresnel_nums, ndim), device=device)
rec = initial.to(dtype=precision, memory_format=torch.contiguous_format, device=device) # variable
holos = holos.to(dtype=precision, memory_format=torch.contiguous_format, device=device) # target
# setup "Tikhonov functional"
fresnel = ImageModel(rec, fresnel_nums, device=device)
loss_fn = TikhonovLoss(regularization_kernel(imshape, fresnel_nums, alpha=alpha), device=device, reduction='sum')
def error_functional(): # aka closure
y = fresnel()
return loss_fn(y, holos, fresnel.x)
# (projected) gradient decent
IdProj = lambda x: x
solver = PGD(
list(fresnel.parameters()),
[IdProj],
error_functional,
lr,
stepsize_barlizai_borwein,
ls_backtracking=2,
ls_decrease_factor=4,
)
print(f"initial ss {lr:.3e}")
## solving
def iterate():
# initialize
value = error_functional()
value.backward()
solver.state["ls_values"][0] = value
ref_grad = vector_norm(solver.param.grad)
# iterate
for i in range(max_iter):
# check convergence
rel_grad = vector_norm(solver.param.grad) / ref_grad
print(f"PGD {i}: value = {value:.3e} rel_grad = {rel_grad:.3e}")
if rel_grad < tolerance:
break
# do projected gradient step
value = solver.step()
# reset autograd and populate gradient for next iteration
solver.param.grad = None # clean gradient graphs for backward pass
value.backward()
return fresnel
return iterate, fresnel, solver
################################
# Proof of Concept
################################
if __name__ == '__main__':
from hotopy.phase import CTF
from matplotlib import pyplot as plt
import numpy as np
import h5py
# load example data (from matlab holotomotoolbox)
holos, fresnel_nums = get_holograms('beads')
holos = matswap(holos).astype(np.float_) # [:, 200:1000, 1000:1800]
holos, fresnel_nums = torch.from_numpy(holos, ), torch.from_numpy(fresnel_nums)
npad = 1
# apply padding first
pad, crop, padshape = sympad(holos.shape, npad)
holos = pad(holos)
holoshape = N, *imshape = holos.shape
tikhonov_optimization, fresnel, solver = tikhonov(holos, fresnel_nums)
t_start = time.time()
with torch.profiler.profile(
activities=[torch.profiler.profiler.ProfilerActivity.CUDA,
torch.profiler.profiler.ProfilerActivity.CPU],
with_stack=True,
) as profiler:
tikhonov_optimization()
t_end = time.time()
print(profiler.key_averages(group_by_stack_n=5).table(sort_by="self_cpu_time_total", row_limit=10))
print(f'Computation time: {t_end - t_start:.2e}')
"""
print(profiler.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
"""
# save resulsts from profiler
profiler.export_stacks('tikhonov-torch-profilter.stk')
profiler.export_chrome_trace('tikhonov-torch-trace.json')
# prepare for plotting, remove padding
recon = crop(fresnel.x).detach().cpu()
recon_np = recon.numpy()
# inital_np = crop(stats["p0"]).detach().cpu().numpy()
n = recon.shape[0]
k = solver.state["niter"]
# save for later
# hf = h5py.File('pyreconBeads.h5', 'w')
# hf.create_dataset('tikhonov', data=recon_np)
# hf.create_dataset('ctf', data=inital_np)
# for var, s in stats.items():
# if isinstance(s, torch.Tensor):
# hf.create_dataset(f'stats/{var}', data=s[:k].cpu().detach().numpy())
# hf.close()
# plotting
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(3*5., 5.))
fig.suptitle('Compare reconstructions')
ax1.set_title(r'Tikhonov')
im1 = ax1.imshow(recon_np)
fig.colorbar(im1, ax=ax1)
ax2.set_title('CTF (initial)')
# im2 = ax2.imshow(inital_np)
# fig.colorbar(im2, ax=ax2)
ax3.set_title('Line profile')
# ax3.plot(inital_np[n // 2], '--', c='b', label='CTF')
ax3.plot(recon_np[n // 2], c='r', label='Tikhonov')
ax3.legend()
plt.show()
\ No newline at end of file
"""
Playing with Fresnel propagation and reconstruction in the Tikhonov scheme in PyTorch.
Author: Jens Lucht
"""
from typing import List, Callable
from functools import partial
import time
import torch
from torch import Tensor
from torch.nn import Module, Parameter, MSELoss, functional
from torch.fft import fftfreq, fft2, ifft2
from torch.optim import SGD
from torch.linalg import vector_norm
from hotopy.phase import cloetens_regularization
from hotopy.fourier import fftfreqn
from hotopy.phase.util import expand_fresnel_numbers, symmetric_padding
# lazy aliases
# vecotrial 2-norm
v2norm = partial(vector_norm, ord=2)
iprint = partial(print, end='')
## beginregion matlab import
from scipy.io import loadmat
from pathlib import Path
import numpy as np
toolbox_path = Path("../../holotomotoolbox")
sample_path = Path("functions/generators/")
def matswap(inp):
return np.moveaxis(inp, -1, 0)
def get_holograms(name, unpack=True):
fname = f"holograms{name.capitalize()}.mat"
hpath = toolbox_path / sample_path / "holograms" / fname
mat = loadmat(hpath)
if unpack:
return mat['holograms'], mat['fresnelNumbers'][0]
return mat
## endregion
def fftfreqn_torch(shape, dx=1.0, device=None):
xi = [fftfreq(dim, dx, device=device) for dim in shape]
return torch.meshgrid(*xi, indexing="ij")
def fresnel_kernel(shape: tuple, nf: List[float], device=None):
kernel = torch.zeros((len(nf), *shape), dtype=torch.float64, device=device)
freqs = fftfreqn_torch(shape, device=device)
for i in range(len(shape)):
kernel.add_(torch.square(freqs[i]) / nf[:, i, None, None])
return torch.exp(-1j * torch.pi * kernel)
def regularization_kernel(shape, nf, alpha=None):
# note: this function used CPU only and then wraps result as a tensor
freqs = fftfreqn(shape)
reg = cloetens_regularization(freqs, nf.cpu().numpy(), alpha=alpha)
return torch.from_numpy(reg)
def _torch_pad(pad):
"""
Converts per axes padding (in same order), like in numpy.pad, to torch padding (in reserved 2-value per dim order)
"""
return tuple([p for dim in reversed(pad) for p in dim])
def _n_slice(*args, **kwargs):
args = [None if arg == 0 else arg for arg in args]
return slice(*args, **kwargs)
def sympad(shape: tuple, factor=2, dims: int = 2, mode: str = 'replicate') -> (Callable, Callable):
"""
Symmetric padding for input tensor by factor factor.
"""
factor -= 1
newshape = [int(factor*shape) for shape in shape[-dims:]]
newshape = [0] * (len(shape) - dims) + newshape
padding = symmetric_padding(shape, newshape)
print(padding)
torch_padding = _torch_pad(padding[-dims:])
crops = tuple([_n_slice(pre, -post) for (pre, post) in padding])
def pad(input: Tensor):
# little hack for 2d padding
if input.ndim == 2:
input = input[None, ...]
out = 0
else:
out = slice(None)
return functional.pad(input, torch_padding, mode=mode)[out]
def crop(input: Tensor):
return input[crops[-input.ndim:]]
return pad, crop, newshape
class FresnelProp(Module):
"""
Fresnel (forward) propagation defined by given fresnel number nf.
"""
def __init__(self, shape, nf, device=None):
super(FresnelProp, self).__init__()
self.kernel = fresnel_kernel(shape, nf, device=device).to(memory_format=torch.contiguous_format)
self._dim = (-2, -1)
def forward(self, x: Tensor):
return ifft2(self.kernel * fft2(x))
class ImageModel(Module):
def __init__(self, x0, nf, device=None):
super(ImageModel, self).__init__()
self.x = Parameter(x0, requires_grad=True)
self.propagator = FresnelProp(x0.shape, nf, device=device)
def forward(self):
return (self.propagator(torch.exp(1j * self.x) ) ).abs().square()
class TikhonovLoss:
def __init__(self, regularization: Tensor, device=None, *args, **kwargs) -> None:
self.regularization = torch.as_tensor(regularization.sqrt(), device=device).to(memory_format=torch.contiguous_format)
self.data_loss = MSELoss(reduction='sum')
def __call__(self, input: Tensor, target: Tensor, recon: Tensor) -> Tensor:
return 0.5 * self.data_loss(input, target) + \
v2norm(fft2(recon, norm="ortho").mul(self.regularization)).square() # FIXME: raises warning about complex to real
def tikhonov(holos, fresnel_nums):
holoshape = N, *imshape = holos.shape
ndim = len(imshape)
alpha = [5e-4, 1e-4]
# alpha = [1e-3, 5e-1] # does not work, sensitive
device = 'cuda'
max_iter = 30
npad = 1 # padding factor: 1 = no padding
use_bb = True # use Barzilai-Borwein adaptive stepsizes
use_ls = True # do non-monotone linesearch to stabilize stepsizes in non-monotone situations
ls_opts = {
'last_steps': 2, # > 0, int, TODO: if 0, no linesearch is performed
'ls_max_evals': 10,
'decrease_factor': 4, # >= 2, float
}
tolerance = 1e-2
precision = torch.double
"""
If we do not have any constraint, the projected gradient-decent (an instance of forward-backward-splitting, if
one function is the indicator function on a convex set -> proximal operator is the projection onto this set).
In PyTorch gradient-decent is implemented in the SGD (stochastic gradient decent) optimizer.
To extend this to a projected gradient decent is straight forward, i.e. we need to apply the projection
after the 'inner' forward (gradient) step, thus it can be wrapped in a sub-class of SGD.
"""
lr = 1 / (4 * N + max(alpha))
# inital guess with CTF
ctf = CTF(imshape, fresnel_nums.numpy(), alpha=alpha, device=device)
initial = ctf(holos.numpy())
fresnel_nums = torch.tensor(expand_fresnel_numbers(fresnel_nums, ndim), device=device)
rec = initial.to(dtype=precision, memory_format=torch.contiguous_format, device=device).detach() # variable
holos = holos.to(dtype=precision, memory_format=torch.contiguous_format, device=device).detach() # target
# setup "Tikhonov functional"
fresnel = ImageModel(rec, fresnel_nums, device=device)
loss_fn = TikhonovLoss(regularization_kernel(imshape, fresnel_nums, alpha=alpha), device=device, reduction='sum')
# (stochastic) gradient decent
solver = SGD(fresnel.parameters(), lr=lr)
# saving some stats. Note: this is explicitly indented to be on CPU only.
with torch.no_grad():
stats = {
"values": torch.empty(max_iter, dtype=torch.double),
"gradNorm": torch.empty(max_iter, dtype=torch.double),
"p": torch.empty((max_iter, *imshape), dtype=rec.dtype),
"grad_p": torch.empty((max_iter, *imshape), dtype=rec.dtype),
"lr": torch.empty(max_iter, dtype=torch.double),
"p0": initial.clone().detach(),
}
# initialize group params
for group in solver.param_groups:
group["ls_max_evals"] = ls_opts["ls_max_evals"]
group["last_steps"] = ls_opts["last_steps"]
group["decrease_factor"] = ls_opts["decrease_factor"]
for p in group["params"]:
state = solver.state[p]
state["last_value"] = None
state["last_grad"] = None
def iter():
def closure(backwards=True):
solver.zero_grad(True) # set grads to None (minor performance improvement)
# forward = fresnel propagation
y = fresnel()
loss = loss_fn(y, holos, fresnel.x)
# calculate gradient values
if backwards:
loss.backward()
return loss
# initialize model and populate gradients
loss = closure()
# iteration loop: grad-decent
for k in range(max_iter):
if use_bb and k > 0:
# hacky access to optimizer state and parameter
# NOTE: only for single-parametric groups!
group = solver.param_groups[0]
param = group["params"][0]
state = solver.state[param]
# last used learning rate (lr) aka step size
lr = group["lr"]
# Barzilai-Borwein adaptive step sizes
# in alternating manner, see e.g. Dai et al. (2003): Alternate minimization gradient method
with torch.no_grad():
k1 = k - 1
dp = param - state["last_value"]
ds = param.grad - state["last_grad"]
# TODO: implementation for complex parameters
if k % 2:
lr_new = ds.conj().mul(dp).sum() / ds.square().sum()
else:
lr_new = dp.square().sum() / dp.conj().mul(ds).sum()
if not torch.isreal(lr_new):
raise NotImplementedError('lr is complex!')
# ensure non-negativity
if lr_new < 0:
# if negative, keep old value
lr_new = lr
# save for iteration
group["lr"] = lr_new
# stats
lr_k = solver.param_groups[0]["lr"]
stats["values"][k] = loss
stats["gradNorm"][k] = v2norm(fresnel.x.grad)
#stats["p"][k] = fresnel.x.clone()
#stats["grad_p"][k] = fresnel.x.grad.clone()
# update last state
for group in solver.param_groups:
for p in group["params"]:
with torch.no_grad():
state = solver.state[p]
# FIXME merge last_value last_values?
state["last_value"] = p.clone(memory_format=torch.contiguous_format)
state["last_grad"] = p.grad.clone(memory_format=torch.contiguous_format)
if k == 0:
state["refgrad_norm"] = v2norm(p.grad)
# non-monotone linesearch or simple gradient step
# NOTE: only for a single single-parametric group!
# later: in PDG per parameter with projector
# remark: we can borrow some code from LBFGS, as its also do linesearch.
if use_ls:
with torch.no_grad():
group = solver.param_groups[0]
param = group["params"][0]
state = solver.state[param]
lr = group["lr"]
ls_max_evals = group["ls_max_evals"]
decrease_factor = group["decrease_factor"]
M = group["last_steps"]
# Make sure the closure is always called with grad enabled
closure = torch.enable_grad()(closure)
if k == 0:
# lazy initialization
# NOTE: LBFGS has only global state, but we register it as state for
# the first param, because this helps with casting in load_state_dict
# (copied from LBFGS optimizer)
state["last_losses"] = torch.full((M,), -torch.inf, dtype=torch.double, device=loss.device)
state["last_losses"][0] = loss
f_max = loss
else:
# update last functional values
lvs = state["last_losses"]
lvs[k % M] = loss
# state["last_losses"] = lvs
f_max = lvs.max()
n_iter = 0
backtracking_cond = False
lr_new = lr
# param_k = param.clone(memory_format=torch.contiguous_format)
# param_grad_k = param.grad.clone(memory_format=torch.contiguous_format)
# note: this is last_* is already updated to current value in iteration k
param_k = state["last_value"]
param_grad_k = state["last_grad"]
while True:
n_iter += 1
# x(k) |-> x(k+1)
# add projection op here
param.add_(param_grad_k, alpha=-lr_new)
# determine loss at x(k+1) but defer gradient calculations
loss = closure(backwards=False)
# check backtracking condition
dp = param.sub(param_k)
backtracking_cond = \
loss < f_max + dp.mul(param_grad_k).sum() + v2norm(dp).square().mul(1. / (2 * lr_new))
# check conditions
if backtracking_cond:
break
if n_iter >= ls_max_evals:
raise RuntimeWarning(f"Maximum function evaluations in line search exceeded and backtracking"
f" condition not satisfied.")
# reset parameter, lower stepsize and restart
# param.add_(param_grad_k, alpha=lr_new) # TODO can we do this somehow smarter?
param.set_(param_k.clone())
lr_new /= decrease_factor
# populate deferred gradients for iteration k+1
loss.backward()
# debugging for stats
lr_k = lr_new
print(f'LS {n_iter} tries, fmax = {f_max:.3e}')
else:
# Stepping x <-| x(k+1)
# (projected) gradient decent (NOTE: currently only gradient decent)
# warning: parameters are updated, grad's are invalid!
solver.step() # TODO replace with _step method within PGD, and apply projection
# evalute model at k+1 and populate gradients
loss = closure()
# add lr to stats since it may change within linesearch
stats["lr"][k] = lr_k
print(f'iteration {k + 1}, loss = {loss:.3e}, grad norm {stats["gradNorm"][k]:.3e}, lr {lr_k:.3e}')
# check convergence
if k > 0:
rel_grad = v2norm(fresnel.x.grad) / solver.state[fresnel.x]["refgrad_norm"]
print(f'rel grad norm {rel_grad:.3e}')
if rel_grad < tolerance:
print(f"Converged in iteration {k}")
stats["n_iter"] = k
break
return fresnel, stats
return iter
################################
# Proof of Concept
################################
if __name__ == '__main__':
from hotopy.phase import CTF
from matplotlib import pyplot as plt
import numpy as np
import h5py
# load example data (from matlab holotomotoolbox)
holos, fresnel_nums = get_holograms('beads')
holos = matswap(holos).astype(np.float_) # [:, 200:1000, 1000:1800]
holos, fresnel_nums = torch.from_numpy(holos, ), torch.from_numpy(fresnel_nums)
npad = 1
# apply padding first
pad, crop, padshape = sympad(holos.shape, npad)
holos = pad(holos)
holoshape = N, *imshape = holos.shape
tikhonov_optimization = tikhonov(holos, fresnel_nums)
t_start = time.time()
with torch.profiler.profile(
activities=[torch.profiler.profiler.ProfilerActivity.CUDA,
torch.profiler.profiler.ProfilerActivity.CPU],
with_stack=True,
) as profiler:
fresnel, stats = tikhonov_optimization()
t_end = time.time()
print(f'Computation time: {t_end - t_start:.2e}')
print(profiler.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1))
# save resulsts from profiler
profiler.export_stacks('tikhonov-torch-profilter.stk')
profiler.export_chrome_trace('tikhonov-torch-trace.json')
# prepare for plotting, remove padding
recon = crop(fresnel.x).detach().cpu()
recon_np = recon.numpy()
inital_np = crop(stats["p0"]).detach().cpu().numpy()
n = recon.shape[0]
k = stats["n_iter"]
# save for later
hf = h5py.File('pyreconBeads.h5', 'w')
hf.create_dataset('tikhonov', data=recon_np)
hf.create_dataset('ctf', data=inital_np)
for var, s in stats.items():
if isinstance(s, torch.Tensor):
hf.create_dataset(f'stats/{var}', data=s[:k].cpu().detach().numpy())
hf.close()
# plotting
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(3*5., 5.))
fig.suptitle('Compare reconstructions')
ax1.set_title(r'Tikhonov')
im1 = ax1.imshow(recon_np)
fig.colorbar(im1, ax=ax1)
ax2.set_title('CTF (initial)')
im2 = ax2.imshow(inital_np)
fig.colorbar(im2, ax=ax2)
ax3.set_title('Line profile')
ax3.plot(inital_np[n // 2], '--', c='b', label='CTF')
ax3.plot(recon_np[n // 2], c='r', label='Tikhonov')
ax3.legend()
plt.show()
\ No newline at end of file
import torch
import timeit
import time
from torch.fft import fftn, ifftn
from torch.cuda import empty_cache
from torch.profiler import profiler
device = 'cuda'
shape = (1920, 512)
N = 10
n_evals = 200
x = torch.rand((N, *shape), device=device, dtype=torch.double)
m = torch.rand_like(x) * 0.5 * (1 + 1j)
print(f'm-dtype {m.dtype}')
def torch_fft_nograd():
#x_ = x
#x_.requires_grad_(False)
with torch.no_grad():
ifftn(m * fftn(x))
def torch_fft_grad():
#x_ = x
#x_.requires_grad_(True)
with torch.enable_grad():
ifftn(m * fftn(x))
@torch.jit.script
def _torch_fft_jit(xx, mm):
ifftn(mm * fftn(xx))
def torch_fft_jit():
_torch_fft_jit(x, m)
empty_cache()
t_nograd = timeit.timeit(stmt=torch_fft_nograd, number=n_evals)
time.sleep(.5)
empty_cache()
time.sleep(.1)
t_grad = timeit.timeit(stmt=torch_fft_grad, number=n_evals)
time.sleep(.5)
empty_cache()
time.sleep(.1)
t_jit = timeit.timeit(stmt=torch_fft_jit, number=n_evals)
print(f"Time FFT2: no-grad {t_nograd:.3e} grad: {t_grad:.3e} grad: {t_jit:.3e}")
print(f"AvgTime FFT2: no-grad {t_nograd/n_evals:.3e} grad: {t_grad/n_evals:.3e} {t_jit/n_evals:.3e}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment