diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 10d73d7..1bb69fe 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -11,6 +11,7 @@ from .callable import * from .context import * from .enums import Enum +from .flow_model import * from .global_state import * from .local_persistence import * from .models import * diff --git a/ccflow/callable.py b/ccflow/callable.py index b09eaea..9ea7a2f 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -12,10 +12,11 @@ """ import abc +import inspect import logging from functools import lru_cache, wraps from inspect import Signature, isclass, signature -from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator from typing_extensions import override @@ -27,8 +28,12 @@ ResultBase, ResultType, ) +from .local_persistence import create_ccflow_model from .validators import str_to_log_level +if TYPE_CHECKING: + from .flow_model import FlowAPI + __all__ = ( "GraphDepType", "GraphDepList", @@ -60,6 +65,25 @@ def _cached_signature(fn): return signature(fn) +def _callable_qualname(fn: Callable[..., Any]) -> str: + return getattr(fn, "__qualname__", type(fn).__qualname__) + + +def _declared_type_matches(actual: Any, expected: Any) -> bool: + if isinstance(expected, TypeVar): + return True + if get_origin(expected) is Union: + expected_args = tuple(arg for arg in get_args(expected) if isinstance(arg, type)) + if not expected_args: + return False + if get_origin(actual) is Union: + actual_args = tuple(arg for arg in get_args(actual) if isinstance(arg, type)) + return set(actual_args) == set(expected_args) + return isinstance(actual, type) and any(issubclass(actual, arg) for arg in expected_args) + + return isinstance(actual, type) and isinstance(expected, type) and issubclass(actual, expected) + + class MetaData(BaseModel): """Class to represent metadata for all callable models""" @@ -126,7 +150,7 @@ def _check_result_type(cls, result_type): @model_validator(mode="after") def _check_signature(self): sig_call = _cached_signature(self.__class__.__call__) - if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: # ("self", "context") + if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: raise ValueError("__call__ method must take a single argument, named 'context'") sig_deps = _cached_signature(self.__class__.__deps__) @@ -268,14 +292,31 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs): if not isinstance(model, CallableModel): raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.") - if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( - get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) - ): - raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") - if (not isclass(model.result_type) or not issubclass(model.result_type, ResultBase)) and not ( - get_origin(model.result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(model.result_type)) + + # Check if this is an auto_context decorated method + has_auto_context = hasattr(fn, "__auto_context__") + if has_auto_context: + method_context_type = fn.__auto_context__ + else: + method_context_type = model.context_type + + # Validate context type (skip for auto contexts which are always valid ContextBase subclasses) + if not has_auto_context: + if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( + get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) + ): + raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") + + # Validate result type - use __result_type__ for auto contexts if available + if has_auto_context and hasattr(fn, "__result_type__"): + method_result_type = fn.__result_type__ + else: + method_result_type = model.result_type + if (not isclass(method_result_type) or not issubclass(method_result_type, ResultBase)) and not ( + get_origin(method_result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(method_result_type)) ): - raise TypeError(f"Result type {model.result_type} must be a subclass of ResultBase") + raise TypeError(f"Result type {method_result_type} must be a subclass of ResultBase") + if self._deps and fn.__name__ != "__deps__": raise ValueError("Can only apply Flow.deps decorator to __deps__") if context is Signature.empty: @@ -285,18 +326,18 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = context = kwargs else: raise TypeError( - f"{fn.__name__}() missing 1 required positional argument: 'context' of type {model.context_type}, or kwargs to construct it" + f"{fn.__name__}() missing 1 required positional argument: 'context' of type {method_context_type}, or kwargs to construct it" ) elif kwargs: # Kwargs passed in as well as context. Not allowed raise TypeError(f"{fn.__name__}() was passed a context and got an unexpected keyword argument '{next(iter(kwargs.keys()))}'") # Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message - if not isinstance(context, model.context_type): - if get_origin(model.context_type) is Union and type(None) in get_args(model.context_type): - model_context_type = [t for t in get_args(model.context_type) if t is not type(None)][0] + if not isinstance(context, method_context_type): + if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type): + coerce_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0] else: - model_context_type = model.context_type - context = model_context_type.model_validate(context) + coerce_context_type = method_context_type + context = coerce_context_type.model_validate(context) if fn != getattr(model.__class__, fn.__name__).__wrapped__: # This happens when super().__call__ is used when implementing a CallableModel that derives from another one. @@ -310,9 +351,17 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = return result wrap = wraps(fn)(wrapper) - wrap.get_evaluator = self.get_evaluator - wrap.get_options = self.get_options - wrap.get_evaluation_context = get_evaluation_context + wrap_any = cast(Any, wrap) + wrap_any.get_evaluator = self.get_evaluator + wrap_any.get_options = self.get_options + wrap_any.get_evaluation_context = get_evaluation_context + + # Preserve auto context attributes for introspection + if hasattr(fn, "__auto_context__"): + wrap_any.__auto_context__ = fn.__auto_context__ + if hasattr(fn, "__result_type__"): + wrap_any.__result_type__ = fn.__result_type__ + return wrap @@ -391,7 +440,59 @@ def __exit__(self, exc_type, exc_value, exc_tb): class Flow(PydanticBaseModel): @staticmethod def call(*args, **kwargs): - """Decorator for methods on callable models""" + """Decorator for methods on callable models. + + Args: + auto_context: Controls automatic context class generation from the function + signature. Accepts three types of values: + - False (default): No auto-generation, use traditional context parameter + - True: Auto-generate context class with no parent + - ContextBase subclass: Auto-generate context class inheriting from this parent + **kwargs: Additional FlowOptions parameters (log_level, verbose, validate_result, + cacheable, evaluator, volatile). + + Basic Example: + class MyModel(CallableModel): + @Flow.call + def __call__(self, context: MyContext) -> MyResult: + return MyResult(value=context.x) + + Auto Context Example: + class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> MyResult: + return MyResult(value=f"{x}-{y}") + + model = MyModel() + model(x=42) # Call with kwargs directly + + With Parent Context: + class MyModel(CallableModel): + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date, extra: int = 0) -> MyResult: + return MyResult(value=date.day + extra) + + # The generated context inherits from DateContext, so it's compatible + # with infrastructure expecting DateContext instances. + + """ + # Extract auto_context option (not part of FlowOptions) + # Can be: False, True, or a ContextBase subclass + auto_context = kwargs.pop("auto_context", False) + + # Determine if auto_context is enabled and extract parent class if provided + if auto_context is False: + auto_context_enabled = False + context_parent = None + elif auto_context is True: + auto_context_enabled = True + context_parent = None + elif isclass(auto_context) and issubclass(auto_context, ContextBase): + auto_context_enabled = True + context_parent = auto_context + else: + raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}") + if len(args) == 1 and callable(args[0]): # No arguments to decorator, this is the decorator fn = args[0] @@ -400,6 +501,14 @@ def call(*args, **kwargs): else: # Arguments to decorator, this is just returning the decorator # Note that the code below is executed only once + if auto_context_enabled: + # Return a decorator that first applies auto_context, then FlowOptions + def auto_context_decorator(fn: Callable[..., Any]) -> Callable[..., Any]: + wrapped = _apply_auto_context(fn, parent=context_parent) + # FlowOptions.__call__ already applies wraps, so we just return its result + return FlowOptions(**kwargs)(wrapped) + + return auto_context_decorator return FlowOptions(**kwargs) @staticmethod @@ -417,6 +526,69 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def model(*args, **kwargs): + """Decorator that generates a CallableModel class from a plain Python function. + + This is syntactic sugar over CallableModel. The decorator generates a real + CallableModel class with proper __call__ and __deps__ methods, so all existing + features (caching, evaluation, registry, serialization) work unchanged. + + Args: + context_type: Optional ContextBase subclass used only to validate/coerce + `FromContext[...]` inputs against an existing nominal context shape + auto_unwrap: When True, `.flow.compute(...)` unwraps auto-wrapped + `GenericResult(value=...)` outputs back to the annotated return type. + Explicit `ResultBase` returns are left unchanged. Default: False. + model_base: Optional custom `CallableModel` subclass to use as an + additional base for the generated model class. + cacheable: Enable caching of results (default: False) + volatile: Mark as volatile (default: False) + log_level: Logging verbosity (default: logging.DEBUG) + validate_result: Validate return type (default: True) + verbose: Verbose logging output (default: True) + evaluator: Custom evaluator (default: None) + + Primary authoring model: + Mark runtime/contextual inputs explicitly with `FromContext[...]`. + Ordinary unmarked parameters are regular bound inputs and are never + read implicitly from the runtime context. + + @Flow.model + def load_prices( + source: str, + start_date: FromContext[date], + end_date: FromContext[date], + ) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, start_date, end_date)) + + + Dependencies: + Any ordinary parameter can be bound either to a literal value or + to another CallableModel. When a CallableModel is supplied, the + generated model treats it as an upstream dependency and resolves it + with the current context before calling the underlying function. + + `FromContext[...]` parameters are different: they may be satisfied by + runtime context, construction-time contextual defaults, or function + defaults, but not by CallableModel values. + + Usage: + # Create model instances + loader = load_prices(source="prod_db") + returns = compute_returns(prices=loader) + + # Execute + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = returns(ctx) + + Returns: + A factory function that creates CallableModel instances + """ + from .flow_model import flow_model + + return flow_model(*args, **kwargs) + # ***************************************************************************** # Define "Evaluators" and associated types @@ -451,13 +623,15 @@ class ModelEvaluationContext( # TODO: Make the instance check compatible with the generic types instead of the base type @model_validator(mode="wrap") - def _context_validator(cls, values, handler, info): + @classmethod + def _context_validator(cls, values: Any, handler: Any, info: Any): """Override _context_validator from parent""" # Validate the context with the model, if possible - model = values.get("model") - if model and isinstance(model, CallableModel) and not isinstance(values.get("context"), model.context_type): - values["context"] = model.context_type.model_validate(values.get("context")) + if isinstance(values, dict): + model = values.get("model") + if model and isinstance(model, CallableModel) and not isinstance(values.get("context"), model.context_type): + values["context"] = model.context_type.model_validate(values.get("context")) # Apply standard pydantic validation context = handler(values) @@ -485,9 +659,9 @@ def __call__(self) -> ResultType: raise TypeError(f"Model result_type {result_type} is not a subclass of ResultBase") result = result_type.model_validate(result) - return result + return cast(ResultType, result) else: - return fn(self.context) + return cast(ResultType, fn(self.context)) class EvaluatorBase(_CallableModel, abc.ABC): @@ -575,7 +749,7 @@ def context_type(self) -> Type[ContextType]: if not isclass(type_to_check) or not issubclass(type_to_check, ContextBase): raise TypeError(f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received {type_to_check}.") - return typ + return cast(Type[ContextType], typ) @property def result_type(self) -> Type[ResultType]: @@ -618,7 +792,7 @@ def result_type(self) -> Type[ResultType]: # Ensure subclass of ResultBase if not isclass(typ) or not issubclass(typ, ResultBase): raise TypeError(f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received {typ}.") - return typ + return cast(Type[ResultType], typ) @Flow.deps def __deps__( @@ -634,6 +808,13 @@ def __deps__( """ return [] + @property + def flow(self) -> "FlowAPI": + """Access flow helpers for execution, context transforms, and introspection.""" + from .flow_model import FlowAPI + + return FlowAPI(self) + class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): """Abstract class that represents a wrapper around an underlying model, with the same context and return types. @@ -646,12 +827,12 @@ class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): @property def context_type(self) -> Type[ContextType]: """Return the context type of the underlying model.""" - return self.model.context_type + return cast(CallableModel, self.model).context_type @property def result_type(self) -> Type[ResultType]: """Return the result type of the underlying model.""" - return self.model.result_type + return cast(CallableModel, self.model).result_type class CallableModelGeneric(CallableModel, Generic[ContextType, ResultType]): @@ -723,34 +904,110 @@ def _determine_context_result(cls): if new_context_type is not None: # Set on class - cls._context_generic_type = new_context_type + setattr(cls, "_context_generic_type", new_context_type) if new_result_type is not None: # Set on class - cls._result_generic_type = new_result_type + setattr(cls, "_result_generic_type", new_result_type) @model_validator(mode="wrap") - def _validate_callable_model_generic_type(cls, m, handler, info): + @classmethod + def _validate_callable_model_generic_type(cls, m: Any, handler: Any, info: Any): from ccflow.base import resolve_str if isinstance(m, str): m = resolve_str(m) - if isinstance(m, dict): - m = handler(m) - elif isinstance(m, cls): - m = handler(m) + validated_cls = cast(Any, cls) + if isinstance(m, (dict, CallableModel)): + if isinstance(m, dict): + m = handler(m) + elif isinstance(m, validated_cls): + m = handler(m) # Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/ if not isinstance(m, CallableModel): raise ValueError(f"{m} is not a CallableModel: {type(m)}") subtypes = cls.__pydantic_generic_metadata__["args"] - if subtypes: - TypeAdapter(Type[subtypes[0]]).validate_python(m.context_type) - TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type) + if len(subtypes) >= 1 and not _declared_type_matches(m.context_type, subtypes[0]): + raise ValueError(f"{m} context_type {m.context_type} does not match {subtypes[0]}") + if len(subtypes) >= 2 and not _declared_type_matches(m.result_type, subtypes[1]): + raise ValueError(f"{m} result_type {m.result_type} does not match {subtypes[1]}") return m CallableModelGenericType = CallableModelGeneric + + +# ***************************************************************************** +# Auto Context (internal helper for Flow.call(auto_context=True)) +# ***************************************************************************** + + +def _apply_auto_context(func: Callable[..., Any], *, parent: Optional[Type[ContextBase]] = None) -> Callable[..., Any]: + """Internal function that creates an auto context class from function parameters. + + This function extracts the parameters from a function signature and creates + a ContextBase subclass whose fields correspond to those parameters. + The decorated function is then wrapped to accept the context object and + unpack it into keyword arguments. + + Used internally by Flow.call(auto_context=...). + + Example: + class MyCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = MyCallable() + model(x=42, y="hello") # Works with kwargs + """ + sig = signature(func) + base_class = parent or ContextBase + + if sig.return_annotation is inspect.Signature.empty: + raise TypeError(f"Function {_callable_qualname(func)} must have a return type annotation when auto_context=True") + + # Validate parent fields are in function signature + if parent is not None: + parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys()) + sig_params = set(sig.parameters.keys()) - {"self"} + missing = parent_fields - sig_params + if missing: + raise TypeError(f"Parent context fields {missing} must be included in function signature") + + # Build fields from parameters (skip 'self'), pydantic validates types + fields = {} + for name, param in sig.parameters.items(): + if name == "self": + continue + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError(f"Function {_callable_qualname(func)} does not support {param.kind.description} when auto_context=True") + if param.annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation when auto_context=True") + default = ... if param.default is inspect.Parameter.empty else param.default + fields[name] = (param.annotation, default) + + # Create auto context class + auto_context_class = create_ccflow_model(f"{_callable_qualname(func)}_AutoContext", __base__=base_class, **fields) + + @wraps(func) + def wrapper(self, context): + fn_kwargs = {name: getattr(context, name) for name in fields} + return func(self, **fn_kwargs) + + # Must set __signature__ so CallableModel validation sees 'context' parameter + wrapper_any = cast(Any, wrapper) + wrapper_any.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class), + ], + return_annotation=sig.return_annotation, + ) + wrapper_any.__auto_context__ = auto_context_class + wrapper_any.__result_type__ = sig.return_annotation + return wrapper diff --git a/ccflow/context.py b/ccflow/context.py index 9a04fad..ae69e22 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -1,16 +1,18 @@ """This module defines re-usable contexts for the "Callable Model" framework defined in flow.callable.py.""" +from collections.abc import Mapping from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated -from pydantic import field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from .base import ContextBase from .exttypes import Frequency from .validators import normalize_date, normalize_datetime __all__ = ( + "FlowContext", "NullContext", "GenericContext", "DateContext", @@ -89,6 +91,49 @@ # Starting 0.8.0 Nullcontext is an alias to ContextBase NullContext = ContextBase + +class FlowContext(ContextBase): + """Universal context for @Flow.model functions. + + Instead of generating a new ContextBase subclass for each @Flow.model, + this single class with extra="allow" serves as the universal carrier. + Validation happens via TypedDict + TypeAdapter at compute() time. + + This design avoids: + - Proliferation of dynamic _funcname_Context classes + - Class registration overhead for serialization + - Pickling issues with Ray/distributed computing + """ + + model_config = ConfigDict(extra="allow", frozen=True) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FlowContext): + return False + return self.model_dump(mode="python") == other.model_dump(mode="python") + + def __hash__(self) -> int: + return hash(_freeze_for_hash(self.model_dump(mode="python"))) + + +def _freeze_for_hash(value: Any) -> Hashable: + if isinstance(value, Mapping): + return tuple(sorted((key, _freeze_for_hash(item)) for key, item in value.items())) + if isinstance(value, (list, tuple)): + return tuple(_freeze_for_hash(item) for item in value) + if isinstance(value, (set, frozenset)): + return frozenset(_freeze_for_hash(item) for item in value) + if hasattr(value, "model_dump"): + return (type(value), _freeze_for_hash(value.model_dump(mode="python"))) + try: + hash(value) + except TypeError as exc: + if hasattr(value, "__dict__"): + return (type(value), _freeze_for_hash(vars(value))) + raise TypeError(f"FlowContext contains an unhashable value of type {type(value).__name__}") from exc + return value + + C = TypeVar("C", bound=Hashable) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py new file mode 100644 index 0000000..48c3da6 --- /dev/null +++ b/ccflow/flow_model.py @@ -0,0 +1,1061 @@ +"""Flow.model decorator implementation built around ``FromContext``.""" + +import hashlib +import inspect +import logging +import marshal +from dataclasses import dataclass +from functools import lru_cache, wraps +from types import UnionType +from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, get_type_hints + +from pydantic import Field, PrivateAttr, TypeAdapter, ValidationError, model_serializer, model_validator +from pydantic.errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation + +from .base import BaseModel, ContextBase, ResultBase +from .callable import CallableModel, Flow, GraphDepList, WrapperModel +from .context import FlowContext +from .local_persistence import register_ccflow_import_path +from .result import GenericResult + +__all__ = ("FlowAPI", "BoundModel", "FromContext", "Lazy") + +_AnyCallable = Callable[..., Any] +log = logging.getLogger(__name__) + + +class _UnsetFlowInput: + def __repr__(self) -> str: + return "" + + +class _InternalSentinel: + def __init__(self, name: str): + self._name = name + + def __repr__(self) -> str: + return self._name + + def __reduce__(self): + return (_get_internal_sentinel, (self._name,)) + + +_UNSET_FLOW_INPUT = _UnsetFlowInput() +_UNION_ORIGINS = (Union, UnionType) + + +def _get_internal_sentinel(name: str) -> _InternalSentinel: + return _INTERNAL_SENTINELS[name] + + +_INTERNAL_SENTINELS = { + "_UNSET": _InternalSentinel("_UNSET"), + "_REMOVED_CONTEXT_ARGS": _InternalSentinel("_REMOVED_CONTEXT_ARGS"), +} +_UNSET = _INTERNAL_SENTINELS["_UNSET"] +_REMOVED_CONTEXT_ARGS = _INTERNAL_SENTINELS["_REMOVED_CONTEXT_ARGS"] + + +def _unset_flow_input_factory() -> _UnsetFlowInput: + return _UNSET_FLOW_INPUT + + +def _is_unset_flow_input(value: Any) -> bool: + return value is _UNSET_FLOW_INPUT + + +class _LazyMarker: + pass + + +class _FromContextMarker: + pass + + +class FromContext: + """Marker used in ``@Flow.model`` signatures for runtime/contextual inputs.""" + + def __class_getitem__(cls, item): + return Annotated[item, _FromContextMarker()] + + +class Lazy: + """Lazy dependency marker used only as ``Lazy[T]`` in type annotations.""" + + def __new__(cls, *args, **kwargs): + raise TypeError("Lazy(model)(...) has been removed. Use model.flow.with_inputs(...) for contextual rewrites.") + + def __class_getitem__(cls, item): + return Annotated[item, _LazyMarker()] + + +@dataclass(frozen=True) +class _ParsedAnnotation: + base: Any + is_lazy: bool + is_from_context: bool + + +@dataclass(frozen=True) +class _FlowModelParam: + name: str + annotation: Any + kind: str + is_lazy: bool + has_function_default: bool + function_default: Any = _UNSET + context_validation_annotation: Any = _UNSET + + @property + def is_contextual(self) -> bool: + return self.kind == "contextual" + + @property + def validation_annotation(self) -> Any: + if self.context_validation_annotation is not _UNSET: + return self.context_validation_annotation + return self.annotation + + +@dataclass(frozen=True) +class _FlowModelConfig: + func: _AnyCallable + context_type: Type[ContextBase] + result_type: Type[ResultBase] + auto_wrap_result: bool + auto_unwrap: bool + parameters: Tuple[_FlowModelParam, ...] + context_input_types: Dict[str, Any] + context_required_names: Tuple[str, ...] + declared_context_type: Optional[Type[ContextBase]] = None + + @property + def regular_params(self) -> Tuple[_FlowModelParam, ...]: + return tuple(param for param in self.parameters if not param.is_contextual) + + @property + def contextual_params(self) -> Tuple[_FlowModelParam, ...]: + return tuple(param for param in self.parameters if param.is_contextual) + + @property + def regular_param_names(self) -> Tuple[str, ...]: + return tuple(param.name for param in self.regular_params) + + @property + def contextual_param_names(self) -> Tuple[str, ...]: + return tuple(param.name for param in self.contextual_params) + + def param(self, name: str) -> _FlowModelParam: + for param in self.parameters: + if param.name == name: + return param + raise KeyError(name) + + +def _callable_name(func: _AnyCallable) -> str: + return getattr(func, "__name__", type(func).__name__) + + +def _callable_module(func: _AnyCallable) -> str: + return getattr(func, "__module__", __name__) + + +def _context_values(context: ContextBase) -> Dict[str, Any]: + return dict(context) + + +def _transform_repr(transform: Any) -> str: + if callable(transform): + name = _callable_name(transform) + if name.startswith("<") and name.endswith(">"): + return name + return f"<{name}>" + return repr(transform) + + +def _is_model_dependency(value: Any) -> bool: + return isinstance(value, CallableModel) + + +def _bound_field_names(model: Any) -> set[str]: + fields_set = getattr(model, "model_fields_set", None) + if fields_set is not None: + return set(fields_set) + return set() + + +def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: + if isinstance(context_type, type) and issubclass(context_type, ContextBase): + return context_type + + if get_origin(context_type) in _UNION_ORIGINS: + for arg in get_args(context_type): + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, ContextBase): + return arg + + return None + + +@lru_cache(maxsize=None) +def _type_adapter(annotation: Any) -> TypeAdapter: + return TypeAdapter(annotation) + + +def _can_validate_type(annotation: Any) -> bool: + try: + _type_adapter(annotation) + except (PydanticSchemaGenerationError, PydanticUndefinedAnnotation, TypeError, ValueError): + return False + return True + + +def _expected_type_repr(annotation: Any) -> str: + try: + return annotation.__name__ + except AttributeError: + return repr(annotation) + + +def _coerce_value(name: str, value: Any, annotation: Any, source: str) -> Any: + if not _can_validate_type(annotation): + return value + try: + return _type_adapter(annotation).validate_python(value) + except Exception as exc: + expected = _expected_type_repr(annotation) + raise TypeError(f"{source} '{name}': expected {expected}, got {type(value).__name__} ({value!r})") from exc + + +def _unwrap_model_result(value: Any) -> Any: + if isinstance(value, GenericResult): + return value.value + return value + + +def _make_lazy_thunk(model: CallableModel, context: ContextBase) -> Callable[[], Any]: + cache: Dict[str, Any] = {} + + def thunk(): + if "result" not in cache: + cache["result"] = _unwrap_model_result(model(context)) + return cache["result"] + + return thunk + + +def _maybe_auto_unwrap_external_result(target: CallableModel, result: Any) -> Any: + generated = _generated_model_instance(target) + if generated is None: + return result + + config = type(generated).__flow_model_config__ + if config.auto_wrap_result and config.auto_unwrap: + return _unwrap_model_result(result) + return result + + +def _parse_annotation(annotation: Any) -> _ParsedAnnotation: + is_lazy = False + is_from_context = False + + while get_origin(annotation) is Annotated: + args = get_args(annotation) + annotation = args[0] + for metadata in args[1:]: + if isinstance(metadata, _LazyMarker): + is_lazy = True + elif isinstance(metadata, _FromContextMarker): + is_from_context = True + + return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context) + + +def _type_accepts_str(annotation: Any) -> bool: + if annotation is str: + return True + origin = get_origin(annotation) + if origin is Annotated: + return _type_accepts_str(get_args(annotation)[0]) + if origin in _UNION_ORIGINS: + return any(_type_accepts_str(arg) for arg in get_args(annotation) if arg is not type(None)) + return False + + +def _resolve_registry_candidate(value: str) -> Any: + try: + candidate = BaseModel.model_validate(value) + except ValidationError: + return None + return candidate if isinstance(candidate, BaseModel) else None + + +def _registry_candidate_allowed(expected_type: Any, candidate: Any) -> bool: + if _is_model_dependency(candidate): + return True + if not _can_validate_type(expected_type): + return True + try: + _type_adapter(expected_type).validate_python(candidate) + except ValidationError: + return False + return True + + +def _callable_closure_repr(transform: Any) -> str: + closure = getattr(transform, "__closure__", None) + if not closure: + return "" + pieces = [] + for cell in closure: + try: + pieces.append(repr(cell.cell_contents)) + except Exception: + pieces.append("") + return "|".join(pieces) + + +def _callable_fingerprint(transform: Any) -> str: + module = getattr(transform, "__module__", type(transform).__module__) + qualname = getattr(transform, "__qualname__", type(transform).__qualname__) + code = getattr(transform, "__code__", None) + if code is None: + return f"callable:{module}:{qualname}:{repr(transform)}" + + payload = "|".join( + [ + module, + qualname, + code.co_filename, + str(code.co_firstlineno), + hashlib.sha256(marshal.dumps(code)).hexdigest(), + repr(getattr(transform, "__defaults__", None)), + _callable_closure_repr(transform), + ] + ) + return f"callable:{payload}" + + +def _fingerprint_transforms(transforms: Dict[str, Any]) -> Tuple[Tuple[str, str], ...]: + items = [] + for name, transform in sorted(transforms.items()): + if callable(transform): + items.append((name, _callable_fingerprint(transform))) + else: + items.append((name, repr(transform))) + return tuple(items) + + +def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: + model = stage.model if isinstance(stage, BoundModel) else stage + if isinstance(model, _GeneratedFlowModelBase): + return model + return None + + +def _generated_model_class(stage: Any) -> Optional[type["_GeneratedFlowModelBase"]]: + model = _generated_model_instance(stage) + if model is not None: + return type(model) + + generated_model = getattr(stage, "_generated_model", None) + if isinstance(generated_model, type) and issubclass(generated_model, _GeneratedFlowModelBase): + return generated_model + return None + + +def _context_input_types_for_model(model: CallableModel) -> Optional[Dict[str, Any]]: + generated = _generated_model_instance(model) + if generated is not None: + return dict(type(generated).__flow_model_config__.context_input_types) + + context_cls = _concrete_context_type(model.context_type) + if context_cls is None or context_cls is FlowContext or not hasattr(context_cls, "model_fields"): + return None + return {name: info.annotation for name, info in context_cls.model_fields.items()} + + +def _context_required_names_for_model(model: CallableModel) -> Tuple[str, ...]: + generated = _generated_model_instance(model) + if generated is not None: + return type(generated).__flow_model_config__.context_required_names + + context_cls = _concrete_context_type(model.context_type) + if context_cls is None or not hasattr(context_cls, "model_fields"): + return () + return tuple(name for name, info in context_cls.model_fields.items() if info.is_required()) + + +def _missing_regular_param_names(model: "_GeneratedFlowModelBase", config: _FlowModelConfig) -> List[str]: + missing = [] + for param in config.regular_params: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + missing.append(param.name) + return missing + + +def _resolve_regular_param_value(model: "_GeneratedFlowModelBase", param: _FlowModelParam, context: ContextBase) -> Any: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + raise TypeError( + f"Regular parameter '{param.name}' for {_callable_name(type(model).__flow_model_config__.func)} is still unbound. " + "Bind it at construction time." + ) + if param.is_lazy: + if _is_model_dependency(value): + return _make_lazy_thunk(value, context) + return lambda v=value: v + if _is_model_dependency(value): + return _unwrap_model_result(value(context)) + return value + + +def _resolve_contextual_param_value( + model: "_GeneratedFlowModelBase", + param: _FlowModelParam, + context_values: Dict[str, Any], +) -> Tuple[Any, bool]: + if param.name in context_values: + return context_values[param.name], True + + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if not _is_unset_flow_input(value): + return value, True + + if param.has_function_default: + return param.function_default, True + + return _UNSET, False + + +def _resolved_contextual_inputs(model: "_GeneratedFlowModelBase", config: _FlowModelConfig, context: ContextBase) -> Dict[str, Any]: + context_values = _context_values(context) + resolved: Dict[str, Any] = {} + missing_contextual = [] + + for param in config.contextual_params: + value, found = _resolve_contextual_param_value(model, param, context_values) + if not found: + missing_contextual.append(param.name) + continue + resolved[param.name] = value + + if missing_contextual: + missing = ", ".join(sorted(missing_contextual)) + raise TypeError( + f"Missing contextual input(s) for {_callable_name(config.func)}: {missing}. " + "Supply them via the runtime context, compute(), with_inputs(), or construction-time contextual defaults." + ) + + if config.declared_context_type is not None: + validated = config.declared_context_type.model_validate(resolved) + return {param.name: getattr(validated, param.name) for param in config.contextual_params} + + return { + param.name: _coerce_value(param.name, resolved[param.name], param.validation_annotation, "Context field") + for param in config.contextual_params + } + + +def _validate_with_inputs_transforms(model: CallableModel, transforms: Dict[str, Any]) -> Dict[str, Any]: + context_input_types = _context_input_types_for_model(model) + validated = dict(transforms) + + if context_input_types is not None: + invalid = sorted(set(transforms) - set(context_input_types)) + if invalid: + names = ", ".join(invalid) + raise TypeError(f"with_inputs() only accepts contextual fields. Invalid field(s): {names}.") + + for name, transform in list(validated.items()): + if callable(transform): + continue + validated[name] = _coerce_value(name, transform, context_input_types[name], "with_inputs()") + + return validated + + +def _build_generated_compute_context(model: "_GeneratedFlowModelBase", context: Any, kwargs: Dict[str, Any]) -> ContextBase: + config = type(model).__flow_model_config__ + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") + + if context is not _UNSET: + if isinstance(context, FlowContext): + return context + if isinstance(context, ContextBase): + return FlowContext(**_context_values(context)) + return FlowContext.model_validate(context) + + unresolved_regular = sorted( + name for name in config.regular_param_names if name in kwargs and _is_unset_flow_input(getattr(model, name, _UNSET_FLOW_INPUT)) + ) + if unresolved_regular: + names = ", ".join(unresolved_regular) + raise TypeError( + f"compute() cannot satisfy unbound regular parameter(s): {names}. " + "Bind them at construction time; compute() only supplies runtime context." + ) + + ambient = dict(kwargs) + for param in config.contextual_params: + if param.name not in kwargs: + continue + ambient[param.name] = _coerce_value(param.name, kwargs[param.name], param.validation_annotation, "compute() input") + return FlowContext(**ambient) + + +class FlowAPI: + """API namespace for contextual execution and rewrites.""" + + def __init__(self, model: CallableModel): + self._model = model + + @property + def _compute_target(self) -> CallableModel: + return self._model + + def compute(self, context: Any = _UNSET, /, **kwargs) -> Any: + target = self._compute_target + generated = _generated_model_instance(target) + if generated is not None: + built_context = _build_generated_compute_context(generated, context, kwargs) + return _maybe_auto_unwrap_external_result(target, target(built_context)) + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") + if context is _UNSET: + built_context = target.context_type.model_validate(kwargs) + else: + built_context = context if isinstance(context, ContextBase) else target.context_type.model_validate(context) + return _maybe_auto_unwrap_external_result(target, target(built_context)) + + @property + def context_inputs(self) -> Dict[str, Any]: + context_input_types = _context_input_types_for_model(self._model) + return dict(context_input_types or {}) + + @property + def unbound_inputs(self) -> Dict[str, Any]: + generated = _generated_model_instance(self._model) + if generated is not None: + config = type(generated).__flow_model_config__ + result = {} + for param in config.contextual_params: + if not _is_unset_flow_input(getattr(generated, param.name, _UNSET_FLOW_INPUT)): + continue + if param.has_function_default: + continue + result[param.name] = config.context_input_types[param.name] + return result + + required_names = _context_required_names_for_model(self._model) + context_input_types = _context_input_types_for_model(self._model) or {} + return {name: context_input_types[name] for name in required_names} + + @property + def bound_inputs(self) -> Dict[str, Any]: + generated = _generated_model_instance(self._model) + if generated is not None: + config = type(generated).__flow_model_config__ + result: Dict[str, Any] = {} + explicit_fields = _bound_field_names(generated) + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + result[param.name] = value + for param in config.contextual_params: + if param.name not in explicit_fields: + continue + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + result[param.name] = value + return result + + result: Dict[str, Any] = {} + model_fields = getattr(self._model.__class__, "model_fields", {}) + for name in model_fields: + if name == "meta": + continue + result[name] = getattr(self._model, name) + return result + + def with_inputs(self, **transforms) -> "BoundModel": + validated = _validate_with_inputs_transforms(self._model, transforms) + return BoundModel(model=self._model, input_transforms=validated) + + +class BoundModel(WrapperModel): + """A model with contextual input transforms applied locally.""" + + _input_transforms: Dict[str, Any] = PrivateAttr(default_factory=dict) + serialized_transforms: Dict[str, Any] = Field(default_factory=dict, alias="_static_transforms", repr=False, exclude=True) + + @model_validator(mode="before") + @classmethod + def _strip_runtime_serializer_fields(cls, values): + if isinstance(values, dict): + cleaned = dict(values) + cleaned.pop("_input_transforms_fingerprint", None) + return cleaned + return values + + def __init__(self, *, model: CallableModel, input_transforms: Optional[Dict[str, Any]] = None, **kwargs): + if input_transforms is not None: + static_transforms = {name: value for name, value in input_transforms.items() if not callable(value)} + kwargs["_static_transforms"] = static_transforms + super().__init__(model=model, **kwargs) + if input_transforms is not None: + self._input_transforms = dict(input_transforms) + else: + self._input_transforms = dict(self.serialized_transforms) + + def model_post_init(self, __context): + if not self._input_transforms: + self._input_transforms = dict(self.serialized_transforms) + + def _transform_context(self, context: ContextBase) -> ContextBase: + ctx_dict = _context_values(context) + context_input_types = _context_input_types_for_model(self.model) + + for name, transform in self._input_transforms.items(): + value = transform(context) if callable(transform) else transform + if context_input_types is not None and name in context_input_types: + value = _coerce_value(name, value, context_input_types[name], "with_inputs()") + ctx_dict[name] = value + + generated = _generated_model_instance(self.model) + if generated is not None: + return FlowContext(**ctx_dict) + + context_type = _concrete_context_type(self.model.context_type) + if context_type is not None and context_type is not FlowContext: + return context_type.model_validate(ctx_dict) + return FlowContext(**ctx_dict) + + @Flow.call + def __call__(self, context: ContextBase) -> ResultBase: + return self.model(self._transform_context(context)) + + @Flow.deps + def __deps__(self, context: ContextBase) -> GraphDepList: + return [(self.model, [self._transform_context(context)])] + + @model_serializer(mode="wrap") + def _serialize_with_transforms(self, handler): + data = handler(self) + if self.serialized_transforms: + data["_static_transforms"] = dict(self.serialized_transforms) + data["_input_transforms_fingerprint"] = _fingerprint_transforms(self._input_transforms) + return data + + def __repr__(self) -> str: + transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) + return f"{self.model!r}.flow.with_inputs({transforms})" + + @property + def flow(self) -> "FlowAPI": + return _BoundFlowAPI(self) + + +class _BoundFlowAPI(FlowAPI): + def __init__(self, bound_model: BoundModel): + self._bound = bound_model + super().__init__(bound_model.model) + + @property + def _compute_target(self) -> CallableModel: + return self._bound + + def with_inputs(self, **transforms) -> BoundModel: + validated = _validate_with_inputs_transforms(self._bound.model, transforms) + merged = {**self._bound._input_transforms, **validated} + return BoundModel(model=self._bound.model, input_transforms=merged) + + +class _GeneratedFlowModelBase(CallableModel): + __flow_model_config__: ClassVar[_FlowModelConfig] + + @model_validator(mode="before") + @classmethod + def _resolve_registry_refs(cls, values): + if not isinstance(values, dict): + return values + + config = getattr(cls, "__flow_model_config__", None) + if config is None: + return values + + resolved = dict(values) + for param in config.regular_params: + if param.name not in resolved: + continue + value = resolved[param.name] + if not isinstance(value, str): + continue + if _type_accepts_str(param.annotation): + continue + candidate = _resolve_registry_candidate(value) + if candidate is None: + continue + if _registry_candidate_allowed(param.annotation, candidate): + resolved[param.name] = candidate + return resolved + + @model_validator(mode="after") + def _validate_flow_model_fields(self): + config = self.__class__.__flow_model_config__ + + for param in config.parameters: + value = getattr(self, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + + if param.is_contextual: + if _is_model_dependency(value): + raise TypeError( + f"Parameter '{param.name}' is marked FromContext[...] and cannot be bound to a CallableModel. " + "Bind a literal contextual default or supply it via compute()/with_inputs()." + ) + object.__setattr__( + self, + param.name, + _coerce_value(param.name, value, param.validation_annotation, "Contextual default"), + ) + continue + + if _is_model_dependency(value): + continue + + object.__setattr__(self, param.name, _coerce_value(param.name, value, param.annotation, "Field")) + + return self + + @property + def context_type(self) -> Type[ContextBase]: + return self.__class__.__flow_model_config__.context_type + + @property + def result_type(self) -> Type[ResultBase]: + return self.__class__.__flow_model_config__.result_type + + @property + def flow(self) -> FlowAPI: + return FlowAPI(self) + + +def _make_call_impl(config: _FlowModelConfig) -> _AnyCallable: + def __call__(self, context): + missing_regular = _missing_regular_param_names(self, config) + if missing_regular: + missing = ", ".join(sorted(missing_regular)) + raise TypeError( + f"Missing regular parameter(s) for {_callable_name(config.func)}: {missing}. " + "Bind them at construction time; compute() only supplies contextual inputs." + ) + + fn_kwargs: Dict[str, Any] = {} + for param in config.regular_params: + fn_kwargs[param.name] = _resolve_regular_param_value(self, param, context) + + fn_kwargs.update(_resolved_contextual_inputs(self, config, context)) + + raw_result = config.func(**fn_kwargs) + if config.auto_wrap_result: + return GenericResult(value=raw_result) + return raw_result + + cast(Any, __call__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=config.context_type), + ], + return_annotation=config.result_type, + ) + return __call__ + + +def _make_deps_impl(config: _FlowModelConfig) -> _AnyCallable: + def __deps__(self, context): + missing_regular = _missing_regular_param_names(self, config) + if missing_regular: + missing = ", ".join(sorted(missing_regular)) + raise TypeError(f"Missing regular parameter(s) for {_callable_name(config.func)}: {missing}. Bind them before dependency evaluation.") + + deps = [] + for param in config.regular_params: + if param.is_lazy: + continue + value = getattr(self, param.name, _UNSET_FLOW_INPUT) + if isinstance(value, BoundModel): + deps.append((value.model, [value._transform_context(context)])) + elif isinstance(value, CallableModel): + deps.append((value, [context])) + return deps + + cast(Any, __deps__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=config.context_type), + ], + return_annotation=GraphDepList, + ) + return __deps__ + + +def _context_type_annotations_compatible(func_annotation: Any, context_annotation: Any) -> bool: + if func_annotation is context_annotation: + return True + if isinstance(func_annotation, type) and isinstance(context_annotation, type): + return issubclass(func_annotation, context_annotation) or issubclass(context_annotation, func_annotation) + return True + + +def _validate_declared_context_type(context_type: Any, contextual_params: Tuple[_FlowModelParam, ...]) -> Type[ContextBase]: + if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): + raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}") + + context_fields = getattr(context_type, "model_fields", {}) + contextual_names = {param.name for param in contextual_params} + + missing = sorted(name for name in contextual_names if name not in context_fields) + if missing: + raise TypeError(f"context_type {context_type.__name__} must define fields for all FromContext parameters: {', '.join(missing)}") + + required_extra_fields = sorted( + name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in contextual_names and info.is_required() + ) + if required_extra_fields: + raise TypeError( + f"context_type {context_type.__name__} has required fields that are not declared as FromContext parameters: " + f"{', '.join(required_extra_fields)}" + ) + + for param in contextual_params: + ctx_field = context_fields[param.name] + if not _context_type_annotations_compatible(param.annotation, ctx_field.annotation): + raise TypeError( + f"FromContext parameter '{param.name}' annotates {param.annotation!r}, but " + f"context_type {context_type.__name__} declares {ctx_field.annotation!r}." + ) + + return context_type + + +def _analyze_flow_model( + fn: _AnyCallable, + sig: inspect.Signature, + resolved_hints: Dict[str, Any], + *, + context_type: Optional[Type[ContextBase]], + auto_unwrap: bool, +) -> _FlowModelConfig: + params = sig.parameters + + analyzed_params: List[_FlowModelParam] = [] + + for name, param in params.items(): + if name == "self": + continue + + if param.kind is inspect.Parameter.POSITIONAL_ONLY: + raise TypeError(f"Function {_callable_name(fn)} does not support positional-only parameter '{name}'.") + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError(f"Function {_callable_name(fn)} does not support {param.kind.description} parameter '{name}'.") + + annotation = resolved_hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation") + + parsed = _parse_annotation(annotation) + if parsed.is_lazy and parsed.is_from_context: + raise TypeError(f"Parameter '{name}' cannot combine Lazy[...] and FromContext[...].") + + has_function_default = param.default is not inspect.Parameter.empty + function_default = param.default if has_function_default else _UNSET + if parsed.is_from_context and has_function_default and _is_model_dependency(function_default): + raise TypeError(f"Parameter '{name}' is marked FromContext[...] and cannot default to a CallableModel.") + + analyzed_params.append( + _FlowModelParam( + name=name, + annotation=parsed.base, + kind="contextual" if parsed.is_from_context else "regular", + is_lazy=parsed.is_lazy, + has_function_default=has_function_default, + function_default=function_default, + ) + ) + + contextual_params = tuple(param for param in analyzed_params if param.is_contextual) + declared_context_type = None + if context_type is not None and not contextual_params: + raise TypeError("context_type=... requires FromContext[...] parameters.") + if context_type is not None: + declared_context_type = _validate_declared_context_type(context_type, contextual_params) + + call_context_type = FlowContext + context_input_types = {param.name: param.annotation for param in contextual_params} + context_required_names = tuple(param.name for param in contextual_params if not param.has_function_default) + + if declared_context_type is not None: + updated_params = [] + context_fields = declared_context_type.model_fields + for param in analyzed_params: + if not param.is_contextual: + updated_params.append(param) + continue + updated_params.append( + _FlowModelParam( + name=param.name, + annotation=param.annotation, + kind=param.kind, + is_lazy=param.is_lazy, + has_function_default=param.has_function_default, + function_default=param.function_default, + context_validation_annotation=context_fields[param.name].annotation, + ) + ) + analyzed_params = updated_params + + return_type = resolved_hints.get("return", sig.return_annotation) + if return_type is inspect.Signature.empty: + raise TypeError(f"Function {_callable_name(fn)} must have a return type annotation") + + return_origin = get_origin(return_type) or return_type + auto_wrap_result = not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)) + result_type = GenericResult if auto_wrap_result else return_type + + return _FlowModelConfig( + func=fn, + context_type=call_context_type, + result_type=result_type, + auto_wrap_result=auto_wrap_result, + auto_unwrap=auto_unwrap, + parameters=tuple(analyzed_params), + context_input_types=context_input_types, + context_required_names=context_required_names, + declared_context_type=declared_context_type, + ) + + +def _validate_factory_kwargs(config: _FlowModelConfig, kwargs: Dict[str, Any]) -> None: + for param in config.parameters: + if param.name not in kwargs: + continue + value = kwargs[param.name] + if param.is_contextual: + if _is_model_dependency(value): + raise TypeError( + f"Parameter '{param.name}' is marked FromContext[...] and cannot be bound to a CallableModel. " + "Use a literal contextual default or supply it at runtime." + ) + _coerce_value(param.name, value, param.validation_annotation, "Field") + continue + + if value is None or _is_model_dependency(value): + continue + if isinstance(value, str) and not _type_accepts_str(param.annotation): + candidate = _resolve_registry_candidate(value) + if candidate is not None and _registry_candidate_allowed(param.annotation, candidate): + continue + _coerce_value(param.name, value, param.annotation, "Field") + + +def _resolve_generated_model_bases(model_base: Type[CallableModel]) -> Tuple[type, ...]: + if not isinstance(model_base, type) or not issubclass(model_base, CallableModel): + raise TypeError(f"model_base must be a CallableModel subclass, got {model_base!r}") + + if issubclass(model_base, _GeneratedFlowModelBase): + return (model_base,) + if model_base is CallableModel: + return (_GeneratedFlowModelBase,) + return (_GeneratedFlowModelBase, model_base) + + +def flow_model( + func: Optional[_AnyCallable] = None, + *, + context_args: Any = _REMOVED_CONTEXT_ARGS, + context_type: Optional[Type[ContextBase]] = None, + auto_unwrap: bool = False, + model_base: Type[CallableModel] = CallableModel, + cacheable: Any = _UNSET, + volatile: Any = _UNSET, + log_level: Any = _UNSET, + validate_result: Any = _UNSET, + verbose: Any = _UNSET, + evaluator: Any = _UNSET, +) -> _AnyCallable: + """Decorator that generates a CallableModel class from a plain Python function.""" + + if context_args is not _REMOVED_CONTEXT_ARGS: + raise TypeError("context_args=... has been removed. Mark runtime/contextual parameters with FromContext[...] instead.") + + def decorator(fn: _AnyCallable) -> _AnyCallable: + sig = inspect.signature(fn) + + try: + resolved_hints = get_type_hints(fn, include_extras=True) + except (AttributeError, NameError, TypeError): + resolved_hints = {} + + config = _analyze_flow_model(fn, sig, resolved_hints, context_type=context_type, auto_unwrap=auto_unwrap) + + annotations: Dict[str, Any] = {} + namespace: Dict[str, Any] = { + "__module__": _callable_module(fn), + "__qualname__": f"_{_callable_name(fn)}_Model", + "__call__": Flow.call( + **{ + name: value + for name, value in [ + ("cacheable", cacheable), + ("volatile", volatile), + ("log_level", log_level), + ("validate_result", validate_result), + ("verbose", verbose), + ("evaluator", evaluator), + ] + if value is not _UNSET + } + )(_make_call_impl(config)), + "__deps__": Flow.deps(_make_deps_impl(config)), + } + + for param in config.parameters: + annotations[param.name] = Any + if param.is_contextual: + namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + elif param.has_function_default: + namespace[param.name] = param.function_default + else: + namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + + namespace["__annotations__"] = annotations + + GeneratedModel = cast( + type[_GeneratedFlowModelBase], + type(f"_{_callable_name(fn)}_Model", _resolve_generated_model_bases(model_base), namespace), + ) + GeneratedModel.__flow_model_config__ = config + register_ccflow_import_path(GeneratedModel) + GeneratedModel.model_rebuild() + + @wraps(fn) + def factory(**kwargs) -> _GeneratedFlowModelBase: + _validate_factory_kwargs(config, kwargs) + return GeneratedModel(**kwargs) + + cast(Any, factory)._generated_model = GeneratedModel + factory.__doc__ = fn.__doc__ + return factory + + if func is not None: + return decorator(func) + return decorator diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml new file mode 100644 index 0000000..c58d7b7 --- /dev/null +++ b/ccflow/tests/config/conf_flow.yaml @@ -0,0 +1,73 @@ +# Flow.model configurations for Hydra integration tests. + +flow_loader: + _target_: ccflow.tests.test_flow_model.basic_loader + source: test_source + multiplier: 5 + +flow_processor: + _target_: ccflow.tests.test_flow_model.string_processor + prefix: "value=" + suffix: "!" + +flow_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 100 + +flow_transformer: + _target_: ccflow.tests.test_flow_model.data_transformer + source: flow_source + factor: 3 + +flow_stage1: + _target_: ccflow.tests.test_flow_model.pipeline_stage1 + initial: 10 + +flow_stage2: + _target_: ccflow.tests.test_flow_model.pipeline_stage2 + stage1_output: flow_stage1 + multiplier: 2 + +flow_stage3: + _target_: ccflow.tests.test_flow_model.pipeline_stage3 + stage2_output: flow_stage2 + offset: 50 + +diamond_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 10 + +diamond_branch_a: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 2 + +diamond_branch_b: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 5 + +diamond_aggregator: + _target_: ccflow.tests.test_flow_model.data_aggregator + input_a: diamond_branch_a + input_b: diamond_branch_b + operation: add + +flow_date_loader: + _target_: ccflow.tests.test_flow_model.date_range_loader_previous_day + source: market_data + include_weekends: false + +flow_date_processor: + _target_: ccflow.tests.test_flow_model.date_range_processor + raw_data: flow_date_loader + normalize: true + +contextual_loader_model: + _target_: ccflow.tests.test_flow_model.contextual_loader + source: data_source + +contextual_processor_model: + _target_: ccflow.tests.test_flow_model.contextual_processor + data: contextual_loader_model + prefix: output diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 43f86b5..9b51592 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -462,6 +462,7 @@ def test_types(self): error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelMissingContextArg) + # BadModelDoubleContextArg also fails with the same error since extra params aren't allowed error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelDoubleContextArg) @@ -783,3 +784,373 @@ class MyCallableParent_bad_decorator(MyCallableParent): @Flow.deps def foo(self, context): return [] + + +# ============================================================================= +# Tests for Flow.call(auto_context=True) +# ============================================================================= + + +class TestAutoContext(TestCase): + """Tests for @Flow.call(auto_context=True).""" + + def test_basic_usage_with_kwargs(self): + """Test basic auto_context usage with keyword arguments.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = AutoContextCallable() + + # Call with kwargs + result = model(x=42, y="hello") + self.assertEqual(result.value, "42-hello") + + # Call with default + result = model(x=10) + self.assertEqual(result.value, "10-default") + + def test_auto_context_attribute(self): + """Test that __auto_context__ attribute is set.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, a: int, b: str) -> GenericResult: + return GenericResult(value=f"{a}-{b}") + + # The __call__ method should have __auto_context__ + call_method = AutoContextCallable.__call__ + self.assertTrue(hasattr(call_method, "__wrapped__")) + # Access the inner function's __auto_context__ + inner = call_method.__wrapped__ + self.assertTrue(hasattr(inner, "__auto_context__")) + + auto_ctx = inner.__auto_context__ + self.assertTrue(issubclass(auto_ctx, ContextBase)) + self.assertIn("a", auto_ctx.model_fields) + self.assertIn("b", auto_ctx.model_fields) + + def test_auto_context_is_registered(self): + """Test that the auto context is registered for serialization.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + inner = AutoContextCallable.__call__.__wrapped__ + auto_ctx = inner.__auto_context__ + + # Should have __ccflow_import_path__ set + self.assertTrue(hasattr(auto_ctx, "__ccflow_import_path__")) + self.assertTrue(auto_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + def test_call_with_context_object(self): + """Test calling with a context object instead of kwargs.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = AutoContextCallable() + + # Get the auto context class + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + + # Create a context object + ctx = auto_ctx(x=99, y="context") + result = model(ctx) + self.assertEqual(result.value, "99-context") + + def test_with_parent_context(self): + """Test auto_context with a parent context class.""" + + class ParentContext(ContextBase): + base_value: str = "base" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + # Get auto context + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + + # Should inherit from ParentContext + self.assertTrue(issubclass(auto_ctx, ParentContext)) + + # Should have both fields + self.assertIn("base_value", auto_ctx.model_fields) + self.assertIn("x", auto_ctx.model_fields) + + # Create context with parent field + ctx = auto_ctx(x=42, base_value="custom") + self.assertEqual(ctx.base_value, "custom") + self.assertEqual(ctx.x, 42) + + def test_parent_fields_must_be_in_signature(self): + """Test that parent context fields must be included in function signature.""" + + class ParentContext(ContextBase): + required_field: str + + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + self.assertIn("required_field", str(cm.exception)) + + def test_cloudpickle_roundtrip(self): + """Test cloudpickle roundtrip for auto_context callable.""" + + class AutoContextCallable(CallableModel): + multiplier: int = 2 + + @Flow.call(auto_context=True) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + model = AutoContextCallable(multiplier=3) + + # Test roundtrip + restored = rcploads(rcpdumps(model)) + + result = restored(x=10) + self.assertEqual(result.value, 30) + + def test_ray_task_execution(self): + """Test auto_context callable in Ray task.""" + + class AutoContextCallable(CallableModel): + factor: int = 2 + + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + model = AutoContextCallable(factor=5) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, x=10, y=2)) + + self.assertEqual(result, 60) # (10 + 2) * 5 + + def test_context_type_property_works(self): + """Test that type_ property works on the auto context.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + ctx = auto_ctx(x=42) + + # type_ should work and be importable + type_path = str(ctx.type_) + self.assertIn("_Local_", type_path) + self.assertEqual(ctx.type_.object, auto_ctx) + + def test_complex_field_types(self): + """Test auto_context with complex field types.""" + from typing import List, Optional + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__( + self, + *, + items: List[int], + name: Optional[str] = None, + count: int = 0, + ) -> GenericResult: + total = sum(items) + count + return GenericResult(value=f"{name}:{total}" if name else str(total)) + + model = AutoContextCallable() + + result = model(items=[1, 2, 3], name="test", count=10) + self.assertEqual(result.value, "test:16") + + result = model(items=[5, 5]) + self.assertEqual(result.value, "10") + + def test_with_flow_options(self): + """Test auto_context with FlowOptions parameters.""" + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True, validate_result=False) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + model = AutoContextCallable() + result = model(x=42) + self.assertEqual(result.value, 42) + + def test_error_without_auto_context(self): + """Test that using kwargs signature without auto_context raises an error.""" + + class BadCallable(CallableModel): + @Flow.call # Missing auto_context=True! + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + # Error happens at instantiation time when _check_signature validates + with self.assertRaises(ValueError) as cm: + BadCallable() + + # Should fail because __call__ must take a single argument named 'context' + error_msg = str(cm.exception) + self.assertIn("__call__", error_msg) + self.assertIn("context", error_msg) + + def test_invalid_auto_context_value(self): + """Test that invalid auto_context values raise TypeError with helpful message.""" + with self.assertRaises(TypeError) as cm: + + @Flow.call(auto_context="invalid") + def bad_func(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + error_msg = str(cm.exception) + self.assertIn("auto_context must be False, True, or a ContextBase subclass", error_msg) + + def test_auto_context_rejects_var_args(self): + """auto_context should reject *args early with a clear error.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *args: int) -> GenericResult: + return GenericResult(value=len(args)) + + self.assertIn("variadic positional", str(cm.exception)) + + def test_auto_context_rejects_var_kwargs(self): + """auto_context should reject **kwargs early with a clear error.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, **kwargs: int) -> GenericResult: + return GenericResult(value=len(kwargs)) + + self.assertIn("variadic keyword", str(cm.exception)) + + def test_auto_context_requires_return_annotation(self): + """auto_context should reject missing return annotations immediately.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int): + return GenericResult(value=value) + + self.assertIn("must have a return type annotation", str(cm.exception)) + + def test_auto_context_rejects_missing_annotation(self): + """auto_context should reject params without type annotations.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value) -> GenericResult: + return GenericResult(value=value) + + self.assertIn("must have a type annotation", str(cm.exception)) + + +class TestDeclaredTypeMatches(TestCase): + """Tests for _declared_type_matches helper in callable.py.""" + + def test_typevar_always_matches(self): + from ccflow.callable import _declared_type_matches + + T = TypeVar("T") + self.assertTrue(_declared_type_matches(int, T)) + + def test_union_expected_no_type_args(self): + """Union with no concrete type args should return False.""" + from ccflow.callable import _declared_type_matches + + # Union[None] after filtering out NoneType has no concrete args + self.assertFalse(_declared_type_matches(int, Union[None])) + + def test_union_expected_with_actual_type(self): + """Concrete type matching Union expected.""" + from ccflow.callable import _declared_type_matches + + self.assertTrue(_declared_type_matches(int, Union[int, str])) + self.assertFalse(_declared_type_matches(float, Union[int, str])) + + def test_union_both_sides(self): + """Both actual and expected are Unions.""" + from ccflow.callable import _declared_type_matches + + self.assertTrue(_declared_type_matches(Union[int, str], Union[int, str])) + self.assertTrue(_declared_type_matches(Union[str, int], Union[int, str])) # order independent + self.assertFalse(_declared_type_matches(Union[int, float], Union[int, str])) + + def test_non_type_actual(self): + """Non-type actual should return False.""" + from ccflow.callable import _declared_type_matches + + self.assertFalse(_declared_type_matches("not_a_type", int)) + + def test_non_type_expected(self): + """Non-type expected should return False.""" + from ccflow.callable import _declared_type_matches + + self.assertFalse(_declared_type_matches(int, "not_a_type")) + + +class TestCallableModelGenericValidation(TestCase): + """Tests for CallableModelGeneric type validation paths.""" + + def test_context_type_mismatch_raises(self): + """Generic type validation should reject context type mismatch.""" + + class ContextA(ContextBase): + a: int + + class ContextB(ContextBase): + b: int + + class ModelA(CallableModel): + @Flow.call + def __call__(self, context: ContextA) -> GenericResult[int]: + return GenericResult(value=context.a) + + with self.assertRaises(ValidationError): + # Expect ContextB but model has ContextA + CallableModelGenericType[ContextB, GenericResult[int]].model_validate(ModelA()) + + def test_result_type_mismatch_raises(self): + """Generic type validation should reject result type mismatch.""" + + class MyContext(ContextBase): + x: int + + class ResultA(ResultBase): + a: int + + class ResultB(ResultBase): + b: int + + class ModelA(CallableModel): + @Flow.call + def __call__(self, context: MyContext) -> ResultA: + return ResultA(a=context.x) + + with self.assertRaises(ValidationError): + CallableModelGenericType[MyContext, ResultB].model_validate(ModelA()) diff --git a/ccflow/tests/test_context.py b/ccflow/tests/test_context.py index ad98bd9..64d71e8 100644 --- a/ccflow/tests/test_context.py +++ b/ccflow/tests/test_context.py @@ -275,8 +275,13 @@ def split_camel(name: str): def test_inheritance(self): """Test that if a context has a superset of fields of another context, it is a subclass of that context.""" - for parent_name, parent_class in self.classes.items(): - for child_name, child_class in self.classes.items(): + # Exclude FlowContext from this test - it's a special universal carrier with no + # declared fields (uses extra="allow"), so the "superset implies subclass" logic + # doesn't apply to it. + classes_to_check = {name: cls for name, cls in self.classes.items() if name != "FlowContext"} + + for parent_name, parent_class in classes_to_check.items(): + for child_name, child_class in classes_to_check.items(): if parent_class is child_class: continue diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index cc34155..dabf815 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -1,9 +1,21 @@ from datetime import date from unittest import TestCase -from ccflow import DateContext, Evaluator, ModelEvaluationContext +import pytest -from .evaluators.util import MyDateCallable +from ccflow import CallableModel, DateContext, Evaluator, Flow, ModelEvaluationContext + +from .evaluators.util import MyDateCallable, MyResult + + +class MyAutoContextDateCallable(CallableModel): + """Auto context version of MyDateCallable for testing evaluators.""" + + offset: int + + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date) -> MyResult: + return MyResult(x=date.day + self.offset) class TestEvaluator(TestCase): @@ -32,3 +44,57 @@ def test_evaluator_deps(self): evaluator = Evaluator() out2 = evaluator.__deps__(model_evaluation_context) self.assertEqual(out2, out) + + +@pytest.mark.parametrize( + "callable_class", + [MyDateCallable, MyAutoContextDateCallable], + ids=["standard", "auto_context"], +) +class TestEvaluatorParametrized: + """Test evaluators work with both standard and auto_context callables.""" + + def test_evaluator_with_context_object(self, callable_class): + """Test evaluator with a context object.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + + out = model_evaluation_context() + assert out == MyResult(x=2) # day 1 + offset 1 + + evaluator = Evaluator() + out2 = evaluator(model_evaluation_context) + assert out2 == out + + def test_evaluator_with_fn_specified(self, callable_class): + """Test evaluator with fn='__call__' explicitly specified.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__call__") + + out = model_evaluation_context() + assert out == MyResult(x=2) + + def test_evaluator_direct_call_matches(self, callable_class): + """Test that evaluator result matches direct call.""" + m1 = callable_class(offset=5) + context = DateContext(date=date(2022, 1, 15)) + + # Direct call + direct_result = m1(context) + + # Via evaluator + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + evaluator_result = model_evaluation_context() + + assert direct_result == evaluator_result + assert direct_result == MyResult(x=20) # day 15 + offset 5 + + def test_evaluator_with_kwargs(self, callable_class): + """Test that evaluator works when callable is called with kwargs.""" + m1 = callable_class(offset=1) + + # Call with kwargs + result = m1(date=date(2022, 1, 10)) + assert result == MyResult(x=11) # day 10 + offset 1 diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py new file mode 100644 index 0000000..1480d3b --- /dev/null +++ b/ccflow/tests/test_flow_context.py @@ -0,0 +1,231 @@ +"""Tests for FlowContext, FlowAPI, and BoundModel under the FromContext design.""" + +import pickle +from concurrent.futures import ThreadPoolExecutor +from datetime import date, timedelta + +import cloudpickle +import pytest + +from ccflow import BoundModel, CallableModel, ContextBase, Flow, FlowContext, FromContext, GenericResult + + +class NumberContext(ContextBase): + x: int + + +class OffsetModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: NumberContext) -> GenericResult[int]: + return GenericResult(value=context.x + self.offset) + + +def test_flow_context_basic_properties(): + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), label="x") + assert ctx.start_date == date(2024, 1, 1) + assert ctx.end_date == date(2024, 1, 31) + assert ctx.label == "x" + assert dict(ctx) == {"start_date": date(2024, 1, 1), "end_date": date(2024, 1, 31), "label": "x"} + + +def test_flow_context_value_semantics_and_hash(): + first = FlowContext(x=1, values=[1, 2]) + second = FlowContext(x=1, values=[1, 2]) + third = FlowContext(x=2, values=[1, 2]) + + assert first == second + assert first != third + assert len({first, second, third}) == 2 + + +def test_flow_context_pickle_and_cloudpickle_roundtrip(): + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), tags=frozenset({"a", "b"})) + assert pickle.loads(pickle.dumps(ctx)) == ctx + assert cloudpickle.loads(cloudpickle.dumps(ctx)) == ctx + + +def test_flow_api_introspection_for_from_context_model(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + model = add(a=10) + assert model.flow.context_inputs == {"b": int, "c": int} + assert model.flow.unbound_inputs == {"b": int} + assert model.flow.bound_inputs == {"a": 10} + assert model.flow.compute(b=2).value == 17 + + +def test_flow_api_compute_accepts_single_context_or_kwargs_but_not_both(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + assert model.flow.compute(b=5).value == 15 + assert model.flow.compute(FlowContext(b=6)).value == 16 + + with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + model.flow.compute(FlowContext(b=5), b=6) + + +def test_bound_model_with_inputs_static_and_callable(): + @Flow.model + def load_window(start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_window() + shifted = model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=7), + end_date=date(2024, 1, 31), + ) + + result = shifted(FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 30))) + assert result.value == {"start": date(2024, 1, 1), "end": date(2024, 1, 31)} + + +def test_bound_model_with_inputs_is_branch_local_and_chained(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, value: FromContext[int]) -> int: + return left + right + value + + base = source() + left = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) + right = base.flow.with_inputs(value=lambda ctx: ctx.value + 2).flow.with_inputs(value=lambda ctx: ctx.value + 10) + model = combine(left=left, right=right) + + assert model.flow.compute(value=5).value == (6 + 15 + 5) + + +def test_compute_kwargs_can_supply_ambient_context_for_upstream_transforms(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + base = source() + model = combine( + left=base.flow.with_inputs(value=lambda ctx: ctx.value + 1), + right=base.flow.with_inputs(value=lambda ctx: ctx.value + 10), + ) + + assert model.flow.context_inputs == {"bonus": int} + assert model.flow.compute(value=5, bonus=100).value == (6 + 15 + 100) + + +def test_bound_model_rejects_regular_field_rewrites(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="only accepts contextual fields"): + add(a=1).flow.with_inputs(a=3) + + +def test_bound_model_repr_matches_user_facing_api(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=1) + bound = model.flow.with_inputs(b=lambda ctx: ctx.b + 1) + assert repr(bound) == f"{model!r}.flow.with_inputs(b=)" + + +def test_bound_model_serialization_roundtrip_preserves_static_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_inputs(b=5) + dumped = bound.model_dump(mode="python") + restored = type(bound).model_validate(dumped) + + assert restored.flow.compute().value == 15 + assert restored.model.flow.bound_inputs == {"a": 10} + + +def test_bound_model_cloudpickle_roundtrip_preserves_callable_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_inputs(b=lambda ctx: ctx.b + 1) + restored = cloudpickle.loads(cloudpickle.dumps(bound)) + + assert restored.flow.compute(b=4).value == 15 + + +def test_transformed_dag_cloudpickle_roundtrip_preserves_callable_transforms(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, value: FromContext[int]) -> int: + return left + right + value + + base = source() + model = combine( + left=base.flow.with_inputs(value=lambda ctx: ctx.value + 1), + right=base.flow.with_inputs(value=lambda ctx: ctx.value + 10), + ) + restored = cloudpickle.loads(cloudpickle.dumps(model)) + + assert restored.flow.compute(value=5).value == (6 + 15 + 5) + + +def test_regular_callable_models_still_support_with_inputs(): + model = OffsetModel(offset=10) + shifted = model.flow.with_inputs(x=lambda ctx: ctx.x * 2) + assert shifted(NumberContext(x=5)).value == 20 + + +def test_flow_api_for_regular_callable_model(): + model = OffsetModel(offset=10) + assert model.flow.compute(x=5).value == 15 + assert model.flow.context_inputs == {"x": int} + assert model.flow.unbound_inputs == {"x": int} + assert model.flow.bound_inputs == {"offset": 10} + + +def test_generated_flow_model_compute_is_thread_safe(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int]) -> int: + return a + b + c + + model = add(a=10) + + def worker(n: int) -> int: + return model.flow.compute(b=n, c=n + 1).value + + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(worker, range(20))) + + assert results == [10 + n + n + 1 for n in range(20)] + + +def test_bound_model_restore_is_thread_safe(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + dumped = add(a=10).flow.with_inputs(b=5).model_dump(mode="python") + + def worker(_: int) -> int: + restored = BoundModel.model_validate(dumped) + return restored.flow.compute().value + + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(worker, range(20))) + + assert results == [15] * 20 diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py new file mode 100644 index 0000000..5dd4abd --- /dev/null +++ b/ccflow/tests/test_flow_model.py @@ -0,0 +1,699 @@ +"""Focused tests for the FromContext-based Flow.model API.""" + +import graphlib +from datetime import date, timedelta +from typing import Annotated + +import pytest +from pydantic import model_validator +from ray.cloudpickle import dumps as rcpdumps, loads as rcploads + +import ccflow.flow_model as flow_model_module +from ccflow import ( + CallableModel, + ContextBase, + DateRangeContext, + Flow, + FlowContext, + FlowOptionsOverride, + FromContext, + GenericResult, + Lazy, + ModelRegistry, +) +from ccflow.evaluators import GraphEvaluator + + +class SimpleContext(ContextBase): + value: int + + +class ParentRangeContext(ContextBase): + start_date: date + end_date: date + + +class RichRangeContext(ParentRangeContext): + label: str = "child" + + +class OrderedContext(ContextBase): + a: int + b: int + + @model_validator(mode="after") + def _validate_order(self): + if self.a > self.b: + raise ValueError("a must be <= b") + return self + + +@Flow.model +def basic_loader(source: str, multiplier: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + +@Flow.model +def string_processor(value: FromContext[int], prefix: str = "value=", suffix: str = "!") -> GenericResult[str]: + return GenericResult(value=f"{prefix}{value}{suffix}") + + +@Flow.model +def data_source(base_value: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + base_value) + + +@Flow.model +def data_transformer(source: int, factor: int) -> GenericResult[int]: + return GenericResult(value=source * factor) + + +@Flow.model +def data_aggregator(input_a: int, input_b: int, operation: str = "add") -> GenericResult[int]: + if operation == "add": + return GenericResult(value=input_a + input_b) + raise ValueError(f"unsupported operation: {operation}") + + +@Flow.model +def pipeline_stage1(initial: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + initial) + + +@Flow.model +def pipeline_stage2(stage1_output: int, multiplier: int) -> GenericResult[int]: + return GenericResult(value=stage1_output * multiplier) + + +@Flow.model +def pipeline_stage3(stage2_output: int, offset: int) -> GenericResult[int]: + return GenericResult(value=stage2_output + offset) + + +@Flow.model +def date_range_loader_previous_day( + source: str, + start_date: FromContext[date], + end_date: FromContext[date], + include_weekends: bool = False, +) -> GenericResult[dict]: + del include_weekends + return GenericResult( + value={ + "source": source, + "start_date": str(start_date - timedelta(days=1)), + "end_date": str(end_date), + } + ) + + +@Flow.model +def date_range_processor(raw_data: dict, normalize: bool = False) -> GenericResult[str]: + prefix = "normalized:" if normalize else "raw:" + return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") + + +@Flow.model +def contextual_loader(source: str, start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult( + value={ + "source": source, + "start_date": str(start_date), + "end_date": str(end_date), + } + ) + + +@Flow.model +def contextual_processor( + prefix: str, + data: dict, + start_date: FromContext[date], + end_date: FromContext[date], +) -> GenericResult[str]: + del start_date, end_date + return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") + + +def test_from_context_anchor_behavior(): + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + assert foo(a=11).flow.compute(b=12).value == 23 + assert foo(a=11, b=12).flow.compute().value == 23 + + with pytest.raises(TypeError, match="compute\\(\\) cannot satisfy unbound regular parameter\\(s\\): a"): + foo().flow.compute(a=11, b=12) + + +def test_regular_param_accepts_upstream_model(): + @Flow.model + def source(value: FromContext[int], offset: int) -> int: + return value + offset + + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + model = foo(a=source(offset=5)) + assert model.flow.compute(value=7, b=12).value == 24 + assert model.flow.compute(FlowContext(value=7, b=12)).value == 24 + + +def test_bound_regular_param_name_can_collide_with_ambient_context(): + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + @Flow.model + def combine(a: int, left: int, bonus: FromContext[int]) -> int: + return a + left + bonus + + model = combine(a=100, left=source()) + assert model.flow.compute(a=7, bonus=5).value == 112 + + +def test_contextual_param_rejects_callable_model(): + @Flow.model + def source(offset: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + offset) + + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + with pytest.raises(TypeError, match="cannot be bound to a CallableModel"): + foo(a=1, b=source(offset=2)) + + +def test_contextual_construction_defaults_and_bound_inputs(): + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + + model = foo(a=11, b=12) + assert model.flow.bound_inputs == {"a": 11, "b": 12} + assert model.flow.context_inputs == {"b": int} + assert model.flow.unbound_inputs == {} + assert model.flow.compute().value == 23 + + +def test_contextual_function_defaults_remain_contextual(): + @Flow.model + def foo(a: int, b: FromContext[int] = 5) -> int: + return a + b + + model = foo(a=2) + assert model.flow.bound_inputs == {"a": 2} + assert model.flow.context_inputs == {"b": int} + assert model.flow.unbound_inputs == {} + assert model.flow.compute().value == 7 + assert model.flow.compute(b=10).value == 12 + + +def test_context_type_accepts_richer_subclass_for_from_context(): + @Flow.model(context_type=ParentRangeContext) + def span_days(multiplier: int, start_date: FromContext[date], end_date: FromContext[date]) -> int: + return multiplier * ((end_date - start_date).days + 1) + + model = span_days(multiplier=2) + assert model.flow.compute(start_date="2024-01-01", end_date="2024-01-03").value == 6 + assert model(RichRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 4), label="x")).value == 8 + + +def test_context_type_validation_applies_to_resolved_contextual_values(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b + + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.compute(a=2, b=1) + + with pytest.raises(ValueError, match="a must be <= b"): + add(a=2, b=1).flow.compute() + + +def test_context_named_parameters_are_just_regular_parameters(): + @Flow.model + def loader(context: DateRangeContext, source: str = "db") -> GenericResult[str]: + return GenericResult(value=f"{source}:{context.start_date}:{context.end_date}") + + model = loader(context=DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)), source="api") + assert model.flow.bound_inputs["context"] == DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)) + assert model.flow.context_inputs == {} + assert model.flow.compute().value == "api:2024-01-01:2024-01-02" + + with pytest.raises(TypeError, match="Missing regular parameter\\(s\\) for loader: context"): + loader(source="api").flow.compute(start_date="2024-01-01", end_date="2024-01-02") + + +def test_auto_unwrap_defaults_to_false_for_auto_wrapped_results(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + result = add(a=10).flow.compute(b=5) + assert isinstance(result, GenericResult) + assert result.value == 15 + + +def test_compute_does_not_unwrap_explicit_generic_result_returns(): + @Flow.model + def load(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * 2) + + result = load().flow.compute(value=3) + assert isinstance(result, GenericResult) + assert result.value == 6 + + +def test_auto_unwrap_can_be_enabled_for_auto_wrapped_results(): + @Flow.model(auto_unwrap=True) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + assert add(a=10).flow.compute(b=5) == 15 + + +def test_auto_unwrap_only_affects_external_compute_results(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 2 + + @Flow.model(auto_unwrap=True) + def add(left: int, bonus: FromContext[int]) -> int: + return left + bonus + + model = add(left=source()) + assert model.flow.compute(FlowContext(value=4, bonus=3)) == 11 + + +def test_model_base_allows_custom_callable_model_subclass(): + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @model_validator(mode="after") + def _validate_multiplier(self): + if self.multiplier <= 0: + raise ValueError("multiplier must be positive") + return self + + def scaled(self, value: int) -> int: + return value * self.multiplier + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10, multiplier=3) + assert isinstance(model, CustomFlowBase) + assert model.multiplier == 3 + assert model.scaled(4) == 12 + assert model.flow.compute(b=5).value == 15 + + with pytest.raises(ValueError, match="multiplier must be positive"): + add(a=10, multiplier=0) + + +def test_model_base_must_be_callable_model_subclass(): + with pytest.raises(TypeError, match="model_base must be a CallableModel subclass"): + + @Flow.model(model_base=int) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + +def test_context_named_regular_parameter_can_coexist_with_from_context(): + @Flow.model + def mixed(context: SimpleContext, y: FromContext[int]) -> int: + return context.value + y + + model = mixed(context=SimpleContext(value=10)) + assert model.flow.bound_inputs == {"context": SimpleContext(value=10)} + assert model.flow.context_inputs == {"y": int} + assert model.flow.compute(y=5).value == 15 + + +def test_context_args_keyword_is_removed(): + with pytest.raises(TypeError, match="context_args=... has been removed"): + + @Flow.model(context_args=["x"]) + def bad(x: int) -> int: + return x + + +def test_context_type_requires_from_context(): + with pytest.raises(TypeError, match="context_type=... requires FromContext"): + + @Flow.model(context_type=DateRangeContext) + def bad(x: int) -> int: + return x + + +def test_lazy_dependency_remains_lazy(): + calls = {"source": 0} + + @Flow.model + def source(value: FromContext[int]) -> int: + calls["source"] += 1 + return value * 10 + + @Flow.model + def choose(value: int, lazy_value: Lazy[int], threshold: FromContext[int]) -> int: + if value > threshold: + return value + return lazy_value() + + eager = choose(value=50, lazy_value=source()) + assert eager.flow.compute(FlowContext(value=3, threshold=10)).value == 50 + assert calls["source"] == 0 + + deferred = choose(value=5, lazy_value=source()) + assert deferred.flow.compute(FlowContext(value=3, threshold=10)).value == 30 + assert calls["source"] == 1 + + +def test_lazy_runtime_helper_is_removed(): + @Flow.model + def source(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value) + + with pytest.raises(TypeError, match="Lazy\\(model\\)\\(\\.\\.\\.\\) has been removed"): + Lazy(source()) + + +def test_lazy_and_from_context_combination_is_rejected(): + with pytest.raises(TypeError, match="cannot combine Lazy"): + + @Flow.model + def bad(x: Lazy[FromContext[int]]) -> int: + return x() + + +def test_auto_wrap_and_serialization_roundtrip(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10) + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) + + assert restored.flow.bound_inputs == {"a": 10} + assert restored.flow.unbound_inputs == {"b": int} + assert restored.flow.compute(b=5).value == 15 + + +def test_generated_models_cloudpickle_roundtrip(): + @Flow.model + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + model = multiply(a=6) + restored = rcploads(rcpdumps(model, protocol=5)) + assert restored.flow.compute(b=7).value == 42 + + +def test_generated_models_cloudpickle_preserves_unset_validation_sentinel(): + @Flow.model + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + model = multiply(a=6) + restored = rcploads(rcpdumps(model, protocol=5)) + param = type(restored).__flow_model_config__.contextual_params[0] + + assert param.context_validation_annotation is flow_model_module._UNSET + assert param.validation_annotation is int + + +def test_graph_integration_fanout_fanin(): + @Flow.model + def source(base: int, value: FromContext[int]) -> int: + return value + base + + @Flow.model + def scale(data: int, factor: int) -> int: + return data * factor + + @Flow.model + def merge(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + src = source(base=10) + left = scale(data=src, factor=2) + right = scale(data=src, factor=5) + model = merge(left=left, right=right) + + assert model.flow.compute(FlowContext(value=3, bonus=7)).value == ((3 + 10) * 2) + ((3 + 10) * 5) + 7 + + +def test_graph_integration_cycle_raises_cleanly(): + @Flow.model + def increment(x: int, n: FromContext[int]) -> int: + return x + n + + root = increment() + branch = increment(x=root) + object.__setattr__(root, "x", branch) + + with FlowOptionsOverride(options={"evaluator": GraphEvaluator()}): + with pytest.raises(graphlib.CycleError): + root.flow.compute(n=1) + + +def test_large_contextual_contract_stress(): + @Flow.model + def total( + base: int, + x1: FromContext[int], + x2: FromContext[int], + x3: FromContext[int], + x4: FromContext[int], + x5: FromContext[int], + x6: FromContext[int], + ) -> int: + return base + x1 + x2 + x3 + x4 + x5 + x6 + + model = total(base=10) + assert model.flow.context_inputs == {"x1": int, "x2": int, "x3": int, "x4": int, "x5": int, "x6": int} + assert model.flow.compute(x1=1, x2=2, x3=3, x4=4, x5=5, x6=6).value == 31 + + +def test_registry_integration_for_generated_models(): + registry = ModelRegistry.root().clear() + model = basic_loader(source="warehouse", multiplier=3) + registry.add("loader", model) + + retrieved = registry["loader"] + assert isinstance(retrieved, CallableModel) + assert retrieved(SimpleContext(value=4)).value == 12 + + +def test_unexpected_type_adapter_errors_are_not_silently_swallowed(): + class BrokenSchema: + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + raise RuntimeError("boom") + + @Flow.model + def bad(x: BrokenSchema, y: FromContext[int]) -> int: + del x, y + return 0 + + with pytest.raises(RuntimeError, match="boom"): + bad(x=object()) + + +def test_unexpected_type_hint_resolution_errors_propagate(monkeypatch): + def broken_get_type_hints(*args, **kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(flow_model_module, "get_type_hints", broken_get_type_hints) + + def add(x: int) -> int: + return x + + with pytest.raises(RuntimeError, match="boom"): + Flow.model(add) + + +def test_internal_generated_model_helpers_and_config_properties(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + generated_cls = flow_model_module._generated_model_class(add) + assert generated_cls is not None + + config = generated_cls.__flow_model_config__ + assert config.regular_param_names == ("a",) + assert config.contextual_param_names == ("b",) + assert config.param("a").name == "a" + assert config.param("b").name == "b" + + with pytest.raises(KeyError): + config.param("missing") + + model = add(a=10) + assert flow_model_module._generated_model_class(model) is generated_cls + + class DerivedGeneratedBase(flow_model_module._GeneratedFlowModelBase): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + del context + return [] + + assert flow_model_module._resolve_generated_model_bases(DerivedGeneratedBase) == (DerivedGeneratedBase,) + + +def test_internal_type_helpers_and_plain_callable_flow_api_paths(): + assert flow_model_module._concrete_context_type(SimpleContext | None) is SimpleContext + assert flow_model_module._concrete_context_type(int) is None + assert flow_model_module._type_accepts_str(Annotated[str, flow_model_module._FromContextMarker()]) is True + assert flow_model_module._type_accepts_str(int | None) is False + assert flow_model_module._transform_repr(lambda value: value) == "" + assert flow_model_module._bound_field_names(object()) == set() + + class PlainModel(CallableModel): + @property + def context_type(self): + return SimpleContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + del context + return [] + + model = PlainModel() + + assert model.flow.context_inputs == {"value": int} + assert model.flow.unbound_inputs == {"value": int} + assert model.flow.bound_inputs == {} + assert model.flow.compute({"value": 3}).value == 3 + + with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + model.flow.compute(SimpleContext(value=1), value=2) + + +def test_compute_accepts_context_object_for_from_context_models(): + model = basic_loader(source="warehouse", multiplier=3) + + assert model.flow.context_inputs == {"value": int} + assert model.flow.unbound_inputs == {"value": int} + assert model.flow.compute({"value": 4}).value == 12 + assert model.flow.compute(SimpleContext(value=5)).value == 15 + + with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + model.flow.compute(SimpleContext(value=1), value=2) + + +def test_additional_validation_and_hint_fallback_paths(monkeypatch): + class MissingFieldContext(ContextBase): + start_date: date + + with pytest.raises(TypeError, match="must define fields for all FromContext parameters"): + + @Flow.model(context_type=MissingFieldContext) + def bad_missing(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return 0 + + class ExtraRequiredContext(ContextBase): + start_date: date + end_date: date + label: str + + with pytest.raises(TypeError, match="has required fields that are not declared as FromContext parameters"): + + @Flow.model(context_type=ExtraRequiredContext) + def bad_extra(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return 0 + + class BadAnnotationContext(ContextBase): + value: str + + with pytest.raises(TypeError, match="annotates"): + + @Flow.model(context_type=BadAnnotationContext) + def bad_annotation(value: FromContext[int]) -> int: + return value + + with pytest.raises(TypeError, match="context_type must be a ContextBase subclass"): + + @Flow.model(context_type=int) + def bad_context_type(value: FromContext[int]) -> int: + return value + + @Flow.model + def source(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value) + + with pytest.raises(TypeError, match="cannot default to a CallableModel"): + + @Flow.model + def bad_default(value: FromContext[int] = source()) -> int: + return value + + with pytest.raises(TypeError, match="return type annotation"): + + @Flow.model + def missing_return(value: int): + return value + + with pytest.raises(TypeError, match="does not support positional-only parameter 'value'"): + + @Flow.model + def bad_positional_only(value: int, /, bonus: FromContext[int]) -> int: + return value + bonus + + with pytest.raises(TypeError, match="does not support variadic positional parameter 'values'"): + + @Flow.model + def bad_varargs(*values: int) -> int: + return sum(values) + + with pytest.raises(TypeError, match="does not support variadic keyword parameter 'values'"): + + @Flow.model + def bad_varkw(**values: int) -> int: + return sum(values.values()) + + @Flow.model + def keyword_only(value: int, *, bonus: FromContext[int]) -> int: + return value + bonus + + assert keyword_only(value=2).flow.compute(bonus=3).value == 5 + + @Flow.model + def keyword_only_context(*, context: SimpleContext, offset: int) -> int: + return context.value + offset + + assert keyword_only_context(context=SimpleContext(value=3), offset=4).flow.compute().value == 7 + + def missing_hints(*args, **kwargs): + raise AttributeError("missing hints") + + monkeypatch.setattr(flow_model_module, "get_type_hints", missing_hints) + + @Flow.model + def add(x: int, y: FromContext[int]) -> int: + return x + y + + assert add(x=1).flow.compute(y=2).value == 3 diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py new file mode 100644 index 0000000..981c511 --- /dev/null +++ b/ccflow/tests/test_flow_model_hydra.py @@ -0,0 +1,135 @@ +"""Hydra integration tests for the FromContext-based Flow.model API.""" + +from datetime import date +from pathlib import Path + +from omegaconf import OmegaConf + +from ccflow import CallableModel, DateRangeContext, FlowContext, ModelRegistry + +from .test_flow_model import SimpleContext + +CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") + + +def setup_function(): + ModelRegistry.root().clear() + + +def teardown_function(): + ModelRegistry.root().clear() + + +def test_basic_loader_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + loader = registry["flow_loader"] + assert isinstance(loader, CallableModel) + assert loader(SimpleContext(value=10)).value == 50 + + +def test_basic_processor_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + processor = registry["flow_processor"] + assert processor(SimpleContext(value=42)).value == "value=42!" + + +def test_two_stage_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + transformer = registry["flow_transformer"] + assert transformer(SimpleContext(value=5)).value == 315 + + +def test_three_stage_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + stage3 = registry["flow_stage3"] + assert stage3(SimpleContext(value=10)).value == 90 + + +def test_diamond_dependency_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + aggregator = registry["diamond_aggregator"] + assert aggregator(SimpleContext(value=10)).value == 140 + + +def test_date_range_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + processor = registry["flow_date_processor"] + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + + assert "normalized:" in result.value + assert "2024-01-09" in result.value + + +def test_from_context_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + loader = registry["contextual_loader_model"] + processor = registry["contextual_processor_model"] + + assert loader.flow.context_inputs == {"start_date": date, "end_date": date} + result = processor.flow.compute(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31)) + assert result.value == "output:data_source:2024-03-01 to 2024-03-31" + assert processor.data is loader + + +def test_registry_name_references_share_instances(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + transformer = registry["flow_transformer"] + source = registry["flow_source"] + assert transformer.source is source + + stage2 = registry["flow_stage2"] + stage3 = registry["flow_stage3"] + assert stage2.stage1_output is registry["flow_stage1"] + assert stage3.stage2_output is stage2 + + +def test_instantiate_with_omegaconf(): + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "dynamic_source", + "multiplier": 7, + }, + "contextual": { + "_target_": "ccflow.tests.test_flow_model.contextual_loader", + "source": "warehouse", + }, + } + ) + + registry = ModelRegistry.root() + registry.load_config(cfg) + + assert registry["loader"](SimpleContext(value=3)).value == 21 + assert registry["contextual"].flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)).value == { + "source": "warehouse", + "start_date": "2024-01-01", + "end_date": "2024-01-02", + } + + +def test_flow_context_execution_with_yaml_models(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) + + processor = registry["contextual_processor_model"] + result = processor.flow.compute(FlowContext(start_date=date(2024, 4, 1), end_date=date(2024, 4, 30))) + assert result.value == "output:data_source:2024-04-01 to 2024-04-30" diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py index 4f50b95..3c59ae5 100644 --- a/ccflow/tests/test_local_persistence.py +++ b/ccflow/tests/test_local_persistence.py @@ -1211,7 +1211,7 @@ class TestCreateCcflowModelCloudpickleCrossProcess: id="context_only", ), pytest.param( - # Dynamic context with CallableModel + # Runtime-created context with CallableModel """ from ray.cloudpickle import dump from ccflow import CallableModel, ContextBase, GenericResult, Flow diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md new file mode 100644 index 0000000..7aa0a6c --- /dev/null +++ b/docs/design/flow_model_design.md @@ -0,0 +1,399 @@ +# Flow.model Design + +## Overview + +`@Flow.model` turns a plain Python function into a real `CallableModel`. + +The design is intentionally narrow: + +- ordinary unmarked parameters are regular bound inputs, +- `FromContext[T]` marks the only runtime/contextual inputs, +- `.flow.compute(...)` is the execution entry point for the full DAG, +- `.flow.with_inputs(...)` rewires contextual inputs on one dependency edge, +- upstream `CallableModel`s can still be passed as ordinary arguments. + +The goal is that a reader can look at one function signature and immediately +see: + +1. which values come from runtime context, +2. which values must be bound as regular configuration or dependencies, +3. how to rewrite contextual inputs for one branch of the graph. + +## Core Example + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def foo(a: int, b: FromContext[int]) -> int: + return a + b + + +# Build an instance with a=11 bound, then supply b=12 at runtime: +configured = foo(a=11) +result = configured.flow.compute(b=12) +assert result.value == 23 # .value unwraps the GenericResult wrapper + +# Or create a different instance that stores b=12 as its contextual default: +prefilled = foo(a=11, b=12) +result = prefilled.flow.compute() +assert result.value == 23 +``` + +> **Note:** When the function returns a plain value (like `int` above) instead +> of a `ResultBase` subclass, `@Flow.model` automatically wraps it in +> `GenericResult`. Access the inner value with `.value`. + +This is the core contract: + +- `a` is a regular parameter — it must be bound at construction time, +- `b` is contextual because it is marked with `FromContext[int]` — it can come + from runtime context, a contextual default stored on the model instance, or a + function default, +- `.flow.compute(...)` may carry extra ambient context for upstream graph + branches, but it never binds regular parameters. + +Nothing is being mutated at execution time in the second example. +`prefilled = foo(a=11, b=12)` constructs a different model instance whose +contextual default for `b` is already `12`. Because `b` is still contextual, +incoming runtime context can still override that default. + +This means the following is **invalid**: + +```python +foo().flow.compute(a=11, b=12) +# TypeError: compute() cannot bind regular parameter(s): a. +# Bind them at construction time. +``` + +`a` is not contextual, so it must be bound at construction time (`foo(a=11)`). +By contrast, extra ambient fields that are only needed by upstream +`with_inputs(...)` rewrites are allowed on the kwargs entrypoint for +implicit-`FlowContext` graphs. + +## Regular Parameters vs Contextual Parameters + +### Regular Parameters + +Regular parameters are the unmarked ones. + +They can be satisfied by: + +- a literal value, +- a default value from the function signature, +- an upstream `CallableModel`. + +When an upstream model is supplied, `@Flow.model` evaluates it with the current +context and passes the resolved value into the function. This is how you wire +stages together — just pass one model as an argument to another: + +```python +from ccflow import Flow, FlowContext, FromContext + + +@Flow.model +def load_value(value: FromContext[int], offset: int) -> int: + return value + offset + + +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +# Wire load_value into add's 'a' parameter: +model = add(a=load_value(offset=5)) + +# At runtime, load_value runs first (value=7 + offset=5 = 12), +# then add runs (a=12 + b=12 = 24): +assert model.flow.compute(value=7, b=12).value == 24 +``` + +### Contextual Parameters + +Contextual parameters are the ones marked with `FromContext[...]`. + +They can be satisfied by: + +- runtime context, +- contextual defaults stored on the model instance, +- function defaults. + +They cannot be satisfied by `CallableModel` values. + +A construction-time value for a contextual parameter is still a default, not a +conversion into a regular bound parameter. + +Contextual precedence is: + +1. branch-local `.flow.with_inputs(...)` rewrites, +2. incoming runtime context, +3. contextual defaults stored on the model instance, +4. function defaults. + +## `.flow.compute(...)` + +`.flow.compute(...)` is the ergonomic execution entry point for contextual +execution of the whole DAG. + +For generated `@Flow.model` stages it accepts either: + +- keyword arguments that become the ambient runtime context bag, or +- one context object. + +It does not accept both at the same time. + +```python +from ccflow import Flow, FlowContext, FromContext + + +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +model = add(a=10) +assert model.flow.compute(b=5).value == 15 +assert model.flow.compute(FlowContext(b=6)).value == 16 +``` + +For implicit-`FlowContext` models, the kwargs form is intentionally a DAG +entrypoint: it can include extra fields needed only by upstream transformed +dependencies. Regular parameters are still never read from runtime context. If +the root model has an unbound regular parameter whose name appears in +`compute(**kwargs)`, `compute()` raises early instead of silently treating that +value as configuration. + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def source(value: FromContext[int]) -> int: + return value + + +@Flow.model +def add(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + +base = source() +model = add( + left=base.flow.with_inputs(value=lambda ctx: ctx.value + 1), + right=base.flow.with_inputs(value=lambda ctx: ctx.value + 10), +) + +assert model.flow.context_inputs == {"bonus": int} +assert model.flow.compute(value=5, bonus=100).value == 121 +``` + +If a regular parameter is already bound on the root model, a same-named key in +`compute(**kwargs)` is treated as ambient context for the graph rather than a +rebind of the root parameter: + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def source(a: FromContext[int]) -> int: + return a + + +@Flow.model +def combine(a: int, left: int, bonus: FromContext[int]) -> int: + return a + left + bonus + + +model = combine(a=100, left=source()) + +# Root 'a' stays bound to 100. The runtime 'a=7' is still available to +# upstream graph nodes that read it from context. +assert model.flow.compute(a=7, bonus=5).value == 112 +``` + +`compute()` returns the same result object you would get from `model(context)`, +unless `auto_unwrap=True` is enabled for an auto-wrapped plain return type: + +```python +from ccflow import Flow, FromContext + + +@Flow.model(auto_unwrap=True) +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +result = add(a=10).flow.compute(b=5) +assert result == 15 +``` + +## `.flow.with_inputs(...)` + +`.flow.with_inputs(...)` rewrites contextual inputs locally for one wrapped +dependency. + +```python +from datetime import date, timedelta + +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: + days = (end_date - start_date).days + 1 + return 1000.0 + days * 10.0 + + +@Flow.model +def revenue_growth(current: float, previous: float, start_date: FromContext[date], end_date: FromContext[date]) -> dict: + return {"window_end": end_date, "growth_pct": round((current - previous) / previous * 100, 2)} + + +current = load_revenue(region="us") +previous = current.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), +) + +growth = revenue_growth(current=current, previous=previous) +result = growth(DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31))) +``` + +In this example, `current` and `previous` share the same `load_revenue` model +but see different date windows at runtime. The `with_inputs()` call on +`previous` shifts the dates back 30 days without affecting `current`. + +Key rules: + +- `with_inputs()` only targets contextual fields, +- transforms are branch-local — they only affect the wrapped dependency, not + the entire pipeline, +- chained `with_inputs()` calls merge, with the newest transform winning for a + repeated field. + +## `context_type=...` + +When you want the `FromContext[...]` fields to match an existing nominal +context shape, use `context_type=...`: + +```python +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model(context_type=DateRangeContext) +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: + return 125.0 +``` + +That preserves the primary `FromContext[...]` authoring model while letting +callers pass richer context objects whose relevant fields satisfy the declared +`context_type`. + +If the function genuinely needs the runtime context object itself inside the +function body on each call, use a normal `CallableModel` subclass instead of +`@Flow.model`. + +## Introspection APIs + +Generated models expose three useful introspection helpers: + +- `model.flow.context_inputs`: the full contextual contract, +- `model.flow.unbound_inputs`: the contextual fields still required at runtime, +- `model.flow.bound_inputs`: regular bound inputs plus any construction-time + contextual defaults. + +Example: + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + +model = add(a=10) +assert model.flow.context_inputs == {"b": int, "c": int} +assert model.flow.unbound_inputs == {"b": int} +assert model.flow.bound_inputs == {"a": 10} +``` + +## Lazy Dependencies + +`Lazy[T]` defers evaluation of an upstream dependency until the function body +explicitly calls it. This is useful when a dependency is expensive and only +needed conditionally: + +```python +from ccflow import Flow, FlowContext, FromContext, Lazy + + +@Flow.model +def load_value(value: FromContext[int]) -> int: + return value * 10 + + +@Flow.model +def maybe_use(current: int, fallback: Lazy[int], threshold: FromContext[int]) -> int: + if current > threshold: + return current # fallback is never evaluated + return fallback() # evaluate only when needed + + +model = maybe_use(current=50, fallback=load_value()) + +# current (50) > threshold (10), so load_value never runs: +assert model.flow.compute(value=3, threshold=10).value == 50 + +# current (5) <= threshold (10), so load_value runs (3 * 10 = 30): +model2 = maybe_use(current=5, fallback=load_value()) +assert model2.flow.compute(value=3, threshold=10).value == 30 +``` + +Without `Lazy[T]`, the upstream model would always run. With it, the function +controls exactly when (and whether) the dependency executes. + +## When To Use `@Flow.model` + +Use `@Flow.model` when: + +- the stage logic is naturally a plain function, +- you want ordinary arguments to look like ordinary Python function parameters, +- the contextual contract is small and explicit, +- the main goal is easy graph authoring on top of existing ccflow machinery. + +Use a hand-written class-based `CallableModel` when: + +- the model needs custom methods or substantial internal state, +- the full context object is the natural primary interface, +- the stage is no longer best expressed as one function and a small amount of + wiring. + +## Troubleshooting + +**`compute()` says a field is not contextual** + +That field is a regular parameter. Bind it at construction time. Only +`FromContext[...]` fields belong in `compute()`. + +**`with_inputs()` rejects a field** + +`with_inputs()` only rewrites contextual inputs. If you are trying to attach one +stage to another, pass the upstream model as a regular argument at construction +time. + +**A contextual parameter still shows up in `context_inputs` after I bound it** + +That is expected. `context_inputs` reports the full contextual contract. +`unbound_inputs` reports only the contextual values still needed at runtime. + +**A shared dependency runs more than once** + +`@Flow.model` authors the graph cleanly, but execution still follows the normal +ccflow evaluator path. If you need deduplication or graph scheduling, use the +appropriate evaluators and cache settings just as you would for class-based +`CallableModel`s. diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 616e3d8..0da832c 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -6,121 +6,211 @@ - [Evaluators](#evaluators) - [Results](#results) -`ccflow` (Composable Configuration Flow) is a collection of tools for workflow configuration, orchestration, and dependency injection. -It is intended to be flexible enough to handle diverse use cases, including data retrieval, validation, transformation, and loading (i.e. ETL workflows), model training, microservice configuration, and automated report generation. +`ccflow` (Composable Configuration Flow) is a collection of tools for workflow +configuration, orchestration, and dependency injection. It is intended to stay +flexible across ETL workflows, model training, report generation, and service +configuration. ## Base Model -Central to `ccflow` is the `BaseModel` class. -`BaseModel` is the base class for models in the `ccflow` framework. -A model is basically a data class (class with attributes). -The naming was inspired by the open source library [Pydantic](https://docs.pydantic.dev/latest/) (`BaseModel` actually inherits from the Pydantic base model class). +`BaseModel` is the core pydantic-based model class used throughout `ccflow`. +Models are regular data objects with validation, serialization, and registry +support. ## Callable Model -`CallableModel` is the base class for a special type of `BaseModel` which can be called. -`CallableModel`'s are called with a context (something that derives from `ContextBase`) and returns a result (something that derives from `ResultBase`). -As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. +`CallableModel` is a `BaseModel` that can be executed with a context +(`ContextBase`) and produces a result (`ResultBase`). + +### Flow.model Decorator + +`@Flow.model` is the plain-function front door to `CallableModel`. + +It generates a real `CallableModel` class with proper `__call__` and `__deps__` +methods, so it still plugs into the normal evaluator, registry, cache, Hydra, +and serialization machinery. + +If the function returns a plain value instead of a `ResultBase`, the generated +model wraps it in `GenericResult`. + +#### Primary Authoring Model + +`FromContext[T]` is the only marker for runtime/contextual inputs. + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +model = add(a=10) +assert model.flow.compute(b=5).value == 15 + +prefilled = add(a=10, b=7) +assert prefilled.flow.compute().value == 17 +``` + +That means: + +- `a` is a regular parameter, +- `b` is contextual, +- `.flow.compute(...)` only accepts contextual inputs. + +`prefilled = add(a=10, b=7)` creates a different model instance with a stored +contextual default for `b`. `compute()` does not mutate the model; it resolves +the remaining contextual inputs for that execution. + +Regular parameters can be satisfied by: + +- literal values, +- function defaults, +- upstream `CallableModel`s. + +Contextual parameters can be satisfied by: + +- runtime context, +- contextual defaults stored on the model instance, +- function defaults. + +Contextual parameters cannot be bound to `CallableModel` values. + +#### Nominal Context Validation + +You can keep the `FromContext[...]` style while validating those fields against +an existing context type: + +```python +from datetime import date +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model(context_type=DateRangeContext) +def load_data(source: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: + return 125.0 +``` + +If the function genuinely needs the runtime context object itself inside the +function body on each call, write a normal `CallableModel` subclass instead of +using `@Flow.model`. + +#### Composing Dependencies + +Passing an upstream model as an ordinary argument is the main composition story. + +```python +from datetime import date, timedelta +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: + days = (end_date - start_date).days + 1 + return 1000.0 + days * 10.0 + + +@Flow.model +def revenue_growth(current: float, previous: float, start_date: FromContext[date], end_date: FromContext[date]) -> dict: + return {"window_end": end_date, "growth_pct": round((current - previous) / previous * 100, 2)} + + +current = load_revenue(region="us") + +# Reuse the same model with a shifted date window for "previous": +previous = current.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), +) + +growth = revenue_growth(current=current, previous=previous) + +# Execute — current sees Jan 2024, previous sees Dec 2023: +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = growth(ctx) +``` + +#### Deferred Execution Helpers + +`model.flow.compute(...)` accepts either contextual keyword arguments or one +context object and returns the same result as `model(context)`. + +`model.flow.with_inputs(...)` rewrites contextual inputs on one dependency edge. +It only accepts contextual fields and remains branch-local. + +Generated models also expose introspection helpers: + +```python +model = add(a=10) +model.flow.context_inputs # {"b": int} — the full contextual contract +model.flow.unbound_inputs # {"b": int} — contextual fields still needed at runtime +model.flow.bound_inputs # {"a": 10} — all construction-time values +``` + +#### Lazy Dependencies + +`Lazy[T]` is the lazy type-level marker for dependency parameters. + +```python +from ccflow import Flow, FromContext, Lazy + + +@Flow.model +def load_value(value: FromContext[int]) -> int: + return value * 10 + + +@Flow.model +def choose(current: int, deferred: Lazy[int], threshold: FromContext[int]) -> int: + if current > threshold: + return current + return deferred() +``` + +Use `Lazy[T]` when a dependency is expensive and the function should decide +whether to execute it. ## Model Registry -A `ModelRegistry` is a named collection of models. -A `ModelRegistry` can be loaded from YAML configuration, which means you can define a collection of models using YAML. -This is really powerful because this gives you a easy way to define a collection of Python objects via configuration. +The model registry lets you register models by name and resolve them later, +including from config-driven workflows. -## Models +- root registry access: `ModelRegistry.root()` +- add and remove models by name +- reuse shared instances through registry references -Although you are free to define your own models (`BaseModel` implementations) to use in your flow graph, -`ccflow` comes with some models that you can use off the shelf to solve common problems. `ccflow` comes with a range of models for reading data. +## Models -The following table summarizes the available models. +The `ccflow.models` package contains concrete model implementations that build +on the framework primitives. -> [!NOTE] -> -> Some models are still in the process of being open sourced. +Use these when you want reusable, prebuilt model classes instead of authoring +your own `CallableModel` or `@Flow.model` stage. ## Publishers -`ccflow` also comes with a range of models for writing data. -These are referred to as publishers. -You can "chain" publishers and callable models using `PublisherModel` to call a `CallableModel` and publish the results in one step. -In fact, `ccflow` comes with several implementations of `PublisherModel` for common publishing use cases. - -The following table summarizes the "publisher" models. - -> [!NOTE] -> -> Some models are still in the process of being open sourced. - -| Name | Path | Description | -| :--------------------------- | :------------------ | :---------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `DictTemplateFilePublisher` | `ccflow.publishers` | Publish data to a file after populating a Jinja template. | -| `GenericFilePublisher` | `ccflow.publishers` | Publish data using a generic "dump" Callable. Uses `smart_open` under the hood so that local and cloud paths are supported. | -| `JSONPublisher` | `ccflow.publishers` | Publish data to file in JSON format. | -| `PandasFilePublisher` | `ccflow.publishers` | Publish a pandas data frame to a file using an appropriate method on pd.DataFrame. For large-scale exporting (using parquet), see `PandasParquetPublisher`. | -| `NarwhalsFilePublisher` | `ccflow.publishers` | Publish a narwhals data frame to a file using an appropriate method on nw.DataFrame. | -| `PicklePublisher` | `ccflow.publishers` | Publish data to a pickle file. | -| `PydanticJSONPublisher` | `ccflow.publishers` | Publish a pydantic model to a json file. See [Pydantic modeljson](https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump) | -| `YAMLPublisher` | `ccflow.publishers` | Publish data to file in YAML format. | -| `CompositePublisher` | `ccflow.publishers` | Highly configurable, publisher that decomposes a pydantic BaseModel or a dictionary into pieces and publishes each piece separately. | -| `PrintPublisher` | `ccflow.publishers` | Print data using python standard print. | -| `LogPublisher` | `ccflow.publishers` | Print data using python standard logging. | -| `PrintJSONPublisher` | `ccflow.publishers` | Print data in JSON format. | -| `PrintYAMLPublisher` | `ccflow.publishers` | Print data in YAML format. | -| `PrintPydanticJSONPublisher` | `ccflow.publishers` | Print pydantic model as json. See https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump_json | -| `ArrowDatasetPublisher` | *Coming Soon!* | | -| `PandasDeltaPublisher` | *Coming Soon!* | | -| `EmailPublisher` | *Coming Soon!* | | -| `MatplotlibFilePublisher` | *Coming Soon!* | | -| `MLFlowArtifactPublisher` | *Coming Soon!* | | -| `MLFlowPublisher` | *Coming Soon!* | | -| `PandasParquetPublisher` | *Coming Soon!* | | -| `PlotlyFilePublisher` | *Coming Soon!* | | -| `XArrayPublisher` | *Coming Soon!* | | +Publishers handle result publication and side-effectful output sinks. + +They are useful when a workflow result needs to be written to an external +system rather than only returned to the caller. ## Evaluators -`ccflow` comes with "evaluators" that allows you to evaluate (i.e. run) `CallableModel` s in different ways. - -The following table summarizes the "evaluator" models. - -> [!NOTE] -> -> Some models are still in the process of being open sourced. - -| Name | Path | Description | -| :---------------------------------- | :------------------ | :----------------------------------------------------------------------------------------------------------------------------- | -| `LazyEvaluator` | `ccflow.evaluators` | Evaluator that only actually runs the callable once an attribute of the result is queried (by hooking into `__getattribute__`) | -| `LoggingEvaluator` | `ccflow.evaluators` | Evaluator that logs information about evaluating the callable. | -| `MemoryCacheEvaluator` | `ccflow.evaluators` | Evaluator that caches results in memory. | -| `MultiEvaluator` | `ccflow.evaluators` | An evaluator that combines multiple evaluators. | -| `GraphEvaluator` | `ccflow.evaluators` | Evaluator that evaluates the dependency graph of callable models in topologically sorted order. | -| `ChunkedDateRangeEvaluator` | *Coming Soon!* | | -| `ChunkedDateRangeResultsAggregator` | *Coming Soon!* | | -| `RayChunkedDateRangeEvaluator` | *Coming Soon!* | | -| `DependencyTrackingEvaluator` | *Coming Soon!* | | -| `DiskCacheEvaluator` | *Coming Soon!* | | -| `ParquetCacheEvaluator` | *Coming Soon!* | | -| `RayCacheEvaluator` | *Coming Soon!* | | -| `RayGraphEvaluator` | *Coming Soon!* | | -| `RayDelayedDistributedEvaluator` | *Coming Soon!* | | -| `ParquetCacheEvaluator` | *Coming Soon!* | | -| `RetryEvaluator` | *Coming Soon!* | | +Evaluators control how `CallableModel`s execute. + +Key point for `@Flow.model`: it does not create a new execution engine. It +authors models that still run through the existing evaluator stack. + +Depending on your evaluator setup, you can add logging, caching, graph-aware +execution, or custom execution policies. ## Results -A Result is an object that holds the results from a callable model. It provides the equivalent of a strongly typed dictionary where the keys and schema are known upfront. - -The following table summarizes the "result" models. - -| Name | Path | Description | -| :------------------------ | :----------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `GenericResult` | `ccflow.result` | A generic result (holds anything). | -| `DictResult` | `ccflow.result` | A generic dict (key/value) result. | -| `ArrowResult` | `ccflow.result.pyarrow` | Holds an arrow table. | -| `ArrowDateRangeResult` | `ccflow.result.pyarrow` | Extension of `ArrowResult` for representing a table over a date range that can be divided by date, such that generation of any sub-range of dates gives the same results as the original table filtered for those dates. | -| `NarwhalsResult` | `ccflow.result.narwhals` | Holds a narwhals `DataFrame` or `LazyFrame`. | -| `NarwhalsDataFrameResult` | `ccflow.result.narwhals` | Holds a narwhals eager `DataFrame`. | -| `NumpyResult` | `ccflow.result.numpy` | Holds a numpy array. | -| `PandasResult` | `ccflow.result.pandas` | Holds a pandas dataframe. | -| `XArrayResult` | `ccflow.result.xarray` | Holds an xarray. | +`ResultBase` is the common base class for workflow results. + +`GenericResult[T]` is the default wrapper used when: + +- a model naturally wants one value payload, or +- a `@Flow.model` function returns a plain Python value instead of a concrete + `ResultBase` subclass. diff --git a/examples/config/flow_model_hydra_builder_demo.yaml b/examples/config/flow_model_hydra_builder_demo.yaml new file mode 100644 index 0000000..5579a5c --- /dev/null +++ b/examples/config/flow_model_hydra_builder_demo.yaml @@ -0,0 +1,24 @@ +# Hydra config for examples/flow_model_hydra_builder_demo.py +# +# Pattern: +# - configure static pipeline specs in YAML +# - use model_alias to pass already-registered models into a plain Python builder +# - keep runtime context as runtime inputs, supplied later at execution time + +current_revenue: + _target_: examples.flow_model_hydra_builder_demo.load_revenue + region: us + +week_over_week: + _target_: examples.flow_model_hydra_builder_demo.build_comparison + current: + _target_: ccflow.compose.model_alias + model_name: current_revenue + comparison: week_over_week + +month_over_month: + _target_: examples.flow_model_hydra_builder_demo.build_comparison + current: + _target_: ccflow.compose.model_alias + model_name: current_revenue + comparison: month_over_month diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py new file mode 100644 index 0000000..d4dd633 --- /dev/null +++ b/examples/flow_model_example.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python +"""Main `@Flow.model` example. + +Shows how to: + +1. define stages as plain Python functions, +2. compose stages by passing upstream models as ordinary arguments, +3. rewrite contextual inputs on one dependency edge with `.flow.with_inputs(...)`, +4. execute either as `model(context)` or `model.flow.compute(...)`. + +Run with: + python examples/flow_model_example.py +""" + +from datetime import date, timedelta + +from ccflow import DateRangeContext, Flow, FromContext + + +@Flow.model(context_type=DateRangeContext) +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: + """Return synthetic revenue for one reporting window.""" + days = (end_date - start_date).days + 1 + region_base = {"us": 1000.0, "eu": 850.0}.get(region, 900.0) + days_since_2024 = (end_date - date(2024, 1, 1)).days + trend = days_since_2024 * 2.5 + return round(region_base + days * 8.0 + trend, 2) + + +@Flow.model(context_type=DateRangeContext) +def revenue_change( + current: float, + previous: float, + label: str, + days_back: int, + start_date: FromContext[date], + end_date: FromContext[date], +) -> dict: + """Compare the current window against a shifted previous window.""" + previous_start = start_date - timedelta(days=days_back) + previous_end = end_date - timedelta(days=days_back) + growth_pct = round((current - previous) / previous * 100, 2) + return { + "comparison": label, + "current_window": f"{start_date} -> {end_date}", + "previous_window": f"{previous_start} -> {previous_end}", + "current": current, + "previous": previous, + "growth_pct": growth_pct, + } + + +def shifted_window(model, *, days_back: int): + """Reuse one upstream model with a shifted runtime window.""" + return model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=days_back), + end_date=lambda ctx: ctx.end_date - timedelta(days=days_back), + ) + + +def build_week_over_week_pipeline(region: str): + """Build one reusable comparison pipeline.""" + current = load_revenue(region=region) + previous = shifted_window(current, days_back=7) + return revenue_change( + current=current, + previous=previous, + label="week_over_week", + days_back=7, + ) + + +def main() -> None: + print("=" * 64) + print("Flow.model Example") + print("=" * 64) + + pipeline = build_week_over_week_pipeline(region="us") + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 31), + ) + + direct = pipeline(ctx) + computed = pipeline.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ) + + print("\nPipeline:") + print(" current input:", pipeline.current) + print(" previous input:", pipeline.previous) + + print("\nExecution:") + print(f" direct == computed: {direct == computed}") + + print("\nResult:") + for key, value in computed.value.items(): + print(f" {key}: {value}") + + +if __name__ == "__main__": + main() diff --git a/examples/flow_model_hydra_builder_demo.py b/examples/flow_model_hydra_builder_demo.py new file mode 100644 index 0000000..0e2aaa1 --- /dev/null +++ b/examples/flow_model_hydra_builder_demo.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python +"""Hydra + Flow.model builder demo. + +This example shows a clean way to mix: + +1. ergonomic `@Flow.model` pipeline wiring in Python, and +2. Hydra / ModelRegistry configuration for static pipeline specs. + +The pattern is: + +- keep runtime context (`start_date`, `end_date`) as runtime inputs, +- use a plain Python builder function for graph construction, +- let Hydra instantiate that builder and register the returned model. + +Run with: + python examples/flow_model_hydra_builder_demo.py +""" + +from calendar import monthrange +from datetime import date, timedelta +from pathlib import Path +from typing import Literal + +from ccflow import BoundModel, CallableModel, DateRangeContext, Flow, FromContext, ModelRegistry + +CONFIG_PATH = Path(__file__).with_name("config") / "flow_model_hydra_builder_demo.yaml" +ComparisonName = Literal["week_over_week", "month_over_month"] + + +@Flow.model(context_type=DateRangeContext) +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: + """Return synthetic revenue for a date window.""" + days = (end_date - start_date).days + 1 + region_base = {"us": 1000.0, "eu": 850.0, "apac": 920.0}.get(region, 900.0) + days_since_2024 = (end_date - date(2024, 1, 1)).days + trend = days_since_2024 * 2.5 + return round(region_base + days * 8.0 + trend, 2) + + +@Flow.model(context_type=DateRangeContext) +def revenue_change( + current: float, + previous: float, + comparison: ComparisonName, + start_date: FromContext[date], + end_date: FromContext[date], +) -> dict: + """Compare the current window against a shifted previous window.""" + growth = (current - previous) / previous + previous_start, previous_end = comparison_window(start_date, end_date, comparison) + return { + "comparison": comparison, + "current_window": f"{start_date} -> {end_date}", + "previous_window": f"{previous_start} -> {previous_end}", + "current": current, + "previous": previous, + "delta": round(current - previous, 2), + "growth_pct": round(growth * 100, 2), + } + + +def comparison_window(start_date: date, end_date: date, comparison: ComparisonName) -> tuple[date, date]: + """Return the previous window for a named comparison policy.""" + if comparison == "week_over_week": + return start_date - timedelta(days=7), end_date - timedelta(days=7) + + if start_date.day != 1: + raise ValueError("month_over_month requires start_date to be the first day of a month") + if start_date.year != end_date.year or start_date.month != end_date.month: + raise ValueError("month_over_month requires the current window to stay within one calendar month") + expected_end = date(end_date.year, end_date.month, monthrange(end_date.year, end_date.month)[1]) + if end_date != expected_end: + raise ValueError("month_over_month requires end_date to be the last day of that month") + + previous_year = start_date.year if start_date.month > 1 else start_date.year - 1 + previous_month = start_date.month - 1 if start_date.month > 1 else 12 + previous_start = date(previous_year, previous_month, 1) + previous_end = date(previous_year, previous_month, monthrange(previous_year, previous_month)[1]) + return previous_start, previous_end + + +def comparison_input(model: CallableModel, comparison: ComparisonName) -> BoundModel: + """Apply a named comparison policy to one dependency.""" + return model.flow.with_inputs( + start_date=lambda ctx: comparison_window(ctx.start_date, ctx.end_date, comparison)[0], + end_date=lambda ctx: comparison_window(ctx.start_date, ctx.end_date, comparison)[1], + ) + + +def build_comparison(current: CallableModel, *, comparison: ComparisonName): + """Hydra-friendly builder that returns a configured comparison model.""" + previous = comparison_input(current, comparison) + return revenue_change( + current=current, + previous=previous, + comparison=comparison, + ) + + +def main() -> None: + registry = ModelRegistry.root() + registry.clear() + try: + registry.load_config_from_path(str(CONFIG_PATH), overwrite=True) + + week_over_week = registry["week_over_week"] + month_over_month = registry["month_over_month"] + + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 31), + ) + + print("=" * 68) + print("Hydra + Flow.model Builder Demo") + print("=" * 68) + print("\nLoaded from config:") + print(" current_revenue:", registry["current_revenue"]) + print(" week_over_week:", week_over_week) + print(" month_over_month:", month_over_month) + + week_over_week_result = week_over_week.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ).value + month_over_month_result = month_over_month(ctx).value + + print("\nWeek-over-week:") + for key, value in week_over_week_result.items(): + print(f" {key}: {value}") + + print("\nMonth-over-month:") + for key, value in month_over_month_result.items(): + print(f" {key}: {value}") + finally: + registry.clear() + + +if __name__ == "__main__": + main()