diff --git a/docs/user/training.rst b/docs/user/training.rst index a2741424..28a6dc42 100644 --- a/docs/user/training.rst +++ b/docs/user/training.rst @@ -11,6 +11,7 @@ Training training.LoggingCallback training.Metrics training.PCADecodedMetrics + training.PCADecodedMetrics2 training.VAEDecodedMetrics training.WandbLogger training.CellFlowTrainer diff --git a/src/cellflow/training/__init__.py b/src/cellflow/training/__init__.py index 387411d2..d6b6f176 100644 --- a/src/cellflow/training/__init__.py +++ b/src/cellflow/training/__init__.py @@ -5,6 +5,7 @@ LoggingCallback, Metrics, PCADecodedMetrics, + PCADecodedMetrics2, VAEDecodedMetrics, WandbLogger, ) @@ -19,6 +20,6 @@ "WandbLogger", "CallbackRunner", "PCADecodedMetrics", - "PCADecoder", + "PCADecodedMetrics2", "VAEDecodedMetrics", ] diff --git a/src/cellflow/training/_callbacks.py b/src/cellflow/training/_callbacks.py index f1539b79..ae63a863 100644 --- a/src/cellflow/training/_callbacks.py +++ b/src/cellflow/training/_callbacks.py @@ -6,6 +6,7 @@ import jax.tree as jt import jax.tree_util as jtu import numpy as np +import scipy from cellflow._types import ArrayLike from cellflow.metrics._metrics import ( @@ -24,6 +25,7 @@ "WandbLogger", "CallbackRunner", "PCADecodedMetrics", + "PCADecodedMetrics2", "VAEDecodedMetrics", ] @@ -321,6 +323,103 @@ def on_log_iteration( return metrics +class PCADecodedMetrics2(Metrics): + """Callback to compute metrics on true validation data during training + + Parameters + ---------- + ref_adata + An :class:`~anndata.AnnData` object with the reference data containing + ``adata.varm["X_mean"]`` and ``adata.varm["PCs"]``. + validation_adata + Dictionary where the keys are the names of the datasets given in + :func:`~cellflow.model.CellFlow.prepare_validation_data` and the values are the corresponding + :class:`~anndata.AnnData` objects. + metrics + List of metrics to compute. Supported metrics are ``"r_squared"``, ``"mmd"``, + ``"sinkhorn_div"``, and ``"e_distance"``. + metric_aggregations + List of aggregation functions to use for each metric. Supported aggregations are ``"mean"`` + and ``"median"``. + condition_id_key + Key in :attr:`~anndata.AnnData.obs` that defines the condition id. + layer + Key in :attr:`~anndata.AnnData.layers` from which to get the counts. + If :obj:`None`, use :attr:`~anndata.AnnData.X`. + log_prefix + Prefix to add to the log keys. + """ + + def __init__( + self, + ref_adata: ad.AnnData, + validation_adata: dict[str, ad.AnnData], + metrics: list[Literal["r_squared", "mmd", "sinkhorn_div", "e_distance"]], + metric_aggregations: list[Literal["mean", "median"]] = None, + condition_id_key: str = "condition", + layers: str | None = None, + log_prefix: str = "pca_decoded_2_", + ): + super().__init__(metrics, metric_aggregations) + self.pcs = ref_adata.varm["PCs"] + self.means = ref_adata.varm["X_mean"] + self.reconstruct_data = lambda x: x @ np.transpose(self.pcs) + np.transpose(self.means) + self.validation_adata = validation_adata + self.condition_id_key = condition_id_key + self.layers = layers + self.log_prefix = log_prefix + + def on_log_iteration( + self, + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], + solver: _genot.GENOT | _otfm.OTFlowMatching, + ) -> dict[str, float]: + """Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction + + Parameters + ---------- + valid_source_data + Source data in nested dictionary format with same keys as ``valid_true_data`` + valid_true_data + Validation data in nested dictionary format with same keys as ``valid_pred_data`` + valid_pred_data + Predicted data in nested dictionary format with same keys as ``valid_true_data`` + solver + :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` + solver with a conditional velocity field. + """ + true_counts = {} + for name in self.validation_adata.keys(): + true_counts[name] = {} + conditions_adata = set(self.validation_adata[name].obs[self.condition_id_key].unique()) + conditions_pred = valid_pred_data[name].keys() + for cond in conditions_adata & conditions_pred: + condition_mask = self.validation_adata[name].obs[self.condition_id_key] == cond + counts = ( + self.validation_adata[name][condition_mask].X + if self.layers is None + else self.validation_adata[name][condition_mask].layers[self.layers] + ) + true_counts[name][cond] = counts.toarray() if scipy.sparse.issparse(counts) else counts + + predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data) + + metrics = super().on_log_iteration(valid_source_data, true_counts, predicted_data_decoded, solver) + metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()} + return metrics + + def on_train_end( + self, + valid_source_data: dict[str, dict[str, ArrayLike]], + valid_true_data: dict[str, dict[str, ArrayLike]], + valid_pred_data: dict[str, dict[str, ArrayLike]], + solver: _genot.GENOT | _otfm.OTFlowMatching, + ) -> dict[str, float]: + return self.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver) + + class VAEDecodedMetrics(Metrics): """Callback to compute metrics on decoded validation data during training diff --git a/src/cellflow/training/_trainer.py b/src/cellflow/training/_trainer.py index 98a594c1..76f2c332 100644 --- a/src/cellflow/training/_trainer.py +++ b/src/cellflow/training/_trainer.py @@ -16,8 +16,6 @@ class CellFlowTrainer: Parameters ---------- - dataloader - Data sampler. solver :class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT` solver with a conditional velocity field. diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index f1346ce6..4459362d 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -1,6 +1,7 @@ import anndata as ad import jax.numpy as jnp import jax.tree_util as jtu +import numpy as np import pytest @@ -18,6 +19,33 @@ def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics): assert reconstruction.shape == adata_pca.X.shape assert jnp.allclose(reconstruction, adata_pca.layers["counts"]) + @pytest.mark.parametrize("sparse_matrix", [True, False]) + @pytest.mark.parametrize("layers", [None, "test"]) + def test_pca_decoded_2(self, adata_pca: ad.AnnData, sparse_matrix, layers): + from cellflow.solvers import OTFlowMatching + from cellflow.training import PCADecodedMetrics2 + + adata_gt = adata_pca.copy() + adata_gt.obs["condition"] = np.random.choice(["A", "B"], size=adata_pca.shape[0]) + if not sparse_matrix: + adata_gt.X = adata_gt.X.toarray() + if layers is not None: + adata_gt.layers[layers] = adata_gt.X.copy() + + decoded_metrics_callback = PCADecodedMetrics2( + ref_adata=adata_pca, + validation_adata={"test": adata_gt}, + metrics=["r_squared"], + condition_id_key="condition", + layers=layers, + ) + + valid_pred_data = {"test": {"A": np.random.random((2, 10)), "B": np.random.random((2, 10))}} + + res = decoded_metrics_callback.on_log_iteration({}, {}, valid_pred_data, OTFlowMatching) + assert "pca_decoded_2_test_r_squared_mean" in res + assert isinstance(res["pca_decoded_2_test_r_squared_mean"], float) + @pytest.mark.parametrize("metrics", [["r_squared"]]) def test_vae_reconstruction(self, metrics): from scvi.data import synthetic_iid