Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/jobflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jobflow.core.flow import Flow, JobOrder, flow
from jobflow.core.job import Job, JobConfig, Response, job
from jobflow.core.maker import Maker
from jobflow.core.reference import OnMissing, OutputReference
from jobflow.core.reference import OnMissing, OutputReference, ResolvedReference
from jobflow.core.state import CURRENT_JOB
from jobflow.core.store import JobStore
from jobflow.managers.local import run_locally
Expand Down
34 changes: 29 additions & 5 deletions src/jobflow/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,14 @@ class JobConfig(MSONable):
Parameters
----------
resolve_references
Whether to resolve any references before the job function is executed.
If ``False`` the unresolved reference objects will be passed into the function
call.
Controls how :obj:`.OutputReference` inputs are handled before the job
function is executed. Accepted values:

- ``True`` (default): resolve references and pass the resolved values.
- ``False``: pass the unresolved :obj:`.OutputReference` objects to the
function.
- ``"both"``: pass a :obj:`.ResolvedReference` wrapper containing both
the original reference and its resolved value.
on_missing_references
What to do if the references cannot be resolved. The default is to throw an
error.
Expand All @@ -61,13 +66,23 @@ class JobConfig(MSONable):
A :obj:`JobConfig` object.
"""

resolve_references: bool = True
resolve_references: bool | str = True
on_missing_references: OnMissing = OnMissing.ERROR
manager_config: dict = field(default_factory=dict)
expose_store: bool = False
pass_manager_config: bool = True
response_manager_config: dict = field(default_factory=dict)

def __post_init__(self) -> None:
"""Verify that resolve_references contains an acceptable value."""
if not isinstance(self.resolve_references, bool) and (
self.resolve_references != "both"
):
raise ValueError(
"resolve_references must be a bool or the string 'both', got "
f"{self.resolve_references!r}"
)


@overload
def job(method: Callable | None = None) -> Callable[..., Job]:
Expand Down Expand Up @@ -615,7 +630,10 @@ def run(self, store: jobflow.JobStore, job_dir: Path = None) -> Response:
CURRENT_JOB.store = store

if self.config.resolve_references:
self.resolve_args(store=store)
self.resolve_args(
store=store,
wrap_resolved=self.config.resolve_references == "both",
)

# if Job was created using the job decorator, then access the original function
function = getattr(self.function, "original", self.function)
Expand Down Expand Up @@ -702,6 +720,7 @@ def resolve_args(
self,
store: jobflow.JobStore,
inplace: bool = True,
wrap_resolved: bool = False,
) -> Job:
"""
Resolve any :obj:`.OutputReference` objects in the input arguments.
Expand All @@ -714,6 +733,9 @@ def resolve_args(
A maggma store to use for resolving references.
inplace
Update the arguments of the current job or return a new job object.
wrap_resolved
If True, wrap each resolved reference in a :obj:`.ResolvedReference`
exposing both the original reference and its resolved value.

Returns
-------
Expand All @@ -730,12 +752,14 @@ def resolve_args(
store,
cache=cache,
on_missing=self.config.on_missing_references,
wrap_resolved=wrap_resolved,
)
resolved_kwargs = find_and_resolve_references(
self.function_kwargs,
store,
cache=cache,
on_missing=self.config.on_missing_references,
wrap_resolved=wrap_resolved,
)
resolved_args = tuple(resolved_args)

Expand Down
53 changes: 50 additions & 3 deletions src/jobflow/core/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
import typing
from dataclasses import dataclass
from typing import Any

from monty.json import MontyDecoder, MontyEncoder, MSONable, jsanitize
Expand Down Expand Up @@ -305,6 +306,36 @@ def as_dict(self):
}


@dataclass
class ResolvedReference(MSONable):
"""
A wrapper pairing an :obj:`OutputReference` with its resolved value.

The reference identity is stored as plain ``uuid`` and ``attributes`` fields
rather than as a nested :obj:`OutputReference` object. This avoids the
serialized form containing an ``OutputReference`` that in some cases may be
resolved again on a subsequent pass.

Parameters
----------
uuid
The job uuid to which the output belongs.
attributes
The chain of attribute/index accesses stored on the original reference.
value
The resolved value the reference points to.
"""

uuid: str
attributes: tuple[tuple[str, Any], ...]
value: Any

@property
def reference(self) -> OutputReference:
"""Reconstruct the original :obj:`OutputReference` on demand."""
return OutputReference(self.uuid, self.attributes)


def resolve_references(
references: Sequence[OutputReference],
store: jobflow.JobStore,
Expand Down Expand Up @@ -409,6 +440,7 @@ def find_and_resolve_references(
cache: dict[str, Any] = None,
on_missing: OnMissing = OnMissing.ERROR,
deserialize: bool = True,
wrap_resolved: bool = False,
) -> Any:
"""
Return the input but with all output references replaced with their resolved values.
Expand All @@ -431,6 +463,9 @@ def find_and_resolve_references(
If False, the data extracted from the store will not be deserialized.
Note that in this case, if a reference contains a derived property,
it cannot be resolved.
wrap_resolved
If True, each resolved reference is wrapped in a :obj:`ResolvedReference`
exposing both the original reference and its resolved value.

Returns
-------
Expand All @@ -449,9 +484,14 @@ def find_and_resolve_references(

if isinstance(arg, OutputReference):
# if the argument is a reference then stop there
return arg.resolve(
resolved = arg.resolve(
store, cache=cache, on_missing=on_missing, deserialize=deserialize
)
if wrap_resolved:
return ResolvedReference(
uuid=arg.uuid, attributes=arg.attributes, value=resolved
)
return resolved

if isinstance(arg, (float, int, str, bool)):
# argument is a primitive, we won't find a reference here
Expand All @@ -476,11 +516,18 @@ def find_and_resolve_references(

# replace the references in the arg dict
for location, reference in zip(locations, references):
# skip references that have not been resolved, e.g., on missing is PASS
if reference == resolved_references[reference]:
# skip references that have not been resolved.
# If wrap_resolved always wrap in ResolvedReference.
if reference == resolved_references[reference] and not wrap_resolved:
continue

resolved_reference = resolved_references[reference]
if wrap_resolved:
resolved_reference = ResolvedReference(
uuid=reference.uuid,
attributes=reference.attributes,
value=resolved_reference,
)
set_(encoded_arg, list(location), resolved_reference)

# deserialize dict array
Expand Down
13 changes: 13 additions & 0 deletions tests/core/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def test_job_config(memory_jobstore):
JobConfig,
OnMissing,
OutputReference,
ResolvedReference,
Response,
)

Expand Down Expand Up @@ -272,6 +273,18 @@ def return_arg(arg):
response = test_job.run(memory_jobstore)
assert response.output is True

# test resolve_references="both": job receives a ResolvedReference
config = JobConfig(resolve_references="both")
test_job = Job(return_arg, function_args=(ref,), config=config)
response = test_job.run(memory_jobstore)
assert isinstance(response.output, ResolvedReference)
assert response.output.reference == ref
assert response.output.value == 5

# validate invalid values are rejected
with pytest.raises(ValueError, match="resolve_references must be"):
JobConfig(resolve_references="wrong_value")

ref = OutputReference("xyz")
config = JobConfig(on_missing_references=OnMissing.ERROR)
test_job = Job(return_arg, function_args=(ref,), config=config)
Expand Down
53 changes: 53 additions & 0 deletions tests/core/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def test_find_and_resolve_references(memory_jobstore):
from jobflow.core.reference import (
OnMissing,
OutputReference,
ResolvedReference,
find_and_resolve_references,
)

Expand Down Expand Up @@ -496,6 +497,58 @@ def plus(self):
[ref1, ref3], memory_jobstore, on_missing=OnMissing.ERROR, deserialize=False
)

# test wrap_resolved: single reference
wrapped = find_and_resolve_references(ref1, memory_jobstore, wrap_resolved=True)
assert isinstance(wrapped, ResolvedReference)
assert wrapped.reference == ref1
assert wrapped.value == 101

# test wrap_resolved: nested
wrapped_nested = find_and_resolve_references(
{"a": [ref1, ref2]}, memory_jobstore, wrap_resolved=True
)
inner = wrapped_nested["a"]
assert isinstance(inner[0], ResolvedReference)
assert inner[0].reference == ref1
assert inner[0].value == 101
assert isinstance(inner[1], ResolvedReference)
assert inner[1].reference == ref2
assert inner[1].value == "xyz"

# test wrap_resolved with on_missing=PASS
wrapped_pass = find_and_resolve_references(
[ref1, ref3],
memory_jobstore,
on_missing=OnMissing.PASS,
wrap_resolved=True,
)
assert isinstance(wrapped_pass[0], ResolvedReference)
assert isinstance(wrapped_pass[1], ResolvedReference)
assert wrapped_pass[0].value == 101
assert wrapped_pass[1].value == ref3

# test wrap_resolved with on_missing=NONE
wrapped_pass = find_and_resolve_references(
[ref1, ref3],
memory_jobstore,
on_missing=OnMissing.NONE,
wrap_resolved=True,
)
assert isinstance(wrapped_pass[0], ResolvedReference)
assert isinstance(wrapped_pass[1], ResolvedReference)
assert wrapped_pass[0].value == 101
assert wrapped_pass[1].value is None

# test wrap_resolved with on_missing=ERROR
with pytest.raises(ValueError, match="Could not resolve reference"):
find_and_resolve_references(
[ref1, ref3],
memory_jobstore,
on_missing=OnMissing.ERROR,
deserialize=False,
wrap_resolved=True,
)


def test_circular_resolve(memory_jobstore):
from jobflow.core.reference import OutputReference
Expand Down
23 changes: 23 additions & 0 deletions tests/managers/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,26 @@ def test_external_reference(memory_jobstore, clean_dir, simple_job):
assert responses[uuid2][1].output == "12345_end_end"
assert isinstance(responses[uuid2][1].job_dir, Path)
assert os.path.isdir(responses[uuid2][1].job_dir)


def test_resolve_references_both(memory_jobstore, clean_dir):
from jobflow import Flow, JobConfig, job, run_locally
from jobflow.core.reference import OutputReference, ResolvedReference

@job
def initial():
return 2

@job
def verify(arg):
assert isinstance(arg, ResolvedReference)
assert isinstance(arg.reference, OutputReference)
assert arg.value == 2
return arg.value * 2

job1 = initial()
job2 = verify(job1.output)
job2.config = JobConfig(resolve_references="both")

responses = run_locally(Flow([job1, job2]), store=memory_jobstore)
assert responses[job2.uuid][1].output == 4
Loading