diff --git a/docs/api.md b/docs/api.md index 279070e50..14883f1e7 100644 --- a/docs/api.md +++ b/docs/api.md @@ -92,6 +92,7 @@ Writing a complete {class}`AnnData` object to disk in anndata’s native formats AnnData.write_h5ad AnnData.write_zarr + AnnData.unwriteable .. diff --git a/docs/concatenation.rst b/docs/concatenation.rst index fe4046d75..b22b68d7a 100644 --- a/docs/concatenation.rst +++ b/docs/concatenation.rst @@ -26,7 +26,6 @@ Let's start off with an example: AnnData object with n_obs × n_vars = 700 × 765 obs: 'bulk_labels', 'n_genes', 'percent_mito', 'n_counts', 'S_score', 'G2M_score', 'phase', 'louvain' var: 'n_counts', 'means', 'dispersions', 'dispersions_norm', 'highly_variable' - uns: 'bulk_labels_colors', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups' obsm: 'X_pca', 'X_umap' varm: 'PCs' obsp: ... @@ -165,9 +164,9 @@ First, our example case: >>> blobs AnnData object with n_obs × n_vars = 640 × 30 obs: 'blobs' - uns: 'pca' obsm: 'X_pca' varm: 'PCs' + uns: 'pca' Now we will split this object by the categorical `"blobs"` and recombine it to illustrate different merge strategies. @@ -181,9 +180,9 @@ Now we will split this object by the categorical `"blobs"` and recombine it to i >>> adatas[0] AnnData object with n_obs × n_vars = 128 × 30 obs: 'blobs' - uns: 'pca' obsm: 'X_pca', 'qc' varm: 'PCs', '0_qc' + uns: 'pca' `adatas` is now a list of datasets with disjoint sets of observations and a common set of variables. Each object has had QC metrics computed, with observation-wise metrics stored under `"qc"` in `.obsm`, and variable-wise metrics stored with a unique key for each subset. diff --git a/docs/release-notes/2372.feat.md b/docs/release-notes/2372.feat.md new file mode 100644 index 000000000..61e9fa0bc --- /dev/null +++ b/docs/release-notes/2372.feat.md @@ -0,0 +1 @@ +New {meth}`anndata.AnnData.unwriteable` for checking if an `AnnData` can be written {user}`ilan-gold` diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index 95dd75c64..3c8574436 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -4,7 +4,7 @@ from __future__ import annotations -from collections import OrderedDict +from collections import OrderedDict, defaultdict from collections.abc import Mapping, MutableMapping, Sequence from copy import copy, deepcopy from functools import singledispatchmethod @@ -26,8 +26,10 @@ from .. import utils from .._settings import settings from ..compat import ( + AwkArray, DaskArray, IndexManager, + XDataset, ZarrArray, _move_adj_mtx, has_xp, @@ -39,6 +41,7 @@ axis_len, deprecation_msg, ensure_df_homogeneous, + iter_outer, raise_value_error_if_multiindex_columns, set_module, warn, @@ -62,9 +65,12 @@ from scipy import sparse from zarr.storage import StoreLike + from anndata.typing import RWAble + + from .._types import ReduceFunc from ..acc import AdRef, Array, MapAcc, RefAcc - from ..compat import XDataset - from ..typing import Index, Index1D, _Index1DNorm, _XDataType + from ..compat import CSArray, CSMatrix + from ..typing import AxisStorable, Index, Index1D, _Index1DNorm, _XDataType from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView @@ -512,53 +518,54 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False ) -> int: - def get_size(X) -> int: - def cs_to_bytes(X) -> int: - return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) + def cs_to_bytes(X: CSArray | CSMatrix) -> int: + return int(X.data.nbytes + X.indptr.nbytes + X.indices.nbytes) + def get_size(X: RWAble) -> int: if isinstance(X, h5py.Dataset) and with_disk: return int(np.array(X.shape).prod() * X.dtype.itemsize) elif isinstance(X, BaseCompressedSparseDataset) and with_disk: return cs_to_bytes(X._to_backed()) elif issparse(X): return cs_to_bytes(X) + elif isinstance(X, dict | MutableMapping): + return sum(get_size(v) for v in X.values()) else: return X.__sizeof__() - sizes = {} - attrs = ["X", "_obs", "_var"] - attrs_multi = ["_uns", "_obsm", "_varm", "varp", "_obsp", "_layers"] - for attr in attrs + attrs_multi: - if attr in attrs_multi: - keys = getattr(self, attr).keys() - s = sum(get_size(getattr(self, attr)[k]) for k in keys) + def fold_size( + elem: _XDataType | AxisStorable | pd.DataFrame | XDataset, + *, + accumulate: dict[str, int], + attr_name: str | None, # TODO: type + ): + if elem is None: + size = 0 + elif elem is self.raw: + size = ( + get_size(elem.X) + + get_size(elem.var) + + sum(get_size(v) for v in elem.varm.values()) + ) else: - s = get_size(getattr(self, attr)) - if s > 0 and show_stratified: + size = get_size(elem) + accumulate[attr_name] = size + if size > 0 and show_stratified: from tqdm import tqdm - print( - f"Size of {attr.replace('_', '.'):<7}: {tqdm.format_sizeof(s, 'B')}" - ) - sizes[attr] = s - return sum(sizes.values()) + print(f"Size of {attr_name}: {tqdm.format_sizeof(size, 'B')}") + return accumulate + + return sum(self._reduce(fold_size, init=defaultdict(int)).values()) def _gen_repr(self, n_obs, n_vars) -> str: backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" descr = f"AnnData object with n_obs × n_vars = {n_obs} × {n_vars}{backed_at}" - for attr in [ - "obs", - "var", - "uns", - "obsm", - "varm", - "layers", - "obsp", - "varp", - ]: - keys = getattr(self, attr).keys() - if len(keys) > 0: - descr += f"\n {attr}: {str(list(keys))[1:-1]}" + for attr_name, elem in iter_outer(self): + if attr_name not in {"raw", "X"}: + keys = elem.keys() + if len(keys) > 0: + descr += f"\n {attr_name}: {str(list(keys))[1:-1]}" return descr def __repr__(self) -> str: @@ -1383,27 +1390,16 @@ def to_memory(self, *, copy: bool = False) -> AnnData: mem = backed[backed.obs["cluster"] == "a", :].to_memory() """ new = {} - for attr_name in [ - "X", - "obs", - "var", - "obsm", - "varm", - "obsp", - "varp", - "layers", - "uns", - ]: - attr = getattr(self, attr_name, None) + for attr_name, attr in iter_outer(self): if attr is not None: - new[attr_name] = to_memory(attr, copy=copy) - - if self.raw is not None: - new["raw"] = { - "X": to_memory(self.raw.X, copy=copy), - "var": to_memory(self.raw.var, copy=copy), - "varm": to_memory(self.raw.varm, copy=copy), - } + if attr is self.raw: + new["raw"] = { + "X": to_memory(self.raw.X, copy=copy), + "var": to_memory(self.raw.var, copy=copy), + "varm": to_memory(self.raw.varm, copy=copy), + } + else: + new[attr_name] = to_memory(attr, copy=copy) if self.isbacked: self.file.close() @@ -1436,6 +1432,100 @@ def copy(self, filename: PathLike[str] | str | None = None) -> AnnData: write_h5ad(filename, self) return read_h5ad(filename, backed=mode) + def _reduce[T]( + self, + func: ReduceFunc[T], + *, + init: T, + ) -> T: + """Accumulate a value starting from init by iterating over the parent "elems"of the AnnData object i.e., raw, obs, varp etc. + + Parameters + ---------- + func + The function that performs the accumulation. + init + The starting value + + Returns + ------- + An accumulated value + """ + accumulate = init + for attr_name, attr in iter_outer(self): + accumulate = func(attr, accumulate=accumulate, attr_name=attr_name) + return accumulate + + def unwriteable(self, *, store_type: Literal["h5", "zarr"] | None) -> bool: + """Whether or not an `AnnData` object can be written to disk for a given store type. + + Parameters + ---------- + store_type + Which backing store - `None` indicates that it can be writeable to either. + + Returns + ------- + Whether or not this object is writeable. + While the return type may change to include richer output about which elements cannot be written, + this new type's evaluation as a boolean will not change from the current behavior i.e., + `bool(adata.unwriteable())` will always evaluate the same. + """ + + from anndata._io.specs.registry import _REGISTRY + + writeable_elems = { + src_type + for (dest_type, src_type, __) in _REGISTRY.write + if store_type is None or store_type in dest_type.__module__ + } + + def predicate( # noqa: PLR0911 + elem: RWAble, + *, + accumulate: bool, + attr_name: str | None = None, # TODO: type + ): + if elem is None: + return accumulate + if isinstance(elem, AnnData): + return accumulate and elem.unwriteable(store_type=store_type) + if isinstance(elem, pd.Categorical): + return accumulate and predicate(elem.categories, accumulate=accumulate) + if isinstance(elem, pd.Series | pd.Index): + # matches behavior in methods.py + return accumulate and predicate(elem._values, accumulate=accumulate) + if isinstance(elem, AwkArray): + import awkward as ak + + container = ak.to_buffers(ak.to_packed(elem)) + return accumulate and all( + predicate(v, accumulate=accumulate) for v in container[2].values() + ) + if attr_name == "raw": + accumulate = accumulate and type(elem.X) in writeable_elems + return accumulate and all( + predicate(e[attr], accumulate=accumulate) + for e in [elem.var, elem.varm] + for attr in e + ) + if attr_name in { + "obs", + "obsm", + "varm", + "var", + "layers", + "varp", + "obsp", + "uns", + } or isinstance(elem, pd.DataFrame | XDataset | MutableMapping): + return accumulate and all( + predicate(elem[k], accumulate=accumulate) for k in elem + ) + return accumulate and type(elem) in writeable_elems + + return self._reduce(predicate, init=True) + def var_names_make_unique(self, join: str = "-") -> None: # Important to go through the setter so obsm dataframes are updated too self.var_names = utils.make_index_unique(self.var.index, join) diff --git a/src/anndata/_io/h5ad.py b/src/anndata/_io/h5ad.py index d0540de1c..df6272567 100644 --- a/src/anndata/_io/h5ad.py +++ b/src/anndata/_io/h5ad.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from collections.abc import MutableMapping from functools import partial from pathlib import Path from types import MappingProxyType @@ -23,7 +24,7 @@ _from_fixed_length_strings, ) from ..experimental import read_dispatched -from ..utils import warn +from ..utils import iter_outer, warn from .specs import read_elem, write_elem from .specs.registry import IOSpec, write_spec from .utils import ( @@ -84,23 +85,26 @@ def write_h5ad( f = cast("h5py.Group", f["/"]) f.attrs.setdefault("encoding-type", "anndata") f.attrs.setdefault("encoding-version", "0.1.0") - - _write_x( - f, - adata, # accessing adata.X reopens adata.file if it’s backed - is_backed=adata.isbacked and adata.filename == filepath, - as_dense=as_dense, - dataset_kwargs=dataset_kwargs, - ) - _write_raw(f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs) - write_elem(f, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - write_elem(f, "var", adata.var, dataset_kwargs=dataset_kwargs) - write_elem(f, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - write_elem(f, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - write_elem(f, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - write_elem(f, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) + for k, elem in iter_outer(adata): + if k == "X": + _write_x( + f, + adata, # accessing adata.X reopens adata.file if it’s backed + is_backed=adata.isbacked and adata.filename == filepath, + as_dense=as_dense, + dataset_kwargs=dataset_kwargs, + ) + elif k == "raw": + _write_raw( + f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs + ) + else: + write_elem( + f, + k, + dict(elem) if isinstance(elem, MutableMapping) else elem, + dataset_kwargs=dataset_kwargs, + ) def _write_x( diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 43b084a00..1dee0e3e0 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping from copy import copy from functools import partial from itertools import product @@ -41,7 +41,7 @@ from ..._settings import settings from ...compat import PANDAS_STRING_ARRAY_TYPES, PANDAS_SUPPORTS_NA_VALUE -from ...utils import warn +from ...utils import iter_outer, warn from .registry import _REGISTRY, IOSpec, read_elem, read_elem_partial if TYPE_CHECKING: @@ -286,17 +286,14 @@ def write_anndata( dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): g = f.require_group(k) - if adata.X is not None: - _writer.write_elem(g, "X", adata.X, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "var", adata.var, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "raw", adata.raw, dataset_kwargs=dataset_kwargs) + for sub_key, elem in iter_outer(adata): + if not (sub_key == "X" and elem is None): + _writer.write_elem( + g, + sub_key, + dict(elem) if isinstance(elem, MutableMapping) else elem, + dataset_kwargs=dataset_kwargs, + ) @_REGISTRY.register_read(H5Group, IOSpec("anndata", "0.1.0")) diff --git a/src/anndata/_types.py b/src/anndata/_types.py index 6006b31c3..514b8b1e1 100644 --- a/src/anndata/_types.py +++ b/src/anndata/_types.py @@ -14,7 +14,10 @@ from collections.abc import Mapping from typing import Any, TypeAlias + from pandas import DataFrame + from anndata._core.xarray import Dataset2D + from anndata.typing import AxisStorable, _XDataType from ._io.specs.registry import ( IOSpec, @@ -23,6 +26,9 @@ Reader, Writer, ) + from ._types import AnnDataElem + from .compat import XDataset + else: # https://github.com/tox-dev/sphinx-autodoc-typehints/issues/580 type S = StorageType type RWAble = typing.RWAble @@ -216,3 +222,29 @@ def __call__( ] type Join_T = Literal["inner", "outer"] + + +class ReduceFunc[T](Protocol): + def __call__( + self, + elem: _XDataType | AxisStorable | DataFrame | XDataset, + *, + accumulate: T, + attr_name: AnnDataElem | None, + ) -> T: + """Function to be called on each visit within `anndata.AnnData._reduce`. + + Parameters + ---------- + elem + The current element. + accumulate + The value being accumulated. + ref_acc + A reference to help uses distinguish where they are in the `AnnData` object. + + Returns + ------- + An accumulated value + """ + ... diff --git a/src/anndata/utils.py b/src/anndata/utils.py index 31b6f1bee..e6758ae3d 100644 --- a/src/anndata/utils.py +++ b/src/anndata/utils.py @@ -19,9 +19,13 @@ from .logging import get_logger if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Mapping, Sequence + from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import Any, LiteralString + from ._core.xarray import Dataset2D + from ._types import AnnDataElem + from .typing import AxisStorable, _XDataType + logger = get_logger(__name__) @@ -435,3 +439,27 @@ def module_get_attr_redirect( return getattr(mod, new_path) msg = f"module {full_old_module_path} has no attribute {attr_name!r}" raise AttributeError(msg) + + +def iter_outer( + adata, +) -> Generator[ + tuple[AnnDataElem, AxisStorable | _XDataType | Dataset2D | pd.DataFrame] +]: + """Iterate over key-value pairs of the parent "elems" like aw, obs, varp etc""" + for attr_name in [ + "X", + "obs", + "var", + "obsm", + "varm", + "obsp", + "varp", + "layers", + "uns", + "raw", + ]: + was_closed = adata.isbacked and not adata.file.is_open + yield (attr_name, getattr(adata, attr_name)) + if was_closed: + adata.file.close() diff --git a/tests/test_base.py b/tests/test_base.py index 363bf5aba..2c3ec4d11 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -186,6 +186,31 @@ def test_df_warnings(): adata.X = df +@pytest.mark.parametrize("use_raw", [True, False], ids=["raw", "no_raw"]) +@pytest.mark.parametrize("use_uns", [True, False], ids=["uns", "no_uns"]) +def test_sizeof_print_stratified(capsys, *, use_raw: bool, use_uns: bool): + adata = gen_adata((10, 20)) + if use_uns: + adata.uns = {"foo": np.arange(10), "nested": {"here": np.arange(10)}} + else: + adata.uns = {} + if use_raw: + adata.raw = adata.copy() + adata.__sizeof__(show_stratified=True) + captured = capsys.readouterr() + for attr in [ + "X", + "layers", + "obsm", + "varm", + "obsp", + "varp", + ]: + assert attr in captured.out + assert use_uns == ("uns" in captured.out) + assert use_raw == ("raw" in captured.out) + + @pytest.mark.parametrize("attr", ["X", "layers", "obsm", "varm", "obsp", "varp"]) @pytest.mark.parametrize("when", ["init", "assign"]) def test_convert_matrix(attr, when): diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 9926c7cb3..d9ea1e595 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -98,7 +98,7 @@ def dataset_kwargs(request): @pytest.fixture -def rw(backing_h5ad): +def rw(backing_h5ad) -> tuple[ad.AnnData, ad.AnnData]: M, N = 100, 101 orig = gen_adata((M, N), **GEN_ADATA_NO_XARRAY_ARGS) orig.write(backing_h5ad) @@ -125,6 +125,53 @@ def dtype(request): # ------------------------------------------------------------------------------ +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +def test_can_write( + rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None +): + adata, _ = rw + assert adata.unwriteable(store_type=store_type) + + +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +def test_can_not_write_bad_categorical( + rw: tuple[ad.AnnData, ad.AnnData], store_type: Literal["h5", "zarr"] | None +): + + adata, _ = rw + adata.var["arrow_categorical_array"] = pd.Categorical.from_codes( + [i % 2 for i in range(adata.shape[1])], + categories=pd.arrays.IntervalArray.from_tuples([(0, 10), (20, 30)]), + ) + assert not adata.unwriteable(store_type=store_type) + + +@pytest.mark.parametrize("store_type", ["h5", "zarr", None]) +@pytest.mark.parametrize("should_nest", [True, False], ids=["nest", "no_nest"]) +@pytest.mark.parametrize("parent_elem", ["var", "uns", "raw"]) +def test_can_not_write_with_custom_array( + rw: tuple[ad.AnnData, ad.AnnData], + store_type: Literal["h5", "zarr"] | None, + parent_elem: Literal["obs", "uns", "raw"], + *, + should_nest: bool, +): + import pyarrow as pa + + adata, _ = rw + if parent_elem == "raw": + adata.raw = adata.copy() + getter = lambda adata: getattr(adata, parent_elem).var + else: + getter = lambda adata: getattr(adata, parent_elem) + if should_nest: + adata.uns["adata"] = adata.copy() + getter(adata.uns["adata"] if should_nest else adata)["arrow_array"] = ( + pd.arrays.ArrowExtensionArray(pa.array([{"x": 1, "y": True}] * adata.shape[1])) + ) + assert not adata.unwriteable(store_type=store_type) + + @pytest.mark.parametrize("typ", ARRAY_TYPES) def test_readwrite_roundtrip(typ, tmp_path, diskfmt, diskfmt2): pth1 = tmp_path / f"first.{diskfmt}"