From db51af4ca9d6af8e6ee4047854440f9c10f0e6fb Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 19 Mar 2026 21:02:55 +0100 Subject: [PATCH 1/6] Deduplicate shared logic between _render_shapes and _render_points Extract 8 helper functions from the near-identical datashader rendering paths in _render_shapes() and _render_points(): - _apply_datashader_norm: norm vmin/vmax edge-case handling - _build_datashader_colorbar_mappable: ScalarMappable construction - _datashader_aggregate: categorical/continuous/no-color aggregation - _datashader_shade_continuous: continuous color mapping + spread + NaN - _datashader_shade_categorical: categorical/no-color color mapping - _render_datashader_result: RGBA image rendering + NaN overlay - _make_palette: ListedColormap construction - _decorate_render: legend/colorbar/scalebar decoration Also refactor the show() dispatch loop in basic.py from 4 if/elif branches to a table-driven pattern. No public API changes. No behavioral changes. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/spatialdata_plot/pl/render.py | 686 ++++++++++++++++-------------- 1 file changed, 369 insertions(+), 317 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 9cddb499..5cedb650 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -2,7 +2,7 @@ from collections import abc from copy import copy -from typing import Any +from typing import Any, Literal import dask import dask.dataframe as dd @@ -68,6 +68,8 @@ # missing (NaN) values. Must not collide with realistic user category names. _DS_NAN_CATEGORY = "ds_nan" +_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] + def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical: """Return a ``pd.Categorical`` from a pandas or dask Series.""" @@ -218,6 +220,266 @@ def _should_request_colorbar( return bool(auto_condition) +def _apply_datashader_norm( + agg: Any, + norm: Normalize, +) -> tuple[Any, list[float] | None]: + """Apply norm vmin/vmax to a datashader aggregate. + + When vmin == vmax, maps the value to 0.5 using an artificial [0, 1] span. + Returns (agg, ds_span) where ds_span is None if no norm was set. + """ + if norm.vmin is None and norm.vmax is None: + return agg, None + norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin + norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax + ds_span: list[float] = [norm.vmin, norm.vmax] + if norm.vmin == norm.vmax: + ds_span = [0, 1] + if norm.clip: + agg = (agg - agg) + 0.5 + else: + agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) + agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) + agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) + return agg, ds_span + + +def _build_datashader_colorbar_mappable( + aggregate_with_reduction: tuple[Any, Any] | None, + norm: Normalize, + cmap: Any, +) -> ScalarMappable | None: + """Create a ScalarMappable for the colorbar from datashader reduction bounds. + + Returns None if there is no continuous reduction. + """ + if aggregate_with_reduction is None: + return None + vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin + vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax + if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: + assert norm.vmin is not None + assert norm.vmax is not None + vmin = norm.vmin - 0.5 + vmax = norm.vmin + 0.5 + return ScalarMappable( + norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), + cmap=cmap, + ) + + +def _datashader_aggregate( + cvs: Any, + transformed_element: Any, + col_for_color: str | None, + color_by_categorical: bool, + ds_reduction: _DsReduction | None, + default_reduction: _DsReduction, + geom_type: Literal["points", "shapes"], +) -> tuple[Any, tuple[Any, Any] | None, Any | None]: + """Aggregate spatial elements with datashader. + + Dispatches between categorical (ds.by), continuous (reduction function), + and no-color (ds.count) aggregation modes. + + Returns (agg, aggregate_with_reduction, continuous_nan_agg). + """ + aggregate_with_reduction = None + continuous_nan_agg = None + + def _agg_call(element: Any, agg_func: Any) -> Any: + if geom_type == "shapes": + return cvs.polygons(element, geometry="geometry", agg=agg_func) + return cvs.points(element, "x", "y", agg=agg_func) + + if col_for_color is not None: + if color_by_categorical: + transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) + agg = _agg_call(transformed_element, ds.by(col_for_color, ds.count())) + else: + reduction_name = ds_reduction if ds_reduction is not None else default_reduction + logger.info( + f'Using the datashader reduction "{reduction_name}". "max" will give an output ' + "very close to the matplotlib result." + ) + agg = _datashader_aggregate_with_function(ds_reduction, cvs, transformed_element, col_for_color, geom_type) + aggregate_with_reduction = (agg.min(), agg.max()) + + nan_elements = transformed_element[transformed_element[col_for_color].isnull()] + if len(nan_elements) > 0: + continuous_nan_agg = _datashader_aggregate_with_function("any", cvs, nan_elements, None, geom_type) + else: + agg = _agg_call(transformed_element, ds.count()) + + return agg, aggregate_with_reduction, continuous_nan_agg + + +def _datashader_shade_continuous( + agg: Any, + ds_span: list[float] | None, + norm: Normalize, + cmap: Any, + alpha: float, + aggregate_with_reduction: tuple[Any, Any] | None, + continuous_nan_agg: Any | None, + na_color_hex: str, + spread_px: int | None = None, + ds_reduction: _DsReduction | None = None, +) -> tuple[Any, Any | None, tuple[Any, Any] | None]: + """Shade a continuous datashader aggregate, optionally applying spread and NaN coloring. + + Returns (ds_result, continuous_nan_shaded, aggregate_with_reduction). + """ + if spread_px is not None: + spread_how = _datshader_get_how_kw_for_spread(ds_reduction) + agg = ds.tf.spread(agg, px=spread_px, how=spread_how) + aggregate_with_reduction = (agg.min(), agg.max()) + + ds_cmap = cmap + if ( + aggregate_with_reduction is not None + and aggregate_with_reduction[0] == aggregate_with_reduction[1] + and (ds_span is None or ds_span != [0, 1]) + ): + ds_cmap = matplotlib.colors.to_hex(cmap(0.0), keep_alpha=False) + aggregate_with_reduction = ( + aggregate_with_reduction[0], + aggregate_with_reduction[0] + 1, + ) + + ds_result = _datashader_map_aggregate_to_color( + agg, + cmap=ds_cmap, + min_alpha=_convert_alpha_to_datashader_range(alpha), + span=ds_span, + clip=norm.clip, + ) + + continuous_nan_shaded = None + if continuous_nan_agg is not None: + shade_kwargs: dict[str, Any] = {"cmap": na_color_hex, "how": "linear"} + if spread_px is not None: + continuous_nan_agg = ds.tf.spread(continuous_nan_agg, px=spread_px, how="max") + else: + # only shapes (no spread) pass min_alpha for NaN shading + shade_kwargs["min_alpha"] = _convert_alpha_to_datashader_range(alpha) + continuous_nan_shaded = ds.tf.shade(continuous_nan_agg, **shade_kwargs) + + return ds_result, continuous_nan_shaded, aggregate_with_reduction + + +def _datashader_shade_categorical( + agg: Any, + color_key: dict[str, str] | None, + color_vector: Any, + alpha: float, + spread_px: int | None = None, +) -> Any: + """Shade a categorical or no-color datashader aggregate.""" + ds_cmap = None + if color_vector is not None: + ds_cmap = color_vector[0] + if isinstance(ds_cmap, str) and ds_cmap[0] == "#": + ds_cmap = _hex_no_alpha(ds_cmap) + + agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg + return _datashader_map_aggregate_to_color( + agg_to_shade, + cmap=ds_cmap, + color_key=color_key, + min_alpha=_convert_alpha_to_datashader_range(alpha), + ) + + +def _render_datashader_result( + ax: matplotlib.axes.SubplotBase, + ds_result: Any, + factor: float, + zorder: int, + alpha: float, + extent: list[float] | None, + continuous_nan_result: Any | None = None, +) -> Any: + """Render a shaded datashader result onto matplotlib axes, with optional NaN overlay.""" + if continuous_nan_result is not None: + rgba_nan, trans_nan = _create_image_from_datashader_result(continuous_nan_result, factor, ax) + _ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, alpha=alpha, extent=extent) + rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) + return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, alpha=alpha, extent=extent) + + +def _make_palette( + color_source_vector: pd.Series | None, + color_vector: Any, +) -> ListedColormap: + """Build a ListedColormap from a color vector, filtering out NaN entries when categorical.""" + if color_source_vector is None: + return ListedColormap(dict.fromkeys(color_vector)) + return ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) + + +def _decorate_render( + ax: matplotlib.axes.SubplotBase, + cax: ScalarMappable | None, + fig_params: FigParams, + adata: AnnData | None, + col_for_color: str | None, + color_source_vector: pd.Series | None, + color_vector: Any, + palette: ListedColormap | list[str] | None, + alpha: float, + na_color: Color, + legend_params: LegendParams, + colorbar: bool | str | None, + colorbar_params: dict[str, object] | None, + colorbar_requests: list[ColorbarSpec] | None, + scalebar_params: ScalebarParams, +) -> None: + """Add legend, colorbar, and scalebar decorations if the color vector warrants them.""" + if not _want_decorations(color_vector, na_color): + return + + if palette is None: + palette = _make_palette(color_source_vector, color_vector) + + if color_source_vector is not None and hasattr(color_source_vector, "remove_unused_categories"): + color_source_vector = color_source_vector.remove_unused_categories() + + wants_colorbar = _should_request_colorbar( + colorbar, + has_mappable=cax is not None, + is_continuous=col_for_color is not None and color_source_vector is None, + ) + + _decorate_axs( + ax=ax, + cax=cax, + fig_params=fig_params, + adata=adata, + value_to_plot=col_for_color, + color_source_vector=color_source_vector, + color_vector=color_vector, + palette=palette, + alpha=alpha, + na_color=na_color, + legend_fontsize=legend_params.legend_fontsize, + legend_fontweight=legend_params.legend_fontweight, + legend_loc=legend_params.legend_loc, + legend_fontoutline=legend_params.legend_fontoutline, + na_in_legend=legend_params.na_in_legend, + colorbar=wants_colorbar and legend_params.colorbar, + colorbar_params=colorbar_params, + colorbar_requests=colorbar_requests, + colorbar_label=_resolve_colorbar_label( + colorbar_params, + col_for_color if isinstance(col_for_color, str) else None, + ), + scalebar_dx=scalebar_params.scalebar_dx, + scalebar_units=scalebar_params.scalebar_units, + ) + + def _render_shapes( sdata: sd.SpatialData, render_params: ShapesRenderParams, @@ -315,13 +577,7 @@ def _render_shapes( "These observations will be colored with the 'na_color'." ) - # Using dict.fromkeys here since set returns in arbitrary order - # remove the color of NaN values, else it might be assigned to a category - # order of color in the palette should agree to order of occurence - if color_source_vector is None: - palette = ListedColormap(dict.fromkeys(color_vector)) - else: - palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) + palette = _make_palette(color_source_vector, color_vector) has_valid_color = ( len(set(color_vector)) != 1 @@ -419,41 +675,15 @@ def _render_shapes( cat_series = cat_series.astype("category") transformed_element[col_for_color] = cat_series - aggregate_with_reduction = None - continuous_nan_shapes = None - if col_for_color is not None: - if color_by_categorical: - # add a sentinel category so that shapes with NaN value are colored in the na_color - transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) - agg = cvs.polygons( - transformed_element, - geometry="geometry", - agg=ds.by(col_for_color, ds.count()), - ) - else: - reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean" - logger.info( - f'Using the datashader reduction "{reduction_name}". "max" will give an output very close ' - "to the matplotlib result." - ) - agg = _datashader_aggregate_with_function( - render_params.ds_reduction, - cvs, - transformed_element, - col_for_color, - "shapes", - ) - # save min and max values for drawing the colorbar - aggregate_with_reduction = (agg.min(), agg.max()) - - # nan shapes need to be rendered separately (else: invisible, bc nan is skipped by aggregation methods) - transformed_element_nan_color = transformed_element[transformed_element[col_for_color].isnull()] - if len(transformed_element_nan_color) > 0: - continuous_nan_shapes = _datashader_aggregate_with_function( - "any", cvs, transformed_element_nan_color, None, "shapes" - ) - else: - agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.count()) + agg, aggregate_with_reduction, continuous_nan_agg = _datashader_aggregate( + cvs, + transformed_element, + col_for_color, + color_by_categorical, + render_params.ds_reduction, + "mean", + "shapes", + ) # render outlines if needed # outline_linewidth is in points (1pt = 1/72 inch); datashader line_width is in canvas pixels @@ -472,20 +702,7 @@ def _render_shapes( line_width=render_params.outline_params.inner_outline_linewidth * ds_lw_factor, ) - ds_span = None - if norm.vmin is not None or norm.vmax is not None: - norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin - norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax - ds_span = [norm.vmin, norm.vmax] - if norm.vmin == norm.vmax: - # edge case, value vmin is rendered as the middle of the cmap - ds_span = [0, 1] - if norm.clip: - agg = (agg - agg) + 0.5 - else: - agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) - agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) - agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) + agg, ds_span = _apply_datashader_norm(agg, norm) color_key: dict[str, str] | None = None if color_by_categorical and col_for_color is not None: @@ -494,49 +711,27 @@ def _render_shapes( cat_series, color_vector, render_params.cmap_params.na_color.get_hex() ) + continuous_nan_shaded = None if color_by_categorical or col_for_color is None: - ds_cmap = None - if color_vector is not None: - ds_cmap = color_vector[0] - if isinstance(ds_cmap, str) and ds_cmap[0] == "#": - ds_cmap = _hex_no_alpha(ds_cmap) - - ds_result = _datashader_map_aggregate_to_color( + ds_result = _datashader_shade_categorical( agg, - cmap=ds_cmap, - color_key=color_key, - min_alpha=_convert_alpha_to_datashader_range(render_params.fill_alpha), + color_key, + color_vector, + render_params.fill_alpha, ) - elif aggregate_with_reduction is not None: # to shut up mypy - ds_cmap = render_params.cmap_params.cmap - # in case all elements have the same value X: we render them using cmap(0.0), - # using an artificial "span" of [X, X + 1] for the color bar - # else: all elements would get alpha=0 and the color bar would have a weird range - if aggregate_with_reduction[0] == aggregate_with_reduction[1]: - ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False) - aggregate_with_reduction = ( - aggregate_with_reduction[0], - aggregate_with_reduction[0] + 1, - ) - - ds_result = _datashader_map_aggregate_to_color( + else: + na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) + ds_result, continuous_nan_shaded, aggregate_with_reduction = _datashader_shade_continuous( agg, - cmap=ds_cmap, - min_alpha=_convert_alpha_to_datashader_range(render_params.fill_alpha), - span=ds_span, - clip=norm.clip, + ds_span, + norm, + render_params.cmap_params.cmap, + render_params.fill_alpha, + aggregate_with_reduction, + continuous_nan_agg, + na_color_hex, ) - if continuous_nan_shapes is not None: - # for coloring by continuous variable: render nan shapes separately - nan_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) - continuous_nan_shapes = ds.tf.shade( - continuous_nan_shapes, - cmap=nan_color_hex, - how="linear", - min_alpha=_convert_alpha_to_datashader_range(render_params.fill_alpha), - ) - # shade outlines if needed if render_params.outline_alpha[0] > 0 and isinstance(render_params.outline_params.outer_outline_color, Color): outline_color = render_params.outline_params.outer_outline_color.get_hex() @@ -578,42 +773,17 @@ def _render_shapes( extent=x_ext + y_ext, ) - if continuous_nan_shapes is not None: - # for coloring by continuous variable: render nan points separately - rgba_image_nan, trans_data_nan = _create_image_from_datashader_result(continuous_nan_shapes, factor, ax) - _ax_show_and_transform( - rgba_image_nan, - trans_data_nan, - ax, - zorder=render_params.zorder, - alpha=render_params.fill_alpha, - extent=x_ext + y_ext, - ) - rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) - _cax = _ax_show_and_transform( - rgba_image, - trans_data, + _cax = _render_datashader_result( ax, - zorder=render_params.zorder, - alpha=render_params.fill_alpha, - extent=x_ext + y_ext, + ds_result, + factor, + render_params.zorder, + render_params.fill_alpha, + x_ext + y_ext, + continuous_nan_result=continuous_nan_shaded, ) - cax = None - if aggregate_with_reduction is not None: - vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin - vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax - if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: - assert norm.vmin is not None - assert norm.vmax is not None - # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and - # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) - vmin = norm.vmin - 0.5 - vmax = norm.vmin + 0.5 - cax = ScalarMappable( - norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), - cmap=render_params.cmap_params.cmap, - ) + cax = _build_datashader_colorbar_mappable(aggregate_with_reduction, norm, render_params.cmap_params.cmap) elif method == "matplotlib": # render outlines separately to ensure they are always underneath the shape @@ -698,43 +868,23 @@ def _render_shapes( vmax = 1.0 _cax.set_clim(vmin=vmin, vmax=vmax) - if _want_decorations(color_vector, render_params.cmap_params.na_color): - # necessary in case different shapes elements are annotated with one table - if color_source_vector is not None and render_params.col_for_color is not None: - color_source_vector = color_source_vector.remove_unused_categories() - - wants_colorbar = _should_request_colorbar( - render_params.colorbar, - has_mappable=cax is not None, - is_continuous=render_params.col_for_color is not None and color_source_vector is None, - ) - - _ = _decorate_axs( - ax=ax, - cax=cax, - fig_params=fig_params, - adata=table, - value_to_plot=col_for_color, - color_source_vector=color_source_vector, - color_vector=color_vector, - palette=palette, - alpha=render_params.fill_alpha, - na_color=render_params.cmap_params.na_color, - legend_fontsize=legend_params.legend_fontsize, - legend_fontweight=legend_params.legend_fontweight, - legend_loc=legend_params.legend_loc, - legend_fontoutline=legend_params.legend_fontoutline, - na_in_legend=legend_params.na_in_legend, - colorbar=wants_colorbar and legend_params.colorbar, - colorbar_params=render_params.colorbar_params, - colorbar_requests=colorbar_requests, - colorbar_label=_resolve_colorbar_label( - render_params.colorbar_params, - col_for_color if isinstance(col_for_color, str) else None, - ), - scalebar_dx=scalebar_params.scalebar_dx, - scalebar_units=scalebar_params.scalebar_units, - ) + _decorate_render( + ax=ax, + cax=cax, + fig_params=fig_params, + adata=table, + col_for_color=col_for_color, + color_source_vector=color_source_vector, + color_vector=color_vector, + palette=palette, + alpha=render_params.fill_alpha, + na_color=render_params.cmap_params.na_color, + legend_params=legend_params, + colorbar=render_params.colorbar, + colorbar_params=render_params.colorbar_params, + colorbar_requests=colorbar_requests, + scalebar_params=scalebar_params, + ) def _render_points( @@ -972,52 +1122,17 @@ def _render_points( if color_by_categorical and not isinstance(color_dtype, pd.CategoricalDtype): transformed_element[col_for_color] = transformed_element[col_for_color].astype("category") - aggregate_with_reduction = None - continuous_nan_points = None - if col_for_color is not None: - if color_by_categorical: - # add nan as category so that nan points are shown in the nan color - transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) - agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count())) - else: - reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum" - logger.info( - f'Using the datashader reduction "{reduction_name}". "max" will give an output very close ' - "to the matplotlib result." - ) - agg = _datashader_aggregate_with_function( - render_params.ds_reduction, - cvs, - transformed_element, - col_for_color, - "points", - ) - # save min and max values for drawing the colorbar - aggregate_with_reduction = (agg.min(), agg.max()) - # nan points need to be rendered separately (else: invisible, bc nan is skipped by aggregation methods) - transformed_element_nan_color = transformed_element[transformed_element[col_for_color].isnull()] - if len(transformed_element_nan_color) > 0: - continuous_nan_points = _datashader_aggregate_with_function( - "any", cvs, transformed_element_nan_color, None, "points" - ) - else: - agg = cvs.points(transformed_element, "x", "y", agg=ds.count()) - - ds_span = None - if norm.vmin is not None or norm.vmax is not None: - norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin - norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax - ds_span = [norm.vmin, norm.vmax] - if norm.vmin == norm.vmax: - ds_span = [0, 1] - if norm.clip: - # all data is mapped to 0.5 - agg = (agg - agg) + 0.5 - else: - # values equal to norm.vmin are mapped to 0.5, the rest to -1 or 2 - agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) - agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) - agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) + agg, aggregate_with_reduction, continuous_nan_agg = _datashader_aggregate( + cvs, + transformed_element, + col_for_color, + color_by_categorical, + render_params.ds_reduction, + "sum", + "points", + ) + + agg, ds_span = _apply_datashader_norm(agg, norm) color_key: dict[str, str] | None = None if color_by_categorical and col_for_color is not None: @@ -1034,83 +1149,41 @@ def _render_points( ): color_vector = np.asarray([_hex_no_alpha(x) for x in color_vector]) + continuous_nan_shaded = None if color_by_categorical or col_for_color is None: - ds_result = _datashader_map_aggregate_to_color( - ds.tf.spread(agg, px=px), - cmap=color_vector[0], - color_key=color_key, - min_alpha=_convert_alpha_to_datashader_range(render_params.alpha), + ds_result = _datashader_shade_categorical( + agg, + color_key, + color_vector, + render_params.alpha, + spread_px=px, ) else: - spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction) - agg = ds.tf.spread(agg, px=px, how=spread_how) - aggregate_with_reduction = (agg.min(), agg.max()) - - ds_cmap = render_params.cmap_params.cmap - # in case all elements have the same value X: we render them using cmap(0.0), - # using an artificial "span" of [X, X + 1] for the color bar - # else: all elements would get alpha=0 and the color bar would have a weird range - if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]): - ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False) - aggregate_with_reduction = ( - aggregate_with_reduction[0], - aggregate_with_reduction[0] + 1, - ) - - ds_result = _datashader_map_aggregate_to_color( + na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) + ds_result, continuous_nan_shaded, aggregate_with_reduction = _datashader_shade_continuous( agg, - cmap=ds_cmap, - span=ds_span, - clip=norm.clip, - min_alpha=_convert_alpha_to_datashader_range(render_params.alpha), + ds_span, + norm, + render_params.cmap_params.cmap, + render_params.alpha, + aggregate_with_reduction, + continuous_nan_agg, + na_color_hex, + spread_px=px, + ds_reduction=render_params.ds_reduction, ) - if continuous_nan_points is not None: - # for coloring by continuous variable: render nan points separately - nan_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) - continuous_nan_points = ds.tf.spread(continuous_nan_points, px=px, how="max") - continuous_nan_points = ds.tf.shade( - continuous_nan_points, - cmap=nan_color_hex, - how="linear", - ) - - if continuous_nan_points is not None: - # for coloring by continuous variable: render nan points separately - rgba_image_nan, trans_data_nan = _create_image_from_datashader_result(continuous_nan_points, factor, ax) - _ax_show_and_transform( - rgba_image_nan, - trans_data_nan, - ax, - zorder=render_params.zorder, - alpha=render_params.alpha, - extent=x_ext + y_ext, - ) - rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) - _ax_show_and_transform( - rgba_image, - trans_data, + _render_datashader_result( ax, - zorder=render_params.zorder, - alpha=render_params.alpha, - extent=x_ext + y_ext, + ds_result, + factor, + render_params.zorder, + render_params.alpha, + x_ext + y_ext, + continuous_nan_result=continuous_nan_shaded, ) - cax = None - if aggregate_with_reduction is not None: - vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin - vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax - if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: - assert norm.vmin is not None - assert norm.vmax is not None - # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and - # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) - vmin = norm.vmin - 0.5 - vmax = norm.vmin + 0.5 - cax = ScalarMappable( - norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), - cmap=render_params.cmap_params.cmap, - ) + cax = _build_datashader_colorbar_mappable(aggregate_with_reduction, norm, render_params.cmap_params.cmap) elif method == "matplotlib": # update axis limits if plot was empty before (necessary if datashader comes after) @@ -1135,44 +1208,23 @@ def _render_points( ax.set_xbound(extent["x"]) ax.set_ybound(extent["y"]) - if _want_decorations(color_vector, render_params.cmap_params.na_color): - if color_source_vector is None: - palette = ListedColormap(dict.fromkeys(color_vector)) - else: - palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) - - wants_colorbar = _should_request_colorbar( - render_params.colorbar, - has_mappable=cax is not None, - is_continuous=col_for_color is not None and color_source_vector is None, - ) - - _ = _decorate_axs( - ax=ax, - cax=cax, - fig_params=fig_params, - adata=adata, - value_to_plot=col_for_color, - color_source_vector=color_source_vector, - color_vector=color_vector, - palette=palette, - alpha=render_params.alpha, - na_color=render_params.cmap_params.na_color, - legend_fontsize=legend_params.legend_fontsize, - legend_fontweight=legend_params.legend_fontweight, - legend_loc=legend_params.legend_loc, - legend_fontoutline=legend_params.legend_fontoutline, - na_in_legend=legend_params.na_in_legend, - colorbar=wants_colorbar and legend_params.colorbar, - colorbar_params=render_params.colorbar_params, - colorbar_requests=colorbar_requests, - colorbar_label=_resolve_colorbar_label( - render_params.colorbar_params, - col_for_color if isinstance(col_for_color, str) else None, - ), - scalebar_dx=scalebar_params.scalebar_dx, - scalebar_units=scalebar_params.scalebar_units, - ) + _decorate_render( + ax=ax, + cax=cax, + fig_params=fig_params, + adata=adata, + col_for_color=col_for_color, + color_source_vector=color_source_vector, + color_vector=color_vector, + palette=None, + alpha=render_params.alpha, + na_color=render_params.cmap_params.na_color, + legend_params=legend_params, + colorbar=render_params.colorbar, + colorbar_params=render_params.colorbar_params, + colorbar_requests=colorbar_requests, + scalebar_params=scalebar_params, + ) def _render_images( @@ -1366,7 +1418,7 @@ def _render_images( seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)] channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - # Stack (n_channels, height, width) → (height*width, n_channels) + # Stack (n_channels, height, width) -> (height*width, n_channels) H, W = next(iter(layers.values())).shape comp_rgb = np.zeros((H, W, 3), dtype=float) From 8c0d2959d68fef408900023b1dbfa5bd2e00c755 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 19 Mar 2026 21:30:08 +0100 Subject: [PATCH 2/6] Extract outline rendering and color key building into shared helpers - _render_ds_outlines: consolidates the 40-line outline aggregation + shading + rendering block (outer + inner) into a single loop - _build_color_key: extracts the identical color key construction from both shapes and points datashader paths The shapes and points datashader pipelines now read nearly identically, differing only in parameter values. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/spatialdata_plot/pl/render.py | 145 +++++++++++++++--------------- 1 file changed, 72 insertions(+), 73 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 5cedb650..d92fa4b9 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -369,6 +369,23 @@ def _datashader_shade_continuous( return ds_result, continuous_nan_shaded, aggregate_with_reduction +def _build_color_key( + transformed_element: Any, + col_for_color: str | None, + color_by_categorical: bool, + color_vector: Any, + na_color_hex: str, +) -> dict[str, str] | None: + """Build a datashader color key mapping categories to hex colors. + + Returns None when not coloring by a categorical column. + """ + if not color_by_categorical or col_for_color is None: + return None + cat_series = _coerce_categorical_source(transformed_element[col_for_color]) + return _build_datashader_color_key(cat_series, color_vector, na_color_hex) + + def _datashader_shade_categorical( agg: Any, color_key: dict[str, str] | None, @@ -392,6 +409,44 @@ def _datashader_shade_categorical( ) +def _render_ds_outlines( + cvs: Any, + transformed_element: Any, + render_params: ShapesRenderParams, + fig_params: FigParams, + ax: matplotlib.axes.SubplotBase, + factor: float, + extent: list[float], +) -> None: + """Aggregate, shade, and render shape outlines (outer and inner) with datashader.""" + ds_lw_factor = fig_params.fig.dpi / 72 + assert len(render_params.outline_alpha) == 2 # noqa: S101 + + for idx, (outline_color_obj, linewidth) in enumerate( + [ + (render_params.outline_params.outer_outline_color, render_params.outline_params.outer_outline_linewidth), + (render_params.outline_params.inner_outline_color, render_params.outline_params.inner_outline_linewidth), + ] + ): + alpha = render_params.outline_alpha[idx] + if alpha <= 0: + continue + agg_outline = cvs.line( + transformed_element, + geometry="geometry", + line_width=linewidth * ds_lw_factor, + ) + if isinstance(outline_color_obj, Color): + shaded = ds.tf.shade( + agg_outline, + cmap=outline_color_obj.get_hex(), + min_alpha=_convert_alpha_to_datashader_range(alpha), + how="linear", + ) + rgba, trans = _create_image_from_datashader_result(shaded, factor, ax) + _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, alpha=alpha, extent=extent) + + def _render_datashader_result( ax: matplotlib.axes.SubplotBase, ds_result: Any, @@ -685,31 +740,15 @@ def _render_shapes( "shapes", ) - # render outlines if needed - # outline_linewidth is in points (1pt = 1/72 inch); datashader line_width is in canvas pixels - ds_lw_factor = fig_params.fig.dpi / 72 - assert len(render_params.outline_alpha) == 2 # shut up mypy - if render_params.outline_alpha[0] > 0: - agg_outlines = cvs.line( - transformed_element, - geometry="geometry", - line_width=render_params.outline_params.outer_outline_linewidth * ds_lw_factor, - ) - if render_params.outline_alpha[1] > 0: - agg_inner_outlines = cvs.line( - transformed_element, - geometry="geometry", - line_width=render_params.outline_params.inner_outline_linewidth * ds_lw_factor, - ) - agg, ds_span = _apply_datashader_norm(agg, norm) - - color_key: dict[str, str] | None = None - if color_by_categorical and col_for_color is not None: - cat_series = _coerce_categorical_source(transformed_element[col_for_color]) - color_key = _build_datashader_color_key( - cat_series, color_vector, render_params.cmap_params.na_color.get_hex() - ) + na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) + color_key = _build_color_key( + transformed_element, + col_for_color, + color_by_categorical, + color_vector, + na_color_hex, + ) continuous_nan_shaded = None if color_by_categorical or col_for_color is None: @@ -720,7 +759,6 @@ def _render_shapes( render_params.fill_alpha, ) else: - na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) ds_result, continuous_nan_shaded, aggregate_with_reduction = _datashader_shade_continuous( agg, ds_span, @@ -732,46 +770,7 @@ def _render_shapes( na_color_hex, ) - # shade outlines if needed - if render_params.outline_alpha[0] > 0 and isinstance(render_params.outline_params.outer_outline_color, Color): - outline_color = render_params.outline_params.outer_outline_color.get_hex() - ds_outlines = ds.tf.shade( - agg_outlines, - cmap=outline_color, - min_alpha=_convert_alpha_to_datashader_range(render_params.outline_alpha[0]), - how="linear", - ) - # inner outlines - if render_params.outline_alpha[1] > 0 and isinstance(render_params.outline_params.inner_outline_color, Color): - outline_color = render_params.outline_params.inner_outline_color.get_hex() - ds_inner_outlines = ds.tf.shade( - agg_inner_outlines, - cmap=outline_color, - min_alpha=_convert_alpha_to_datashader_range(render_params.outline_alpha[1]), - how="linear", - ) - - # render outline image(s) - if render_params.outline_alpha[0] > 0: - rgba_image, trans_data = _create_image_from_datashader_result(ds_outlines, factor, ax) - _ax_show_and_transform( - rgba_image, - trans_data, - ax, - zorder=render_params.zorder, - alpha=render_params.outline_alpha[0], - extent=x_ext + y_ext, - ) - if render_params.outline_alpha[1] > 0: - rgba_image, trans_data = _create_image_from_datashader_result(ds_inner_outlines, factor, ax) - _ax_show_and_transform( - rgba_image, - trans_data, - ax, - zorder=render_params.zorder, - alpha=render_params.outline_alpha[1], - extent=x_ext + y_ext, - ) + _render_ds_outlines(cvs, transformed_element, render_params, fig_params, ax, factor, x_ext + y_ext) _cax = _render_datashader_result( ax, @@ -1133,13 +1132,14 @@ def _render_points( ) agg, ds_span = _apply_datashader_norm(agg, norm) - - color_key: dict[str, str] | None = None - if color_by_categorical and col_for_color is not None: - cat_series = _coerce_categorical_source(transformed_element[col_for_color]) - color_key = _build_datashader_color_key( - cat_series, color_vector, render_params.cmap_params.na_color.get_hex() - ) + na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) + color_key = _build_color_key( + transformed_element, + col_for_color, + color_by_categorical, + color_vector, + na_color_hex, + ) if ( color_vector is not None @@ -1159,7 +1159,6 @@ def _render_points( spread_px=px, ) else: - na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) ds_result, continuous_nan_shaded, aggregate_with_reduction = _datashader_shade_continuous( agg, ds_span, From feb25214d5fb298e04a7d1cfe887bef9ea55226d Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 19 Mar 2026 21:41:12 +0100 Subject: [PATCH 3/6] Rename helpers and variables for clarity Functions: - _datashader_aggregate -> _ds_aggregate (consistent prefix) - _datashader_shade_continuous -> _ds_shade_continuous - _datashader_shade_categorical -> _ds_shade_categorical - _apply_datashader_norm -> _apply_ds_norm - _build_datashader_colorbar_mappable -> _build_ds_colorbar - _render_datashader_result -> _render_ds_image - _decorate_render -> _add_legend_and_colorbar Variables: - aggregate_with_reduction -> reduction_bounds - continuous_nan_agg -> nan_agg - continuous_nan_shaded -> nan_shaded - ds_result -> shaded - ds_span -> color_span Co-Authored-By: Claude Opus 4.6 (1M context) --- src/spatialdata_plot/pl/render.py | 142 +++++++++++++++--------------- 1 file changed, 71 insertions(+), 71 deletions(-) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index d92fa4b9..17e3fd67 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -220,33 +220,33 @@ def _should_request_colorbar( return bool(auto_condition) -def _apply_datashader_norm( +def _apply_ds_norm( agg: Any, norm: Normalize, ) -> tuple[Any, list[float] | None]: """Apply norm vmin/vmax to a datashader aggregate. When vmin == vmax, maps the value to 0.5 using an artificial [0, 1] span. - Returns (agg, ds_span) where ds_span is None if no norm was set. + Returns (agg, color_span) where color_span is None if no norm was set. """ if norm.vmin is None and norm.vmax is None: return agg, None norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax - ds_span: list[float] = [norm.vmin, norm.vmax] + color_span: list[float] = [norm.vmin, norm.vmax] if norm.vmin == norm.vmax: - ds_span = [0, 1] + color_span = [0, 1] if norm.clip: agg = (agg - agg) + 0.5 else: agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) - return agg, ds_span + return agg, color_span -def _build_datashader_colorbar_mappable( - aggregate_with_reduction: tuple[Any, Any] | None, +def _build_ds_colorbar( + reduction_bounds: tuple[Any, Any] | None, norm: Normalize, cmap: Any, ) -> ScalarMappable | None: @@ -254,10 +254,10 @@ def _build_datashader_colorbar_mappable( Returns None if there is no continuous reduction. """ - if aggregate_with_reduction is None: + if reduction_bounds is None: return None - vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin - vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax + vmin = reduction_bounds[0].values if norm.vmin is None else norm.vmin + vmax = reduction_bounds[1].values if norm.vmax is None else norm.vmax if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: assert norm.vmin is not None assert norm.vmax is not None @@ -269,7 +269,7 @@ def _build_datashader_colorbar_mappable( ) -def _datashader_aggregate( +def _ds_aggregate( cvs: Any, transformed_element: Any, col_for_color: str | None, @@ -283,10 +283,10 @@ def _datashader_aggregate( Dispatches between categorical (ds.by), continuous (reduction function), and no-color (ds.count) aggregation modes. - Returns (agg, aggregate_with_reduction, continuous_nan_agg). + Returns (agg, reduction_bounds, nan_agg). """ - aggregate_with_reduction = None - continuous_nan_agg = None + reduction_bounds = None + nan_agg = None def _agg_call(element: Any, agg_func: Any) -> Any: if geom_type == "shapes": @@ -304,69 +304,69 @@ def _agg_call(element: Any, agg_func: Any) -> Any: "very close to the matplotlib result." ) agg = _datashader_aggregate_with_function(ds_reduction, cvs, transformed_element, col_for_color, geom_type) - aggregate_with_reduction = (agg.min(), agg.max()) + reduction_bounds = (agg.min(), agg.max()) nan_elements = transformed_element[transformed_element[col_for_color].isnull()] if len(nan_elements) > 0: - continuous_nan_agg = _datashader_aggregate_with_function("any", cvs, nan_elements, None, geom_type) + nan_agg = _datashader_aggregate_with_function("any", cvs, nan_elements, None, geom_type) else: agg = _agg_call(transformed_element, ds.count()) - return agg, aggregate_with_reduction, continuous_nan_agg + return agg, reduction_bounds, nan_agg -def _datashader_shade_continuous( +def _ds_shade_continuous( agg: Any, - ds_span: list[float] | None, + color_span: list[float] | None, norm: Normalize, cmap: Any, alpha: float, - aggregate_with_reduction: tuple[Any, Any] | None, - continuous_nan_agg: Any | None, + reduction_bounds: tuple[Any, Any] | None, + nan_agg: Any | None, na_color_hex: str, spread_px: int | None = None, ds_reduction: _DsReduction | None = None, ) -> tuple[Any, Any | None, tuple[Any, Any] | None]: """Shade a continuous datashader aggregate, optionally applying spread and NaN coloring. - Returns (ds_result, continuous_nan_shaded, aggregate_with_reduction). + Returns (shaded, nan_shaded, reduction_bounds). """ if spread_px is not None: spread_how = _datshader_get_how_kw_for_spread(ds_reduction) agg = ds.tf.spread(agg, px=spread_px, how=spread_how) - aggregate_with_reduction = (agg.min(), agg.max()) + reduction_bounds = (agg.min(), agg.max()) ds_cmap = cmap if ( - aggregate_with_reduction is not None - and aggregate_with_reduction[0] == aggregate_with_reduction[1] - and (ds_span is None or ds_span != [0, 1]) + reduction_bounds is not None + and reduction_bounds[0] == reduction_bounds[1] + and (color_span is None or color_span != [0, 1]) ): ds_cmap = matplotlib.colors.to_hex(cmap(0.0), keep_alpha=False) - aggregate_with_reduction = ( - aggregate_with_reduction[0], - aggregate_with_reduction[0] + 1, + reduction_bounds = ( + reduction_bounds[0], + reduction_bounds[0] + 1, ) - ds_result = _datashader_map_aggregate_to_color( + shaded = _datashader_map_aggregate_to_color( agg, cmap=ds_cmap, min_alpha=_convert_alpha_to_datashader_range(alpha), - span=ds_span, + span=color_span, clip=norm.clip, ) - continuous_nan_shaded = None - if continuous_nan_agg is not None: + nan_shaded = None + if nan_agg is not None: shade_kwargs: dict[str, Any] = {"cmap": na_color_hex, "how": "linear"} if spread_px is not None: - continuous_nan_agg = ds.tf.spread(continuous_nan_agg, px=spread_px, how="max") + nan_agg = ds.tf.spread(nan_agg, px=spread_px, how="max") else: # only shapes (no spread) pass min_alpha for NaN shading shade_kwargs["min_alpha"] = _convert_alpha_to_datashader_range(alpha) - continuous_nan_shaded = ds.tf.shade(continuous_nan_agg, **shade_kwargs) + nan_shaded = ds.tf.shade(nan_agg, **shade_kwargs) - return ds_result, continuous_nan_shaded, aggregate_with_reduction + return shaded, nan_shaded, reduction_bounds def _build_color_key( @@ -386,7 +386,7 @@ def _build_color_key( return _build_datashader_color_key(cat_series, color_vector, na_color_hex) -def _datashader_shade_categorical( +def _ds_shade_categorical( agg: Any, color_key: dict[str, str] | None, color_vector: Any, @@ -447,20 +447,20 @@ def _render_ds_outlines( _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, alpha=alpha, extent=extent) -def _render_datashader_result( +def _render_ds_image( ax: matplotlib.axes.SubplotBase, - ds_result: Any, + shaded: Any, factor: float, zorder: int, alpha: float, extent: list[float] | None, - continuous_nan_result: Any | None = None, + nan_result: Any | None = None, ) -> Any: - """Render a shaded datashader result onto matplotlib axes, with optional NaN overlay.""" - if continuous_nan_result is not None: - rgba_nan, trans_nan = _create_image_from_datashader_result(continuous_nan_result, factor, ax) + """Render a shaded datashader image onto matplotlib axes, with optional NaN overlay.""" + if nan_result is not None: + rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax) _ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, alpha=alpha, extent=extent) - rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) + rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax) return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, alpha=alpha, extent=extent) @@ -474,7 +474,7 @@ def _make_palette( return ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) -def _decorate_render( +def _add_legend_and_colorbar( ax: matplotlib.axes.SubplotBase, cax: ScalarMappable | None, fig_params: FigParams, @@ -730,7 +730,7 @@ def _render_shapes( cat_series = cat_series.astype("category") transformed_element[col_for_color] = cat_series - agg, aggregate_with_reduction, continuous_nan_agg = _datashader_aggregate( + agg, reduction_bounds, nan_agg = _ds_aggregate( cvs, transformed_element, col_for_color, @@ -740,7 +740,7 @@ def _render_shapes( "shapes", ) - agg, ds_span = _apply_datashader_norm(agg, norm) + agg, color_span = _apply_ds_norm(agg, norm) na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) color_key = _build_color_key( transformed_element, @@ -750,39 +750,39 @@ def _render_shapes( na_color_hex, ) - continuous_nan_shaded = None + nan_shaded = None if color_by_categorical or col_for_color is None: - ds_result = _datashader_shade_categorical( + shaded = _ds_shade_categorical( agg, color_key, color_vector, render_params.fill_alpha, ) else: - ds_result, continuous_nan_shaded, aggregate_with_reduction = _datashader_shade_continuous( + shaded, nan_shaded, reduction_bounds = _ds_shade_continuous( agg, - ds_span, + color_span, norm, render_params.cmap_params.cmap, render_params.fill_alpha, - aggregate_with_reduction, - continuous_nan_agg, + reduction_bounds, + nan_agg, na_color_hex, ) _render_ds_outlines(cvs, transformed_element, render_params, fig_params, ax, factor, x_ext + y_ext) - _cax = _render_datashader_result( + _cax = _render_ds_image( ax, - ds_result, + shaded, factor, render_params.zorder, render_params.fill_alpha, x_ext + y_ext, - continuous_nan_result=continuous_nan_shaded, + nan_result=nan_shaded, ) - cax = _build_datashader_colorbar_mappable(aggregate_with_reduction, norm, render_params.cmap_params.cmap) + cax = _build_ds_colorbar(reduction_bounds, norm, render_params.cmap_params.cmap) elif method == "matplotlib": # render outlines separately to ensure they are always underneath the shape @@ -867,7 +867,7 @@ def _render_shapes( vmax = 1.0 _cax.set_clim(vmin=vmin, vmax=vmax) - _decorate_render( + _add_legend_and_colorbar( ax=ax, cax=cax, fig_params=fig_params, @@ -1121,7 +1121,7 @@ def _render_points( if color_by_categorical and not isinstance(color_dtype, pd.CategoricalDtype): transformed_element[col_for_color] = transformed_element[col_for_color].astype("category") - agg, aggregate_with_reduction, continuous_nan_agg = _datashader_aggregate( + agg, reduction_bounds, nan_agg = _ds_aggregate( cvs, transformed_element, col_for_color, @@ -1131,7 +1131,7 @@ def _render_points( "points", ) - agg, ds_span = _apply_datashader_norm(agg, norm) + agg, color_span = _apply_ds_norm(agg, norm) na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) color_key = _build_color_key( transformed_element, @@ -1149,9 +1149,9 @@ def _render_points( ): color_vector = np.asarray([_hex_no_alpha(x) for x in color_vector]) - continuous_nan_shaded = None + nan_shaded = None if color_by_categorical or col_for_color is None: - ds_result = _datashader_shade_categorical( + shaded = _ds_shade_categorical( agg, color_key, color_vector, @@ -1159,30 +1159,30 @@ def _render_points( spread_px=px, ) else: - ds_result, continuous_nan_shaded, aggregate_with_reduction = _datashader_shade_continuous( + shaded, nan_shaded, reduction_bounds = _ds_shade_continuous( agg, - ds_span, + color_span, norm, render_params.cmap_params.cmap, render_params.alpha, - aggregate_with_reduction, - continuous_nan_agg, + reduction_bounds, + nan_agg, na_color_hex, spread_px=px, ds_reduction=render_params.ds_reduction, ) - _render_datashader_result( + _render_ds_image( ax, - ds_result, + shaded, factor, render_params.zorder, render_params.alpha, x_ext + y_ext, - continuous_nan_result=continuous_nan_shaded, + nan_result=nan_shaded, ) - cax = _build_datashader_colorbar_mappable(aggregate_with_reduction, norm, render_params.cmap_params.cmap) + cax = _build_ds_colorbar(reduction_bounds, norm, render_params.cmap_params.cmap) elif method == "matplotlib": # update axis limits if plot was empty before (necessary if datashader comes after) @@ -1207,7 +1207,7 @@ def _render_points( ax.set_xbound(extent["x"]) ax.set_ybound(extent["y"]) - _decorate_render( + _add_legend_and_colorbar( ax=ax, cax=cax, fig_params=fig_params, From 4a7578b9b0de22e0b4c106853c9ec86029c35adf Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 19 Mar 2026 22:33:06 +0100 Subject: [PATCH 4/6] Move datashader helpers to _datashader.py All datashader-specific aggregation, shading, and rendering helpers now live in pl/_datashader.py. render.py imports them and focuses on the element-type-specific orchestration logic. Moved: - _ds_aggregate, _apply_ds_norm, _build_color_key - _ds_shade_continuous, _ds_shade_categorical - _render_ds_image, _render_ds_outlines, _build_ds_colorbar - _coerce_categorical_source, _build_datashader_color_key - _inject_ds_nan_sentinel, _DS_NAN_CATEGORY, _DsReduction Co-Authored-By: Claude Opus 4.6 (1M context) --- src/spatialdata_plot/pl/_datashader.py | 342 +++++++++++++++++++++++++ src/spatialdata_plot/pl/render.py | 313 +--------------------- 2 files changed, 353 insertions(+), 302 deletions(-) create mode 100644 src/spatialdata_plot/pl/_datashader.py diff --git a/src/spatialdata_plot/pl/_datashader.py b/src/spatialdata_plot/pl/_datashader.py new file mode 100644 index 00000000..3b50dd52 --- /dev/null +++ b/src/spatialdata_plot/pl/_datashader.py @@ -0,0 +1,342 @@ +"""Datashader aggregation, shading, and rendering helpers. + +Shared by ``_render_shapes`` and ``_render_points`` in ``render.py``. +""" + +from __future__ import annotations + +from typing import Any, Literal + +import dask.dataframe as dd +import datashader as ds +import matplotlib +import matplotlib.colors +import numpy as np +import pandas as pd +from matplotlib.cm import ScalarMappable +from matplotlib.colors import Normalize + +from spatialdata_plot._logging import logger +from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams +from spatialdata_plot.pl.utils import ( + _ax_show_and_transform, + _convert_alpha_to_datashader_range, + _create_image_from_datashader_result, + _datashader_aggregate_with_function, + _datashader_map_aggregate_to_color, + _datshader_get_how_kw_for_spread, + _hex_no_alpha, +) + +# --------------------------------------------------------------------------- +# Type aliases and constants +# --------------------------------------------------------------------------- + +_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] + +# Sentinel category name used in datashader categorical paths to represent +# missing (NaN) values. Must not collide with realistic user category names. +_DS_NAN_CATEGORY = "ds_nan" + +# --------------------------------------------------------------------------- +# Low-level helpers +# --------------------------------------------------------------------------- + + +def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical: + """Return a ``pd.Categorical`` from a pandas or dask Series.""" + if isinstance(series, dd.Series): + if isinstance(series.dtype, pd.CategoricalDtype) and getattr(series.cat, "known", True) is False: + series = series.cat.as_known() + series = series.compute() + if isinstance(series.dtype, pd.CategoricalDtype): + return series.array + return pd.Categorical(series) + + +def _build_datashader_color_key( + cat_series: pd.Categorical, + color_vector: Any, + na_color_hex: str, +) -> dict[str, str]: + """Build a datashader ``color_key`` dict from a categorical series and its color vector.""" + na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex + colors_arr = np.asarray(color_vector, dtype=object) + first_color: dict[str, str] = {} + for code, color in zip(cat_series.codes, colors_arr, strict=False): + if code < 0: + continue + cat_name = str(cat_series.categories[code]) + if cat_name not in first_color: + first_color[cat_name] = _hex_no_alpha(color) if isinstance(color, str) and color.startswith("#") else color + return {str(c): first_color.get(str(c), na_hex) for c in cat_series.categories} + + +def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series: + """Add a sentinel category for NaN values in a categorical series. + + Safely handles series that are not yet categorical, dask-backed + categoricals that need ``as_known()``, and series that already + contain the sentinel. + """ + if not isinstance(series.dtype, pd.CategoricalDtype): + series = series.astype("category") + if hasattr(series.cat, "as_known"): + series = series.cat.as_known() + if sentinel not in series.cat.categories: + series = series.cat.add_categories(sentinel) + return series.fillna(sentinel) + + +# --------------------------------------------------------------------------- +# Pipeline helpers (aggregate -> norm -> shade -> render) +# --------------------------------------------------------------------------- + + +def _ds_aggregate( + cvs: Any, + transformed_element: Any, + col_for_color: str | None, + color_by_categorical: bool, + ds_reduction: _DsReduction | None, + default_reduction: _DsReduction, + geom_type: Literal["points", "shapes"], +) -> tuple[Any, tuple[Any, Any] | None, Any | None]: + """Aggregate spatial elements with datashader. + + Dispatches between categorical (ds.by), continuous (reduction function), + and no-color (ds.count) aggregation modes. + + Returns (agg, reduction_bounds, nan_agg). + """ + reduction_bounds = None + nan_agg = None + + def _agg_call(element: Any, agg_func: Any) -> Any: + if geom_type == "shapes": + return cvs.polygons(element, geometry="geometry", agg=agg_func) + return cvs.points(element, "x", "y", agg=agg_func) + + if col_for_color is not None: + if color_by_categorical: + transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) + agg = _agg_call(transformed_element, ds.by(col_for_color, ds.count())) + else: + reduction_name = ds_reduction if ds_reduction is not None else default_reduction + logger.info( + f'Using the datashader reduction "{reduction_name}". "max" will give an output ' + "very close to the matplotlib result." + ) + agg = _datashader_aggregate_with_function(ds_reduction, cvs, transformed_element, col_for_color, geom_type) + reduction_bounds = (agg.min(), agg.max()) + + nan_elements = transformed_element[transformed_element[col_for_color].isnull()] + if len(nan_elements) > 0: + nan_agg = _datashader_aggregate_with_function("any", cvs, nan_elements, None, geom_type) + else: + agg = _agg_call(transformed_element, ds.count()) + + return agg, reduction_bounds, nan_agg + + +def _apply_ds_norm( + agg: Any, + norm: Normalize, +) -> tuple[Any, list[float] | None]: + """Apply norm vmin/vmax to a datashader aggregate. + + When vmin == vmax, maps the value to 0.5 using an artificial [0, 1] span. + Returns (agg, color_span) where color_span is None if no norm was set. + """ + if norm.vmin is None and norm.vmax is None: + return agg, None + norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin + norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax + color_span: list[float] = [norm.vmin, norm.vmax] + if norm.vmin == norm.vmax: + color_span = [0, 1] + if norm.clip: + agg = (agg - agg) + 0.5 + else: + agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) + agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) + agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) + return agg, color_span + + +def _build_color_key( + transformed_element: Any, + col_for_color: str | None, + color_by_categorical: bool, + color_vector: Any, + na_color_hex: str, +) -> dict[str, str] | None: + """Build a datashader color key mapping categories to hex colors. + + Returns None when not coloring by a categorical column. + """ + if not color_by_categorical or col_for_color is None: + return None + cat_series = _coerce_categorical_source(transformed_element[col_for_color]) + return _build_datashader_color_key(cat_series, color_vector, na_color_hex) + + +def _ds_shade_continuous( + agg: Any, + color_span: list[float] | None, + norm: Normalize, + cmap: Any, + alpha: float, + reduction_bounds: tuple[Any, Any] | None, + nan_agg: Any | None, + na_color_hex: str, + spread_px: int | None = None, + ds_reduction: _DsReduction | None = None, +) -> tuple[Any, Any | None, tuple[Any, Any] | None]: + """Shade a continuous datashader aggregate, optionally applying spread and NaN coloring. + + Returns (shaded, nan_shaded, reduction_bounds). + """ + if spread_px is not None: + spread_how = _datshader_get_how_kw_for_spread(ds_reduction) + agg = ds.tf.spread(agg, px=spread_px, how=spread_how) + reduction_bounds = (agg.min(), agg.max()) + + ds_cmap = cmap + if ( + reduction_bounds is not None + and reduction_bounds[0] == reduction_bounds[1] + and (color_span is None or color_span != [0, 1]) + ): + ds_cmap = matplotlib.colors.to_hex(cmap(0.0), keep_alpha=False) + reduction_bounds = ( + reduction_bounds[0], + reduction_bounds[0] + 1, + ) + + shaded = _datashader_map_aggregate_to_color( + agg, + cmap=ds_cmap, + min_alpha=_convert_alpha_to_datashader_range(alpha), + span=color_span, + clip=norm.clip, + ) + + nan_shaded = None + if nan_agg is not None: + shade_kwargs: dict[str, Any] = {"cmap": na_color_hex, "how": "linear"} + if spread_px is not None: + nan_agg = ds.tf.spread(nan_agg, px=spread_px, how="max") + else: + # only shapes (no spread) pass min_alpha for NaN shading + shade_kwargs["min_alpha"] = _convert_alpha_to_datashader_range(alpha) + nan_shaded = ds.tf.shade(nan_agg, **shade_kwargs) + + return shaded, nan_shaded, reduction_bounds + + +def _ds_shade_categorical( + agg: Any, + color_key: dict[str, str] | None, + color_vector: Any, + alpha: float, + spread_px: int | None = None, +) -> Any: + """Shade a categorical or no-color datashader aggregate.""" + ds_cmap = None + if color_vector is not None: + ds_cmap = color_vector[0] + if isinstance(ds_cmap, str) and ds_cmap[0] == "#": + ds_cmap = _hex_no_alpha(ds_cmap) + + agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg + return _datashader_map_aggregate_to_color( + agg_to_shade, + cmap=ds_cmap, + color_key=color_key, + min_alpha=_convert_alpha_to_datashader_range(alpha), + ) + + +# --------------------------------------------------------------------------- +# Image rendering +# --------------------------------------------------------------------------- + + +def _render_ds_image( + ax: matplotlib.axes.SubplotBase, + shaded: Any, + factor: float, + zorder: int, + alpha: float, + extent: list[float] | None, + nan_result: Any | None = None, +) -> Any: + """Render a shaded datashader image onto matplotlib axes, with optional NaN overlay.""" + if nan_result is not None: + rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax) + _ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, alpha=alpha, extent=extent) + rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax) + return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, alpha=alpha, extent=extent) + + +def _render_ds_outlines( + cvs: Any, + transformed_element: Any, + render_params: ShapesRenderParams, + fig_params: FigParams, + ax: matplotlib.axes.SubplotBase, + factor: float, + extent: list[float], +) -> None: + """Aggregate, shade, and render shape outlines (outer and inner) with datashader.""" + ds_lw_factor = fig_params.fig.dpi / 72 + assert len(render_params.outline_alpha) == 2 # noqa: S101 + + for idx, (outline_color_obj, linewidth) in enumerate( + [ + (render_params.outline_params.outer_outline_color, render_params.outline_params.outer_outline_linewidth), + (render_params.outline_params.inner_outline_color, render_params.outline_params.inner_outline_linewidth), + ] + ): + alpha = render_params.outline_alpha[idx] + if alpha <= 0: + continue + agg_outline = cvs.line( + transformed_element, + geometry="geometry", + line_width=linewidth * ds_lw_factor, + ) + if isinstance(outline_color_obj, Color): + shaded = ds.tf.shade( + agg_outline, + cmap=outline_color_obj.get_hex(), + min_alpha=_convert_alpha_to_datashader_range(alpha), + how="linear", + ) + rgba, trans = _create_image_from_datashader_result(shaded, factor, ax) + _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, alpha=alpha, extent=extent) + + +def _build_ds_colorbar( + reduction_bounds: tuple[Any, Any] | None, + norm: Normalize, + cmap: Any, +) -> ScalarMappable | None: + """Create a ScalarMappable for the colorbar from datashader reduction bounds. + + Returns None if there is no continuous reduction. + """ + if reduction_bounds is None: + return None + vmin = reduction_bounds[0].values if norm.vmin is None else norm.vmin + vmax = reduction_bounds[1].values if norm.vmax is None else norm.vmax + if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: + assert norm.vmin is not None + assert norm.vmax is not None + vmin = norm.vmin - 0.5 + vmax = norm.vmin + 0.5 + return ScalarMappable( + norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), + cmap=cmap, + ) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 17e3fd67..488db5c9 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -2,7 +2,7 @@ from collections import abc from copy import copy -from typing import Any, Literal +from typing import Any import dask import dask.dataframe as dd @@ -27,6 +27,16 @@ from xarray import DataTree from spatialdata_plot._logging import _log_context, logger +from spatialdata_plot.pl._datashader import ( + _apply_ds_norm, + _build_color_key, + _build_ds_colorbar, + _ds_aggregate, + _ds_shade_categorical, + _ds_shade_continuous, + _render_ds_image, + _render_ds_outlines, +) from spatialdata_plot.pl.render_params import ( Color, ColorbarSpec, @@ -40,12 +50,7 @@ ) from spatialdata_plot.pl.utils import ( _ax_show_and_transform, - _convert_alpha_to_datashader_range, _convert_shapes, - _create_image_from_datashader_result, - _datashader_aggregate_with_function, - _datashader_map_aggregate_to_color, - _datshader_get_how_kw_for_spread, _decorate_axs, _get_collection_shape, _get_colors_for_categorical_obs, @@ -64,58 +69,6 @@ _Normalize = Normalize | abc.Sequence[Normalize] -# Sentinel category name used in datashader categorical paths to represent -# missing (NaN) values. Must not collide with realistic user category names. -_DS_NAN_CATEGORY = "ds_nan" - -_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] - - -def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical: - """Return a ``pd.Categorical`` from a pandas or dask Series.""" - if isinstance(series, dd.Series): - if isinstance(series.dtype, pd.CategoricalDtype) and getattr(series.cat, "known", True) is False: - series = series.cat.as_known() - series = series.compute() - if isinstance(series.dtype, pd.CategoricalDtype): - return series.array - return pd.Categorical(series) - - -def _build_datashader_color_key( - cat_series: pd.Categorical, - color_vector: Any, - na_color_hex: str, -) -> dict[str, str]: - """Build a datashader ``color_key`` dict from a categorical series and its color vector.""" - na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex - # Map each category to its first-occurrence color via codes - colors_arr = np.asarray(color_vector, dtype=object) - first_color: dict[str, str] = {} - for code, color in zip(cat_series.codes, colors_arr, strict=False): - if code < 0: - continue - cat_name = str(cat_series.categories[code]) - if cat_name not in first_color: - first_color[cat_name] = _hex_no_alpha(color) if isinstance(color, str) and color.startswith("#") else color - return {str(c): first_color.get(str(c), na_hex) for c in cat_series.categories} - - -def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series: - """Add a sentinel category for NaN values in a categorical series. - - Safely handles series that are not yet categorical, dask-backed - categoricals that need ``as_known()``, and series that already - contain the sentinel. - """ - if not isinstance(series.dtype, pd.CategoricalDtype): - series = series.astype("category") - if hasattr(series.cat, "as_known"): - series = series.cat.as_known() - if sentinel not in series.cat.categories: - series = series.cat.add_categories(sentinel) - return series.fillna(sentinel) - def _want_decorations(color_vector: Any, na_color: Color) -> bool: """Return whether legend/colorbar decorations should be shown. @@ -220,250 +173,6 @@ def _should_request_colorbar( return bool(auto_condition) -def _apply_ds_norm( - agg: Any, - norm: Normalize, -) -> tuple[Any, list[float] | None]: - """Apply norm vmin/vmax to a datashader aggregate. - - When vmin == vmax, maps the value to 0.5 using an artificial [0, 1] span. - Returns (agg, color_span) where color_span is None if no norm was set. - """ - if norm.vmin is None and norm.vmax is None: - return agg, None - norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin - norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax - color_span: list[float] = [norm.vmin, norm.vmax] - if norm.vmin == norm.vmax: - color_span = [0, 1] - if norm.clip: - agg = (agg - agg) + 0.5 - else: - agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) - agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) - agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) - return agg, color_span - - -def _build_ds_colorbar( - reduction_bounds: tuple[Any, Any] | None, - norm: Normalize, - cmap: Any, -) -> ScalarMappable | None: - """Create a ScalarMappable for the colorbar from datashader reduction bounds. - - Returns None if there is no continuous reduction. - """ - if reduction_bounds is None: - return None - vmin = reduction_bounds[0].values if norm.vmin is None else norm.vmin - vmax = reduction_bounds[1].values if norm.vmax is None else norm.vmax - if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: - assert norm.vmin is not None - assert norm.vmax is not None - vmin = norm.vmin - 0.5 - vmax = norm.vmin + 0.5 - return ScalarMappable( - norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), - cmap=cmap, - ) - - -def _ds_aggregate( - cvs: Any, - transformed_element: Any, - col_for_color: str | None, - color_by_categorical: bool, - ds_reduction: _DsReduction | None, - default_reduction: _DsReduction, - geom_type: Literal["points", "shapes"], -) -> tuple[Any, tuple[Any, Any] | None, Any | None]: - """Aggregate spatial elements with datashader. - - Dispatches between categorical (ds.by), continuous (reduction function), - and no-color (ds.count) aggregation modes. - - Returns (agg, reduction_bounds, nan_agg). - """ - reduction_bounds = None - nan_agg = None - - def _agg_call(element: Any, agg_func: Any) -> Any: - if geom_type == "shapes": - return cvs.polygons(element, geometry="geometry", agg=agg_func) - return cvs.points(element, "x", "y", agg=agg_func) - - if col_for_color is not None: - if color_by_categorical: - transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) - agg = _agg_call(transformed_element, ds.by(col_for_color, ds.count())) - else: - reduction_name = ds_reduction if ds_reduction is not None else default_reduction - logger.info( - f'Using the datashader reduction "{reduction_name}". "max" will give an output ' - "very close to the matplotlib result." - ) - agg = _datashader_aggregate_with_function(ds_reduction, cvs, transformed_element, col_for_color, geom_type) - reduction_bounds = (agg.min(), agg.max()) - - nan_elements = transformed_element[transformed_element[col_for_color].isnull()] - if len(nan_elements) > 0: - nan_agg = _datashader_aggregate_with_function("any", cvs, nan_elements, None, geom_type) - else: - agg = _agg_call(transformed_element, ds.count()) - - return agg, reduction_bounds, nan_agg - - -def _ds_shade_continuous( - agg: Any, - color_span: list[float] | None, - norm: Normalize, - cmap: Any, - alpha: float, - reduction_bounds: tuple[Any, Any] | None, - nan_agg: Any | None, - na_color_hex: str, - spread_px: int | None = None, - ds_reduction: _DsReduction | None = None, -) -> tuple[Any, Any | None, tuple[Any, Any] | None]: - """Shade a continuous datashader aggregate, optionally applying spread and NaN coloring. - - Returns (shaded, nan_shaded, reduction_bounds). - """ - if spread_px is not None: - spread_how = _datshader_get_how_kw_for_spread(ds_reduction) - agg = ds.tf.spread(agg, px=spread_px, how=spread_how) - reduction_bounds = (agg.min(), agg.max()) - - ds_cmap = cmap - if ( - reduction_bounds is not None - and reduction_bounds[0] == reduction_bounds[1] - and (color_span is None or color_span != [0, 1]) - ): - ds_cmap = matplotlib.colors.to_hex(cmap(0.0), keep_alpha=False) - reduction_bounds = ( - reduction_bounds[0], - reduction_bounds[0] + 1, - ) - - shaded = _datashader_map_aggregate_to_color( - agg, - cmap=ds_cmap, - min_alpha=_convert_alpha_to_datashader_range(alpha), - span=color_span, - clip=norm.clip, - ) - - nan_shaded = None - if nan_agg is not None: - shade_kwargs: dict[str, Any] = {"cmap": na_color_hex, "how": "linear"} - if spread_px is not None: - nan_agg = ds.tf.spread(nan_agg, px=spread_px, how="max") - else: - # only shapes (no spread) pass min_alpha for NaN shading - shade_kwargs["min_alpha"] = _convert_alpha_to_datashader_range(alpha) - nan_shaded = ds.tf.shade(nan_agg, **shade_kwargs) - - return shaded, nan_shaded, reduction_bounds - - -def _build_color_key( - transformed_element: Any, - col_for_color: str | None, - color_by_categorical: bool, - color_vector: Any, - na_color_hex: str, -) -> dict[str, str] | None: - """Build a datashader color key mapping categories to hex colors. - - Returns None when not coloring by a categorical column. - """ - if not color_by_categorical or col_for_color is None: - return None - cat_series = _coerce_categorical_source(transformed_element[col_for_color]) - return _build_datashader_color_key(cat_series, color_vector, na_color_hex) - - -def _ds_shade_categorical( - agg: Any, - color_key: dict[str, str] | None, - color_vector: Any, - alpha: float, - spread_px: int | None = None, -) -> Any: - """Shade a categorical or no-color datashader aggregate.""" - ds_cmap = None - if color_vector is not None: - ds_cmap = color_vector[0] - if isinstance(ds_cmap, str) and ds_cmap[0] == "#": - ds_cmap = _hex_no_alpha(ds_cmap) - - agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg - return _datashader_map_aggregate_to_color( - agg_to_shade, - cmap=ds_cmap, - color_key=color_key, - min_alpha=_convert_alpha_to_datashader_range(alpha), - ) - - -def _render_ds_outlines( - cvs: Any, - transformed_element: Any, - render_params: ShapesRenderParams, - fig_params: FigParams, - ax: matplotlib.axes.SubplotBase, - factor: float, - extent: list[float], -) -> None: - """Aggregate, shade, and render shape outlines (outer and inner) with datashader.""" - ds_lw_factor = fig_params.fig.dpi / 72 - assert len(render_params.outline_alpha) == 2 # noqa: S101 - - for idx, (outline_color_obj, linewidth) in enumerate( - [ - (render_params.outline_params.outer_outline_color, render_params.outline_params.outer_outline_linewidth), - (render_params.outline_params.inner_outline_color, render_params.outline_params.inner_outline_linewidth), - ] - ): - alpha = render_params.outline_alpha[idx] - if alpha <= 0: - continue - agg_outline = cvs.line( - transformed_element, - geometry="geometry", - line_width=linewidth * ds_lw_factor, - ) - if isinstance(outline_color_obj, Color): - shaded = ds.tf.shade( - agg_outline, - cmap=outline_color_obj.get_hex(), - min_alpha=_convert_alpha_to_datashader_range(alpha), - how="linear", - ) - rgba, trans = _create_image_from_datashader_result(shaded, factor, ax) - _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, alpha=alpha, extent=extent) - - -def _render_ds_image( - ax: matplotlib.axes.SubplotBase, - shaded: Any, - factor: float, - zorder: int, - alpha: float, - extent: list[float] | None, - nan_result: Any | None = None, -) -> Any: - """Render a shaded datashader image onto matplotlib axes, with optional NaN overlay.""" - if nan_result is not None: - rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax) - _ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, alpha=alpha, extent=extent) - rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax) - return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, alpha=alpha, extent=extent) - - def _make_palette( color_source_vector: pd.Series | None, color_vector: Any, From 6985eff8add01d1bb87efeaf6d0504f4dacdbffc Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Fri, 20 Mar 2026 01:25:04 +0100 Subject: [PATCH 5/6] Add compositing plan documenting investigation and design decisions Co-Authored-By: Claude Opus 4.6 --- plans/unify-additive-blending.md | 131 +++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 plans/unify-additive-blending.md diff --git a/plans/unify-additive-blending.md b/plans/unify-additive-blending.md new file mode 100644 index 00000000..2ed12f0a --- /dev/null +++ b/plans/unify-additive-blending.md @@ -0,0 +1,131 @@ +# Plan: Unify Multi-Channel Image Compositing to Additive Blending + +## Goal + +Replace the inconsistent compositing formulas in `_render_images` with a single shared helper that implements standard additive blending with clamping (matching Napari/ImageJ/FIJI). + +## Current State (investigation findings) + +### Compositing paths in `render.py` `_render_images` + +| Path | Lines | Condition | Formula | Alpha | Bug | +|------|-------|-----------|---------|-------|-----| +| **1** | 1232-1254 | 1 channel | Direct (no compositing) | Baked in cmap | None | +| **2A-RGB** | 1282-1283 | 3ch, default cmap | `np.stack(..., axis=-1)` | Via imshow | None | +| **2A-cmap** | 1284-1293 | 3ch, user cmap | `.sum(0) / n_channels` | Via imshow | **Averaging** | +| **2B** 2ch | 1314-1321 | 2ch, no palette | `.sum(0)`, no clip | Via imshow | **No clip** | +| **2B** 3ch | 1322-1329 | 3ch, no palette | `.sum(0)`, no clip | Via imshow | **No clip** | +| **2B** 4+ch | 1330-1363 | 4+ch, no palette | additive + `np.clip` | **Baked + imshow** | **Double alpha** | +| **2C** | 1374-1380 | palette | `.sum(0)`, no clip | Via imshow | **No clip + filter bug** | +| **2D** | 1390-1398 | multiple cmaps | `.sum(0) / n_channels` | Via imshow | **Averaging** | + +### Key infrastructure details + +- **`layers` dict** (L1269-1288): Keys are channel identifiers (str or int). Values are normalized arrays (norm applied before compositing). Constructed correctly. +- **`_get_linear_colormap`** (utils.py L1766-1767): `LinearSegmentedColormap.from_list(c, ["k", c], N=256)` — black-to-color LUTs. Returns RGBA (4 channels) when called. Correct for additive blending. +- **`_ax_show_and_transform`** (utils.py L2893-2936): When `cmap is None` and `alpha is not None`, passes `alpha` to `ax.imshow()`. When cmap is present, does NOT pass alpha. This means for composite RGB arrays, alpha is applied by imshow. +- **Palette validation**: `_type_check_params` already ensures palette contains only strings. The `if isinstance(c, str)` filter at L1378 is redundant and risks index misalignment. + +### Alpha flow summary + +All paths except 2B-4+ch correctly apply alpha once (via `_ax_show_and_transform` → `ax.imshow(alpha=...)`). + +Path 2B-4+ch double-applies alpha: +1. L1356-1357: `rgba[..., 3] = render_params.alpha` then `comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]` +2. L1369: passes `render_params.alpha` to `_ax_show_and_transform` → `ax.imshow(alpha=...)` + +### Tests affected + +Tests that will produce different (brighter) output when switching from averaging to additive: + +| Baseline | Test | Path | +|----------|------|------| +| `Images_can_pass_str_cmap.png` | `test_plot_can_pass_str_cmap` | 2A-cmap | +| `Images_can_pass_cmap.png` | `test_plot_can_pass_cmap` | 2A-cmap | +| `Images_can_render_multiscale_image_with_custom_cmap.png` | `test_plot_can_render_multiscale_image_with_custom_cmap` | 2A-cmap | +| `Images_can_pass_str_cmap_list.png` | `test_plot_can_pass_str_cmap_list` | 2D | +| `Images_can_pass_cmap_list.png` | `test_plot_can_pass_cmap_list` | 2D | +| `Images_can_pass_cmap_to_each_channel.png` | `test_plot_can_pass_cmap_to_each_channel` | 2D | + +Tests using 2B (2-3ch) or 2C may also shift slightly due to clip being added, but only if values currently exceed [0,1]. + +### Baseline regeneration + +Per `docs/contributing.md`: baselines must be generated on Ubuntu in GitHub Actions (not locally). Push the change, let CI run, download `visual_test_results_*` artifact, review manually, copy to `tests/_images/`. + +## Steps + +### Step 1: Add `_additive_blend` helper + +**File**: `src/spatialdata_plot/pl/render.py` + +Place as a module-level function before `_render_images` (near other helpers). + +```python +def _additive_blend( + layers: dict, + channels: list, + channel_cmaps: list, +) -> np.ndarray: + """Additive blend of colormapped channels, matching Napari's additive mode. + + Each channel is mapped through its colormap, the RGB components are summed, + and the result is clamped to [0, 1]. + """ + H, W = next(iter(layers.values())).shape + composite = np.zeros((H, W, 3), dtype=float) + for ch_idx, ch in enumerate(channels): + rgba = channel_cmaps[ch_idx](np.asarray(layers[ch])) + composite += rgba[..., :3] + return np.clip(composite, 0, 1) +``` + +No alpha parameter — alpha is handled uniformly by `_ax_show_and_transform` → `ax.imshow()`. + +### Step 2: Update path 2A-cmap (L1284-1293) + +Replace averaging with `_additive_blend`. Keep the warning about white cmaps. + +### Step 3: Update path 2B 2ch (L1314-1321) + +Replace `.sum(0)[:, :, :3]` with `_additive_blend(layers, channels, channel_cmaps)`. + +### Step 4: Update path 2B 3ch (L1322-1329) + +Same as step 3. + +### Step 5: Simplify path 2B 4+ch (L1330-1363) + +Replace inline loop with `_additive_blend(layers, channels, channel_cmaps)`. This removes the double-alpha bug (no more alpha bake-in; alpha is only applied via imshow). + +### Step 6: Update path 2C palette (L1374-1380) + +Replace `.sum(0)[:, :, :3]` with `_additive_blend`. Remove the `if isinstance(c, str)` filter. + +### Step 7: Update path 2D multiple cmaps (L1390-1398) + +Replace averaging with `_additive_blend`. + +### Step 8: Regenerate baseline images + +Push to PR branch, let CI run, download artifacts, review, commit new baselines. + +## Not in scope + +- PCA multichannel strategy / new parameters +- Changes to colormap selection logic +- Channel validation improvements (separate PR) +- The `norm.vmin` bug at L241 (functionally neutral when `vmin == vmax`) +- The `norm.vmax` bug from PR #451 at L258 (already fixed on main) + +## Risks + +1. **Visual change**: Averaging → additive makes composites brighter. Intentional and correct but requires baseline regeneration. +2. **Alpha semantics for 4+ch**: Removing double-alpha changes how `alpha < 1` looks. Since the current behavior is a bug, this is a fix. +3. **Edge case — saturated composites**: Additive sum of many channels can saturate to white. This matches Napari behavior and is expected. + +## Test strategy + +- Existing image comparison tests cover all paths. +- Add one direct unit test for `_additive_blend` with known inputs/outputs. +- Baselines regenerated via CI on Ubuntu. From 2d8f0da4ea25958770c75c056d467137e00fd83c Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Fri, 20 Mar 2026 01:26:19 +0100 Subject: [PATCH 6/6] Revert "Add compositing plan documenting investigation and design decisions" This reverts commit 6985eff8add01d1bb87efeaf6d0504f4dacdbffc. --- plans/unify-additive-blending.md | 131 ------------------------------- 1 file changed, 131 deletions(-) delete mode 100644 plans/unify-additive-blending.md diff --git a/plans/unify-additive-blending.md b/plans/unify-additive-blending.md deleted file mode 100644 index 2ed12f0a..00000000 --- a/plans/unify-additive-blending.md +++ /dev/null @@ -1,131 +0,0 @@ -# Plan: Unify Multi-Channel Image Compositing to Additive Blending - -## Goal - -Replace the inconsistent compositing formulas in `_render_images` with a single shared helper that implements standard additive blending with clamping (matching Napari/ImageJ/FIJI). - -## Current State (investigation findings) - -### Compositing paths in `render.py` `_render_images` - -| Path | Lines | Condition | Formula | Alpha | Bug | -|------|-------|-----------|---------|-------|-----| -| **1** | 1232-1254 | 1 channel | Direct (no compositing) | Baked in cmap | None | -| **2A-RGB** | 1282-1283 | 3ch, default cmap | `np.stack(..., axis=-1)` | Via imshow | None | -| **2A-cmap** | 1284-1293 | 3ch, user cmap | `.sum(0) / n_channels` | Via imshow | **Averaging** | -| **2B** 2ch | 1314-1321 | 2ch, no palette | `.sum(0)`, no clip | Via imshow | **No clip** | -| **2B** 3ch | 1322-1329 | 3ch, no palette | `.sum(0)`, no clip | Via imshow | **No clip** | -| **2B** 4+ch | 1330-1363 | 4+ch, no palette | additive + `np.clip` | **Baked + imshow** | **Double alpha** | -| **2C** | 1374-1380 | palette | `.sum(0)`, no clip | Via imshow | **No clip + filter bug** | -| **2D** | 1390-1398 | multiple cmaps | `.sum(0) / n_channels` | Via imshow | **Averaging** | - -### Key infrastructure details - -- **`layers` dict** (L1269-1288): Keys are channel identifiers (str or int). Values are normalized arrays (norm applied before compositing). Constructed correctly. -- **`_get_linear_colormap`** (utils.py L1766-1767): `LinearSegmentedColormap.from_list(c, ["k", c], N=256)` — black-to-color LUTs. Returns RGBA (4 channels) when called. Correct for additive blending. -- **`_ax_show_and_transform`** (utils.py L2893-2936): When `cmap is None` and `alpha is not None`, passes `alpha` to `ax.imshow()`. When cmap is present, does NOT pass alpha. This means for composite RGB arrays, alpha is applied by imshow. -- **Palette validation**: `_type_check_params` already ensures palette contains only strings. The `if isinstance(c, str)` filter at L1378 is redundant and risks index misalignment. - -### Alpha flow summary - -All paths except 2B-4+ch correctly apply alpha once (via `_ax_show_and_transform` → `ax.imshow(alpha=...)`). - -Path 2B-4+ch double-applies alpha: -1. L1356-1357: `rgba[..., 3] = render_params.alpha` then `comp_rgb += rgba[..., :3] * rgba[..., 3][..., None]` -2. L1369: passes `render_params.alpha` to `_ax_show_and_transform` → `ax.imshow(alpha=...)` - -### Tests affected - -Tests that will produce different (brighter) output when switching from averaging to additive: - -| Baseline | Test | Path | -|----------|------|------| -| `Images_can_pass_str_cmap.png` | `test_plot_can_pass_str_cmap` | 2A-cmap | -| `Images_can_pass_cmap.png` | `test_plot_can_pass_cmap` | 2A-cmap | -| `Images_can_render_multiscale_image_with_custom_cmap.png` | `test_plot_can_render_multiscale_image_with_custom_cmap` | 2A-cmap | -| `Images_can_pass_str_cmap_list.png` | `test_plot_can_pass_str_cmap_list` | 2D | -| `Images_can_pass_cmap_list.png` | `test_plot_can_pass_cmap_list` | 2D | -| `Images_can_pass_cmap_to_each_channel.png` | `test_plot_can_pass_cmap_to_each_channel` | 2D | - -Tests using 2B (2-3ch) or 2C may also shift slightly due to clip being added, but only if values currently exceed [0,1]. - -### Baseline regeneration - -Per `docs/contributing.md`: baselines must be generated on Ubuntu in GitHub Actions (not locally). Push the change, let CI run, download `visual_test_results_*` artifact, review manually, copy to `tests/_images/`. - -## Steps - -### Step 1: Add `_additive_blend` helper - -**File**: `src/spatialdata_plot/pl/render.py` - -Place as a module-level function before `_render_images` (near other helpers). - -```python -def _additive_blend( - layers: dict, - channels: list, - channel_cmaps: list, -) -> np.ndarray: - """Additive blend of colormapped channels, matching Napari's additive mode. - - Each channel is mapped through its colormap, the RGB components are summed, - and the result is clamped to [0, 1]. - """ - H, W = next(iter(layers.values())).shape - composite = np.zeros((H, W, 3), dtype=float) - for ch_idx, ch in enumerate(channels): - rgba = channel_cmaps[ch_idx](np.asarray(layers[ch])) - composite += rgba[..., :3] - return np.clip(composite, 0, 1) -``` - -No alpha parameter — alpha is handled uniformly by `_ax_show_and_transform` → `ax.imshow()`. - -### Step 2: Update path 2A-cmap (L1284-1293) - -Replace averaging with `_additive_blend`. Keep the warning about white cmaps. - -### Step 3: Update path 2B 2ch (L1314-1321) - -Replace `.sum(0)[:, :, :3]` with `_additive_blend(layers, channels, channel_cmaps)`. - -### Step 4: Update path 2B 3ch (L1322-1329) - -Same as step 3. - -### Step 5: Simplify path 2B 4+ch (L1330-1363) - -Replace inline loop with `_additive_blend(layers, channels, channel_cmaps)`. This removes the double-alpha bug (no more alpha bake-in; alpha is only applied via imshow). - -### Step 6: Update path 2C palette (L1374-1380) - -Replace `.sum(0)[:, :, :3]` with `_additive_blend`. Remove the `if isinstance(c, str)` filter. - -### Step 7: Update path 2D multiple cmaps (L1390-1398) - -Replace averaging with `_additive_blend`. - -### Step 8: Regenerate baseline images - -Push to PR branch, let CI run, download artifacts, review, commit new baselines. - -## Not in scope - -- PCA multichannel strategy / new parameters -- Changes to colormap selection logic -- Channel validation improvements (separate PR) -- The `norm.vmin` bug at L241 (functionally neutral when `vmin == vmax`) -- The `norm.vmax` bug from PR #451 at L258 (already fixed on main) - -## Risks - -1. **Visual change**: Averaging → additive makes composites brighter. Intentional and correct but requires baseline regeneration. -2. **Alpha semantics for 4+ch**: Removing double-alpha changes how `alpha < 1` looks. Since the current behavior is a bug, this is a fix. -3. **Edge case — saturated composites**: Additive sum of many channels can saturate to white. This matches Napari behavior and is expected. - -## Test strategy - -- Existing image comparison tests cover all paths. -- Add one direct unit test for `_additive_blend` with known inputs/outputs. -- Baselines regenerated via CI on Ubuntu.