diff --git a/src/jobflow/__init__.py b/src/jobflow/__init__.py index 34225c3e..8dd4f8c2 100644 --- a/src/jobflow/__init__.py +++ b/src/jobflow/__init__.py @@ -4,7 +4,7 @@ from jobflow.core.flow import Flow, JobOrder, flow from jobflow.core.job import Job, JobConfig, Response, job from jobflow.core.maker import Maker -from jobflow.core.reference import OnMissing, OutputReference +from jobflow.core.reference import OnMissing, OutputReference, ResolvedReference from jobflow.core.state import CURRENT_JOB from jobflow.core.store import JobStore from jobflow.managers.local import run_locally diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index d00501b2..9c6dfb0a 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -36,9 +36,14 @@ class JobConfig(MSONable): Parameters ---------- resolve_references - Whether to resolve any references before the job function is executed. - If ``False`` the unresolved reference objects will be passed into the function - call. + Controls how :obj:`.OutputReference` inputs are handled before the job + function is executed. Accepted values: + + - ``True`` (default): resolve references and pass the resolved values. + - ``False``: pass the unresolved :obj:`.OutputReference` objects to the + function. + - ``"both"``: pass a :obj:`.ResolvedReference` wrapper containing both + the original reference and its resolved value. on_missing_references What to do if the references cannot be resolved. The default is to throw an error. @@ -61,13 +66,23 @@ class JobConfig(MSONable): A :obj:`JobConfig` object. """ - resolve_references: bool = True + resolve_references: bool | str = True on_missing_references: OnMissing = OnMissing.ERROR manager_config: dict = field(default_factory=dict) expose_store: bool = False pass_manager_config: bool = True response_manager_config: dict = field(default_factory=dict) + def __post_init__(self) -> None: + """Verify that resolve_references contains an acceptable value.""" + if not isinstance(self.resolve_references, bool) and ( + self.resolve_references != "both" + ): + raise ValueError( + "resolve_references must be a bool or the string 'both', got " + f"{self.resolve_references!r}" + ) + @overload def job(method: Callable | None = None) -> Callable[..., Job]: @@ -615,7 +630,10 @@ def run(self, store: jobflow.JobStore, job_dir: Path = None) -> Response: CURRENT_JOB.store = store if self.config.resolve_references: - self.resolve_args(store=store) + self.resolve_args( + store=store, + wrap_resolved=self.config.resolve_references == "both", + ) # if Job was created using the job decorator, then access the original function function = getattr(self.function, "original", self.function) @@ -702,6 +720,7 @@ def resolve_args( self, store: jobflow.JobStore, inplace: bool = True, + wrap_resolved: bool = False, ) -> Job: """ Resolve any :obj:`.OutputReference` objects in the input arguments. @@ -714,6 +733,9 @@ def resolve_args( A maggma store to use for resolving references. inplace Update the arguments of the current job or return a new job object. + wrap_resolved + If True, wrap each resolved reference in a :obj:`.ResolvedReference` + exposing both the original reference and its resolved value. Returns ------- @@ -730,12 +752,14 @@ def resolve_args( store, cache=cache, on_missing=self.config.on_missing_references, + wrap_resolved=wrap_resolved, ) resolved_kwargs = find_and_resolve_references( self.function_kwargs, store, cache=cache, on_missing=self.config.on_missing_references, + wrap_resolved=wrap_resolved, ) resolved_args = tuple(resolved_args) diff --git a/src/jobflow/core/reference.py b/src/jobflow/core/reference.py index ab17b3df..84646f07 100644 --- a/src/jobflow/core/reference.py +++ b/src/jobflow/core/reference.py @@ -4,6 +4,7 @@ import contextlib import typing +from dataclasses import dataclass from typing import Any from monty.json import MontyDecoder, MontyEncoder, MSONable, jsanitize @@ -305,6 +306,36 @@ def as_dict(self): } +@dataclass +class ResolvedReference(MSONable): + """ + A wrapper pairing an :obj:`OutputReference` with its resolved value. + + The reference identity is stored as plain ``uuid`` and ``attributes`` fields + rather than as a nested :obj:`OutputReference` object. This avoids the + serialized form containing an ``OutputReference`` that in some cases may be + resolved again on a subsequent pass. + + Parameters + ---------- + uuid + The job uuid to which the output belongs. + attributes + The chain of attribute/index accesses stored on the original reference. + value + The resolved value the reference points to. + """ + + uuid: str + attributes: tuple[tuple[str, Any], ...] + value: Any + + @property + def reference(self) -> OutputReference: + """Reconstruct the original :obj:`OutputReference` on demand.""" + return OutputReference(self.uuid, self.attributes) + + def resolve_references( references: Sequence[OutputReference], store: jobflow.JobStore, @@ -409,6 +440,7 @@ def find_and_resolve_references( cache: dict[str, Any] = None, on_missing: OnMissing = OnMissing.ERROR, deserialize: bool = True, + wrap_resolved: bool = False, ) -> Any: """ Return the input but with all output references replaced with their resolved values. @@ -431,6 +463,9 @@ def find_and_resolve_references( If False, the data extracted from the store will not be deserialized. Note that in this case, if a reference contains a derived property, it cannot be resolved. + wrap_resolved + If True, each resolved reference is wrapped in a :obj:`ResolvedReference` + exposing both the original reference and its resolved value. Returns ------- @@ -449,9 +484,14 @@ def find_and_resolve_references( if isinstance(arg, OutputReference): # if the argument is a reference then stop there - return arg.resolve( + resolved = arg.resolve( store, cache=cache, on_missing=on_missing, deserialize=deserialize ) + if wrap_resolved: + return ResolvedReference( + uuid=arg.uuid, attributes=arg.attributes, value=resolved + ) + return resolved if isinstance(arg, (float, int, str, bool)): # argument is a primitive, we won't find a reference here @@ -476,11 +516,18 @@ def find_and_resolve_references( # replace the references in the arg dict for location, reference in zip(locations, references): - # skip references that have not been resolved, e.g., on missing is PASS - if reference == resolved_references[reference]: + # skip references that have not been resolved. + # If wrap_resolved always wrap in ResolvedReference. + if reference == resolved_references[reference] and not wrap_resolved: continue resolved_reference = resolved_references[reference] + if wrap_resolved: + resolved_reference = ResolvedReference( + uuid=reference.uuid, + attributes=reference.attributes, + value=resolved_reference, + ) set_(encoded_arg, list(location), resolved_reference) # deserialize dict array diff --git a/tests/core/test_job.py b/tests/core/test_job.py index f9416221..e90045d5 100644 --- a/tests/core/test_job.py +++ b/tests/core/test_job.py @@ -237,6 +237,7 @@ def test_job_config(memory_jobstore): JobConfig, OnMissing, OutputReference, + ResolvedReference, Response, ) @@ -272,6 +273,18 @@ def return_arg(arg): response = test_job.run(memory_jobstore) assert response.output is True + # test resolve_references="both": job receives a ResolvedReference + config = JobConfig(resolve_references="both") + test_job = Job(return_arg, function_args=(ref,), config=config) + response = test_job.run(memory_jobstore) + assert isinstance(response.output, ResolvedReference) + assert response.output.reference == ref + assert response.output.value == 5 + + # validate invalid values are rejected + with pytest.raises(ValueError, match="resolve_references must be"): + JobConfig(resolve_references="wrong_value") + ref = OutputReference("xyz") config = JobConfig(on_missing_references=OnMissing.ERROR) test_job = Job(return_arg, function_args=(ref,), config=config) diff --git a/tests/core/test_reference.py b/tests/core/test_reference.py index b91b6fc7..5736ec38 100644 --- a/tests/core/test_reference.py +++ b/tests/core/test_reference.py @@ -391,6 +391,7 @@ def test_find_and_resolve_references(memory_jobstore): from jobflow.core.reference import ( OnMissing, OutputReference, + ResolvedReference, find_and_resolve_references, ) @@ -496,6 +497,58 @@ def plus(self): [ref1, ref3], memory_jobstore, on_missing=OnMissing.ERROR, deserialize=False ) + # test wrap_resolved: single reference + wrapped = find_and_resolve_references(ref1, memory_jobstore, wrap_resolved=True) + assert isinstance(wrapped, ResolvedReference) + assert wrapped.reference == ref1 + assert wrapped.value == 101 + + # test wrap_resolved: nested + wrapped_nested = find_and_resolve_references( + {"a": [ref1, ref2]}, memory_jobstore, wrap_resolved=True + ) + inner = wrapped_nested["a"] + assert isinstance(inner[0], ResolvedReference) + assert inner[0].reference == ref1 + assert inner[0].value == 101 + assert isinstance(inner[1], ResolvedReference) + assert inner[1].reference == ref2 + assert inner[1].value == "xyz" + + # test wrap_resolved with on_missing=PASS + wrapped_pass = find_and_resolve_references( + [ref1, ref3], + memory_jobstore, + on_missing=OnMissing.PASS, + wrap_resolved=True, + ) + assert isinstance(wrapped_pass[0], ResolvedReference) + assert isinstance(wrapped_pass[1], ResolvedReference) + assert wrapped_pass[0].value == 101 + assert wrapped_pass[1].value == ref3 + + # test wrap_resolved with on_missing=NONE + wrapped_pass = find_and_resolve_references( + [ref1, ref3], + memory_jobstore, + on_missing=OnMissing.NONE, + wrap_resolved=True, + ) + assert isinstance(wrapped_pass[0], ResolvedReference) + assert isinstance(wrapped_pass[1], ResolvedReference) + assert wrapped_pass[0].value == 101 + assert wrapped_pass[1].value is None + + # test wrap_resolved with on_missing=ERROR + with pytest.raises(ValueError, match="Could not resolve reference"): + find_and_resolve_references( + [ref1, ref3], + memory_jobstore, + on_missing=OnMissing.ERROR, + deserialize=False, + wrap_resolved=True, + ) + def test_circular_resolve(memory_jobstore): from jobflow.core.reference import OutputReference diff --git a/tests/managers/test_local.py b/tests/managers/test_local.py index e7c9e3c9..661d4ac1 100644 --- a/tests/managers/test_local.py +++ b/tests/managers/test_local.py @@ -469,3 +469,26 @@ def test_external_reference(memory_jobstore, clean_dir, simple_job): assert responses[uuid2][1].output == "12345_end_end" assert isinstance(responses[uuid2][1].job_dir, Path) assert os.path.isdir(responses[uuid2][1].job_dir) + + +def test_resolve_references_both(memory_jobstore, clean_dir): + from jobflow import Flow, JobConfig, job, run_locally + from jobflow.core.reference import OutputReference, ResolvedReference + + @job + def initial(): + return 2 + + @job + def verify(arg): + assert isinstance(arg, ResolvedReference) + assert isinstance(arg.reference, OutputReference) + assert arg.value == 2 + return arg.value * 2 + + job1 = initial() + job2 = verify(job1.output) + job2.config = JobConfig(resolve_references="both") + + responses = run_locally(Flow([job1, job2]), store=memory_jobstore) + assert responses[job2.uuid][1].output == 4