diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 52297005..8fc78f00 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -338,7 +338,7 @@ def _render_shapes( cax = None if aggregate_with_reduction is not None: vmin = aggregate_with_reduction[0].values if norm.vmin is None else norm.vmin - vmax = aggregate_with_reduction[1].values if norm.vmin is None else norm.vmax + vmax = aggregate_with_reduction[1].values if norm.vmax is None else norm.vmax if (norm.vmin is not None or norm.vmax is not None) and norm.vmin == norm.vmax: # value (vmin=vmax) is placed in the middle of the colorbar so that we can distinguish it from over and # under values in case clip=True or clip=False with cmap(under)=cmap(0) & cmap(over)=cmap(1) @@ -846,20 +846,22 @@ def _render_images( # 2) Image has any number of channels but 1 else: layers = {} - for ch_index, c in enumerate(channels): - layers[c] = img.sel(c=c).copy(deep=True).squeeze() - - if not isinstance(render_params.cmap_params, list): - if render_params.cmap_params.norm is not None: - layers[c] = render_params.cmap_params.norm(layers[c]) + for ch_idx, ch in enumerate(channels): + layers[ch] = img.sel(c=ch).copy(deep=True).squeeze() + if isinstance(render_params.cmap_params, list): + ch_norm = render_params.cmap_params[ch_idx].norm + ch_cmap_is_default = render_params.cmap_params[ch_idx].cmap_is_default else: - if render_params.cmap_params[ch_index].norm is not None: - layers[c] = render_params.cmap_params[ch_index].norm(layers[c]) + ch_norm = render_params.cmap_params.norm + ch_cmap_is_default = render_params.cmap_params.cmap_is_default + + if not ch_cmap_is_default and ch_norm is not None: + layers[ch_idx] = ch_norm(layers[ch_idx]) # 2A) Image has 3 channels, no palette info, and no/only one cmap was given if palette is None and n_channels == 3 and not isinstance(render_params.cmap_params, list): if render_params.cmap_params.cmap_is_default: # -> use RGB - stacked = np.stack([layers[c] for c in channels], axis=-1) + stacked = np.stack([layers[ch] for ch in layers], axis=-1) else: # -> use given cmap for each channel channel_cmaps = [render_params.cmap_params.cmap] * n_channels stacked = ( @@ -892,12 +894,54 @@ def _render_images( # overwrite if n_channels == 2 for intuitive result if n_channels == 2: seed_colors = ["#ff0000ff", "#00ff00ff"] - else: + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + colored = np.stack( + [channel_cmaps[ch_ind](layers[ch]) for ch_ind, ch in enumerate(channels)], + 0, + ).sum(0) + colored = colored[:, :, :3] + elif n_channels == 3: seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + colored = np.stack( + [channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], + 0, + ).sum(0) + colored = colored[:, :, :3] + else: + if isinstance(render_params.cmap_params, list): + cmap_is_default = render_params.cmap_params[0].cmap_is_default + else: + cmap_is_default = render_params.cmap_params.cmap_is_default - channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] - colored = np.stack([channel_cmaps[ind](layers[ch]) for ind, ch in enumerate(channels)], 0).sum(0) - colored = colored[:, :, :3] + if cmap_is_default: + seed_colors = _get_colors_for_categorical_obs(list(range(n_channels))) + else: + # Sample n_channels colors evenly from the colormap + if isinstance(render_params.cmap_params, list): + seed_colors = [ + render_params.cmap_params[i].cmap(i / (n_channels - 1)) for i in range(n_channels) + ] + else: + seed_colors = [render_params.cmap_params.cmap(i / (n_channels - 1)) for i in range(n_channels)] + channel_cmaps = [_get_linear_colormap([c], "k")[0] for c in seed_colors] + + # Stack (n_channels, height, width) → (height*width, n_channels) + H, W = next(iter(layers.values())).shape + comp_rgb = np.zeros((H, W, 3), dtype=float) + + # For each channel: map to RGBA, apply constant alpha, then add + for ch_idx, ch in enumerate(channels): + layer_arr = layers[ch] + rgba = channel_cmaps[ch_idx](layer_arr) + rgba[..., 3] = render_params.alpha + comp_rgb += rgba[..., :3] * rgba[..., 3][..., None] + + colored = np.clip(comp_rgb, 0, 1) + logger.info( + f"Your image has {n_channels} channels. Sampling categorical colors and using " + f"multichannel strategy 'stack' to render." + ) # TODO: update when pca is added as strategy _ax_show_and_transform( colored, @@ -943,6 +987,7 @@ def _render_images( zorder=render_params.zorder, ) + # 2D) Image has n channels, no palette but cmap info elif palette is not None and got_multiple_cmaps: raise ValueError("If 'palette' is provided, 'cmap' must be None.") diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index a2e8f767..f3f47b58 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -2006,7 +2006,7 @@ def _validate_col_for_column_table( table_name = next(iter(tables)) if len(tables) > 1: warnings.warn( - f"Multiple tables contain color column, using {table_name}", + f"Multiple tables contain column '{col_for_color}', using table '{table_name}'.", UserWarning, stacklevel=2, ) @@ -2042,25 +2042,49 @@ def _validate_image_render_params( element_params[el] = {} spatial_element = param_dict["sdata"][el] + # robustly get channel names from image or multiscale image spatial_element_ch = ( - spatial_element.c if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c + spatial_element.c.values if isinstance(spatial_element, DataArray) else spatial_element["scale0"].c.values ) - if (channel := param_dict["channel"]) is not None and ( - (isinstance(channel[0], int) and max([abs(ch) for ch in channel]) <= len(spatial_element_ch)) - or all(ch in spatial_element_ch for ch in channel) - ): + channel = param_dict["channel"] + if channel is not None: + # Normalize channel to always be a list of str or a list of int + if isinstance(channel, str): + channel = [channel] + + if isinstance(channel, int): + channel = [channel] + + # If channel is a list, ensure all elements are the same type + if not (isinstance(channel, list) and channel and all(isinstance(c, type(channel[0])) for c in channel)): + raise TypeError("Each item in 'channel' list must be of the same type, either string or integer.") + + invalid = [c for c in channel if c not in spatial_element_ch] + if invalid: + raise ValueError( + f"Invalid channel(s): {', '.join(str(c) for c in invalid)}. Valid choices are: {spatial_element_ch}" + ) element_params[el]["channel"] = channel else: element_params[el]["channel"] = None element_params[el]["alpha"] = param_dict["alpha"] - if isinstance(palette := param_dict["palette"], list): + palette = param_dict["palette"] + assert isinstance(palette, list | type(None)) # if present, was converted to list, just to make sure + + if isinstance(palette, list): + # case A: single palette for all channels if len(palette) == 1: palette_length = len(channel) if channel is not None else len(spatial_element_ch) palette = palette * palette_length - if (channel is not None and len(palette) != len(channel)) and len(palette) != len(spatial_element_ch): - palette = None + # case B: one palette per channel (either given or derived from channel length) + channels_to_use = spatial_element_ch if element_params[el]["channel"] is None else channel + if channels_to_use is not None and len(palette) != len(channels_to_use): + raise ValueError( + f"Palette length ({len(palette)}) does not match channel length " + f"({', '.join(str(c) for c in channels_to_use)})." + ) element_params[el]["palette"] = palette element_params[el]["na_color"] = param_dict["na_color"] @@ -2086,7 +2110,7 @@ def _validate_image_render_params( def _get_wanted_render_elements( sdata: SpatialData, sdata_wanted_elements: list[str], - params: (ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams), + params: ImageRenderParams | LabelsRenderParams | PointsRenderParams | ShapesRenderParams, cs: str, element_type: Literal["images", "labels", "points", "shapes"], ) -> tuple[list[str], list[str], bool]: @@ -2243,7 +2267,7 @@ def _create_image_from_datashader_result( def _datashader_aggregate_with_function( - reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None), + reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, cvs: Canvas, spatial_element: GeoDataFrame | dask.dataframe.core.DataFrame, col_for_color: str | None, @@ -2307,7 +2331,7 @@ def _datashader_aggregate_with_function( def _datshader_get_how_kw_for_spread( - reduction: (Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None), + reduction: Literal["sum", "mean", "any", "count", "std", "var", "max", "min"] | None, ) -> str: # Get the best input for the how argument of ds.tf.spread(), needed for numerical values reduction = reduction or "sum" diff --git a/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png b/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png index c22b9f2b..16bedd33 100644 Binary files a/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png and b/tests/_images/Extent_extent_of_img_is_correct_after_spatial_query.png differ