From 5d3d815ff6e4aad649851e0917bc436e815df36a Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Tue, 10 Feb 2026 12:53:33 +0100 Subject: [PATCH 1/2] introducing flow output --- src/jobflow/core/flow.py | 196 ++++++++++++++++++++++++++++++---- src/jobflow/core/job.py | 71 ++++++++---- src/jobflow/core/maker.py | 26 +++++ src/jobflow/core/schemas.py | 45 +++++++- src/jobflow/core/store.py | 52 +++++++++ src/jobflow/managers/local.py | 70 ++++++++++-- src/jobflow/utils/graph.py | 35 ++++++ src/jobflow/utils/hosts.py | 68 ++++++++++++ 8 files changed, 518 insertions(+), 45 deletions(-) create mode 100644 src/jobflow/utils/hosts.py diff --git a/src/jobflow/core/flow.py b/src/jobflow/core/flow.py index 17fb2fd7..1f875515 100644 --- a/src/jobflow/core/flow.py +++ b/src/jobflow/core/flow.py @@ -7,13 +7,15 @@ from contextlib import contextmanager from contextvars import ContextVar from copy import deepcopy +from datetime import datetime from typing import TYPE_CHECKING from monty.json import MSONable import jobflow -from jobflow.core.reference import find_and_get_references +from jobflow.core.reference import OutputReference, find_and_get_references from jobflow.utils import ValueEnum, contains_flow_or_job, suid +from jobflow.utils.hosts import normalize_hosts if TYPE_CHECKING: from collections.abc import Iterator, Sequence @@ -21,7 +23,7 @@ from networkx import DiGraph - from jobflow import Job + from jobflow import Job, JobStore, Maker logger = logging.getLogger(__name__) @@ -134,9 +136,13 @@ def __init__( name: str = "Flow", order: JobOrder = JobOrder.AUTO, uuid: str = None, - hosts: list[str] = None, + hosts: list[tuple[str, int]] = None, metadata: dict[str, Any] = None, metadata_updates: list[dict[str, Any]] = None, + maker: Maker | None = None, + make_args: list | None = None, + make_kwargs: dict | None = None, + index: int = 1, ): from jobflow.core.job import Job @@ -149,13 +155,22 @@ def __init__( self.name = name self.order = order self.uuid = uuid - self.hosts = hosts or [] + self.index = index + self.hosts = normalize_hosts(hosts) self.metadata = metadata or {} self.metadata_updates = metadata_updates or [] self._jobs: tuple[Flow | Job, ...] = () self.add_jobs(jobs) + # keep track of the references to the real output and + # prepare an OutputReference for the Flow. self.output = output + # TODO output schema? + self._output_reference = OutputReference(self.uuid) + + self.maker = maker + self.make_args = make_args + self.make_kwargs = make_kwargs # If we're running inside a `DecoratedFlow`, add *this* Flow to the # context. @@ -232,6 +247,21 @@ def __hash__(self) -> int: """Get the hash of the flow.""" return hash(self.uuid) + def as_dict(self) -> dict: + """ + Create a JSON serializable dict representation of an object. + + Returns + ------- + dict + The serialized version of the object + """ + d = super().as_dict() + # replace the output with the dereferenced, otherwise it will store the + # reference to itself + d["output"] = self.output_dereferenced + return d + @property def jobs(self) -> tuple[Flow | Job, ...]: """ @@ -268,7 +298,7 @@ def output(self) -> Any: Any The output of the flow. """ - return self._output + return self._output_reference @output.setter def output(self, output: Any): @@ -304,6 +334,45 @@ def output(self, output: Any): ) self._output = output + @property + def output_reference(self) -> OutputReference: + """ + The Flow output reference. + + Returns + ------- + OutputReference + The Flow output reference. + """ + return self._output_reference + + @property + def output_dereferenced(self) -> Any: + """ + The references to real output of the Flow. + + Returns + ------- + Any + The references to real output of the Flow. + """ + return self._output + + def set_uuid_index(self, uuid: str, index: int) -> None: + """ + Set the UUID of the job. + + Parameters + ---------- + uuid + A UUID. + """ + for job_or_flow in self.jobs: + job_or_flow.replace_host((self.uuid, self.index), (uuid, index)) + self.uuid = uuid + self.index = index + self.output.set_uuid(uuid) + @property def job_uuids(self) -> tuple[str, ...]: """ @@ -322,6 +391,19 @@ def job_uuids(self) -> tuple[str, ...]: uuids.append(job.uuid) return tuple(uuids) + @property + def job_flow_uuids(self) -> tuple[str, ...]: + """ + Uuids of every Job contained in the Flow (including nested Flows). + + Returns + ------- + tuple[str] + The uuids of all Jobs in the Flow (including nested Flows). + """ + uuids: list[str] = [job.uuid for job in self] + return tuple(uuids) + @property def all_uuids(self) -> tuple[str, ...]: """ @@ -380,7 +462,30 @@ def graph(self) -> DiGraph: return graph @property - def host(self) -> str | None: + def hierarchy_tree(self) -> DiGraph: + """ + Generate a hierarchy tree of the elements in the Flow. + + Returns + ------- + DiGraph + The graph with the hierarchy tree. + """ + import networkx as nx + + sub_trees = [job.hierarchy_tree for job in self.jobs] + + tree = nx.compose_all(sub_trees) if sub_trees else nx.DiGraph() + + tree.add_node(self) + + for job in self.jobs: + tree.add_edge(self, job) + + return tree + + @property + def host(self) -> tuple[str, int] | None: """ UUID of the first Flow that contains this Flow. @@ -391,6 +496,29 @@ def host(self) -> str | None: """ return self.hosts[0] if self.hosts else None + def replace_host(self, old_host: tuple[str, int], new_host: tuple[str, int]): + """ + Replace the uuid of an host if present. + + Applied also to all the inner Jobs/Flows. + + Parameters + ---------- + old_host + The host to be replaced, + new_host + The new host. + """ + old_host = tuple(old_host) # type: ignore + new_host = tuple(new_host) # type: ignore + try: + i = self.hosts.index(old_host) + self.hosts[i] = new_host + for job in self.jobs: + job.replace_host(old_host, new_host) + except ValueError: + pass + def draw_graph(self, **kwargs): """ Draw the flow graph using matplotlib. @@ -783,7 +911,9 @@ def update_config( ) def add_hosts_uuids( - self, hosts_uuids: str | list[str] = None, prepend: bool = False + self, + hosts: tuple[str, int] | list[tuple[str, int]] = None, + prepend: bool = False, ): """ Add a list of UUIDs to the internal list of hosts. @@ -797,23 +927,22 @@ def add_hosts_uuids( Parameters ---------- - hosts_uuids + hosts A list of UUIDs to add. If None the current uuid of the flow will be added to the inner Flows and Jobs. prepend Insert the UUIDs at the beginning of the list rather than extending it. """ - if hosts_uuids is not None: - if not isinstance(hosts_uuids, (list, tuple)): - hosts_uuids = [hosts_uuids] + hosts = normalize_hosts(hosts) + if hosts: if prepend: - self.hosts[0:0] = hosts_uuids + self.hosts[0:0] = hosts else: - self.hosts.extend(hosts_uuids) + self.hosts.extend(hosts) else: - hosts_uuids = [self.uuid] + hosts = [(self.uuid, self.index)] for job in self: - job.add_hosts_uuids(hosts_uuids, prepend=prepend) + job.add_hosts_uuids(hosts, prepend=prepend) def add_jobs(self, jobs: Job | Flow | Sequence[Flow | Job]) -> None: """ @@ -831,9 +960,9 @@ def add_jobs(self, jobs: Job | Flow | Sequence[Flow | Job]) -> None: jobs = [jobs] # type: ignore[list-item] job_ids = set(self.all_uuids) - hosts = [self.uuid, *self.hosts] + hosts = [(self.uuid, self.index), *self.hosts] for job in jobs: - if job.host is not None and job.host != self.uuid: + if job.host is not None and tuple(job.host) != (self.uuid, self.index): raise ValueError( f"{type(job).__name__} {job.name} ({job.uuid}) already belongs " f"to another flow: {job.host}." @@ -850,7 +979,7 @@ def add_jobs(self, jobs: Job | Flow | Sequence[Flow | Job]) -> None: f"current Flow ({self.uuid})" ) job_ids.add(job.uuid) - if job.host != self.uuid: + if not job.host or tuple(job.host) != (self.uuid, self.index): job.add_hosts_uuids(hosts) self._jobs += tuple(jobs) @@ -919,7 +1048,7 @@ def get_flow( # ensure that we have all the jobs needed to resolve the reference connections job_references = find_and_get_references(flow.jobs) job_reference_uuids = {ref.uuid for ref in job_references} - missing_jobs = job_reference_uuids.difference(set(flow.job_uuids)) + missing_jobs = job_reference_uuids.difference(set(flow.all_uuids)) if len(missing_jobs) > 0: raise ValueError( "The following jobs were not found in the jobs array and are needed to " @@ -1030,3 +1159,32 @@ def flow_build_context(children_list): _current_flow_context = ContextVar("current_flow_context", default=None) + + +def store_flow_output(store: JobStore, flow: Flow): + """ + Add the output of a Flow to the Store. + + Parameters + ---------- + store + flow + """ + from jobflow.core.schemas import JobStoreDocument, MakerData + + maker = None + if flow.maker: + maker = MakerData( + maker=flow.maker, args=flow.make_args, kwargs=flow.make_kwargs + ) + data: JobStoreDocument = JobStoreDocument( + uuid=flow.uuid, + index=flow.index, + output=flow.output_dereferenced, + completed_at=datetime.now().isoformat(), + metadata=flow.metadata, + hosts=flow.hosts, + name=flow.name, + maker=maker, + ) + store.update(data, key=["uuid", "index"]) diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index aa084923..72e65d1a 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -13,6 +13,7 @@ from jobflow.core.flow import _current_flow_context from jobflow.core.reference import OnMissing, OutputReference +from jobflow.utils.hosts import normalize_hosts from jobflow.utils.uid import suid if typing.TYPE_CHECKING: @@ -331,7 +332,7 @@ def __init__( name: str = None, metadata: dict[str, Any] = None, config: JobConfig = None, - hosts: list[str] = None, + hosts: list[tuple[str, int]] = None, metadata_updates: list[dict[str, Any]] = None, config_updates: list[dict[str, Any]] = None, name_updates: list[dict[str, Any]] = None, @@ -356,7 +357,7 @@ def __init__( self.name = name self.metadata = metadata or {} self.config = config - self.hosts = hosts or [] + self.hosts = normalize_hosts(hosts) self.metadata_updates = metadata_updates or [] self.name_updates = name_updates or [] self.config_updates = config_updates or [] @@ -548,6 +549,22 @@ def graph(self) -> DiGraph: graph.add_edges_from(edges) return graph + @property + def hierarchy_tree(self) -> DiGraph: + """ + Generate the Job node of the hierarchy tree. + + Returns + ------- + DiGraph + The graph with the job node. + """ + from networkx import DiGraph + + tree = DiGraph() + tree.add_node(self) + return tree + @property def host(self): """ @@ -560,6 +577,25 @@ def host(self): """ return self.hosts[0] if self.hosts else None + def replace_host(self, old_host: tuple[str, int], new_host: tuple[str, int]): + """ + Replace the uuid of an host if present. + + Parameters + ---------- + old_host + The host to be replaced, + new_host + The new host. + """ + old_host = tuple(old_host) # type: ignore + new_host = tuple(new_host) # type: ignore + try: + i = self.hosts.index(old_host) + self.hosts[i] = new_host + except ValueError: + pass + def set_uuid(self, uuid: str) -> None: """ Set the UUID of the job. @@ -1207,7 +1243,11 @@ def __setattr__(self, key, value): else: super().__setattr__(key, value) - def add_hosts_uuids(self, hosts_uuids: str | Sequence[str], prepend: bool = False): + def add_hosts_uuids( + self, + hosts: tuple[str, int] | Sequence[tuple[str, int]], + prepend: bool = False, + ): """ Add a list of UUIDs to the internal list of hosts. @@ -1217,17 +1257,16 @@ def add_hosts_uuids(self, hosts_uuids: str | Sequence[str], prepend: bool = Fals Parameters ---------- - hosts_uuids + hosts A list of UUIDs to add. prepend Insert the UUIDs at the beginning of the list rather than extending it. """ - if isinstance(hosts_uuids, str): - hosts_uuids = [hosts_uuids] + hosts = normalize_hosts(hosts) if prepend: - self.hosts[0:0] = hosts_uuids + self.hosts[0:0] = hosts else: - self.hosts.extend(hosts_uuids) + self.hosts.extend(hosts) # For type checking, the Response output type can be specified @@ -1423,16 +1462,12 @@ def prepare_replace( replace = Flow(jobs=replace) if isinstance(replace, Flow) and replace.output is not None: - # add a job with same UUID as the current job to store the outputs of the - # flow; this job will inherit the metadata and output schema of the current - # job - store_output_job = store_inputs(replace.output) - store_output_job.set_uuid(current_job.uuid) - store_output_job.index = current_job.index + 1 - store_output_job.metadata = current_job.metadata - store_output_job.output_schema = current_job.output_schema - store_output_job._kwargs = current_job._kwargs - replace.add_jobs(store_output_job) + replace.set_uuid_index(current_job.uuid, current_job.index + 1) + # replace.index = current_job.index + 1 + + metadata = replace.metadata + metadata.update(current_job.metadata) + replace.metadata = metadata elif isinstance(replace, Job): # replace is a single Job diff --git a/src/jobflow/core/maker.py b/src/jobflow/core/maker.py index f45935d3..fde8f9bd 100644 --- a/src/jobflow/core/maker.py +++ b/src/jobflow/core/maker.py @@ -13,6 +13,25 @@ import jobflow +from functools import wraps + + +def _make(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + from jobflow.core.flow import Flow + + result = func(self, *args, **kwargs) + if isinstance(result, Flow): + result.maker = self + result.make_args = args + result.make_kwargs = kwargs + + return result + + return wrapper + + @dataclass class Maker(MSONable): """ @@ -118,6 +137,13 @@ class Maker(MSONable): >>> double_add_job = maker.make(1, 2) """ + def __init_subclass__(cls, **kwargs): + """Init subclass.""" + super().__init_subclass__(**kwargs) + + if hasattr(cls, "make") and callable(cls.make): + cls.make = _make(cls.make) + def make(self, *args, **kwargs) -> jobflow.Flow | jobflow.Job: """Make a job or a flow - must be overridden with a concrete implementation.""" raise NotImplementedError diff --git a/src/jobflow/core/schemas.py b/src/jobflow/core/schemas.py index 09c84107..46ef8463 100644 --- a/src/jobflow/core/schemas.py +++ b/src/jobflow/core/schemas.py @@ -1,8 +1,46 @@ """A Pydantic model for Jobstore document.""" +from __future__ import annotations + from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_serializer + +from jobflow import Maker + + +class MakerData(BaseModel): + """A Pydantic model for the Maker data.""" + + maker: Maker = Field( + description="The instance of the Maker used to generate the Job/Flow" + ) + args: list = Field(description="The args passed to the make method of the Maker") + kwargs: dict = Field( + description="The kwargs passed to the make method of the Maker" + ) + + @field_serializer("maker", mode="plain") + def ser_maker(self, value: Any) -> Any: + """Serialize the Maker object to prevent pydantic serialization.""" + # serialize the object manually, otherwise pydantic always converts it to + # the standard dataclass serialization. + if isinstance(value, Maker): + return value.as_dict() + return value + + def make(self): + """ + Generate the object from the Maker using the arguments. + + Returns + ------- + Flow or Job + The generated Flow or Job. + """ + args = self.args or [] + kwargs = self.kwargs or {} + return self.maker.make(*args, **kwargs) class JobStoreDocument(BaseModel): @@ -24,7 +62,7 @@ class JobStoreDocument(BaseModel): None, description="Metadata information supplied by the user.", ) - hosts: list[str] = Field( + hosts: list[tuple[str, int]] = Field( None, description="The list of UUIDs of the hosts containing the job.", ) @@ -32,3 +70,6 @@ class JobStoreDocument(BaseModel): None, description="The name of the job.", ) + maker: MakerData | None = Field( + None, description="The information of the Maker used to generate the Job/Flow" + ) diff --git a/src/jobflow/core/store.py b/src/jobflow/core/store.py index e2b5ddb1..444df547 100644 --- a/src/jobflow/core/store.py +++ b/src/jobflow/core/store.py @@ -559,6 +559,58 @@ def get_output( results, self, cache=cache, on_missing=on_missing ) + def get_output_from_criteria( + self, + criteria: dict | None = None, + sort: dict[str, Sort | int] = None, + load: load_type = False, + ): + """ + Get the output of a job based on a search criteria. + + Note that, unlike :obj:`JobStore.query`, this function will automatically + try to resolve any output references in the job outputs. + + Parameters + ---------- + criteria + PyMongo filter for documents to search. + load + Which items to load from additional stores. Setting to ``True`` will load + all items stored in additional stores. See the ``JobStore`` constructor for + more details. + sort + Dictionary of sort order for fields. Keys are field names and values are 1 + for ascending or -1 for descending. + + Returns + ------- + Any + The output for the selected job. + """ + from jobflow.core.reference import ( + find_and_get_references, + find_and_resolve_references, + ) + + result = self.query_one( + criteria=criteria, + properties=["output", "uuid"], + sort=sort, + load=load, + ) + + if result is None: + raise ValueError(f"No result from criteria {criteria}") + + refs = find_and_get_references(result["output"]) + if any(ref.uuid == result["uuid"] for ref in refs): + raise RuntimeError("Reference cycle detected - aborting.") + + return find_and_resolve_references( + result["output"], self, on_missing=OnMissing.ERROR + ) + @classmethod def from_file(cls, db_file: str | Path, **kwargs) -> Self: """ diff --git a/src/jobflow/managers/local.py b/src/jobflow/managers/local.py index d64b59cd..f7515015 100644 --- a/src/jobflow/managers/local.py +++ b/src/jobflow/managers/local.py @@ -5,6 +5,9 @@ import logging import typing +from jobflow import Response +from jobflow.core.flow import store_flow_output + if typing.TYPE_CHECKING: from pathlib import Path @@ -65,11 +68,13 @@ def run_locally( from pathlib import Path from random import randint + import networkx as nx from monty.os import cd from jobflow import SETTINGS, initialize_logger from jobflow.core.flow import get_flow from jobflow.core.reference import OnMissing + from jobflow.utils.graph import build_hierarchy_graph if store is None: store = SETTINGS.JOB_STORE @@ -89,8 +94,13 @@ def run_locally( responses: dict[str, dict[int, jobflow.Response]] = defaultdict(dict) stop_jobflow = False + processed: set[tuple[str, int]] = set() + full_tree = build_hierarchy_graph(flow) + flow_job_refs: dict[tuple[str, int], jobflow.Flow | jobflow.Job] = {} + def _run_job(job: jobflow.Job, parents): nonlocal stop_jobflow + nonlocal full_tree if stop_jobflow: return None, True @@ -103,12 +113,24 @@ def _run_job(job: jobflow.Job, parents): stopped_parents.add(job.uuid) return None, False - if ( - len(set(parents).intersection(errored)) > 0 - and job.config.on_missing_references == OnMissing.ERROR - ): - errored.add(job.uuid) - return None, False + # handle the case where a job should not be executed if not all + # the references are available. + if job.config.on_missing_references == OnMissing.ERROR: + # avoid further checks if can it is already know that references will + # not be available + if len(set(parents).intersection(errored)) > 0: + errored.add(job.uuid) + return None, False + try: + # Try to explicitly resolve references to check if possible. + # This prevents failures due to previous jobs containing further + # references. + # References are set inplace, so this does not require + # fetching the references more than once. + job.resolve_args(store=store) + except ValueError: + errored.add(job.uuid) + return None, False if raise_immediately: response = job.run(store=store) @@ -138,14 +160,21 @@ def _run_job(job: jobflow.Job, parents): diversion_responses = [] if response.replace is not None: + full_tree = build_hierarchy_graph( + response.replace, hierarchy_tree=full_tree + ) # first run any restarts diversion_responses.append(_run(response.replace)) if response.detour is not None: + full_tree = build_hierarchy_graph(response.detour, hierarchy_tree=full_tree) # next any detours diversion_responses.append(_run(response.detour)) if response.addition is not None: + full_tree = build_hierarchy_graph( + response.addition, hierarchy_tree=full_tree + ) # finally any additions diversion_responses.append(_run(response.addition)) @@ -161,8 +190,31 @@ def _get_job_dir(): return job_dir return root_dir + def _check_complete_flows(job): + # iterate over the hosts of a Job and complete the Flow if + # all its children have been processed. + for host in job.hosts: + host = tuple(host) # noqa: PLW2901 + descendants_ids = nx.descendants(full_tree, host) + if descendants_ids.issubset(processed): + host_flow = flow_job_refs[host] + store_flow_output(store, host_flow) + processed.add(host) + responses[host[0]][host[1]] = Response( + output=host_flow.output_dereferenced + ) + logger.info(f"Completing Flow - {host_flow.name} ({host[0]} {host[1]})") + else: + # if the current flow is not completed do not go up in the hosts + break + def _run(root_flow): encountered_bad_response = False + + # build a lookup map matching the Jobs/Flows to their uuid/index + for n in root_flow.hierarchy_tree.nodes: + flow_job_refs[(n.uuid, n.index)] = n + for job, parents in root_flow.iterflow(): job_dir = _get_job_dir() with cd(job_dir): @@ -174,6 +226,12 @@ def _run(root_flow): if jobflow_stopped: return False + # Always set a Job as processes, even if an error happened. + # The containing Flow will be completed once all the Jobs are processed. + # If not, in case of replace references to a specific uuid may fetch the + # already existing output with the wrong index. + processed.add((job.uuid, job.index)) + _check_complete_flows(job) return not encountered_bad_response logger.info("Started executing jobs locally") diff --git a/src/jobflow/utils/graph.py b/src/jobflow/utils/graph.py index 8c05f5b2..f5c9a4ff 100644 --- a/src/jobflow/utils/graph.py +++ b/src/jobflow/utils/graph.py @@ -249,3 +249,38 @@ def add_subgraph(nested_flow, indent_level=1): add_subgraph(flow) return "\n".join(lines) + + +def build_hierarchy_graph(flow_or_job, hierarchy_tree=None) -> nx.DiGraph: + """ + Build a hierarchy graph with (uuid, index) of Flow or Job. + + Optionally add it to an already existing tree, + + Parameters + ---------- + flow_or_job + A Flow or Job to build the hierarchy tree. + hierarchy_tree + A hierarchy tree to which the generated one will be added. + + Returns + ------- + DiGraph + The graph of the hierarchy tree. + """ + from jobflow import Flow + + if hierarchy_tree is None: + hierarchy_tree = nx.DiGraph() + if isinstance(flow_or_job, Flow): + iterator = flow_or_job.iterflow() + else: + iterator = [flow_or_job] + for job, _ in iterator: + hierarchy_tree.add_nodes_from(tuple(uuid_index) for uuid_index in job.hosts) + hierarchy_tree.add_node((job.uuid, job.index)) + for n1, n2 in zip([(job.uuid, job.index), *job.hosts], job.hosts): + hierarchy_tree.add_edge(n2, n1) + + return hierarchy_tree diff --git a/src/jobflow/utils/hosts.py b/src/jobflow/utils/hosts.py new file mode 100644 index 00000000..1a5ed8d8 --- /dev/null +++ b/src/jobflow/utils/hosts.py @@ -0,0 +1,68 @@ +"""Tools for managing hosts.""" + +from collections.abc import Sequence +from typing import TypeAlias + +HostPairInput: TypeAlias = Sequence[str | int] + +HostsInput: TypeAlias = HostPairInput | Sequence[HostPairInput] +# +# +# def normalize_hosts(hosts: HostsInput) -> list[tuple[str, int]]: +# def to_pair(item: HostPairInput) -> tuple[str, int]: +# a, b = item +# if not isinstance(a, str) or not isinstance(b, int): +# raise TypeError(f"Invalid host pair: {item}") +# return (a, b) +# +# # single pair: ("x",1) or ["x",1] +# if ( +# isinstance(hosts, (list, tuple)) +# and len(hosts) == 2 +# and isinstance(hosts[0], str) +# and isinstance(hosts[1], int) +# ): +# return [to_pair(hosts)] # mypy is happy +# +# # list of pairs +# if isinstance(hosts, list): +# return [to_pair(x) for x in hosts] +# +# # fallback (shouldn't happen with declared types) +# raise TypeError("Invalid hosts input") + + +def normalize_hosts( + hosts: HostsInput, +) -> list[tuple[str, int]]: + """ + Normalize various host formats into a list of (str, int) tuples. + + Parameters + ---------- + hosts + The hosts to be normalized. + + Returns + ------- + A list of (str, int) tuples + """ + if not hosts: + return [] + + # If it's a single tuple, wrap it in a list + if isinstance(hosts, tuple): + return [hosts] + + if isinstance(hosts, list): + first_item = hosts[0] + + # Single host as flat list: [str, int] + if isinstance(first_item, str): + return [(hosts[0], hosts[1])] # type: ignore + + # List of tuples or lists: [(str, int), ...] or [[str, int], ...] + if isinstance(first_item, (tuple, list)): + return [tuple(item) if isinstance(item, list) else item for item in hosts] # type: ignore + + raise TypeError(f"Unsupported type: {type(hosts)}") From e5c4e7e51f69eeec75e5bfe94e41a5d374f8842c Mon Sep 17 00:00:00 2001 From: Guido Petretto Date: Mon, 23 Feb 2026 18:29:19 +0100 Subject: [PATCH 2/2] graph for flow output --- src/jobflow/core/flow.py | 100 +++++++++++++++++++++++++++++++++++-- src/jobflow/core/job.py | 15 +++++- src/jobflow/utils/graph.py | 5 +- 3 files changed, 112 insertions(+), 8 deletions(-) diff --git a/src/jobflow/core/flow.py b/src/jobflow/core/flow.py index 1f875515..88093d69 100644 --- a/src/jobflow/core/flow.py +++ b/src/jobflow/core/flow.py @@ -328,7 +328,7 @@ def output(self, output: Any): references = find_and_get_references(output) reference_uuids = {ref.uuid for ref in references} - if not reference_uuids.issubset(set(self.job_uuids)): + if not reference_uuids.issubset(set(self.all_uuids)): raise ValueError( "jobs array does not contain all jobs needed for flow output" ) @@ -358,6 +358,38 @@ def output_dereferenced(self) -> Any: """ return self._output + @property + def output_references(self) -> tuple[jobflow.OutputReference, ...]: + """ + Find :obj:`.OutputReference` objects in the flow output. + + Returns + ------- + tuple(OutputReference, ...) + The references in the flow output. + """ + from jobflow.core.reference import find_and_get_references + + return find_and_get_references(self.output_dereferenced) + + @property + def output_references_grouped(self) -> dict[str, tuple[OutputReference, ...]]: + """ + Group any :obj:`.OutputReference` objects in the flow outputs by their UUIDs. + + Returns + ------- + dict[str, tuple(OutputReference, ...)] + The references grouped by their UUIDs. + """ + from collections import defaultdict + + groups = defaultdict(set) + for ref in self.output_references: + groups[ref.uuid].add(ref) + + return {k: tuple(v) for k, v in groups.items()} + def set_uuid_index(self, uuid: str, index: int) -> None: """ Set the UUID of the job. @@ -461,6 +493,62 @@ def graph(self) -> DiGraph: graph.add_edges_from(edges) return graph + @property + def full_graph(self) -> DiGraph: + """ + Get a graph indicating the connectivity of jobs and subflows in the flow. + + Returns + ------- + DiGraph + The graph showing the connectivity of the jobs. + """ + from itertools import product + + import networkx as nx + + graph = nx.compose_all([job.full_graph for job in self]) + + for node in graph: + node_props = graph.nodes[node] + if all(k not in node_props for k in ("job", "label")): + nx.set_node_attributes(graph, {node: {"label": "external"}}) + + graph.add_node(self.uuid, flow=self, label=self.name) + edges = [] + for uuid, refs in self.output_references_grouped.items(): + properties: list[str] | str = [ + ref.attributes_formatted[-1] + .replace("[", "") + .replace("]", "") + .replace(".", "") + for ref in refs + if ref.attributes + ] + properties = properties[0] if len(properties) == 1 else properties + properties = properties if len(properties) > 0 else "output" + edges.append((uuid, self.uuid, {"properties": properties})) + graph.add_edges_from(edges) + + if self.order == JobOrder.LINEAR: + # add fake edges between jobs to force linear order + edges = [] + for job_a, job_b in nx.utils.pairwise(self): + if isinstance(job_a, Flow): + leaves = [v for v, d in job_a.graph.out_degree() if d == 0] + else: + leaves = [job_a.uuid] + + if isinstance(job_b, Flow): + roots = [v for v, d in job_b.graph.in_degree() if d == 0] + else: + roots = [job_b.uuid] + + for leaf, root in product(leaves, roots): + edges.append((leaf, root, {"properties": ""})) + graph.add_edges_from(edges) + return graph + @property def hierarchy_tree(self) -> DiGraph: """ @@ -539,7 +627,7 @@ def draw_graph(self, **kwargs): return draw_graph(self.graph, **kwargs) - def iterflow(self): + def iterflow(self, include_flow_outputs: bool = False): """ Iterate through the jobs of the flow. @@ -548,7 +636,7 @@ def iterflow(self): Yields ------ - Job, list[str] + Job | Flow, list[str] The Job and the uuids of any parent jobs (not to be confused with the host flow). """ @@ -556,7 +644,7 @@ def iterflow(self): from jobflow.utils.graph import itergraph - graph = self.graph + graph = self.full_graph if not is_directed_acyclic_graph(graph): raise ValueError( @@ -566,7 +654,9 @@ def iterflow(self): for node in itergraph(graph): parents = [u for u, v in graph.in_edges(node) if "job" in graph.nodes[u]] - if "job" not in graph.nodes[node]: + if "job" not in graph.nodes[node] and ( + ("flow" not in graph.nodes[node]) or (not include_flow_outputs) + ): continue job = graph.nodes[node]["job"] yield job, parents diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index 72e65d1a..c631a244 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -549,6 +549,18 @@ def graph(self) -> DiGraph: graph.add_edges_from(edges) return graph + @property + def full_graph(self) -> DiGraph: + """ + Get a graph of the job indicating the inputs to the job. + + Returns + ------- + DiGraph + The graph showing the connectivity of the jobs. + """ + return self.graph + @property def hierarchy_tree(self) -> DiGraph: """ @@ -1463,11 +1475,12 @@ def prepare_replace( if isinstance(replace, Flow) and replace.output is not None: replace.set_uuid_index(current_job.uuid, current_job.index + 1) - # replace.index = current_job.index + 1 metadata = replace.metadata metadata.update(current_job.metadata) replace.metadata = metadata + if replace.name == "Flow": + replace.name = current_job.name elif isinstance(replace, Job): # replace is a single Job diff --git a/src/jobflow/utils/graph.py b/src/jobflow/utils/graph.py index f5c9a4ff..58402c6a 100644 --- a/src/jobflow/utils/graph.py +++ b/src/jobflow/utils/graph.py @@ -219,10 +219,11 @@ def to_mermaid(flow: jobflow.Flow | jobflow.Job, show_flow_boxes: bool = False) flow = Flow(jobs=[flow]) lines = ["flowchart TD"] - nodes = flow.graph.nodes(data=True) + graph = flow.full_graph + nodes = graph.nodes(data=True) # add edges - for u, v, d in flow.graph.edges(data=True): + for u, v, d in graph.edges(data=True): if isinstance(d["properties"], list): props = ", ".join(d["properties"]) else: