Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/user/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Training
training.LoggingCallback
training.Metrics
training.PCADecodedMetrics
training.PCADecodedMetrics2
training.VAEDecodedMetrics
training.WandbLogger
training.CellFlowTrainer
3 changes: 2 additions & 1 deletion src/cellflow/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LoggingCallback,
Metrics,
PCADecodedMetrics,
PCADecodedMetrics2,
VAEDecodedMetrics,
WandbLogger,
)
Expand All @@ -19,6 +20,6 @@
"WandbLogger",
"CallbackRunner",
"PCADecodedMetrics",
"PCADecoder",
"PCADecodedMetrics2",
"VAEDecodedMetrics",
]
99 changes: 99 additions & 0 deletions src/cellflow/training/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.tree as jt
import jax.tree_util as jtu
import numpy as np
import scipy

from cellflow._types import ArrayLike
from cellflow.metrics._metrics import (
Expand All @@ -24,6 +25,7 @@
"WandbLogger",
"CallbackRunner",
"PCADecodedMetrics",
"PCADecodedMetrics2",
"VAEDecodedMetrics",
]

Expand Down Expand Up @@ -321,6 +323,103 @@ def on_log_iteration(
return metrics


class PCADecodedMetrics2(Metrics):
"""Callback to compute metrics on true validation data during training

Parameters
----------
ref_adata
An :class:`~anndata.AnnData` object with the reference data containing
``adata.varm["X_mean"]`` and ``adata.varm["PCs"]``.
validation_adata
Dictionary where the keys are the names of the datasets given in
:func:`~cellflow.model.CellFlow.prepare_validation_data` and the values are the corresponding
:class:`~anndata.AnnData` objects.
metrics
List of metrics to compute. Supported metrics are ``"r_squared"``, ``"mmd"``,
``"sinkhorn_div"``, and ``"e_distance"``.
metric_aggregations
List of aggregation functions to use for each metric. Supported aggregations are ``"mean"``
and ``"median"``.
condition_id_key
Key in :attr:`~anndata.AnnData.obs` that defines the condition id.
layer
Key in :attr:`~anndata.AnnData.layers` from which to get the counts.
If :obj:`None`, use :attr:`~anndata.AnnData.X`.
log_prefix
Prefix to add to the log keys.
"""

def __init__(
self,
ref_adata: ad.AnnData,
validation_adata: dict[str, ad.AnnData],
metrics: list[Literal["r_squared", "mmd", "sinkhorn_div", "e_distance"]],
metric_aggregations: list[Literal["mean", "median"]] = None,
condition_id_key: str = "condition",
layers: str | None = None,
log_prefix: str = "pca_decoded_2_",
):
super().__init__(metrics, metric_aggregations)
self.pcs = ref_adata.varm["PCs"]
self.means = ref_adata.varm["X_mean"]
self.reconstruct_data = lambda x: x @ np.transpose(self.pcs) + np.transpose(self.means)
self.validation_adata = validation_adata
self.condition_id_key = condition_id_key
self.layers = layers
self.log_prefix = log_prefix

def on_log_iteration(
self,
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _genot.GENOT | _otfm.OTFlowMatching,
) -> dict[str, float]:
"""Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction

Parameters
----------
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.
"""
true_counts = {}
for name in self.validation_adata.keys():
true_counts[name] = {}
conditions_adata = set(self.validation_adata[name].obs[self.condition_id_key].unique())
conditions_pred = valid_pred_data[name].keys()
for cond in conditions_adata & conditions_pred:
condition_mask = self.validation_adata[name].obs[self.condition_id_key] == cond
counts = (
self.validation_adata[name][condition_mask].X
if self.layers is None
else self.validation_adata[name][condition_mask].layers[self.layers]
)
true_counts[name][cond] = counts.toarray() if scipy.sparse.issparse(counts) else counts

predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data)

metrics = super().on_log_iteration(valid_source_data, true_counts, predicted_data_decoded, solver)
metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()}
return metrics

def on_train_end(
self,
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _genot.GENOT | _otfm.OTFlowMatching,
) -> dict[str, float]:
return self.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver)


class VAEDecodedMetrics(Metrics):
"""Callback to compute metrics on decoded validation data during training

Expand Down
2 changes: 0 additions & 2 deletions src/cellflow/training/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class CellFlowTrainer:

Parameters
----------
dataloader
Data sampler.
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.
Expand Down
28 changes: 28 additions & 0 deletions tests/trainer/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import anndata as ad
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import pytest


Expand All @@ -18,6 +19,33 @@ def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics):
assert reconstruction.shape == adata_pca.X.shape
assert jnp.allclose(reconstruction, adata_pca.layers["counts"])

@pytest.mark.parametrize("sparse_matrix", [True, False])
@pytest.mark.parametrize("layers", [None, "test"])
def test_pca_decoded_2(self, adata_pca: ad.AnnData, sparse_matrix, layers):
from cellflow.solvers import OTFlowMatching
from cellflow.training import PCADecodedMetrics2

adata_gt = adata_pca.copy()
adata_gt.obs["condition"] = np.random.choice(["A", "B"], size=adata_pca.shape[0])
if not sparse_matrix:
adata_gt.X = adata_gt.X.toarray()
if layers is not None:
adata_gt.layers[layers] = adata_gt.X.copy()

decoded_metrics_callback = PCADecodedMetrics2(
ref_adata=adata_pca,
validation_adata={"test": adata_gt},
metrics=["r_squared"],
condition_id_key="condition",
layers=layers,
)

valid_pred_data = {"test": {"A": np.random.random((2, 10)), "B": np.random.random((2, 10))}}

res = decoded_metrics_callback.on_log_iteration({}, {}, valid_pred_data, OTFlowMatching)
assert "pca_decoded_2_test_r_squared_mean" in res
assert isinstance(res["pca_decoded_2_test_r_squared_mean"], float)

@pytest.mark.parametrize("metrics", [["r_squared"]])
def test_vae_reconstruction(self, metrics):
from scvi.data import synthetic_iid
Expand Down
Loading