diff --git a/src/spatialdata_plot/pl/basic.py b/src/spatialdata_plot/pl/basic.py index 8af4e177..3d486e57 100644 --- a/src/spatialdata_plot/pl/basic.py +++ b/src/spatialdata_plot/pl/basic.py @@ -505,14 +505,14 @@ def render_points( return sdata - @_deprecation_alias(elements="element", quantiles_for_norm="percentiles_for_norm", version="version 0.3.0") + @_deprecation_alias(elements="element", version="version 0.3.0") def render_images( self, element: str | None = None, *, channel: list[str] | list[int] | str | int | None = None, cmap: list[Colormap | str] | Colormap | str | None = None, - norm: Normalize | None = None, + norm: list[Normalize] | Normalize | None = None, na_color: ColorLike | None = "default", palette: list[str] | str | None = None, alpha: float | int = 1.0, @@ -541,9 +541,10 @@ def render_images( cmap : list[Colormap | str] | Colormap | str | None Colormap or list of colormaps for continuous annotations, see :class:`matplotlib.colors.Colormap`. Each colormap applies to a corresponding channel. - norm : Normalize | None, optional + norm : list[Normalize] | Normalize | None, optional Colormap normalization for continuous annotations, see :class:`matplotlib.colors.Normalize`. - Applies to all channels if set. + Can be a single :class:`~matplotlib.colors.Normalize` (applied to all channels) or a list + of :class:`~matplotlib.colors.Normalize` objects (one per channel) for per-channel control. na_color : ColorLike | None, default "default" (gets set to "lightgray") Color to be used for NAs values, if present. Can either be a named color ("red"), a hex representation ("#000000ff") or a list of floats that represent RGB/RGBA values (1.0, 0.0, 0.0, 1.0). When None, the values @@ -596,12 +597,16 @@ def render_images( n_steps = len(sdata.plotting_tree.keys()) for element, param_values in params_dict.items(): + # Per-channel norms are stored in ImageRenderParams.norms. In that case + # _prepare_cmap_norm gets None and creates a default norm used as fallback + # for the single-channel rendering path. + scalar_norm = None if isinstance(norm, list) else norm cmap_params: list[CmapParams] | CmapParams if isinstance(cmap, list): cmap_params = [ _prepare_cmap_norm( cmap=c, - norm=norm, + norm=scalar_norm, na_color=param_values["na_color"], ) for c in cmap @@ -610,7 +615,7 @@ def render_images( else: cmap_params = _prepare_cmap_norm( cmap=cmap, - norm=norm, + norm=scalar_norm, na_color=param_values["na_color"], **kwargs, ) @@ -619,6 +624,7 @@ def render_images( channel=param_values["channel"], cmap_params=cmap_params, palette=param_values["palette"], + norms=norm if isinstance(norm, list) else None, alpha=param_values["alpha"], scale=param_values["scale"], zorder=n_steps, diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 9cddb499..976441df 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -17,7 +17,7 @@ import spatialdata as sd from anndata import AnnData from matplotlib.cm import ScalarMappable -from matplotlib.colors import ListedColormap, Normalize +from matplotlib.colors import Colormap, ListedColormap, Normalize from scanpy._settings import settings as sc_settings from spatialdata import get_extent, get_values, join_spatialelement_table from spatialdata._core.query.relational_query import match_table_to_element @@ -275,7 +275,8 @@ def _render_shapes( # When groups are specified, filter out non-matching elements by default. # Only show non-matching elements if the user explicitly sets na_color. _na = render_params.cmap_params.na_color - if groups is not None and values_are_categorical and (_na.default_color_set or _na.alpha == "00"): + na_is_transparent = _na.default_color_set or _na.get_alpha_as_float() == 0.0 + if groups is not None and values_are_categorical and na_is_transparent: keep, color_source_vector, color_vector = _filter_groups_transparent_na( groups, color_source_vector, color_vector ) @@ -887,7 +888,8 @@ def _render_points( # When groups are specified, filter out non-matching elements by default. # Only show non-matching elements if the user explicitly sets na_color. _na = render_params.cmap_params.na_color - if groups is not None and color_source_vector is not None and (_na.default_color_set or _na.alpha == "00"): + na_is_transparent = _na.default_color_set or _na.get_alpha_as_float() == 0.0 + if groups is not None and color_source_vector is not None and na_is_transparent: keep, color_source_vector, color_vector = _filter_groups_transparent_na( groups, color_source_vector, color_vector ) @@ -1175,6 +1177,42 @@ def _render_points( ) +def _additive_blend( + layers: dict[str | int, np.ndarray], + channels: list[Any], + channel_cmaps: list[Colormap], +) -> np.ndarray: + """Composite colormapped channels via signal-based additive blending. + + For each channel the "signal" is the deviation of the mapped color from + that colormap's zero-value color (cmap(0)). Signals are summed and placed + on a canvas whose color is the mean zero-value across all cmaps. + + This handles two common compositing strategies: + + * Black-to-color LUTs (Napari/ImageJ additive): cmap(0) is black, so + the signal equals cmap(val) and the canvas is black. This is identical + to a naive RGB sum. + * White-to-color LUTs (ImageJ Composite Invert): cmap(0) is white, so + the signal is cmap(val) minus white and the canvas is white. This is + equivalent to inverting, adding, then inverting again. + + For arbitrary colormaps (e.g. diverging) the same formula produces a + reasonable result by extracting each cmap's contribution relative to its + own background. + """ + if not layers: + raise ValueError("Cannot blend an empty set of layers.") + height, width = next(iter(layers.values())).shape + zero_colors = np.array([cm(0.0)[:3] for cm in channel_cmaps]) + canvas = np.mean(zero_colors, axis=0) + composite = np.full((height, width, 3), canvas, dtype=float) + for idx, (ch, cmap) in enumerate(zip(channels, channel_cmaps, strict=True)): + rgba = cmap(np.asarray(layers[ch])) + composite += rgba[..., :3] - zero_colors[idx] + return np.clip(composite, 0, 1, out=composite) + + def _render_images( sdata: sd.SpatialData, render_params: ImageRenderParams, @@ -1225,22 +1263,19 @@ def _render_images( # True if user gave n cmaps for n channels got_multiple_cmaps = isinstance(render_params.cmap_params, list) - if got_multiple_cmaps: - logger.warning( - "You're blending multiple cmaps. " - "If the plot doesn't look like you expect, it might be because your " - "cmaps go from a given color to 'white', and not to 'transparent'. " - "Therefore, the 'white' of higher layers will overlay the lower layers. " - "Consider using 'palette' instead." - ) - # not using got_multiple_cmaps here because of ruff :( + # ruff needs the isinstance check here for type narrowing if isinstance(render_params.cmap_params, list) and len(render_params.cmap_params) != n_channels: raise ValueError("If 'cmap' is provided, its length must match the number of channels.") + if render_params.norms is not None and len(render_params.norms) != n_channels: + raise ValueError( + f"Length of 'norm' list ({len(render_params.norms)}) must match the number of channels ({n_channels})." + ) + _, trans_data = _prepare_transformation(img, coordinate_system, ax) - # 1) Image has only 1 channel + # Single channel if n_channels == 1 and not isinstance(render_params.cmap_params, list): layer = img.sel(c=channels[0]).squeeze() if isinstance(channels[0], str) else img.isel(c=channels[0]).squeeze() @@ -1255,13 +1290,14 @@ def _render_images( cmap._lut[:, -1] = render_params.alpha # norm needs to be passed directly to ax.imshow(). If we normalize before, that method would always clip. + single_norm = render_params.norms[0] if render_params.norms else render_params.cmap_params.norm _ax_show_and_transform( layer, trans_data, ax, cmap=cmap, zorder=render_params.zorder, - norm=render_params.cmap_params.norm, + norm=single_norm, ) wants_colorbar = _should_request_colorbar( @@ -1271,7 +1307,7 @@ def _render_images( auto_condition=n_channels == 1, ) if wants_colorbar and legend_params.colorbar and colorbar_requests is not None: - sm = plt.cm.ScalarMappable(cmap=cmap, norm=render_params.cmap_params.norm) + sm = plt.cm.ScalarMappable(cmap=cmap, norm=single_norm) colorbar_requests.append( ColorbarSpec( ax=ax, @@ -1285,12 +1321,14 @@ def _render_images( ) ) - # 2) Image has any number of channels but 1 + # Multiple channels else: layers = {} for ch_idx, ch in enumerate(channels): layers[ch] = img.sel(c=ch).copy(deep=True).squeeze() - if isinstance(render_params.cmap_params, list): + if render_params.norms is not None: + ch_norm = render_params.norms[ch_idx] + elif isinstance(render_params.cmap_params, list): ch_norm = render_params.cmap_params[ch_idx].norm else: ch_norm = render_params.cmap_params.norm @@ -1298,27 +1336,16 @@ def _render_images( if ch_norm is not None: layers[ch] = ch_norm(layers[ch]) - # 2A) Image has 3 channels, no palette info, and no/only one cmap was given + # Image has 3 channels, no palette, and at most one cmap if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list): - if render_params.cmap_params.cmap_is_default: # -> use RGB + if render_params.cmap_params.cmap_is_default: # treat as RGB stacked = np.stack([layers[ch] for ch in layers], axis=-1) - else: # -> use given cmap for each channel + else: # apply the given cmap to each channel channel_cmaps = [render_params.cmap_params.cmap] * n_channels - stacked = ( - np.stack( - [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], - 0, - ).sum(0) - / n_channels - ) - stacked = stacked[:, :, :3] + stacked = _additive_blend(layers, channels, channel_cmaps) logger.warning( "One cmap was given for multiple channels and is now used for each channel. " - "You're blending multiple cmaps. " - "If the plot doesn't look like you expect, it might be because your " - "cmaps go from a given color to 'white', and not to 'transparent'. " - "Therefore, the 'white' of higher layers will overlay the lower layers. " - "Consider using 'palette' instead." + "Consider using 'palette' for black-to-color compositing instead." ) _ax_show_and_transform( @@ -1329,25 +1356,11 @@ def _render_images( zorder=render_params.zorder, ) - # 2B) Image has n channels, no palette/cmap info -> sample n categorical colors + # n channels, no palette/cmap: sample n categorical colors elif palette is None and not got_multiple_cmaps: - # overwrite if n_channels == 2 for intuitive result + # For 2 channels default to red/green for an intuitive result if n_channels == 2: seed_colors = ["#ff0000ff", "#00ff00ff"] - channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - colored = np.stack( - [channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)], - 0, - ).sum(0) - colored = colored[:, :, :3] - elif n_channels == 3: - seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) - channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - colored = np.stack( - [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], - 0, - ).sum(0) - colored = colored[:, :, :3] else: if isinstance(render_params.cmap_params, list): cmap_is_default = render_params.cmap_params[0].cmap_is_default @@ -1357,31 +1370,16 @@ def _render_images( if cmap_is_default: seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) else: - # Sample n_channels colors evenly from the colormap + # Sample n_channels evenly spaced colors from the colormap if isinstance(render_params.cmap_params, list): seed_colors = [ render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels) ] else: seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)] - channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - - # Stack (n_channels, height, width) → (height*width, n_channels) - H, W = next(iter(layers.values())).shape - comp_rgb = np.zeros((H, W, 3), dtype=float) - # For each channel: map to RGBA, apply constant alpha, then add - for ch_idx, ch in enumerate(channels): - layer_arr = layers[ch] - rgba = channel_cmaps[ch_idx](layer_arr) - rgba[..., 3] = render_params.alpha - comp_rgb += rgba[..., :3] * rgba[..., 3][..., None] - - colored = np.clip(comp_rgb, 0, 1) - logger.info( - f"Your image has {n_channels} channels. Sampling categorical colors and using " - f"multichannel strategy 'stack' to render." - ) # TODO: update when pca is added as strategy + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + colored = _additive_blend(layers, channels, channel_cmaps) _ax_show_and_transform( colored, @@ -1391,14 +1389,13 @@ def _render_images( zorder=render_params.zorder, ) - # 2C) Image has n channels and palette info + # n channels with palette: build black-to-color LUT per channel elif palette is not None and not got_multiple_cmaps: if len(palette) != n_channels: raise ValueError("If 'palette' is provided, its length must match the number of channels.") - channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette if isinstance(c, str)] - colored = np.stack([channel_cmaps[i](layers[c]) for i, c in enumerate(channels)], 0).sum(0) - colored = colored[:, :, :3] + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in palette] + colored = _additive_blend(layers, channels, channel_cmaps) _ax_show_and_transform( colored, @@ -1410,14 +1407,7 @@ def _render_images( elif palette is None and got_multiple_cmaps: channel_cmaps = [cp.cmap for cp in render_params.cmap_params] # type: ignore[union-attr] - colored = ( - np.stack( - [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], - 0, - ).sum(0) - / n_channels - ) - colored = colored[:, :, :3] + colored = _additive_blend(layers, channels, channel_cmaps) _ax_show_and_transform( colored, @@ -1427,7 +1417,7 @@ def _render_images( zorder=render_params.zorder, ) - # 2D) Image has n channels, no palette but cmap info + # n channels with both palette and cmap is not allowed elif palette is not None and got_multiple_cmaps: raise ValueError("If 'palette' is provided, 'cmap' must be None.") @@ -1534,12 +1524,8 @@ def _render_labels( # When groups are specified, zero out non-matching label IDs so they render as background. # Only show non-matching labels if the user explicitly sets na_color. _na = render_params.cmap_params.na_color - if ( - groups is not None - and categorical - and color_source_vector is not None - and (_na.default_color_set or _na.alpha == "00") - ): + na_is_transparent = _na.default_color_set or _na.get_alpha_as_float() == 0.0 + if groups is not None and categorical and color_source_vector is not None and na_is_transparent: keep_vec = color_source_vector.isin(groups) matching_ids = instance_id[keep_vec] keep_mask = np.isin(label.values, matching_ids) diff --git a/src/spatialdata_plot/pl/render_params.py b/src/spatialdata_plot/pl/render_params.py index a108e131..1c0d94cc 100644 --- a/src/spatialdata_plot/pl/render_params.py +++ b/src/spatialdata_plot/pl/render_params.py @@ -264,8 +264,8 @@ class ImageRenderParams: element: str channel: list[str] | list[int] | int | str | None = None palette: ListedColormap | list[str] | None = None + norms: list[Normalize] | None = None # per-channel norms, separate from the scalar norm in CmapParams alpha: float = 1.0 - percentiles_for_norm: tuple[float | None, float | None] = (None, None) scale: str | None = None zorder: int = 0 colorbar: bool | str | None = "auto" diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 402a2191..a7f56ae9 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2394,7 +2394,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st norm = param_dict.get("norm") if norm is not None: - if element_type in {"images", "labels"} and not isinstance(norm, Normalize): + if element_type == "images" and isinstance(norm, Sequence) and not isinstance(norm, list): + raise TypeError("'norm' must be a list of Normalize objects, not a tuple or other sequence.") + if element_type == "images" and isinstance(norm, list): + if len(norm) == 0: + raise ValueError("'norm' list must not be empty.") + if not all(isinstance(n, Normalize) for n in norm): + raise TypeError("All elements of 'norm' list must be of type Normalize.") + elif element_type in {"images", "labels"} and not isinstance(norm, Normalize): raise TypeError("Parameter 'norm' must be of type Normalize.") if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize): raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.") diff --git a/tests/_images/Images_can_composite_five_channels.png b/tests/_images/Images_can_composite_five_channels.png new file mode 100644 index 00000000..57642190 Binary files /dev/null and b/tests/_images/Images_can_composite_five_channels.png differ diff --git a/tests/_images/Images_can_composite_five_channels_with_palette.png b/tests/_images/Images_can_composite_five_channels_with_palette.png new file mode 100644 index 00000000..603c5860 Binary files /dev/null and b/tests/_images/Images_can_composite_five_channels_with_palette.png differ diff --git a/tests/_images/Images_can_normalize_channels_independently.png b/tests/_images/Images_can_normalize_channels_independently.png new file mode 100644 index 00000000..c3a3be5a Binary files /dev/null and b/tests/_images/Images_can_normalize_channels_independently.png differ diff --git a/tests/_images/Images_can_normalize_single_channel_via_list.png b/tests/_images/Images_can_normalize_single_channel_via_list.png new file mode 100644 index 00000000..e163bda2 Binary files /dev/null and b/tests/_images/Images_can_normalize_single_channel_via_list.png differ diff --git a/tests/_images/Images_can_pass_cmap.png b/tests/_images/Images_can_pass_cmap.png index eb13f61e..eb0a35a2 100644 Binary files a/tests/_images/Images_can_pass_cmap.png and b/tests/_images/Images_can_pass_cmap.png differ diff --git a/tests/_images/Images_can_pass_cmap_list.png b/tests/_images/Images_can_pass_cmap_list.png index 3f887d9b..00ca2d0f 100644 Binary files a/tests/_images/Images_can_pass_cmap_list.png and b/tests/_images/Images_can_pass_cmap_list.png differ diff --git a/tests/_images/Images_can_pass_cmap_to_each_channel.png b/tests/_images/Images_can_pass_cmap_to_each_channel.png index 94cfcfcd..424358df 100644 Binary files a/tests/_images/Images_can_pass_cmap_to_each_channel.png and b/tests/_images/Images_can_pass_cmap_to_each_channel.png differ diff --git a/tests/_images/Images_can_pass_str_cmap.png b/tests/_images/Images_can_pass_str_cmap.png index eb13f61e..eb0a35a2 100644 Binary files a/tests/_images/Images_can_pass_str_cmap.png and b/tests/_images/Images_can_pass_str_cmap.png differ diff --git a/tests/_images/Images_can_pass_str_cmap_list.png b/tests/_images/Images_can_pass_str_cmap_list.png index 3f887d9b..00ca2d0f 100644 Binary files a/tests/_images/Images_can_pass_str_cmap_list.png and b/tests/_images/Images_can_pass_str_cmap_list.png differ diff --git a/tests/pl/test_render_images.py b/tests/pl/test_render_images.py index 3e28799a..cec149c7 100644 --- a/tests/pl/test_render_images.py +++ b/tests/pl/test_render_images.py @@ -2,6 +2,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import pytest import scanpy as sc from matplotlib.colors import Normalize from spatial_image import to_spatial_image @@ -134,6 +135,71 @@ def test_plot_can_stick_to_zorder(self, sdata_blobs: SpatialData): def test_plot_can_render_multiscale_image_with_custom_cmap(self, sdata_blobs: SpatialData): sdata_blobs.pl.render_images("blobs_multiscale_image", channel=0, scale="scale2", cmap="Greys").pl.show() + # Multichannel compositing with 4+ channels. + # Uses sdata_blobs_str which has 5 named channels, similar to a + # multiplexed immunofluorescence panel (DAPI + 4 markers). + + def test_plot_can_composite_five_channels(self, sdata_blobs_str: SpatialData): + """5-channel additive blend with auto-assigned colors.""" + sdata_blobs_str.pl.render_images(element="blobs_image").pl.show() + + def test_plot_can_composite_five_channels_with_palette(self, sdata_blobs_str: SpatialData): + """5-channel fluorescence-style overlay with explicit palette.""" + sdata_blobs_str.pl.render_images( + element="blobs_image", + palette=["blue", "green", "red", "magenta", "cyan"], + ).pl.show() + + # Per-channel normalization. + # Common in multiplexed imaging where markers have very different intensity + # ranges, e.g. bright DAPI vs dim cytokine stain. + + def test_plot_can_normalize_channels_independently(self, sdata_blobs: SpatialData): + """Per-channel norm: each channel gets its own contrast window.""" + sdata_blobs.pl.render_images( + element="blobs_image", + channel=[0, 1, 2], + palette=["red", "green", "blue"], + norm=[Normalize(vmin=0, vmax=0.5), Normalize(vmin=0.2, vmax=0.8), Normalize(vmin=0, vmax=1)], + ).pl.show() + + def test_plot_can_normalize_single_channel_via_list(self, sdata_blobs: SpatialData): + """Single-element norm list works the same as a scalar norm.""" + sdata_blobs.pl.render_images( + element="blobs_image", + channel=0, + norm=[Normalize(vmin=0.1, vmax=0.5)], + cmap="Greys", + ).pl.show() + + # Validation errors + + def test_norm_list_length_mismatch_raises(self, sdata_blobs: SpatialData): + """norm list length must match number of channels.""" + with pytest.raises(ValueError, match="must match the number of channels"): + sdata_blobs.pl.render_images( + element="blobs_image", + channel=[0, 1, 2], + norm=[Normalize(), Normalize()], # 2 norms for 3 channels + ).pl.show() + + def test_norm_list_empty_raises(self, sdata_blobs: SpatialData): + """Empty norm list is rejected at validation time.""" + with pytest.raises(ValueError, match="must not be empty"): + sdata_blobs.pl.render_images( + element="blobs_image", + norm=[], + ).pl.show() + + def test_norm_tuple_raises(self, sdata_blobs: SpatialData): + """Tuple of norms is rejected, only list is accepted.""" + with pytest.raises(TypeError, match="not a tuple"): + sdata_blobs.pl.render_images( + element="blobs_image", + channel=[0, 1, 2], + norm=(Normalize(), Normalize(), Normalize()), + ).pl.show() + def test_plot_correctly_normalizes_multichannel_images(self, sdata_raccoon: SpatialData): sdata_raccoon["raccoon_int16"] = Image2DModel.parse( sdata_raccoon["raccoon"].data.astype(np.uint16) * 257, # 255 * 257 = 65535,