Commit 8537d15f authored by Matthijs's avatar Matthijs
Browse files

compatibility with complex 3d data

parent ce126a43
import matplotlib.pyplot as plt
from matplotlib import get_backend
from numpy import max, min
from numpy import max, min, iscomplexobj
from matplotlib.widgets import Slider
from warnings import warn
from .complex_field_visualization import complex_to_rgb
good_backends = ['Qt5Agg', 'TkAgg', 'Qt4Agg']
......@@ -50,18 +51,34 @@ class SingleStackViewer:
class XYZStackViewer:
def __init__(self, volume, limit_sliders=(0, 0, 0), cmap: str = 'seismic', clim: tuple = (None, None)):
def __init__(self, volume, limit_sliders=(0, 0, 0), cmap: str = 'seismic', clim: tuple = (None, None),
data_transform: callable = None):
"""
Dynamic matplotlib plot showing slices out of a 3d dataset
:param volume: real-valued 3d numpy array
:param volume: 3d numpy array
:param limit_sliders: sliders allow range given by [limit, N-limit], where N is the maximal length
:param cmap: color map to use (e.g. 'seismic' or 'viridis'). For 'seismic, will center on 0, else'
:param cmap: color map to use (e.g. 'seismic' or 'viridis'). For 'seismic', will center on 0
:param clim: tuple with the min-max values of the colorscale
:param data_transform: simple transformation function to apply to data before plot.
defaults to abs() for complex data
"""
if get_backend() not in good_backends:
warn(Warning('Current matplotlib backend may not allow for optimal funcionality! Use, e.g., Qt'))
if cmap is not None:
show_colorbar = True
if data_transform is not None:
self.cast_fn = data_transform
elif iscomplexobj(volume):
self.cast_fn = complex_to_rgb
show_colorbar = False
else:
def dummy(array):
return array
self.cast_fn = dummy
self.volume = volume
self.indices = [s // 2 for s in volume.shape]
self.shape = self.volume.shape
......@@ -75,13 +92,17 @@ class XYZStackViewer:
self.minval = min([0, min(volume)])
self.fig, self.ax = plt.subplots(1, 3, figsize=(10, 4))
im = self.ax[0].imshow(volume[self.indices[0], :, :], cmap=cmap, vmin=self.minval, vmax=self.maxval)
im = self.ax[1].imshow(volume[:, self.indices[0], :], cmap=cmap, vmin=self.minval, vmax=self.maxval)
im = self.ax[2].imshow(volume[:, :, self.indices[0]], cmap=cmap, vmin=self.minval, vmax=self.maxval)
im = self.ax[0].imshow(self.cast_fn(volume[self.indices[0], :, :]), cmap=cmap, vmin=self.minval,
vmax=self.maxval)
im = self.ax[1].imshow(self.cast_fn(volume[:, self.indices[0], :]), cmap=cmap, vmin=self.minval,
vmax=self.maxval)
im = self.ax[2].imshow(self.cast_fn(volume[:, :, self.indices[0]]), cmap=cmap, vmin=self.minval,
vmax=self.maxval)
self.ax[1].set_yticks([])
self.ax[2].set_yticks([])
self.fig.subplots_adjust(bottom=0.1, right=0.85)
if show_colorbar:
cbar_ax = self.fig.add_axes([0.87, 0.2, 0.01, 0.6])
plt.colorbar(im, cax=cbar_ax)
......@@ -108,19 +129,19 @@ class XYZStackViewer:
def update_1(self, n):
"""Update subfigure 1 on change of the slider """
self.indices[0] = int(n)
self.ax[0].images[0].set_array(self.volume[int(n), :, :])
self.ax[0].images[0].set_array(self.cast_fn(self.volume[int(n), :, :]))
self.fig.canvas.draw_idle()
def update_2(self, n):
"""Update subfigure 2 on change of the slider """
self.indices[1] = int(n)
self.ax[1].images[0].set_array(self.volume[:, int(n), :])
self.ax[1].images[0].set_array(self.cast_fn(self.volume[:, int(n), :]))
self.fig.canvas.draw_idle()
def update_3(self, n):
"""Update subfigure 3 on change of the slider """
self.indices[2] = int(n)
self.ax[2].images[0].set_array(self.volume[:, :, int(n)])
self.ax[2].images[0].set_array(self.cast_fn(self.volume[:, :, int(n)]))
self.fig.canvas.draw_idle()
# def process_key(self, event):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment