diff --git a/CHANGELOG.md b/CHANGELOG.md index 431be5c..0a01d1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ All notable changes to this project will be documented in this file. - **`--dry-run` flag for `kompot de` CLI**: estimates memory, disk, and output field requirements without running the analysis. Outputs machine-parseable JSON to stdout and a human-readable report to stderr. Exit code reflects feasibility. - **`kompot.configure_logging(stream)`**: reconfigure the kompot logger output stream. The CLI now logs to stderr by default, keeping stdout clean for machine-parseable output (dry-run JSON, table output). + - **`kompot.plot.dotplot`**: ax-embeddable fold-change-per-group dotplot. Color = mean of a per-cell LFC layer within each `groupby` category; size = fraction of cells expressing. Gene selection is either an explicit list or auto-picked top-N by Mahalanobis from run history (with optional `filter_key`, e.g. restricting to `is_de=True`). Pass `axes=(main, cbar, size_legend)` to compose into a larger figure, or leave `axes=None` for a standalone figure. Unlike `scanpy.pl.DotPlot`, this function does not build its own `GridSpec` and does not fight externally-provided axes, which is the whole reason it exists. Shares gene-selection, layer-fetch, and colormap-normalization primitives with `kompot.plot.heatmap` via the existing `heatmap.utils` helpers. ### Improvements @@ -18,7 +19,6 @@ All notable changes to this project will be documented in this file. - Add `smooth_expression()` module to Sphinx API docs. - Add `RunInfo.to_settings()` and `call_args()` to documented members. - Fix "Gene Expression Imputation" → "Gene Expression Smoothing" in docs toctree. - ## [0.7.0] - 2026-04-13 ### Breaking changes diff --git a/examples/02_differential_expression_detailed.ipynb b/examples/02_differential_expression_detailed.ipynb index bf3e399..aa90e30 100644 --- a/examples/02_differential_expression_detailed.ipynb +++ b/examples/02_differential_expression_detailed.ipynb @@ -106,7 +106,7 @@ { "data": { "text/plain": [ - "AnnData object with n_obs × n_vars = 8090 × 16285\n", + "AnnData object with n_obs \u00d7 n_vars = 8090 \u00d7 16285\n", " obs: 'Compartment', 'Replicate', 'Age', 'Sample', 'Info', 'batch', 'doublet_score', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_hb', 'pct_counts_hb', 'S_score', 'G2M_score', 'phase', 'leiden', 'phenograph', 'highres_celltype', 'midres_celltype'\n", " var: 'gene_ids', 'feature_types', 'genome', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'hb', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'\n", " uns: 'Age_colors', 'Compartment_colors', 'DMEigenValues', 'Info_colors', 'README', 'Replicate_colors', 'Sample_colors', 'batch_colors', 'draw_graph', 'highres_celltype_colors', 'hvg', 'leiden', 'leiden_colors', 'midres_celltype_colors', 'neighbors', 'pca', 'phase_colors', 'umap', 'DM_EigenValues'\n", @@ -148,7 +148,7 @@ "\n", "### FDR Settings (`kompot.FDRSettings`)\n", "\n", - "- **`null_genes`**: Number of permuted genes for FDR estimation (default: \"auto\" → 2000)\n", + "- **`null_genes`**: Number of permuted genes for FDR estimation (default: \"auto\" \u2192 2000)\n", " - Higher values give better FDR estimates but increase computation time\n", " - Set to 0 to disable FDR computation\n", "\n", @@ -669,6 +669,90 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dotplot Customization\n", + "\n", + "The [dotplot](https://kompot.readthedocs.io/en/latest/plotting.html#kompot.plot.dotplot) function renders the same per-group fold-change summary as `kompot.plot.heatmap` but adds a second encoding dimension: dot **color** is the mean per-cell LFC (like `heatmap(fold_change_mode=True)`), and dot **size** is the fraction of cells in each category whose expression exceeds a threshold. Both encodings come from the same kompot DE run.\n", + "\n", + "It also accepts externally-provided axes (`axes=(main, cbar, size_legend)`), so it composes cleanly into figure-level layouts where `scanpy.pl.DotPlot`'s built-in GridSpec would fight back.\n", + "\n", + "### Auto-pick top genes by Mahalanobis\n", + "\n", + "With `genes=None`, the top `n_top` genes are picked by the Mahalanobis column inferred from the latest kompot DE run \u2014 same default as `heatmap`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "categories = [\n", + " c for c in adata.obs[CELL_TYPE_COLUMN].cat.categories\n", + " if c != \"Plasma cell\"\n", + "]\n", + "\n", + "kompot.plot.dotplot(\n", + " adata,\n", + " genes=None,\n", + " groupby=CELL_TYPE_COLUMN,\n", + " categories_order=categories,\n", + " n_top=15,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Custom gene list\n", + "\n", + "Pass an explicit `genes` list when you want to match a specific comparison or reproduce a figure \u2014 same shape as `heatmap`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "kompot.plot.dotplot(\n", + " adata,\n", + " genes=custom_genes,\n", + " groupby=CELL_TYPE_COLUMN,\n", + " categories_order=categories,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Adjusting the color scale and size encoding\n", + "\n", + "The color scale is symmetric around 0 keyed on the `vabs_pct`-th percentile of `|LFC|` by default. Tighten the scale when a few outlier genes compress the visible dynamic range, and tune `size_exponent` / `dot_max` to emphasise fraction-expressing differences:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "kompot.plot.dotplot(\n", + " adata,\n", + " genes=custom_genes,\n", + " groupby=CELL_TYPE_COLUMN,\n", + " categories_order=categories,\n", + " vabs_pct=90, # tighter color scale\n", + " size_exponent=2.0, # stronger fraction-expressing emphasis\n", + " dot_max=80,\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -973,14 +1057,14 @@ " display: none;\n", " }\n", " .kompot-runinfo summary::before {\n", - " content: \"▶ \";\n", + " content: \"\u25b6 \";\n", " display: inline-block;\n", " font-size: 0.8em;\n", " color: #888;\n", " margin-right: 5px;\n", " }\n", " .kompot-runinfo details[open] > summary::before {\n", - " content: \"▼ \";\n", + " content: \"\u25bc \";\n", " }\n", " .kompot-runinfo table {\n", " width: 100%;\n", @@ -1168,14 +1252,14 @@ " display: none;\n", " }\n", " .kompot-comparison summary::before {\n", - " content: \"▶ \";\n", + " content: \"\u25b6 \";\n", " display: inline-block;\n", " font-size: 0.8em;\n", " color: #888;\n", " margin-right: 5px;\n", " }\n", " .kompot-comparison details[open] > summary::before {\n", - " content: \"▼ \";\n", + " content: \"\u25bc \";\n", " }\n", " .kompot-comparison table {\n", " width: 100%;\n", @@ -1433,17 +1517,17 @@ " Memory: 31.76 GB (10% of available)\n", "\n", "Memory Allocations:\n", - " • Mellon precision matrix L (condition 1, 2,917/3,116 cells) (np.int64(2917), np.int64(2917)): 64.92 MB\n", - " • Mellon precision matrix L (condition 2, 2,917/3,116 cells) (np.int64(3116), np.int64(3116)): 74.08 MB\n", - " • Imputed expression (condition 1) (8090, 18285): 1.10 GB → adata.layers['kompot_de_Young_imputed']\n", - " • Imputed expression (condition 2) (8090, 18285): 1.10 GB → adata.layers['kompot_de_Old_imputed']\n", - " • Fold change (8090, 18285): 1.10 GB → adata.layers['kompot_de_Young_to_Old_fold_change']\n", - " • Temporary matrices during predictions (batch_size=100) (100, 5000) + (100, 18285): 17.77 MB\n", - " • Peak intermediate arrays during predictions (~25 arrays) 25×(8090, 18285): 27.55 GB\n", - " • Function predictor covariances (per condition) (5000, 5000): 381.47 MB\n", - " • Combined covariance matrix (5000, 5000): 190.73 MB\n", - " • Cholesky decomposition (for Mahalanobis) (5000, 5000): 190.73 MB\n", - " • Mahalanobis batch processing (batch_size=100) (100, 5000): 3.81 MB\n", + " \u2022 Mellon precision matrix L (condition 1, 2,917/3,116 cells) (np.int64(2917), np.int64(2917)): 64.92 MB\n", + " \u2022 Mellon precision matrix L (condition 2, 2,917/3,116 cells) (np.int64(3116), np.int64(3116)): 74.08 MB\n", + " \u2022 Imputed expression (condition 1) (8090, 18285): 1.10 GB \u2192 adata.layers['kompot_de_Young_imputed']\n", + " \u2022 Imputed expression (condition 2) (8090, 18285): 1.10 GB \u2192 adata.layers['kompot_de_Old_imputed']\n", + " \u2022 Fold change (8090, 18285): 1.10 GB \u2192 adata.layers['kompot_de_Young_to_Old_fold_change']\n", + " \u2022 Temporary matrices during predictions (batch_size=100) (100, 5000) + (100, 18285): 17.77 MB\n", + " \u2022 Peak intermediate arrays during predictions (~25 arrays) 25\u00d7(8090, 18285): 27.55 GB\n", + " \u2022 Function predictor covariances (per condition) (5000, 5000): 381.47 MB\n", + " \u2022 Combined covariance matrix (5000, 5000): 190.73 MB\n", + " \u2022 Cholesky decomposition (for Mahalanobis) (5000, 5000): 190.73 MB\n", + " \u2022 Mahalanobis batch processing (batch_size=100) (100, 5000): 3.81 MB\n", "\n", "Output Fields:\n", " adata.layers:\n", @@ -1457,16 +1541,16 @@ " - kompot_de_Young_to_Old_is_de\n", "\n", "Info:\n", - " ℹ Null distribution will use 2000 additional genes (total: 18285 genes processed)\n", - " ℹ Cell batching reduces memory: Each of 4 prediction operations uses ~17.77 MB temporary arrays instead of 1.40 GB (saving 1.39 GB).\n", - " ℹ Prediction creates ~25 intermediate arrays of shape (8,090, 18285). These coexist at peak memory (27.55 GB) but are freed before completion.\n", - " ℹ Mahalanobis computation processes 100 genes per batch. Reduce via gp=GPSettings(batch_size=...) to lower peak memory (currently 3.81 MB for batch arrays).\n", + " \u2139 Null distribution will use 2000 additional genes (total: 18285 genes processed)\n", + " \u2139 Cell batching reduces memory: Each of 4 prediction operations uses ~17.77 MB temporary arrays instead of 1.40 GB (saving 1.39 GB).\n", + " \u2139 Prediction creates ~25 intermediate arrays of shape (8,090, 18285). These coexist at peak memory (27.55 GB) but are freed before completion.\n", + " \u2139 Mahalanobis computation processes 100 genes per batch. Reduce via gp=GPSettings(batch_size=...) to lower peak memory (currently 3.81 MB for batch arrays).\n", "\n", "Warnings:\n", - " ⚠ Results with result_key='kompot_de' already exist (run_id=1). Previous run: 2026-03-26T05:39:38.441389 comparing Young to Mid (null_genes=2000). Fields that will be overwritten: var.kompot_de_Young_to_Old_mahalanobis, var.kompot_de_Young_to_Old_mean_lfc, layers.kompot_de_Young_imputed, layers.kompot_de_Old_imputed, layers.kompot_de_Young_to_Old_fold_change and 2 more\n", + " \u26a0 Results with result_key='kompot_de' already exist (run_id=1). Previous run: 2026-03-26T05:39:38.441389 comparing Young to Mid (null_genes=2000). Fields that will be overwritten: var.kompot_de_Young_to_Old_mahalanobis, var.kompot_de_Young_to_Old_mean_lfc, layers.kompot_de_Young_imputed, layers.kompot_de_Old_imputed, layers.kompot_de_Young_to_Old_fold_change and 2 more\n", "\n", "================================================================================\n", - "STATUS: ⚠ FEASIBLE WITH WARNINGS - Proceed with caution\n", + "STATUS: \u26a0 FEASIBLE WITH WARNINGS - Proceed with caution\n", "================================================================================\n" ] } @@ -1562,12 +1646,12 @@ "\n", "This tutorial covered:\n", "\n", - "✓ Customizing DE parameters (`null_genes`, `sigma`, `batch_size`) \n", - "✓ Advanced volcano plot options \n", - "✓ Expression visualization techniques \n", - "✓ Heatmap customization \n", - "✓ Managing multiple comparisons with `run_id` \n", - "✓ Resource planning with dry runs \n", + "\u2713 Customizing DE parameters (`null_genes`, `sigma`, `batch_size`) \n", + "\u2713 Advanced volcano plot options \n", + "\u2713 Expression visualization techniques \n", + "\u2713 Heatmap customization \n", + "\u2713 Managing multiple comparisons with `run_id` \n", + "\u2713 Resource planning with dry runs \n", "\n", "### Next Steps\n", "\n", diff --git a/kompot/plot/__init__.py b/kompot/plot/__init__.py index 7c46689..d099c66 100644 --- a/kompot/plot/__init__.py +++ b/kompot/plot/__init__.py @@ -152,6 +152,20 @@ def plot_smoothing(*args, **kwargs): raise ImportError("Smoothing plot unavailable due to missing dependencies.") +try: + from .dotplot import dotplot + + __all__.append("dotplot") +except ImportError as e: + logger.warning(f"Could not import dotplot function due to: {e}") + + def dotplot(*args, **kwargs): + raise ImportError( + "Dotplot unavailable due to missing dependencies. " + "matplotlib is required." + ) + + # Import StringDB report class try: from .stringdb import StringDBReport diff --git a/kompot/plot/dotplot.py b/kompot/plot/dotplot.py new file mode 100644 index 0000000..4a2f0bb --- /dev/null +++ b/kompot/plot/dotplot.py @@ -0,0 +1,509 @@ +"""Dotplot of kompot differential expression results. + +Ax-embeddable fold-change-per-group dotplot. Color encodes the mean of a +per-cell log fold-change layer within each ``groupby`` category; size +encodes the fraction of cells in each category whose expression exceeds a +threshold. Gene selection is either an explicit list or auto-picked by +top-N Mahalanobis from a kompot DE run. + +Unlike :func:`scanpy.pl.DotPlot`, this function does not build its own +``GridSpec`` and composes cleanly into externally-provided axes. Pass +``axes=(main, cbar, size_legend)`` to embed the dotplot into a composite +figure, or leave ``axes=None`` for a standalone figure. + +This module reuses three primitives from :mod:`kompot.plot.heatmap.utils` +so dotplot and heatmap share the same gene-selection and colorbar +semantics: + +* ``_prepare_gene_list`` — explicit list or top-N by inferred Mahalanobis + (with a small local addition for the ``filter_key`` path). +* ``_get_expression_matrix`` — dense layer / ``X`` fetch with sparse and + missing-layer handling. +* ``_setup_colormap_normalization`` — ``TwoSlopeNorm`` + colormap object. +""" + +from __future__ import annotations + +import logging +from typing import Optional, Sequence, Tuple + +import numpy as np +import pandas as pd +from anndata import AnnData + +# Reuse the heatmap primitives — keeps gene-list inference, layer +# fetching, and colormap normalization in one place. The per-group +# aggregation below mirrors heatmap's `groupby().mean()` idiom; a +# TODO notes this could still be lifted into a shared helper once a +# third consumer shows up (see below). +from .heatmap.utils import ( + _get_expression_matrix, + _infer_score_key, + _prepare_gene_list, + _setup_colormap_normalization, +) + +logger = logging.getLogger("kompot") + +try: + import matplotlib.pyplot as plt + from matplotlib.axes import Axes + from matplotlib.cm import ScalarMappable + from matplotlib.figure import Figure + from matplotlib.ticker import MaxNLocator +except ImportError as e: + raise ImportError( + "matplotlib is required for plotting: pip install matplotlib" + ) from e + + +def _infer_lfc_layer(adata: AnnData, run_id: int) -> Optional[str]: + """Fold-change layer name from kompot DE ``run_info``. + + No analog exists in ``heatmap.utils`` — heatmap colors by + per-condition *expression* means, while the dotplot colors by the + per-cell *LFC* layer kompot writes during DE. Kept local because of + that specificity. + """ + try: + from ..anndata.utils import get_run_from_history + except ImportError: + return None + try: + run_info = get_run_from_history(adata, run_id, analysis_type="de") + except Exception: + return None + if run_info is None: + return None + layer_keys = ( + run_info.get("smoothed_layer_keys") + or run_info.get("imputed_layer_keys") + or {} + ) + fc = layer_keys.get("fold_change") + if fc is None and "field_names" in run_info: + fc = run_info["field_names"].get("fold_change_key") + return fc + + +def _group_aggregate( + values: np.ndarray, + obs_labels: np.ndarray, + gene_names: Sequence[str], + categories: Sequence[str], + reducer: str, +) -> np.ndarray: + """Per-category reduction of ``(n_obs, n_genes)`` into ``(n_cats, n_genes)``. + + Mirrors the ``df.groupby(col, observed=True).mean()`` idiom heatmap + uses inline. Kept here (rather than in ``heatmap.utils``) until a + third consumer appears — the two call-sites differ enough in their + surrounding data plumbing that a premature abstraction would be + noise. TODO: lift into ``plot/_utils.py`` on third user. + """ + if reducer not in {"mean", "fraction"}: + raise ValueError(f"unsupported reducer: {reducer}") + frame = pd.DataFrame( + np.asarray(values, dtype=float), + columns=list(gene_names), + ) + frame["_group_"] = np.asarray(obs_labels) + grouped = frame.groupby("_group_", observed=True)[list(gene_names)] + reduced = grouped.mean() + return reduced.reindex(list(categories)).to_numpy(dtype=float) + + +def dotplot( + adata: AnnData, + genes: Optional[Sequence[str]], + groupby: str, + *, + lfc_layer: Optional[str] = None, + expr_layer: Optional[str] = None, + score_key: Optional[str] = None, + filter_key: Optional[str] = None, + n_top: int = 15, + categories_order: Optional[Sequence[str]] = None, + min_cells: int = 0, + expr_threshold: float = 0.0, + vabs_pct: float = 98.0, + vabs_min: float = 0.0, + vmax: Optional[float] = None, + cmap: str = "RdBu_r", + size_exponent: float = 1.5, + dot_max: float = 60.0, + dot_edge_color: str = "white", + dot_edge_lw: float = 0.2, + axes: Optional[Tuple[Axes, Axes, Axes]] = None, + figsize: Tuple[float, float] = (7.5, 3.2), + cbar_title: str = "mean LFC", + size_title: str = "fraction\nexpressing", + title: Optional[str] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + gene_label_fontsize: float = 6.5, + category_label_fontsize: float = 5.5, + italic_genes: bool = True, + run_id: int = -1, + return_fig: bool = False, + save: Optional[str] = None, +) -> Optional[Figure]: + """Kompot fold-change dotplot across groups. + + Each tile encodes two quantities for a (gene, group) pair: + + * **color** — mean of the per-cell LFC layer ``lfc_layer`` over cells + in the group (i.e. the per-cell kompot fold-change averaged within + the category). Symmetric diverging scale keyed on the + ``vabs_pct``-th percentile of ``|LFC|`` by default. + * **size** — fraction of cells in the group whose + ``expr_layer`` value exceeds ``expr_threshold`` (default 0). + + Parameters + ---------- + adata : AnnData + AnnData with kompot DE results. + genes : sequence of str or None + Explicit gene list (in display order, top-first). If ``None``, + the top ``n_top`` genes are picked by the Mahalanobis column + inferred from run history (or ``score_key`` if given). + groupby : str + Column in ``adata.obs`` used for the column axis of the dotplot. + lfc_layer : str, optional + Layer in ``adata.layers`` holding per-cell LFC (e.g. + ``"kompot_de__to__fold_change"``). If ``None``, inferred + from the latest kompot DE run. + expr_layer : str, optional + Layer used to compute fraction-expressed. Defaults to ``adata.X``. + score_key : str, optional + Column in ``adata.var`` used to rank genes when ``genes is None``. + Defaults to the Mahalanobis column inferred from run history. + filter_key : str, optional + Boolean column in ``adata.var`` restricting auto-pick candidates + (e.g. ``"kompot_de__to__is_de"``). + n_top : int, default 15 + Number of genes selected when ``genes is None``. + categories_order : sequence of str, optional + Subset/order of ``groupby`` categories to display. Categories not + present in ``adata.obs[groupby]`` are dropped. + min_cells : int, default 0 + Drop categories with fewer than this many cells. + expr_threshold : float, default 0.0 + Threshold used for the fraction-expressed calculation. + vabs_pct : float, default 98.0 + Percentile of ``|LFC|`` setting the symmetric color limits. + vabs_min : float, default 0.0 + Floor for the symmetric color limit (``max(pct, floor)``). Useful + when most tiles are near-zero and the percentile rounds down. + vmax : float, optional + Override the color limit. If provided, ``vabs_pct`` and + ``vabs_min`` are ignored and the scale spans ``[-vmax, vmax]``. + cmap : str, default ``"RdBu_r"`` + Diverging colormap name. + size_exponent : float, default 1.5 + Exponent applied to the fraction-expressed before scaling to a + scatter ``s`` value. Higher compresses low fractions. + dot_max : float, default 60.0 + Target scatter ``s`` (area in pt²) for a ``frac=1.0`` dot. + dot_edge_color : str, default ``"white"`` + dot_edge_lw : float, default 0.2 + axes : 3-tuple of :class:`matplotlib.axes.Axes`, optional + ``(main, cbar, size_legend)``. If ``None`` a standalone figure is + built with the three axes laid out in a constrained grid. + figsize : tuple, default ``(7.5, 3.2)`` + Size of the standalone figure; ignored when ``axes`` is given. + cbar_title, size_title : str + Titles for the colorbar and size legend. + title, xlabel, ylabel : str, optional + Main-axis annotations. Defaults leave these empty. + gene_label_fontsize, category_label_fontsize : float + Font sizes for the y-axis (genes) and x-axis (categories). + italic_genes : bool, default True + Italicize gene y-labels. + run_id : int, default -1 + Kompot DE run index used when inferring ``lfc_layer`` / + ``score_key``. + return_fig : bool, default False + If ``True`` and ``axes is None``, return the created ``Figure``. + save : str, optional + If given, ``fig.savefig(save, bbox_inches="tight")`` is called. + + Returns + ------- + matplotlib.figure.Figure or None + Figure if ``axes is None`` and ``return_fig`` is ``True``, + otherwise ``None``. + + Examples + -------- + Standalone figure, explicit genes:: + + import kompot + kompot.plot.dotplot( + adata, + genes=["Hbb-bh1", "Hba-x", "Tal1"], + groupby="celltype.mapped", + lfc_layer="kompot_de_WT_to_Tal1_fold_change", + expr_layer="logcounts", + return_fig=True, + ) + + Auto-pick top-20 Mahalanobis hits restricted to ``is_de=True``:: + + kompot.plot.dotplot( + adata, genes=None, groupby="celltype.mapped", + filter_key="kompot_de_WT_to_Tal1_is_de", + n_top=20, return_fig=True, + ) + + Embed into a composite figure:: + + fig = plt.figure(figsize=(8, 4), layout="constrained") + gs = fig.add_gridspec(1, 2, width_ratios=[1.0, 0.14]) + ax_main = fig.add_subplot(gs[0, 0]) + inner = gs[0, 1].subgridspec(2, 1) + ax_size = fig.add_subplot(inner[0, 0]) + ax_cbar = fig.add_subplot(inner[1, 0]) + kompot.plot.dotplot( + adata, genes=top_genes, groupby="celltype.mapped", + lfc_layer="kompot_de_WT_to_Tal1_fold_change", + axes=(ax_main, ax_cbar, ax_size), + ) + """ + # ---- resolve gene list -------------------------------------- + # Reuses `_infer_score_key` + `_prepare_gene_list` from + # heatmap.utils for the common path so score inference, strict/ + # non-strict fallback, and run-info logging stay identical across + # dotplot and heatmap. The `filter_key` branch stays local since + # heatmap has no equivalent (and adding one there is out of scope + # for this PR). + if genes is None: + score_key = _infer_score_key(adata, run_id=run_id, score_key=score_key) + if score_key is None: + raise ValueError( + "genes=None requires a Mahalanobis ranking column, but no " + "`score_key` was provided and none could be inferred from " + "run history. Pass `genes=[...]` or `score_key=...`." + ) + if score_key not in adata.var.columns: + raise KeyError( + f"score_key '{score_key}' not found in adata.var.columns" + ) + if filter_key is not None: + if filter_key not in adata.var.columns: + raise KeyError( + f"filter_key '{filter_key}' not found in adata.var.columns" + ) + mask = adata.var[filter_key].astype(bool).values + if not mask.any(): + raise ValueError( + f"no genes remain after applying filter_key='{filter_key}'" + ) + candidates = adata.var.index[mask] + ranked = ( + adata.var.loc[candidates, score_key] + .astype(float) + .dropna() + .sort_values(ascending=False) + ) + genes = list(ranked.head(n_top).index) + else: + genes, _, _ = _prepare_gene_list( + adata, + var_names=None, + n_top_genes=n_top, + score_key=score_key, + sort_genes=True, + run_id=run_id, + ) + else: + genes = list(genes) + + if not genes: + raise ValueError("no genes to plot") + missing = [g for g in genes if g not in adata.var_names] + if missing: + raise KeyError( + f"{len(missing)} gene(s) not in adata.var_names " + f"(first few: {missing[:5]})" + ) + + # ---- resolve layers ----------------------------------------- + # `_get_expression_matrix` handles the sparse→dense conversion and + # falls back to `adata.X` if the layer is absent (with a warning). + # For `lfc_layer` that fallback is wrong — a fold-change plot with + # raw expression would silently mislead — so we pre-validate. + if lfc_layer is None: + lfc_layer = _infer_lfc_layer(adata, run_id) + if lfc_layer is None: + raise ValueError( + "Could not infer a per-cell fold-change layer from run history. " + "Pass `lfc_layer=` explicitly (e.g. " + "'kompot_de__to__fold_change')." + ) + if lfc_layer not in adata.layers: + raise KeyError( + f"lfc_layer '{lfc_layer}' not found in adata.layers " + f"(available: {list(adata.layers)})" + ) + lfc_values = np.asarray(_get_expression_matrix(adata, genes, layer=lfc_layer)) + expr_values = np.asarray(_get_expression_matrix(adata, genes, layer=expr_layer)) + + # ---- resolve groups ----------------------------------------- + if groupby not in adata.obs.columns: + raise KeyError(f"groupby '{groupby}' not in adata.obs.columns") + obs_series = adata.obs[groupby] + all_cats = list(obs_series.astype("category").cat.categories) + if categories_order is None: + cats = all_cats + else: + cats = [c for c in categories_order if c in all_cats] + obs_labels = obs_series.astype(str).values + cats_str = [str(c) for c in cats] + if min_cells and min_cells > 0: + cats_str = [c for c in cats_str if int((obs_labels == c).sum()) >= min_cells] + if not cats_str: + raise ValueError( + "no categories remain after applying categories_order / min_cells" + ) + + # Per-category aggregation via the shared groupby idiom. + gene_names = [str(g) for g in genes] + lfc_mat = _group_aggregate( + lfc_values, obs_labels, gene_names, cats_str, reducer="mean", + ) + frac_mat = _group_aggregate( + (expr_values > expr_threshold).astype(float), + obs_labels, gene_names, cats_str, reducer="mean", + ) + + # rows = genes, cols = categories (fig-3/fig-4 swap_axes style) + lfc = lfc_mat.T + frac = frac_mat.T + + # ---- color scale -------------------------------------------- + if vmax is None: + finite = lfc[np.isfinite(lfc)] + if finite.size: + vabs = float(np.nanpercentile(np.abs(finite), vabs_pct)) + else: + vabs = 0.0 + vabs = max(vabs, float(vabs_min)) + else: + vabs = float(vmax) + if vabs <= 0: + vabs = 1e-12 # avoid Normalize(vmin=vmax) error; caller warned below + logger.warning( + "dotplot: degenerate color scale (|LFC| ≈ 0 everywhere). " + "Set vabs_min or vmax to force a meaningful range." + ) + + # Delegate norm + cmap resolution to the heatmap helper. This + # returns a `TwoSlopeNorm(vcenter=0, vmin=-vabs, vmax=vabs)` and + # resolves the colormap-string-to-object dance in one place. + norm, cmap_obj, vmin_eff, vmax_eff = _setup_colormap_normalization( + lfc, vcenter=0.0, vmin=-vabs, vmax=vabs, cmap=cmap, + ) + + # ---- dot sizes ---------------------------------------------- + frac_clip = np.clip(frac, 0.0, 1.0) + sizes = np.clip((frac_clip ** size_exponent) * dot_max, 0.0, dot_max) + + # ---- axes --------------------------------------------------- + standalone = axes is None + if standalone: + fig = plt.figure(figsize=figsize, layout="constrained") + gs = fig.add_gridspec(1, 2, width_ratios=[1.0, 0.14]) + ax_main = fig.add_subplot(gs[0, 0]) + inner = gs[0, 1].subgridspec(2, 1, height_ratios=[1.0, 1.0]) + ax_size = fig.add_subplot(inner[0, 0]) + ax_cbar = fig.add_subplot(inner[1, 0]) + else: + if len(axes) != 3: + raise ValueError( + "`axes` must be a 3-tuple (main, cbar, size_legend)" + ) + ax_main, ax_cbar, ax_size = axes + fig = ax_main.figure + + # ---- main --------------------------------------------------- + n_genes, n_cats = lfc.shape + xs, ys = np.meshgrid(np.arange(n_cats), np.arange(n_genes)) + ax_main.scatter( + xs.ravel(), ys.ravel(), + s=sizes.ravel(), c=lfc.ravel(), + cmap=cmap_obj, norm=norm, + edgecolors=dot_edge_color, linewidths=dot_edge_lw, + ) + ax_main.set_xticks(np.arange(n_cats)) + ax_main.set_xticklabels( + cats_str, + rotation=90, fontsize=category_label_fontsize, ha="center", + ) + ax_main.set_yticks(np.arange(n_genes)) + gene_label_kwargs = {"fontsize": gene_label_fontsize} + if italic_genes: + gene_label_kwargs["fontstyle"] = "italic" + ax_main.set_yticklabels(genes, **gene_label_kwargs) + ax_main.set_xlim(-0.5, n_cats - 0.5) + ax_main.set_ylim(n_genes - 0.5, -0.5) + ax_main.tick_params(axis="both", length=0, pad=1) + for spine in ("top", "right"): + ax_main.spines[spine].set_visible(False) + ax_main.grid(False) + if title is not None: + ax_main.set_title(title, pad=2, fontsize=7.5) + if xlabel is not None: + ax_main.set_xlabel(xlabel) + if ylabel is not None: + ax_main.set_ylabel(ylabel) + + # ---- colorbar ----------------------------------------------- + sm = ScalarMappable(norm=norm, cmap=cmap_obj) + cb = fig.colorbar(sm, cax=ax_cbar, orientation="vertical") + cb.locator = MaxNLocator(nbins=3) + cb.update_ticks() + cb.ax.tick_params(labelsize=category_label_fontsize, length=2) + for spine in cb.ax.spines.values(): + spine.set_visible(False) + if cbar_title: + ax_cbar.set_title( + cbar_title, fontsize=category_label_fontsize, pad=3, loc="center", + ) + + # ---- size legend -------------------------------------------- + size_fractions = (0.2, 0.4, 0.6, 0.8, 1.0) + n_leg = len(size_fractions) + ax_size.set_xlim(0, 1) + ax_size.set_ylim(-0.5, n_leg - 0.5) + ax_size.set_xticks([]) + ax_size.set_yticks([]) + for spine in ax_size.spines.values(): + spine.set_visible(False) + # Boost legend dot sizes relative to main so a narrow size-legend axis + # still renders readable dots; matches the fig-4 panel-N rendering. + legend_scale = 3.5 + for i, frac_ in enumerate(size_fractions): + y = n_leg - 1 - i + s_val = (float(frac_) ** size_exponent) * dot_max * legend_scale + ax_size.scatter( + [0.3], [y], s=s_val, c="0.5", + linewidths=0.3, edgecolors=dot_edge_color, zorder=2, + ) + ax_size.text( + 0.55, y, f"{int(round(frac_ * 100))}%", + ha="left", va="center", + fontsize=category_label_fontsize, color="0.2", + ) + if size_title: + ax_size.set_title( + size_title, fontsize=category_label_fontsize, pad=3, loc="center", + ) + + # ---- save / return ------------------------------------------ + if save is not None: + fig.savefig(save, bbox_inches="tight") + + if standalone and return_fig: + return fig + return None diff --git a/tests/test_plot_dotplot.py b/tests/test_plot_dotplot.py new file mode 100644 index 0000000..6be15b1 --- /dev/null +++ b/tests/test_plot_dotplot.py @@ -0,0 +1,382 @@ +"""Tests for kompot.plot.dotplot.""" + +from __future__ import annotations + +import matplotlib + +matplotlib.use("Agg") + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pytest + +try: + from anndata import AnnData + + _has_anndata = True +except ImportError: + _has_anndata = False + +pytestmark = pytest.mark.skipif( + not _has_anndata, reason="anndata is required for these tests" +) + + +def _make_adata(n_obs: int = 90, n_vars: int = 12, seed: int = 0) -> "AnnData": + """Minimal AnnData with an LFC layer, an expression layer, and a + Mahalanobis-like ranking column in var.""" + rng = np.random.default_rng(seed) + expr = np.clip(rng.normal(size=(n_obs, n_vars)) + 0.5, 0.0, None) + var_names = [f"G{i}" for i in range(n_vars)] + obs_names = [f"C{i}" for i in range(n_obs)] + var = pd.DataFrame( + { + "kompot_de_cond_mahalanobis": rng.uniform(0.0, 10.0, size=n_vars), + "kompot_de_cond_is_de": rng.uniform(size=n_vars) > 0.5, + }, + index=var_names, + ) + obs = pd.DataFrame( + {"celltype": pd.Categorical(rng.choice(list("ABC"), size=n_obs))}, + index=obs_names, + ) + adata = AnnData(X=expr, obs=obs, var=var) + adata.layers["lfc"] = rng.normal(size=(n_obs, n_vars)).astype(float) + adata.layers["logcounts"] = expr.astype(float) + return adata + + +def test_dotplot_import_and_exported(): + import kompot.plot as kp + + assert hasattr(kp, "dotplot") + assert "dotplot" in kp.__all__ + + +def test_dotplot_reuses_heatmap_primitives(): + """Sanity-check that the heatmap.utils reuse surface is wired up. + + This is load-bearing: if someone renames the helpers without updating + dotplot, the import will fail here long before the plot does. + """ + import sys + import kompot.plot # noqa: F401 — force module registration + + dp_mod = sys.modules["kompot.plot.dotplot"] + + from kompot.plot.heatmap.utils import ( + _get_expression_matrix, + _infer_score_key, + _prepare_gene_list, + _setup_colormap_normalization, + ) + + assert dp_mod._get_expression_matrix is _get_expression_matrix + assert dp_mod._infer_score_key is _infer_score_key + assert dp_mod._prepare_gene_list is _prepare_gene_list + assert dp_mod._setup_colormap_normalization is _setup_colormap_normalization + + +def test_dotplot_standalone_returns_figure(): + from kompot.plot import dotplot + + adata = _make_adata() + genes = list(adata.var_names[:5]) + fig = dotplot( + adata, + genes=genes, + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + return_fig=True, + ) + assert isinstance(fig, plt.Figure) + # main + colorbar + size legend (>= 3; matplotlib may add a colorbar ax) + assert len(fig.axes) >= 3 + ax_main = fig.axes[0] + # main axis renders one scatter collection with 5 * 3 = 15 offsets + assert len(ax_main.collections) >= 1 + assert ax_main.collections[0].get_offsets().shape[0] == len(genes) * 3 + plt.close(fig) + + +def test_dotplot_standalone_without_return_fig_returns_none(): + from kompot.plot import dotplot + + adata = _make_adata() + out = dotplot( + adata, + genes=list(adata.var_names[:4]), + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + return_fig=False, + ) + assert out is None + plt.close("all") + + +def test_dotplot_embeds_into_provided_axes(): + from kompot.plot import dotplot + + adata = _make_adata() + fig = plt.figure(figsize=(6, 3), layout="constrained") + gs = fig.add_gridspec(1, 2, width_ratios=[1.0, 0.14]) + ax_main = fig.add_subplot(gs[0, 0]) + inner = gs[0, 1].subgridspec(2, 1, height_ratios=[1.0, 1.0]) + ax_size = fig.add_subplot(inner[0, 0]) + ax_cbar = fig.add_subplot(inner[1, 0]) + + figs_before = list(plt.get_fignums()) + axes_before = list(fig.axes) + + out = dotplot( + adata, + genes=list(adata.var_names[:4]), + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + axes=(ax_main, ax_cbar, ax_size), + ) + # caller owns the figure; nothing returned + assert out is None + # no new Figure was created by the function + assert plt.get_fignums() == figs_before + # provided axes are still present and the main axis now has a scatter + for ax in axes_before: + assert ax in fig.axes + assert len(ax_main.collections) >= 1 + assert ax_main.collections[0].get_offsets().shape[0] == 4 * 3 + plt.close(fig) + + +def test_dotplot_rejects_bad_axes_length(): + from kompot.plot import dotplot + + adata = _make_adata() + fig, (ax1, ax2) = plt.subplots(1, 2) + with pytest.raises(ValueError, match="3-tuple"): + dotplot( + adata, + genes=list(adata.var_names[:3]), + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + axes=(ax1, ax2), + ) + plt.close(fig) + + +def test_dotplot_auto_picks_top_n_by_score(): + from kompot.plot import dotplot + + adata = _make_adata() + n = 4 + fig = dotplot( + adata, + genes=None, + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + score_key="kompot_de_cond_mahalanobis", + n_top=n, + return_fig=True, + ) + assert isinstance(fig, plt.Figure) + ax_main = fig.axes[0] + ytl = [t.get_text() for t in ax_main.get_yticklabels()] + assert len(ytl) == n + expected = list( + adata.var["kompot_de_cond_mahalanobis"] + .sort_values(ascending=False) + .head(n) + .index + ) + # order matters — top-ranked gene sits at the top of the y-axis + assert ytl == expected + plt.close(fig) + + +def test_dotplot_auto_pick_honors_filter_key(): + from kompot.plot import dotplot + + adata = _make_adata() + eligible = adata.var.index[adata.var["kompot_de_cond_is_de"].astype(bool)] + assert len(eligible) >= 3 + fig = dotplot( + adata, + genes=None, + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + score_key="kompot_de_cond_mahalanobis", + filter_key="kompot_de_cond_is_de", + n_top=min(3, len(eligible)), + return_fig=True, + ) + ax_main = fig.axes[0] + ytl = [t.get_text() for t in ax_main.get_yticklabels()] + assert set(ytl).issubset(set(eligible)) + plt.close(fig) + + +def test_dotplot_auto_pick_without_score_raises(): + from kompot.plot import dotplot + + # Strip Mahalanobis-like columns so field inference has nothing to latch + # onto — confirms we raise a helpful error rather than silently picking + # random columns. + adata = _make_adata() + adata.var = adata.var.drop( + columns=[c for c in adata.var.columns if "mahalanobis" in c] + ) + with pytest.raises(ValueError, match="Mahalanobis"): + dotplot( + adata, + genes=None, + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + ) + plt.close("all") + + +def test_dotplot_missing_lfc_layer_raises(): + from kompot.plot import dotplot + + adata = _make_adata() + with pytest.raises(ValueError, match="fold-change layer"): + dotplot( + adata, + genes=list(adata.var_names[:3]), + groupby="celltype", + expr_layer="logcounts", + ) + plt.close("all") + + +def test_dotplot_symmetric_color_scale_about_zero(): + from kompot.plot import dotplot + + adata = _make_adata() + # skew the LFC layer strongly positive and stick a few extreme outliers + # in so the 98th-percentile cap matters. + rng = np.random.default_rng(1) + lfc = rng.normal(loc=2.0, scale=0.5, size=adata.shape) + lfc[:4, 0] = 100.0 + adata.layers["lfc"] = lfc + + fig = dotplot( + adata, + genes=list(adata.var_names[:5]), + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + vabs_pct=98.0, + return_fig=True, + ) + vmin, vmax = fig.axes[0].collections[0].get_clim() + assert vmax > 0 + assert vmin == pytest.approx(-vmax) + # clip should be below the 100.0 outliers (98th-pct cap) + assert vmax < 100.0 + plt.close(fig) + + +def test_dotplot_vabs_min_floors_scale(): + from kompot.plot import dotplot + + adata = _make_adata() + adata.layers["lfc"] = np.zeros(adata.shape, dtype=float) + fig = dotplot( + adata, + genes=list(adata.var_names[:3]), + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + vabs_min=0.5, + return_fig=True, + ) + vmin, vmax = fig.axes[0].collections[0].get_clim() + assert vmax == pytest.approx(0.5) + assert vmin == pytest.approx(-0.5) + plt.close(fig) + + +def test_dotplot_vmax_overrides_percentile(): + from kompot.plot import dotplot + + adata = _make_adata() + fig = dotplot( + adata, + genes=list(adata.var_names[:3]), + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + vmax=2.5, + return_fig=True, + ) + vmin, vmax = fig.axes[0].collections[0].get_clim() + assert vmax == pytest.approx(2.5) + assert vmin == pytest.approx(-2.5) + plt.close(fig) + + +def test_dotplot_min_cells_drops_rare_categories(): + from kompot.plot import dotplot + + adata = _make_adata() + # Force exactly one 'C' cell so min_cells=5 drops the category. + ct = adata.obs["celltype"].astype(str).copy() + ct.iloc[:] = "A" + ct.iloc[:30] = "B" + ct.iloc[0] = "C" + adata.obs["celltype"] = pd.Categorical(ct, categories=["A", "B", "C"]) + + fig = dotplot( + adata, + genes=list(adata.var_names[:3]), + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + min_cells=5, + return_fig=True, + ) + ax_main = fig.axes[0] + xtl = [t.get_text() for t in ax_main.get_xticklabels()] + assert xtl == ["A", "B"] + plt.close(fig) + + +def test_dotplot_categories_order_filters_and_orders(): + from kompot.plot import dotplot + + adata = _make_adata() + fig = dotplot( + adata, + genes=list(adata.var_names[:3]), + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + categories_order=["C", "A"], + return_fig=True, + ) + xtl = [t.get_text() for t in fig.axes[0].get_xticklabels()] + assert xtl == ["C", "A"] + plt.close(fig) + + +def test_dotplot_missing_gene_raises(): + from kompot.plot import dotplot + + adata = _make_adata() + with pytest.raises(KeyError): + dotplot( + adata, + genes=["G0", "NOT_A_GENE"], + groupby="celltype", + lfc_layer="lfc", + expr_layer="logcounts", + ) + plt.close("all")