Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,14 @@ def render_points(

return sdata

@_deprecation_alias(elements="element", quantiles_for_norm="percentiles_for_norm", version="version 0.3.0")
@_deprecation_alias(elements="element", version="version 0.3.0")
def render_images(
self,
element: str | None = None,
*,
channel: list[str] | list[int] | str | int | None = None,
cmap: list[Colormap | str] | Colormap | str | None = None,
norm: Normalize | None = None,
norm: list[Normalize] | Normalize | None = None,
na_color: ColorLike | None = "default",
palette: list[str] | str | None = None,
alpha: float | int = 1.0,
Expand Down Expand Up @@ -541,9 +541,10 @@ def render_images(
cmap : list[Colormap | str] | Colormap | str | None
Colormap or list of colormaps for continuous annotations, see :class:`matplotlib.colors.Colormap`.
Each colormap applies to a corresponding channel.
norm : Normalize | None, optional
norm : list[Normalize] | Normalize | None, optional
Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`.
Applies to all channels if set.
Can be a single :class:`~matplotlib.colors.Normalize` (applied to all channels) or a list
of :class:`~matplotlib.colors.Normalize` objects (one per channel) for per-channel control.
na_color : ColorLike | None, default "default" (gets set to "lightgray")
Color to be used for NAs values, if present. Can either be a named color ("red"), a hex representation
("#000000ff") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). When None, the values
Expand Down Expand Up @@ -596,12 +597,16 @@ def render_images(
n_steps = len(sdata.plotting_tree.keys())

for element, param_values in params_dict.items():
# Per-channel norms are stored in ImageRenderParams.norms. In that case
# _prepare_cmap_norm gets None and creates a default norm used as fallback
# for the single-channel rendering path.
scalar_norm = None if isinstance(norm, list) else norm
cmap_params: list[CmapParams] | CmapParams
if isinstance(cmap, list):
cmap_params = [
_prepare_cmap_norm(
cmap=c,
norm=norm,
norm=scalar_norm,
na_color=param_values["na_color"],
)
for c in cmap
Expand All @@ -610,7 +615,7 @@ def render_images(
else:
cmap_params = _prepare_cmap_norm(
cmap=cmap,
norm=norm,
norm=scalar_norm,
na_color=param_values["na_color"],
**kwargs,
)
Expand All @@ -619,6 +624,7 @@ def render_images(
channel=param_values["channel"],
cmap_params=cmap_params,
palette=param_values["palette"],
norms=norm if isinstance(norm, list) else None,
alpha=param_values["alpha"],
scale=param_values["scale"],
zorder=n_steps,
Expand Down
158 changes: 72 additions & 86 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import spatialdata as sd
from anndata import AnnData
from matplotlib.cm import ScalarMappable
from matplotlib.colors import ListedColormap, Normalize
from matplotlib.colors import Colormap, ListedColormap, Normalize
from scanpy._settings import settings as sc_settings
from spatialdata import get_extent, get_values, join_spatialelement_table
from spatialdata._core.query.relational_query import match_table_to_element
Expand Down Expand Up @@ -275,7 +275,8 @@ def _render_shapes(
# When groups are specified, filter out non-matching elements by default.
# Only show non-matching elements if the user explicitly sets na_color.
_na = render_params.cmap_params.na_color
if groups is not None and values_are_categorical and (_na.default_color_set or _na.alpha == "00"):
na_is_transparent = _na.default_color_set or _na.get_alpha_as_float() == 0.0
if groups is not None and values_are_categorical and na_is_transparent:
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
groups, color_source_vector, color_vector
)
Expand Down Expand Up @@ -887,7 +888,8 @@ def _render_points(
# When groups are specified, filter out non-matching elements by default.
# Only show non-matching elements if the user explicitly sets na_color.
_na = render_params.cmap_params.na_color
if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"):
na_is_transparent = _na.default_color_set or _na.get_alpha_as_float() == 0.0
if groups is not None and color_source_vector is not None and na_is_transparent:
keep, color_source_vector, color_vector = _filter_groups_transparent_na(
groups, color_source_vector, color_vector
)
Expand Down Expand Up @@ -1175,6 +1177,42 @@ def _render_points(
)


def _additive_blend(
layers: dict[str | int, np.ndarray],
channels: list[Any],
channel_cmaps: list[Colormap],
) -> np.ndarray:
"""Composite colormapped channels via signal-based additive blending.

For each channel the "signal" is the deviation of the mapped color from
that colormap's zero-value color (cmap(0)). Signals are summed and placed
on a canvas whose color is the mean zero-value across all cmaps.

This handles two common compositing strategies:

* Black-to-color LUTs (Napari/ImageJ additive): cmap(0) is black, so
the signal equals cmap(val) and the canvas is black. This is identical
to a naive RGB sum.
* White-to-color LUTs (ImageJ Composite Invert): cmap(0) is white, so
the signal is cmap(val) minus white and the canvas is white. This is
equivalent to inverting, adding, then inverting again.

For arbitrary colormaps (e.g. diverging) the same formula produces a
reasonable result by extracting each cmap's contribution relative to its
own background.
"""
if not layers:
raise ValueError("Cannot blend an empty set of layers.")
height, width = next(iter(layers.values())).shape
zero_colors = np.array([cm(0.0)[:3] for cm in channel_cmaps])
canvas = np.mean(zero_colors, axis=0)
composite = np.full((height, width, 3), canvas, dtype=float)
for idx, (ch, cmap) in enumerate(zip(channels, channel_cmaps, strict=True)):
rgba = cmap(np.asarray(layers[ch]))
composite += rgba[..., :3] - zero_colors[idx]
return np.clip(composite, 0, 1, out=composite)


def _render_images(
sdata: sd.SpatialData,
render_params: ImageRenderParams,
Expand Down Expand Up @@ -1225,22 +1263,19 @@ def _render_images(

# True if user gave n cmaps for n channels
got_multiple_cmaps = isinstance(render_params.cmap_params, list)
if got_multiple_cmaps:
logger.warning(
"You're blending multiple cmaps. "
"If the plot doesn't look like you expect, it might be because your "
"cmaps go from a given color to 'white', and not to 'transparent'. "
"Therefore, the 'white' of higher layers will overlay the lower layers. "
"Consider using 'palette' instead."
)

# not using got_multiple_cmaps here because of ruff :(
# ruff needs the isinstance check here for type narrowing
if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels:
raise ValueError("If 'cmap' is provided, its length must match the number of channels.")

if render_params.norms is not None and len(render_params.norms) != n_channels:
raise ValueError(
f"Length of 'norm' list ({len(render_params.norms)}) must match the number of channels ({n_channels})."
)

_, trans_data = _prepare_transformation(img, coordinate_system, ax)

# 1) Image has only 1 channel
# Single channel
if n_channels == 1 and not isinstance(render_params.cmap_params, list):
layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze()

Expand All @@ -1255,13 +1290,14 @@ def _render_images(
cmap._lut[:, -1] = render_params.alpha

# norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip.
single_norm = render_params.norms[0] if render_params.norms else render_params.cmap_params.norm
_ax_show_and_transform(
layer,
trans_data,
ax,
cmap=cmap,
zorder=render_params.zorder,
norm=render_params.cmap_params.norm,
norm=single_norm,
)

wants_colorbar = _should_request_colorbar(
Expand All @@ -1271,7 +1307,7 @@ def _render_images(
auto_condition=n_channels == 1,
)
if wants_colorbar and legend_params.colorbar and colorbar_requests is not None:
sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=single_norm)
colorbar_requests.append(
ColorbarSpec(
ax=ax,
Expand All @@ -1285,40 +1321,31 @@ def _render_images(
)
)

# 2) Image has any number of channels but 1
# Multiple channels
else:
layers = {}
for ch_idx, ch in enumerate(channels):
layers[ch] = img.sel(c=ch).copy(deep=True).squeeze()
if isinstance(render_params.cmap_params, list):
if render_params.norms is not None:
ch_norm = render_params.norms[ch_idx]
elif isinstance(render_params.cmap_params, list):
ch_norm = render_params.cmap_params[ch_idx].norm
else:
ch_norm = render_params.cmap_params.norm

if ch_norm is not None:
layers[ch] = ch_norm(layers[ch])

# 2A) Image has 3 channels, no palette info, and no/only one cmap was given
# Image has 3 channels, no palette, and at most one cmap
if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list):
if render_params.cmap_params.cmap_is_default: # -> use RGB
if render_params.cmap_params.cmap_is_default: # treat as RGB
stacked = np.stack([layers[ch] for ch in layers], axis=-1)
else: # -> use given cmap for each channel
else: # apply the given cmap to each channel
channel_cmaps = [render_params.cmap_params.cmap] * n_channels
stacked = (
np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
/ n_channels
)
stacked = stacked[:, :, :3]
stacked = _additive_blend(layers, channels, channel_cmaps)
logger.warning(
"One cmap was given for multiple channels and is now used for each channel. "
"You're blending multiple cmaps. "
"If the plot doesn't look like you expect, it might be because your "
"cmaps go from a given color to 'white', and not to 'transparent'. "
"Therefore, the 'white' of higher layers will overlay the lower layers. "
"Consider using 'palette' instead."
"Consider using 'palette' for black-to-color compositing instead."
)

_ax_show_and_transform(
Expand All @@ -1329,25 +1356,11 @@ def _render_images(
zorder=render_params.zorder,
)

# 2B) Image has n channels, no palette/cmap info -> sample n categorical colors
# n channels, no palette/cmap: sample n categorical colors
elif palette is None and not got_multiple_cmaps:
# overwrite if n_channels == 2 for intuitive result
# For 2 channels default to red/green for an intuitive result
if n_channels == 2:
seed_colors = ["#ff0000ff", "#00ff00ff"]
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack(
[channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)],
0,
).sum(0)
colored = colored[:, :, :3]
elif n_channels == 3:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
colored = colored[:, :, :3]
else:
if isinstance(render_params.cmap_params, list):
cmap_is_default = render_params.cmap_params[0].cmap_is_default
Expand All @@ -1357,31 +1370,16 @@ def _render_images(
if cmap_is_default:
seed_colors = _get_colors_for_categorical_obs(list(range(n_channels)))
else:
# Sample n_channels colors evenly from the colormap
# Sample n_channels evenly spaced colors from the colormap
if isinstance(render_params.cmap_params, list):
seed_colors = [
render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels)
]
else:
seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)]
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]

# Stack (n_channels, height, width) → (height*width, n_channels)
H, W = next(iter(layers.values())).shape
comp_rgb = np.zeros((H, W, 3), dtype=float)

# For each channel: map to RGBA, apply constant alpha, then add
for ch_idx, ch in enumerate(channels):
layer_arr = layers[ch]
rgba = channel_cmaps[ch_idx](layer_arr)
rgba[..., 3] = render_params.alpha
comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]

colored = np.clip(comp_rgb, 0, 1)
logger.info(
f"Your image has {n_channels} channels. Sampling categorical colors and using "
f"multichannel strategy 'stack' to render."
) # TODO: update when pca is added as strategy
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors]
colored = _additive_blend(layers, channels, channel_cmaps)

_ax_show_and_transform(
colored,
Expand All @@ -1391,14 +1389,13 @@ def _render_images(
zorder=render_params.zorder,
)

# 2C) Image has n channels and palette info
# n channels with palette: build black-to-color LUT per channel
elif palette is not None and not got_multiple_cmaps:
if len(palette) != n_channels:
raise ValueError("If 'palette' is provided, its length must match the number of channels.")

channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)]
colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0)
colored = colored[:, :, :3]
channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette]
colored = _additive_blend(layers, channels, channel_cmaps)

_ax_show_and_transform(
colored,
Expand All @@ -1410,14 +1407,7 @@ def _render_images(

elif palette is None and got_multiple_cmaps:
channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr]
colored = (
np.stack(
[channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)],
0,
).sum(0)
/ n_channels
)
colored = colored[:, :, :3]
colored = _additive_blend(layers, channels, channel_cmaps)

_ax_show_and_transform(
colored,
Expand All @@ -1427,7 +1417,7 @@ def _render_images(
zorder=render_params.zorder,
)

# 2D) Image has n channels, no palette but cmap info
# n channels with both palette and cmap is not allowed
elif palette is not None and got_multiple_cmaps:
raise ValueError("If 'palette' is provided, 'cmap' must be None.")

Expand Down Expand Up @@ -1534,12 +1524,8 @@ def _render_labels(
# When groups are specified, zero out non-matching label IDs so they render as background.
# Only show non-matching labels if the user explicitly sets na_color.
_na = render_params.cmap_params.na_color
if (
groups is not None
and categorical
and color_source_vector is not None
and (_na.default_color_set or _na.alpha == "00")
):
na_is_transparent = _na.default_color_set or _na.get_alpha_as_float() == 0.0
if groups is not None and categorical and color_source_vector is not None and na_is_transparent:
keep_vec = color_source_vector.isin(groups)
matching_ids = instance_id[keep_vec]
keep_mask = np.isin(label.values, matching_ids)
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata_plot/pl/render_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ class ImageRenderParams:
element: str
channel: list[str] | list[int] | int | str | None = None
palette: ListedColormap | list[str] | None = None
norms: list[Normalize] | None = None # per-channel norms, separate from the scalar norm in CmapParams
alpha: float = 1.0
percentiles_for_norm: tuple[float | None, float | None] = (None, None)
scale: str | None = None
zorder: int = 0
colorbar: bool | str | None = "auto"
Expand Down
9 changes: 8 additions & 1 deletion src/spatialdata_plot/pl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2394,7 +2394,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st

norm = param_dict.get("norm")
if norm is not None:
if element_type in {"images", "labels"} and not isinstance(norm, Normalize):
if element_type == "images" and isinstance(norm, Sequence) and not isinstance(norm, list):
raise TypeError("'norm' must be a list of Normalize objects, not a tuple or other sequence.")
if element_type == "images" and isinstance(norm, list):
if len(norm) == 0:
raise ValueError("'norm' list must not be empty.")
if not all(isinstance(n, Normalize) for n in norm):
raise TypeError("All elements of 'norm' list must be of type Normalize.")
elif element_type in {"images", "labels"} and not isinstance(norm, Normalize):
raise TypeError("Parameter 'norm' must be of type Normalize.")
if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize):
raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.")
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_pass_cmap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_pass_cmap_list.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_pass_cmap_to_each_channel.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_pass_str_cmap.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/_images/Images_can_pass_str_cmap_list.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading