Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions src/jobflow/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,17 @@

from jobflow.core.schemas import JobStoreDocument

obj_type = Union[str, Enum, type[MSONable], list[Union[Enum, str, type[MSONable]]]]
save_type = Optional[dict[str, obj_type]]
load_type = Union[bool, dict[str, Union[bool, obj_type]]]
obj_save_type = Union[
str,
Enum,
type[MSONable],
list[Union[Enum, str, type[MSONable], list[Union[str, type[MSONable]]]]],
]
save_type = Optional[dict[str, obj_save_type]]
obj_load_type = Union[
str, Enum, type[MSONable], list[Union[Enum, str, type[MSONable]]]
]
load_type = Union[bool, dict[str, Union[bool, obj_load_type]]]


T = typing.TypeVar("T", bound="JobStore")
Expand All @@ -43,6 +51,9 @@ class JobStore(Store):
Which items to save in additional stores when uploading documents. Given as a
mapping of ``{store name: store_type}`` where ``store_type`` can be a dictionary
key (string or enum), an :obj:`.MSONable` class, or a list of keys/classes.
If the list of keys/classes itself contains a list, this will be treated
as the path to the save key. Can be used if a key is duplicate in the output
and only a single occurrence shall be put in the additional store.
load
Which items to load from additional stores when querying documents. Given as a
mapping of ``{store name: store_type}`` where ``store_type`` can be `True``, in
Expand Down Expand Up @@ -304,7 +315,10 @@ def update(
locations = []
for store_name, store_save in save_keys.items():
for save_key in store_save:
locations.extend(find_key(doc, save_key, include_end=True))
if isinstance(save_key, list):
locations.append(save_key)
else:
locations.extend(find_key(doc, save_key, include_end=True))

locations = get_root_locations(locations)
objects = [get(doc, list(loc)) for loc in locations]
Expand Down Expand Up @@ -726,7 +740,7 @@ def _prepare_load(

def _prepare_save(
save: bool | save_type,
) -> dict[str, list[str | type[MSONable]]]:
) -> dict[str, list[str | type[MSONable] | list[str | type[MSONable]]]]:
"""Standardize save type."""
from enum import Enum

Expand Down
Loading