diff --git a/src/spatialdata_plot/pl/_datashader.py b/src/spatialdata_plot/pl/_datashader.py new file mode 100644 index 00000000..3b50dd52 --- /dev/null +++ b/src/spatialdata_plot/pl/_datashader.py @@ -0,0 +1,342 @@ +"""Datashader aggregation, shading, and rendering helpers. + +Shared by ``_render_shapes`` and ``_render_points`` in ``render.py``. +""" + +from __future__ import annotations + +from typing import Any, Literal + +import dask.dataframe as dd +import datashader as ds +import matplotlib +import matplotlib.colors +import numpy as np +import pandas as pd +from matplotlib.cm import ScalarMappable +from matplotlib.colors import Normalize + +from spatialdata_plot._logging import logger +from spatialdata_plot.pl.render_params import Color, FigParams, ShapesRenderParams +from spatialdata_plot.pl.utils import ( + _ax_show_and_transform, + _convert_alpha_to_datashader_range, + _create_image_from_datashader_result, + _datashader_aggregate_with_function, + _datashader_map_aggregate_to_color, + _datshader_get_how_kw_for_spread, + _hex_no_alpha, +) + +# --------------------------------------------------------------------------- +# Type aliases and constants +# --------------------------------------------------------------------------- + +_DsReduction = Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] + +# Sentinel category name used in datashader categorical paths to represent +# missing (NaN) values. Must not collide with realistic user category names. +_DS_NAN_CATEGORY = "ds_nan" + +# --------------------------------------------------------------------------- +# Low-level helpers +# --------------------------------------------------------------------------- + + +def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical: + """Return a ``pd.Categorical`` from a pandas or dask Series.""" + if isinstance(series, dd.Series): + if isinstance(series.dtype, pd.CategoricalDtype) and getattr(series.cat, "known", True) is False: + series = series.cat.as_known() + series = series.compute() + if isinstance(series.dtype, pd.CategoricalDtype): + return series.array + return pd.Categorical(series) + + +def _build_datashader_color_key( + cat_series: pd.Categorical, + color_vector: Any, + na_color_hex: str, +) -> dict[str, str]: + """Build a datashader ``color_key`` dict from a categorical series and its color vector.""" + na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex + colors_arr = np.asarray(color_vector, dtype=object) + first_color: dict[str, str] = {} + for code, color in zip(cat_series.codes, colors_arr, strict=False): + if code < 0: + continue + cat_name = str(cat_series.categories[code]) + if cat_name not in first_color: + first_color[cat_name] = _hex_no_alpha(color) if isinstance(color, str) and color.startswith("#") else color + return {str(c): first_color.get(str(c), na_hex) for c in cat_series.categories} + + +def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series: + """Add a sentinel category for NaN values in a categorical series. + + Safely handles series that are not yet categorical, dask-backed + categoricals that need ``as_known()``, and series that already + contain the sentinel. + """ + if not isinstance(series.dtype, pd.CategoricalDtype): + series = series.astype("category") + if hasattr(series.cat, "as_known"): + series = series.cat.as_known() + if sentinel not in series.cat.categories: + series = series.cat.add_categories(sentinel) + return series.fillna(sentinel) + + +# --------------------------------------------------------------------------- +# Pipeline helpers (aggregate -> norm -> shade -> render) +# --------------------------------------------------------------------------- + + +def _ds_aggregate( + cvs: Any, + transformed_element: Any, + col_for_color: str | None, + color_by_categorical: bool, + ds_reduction: _DsReduction | None, + default_reduction: _DsReduction, + geom_type: Literal["points", "shapes"], +) -> tuple[Any, tuple[Any, Any] | None, Any | None]: + """Aggregate spatial elements with datashader. + + Dispatches between categorical (ds.by), continuous (reduction function), + and no-color (ds.count) aggregation modes. + + Returns (agg, reduction_bounds, nan_agg). + """ + reduction_bounds = None + nan_agg = None + + def _agg_call(element: Any, agg_func: Any) -> Any: + if geom_type == "shapes": + return cvs.polygons(element, geometry="geometry", agg=agg_func) + return cvs.points(element, "x", "y", agg=agg_func) + + if col_for_color is not None: + if color_by_categorical: + transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) + agg = _agg_call(transformed_element, ds.by(col_for_color, ds.count())) + else: + reduction_name = ds_reduction if ds_reduction is not None else default_reduction + logger.info( + f'Using the datashader reduction "{reduction_name}". "max" will give an output ' + "very close to the matplotlib result." + ) + agg = _datashader_aggregate_with_function(ds_reduction, cvs, transformed_element, col_for_color, geom_type) + reduction_bounds = (agg.min(), agg.max()) + + nan_elements = transformed_element[transformed_element[col_for_color].isnull()] + if len(nan_elements) > 0: + nan_agg = _datashader_aggregate_with_function("any", cvs, nan_elements, None, geom_type) + else: + agg = _agg_call(transformed_element, ds.count()) + + return agg, reduction_bounds, nan_agg + + +def _apply_ds_norm( + agg: Any, + norm: Normalize, +) -> tuple[Any, list[float] | None]: + """Apply norm vmin/vmax to a datashader aggregate. + + When vmin == vmax, maps the value to 0.5 using an artificial [0, 1] span. + Returns (agg, color_span) where color_span is None if no norm was set. + """ + if norm.vmin is None and norm.vmax is None: + return agg, None + norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin + norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax + color_span: list[float] = [norm.vmin, norm.vmax] + if norm.vmin == norm.vmax: + color_span = [0, 1] + if norm.clip: + agg = (agg - agg) + 0.5 + else: + agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) + agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) + agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) + return agg, color_span + + +def _build_color_key( + transformed_element: Any, + col_for_color: str | None, + color_by_categorical: bool, + color_vector: Any, + na_color_hex: str, +) -> dict[str, str] | None: + """Build a datashader color key mapping categories to hex colors. + + Returns None when not coloring by a categorical column. + """ + if not color_by_categorical or col_for_color is None: + return None + cat_series = _coerce_categorical_source(transformed_element[col_for_color]) + return _build_datashader_color_key(cat_series, color_vector, na_color_hex) + + +def _ds_shade_continuous( + agg: Any, + color_span: list[float] | None, + norm: Normalize, + cmap: Any, + alpha: float, + reduction_bounds: tuple[Any, Any] | None, + nan_agg: Any | None, + na_color_hex: str, + spread_px: int | None = None, + ds_reduction: _DsReduction | None = None, +) -> tuple[Any, Any | None, tuple[Any, Any] | None]: + """Shade a continuous datashader aggregate, optionally applying spread and NaN coloring. + + Returns (shaded, nan_shaded, reduction_bounds). + """ + if spread_px is not None: + spread_how = _datshader_get_how_kw_for_spread(ds_reduction) + agg = ds.tf.spread(agg, px=spread_px, how=spread_how) + reduction_bounds = (agg.min(), agg.max()) + + ds_cmap = cmap + if ( + reduction_bounds is not None + and reduction_bounds[0] == reduction_bounds[1] + and (color_span is None or color_span != [0, 1]) + ): + ds_cmap = matplotlib.colors.to_hex(cmap(0.0), keep_alpha=False) + reduction_bounds = ( + reduction_bounds[0], + reduction_bounds[0] + 1, + ) + + shaded = _datashader_map_aggregate_to_color( + agg, + cmap=ds_cmap, + min_alpha=_convert_alpha_to_datashader_range(alpha), + span=color_span, + clip=norm.clip, + ) + + nan_shaded = None + if nan_agg is not None: + shade_kwargs: dict[str, Any] = {"cmap": na_color_hex, "how": "linear"} + if spread_px is not None: + nan_agg = ds.tf.spread(nan_agg, px=spread_px, how="max") + else: + # only shapes (no spread) pass min_alpha for NaN shading + shade_kwargs["min_alpha"] = _convert_alpha_to_datashader_range(alpha) + nan_shaded = ds.tf.shade(nan_agg, **shade_kwargs) + + return shaded, nan_shaded, reduction_bounds + + +def _ds_shade_categorical( + agg: Any, + color_key: dict[str, str] | None, + color_vector: Any, + alpha: float, + spread_px: int | None = None, +) -> Any: + """Shade a categorical or no-color datashader aggregate.""" + ds_cmap = None + if color_vector is not None: + ds_cmap = color_vector[0] + if isinstance(ds_cmap, str) and ds_cmap[0] == "#": + ds_cmap = _hex_no_alpha(ds_cmap) + + agg_to_shade = ds.tf.spread(agg, px=spread_px) if spread_px is not None else agg + return _datashader_map_aggregate_to_color( + agg_to_shade, + cmap=ds_cmap, + color_key=color_key, + min_alpha=_convert_alpha_to_datashader_range(alpha), + ) + + +# --------------------------------------------------------------------------- +# Image rendering +# --------------------------------------------------------------------------- + + +def _render_ds_image( + ax: matplotlib.axes.SubplotBase, + shaded: Any, + factor: float, + zorder: int, + alpha: float, + extent: list[float] | None, + nan_result: Any | None = None, +) -> Any: + """Render a shaded datashader image onto matplotlib axes, with optional NaN overlay.""" + if nan_result is not None: + rgba_nan, trans_nan = _create_image_from_datashader_result(nan_result, factor, ax) + _ax_show_and_transform(rgba_nan, trans_nan, ax, zorder=zorder, alpha=alpha, extent=extent) + rgba_image, trans_data = _create_image_from_datashader_result(shaded, factor, ax) + return _ax_show_and_transform(rgba_image, trans_data, ax, zorder=zorder, alpha=alpha, extent=extent) + + +def _render_ds_outlines( + cvs: Any, + transformed_element: Any, + render_params: ShapesRenderParams, + fig_params: FigParams, + ax: matplotlib.axes.SubplotBase, + factor: float, + extent: list[float], +) -> None: + """Aggregate, shade, and render shape outlines (outer and inner) with datashader.""" + ds_lw_factor = fig_params.fig.dpi / 72 + assert len(render_params.outline_alpha) == 2 # noqa: S101 + + for idx, (outline_color_obj, linewidth) in enumerate( + [ + (render_params.outline_params.outer_outline_color, render_params.outline_params.outer_outline_linewidth), + (render_params.outline_params.inner_outline_color, render_params.outline_params.inner_outline_linewidth), + ] + ): + alpha = render_params.outline_alpha[idx] + if alpha <= 0: + continue + agg_outline = cvs.line( + transformed_element, + geometry="geometry", + line_width=linewidth * ds_lw_factor, + ) + if isinstance(outline_color_obj, Color): + shaded = ds.tf.shade( + agg_outline, + cmap=outline_color_obj.get_hex(), + min_alpha=_convert_alpha_to_datashader_range(alpha), + how="linear", + ) + rgba, trans = _create_image_from_datashader_result(shaded, factor, ax) + _ax_show_and_transform(rgba, trans, ax, zorder=render_params.zorder, alpha=alpha, extent=extent) + + +def _build_ds_colorbar( + reduction_bounds: tuple[Any, Any] | None, + norm: Normalize, + cmap: Any, +) -> ScalarMappable | None: + """Create a ScalarMappable for the colorbar from datashader reduction bounds. + + Returns None if there is no continuous reduction. + """ + if reduction_bounds is None: + return None + vmin = reduction_bounds[0].values if norm.vmin is None else norm.vmin + vmax = reduction_bounds[1].values if norm.vmax is None else norm.vmax + if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: + assert norm.vmin is not None + assert norm.vmax is not None + vmin = norm.vmin - 0.5 + vmax = norm.vmin + 0.5 + return ScalarMappable( + norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), + cmap=cmap, + ) diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 9cddb499..488db5c9 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -27,6 +27,16 @@ from xarray import DataTree from spatialdata_plot._logging import _log_context, logger +from spatialdata_plot.pl._datashader import ( + _apply_ds_norm, + _build_color_key, + _build_ds_colorbar, + _ds_aggregate, + _ds_shade_categorical, + _ds_shade_continuous, + _render_ds_image, + _render_ds_outlines, +) from spatialdata_plot.pl.render_params import ( Color, ColorbarSpec, @@ -40,12 +50,7 @@ ) from spatialdata_plot.pl.utils import ( _ax_show_and_transform, - _convert_alpha_to_datashader_range, _convert_shapes, - _create_image_from_datashader_result, - _datashader_aggregate_with_function, - _datashader_map_aggregate_to_color, - _datshader_get_how_kw_for_spread, _decorate_axs, _get_collection_shape, _get_colors_for_categorical_obs, @@ -64,56 +69,6 @@ _Normalize = Normalize | abc.Sequence[Normalize] -# Sentinel category name used in datashader categorical paths to represent -# missing (NaN) values. Must not collide with realistic user category names. -_DS_NAN_CATEGORY = "ds_nan" - - -def _coerce_categorical_source(series: pd.Series | dd.Series) -> pd.Categorical: - """Return a ``pd.Categorical`` from a pandas or dask Series.""" - if isinstance(series, dd.Series): - if isinstance(series.dtype, pd.CategoricalDtype) and getattr(series.cat, "known", True) is False: - series = series.cat.as_known() - series = series.compute() - if isinstance(series.dtype, pd.CategoricalDtype): - return series.array - return pd.Categorical(series) - - -def _build_datashader_color_key( - cat_series: pd.Categorical, - color_vector: Any, - na_color_hex: str, -) -> dict[str, str]: - """Build a datashader ``color_key`` dict from a categorical series and its color vector.""" - na_hex = _hex_no_alpha(na_color_hex) if na_color_hex.startswith("#") else na_color_hex - # Map each category to its first-occurrence color via codes - colors_arr = np.asarray(color_vector, dtype=object) - first_color: dict[str, str] = {} - for code, color in zip(cat_series.codes, colors_arr, strict=False): - if code < 0: - continue - cat_name = str(cat_series.categories[code]) - if cat_name not in first_color: - first_color[cat_name] = _hex_no_alpha(color) if isinstance(color, str) and color.startswith("#") else color - return {str(c): first_color.get(str(c), na_hex) for c in cat_series.categories} - - -def _inject_ds_nan_sentinel(series: pd.Series, sentinel: str = _DS_NAN_CATEGORY) -> pd.Series: - """Add a sentinel category for NaN values in a categorical series. - - Safely handles series that are not yet categorical, dask-backed - categoricals that need ``as_known()``, and series that already - contain the sentinel. - """ - if not isinstance(series.dtype, pd.CategoricalDtype): - series = series.astype("category") - if hasattr(series.cat, "as_known"): - series = series.cat.as_known() - if sentinel not in series.cat.categories: - series = series.cat.add_categories(sentinel) - return series.fillna(sentinel) - def _want_decorations(color_vector: Any, na_color: Color) -> bool: """Return whether legend/colorbar decorations should be shown. @@ -218,6 +173,77 @@ def _should_request_colorbar( return bool(auto_condition) +def _make_palette( + color_source_vector: pd.Series | None, + color_vector: Any, +) -> ListedColormap: + """Build a ListedColormap from a color vector, filtering out NaN entries when categorical.""" + if color_source_vector is None: + return ListedColormap(dict.fromkeys(color_vector)) + return ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) + + +def _add_legend_and_colorbar( + ax: matplotlib.axes.SubplotBase, + cax: ScalarMappable | None, + fig_params: FigParams, + adata: AnnData | None, + col_for_color: str | None, + color_source_vector: pd.Series | None, + color_vector: Any, + palette: ListedColormap | list[str] | None, + alpha: float, + na_color: Color, + legend_params: LegendParams, + colorbar: bool | str | None, + colorbar_params: dict[str, object] | None, + colorbar_requests: list[ColorbarSpec] | None, + scalebar_params: ScalebarParams, +) -> None: + """Add legend, colorbar, and scalebar decorations if the color vector warrants them.""" + if not _want_decorations(color_vector, na_color): + return + + if palette is None: + palette = _make_palette(color_source_vector, color_vector) + + if color_source_vector is not None and hasattr(color_source_vector, "remove_unused_categories"): + color_source_vector = color_source_vector.remove_unused_categories() + + wants_colorbar = _should_request_colorbar( + colorbar, + has_mappable=cax is not None, + is_continuous=col_for_color is not None and color_source_vector is None, + ) + + _decorate_axs( + ax=ax, + cax=cax, + fig_params=fig_params, + adata=adata, + value_to_plot=col_for_color, + color_source_vector=color_source_vector, + color_vector=color_vector, + palette=palette, + alpha=alpha, + na_color=na_color, + legend_fontsize=legend_params.legend_fontsize, + legend_fontweight=legend_params.legend_fontweight, + legend_loc=legend_params.legend_loc, + legend_fontoutline=legend_params.legend_fontoutline, + na_in_legend=legend_params.na_in_legend, + colorbar=wants_colorbar and legend_params.colorbar, + colorbar_params=colorbar_params, + colorbar_requests=colorbar_requests, + colorbar_label=_resolve_colorbar_label( + colorbar_params, + col_for_color if isinstance(col_for_color, str) else None, + ), + scalebar_dx=scalebar_params.scalebar_dx, + scalebar_units=scalebar_params.scalebar_units, + ) + + def _render_shapes( sdata: sd.SpatialData, render_params: ShapesRenderParams, @@ -315,13 +341,7 @@ def _render_shapes( "These observations will be colored with the 'na_color'." ) - # Using dict.fromkeys here since set returns in arbitrary order - # remove the color of NaN values, else it might be assigned to a category - # order of color in the palette should agree to order of occurence - if color_source_vector is None: - palette = ListedColormap(dict.fromkeys(color_vector)) - else: - palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) + palette = _make_palette(color_source_vector, color_vector) has_valid_color = ( len(set(color_vector)) != 1 @@ -419,201 +439,59 @@ def _render_shapes( cat_series = cat_series.astype("category") transformed_element[col_for_color] = cat_series - aggregate_with_reduction = None - continuous_nan_shapes = None - if col_for_color is not None: - if color_by_categorical: - # add a sentinel category so that shapes with NaN value are colored in the na_color - transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) - agg = cvs.polygons( - transformed_element, - geometry="geometry", - agg=ds.by(col_for_color, ds.count()), - ) - else: - reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean" - logger.info( - f'Using the datashader reduction "{reduction_name}". "max" will give an output very close ' - "to the matplotlib result." - ) - agg = _datashader_aggregate_with_function( - render_params.ds_reduction, - cvs, - transformed_element, - col_for_color, - "shapes", - ) - # save min and max values for drawing the colorbar - aggregate_with_reduction = (agg.min(), agg.max()) - - # nan shapes need to be rendered separately (else: invisible, bc nan is skipped by aggregation methods) - transformed_element_nan_color = transformed_element[transformed_element[col_for_color].isnull()] - if len(transformed_element_nan_color) > 0: - continuous_nan_shapes = _datashader_aggregate_with_function( - "any", cvs, transformed_element_nan_color, None, "shapes" - ) - else: - agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.count()) - - # render outlines if needed - # outline_linewidth is in points (1pt = 1/72 inch); datashader line_width is in canvas pixels - ds_lw_factor = fig_params.fig.dpi / 72 - assert len(render_params.outline_alpha) == 2 # shut up mypy - if render_params.outline_alpha[0] > 0: - agg_outlines = cvs.line( - transformed_element, - geometry="geometry", - line_width=render_params.outline_params.outer_outline_linewidth * ds_lw_factor, - ) - if render_params.outline_alpha[1] > 0: - agg_inner_outlines = cvs.line( - transformed_element, - geometry="geometry", - line_width=render_params.outline_params.inner_outline_linewidth * ds_lw_factor, - ) + agg, reduction_bounds, nan_agg = _ds_aggregate( + cvs, + transformed_element, + col_for_color, + color_by_categorical, + render_params.ds_reduction, + "mean", + "shapes", + ) - ds_span = None - if norm.vmin is not None or norm.vmax is not None: - norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin - norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax - ds_span = [norm.vmin, norm.vmax] - if norm.vmin == norm.vmax: - # edge case, value vmin is rendered as the middle of the cmap - ds_span = [0, 1] - if norm.clip: - agg = (agg - agg) + 0.5 - else: - agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) - agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) - agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) - - color_key: dict[str, str] | None = None - if color_by_categorical and col_for_color is not None: - cat_series = _coerce_categorical_source(transformed_element[col_for_color]) - color_key = _build_datashader_color_key( - cat_series, color_vector, render_params.cmap_params.na_color.get_hex() - ) + agg, color_span = _apply_ds_norm(agg, norm) + na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) + color_key = _build_color_key( + transformed_element, + col_for_color, + color_by_categorical, + color_vector, + na_color_hex, + ) + nan_shaded = None if color_by_categorical or col_for_color is None: - ds_cmap = None - if color_vector is not None: - ds_cmap = color_vector[0] - if isinstance(ds_cmap, str) and ds_cmap[0] == "#": - ds_cmap = _hex_no_alpha(ds_cmap) - - ds_result = _datashader_map_aggregate_to_color( + shaded = _ds_shade_categorical( agg, - cmap=ds_cmap, - color_key=color_key, - min_alpha=_convert_alpha_to_datashader_range(render_params.fill_alpha), + color_key, + color_vector, + render_params.fill_alpha, ) - elif aggregate_with_reduction is not None: # to shut up mypy - ds_cmap = render_params.cmap_params.cmap - # in case all elements have the same value X: we render them using cmap(0.0), - # using an artificial "span" of [X, X + 1] for the color bar - # else: all elements would get alpha=0 and the color bar would have a weird range - if aggregate_with_reduction[0] == aggregate_with_reduction[1]: - ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False) - aggregate_with_reduction = ( - aggregate_with_reduction[0], - aggregate_with_reduction[0] + 1, - ) - - ds_result = _datashader_map_aggregate_to_color( + else: + shaded, nan_shaded, reduction_bounds = _ds_shade_continuous( agg, - cmap=ds_cmap, - min_alpha=_convert_alpha_to_datashader_range(render_params.fill_alpha), - span=ds_span, - clip=norm.clip, + color_span, + norm, + render_params.cmap_params.cmap, + render_params.fill_alpha, + reduction_bounds, + nan_agg, + na_color_hex, ) - if continuous_nan_shapes is not None: - # for coloring by continuous variable: render nan shapes separately - nan_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) - continuous_nan_shapes = ds.tf.shade( - continuous_nan_shapes, - cmap=nan_color_hex, - how="linear", - min_alpha=_convert_alpha_to_datashader_range(render_params.fill_alpha), - ) + _render_ds_outlines(cvs, transformed_element, render_params, fig_params, ax, factor, x_ext + y_ext) - # shade outlines if needed - if render_params.outline_alpha[0] > 0 and isinstance(render_params.outline_params.outer_outline_color, Color): - outline_color = render_params.outline_params.outer_outline_color.get_hex() - ds_outlines = ds.tf.shade( - agg_outlines, - cmap=outline_color, - min_alpha=_convert_alpha_to_datashader_range(render_params.outline_alpha[0]), - how="linear", - ) - # inner outlines - if render_params.outline_alpha[1] > 0 and isinstance(render_params.outline_params.inner_outline_color, Color): - outline_color = render_params.outline_params.inner_outline_color.get_hex() - ds_inner_outlines = ds.tf.shade( - agg_inner_outlines, - cmap=outline_color, - min_alpha=_convert_alpha_to_datashader_range(render_params.outline_alpha[1]), - how="linear", - ) - - # render outline image(s) - if render_params.outline_alpha[0] > 0: - rgba_image, trans_data = _create_image_from_datashader_result(ds_outlines, factor, ax) - _ax_show_and_transform( - rgba_image, - trans_data, - ax, - zorder=render_params.zorder, - alpha=render_params.outline_alpha[0], - extent=x_ext + y_ext, - ) - if render_params.outline_alpha[1] > 0: - rgba_image, trans_data = _create_image_from_datashader_result(ds_inner_outlines, factor, ax) - _ax_show_and_transform( - rgba_image, - trans_data, - ax, - zorder=render_params.zorder, - alpha=render_params.outline_alpha[1], - extent=x_ext + y_ext, - ) - - if continuous_nan_shapes is not None: - # for coloring by continuous variable: render nan points separately - rgba_image_nan, trans_data_nan = _create_image_from_datashader_result(continuous_nan_shapes, factor, ax) - _ax_show_and_transform( - rgba_image_nan, - trans_data_nan, - ax, - zorder=render_params.zorder, - alpha=render_params.fill_alpha, - extent=x_ext + y_ext, - ) - rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) - _cax = _ax_show_and_transform( - rgba_image, - trans_data, + _cax = _render_ds_image( ax, - zorder=render_params.zorder, - alpha=render_params.fill_alpha, - extent=x_ext + y_ext, + shaded, + factor, + render_params.zorder, + render_params.fill_alpha, + x_ext + y_ext, + nan_result=nan_shaded, ) - cax = None - if aggregate_with_reduction is not None: - vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin - vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax - if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: - assert norm.vmin is not None - assert norm.vmax is not None - # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and - # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) - vmin = norm.vmin - 0.5 - vmax = norm.vmin + 0.5 - cax = ScalarMappable( - norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), - cmap=render_params.cmap_params.cmap, - ) + cax = _build_ds_colorbar(reduction_bounds, norm, render_params.cmap_params.cmap) elif method == "matplotlib": # render outlines separately to ensure they are always underneath the shape @@ -698,43 +576,23 @@ def _render_shapes( vmax = 1.0 _cax.set_clim(vmin=vmin, vmax=vmax) - if _want_decorations(color_vector, render_params.cmap_params.na_color): - # necessary in case different shapes elements are annotated with one table - if color_source_vector is not None and render_params.col_for_color is not None: - color_source_vector = color_source_vector.remove_unused_categories() - - wants_colorbar = _should_request_colorbar( - render_params.colorbar, - has_mappable=cax is not None, - is_continuous=render_params.col_for_color is not None and color_source_vector is None, - ) - - _ = _decorate_axs( - ax=ax, - cax=cax, - fig_params=fig_params, - adata=table, - value_to_plot=col_for_color, - color_source_vector=color_source_vector, - color_vector=color_vector, - palette=palette, - alpha=render_params.fill_alpha, - na_color=render_params.cmap_params.na_color, - legend_fontsize=legend_params.legend_fontsize, - legend_fontweight=legend_params.legend_fontweight, - legend_loc=legend_params.legend_loc, - legend_fontoutline=legend_params.legend_fontoutline, - na_in_legend=legend_params.na_in_legend, - colorbar=wants_colorbar and legend_params.colorbar, - colorbar_params=render_params.colorbar_params, - colorbar_requests=colorbar_requests, - colorbar_label=_resolve_colorbar_label( - render_params.colorbar_params, - col_for_color if isinstance(col_for_color, str) else None, - ), - scalebar_dx=scalebar_params.scalebar_dx, - scalebar_units=scalebar_params.scalebar_units, - ) + _add_legend_and_colorbar( + ax=ax, + cax=cax, + fig_params=fig_params, + adata=table, + col_for_color=col_for_color, + color_source_vector=color_source_vector, + color_vector=color_vector, + palette=palette, + alpha=render_params.fill_alpha, + na_color=render_params.cmap_params.na_color, + legend_params=legend_params, + colorbar=render_params.colorbar, + colorbar_params=render_params.colorbar_params, + colorbar_requests=colorbar_requests, + scalebar_params=scalebar_params, + ) def _render_points( @@ -972,59 +830,25 @@ def _render_points( if color_by_categorical and not isinstance(color_dtype, pd.CategoricalDtype): transformed_element[col_for_color] = transformed_element[col_for_color].astype("category") - aggregate_with_reduction = None - continuous_nan_points = None - if col_for_color is not None: - if color_by_categorical: - # add nan as category so that nan points are shown in the nan color - transformed_element[col_for_color] = _inject_ds_nan_sentinel(transformed_element[col_for_color]) - agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count())) - else: - reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum" - logger.info( - f'Using the datashader reduction "{reduction_name}". "max" will give an output very close ' - "to the matplotlib result." - ) - agg = _datashader_aggregate_with_function( - render_params.ds_reduction, - cvs, - transformed_element, - col_for_color, - "points", - ) - # save min and max values for drawing the colorbar - aggregate_with_reduction = (agg.min(), agg.max()) - # nan points need to be rendered separately (else: invisible, bc nan is skipped by aggregation methods) - transformed_element_nan_color = transformed_element[transformed_element[col_for_color].isnull()] - if len(transformed_element_nan_color) > 0: - continuous_nan_points = _datashader_aggregate_with_function( - "any", cvs, transformed_element_nan_color, None, "points" - ) - else: - agg = cvs.points(transformed_element, "x", "y", agg=ds.count()) - - ds_span = None - if norm.vmin is not None or norm.vmax is not None: - norm.vmin = np.min(agg) if norm.vmin is None else norm.vmin - norm.vmax = np.max(agg) if norm.vmax is None else norm.vmax - ds_span = [norm.vmin, norm.vmax] - if norm.vmin == norm.vmax: - ds_span = [0, 1] - if norm.clip: - # all data is mapped to 0.5 - agg = (agg - agg) + 0.5 - else: - # values equal to norm.vmin are mapped to 0.5, the rest to -1 or 2 - agg = agg.where((agg >= norm.vmin) | (np.isnan(agg)), other=-1) - agg = agg.where((agg <= norm.vmin) | (np.isnan(agg)), other=2) - agg = agg.where((agg != norm.vmin) | (np.isnan(agg)), other=0.5) - - color_key: dict[str, str] | None = None - if color_by_categorical and col_for_color is not None: - cat_series = _coerce_categorical_source(transformed_element[col_for_color]) - color_key = _build_datashader_color_key( - cat_series, color_vector, render_params.cmap_params.na_color.get_hex() - ) + agg, reduction_bounds, nan_agg = _ds_aggregate( + cvs, + transformed_element, + col_for_color, + color_by_categorical, + render_params.ds_reduction, + "sum", + "points", + ) + + agg, color_span = _apply_ds_norm(agg, norm) + na_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) + color_key = _build_color_key( + transformed_element, + col_for_color, + color_by_categorical, + color_vector, + na_color_hex, + ) if ( color_vector is not None @@ -1034,83 +858,40 @@ def _render_points( ): color_vector = np.asarray([_hex_no_alpha(x) for x in color_vector]) + nan_shaded = None if color_by_categorical or col_for_color is None: - ds_result = _datashader_map_aggregate_to_color( - ds.tf.spread(agg, px=px), - cmap=color_vector[0], - color_key=color_key, - min_alpha=_convert_alpha_to_datashader_range(render_params.alpha), + shaded = _ds_shade_categorical( + agg, + color_key, + color_vector, + render_params.alpha, + spread_px=px, ) else: - spread_how = _datshader_get_how_kw_for_spread(render_params.ds_reduction) - agg = ds.tf.spread(agg, px=px, how=spread_how) - aggregate_with_reduction = (agg.min(), agg.max()) - - ds_cmap = render_params.cmap_params.cmap - # in case all elements have the same value X: we render them using cmap(0.0), - # using an artificial "span" of [X, X + 1] for the color bar - # else: all elements would get alpha=0 and the color bar would have a weird range - if aggregate_with_reduction[0] == aggregate_with_reduction[1] and (ds_span is None or ds_span != [0, 1]): - ds_cmap = matplotlib.colors.to_hex(render_params.cmap_params.cmap(0.0), keep_alpha=False) - aggregate_with_reduction = ( - aggregate_with_reduction[0], - aggregate_with_reduction[0] + 1, - ) - - ds_result = _datashader_map_aggregate_to_color( + shaded, nan_shaded, reduction_bounds = _ds_shade_continuous( agg, - cmap=ds_cmap, - span=ds_span, - clip=norm.clip, - min_alpha=_convert_alpha_to_datashader_range(render_params.alpha), + color_span, + norm, + render_params.cmap_params.cmap, + render_params.alpha, + reduction_bounds, + nan_agg, + na_color_hex, + spread_px=px, + ds_reduction=render_params.ds_reduction, ) - if continuous_nan_points is not None: - # for coloring by continuous variable: render nan points separately - nan_color_hex = _hex_no_alpha(render_params.cmap_params.na_color.get_hex()) - continuous_nan_points = ds.tf.spread(continuous_nan_points, px=px, how="max") - continuous_nan_points = ds.tf.shade( - continuous_nan_points, - cmap=nan_color_hex, - how="linear", - ) - - if continuous_nan_points is not None: - # for coloring by continuous variable: render nan points separately - rgba_image_nan, trans_data_nan = _create_image_from_datashader_result(continuous_nan_points, factor, ax) - _ax_show_and_transform( - rgba_image_nan, - trans_data_nan, - ax, - zorder=render_params.zorder, - alpha=render_params.alpha, - extent=x_ext + y_ext, - ) - rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax) - _ax_show_and_transform( - rgba_image, - trans_data, + _render_ds_image( ax, - zorder=render_params.zorder, - alpha=render_params.alpha, - extent=x_ext + y_ext, + shaded, + factor, + render_params.zorder, + render_params.alpha, + x_ext + y_ext, + nan_result=nan_shaded, ) - cax = None - if aggregate_with_reduction is not None: - vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin - vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax - if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: - assert norm.vmin is not None - assert norm.vmax is not None - # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and - # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) - vmin = norm.vmin - 0.5 - vmax = norm.vmin + 0.5 - cax = ScalarMappable( - norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax), - cmap=render_params.cmap_params.cmap, - ) + cax = _build_ds_colorbar(reduction_bounds, norm, render_params.cmap_params.cmap) elif method == "matplotlib": # update axis limits if plot was empty before (necessary if datashader comes after) @@ -1135,44 +916,23 @@ def _render_points( ax.set_xbound(extent["x"]) ax.set_ybound(extent["y"]) - if _want_decorations(color_vector, render_params.cmap_params.na_color): - if color_source_vector is None: - palette = ListedColormap(dict.fromkeys(color_vector)) - else: - palette = ListedColormap(dict.fromkeys(color_vector[~pd.Categorical(color_source_vector).isnull()])) - - wants_colorbar = _should_request_colorbar( - render_params.colorbar, - has_mappable=cax is not None, - is_continuous=col_for_color is not None and color_source_vector is None, - ) - - _ = _decorate_axs( - ax=ax, - cax=cax, - fig_params=fig_params, - adata=adata, - value_to_plot=col_for_color, - color_source_vector=color_source_vector, - color_vector=color_vector, - palette=palette, - alpha=render_params.alpha, - na_color=render_params.cmap_params.na_color, - legend_fontsize=legend_params.legend_fontsize, - legend_fontweight=legend_params.legend_fontweight, - legend_loc=legend_params.legend_loc, - legend_fontoutline=legend_params.legend_fontoutline, - na_in_legend=legend_params.na_in_legend, - colorbar=wants_colorbar and legend_params.colorbar, - colorbar_params=render_params.colorbar_params, - colorbar_requests=colorbar_requests, - colorbar_label=_resolve_colorbar_label( - render_params.colorbar_params, - col_for_color if isinstance(col_for_color, str) else None, - ), - scalebar_dx=scalebar_params.scalebar_dx, - scalebar_units=scalebar_params.scalebar_units, - ) + _add_legend_and_colorbar( + ax=ax, + cax=cax, + fig_params=fig_params, + adata=adata, + col_for_color=col_for_color, + color_source_vector=color_source_vector, + color_vector=color_vector, + palette=None, + alpha=render_params.alpha, + na_color=render_params.cmap_params.na_color, + legend_params=legend_params, + colorbar=render_params.colorbar, + colorbar_params=render_params.colorbar_params, + colorbar_requests=colorbar_requests, + scalebar_params=scalebar_params, + ) def _render_images( @@ -1366,7 +1126,7 @@ def _render_images( seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)] channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - # Stack (n_channels, height, width) → (height*width, n_channels) + # Stack (n_channels, height, width) -> (height*width, n_channels) H, W = next(iter(layers.values())).shape comp_rgb = np.zeros((H, W, 3), dtype=float)