Commit d62b7fc5 authored by Simon Maretzke's avatar Simon Maretzke
Browse files

Extended and improved GPU-capabilities

parent 5ef73908
......@@ -50,7 +50,7 @@ function result = phaserec_ctf(holograms, fresnelNumbers, settings)
% in result are <= maxPhase. If maxPhase == inf, no maximum-constraint is imposed.
% useGPUIfPossible : Default = true
% Run the iterative optimization on GPU if GPU-computations are supported (only relevant
% if support and/or min/max-constraints are imposed.
% if support and/or min/max-constraints are imposed)
%
% Returns
% -------
......@@ -181,10 +181,27 @@ end
% Determine whether iterative optimization has to be performed: this is the case if min-, max-
% or support-constraints are to be imposed
doIterativeOptimization = (numel(settings.support) > 0 || settings.minPhase > -inf || settings.maxPhase < inf);
% Determine whether the GPU is to be used in the reconstruction: this is the case if either the
% holograms have already been assigned as gpuArray or optionally in the case of iterative optimization
useGPU = isOnGPU(holograms) || (settings.useGPUIfPossible && doIterativeOptimization && checkGPUSupport());
if settings.useGPUIfPossible && doIterativeOptimization && ~useGPU
warning(['Unable to initialize GPU. Running iterative optimization on CPU instead. ', ...
'Set settings.useGPUIfPossible = false to suppress this warning.']);
end
% Optional padding of holograms
holograms = padarray(holograms, [settings.pady settings.padx], 'replicate');
N = size(holograms(:,:,1));
% Assemble CTF-inversion formula:
%
% result = argmin_f { sum_j ||2*ctf{j}.* fft2(f) - fft2(holograms{j}-1)||_2^2
......@@ -198,25 +215,28 @@ N = size(holograms(:,:,1));
sumCTFHolograms = 0;
sumCTFSq = 0;
for holo_idx = 1:numHolos
xiNormSqFresnel = xi_y.^2/(4*pi*fresnelNumbers(1,holo_idx)) + xi_x.^2/(4*pi*fresnelNumbers(2,holo_idx));
xiNormSqFresnel = gpuArrayIf(xi_y, useGPU).^2/(4*pi*fresnelNumbers(1,holo_idx)) ...
+ gpuArrayIf(xi_x, useGPU).^2/(4*pi*fresnelNumbers(2,holo_idx));
ctf = sin(xiNormSqFresnel);
if settings.betaDeltaRatio > 0
ctf = ctf + settings.betaDeltaRatio * cos(xiNormSqFresnel);
end
sumCTFHolograms = sumCTFHolograms + ctf .* fft2(holograms(:,:,holo_idx));
sumCTFHolograms = sumCTFHolograms + ctf .* fft2(gpuArrayIf(holograms(:,:,holo_idx), useGPU));
sumCTFSq = sumCTFSq + ctf.^2;
end
sumCTFSq = 2*sumCTFSq;
clear 'ctf' 'xiNormSqFresnel';
% Correction for zero-frequency of Fourier transformed data:
% equivalent to subtracting a constant 1 from all input holograms
sumCTFHolograms(1,1) = sumCTFHolograms(1,1) - prod(N) * numHolos * settings.betaDeltaRatio;
% Assemble regularization weights in Fourier-space that smoothly transitions from the value
% Add regularization weights in Fourier-space that smoothly transition from the value
% lim1 in the low-frequency regime (around the central CTF-minimum) to lim2 at larger
% larger spatial frequencies beyond the first CTF-maximum
regWeights = ctfRegWeights(N, mean(fresnelNumbers,2), settings.lim1, settings.lim2);
sumCTFSqReg = sumCTFSq + ctfRegWeights(N, mean(fresnelNumbers,2), settings.lim1, settings.lim2, useGPU);
clear 'sumCTFSq';
......@@ -225,22 +245,13 @@ regWeights = ctfRegWeights(N, mean(fresnelNumbers,2), settings.lim1, settings.li
% ----------------------------------------------------------------------------------------------- %
% ----- Special case: Minimization with constraints. Required iterative solver (fast ADMM) ------ %
% ----------------------------------------------------------------------------------------------- %
if numel(settings.support) || settings.minPhase > -inf || settings.maxPhase < inf
% Run iterative algorithm on GPU if desired and possible
useGPU = settings.useGPUIfPossible && checkGPUSupport();
if settings.useGPUIfPossible && ~useGPU
warning(['Unable to initialize GPU. Running iterative optimization on CPU instead. ', ...
'Set settings.useGPUIfPossible = false to suppress this warning.']);
end
if doIterativeOptimization
% Assemble proximal map of the regularized CTF-functional
%
% F(f) := ||2*ctf{j}.*fft2(f) - fft2(holograms{j}-1)||_2^2
% + 2*||sqrt(regWeights) .* fft2(f)||_2^2
%
sumCTFSqReg = gpuArrayIf(sumCTFSq + regWeights, useGPU);
sumCTFHolograms = gpuArrayIf(sumCTFHolograms, useGPU);
proxCTF = @(f, sigma) real(ifft2( (sumCTFHolograms + fft2((1./sigma)*f)) ...
./ (sumCTFSqReg + 1./sigma) ));
......@@ -277,13 +288,16 @@ if numel(settings.support) || settings.minPhase > -inf || settings.maxPhase < i
% ---------- Default case: Minimization without constraints. Can be performed directly ---------- %
% ----------------------------------------------------------------------------------------------- %
else
result = real(ifft2(sumCTFHolograms ./ (sumCTFSq + regWeights)));
result = real(ifft2(sumCTFHolograms ./ sumCTFSqReg));
end
% Undo optional padding
result = gather(croparray(result, [settings.pady, settings.padx]));
result = croparray(result, [settings.pady, settings.padx]);
% Ensure that result is copied back from GPU if holograms were not already assigned as gpuArray
result = gatherIf(result, ~isOnGPU(holograms));
end
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment