Skip to content

First poc #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Oct 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
__pycache__
*.pyc
*.pyo
*.egg-info
dist/
build/
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,32 @@
# dash-3d-viewer
# dash-3d-viewer

A tool to make it easy to build slice-views on 3D image data, in Dash apps.

The API is currently a WIP.


## Installation

Eventually, this would be pip-installable. For now, use the developer workflow.


## Usage

TODO, see the examples.


## License

This code is distributed under MIT license.


## Developers


* Make sure that you have Python with the appropriate dependencies installed, e.g. via `venv`.
* Run `pip install -e .` to do an in-place install of the package.
* Run the examples using e.g. `python examples/slicer_with_1_view.py`

* Use `black .` to autoformat.
* Use `flake8 .` to lint.
* Use `pytest .` to run the tests.
10 changes: 10 additions & 0 deletions dash_3d_viewer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
Dash 3d viewer - a tool to make it easy to build slice-views on 3D image data.
"""


from .slicer import DashVolumeSlicer # noqa: F401


__version__ = "0.0.1"
version_info = tuple(map(int, __version__.split(".")))
192 changes: 192 additions & 0 deletions dash_3d_viewer/slicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import numpy as np
from plotly.graph_objects import Figure
from dash import Dash
from dash.dependencies import Input, Output, State
from dash_core_components import Graph, Slider, Store

from .utils import gen_random_id, img_array_to_uri


class DashVolumeSlicer:
"""A slicer to show 3D image data in Dash."""

def __init__(self, app, volume, axis=0, id=None):
if not isinstance(app, Dash):
raise TypeError("Expect first arg to be a Dash app.")
# Check and store volume
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
raise TypeError("Expected volume to be a 3D numpy array")
self._volume = volume
# Check and store axis
if not (isinstance(axis, int) and 0 <= axis <= 2):
raise ValueError("The given axis must be 0, 1, or 2.")
self._axis = int(axis)
# Check and store id
if id is None:
id = gen_random_id()
elif not isinstance(id, str):
raise TypeError("Id must be a string")
self._id = id

# Get the slice size (width, height), and max index
arr_shape = list(volume.shape)
arr_shape.pop(self._axis)
slice_size = list(reversed(arr_shape))
self._max_index = self._volume.shape[self._axis] - 1

# Create the figure object
fig = Figure()
fig.update_layout(
template=None,
margin=dict(l=0, r=0, b=0, t=0, pad=4),
)
fig.update_xaxes(
showgrid=False,
range=(0, slice_size[0]),
showticklabels=False,
zeroline=False,
)
fig.update_yaxes(
showgrid=False,
scaleanchor="x",
range=(slice_size[1], 0), # todo: allow flipping x or y
showticklabels=False,
zeroline=False,
)
# Add an empty layout image that we can populate from JS.
fig.add_layout_image(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know we sent you code where the image was added as a layout image but the more "modern" way of doing this is to use an Image trace like in https://github.com/plotly/dash-sample-apps/blob/master/apps/dash-covid-xray/app.py#L68 and https://github.com/plotly/dash-sample-apps/blob/master/apps/dash-covid-xray/app.py#L418. This way you can get hover and click events on image pixels, which is not the case with a layout image

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that the two examples used different methods, and I assumed that layout images where easier, but will change this :) I think I'll do that in a new PR. Are there any other differences? E.g. I saw that with layout images one can create a stack of them and use alpha blending to overlay e.g. segmentation results. Is that possible with image traces as well?

dict(
source="",
xref="x",
yref="y",
x=0,
y=0,
sizex=slice_size[0],
sizey=slice_size[1],
sizing="contain",
layer="below",
)
)
# Wrap the figure in a graph
# todo: or should the user provide this?
self.graph = Graph(
id=self._subid("graph"),
figure=fig,
config={"scrollZoom": True},
)
# Create a slider object that the user can put in the layout (or not)
self.slider = Slider(
id=self._subid("slider"),
min=0,
max=self._max_index,
step=1,
value=self._max_index // 2,
updatemode="drag",
)
# Create the stores that we need (these must be present in the layout)
self.stores = [
Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2),
Store(id=self._subid("_requested-slice-index"), data=0),
Store(id=self._subid("_slice-data"), data=""),
]

self._create_server_callbacks(app)
self._create_client_callbacks(app)

def _subid(self, subid):
"""Given a subid, get the full id including the slicer's prefix."""
return self._id + "-" + subid

def _slice(self, index):
"""Sample a slice from the volume."""
indices = [slice(None), slice(None), slice(None)]
indices[self._axis] = index
return self._volume[tuple(indices)]

def _create_server_callbacks(self, app):
"""Create the callbacks that run server-side."""

@app.callback(
Output(self._subid("_slice-data"), "data"),
[Input(self._subid("_requested-slice-index"), "data")],
)
def upload_requested_slice(slice_index):
slice = self._slice(slice_index)
slice = (slice.astype(np.float32) * (255 / slice.max())).astype(np.uint8)
return [slice_index, img_array_to_uri(slice)]

def _create_client_callbacks(self, app):
"""Create the callbacks that run client-side."""

app.clientside_callback(
"""
function handle_slider_move(index) {
return index;
}
""",
Output(self._subid("slice-index"), "data"),
[Input(self._subid("slider"), "value")],
)

app.clientside_callback(
"""
function handle_slice_index(index) {
if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; }
let slice_cache = window.slicecache_for_{{ID}};
if (slice_cache[index]) {
return window.dash_clientside.no_update;
} else {
console.log('requesting slice ' + index)
return index;
}
}
""".replace(
"{{ID}}", self._id
),
Output(self._subid("_requested-slice-index"), "data"),
[Input(self._subid("slice-index"), "data")],
)

# app.clientside_callback("""
# function update_slider_pos(index) {
# return index;
# }
# """,
# [Output("slice-index", "data")],
# [State("slider", "value")],
# )

app.clientside_callback(
"""
function handle_incoming_slice(index, index_and_data, ori_figure) {
let new_index = index_and_data[0];
let new_data = index_and_data[1];
// Store data in cache
if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; }
let slice_cache = window.slicecache_for_{{ID}};
slice_cache[new_index] = new_data;
// Get the data we need *now*
let data = slice_cache[index];
// Maybe we do not need an update
if (!data) {
return window.dash_clientside.no_update;
}
if (data == ori_figure.layout.images[0].source) {
return window.dash_clientside.no_update;
}
// Otherwise, perform update
console.log("updating figure");
let figure = {...ori_figure};
figure.layout.images[0].source = data;
return figure;
}
""".replace(
"{{ID}}", self._id
),
Output(self._subid("graph"), "figure"),
[
Input(self._subid("slice-index"), "data"),
Input(self._subid("_slice-data"), "data"),
],
[State(self._subid("graph"), "figure")],
)
19 changes: 19 additions & 0 deletions dash_3d_viewer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import random

import PIL.Image
import skimage
from plotly.utils import ImageUriValidator


def gen_random_id(n=6):
return "".join(random.choice("abcdefghijklmnopqrtsuvwxyz") for i in range(n))


def img_array_to_uri(img_array):
img_array = skimage.util.img_as_ubyte(img_array)
# todo: leverage this Plotly util once it becomes part of the public API (also drops the Pillow dependency)
# from plotly.express._imshow import _array_to_b64str
# return _array_to_b64str(img_array)
img_pil = PIL.Image.fromarray(img_array)
uri = ImageUriValidator.pil_image_to_uri(img_pil)
return uri
20 changes: 20 additions & 0 deletions examples/slicer_with_1_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
A truly minimal example.
"""

import dash
import dash_html_components as html
from dash_3d_viewer import DashVolumeSlicer
import imageio


app = dash.Dash(__name__)

vol = imageio.volread("imageio:stent.npz")
slicer = DashVolumeSlicer(app, vol)

app.layout = html.Div([slicer.graph, slicer.slider, *slicer.stores])


if __name__ == "__main__":
app.run_server(debug=False)
46 changes: 46 additions & 0 deletions examples/slicer_with_2_views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
An example with two slicers on the same volume.
"""

import dash
import dash_html_components as html
from dash_3d_viewer import DashVolumeSlicer
import imageio


app = dash.Dash(__name__)

vol = imageio.volread("imageio:stent.npz")
slicer1 = DashVolumeSlicer(app, vol, axis=1, id="slicer1")
slicer2 = DashVolumeSlicer(app, vol, axis=2, id="slicer2")

app.layout = html.Div(
style={
"display": "grid",
"grid-template-columns": "40% 40%",
},
children=[
html.Div(
[
html.H1("Coronal"),
slicer1.graph,
html.Br(),
slicer1.slider,
*slicer1.stores,
]
),
html.Div(
[
html.H1("Sagittal"),
slicer2.graph,
html.Br(),
slicer2.slider,
*slicer2.stores,
]
),
],
)


if __name__ == "__main__":
app.run_server(debug=True)
Loading