From de87bb26551e8e4feba8c2ad3dbff99f701446e4 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Wed, 31 Dec 2025 03:56:58 -0500 Subject: [PATCH 01/26] Add ability to define dynamic context from kwargs Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 165 ++++++++++++-- ccflow/tests/test_callable.py | 381 +++++++++++++++++++++++++++++++++ ccflow/tests/test_evaluator.py | 70 +++++- 3 files changed, 599 insertions(+), 17 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index b09eaea..b6580c9 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -12,10 +12,11 @@ """ import abc +import inspect import logging -from functools import lru_cache, wraps +from functools import lru_cache, partial, wraps from inspect import Signature, isclass, signature -from typing import Any, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator from typing_extensions import override @@ -27,6 +28,7 @@ ResultBase, ResultType, ) +from .local_persistence import create_ccflow_model from .validators import str_to_log_level __all__ = ( @@ -44,6 +46,7 @@ "EvaluatorBase", "Evaluator", "WrapperModel", + "dynamic_context", ) log = logging.getLogger(__name__) @@ -268,14 +271,31 @@ def get_evaluation_context(model: CallableModelType, context: ContextType, as_di def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = None, **kwargs): if not isinstance(model, CallableModel): raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.") - if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( - get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) - ): - raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") - if (not isclass(model.result_type) or not issubclass(model.result_type, ResultBase)) and not ( - get_origin(model.result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(model.result_type)) + + # Check if this is a dynamic_context decorated method + has_dynamic_context = hasattr(fn, "__dynamic_context__") + if has_dynamic_context: + method_context_type = fn.__dynamic_context__ + else: + method_context_type = model.context_type + + # Validate context type (skip for dynamic contexts which are always valid ContextBase subclasses) + if not has_dynamic_context: + if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( + get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) + ): + raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") + + # Validate result type - use __result_type__ for dynamic contexts if available + if has_dynamic_context and hasattr(fn, "__result_type__"): + method_result_type = fn.__result_type__ + else: + method_result_type = model.result_type + if (not isclass(method_result_type) or not issubclass(method_result_type, ResultBase)) and not ( + get_origin(method_result_type) is Union and all(isclass(t) and issubclass(t, ResultBase) for t in get_args(method_result_type)) ): - raise TypeError(f"Result type {model.result_type} must be a subclass of ResultBase") + raise TypeError(f"Result type {method_result_type} must be a subclass of ResultBase") + if self._deps and fn.__name__ != "__deps__": raise ValueError("Can only apply Flow.deps decorator to __deps__") if context is Signature.empty: @@ -285,18 +305,18 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = context = kwargs else: raise TypeError( - f"{fn.__name__}() missing 1 required positional argument: 'context' of type {model.context_type}, or kwargs to construct it" + f"{fn.__name__}() missing 1 required positional argument: 'context' of type {method_context_type}, or kwargs to construct it" ) elif kwargs: # Kwargs passed in as well as context. Not allowed raise TypeError(f"{fn.__name__}() was passed a context and got an unexpected keyword argument '{next(iter(kwargs.keys()))}'") # Type coercion on input. We do this here (rather than relying on ModelEvaluationContext) as it produces a nicer traceback/error message - if not isinstance(context, model.context_type): - if get_origin(model.context_type) is Union and type(None) in get_args(model.context_type): - model_context_type = [t for t in get_args(model.context_type) if t is not type(None)][0] + if not isinstance(context, method_context_type): + if get_origin(method_context_type) is Union and type(None) in get_args(method_context_type): + coerce_context_type = [t for t in get_args(method_context_type) if t is not type(None)][0] else: - model_context_type = model.context_type - context = model_context_type.model_validate(context) + coerce_context_type = method_context_type + context = coerce_context_type.model_validate(context) if fn != getattr(model.__class__, fn.__name__).__wrapped__: # This happens when super().__call__ is used when implementing a CallableModel that derives from another one. @@ -313,6 +333,13 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = wrap.get_evaluator = self.get_evaluator wrap.get_options = self.get_options wrap.get_evaluation_context = get_evaluation_context + + # Preserve dynamic context attributes for introspection + if hasattr(fn, "__dynamic_context__"): + wrap.__dynamic_context__ = fn.__dynamic_context__ + if hasattr(fn, "__result_type__"): + wrap.__result_type__ = fn.__result_type__ + return wrap @@ -417,6 +444,49 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def dynamic_call(*args, **kwargs): + """Decorator for methods that creates a dynamic context from the function signature. + + This combines @Flow.call and @dynamic_context into a single decorator, allowing + you to define the context inline in the function signature instead of creating + a separate context class. + + Example: + class MyModel(CallableModel): + @Flow.dynamic_call + def __call__(self, *, a: int, b: str = "default") -> GenericResult: + return GenericResult(value=f"{a}-{b}") + + model = MyModel() + model(a=42) # Works with kwargs + model(a=42, b="test") # Also works + + Args: + *args: When used without arguments, the decorated function + **kwargs: FlowOptions parameters (log_level, verbose, validate_result, etc.) + plus dynamic_context options: + - parent: Optional parent context class to inherit from + """ + # Import here to avoid circular import at module level + from ccflow.callable import dynamic_context + + # Extract dynamic_context-specific options + parent = kwargs.pop("parent", None) + + if len(args) == 1 and callable(args[0]): + # No arguments to decorator (@Flow.dynamic_call) + fn = args[0] + wrapped = dynamic_context(fn, parent=parent) + return Flow.call(wrapped) + else: + # Arguments to decorator (@Flow.dynamic_call(...)) + def decorator(fn): + wrapped = dynamic_context(fn, parent=parent) + return Flow.call(**kwargs)(wrapped) + + return decorator + # ***************************************************************************** # Define "Evaluators" and associated types @@ -754,3 +824,68 @@ def _validate_callable_model_generic_type(cls, m, handler, info): CallableModelGenericType = CallableModelGeneric + + +# ***************************************************************************** +# Dynamic Context Decorator +# ***************************************************************************** + + +def dynamic_context(func: Callable = None, *, parent: Type[ContextBase] = None) -> Callable: + """Decorator that creates a dynamic context class from function parameters. + + This decorator extracts the parameters from a function signature and creates + a dynamic ContextBase subclass whose fields correspond to those parameters. + The decorated function is then wrapped to accept the context object and + unpack it into keyword arguments. + + Example: + class MyCallable(CallableModel): + @Flow.dynamic_call # or @Flow.call @dynamic_context + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = MyCallable() + model(x=42, y="hello") # Works with kwargs + """ + if func is None: + return partial(dynamic_context, parent=parent) + + sig = signature(func) + base_class = parent or ContextBase + + # Validate parent fields are in function signature + if parent is not None: + parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys()) + sig_params = set(sig.parameters.keys()) - {"self"} + missing = parent_fields - sig_params + if missing: + raise TypeError(f"Parent context fields {missing} must be included in function signature") + + # Build fields from parameters (skip 'self'), pydantic validates types + fields = {} + for name, param in sig.parameters.items(): + if name == "self": + continue + default = ... if param.default is inspect.Parameter.empty else param.default + fields[name] = (param.annotation, default) + + # Create dynamic context class + dyn_context = create_ccflow_model(f"{func.__qualname__}_DynamicContext", __base__=base_class, **fields) + + @wraps(func) + def wrapper(self, context): + fn_kwargs = {name: getattr(context, name) for name in fields} + return func(self, **fn_kwargs) + + # Must set __signature__ so CallableModel validation sees 'context' parameter + wrapper.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=dyn_context), + ], + return_annotation=sig.return_annotation, + ) + wrapper.__dynamic_context__ = dyn_context + wrapper.__result_type__ = sig.return_annotation + return wrapper diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 43f86b5..444d496 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -20,6 +20,7 @@ ResultBase, ResultType, WrapperModel, + dynamic_context, ) from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME @@ -783,3 +784,383 @@ class MyCallableParent_bad_decorator(MyCallableParent): @Flow.deps def foo(self, context): return [] + + +# ============================================================================= +# Tests for dynamic_context decorator +# ============================================================================= + + +class TestDynamicContext(TestCase): + """Tests for the @dynamic_context decorator.""" + + def test_basic_usage_with_kwargs(self): + """Test basic dynamic_context usage with keyword arguments.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = DynamicCallable() + + # Call with kwargs + result = model(x=42, y="hello") + self.assertEqual(result.value, "42-hello") + + # Call with default + result = model(x=10) + self.assertEqual(result.value, "10-default") + + def test_dynamic_context_attribute(self): + """Test that __dynamic_context__ attribute is set.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, a: int, b: str) -> GenericResult: + return GenericResult(value=f"{a}-{b}") + + # The __call__ method should have __dynamic_context__ + call_method = DynamicCallable.__call__ + self.assertTrue(hasattr(call_method, "__wrapped__")) + # Access the inner function's __dynamic_context__ + inner = call_method.__wrapped__ + self.assertTrue(hasattr(inner, "__dynamic_context__")) + + dyn_ctx = inner.__dynamic_context__ + self.assertTrue(issubclass(dyn_ctx, ContextBase)) + self.assertIn("a", dyn_ctx.model_fields) + self.assertIn("b", dyn_ctx.model_fields) + + def test_dynamic_context_is_registered(self): + """Test that the dynamic context is registered for serialization.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + inner = DynamicCallable.__call__.__wrapped__ + dyn_ctx = inner.__dynamic_context__ + + # Should have __ccflow_import_path__ set + self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) + self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + def test_call_with_context_object(self): + """Test calling with a context object instead of kwargs.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = DynamicCallable() + + # Get the dynamic context class + dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + + # Create a context object + ctx = dyn_ctx(x=99, y="context") + result = model(ctx) + self.assertEqual(result.value, "99-context") + + def test_with_parent_context(self): + """Test dynamic_context with parent context class.""" + + class ParentContext(ContextBase): + base_value: str = "base" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context(parent=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + # Get dynamic context + dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + + # Should inherit from ParentContext + self.assertTrue(issubclass(dyn_ctx, ParentContext)) + + # Should have both fields + self.assertIn("base_value", dyn_ctx.model_fields) + self.assertIn("x", dyn_ctx.model_fields) + + # Create context with parent field + ctx = dyn_ctx(x=42, base_value="custom") + self.assertEqual(ctx.base_value, "custom") + self.assertEqual(ctx.x, 42) + + def test_parent_fields_must_be_in_signature(self): + """Test that parent fields must be included in function signature.""" + + class ParentContext(ContextBase): + required_field: str + + with self.assertRaises(TypeError) as cm: + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context(parent=ParentContext) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + self.assertIn("required_field", str(cm.exception)) + + def test_cloudpickle_roundtrip(self): + """Test cloudpickle roundtrip for dynamic context callable.""" + + class DynamicCallable(CallableModel): + multiplier: int = 2 + + @Flow.call + @dynamic_context + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + model = DynamicCallable(multiplier=3) + + # Test roundtrip + restored = rcploads(rcpdumps(model)) + + result = restored(x=10) + self.assertEqual(result.value, 30) + + def test_ray_task_execution(self): + """Test dynamic context callable in Ray task.""" + + class DynamicCallable(CallableModel): + factor: int = 2 + + @Flow.call + @dynamic_context + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + model = DynamicCallable(factor=5) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, x=10, y=2)) + + self.assertEqual(result, 60) # (10 + 2) * 5 + + def test_multiple_dynamic_context_methods(self): + """Test callable with multiple dynamic_context decorated methods.""" + + class MultiMethodCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, a: int) -> GenericResult: + return GenericResult(value=a) + + @dynamic_context + def other_method(self, *, b: str, c: float = 1.0) -> GenericResult: + return GenericResult(value=f"{b}-{c}") + + model = MultiMethodCallable() + + # Test __call__ + result1 = model(a=42) + self.assertEqual(result1.value, 42) + + # Test other_method (without Flow.call, just the dynamic_context wrapper) + # Need to create the context manually + other_ctx = model.other_method.__dynamic_context__ + ctx = other_ctx(b="hello", c=2.5) + result2 = model.other_method(ctx) + self.assertEqual(result2.value, "hello-2.5") + + def test_context_type_property_works(self): + """Test that type_ property works on the dynamic context.""" + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + ctx = dyn_ctx(x=42) + + # type_ should work and be importable + type_path = str(ctx.type_) + self.assertIn("_Local_", type_path) + self.assertEqual(ctx.type_.object, dyn_ctx) + + def test_complex_field_types(self): + """Test dynamic_context with complex field types.""" + from typing import List, Optional + + class DynamicCallable(CallableModel): + @Flow.call + @dynamic_context + def __call__( + self, + *, + items: List[int], + name: Optional[str] = None, + count: int = 0, + ) -> GenericResult: + total = sum(items) + count + return GenericResult(value=f"{name}:{total}" if name else str(total)) + + model = DynamicCallable() + + result = model(items=[1, 2, 3], name="test", count=10) + self.assertEqual(result.value, "test:16") + + result = model(items=[5, 5]) + self.assertEqual(result.value, "10") + + +class TestFlowDynamicCall(TestCase): + """Tests for @Flow.dynamic_call decorator.""" + + def test_basic_usage(self): + """Test basic @Flow.dynamic_call usage.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + + model = DynamicCallable() + + result = model(x=42, y="hello") + self.assertEqual(result.value, "42-hello") + + result = model(x=10) + self.assertEqual(result.value, "10-default") + + def test_dynamic_context_attributes_preserved(self): + """Test that __dynamic_context__ and __result_type__ are directly accessible.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + # Should be directly accessible without traversing __wrapped__ chain + method = DynamicCallable.__call__ + self.assertTrue(hasattr(method, "__dynamic_context__")) + self.assertTrue(hasattr(method, "__result_type__")) + self.assertTrue(issubclass(method.__dynamic_context__, ContextBase)) + self.assertEqual(method.__result_type__, GenericResult) + + def test_model_result_type_property(self): + """Test that model.result_type returns correct type for dynamic contexts.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + model = DynamicCallable() + self.assertEqual(model.result_type, GenericResult) + + def test_with_parent_context(self): + """Test @Flow.dynamic_call with parent context.""" + + class ParentContext(ContextBase): + base_value: str = "base" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call(parent=ParentContext) + def __call__(self, *, x: int, base_value: str) -> GenericResult: + return GenericResult(value=f"{x}-{base_value}") + + model = DynamicCallable() + + # Get dynamic context by traversing __wrapped__ chain + dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) + + # Should inherit from ParentContext + self.assertTrue(issubclass(dyn_ctx, ParentContext)) + + # Call should work, uses parent default + result = model(x=42, base_value="custom") + self.assertEqual(result.value, "42-custom") + + def test_with_flow_options(self): + """Test @Flow.dynamic_call with FlowOptions parameters.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call(validate_result=False) + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x) + + model = DynamicCallable() + result = model(x=42) + self.assertEqual(result.value, 42) + + def test_cloudpickle_roundtrip(self): + """Test cloudpickle roundtrip with @Flow.dynamic_call.""" + + class DynamicCallable(CallableModel): + multiplier: int = 2 + + @Flow.dynamic_call + def __call__(self, *, x: int) -> GenericResult: + return GenericResult(value=x * self.multiplier) + + model = DynamicCallable(multiplier=3) + restored = rcploads(rcpdumps(model)) + + result = restored(x=10) + self.assertEqual(result.value, 30) + + def test_ray_task(self): + """Test @Flow.dynamic_call in Ray task.""" + + class DynamicCallable(CallableModel): + factor: int = 2 + + @Flow.dynamic_call + def __call__(self, *, x: int, y: int = 1) -> GenericResult: + return GenericResult(value=(x + y) * self.factor) + + @ray.remote + def run_callable(model, **kwargs): + return model(**kwargs).value + + model = DynamicCallable(factor=5) + + with ray.init(num_cpus=1): + result = ray.get(run_callable.remote(model, x=10, y=2)) + + self.assertEqual(result, 60) + + def test_dynamic_context_is_registered(self): + """Test that the dynamic context from @Flow.dynamic_call is registered.""" + + class DynamicCallable(CallableModel): + @Flow.dynamic_call + def __call__(self, *, value: int) -> GenericResult: + return GenericResult(value=value) + + # Find dynamic context by traversing __wrapped__ chain + dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) + + self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) + self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + + +def _find_dynamic_context(func): + """Helper to find __dynamic_context__ by traversing the __wrapped__ chain.""" + visited = set() + current = func + while current is not None and id(current) not in visited: + visited.add(id(current)) + if hasattr(current, "__dynamic_context__"): + return current.__dynamic_context__ + current = getattr(current, "__wrapped__", None) + return None diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index cc34155..34f3f7e 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -1,9 +1,21 @@ from datetime import date from unittest import TestCase -from ccflow import DateContext, Evaluator, ModelEvaluationContext +import pytest -from .evaluators.util import MyDateCallable +from ccflow import CallableModel, DateContext, Evaluator, Flow, ModelEvaluationContext + +from .evaluators.util import MyDateCallable, MyResult + + +class MyDynamicDateCallable(CallableModel): + """Dynamic context version of MyDateCallable for testing evaluators.""" + + offset: int + + @Flow.dynamic_call(parent=DateContext) + def __call__(self, *, date: date) -> MyResult: + return MyResult(x=date.day + self.offset) class TestEvaluator(TestCase): @@ -32,3 +44,57 @@ def test_evaluator_deps(self): evaluator = Evaluator() out2 = evaluator.__deps__(model_evaluation_context) self.assertEqual(out2, out) + + +@pytest.mark.parametrize( + "callable_class", + [MyDateCallable, MyDynamicDateCallable], + ids=["standard", "dynamic"], +) +class TestEvaluatorParametrized: + """Test evaluators work with both standard and dynamic context callables.""" + + def test_evaluator_with_context_object(self, callable_class): + """Test evaluator with a context object.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + + out = model_evaluation_context() + assert out == MyResult(x=2) # day 1 + offset 1 + + evaluator = Evaluator() + out2 = evaluator(model_evaluation_context) + assert out2 == out + + def test_evaluator_with_fn_specified(self, callable_class): + """Test evaluator with fn='__call__' explicitly specified.""" + m1 = callable_class(offset=1) + context = DateContext(date=date(2022, 1, 1)) + model_evaluation_context = ModelEvaluationContext(model=m1, context=context, fn="__call__") + + out = model_evaluation_context() + assert out == MyResult(x=2) + + def test_evaluator_direct_call_matches(self, callable_class): + """Test that evaluator result matches direct call.""" + m1 = callable_class(offset=5) + context = DateContext(date=date(2022, 1, 15)) + + # Direct call + direct_result = m1(context) + + # Via evaluator + model_evaluation_context = ModelEvaluationContext(model=m1, context=context) + evaluator_result = model_evaluation_context() + + assert direct_result == evaluator_result + assert direct_result == MyResult(x=20) # day 15 + offset 5 + + def test_evaluator_with_kwargs(self, callable_class): + """Test that evaluator works when callable is called with kwargs.""" + m1 = callable_class(offset=1) + + # Call with kwargs + result = m1(date=date(2022, 1, 10)) + assert result == MyResult(x=11) # day 10 + offset 1 From 95119e335d80a3edb4c2e255eef21adb4331be7a Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Wed, 31 Dec 2025 23:19:39 -0500 Subject: [PATCH 02/26] Remove dynamic context, add option to Flow.call Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 186 +++++++-------- ccflow/tests/test_callable.py | 318 +++++++------------------ ccflow/tests/test_evaluator.py | 12 +- ccflow/tests/test_local_persistence.py | 2 +- 4 files changed, 180 insertions(+), 338 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 9b971c7..748759c 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -14,7 +14,7 @@ import abc import inspect import logging -from functools import lru_cache, partial, wraps +from functools import lru_cache, wraps from inspect import Signature, isclass, signature from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin @@ -46,7 +46,6 @@ "EvaluatorBase", "Evaluator", "WrapperModel", - "dynamic_context", ) log = logging.getLogger(__name__) @@ -272,22 +271,22 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = if not isinstance(model, CallableModel): raise TypeError(f"Can only decorate methods on CallableModels (not {type(model)}) with the flow decorator.") - # Check if this is a dynamic_context decorated method - has_dynamic_context = hasattr(fn, "__dynamic_context__") - if has_dynamic_context: - method_context_type = fn.__dynamic_context__ + # Check if this is an auto_context decorated method + has_auto_context = hasattr(fn, "__auto_context__") + if has_auto_context: + method_context_type = fn.__auto_context__ else: method_context_type = model.context_type - # Validate context type (skip for dynamic contexts which are always valid ContextBase subclasses) - if not has_dynamic_context: + # Validate context type (skip for auto contexts which are always valid ContextBase subclasses) + if not has_auto_context: if (not isclass(model.context_type) or not issubclass(model.context_type, ContextBase)) and not ( get_origin(model.context_type) is Union and type(None) in get_args(model.context_type) ): raise TypeError(f"Context type {model.context_type} must be a subclass of ContextBase") - # Validate result type - use __result_type__ for dynamic contexts if available - if has_dynamic_context and hasattr(fn, "__result_type__"): + # Validate result type - use __result_type__ for auto contexts if available + if has_auto_context and hasattr(fn, "__result_type__"): method_result_type = fn.__result_type__ else: method_result_type = model.result_type @@ -334,9 +333,9 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = wrap.get_options = self.get_options wrap.get_evaluation_context = get_evaluation_context - # Preserve dynamic context attributes for introspection - if hasattr(fn, "__dynamic_context__"): - wrap.__dynamic_context__ = fn.__dynamic_context__ + # Preserve auto context attributes for introspection + if hasattr(fn, "__auto_context__"): + wrap.__auto_context__ = fn.__auto_context__ if hasattr(fn, "__result_type__"): wrap.__result_type__ = fn.__result_type__ @@ -418,7 +417,58 @@ def __exit__(self, exc_type, exc_value, exc_tb): class Flow(PydanticBaseModel): @staticmethod def call(*args, **kwargs): - """Decorator for methods on callable models""" + """Decorator for methods on callable models. + + Args: + auto_context: Controls automatic context class generation from the function + signature. Accepts three types of values: + - False (default): No auto-generation, use traditional context parameter + - True: Auto-generate context class with no parent + - ContextBase subclass: Auto-generate context class inheriting from this parent + **kwargs: Additional FlowOptions parameters (log_level, verbose, validate_result, + cacheable, evaluator, volatile). + + Basic Example: + class MyModel(CallableModel): + @Flow.call + def __call__(self, context: MyContext) -> MyResult: + return MyResult(value=context.x) + + Auto Context Example: + class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> MyResult: + return MyResult(value=f"{x}-{y}") + + model = MyModel() + model(x=42) # Call with kwargs directly + + With Parent Context: + class MyModel(CallableModel): + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date, extra: int = 0) -> MyResult: + return MyResult(value=date.day + extra) + + # The generated context inherits from DateContext, so it's compatible + # with infrastructure expecting DateContext instances. + """ + # Extract auto_context option (not part of FlowOptions) + # Can be: False, True, or a ContextBase subclass + auto_context = kwargs.pop("auto_context", False) + + # Determine if auto_context is enabled and extract parent class if provided + if auto_context is False: + auto_context_enabled = False + context_parent = None + elif auto_context is True: + auto_context_enabled = True + context_parent = None + elif isclass(auto_context) and issubclass(auto_context, ContextBase): + auto_context_enabled = True + context_parent = auto_context + else: + raise TypeError(f"auto_context must be False, True, or a ContextBase subclass, got {auto_context!r}") + if len(args) == 1 and callable(args[0]): # No arguments to decorator, this is the decorator fn = args[0] @@ -427,6 +477,14 @@ def call(*args, **kwargs): else: # Arguments to decorator, this is just returning the decorator # Note that the code below is executed only once + if auto_context_enabled: + # Return a decorator that first applies auto_context, then FlowOptions + def auto_context_decorator(fn): + wrapped = _apply_auto_context(fn, parent=context_parent) + # FlowOptions.__call__ already applies wraps, so we just return its result + return FlowOptions(**kwargs)(wrapped) + + return auto_context_decorator return FlowOptions(**kwargs) @staticmethod @@ -444,81 +502,6 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) - @staticmethod - def dynamic_call(*args, **kwargs): - """Decorator that combines @Flow.call with dynamic context creation. - - Instead of defining a separate context class, this decorator creates one - automatically from the function signature. The method can then be called - with keyword arguments directly. - - Basic Example: - class MyModel(CallableModel): - @Flow.dynamic_call - def __call__(self, *, date: date, region: str = "US") -> MyResult: - return MyResult(value=f"{date}-{region}") - - model = MyModel() - model(date=date.today()) # Uses default region="US" - model(date=date.today(), region="EU") # Override default - - With Parent Context: - class MyModel(CallableModel): - @Flow.dynamic_call(parent=DateContext) - def __call__(self, *, date: date, extra: int = 0) -> MyResult: - return MyResult(value=date.day + extra) - - # Parent fields (date) must be included in the function signature. - # This is useful for integrating with existing infrastructure that - # expects specific context types. - - Args: - *args: The decorated function when used without parentheses - **kwargs: Combined options for FlowOptions and dynamic_context: - - Dynamic context options: - parent: Parent context class to inherit from. All parent fields - must appear in the function signature. - - FlowOptions (passed through to @Flow.call): - log_level: Logging level for evaluation (default: DEBUG) - verbose: Use verbose logging (default: True) - validate_result: Validate return against result_type (default: True) - cacheable: Allow result caching (default: False) - evaluator: Custom evaluator instance - - Returns: - A decorated method that accepts keyword arguments matching the signature. - - Notes: - - All parameters (except 'self') must have type annotations - - Use keyword-only parameters (after *) for cleaner signatures - - The generated context class is accessible via method.__dynamic_context__ - - The return type is accessible via method.__result_type__ - - See Also: - dynamic_context: The underlying decorator for context creation - Flow.call: The underlying decorator for flow evaluation - """ - # Import here to avoid circular import at module level - from ccflow.callable import dynamic_context - - # Extract dynamic_context-specific options - parent = kwargs.pop("parent", None) - - if len(args) == 1 and callable(args[0]): - # No arguments to decorator (@Flow.dynamic_call) - fn = args[0] - wrapped = dynamic_context(fn, parent=parent) - return Flow.call(wrapped) - else: - # Arguments to decorator (@Flow.dynamic_call(...)) - def decorator(fn): - wrapped = dynamic_context(fn, parent=parent) - return Flow.call(**kwargs)(wrapped) - - return decorator - # ***************************************************************************** # Define "Evaluators" and associated types @@ -859,30 +842,29 @@ def _validate_callable_model_generic_type(cls, m, handler, info): # ***************************************************************************** -# Dynamic Context Decorator +# Auto Context (internal helper for Flow.call(auto_context=True)) # ***************************************************************************** -def dynamic_context(func: Callable = None, *, parent: Type[ContextBase] = None) -> Callable: - """Decorator that creates a dynamic context class from function parameters. +def _apply_auto_context(func: Callable, *, parent: Type[ContextBase] = None) -> Callable: + """Internal function that creates an auto context class from function parameters. - This decorator extracts the parameters from a function signature and creates - a dynamic ContextBase subclass whose fields correspond to those parameters. + This function extracts the parameters from a function signature and creates + a ContextBase subclass whose fields correspond to those parameters. The decorated function is then wrapped to accept the context object and unpack it into keyword arguments. + Used internally by Flow.call(auto_context=...). + Example: class MyCallable(CallableModel): - @Flow.dynamic_call # or @Flow.call @dynamic_context + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: str = "default") -> GenericResult: return GenericResult(value=f"{x}-{y}") model = MyCallable() model(x=42, y="hello") # Works with kwargs """ - if func is None: - return partial(dynamic_context, parent=parent) - sig = signature(func) base_class = parent or ContextBase @@ -902,8 +884,8 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: default = ... if param.default is inspect.Parameter.empty else param.default fields[name] = (param.annotation, default) - # Create dynamic context class - dyn_context = create_ccflow_model(f"{func.__qualname__}_DynamicContext", __base__=base_class, **fields) + # Create auto context class + auto_context_class = create_ccflow_model(f"{func.__qualname__}_AutoContext", __base__=base_class, **fields) @wraps(func) def wrapper(self, context): @@ -914,10 +896,10 @@ def wrapper(self, context): wrapper.__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=dyn_context), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class), ], return_annotation=sig.return_annotation, ) - wrapper.__dynamic_context__ = dyn_context + wrapper.__auto_context__ = auto_context_class wrapper.__result_type__ = sig.return_annotation return wrapper diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 444d496..a748765 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -20,7 +20,6 @@ ResultBase, ResultType, WrapperModel, - dynamic_context, ) from ccflow.local_persistence import LOCAL_ARTIFACTS_MODULE_NAME @@ -787,23 +786,22 @@ def foo(self, context): # ============================================================================= -# Tests for dynamic_context decorator +# Tests for Flow.call(auto_context=True) # ============================================================================= -class TestDynamicContext(TestCase): - """Tests for the @dynamic_context decorator.""" +class TestAutoContext(TestCase): + """Tests for @Flow.call(auto_context=True).""" def test_basic_usage_with_kwargs(self): - """Test basic dynamic_context usage with keyword arguments.""" + """Test basic auto_context usage with keyword arguments.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: str = "default") -> GenericResult: return GenericResult(value=f"{x}-{y}") - model = DynamicCallable() + model = AutoContextCallable() # Call with kwargs result = model(x=42, y="hello") @@ -813,117 +811,111 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: result = model(x=10) self.assertEqual(result.value, "10-default") - def test_dynamic_context_attribute(self): - """Test that __dynamic_context__ attribute is set.""" + def test_auto_context_attribute(self): + """Test that __auto_context__ attribute is set.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, a: int, b: str) -> GenericResult: return GenericResult(value=f"{a}-{b}") - # The __call__ method should have __dynamic_context__ - call_method = DynamicCallable.__call__ + # The __call__ method should have __auto_context__ + call_method = AutoContextCallable.__call__ self.assertTrue(hasattr(call_method, "__wrapped__")) - # Access the inner function's __dynamic_context__ + # Access the inner function's __auto_context__ inner = call_method.__wrapped__ - self.assertTrue(hasattr(inner, "__dynamic_context__")) + self.assertTrue(hasattr(inner, "__auto_context__")) - dyn_ctx = inner.__dynamic_context__ - self.assertTrue(issubclass(dyn_ctx, ContextBase)) - self.assertIn("a", dyn_ctx.model_fields) - self.assertIn("b", dyn_ctx.model_fields) + auto_ctx = inner.__auto_context__ + self.assertTrue(issubclass(auto_ctx, ContextBase)) + self.assertIn("a", auto_ctx.model_fields) + self.assertIn("b", auto_ctx.model_fields) - def test_dynamic_context_is_registered(self): - """Test that the dynamic context is registered for serialization.""" + def test_auto_context_is_registered(self): + """Test that the auto context is registered for serialization.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, value: int) -> GenericResult: return GenericResult(value=value) - inner = DynamicCallable.__call__.__wrapped__ - dyn_ctx = inner.__dynamic_context__ + inner = AutoContextCallable.__call__.__wrapped__ + auto_ctx = inner.__auto_context__ # Should have __ccflow_import_path__ set - self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) - self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + self.assertTrue(hasattr(auto_ctx, "__ccflow_import_path__")) + self.assertTrue(auto_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) def test_call_with_context_object(self): """Test calling with a context object instead of kwargs.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: str = "default") -> GenericResult: return GenericResult(value=f"{x}-{y}") - model = DynamicCallable() + model = AutoContextCallable() - # Get the dynamic context class - dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + # Get the auto context class + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ # Create a context object - ctx = dyn_ctx(x=99, y="context") + ctx = auto_ctx(x=99, y="context") result = model(ctx) self.assertEqual(result.value, "99-context") def test_with_parent_context(self): - """Test dynamic_context with parent context class.""" + """Test auto_context with a parent context class.""" class ParentContext(ContextBase): base_value: str = "base" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context(parent=ParentContext) + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) def __call__(self, *, x: int, base_value: str) -> GenericResult: return GenericResult(value=f"{x}-{base_value}") - # Get dynamic context - dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ + # Get auto context + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ # Should inherit from ParentContext - self.assertTrue(issubclass(dyn_ctx, ParentContext)) + self.assertTrue(issubclass(auto_ctx, ParentContext)) # Should have both fields - self.assertIn("base_value", dyn_ctx.model_fields) - self.assertIn("x", dyn_ctx.model_fields) + self.assertIn("base_value", auto_ctx.model_fields) + self.assertIn("x", auto_ctx.model_fields) # Create context with parent field - ctx = dyn_ctx(x=42, base_value="custom") + ctx = auto_ctx(x=42, base_value="custom") self.assertEqual(ctx.base_value, "custom") self.assertEqual(ctx.x, 42) def test_parent_fields_must_be_in_signature(self): - """Test that parent fields must be included in function signature.""" + """Test that parent context fields must be included in function signature.""" class ParentContext(ContextBase): required_field: str with self.assertRaises(TypeError) as cm: - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context(parent=ParentContext) + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=ParentContext) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x) self.assertIn("required_field", str(cm.exception)) def test_cloudpickle_roundtrip(self): - """Test cloudpickle roundtrip for dynamic context callable.""" + """Test cloudpickle roundtrip for auto_context callable.""" - class DynamicCallable(CallableModel): + class AutoContextCallable(CallableModel): multiplier: int = 2 - @Flow.call - @dynamic_context + @Flow.call(auto_context=True) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x * self.multiplier) - model = DynamicCallable(multiplier=3) + model = AutoContextCallable(multiplier=3) # Test roundtrip restored = rcploads(rcpdumps(model)) @@ -932,13 +924,12 @@ def __call__(self, *, x: int) -> GenericResult: self.assertEqual(result.value, 30) def test_ray_task_execution(self): - """Test dynamic context callable in Ray task.""" + """Test auto_context callable in Ray task.""" - class DynamicCallable(CallableModel): + class AutoContextCallable(CallableModel): factor: int = 2 - @Flow.call - @dynamic_context + @Flow.call(auto_context=True) def __call__(self, *, x: int, y: int = 1) -> GenericResult: return GenericResult(value=(x + y) * self.factor) @@ -946,63 +937,35 @@ def __call__(self, *, x: int, y: int = 1) -> GenericResult: def run_callable(model, **kwargs): return model(**kwargs).value - model = DynamicCallable(factor=5) + model = AutoContextCallable(factor=5) with ray.init(num_cpus=1): result = ray.get(run_callable.remote(model, x=10, y=2)) self.assertEqual(result, 60) # (10 + 2) * 5 - def test_multiple_dynamic_context_methods(self): - """Test callable with multiple dynamic_context decorated methods.""" - - class MultiMethodCallable(CallableModel): - @Flow.call - @dynamic_context - def __call__(self, *, a: int) -> GenericResult: - return GenericResult(value=a) - - @dynamic_context - def other_method(self, *, b: str, c: float = 1.0) -> GenericResult: - return GenericResult(value=f"{b}-{c}") - - model = MultiMethodCallable() - - # Test __call__ - result1 = model(a=42) - self.assertEqual(result1.value, 42) - - # Test other_method (without Flow.call, just the dynamic_context wrapper) - # Need to create the context manually - other_ctx = model.other_method.__dynamic_context__ - ctx = other_ctx(b="hello", c=2.5) - result2 = model.other_method(ctx) - self.assertEqual(result2.value, "hello-2.5") - def test_context_type_property_works(self): - """Test that type_ property works on the dynamic context.""" + """Test that type_ property works on the auto context.""" - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x) - dyn_ctx = DynamicCallable.__call__.__wrapped__.__dynamic_context__ - ctx = dyn_ctx(x=42) + auto_ctx = AutoContextCallable.__call__.__wrapped__.__auto_context__ + ctx = auto_ctx(x=42) # type_ should work and be importable type_path = str(ctx.type_) self.assertIn("_Local_", type_path) - self.assertEqual(ctx.type_.object, dyn_ctx) + self.assertEqual(ctx.type_.object, auto_ctx) def test_complex_field_types(self): - """Test dynamic_context with complex field types.""" + """Test auto_context with complex field types.""" from typing import List, Optional - class DynamicCallable(CallableModel): - @Flow.call - @dynamic_context + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) def __call__( self, *, @@ -1013,7 +976,7 @@ def __call__( total = sum(items) + count return GenericResult(value=f"{name}:{total}" if name else str(total)) - model = DynamicCallable() + model = AutoContextCallable() result = model(items=[1, 2, 3], name="test", count=10) self.assertEqual(result.value, "test:16") @@ -1021,146 +984,43 @@ def __call__( result = model(items=[5, 5]) self.assertEqual(result.value, "10") - -class TestFlowDynamicCall(TestCase): - """Tests for @Flow.dynamic_call decorator.""" - - def test_basic_usage(self): - """Test basic @Flow.dynamic_call usage.""" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, x: int, y: str = "default") -> GenericResult: - return GenericResult(value=f"{x}-{y}") - - model = DynamicCallable() - - result = model(x=42, y="hello") - self.assertEqual(result.value, "42-hello") - - result = model(x=10) - self.assertEqual(result.value, "10-default") - - def test_dynamic_context_attributes_preserved(self): - """Test that __dynamic_context__ and __result_type__ are directly accessible.""" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, x: int) -> GenericResult: - return GenericResult(value=x) - - # Should be directly accessible without traversing __wrapped__ chain - method = DynamicCallable.__call__ - self.assertTrue(hasattr(method, "__dynamic_context__")) - self.assertTrue(hasattr(method, "__result_type__")) - self.assertTrue(issubclass(method.__dynamic_context__, ContextBase)) - self.assertEqual(method.__result_type__, GenericResult) - - def test_model_result_type_property(self): - """Test that model.result_type returns correct type for dynamic contexts.""" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, x: int) -> GenericResult: - return GenericResult(value=x) - - model = DynamicCallable() - self.assertEqual(model.result_type, GenericResult) - - def test_with_parent_context(self): - """Test @Flow.dynamic_call with parent context.""" - - class ParentContext(ContextBase): - base_value: str = "base" - - class DynamicCallable(CallableModel): - @Flow.dynamic_call(parent=ParentContext) - def __call__(self, *, x: int, base_value: str) -> GenericResult: - return GenericResult(value=f"{x}-{base_value}") - - model = DynamicCallable() - - # Get dynamic context by traversing __wrapped__ chain - dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) - - # Should inherit from ParentContext - self.assertTrue(issubclass(dyn_ctx, ParentContext)) - - # Call should work, uses parent default - result = model(x=42, base_value="custom") - self.assertEqual(result.value, "42-custom") - def test_with_flow_options(self): - """Test @Flow.dynamic_call with FlowOptions parameters.""" + """Test auto_context with FlowOptions parameters.""" - class DynamicCallable(CallableModel): - @Flow.dynamic_call(validate_result=False) + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True, validate_result=False) def __call__(self, *, x: int) -> GenericResult: return GenericResult(value=x) - model = DynamicCallable() + model = AutoContextCallable() result = model(x=42) self.assertEqual(result.value, 42) - def test_cloudpickle_roundtrip(self): - """Test cloudpickle roundtrip with @Flow.dynamic_call.""" - - class DynamicCallable(CallableModel): - multiplier: int = 2 - - @Flow.dynamic_call - def __call__(self, *, x: int) -> GenericResult: - return GenericResult(value=x * self.multiplier) - - model = DynamicCallable(multiplier=3) - restored = rcploads(rcpdumps(model)) - - result = restored(x=10) - self.assertEqual(result.value, 30) - - def test_ray_task(self): - """Test @Flow.dynamic_call in Ray task.""" - - class DynamicCallable(CallableModel): - factor: int = 2 - - @Flow.dynamic_call - def __call__(self, *, x: int, y: int = 1) -> GenericResult: - return GenericResult(value=(x + y) * self.factor) - - @ray.remote - def run_callable(model, **kwargs): - return model(**kwargs).value - - model = DynamicCallable(factor=5) + def test_error_without_auto_context(self): + """Test that using kwargs signature without auto_context raises an error.""" - with ray.init(num_cpus=1): - result = ray.get(run_callable.remote(model, x=10, y=2)) - - self.assertEqual(result, 60) - - def test_dynamic_context_is_registered(self): - """Test that the dynamic context from @Flow.dynamic_call is registered.""" + class BadCallable(CallableModel): + @Flow.call # Missing auto_context=True! + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") - class DynamicCallable(CallableModel): - @Flow.dynamic_call - def __call__(self, *, value: int) -> GenericResult: - return GenericResult(value=value) + # Error happens at instantiation time when _check_signature validates + with self.assertRaises(ValueError) as cm: + BadCallable() - # Find dynamic context by traversing __wrapped__ chain - dyn_ctx = _find_dynamic_context(DynamicCallable.__call__) + # Should fail because __call__ must take a single argument named 'context' + error_msg = str(cm.exception) + self.assertIn("__call__", error_msg) + self.assertIn("context", error_msg) - self.assertTrue(hasattr(dyn_ctx, "__ccflow_import_path__")) - self.assertTrue(dyn_ctx.__ccflow_import_path__.startswith(LOCAL_ARTIFACTS_MODULE_NAME)) + def test_invalid_auto_context_value(self): + """Test that invalid auto_context values raise TypeError with helpful message.""" + with self.assertRaises(TypeError) as cm: + @Flow.call(auto_context="invalid") + def bad_func(self, *, x: int) -> GenericResult: + return GenericResult(value=x) -def _find_dynamic_context(func): - """Helper to find __dynamic_context__ by traversing the __wrapped__ chain.""" - visited = set() - current = func - while current is not None and id(current) not in visited: - visited.add(id(current)) - if hasattr(current, "__dynamic_context__"): - return current.__dynamic_context__ - current = getattr(current, "__wrapped__", None) - return None + error_msg = str(cm.exception) + self.assertIn("auto_context must be False, True, or a ContextBase subclass", error_msg) + self.assertIn("invalid", error_msg) diff --git a/ccflow/tests/test_evaluator.py b/ccflow/tests/test_evaluator.py index 34f3f7e..dabf815 100644 --- a/ccflow/tests/test_evaluator.py +++ b/ccflow/tests/test_evaluator.py @@ -8,12 +8,12 @@ from .evaluators.util import MyDateCallable, MyResult -class MyDynamicDateCallable(CallableModel): - """Dynamic context version of MyDateCallable for testing evaluators.""" +class MyAutoContextDateCallable(CallableModel): + """Auto context version of MyDateCallable for testing evaluators.""" offset: int - @Flow.dynamic_call(parent=DateContext) + @Flow.call(auto_context=DateContext) def __call__(self, *, date: date) -> MyResult: return MyResult(x=date.day + self.offset) @@ -48,11 +48,11 @@ def test_evaluator_deps(self): @pytest.mark.parametrize( "callable_class", - [MyDateCallable, MyDynamicDateCallable], - ids=["standard", "dynamic"], + [MyDateCallable, MyAutoContextDateCallable], + ids=["standard", "auto_context"], ) class TestEvaluatorParametrized: - """Test evaluators work with both standard and dynamic context callables.""" + """Test evaluators work with both standard and auto_context callables.""" def test_evaluator_with_context_object(self, callable_class): """Test evaluator with a context object.""" diff --git a/ccflow/tests/test_local_persistence.py b/ccflow/tests/test_local_persistence.py index 586b03f..dc2db55 100644 --- a/ccflow/tests/test_local_persistence.py +++ b/ccflow/tests/test_local_persistence.py @@ -1235,7 +1235,7 @@ class TestCreateCcflowModelCloudpickleCrossProcess: id="context_only", ), pytest.param( - # Dynamic context with CallableModel + # Runtime-created context with CallableModel """ from ray.cloudpickle import dump from ccflow import CallableModel, ContextBase, GenericResult, Flow From 989e278eccc1a1dec446eb26fb53febad1d3449e Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Sun, 4 Jan 2026 19:22:25 -0500 Subject: [PATCH 03/26] Add @Flow.model decorator, new annotation that pulls from deps Signed-off-by: Nijat Khanbabayev --- ccflow/__init__.py | 1 + ccflow/callable.py | 216 +++- ccflow/dep.py | 278 +++++ ccflow/flow_model.py | 341 ++++++ ccflow/tests/config/conf_flow.yaml | 80 ++ ccflow/tests/test_callable.py | 1 + ccflow/tests/test_flow_model.py | 1477 +++++++++++++++++++++++++ ccflow/tests/test_flow_model_hydra.py | 437 ++++++++ docs/wiki/Key-Features.md | 115 ++ examples/flow_model_example.py | 219 ++++ 10 files changed, 3163 insertions(+), 2 deletions(-) create mode 100644 ccflow/dep.py create mode 100644 ccflow/flow_model.py create mode 100644 ccflow/tests/config/conf_flow.yaml create mode 100644 ccflow/tests/test_flow_model.py create mode 100644 ccflow/tests/test_flow_model_hydra.py create mode 100644 examples/flow_model_example.py diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 163f275..9916168 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -10,6 +10,7 @@ from .compose import * from .callable import * from .context import * +from .dep import * from .enums import Enum from .global_state import * from .local_persistence import * diff --git a/ccflow/callable.py b/ccflow/callable.py index 748759c..5296bfe 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -28,6 +28,7 @@ ResultBase, ResultType, ) +from .dep import Dep, extract_dep from .local_persistence import create_ccflow_model from .validators import str_to_log_level @@ -128,7 +129,7 @@ def _check_result_type(cls, result_type): @model_validator(mode="after") def _check_signature(self): sig_call = _cached_signature(self.__class__.__call__) - if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: # ("self", "context") + if len(sig_call.parameters) != 2 or "context" not in sig_call.parameters: raise ValueError("__call__ method must take a single argument, named 'context'") sig_deps = _cached_signature(self.__class__.__deps__) @@ -195,6 +196,114 @@ def _get_logging_evaluator(log_level): return LoggingEvaluator(log_level=log_level) +def _get_dep_fields(model_class) -> Dict[str, Dep]: + """Analyze class fields to find Dep-annotated fields. + + Returns a dict mapping field name to Dep instance for fields that need resolution. + """ + dep_fields = {} + + # Get type hints from the class + hints = {} + for cls in model_class.__mro__: + if hasattr(cls, "__annotations__"): + for name, annotation in cls.__annotations__.items(): + if name not in hints: # Don't override child class annotations + hints[name] = annotation + + for name, annotation in hints.items(): + base_type, dep = extract_dep(annotation) + if dep is not None: + dep_fields[name] = dep + + return dep_fields + + +def _wrap_with_dep_resolution(fn): + """Wrap a function to auto-resolve DepOf fields before calling. + + For each Dep-annotated field on the model that contains a CallableModel, + resolves it using __deps__ and temporarily sets the resolved value on self. + + Note: This wrapper is only applied at runtime when the function is called, + not during decoration. This avoids issues with functools.wraps flattening + the __wrapped__ chain. + + Args: + fn: The original function + + Returns: + The original function unchanged - dep resolution happens at the call site + """ + # Don't modify the function - dep resolution is handled in ModelEvaluationContext + return fn + + +def _resolve_deps_and_call(model, context, fn): + """Resolve DepOf fields and call the function. + + This is called from ModelEvaluationContext.__call__ to handle dep resolution. + + Args: + model: The CallableModel instance + context: The context to pass to the function + fn: The function to call + + Returns: + The result of calling fn(model, context) + """ + # Don't resolve deps for __deps__ method + if fn.__name__ == "__deps__": + return fn(model, context) + + # Get Dep-annotated fields for this model class + dep_fields = _get_dep_fields(model.__class__) + + if not dep_fields: + return fn(model, context) + + # Get dependencies from __deps__ + deps_result = model.__deps__(context) + # Build a map from model instance id to (model, contexts) for lookup + dep_map = {} + for dep_model, contexts in deps_result: + dep_map[id(dep_model)] = (dep_model, contexts) + + # Store original values and resolve + originals = {} + for field_name, dep in dep_fields.items(): + field_value = getattr(model, field_name, None) + if field_value is None: + continue + + # Check if field is a CallableModel that needs resolution + if not isinstance(field_value, _CallableModel): + continue # Already a resolved value, skip + + originals[field_name] = field_value + + # Check if this field is in __deps__ (for custom transforms) + if id(field_value) in dep_map: + dep_model, contexts = dep_map[id(field_value)] + # Call dependency with the (transformed) context + resolved = dep_model(contexts[0]) if contexts else dep_model(context) + else: + # Not in __deps__, use Dep annotation transform directly + transformed_ctx = dep.apply(context) + resolved = field_value(transformed_ctx) + + # Temporarily set resolved value on model + object.__setattr__(model, field_name, resolved) + + try: + # Call original function + return fn(model, context) + finally: + # Restore original CallableModel values + for field_name, original_value in originals.items(): + object.__setattr__(model, field_name, original_value) + + class FlowOptions(BaseModel): """Options for Flow evaluation. @@ -246,6 +355,9 @@ def get_evaluator(self, model: CallableModelType) -> "EvaluatorBase": return self._get_evaluator_from_options(options) def __call__(self, fn): + # Wrap function with dependency resolution for DepOf fields + fn = _wrap_with_dep_resolution(fn) + # Used for building a graph of model evaluation contexts without evaluating def get_evaluation_context(model: CallableModelType, context: ContextType, as_dict: bool = False, *, _options: Optional[FlowOptions] = None): # Create the evaluation context. @@ -451,6 +563,33 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult: # The generated context inherits from DateContext, so it's compatible # with infrastructure expecting DateContext instances. + + Auto-Resolve Dependencies Example: + When __call__ has parameters beyond 'self' and 'context' that match field + names annotated with DepOf/Dep, those dependencies are automatically resolved + using __deps__ (if defined) or auto-generated from Dep annotations. + + class MyModel(CallableModel): + data: Annotated[GenericResult[dict], Dep(transform=my_transform)] + + @Flow.call + def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: + # data is automatically resolved - no manual calling needed + return GenericResult(value=process(data.value)) + + For transforms that need access to instance fields, define __deps__ manually: + + class MyModel(CallableModel): + data: DepOf[..., GenericResult[dict]] + window: int = 7 + + def __deps__(self, context): + # Can access self.window here + return [(self.data, [context.with_lookback(self.window)])] + + @Flow.call + def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: + return GenericResult(value=process(data.value)) """ # Extract auto_context option (not part of FlowOptions) # Can be: False, True, or a ContextBase subclass @@ -502,6 +641,78 @@ def deps(*args, **kwargs): # Note that the code below is executed only once return FlowOptionsDeps(**kwargs) + @staticmethod + def model(*args, **kwargs): + """Decorator that generates a CallableModel class from a plain Python function. + + This is syntactic sugar over CallableModel. The decorator generates a real + CallableModel class with proper __call__ and __deps__ methods, so all existing + features (caching, evaluation, registry, serialization) work unchanged. + + Args: + context_args: List of parameter names that come from context (for unpacked mode) + cacheable: Enable caching of results (default: False) + volatile: Mark as volatile (default: False) + log_level: Logging verbosity (default: logging.DEBUG) + validate_result: Validate return type (default: True) + verbose: Verbose logging output (default: True) + evaluator: Custom evaluator (default: None) + + Two Context Modes: + + Mode 1 - Explicit context parameter: + Function has a 'context' parameter annotated with a ContextBase subclass. + + @Flow.model + def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, context.start_date, context.end_date)) + + Mode 2 - Unpacked context_args: + Context fields are unpacked into function parameters. + + @Flow.model(context_args=["start_date", "end_date"]) + def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, start_date, end_date)) + + Dependencies: + Use Dep() or DepOf to mark parameters that can accept CallableModel dependencies: + + from ccflow import Dep, DepOf + from typing import Annotated + + @Flow.model + def compute_returns( + context: DateRangeContext, + prices: Annotated[GenericResult[pl.DataFrame], Dep( + transform=lambda ctx: ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + )] + ) -> GenericResult[pl.DataFrame]: + return GenericResult(value=prices.value.pct_change()) + + # Or use DepOf shorthand for no transform: + @Flow.model + def compute_stats( + context: DateRangeContext, + data: DepOf[..., GenericResult[pl.DataFrame]] + ) -> GenericResult[pl.DataFrame]: + return GenericResult(value=data.value.describe()) + + Usage: + # Create model instances + loader = load_prices(source="prod_db") + returns = compute_returns(prices=loader) + + # Execute + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = returns(ctx) + + Returns: + A factory function that creates CallableModel instances + """ + from .flow_model import flow_model + + return flow_model(*args, **kwargs) + # ***************************************************************************** # Define "Evaluators" and associated types @@ -555,7 +766,8 @@ def _context_validator(cls, values, handler, info): def __call__(self) -> ResultType: fn = getattr(self.model, self.fn) if hasattr(fn, "__wrapped__"): - result = fn.__wrapped__(self.model, self.context) + # Call through _resolve_deps_and_call to handle DepOf field resolution + result = _resolve_deps_and_call(self.model, self.context, fn.__wrapped__) # If it's a callable model, then we can validate the result if self.options.get("validate_result", True): if fn.__name__ == "__deps__": diff --git a/ccflow/dep.py b/ccflow/dep.py new file mode 100644 index 0000000..a7e0121 --- /dev/null +++ b/ccflow/dep.py @@ -0,0 +1,278 @@ +"""Dependency annotation markers for Flow.model. + +This module provides: +- Dep: Annotation marker for dependency parameters that can accept CallableModel +- DepOf: Shorthand for Annotated[Union[T, CallableModel], Dep()] +""" + +from typing import TYPE_CHECKING, Annotated, Callable, Optional, Type, TypeVar, Union, get_args, get_origin + +from .base import ContextBase + +if TYPE_CHECKING: + from .callable import CallableModel + +__all__ = ("Dep", "DepOf") + +T = TypeVar("T") + +# Lazy reference to CallableModel to avoid circular import +_CallableModel = None + + +def _get_callable_model(): + """Lazily import CallableModel to avoid circular imports.""" + global _CallableModel + if _CallableModel is None: + from .callable import CallableModel + + _CallableModel = CallableModel + return _CallableModel + + +class _DepOfMeta(type): + """Metaclass that makes DepOf[ContextType, ResultType] work.""" + + def __getitem__(cls, item): + if not isinstance(item, tuple) or len(item) != 2: + raise TypeError( + "DepOf requires 2 type arguments: DepOf[ContextType, ResultType]. " + "Use ... for ContextType to inherit from parent: DepOf[..., ResultType]" + ) + context_type, result_type = item + CallableModel = _get_callable_model() + + if context_type is ...: + # DepOf[..., ResultType] - inherit context from parent + return Annotated[Union[result_type, CallableModel], Dep()] + else: + # DepOf[ContextType, ResultType] - explicit context type + return Annotated[Union[result_type, CallableModel], Dep(context_type=context_type)] + + +class DepOf(metaclass=_DepOfMeta): + """ + Shorthand for Annotated[Union[ResultType, CallableModel], Dep(context_type=...)]. + + Follows Callable convention: DepOf[InputContext, OutputResult] + + For class fields, accepts either: + - The result type directly (pre-computed value) + - A CallableModel that produces the result type (resolved at call time) + + Usage: + # Inherit context type from parent model (most common) + data: DepOf[..., GenericResult[dict]] + + # Explicit context type validation + data: DepOf[DateRangeContext, GenericResult[dict]] + + At call time, if the field contains a CallableModel, it will be automatically + resolved using __deps__ and the resolved value will be accessible via self.field_name. + + For dependencies with transforms, define them in __deps__: + def __deps__(self, context): + transformed_ctx = context.model_copy(update={...}) + return [(self.data, [transformed_ctx])] + """ + + pass + + +def _is_compatible_type(actual: Type, expected: Type) -> bool: + """Check if actual type is compatible with expected type. + + Handles generic types like GenericResult[pl.DataFrame] where issubclass + would raise TypeError. + + Args: + actual: The actual type to check + expected: The expected type to match against + + Returns: + True if actual is compatible with expected + """ + # Handle None/empty types + if actual is None or expected is None: + return actual is expected + + # Get origins for generic types + actual_origin = get_origin(actual) or actual + expected_origin = get_origin(expected) or expected + + # Check if origins are compatible + try: + if not (isinstance(actual_origin, type) and isinstance(expected_origin, type)): + return False + if not issubclass(actual_origin, expected_origin): + return False + except TypeError: + # issubclass can fail for certain types + return False + + # Check generic args if present + actual_args = get_args(actual) + expected_args = get_args(expected) + + if expected_args and actual_args: + if len(actual_args) != len(expected_args): + return False + return all(_is_compatible_type(a, e) for a, e in zip(actual_args, expected_args)) + + return True + + +class Dep: + """ + Annotation marker for dependency parameters. + + Marks a parameter as accepting either the declared type or a CallableModel + that produces that type. Supports optional context transform and + construction-time type validation. + + Usage: + # No transform, no explicit validation (uses parent's context_type) + prices: Annotated[GenericResult[pl.DataFrame], Dep()] + + # With transform + prices: Annotated[GenericResult[pl.DataFrame], Dep( + transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) + )] + + # With explicit context_type validation + prices: Annotated[GenericResult[pl.DataFrame], Dep( + context_type=DateRangeContext, + transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) + )] + + # Cross-context dependency (transform changes context type) + sim_data: Annotated[GenericResult[pl.DataFrame], Dep( + context_type=SimulationContext, + transform=date_to_simulation_context + )] + """ + + def __init__( + self, + transform: Optional[Callable[[ContextBase], ContextBase]] = None, + context_type: Optional[Type[ContextBase]] = None, + ): + """ + Args: + transform: Optional function to transform context before calling dependency. + Signature: (context) -> transformed_context + context_type: Expected context_type of the dependency CallableModel. + If None, defaults to the parent model's context_type. + Validated at construction time when a CallableModel is passed. + """ + self.transform = transform + self.context_type = context_type + + def apply(self, context: ContextBase) -> ContextBase: + """Apply the transform to a context, or return unchanged if no transform.""" + if self.transform is not None: + return self.transform(context) + return context + + def validate_dependency( + self, + value: "CallableModel", # noqa: F821 + expected_result_type: Type, + parent_context_type: Type[ContextBase], + param_name: str, + ) -> None: + """ + Validate a CallableModel dependency at construction time. + + Args: + value: The CallableModel being passed as a dependency + expected_result_type: The result type from the Annotated type hint + parent_context_type: The context_type of the parent model + param_name: Name of the parameter (for error messages) + + Raises: + TypeError: If context_type or result_type don't match + """ + # Import here to avoid circular import + from .callable import CallableModel + + if not isinstance(value, CallableModel): + return # Not a CallableModel, skip validation + + # Determine expected context type + expected_ctx = self.context_type if self.context_type is not None else parent_context_type + + # Validate context_type - the dependency's context_type should be compatible + # with what we'll pass to it (expected_ctx) + dep_context_type = value.context_type + try: + if not issubclass(expected_ctx, dep_context_type): + raise TypeError( + f"Dependency '{param_name}': expected context_type compatible with " + f"{dep_context_type.__name__}, but will pass {expected_ctx.__name__}" + ) + except TypeError: + # issubclass can fail for certain types, try alternate check + if expected_ctx != dep_context_type: + raise TypeError(f"Dependency '{param_name}': context_type mismatch - expected {dep_context_type}, got {expected_ctx}") + + # Validate result_type using the generic-safe comparison + # If expected_result_type is Union[T, CallableModel], extract T for validation + dep_result_type = value.result_type + actual_expected_type = expected_result_type + + # Handle Union[T, CallableModel] from DepOf expansion + if get_origin(expected_result_type) is Union: + union_args = get_args(expected_result_type) + # Filter out CallableModel from the union + non_callable_types = [t for t in union_args if t is not CallableModel] + if non_callable_types: + actual_expected_type = non_callable_types[0] + + if not _is_compatible_type(dep_result_type, actual_expected_type): + raise TypeError( + f"Dependency '{param_name}': expected result_type compatible with " + f"{actual_expected_type}, but got CallableModel with result_type {dep_result_type}" + ) + + def __repr__(self): + parts = [] + if self.transform is not None: + parts.append(f"transform={self.transform}") + if self.context_type is not None: + parts.append(f"context_type={self.context_type.__name__}") + return f"Dep({', '.join(parts)})" if parts else "Dep()" + + def __eq__(self, other): + if not isinstance(other, Dep): + return False + return self.transform == other.transform and self.context_type == other.context_type + + def __hash__(self): + # Make Dep hashable for use in sets/dicts + return hash((id(self.transform), self.context_type)) + + +def extract_dep(annotation) -> tuple: + """Extract Dep from Annotated[T, Dep(...)] or DepOf[ContextType, T]. + + When multiple Dep annotations exist (e.g., from nested Annotated that flattens), + returns the LAST one, which represents the outermost user annotation. + + Args: + annotation: A type annotation, possibly Annotated with Dep + + Returns: + Tuple of (base_type, Dep instance or None) + """ + if get_origin(annotation) is Annotated: + args = get_args(annotation) + base_type = args[0] + # Find the LAST Dep - nested Annotated flattens, so outer annotation comes last + last_dep = None + for metadata in args[1:]: + if isinstance(metadata, Dep): + last_dep = metadata + if last_dep is not None: + return base_type, last_dep + return annotation, None diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py new file mode 100644 index 0000000..3a96886 --- /dev/null +++ b/ccflow/flow_model.py @@ -0,0 +1,341 @@ +"""Flow.model decorator implementation. + +This module provides the Flow.model decorator that generates CallableModel classes +from plain Python functions, reducing boilerplate while maintaining full compatibility +with existing ccflow infrastructure. +""" + +import inspect +import logging +from functools import wraps +from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_origin + +from pydantic import Field + +from .base import ContextBase, ResultBase +from .dep import Dep, extract_dep +from .local_persistence import register_ccflow_import_path + +__all__ = ("flow_model",) + +log = logging.getLogger(__name__) + + +def _infer_context_type_from_args(context_args: List[str], func: Callable, sig: inspect.Signature) -> Type[ContextBase]: + """Infer or create a context type from context_args parameter names. + + This attempts to match existing context types or creates a new one. + + Args: + context_args: List of parameter names that come from context + func: The decorated function + sig: The function signature + + Returns: + A ContextBase subclass + """ + from .local_persistence import create_ccflow_model + + # Build field definitions for the context from parameter annotations + fields = {} + for name in context_args: + if name not in sig.parameters: + raise ValueError(f"context_arg '{name}' not found in function parameters") + param = sig.parameters[name] + if param.annotation is inspect.Parameter.empty: + raise ValueError(f"context_arg '{name}' must have a type annotation") + default = ... if param.default is inspect.Parameter.empty else param.default + fields[name] = (param.annotation, default) + + # Try to match common context types + from .context import DateRangeContext + + # Check for DateRangeContext pattern + if set(context_args) == {"start_date", "end_date"}: + from datetime import date + + if all( + sig.parameters[name].annotation in (date, "date") + or (isinstance(sig.parameters[name].annotation, type) and sig.parameters[name].annotation is date) + for name in context_args + ): + return DateRangeContext + + # Create a new context type dynamically + context_class = create_ccflow_model( + f"_{func.__name__}_Context", + __base__=ContextBase, + **fields, + ) + return context_class + + +def _get_dep_info(annotation) -> Tuple[Type, Optional[Dep]]: + """Extract dependency info from an annotation. + + Returns: + Tuple of (base_type, Dep instance or None) + """ + return extract_dep(annotation) + + +def flow_model( + func: Callable = None, + *, + # Context handling + context_args: Optional[List[str]] = None, + # Flow.call options (passed to generated __call__) + cacheable: bool = False, + volatile: bool = False, + log_level: int = logging.DEBUG, + validate_result: bool = True, + verbose: bool = True, + evaluator: Optional[Any] = None, +) -> Callable: + """Decorator that generates a CallableModel class from a plain Python function. + + This is syntactic sugar over CallableModel. The decorator generates a real + CallableModel class with proper __call__ and __deps__ methods, so all existing + features (caching, evaluation, registry, serialization) work unchanged. + + Args: + func: The function to decorate + context_args: List of parameter names that come from context (for unpacked mode) + cacheable: Enable caching of results + volatile: Mark as volatile (always re-execute) + log_level: Logging verbosity + validate_result: Validate return type + verbose: Verbose logging output + evaluator: Custom evaluator + + Two Context Modes: + 1. Explicit context parameter: Function has a 'context' parameter annotated + with a ContextBase subclass. + + @Flow.model + def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: + ... + + 2. Unpacked context_args: Context fields are unpacked into function parameters. + + @Flow.model(context_args=["start_date", "end_date"]) + def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: + ... + + Returns: + A factory function that creates CallableModel instances + """ + + def decorator(fn: Callable) -> Callable: + # Import here to avoid circular imports + from .callable import CallableModel, Flow, GraphDepList + + sig = inspect.signature(fn) + params = sig.parameters + + # Validate return type + return_type = sig.return_annotation + if return_type is inspect.Signature.empty: + raise TypeError(f"Function {fn.__name__} must have a return type annotation") + # Check that return type is a ResultBase subclass + return_origin = get_origin(return_type) or return_type + if not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)): + raise TypeError(f"Function {fn.__name__} must return a ResultBase subclass, got {return_type}") + + # Determine context mode and extract info + if context_args is not None: + # Mode 2: Unpacked context args + context_type = _infer_context_type_from_args(context_args, fn, sig) + model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} + use_context_args = True + elif "context" in params or "_" in params: + # Mode 1: Explicit context parameter (named 'context' or '_' for unused) + context_param_name = "context" if "context" in params else "_" + context_param = params[context_param_name] + if context_param.annotation is inspect.Parameter.empty: + raise TypeError(f"Function {fn.__name__}: '{context_param_name}' parameter must have a type annotation") + context_type = context_param.annotation + if not (isinstance(context_type, type) and issubclass(context_type, ContextBase)): + raise TypeError(f"Function {fn.__name__}: '{context_param_name}' must be annotated with a ContextBase subclass") + model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} + use_context_args = False + else: + raise TypeError(f"Function {fn.__name__} must either have a 'context' (or '_') parameter or specify context_args in the decorator") + + # Analyze parameters to find dependencies and regular fields + dep_fields: Dict[str, Tuple[Type, Dep]] = {} # name -> (base_type, Dep) + model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) + + for name, param in model_field_params.items(): + if param.annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation") + + base_type, dep = _get_dep_info(param.annotation) + default = ... if param.default is inspect.Parameter.empty else param.default + + if dep is not None: + # This is a dependency parameter + dep_fields[name] = (base_type, dep) + # Use Annotated so _resolve_deps_and_call in callable.py can find the Dep + # This consolidates resolution logic into one place + model_fields[name] = (Annotated[Union[base_type, CallableModel], dep], default) + else: + # Regular model field + model_fields[name] = (param.annotation, default) + + # Capture context_args in local variable for closures + ctx_args_list = context_args or [] + # Capture context parameter name for closures (only used in mode 1) + ctx_param_name = context_param_name if not use_context_args else "context" + + # Create the __call__ method + def make_call_impl(): + def __call__(self, context): + # Build kwargs for the original function + if use_context_args: + # Unpack context into args + fn_kwargs = {name: getattr(context, name) for name in ctx_args_list} + else: + # Pass context directly (using actual parameter name: 'context' or '_') + fn_kwargs = {ctx_param_name: context} + + # Add model fields (deps are resolved by _resolve_deps_and_call in callable.py) + for name in model_fields: + fn_kwargs[name] = getattr(self, name) + + return fn(**fn_kwargs) + + # Set proper signature for CallableModel validation + __call__.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + ], + return_annotation=return_type, + ) + return __call__ + + call_impl = make_call_impl() + + # Apply Flow.call decorator + flow_options = { + "cacheable": cacheable, + "volatile": volatile, + "log_level": log_level, + "validate_result": validate_result, + "verbose": verbose, + } + if evaluator is not None: + flow_options["evaluator"] = evaluator + + decorated_call = Flow.call(**flow_options)(call_impl) + + # Create the __deps__ method + def make_deps_impl(): + def __deps__(self, context) -> GraphDepList: + deps = [] + for dep_name, (base_type, dep_obj) in dep_fields.items(): + value = getattr(self, dep_name) + if isinstance(value, CallableModel): + transformed_ctx = dep_obj.apply(context) + deps.append((value, [transformed_ctx])) + return deps + + # Set proper signature + __deps__.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + ], + return_annotation=GraphDepList, + ) + return __deps__ + + deps_impl = make_deps_impl() + decorated_deps = Flow.deps(deps_impl) + + # Build pydantic field annotations for the class + annotations = {} + + namespace = { + "__module__": fn.__module__, + "__qualname__": f"_{fn.__name__}_Model", + "__call__": decorated_call, + "__deps__": decorated_deps, + } + + for name, (typ, default) in model_fields.items(): + annotations[name] = typ + if default is not ...: + namespace[name] = default + else: + # For required fields, use Field(...) + namespace[name] = Field(...) + + namespace["__annotations__"] = annotations + + # Add model validator for dependency validation if we have dep fields + if dep_fields: + from pydantic import model_validator + + # Create validator function that captures dep_fields and context_type + def make_dep_validator(d_fields, ctx_type): + @model_validator(mode="after") + def __validate_deps__(self): + from .callable import CallableModel + + for dep_name, (base_type, dep_obj) in d_fields.items(): + value = getattr(self, dep_name) + if isinstance(value, CallableModel): + dep_obj.validate_dependency(value, base_type, ctx_type, dep_name) + return self + + return __validate_deps__ + + namespace["__validate_deps__"] = make_dep_validator(dep_fields, context_type) + + # Create the class using type() + GeneratedModel = type(f"_{fn.__name__}_Model", (CallableModel,), namespace) + + # Set class-level attributes after class creation (to avoid pydantic processing) + GeneratedModel.__flow_model_context_type__ = context_type + GeneratedModel.__flow_model_return_type__ = return_type + GeneratedModel.__flow_model_func__ = fn + GeneratedModel.__flow_model_dep_fields__ = dep_fields + GeneratedModel.__flow_model_use_context_args__ = use_context_args + GeneratedModel.__flow_model_context_args__ = ctx_args_list + + # Override context_type property after class creation + @property + def context_type_getter(self) -> Type[ContextBase]: + return self.__class__.__flow_model_context_type__ + + # Override result_type property after class creation + @property + def result_type_getter(self) -> Type[ResultBase]: + return self.__class__.__flow_model_return_type__ + + GeneratedModel.context_type = context_type_getter + GeneratedModel.result_type = result_type_getter + + # Register for serialization (local classes need this) + register_ccflow_import_path(GeneratedModel) + + # Rebuild the model to process annotations properly + GeneratedModel.model_rebuild() + + # Create factory function that returns model instances + @wraps(fn) + def factory(**kwargs) -> GeneratedModel: + return GeneratedModel(**kwargs) + + # Preserve useful attributes on factory + factory._generated_model = GeneratedModel + factory.__doc__ = fn.__doc__ + + return factory + + # Handle both @Flow.model and @Flow.model(...) syntax + if func is not None: + return decorator(func) + return decorator diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml new file mode 100644 index 0000000..781bd24 --- /dev/null +++ b/ccflow/tests/config/conf_flow.yaml @@ -0,0 +1,80 @@ +# Flow.model configurations for Hydra integration tests +# This file is separate from conf.yaml to avoid affecting existing tests + +# Basic Flow.model +flow_loader: + _target_: ccflow.tests.test_flow_model.basic_loader + source: test_source + multiplier: 5 + +flow_processor: + _target_: ccflow.tests.test_flow_model.string_processor + prefix: "value=" + suffix: "!" + +# Pipeline with dependencies (uses registry name references for same instance) +flow_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 100 + +flow_transformer: + _target_: ccflow.tests.test_flow_model.data_transformer + source: flow_source + factor: 3 + +# Three-stage pipeline +flow_stage1: + _target_: ccflow.tests.test_flow_model.pipeline_stage1 + initial: 10 + +flow_stage2: + _target_: ccflow.tests.test_flow_model.pipeline_stage2 + stage1_output: flow_stage1 + multiplier: 2 + +flow_stage3: + _target_: ccflow.tests.test_flow_model.pipeline_stage3 + stage2_output: flow_stage2 + offset: 50 + +# Diamond dependency pattern +diamond_source: + _target_: ccflow.tests.test_flow_model.data_source + base_value: 10 + +diamond_branch_a: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 2 + +diamond_branch_b: + _target_: ccflow.tests.test_flow_model.data_transformer + source: diamond_source + factor: 5 + +diamond_aggregator: + _target_: ccflow.tests.test_flow_model.data_aggregator + input_a: diamond_branch_a + input_b: diamond_branch_b + operation: add + +# DateRangeContext with transform +flow_date_loader: + _target_: ccflow.tests.test_flow_model.date_range_loader + source: market_data + include_weekends: false + +flow_date_processor: + _target_: ccflow.tests.test_flow_model.date_range_processor + raw_data: flow_date_loader + normalize: true + +# context_args models (auto-unpacked context parameters) +ctx_args_loader: + _target_: ccflow.tests.test_flow_model.context_args_loader + source: data_source + +ctx_args_processor: + _target_: ccflow.tests.test_flow_model.context_args_processor + data: ctx_args_loader + prefix: "output" diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index a748765..29f4524 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -462,6 +462,7 @@ def test_types(self): error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelMissingContextArg) + # BadModelDoubleContextArg also fails with the same error since extra params aren't allowed error = "__call__ method must take a single argument, named 'context'" self.assertRaisesRegex(ValueError, error, BadModelDoubleContextArg) diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py new file mode 100644 index 0000000..75e0899 --- /dev/null +++ b/ccflow/tests/test_flow_model.py @@ -0,0 +1,1477 @@ +"""Tests for Flow.model decorator.""" + +from datetime import date, timedelta +from typing import Annotated +from unittest import TestCase + +from pydantic import ValidationError +from ray.cloudpickle import dumps as rcpdumps, loads as rcploads + +from ccflow import ( + CallableModel, + ContextBase, + DateRangeContext, + Dep, + DepOf, + Flow, + GenericResult, + ModelRegistry, + ResultBase, +) + + +class SimpleContext(ContextBase): + """Simple context for testing.""" + + value: int + + +class ExtendedContext(ContextBase): + """Extended context with multiple fields.""" + + x: int + y: str = "default" + + +class MyResult(ResultBase): + """Custom result type for testing.""" + + data: str + + +# ============================================================================= +# Basic Flow.model Tests +# ============================================================================= + + +class TestFlowModelBasic(TestCase): + """Basic Flow.model functionality tests.""" + + def test_simple_model_explicit_context(self): + """Test Flow.model with explicit context parameter.""" + + @Flow.model + def simple_loader(context: SimpleContext, multiplier: int) -> GenericResult[int]: + return GenericResult(value=context.value * multiplier) + + # Create model instance + loader = simple_loader(multiplier=3) + + # Should be a CallableModel + self.assertIsInstance(loader, CallableModel) + + # Execute + ctx = SimpleContext(value=10) + result = loader(ctx) + + self.assertIsInstance(result, GenericResult) + self.assertEqual(result.value, 30) + + def test_model_with_default_params(self): + """Test Flow.model with default parameter values.""" + + @Flow.model + def loader_with_defaults(context: SimpleContext, multiplier: int = 2, prefix: str = "result") -> GenericResult[str]: + return GenericResult(value=f"{prefix}:{context.value * multiplier}") + + # Create with defaults + loader = loader_with_defaults() + result = loader(SimpleContext(value=5)) + self.assertEqual(result.value, "result:10") + + # Create with custom values + loader2 = loader_with_defaults(multiplier=3, prefix="custom") + result2 = loader2(SimpleContext(value=5)) + self.assertEqual(result2.value, "custom:15") + + def test_model_context_type_property(self): + """Test that generated model has correct context_type.""" + + @Flow.model + def typed_model(context: ExtendedContext, factor: int) -> GenericResult[int]: + return GenericResult(value=context.x * factor) + + model = typed_model(factor=2) + self.assertEqual(model.context_type, ExtendedContext) + + def test_model_result_type_property(self): + """Test that generated model has correct result_type.""" + + @Flow.model + def custom_result_model(context: SimpleContext) -> MyResult: + return MyResult(data=f"value={context.value}") + + model = custom_result_model() + self.assertEqual(model.result_type, MyResult) + + def test_model_with_no_extra_params(self): + """Test Flow.model with only context parameter.""" + + @Flow.model + def identity_model(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + model = identity_model() + result = model(SimpleContext(value=42)) + self.assertEqual(result.value, 42) + + def test_model_with_flow_options(self): + """Test Flow.model with Flow.call options.""" + + @Flow.model(cacheable=True, validate_result=True) + def cached_model(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=value + context.value) + + model = cached_model(value=10) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 15) + + def test_model_with_underscore_context(self): + """Test Flow.model with '_' as context parameter (unused context convention).""" + + @Flow.model + def loader(context: SimpleContext, base: int) -> GenericResult[int]: + return GenericResult(value=context.value + base) + + @Flow.model + def consumer(_: SimpleContext, data: DepOf[..., GenericResult[int]]) -> GenericResult[int]: + # Context not used directly, just passed to dependency + return GenericResult(value=data.value * 2) + + load = loader(base=100) + consume = consumer(data=load) + + result = consume(SimpleContext(value=10)) + # loader: 10 + 100 = 110, consumer: 110 * 2 = 220 + self.assertEqual(result.value, 220) + + # Verify context_type is still correct + self.assertEqual(consume.context_type, SimpleContext) + + +# ============================================================================= +# context_args Mode Tests +# ============================================================================= + + +class TestFlowModelContextArgs(TestCase): + """Tests for Flow.model with context_args (unpacked context).""" + + def test_context_args_basic(self): + """Test basic context_args usage.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def date_range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + + loader = date_range_loader(source="db") + + # Should use DateRangeContext + self.assertEqual(loader.context_type, DateRangeContext) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = loader(ctx) + self.assertEqual(result.value, "db:2024-01-01 to 2024-01-31") + + def test_context_args_custom_context(self): + """Test context_args with custom context type.""" + + @Flow.model(context_args=["x", "y"]) + def unpacked_model(x: int, y: str, multiplier: int = 1) -> GenericResult[str]: + return GenericResult(value=f"{y}:{x * multiplier}") + + model = unpacked_model(multiplier=2) + + # Create context with generated type + ctx_type = model.context_type + ctx = ctx_type(x=5, y="test") + + result = model(ctx) + self.assertEqual(result.value, "test:10") + + def test_context_args_with_defaults(self): + """Test context_args where context fields have defaults.""" + + @Flow.model(context_args=["value"]) + def model_with_ctx_default(value: int = 42, extra: str = "foo") -> GenericResult[str]: + return GenericResult(value=f"{extra}:{value}") + + model = model_with_ctx_default() + + # Create context - the generated context should allow default + ctx_type = model.context_type + ctx = ctx_type(value=100) + + result = model(ctx) + self.assertEqual(result.value, "foo:100") + + +# ============================================================================= +# Dependency Tests +# ============================================================================= + + +class TestFlowModelDependencies(TestCase): + """Tests for Flow.model with dependencies.""" + + def test_simple_dependency_with_depof(self): + """Test simple dependency using DepOf shorthand.""" + + @Flow.model + def loader(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=value + context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + multiplier: int = 1, + ) -> GenericResult[int]: + return GenericResult(value=data.value * multiplier) + + # Create pipeline + load = loader(value=10) + consume = consumer(data=load, multiplier=2) + + ctx = SimpleContext(value=5) + result = consume(ctx) + + # loader returns 10 + 5 = 15, consumer multiplies by 2 = 30 + self.assertEqual(result.value, 30) + + def test_dependency_with_explicit_dep(self): + """Test dependency using explicit Dep() annotation.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 2) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[GenericResult[int], Dep()], + ) -> GenericResult[int]: + return GenericResult(value=data.value + 100) + + load = loader() + consume = consumer(data=load) + + result = consume(SimpleContext(value=10)) + # loader: 10 * 2 = 20, consumer: 20 + 100 = 120 + self.assertEqual(result.value, 120) + + def test_dependency_with_direct_value(self): + """Test that Dep fields can accept direct values (not CallableModel).""" + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value + context.value) + + # Pass direct value instead of CallableModel + consume = consumer(data=GenericResult(value=100)) + + result = consume(SimpleContext(value=5)) + self.assertEqual(result.value, 105) + + def test_deps_method_generation(self): + """Test that __deps__ method is correctly generated.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=10) + deps = consume.__deps__(ctx) + + # Should have one dependency + self.assertEqual(len(deps), 1) + self.assertEqual(deps[0][0], load) + self.assertEqual(deps[0][1], [ctx]) + + def test_no_deps_when_direct_value(self): + """Test that __deps__ returns empty when direct values used.""" + + @Flow.model + def consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + consume = consumer(data=GenericResult(value=100)) + + deps = consume.__deps__(SimpleContext(value=10)) + self.assertEqual(len(deps), 0) + + +# ============================================================================= +# Transform Tests +# ============================================================================= + + +class TestFlowModelTransforms(TestCase): + """Tests for Flow.model with context transforms.""" + + def test_transform_in_dep(self): + """Test dependency with context transform.""" + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[ + GenericResult[int], + Dep(transform=lambda ctx: ctx.model_copy(update={"value": ctx.value + 10})), + ], + ) -> GenericResult[int]: + return GenericResult(value=data.value * 2) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=5) + result = consume(ctx) + + # Transform adds 10 to context.value: 5 + 10 = 15 + # Loader returns that: 15 + # Consumer multiplies by 2: 30 + self.assertEqual(result.value, 30) + + def test_transform_in_deps_method(self): + """Test that transform is applied in __deps__ method.""" + + def transform_fn(ctx): + return ctx.model_copy(update={"value": ctx.value * 3}) + + @Flow.model + def loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[GenericResult[int], Dep(transform=transform_fn)], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + load = loader() + consume = consumer(data=load) + + ctx = SimpleContext(value=7) + deps = consume.__deps__(ctx) + + # Transform should be applied + self.assertEqual(len(deps), 1) + transformed_ctx = deps[0][1][0] + self.assertEqual(transformed_ctx.value, 21) # 7 * 3 + + def test_date_range_transform(self): + """Test transform pattern with date ranges using context_args.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date}") + + def lookback_transform(ctx: DateRangeContext) -> DateRangeContext: + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + + @Flow.model(context_args=["start_date", "end_date"]) + def range_processor( + start_date: date, + end_date: date, + data: Annotated[GenericResult[str], Dep(transform=lookback_transform)], + ) -> GenericResult[str]: + return GenericResult(value=f"processed:{data.value}") + + loader = range_loader(source="db") + processor = range_processor(data=loader) + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + + # Transform should shift start_date back by 1 day + self.assertEqual(result.value, "processed:db:2024-01-09") + + +# ============================================================================= +# Pipeline Tests +# ============================================================================= + + +class TestFlowModelPipeline(TestCase): + """Tests for multi-stage pipelines with Flow.model.""" + + def test_three_stage_pipeline(self): + """Test a three-stage computation pipeline.""" + + @Flow.model + def stage1(context: SimpleContext, base: int) -> GenericResult[int]: + return GenericResult(value=context.value + base) + + @Flow.model + def stage2( + context: SimpleContext, + input_data: DepOf[..., GenericResult[int]], + multiplier: int, + ) -> GenericResult[int]: + return GenericResult(value=input_data.value * multiplier) + + @Flow.model + def stage3( + context: SimpleContext, + input_data: DepOf[..., GenericResult[int]], + offset: int = 0, + ) -> GenericResult[int]: + return GenericResult(value=input_data.value + offset) + + # Build pipeline + s1 = stage1(base=100) + s2 = stage2(input_data=s1, multiplier=2) + s3 = stage3(input_data=s2, offset=50) + + ctx = SimpleContext(value=10) + result = s3(ctx) + + # s1: 10 + 100 = 110 + # s2: 110 * 2 = 220 + # s3: 220 + 50 = 270 + self.assertEqual(result.value, 270) + + def test_diamond_dependency_pattern(self): + """Test diamond-shaped dependency pattern.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def branch_a( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value * 2) + + @Flow.model + def branch_b( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=data.value + 100) + + @Flow.model + def merger( + context: SimpleContext, + a: DepOf[..., GenericResult[int]], + b: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=a.value + b.value) + + src = source() + a = branch_a(data=src) + b = branch_b(data=src) + merge = merger(a=a, b=b) + + ctx = SimpleContext(value=10) + result = merge(ctx) + + # source: 10 + # branch_a: 10 * 2 = 20 + # branch_b: 10 + 100 = 110 + # merger: 20 + 110 = 130 + self.assertEqual(result.value, 130) + + +# ============================================================================= +# Integration Tests +# ============================================================================= + + +class TestFlowModelIntegration(TestCase): + """Integration tests for Flow.model with ccflow infrastructure.""" + + def test_registry_integration(self): + """Test that Flow.model models work with ModelRegistry.""" + + @Flow.model + def registrable_model(context: SimpleContext, value: int) -> GenericResult[int]: + return GenericResult(value=context.value + value) + + model = registrable_model(value=100) + + registry = ModelRegistry.root().clear() + registry.add("test_model", model) + + retrieved = registry["test_model"] + self.assertEqual(retrieved, model) + + result = retrieved(SimpleContext(value=10)) + self.assertEqual(result.value, 110) + + def test_serialization_dump(self): + """Test that generated models can be serialized.""" + + @Flow.model + def serializable_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: + return GenericResult(value=value) + + model = serializable_model(value=100) + dumped = model.model_dump(mode="python") + + self.assertIn("value", dumped) + self.assertEqual(dumped["value"], 100) + self.assertIn("type_", dumped) + + def test_pickle_roundtrip(self): + """Test cloudpickle serialization of generated models.""" + + @Flow.model + def pickleable_model(context: SimpleContext, factor: int) -> GenericResult[int]: + return GenericResult(value=context.value * factor) + + model = pickleable_model(factor=3) + + # Cloudpickle roundtrip (standard pickle won't work for local classes) + pickled = rcpdumps(model, protocol=5) + restored = rcploads(pickled) + + result = restored(SimpleContext(value=10)) + self.assertEqual(result.value, 30) + + def test_mix_with_manual_callable_model(self): + """Test mixing Flow.model with manually defined CallableModel.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + @Flow.model + def generated_consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + multiplier: int, + ) -> GenericResult[int]: + return GenericResult(value=data.value * multiplier) + + manual = ManualModel(offset=50) + generated = generated_consumer(data=manual, multiplier=2) + + result = generated(SimpleContext(value=10)) + # manual: 10 + 50 = 60 + # generated: 60 * 2 = 120 + self.assertEqual(result.value, 120) + + +# ============================================================================= +# Error Case Tests +# ============================================================================= + + +class TestFlowModelErrors(TestCase): + """Error case tests for Flow.model.""" + + def test_missing_return_type(self): + """Test error when return type annotation is missing.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model + def no_return(context: SimpleContext): + return GenericResult(value=1) + + self.assertIn("return type annotation", str(cm.exception)) + + def test_non_result_return_type(self): + """Test error when return type is not ResultBase subclass.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model + def bad_return(context: SimpleContext) -> int: + return 42 + + self.assertIn("ResultBase", str(cm.exception)) + + def test_missing_context_and_context_args(self): + """Test error when neither context param nor context_args provided.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model + def no_context(value: int) -> GenericResult[int]: + return GenericResult(value=value) + + self.assertIn("context", str(cm.exception)) + + def test_invalid_context_arg(self): + """Test error when context_args refers to non-existent parameter.""" + with self.assertRaises(ValueError) as cm: + + @Flow.model(context_args=["nonexistent"]) + def bad_context_args(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + self.assertIn("nonexistent", str(cm.exception)) + + def test_context_arg_without_annotation(self): + """Test error when context_arg parameter lacks type annotation.""" + with self.assertRaises(ValueError) as cm: + + @Flow.model(context_args=["x"]) + def untyped_context_arg(x) -> GenericResult[int]: + return GenericResult(value=x) + + self.assertIn("type annotation", str(cm.exception)) + + +# ============================================================================= +# Dep and DepOf Tests +# ============================================================================= + + +class TestDepAndDepOf(TestCase): + """Tests for Dep and DepOf classes.""" + + def test_depof_creates_annotated(self): + """Test that DepOf[..., T] creates Annotated[Union[T, CallableModel], Dep()].""" + from typing import Union as TypingUnion, get_args, get_origin + + annotation = DepOf[..., GenericResult[int]] + self.assertEqual(get_origin(annotation), Annotated) + + args = get_args(annotation) + # First arg is Union[ResultType, CallableModel] + self.assertEqual(get_origin(args[0]), TypingUnion) + union_args = get_args(args[0]) + self.assertIn(GenericResult[int], union_args) + self.assertIn(CallableModel, union_args) + # Second arg is Dep() + self.assertIsInstance(args[1], Dep) + self.assertIsNone(args[1].context_type) # ... means inherit from parent + + def test_depof_with_generic_type(self): + """Test DepOf with nested generic types.""" + from typing import List as TypingList, Union as TypingUnion, get_args, get_origin + + annotation = DepOf[..., GenericResult[TypingList[str]]] + self.assertEqual(get_origin(annotation), Annotated) + + args = get_args(annotation) + # First arg is Union[ResultType, CallableModel] + self.assertEqual(get_origin(args[0]), TypingUnion) + union_args = get_args(args[0]) + self.assertIn(GenericResult[TypingList[str]], union_args) + self.assertIn(CallableModel, union_args) + + def test_depof_with_context_type(self): + """Test DepOf[ContextType, ResultType] syntax.""" + from typing import Union as TypingUnion, get_args, get_origin + + annotation = DepOf[SimpleContext, GenericResult[int]] + self.assertEqual(get_origin(annotation), Annotated) + + args = get_args(annotation) + # First arg is Union[ResultType, CallableModel] + self.assertEqual(get_origin(args[0]), TypingUnion) + union_args = get_args(args[0]) + self.assertIn(GenericResult[int], union_args) + self.assertIn(CallableModel, union_args) + # Second arg is Dep with context_type + self.assertIsInstance(args[1], Dep) + self.assertEqual(args[1].context_type, SimpleContext) + + def test_extract_dep_with_annotated(self): + """Test extract_dep with Annotated type.""" + from ccflow.dep import extract_dep + + dep = Dep(context_type=SimpleContext) + annotation = Annotated[GenericResult[int], dep] + + base_type, extracted_dep = extract_dep(annotation) + self.assertEqual(base_type, GenericResult[int]) + self.assertEqual(extracted_dep, dep) + + def test_extract_dep_with_depof(self): + """Test extract_dep with DepOf type.""" + from typing import Union as TypingUnion, get_args, get_origin + + from ccflow.dep import extract_dep + + annotation = DepOf[..., GenericResult[str]] + base_type, extracted_dep = extract_dep(annotation) + + # base_type is Union[ResultType, CallableModel] + self.assertEqual(get_origin(base_type), TypingUnion) + union_args = get_args(base_type) + self.assertIn(GenericResult[str], union_args) + self.assertIn(CallableModel, union_args) + self.assertIsInstance(extracted_dep, Dep) + + def test_extract_dep_without_dep(self): + """Test extract_dep with regular type (no Dep).""" + from ccflow.dep import extract_dep + + base_type, extracted_dep = extract_dep(int) + self.assertEqual(base_type, int) + self.assertIsNone(extracted_dep) + + def test_extract_dep_annotated_without_dep(self): + """Test extract_dep with Annotated but no Dep marker.""" + from ccflow.dep import extract_dep + + annotation = Annotated[int, "some metadata"] + base_type, extracted_dep = extract_dep(annotation) + + # When no Dep marker is found, returns original annotation unchanged + self.assertEqual(base_type, annotation) + self.assertIsNone(extracted_dep) + + def test_is_compatible_type_simple(self): + """Test _is_compatible_type with simple types.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(int, int)) + self.assertFalse(_is_compatible_type(int, str)) + self.assertTrue(_is_compatible_type(bool, int)) # bool is subclass of int + + def test_is_compatible_type_generic(self): + """Test _is_compatible_type with generic types.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(GenericResult[int], GenericResult[int])) + self.assertFalse(_is_compatible_type(GenericResult[int], GenericResult[str])) + self.assertTrue(_is_compatible_type(GenericResult, GenericResult)) + + def test_is_compatible_type_none(self): + """Test _is_compatible_type with None.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(None, None)) + self.assertFalse(_is_compatible_type(None, int)) + self.assertFalse(_is_compatible_type(int, None)) + + def test_is_compatible_type_subclass(self): + """Test _is_compatible_type with subclasses.""" + from ccflow.dep import _is_compatible_type + + self.assertTrue(_is_compatible_type(MyResult, ResultBase)) + self.assertFalse(_is_compatible_type(ResultBase, MyResult)) + + def test_dep_validate_dependency_success(self): + """Test Dep.validate_dependency with valid dependency.""" + + @Flow.model + def valid_dep(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + dep = Dep() + model = valid_dep() + + # Should not raise + dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") + + def test_dep_validate_dependency_context_mismatch(self): + """Test Dep.validate_dependency with context type mismatch.""" + + class OtherContext(ContextBase): + other: str + + @Flow.model + def other_dep(context: OtherContext) -> GenericResult[int]: + return GenericResult(value=42) + + dep = Dep(context_type=SimpleContext) + model = other_dep() + + with self.assertRaises(TypeError) as cm: + dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") + + self.assertIn("context_type", str(cm.exception)) + + def test_dep_validate_dependency_result_mismatch(self): + """Test Dep.validate_dependency with result type mismatch.""" + + @Flow.model + def wrong_result(context: SimpleContext) -> MyResult: + return MyResult(data="test") + + dep = Dep() + model = wrong_result() + + with self.assertRaises(TypeError) as cm: + dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") + + self.assertIn("result_type", str(cm.exception)) + + def test_dep_validate_dependency_non_callable(self): + """Test Dep.validate_dependency with non-CallableModel value.""" + dep = Dep() + # Should not raise for non-CallableModel values + dep.validate_dependency(GenericResult(value=42), GenericResult[int], SimpleContext, "data") + dep.validate_dependency("string", GenericResult[int], SimpleContext, "data") + dep.validate_dependency(123, GenericResult[int], SimpleContext, "data") + + def test_dep_hash(self): + """Test Dep is hashable for use in sets/dicts.""" + dep1 = Dep() + dep2 = Dep(context_type=SimpleContext) + + # Should be hashable + dep_set = {dep1, dep2} + self.assertEqual(len(dep_set), 2) + + dep_dict = {dep1: "value1", dep2: "value2"} + self.assertEqual(dep_dict[dep1], "value1") + self.assertEqual(dep_dict[dep2], "value2") + + def test_dep_apply_with_transform(self): + """Test Dep.apply with transform function.""" + + def transform(ctx): + return ctx.model_copy(update={"value": ctx.value * 2}) + + dep = Dep(transform=transform) + + ctx = SimpleContext(value=10) + result = dep.apply(ctx) + + self.assertEqual(result.value, 20) + + def test_dep_apply_without_transform(self): + """Test Dep.apply without transform (identity).""" + dep = Dep() + + ctx = SimpleContext(value=10) + result = dep.apply(ctx) + + self.assertIs(result, ctx) + + def test_dep_repr(self): + """Test Dep string representation.""" + dep1 = Dep() + self.assertEqual(repr(dep1), "Dep()") + + dep2 = Dep(context_type=SimpleContext) + self.assertIn("SimpleContext", repr(dep2)) + + dep3 = Dep(transform=lambda x: x) + self.assertIn("transform=", repr(dep3)) + + def test_dep_equality(self): + """Test Dep equality comparison.""" + dep1 = Dep() + dep2 = Dep() + dep3 = Dep(context_type=SimpleContext) + + # Note: Two Dep() instances with no arguments are equal + self.assertEqual(dep1, dep2) + self.assertNotEqual(dep1, dep3) + + +# ============================================================================= +# Validation Tests +# ============================================================================= + + +class TestFlowModelValidation(TestCase): + """Tests for dependency validation in Flow.model.""" + + def test_context_type_validation(self): + """Test that context_type mismatch is detected.""" + + class OtherContext(ContextBase): + other: str + + @Flow.model + def simple_loader(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def other_loader(context: OtherContext) -> GenericResult[int]: + return GenericResult(value=42) + + @Flow.model + def consumer( + context: SimpleContext, + data: Annotated[GenericResult[int], Dep(context_type=SimpleContext)], + ) -> GenericResult[int]: + return GenericResult(value=data.value) + + # Should work with matching context + load1 = simple_loader() + consume1 = consumer(data=load1) + self.assertIsNotNone(consume1) + + # Should fail with mismatched context + load2 = other_loader() + with self.assertRaises((TypeError, ValidationError)): + consumer(data=load2) + + +# ============================================================================= +# Hydra Integration Tests +# ============================================================================= + + +# Define Flow.model functions at module level for Hydra to find them +@Flow.model +def hydra_basic_model(context: SimpleContext, value: int, name: str = "default") -> GenericResult[str]: + """Module-level model for Hydra testing.""" + return GenericResult(value=f"{name}:{context.value + value}") + + +# --- Additional module-level fixtures for Hydra YAML tests --- + + +@Flow.model +def basic_loader(context: SimpleContext, source: str, multiplier: int = 1) -> GenericResult[int]: + """Basic loader that multiplies context value by multiplier.""" + return GenericResult(value=context.value * multiplier) + + +@Flow.model +def string_processor(context: SimpleContext, prefix: str, suffix: str = "") -> GenericResult[str]: + """Process context value into a string with prefix and suffix.""" + return GenericResult(value=f"{prefix}{context.value}{suffix}") + + +@Flow.model +def data_source(context: SimpleContext, base_value: int) -> GenericResult[int]: + """Source that provides base data.""" + return GenericResult(value=context.value + base_value) + + +@Flow.model +def data_transformer( + context: SimpleContext, + source: DepOf[..., GenericResult[int]], + factor: int = 2, +) -> GenericResult[int]: + """Transform data by multiplying with factor.""" + return GenericResult(value=source.value * factor) + + +@Flow.model +def data_aggregator( + context: SimpleContext, + input_a: DepOf[..., GenericResult[int]], + input_b: DepOf[..., GenericResult[int]], + operation: str = "add", +) -> GenericResult[int]: + """Aggregate two inputs.""" + if operation == "add": + return GenericResult(value=input_a.value + input_b.value) + elif operation == "multiply": + return GenericResult(value=input_a.value * input_b.value) + else: + return GenericResult(value=input_a.value - input_b.value) + + +@Flow.model +def pipeline_stage1(context: SimpleContext, initial: int) -> GenericResult[int]: + """First stage of pipeline.""" + return GenericResult(value=context.value + initial) + + +@Flow.model +def pipeline_stage2( + context: SimpleContext, + stage1_output: DepOf[..., GenericResult[int]], + multiplier: int = 2, +) -> GenericResult[int]: + """Second stage of pipeline.""" + return GenericResult(value=stage1_output.value * multiplier) + + +@Flow.model +def pipeline_stage3( + context: SimpleContext, + stage2_output: DepOf[..., GenericResult[int]], + offset: int = 0, +) -> GenericResult[int]: + """Third stage of pipeline.""" + return GenericResult(value=stage2_output.value + offset) + + +def lookback_one_day(ctx: DateRangeContext) -> DateRangeContext: + """Transform that extends start_date back by one day.""" + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + + +@Flow.model +def date_range_loader( + context: DateRangeContext, + source: str, + include_weekends: bool = True, +) -> GenericResult[str]: + """Load data for a date range.""" + return GenericResult(value=f"{source}:{context.start_date} to {context.end_date}") + + +@Flow.model +def date_range_processor( + context: DateRangeContext, + raw_data: Annotated[GenericResult[str], Dep(transform=lookback_one_day)], + normalize: bool = False, +) -> GenericResult[str]: + """Process date range data with lookback.""" + prefix = "normalized:" if normalize else "raw:" + return GenericResult(value=f"{prefix}{raw_data.value}") + + +@Flow.model +def hydra_default_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: + """Module-level model with defaults for Hydra testing.""" + return GenericResult(value=context.value + value) + + +@Flow.model +def hydra_source_model(context: SimpleContext, base: int) -> GenericResult[int]: + """Source model for dependency testing.""" + return GenericResult(value=context.value * base) + + +@Flow.model +def hydra_consumer_model( + context: SimpleContext, + source: DepOf[..., GenericResult[int]], + factor: int = 1, +) -> GenericResult[int]: + """Consumer model for dependency testing.""" + return GenericResult(value=source.value * factor) + + +# --- context_args fixtures for Hydra testing --- + + +@Flow.model(context_args=["start_date", "end_date"]) +def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: + """Loader using context_args with DateRangeContext.""" + return GenericResult(value=f"{source}:{start_date} to {end_date}") + + +@Flow.model(context_args=["start_date", "end_date"]) +def context_args_processor( + start_date: date, + end_date: date, + data: DepOf[..., GenericResult[str]], + prefix: str = "processed", +) -> GenericResult[str]: + """Processor using context_args with dependency.""" + return GenericResult(value=f"{prefix}:{data.value}") + + +class TestFlowModelHydra(TestCase): + """Tests for Flow.model with Hydra configuration.""" + + def test_hydra_instantiate_basic(self): + """Test that Flow.model factory can be instantiated via Hydra.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Create config that references the factory function by module path + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_basic_model", + "value": 100, + "name": "test", + } + ) + + # Instantiate via Hydra + model = instantiate(cfg) + + self.assertIsInstance(model, CallableModel) + result = model(SimpleContext(value=10)) + self.assertEqual(result.value, "test:110") + + def test_hydra_instantiate_with_defaults(self): + """Test Hydra instantiation using default parameter values.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_default_model", + # Not specifying value, should use default + } + ) + + model = instantiate(cfg) + result = model(SimpleContext(value=8)) + self.assertEqual(result.value, 50) + + def test_hydra_instantiate_with_dependency(self): + """Test Hydra instantiation with dependencies.""" + from hydra.utils import instantiate + from omegaconf import OmegaConf + + # Create nested config + cfg = OmegaConf.create( + { + "_target_": "ccflow.tests.test_flow_model.hydra_consumer_model", + "source": { + "_target_": "ccflow.tests.test_flow_model.hydra_source_model", + "base": 10, + }, + "factor": 2, + } + ) + + model = instantiate(cfg) + + result = model(SimpleContext(value=5)) + # source: 5 * 10 = 50, consumer: 50 * 2 = 100 + self.assertEqual(result.value, 100) + + +# ============================================================================= +# Class-based CallableModel with Auto-Resolution Tests +# ============================================================================= + + +class TestClassBasedDepResolution(TestCase): + """Tests for auto-resolution of DepOf fields in class-based CallableModels. + + Key pattern: Fields use DepOf annotation, __call__ only takes context, + and resolved values are accessed via self.field_name during __call__. + """ + + def test_class_based_auto_resolve_basic(self): + """Test that DepOf fields are auto-resolved and accessible via self.""" + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + class Consumer(CallableModel): + # DepOf expands to Annotated[Union[ResultType, CallableModel], Dep()] + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + # Access resolved value via self.source + return GenericResult(value=self.source.value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.source, [context])] + + src = data_source() + consumer = Consumer(source=src) + + result = consumer(SimpleContext(value=5)) + # source: 5 * 10 = 50, consumer: 50 + 1 = 51 + self.assertEqual(result.value, 51) + + def test_class_based_with_custom_transform(self): + """Test that custom __deps__ transform is used.""" + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + class Consumer(CallableModel): + source: DepOf[..., GenericResult[int]] + offset: int = 100 + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.source.value + self.offset) + + @Flow.deps + def __deps__(self, context: SimpleContext): + # Apply custom transform + transformed_ctx = SimpleContext(value=context.value + 5) + return [(self.source, [transformed_ctx])] + + src = data_source() + consumer = Consumer(source=src, offset=1) + + result = consumer(SimpleContext(value=5)) + # transformed context: 5 + 5 = 10 + # source: 10 * 10 = 100 + # consumer: 100 + 1 = 101 + self.assertEqual(result.value, 101) + + def test_class_based_with_annotated_transform(self): + """Test that Dep transform is used when field not in __deps__.""" + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + def double_value(ctx: SimpleContext) -> SimpleContext: + return SimpleContext(value=ctx.value * 2) + + class Consumer(CallableModel): + source: Annotated[DepOf[..., GenericResult[int]], Dep(transform=double_value)] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.source.value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [] # Empty - uses Dep annotation transform from field + + src = data_source() + consumer = Consumer(source=src) + + result = consumer(SimpleContext(value=5)) + # transform: 5 * 2 = 10 + # source: 10 * 10 = 100 + # consumer: 100 + 1 = 101 + self.assertEqual(result.value, 101) + + def test_class_based_multiple_deps(self): + """Test auto-resolution with multiple dependencies.""" + + @Flow.model + def source_a(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def source_b(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 2) + + class Aggregator(CallableModel): + a: DepOf[..., GenericResult[int]] + b: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.a.value + self.b.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.a, [context]), (self.b, [context])] + + agg = Aggregator(a=source_a(), b=source_b()) + + result = agg(SimpleContext(value=10)) + # a: 10, b: 20, aggregator: 30 + self.assertEqual(result.value, 30) + + def test_class_based_deps_with_instance_field_access(self): + """Test that __deps__ can access instance fields for configurable transforms. + + This is the key advantage of class-based models over @Flow.model: + transforms can use instance fields like window size. + """ + + @Flow.model + def data_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + class Consumer(CallableModel): + source: DepOf[..., GenericResult[int]] + lookback: int = 5 # Configurable instance field + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.source.value * 2) + + @Flow.deps + def __deps__(self, context: SimpleContext): + # Access self.lookback in transform - this is why we use class-based! + transformed = SimpleContext(value=context.value + self.lookback) + return [(self.source, [transformed])] + + src = data_source() + consumer = Consumer(source=src, lookback=10) + + result = consumer(SimpleContext(value=5)) + # transformed: 5 + 10 = 15 + # source: 15 + # consumer: 15 * 2 = 30 + self.assertEqual(result.value, 30) + + def test_class_based_with_direct_value(self): + """Test that DepOf fields can accept pre-resolved values.""" + + class Consumer(CallableModel): + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.source.value + context.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + # No deps when source is already resolved + return [] + + # Pass direct value instead of CallableModel + consumer = Consumer(source=GenericResult(value=100)) + + result = consumer(SimpleContext(value=5)) + self.assertEqual(result.value, 105) + + def test_class_based_no_double_call(self): + """Test that dependencies are not called twice during DepOf resolution. + + This verifies that the auto-resolution mechanism doesn't accidentally + evaluate the same dependency multiple times. + """ + call_counts = {"source": 0} + + @Flow.model + def counting_source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=self.data.value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.data, [context])] + + src = counting_source() + consumer = Consumer(data=src) + + # Call consumer - source should only be called once + result = consumer(SimpleContext(value=5)) + + self.assertEqual(result.value, 51) # 5 * 10 + 1 + self.assertEqual(call_counts["source"], 1, "Source should only be called once") + + def test_class_based_nested_depof_no_double_call(self): + """Test nested DepOf chain (A -> B -> C) has no double-calls at any layer. + + This tests a 3-layer dependency chain where: + - layer_c is the leaf (no dependencies) + - layer_b depends on layer_c + - layer_a depends on layer_b + + Each layer should be called exactly once. + """ + call_counts = {"layer_a": 0, "layer_b": 0, "layer_c": 0} + + # Layer C: leaf node (no dependencies) + @Flow.model + def layer_c(context: SimpleContext) -> GenericResult[int]: + call_counts["layer_c"] += 1 + return GenericResult(value=context.value) + + # Layer B: depends on layer_c + class LayerB(CallableModel): + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + call_counts["layer_b"] += 1 + return GenericResult(value=self.source.value * 10) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.source, [context])] + + # Layer A: depends on layer_b + class LayerA(CallableModel): + source: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + call_counts["layer_a"] += 1 + return GenericResult(value=self.source.value + 1) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.source, [context])] + + # Build the chain: A -> B -> C + c = layer_c() + b = LayerB(source=c) + a = LayerA(source=b) + + # Call layer_a - each layer should be called exactly once + result = a(SimpleContext(value=5)) + + # Verify result: C returns 5, B returns 5*10=50, A returns 50+1=51 + self.assertEqual(result.value, 51) + + # Verify each layer called exactly once + self.assertEqual(call_counts["layer_c"], 1, "layer_c should be called exactly once") + self.assertEqual(call_counts["layer_b"], 1, "layer_b should be called exactly once") + self.assertEqual(call_counts["layer_a"], 1, "layer_a should be called exactly once") + + def test_flow_model_uses_unified_resolution_path(self): + """Test that @Flow.model uses the same resolution path as class-based CallableModel. + + This verifies the consolidation of resolution logic - both @Flow.model and + class-based models should use _resolve_deps_and_call in callable.py. + """ + call_counts = {"source": 0, "decorator_model": 0, "class_model": 0} + + @Flow.model + def shared_source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 2) + + # @Flow.model consumer + @Flow.model + def decorator_consumer( + context: SimpleContext, + data: DepOf[..., GenericResult[int]], + ) -> GenericResult[int]: + call_counts["decorator_model"] += 1 + return GenericResult(value=data.value + 100) + + # Class-based consumer (same logic) + class ClassConsumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + call_counts["class_model"] += 1 + return GenericResult(value=self.data.value + 100) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.data, [context])] + + # Test both consumers with the same source + src = shared_source() + dec_consumer = decorator_consumer(data=src) + cls_consumer = ClassConsumer(data=src) + + ctx = SimpleContext(value=10) + + # Both should produce the same result + dec_result = dec_consumer(ctx) + cls_result = cls_consumer(ctx) + + self.assertEqual(dec_result.value, cls_result.value) + self.assertEqual(dec_result.value, 120) # 10 * 2 + 100 + + # Source should be called exactly twice (once per consumer) + self.assertEqual(call_counts["source"], 2) + self.assertEqual(call_counts["decorator_model"], 1) + self.assertEqual(call_counts["class_model"], 1) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py new file mode 100644 index 0000000..661ac4f --- /dev/null +++ b/ccflow/tests/test_flow_model_hydra.py @@ -0,0 +1,437 @@ +"""Hydra integration tests for Flow.model. + +These tests verify that Flow.model decorated functions work correctly when +loaded from YAML configuration files using ModelRegistry.load_config_from_path(). + +Key feature: Registry name references (e.g., `source: flow_source`) ensure the same +object instance is shared across all consumers. +""" + +from datetime import date +from pathlib import Path +from unittest import TestCase + +from omegaconf import OmegaConf + +from ccflow import CallableModel, DateRangeContext, GenericResult, ModelRegistry + +from .test_flow_model import SimpleContext + +CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") + + +class TestFlowModelHydraYAML(TestCase): + """Tests loading Flow.model from YAML config files using ModelRegistry.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_basic_loader_from_yaml(self): + """Test basic model instantiation from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + + self.assertIsInstance(loader, CallableModel) + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 50) # 10 * 5 + + def test_string_processor_from_yaml(self): + """Test string processor model from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_processor"] + + ctx = SimpleContext(value=42) + result = processor(ctx) + self.assertEqual(result.value, "value=42!") + + def test_two_stage_pipeline_from_yaml(self): + """Test two-stage pipeline from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + + self.assertIsInstance(transformer, CallableModel) + + ctx = SimpleContext(value=5) + result = transformer(ctx) + # flow_source: 5 + 100 = 105 + # flow_transformer: 105 * 3 = 315 + self.assertEqual(result.value, 315) + + def test_three_stage_pipeline_from_yaml(self): + """Test three-stage pipeline from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + stage3 = r["flow_stage3"] + + ctx = SimpleContext(value=10) + result = stage3(ctx) + # stage1: 10 + 10 = 20 + # stage2: 20 * 2 = 40 + # stage3: 40 + 50 = 90 + self.assertEqual(result.value, 90) + + def test_diamond_dependency_from_yaml(self): + """Test diamond dependency pattern from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + aggregator = r["diamond_aggregator"] + + ctx = SimpleContext(value=10) + result = aggregator(ctx) + # source: 10 + 10 = 20 + # branch_a: 20 * 2 = 40 + # branch_b: 20 * 5 = 100 + # aggregator: 40 + 100 = 140 + self.assertEqual(result.value, 140) + + def test_date_range_pipeline_from_yaml(self): + """Test DateRangeContext pipeline with transforms from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_date_processor"] + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) + + # The transform extends start_date back by one day + self.assertIn("2024-01-09", result.value) + self.assertIn("normalized:", result.value) + + def test_context_args_from_yaml(self): + """Test context_args model from YAML config.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["ctx_args_loader"] + + self.assertIsInstance(loader, CallableModel) + # context_args models use DateRangeContext + self.assertEqual(loader.context_type, DateRangeContext) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = loader(ctx) + self.assertEqual(result.value, "data_source:2024-01-01 to 2024-01-31") + + def test_context_args_pipeline_from_yaml(self): + """Test context_args pipeline with dependencies from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["ctx_args_processor"] + + ctx = DateRangeContext(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31)) + result = processor(ctx) + # loader: "data_source:2024-03-01 to 2024-03-31" + # processor: "output:data_source:2024-03-01 to 2024-03-31" + self.assertEqual(result.value, "output:data_source:2024-03-01 to 2024-03-31") + + def test_context_args_shares_instance(self): + """Test that context_args pipeline shares dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["ctx_args_loader"] + processor = r["ctx_args_processor"] + + self.assertIs(processor.data, loader) + + +class TestFlowModelHydraInstanceSharing(TestCase): + """Tests that registry name references share the same object instance.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_pipeline_shares_instance(self): + """Test that pipeline stages share the same dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + source = r["flow_source"] + + self.assertIs(transformer.source, source) + + def test_three_stage_pipeline_shares_instances(self): + """Test that three-stage pipeline shares instances correctly.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + stage1 = r["flow_stage1"] + stage2 = r["flow_stage2"] + stage3 = r["flow_stage3"] + + self.assertIs(stage2.stage1_output, stage1) + self.assertIs(stage3.stage2_output, stage2) + + def test_diamond_pattern_shares_source_instance(self): + """Test that diamond pattern branches share the same source instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + source = r["diamond_source"] + branch_a = r["diamond_branch_a"] + branch_b = r["diamond_branch_b"] + aggregator = r["diamond_aggregator"] + + # Both branches should share the SAME source instance + self.assertIs(branch_a.source, source) + self.assertIs(branch_b.source, source) + self.assertIs(branch_a.source, branch_b.source) + + self.assertIs(aggregator.input_a, branch_a) + self.assertIs(aggregator.input_b, branch_b) + + def test_date_range_shares_instance(self): + """Test that date range pipeline shares dependency instance.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_date_loader"] + processor = r["flow_date_processor"] + + self.assertIs(processor.raw_data, loader) + + +class TestFlowModelHydraOmegaConf(TestCase): + """Tests using OmegaConf.create for dynamic config creation.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_instantiate_with_omegaconf(self): + """Test instantiation using OmegaConf.create via ModelRegistry.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "dynamic_source", + "multiplier": 7, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=3) + result = loader(ctx) + self.assertEqual(result.value, 21) # 3 * 7 + + def test_nested_deps_with_omegaconf(self): + """Test nested dependencies using OmegaConf with registry names.""" + cfg = OmegaConf.create( + { + "source": { + "_target_": "ccflow.tests.test_flow_model.data_source", + "base_value": 50, + }, + "transformer": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 4, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + transformer = r["transformer"] + + ctx = SimpleContext(value=10) + result = transformer(ctx) + # source: 10 + 50 = 60 + # transformer: 60 * 4 = 240 + self.assertEqual(result.value, 240) + + self.assertIs(transformer.source, r["source"]) + + def test_diamond_with_omegaconf(self): + """Test diamond pattern with OmegaConf using registry names.""" + cfg = OmegaConf.create( + { + "source": { + "_target_": "ccflow.tests.test_flow_model.data_source", + "base_value": 10, + }, + "branch_a": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 2, + }, + "branch_b": { + "_target_": "ccflow.tests.test_flow_model.data_transformer", + "source": "source", + "factor": 3, + }, + "aggregator": { + "_target_": "ccflow.tests.test_flow_model.data_aggregator", + "input_a": "branch_a", + "input_b": "branch_b", + "operation": "multiply", + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + aggregator = r["aggregator"] + + ctx = SimpleContext(value=5) + result = aggregator(ctx) + # source: 5 + 10 = 15 + # branch_a: 15 * 2 = 30 + # branch_b: 15 * 3 = 45 + # aggregator: 30 * 45 = 1350 + self.assertEqual(result.value, 1350) + + # Verify SAME source instance is shared + self.assertIs(r["branch_a"].source, r["source"]) + self.assertIs(r["branch_b"].source, r["source"]) + + +class TestFlowModelHydraDefaults(TestCase): + """Tests that default parameter values work with Hydra.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_defaults_used_when_not_specified(self): + """Test that default values are used when not in config.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "test", + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 10) # 10 * 1 (default) + + def test_defaults_can_be_overridden(self): + """Test that defaults can be overridden in config.""" + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "test", + "multiplier": 100, + }, + } + ) + + r = ModelRegistry.root() + r.load_config(cfg) + loader = r["loader"] + + ctx = SimpleContext(value=10) + result = loader(ctx) + self.assertEqual(result.value, 1000) # 10 * 100 + + +class TestFlowModelHydraModelProperties(TestCase): + """Tests that model properties are correct after Hydra instantiation.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_context_type_property(self): + """Test that context_type is correct.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + self.assertEqual(loader.context_type, SimpleContext) + + def test_result_type_property(self): + """Test that result_type is correct.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + loader = r["flow_loader"] + self.assertEqual(loader.result_type, GenericResult[int]) + + def test_deps_method_works(self): + """Test that __deps__ method works after Hydra instantiation.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + transformer = r["flow_transformer"] + + ctx = SimpleContext(value=5) + deps = transformer.__deps__(ctx) + + self.assertEqual(len(deps), 1) + self.assertIsInstance(deps[0][0], CallableModel) + self.assertEqual(deps[0][1], [ctx]) + self.assertIs(deps[0][0], r["flow_source"]) + + +class TestFlowModelHydraDateRangeTransforms(TestCase): + """Tests transforms with DateRangeContext from Hydra config.""" + + def setUp(self) -> None: + ModelRegistry.root().clear() + + def tearDown(self) -> None: + ModelRegistry.root().clear() + + def test_transform_applied_from_yaml(self): + """Test that transform is applied when loaded from YAML.""" + r = ModelRegistry.root() + r.load_config_from_path(CONFIG_PATH) + + processor = r["flow_date_processor"] + + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + deps = processor.__deps__(ctx) + + self.assertEqual(len(deps), 1) + dep_model, dep_contexts = deps[0] + + # The transform should extend start_date back by one day + transformed_ctx = dep_contexts[0] + self.assertEqual(transformed_ctx.start_date, date(2024, 1, 9)) + self.assertEqual(transformed_ctx.end_date, date(2024, 1, 31)) + + self.assertIs(dep_model, r["flow_date_loader"]) + + +if __name__ == "__main__": + import unittest + + unittest.main() diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 616e3d8..a89d8f8 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -22,6 +22,121 @@ The naming was inspired by the open source library [Pydantic](https://docs.pydan `CallableModel`'s are called with a context (something that derives from `ContextBase`) and returns a result (something that derives from `ResultBase`). As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. +### Flow.model Decorator + +The `@Flow.model` decorator provides a simpler way to define `CallableModel`s using plain Python functions instead of classes. It automatically generates a `CallableModel` class with proper `__call__` and `__deps__` methods. + +**Basic Example:** + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +@Flow.model +def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: + # Your data loading logic here + return GenericResult(value=query_db(source, context.start_date, context.end_date)) + +# Create model instance +loader = load_data(source="my_database") + +# Execute with context +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) +``` + +**Composing Dependencies with `Dep` and `DepOf`:** + +Use `Dep()` or `DepOf` to mark parameters that accept other `CallableModel`s as dependencies. The framework automatically resolves the dependency graph. + +> **Tip:** If your function doesn't use the context directly (only passes it to dependencies), use `_` as the parameter name to signal this: `def my_func(_: DateRangeContext, data: DepOf[..., ResultType])`. This is a Python convention for intentionally unused parameters. + +```python +from datetime import date, timedelta +from typing import Annotated +from ccflow import Flow, GenericResult, DateRangeContext, Dep, DepOf + +@Flow.model +def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: + return GenericResult(value={"records": [1, 2, 3]}) + +@Flow.model +def transform_data( + _: DateRangeContext, # Context passed to dependency, not used directly + raw_data: Annotated[GenericResult[dict], Dep( + # Transform context to fetch one extra day for lookback + transform=lambda ctx: ctx.model_copy(update={ + "start_date": ctx.start_date - timedelta(days=1) + }) + )] +) -> GenericResult[dict]: + # raw_data.value contains the resolved result from load_data + return GenericResult(value={"transformed": raw_data.value["records"]}) + +# Or use DepOf shorthand (no transform needed): +@Flow.model +def aggregate_data( + _: DateRangeContext, # Context passed to dependency, not used directly + transformed: DepOf[..., GenericResult[dict]] # Shorthand for Annotated[T, Dep()] +) -> GenericResult[dict]: + return GenericResult(value={"count": len(transformed.value["transformed"])}) + +# Build the pipeline +data = load_data(source="my_database") +transformed = transform_data(raw_data=data) +aggregated = aggregate_data(transformed=transformed) + +# Execute - dependencies are automatically resolved +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = aggregated(ctx) +``` + +**Hydra/YAML Configuration:** + +`Flow.model` decorated functions work seamlessly with Hydra configuration and the `ModelRegistry`: + +```yaml +# config.yaml +data: + _target_: mymodule.load_data + source: my_database + +transformed: + _target_: mymodule.transform_data + raw_data: data # Reference by registry name (same instance is shared) + +aggregated: + _target_: mymodule.aggregate_data + transformed: transformed # Reference by registry name +``` + +When loaded via `ModelRegistry.load_config()`, references by name ensure the same object instance is shared across all consumers. + +**Auto-Unpacked Context with `context_args`:** + +Instead of taking an explicit `context` parameter, you can use `context_args` to automatically unpack context fields as function parameters. This is useful when you want cleaner function signatures: + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +# Instead of: def load_data(context: DateRangeContext, source: str) +# Use context_args to unpack the context fields directly: +@Flow.model(context_args=["start_date", "end_date"]) +def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + +# The decorator infers DateRangeContext from the parameter types +loader = load_data(source="my_database") +assert loader.context_type == DateRangeContext + +# Execute with context as usual +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) # "my_database:2024-01-01 to 2024-01-31" +``` + +The `context_args` parameter specifies which function parameters should be extracted from the context. The framework automatically determines the context type based on the parameter type annotations. + ## Model Registry A `ModelRegistry` is a named collection of models. diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py new file mode 100644 index 0000000..c3d12d1 --- /dev/null +++ b/examples/flow_model_example.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python +"""Example demonstrating Flow.model decorator and class-based CallableModel. + +This example shows: +- Flow.model for simple functions with minimal boilerplate +- Context transforms with Dep annotations +- Class-based CallableModel for complex cases needing instance field access +""" + +from datetime import date, timedelta +from typing import Annotated + +from ccflow import CallableModel, DateRangeContext, Dep, DepOf, Flow, GenericResult + + +# ============================================================================= +# Example 1: Basic Flow.model - No more boilerplate classes! +# ============================================================================= + +@Flow.model +def load_records(context: DateRangeContext, source: str, limit: int = 100) -> GenericResult[list]: + """Load records from a data source for the given date range.""" + print(f" Loading from '{source}' for {context.start_date} to {context.end_date} (limit={limit})") + return GenericResult(value=[ + {"id": i, "date": str(context.start_date), "value": i * 10} + for i in range(min(limit, 5)) + ]) + + +# ============================================================================= +# Example 2: Dependencies with DepOf - Automatic dependency resolution +# ============================================================================= + +@Flow.model +def compute_totals( + _: DateRangeContext, # Context passed to dependency, not used directly here + records: DepOf[..., GenericResult[list]], +) -> GenericResult[dict]: + """Compute totals from loaded records.""" + total = sum(r["value"] for r in records.value) + count = len(records.value) + print(f" Computing totals: {count} records, total={total}") + return GenericResult(value={"total": total, "count": count}) + + +# ============================================================================= +# Example 3: Simple Transform with Flow.model +# When the transform is a fixed function, Flow.model works great +# ============================================================================= + +def lookback_7_days(ctx: DateRangeContext) -> DateRangeContext: + """Fixed transform that extends the date range back by 7 days.""" + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=7)}) + + +@Flow.model +def compute_weekly_average( + _: DateRangeContext, + records: Annotated[GenericResult[list], Dep(transform=lookback_7_days)], +) -> GenericResult[float]: + """Compute average using fixed 7-day lookback.""" + values = [r["value"] for r in records.value] + avg = sum(values) / len(values) if values else 0 + print(f" Computing weekly average: {avg:.2f} (from {len(values)} records)") + return GenericResult(value=avg) + + +# ============================================================================= +# Example 4: Class-based CallableModel with Configurable Transform +# When the transform needs access to instance fields (like window size), +# use a class-based approach with auto-resolution +# ============================================================================= + +class ComputeMovingAverage(CallableModel): + """Compute moving average with configurable lookback window. + + This demonstrates: + - Field uses DepOf annotation: accepts either result or CallableModel + - Instance field (window) accessible in __deps__ for custom transforms + - Auto-resolution: self.records returns resolved value during __call__ + """ + + records: DepOf[..., GenericResult[list]] + window: int = 7 # Configurable lookback window + + @Flow.call + def __call__(self, context: DateRangeContext) -> GenericResult[float]: + """Compute the moving average - self.records is already resolved.""" + values = [r["value"] for r in self.records.value] + avg = sum(values) / len(values) if values else 0 + print(f" Computing {self.window}-day moving average: {avg:.2f} (from {len(values)} records)") + return GenericResult(value=avg) + + @Flow.deps + def __deps__(self, context: DateRangeContext): + """Define dependencies with transform that uses self.window.""" + # This is where we can access instance fields! + lookback_ctx = context.model_copy( + update={"start_date": context.start_date - timedelta(days=self.window)} + ) + return [(self.records, [lookback_ctx])] + + +# ============================================================================= +# Example 5: Multi-stage pipeline - Composing models together +# ============================================================================= + +@Flow.model +def generate_report( + context: DateRangeContext, + totals: DepOf[..., GenericResult[dict]], + moving_avg: DepOf[..., GenericResult[float]], + report_name: str = "Daily Report", +) -> GenericResult[str]: + """Generate a report combining multiple data sources.""" + report = f""" +{report_name} +{'=' * len(report_name)} +Date Range: {context.start_date} to {context.end_date} +Total Value: {totals.value['total']} +Record Count: {totals.value['count']} +Moving Avg: {moving_avg.value:.2f} +""" + return GenericResult(value=report.strip()) + + +# ============================================================================= +# Example 6: Using context_args for cleaner signatures +# ============================================================================= + +@Flow.model(context_args=["start_date", "end_date"]) +def fetch_metadata(start_date: date, end_date: date, category: str) -> GenericResult[dict]: + """Fetch metadata - note how start_date/end_date are direct parameters.""" + print(f" Fetching metadata for '{category}' from {start_date} to {end_date}") + return GenericResult(value={ + "category": category, + "days": (end_date - start_date).days, + "generated_at": str(date.today()), + }) + + +# ============================================================================= +# Main: Build and execute the pipeline +# ============================================================================= + +def main(): + print("=" * 60) + print("Flow.model Example - Simplified CallableModel Creation") + print("=" * 60) + + ctx = DateRangeContext( + start_date=date(2024, 1, 15), + end_date=date(2024, 1, 31) + ) + + # --- Example 1: Basic model --- + print("\n[1] Basic Flow.model:") + loader = load_records(source="main_db", limit=5) + result = loader(ctx) + print(f" Result: {result.value}") + + # --- Example 2: Simple dependency chain --- + print("\n[2] Dependency chain (loader -> totals):") + loader = load_records(source="main_db") + totals = compute_totals(records=loader) + result = totals(ctx) + print(f" Result: {result.value}") + + # --- Example 3: Fixed transform with Flow.model --- + print("\n[3] Fixed transform (7-day lookback with Flow.model):") + loader = load_records(source="main_db") + weekly_avg = compute_weekly_average(records=loader) + result = weekly_avg(ctx) + print(f" Result: {result.value}") + + # --- Example 4: Configurable transform with class-based model --- + print("\n[4] Configurable transform (class-based with auto-resolution):") + loader = load_records(source="main_db") + + # 14-day window + moving_avg_14 = ComputeMovingAverage(records=loader, window=14) + result = moving_avg_14(ctx) + print(f" 14-day result: {result.value}") + + # 30-day window - same loader, different window + moving_avg_30 = ComputeMovingAverage(records=loader, window=30) + result = moving_avg_30(ctx) + print(f" 30-day result: {result.value}") + + # --- Example 5: Full pipeline --- + print("\n[5] Full pipeline (mixing Flow.model and class-based):") + loader = load_records(source="analytics_db") + totals = compute_totals(records=loader) + moving_avg = ComputeMovingAverage(records=loader, window=7) + report = generate_report( + totals=totals, + moving_avg=moving_avg, + report_name="Analytics Summary" + ) + result = report(ctx) + print(result.value) + + # --- Example 6: context_args --- + print("\n[6] Using context_args (auto-unpacked context):") + metadata = fetch_metadata(category="sales") + result = metadata(ctx) + print(f" Result: {result.value}") + + # --- Bonus: Inspecting models --- + print("\n[Bonus] Inspecting models:") + print(f" load_records.context_type = {loader.context_type.__name__}") + print(f" ComputeMovingAverage uses __deps__ for custom transforms") + deps = moving_avg.__deps__(ctx) + for dep_model, dep_contexts in deps: + print(f" - Dependency context start: {dep_contexts[0].start_date} (lookback applied)") + + +if __name__ == "__main__": + main() From 970b4ab24f0f1d2c86ad1a0d3aa652bece4c9187 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Sun, 4 Jan 2026 21:09:54 -0500 Subject: [PATCH 04/26] Inside CallableModel, force calling resolve on DepOf to not do hacky switching out attributes at runtime Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 72 +++++++++++++++++++++++++----- ccflow/dep.py | 2 +- ccflow/flow_model.py | 12 ++++- ccflow/tests/test_flow_model.py | 78 ++++++++++++++++++++++++++++----- examples/flow_model_example.py | 8 ++-- 5 files changed, 144 insertions(+), 28 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 5296bfe..8f6adfe 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -14,6 +14,7 @@ import abc import inspect import logging +from contextvars import ContextVar from functools import lru_cache, wraps from inspect import Signature, isclass, signature from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin @@ -47,6 +48,8 @@ "EvaluatorBase", "Evaluator", "WrapperModel", + # Note: resolve() is intentionally not in __all__ to avoid namespace pollution. + # Users who need it can import explicitly: from ccflow.callable import resolve ) log = logging.getLogger(__name__) @@ -239,10 +242,60 @@ def _wrap_with_dep_resolution(fn): return fn +# Context variable for storing resolved dependency values during __call__ +# Maps id(callable_model) -> resolved_value +_resolved_deps: ContextVar[Dict[int, Any]] = ContextVar("resolved_deps", default={}) + +# TypeVar for resolve() function to enable proper type inference +_T = TypeVar("_T") + + +def resolve(dep: Union[_T, "_CallableModel"]) -> _T: + """Access the resolved value of a DepOf dependency during __call__. + + This function is used inside a CallableModel's __call__ method to get + the resolved value of a dependency field. It provides proper type inference - + if the field is `DepOf[..., GenericResult[int]]`, this returns `GenericResult[int]`. + + Args: + dep: The dependency field value (either a CallableModel or already-resolved value) + + Returns: + The resolved value. If dep is already a resolved value (not a CallableModel), + returns it unchanged. + + Raises: + RuntimeError: If called outside of __call__ or if the dependency wasn't resolved. + + Example: + class MyModel(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: MyContext) -> GenericResult[int]: + # resolve() provides proper type inference + data = resolve(self.data) # type: GenericResult[int] + return GenericResult(value=data.value + 1) + """ + # If it's not a CallableModel, it's already a resolved value - pass through + if not isinstance(dep, _CallableModel): + return dep # type: ignore[return-value] + + # Look up in context var + store = _resolved_deps.get() + dep_id = id(dep) + if dep_id not in store: + raise RuntimeError( + "resolve() can only be used inside __call__ for DepOf fields. Make sure the field is annotated with DepOf and contains a CallableModel." + ) + return store[dep_id] + + def _resolve_deps_and_call(model, context, fn): """Resolve DepOf fields and call the function. This is called from ModelEvaluationContext.__call__ to handle dep resolution. + Resolved values are stored in a context variable and accessed via resolve(). Args: model: The CallableModel instance @@ -269,8 +322,8 @@ def _resolve_deps_and_call(model, context, fn): for dep_model, contexts in deps_result: dep_map[id(dep_model)] = (dep_model, contexts) - # Store original values and resolve - originals = {} + # Resolve dependencies and store in context var + resolved_values = {} for field_name, dep in dep_fields.items(): field_value = getattr(model, field_name, None) if field_value is None: @@ -280,8 +333,6 @@ def _resolve_deps_and_call(model, context, fn): if not isinstance(field_value, _CallableModel): continue # Already a resolved value, skip - originals[field_name] = field_value - # Check if this field is in __deps__ (for custom transforms) if id(field_value) in dep_map: dep_model, contexts = dep_map[id(field_value)] @@ -292,16 +343,17 @@ def _resolve_deps_and_call(model, context, fn): transformed_ctx = dep.apply(context) resolved = field_value(transformed_ctx) - # Temporarily set resolved value on model - object.__setattr__(model, field_name, resolved) + # Store resolved value keyed by the CallableModel's id + resolved_values[id(field_value)] = resolved + # Store in context var and call function + current_store = _resolved_deps.get() + new_store = {**current_store, **resolved_values} + token = _resolved_deps.set(new_store) try: - # Call original function return fn(model, context) finally: - # Restore original CallableModel values - for field_name, original_value in originals.items(): - object.__setattr__(model, field_name, original_value) + _resolved_deps.reset(token) class FlowOptions(BaseModel): diff --git a/ccflow/dep.py b/ccflow/dep.py index a7e0121..b57261e 100644 --- a/ccflow/dep.py +++ b/ccflow/dep.py @@ -154,7 +154,7 @@ class Dep: def __init__( self, - transform: Optional[Callable[[ContextBase], ContextBase]] = None, + transform: Optional[Callable[..., ContextBase]] = None, context_type: Optional[Type[ContextBase]] = None, ): """ diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 3a96886..25db69c 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -190,6 +190,9 @@ def decorator(fn: Callable) -> Callable: # Create the __call__ method def make_call_impl(): + # Import resolve here to avoid circular import at module level + from .callable import resolve + def __call__(self, context): # Build kwargs for the original function if use_context_args: @@ -199,9 +202,14 @@ def __call__(self, context): # Pass context directly (using actual parameter name: 'context' or '_') fn_kwargs = {ctx_param_name: context} - # Add model fields (deps are resolved by _resolve_deps_and_call in callable.py) + # Add model fields - use resolve() for dep fields to get resolved values for name in model_fields: - fn_kwargs[name] = getattr(self, name) + value = getattr(self, name) + if name in dep_fields: + # Use resolve() to get the resolved value from context var + fn_kwargs[name] = resolve(value) + else: + fn_kwargs[name] = value return fn(**fn_kwargs) diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 75e0899..fe7e32e 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -18,6 +18,7 @@ ModelRegistry, ResultBase, ) +from ccflow.callable import resolve class SimpleContext(ContextBase): @@ -1153,7 +1154,7 @@ class TestClassBasedDepResolution(TestCase): """ def test_class_based_auto_resolve_basic(self): - """Test that DepOf fields are auto-resolved and accessible via self.""" + """Test that DepOf fields are auto-resolved and accessible via resolve().""" @Flow.model def data_source(context: SimpleContext) -> GenericResult[int]: @@ -1165,8 +1166,8 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - # Access resolved value via self.source - return GenericResult(value=self.source.value + 1) + # Access resolved value via resolve() + return GenericResult(value=resolve(self.source).value + 1) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1192,7 +1193,7 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.source.value + self.offset) + return GenericResult(value=resolve(self.source).value + self.offset) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1224,7 +1225,7 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.source.value + 1) + return GenericResult(value=resolve(self.source).value + 1) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1256,7 +1257,7 @@ class Aggregator(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.a.value + self.b.value) + return GenericResult(value=resolve(self.a).value + resolve(self.b).value) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1285,7 +1286,7 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.source.value * 2) + return GenericResult(value=resolve(self.source).value * 2) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1310,7 +1311,8 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.source.value + context.value) + # resolve() passes through non-CallableModel values unchanged + return GenericResult(value=resolve(self.source).value + context.value) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1341,7 +1343,7 @@ class Consumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=self.data.value + 1) + return GenericResult(value=resolve(self.data).value + 1) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1381,7 +1383,7 @@ class LayerB(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: call_counts["layer_b"] += 1 - return GenericResult(value=self.source.value * 10) + return GenericResult(value=resolve(self.source).value * 10) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1394,7 +1396,7 @@ class LayerA(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: call_counts["layer_a"] += 1 - return GenericResult(value=self.source.value + 1) + return GenericResult(value=resolve(self.source).value + 1) @Flow.deps def __deps__(self, context: SimpleContext): @@ -1416,6 +1418,58 @@ def __deps__(self, context: SimpleContext): self.assertEqual(call_counts["layer_b"], 1, "layer_b should be called exactly once") self.assertEqual(call_counts["layer_a"], 1, "layer_a should be called exactly once") + def test_resolve_direct_value_passthrough(self): + """Test that resolve() passes through non-CallableModel values unchanged.""" + + class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + # resolve() should return the GenericResult directly (pass-through) + resolved = resolve(self.data) + # Verify it's the actual GenericResult, not a CallableModel + assert isinstance(resolved, GenericResult) + return GenericResult(value=resolved.value * 2) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [] + + # Pass a direct value, not a CallableModel + direct_result = GenericResult(value=42) + consumer = Consumer(data=direct_result) + + result = consumer(SimpleContext(value=5)) + self.assertEqual(result.value, 84) # 42 * 2 + + def test_resolve_outside_call_raises_error(self): + """Test that resolve() raises RuntimeError when called outside __call__.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=resolve(self.data).value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + return [(self.data, [context])] + + src = source() + consumer = Consumer(data=src) + + # Calling resolve() outside of __call__ should raise RuntimeError + with self.assertRaises(RuntimeError) as cm: + resolve(consumer.data) + + self.assertIn("resolve() can only be used inside __call__", str(cm.exception)) + def test_flow_model_uses_unified_resolution_path(self): """Test that @Flow.model uses the same resolution path as class-based CallableModel. @@ -1445,7 +1499,7 @@ class ClassConsumer(CallableModel): @Flow.call def __call__(self, context: SimpleContext) -> GenericResult[int]: call_counts["class_model"] += 1 - return GenericResult(value=self.data.value + 100) + return GenericResult(value=resolve(self.data).value + 100) @Flow.deps def __deps__(self, context: SimpleContext): diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index c3d12d1..e93d452 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -11,6 +11,7 @@ from typing import Annotated from ccflow import CallableModel, DateRangeContext, Dep, DepOf, Flow, GenericResult +from ccflow.callable import resolve # ============================================================================= @@ -77,7 +78,7 @@ class ComputeMovingAverage(CallableModel): This demonstrates: - Field uses DepOf annotation: accepts either result or CallableModel - Instance field (window) accessible in __deps__ for custom transforms - - Auto-resolution: self.records returns resolved value during __call__ + - resolve() to access resolved dependency values during __call__ """ records: DepOf[..., GenericResult[list]] @@ -85,8 +86,9 @@ class ComputeMovingAverage(CallableModel): @Flow.call def __call__(self, context: DateRangeContext) -> GenericResult[float]: - """Compute the moving average - self.records is already resolved.""" - values = [r["value"] for r in self.records.value] + """Compute the moving average - use resolve() to get resolved value.""" + records = resolve(self.records) # Get the resolved GenericResult + values = [r["value"] for r in records.value] avg = sum(values) / len(values) if values else 0 print(f" Computing {self.window}-day moving average: {avg:.2f} (from {len(values)} records)") return GenericResult(value=avg) From 2696fcea9993be36470b026e6d70935dc1b8f550 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Sun, 4 Jan 2026 22:00:04 -0500 Subject: [PATCH 05/26] High level design doc Signed-off-by: Nijat Khanbabayev --- docs/design/flow_model_design.md | 440 +++++++++++++++++++++++++++++++ 1 file changed, 440 insertions(+) create mode 100644 docs/design/flow_model_design.md diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md new file mode 100644 index 0000000..76d0eb7 --- /dev/null +++ b/docs/design/flow_model_design.md @@ -0,0 +1,440 @@ +# Flow.model and DepOf: Dependency Injection for CallableModel + +## Overview + +This document describes the `@Flow.model` decorator and `DepOf` annotation system for reducing boilerplate when creating `CallableModel` pipelines with dependencies. + +**Key features:** +- `@Flow.model` - Decorator that generates `CallableModel` classes from plain functions +- `DepOf[ContextType, ResultType]` - Type annotation for dependency fields +- `resolve()` - Function to access resolved dependency values in class-based models + +## Quick Start + +### Pattern 1: `@Flow.model` (Recommended for Simple Cases) + +```python +from datetime import date, timedelta +from typing import Annotated + +from ccflow import Flow, DateRangeContext, GenericResult, DepOf + +@Flow.model +def load_records(context: DateRangeContext, source: str) -> GenericResult[dict]: + return GenericResult(value={"count": 100, "date": str(context.start_date)}) + +@Flow.model +def compute_stats( + context: DateRangeContext, + records: DepOf[..., GenericResult[dict]], # Dependency field +) -> GenericResult[float]: + # records is already resolved - just use it directly + return GenericResult(value=records.value["count"] * 0.05) + +# Build pipeline +loader = load_records(source="main_db") +stats = compute_stats(records=loader) + +# Execute +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = stats(ctx) +``` + +### Pattern 2: Class-Based (For Complex Cases) + +Use class-based when you need **configurable transforms** that depend on instance fields: + +```python +from datetime import timedelta + +from ccflow import CallableModel, DateRangeContext, Flow, GenericResult, DepOf +from ccflow.callable import resolve # Import resolve for class-based models + +class AggregateWithWindow(CallableModel): + """Aggregate records with configurable lookback window.""" + + records: DepOf[..., GenericResult[dict]] + window: int = 7 # Configurable instance field + + @Flow.call + def __call__(self, context: DateRangeContext) -> GenericResult[float]: + # Use resolve() to get the resolved value + records = resolve(self.records) + return GenericResult(value=records.value["count"] / self.window) + + @Flow.deps + def __deps__(self, context: DateRangeContext): + # Transform uses self.window - this is why we need class-based! + lookback_ctx = context.model_copy( + update={"start_date": context.start_date - timedelta(days=self.window)} + ) + return [(self.records, [lookback_ctx])] + +# Usage - different window sizes, same source +loader = load_records(source="main_db") +agg_7 = AggregateWithWindow(records=loader, window=7) +agg_30 = AggregateWithWindow(records=loader, window=30) +``` + +## When to Use Which Pattern + +| Use `@Flow.model` when... | Use Class-Based when... | +|--------------------------------|--------------------------------------| +| Simple transformations | Transforms depend on instance fields | +| Fixed context transforms | Need `self.field` in `__deps__` | +| Less boilerplate is priority | Full control over resolution | +| No custom `__deps__` logic | Complex dependency patterns | + +## Core Concepts + +### `DepOf[ContextType, ResultType]` + +Shorthand for declaring dependency fields that can accept either: +- A pre-computed value of `ResultType` +- A `CallableModel` that produces `ResultType` + +```python +# Inherit context type from parent model +data: DepOf[..., GenericResult[dict]] + +# Explicit context type +data: DepOf[DateRangeContext, GenericResult[dict]] + +# Equivalent to: +data: Annotated[Union[GenericResult[dict], CallableModel], Dep()] +``` + +### `Dep(transform=..., context_type=...)` + +For transforms, use the full `Annotated` form: + +```python +from ccflow import Dep + +@Flow.model +def compute_stats( + context: DateRangeContext, + records: Annotated[GenericResult[dict], Dep( + transform=lambda ctx: ctx.model_copy( + update={"start_date": ctx.start_date - timedelta(days=1)} + ) + )], +) -> GenericResult[float]: + return GenericResult(value=records.value["count"] * 0.05) +``` + +### `resolve()` Function + +**Only needed for class-based models.** Accesses the resolved value of a `DepOf` field during `__call__`. + +```python +from ccflow.callable import resolve + +class MyModel(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context: MyContext) -> GenericResult[int]: + # resolve() returns the GenericResult, not the CallableModel + result = resolve(self.data) + return GenericResult(value=result.value + 1) +``` + +**Behavior:** +- Inside `__call__`: Returns the resolved value +- With direct values (not CallableModel): Returns unchanged (no-op) +- Outside `__call__`: Raises `RuntimeError` +- In `@Flow.model`: Not needed - values are passed as function arguments + +**Type inference:** +```python +data: DepOf[..., GenericResult[int]] +resolved = resolve(self.data) # Type: GenericResult[int] +``` + +## How Resolution Works + +### `@Flow.model` Resolution Flow + +1. User calls `model(context)` +2. Generated `__call__` invokes `_resolve_deps_and_call()` +3. For each `DepOf` field containing a `CallableModel`: + - Apply transform (if any) + - Call the dependency + - Store resolved value in context variable +4. Generated `__call__` retrieves resolved values via `resolve()` +5. Original function receives resolved values as arguments + +### Class-Based Resolution Flow + +1. User calls `model(context)` +2. `_resolve_deps_and_call()` runs +3. For each `DepOf` field containing a `CallableModel`: + - Check `__deps__` for custom transforms + - Call the dependency + - Store resolved value in context variable +4. User's `__call__` accesses values via `resolve(self.field)` + +**Important:** Resolution uses a context variable (`contextvars.ContextVar`), making it thread-safe and async-safe. + +## Design Decisions + +### Decision 1: `resolve()` Instead of Temporary Mutation + +**What we chose:** Explicit `resolve()` function with context variables. + +**Alternative considered:** Temporarily mutate `self.field` during `__call__` to hold the resolved value, then restore after. + +**Why we chose this:** +- No mutation of model state +- Thread/async-safe via contextvars +- Explicit about what's happening +- Easier to debug - `self.field` always shows the original value + +**Trade-off:** Slightly more verbose (`resolve(self.data).value` vs `self.data.value`). + +### Decision 2: Unified Resolution Path + +**What we chose:** Both `@Flow.model` and class-based use the same `_resolve_deps_and_call()` function. + +**Why:** +- Single source of truth for resolution logic +- Easier to maintain +- Consistent behavior across patterns + +### Decision 3: `resolve()` Not in Top-Level `__all__` + +**What we chose:** `resolve` must be imported explicitly: `from ccflow.callable import resolve` + +**Why:** +- Only needed for class-based models with `DepOf` +- Keeps top-level namespace clean +- Users who need it can find it easily + +### Decision 4: No Auto-Wrapping Return Values + +**What we chose:** Functions must explicitly return `ResultBase` subclass. + +**Why:** +- Type annotations remain honest +- Consistent with existing `CallableModel` contract +- `GenericResult(value=x)` is minimal overhead + +### Decision 5: Generated Classes Are Real CallableModels + +**What we chose:** Generate actual `CallableModel` subclasses using `type()`. + +**Why:** +- Full compatibility with existing infrastructure +- Caching, registry, serialization work unchanged +- Can mix with hand-written classes + +## Pitfalls and Limitations + +### Pitfall 1: Forgetting `resolve()` in Class-Based Models + +```python +class MyModel(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context): + # WRONG - self.data is still the CallableModel! + return GenericResult(value=self.data.value + 1) + + # CORRECT + return GenericResult(value=resolve(self.data).value + 1) +``` + +**Error you'll see:** `AttributeError: '_SomeModel' object has no attribute 'value'` + +### Pitfall 2: Calling `resolve()` Outside `__call__` + +```python +model = MyModel(data=some_source()) +resolve(model.data) # RuntimeError! +``` + +`resolve()` only works during `__call__` execution. + +### Pitfall 3: Lambda Transforms Don't Serialize + +```python +# Won't serialize - lambdas can't be pickled +Dep(transform=lambda ctx: ctx.model_copy(...)) + +# Will serialize - use named functions +def shift_start(ctx): + return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + +Dep(transform=shift_start) +``` + +### Pitfall 4: GraphEvaluator Requires Caching + +When using `GraphEvaluator` with `DepOf`, dependencies may be called twice (once by GraphEvaluator, once by resolution) unless caching is enabled. + +```python +# Use with caching +from ccflow.evaluators import GraphEvaluator, CachingEvaluator, MultiEvaluator + +evaluator = MultiEvaluator(evaluators=[ + CachingEvaluator(), + GraphEvaluator(), +]) +``` + +### Pitfall 5: Two Mental Models + +Users need to remember: +- `@Flow.model`: Use dependency values directly as function arguments +- Class-based: Use `resolve(self.field)` to access values + +### Limitation: `__deps__` Still Required for Class-Based + +Even without transforms, class-based models need `__deps__`: + +```python +class Consumer(CallableModel): + data: DepOf[..., GenericResult[int]] + + @Flow.call + def __call__(self, context): + return GenericResult(value=resolve(self.data).value) + + @Flow.deps + def __deps__(self, context): + return [(self.data, [context])] # Boilerplate, but required +``` + +## Complete Example: Multi-Stage Pipeline + +```python +from datetime import date, timedelta +from typing import Annotated + +from ccflow import ( + CallableModel, DateRangeContext, Dep, DepOf, + Flow, GenericResult +) +from ccflow.callable import resolve + + +# Stage 1: Data loader (simple, use @Flow.model) +@Flow.model +def load_events(context: DateRangeContext, source: str) -> GenericResult[list]: + print(f"Loading from {source} for {context.start_date} to {context.end_date}") + return GenericResult(value=[ + {"date": str(context.start_date), "count": 100 + i} + for i in range(5) + ]) + + +# Stage 2: Transform with fixed lookback (use @Flow.model with Dep transform) +@Flow.model +def compute_daily_totals( + context: DateRangeContext, + events: Annotated[GenericResult[list], Dep( + transform=lambda ctx: ctx.model_copy( + update={"start_date": ctx.start_date - timedelta(days=1)} + ) + )], +) -> GenericResult[float]: + values = [e["count"] for e in events.value] + total = sum(values) / len(values) if values else 0 + return GenericResult(value=total) + + +# Stage 3: Configurable window (use class-based) +class ComputeRollingSummary(CallableModel): + """Summary with configurable lookback window.""" + + totals: DepOf[..., GenericResult[float]] + window: int = 20 + + @Flow.call + def __call__(self, context: DateRangeContext) -> GenericResult[float]: + totals = resolve(self.totals) + # Scale by window size + summary = totals.value * (self.window ** 0.5) + return GenericResult(value=summary) + + @Flow.deps + def __deps__(self, context: DateRangeContext): + lookback = context.model_copy( + update={"start_date": context.start_date - timedelta(days=self.window)} + ) + return [(self.totals, [lookback])] + + +# Build pipeline +events = load_events(source="main_db") +totals = compute_daily_totals(events=events) +summary_20 = ComputeRollingSummary(totals=totals, window=20) +summary_60 = ComputeRollingSummary(totals=totals, window=60) + +# Execute +ctx = DateRangeContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) +print(f"20-day summary: {summary_20(ctx).value}") +print(f"60-day summary: {summary_60(ctx).value}") +``` + +## API Reference + +### `@Flow.model` + +```python +@Flow.model( + context_args: list[str] = None, # Unpack context fields as function args + cacheable: bool = False, + volatile: bool = False, + log_level: int = logging.DEBUG, + validate_result: bool = True, + verbose: bool = True, + evaluator: EvaluatorBase = None, +) +def my_function(context: ContextType, ...) -> ResultType: + ... +``` + +### `DepOf[ContextType, ResultType]` + +```python +# Inherit context from parent +field: DepOf[..., GenericResult[int]] + +# Explicit context type +field: DepOf[DateRangeContext, GenericResult[int]] +``` + +### `Dep(transform=..., context_type=...)` + +```python +field: Annotated[GenericResult[int], Dep( + transform=my_transform_func, # Optional: (context) -> transformed_context + context_type=DateRangeContext, # Optional: Expected context type +)] +``` + +### `resolve(dep)` + +```python +from ccflow.callable import resolve + +# Inside __call__ of class-based CallableModel: +resolved_value = resolve(self.dep_field) + +# Type signature: +def resolve(dep: Union[T, CallableModel]) -> T: ... +``` + +## File Structure + +``` +ccflow/ +├── callable.py # CallableModel, Flow, resolve(), _resolve_deps_and_call() +├── dep.py # Dep, DepOf, extract_dep() +├── flow_model.py # @Flow.model implementation +└── tests/ + └── test_flow_model.py # Comprehensive tests +``` From d180f3566d53f7d5989f4bf005ea4ec073c73210 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 13 Jan 2026 11:03:22 -0500 Subject: [PATCH 06/26] Add extra stuff, need clean-up Signed-off-by: Nijat Khanbabayev --- ccflow/__init__.py | 1 + ccflow/callable.py | 51 +-- ccflow/context.py | 41 ++- ccflow/flow_model.py | 522 ++++++++++++++++++++++++++---- ccflow/tests/test_context.py | 9 +- ccflow/tests/test_flow_context.py | 467 ++++++++++++++++++++++++++ ccflow/tests/test_flow_model.py | 44 ++- 7 files changed, 1042 insertions(+), 93 deletions(-) create mode 100644 ccflow/tests/test_flow_context.py diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 9916168..c8d2259 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -12,6 +12,7 @@ from .context import * from .dep import * from .enums import Enum +from .flow_model import FlowAPI, BoundModel, Lazy from .global_state import * from .local_persistence import * from .models import * diff --git a/ccflow/callable.py b/ccflow/callable.py index 8f6adfe..1aa7189 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -312,7 +312,10 @@ def _resolve_deps_and_call(model, context, fn): # Get Dep-annotated fields for this model class dep_fields = _get_dep_fields(model.__class__) - if not dep_fields: + # Check if model has custom deps (from @func.deps decorator) + has_custom_deps = getattr(model.__class__, "__has_custom_deps__", False) + + if not dep_fields and not has_custom_deps: return fn(model, context) # Get dependencies from __deps__ @@ -324,27 +327,37 @@ def _resolve_deps_and_call(model, context, fn): # Resolve dependencies and store in context var resolved_values = {} - for field_name, dep in dep_fields.items(): - field_value = getattr(model, field_name, None) - if field_value is None: - continue - - # Check if field is a CallableModel that needs resolution - if not isinstance(field_value, _CallableModel): - continue # Already a resolved value, skip - # Check if this field is in __deps__ (for custom transforms) - if id(field_value) in dep_map: - dep_model, contexts = dep_map[id(field_value)] - # Call dependency with the (transformed) context + # If custom deps, resolve ALL CallableModel fields from dep_map + if has_custom_deps: + for dep_model, contexts in deps_result: resolved = dep_model(contexts[0]) if contexts else dep_model(context) - else: - # Not in __deps__, use Dep annotation transform directly - transformed_ctx = dep.apply(context) - resolved = field_value(transformed_ctx) + # Unwrap GenericResult if present (consistent with auto-detected deps) + if hasattr(resolved, 'value'): + resolved = resolved.value + resolved_values[id(dep_model)] = resolved + else: + # Standard path: iterate over Dep-annotated fields + for field_name, dep in dep_fields.items(): + field_value = getattr(model, field_name, None) + if field_value is None: + continue + + # Check if field is a CallableModel that needs resolution + if not isinstance(field_value, _CallableModel): + continue # Already a resolved value, skip + + # Check if this field is in __deps__ (for custom transforms) + if id(field_value) in dep_map: + dep_model, contexts = dep_map[id(field_value)] + # Call dependency with the (transformed) context + resolved = dep_model(contexts[0]) if contexts else dep_model(context) + else: + # Not in __deps__, use Dep annotation transform directly + transformed_ctx = dep.apply(context) + resolved = field_value(transformed_ctx) - # Store resolved value keyed by the CallableModel's id - resolved_values[id(field_value)] = resolved + resolved_values[id(field_value)] = resolved # Store in context var and call function current_store = _resolved_deps.get() diff --git a/ccflow/context.py b/ccflow/context.py index cf17d24..62ce0f7 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -2,10 +2,10 @@ import warnings from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated -from pydantic import field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from .base import ContextBase from .exttypes import Frequency @@ -15,6 +15,7 @@ __all__ = ( + "FlowContext", "NullContext", "GenericContext", "DateContext", @@ -93,6 +94,42 @@ # Starting 0.8.0 Nullcontext is an alias to ContextBase NullContext = ContextBase + +class FlowContext(ContextBase): + """Universal context for @Flow.model functions. + + Instead of generating a new ContextBase subclass for each @Flow.model, + this single class with extra="allow" serves as the universal carrier. + Validation happens via TypedDict + TypeAdapter at compute() time. + + This design avoids: + - Proliferation of dynamic _funcname_Context classes + - Class registration overhead for serialization + - Pickling issues with Ray/distributed computing + + Fields are stored in __pydantic_extra__ and accessed via __getattr__. + """ + + model_config = ConfigDict(extra="allow", frozen=True) + + def __getattr__(self, name: str) -> Any: + """Access fields stored in __pydantic_extra__.""" + # Use object.__getattribute__ to avoid infinite recursion + try: + extra = object.__getattribute__(self, "__pydantic_extra__") + if extra is not None and name in extra: + return extra[name] + except AttributeError: + pass + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __repr__(self) -> str: + """Show all fields including extra fields.""" + extra = object.__getattribute__(self, "__pydantic_extra__") or {} + fields = ", ".join(f"{k}={v!r}" for k, v in extra.items()) + return f"FlowContext({fields})" + + C = TypeVar("C", bound=Hashable) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 25db69c..2d3ab3a 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -3,6 +3,10 @@ This module provides the Flow.model decorator that generates CallableModel classes from plain Python functions, reducing boilerplate while maintaining full compatibility with existing ccflow infrastructure. + +Key design: Uses TypedDict + TypeAdapter for context schema validation instead of +generating dynamic ContextBase subclasses. This avoids class registration overhead +and enables clean pickling for distributed computing (e.g., Ray). """ import inspect @@ -10,21 +14,219 @@ from functools import wraps from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_origin -from pydantic import Field +from pydantic import Field, TypeAdapter +from typing_extensions import TypedDict from .base import ContextBase, ResultBase +from .context import FlowContext from .dep import Dep, extract_dep -from .local_persistence import register_ccflow_import_path -__all__ = ("flow_model",) +__all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy") log = logging.getLogger(__name__) -def _infer_context_type_from_args(context_args: List[str], func: Callable, sig: inspect.Signature) -> Type[ContextBase]: - """Infer or create a context type from context_args parameter names. +class FlowAPI: + """API namespace for deferred computation operations. + + Provides methods for executing models and transforming contexts. + Accessed via model.flow property. + """ + + def __init__(self, model: "CallableModel"): # noqa: F821 + self._model = model + + def compute(self, **kwargs) -> Any: + """Execute the model with the provided context arguments. + + Validates kwargs against the model's context schema using TypeAdapter, + then wraps in FlowContext and calls the model. + + Args: + **kwargs: Context arguments (e.g., start_date, end_date) + + Returns: + The model's result, unwrapped from GenericResult if applicable. + """ + # Get validator from model (lazily created if needed after unpickling) + validator = self._model._get_context_validator() + + # Validate and coerce kwargs via TypeAdapter + validated = validator.validate_python(kwargs) + + # Wrap in FlowContext (single class, always) + ctx = FlowContext(**validated) + + # Call the model + result = self._model(ctx) + + # Unwrap GenericResult if present + if hasattr(result, "value"): + return result.value + return result + + @property + def unbound_inputs(self) -> Dict[str, Type]: + """Return the context schema (field name -> type). + + In deferred mode, this is everything NOT provided at construction. + """ + all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) + bound_fields = getattr(self._model, "_bound_fields", set()) + + # If explicit context_args was provided, use _context_schema + explicit_args = getattr(self._model.__class__, "__flow_model_explicit_context_args__", None) + if explicit_args is not None: + return self._model._context_schema.copy() + + # Otherwise, unbound = all params - bound + return {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + + @property + def bound_inputs(self) -> Dict[str, Any]: + """Return the config values bound at construction time.""" + bound_fields = getattr(self._model, "_bound_fields", set()) + result = {} + for name in bound_fields: + if hasattr(self._model, name): + result[name] = getattr(self._model, name) + return result + + def with_inputs(self, **transforms) -> "BoundModel": + """Create a version of this model with transformed context inputs. + + Args: + **transforms: Mapping of field name to either: + - A callable (ctx) -> value for dynamic transforms + - A static value to bind + + Returns: + A BoundModel that applies the transforms before calling. + """ + return BoundModel(model=self._model, input_transforms=transforms) + + +class BoundModel: + """A model with context transforms applied. + + Created by model.flow.with_inputs(). Applies transforms to context + before delegating to the underlying model. + """ + + def __init__(self, model: "CallableModel", input_transforms: Dict[str, Any]): # noqa: F821 + self._model = model + self._input_transforms = input_transforms + + def __call__(self, context: ContextBase) -> Any: + """Call the model with transformed context.""" + # Build new context dict with transforms applied + ctx_dict = {} + + # Get fields from context + if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: + ctx_dict.update(context.__pydantic_extra__) + for field in context.__class__.model_fields: + ctx_dict[field] = getattr(context, field) + + # Apply transforms + for name, transform in self._input_transforms.items(): + if callable(transform): + ctx_dict[name] = transform(context) + else: + ctx_dict[name] = transform + + # Create new context and call model + new_ctx = FlowContext(**ctx_dict) + return self._model(new_ctx) + + @property + def flow(self) -> FlowAPI: + """Access the flow API.""" + return FlowAPI(self._model) + + +class Lazy: + """Deferred model execution with runtime context overrides. + + Wraps a CallableModel to allow context fields to be determined at + runtime rather than at construction time. Use in with_inputs() when + you need values that aren't available until execution. + + Example: + # Create a model that needs runtime-determined context + market_data = load_market_data(symbols=["AAPL"]) + + # Use Lazy to defer the start_date calculation + lookback_data = market_data.flow.with_inputs( + start_date=Lazy(market_data)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + ) + + # More commonly, use Lazy for self-referential transforms: + adjusted_model = model.flow.with_inputs( + value=Lazy(other_model)(multiplier=2) # Call other_model with multiplier=2 + ) + + The __call__ method returns a callable that, when invoked with a context, + calls the wrapped model with the specified overrides applied. + """ + + def __init__(self, model: "CallableModel"): # noqa: F821 + """Wrap a model for deferred execution. + + Args: + model: The CallableModel to wrap + """ + self._model = model + + def __call__(self, **overrides) -> Callable[[ContextBase], Any]: + """Create a callable that applies overrides to context before execution. + + Args: + **overrides: Context field overrides. Values can be: + - Static values (applied directly) + - Callables (ctx) -> value (called with context at runtime) + + Returns: + A callable (context) -> result that applies overrides and calls the model + """ + model = self._model + + def execute_with_overrides(context: ContextBase) -> Any: + # Build context dict from incoming context + ctx_dict = {} + if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: + ctx_dict.update(context.__pydantic_extra__) + for field in context.__class__.model_fields: + ctx_dict[field] = getattr(context, field) + + # Apply overrides + for name, value in overrides.items(): + if callable(value): + ctx_dict[name] = value(context) + else: + ctx_dict[name] = value + + # Call model with modified context + new_ctx = FlowContext(**ctx_dict) + return model(new_ctx) + + return execute_with_overrides + + @property + def model(self) -> "CallableModel": # noqa: F821 + """Access the wrapped model.""" + return self._model + + +def _build_context_schema( + context_args: List[str], func: Callable, sig: inspect.Signature +) -> Tuple[Dict[str, Type], Type, Optional[Type[ContextBase]]]: + """Build context schema from context_args parameter names. - This attempts to match existing context types or creates a new one. + Instead of creating a dynamic ContextBase subclass, this builds: + - A schema dict mapping field names to types + - A TypedDict for Pydantic TypeAdapter validation + - Optionally, a matched existing ContextBase type for compatibility Args: context_args: List of parameter names that come from context @@ -32,25 +234,22 @@ def _infer_context_type_from_args(context_args: List[str], func: Callable, sig: sig: The function signature Returns: - A ContextBase subclass + Tuple of (schema_dict, TypedDict type, optional matched ContextBase type) """ - from .local_persistence import create_ccflow_model - - # Build field definitions for the context from parameter annotations - fields = {} + # Build schema dict from parameter annotations + schema = {} for name in context_args: if name not in sig.parameters: raise ValueError(f"context_arg '{name}' not found in function parameters") param = sig.parameters[name] if param.annotation is inspect.Parameter.empty: raise ValueError(f"context_arg '{name}' must have a type annotation") - default = ... if param.default is inspect.Parameter.empty else param.default - fields[name] = (param.annotation, default) + schema[name] = param.annotation - # Try to match common context types + # Try to match common context types for compatibility + matched_context_type = None from .context import DateRangeContext - # Check for DateRangeContext pattern if set(context_args) == {"start_date", "end_date"}: from datetime import date @@ -59,15 +258,12 @@ def _infer_context_type_from_args(context_args: List[str], func: Callable, sig: or (isinstance(sig.parameters[name].annotation, type) and sig.parameters[name].annotation is date) for name in context_args ): - return DateRangeContext + matched_context_type = DateRangeContext + + # Create TypedDict for validation (not registered anywhere!) + context_td = TypedDict(f"{func.__name__}Inputs", schema) - # Create a new context type dynamically - context_class = create_ccflow_model( - f"_{func.__name__}_Context", - __base__=ContextBase, - **fields, - ) - return context_class + return schema, context_td, matched_context_type def _get_dep_info(annotation) -> Tuple[Type, Optional[Dep]]: @@ -142,13 +338,8 @@ def decorator(fn: Callable) -> Callable: if not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)): raise TypeError(f"Function {fn.__name__} must return a ResultBase subclass, got {return_type}") - # Determine context mode and extract info - if context_args is not None: - # Mode 2: Unpacked context args - context_type = _infer_context_type_from_args(context_args, fn, sig) - model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} - use_context_args = True - elif "context" in params or "_" in params: + # Determine context mode + if "context" in params or "_" in params: # Mode 1: Explicit context parameter (named 'context' or '_' for unused) context_param_name = "context" if "context" in params else "_" context_param = params[context_param_name] @@ -159,57 +350,139 @@ def decorator(fn: Callable) -> Callable: raise TypeError(f"Function {fn.__name__}: '{context_param_name}' must be annotated with a ContextBase subclass") model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} use_context_args = False + explicit_context_args = None + elif context_args is not None: + # Mode 2: Explicit context_args - specified params come from context + context_param_name = "context" + # Build context schema early to determine matched_context_type + context_schema_early, _, matched_type = _build_context_schema(context_args, fn, sig) + # Use matched type if available (e.g., DateRangeContext), else FlowContext + context_type = matched_type if matched_type is not None else FlowContext + # Exclude context_args from model fields + model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} + use_context_args = True + explicit_context_args = context_args else: - raise TypeError(f"Function {fn.__name__} must either have a 'context' (or '_') parameter or specify context_args in the decorator") + # Mode 3: Dynamic deferred mode - ALL params are potential context or config + # What's provided at construction = config/deps + # What's NOT provided = comes from context at runtime + context_param_name = "context" + context_type = FlowContext + model_field_params = {name: param for name, param in params.items() if name != "self"} + use_context_args = True + explicit_context_args = None # Dynamic - determined at construction # Analyze parameters to find dependencies and regular fields dep_fields: Dict[str, Tuple[Type, Dep]] = {} # name -> (base_type, Dep) model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) + # In dynamic deferred mode (no explicit context_args), all fields are optional + # because values not provided at construction come from context at runtime + dynamic_deferred_mode = use_context_args and explicit_context_args is None + for name, param in model_field_params.items(): if param.annotation is inspect.Parameter.empty: raise TypeError(f"Parameter '{name}' must have a type annotation") base_type, dep = _get_dep_info(param.annotation) - default = ... if param.default is inspect.Parameter.empty else param.default + if param.default is not inspect.Parameter.empty: + default = param.default + elif dynamic_deferred_mode: + # In dynamic mode, params without defaults are optional (come from context) + default = None + else: + # In explicit mode, params without defaults are required + default = ... if dep is not None: - # This is a dependency parameter + # This is an explicit dependency parameter (DepOf annotation) dep_fields[name] = (base_type, dep) # Use Annotated so _resolve_deps_and_call in callable.py can find the Dep - # This consolidates resolution logic into one place model_fields[name] = (Annotated[Union[base_type, CallableModel], dep], default) else: - # Regular model field - model_fields[name] = (param.annotation, default) + # Regular model field - use Any for auto-detection of CallableModels. + # We can't use Union[T, CallableModel] because Pydantic tries to generate + # schema for T, which fails for arbitrary types like pl.DataFrame. + # Using Any allows any value; we do runtime isinstance checks in __call__. + model_fields[name] = (Any, default) - # Capture context_args in local variable for closures - ctx_args_list = context_args or [] - # Capture context parameter name for closures (only used in mode 1) + # Capture variables for closures ctx_param_name = context_param_name if not use_context_args else "context" + all_param_names = list(model_fields.keys()) # All non-context params (model fields) + all_param_types = {name: param.annotation for name, param in model_field_params.items()} + # For explicit context_args mode, we also need the list of context arg names + ctx_args_for_closure = context_args if context_args is not None else [] + is_dynamic_mode = use_context_args and explicit_context_args is None # Create the __call__ method def make_call_impl(): - # Import resolve here to avoid circular import at module level - from .callable import resolve - def __call__(self, context): + # Import here (inside function) to avoid pickling issues with ContextVar + from .callable import _resolved_deps + + # Check if this model has custom deps (from @func.deps decorator) + has_custom_deps = getattr(self.__class__, "__has_custom_deps__", False) + + def resolve_callable_model(name, value, store): + """Resolve a CallableModel field. + + When has_custom_deps is True and the value is NOT in the store, + it means the custom deps function chose not to include this dep. + In that case, we return None (the field's default) instead of + calling the CallableModel directly. + """ + if id(value) in store: + return store[id(value)] + elif has_custom_deps: + # Custom deps excluded this field - use None + return None + else: + # Auto-detection fallback: call directly + resolved = value(context) + if hasattr(resolved, 'value'): + return resolved.value + return resolved + # Build kwargs for the original function - if use_context_args: - # Unpack context into args - fn_kwargs = {name: getattr(context, name) for name in ctx_args_list} + fn_kwargs = {} + store = _resolved_deps.get() + + if not use_context_args: + # Mode 1: Explicit context param - pass context directly + fn_kwargs[ctx_param_name] = context + # Add model fields + for name in all_param_names: + value = getattr(self, name) + if isinstance(value, CallableModel): + fn_kwargs[name] = resolve_callable_model(name, value, store) + else: + fn_kwargs[name] = value + elif not is_dynamic_mode: + # Mode 2: Explicit context_args - get those from context, rest from self + for name in ctx_args_for_closure: + fn_kwargs[name] = getattr(context, name) + # Add model fields + for name in all_param_names: + value = getattr(self, name) + if isinstance(value, CallableModel): + fn_kwargs[name] = resolve_callable_model(name, value, store) + else: + fn_kwargs[name] = value else: - # Pass context directly (using actual parameter name: 'context' or '_') - fn_kwargs = {ctx_param_name: context} - - # Add model fields - use resolve() for dep fields to get resolved values - for name in model_fields: - value = getattr(self, name) - if name in dep_fields: - # Use resolve() to get the resolved value from context var - fn_kwargs[name] = resolve(value) - else: - fn_kwargs[name] = value + # Mode 3: Dynamic deferred mode - unbound from context, bound from self + bound_fields = getattr(self, "_bound_fields", set()) + + for name in all_param_names: + if name in bound_fields: + # Bound at construction - get from self + value = getattr(self, name) + if isinstance(value, CallableModel): + fn_kwargs[name] = resolve_callable_model(name, value, store) + else: + fn_kwargs[name] = value + else: + # Unbound - get from context + fn_kwargs[name] = getattr(context, name) return fn(**fn_kwargs) @@ -242,11 +515,18 @@ def __call__(self, context): def make_deps_impl(): def __deps__(self, context) -> GraphDepList: deps = [] - for dep_name, (base_type, dep_obj) in dep_fields.items(): - value = getattr(self, dep_name) + # Check ALL fields for CallableModels (auto-detection) + for name in model_fields: + value = getattr(self, name) if isinstance(value, CallableModel): - transformed_ctx = dep_obj.apply(context) - deps.append((value, [transformed_ctx])) + if name in dep_fields: + # Explicit DepOf with transform (backwards compat) + _, dep_obj = dep_fields[name] + transformed_ctx = dep_obj.apply(context) + deps.append((value, [transformed_ctx])) + else: + # Auto-detected dependency - use context as-is + deps.append((value, [context])) return deps # Set proper signature @@ -311,7 +591,69 @@ def __validate_deps__(self): GeneratedModel.__flow_model_func__ = fn GeneratedModel.__flow_model_dep_fields__ = dep_fields GeneratedModel.__flow_model_use_context_args__ = use_context_args - GeneratedModel.__flow_model_context_args__ = ctx_args_list + GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args + GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type + + # Build context_schema and matched_context_type + context_schema: Dict[str, Type] = {} + context_td = None + matched_context_type: Optional[Type[ContextBase]] = None + + if explicit_context_args is not None: + # Explicit context_args provided - use early-computed schema + # (matched_context_type was already used to set context_type above) + context_schema, context_td, matched_context_type = _build_context_schema(explicit_context_args, fn, sig) + elif not use_context_args: + # Explicit context mode - schema comes from the context type's fields + if hasattr(context_type, "model_fields"): + context_schema = {name: info.annotation for name, info in context_type.model_fields.items()} + # For dynamic mode (is_dynamic_mode), _context_schema remains empty + # and schema is built dynamically from _bound_fields at runtime + + # Store context schema for TypedDict-based validation (picklable!) + GeneratedModel._context_schema = context_schema + GeneratedModel._context_td = context_td + GeneratedModel._matched_context_type = matched_context_type + # Validator is created lazily to survive pickling + GeneratedModel._cached_context_validator = None + + # Method to get/create context validator (lazy for pickling support) + def _get_context_validator(self) -> TypeAdapter: + """Get or create the context validator. + + For dynamic deferred mode, builds schema from unbound fields. + For explicit context_args or explicit context mode, uses cached schema. + """ + cls = self.__class__ + explicit_args = getattr(cls, "__flow_model_explicit_context_args__", None) + + # For explicit context_args or explicit context mode, use cached validator + if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): + if cls._cached_context_validator is None: + if cls._context_td is not None: + cls._cached_context_validator = TypeAdapter(cls._context_td) + elif cls._context_schema: + td = TypedDict(f"{cls.__name__}Inputs", cls._context_schema) + cls._cached_context_validator = TypeAdapter(td) + else: + cls._cached_context_validator = TypeAdapter(cls.__flow_model_context_type__) + return cls._cached_context_validator + + # Dynamic mode: build schema from unbound fields (instance-specific) + # Cache on instance since bound_fields varies per instance + if not hasattr(self, "_instance_context_validator"): + all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) + bound_fields = getattr(self, "_bound_fields", set()) + unbound_schema = {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + if unbound_schema: + td = TypedDict(f"{cls.__name__}Inputs", unbound_schema) + object.__setattr__(self, "_instance_context_validator", TypeAdapter(td)) + else: + # No unbound fields - empty validator + object.__setattr__(self, "_instance_context_validator", TypeAdapter(dict)) + return self._instance_context_validator + + GeneratedModel._get_context_validator = _get_context_validator # Override context_type property after class creation @property @@ -323,10 +665,20 @@ def context_type_getter(self) -> Type[ContextBase]: def result_type_getter(self) -> Type[ResultBase]: return self.__class__.__flow_model_return_type__ + # Add .flow property for the new API + @property + def flow_getter(self) -> FlowAPI: + return FlowAPI(self) + GeneratedModel.context_type = context_type_getter GeneratedModel.result_type = result_type_getter + GeneratedModel.flow = flow_getter + + # Register the MODEL class for serialization (needed for model_dump/_target_). + # Note: We do NOT register dynamic context classes anymore - context handling + # uses FlowContext + TypedDict instead, which don't need registration. + from .local_persistence import register_ccflow_import_path - # Register for serialization (local classes need this) register_ccflow_import_path(GeneratedModel) # Rebuild the model to process annotations properly @@ -335,12 +687,56 @@ def result_type_getter(self) -> Type[ResultBase]: # Create factory function that returns model instances @wraps(fn) def factory(**kwargs) -> GeneratedModel: - return GeneratedModel(**kwargs) + instance = GeneratedModel(**kwargs) + # Track which fields were explicitly provided at construction + # These are "bound" - everything else comes from context at runtime + object.__setattr__(instance, "_bound_fields", set(kwargs.keys())) + return instance # Preserve useful attributes on factory factory._generated_model = GeneratedModel factory.__doc__ = fn.__doc__ + # Add .deps decorator for customizing __deps__ + def deps_decorator(deps_fn): + """Decorator to customize the __deps__ method. + + Usage: + @Flow.model + def my_func(start_date: date, prices: dict) -> GenericResult[...]: + ... + + @my_func.deps + def _(self, context): + # Custom context transform + lookback_ctx = FlowContext( + start_date=context.start_date - timedelta(days=30), + end_date=context.end_date, + ) + return [(self.prices, [lookback_ctx])] + """ + from .callable import GraphDepList + + # Rename the function to __deps__ so Flow.deps accepts it + deps_fn.__name__ = "__deps__" + deps_fn.__qualname__ = f"{GeneratedModel.__qualname__}.__deps__" + # Set proper signature to match __call__'s context type + deps_fn.__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + ], + return_annotation=GraphDepList, + ) + # Wrap with Flow.deps and replace on the class + decorated = Flow.deps(deps_fn) + GeneratedModel.__deps__ = decorated + # Mark that this model has custom deps (so _resolve_deps_and_call will call it) + GeneratedModel.__has_custom_deps__ = True + return factory # Return factory for chaining + + factory.deps = deps_decorator + return factory # Handle both @Flow.model and @Flow.model(...) syntax diff --git a/ccflow/tests/test_context.py b/ccflow/tests/test_context.py index ad98bd9..64d71e8 100644 --- a/ccflow/tests/test_context.py +++ b/ccflow/tests/test_context.py @@ -275,8 +275,13 @@ def split_camel(name: str): def test_inheritance(self): """Test that if a context has a superset of fields of another context, it is a subclass of that context.""" - for parent_name, parent_class in self.classes.items(): - for child_name, child_class in self.classes.items(): + # Exclude FlowContext from this test - it's a special universal carrier with no + # declared fields (uses extra="allow"), so the "superset implies subclass" logic + # doesn't apply to it. + classes_to_check = {name: cls for name, cls in self.classes.items() if name != "FlowContext"} + + for parent_name, parent_class in classes_to_check.items(): + for child_name, child_class in classes_to_check.items(): if parent_class is child_class: continue diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py new file mode 100644 index 0000000..70af8b2 --- /dev/null +++ b/ccflow/tests/test_flow_context.py @@ -0,0 +1,467 @@ +"""Tests for FlowContext, FlowAPI, and TypedDict-based context validation. + +These tests verify the new deferred computation API that uses: +- FlowContext: Universal context carrier with extra="allow" +- TypedDict + TypeAdapter: Schema validation without dynamic class registration +- FlowAPI: The .flow namespace for compute/with_inputs/etc. +""" + +import pickle +from datetime import date, timedelta + +import cloudpickle +import pytest + +from ccflow import Flow, FlowAPI, FlowContext, GenericResult +from ccflow.context import DateRangeContext + + +class TestFlowContext: + """Tests for the FlowContext universal carrier.""" + + def test_flow_context_basic(self): + """FlowContext accepts arbitrary fields.""" + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + assert ctx.start_date == date(2024, 1, 1) + assert ctx.end_date == date(2024, 1, 31) + + def test_flow_context_extra_fields(self): + """FlowContext stores fields in __pydantic_extra__.""" + ctx = FlowContext(x=1, y="hello", z=[1, 2, 3]) + assert ctx.x == 1 + assert ctx.y == "hello" + assert ctx.z == [1, 2, 3] + assert ctx.__pydantic_extra__ == {"x": 1, "y": "hello", "z": [1, 2, 3]} + + def test_flow_context_frozen(self): + """FlowContext is immutable (frozen).""" + ctx = FlowContext(value=42) + with pytest.raises(Exception): # ValidationError for frozen model + ctx.value = 100 + + def test_flow_context_repr(self): + """FlowContext has a useful repr.""" + ctx = FlowContext(a=1, b=2) + repr_str = repr(ctx) + assert "FlowContext" in repr_str + assert "a=1" in repr_str + assert "b=2" in repr_str + + def test_flow_context_attribute_error(self): + """FlowContext raises AttributeError for missing fields.""" + ctx = FlowContext(x=1) + with pytest.raises(AttributeError, match="no attribute 'missing'"): + _ = ctx.missing + + def test_flow_context_model_dump(self): + """FlowContext can be dumped (includes extra fields).""" + ctx = FlowContext(start_date=date(2024, 1, 1), value=42) + dumped = ctx.model_dump() + assert dumped["start_date"] == date(2024, 1, 1) + assert dumped["value"] == 42 + + def test_flow_context_pickle(self): + """FlowContext pickles cleanly.""" + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + pickled = pickle.dumps(ctx) + unpickled = pickle.loads(pickled) + assert unpickled.start_date == date(2024, 1, 1) + assert unpickled.end_date == date(2024, 1, 31) + + def test_flow_context_cloudpickle(self): + """FlowContext works with cloudpickle (for Ray).""" + ctx = FlowContext(data=[1, 2, 3], name="test") + pickled = cloudpickle.dumps(ctx) + unpickled = cloudpickle.loads(pickled) + assert unpickled.data == [1, 2, 3] + assert unpickled.name == "test" + + +class TestFlowAPI: + """Tests for the FlowAPI (.flow namespace).""" + + def test_flow_compute_basic(self): + """FlowAPI.compute() validates and executes.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date, "source": source}) + + model = load_data(source="api") + result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + + assert result["start"] == date(2024, 1, 1) + assert result["end"] == date(2024, 1, 31) + assert result["source"] == "api" + + def test_flow_compute_type_coercion(self): + """FlowAPI.compute() coerces types via TypeAdapter.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + # Pass strings - should be coerced to dates + result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") + + assert result["start"] == date(2024, 1, 1) + assert result["end"] == date(2024, 1, 31) + + def test_flow_compute_validation_error(self): + """FlowAPI.compute() raises on missing required args.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + with pytest.raises(Exception): # ValidationError + model.flow.compute(start_date=date(2024, 1, 1)) # Missing end_date + + def test_flow_unbound_inputs(self): + """FlowAPI.unbound_inputs returns the context schema.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data(source="api") + unbound = model.flow.unbound_inputs + + assert "start_date" in unbound + assert "end_date" in unbound + assert unbound["start_date"] == date + assert unbound["end_date"] == date + # source is not unbound (it has a default/is bound) + assert "source" not in unbound + + def test_flow_bound_inputs(self): + """FlowAPI.bound_inputs returns config values.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data(source="api") + bound = model.flow.bound_inputs + + assert "source" in bound + assert bound["source"] == "api" + # Context args are not in bound_inputs + assert "start_date" not in bound + assert "end_date" not in bound + + +class TestBoundModel: + """Tests for BoundModel (created via .flow.with_inputs()).""" + + def test_with_inputs_static_value(self): + """with_inputs can bind static values.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + bound = model.flow.with_inputs(start_date=date(2024, 1, 1)) + + # Call with just end_date (start_date is bound) + ctx = FlowContext(end_date=date(2024, 1, 31)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_with_inputs_transform_function(self): + """with_inputs can use transform functions.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + # Lookback: start_date is 7 days before the context's start_date + bound = model.flow.with_inputs(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + + ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 31)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) # 7 days before + assert result.value["end"] == date(2024, 1, 31) + + def test_with_inputs_multiple_transforms(self): + """with_inputs can apply multiple transforms.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + bound = model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=7), + end_date=lambda ctx: ctx.end_date + timedelta(days=1), + ) + + ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 30)) + result = bound(ctx) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_bound_model_has_flow_property(self): + """BoundModel has a .flow property.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x * 2) + + model = compute() + bound = model.flow.with_inputs(x=42) + assert isinstance(bound.flow, FlowAPI) + + +class TestTypedDictValidation: + """Tests for TypedDict-based context validation.""" + + def test_schema_stored_on_model(self): + """Model stores _context_schema for validation.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + assert hasattr(model, "_context_schema") + assert model._context_schema == {"start_date": date, "end_date": date} + + def test_validator_created_lazily(self): + """TypeAdapter validator is created lazily.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + model = compute() + # Initially None + assert model.__class__._cached_context_validator is None + + # After getting validator, it's cached + validator = model._get_context_validator() + assert validator is not None + assert model.__class__._cached_context_validator is validator + + def test_matched_context_type(self): + """DateRangeContext pattern is matched for compatibility.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + model = load_data() + # Should match DateRangeContext + assert model.context_type == DateRangeContext + + +class TestPicklingSupport: + """Tests for pickling support (important for Ray). + + Note: Regular pickle cannot pickle locally-defined classes (functions decorated + inside test methods). cloudpickle CAN handle this, which is why Ray uses it. + All tests here use cloudpickle to match Ray's behavior. + """ + + def test_model_cloudpickle_roundtrip(self): + """Model works with cloudpickle (for Ray).""" + + @Flow.model(context_args=["x", "y"]) + def compute(x: int, y: int, multiplier: int = 2) -> GenericResult[int]: + return GenericResult(value=(x + y) * multiplier) + + model = compute(multiplier=3) + + # cloudpickle roundtrip (what Ray uses) + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + # Should work after unpickling + result = unpickled.flow.compute(x=1, y=2) + assert result == 9 # (1 + 2) * 3 + + def test_model_cloudpickle_simple(self): + """Simple model cloudpickle test.""" + + @Flow.model(context_args=["value"]) + def double(value: int) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = double() + + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + result = unpickled.flow.compute(value=21) + assert result == 42 + + def test_validator_recreated_after_cloudpickle(self): + """TypeAdapter validator is recreated after cloudpickling.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x) + + model = compute() + # Warm up the validator cache + _ = model._get_context_validator() + assert model.__class__._cached_context_validator is not None + + # cloudpickle and unpickle + pickled = cloudpickle.dumps(model) + unpickled = cloudpickle.loads(pickled) + + # Validator should still work (may be lazily recreated) + result = unpickled.flow.compute(x=42) + assert result == 42 + + def test_flow_context_pickle_standard(self): + """FlowContext works with standard pickle.""" + ctx = FlowContext(x=1, y=2, z="test") + + pickled = pickle.dumps(ctx) + unpickled = pickle.loads(pickled) + + assert unpickled.x == 1 + assert unpickled.y == 2 + assert unpickled.z == "test" + + +class TestIntegrationWithExistingContextTypes: + """Tests for integration with existing ContextBase subclasses.""" + + def test_explicit_context_still_works(self): + """Explicit context parameter mode still works.""" + + @Flow.model + def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date, "source": source}) + + model = load_data(source="api") + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = model(ctx) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["source"] == "api" + + def test_flow_context_coerces_to_date_range(self): + """FlowContext can be used with models expecting DateRangeContext.""" + + @Flow.model + def load_data(context: DateRangeContext) -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date}) + + model = load_data() + # Use FlowContext - should coerce to DateRangeContext + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + result = model(ctx) + + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + + def test_flow_api_with_explicit_context(self): + """FlowAPI.compute works with explicit context mode.""" + + @Flow.model + def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: + return GenericResult(value={"start": context.start_date, "end": context.end_date}) + + model = load_data(source="api") + result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + + assert result["start"] == date(2024, 1, 1) + assert result["end"] == date(2024, 1, 31) + + +class TestLazy: + """Tests for Lazy (deferred execution with context overrides).""" + + def test_lazy_basic(self): + """Lazy wraps a model for deferred execution.""" + from ccflow import Lazy + + @Flow.model(context_args=["value"]) + def compute(value: int, multiplier: int = 2) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + model = compute(multiplier=3) + lazy = Lazy(model) + + assert lazy.model is model + + def test_lazy_call_with_static_override(self): + """Lazy.__call__ with static override values.""" + from ccflow import Lazy + + @Flow.model(context_args=["x", "y"]) + def add(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x + y) + + model = add() + lazy_fn = Lazy(model)(y=100) # Override y to 100 + + ctx = FlowContext(x=5, y=10) # Original y=10 + result = lazy_fn(ctx) + assert result.value == 105 # x=5 + y=100 (overridden) + + def test_lazy_call_with_callable_override(self): + """Lazy.__call__ with callable override (computed at runtime).""" + from ccflow import Lazy + + @Flow.model(context_args=["value"]) + def double(value: int) -> GenericResult[int]: + return GenericResult(value=value * 2) + + model = double() + # Override value to be original value + 10 + lazy_fn = Lazy(model)(value=lambda ctx: ctx.value + 10) + + ctx = FlowContext(value=5) + result = lazy_fn(ctx) + assert result.value == 30 # (5 + 10) * 2 = 30 + + def test_lazy_with_date_transforms(self): + """Lazy works with date transforms.""" + from ccflow import Lazy + + @Flow.model(context_args=["start_date", "end_date"]) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) + + model = load_data() + + # Use Lazy to create a transform that shifts dates + lazy_fn = Lazy(model)( + start_date=lambda ctx: ctx.start_date - timedelta(days=7), + end_date=lambda ctx: ctx.end_date + ) + + ctx = FlowContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) + result = lazy_fn(ctx) + + assert result.value["start"] == date(2024, 1, 8) # 7 days before + assert result.value["end"] == date(2024, 1, 31) + + def test_lazy_multiple_overrides(self): + """Lazy supports multiple overrides at once.""" + from ccflow import Lazy + + @Flow.model(context_args=["a", "b", "c"]) + def compute(a: int, b: int, c: int) -> GenericResult[int]: + return GenericResult(value=a + b + c) + + model = compute() + lazy_fn = Lazy(model)( + a=10, # Static + b=lambda ctx: ctx.b * 2, # Transform + # c not overridden, uses context value + ) + + ctx = FlowContext(a=1, b=5, c=100) + result = lazy_fn(ctx) + assert result.value == 10 + 10 + 100 # a=10, b=5*2=10, c=100 diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index fe7e32e..b283a2b 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -609,15 +609,45 @@ def bad_return(context: SimpleContext) -> int: self.assertIn("ResultBase", str(cm.exception)) - def test_missing_context_and_context_args(self): - """Test error when neither context param nor context_args provided.""" - with self.assertRaises(TypeError) as cm: + def test_dynamic_deferred_mode(self): + """Test dynamic deferred mode where what you provide at construction = bound.""" + from ccflow import FlowContext - @Flow.model - def no_context(value: int) -> GenericResult[int]: - return GenericResult(value=value) + @Flow.model + def dynamic_model(value: int, multiplier: int) -> GenericResult[int]: + return GenericResult(value=value * multiplier) + + # Provide 'multiplier' at construction -> it's bound + # Don't provide 'value' -> comes from context + model = dynamic_model(multiplier=3) + + # Check bound vs unbound + self.assertEqual(model.flow.bound_inputs, {"multiplier": 3}) + self.assertEqual(model.flow.unbound_inputs, {"value": int}) + + # Call with context providing 'value' + ctx = FlowContext(value=10) + result = model(ctx) + self.assertEqual(result.value, 30) # 10 * 3 - self.assertIn("context", str(cm.exception)) + def test_all_defaults_is_valid(self): + """Test that all-defaults function is valid (everything can be pre-bound).""" + from ccflow import FlowContext + + @Flow.model + def all_defaults(value: int = 1, other: str = "x") -> GenericResult[str]: + return GenericResult(value=f"{value}-{other}") + + # No args provided -> everything comes from defaults or context + model = all_defaults() + + # All params are unbound (not provided at construction) + self.assertEqual(model.flow.unbound_inputs, {"value": int, "other": str}) + + # Call with context - context values override defaults + ctx = FlowContext(value=5, other="y") + result = model(ctx) + self.assertEqual(result.value, "5-y") def test_invalid_context_arg(self): """Test error when context_args refers to non-existent parameter.""" From 94959158d17d4c3791d66ae0ad8e6fec7b00ac10 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Mon, 16 Mar 2026 18:30:24 -0400 Subject: [PATCH 07/26] Lint fixes Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 2 +- ccflow/flow_model.py | 2 +- ccflow/tests/test_flow_context.py | 5 +---- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 1aa7189..aa92ae1 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -333,7 +333,7 @@ def _resolve_deps_and_call(model, context, fn): for dep_model, contexts in deps_result: resolved = dep_model(contexts[0]) if contexts else dep_model(context) # Unwrap GenericResult if present (consistent with auto-detected deps) - if hasattr(resolved, 'value'): + if hasattr(resolved, "value"): resolved = resolved.value resolved_values[id(dep_model)] = resolved else: diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 2d3ab3a..414202e 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -439,7 +439,7 @@ def resolve_callable_model(name, value, store): else: # Auto-detection fallback: call directly resolved = value(context) - if hasattr(resolved, 'value'): + if hasattr(resolved, "value"): return resolved.value return resolved diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 70af8b2..3f613ab 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -436,10 +436,7 @@ def load_data(start_date: date, end_date: date) -> GenericResult[dict]: model = load_data() # Use Lazy to create a transform that shifts dates - lazy_fn = Lazy(model)( - start_date=lambda ctx: ctx.start_date - timedelta(days=7), - end_date=lambda ctx: ctx.end_date - ) + lazy_fn = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7), end_date=lambda ctx: ctx.end_date) ctx = FlowContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) result = lazy_fn(ctx) From 765299d9612960bb05c3ba49c5c11e55ab3c66bc Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 05:36:22 -0400 Subject: [PATCH 08/26] Flow.model cleanup Signed-off-by: Nijat Khanbabayev --- ccflow/__init__.py | 2 +- ccflow/callable.py | 52 +- ccflow/context.py | 21 +- ccflow/flow_model.py | 582 +++++++++++++-------- ccflow/tests/test_flow_context.py | 4 +- ccflow/tests/test_flow_model.py | 810 +++++++++++++++++++++++++++++- 6 files changed, 1211 insertions(+), 260 deletions(-) diff --git a/ccflow/__init__.py b/ccflow/__init__.py index c703a1c..4dbe143 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -12,7 +12,7 @@ from .context import * from .dep import * from .enums import Enum -from .flow_model import FlowAPI, BoundModel, Lazy +from .flow_model import * from .global_state import * from .local_persistence import * from .models import * diff --git a/ccflow/callable.py b/ccflow/callable.py index aa92ae1..d3b22e4 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -312,10 +312,7 @@ def _resolve_deps_and_call(model, context, fn): # Get Dep-annotated fields for this model class dep_fields = _get_dep_fields(model.__class__) - # Check if model has custom deps (from @func.deps decorator) - has_custom_deps = getattr(model.__class__, "__has_custom_deps__", False) - - if not dep_fields and not has_custom_deps: + if not dep_fields: return fn(model, context) # Get dependencies from __deps__ @@ -328,36 +325,27 @@ def _resolve_deps_and_call(model, context, fn): # Resolve dependencies and store in context var resolved_values = {} - # If custom deps, resolve ALL CallableModel fields from dep_map - if has_custom_deps: - for dep_model, contexts in deps_result: + # Standard path: iterate over Dep-annotated fields + for field_name, dep in dep_fields.items(): + field_value = getattr(model, field_name, None) + if field_value is None: + continue + + # Check if field is a CallableModel that needs resolution + if not isinstance(field_value, _CallableModel): + continue # Already a resolved value, skip + + # Check if this field is in __deps__ (for custom transforms) + if id(field_value) in dep_map: + dep_model, contexts = dep_map[id(field_value)] + # Call dependency with the (transformed) context resolved = dep_model(contexts[0]) if contexts else dep_model(context) - # Unwrap GenericResult if present (consistent with auto-detected deps) - if hasattr(resolved, "value"): - resolved = resolved.value - resolved_values[id(dep_model)] = resolved - else: - # Standard path: iterate over Dep-annotated fields - for field_name, dep in dep_fields.items(): - field_value = getattr(model, field_name, None) - if field_value is None: - continue - - # Check if field is a CallableModel that needs resolution - if not isinstance(field_value, _CallableModel): - continue # Already a resolved value, skip - - # Check if this field is in __deps__ (for custom transforms) - if id(field_value) in dep_map: - dep_model, contexts = dep_map[id(field_value)] - # Call dependency with the (transformed) context - resolved = dep_model(contexts[0]) if contexts else dep_model(context) - else: - # Not in __deps__, use Dep annotation transform directly - transformed_ctx = dep.apply(context) - resolved = field_value(transformed_ctx) + else: + # Not in __deps__, use Dep annotation transform directly + transformed_ctx = dep.apply(context) + resolved = field_value(transformed_ctx) - resolved_values[id(field_value)] = resolved + resolved_values[id(field_value)] = resolved # Store in context var and call function current_store = _resolved_deps.get() diff --git a/ccflow/context.py b/ccflow/context.py index 50ff6dc..0d00d2e 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -1,7 +1,7 @@ """This module defines re-usable contexts for the "Callable Model" framework defined in flow.callable.py.""" from datetime import date, datetime -from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated from pydantic import ConfigDict, field_validator, model_validator @@ -102,29 +102,10 @@ class FlowContext(ContextBase): - Proliferation of dynamic _funcname_Context classes - Class registration overhead for serialization - Pickling issues with Ray/distributed computing - - Fields are stored in __pydantic_extra__ and accessed via __getattr__. """ model_config = ConfigDict(extra="allow", frozen=True) - def __getattr__(self, name: str) -> Any: - """Access fields stored in __pydantic_extra__.""" - # Use object.__getattribute__ to avoid infinite recursion - try: - extra = object.__getattribute__(self, "__pydantic_extra__") - if extra is not None and name in extra: - return extra[name] - except AttributeError: - pass - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - - def __repr__(self) -> str: - """Show all fields including extra fields.""" - extra = object.__getattribute__(self, "__pydantic_extra__") or {} - fields = ", ".join(f"{k}={v!r}" for k, v in extra.items()) - return f"FlowContext({fields})" - C = TypeVar("C", bound=Hashable) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 414202e..e9f2704 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -12,20 +12,121 @@ import inspect import logging from functools import wraps -from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_origin +from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin from pydantic import Field, TypeAdapter from typing_extensions import TypedDict from .base import ContextBase, ResultBase +from .callable import CallableModel, Flow, GraphDepList, _CallableModel from .context import FlowContext from .dep import Dep, extract_dep +from .local_persistence import register_ccflow_import_path +from .result import GenericResult + +__all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy", "FieldExtractor") + + +class _LazyMarker: + """Sentinel that marks a parameter as lazily evaluated via Lazy[T].""" + + pass + + +def _extract_lazy(annotation) -> Tuple[Any, bool]: + """Check if annotation is Lazy[T]. Returns (base_type, is_lazy). + + Handles nested Annotated types — e.g. Lazy[Annotated[T, Dep(...)]] produces + Annotated[Annotated[T, Dep(...)], _LazyMarker()], so we need to check the + outermost Annotated layer for _LazyMarker. + """ + if get_origin(annotation) is Annotated: + args = get_args(annotation) + for metadata in args[1:]: + if isinstance(metadata, _LazyMarker): + return args[0], True + return annotation, False + + +def _make_lazy_thunk(model, context): + """Create a zero-arg callable that evaluates model(context) on demand. + + The thunk caches its result so repeated calls don't re-evaluate. + """ + _cache = {} + + def thunk(): + if "result" not in _cache: + result = model(context) + if isinstance(result, GenericResult): + result = result.value + _cache["result"] = result + return _cache["result"] + + return thunk -__all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy") log = logging.getLogger(__name__) +def _context_values(context: ContextBase) -> Dict[str, Any]: + """Return a plain mapping of all context values. + + `dict(context)` uses pydantic's public iteration behavior, which includes + both declared fields and any allowed extra fields. + """ + + return dict(context) + + +def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter: + """Build a TypeAdapter for a runtime TypedDict schema.""" + + if not schema: + return TypeAdapter(dict) + return TypeAdapter(TypedDict(name, schema)) + + +def _build_config_validators( + all_param_types: Dict[str, Type], dep_fields: Dict[str, Tuple[Type, Dep]] +) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: + """Precompute validators for non-dependency config fields.""" + + validatable_types: Dict[str, Type] = {} + for name, typ in all_param_types.items(): + if name in dep_fields: + continue + try: + TypeAdapter(typ) + validatable_types[name] = typ + except Exception: + pass + + validators = {name: TypeAdapter(typ) for name, typ in validatable_types.items()} + return validatable_types, validators + + +def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, Type], validators: Dict[str, TypeAdapter]) -> None: + """Validate plain config inputs while still allowing dependency objects.""" + + if not validators: + return + + from .callable import CallableModel as _CM + + for field_name, validator in validators.items(): + if field_name not in kwargs: + continue + value = kwargs[field_name] + if value is None or isinstance(value, (_CM, BoundModel)): + continue + try: + validator.validate_python(value) + except Exception: + expected_type = validatable_types[field_name] + raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") + + class FlowAPI: """API namespace for deferred computation operations. @@ -61,7 +162,7 @@ def compute(self, **kwargs) -> Any: result = self._model(ctx) # Unwrap GenericResult if present - if hasattr(result, "value"): + if isinstance(result, GenericResult): return result.value return result @@ -111,6 +212,20 @@ class BoundModel: Created by model.flow.with_inputs(). Applies transforms to context before delegating to the underlying model. + + Context propagation across dependencies: + Each BoundModel transforms the context locally — only for the model it + wraps. When used as a dependency inside another model, the FlowContext + flows through the chain unchanged until it reaches this BoundModel, + which intercepts it, applies its transforms, and passes the modified + context to the wrapped model. Upstream models never see the transform. + + Chaining with_inputs: + Calling ``bound.flow.with_inputs(...)`` merges the new transforms with + the existing ones (new overrides old for the same key). All transforms + are applied to the incoming context in one pass — they don't compose + sequentially (each transform sees the original context, not the output + of a previous transform). """ def __init__(self, model: "CallableModel", input_transforms: Dict[str, Any]): # noqa: F821 @@ -120,13 +235,7 @@ def __init__(self, model: "CallableModel", input_transforms: Dict[str, Any]): # def __call__(self, context: ContextBase) -> Any: """Call the model with transformed context.""" # Build new context dict with transforms applied - ctx_dict = {} - - # Get fields from context - if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: - ctx_dict.update(context.__pydantic_extra__) - for field in context.__class__.model_fields: - ctx_dict[field] = getattr(context, field) + ctx_dict = _context_values(context) # Apply transforms for name, transform in self._input_transforms.items(): @@ -140,35 +249,125 @@ def __call__(self, context: ContextBase) -> Any: return self._model(new_ctx) @property - def flow(self) -> FlowAPI: + def flow(self) -> "FlowAPI": """Access the flow API.""" - return FlowAPI(self._model) + return _BoundFlowAPI(self) + + +class _BoundFlowAPI(FlowAPI): + """FlowAPI that delegates to a BoundModel, honoring transforms.""" + + def __init__(self, bound_model: BoundModel): + self._bound = bound_model + super().__init__(bound_model._model) + + def compute(self, **kwargs) -> Any: + validator = self._model._get_context_validator() + validated = validator.validate_python(kwargs) + ctx = FlowContext(**validated) + result = self._bound(ctx) # Call through BoundModel, not _model + if isinstance(result, GenericResult): + return result.value + return result + + def with_inputs(self, **transforms) -> "BoundModel": + """Chain transforms: merge new transforms with existing ones. + + New transforms override existing ones for the same key. + """ + merged = {**self._bound._input_transforms, **transforms} + return BoundModel(model=self._bound._model, input_transforms=merged) + + +class _FieldExtractorMixin: + """Turn unknown public attributes into FieldExtractors. + + Real model attributes are still resolved by the normal pydantic/base-model + attribute path via ``super().__getattr__``. + """ + + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + if name.startswith("_"): + raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") from None + return FieldExtractor(source=self, field_name=name) + + +class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): + """Shared behavior for models generated by ``@Flow.model``.""" + + @property + def context_type(self) -> Type[ContextBase]: + return self.__class__.__flow_model_context_type__ + + @property + def result_type(self) -> Type[ResultBase]: + return self.__class__.__flow_model_return_type__ + + @property + def flow(self) -> FlowAPI: + return FlowAPI(self) + + def _get_context_validator(self) -> TypeAdapter: + """Get or create the context validator for this generated model.""" + + cls = self.__class__ + explicit_args = getattr(cls, "__flow_model_explicit_context_args__", None) + + if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): + if cls._cached_context_validator is None: + if cls._context_td is not None: + cls._cached_context_validator = TypeAdapter(cls._context_td) + elif cls._context_schema: + cls._cached_context_validator = _build_typed_dict_adapter(f"{cls.__name__}Inputs", cls._context_schema) + else: + cls._cached_context_validator = TypeAdapter(cls.__flow_model_context_type__) + return cls._cached_context_validator + + if not hasattr(self, "_instance_context_validator"): + all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) + bound_fields = getattr(self, "_bound_fields", set()) + unbound_schema = {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + object.__setattr__(self, "_instance_context_validator", _build_typed_dict_adapter(f"{cls.__name__}Inputs", unbound_schema)) + return self._instance_context_validator class Lazy: """Deferred model execution with runtime context overrides. - Wraps a CallableModel to allow context fields to be determined at - runtime rather than at construction time. Use in with_inputs() when - you need values that aren't available until execution. + Has two distinct uses: - Example: - # Create a model that needs runtime-determined context - market_data = load_market_data(symbols=["AAPL"]) + 1. **Type annotation** — ``Lazy[T]`` marks a parameter as lazily evaluated. + The framework will NOT pre-evaluate the dependency; instead the function + receives a zero-arg thunk that triggers evaluation on demand:: - # Use Lazy to defer the start_date calculation - lookback_data = market_data.flow.with_inputs( - start_date=Lazy(market_data)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) - ) + @Flow.model + def smart_training( + data: PreparedData, + fast_metrics: Metrics, + slow_metrics: Lazy[Metrics], # NOT eagerly evaluated + threshold: float = 0.9, + ) -> Metrics: + if fast_metrics.r2 > threshold: + return fast_metrics + return slow_metrics() # Evaluated on demand + + 2. **Runtime helper** — ``Lazy(model)(overrides)`` creates a callable that + applies context overrides before calling the model. Used with + ``with_inputs()`` for deferred execution:: + + lookback = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + """ - # More commonly, use Lazy for self-referential transforms: - adjusted_model = model.flow.with_inputs( - value=Lazy(other_model)(multiplier=2) # Call other_model with multiplier=2 - ) + def __class_getitem__(cls, item): + """Support Lazy[T] syntax as a type annotation marker. - The __call__ method returns a callable that, when invoked with a context, - calls the wrapped model with the specified overrides applied. - """ + Returns Annotated[T, _LazyMarker()] so the framework can detect + lazy parameters during signature analysis. + """ + return Annotated[item, _LazyMarker()] def __init__(self, model: "CallableModel"): # noqa: F821 """Wrap a model for deferred execution. @@ -193,11 +392,7 @@ def __call__(self, **overrides) -> Callable[[ContextBase], Any]: def execute_with_overrides(context: ContextBase) -> Any: # Build context dict from incoming context - ctx_dict = {} - if hasattr(context, "__pydantic_extra__") and context.__pydantic_extra__: - ctx_dict.update(context.__pydantic_extra__) - for field in context.__class__.model_fields: - ctx_dict[field] = getattr(context, field) + ctx_dict = _context_values(context) # Apply overrides for name, value in overrides.items(): @@ -275,18 +470,23 @@ def _get_dep_info(annotation) -> Tuple[Type, Optional[Dep]]: return extract_dep(annotation) +_UNSET = object() + + def flow_model( func: Callable = None, *, # Context handling context_args: Optional[List[str]] = None, # Flow.call options (passed to generated __call__) - cacheable: bool = False, - volatile: bool = False, - log_level: int = logging.DEBUG, - validate_result: bool = True, - verbose: bool = True, - evaluator: Optional[Any] = None, + # Default to _UNSET so FlowOptionsOverride can control these globally. + # Only explicitly user-provided values are passed to Flow.call. + cacheable: Any = _UNSET, + volatile: Any = _UNSET, + log_level: Any = _UNSET, + validate_result: Any = _UNSET, + verbose: Any = _UNSET, + evaluator: Any = _UNSET, ) -> Callable: """Decorator that generates a CallableModel class from a plain Python function. @@ -297,12 +497,12 @@ def flow_model( Args: func: The function to decorate context_args: List of parameter names that come from context (for unpacked mode) - cacheable: Enable caching of results - volatile: Mark as volatile (always re-execute) - log_level: Logging verbosity - validate_result: Validate return type - verbose: Verbose logging output - evaluator: Custom evaluator + cacheable: Enable caching of results (default: unset, inherits from FlowOptionsOverride) + volatile: Mark as volatile (always re-execute) (default: unset, inherits from FlowOptionsOverride) + log_level: Logging verbosity (default: unset, inherits from FlowOptionsOverride) + validate_result: Validate return type (default: unset, inherits from FlowOptionsOverride) + verbose: Verbose logging output (default: unset, inherits from FlowOptionsOverride) + evaluator: Custom evaluator (default: unset, inherits from FlowOptionsOverride) Two Context Modes: 1. Explicit context parameter: Function has a 'context' parameter annotated @@ -323,29 +523,40 @@ def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[ """ def decorator(fn: Callable) -> Callable: - # Import here to avoid circular imports - from .callable import CallableModel, Flow, GraphDepList + import typing as _typing sig = inspect.signature(fn) params = sig.parameters + # Resolve string annotations (PEP 563 / from __future__ import annotations) + # into real type objects. include_extras=True preserves Annotated metadata. + try: + _resolved_hints = _typing.get_type_hints(fn, include_extras=True) + except Exception: + _resolved_hints = {} + # Validate return type - return_type = sig.return_annotation + return_type = _resolved_hints.get("return", sig.return_annotation) if return_type is inspect.Signature.empty: raise TypeError(f"Function {fn.__name__} must have a return type annotation") - # Check that return type is a ResultBase subclass + # Check if return type is a ResultBase subclass; if not, auto-wrap in GenericResult return_origin = get_origin(return_type) or return_type + auto_wrap_result = False if not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)): - raise TypeError(f"Function {fn.__name__} must return a ResultBase subclass, got {return_type}") + auto_wrap_result = True + internal_return_type = GenericResult # unparameterized for safety + else: + internal_return_type = return_type # Determine context mode if "context" in params or "_" in params: # Mode 1: Explicit context parameter (named 'context' or '_' for unused) context_param_name = "context" if "context" in params else "_" context_param = params[context_param_name] - if context_param.annotation is inspect.Parameter.empty: + context_annotation = _resolved_hints.get(context_param_name, context_param.annotation) + if context_annotation is inspect.Parameter.empty: raise TypeError(f"Function {fn.__name__}: '{context_param_name}' parameter must have a type annotation") - context_type = context_param.annotation + context_type = context_annotation if not (isinstance(context_type, type) and issubclass(context_type, ContextBase)): raise TypeError(f"Function {fn.__name__}: '{context_param_name}' must be annotated with a ContextBase subclass") model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} @@ -372,19 +583,28 @@ def decorator(fn: Callable) -> Callable: use_context_args = True explicit_context_args = None # Dynamic - determined at construction - # Analyze parameters to find dependencies and regular fields + # Analyze parameters to find dependencies, lazy fields, and regular fields dep_fields: Dict[str, Tuple[Type, Dep]] = {} # name -> (base_type, Dep) model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) + lazy_fields: set = set() # Names of parameters marked with Lazy[T] # In dynamic deferred mode (no explicit context_args), all fields are optional # because values not provided at construction come from context at runtime dynamic_deferred_mode = use_context_args and explicit_context_args is None for name, param in model_field_params.items(): - if param.annotation is inspect.Parameter.empty: + # Use resolved hint (handles PEP 563 string annotations) + annotation = _resolved_hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty: raise TypeError(f"Parameter '{name}' must have a type annotation") - base_type, dep = _get_dep_info(param.annotation) + # Check for Lazy[T] annotation first + unwrapped_annotation, is_lazy = _extract_lazy(annotation) + if is_lazy: + lazy_fields.add(name) + + # Extract Dep info from the (possibly unwrapped) annotation + base_type, dep = _get_dep_info(unwrapped_annotation) if param.default is not inspect.Parameter.empty: default = param.default elif dynamic_deferred_mode: @@ -409,7 +629,7 @@ def decorator(fn: Callable) -> Callable: # Capture variables for closures ctx_param_name = context_param_name if not use_context_args else "context" all_param_names = list(model_fields.keys()) # All non-context params (model fields) - all_param_types = {name: param.annotation for name, param in model_field_params.items()} + all_param_types = {name: _resolved_hints.get(name, param.annotation) for name, param in model_field_params.items()} # For explicit context_args mode, we also need the list of context arg names ctx_args_for_closure = context_args if context_args is not None else [] is_dynamic_mode = use_context_args and explicit_context_args is None @@ -420,26 +640,14 @@ def __call__(self, context): # Import here (inside function) to avoid pickling issues with ContextVar from .callable import _resolved_deps - # Check if this model has custom deps (from @func.deps decorator) - has_custom_deps = getattr(self.__class__, "__has_custom_deps__", False) - def resolve_callable_model(name, value, store): - """Resolve a CallableModel field. - - When has_custom_deps is True and the value is NOT in the store, - it means the custom deps function chose not to include this dep. - In that case, we return None (the field's default) instead of - calling the CallableModel directly. - """ + """Resolve a CallableModel field.""" if id(value) in store: return store[id(value)] - elif has_custom_deps: - # Custom deps excluded this field - use None - return None else: # Auto-detection fallback: call directly resolved = value(context) - if hasattr(resolved, "value"): + if isinstance(resolved, GenericResult): return resolved.value return resolved @@ -447,16 +655,28 @@ def resolve_callable_model(name, value, store): fn_kwargs = {} store = _resolved_deps.get() + def _resolve_field(name, value): + """Resolve a single field value, handling lazy wrapping.""" + is_dep = isinstance(value, (CallableModel, BoundModel)) + if name in lazy_fields: + # Lazy field: wrap in a thunk regardless of type + if is_dep: + return _make_lazy_thunk(value, context) + else: + # Non-dep value: wrap in trivial thunk + return lambda v=value: v + elif is_dep: + return resolve_callable_model(name, value, store) + else: + return value + if not use_context_args: # Mode 1: Explicit context param - pass context directly fn_kwargs[ctx_param_name] = context # Add model fields for name in all_param_names: value = getattr(self, name) - if isinstance(value, CallableModel): - fn_kwargs[name] = resolve_callable_model(name, value, store) - else: - fn_kwargs[name] = value + fn_kwargs[name] = _resolve_field(name, value) elif not is_dynamic_mode: # Mode 2: Explicit context_args - get those from context, rest from self for name in ctx_args_for_closure: @@ -464,10 +684,7 @@ def resolve_callable_model(name, value, store): # Add model fields for name in all_param_names: value = getattr(self, name) - if isinstance(value, CallableModel): - fn_kwargs[name] = resolve_callable_model(name, value, store) - else: - fn_kwargs[name] = value + fn_kwargs[name] = _resolve_field(name, value) else: # Mode 3: Dynamic deferred mode - unbound from context, bound from self bound_fields = getattr(self, "_bound_fields", set()) @@ -476,15 +693,15 @@ def resolve_callable_model(name, value, store): if name in bound_fields: # Bound at construction - get from self value = getattr(self, name) - if isinstance(value, CallableModel): - fn_kwargs[name] = resolve_callable_model(name, value, store) - else: - fn_kwargs[name] = value + fn_kwargs[name] = _resolve_field(name, value) else: # Unbound - get from context fn_kwargs[name] = getattr(context, name) - return fn(**fn_kwargs) + raw_result = fn(**fn_kwargs) + if auto_wrap_result: + return GenericResult(value=raw_result) + return raw_result # Set proper signature for CallableModel validation __call__.__signature__ = inspect.Signature( @@ -492,22 +709,24 @@ def resolve_callable_model(name, value, store): inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), ], - return_annotation=return_type, + return_annotation=internal_return_type, ) return __call__ call_impl = make_call_impl() - # Apply Flow.call decorator - flow_options = { - "cacheable": cacheable, - "volatile": volatile, - "log_level": log_level, - "validate_result": validate_result, - "verbose": verbose, - } - if evaluator is not None: - flow_options["evaluator"] = evaluator + # Apply Flow.call decorator — only include options the user explicitly set + flow_options = {} + for opt_name, opt_val in [ + ("cacheable", cacheable), + ("volatile", volatile), + ("log_level", log_level), + ("validate_result", validate_result), + ("verbose", verbose), + ("evaluator", evaluator), + ]: + if opt_val is not _UNSET: + flow_options[opt_name] = opt_val decorated_call = Flow.call(**flow_options)(call_impl) @@ -515,10 +734,12 @@ def resolve_callable_model(name, value, store): def make_deps_impl(): def __deps__(self, context) -> GraphDepList: deps = [] - # Check ALL fields for CallableModels (auto-detection) + # Check ALL fields for CallableModels/BoundModels (auto-detection) for name in model_fields: + if name in lazy_fields: + continue # Lazy deps are NOT pre-evaluated value = getattr(self, name) - if isinstance(value, CallableModel): + if isinstance(value, (CallableModel, BoundModel)): if name in dep_fields: # Explicit DepOf with transform (backwards compat) _, dep_obj = dep_fields[name] @@ -582,17 +803,20 @@ def __validate_deps__(self): namespace["__validate_deps__"] = make_dep_validator(dep_fields, context_type) + _validatable_types, _config_validators = _build_config_validators(all_param_types, dep_fields) + # Create the class using type() - GeneratedModel = type(f"_{fn.__name__}_Model", (CallableModel,), namespace) + GeneratedModel = type(f"_{fn.__name__}_Model", (_GeneratedFlowModelBase,), namespace) # Set class-level attributes after class creation (to avoid pydantic processing) GeneratedModel.__flow_model_context_type__ = context_type - GeneratedModel.__flow_model_return_type__ = return_type + GeneratedModel.__flow_model_return_type__ = internal_return_type GeneratedModel.__flow_model_func__ = fn GeneratedModel.__flow_model_dep_fields__ = dep_fields GeneratedModel.__flow_model_use_context_args__ = use_context_args GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type + GeneratedModel.__flow_model_auto_wrap__ = auto_wrap_result # Build context_schema and matched_context_type context_schema: Dict[str, Type] = {} @@ -617,68 +841,9 @@ def __validate_deps__(self): # Validator is created lazily to survive pickling GeneratedModel._cached_context_validator = None - # Method to get/create context validator (lazy for pickling support) - def _get_context_validator(self) -> TypeAdapter: - """Get or create the context validator. - - For dynamic deferred mode, builds schema from unbound fields. - For explicit context_args or explicit context mode, uses cached schema. - """ - cls = self.__class__ - explicit_args = getattr(cls, "__flow_model_explicit_context_args__", None) - - # For explicit context_args or explicit context mode, use cached validator - if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): - if cls._cached_context_validator is None: - if cls._context_td is not None: - cls._cached_context_validator = TypeAdapter(cls._context_td) - elif cls._context_schema: - td = TypedDict(f"{cls.__name__}Inputs", cls._context_schema) - cls._cached_context_validator = TypeAdapter(td) - else: - cls._cached_context_validator = TypeAdapter(cls.__flow_model_context_type__) - return cls._cached_context_validator - - # Dynamic mode: build schema from unbound fields (instance-specific) - # Cache on instance since bound_fields varies per instance - if not hasattr(self, "_instance_context_validator"): - all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) - bound_fields = getattr(self, "_bound_fields", set()) - unbound_schema = {name: typ for name, typ in all_param_types.items() if name not in bound_fields} - if unbound_schema: - td = TypedDict(f"{cls.__name__}Inputs", unbound_schema) - object.__setattr__(self, "_instance_context_validator", TypeAdapter(td)) - else: - # No unbound fields - empty validator - object.__setattr__(self, "_instance_context_validator", TypeAdapter(dict)) - return self._instance_context_validator - - GeneratedModel._get_context_validator = _get_context_validator - - # Override context_type property after class creation - @property - def context_type_getter(self) -> Type[ContextBase]: - return self.__class__.__flow_model_context_type__ - - # Override result_type property after class creation - @property - def result_type_getter(self) -> Type[ResultBase]: - return self.__class__.__flow_model_return_type__ - - # Add .flow property for the new API - @property - def flow_getter(self) -> FlowAPI: - return FlowAPI(self) - - GeneratedModel.context_type = context_type_getter - GeneratedModel.result_type = result_type_getter - GeneratedModel.flow = flow_getter - # Register the MODEL class for serialization (needed for model_dump/_target_). # Note: We do NOT register dynamic context classes anymore - context handling # uses FlowContext + TypedDict instead, which don't need registration. - from .local_persistence import register_ccflow_import_path - register_ccflow_import_path(GeneratedModel) # Rebuild the model to process annotations properly @@ -687,6 +852,8 @@ def flow_getter(self) -> FlowAPI: # Create factory function that returns model instances @wraps(fn) def factory(**kwargs) -> GeneratedModel: + _validate_config_kwargs(kwargs, _validatable_types, _config_validators) + instance = GeneratedModel(**kwargs) # Track which fields were explicitly provided at construction # These are "bound" - everything else comes from context at runtime @@ -697,49 +864,68 @@ def factory(**kwargs) -> GeneratedModel: factory._generated_model = GeneratedModel factory.__doc__ = fn.__doc__ - # Add .deps decorator for customizing __deps__ - def deps_decorator(deps_fn): - """Decorator to customize the __deps__ method. - - Usage: - @Flow.model - def my_func(start_date: date, prices: dict) -> GenericResult[...]: - ... - - @my_func.deps - def _(self, context): - # Custom context transform - lookback_ctx = FlowContext( - start_date=context.start_date - timedelta(days=30), - end_date=context.end_date, - ) - return [(self.prices, [lookback_ctx])] - """ - from .callable import GraphDepList - - # Rename the function to __deps__ so Flow.deps accepts it - deps_fn.__name__ = "__deps__" - deps_fn.__qualname__ = f"{GeneratedModel.__qualname__}.__deps__" - # Set proper signature to match __call__'s context type - deps_fn.__signature__ = inspect.Signature( - parameters=[ - inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), - ], - return_annotation=GraphDepList, - ) - # Wrap with Flow.deps and replace on the class - decorated = Flow.deps(deps_fn) - GeneratedModel.__deps__ = decorated - # Mark that this model has custom deps (so _resolve_deps_and_call will call it) - GeneratedModel.__has_custom_deps__ = True - return factory # Return factory for chaining - - factory.deps = deps_decorator - return factory # Handle both @Flow.model and @Flow.model(...) syntax if func is not None: return decorator(func) return decorator + + +# ============================================================================= +# FieldExtractor — structured output field access +# ============================================================================= + + +class FieldExtractor(_FieldExtractorMixin, CallableModel): + """Extracts a named field from a source model's result. + + Created automatically by accessing an unknown attribute on a @Flow.model + instance (e.g., ``prepared.X_train``). The extractor is itself a + CallableModel, so it can be wired as a dependency to downstream models. + + When evaluated, it runs the source model and returns + ``GenericResult(value=getattr(source_result, field_name))``. + + Multiple extractors from the same source share the source model instance. + If caching is enabled on the evaluator, the source is evaluated only once. + """ + + source: Any # The source CallableModel + field_name: str # The attribute to extract + + @property + def context_type(self): + if isinstance(self.source, _CallableModel): + return self.source.context_type + return ContextBase + + @property + def result_type(self): + return GenericResult + + @Flow.call + def __call__(self, context: ContextBase) -> GenericResult: + # Lazy import: _resolved_deps is a ContextVar that can't be pickled + from .callable import _resolved_deps + + store = _resolved_deps.get() + if id(self.source) in store: + result = store[id(self.source)] + else: + result = self.source(context) + if isinstance(result, GenericResult): + result = result.value + # Support both attribute access and dict key access + if isinstance(result, dict): + return GenericResult(value=result[self.field_name]) + return GenericResult(value=getattr(result, self.field_name)) + + @Flow.deps + def __deps__(self, context: ContextBase) -> GraphDepList: + if isinstance(self.source, _CallableModel): + return [(self.source, [context])] + return [] + + +register_ccflow_import_path(FieldExtractor) diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 3f613ab..61869f9 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -26,12 +26,12 @@ def test_flow_context_basic(self): assert ctx.end_date == date(2024, 1, 31) def test_flow_context_extra_fields(self): - """FlowContext stores fields in __pydantic_extra__.""" + """FlowContext exposes arbitrary fields through normal model APIs.""" ctx = FlowContext(x=1, y="hello", z=[1, 2, 3]) assert ctx.x == 1 assert ctx.y == "hello" assert ctx.z == [1, 2, 3] - assert ctx.__pydantic_extra__ == {"x": 1, "y": "hello", "z": [1, 2, 3]} + assert dict(ctx) == {"x": 1, "y": "hello", "z": [1, 2, 3]} def test_flow_context_frozen(self): """FlowContext is immutable (frozen).""" diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index b283a2b..b547aee 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -15,6 +15,7 @@ DepOf, Flow, GenericResult, + Lazy, ModelRegistry, ResultBase, ) @@ -599,15 +600,55 @@ def no_return(context: SimpleContext): self.assertIn("return type annotation", str(cm.exception)) - def test_non_result_return_type(self): - """Test error when return type is not ResultBase subclass.""" - with self.assertRaises(TypeError) as cm: + def test_auto_wrap_plain_return_type(self): + """Test that non-ResultBase return types are auto-wrapped in GenericResult.""" - @Flow.model - def bad_return(context: SimpleContext) -> int: - return 42 + @Flow.model + def plain_return(context: SimpleContext) -> int: + return context.value * 2 + + model = plain_return() + result = model(SimpleContext(value=5)) + self.assertIsInstance(result, GenericResult) + self.assertEqual(result.value, 10) + + def test_auto_wrap_unwrap_as_dependency(self): + """Test that auto-wrapped model used as dep delivers unwrapped value downstream. + + Auto-wrapped models have result_type=GenericResult (unparameterized). + When used as an auto-detected dep (no DepOf), the framework resolves + the GenericResult and unwraps .value for the downstream function. + """ + + @Flow.model + def plain_source(context: SimpleContext) -> int: + return context.value * 3 - self.assertIn("ResultBase", str(cm.exception)) + @Flow.model + def consumer( + context: SimpleContext, + data: GenericResult[int], # Auto-detected dep, not DepOf + ) -> GenericResult[int]: + # data is auto-unwrapped to the int value by the framework + return GenericResult(value=data + 1) + + src = plain_source() + model = consumer(data=src) + result = model(SimpleContext(value=10)) + # plain_source: 10 * 3 = 30, auto-wrapped to GenericResult(value=30) + # resolve_callable_model unwraps GenericResult -> 30 + # consumer: 30 + 1 = 31 + self.assertEqual(result.value, 31) + + def test_auto_wrap_result_type_property(self): + """Test that auto-wrapped model has GenericResult as result_type.""" + + @Flow.model + def plain_return(context: SimpleContext) -> int: + return context.value + + model = plain_return() + self.assertEqual(model.result_type, GenericResult) def test_dynamic_deferred_mode(self): """Test dynamic deferred mode where what you provide at construction = bound.""" @@ -953,6 +994,271 @@ def consumer( with self.assertRaises((TypeError, ValidationError)): consumer(data=load2) + def test_config_validation_rejects_bad_type(self): + """Test that config validator rejects wrong types at construction.""" + + @Flow.model + def typed_config(context: SimpleContext, n_estimators: int = 10) -> GenericResult[int]: + return GenericResult(value=n_estimators) + + with self.assertRaises(TypeError) as cm: + typed_config(n_estimators="banana") + + self.assertIn("n_estimators", str(cm.exception)) + + def test_config_validation_accepts_callable_model(self): + """Test that config validator allows CallableModel values for any field.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: + return GenericResult(value=data) + + # Passing a CallableModel for an int field should not raise + src = source() + model = consumer(data=src) + self.assertIsNotNone(model) + + def test_config_validation_accepts_correct_types(self): + """Test that config validator accepts correct types.""" + + @Flow.model + def typed_config(context: SimpleContext, n: int = 10, name: str = "x") -> GenericResult[str]: + return GenericResult(value=f"{name}:{n}") + + # Should not raise + model = typed_config(n=42, name="test") + result = model(SimpleContext(value=1)) + self.assertEqual(result.value, "test:42") + + +# ============================================================================= +# BoundModel Tests +# ============================================================================= + + +class TestBoundModel(TestCase): + """Tests for BoundModel and BoundModel.flow.""" + + def test_bound_model_flow_compute(self): + """Test that bound.flow.compute() honors transforms.""" + + @Flow.model + def my_model(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x + y) + + model = my_model(x=10) + + # Create bound model with y transform + bound = model.flow.with_inputs(y=lambda ctx: getattr(ctx, "y", 0) * 2) + + # flow.compute() should go through BoundModel, applying transform + result = bound.flow.compute(y=5) + # y transform: 5 * 2 = 10, x is bound to 10 + # model: 10 + 10 = 20 + self.assertEqual(result, 20) + + def test_bound_model_flow_compute_static_transform(self): + """Test BoundModel.flow.compute() with static value transform.""" + + @Flow.model + def my_model(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x * y) + + model = my_model(x=7) + bound = model.flow.with_inputs(y=3) + + result = bound.flow.compute(y=999) # y should be overridden by transform + # y is statically bound to 3, x=7 + # 7 * 3 = 21 + self.assertEqual(result, 21) + + def test_bound_model_as_dependency(self): + """Test that BoundModel can be passed as a dependency to another model.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model + def consumer(data: GenericResult[int]) -> GenericResult[int]: + return GenericResult(value=data + 1) + + src = source() + bound_src = src.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) + + # Pass BoundModel as a dependency + model = consumer(data=bound_src) + result = model.flow.compute(x=5) + # x transform: 5 * 2 = 10 + # source: 10 * 10 = 100 + # consumer: 100 + 1 = 101 + self.assertEqual(result, 101) + + def test_bound_model_chained_with_inputs(self): + """Test that chaining with_inputs merges transforms correctly.""" + + @Flow.model + def my_model(x: int, y: int, z: int) -> int: + return x + y + z + + model = my_model() + bound1 = model.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) + bound2 = bound1.flow.with_inputs(y=lambda ctx: getattr(ctx, "y", 0) * 3) + + # Both transforms should be active + result = bound2.flow.compute(x=5, y=10, z=1) + # x transform: 5 * 2 = 10 + # y transform: 10 * 3 = 30 + # z from context: 1 + # 10 + 30 + 1 = 41 + self.assertEqual(result, 41) + + def test_bound_model_chained_with_inputs_override(self): + """Test that chaining with_inputs allows overriding transforms.""" + + @Flow.model + def my_model(x: int) -> int: + return x + + model = my_model() + bound1 = model.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) + bound2 = bound1.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 10) + + # Second transform should override the first for 'x' + result = bound2.flow.compute(x=5) + self.assertEqual(result, 50) # 5 * 10, not 5 * 2 + + def test_bound_model_with_default_args(self): + """with_inputs works when the model has parameters with default values.""" + + @Flow.model + def load(start_date: str, end_date: str, source: str = "warehouse") -> str: + return f"{source}:{start_date}-{end_date}" + + # Bind source at construction, leave dates for context + model = load(source="prod_db") + + # with_inputs transforms a context param; default-valued 'source' stays bound + lookback = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) + + result = lookback.flow.compute(start_date="2024-01-01", end_date="2024-06-30") + self.assertEqual(result, "prod_db:shifted_2024-01-01-2024-06-30") + + def test_bound_model_with_default_arg_unbound(self): + """with_inputs works when defaulted parameter is left unbound (comes from context).""" + + @Flow.model + def load(start_date: str, source: str = "warehouse") -> str: + return f"{source}:{start_date}" + + # Don't bind 'source' — it keeps its default in the model, + # but in dynamic deferred mode, unbound params come from context + model = load() + + # Transform start_date; source comes from context (overriding the default) + bound = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) + + result = bound.flow.compute(start_date="2024-01-01", source="s3_bucket") + self.assertEqual(result, "s3_bucket:shifted_2024-01-01") + + def test_bound_model_default_arg_as_dependency(self): + """BoundModel with default args works correctly as a dependency.""" + + @Flow.model + def source(x: int, multiplier: int = 2) -> int: + return x * multiplier + + @Flow.model + def consumer(data: int) -> int: + return data + 1 + + src = source(multiplier=5) + bound_src = src.flow.with_inputs(x=lambda ctx: ctx.x * 10) + model = consumer(data=bound_src) + + result = model.flow.compute(x=3) + # x transform: 3 * 10 = 30 + # source: 30 * 5 (multiplier) = 150 + # consumer: 150 + 1 = 151 + self.assertEqual(result, 151) + + def test_bound_model_as_lazy_dependency(self): + """Test that BoundModel works as a Lazy dependency.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 3) + + @Flow.model + def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: + if data > 100: + return GenericResult(value=data) + return GenericResult(value=slow()) + + src = source() + bound_src = src.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) + 10) + + # Use BoundModel as lazy dependency + model = consumer(data=5, slow=bound_src) + result = model.flow.compute(x=7) + # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 + self.assertEqual(result, 51) + + +# ============================================================================= +# PEP 563 (from __future__ import annotations) Compatibility Tests +# ============================================================================= + +# These functions are defined at module level to simulate realistic usage. +# Note: We can't use `from __future__ import annotations` at module level +# since it would affect ALL annotations in this file. Instead, we test +# that the annotation resolution code handles string annotations. + + +class TestPEP563Annotations(TestCase): + """Test that Flow.model handles string annotations (PEP 563).""" + + def test_string_annotation_lazy_resolved(self): + """Test that Lazy annotations work even when passed through get_type_hints. + + This verifies the fix for from __future__ import annotations by + confirming the annotation resolution pipeline processes Lazy correctly. + """ + # Verify _extract_lazy handles real type objects (resolved by get_type_hints) + from ccflow.flow_model import _extract_lazy + + lazy_int = Lazy[int] + unwrapped, is_lazy = _extract_lazy(lazy_int) + self.assertTrue(is_lazy) + self.assertEqual(unwrapped, int) + + def test_string_annotation_return_type_resolved(self): + """Test that string return type annotations are resolved correctly.""" + + @Flow.model + def model_func(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=42) + + # If annotation resolution works, this should create successfully + model = model_func() + self.assertEqual(model.result_type, GenericResult[int]) + + def test_auto_wrap_with_resolved_annotations(self): + """Test that auto-wrap works with properly resolved type annotations.""" + + @Flow.model + def plain_model(value: int) -> int: + return value * 2 + + model = plain_model() + result = model.flow.compute(value=5) + self.assertEqual(result, 10) + self.assertEqual(model.result_type, GenericResult) + # ============================================================================= # Hydra Integration Tests @@ -1555,6 +1861,496 @@ def __deps__(self, context: SimpleContext): self.assertEqual(call_counts["class_model"], 1) +# ============================================================================= +# Lazy[T] Type Annotation Tests +# ============================================================================= + + +class TestLazyTypeAnnotation(TestCase): + """Tests for Lazy[T] type annotation (deferred/conditional evaluation).""" + + def test_lazy_type_annotation_basic(self): + """Lazy[T] param receives a thunk (zero-arg callable). + + The thunk unwraps GenericResult.value, so calling thunk() returns + the inner value (e.g., int), not the GenericResult wrapper. + """ + from ccflow import Lazy + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + # data() returns the unwrapped value (int) + resolved = data() + return GenericResult(value=resolved + 1) + + src = source() + model = consumer(data=src) + result = model(SimpleContext(value=5)) + + # source: 5 * 10 = 50, consumer: 50 + 1 = 51 + self.assertEqual(result.value, 51) + + def test_lazy_conditional_evaluation(self): + """Mirror the smart_training example: lazy dep only evaluated if needed. + + Note: Non-lazy CallableModel deps are auto-resolved and their .value is + unwrapped by the framework (auto-detected dep resolution). So 'fast' + receives the unwrapped int, while 'slow' receives a thunk that returns + the unwrapped value (GenericResult.value) when called. + """ + from ccflow import Lazy + + call_counts = {"fast": 0, "slow": 0} + + @Flow.model + def fast_path(context: SimpleContext) -> GenericResult[int]: + call_counts["fast"] += 1 + return GenericResult(value=context.value) + + @Flow.model + def slow_path(context: SimpleContext) -> GenericResult[int]: + call_counts["slow"] += 1 + return GenericResult(value=context.value * 100) + + @Flow.model + def smart_selector( + context: SimpleContext, + fast: GenericResult[int], # Auto-resolved: receives unwrapped int + slow: Lazy[GenericResult[int]], # Lazy: receives thunk returning unwrapped value + threshold: int = 10, + ) -> GenericResult[int]: + # fast is auto-unwrapped to the int value by the framework + if fast > threshold: + return GenericResult(value=fast) + else: + return GenericResult(value=slow()) + + fast = fast_path() + slow = slow_path() + + # Case 1: fast path sufficient (value > threshold) + model = smart_selector(fast=fast, slow=slow, threshold=10) + result = model(SimpleContext(value=20)) + self.assertEqual(result.value, 20) + self.assertEqual(call_counts["fast"], 1) + self.assertEqual(call_counts["slow"], 0) # Never called! + + # Case 2: fast path insufficient (value <= threshold), slow triggered + call_counts["fast"] = 0 + model2 = smart_selector(fast=fast, slow=slow, threshold=100) + result2 = model2(SimpleContext(value=5)) + self.assertEqual(result2.value, 500) # 5 * 100 + self.assertEqual(call_counts["fast"], 1) + self.assertEqual(call_counts["slow"], 1) + + def test_lazy_thunk_caches_result(self): + """Repeated calls to a thunk return the same value without re-evaluation.""" + from ccflow import Lazy + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + # Call thunk multiple times — returns the unwrapped int + val1 = data() + val2 = data() + val3 = data() + self.assertEqual(val1, val2) + self.assertEqual(val2, val3) + return GenericResult(value=val1) + + src = source() + model = consumer(data=src) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 50) + self.assertEqual(call_counts["source"], 1) # Called only once despite 3 thunk() calls + + def test_lazy_with_direct_value(self): + """Pre-computed (non-CallableModel) value wrapped in trivial thunk.""" + from ccflow import Lazy + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[int], + ) -> GenericResult[int]: + # data is a thunk even though the underlying value is a plain int + return GenericResult(value=data() * 2) + + model = consumer(data=42) + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, 84) + + def test_lazy_dep_excluded_from_deps(self): + """__deps__ does NOT include lazy dependencies.""" + from ccflow import Lazy + + @Flow.model + def eager_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model + def lazy_source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + eager: GenericResult[int], # Auto-resolved, unwrapped to int + lazy_dep: Lazy[GenericResult[int]], # Thunk, returns unwrapped value + ) -> GenericResult[int]: + return GenericResult(value=eager + lazy_dep()) + + eager = eager_source() + lazy = lazy_source() + model = consumer(eager=eager, lazy_dep=lazy) + + ctx = SimpleContext(value=5) + deps = model.__deps__(ctx) + + # Only eager dep should be in __deps__ + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], eager) + + def test_lazy_eager_dep_still_pre_evaluated(self): + """Non-lazy deps are still eagerly resolved via __deps__.""" + from ccflow import Lazy + + call_counts = {"eager": 0, "lazy": 0} + + @Flow.model + def eager_source(context: SimpleContext) -> GenericResult[int]: + call_counts["eager"] += 1 + return GenericResult(value=context.value) + + @Flow.model + def lazy_source(context: SimpleContext) -> GenericResult[int]: + call_counts["lazy"] += 1 + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + eager: GenericResult[int], # Auto-resolved, unwrapped to int + lazy_dep: Lazy[GenericResult[int]], # Thunk, returns unwrapped value + ) -> GenericResult[int]: + # eager is auto-unwrapped to int, lazy_dep() returns unwrapped value + return GenericResult(value=eager + lazy_dep()) + + model = consumer(eager=eager_source(), lazy_dep=lazy_source()) + result = model(SimpleContext(value=5)) + + self.assertEqual(result.value, 55) # 5 + 50 + self.assertEqual(call_counts["eager"], 1) + self.assertEqual(call_counts["lazy"], 1) + + def test_lazy_in_dynamic_deferred_mode(self): + """Lazy[T] works in dynamic deferred mode (no context_args).""" + from ccflow import FlowContext, Lazy + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + value: int, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + if value > 10: + return GenericResult(value=value) + return GenericResult(value=data()) # data() returns unwrapped int + + # value comes from context, data is bound at construction + model = consumer(data=source()) + result = model(FlowContext(value=20)) # value > 10, lazy not called + self.assertEqual(result.value, 20) + self.assertEqual(call_counts["source"], 0) + + def test_lazy_in_context_args_mode(self): + """Lazy[T] works with explicit context_args.""" + from ccflow import FlowContext, Lazy + + @Flow.model(context_args=["x"]) + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model(context_args=["x"]) + def consumer( + x: int, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + return GenericResult(value=x + data()) # data() returns unwrapped int + + model = consumer(data=source()) + result = model(FlowContext(x=5)) + self.assertEqual(result.value, 55) # 5 + 50 + + def test_lazy_never_evaluated_if_not_called(self): + """If thunk is never called, the dependency is never evaluated.""" + from ccflow import Lazy + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[GenericResult[int]], + ) -> GenericResult[int]: + # Never call data() + return GenericResult(value=42) + + model = consumer(data=source()) + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 42) + self.assertEqual(call_counts["source"], 0) + + def test_lazy_with_depof(self): + """Lazy[DepOf[...]] works: lazy dep with explicit DepOf annotation.""" + from ccflow import Lazy + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def consumer( + context: SimpleContext, + data: Lazy[DepOf[..., GenericResult[int]]], + ) -> GenericResult[int]: + return GenericResult(value=data() + 1) # data() returns unwrapped int + + src = source() + model = consumer(data=src) + + # Lazy dep should NOT be in __deps__ + deps = model.__deps__(SimpleContext(value=5)) + self.assertEqual(len(deps), 0) + + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 51) # 50 + 1 + + +# ============================================================================= +# FieldExtractor Tests (Structured Output Field Access) +# ============================================================================= + + +class TestFieldExtractor(TestCase): + """Tests for structured output field access (prepared.X_train pattern).""" + + def test_field_extraction_basic(self): + """Accessing unknown attr on @Flow.model instance returns FieldExtractor.""" + from ccflow.flow_model import FieldExtractor + + @Flow.model + def prepare(context: SimpleContext, factor: int = 2) -> GenericResult[dict]: + return GenericResult(value={"X_train": context.value * factor, "X_test": context.value}) + + model = prepare(factor=3) + extractor = model.X_train + + self.assertIsInstance(extractor, FieldExtractor) + self.assertIs(extractor.source, model) + self.assertEqual(extractor.field_name, "X_train") + + def test_field_extraction_evaluates_correctly(self): + """FieldExtractor runs source and extracts the named field.""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"X_train": [1, 2, 3], "y_train": [4, 5, 6]}) + + model = prepare() + x_train = model.X_train + + result = x_train(SimpleContext(value=0)) + self.assertEqual(result.value, [1, 2, 3]) + + def test_field_extraction_as_dependency(self): + """FieldExtractor wired as a dep to a downstream model. + + Note: FieldExtractors are CallableModels, so they're auto-detected as deps + and auto-unwrapped (GenericResult.value). The downstream function receives + the raw extracted value, not a GenericResult wrapper. + """ + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + v = context.value + return GenericResult(value={"X_train": [v, v * 2], "y_train": [v * 10]}) + + @Flow.model + def train(context: SimpleContext, X: list, y: list) -> GenericResult[int]: + # X and y are auto-unwrapped to the raw list values + return GenericResult(value=sum(X) + sum(y)) + + prepared = prepare() + model = train(X=prepared.X_train, y=prepared.y_train) + + result = model(SimpleContext(value=5)) + # X_train = [5, 10], y_train = [50] + # sum(X) + sum(y) = 15 + 50 = 65 + self.assertEqual(result.value, 65) + + def test_field_extraction_multiple_from_same_source(self): + """Multiple extractors from same source share the source instance.""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"a": 1, "b": 2, "c": 3}) + + model = prepare() + ext_a = model.a + ext_b = model.b + ext_c = model.c + + # All should reference the same source + self.assertIs(ext_a.source, model) + self.assertIs(ext_b.source, model) + self.assertIs(ext_c.source, model) + + # All should evaluate correctly + ctx = SimpleContext(value=0) + self.assertEqual(ext_a(ctx).value, 1) + self.assertEqual(ext_b(ctx).value, 2) + self.assertEqual(ext_c(ctx).value, 3) + + def test_field_extraction_nested(self): + """Chained extraction (result.a.b) creates nested FieldExtractors.""" + from ccflow.flow_model import FieldExtractor + + class Nested: + def __init__(self): + self.inner_val = 42 + + @Flow.model + def produce(context: SimpleContext) -> GenericResult: + return GenericResult(value={"nested": Nested()}) + + model = produce() + nested_extractor = model.nested + inner_extractor = nested_extractor.inner_val + + self.assertIsInstance(nested_extractor, FieldExtractor) + self.assertIsInstance(inner_extractor, FieldExtractor) + + result = inner_extractor(SimpleContext(value=0)) + self.assertEqual(result.value, 42) + + def test_field_extraction_context_type_inherited(self): + """FieldExtractor inherits context_type from source.""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"x": 1}) + + model = prepare() + extractor = model.x + + self.assertEqual(extractor.context_type, SimpleContext) + + def test_field_extraction_nonexistent_field_runtime_error(self): + """Non-existent field raises error at evaluation time, not construction. + + For dict results, raises KeyError. For object results, raises AttributeError. + """ + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"x": 1}) + + model = prepare() + extractor = model.nonexistent # No error at construction + + # Error at evaluation time (KeyError for dicts, AttributeError for objects) + with self.assertRaises((KeyError, AttributeError)): + extractor(SimpleContext(value=0)) + + def test_field_extraction_pydantic_fields_not_intercepted(self): + """Accessing real pydantic fields returns the field value, NOT an extractor.""" + from ccflow.flow_model import FieldExtractor + + @Flow.model + def model_with_fields(context: SimpleContext, multiplier: int = 5) -> GenericResult[int]: + return GenericResult(value=context.value * multiplier) + + model = model_with_fields(multiplier=10) + + # 'multiplier' is a real pydantic field — should return the value, not a FieldExtractor + self.assertEqual(model.multiplier, 10) + self.assertNotIsInstance(model.multiplier, FieldExtractor) + + # 'meta' is inherited from CallableModel — should also not be intercepted + self.assertNotIsInstance(model.meta, FieldExtractor) + + def test_field_extraction_with_context_args(self): + """FieldExtractor works with context_args mode models.""" + from ccflow import FlowContext + + @Flow.model(context_args=["x"]) + def prepare(x: int) -> GenericResult[dict]: + return GenericResult(value={"doubled": x * 2, "tripled": x * 3}) + + model = prepare() + doubled = model.doubled + + result = doubled(FlowContext(x=5)) + self.assertEqual(result.value, 10) + + def test_field_extraction_has_flow_property(self): + """FieldExtractor has .flow property (inherits from CallableModel).""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"x": 1}) + + model = prepare() + extractor = model.x + + self.assertTrue(hasattr(extractor, "flow")) + + def test_field_extraction_deps(self): + """FieldExtractor.__deps__ returns the source as a dependency.""" + + @Flow.model + def prepare(context: SimpleContext) -> GenericResult[dict]: + return GenericResult(value={"x": 1}) + + model = prepare() + extractor = model.x + + ctx = SimpleContext(value=0) + deps = extractor.__deps__(ctx) + + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], model) + self.assertEqual(deps[0][1], [ctx]) + + if __name__ == "__main__": import unittest From 097ae6220cb7b2c70d897c806ab5caf885015199 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 05:52:31 -0400 Subject: [PATCH 09/26] Update docs for @Flow.model Signed-off-by: Nijat Khanbabayev --- docs/design/flow_model_design.md | 141 ++++++++++++++++++++++--------- docs/wiki/Key-Features.md | 110 ++++++++++++++++++------ 2 files changed, 183 insertions(+), 68 deletions(-) diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index 76d0eb7..909b597 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -6,38 +6,55 @@ This document describes the `@Flow.model` decorator and `DepOf` annotation syste **Key features:** - `@Flow.model` - Decorator that generates `CallableModel` classes from plain functions +- `FlowContext` - Universal context carrier for unpacked/deferred execution +- `model.flow.compute(...)` / `model.flow.with_inputs(...)` - Deferred execution helpers - `DepOf[ContextType, ResultType]` - Type annotation for dependency fields +- `Lazy[T]` - Mark a dependency for lazy, on-demand evaluation +- `FieldExtractor` - Access structured outputs via attribute access on generated models - `resolve()` - Function to access resolved dependency values in class-based models ## Quick Start -### Pattern 1: `@Flow.model` (Recommended for Simple Cases) +### Pattern 1: `@Flow.model` (Recommended for Declarative Cases) ```python from datetime import date, timedelta from typing import Annotated -from ccflow import Flow, DateRangeContext, GenericResult, DepOf +from ccflow import Flow, DateRangeContext, GenericResult, Dep, DepOf + + +def previous_window(ctx: DateRangeContext) -> DateRangeContext: + window = ctx.end_date - ctx.start_date + return ctx.model_copy( + update={ + "start_date": ctx.start_date - window - timedelta(days=1), + "end_date": ctx.start_date - timedelta(days=1), + } + ) @Flow.model -def load_records(context: DateRangeContext, source: str) -> GenericResult[dict]: - return GenericResult(value={"count": 100, "date": str(context.start_date)}) +def load_revenue(context: DateRangeContext, region: str) -> GenericResult[float]: + return GenericResult(value=125.0) @Flow.model -def compute_stats( +def revenue_growth( context: DateRangeContext, - records: DepOf[..., GenericResult[dict]], # Dependency field -) -> GenericResult[float]: - # records is already resolved - just use it directly - return GenericResult(value=records.value["count"] * 0.05) - -# Build pipeline -loader = load_records(source="main_db") -stats = compute_stats(records=loader) + current: DepOf[..., GenericResult[float]], + previous: Annotated[GenericResult[float], Dep(transform=previous_window)], +) -> GenericResult[dict]: + growth = (current.value - previous.value) / previous.value + return GenericResult(value={"as_of": context.end_date, "growth": growth}) + +# Build pipeline. The same upstream model is reused twice: +# - once with the original context +# - once with a fixed lookback transform +revenue = load_revenue(region="us") +growth = revenue_growth(current=revenue, previous=revenue) # Execute ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = stats(ctx) +result = growth(ctx) ``` ### Pattern 2: Class-Based (For Complex Cases) @@ -50,17 +67,17 @@ from datetime import timedelta from ccflow import CallableModel, DateRangeContext, Flow, GenericResult, DepOf from ccflow.callable import resolve # Import resolve for class-based models -class AggregateWithWindow(CallableModel): - """Aggregate records with configurable lookback window.""" +class RevenueAverageWithWindow(CallableModel): + """Aggregate revenue with a configurable lookback window.""" - records: DepOf[..., GenericResult[dict]] + revenue: DepOf[..., GenericResult[float]] window: int = 7 # Configurable instance field @Flow.call def __call__(self, context: DateRangeContext) -> GenericResult[float]: # Use resolve() to get the resolved value - records = resolve(self.records) - return GenericResult(value=records.value["count"] / self.window) + revenue = resolve(self.revenue) + return GenericResult(value=revenue.value / self.window) @Flow.deps def __deps__(self, context: DateRangeContext): @@ -68,22 +85,22 @@ class AggregateWithWindow(CallableModel): lookback_ctx = context.model_copy( update={"start_date": context.start_date - timedelta(days=self.window)} ) - return [(self.records, [lookback_ctx])] + return [(self.revenue, [lookback_ctx])] # Usage - different window sizes, same source -loader = load_records(source="main_db") -agg_7 = AggregateWithWindow(records=loader, window=7) -agg_30 = AggregateWithWindow(records=loader, window=30) +loader = load_revenue(region="us") +avg_7 = RevenueAverageWithWindow(revenue=loader, window=7) +avg_30 = RevenueAverageWithWindow(revenue=loader, window=30) ``` ## When to Use Which Pattern -| Use `@Flow.model` when... | Use Class-Based when... | -|--------------------------------|--------------------------------------| -| Simple transformations | Transforms depend on instance fields | -| Fixed context transforms | Need `self.field` in `__deps__` | -| Less boilerplate is priority | Full control over resolution | -| No custom `__deps__` logic | Complex dependency patterns | +| Use `@Flow.model` when... | Use Class-Based when... | +|--------------------------------|---------------------------------------| +| The node still reads like a normal function | The main value is custom graph logic | +| Transforms are fixed/declarative | Transforms depend on instance fields | +| Less boilerplate is priority | You need full control over `__deps__` | +| Dependency wiring fits in the signature | Dependency behavior deserves its own class | ## Core Concepts @@ -104,6 +121,17 @@ data: DepOf[DateRangeContext, GenericResult[dict]] data: Annotated[Union[GenericResult[dict], CallableModel], Dep()] ``` +For `@Flow.model`, plain non-`DepOf` parameters can also be populated with a +`CallableModel` instance. That lets callers either inject a concrete value or +splice in an upstream computation for the same parameter. Use `Dep`/`DepOf` +when you need explicit dependency metadata such as context transforms or +context-type validation. + +That means `DepOf` inside `@Flow.model` is most compelling when the function is +still doing real work and the dependency relationship is simple. If the node is +mostly a vessel for custom dependency graph wiring, a hand-written +`CallableModel` is usually clearer. + ### `Dep(transform=..., context_type=...)` For transforms, use the full `Annotated` form: @@ -158,12 +186,12 @@ resolved = resolve(self.data) # Type: GenericResult[int] 1. User calls `model(context)` 2. Generated `__call__` invokes `_resolve_deps_and_call()` -3. For each `DepOf` field containing a `CallableModel`: +3. For each dependency-bearing field containing a `CallableModel`: - Apply transform (if any) - Call the dependency - Store resolved value in context variable -4. Generated `__call__` retrieves resolved values via `resolve()` -5. Original function receives resolved values as arguments +4. Generated `__call__` reads the resolved values from the dependency store +5. Original function receives resolved values directly as normal function arguments ### Class-Based Resolution Flow @@ -171,6 +199,7 @@ resolved = resolve(self.data) # Type: GenericResult[int] 2. `_resolve_deps_and_call()` runs 3. For each `DepOf` field containing a `CallableModel`: - Check `__deps__` for custom transforms + - If not listed in `__deps__`, fall back to the field's `Dep(...)` transform (or the original context) - Call the dependency - Store resolved value in context variable 4. User's `__call__` accesses values via `resolve(self.field)` @@ -211,14 +240,18 @@ resolved = resolve(self.data) # Type: GenericResult[int] - Keeps top-level namespace clean - Users who need it can find it easily -### Decision 4: No Auto-Wrapping Return Values +### Decision 4: Auto-Wrap Plain Return Values -**What we chose:** Functions must explicitly return `ResultBase` subclass. +**What we chose:** If the function's declared return type is not a `ResultBase` +subclass, the generated model wraps the returned value in `GenericResult`. **Why:** -- Type annotations remain honest -- Consistent with existing `CallableModel` contract -- `GenericResult(value=x)` is minimal overhead +- Reduces boilerplate for simple scalar / container-returning functions +- Preserves the `CallableModel` contract that runtime results are `ResultBase` +- Still allows explicit `ResultBase` subclasses when you want a precise result type + +**Trade-off:** The original Python function may be annotated with a plain value +type while the generated model's runtime `result_type` is `GenericResult`. ### Decision 5: Generated Classes Are Real CallableModels @@ -290,23 +323,47 @@ Users need to remember: - `@Flow.model`: Use dependency values directly as function arguments - Class-based: Use `resolve(self.field)` to access values -### Limitation: `__deps__` Still Required for Class-Based +### Limitation: Custom `__deps__` Is Only Needed for Custom Graph Logic -Even without transforms, class-based models need `__deps__`: +Class-based models do not need a custom `__deps__` override when the default +field-level `Dep(...)` behavior is sufficient. Override `__deps__` only when +you need instance-dependent transforms or a custom dependency graph: ```python class Consumer(CallableModel): data: DepOf[..., GenericResult[int]] + @Flow.call + def __call__(self, context): + return GenericResult(value=resolve(self.data).value) +``` + +If you do need to use instance fields in the transform, then `__deps__` is the +right place to do it: + +```python +class WindowedConsumer(CallableModel): + data: DepOf[..., GenericResult[int]] + window: int = 7 + @Flow.call def __call__(self, context): return GenericResult(value=resolve(self.data).value) @Flow.deps def __deps__(self, context): - return [(self.data, [context])] # Boilerplate, but required + shifted = context.model_copy(update={"value": context.value + self.window}) + return [(self.data, [shifted])] ``` +### Limitation: `context_args` Type Matching Is Best-Effort + +When you use `context_args=[...]`, the framework validates those fields via a +runtime `TypedDict` schema. It only maps to a concrete built-in context type in +special cases such as `DateRangeContext`. Otherwise the generated model's +`context_type` is `FlowContext`, a universal frozen carrier for the validated +context values. + ## Complete Example: Multi-Stage Pipeline ```python @@ -397,6 +454,10 @@ def my_function(context: ContextType, ...) -> ResultType: ... ``` +If the function is annotated with a plain value type instead of a `ResultBase` +subclass, the generated model will wrap the returned value in `GenericResult` +at runtime. + ### `DepOf[ContextType, ResultType]` ```python diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index a89d8f8..f73ac6b 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -49,46 +49,75 @@ result = loader(ctx) Use `Dep()` or `DepOf` to mark parameters that accept other `CallableModel`s as dependencies. The framework automatically resolves the dependency graph. -> **Tip:** If your function doesn't use the context directly (only passes it to dependencies), use `_` as the parameter name to signal this: `def my_func(_: DateRangeContext, data: DepOf[..., ResultType])`. This is a Python convention for intentionally unused parameters. +For `@Flow.model`, regular parameters can also accept a `CallableModel` value at +construction time. This lets you either inject a literal value or splice in an +upstream computation for the same parameter. Use `Dep`/`DepOf` when you need +context transforms or explicit dependency metadata. + +> **Rule of thumb:** `@Flow.model` works best when the dependency wiring is declarative and local to the signature. If the main point of the node is custom graph logic or transforms that depend on instance fields, use a class-based `CallableModel` instead. ```python from datetime import date, timedelta from typing import Annotated from ccflow import Flow, GenericResult, DateRangeContext, Dep, DepOf -@Flow.model -def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: - return GenericResult(value={"records": [1, 2, 3]}) +def previous_window(ctx: DateRangeContext) -> DateRangeContext: + window = ctx.end_date - ctx.start_date + return ctx.model_copy( + update={ + "start_date": ctx.start_date - window - timedelta(days=1), + "end_date": ctx.start_date - timedelta(days=1), + } + ) @Flow.model -def transform_data( - _: DateRangeContext, # Context passed to dependency, not used directly - raw_data: Annotated[GenericResult[dict], Dep( - # Transform context to fetch one extra day for lookback - transform=lambda ctx: ctx.model_copy(update={ - "start_date": ctx.start_date - timedelta(days=1) - }) - )] -) -> GenericResult[dict]: - # raw_data.value contains the resolved result from load_data - return GenericResult(value={"transformed": raw_data.value["records"]}) +def load_revenue(context: DateRangeContext, region: str) -> GenericResult[float]: + # Pretend this queries a warehouse + return GenericResult(value=125.0) -# Or use DepOf shorthand (no transform needed): @Flow.model -def aggregate_data( - _: DateRangeContext, # Context passed to dependency, not used directly - transformed: DepOf[..., GenericResult[dict]] # Shorthand for Annotated[T, Dep()] +def revenue_growth( + context: DateRangeContext, + current: DepOf[..., GenericResult[float]], + previous: Annotated[GenericResult[float], Dep(transform=previous_window)], ) -> GenericResult[dict]: - return GenericResult(value={"count": len(transformed.value["transformed"])}) + growth = (current.value - previous.value) / previous.value + return GenericResult(value={"as_of": context.end_date, "growth": growth}) -# Build the pipeline -data = load_data(source="my_database") -transformed = transform_data(raw_data=data) -aggregated = aggregate_data(transformed=transformed) +# Build the pipeline. The same loader is reused with two contexts: +# - current window: original context +# - previous window: transformed via Dep(transform=...) +revenue = load_revenue(region="us") +growth = revenue_growth(current=revenue, previous=revenue) -# Execute - dependencies are automatically resolved ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = aggregated(ctx) +result = growth(ctx) +``` + +`DepOf` is also useful when you want the same parameter to accept either an +upstream model or a precomputed value: + +```python +from ccflow import DateRangeContext, DepOf, Flow, GenericResult + +@Flow.model +def load_signal(context: DateRangeContext, source: str) -> GenericResult[float]: + return GenericResult(value=0.87) + +@Flow.model +def publish_signal( + context: DateRangeContext, + signal: DepOf[..., GenericResult[float]], + threshold: float = 0.8, +) -> GenericResult[dict]: + return GenericResult(value={ + "as_of": context.end_date, + "signal": signal.value, + "go_live": signal.value >= threshold, + }) + +live = publish_signal(signal=load_signal(source="prod")) +override = publish_signal(signal=GenericResult(value=0.95), threshold=0.9) ``` **Hydra/YAML Configuration:** @@ -126,7 +155,7 @@ from ccflow import Flow, GenericResult, DateRangeContext def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: return GenericResult(value=f"{source}:{start_date} to {end_date}") -# The decorator infers DateRangeContext from the parameter types +# The decorator matches common built-in context types when possible loader = load_data(source="my_database") assert loader.context_type == DateRangeContext @@ -135,7 +164,32 @@ ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) # "my_database:2024-01-01 to 2024-01-31" ``` -The `context_args` parameter specifies which function parameters should be extracted from the context. The framework automatically determines the context type based on the parameter type annotations. +The `context_args` parameter specifies which function parameters should be extracted from the context. Those fields are validated through a runtime schema built from the parameter annotations. For well-known shapes such as `start_date` / `end_date`, the generated model uses a concrete built-in context type like `DateRangeContext`; otherwise it uses `FlowContext`, a universal frozen carrier for the validated fields. + +**Deferred Execution Helpers:** + +Generated models also expose a `.flow` helper namespace: + +```python +from ccflow import Flow, GenericResult + +@Flow.model +def add(x: int, y: int) -> GenericResult[int]: + return GenericResult(value=x + y) + +model = add(x=10) + +# Validate and execute by passing context fields as kwargs +assert model.flow.compute(y=5) == 15 + +# Derive a new model by transforming context inputs +shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert shifted.flow.compute(y=5) == 20 +``` + +If a `@Flow.model` function returns a plain value instead of a `ResultBase` +subclass, the generated model automatically wraps that value in `GenericResult` +at runtime so it still behaves like a normal `CallableModel`. ## Model Registry From 3d26896c245f2568b3cbd1aa8f3b7da74699218c Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 15:36:19 -0400 Subject: [PATCH 10/26] Clean up, ty check @Flow.model, add test Signed-off-by: Nijat Khanbabayev --- ccflow/__init__.py | 1 - ccflow/callable.py | 219 +----- ccflow/dep.py | 278 -------- ccflow/flow_model.py | 222 +++--- ccflow/tests/config/conf_flow.yaml | 2 +- ccflow/tests/test_flow_model.py | 927 ++++---------------------- ccflow/tests/test_flow_model_hydra.py | 16 +- docs/design/flow_model_design.md | 552 ++++----------- docs/wiki/Key-Features.md | 200 +++--- examples/flow_model_example.py | 246 ++----- 10 files changed, 565 insertions(+), 2098 deletions(-) delete mode 100644 ccflow/dep.py diff --git a/ccflow/__init__.py b/ccflow/__init__.py index 4dbe143..1bb69fe 100644 --- a/ccflow/__init__.py +++ b/ccflow/__init__.py @@ -10,7 +10,6 @@ from .compose import * from .callable import * from .context import * -from .dep import * from .enums import Enum from .flow_model import * from .global_state import * diff --git a/ccflow/callable.py b/ccflow/callable.py index d3b22e4..fd849c5 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -14,7 +14,6 @@ import abc import inspect import logging -from contextvars import ContextVar from functools import lru_cache, wraps from inspect import Signature, isclass, signature from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin @@ -29,7 +28,6 @@ ResultBase, ResultType, ) -from .dep import Dep, extract_dep from .local_persistence import create_ccflow_model from .validators import str_to_log_level @@ -48,8 +46,6 @@ "EvaluatorBase", "Evaluator", "WrapperModel", - # Note: resolve() is intentionally not in __all__ to avoid namespace pollution. - # Users who need it can import explicitly: from ccflow.callable import resolve ) log = logging.getLogger(__name__) @@ -199,164 +195,6 @@ def _get_logging_evaluator(log_level): return LoggingEvaluator(log_level=log_level) -def _get_dep_fields(model_class) -> Dict[str, Dep]: - """Analyze class fields to find Dep-annotated fields. - - Returns a dict mapping field name to Dep instance for fields that need resolution. - """ - dep_fields = {} - - # Get type hints from the class - hints = {} - for cls in model_class.__mro__: - if hasattr(cls, "__annotations__"): - for name, annotation in cls.__annotations__.items(): - if name not in hints: # Don't override child class annotations - hints[name] = annotation - - for name, annotation in hints.items(): - base_type, dep = extract_dep(annotation) - if dep is not None: - dep_fields[name] = dep - - return dep_fields - - -def _wrap_with_dep_resolution(fn): - """Wrap a function to auto-resolve DepOf fields before calling. - - For each Dep-annotated field on the model that contains a CallableModel, - resolves it using __deps__ and temporarily sets the resolved value on self. - - Note: This wrapper is only applied at runtime when the function is called, - not during decoration. This avoids issues with functools.wraps flattening - the __wrapped__ chain. - - Args: - fn: The original function - - Returns: - The original function unchanged - dep resolution happens at the call site - """ - # Don't modify the function - dep resolution is handled in ModelEvaluationContext - return fn - - -# Context variable for storing resolved dependency values during __call__ -# Maps id(callable_model) -> resolved_value -_resolved_deps: ContextVar[Dict[int, Any]] = ContextVar("resolved_deps", default={}) - -# TypeVar for resolve() function to enable proper type inference -_T = TypeVar("_T") - - -def resolve(dep: Union[_T, "_CallableModel"]) -> _T: - """Access the resolved value of a DepOf dependency during __call__. - - This function is used inside a CallableModel's __call__ method to get - the resolved value of a dependency field. It provides proper type inference - - if the field is `DepOf[..., GenericResult[int]]`, this returns `GenericResult[int]`. - - Args: - dep: The dependency field value (either a CallableModel or already-resolved value) - - Returns: - The resolved value. If dep is already a resolved value (not a CallableModel), - returns it unchanged. - - Raises: - RuntimeError: If called outside of __call__ or if the dependency wasn't resolved. - - Example: - class MyModel(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: MyContext) -> GenericResult[int]: - # resolve() provides proper type inference - data = resolve(self.data) # type: GenericResult[int] - return GenericResult(value=data.value + 1) - """ - # If it's not a CallableModel, it's already a resolved value - pass through - if not isinstance(dep, _CallableModel): - return dep # type: ignore[return-value] - - # Look up in context var - store = _resolved_deps.get() - dep_id = id(dep) - if dep_id not in store: - raise RuntimeError( - "resolve() can only be used inside __call__ for DepOf fields. Make sure the field is annotated with DepOf and contains a CallableModel." - ) - return store[dep_id] - - -def _resolve_deps_and_call(model, context, fn): - """Resolve DepOf fields and call the function. - - This is called from ModelEvaluationContext.__call__ to handle dep resolution. - Resolved values are stored in a context variable and accessed via resolve(). - - Args: - model: The CallableModel instance - context: The context to pass to the function - fn: The function to call - - Returns: - The result of calling fn(model, context) - """ - # Don't resolve deps for __deps__ method - if fn.__name__ == "__deps__": - return fn(model, context) - - # Get Dep-annotated fields for this model class - dep_fields = _get_dep_fields(model.__class__) - - if not dep_fields: - return fn(model, context) - - # Get dependencies from __deps__ - deps_result = model.__deps__(context) - # Build a map from model instance id to (model, contexts) for lookup - dep_map = {} - for dep_model, contexts in deps_result: - dep_map[id(dep_model)] = (dep_model, contexts) - - # Resolve dependencies and store in context var - resolved_values = {} - - # Standard path: iterate over Dep-annotated fields - for field_name, dep in dep_fields.items(): - field_value = getattr(model, field_name, None) - if field_value is None: - continue - - # Check if field is a CallableModel that needs resolution - if not isinstance(field_value, _CallableModel): - continue # Already a resolved value, skip - - # Check if this field is in __deps__ (for custom transforms) - if id(field_value) in dep_map: - dep_model, contexts = dep_map[id(field_value)] - # Call dependency with the (transformed) context - resolved = dep_model(contexts[0]) if contexts else dep_model(context) - else: - # Not in __deps__, use Dep annotation transform directly - transformed_ctx = dep.apply(context) - resolved = field_value(transformed_ctx) - - resolved_values[id(field_value)] = resolved - - # Store in context var and call function - current_store = _resolved_deps.get() - new_store = {**current_store, **resolved_values} - token = _resolved_deps.set(new_store) - try: - return fn(model, context) - finally: - _resolved_deps.reset(token) - - class FlowOptions(BaseModel): """Options for Flow evaluation. @@ -408,9 +246,6 @@ def get_evaluator(self, model: CallableModelType) -> "EvaluatorBase": return self._get_evaluator_from_options(options) def __call__(self, fn): - # Wrap function with dependency resolution for DepOf fields - fn = _wrap_with_dep_resolution(fn) - # Used for building a graph of model evaluation contexts without evaluating def get_evaluation_context(model: CallableModelType, context: ContextType, as_dict: bool = False, *, _options: Optional[FlowOptions] = None): # Create the evaluation context. @@ -617,32 +452,6 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult: # The generated context inherits from DateContext, so it's compatible # with infrastructure expecting DateContext instances. - Auto-Resolve Dependencies Example: - When __call__ has parameters beyond 'self' and 'context' that match field - names annotated with DepOf/Dep, those dependencies are automatically resolved - using __deps__ (if defined) or auto-generated from Dep annotations. - - class MyModel(CallableModel): - data: Annotated[GenericResult[dict], Dep(transform=my_transform)] - - @Flow.call - def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: - # data is automatically resolved - no manual calling needed - return GenericResult(value=process(data.value)) - - For transforms that need access to instance fields, define __deps__ manually: - - class MyModel(CallableModel): - data: DepOf[..., GenericResult[dict]] - window: int = 7 - - def __deps__(self, context): - # Can access self.window here - return [(self.data, [context.with_lookback(self.window)])] - - @Flow.call - def __call__(self, context, data: GenericResult[dict]) -> GenericResult[dict]: - return GenericResult(value=process(data.value)) """ # Extract auto_context option (not part of FlowOptions) # Can be: False, True, or a ContextBase subclass @@ -728,27 +537,10 @@ def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[ return GenericResult(value=query_db(source, start_date, end_date)) Dependencies: - Use Dep() or DepOf to mark parameters that can accept CallableModel dependencies: - - from ccflow import Dep, DepOf - from typing import Annotated - - @Flow.model - def compute_returns( - context: DateRangeContext, - prices: Annotated[GenericResult[pl.DataFrame], Dep( - transform=lambda ctx: ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) - )] - ) -> GenericResult[pl.DataFrame]: - return GenericResult(value=prices.value.pct_change()) - - # Or use DepOf shorthand for no transform: - @Flow.model - def compute_stats( - context: DateRangeContext, - data: DepOf[..., GenericResult[pl.DataFrame]] - ) -> GenericResult[pl.DataFrame]: - return GenericResult(value=data.value.describe()) + Any non-context parameter can be bound either to a literal value or + to another CallableModel. When a CallableModel is supplied, the + generated model treats it as an upstream dependency and resolves it + with the current context before calling the underlying function. Usage: # Create model instances @@ -819,8 +611,7 @@ def _context_validator(cls, values, handler, info): def __call__(self) -> ResultType: fn = getattr(self.model, self.fn) if hasattr(fn, "__wrapped__"): - # Call through _resolve_deps_and_call to handle DepOf field resolution - result = _resolve_deps_and_call(self.model, self.context, fn.__wrapped__) + result = fn.__wrapped__(self.model, self.context) # If it's a callable model, then we can validate the result if self.options.get("validate_result", True): if fn.__name__ == "__deps__": diff --git a/ccflow/dep.py b/ccflow/dep.py deleted file mode 100644 index b57261e..0000000 --- a/ccflow/dep.py +++ /dev/null @@ -1,278 +0,0 @@ -"""Dependency annotation markers for Flow.model. - -This module provides: -- Dep: Annotation marker for dependency parameters that can accept CallableModel -- DepOf: Shorthand for Annotated[Union[T, CallableModel], Dep()] -""" - -from typing import TYPE_CHECKING, Annotated, Callable, Optional, Type, TypeVar, Union, get_args, get_origin - -from .base import ContextBase - -if TYPE_CHECKING: - from .callable import CallableModel - -__all__ = ("Dep", "DepOf") - -T = TypeVar("T") - -# Lazy reference to CallableModel to avoid circular import -_CallableModel = None - - -def _get_callable_model(): - """Lazily import CallableModel to avoid circular imports.""" - global _CallableModel - if _CallableModel is None: - from .callable import CallableModel - - _CallableModel = CallableModel - return _CallableModel - - -class _DepOfMeta(type): - """Metaclass that makes DepOf[ContextType, ResultType] work.""" - - def __getitem__(cls, item): - if not isinstance(item, tuple) or len(item) != 2: - raise TypeError( - "DepOf requires 2 type arguments: DepOf[ContextType, ResultType]. " - "Use ... for ContextType to inherit from parent: DepOf[..., ResultType]" - ) - context_type, result_type = item - CallableModel = _get_callable_model() - - if context_type is ...: - # DepOf[..., ResultType] - inherit context from parent - return Annotated[Union[result_type, CallableModel], Dep()] - else: - # DepOf[ContextType, ResultType] - explicit context type - return Annotated[Union[result_type, CallableModel], Dep(context_type=context_type)] - - -class DepOf(metaclass=_DepOfMeta): - """ - Shorthand for Annotated[Union[ResultType, CallableModel], Dep(context_type=...)]. - - Follows Callable convention: DepOf[InputContext, OutputResult] - - For class fields, accepts either: - - The result type directly (pre-computed value) - - A CallableModel that produces the result type (resolved at call time) - - Usage: - # Inherit context type from parent model (most common) - data: DepOf[..., GenericResult[dict]] - - # Explicit context type validation - data: DepOf[DateRangeContext, GenericResult[dict]] - - At call time, if the field contains a CallableModel, it will be automatically - resolved using __deps__ and the resolved value will be accessible via self.field_name. - - For dependencies with transforms, define them in __deps__: - def __deps__(self, context): - transformed_ctx = context.model_copy(update={...}) - return [(self.data, [transformed_ctx])] - """ - - pass - - -def _is_compatible_type(actual: Type, expected: Type) -> bool: - """Check if actual type is compatible with expected type. - - Handles generic types like GenericResult[pl.DataFrame] where issubclass - would raise TypeError. - - Args: - actual: The actual type to check - expected: The expected type to match against - - Returns: - True if actual is compatible with expected - """ - # Handle None/empty types - if actual is None or expected is None: - return actual is expected - - # Get origins for generic types - actual_origin = get_origin(actual) or actual - expected_origin = get_origin(expected) or expected - - # Check if origins are compatible - try: - if not (isinstance(actual_origin, type) and isinstance(expected_origin, type)): - return False - if not issubclass(actual_origin, expected_origin): - return False - except TypeError: - # issubclass can fail for certain types - return False - - # Check generic args if present - actual_args = get_args(actual) - expected_args = get_args(expected) - - if expected_args and actual_args: - if len(actual_args) != len(expected_args): - return False - return all(_is_compatible_type(a, e) for a, e in zip(actual_args, expected_args)) - - return True - - -class Dep: - """ - Annotation marker for dependency parameters. - - Marks a parameter as accepting either the declared type or a CallableModel - that produces that type. Supports optional context transform and - construction-time type validation. - - Usage: - # No transform, no explicit validation (uses parent's context_type) - prices: Annotated[GenericResult[pl.DataFrame], Dep()] - - # With transform - prices: Annotated[GenericResult[pl.DataFrame], Dep( - transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) - )] - - # With explicit context_type validation - prices: Annotated[GenericResult[pl.DataFrame], Dep( - context_type=DateRangeContext, - transform=lambda ctx: ctx.model_copy(update={"start": ctx.start - timedelta(days=1)}) - )] - - # Cross-context dependency (transform changes context type) - sim_data: Annotated[GenericResult[pl.DataFrame], Dep( - context_type=SimulationContext, - transform=date_to_simulation_context - )] - """ - - def __init__( - self, - transform: Optional[Callable[..., ContextBase]] = None, - context_type: Optional[Type[ContextBase]] = None, - ): - """ - Args: - transform: Optional function to transform context before calling dependency. - Signature: (context) -> transformed_context - context_type: Expected context_type of the dependency CallableModel. - If None, defaults to the parent model's context_type. - Validated at construction time when a CallableModel is passed. - """ - self.transform = transform - self.context_type = context_type - - def apply(self, context: ContextBase) -> ContextBase: - """Apply the transform to a context, or return unchanged if no transform.""" - if self.transform is not None: - return self.transform(context) - return context - - def validate_dependency( - self, - value: "CallableModel", # noqa: F821 - expected_result_type: Type, - parent_context_type: Type[ContextBase], - param_name: str, - ) -> None: - """ - Validate a CallableModel dependency at construction time. - - Args: - value: The CallableModel being passed as a dependency - expected_result_type: The result type from the Annotated type hint - parent_context_type: The context_type of the parent model - param_name: Name of the parameter (for error messages) - - Raises: - TypeError: If context_type or result_type don't match - """ - # Import here to avoid circular import - from .callable import CallableModel - - if not isinstance(value, CallableModel): - return # Not a CallableModel, skip validation - - # Determine expected context type - expected_ctx = self.context_type if self.context_type is not None else parent_context_type - - # Validate context_type - the dependency's context_type should be compatible - # with what we'll pass to it (expected_ctx) - dep_context_type = value.context_type - try: - if not issubclass(expected_ctx, dep_context_type): - raise TypeError( - f"Dependency '{param_name}': expected context_type compatible with " - f"{dep_context_type.__name__}, but will pass {expected_ctx.__name__}" - ) - except TypeError: - # issubclass can fail for certain types, try alternate check - if expected_ctx != dep_context_type: - raise TypeError(f"Dependency '{param_name}': context_type mismatch - expected {dep_context_type}, got {expected_ctx}") - - # Validate result_type using the generic-safe comparison - # If expected_result_type is Union[T, CallableModel], extract T for validation - dep_result_type = value.result_type - actual_expected_type = expected_result_type - - # Handle Union[T, CallableModel] from DepOf expansion - if get_origin(expected_result_type) is Union: - union_args = get_args(expected_result_type) - # Filter out CallableModel from the union - non_callable_types = [t for t in union_args if t is not CallableModel] - if non_callable_types: - actual_expected_type = non_callable_types[0] - - if not _is_compatible_type(dep_result_type, actual_expected_type): - raise TypeError( - f"Dependency '{param_name}': expected result_type compatible with " - f"{actual_expected_type}, but got CallableModel with result_type {dep_result_type}" - ) - - def __repr__(self): - parts = [] - if self.transform is not None: - parts.append(f"transform={self.transform}") - if self.context_type is not None: - parts.append(f"context_type={self.context_type.__name__}") - return f"Dep({', '.join(parts)})" if parts else "Dep()" - - def __eq__(self, other): - if not isinstance(other, Dep): - return False - return self.transform == other.transform and self.context_type == other.context_type - - def __hash__(self): - # Make Dep hashable for use in sets/dicts - return hash((id(self.transform), self.context_type)) - - -def extract_dep(annotation) -> tuple: - """Extract Dep from Annotated[T, Dep(...)] or DepOf[ContextType, T]. - - When multiple Dep annotations exist (e.g., from nested Annotated that flattens), - returns the LAST one, which represents the outermost user annotation. - - Args: - annotation: A type annotation, possibly Annotated with Dep - - Returns: - Tuple of (base_type, Dep instance or None) - """ - if get_origin(annotation) is Annotated: - args = get_args(annotation) - base_type = args[0] - # Find the LAST Dep - nested Annotated flattens, so outer annotation comes last - last_dep = None - for metadata in args[1:]: - if isinstance(metadata, Dep): - last_dep = metadata - if last_dep is not None: - return base_type, last_dep - return annotation, None diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index e9f2704..44d9cfa 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -12,20 +12,29 @@ import inspect import logging from functools import wraps -from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin +from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, cast, get_args, get_origin -from pydantic import Field, TypeAdapter +from pydantic import Field, TypeAdapter, model_validator from typing_extensions import TypedDict from .base import ContextBase, ResultBase from .callable import CallableModel, Flow, GraphDepList, _CallableModel from .context import FlowContext -from .dep import Dep, extract_dep from .local_persistence import register_ccflow_import_path from .result import GenericResult __all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy", "FieldExtractor") +_AnyCallable = Callable[..., Any] + + +def _callable_name(func: _AnyCallable) -> str: + return getattr(func, "__name__", type(func).__name__) + + +def _callable_module(func: _AnyCallable) -> str: + return getattr(func, "__module__", __name__) + class _LazyMarker: """Sentinel that marks a parameter as lazily evaluated via Lazy[T].""" @@ -36,9 +45,8 @@ class _LazyMarker: def _extract_lazy(annotation) -> Tuple[Any, bool]: """Check if annotation is Lazy[T]. Returns (base_type, is_lazy). - Handles nested Annotated types — e.g. Lazy[Annotated[T, Dep(...)]] produces - Annotated[Annotated[T, Dep(...)], _LazyMarker()], so we need to check the - outermost Annotated layer for _LazyMarker. + Handles nested Annotated types, so we need to check the outermost + Annotated layer for _LazyMarker. """ if get_origin(annotation) is Annotated: args = get_args(annotation) @@ -87,15 +95,11 @@ def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter return TypeAdapter(TypedDict(name, schema)) -def _build_config_validators( - all_param_types: Dict[str, Type], dep_fields: Dict[str, Tuple[Type, Dep]] -) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: - """Precompute validators for non-dependency config fields.""" +def _build_config_validators(all_param_types: Dict[str, Type]) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: + """Precompute validators for constructor fields.""" validatable_types: Dict[str, Type] = {} for name, typ in all_param_types.items(): - if name in dep_fields: - continue try: TypeAdapter(typ) validatable_types[name] = typ @@ -112,6 +116,7 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, if not validators: return + from .base import ModelRegistry as _MR from .callable import CallableModel as _CM for field_name, validator in validators.items(): @@ -120,6 +125,8 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, value = kwargs[field_name] if value is None or isinstance(value, (_CM, BoundModel)): continue + if isinstance(value, str) and value in _MR.root(): + continue try: validator.validate_python(value) except Exception: @@ -134,7 +141,7 @@ class FlowAPI: Accessed via model.flow property. """ - def __init__(self, model: "CallableModel"): # noqa: F821 + def __init__(self, model: "_GeneratedFlowModelBase"): self._model = model def compute(self, **kwargs) -> Any: @@ -228,25 +235,23 @@ class BoundModel: of a previous transform). """ - def __init__(self, model: "CallableModel", input_transforms: Dict[str, Any]): # noqa: F821 + def __init__(self, model: "_GeneratedFlowModelBase", input_transforms: Dict[str, Any]): self._model = model self._input_transforms = input_transforms - def __call__(self, context: ContextBase) -> Any: - """Call the model with transformed context.""" - # Build new context dict with transforms applied + def _transform_context(self, context: ContextBase) -> FlowContext: + """Return a FlowContext with this model's input transforms applied.""" ctx_dict = _context_values(context) - - # Apply transforms for name, transform in self._input_transforms.items(): if callable(transform): ctx_dict[name] = transform(context) else: ctx_dict[name] = transform + return FlowContext(**ctx_dict) - # Create new context and call model - new_ctx = FlowContext(**ctx_dict) - return self._model(new_ctx) + def __call__(self, context: ContextBase) -> Any: + """Call the model with transformed context.""" + return self._model(self._transform_context(context)) @property def flow(self) -> "FlowAPI": @@ -288,7 +293,10 @@ class _FieldExtractorMixin: def __getattr__(self, name): try: - return super().__getattr__(name) + super_getattr = getattr(super(), "__getattr__", None) + if super_getattr is None: + raise AttributeError(name) + return super_getattr(name) except AttributeError: if name.startswith("_"): raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") from None @@ -298,6 +306,43 @@ def __getattr__(self, name): class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): """Shared behavior for models generated by ``@Flow.model``.""" + __flow_model_context_type__: ClassVar[Type[ContextBase]] = FlowContext + __flow_model_return_type__: ClassVar[Type[ResultBase]] = GenericResult + __flow_model_func__: ClassVar[_AnyCallable | None] = None + __flow_model_use_context_args__: ClassVar[bool] = True + __flow_model_explicit_context_args__: ClassVar[Optional[List[str]]] = None + __flow_model_all_param_types__: ClassVar[Dict[str, Type]] = {} + __flow_model_auto_wrap__: ClassVar[bool] = False + _context_schema: ClassVar[Dict[str, Type]] = {} + _context_td: ClassVar[Any | None] = None + _matched_context_type: ClassVar[Optional[Type[ContextBase]]] = None + _cached_context_validator: ClassVar[TypeAdapter | None] = None + + @model_validator(mode="before") + def _resolve_registry_refs(cls, values, info): + if not isinstance(values, dict): + return values + + from .base import BaseModel as _BM + + param_types = getattr(cls, "__flow_model_all_param_types__", {}) + resolved = dict(values) + for field_name, expected_type in param_types.items(): + if field_name not in resolved: + continue + value = resolved[field_name] + if not isinstance(value, str): + continue + if expected_type is str: + continue + try: + candidate = _BM.model_validate(value) + except Exception: + continue + if isinstance(candidate, _BM): + resolved[field_name] = candidate + return resolved + @property def context_type(self) -> Type[ContextBase]: return self.__class__.__flow_model_context_type__ @@ -414,8 +459,8 @@ def model(self) -> "CallableModel": # noqa: F821 def _build_context_schema( - context_args: List[str], func: Callable, sig: inspect.Signature -) -> Tuple[Dict[str, Type], Type, Optional[Type[ContextBase]]]: + context_args: List[str], func: _AnyCallable, sig: inspect.Signature +) -> Tuple[Dict[str, Type], Any, Optional[Type[ContextBase]]]: """Build context schema from context_args parameter names. Instead of creating a dynamic ContextBase subclass, this builds: @@ -456,25 +501,16 @@ def _build_context_schema( matched_context_type = DateRangeContext # Create TypedDict for validation (not registered anywhere!) - context_td = TypedDict(f"{func.__name__}Inputs", schema) + context_td = TypedDict(f"{_callable_name(func)}Inputs", schema) return schema, context_td, matched_context_type -def _get_dep_info(annotation) -> Tuple[Type, Optional[Dep]]: - """Extract dependency info from an annotation. - - Returns: - Tuple of (base_type, Dep instance or None) - """ - return extract_dep(annotation) - - _UNSET = object() def flow_model( - func: Callable = None, + func: Optional[_AnyCallable] = None, *, # Context handling context_args: Optional[List[str]] = None, @@ -487,7 +523,7 @@ def flow_model( validate_result: Any = _UNSET, verbose: Any = _UNSET, evaluator: Any = _UNSET, -) -> Callable: +) -> _AnyCallable: """Decorator that generates a CallableModel class from a plain Python function. This is syntactic sugar over CallableModel. The decorator generates a real @@ -522,7 +558,7 @@ def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[ A factory function that creates CallableModel instances """ - def decorator(fn: Callable) -> Callable: + def decorator(fn: _AnyCallable) -> _AnyCallable: import typing as _typing sig = inspect.signature(fn) @@ -538,7 +574,7 @@ def decorator(fn: Callable) -> Callable: # Validate return type return_type = _resolved_hints.get("return", sig.return_annotation) if return_type is inspect.Signature.empty: - raise TypeError(f"Function {fn.__name__} must have a return type annotation") + raise TypeError(f"Function {_callable_name(fn)} must have a return type annotation") # Check if return type is a ResultBase subclass; if not, auto-wrap in GenericResult return_origin = get_origin(return_type) or return_type auto_wrap_result = False @@ -555,10 +591,10 @@ def decorator(fn: Callable) -> Callable: context_param = params[context_param_name] context_annotation = _resolved_hints.get(context_param_name, context_param.annotation) if context_annotation is inspect.Parameter.empty: - raise TypeError(f"Function {fn.__name__}: '{context_param_name}' parameter must have a type annotation") + raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' parameter must have a type annotation") context_type = context_annotation if not (isinstance(context_type, type) and issubclass(context_type, ContextBase)): - raise TypeError(f"Function {fn.__name__}: '{context_param_name}' must be annotated with a ContextBase subclass") + raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' must be annotated with a ContextBase subclass") model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} use_context_args = False explicit_context_args = None @@ -583,10 +619,9 @@ def decorator(fn: Callable) -> Callable: use_context_args = True explicit_context_args = None # Dynamic - determined at construction - # Analyze parameters to find dependencies, lazy fields, and regular fields - dep_fields: Dict[str, Tuple[Type, Dep]] = {} # name -> (base_type, Dep) + # Analyze parameters to find lazy fields and regular fields. model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) - lazy_fields: set = set() # Names of parameters marked with Lazy[T] + lazy_fields: set[str] = set() # Names of parameters marked with Lazy[T] # In dynamic deferred mode (no explicit context_args), all fields are optional # because values not provided at construction come from context at runtime @@ -603,8 +638,6 @@ def decorator(fn: Callable) -> Callable: if is_lazy: lazy_fields.add(name) - # Extract Dep info from the (possibly unwrapped) annotation - base_type, dep = _get_dep_info(unwrapped_annotation) if param.default is not inspect.Parameter.empty: default = param.default elif dynamic_deferred_mode: @@ -614,17 +647,7 @@ def decorator(fn: Callable) -> Callable: # In explicit mode, params without defaults are required default = ... - if dep is not None: - # This is an explicit dependency parameter (DepOf annotation) - dep_fields[name] = (base_type, dep) - # Use Annotated so _resolve_deps_and_call in callable.py can find the Dep - model_fields[name] = (Annotated[Union[base_type, CallableModel], dep], default) - else: - # Regular model field - use Any for auto-detection of CallableModels. - # We can't use Union[T, CallableModel] because Pydantic tries to generate - # schema for T, which fails for arbitrary types like pl.DataFrame. - # Using Any allows any value; we do runtime isinstance checks in __call__. - model_fields[name] = (Any, default) + model_fields[name] = (Any, default) # Capture variables for closures ctx_param_name = context_param_name if not use_context_args else "context" @@ -637,23 +660,15 @@ def decorator(fn: Callable) -> Callable: # Create the __call__ method def make_call_impl(): def __call__(self, context): - # Import here (inside function) to avoid pickling issues with ContextVar - from .callable import _resolved_deps - - def resolve_callable_model(name, value, store): + def resolve_callable_model(value): """Resolve a CallableModel field.""" - if id(value) in store: - return store[id(value)] - else: - # Auto-detection fallback: call directly - resolved = value(context) - if isinstance(resolved, GenericResult): - return resolved.value - return resolved + resolved = value(context) + if isinstance(resolved, GenericResult): + return resolved.value + return resolved # Build kwargs for the original function fn_kwargs = {} - store = _resolved_deps.get() def _resolve_field(name, value): """Resolve a single field value, handling lazy wrapping.""" @@ -666,7 +681,7 @@ def _resolve_field(name, value): # Non-dep value: wrap in trivial thunk return lambda v=value: v elif is_dep: - return resolve_callable_model(name, value, store) + return resolve_callable_model(value) else: return value @@ -704,7 +719,7 @@ def _resolve_field(name, value): return raw_result # Set proper signature for CallableModel validation - __call__.__signature__ = inspect.Signature( + cast(Any, __call__).__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), @@ -739,19 +754,14 @@ def __deps__(self, context) -> GraphDepList: if name in lazy_fields: continue # Lazy deps are NOT pre-evaluated value = getattr(self, name) - if isinstance(value, (CallableModel, BoundModel)): - if name in dep_fields: - # Explicit DepOf with transform (backwards compat) - _, dep_obj = dep_fields[name] - transformed_ctx = dep_obj.apply(context) - deps.append((value, [transformed_ctx])) - else: - # Auto-detected dependency - use context as-is - deps.append((value, [context])) + if isinstance(value, BoundModel): + deps.append((value._model, [value._transform_context(context)])) + elif isinstance(value, CallableModel): + deps.append((value, [context])) return deps # Set proper signature - __deps__.__signature__ = inspect.Signature( + cast(Any, __deps__).__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), @@ -767,8 +777,8 @@ def __deps__(self, context) -> GraphDepList: annotations = {} namespace = { - "__module__": fn.__module__, - "__qualname__": f"_{fn.__name__}_Model", + "__module__": _callable_module(fn), + "__qualname__": f"_{_callable_name(fn)}_Model", "__call__": decorated_call, "__deps__": decorated_deps, } @@ -783,36 +793,15 @@ def __deps__(self, context) -> GraphDepList: namespace["__annotations__"] = annotations - # Add model validator for dependency validation if we have dep fields - if dep_fields: - from pydantic import model_validator - - # Create validator function that captures dep_fields and context_type - def make_dep_validator(d_fields, ctx_type): - @model_validator(mode="after") - def __validate_deps__(self): - from .callable import CallableModel - - for dep_name, (base_type, dep_obj) in d_fields.items(): - value = getattr(self, dep_name) - if isinstance(value, CallableModel): - dep_obj.validate_dependency(value, base_type, ctx_type, dep_name) - return self - - return __validate_deps__ - - namespace["__validate_deps__"] = make_dep_validator(dep_fields, context_type) - - _validatable_types, _config_validators = _build_config_validators(all_param_types, dep_fields) + _validatable_types, _config_validators = _build_config_validators(all_param_types) # Create the class using type() - GeneratedModel = type(f"_{fn.__name__}_Model", (_GeneratedFlowModelBase,), namespace) + GeneratedModel = cast(type[_GeneratedFlowModelBase], type(f"_{_callable_name(fn)}_Model", (_GeneratedFlowModelBase,), namespace)) # Set class-level attributes after class creation (to avoid pydantic processing) GeneratedModel.__flow_model_context_type__ = context_type GeneratedModel.__flow_model_return_type__ = internal_return_type - GeneratedModel.__flow_model_func__ = fn - GeneratedModel.__flow_model_dep_fields__ = dep_fields + setattr(GeneratedModel, "__flow_model_func__", fn) GeneratedModel.__flow_model_use_context_args__ = use_context_args GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type @@ -851,7 +840,7 @@ def __validate_deps__(self): # Create factory function that returns model instances @wraps(fn) - def factory(**kwargs) -> GeneratedModel: + def factory(**kwargs) -> _GeneratedFlowModelBase: _validate_config_kwargs(kwargs, _validatable_types, _config_validators) instance = GeneratedModel(**kwargs) @@ -861,7 +850,7 @@ def factory(**kwargs) -> GeneratedModel: return instance # Preserve useful attributes on factory - factory._generated_model = GeneratedModel + cast(Any, factory)._generated_model = GeneratedModel factory.__doc__ = fn.__doc__ return factory @@ -906,16 +895,9 @@ def result_type(self): @Flow.call def __call__(self, context: ContextBase) -> GenericResult: - # Lazy import: _resolved_deps is a ContextVar that can't be pickled - from .callable import _resolved_deps - - store = _resolved_deps.get() - if id(self.source) in store: - result = store[id(self.source)] - else: - result = self.source(context) - if isinstance(result, GenericResult): - result = result.value + result = self.source(context) + if isinstance(result, GenericResult): + result = result.value # Support both attribute access and dict key access if isinstance(result, dict): return GenericResult(value=result[self.field_name]) diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml index 781bd24..41acfaf 100644 --- a/ccflow/tests/config/conf_flow.yaml +++ b/ccflow/tests/config/conf_flow.yaml @@ -60,7 +60,7 @@ diamond_aggregator: # DateRangeContext with transform flow_date_loader: - _target_: ccflow.tests.test_flow_model.date_range_loader + _target_: ccflow.tests.test_flow_model.date_range_loader_previous_day source: market_data include_weekends: false diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index b547aee..458569d 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -1,25 +1,22 @@ """Tests for Flow.model decorator.""" from datetime import date, timedelta -from typing import Annotated from unittest import TestCase -from pydantic import ValidationError from ray.cloudpickle import dumps as rcpdumps, loads as rcploads from ccflow import ( CallableModel, ContextBase, DateRangeContext, - Dep, - DepOf, Flow, + FlowOptionsOverride, GenericResult, Lazy, ModelRegistry, ResultBase, ) -from ccflow.callable import resolve +from ccflow.evaluators.common import MemoryCacheEvaluator class SimpleContext(ContextBase): @@ -136,9 +133,9 @@ def loader(context: SimpleContext, base: int) -> GenericResult[int]: return GenericResult(value=context.value + base) @Flow.model - def consumer(_: SimpleContext, data: DepOf[..., GenericResult[int]]) -> GenericResult[int]: + def consumer(_: SimpleContext, data: int) -> GenericResult[int]: # Context not used directly, just passed to dependency - return GenericResult(value=data.value * 2) + return GenericResult(value=data * 2) load = loader(base=100) consume = consumer(data=load) @@ -214,10 +211,10 @@ def model_with_ctx_default(value: int = 42, extra: str = "foo") -> GenericResult class TestFlowModelDependencies(TestCase): - """Tests for Flow.model with dependencies.""" + """Tests for Flow.model with upstream CallableModel inputs.""" - def test_simple_dependency_with_depof(self): - """Test simple dependency using DepOf shorthand.""" + def test_simple_dependency(self): + """Test passing an upstream model as a normal parameter.""" @Flow.model def loader(context: SimpleContext, value: int) -> GenericResult[int]: @@ -226,10 +223,10 @@ def loader(context: SimpleContext, value: int) -> GenericResult[int]: @Flow.model def consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, multiplier: int = 1, ) -> GenericResult[int]: - return GenericResult(value=data.value * multiplier) + return GenericResult(value=data * multiplier) # Create pipeline load = loader(value=10) @@ -241,39 +238,17 @@ def consumer( # loader returns 10 + 5 = 15, consumer multiplies by 2 = 30 self.assertEqual(result.value, 30) - def test_dependency_with_explicit_dep(self): - """Test dependency using explicit Dep() annotation.""" - - @Flow.model - def loader(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 2) - - @Flow.model - def consumer( - context: SimpleContext, - data: Annotated[GenericResult[int], Dep()], - ) -> GenericResult[int]: - return GenericResult(value=data.value + 100) - - load = loader() - consume = consumer(data=load) - - result = consume(SimpleContext(value=10)) - # loader: 10 * 2 = 20, consumer: 20 + 100 = 120 - self.assertEqual(result.value, 120) - def test_dependency_with_direct_value(self): - """Test that Dep fields can accept direct values (not CallableModel).""" + """Test that dependency-shaped parameters can also take direct values.""" @Flow.model def consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value + context.value) + return GenericResult(value=data + context.value) - # Pass direct value instead of CallableModel - consume = consumer(data=GenericResult(value=100)) + consume = consumer(data=100) result = consume(SimpleContext(value=5)) self.assertEqual(result.value, 105) @@ -288,9 +263,9 @@ def loader(context: SimpleContext) -> GenericResult[int]: @Flow.model def consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value) + return GenericResult(value=data) load = loader() consume = consumer(data=load) @@ -309,105 +284,80 @@ def test_no_deps_when_direct_value(self): @Flow.model def consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value) + return GenericResult(value=data) - consume = consumer(data=GenericResult(value=100)) + consume = consumer(data=100) deps = consume.__deps__(SimpleContext(value=10)) self.assertEqual(len(deps), 0) # ============================================================================= -# Transform Tests +# with_inputs Tests # ============================================================================= -class TestFlowModelTransforms(TestCase): - """Tests for Flow.model with context transforms.""" +class TestFlowModelWithInputs(TestCase): + """Tests for Flow.model with .flow.with_inputs().""" - def test_transform_in_dep(self): - """Test dependency with context transform.""" + def test_transformed_dependency_with_inputs(self): + """Test dependency context transformation via .flow.with_inputs().""" @Flow.model def loader(context: SimpleContext) -> GenericResult[int]: return GenericResult(value=context.value) @Flow.model - def consumer( - context: SimpleContext, - data: Annotated[ - GenericResult[int], - Dep(transform=lambda ctx: ctx.model_copy(update={"value": ctx.value + 10})), - ], - ) -> GenericResult[int]: - return GenericResult(value=data.value * 2) + def consumer(context: SimpleContext, data: int) -> GenericResult[int]: + return GenericResult(value=data * 2) - load = loader() + load = loader().flow.with_inputs(value=lambda ctx: ctx.value + 10) consume = consumer(data=load) - ctx = SimpleContext(value=5) - result = consume(ctx) - - # Transform adds 10 to context.value: 5 + 10 = 15 - # Loader returns that: 15 - # Consumer multiplies by 2: 30 + result = consume(SimpleContext(value=5)) self.assertEqual(result.value, 30) - def test_transform_in_deps_method(self): - """Test that transform is applied in __deps__ method.""" - - def transform_fn(ctx): - return ctx.model_copy(update={"value": ctx.value * 3}) + def test_with_inputs_changes_dependency_context_in_deps(self): + """Test that BoundModel contributes transformed dependency contexts.""" @Flow.model def loader(context: SimpleContext) -> GenericResult[int]: return GenericResult(value=context.value) @Flow.model - def consumer( - context: SimpleContext, - data: Annotated[GenericResult[int], Dep(transform=transform_fn)], - ) -> GenericResult[int]: - return GenericResult(value=data.value) + def consumer(context: SimpleContext, data: int) -> GenericResult[int]: + return GenericResult(value=data) - load = loader() + load = loader().flow.with_inputs(value=lambda ctx: ctx.value * 3) consume = consumer(data=load) - ctx = SimpleContext(value=7) - deps = consume.__deps__(ctx) - - # Transform should be applied + deps = consume.__deps__(SimpleContext(value=7)) self.assertEqual(len(deps), 1) transformed_ctx = deps[0][1][0] - self.assertEqual(transformed_ctx.value, 21) # 7 * 3 + self.assertEqual(transformed_ctx.value, 21) - def test_date_range_transform(self): - """Test transform pattern with date ranges using context_args.""" + def test_date_range_transform_with_inputs(self): + """Test date-range lookback wiring via .flow.with_inputs().""" @Flow.model(context_args=["start_date", "end_date"]) def range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: return GenericResult(value=f"{source}:{start_date}") - def lookback_transform(ctx: DateRangeContext) -> DateRangeContext: - return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) - @Flow.model(context_args=["start_date", "end_date"]) def range_processor( start_date: date, end_date: date, - data: Annotated[GenericResult[str], Dep(transform=lookback_transform)], + data: str, ) -> GenericResult[str]: - return GenericResult(value=f"processed:{data.value}") + return GenericResult(value=f"processed:{data}") - loader = range_loader(source="db") + loader = range_loader(source="db").flow.with_inputs(start_date=lambda ctx: ctx.start_date - timedelta(days=1)) processor = range_processor(data=loader) ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) result = processor(ctx) - - # Transform should shift start_date back by 1 day self.assertEqual(result.value, "processed:db:2024-01-09") @@ -429,18 +379,18 @@ def stage1(context: SimpleContext, base: int) -> GenericResult[int]: @Flow.model def stage2( context: SimpleContext, - input_data: DepOf[..., GenericResult[int]], + input_data: int, multiplier: int, ) -> GenericResult[int]: - return GenericResult(value=input_data.value * multiplier) + return GenericResult(value=input_data * multiplier) @Flow.model def stage3( context: SimpleContext, - input_data: DepOf[..., GenericResult[int]], + input_data: int, offset: int = 0, ) -> GenericResult[int]: - return GenericResult(value=input_data.value + offset) + return GenericResult(value=input_data + offset) # Build pipeline s1 = stage1(base=100) @@ -465,24 +415,24 @@ def source(context: SimpleContext) -> GenericResult[int]: @Flow.model def branch_a( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value * 2) + return GenericResult(value=data * 2) @Flow.model def branch_b( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, ) -> GenericResult[int]: - return GenericResult(value=data.value + 100) + return GenericResult(value=data + 100) @Flow.model def merger( context: SimpleContext, - a: DepOf[..., GenericResult[int]], - b: DepOf[..., GenericResult[int]], + a: int, + b: int, ) -> GenericResult[int]: - return GenericResult(value=a.value + b.value) + return GenericResult(value=a + b) src = source() a = branch_a(data=src) @@ -568,10 +518,10 @@ def __call__(self, context: SimpleContext) -> GenericResult[int]: @Flow.model def generated_consumer( context: SimpleContext, - data: DepOf[..., GenericResult[int]], + data: int, multiplier: int, ) -> GenericResult[int]: - return GenericResult(value=data.value * multiplier) + return GenericResult(value=data * multiplier) manual = ManualModel(offset=50) generated = generated_consumer(data=manual, multiplier=2) @@ -616,7 +566,7 @@ def test_auto_wrap_unwrap_as_dependency(self): """Test that auto-wrapped model used as dep delivers unwrapped value downstream. Auto-wrapped models have result_type=GenericResult (unparameterized). - When used as an auto-detected dep (no DepOf), the framework resolves + When used as an auto-detected dep, the framework resolves the GenericResult and unwraps .value for the downstream function. """ @@ -627,7 +577,7 @@ def plain_source(context: SimpleContext) -> int: @Flow.model def consumer( context: SimpleContext, - data: GenericResult[int], # Auto-detected dep, not DepOf + data: GenericResult[int], # Auto-detected dep ) -> GenericResult[int]: # data is auto-unwrapped to the int value by the framework return GenericResult(value=data + 1) @@ -711,288 +661,13 @@ def untyped_context_arg(x) -> GenericResult[int]: self.assertIn("type annotation", str(cm.exception)) -# ============================================================================= -# Dep and DepOf Tests -# ============================================================================= - - -class TestDepAndDepOf(TestCase): - """Tests for Dep and DepOf classes.""" - - def test_depof_creates_annotated(self): - """Test that DepOf[..., T] creates Annotated[Union[T, CallableModel], Dep()].""" - from typing import Union as TypingUnion, get_args, get_origin - - annotation = DepOf[..., GenericResult[int]] - self.assertEqual(get_origin(annotation), Annotated) - - args = get_args(annotation) - # First arg is Union[ResultType, CallableModel] - self.assertEqual(get_origin(args[0]), TypingUnion) - union_args = get_args(args[0]) - self.assertIn(GenericResult[int], union_args) - self.assertIn(CallableModel, union_args) - # Second arg is Dep() - self.assertIsInstance(args[1], Dep) - self.assertIsNone(args[1].context_type) # ... means inherit from parent - - def test_depof_with_generic_type(self): - """Test DepOf with nested generic types.""" - from typing import List as TypingList, Union as TypingUnion, get_args, get_origin - - annotation = DepOf[..., GenericResult[TypingList[str]]] - self.assertEqual(get_origin(annotation), Annotated) - - args = get_args(annotation) - # First arg is Union[ResultType, CallableModel] - self.assertEqual(get_origin(args[0]), TypingUnion) - union_args = get_args(args[0]) - self.assertIn(GenericResult[TypingList[str]], union_args) - self.assertIn(CallableModel, union_args) - - def test_depof_with_context_type(self): - """Test DepOf[ContextType, ResultType] syntax.""" - from typing import Union as TypingUnion, get_args, get_origin - - annotation = DepOf[SimpleContext, GenericResult[int]] - self.assertEqual(get_origin(annotation), Annotated) - - args = get_args(annotation) - # First arg is Union[ResultType, CallableModel] - self.assertEqual(get_origin(args[0]), TypingUnion) - union_args = get_args(args[0]) - self.assertIn(GenericResult[int], union_args) - self.assertIn(CallableModel, union_args) - # Second arg is Dep with context_type - self.assertIsInstance(args[1], Dep) - self.assertEqual(args[1].context_type, SimpleContext) - - def test_extract_dep_with_annotated(self): - """Test extract_dep with Annotated type.""" - from ccflow.dep import extract_dep - - dep = Dep(context_type=SimpleContext) - annotation = Annotated[GenericResult[int], dep] - - base_type, extracted_dep = extract_dep(annotation) - self.assertEqual(base_type, GenericResult[int]) - self.assertEqual(extracted_dep, dep) - - def test_extract_dep_with_depof(self): - """Test extract_dep with DepOf type.""" - from typing import Union as TypingUnion, get_args, get_origin - - from ccflow.dep import extract_dep - - annotation = DepOf[..., GenericResult[str]] - base_type, extracted_dep = extract_dep(annotation) - - # base_type is Union[ResultType, CallableModel] - self.assertEqual(get_origin(base_type), TypingUnion) - union_args = get_args(base_type) - self.assertIn(GenericResult[str], union_args) - self.assertIn(CallableModel, union_args) - self.assertIsInstance(extracted_dep, Dep) - - def test_extract_dep_without_dep(self): - """Test extract_dep with regular type (no Dep).""" - from ccflow.dep import extract_dep - - base_type, extracted_dep = extract_dep(int) - self.assertEqual(base_type, int) - self.assertIsNone(extracted_dep) - - def test_extract_dep_annotated_without_dep(self): - """Test extract_dep with Annotated but no Dep marker.""" - from ccflow.dep import extract_dep - - annotation = Annotated[int, "some metadata"] - base_type, extracted_dep = extract_dep(annotation) - - # When no Dep marker is found, returns original annotation unchanged - self.assertEqual(base_type, annotation) - self.assertIsNone(extracted_dep) - - def test_is_compatible_type_simple(self): - """Test _is_compatible_type with simple types.""" - from ccflow.dep import _is_compatible_type - - self.assertTrue(_is_compatible_type(int, int)) - self.assertFalse(_is_compatible_type(int, str)) - self.assertTrue(_is_compatible_type(bool, int)) # bool is subclass of int - - def test_is_compatible_type_generic(self): - """Test _is_compatible_type with generic types.""" - from ccflow.dep import _is_compatible_type - - self.assertTrue(_is_compatible_type(GenericResult[int], GenericResult[int])) - self.assertFalse(_is_compatible_type(GenericResult[int], GenericResult[str])) - self.assertTrue(_is_compatible_type(GenericResult, GenericResult)) - - def test_is_compatible_type_none(self): - """Test _is_compatible_type with None.""" - from ccflow.dep import _is_compatible_type - - self.assertTrue(_is_compatible_type(None, None)) - self.assertFalse(_is_compatible_type(None, int)) - self.assertFalse(_is_compatible_type(int, None)) - - def test_is_compatible_type_subclass(self): - """Test _is_compatible_type with subclasses.""" - from ccflow.dep import _is_compatible_type - - self.assertTrue(_is_compatible_type(MyResult, ResultBase)) - self.assertFalse(_is_compatible_type(ResultBase, MyResult)) - - def test_dep_validate_dependency_success(self): - """Test Dep.validate_dependency with valid dependency.""" - - @Flow.model - def valid_dep(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - dep = Dep() - model = valid_dep() - - # Should not raise - dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") - - def test_dep_validate_dependency_context_mismatch(self): - """Test Dep.validate_dependency with context type mismatch.""" - - class OtherContext(ContextBase): - other: str - - @Flow.model - def other_dep(context: OtherContext) -> GenericResult[int]: - return GenericResult(value=42) - - dep = Dep(context_type=SimpleContext) - model = other_dep() - - with self.assertRaises(TypeError) as cm: - dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") - - self.assertIn("context_type", str(cm.exception)) - - def test_dep_validate_dependency_result_mismatch(self): - """Test Dep.validate_dependency with result type mismatch.""" - - @Flow.model - def wrong_result(context: SimpleContext) -> MyResult: - return MyResult(data="test") - - dep = Dep() - model = wrong_result() - - with self.assertRaises(TypeError) as cm: - dep.validate_dependency(model, GenericResult[int], SimpleContext, "data") - - self.assertIn("result_type", str(cm.exception)) - - def test_dep_validate_dependency_non_callable(self): - """Test Dep.validate_dependency with non-CallableModel value.""" - dep = Dep() - # Should not raise for non-CallableModel values - dep.validate_dependency(GenericResult(value=42), GenericResult[int], SimpleContext, "data") - dep.validate_dependency("string", GenericResult[int], SimpleContext, "data") - dep.validate_dependency(123, GenericResult[int], SimpleContext, "data") - - def test_dep_hash(self): - """Test Dep is hashable for use in sets/dicts.""" - dep1 = Dep() - dep2 = Dep(context_type=SimpleContext) - - # Should be hashable - dep_set = {dep1, dep2} - self.assertEqual(len(dep_set), 2) - - dep_dict = {dep1: "value1", dep2: "value2"} - self.assertEqual(dep_dict[dep1], "value1") - self.assertEqual(dep_dict[dep2], "value2") - - def test_dep_apply_with_transform(self): - """Test Dep.apply with transform function.""" - - def transform(ctx): - return ctx.model_copy(update={"value": ctx.value * 2}) - - dep = Dep(transform=transform) - - ctx = SimpleContext(value=10) - result = dep.apply(ctx) - - self.assertEqual(result.value, 20) - - def test_dep_apply_without_transform(self): - """Test Dep.apply without transform (identity).""" - dep = Dep() - - ctx = SimpleContext(value=10) - result = dep.apply(ctx) - - self.assertIs(result, ctx) - - def test_dep_repr(self): - """Test Dep string representation.""" - dep1 = Dep() - self.assertEqual(repr(dep1), "Dep()") - - dep2 = Dep(context_type=SimpleContext) - self.assertIn("SimpleContext", repr(dep2)) - - dep3 = Dep(transform=lambda x: x) - self.assertIn("transform=", repr(dep3)) - - def test_dep_equality(self): - """Test Dep equality comparison.""" - dep1 = Dep() - dep2 = Dep() - dep3 = Dep(context_type=SimpleContext) - - # Note: Two Dep() instances with no arguments are equal - self.assertEqual(dep1, dep2) - self.assertNotEqual(dep1, dep3) - - # ============================================================================= # Validation Tests # ============================================================================= class TestFlowModelValidation(TestCase): - """Tests for dependency validation in Flow.model.""" - - def test_context_type_validation(self): - """Test that context_type mismatch is detected.""" - - class OtherContext(ContextBase): - other: str - - @Flow.model - def simple_loader(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def other_loader(context: OtherContext) -> GenericResult[int]: - return GenericResult(value=42) - - @Flow.model - def consumer( - context: SimpleContext, - data: Annotated[GenericResult[int], Dep(context_type=SimpleContext)], - ) -> GenericResult[int]: - return GenericResult(value=data.value) - - # Should work with matching context - load1 = simple_loader() - consume1 = consumer(data=load1) - self.assertIsNotNone(consume1) - - # Should fail with mismatched context - load2 = other_loader() - with self.assertRaises((TypeError, ValidationError)): - consumer(data=load2) + """Tests for Flow.model validation behavior.""" def test_config_validation_rejects_bad_type(self): """Test that config validator rejects wrong types at construction.""" @@ -1208,6 +883,36 @@ def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 self.assertEqual(result, 51) + def test_bound_and_unbound_models_share_memory_cache(self): + """Shifted and unshifted models should share one evaluator cache. + + They should not share the same cache key when the effective contexts + differ, but repeated evaluations of either model should still hit the + same underlying MemoryCacheEvaluator instance rather than re-executing. + """ + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + base = source() + shifted = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) + evaluator = MemoryCacheEvaluator() + ctx = SimpleContext(value=5) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + self.assertEqual(base(ctx).value, 50) + self.assertEqual(shifted(ctx).value, 60) + self.assertEqual(base(ctx).value, 50) + self.assertEqual(shifted(ctx).value, 60) + + # One execution for the unshifted context and one for the shifted context. + self.assertEqual(call_counts["source"], 2) + self.assertEqual(len(evaluator.cache), 2) + # ============================================================================= # PEP 563 (from __future__ import annotations) Compatibility Tests @@ -1296,27 +1001,27 @@ def data_source(context: SimpleContext, base_value: int) -> GenericResult[int]: @Flow.model def data_transformer( context: SimpleContext, - source: DepOf[..., GenericResult[int]], + source: int, factor: int = 2, ) -> GenericResult[int]: """Transform data by multiplying with factor.""" - return GenericResult(value=source.value * factor) + return GenericResult(value=source * factor) @Flow.model def data_aggregator( context: SimpleContext, - input_a: DepOf[..., GenericResult[int]], - input_b: DepOf[..., GenericResult[int]], + input_a: int, + input_b: int, operation: str = "add", ) -> GenericResult[int]: """Aggregate two inputs.""" if operation == "add": - return GenericResult(value=input_a.value + input_b.value) + return GenericResult(value=input_a + input_b) elif operation == "multiply": - return GenericResult(value=input_a.value * input_b.value) + return GenericResult(value=input_a * input_b) else: - return GenericResult(value=input_a.value - input_b.value) + return GenericResult(value=input_a - input_b) @Flow.model @@ -1328,26 +1033,21 @@ def pipeline_stage1(context: SimpleContext, initial: int) -> GenericResult[int]: @Flow.model def pipeline_stage2( context: SimpleContext, - stage1_output: DepOf[..., GenericResult[int]], + stage1_output: int, multiplier: int = 2, ) -> GenericResult[int]: """Second stage of pipeline.""" - return GenericResult(value=stage1_output.value * multiplier) + return GenericResult(value=stage1_output * multiplier) @Flow.model def pipeline_stage3( context: SimpleContext, - stage2_output: DepOf[..., GenericResult[int]], + stage2_output: int, offset: int = 0, ) -> GenericResult[int]: """Third stage of pipeline.""" - return GenericResult(value=stage2_output.value + offset) - - -def lookback_one_day(ctx: DateRangeContext) -> DateRangeContext: - """Transform that extends start_date back by one day.""" - return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) + return GenericResult(value=stage2_output + offset) @Flow.model @@ -1355,20 +1055,37 @@ def date_range_loader( context: DateRangeContext, source: str, include_weekends: bool = True, -) -> GenericResult[str]: +) -> GenericResult[dict]: """Load data for a date range.""" - return GenericResult(value=f"{source}:{context.start_date} to {context.end_date}") + return GenericResult( + value={ + "source": source, + "start_date": str(context.start_date), + "end_date": str(context.end_date), + } + ) + + +@Flow.model +def date_range_loader_previous_day( + context: DateRangeContext, + source: str, + include_weekends: bool = True, +) -> dict: + """Hydra helper that applies a one-day lookback before delegating.""" + shifted = context.model_copy(update={"start_date": context.start_date - timedelta(days=1)}) + return date_range_loader(source=source, include_weekends=include_weekends)(shifted).value @Flow.model def date_range_processor( context: DateRangeContext, - raw_data: Annotated[GenericResult[str], Dep(transform=lookback_one_day)], + raw_data: dict, normalize: bool = False, ) -> GenericResult[str]: - """Process date range data with lookback.""" + """Process date range data.""" prefix = "normalized:" if normalize else "raw:" - return GenericResult(value=f"{prefix}{raw_data.value}") + return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") @Flow.model @@ -1386,31 +1103,37 @@ def hydra_source_model(context: SimpleContext, base: int) -> GenericResult[int]: @Flow.model def hydra_consumer_model( context: SimpleContext, - source: DepOf[..., GenericResult[int]], + source: int, factor: int = 1, ) -> GenericResult[int]: """Consumer model for dependency testing.""" - return GenericResult(value=source.value * factor) + return GenericResult(value=source * factor) # --- context_args fixtures for Hydra testing --- @Flow.model(context_args=["start_date", "end_date"]) -def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: +def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[dict]: """Loader using context_args with DateRangeContext.""" - return GenericResult(value=f"{source}:{start_date} to {end_date}") + return GenericResult( + value={ + "source": source, + "start_date": str(start_date), + "end_date": str(end_date), + } + ) @Flow.model(context_args=["start_date", "end_date"]) def context_args_processor( start_date: date, end_date: date, - data: DepOf[..., GenericResult[str]], + data: dict, prefix: str = "processed", ) -> GenericResult[str]: """Processor using context_args with dependency.""" - return GenericResult(value=f"{prefix}:{data.value}") + return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") class TestFlowModelHydra(TestCase): @@ -1477,390 +1200,6 @@ def test_hydra_instantiate_with_dependency(self): self.assertEqual(result.value, 100) -# ============================================================================= -# Class-based CallableModel with Auto-Resolution Tests -# ============================================================================= - - -class TestClassBasedDepResolution(TestCase): - """Tests for auto-resolution of DepOf fields in class-based CallableModels. - - Key pattern: Fields use DepOf annotation, __call__ only takes context, - and resolved values are accessed via self.field_name during __call__. - """ - - def test_class_based_auto_resolve_basic(self): - """Test that DepOf fields are auto-resolved and accessible via resolve().""" - - @Flow.model - def data_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - class Consumer(CallableModel): - # DepOf expands to Annotated[Union[ResultType, CallableModel], Dep()] - source: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - # Access resolved value via resolve() - return GenericResult(value=resolve(self.source).value + 1) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.source, [context])] - - src = data_source() - consumer = Consumer(source=src) - - result = consumer(SimpleContext(value=5)) - # source: 5 * 10 = 50, consumer: 50 + 1 = 51 - self.assertEqual(result.value, 51) - - def test_class_based_with_custom_transform(self): - """Test that custom __deps__ transform is used.""" - - @Flow.model - def data_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - class Consumer(CallableModel): - source: DepOf[..., GenericResult[int]] - offset: int = 100 - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.source).value + self.offset) - - @Flow.deps - def __deps__(self, context: SimpleContext): - # Apply custom transform - transformed_ctx = SimpleContext(value=context.value + 5) - return [(self.source, [transformed_ctx])] - - src = data_source() - consumer = Consumer(source=src, offset=1) - - result = consumer(SimpleContext(value=5)) - # transformed context: 5 + 5 = 10 - # source: 10 * 10 = 100 - # consumer: 100 + 1 = 101 - self.assertEqual(result.value, 101) - - def test_class_based_with_annotated_transform(self): - """Test that Dep transform is used when field not in __deps__.""" - - @Flow.model - def data_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - def double_value(ctx: SimpleContext) -> SimpleContext: - return SimpleContext(value=ctx.value * 2) - - class Consumer(CallableModel): - source: Annotated[DepOf[..., GenericResult[int]], Dep(transform=double_value)] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.source).value + 1) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [] # Empty - uses Dep annotation transform from field - - src = data_source() - consumer = Consumer(source=src) - - result = consumer(SimpleContext(value=5)) - # transform: 5 * 2 = 10 - # source: 10 * 10 = 100 - # consumer: 100 + 1 = 101 - self.assertEqual(result.value, 101) - - def test_class_based_multiple_deps(self): - """Test auto-resolution with multiple dependencies.""" - - @Flow.model - def source_a(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def source_b(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 2) - - class Aggregator(CallableModel): - a: DepOf[..., GenericResult[int]] - b: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.a).value + resolve(self.b).value) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.a, [context]), (self.b, [context])] - - agg = Aggregator(a=source_a(), b=source_b()) - - result = agg(SimpleContext(value=10)) - # a: 10, b: 20, aggregator: 30 - self.assertEqual(result.value, 30) - - def test_class_based_deps_with_instance_field_access(self): - """Test that __deps__ can access instance fields for configurable transforms. - - This is the key advantage of class-based models over @Flow.model: - transforms can use instance fields like window size. - """ - - @Flow.model - def data_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - class Consumer(CallableModel): - source: DepOf[..., GenericResult[int]] - lookback: int = 5 # Configurable instance field - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.source).value * 2) - - @Flow.deps - def __deps__(self, context: SimpleContext): - # Access self.lookback in transform - this is why we use class-based! - transformed = SimpleContext(value=context.value + self.lookback) - return [(self.source, [transformed])] - - src = data_source() - consumer = Consumer(source=src, lookback=10) - - result = consumer(SimpleContext(value=5)) - # transformed: 5 + 10 = 15 - # source: 15 - # consumer: 15 * 2 = 30 - self.assertEqual(result.value, 30) - - def test_class_based_with_direct_value(self): - """Test that DepOf fields can accept pre-resolved values.""" - - class Consumer(CallableModel): - source: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - # resolve() passes through non-CallableModel values unchanged - return GenericResult(value=resolve(self.source).value + context.value) - - @Flow.deps - def __deps__(self, context: SimpleContext): - # No deps when source is already resolved - return [] - - # Pass direct value instead of CallableModel - consumer = Consumer(source=GenericResult(value=100)) - - result = consumer(SimpleContext(value=5)) - self.assertEqual(result.value, 105) - - def test_class_based_no_double_call(self): - """Test that dependencies are not called twice during DepOf resolution. - - This verifies that the auto-resolution mechanism doesn't accidentally - evaluate the same dependency multiple times. - """ - call_counts = {"source": 0} - - @Flow.model - def counting_source(context: SimpleContext) -> GenericResult[int]: - call_counts["source"] += 1 - return GenericResult(value=context.value * 10) - - class Consumer(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.data).value + 1) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.data, [context])] - - src = counting_source() - consumer = Consumer(data=src) - - # Call consumer - source should only be called once - result = consumer(SimpleContext(value=5)) - - self.assertEqual(result.value, 51) # 5 * 10 + 1 - self.assertEqual(call_counts["source"], 1, "Source should only be called once") - - def test_class_based_nested_depof_no_double_call(self): - """Test nested DepOf chain (A -> B -> C) has no double-calls at any layer. - - This tests a 3-layer dependency chain where: - - layer_c is the leaf (no dependencies) - - layer_b depends on layer_c - - layer_a depends on layer_b - - Each layer should be called exactly once. - """ - call_counts = {"layer_a": 0, "layer_b": 0, "layer_c": 0} - - # Layer C: leaf node (no dependencies) - @Flow.model - def layer_c(context: SimpleContext) -> GenericResult[int]: - call_counts["layer_c"] += 1 - return GenericResult(value=context.value) - - # Layer B: depends on layer_c - class LayerB(CallableModel): - source: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - call_counts["layer_b"] += 1 - return GenericResult(value=resolve(self.source).value * 10) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.source, [context])] - - # Layer A: depends on layer_b - class LayerA(CallableModel): - source: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - call_counts["layer_a"] += 1 - return GenericResult(value=resolve(self.source).value + 1) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.source, [context])] - - # Build the chain: A -> B -> C - c = layer_c() - b = LayerB(source=c) - a = LayerA(source=b) - - # Call layer_a - each layer should be called exactly once - result = a(SimpleContext(value=5)) - - # Verify result: C returns 5, B returns 5*10=50, A returns 50+1=51 - self.assertEqual(result.value, 51) - - # Verify each layer called exactly once - self.assertEqual(call_counts["layer_c"], 1, "layer_c should be called exactly once") - self.assertEqual(call_counts["layer_b"], 1, "layer_b should be called exactly once") - self.assertEqual(call_counts["layer_a"], 1, "layer_a should be called exactly once") - - def test_resolve_direct_value_passthrough(self): - """Test that resolve() passes through non-CallableModel values unchanged.""" - - class Consumer(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - # resolve() should return the GenericResult directly (pass-through) - resolved = resolve(self.data) - # Verify it's the actual GenericResult, not a CallableModel - assert isinstance(resolved, GenericResult) - return GenericResult(value=resolved.value * 2) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [] - - # Pass a direct value, not a CallableModel - direct_result = GenericResult(value=42) - consumer = Consumer(data=direct_result) - - result = consumer(SimpleContext(value=5)) - self.assertEqual(result.value, 84) # 42 * 2 - - def test_resolve_outside_call_raises_error(self): - """Test that resolve() raises RuntimeError when called outside __call__.""" - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - class Consumer(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=resolve(self.data).value) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.data, [context])] - - src = source() - consumer = Consumer(data=src) - - # Calling resolve() outside of __call__ should raise RuntimeError - with self.assertRaises(RuntimeError) as cm: - resolve(consumer.data) - - self.assertIn("resolve() can only be used inside __call__", str(cm.exception)) - - def test_flow_model_uses_unified_resolution_path(self): - """Test that @Flow.model uses the same resolution path as class-based CallableModel. - - This verifies the consolidation of resolution logic - both @Flow.model and - class-based models should use _resolve_deps_and_call in callable.py. - """ - call_counts = {"source": 0, "decorator_model": 0, "class_model": 0} - - @Flow.model - def shared_source(context: SimpleContext) -> GenericResult[int]: - call_counts["source"] += 1 - return GenericResult(value=context.value * 2) - - # @Flow.model consumer - @Flow.model - def decorator_consumer( - context: SimpleContext, - data: DepOf[..., GenericResult[int]], - ) -> GenericResult[int]: - call_counts["decorator_model"] += 1 - return GenericResult(value=data.value + 100) - - # Class-based consumer (same logic) - class ClassConsumer(CallableModel): - data: DepOf[..., GenericResult[int]] - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - call_counts["class_model"] += 1 - return GenericResult(value=resolve(self.data).value + 100) - - @Flow.deps - def __deps__(self, context: SimpleContext): - return [(self.data, [context])] - - # Test both consumers with the same source - src = shared_source() - dec_consumer = decorator_consumer(data=src) - cls_consumer = ClassConsumer(data=src) - - ctx = SimpleContext(value=10) - - # Both should produce the same result - dec_result = dec_consumer(ctx) - cls_result = cls_consumer(ctx) - - self.assertEqual(dec_result.value, cls_result.value) - self.assertEqual(dec_result.value, 120) # 10 * 2 + 100 - - # Source should be called exactly twice (once per consumer) - self.assertEqual(call_counts["source"], 2) - self.assertEqual(call_counts["decorator_model"], 1) - self.assertEqual(call_counts["class_model"], 1) - - # ============================================================================= # Lazy[T] Type Annotation Tests # ============================================================================= @@ -2128,8 +1467,8 @@ def consumer( self.assertEqual(result.value, 42) self.assertEqual(call_counts["source"], 0) - def test_lazy_with_depof(self): - """Lazy[DepOf[...]] works: lazy dep with explicit DepOf annotation.""" + def test_lazy_with_upstream_model(self): + """Lazy[T] works when bound to an upstream model.""" from ccflow import Lazy @Flow.model @@ -2139,7 +1478,7 @@ def source(context: SimpleContext) -> GenericResult[int]: @Flow.model def consumer( context: SimpleContext, - data: Lazy[DepOf[..., GenericResult[int]]], + data: Lazy[GenericResult[int]], ) -> GenericResult[int]: return GenericResult(value=data() + 1) # data() returns unwrapped int diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py index 661ac4f..28f2883 100644 --- a/ccflow/tests/test_flow_model_hydra.py +++ b/ccflow/tests/test_flow_model_hydra.py @@ -124,7 +124,14 @@ def test_context_args_from_yaml(self): ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) - self.assertEqual(result.value, "data_source:2024-01-01 to 2024-01-31") + self.assertEqual( + result.value, + { + "source": "data_source", + "start_date": "2024-01-01", + "end_date": "2024-01-31", + }, + ) def test_context_args_pipeline_from_yaml(self): """Test context_args pipeline with dependencies from YAML.""" @@ -423,12 +430,9 @@ def test_transform_applied_from_yaml(self): self.assertEqual(len(deps), 1) dep_model, dep_contexts = deps[0] - # The transform should extend start_date back by one day - transformed_ctx = dep_contexts[0] - self.assertEqual(transformed_ctx.start_date, date(2024, 1, 9)) - self.assertEqual(transformed_ctx.end_date, date(2024, 1, 31)) - self.assertIs(dep_model, r["flow_date_loader"]) + self.assertEqual(dep_contexts[0], ctx) + self.assertEqual(dep_model(ctx).value["start_date"], "2024-01-09") if __name__ == "__main__": diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index 909b597..dca3fbe 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -1,501 +1,245 @@ -# Flow.model and DepOf: Dependency Injection for CallableModel +# Flow.model Design ## Overview -This document describes the `@Flow.model` decorator and `DepOf` annotation system for reducing boilerplate when creating `CallableModel` pipelines with dependencies. +`@Flow.model` turns a plain Python function into a real `CallableModel`. -**Key features:** -- `@Flow.model` - Decorator that generates `CallableModel` classes from plain functions -- `FlowContext` - Universal context carrier for unpacked/deferred execution -- `model.flow.compute(...)` / `model.flow.with_inputs(...)` - Deferred execution helpers -- `DepOf[ContextType, ResultType]` - Type annotation for dependency fields -- `Lazy[T]` - Mark a dependency for lazy, on-demand evaluation -- `FieldExtractor` - Access structured outputs via attribute access on generated models -- `resolve()` - Function to access resolved dependency values in class-based models +The core goals are: -## Quick Start +- keep the authoring model close to an ordinary function, +- preserve the existing evaluator / registry / serialization machinery, +- make deferred execution explicit with `.flow.compute(...)` and `.flow.with_inputs(...)`, +- allow callers to pass either literal values or upstream models for ordinary parameters. -### Pattern 1: `@Flow.model` (Recommended for Declarative Cases) +`@Flow.model` is syntactic sugar over the existing ccflow framework. The +generated object is still a standard `CallableModel`, so you can execute it the +same way as any other model by calling it with a context object. The +`.flow.compute(...)` helper is an explicit, ergonomic way to mark the deferred +execution boundary when supplying runtime inputs as keyword arguments. -```python -from datetime import date, timedelta -from typing import Annotated - -from ccflow import Flow, DateRangeContext, GenericResult, Dep, DepOf - - -def previous_window(ctx: DateRangeContext) -> DateRangeContext: - window = ctx.end_date - ctx.start_date - return ctx.model_copy( - update={ - "start_date": ctx.start_date - window - timedelta(days=1), - "end_date": ctx.start_date - timedelta(days=1), - } - ) +## Core Patterns -@Flow.model -def load_revenue(context: DateRangeContext, region: str) -> GenericResult[float]: - return GenericResult(value=125.0) +### Default Deferred Style -@Flow.model -def revenue_growth( - context: DateRangeContext, - current: DepOf[..., GenericResult[float]], - previous: Annotated[GenericResult[float], Dep(transform=previous_window)], -) -> GenericResult[dict]: - growth = (current.value - previous.value) / previous.value - return GenericResult(value={"as_of": context.end_date, "growth": growth}) - -# Build pipeline. The same upstream model is reused twice: -# - once with the original context -# - once with a fixed lookback transform -revenue = load_revenue(region="us") -growth = revenue_growth(current=revenue, previous=revenue) - -# Execute -ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = growth(ctx) -``` - -### Pattern 2: Class-Based (For Complex Cases) - -Use class-based when you need **configurable transforms** that depend on instance fields: +This is the most ergonomic mode. Bind some parameters up front, then provide +the remaining runtime inputs later. ```python -from datetime import timedelta - -from ccflow import CallableModel, DateRangeContext, Flow, GenericResult, DepOf -from ccflow.callable import resolve # Import resolve for class-based models - -class RevenueAverageWithWindow(CallableModel): - """Aggregate revenue with a configurable lookback window.""" - - revenue: DepOf[..., GenericResult[float]] - window: int = 7 # Configurable instance field - - @Flow.call - def __call__(self, context: DateRangeContext) -> GenericResult[float]: - # Use resolve() to get the resolved value - revenue = resolve(self.revenue) - return GenericResult(value=revenue.value / self.window) - - @Flow.deps - def __deps__(self, context: DateRangeContext): - # Transform uses self.window - this is why we need class-based! - lookback_ctx = context.model_copy( - update={"start_date": context.start_date - timedelta(days=self.window)} - ) - return [(self.revenue, [lookback_ctx])] - -# Usage - different window sizes, same source -loader = load_revenue(region="us") -avg_7 = RevenueAverageWithWindow(revenue=loader, window=7) -avg_30 = RevenueAverageWithWindow(revenue=loader, window=30) -``` +from ccflow import Flow, FlowContext -## When to Use Which Pattern -| Use `@Flow.model` when... | Use Class-Based when... | -|--------------------------------|---------------------------------------| -| The node still reads like a normal function | The main value is custom graph logic | -| Transforms are fixed/declarative | Transforms depend on instance fields | -| Less boilerplate is priority | You need full control over `__deps__` | -| Dependency wiring fits in the signature | Dependency behavior deserves its own class | +@Flow.model +def add(x: int, y: int) -> int: + return x + y -## Core Concepts -### `DepOf[ContextType, ResultType]` +model = add(x=10) -Shorthand for declaring dependency fields that can accept either: -- A pre-computed value of `ResultType` -- A `CallableModel` that produces `ResultType` +# Explicit deferred entry point +assert model.flow.compute(y=5) == 15 -```python -# Inherit context type from parent model -data: DepOf[..., GenericResult[dict]] +# Standard CallableModel call path +assert model(FlowContext(y=5)).value == 15 -# Explicit context type -data: DepOf[DateRangeContext, GenericResult[dict]] - -# Equivalent to: -data: Annotated[Union[GenericResult[dict], CallableModel], Dep()] +shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert shifted.flow.compute(y=5) == 20 ``` -For `@Flow.model`, plain non-`DepOf` parameters can also be populated with a -`CallableModel` instance. That lets callers either inject a concrete value or -splice in an upstream computation for the same parameter. Use `Dep`/`DepOf` -when you need explicit dependency metadata such as context transforms or -context-type validation. - -That means `DepOf` inside `@Flow.model` is most compelling when the function is -still doing real work and the dependency relationship is simple. If the node is -mostly a vessel for custom dependency graph wiring, a hand-written -`CallableModel` is usually clearer. +In this mode: -### `Dep(transform=..., context_type=...)` +- bound parameters are model configuration, +- unbound parameters become runtime inputs for that model instance. -For transforms, use the full `Annotated` form: +### Explicit Context Parameter ```python -from ccflow import Dep +from ccflow import DateRangeContext, Flow + @Flow.model -def compute_stats( - context: DateRangeContext, - records: Annotated[GenericResult[dict], Dep( - transform=lambda ctx: ctx.model_copy( - update={"start_date": ctx.start_date - timedelta(days=1)} - ) - )], -) -> GenericResult[float]: - return GenericResult(value=records.value["count"] * 0.05) +def load_revenue(context: DateRangeContext, region: str) -> float: + return 125.0 ``` -### `resolve()` Function +This is the most direct mode. The function receives a normal context object and +returns either a `ResultBase` subclass or a plain value. Plain values are +wrapped into `GenericResult` automatically by the generated model. -**Only needed for class-based models.** Accesses the resolved value of a `DepOf` field during `__call__`. +### `context_args` ```python -from ccflow.callable import resolve +from datetime import date -class MyModel(CallableModel): - data: DepOf[..., GenericResult[int]] +from ccflow import Flow - @Flow.call - def __call__(self, context: MyContext) -> GenericResult[int]: - # resolve() returns the GenericResult, not the CallableModel - result = resolve(self.data) - return GenericResult(value=result.value + 1) -``` -**Behavior:** -- Inside `__call__`: Returns the resolved value -- With direct values (not CallableModel): Returns unchanged (no-op) -- Outside `__call__`: Raises `RuntimeError` -- In `@Flow.model`: Not needed - values are passed as function arguments - -**Type inference:** -```python -data: DepOf[..., GenericResult[int]] -resolved = resolve(self.data) # Type: GenericResult[int] +@Flow.model(context_args=["start_date", "end_date"]) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + return 125.0 ``` -## How Resolution Works - -### `@Flow.model` Resolution Flow - -1. User calls `model(context)` -2. Generated `__call__` invokes `_resolve_deps_and_call()` -3. For each dependency-bearing field containing a `CallableModel`: - - Apply transform (if any) - - Call the dependency - - Store resolved value in context variable -4. Generated `__call__` reads the resolved values from the dependency store -5. Original function receives resolved values directly as normal function arguments - -### Class-Based Resolution Flow - -1. User calls `model(context)` -2. `_resolve_deps_and_call()` runs -3. For each `DepOf` field containing a `CallableModel`: - - Check `__deps__` for custom transforms - - If not listed in `__deps__`, fall back to the field's `Dep(...)` transform (or the original context) - - Call the dependency - - Store resolved value in context variable -4. User's `__call__` accesses values via `resolve(self.field)` - -**Important:** Resolution uses a context variable (`contextvars.ContextVar`), making it thread-safe and async-safe. - -## Design Decisions - -### Decision 1: `resolve()` Instead of Temporary Mutation +This keeps the function signature focused on the inputs it actually uses while +still producing a `CallableModel` that accepts a context at runtime. -**What we chose:** Explicit `resolve()` function with context variables. +Use `context_args` when certain parameters are semantically the execution +context and you want that split to be explicit and stable across model +instances. -**Alternative considered:** Temporarily mutate `self.field` during `__call__` to hold the resolved value, then restore after. +When the requested shape matches a built-in context like +`DateRangeContext(start_date, end_date)`, the generated model uses that type. +Otherwise it falls back to `FlowContext`. -**Why we chose this:** -- No mutation of model state -- Thread/async-safe via contextvars -- Explicit about what's happening -- Easier to debug - `self.field` always shows the original value +### Upstream Models as Normal Arguments -**Trade-off:** Slightly more verbose (`resolve(self.data).value` vs `self.data.value`). +Any non-context parameter can be given either: -### Decision 2: Unified Resolution Path +- a literal value, or +- another `CallableModel` / `BoundModel`. -**What we chose:** Both `@Flow.model` and class-based use the same `_resolve_deps_and_call()` function. +If a model is passed, it is evaluated with the current context and its result is +unwrapped before the function is called. -**Why:** -- Single source of truth for resolution logic -- Easier to maintain -- Consistent behavior across patterns - -### Decision 3: `resolve()` Not in Top-Level `__all__` - -**What we chose:** `resolve` must be imported explicitly: `from ccflow.callable import resolve` - -**Why:** -- Only needed for class-based models with `DepOf` -- Keeps top-level namespace clean -- Users who need it can find it easily +```python +from ccflow import DateRangeContext, Flow -### Decision 4: Auto-Wrap Plain Return Values -**What we chose:** If the function's declared return type is not a `ResultBase` -subclass, the generated model wraps the returned value in `GenericResult`. +@Flow.model +def load_revenue(context: DateRangeContext, region: str) -> float: + return 125.0 -**Why:** -- Reduces boilerplate for simple scalar / container-returning functions -- Preserves the `CallableModel` contract that runtime results are `ResultBase` -- Still allows explicit `ResultBase` subclasses when you want a precise result type -**Trade-off:** The original Python function may be annotated with a plain value -type while the generated model's runtime `result_type` is `GenericResult`. +@Flow.model +def double_revenue(_: DateRangeContext, revenue: float) -> float: + return revenue * 2 -### Decision 5: Generated Classes Are Real CallableModels -**What we chose:** Generate actual `CallableModel` subclasses using `type()`. +revenue = load_revenue(region="us") +model = double_revenue(revenue=revenue) +result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") +``` -**Why:** -- Full compatibility with existing infrastructure -- Caching, registry, serialization work unchanged -- Can mix with hand-written classes +This is the main composition story for the core API. -## Pitfalls and Limitations +### `.flow.with_inputs(...)` -### Pitfall 1: Forgetting `resolve()` in Class-Based Models +`with_inputs` is how a caller rewires context locally for one upstream model. ```python -class MyModel(CallableModel): - data: DepOf[..., GenericResult[int]] +from datetime import date, timedelta - @Flow.call - def __call__(self, context): - # WRONG - self.data is still the CallableModel! - return GenericResult(value=self.data.value + 1) +from ccflow import DateRangeContext, Flow - # CORRECT - return GenericResult(value=resolve(self.data).value + 1) -``` -**Error you'll see:** `AttributeError: '_SomeModel' object has no attribute 'value'` +@Flow.model(context_args=["start_date", "end_date"]) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + days = (end_date - start_date).days + 1 + return 1000.0 + days * 10.0 -### Pitfall 2: Calling `resolve()` Outside `__call__` -```python -model = MyModel(data=some_source()) -resolve(model.data) # RuntimeError! -``` +@Flow.model(context_args=["start_date", "end_date"]) +def revenue_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: + return { + "window_end": end_date, + "growth_pct": round((current - previous) / previous * 100, 2), + } -`resolve()` only works during `__call__` execution. -### Pitfall 3: Lambda Transforms Don't Serialize +current = load_revenue(region="us") +previous = current.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), +) -```python -# Won't serialize - lambdas can't be pickled -Dep(transform=lambda ctx: ctx.model_copy(...)) +model = revenue_growth(current=current, previous=previous) +ctx = DateRangeContext( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) -# Will serialize - use named functions -def shift_start(ctx): - return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=1)}) +direct = model(ctx).value +computed = model.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) -Dep(transform=shift_start) +assert direct == computed ``` -### Pitfall 4: GraphEvaluator Requires Caching - -When using `GraphEvaluator` with `DepOf`, dependencies may be called twice (once by GraphEvaluator, once by resolution) unless caching is enabled. - -```python -# Use with caching -from ccflow.evaluators import GraphEvaluator, CachingEvaluator, MultiEvaluator +The transform is local to the bound upstream model. The parent model continues +to receive the original context. -evaluator = MultiEvaluator(evaluators=[ - CachingEvaluator(), - GraphEvaluator(), -]) -``` +### `.flow.compute(...)` -### Pitfall 5: Two Mental Models +`compute` is the ergonomic entry point for deferred execution: -Users need to remember: -- `@Flow.model`: Use dependency values directly as function arguments -- Class-based: Use `resolve(self.field)` to access values +```python +from ccflow import Flow -### Limitation: Custom `__deps__` Is Only Needed for Custom Graph Logic -Class-based models do not need a custom `__deps__` override when the default -field-level `Dep(...)` behavior is sufficient. Override `__deps__` only when -you need instance-dependent transforms or a custom dependency graph: +@Flow.model +def add(x: int, y: int) -> int: + return x + y -```python -class Consumer(CallableModel): - data: DepOf[..., GenericResult[int]] - @Flow.call - def __call__(self, context): - return GenericResult(value=resolve(self.data).value) +model = add(x=10) +assert model.flow.compute(y=5) == 15 ``` -If you do need to use instance fields in the transform, then `__deps__` is the -right place to do it: +It validates the supplied keyword arguments against the generated context +schema, creates a `FlowContext`, executes the model, and unwraps +`GenericResult.value` if needed. -```python -class WindowedConsumer(CallableModel): - data: DepOf[..., GenericResult[int]] - window: int = 7 - - @Flow.call - def __call__(self, context): - return GenericResult(value=resolve(self.data).value) - - @Flow.deps - def __deps__(self, context): - shifted = context.model_copy(update={"value": context.value + self.window}) - return [(self.data, [shifted])] -``` +It is not the only execution path. Because the generated object is still a +standard `CallableModel`, calling `model(context)` remains fully supported. -### Limitation: `context_args` Type Matching Is Best-Effort +## Lazy Inputs -When you use `context_args=[...]`, the framework validates those fields via a -runtime `TypedDict` schema. It only maps to a concrete built-in context type in -special cases such as `DateRangeContext`. Otherwise the generated model's -`context_type` is `FlowContext`, a universal frozen carrier for the validated -context values. - -## Complete Example: Multi-Stage Pipeline +`Lazy[T]` marks a parameter as on-demand. Instead of eagerly resolving an +upstream model, the generated model passes a zero-argument thunk. The thunk +caches its first result. ```python -from datetime import date, timedelta -from typing import Annotated - -from ccflow import ( - CallableModel, DateRangeContext, Dep, DepOf, - Flow, GenericResult -) -from ccflow.callable import resolve +from ccflow import Flow, Lazy -# Stage 1: Data loader (simple, use @Flow.model) @Flow.model -def load_events(context: DateRangeContext, source: str) -> GenericResult[list]: - print(f"Loading from {source} for {context.start_date} to {context.end_date}") - return GenericResult(value=[ - {"date": str(context.start_date), "count": 100 + i} - for i in range(5) - ]) +def source(value: int) -> int: + return value * 10 -# Stage 2: Transform with fixed lookback (use @Flow.model with Dep transform) @Flow.model -def compute_daily_totals( - context: DateRangeContext, - events: Annotated[GenericResult[list], Dep( - transform=lambda ctx: ctx.model_copy( - update={"start_date": ctx.start_date - timedelta(days=1)} - ) - )], -) -> GenericResult[float]: - values = [e["count"] for e in events.value] - total = sum(values) / len(values) if values else 0 - return GenericResult(value=total) - - -# Stage 3: Configurable window (use class-based) -class ComputeRollingSummary(CallableModel): - """Summary with configurable lookback window.""" - - totals: DepOf[..., GenericResult[float]] - window: int = 20 - - @Flow.call - def __call__(self, context: DateRangeContext) -> GenericResult[float]: - totals = resolve(self.totals) - # Scale by window size - summary = totals.value * (self.window ** 0.5) - return GenericResult(value=summary) - - @Flow.deps - def __deps__(self, context: DateRangeContext): - lookback = context.model_copy( - update={"start_date": context.start_date - timedelta(days=self.window)} - ) - return [(self.totals, [lookback])] - - -# Build pipeline -events = load_events(source="main_db") -totals = compute_daily_totals(events=events) -summary_20 = ComputeRollingSummary(totals=totals, window=20) -summary_60 = ComputeRollingSummary(totals=totals, window=60) - -# Execute -ctx = DateRangeContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) -print(f"20-day summary: {summary_20(ctx).value}") -print(f"60-day summary: {summary_60(ctx).value}") +def maybe_use_source(value: int, data: Lazy[int]) -> int: + if value > 10: + return value + return data() ``` -## API Reference - -### `@Flow.model` - -```python -@Flow.model( - context_args: list[str] = None, # Unpack context fields as function args - cacheable: bool = False, - volatile: bool = False, - log_level: int = logging.DEBUG, - validate_result: bool = True, - verbose: bool = True, - evaluator: EvaluatorBase = None, -) -def my_function(context: ContextType, ...) -> ResultType: - ... -``` - -If the function is annotated with a plain value type instead of a `ResultBase` -subclass, the generated model will wrap the returned value in `GenericResult` -at runtime. - -### `DepOf[ContextType, ResultType]` +## FlowContext -```python -# Inherit context from parent -field: DepOf[..., GenericResult[int]] +`FlowContext` is the universal frozen carrier for generated contexts that do +not map to a dedicated built-in context type. -# Explicit context type -field: DepOf[DateRangeContext, GenericResult[int]] -``` +The implementation stays intentionally small: -### `Dep(transform=..., context_type=...)` +- context validation is driven by `TypedDict` + `TypeAdapter`, +- runtime execution uses one reusable `FlowContext` type, +- public pydantic iteration (`dict(context)`) is used instead of pydantic + internals. -```python -field: Annotated[GenericResult[int], Dep( - transform=my_transform_func, # Optional: (context) -> transformed_context - context_type=DateRangeContext, # Optional: Expected context type -)] -``` +## BoundModel -### `resolve(dep)` +`.flow.with_inputs(...)` returns a `BoundModel`, which is just a thin wrapper +around: -```python -from ccflow.callable import resolve +- the original model, and +- a mapping of input transforms. -# Inside __call__ of class-based CallableModel: -resolved_value = resolve(self.dep_field) +At call time it: -# Type signature: -def resolve(dep: Union[T, CallableModel]) -> T: ... -``` - -## File Structure +1. converts the incoming context into a plain dictionary, +1. applies the configured transforms, +1. rebuilds a `FlowContext`, +1. delegates to the wrapped model. -``` -ccflow/ -├── callable.py # CallableModel, Flow, resolve(), _resolve_deps_and_call() -├── dep.py # Dep, DepOf, extract_dep() -├── flow_model.py # @Flow.model implementation -└── tests/ - └── test_flow_model.py # Comprehensive tests -``` +That keeps transformed dependency wiring explicit without adding special +annotation machinery to the core API. diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index f73ac6b..5f1b502 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -24,7 +24,19 @@ As an example, you may have a `SQLReader` callable model that when called with a ### Flow.model Decorator -The `@Flow.model` decorator provides a simpler way to define `CallableModel`s using plain Python functions instead of classes. It automatically generates a `CallableModel` class with proper `__call__` and `__deps__` methods. +The `@Flow.model` decorator provides a simpler way to define `CallableModel`s +using plain Python functions instead of classes. It automatically generates a +standard `CallableModel` class with proper `__call__` and `__deps__` methods, +so it still uses the normal ccflow framework for evaluation, caching, +serialization, and registry loading. + +You can execute a generated model in two equivalent ways: + +- call it directly with a context object: `model(ctx)` +- use `.flow.compute(...)` to supply runtime inputs as keyword arguments + +`.flow.compute(...)` is mainly an explicit, ergonomic way to mark the deferred +execution point. **Basic Example:** @@ -45,79 +57,130 @@ ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) ``` -**Composing Dependencies with `Dep` and `DepOf`:** +**Default `@Flow.model` Style:** + +Use this when you want the simplest API and do not need to declare a formal +context shape up front. + +```python +from ccflow import Flow + +@Flow.model +def add(x: int, y: int) -> int: + return x + y + +model = add(x=10) + +# `x` is bound when the model is created. +# `y` is supplied later at execution time. +assert model.flow.compute(y=5) == 15 + +# `.flow.with_inputs(...)` rewrites runtime inputs for this call path. +doubled_y = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert doubled_y.flow.compute(y=5) == 20 +``` -Use `Dep()` or `DepOf` to mark parameters that accept other `CallableModel`s as dependencies. The framework automatically resolves the dependency graph. +In this mode: -For `@Flow.model`, regular parameters can also accept a `CallableModel` value at -construction time. This lets you either inject a literal value or splice in an -upstream computation for the same parameter. Use `Dep`/`DepOf` when you need -context transforms or explicit dependency metadata. +- bound parameters are model configuration +- unbound parameters are runtime inputs for that model instance -> **Rule of thumb:** `@Flow.model` works best when the dependency wiring is declarative and local to the signature. If the main point of the node is custom graph logic or transforms that depend on instance fields, use a class-based `CallableModel` instead. +**Composing Dependencies with Normal Parameters:** + +Any non-context parameter can be bound either to a literal value or to another +`CallableModel`. If you pass an upstream model, `@Flow.model` evaluates it with +the current context and passes the resolved value into your function. ```python from datetime import date, timedelta -from typing import Annotated -from ccflow import Flow, GenericResult, DateRangeContext, Dep, DepOf - -def previous_window(ctx: DateRangeContext) -> DateRangeContext: - window = ctx.end_date - ctx.start_date - return ctx.model_copy( - update={ - "start_date": ctx.start_date - window - timedelta(days=1), - "end_date": ctx.start_date - timedelta(days=1), - } - ) +from ccflow import DateRangeContext, Flow -@Flow.model -def load_revenue(context: DateRangeContext, region: str) -> GenericResult[float]: - # Pretend this queries a warehouse - return GenericResult(value=125.0) +@Flow.model(context_args=["start_date", "end_date"]) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + days = (end_date - start_date).days + 1 + return 1000.0 + days * 10.0 -@Flow.model +@Flow.model(context_args=["start_date", "end_date"]) def revenue_growth( - context: DateRangeContext, - current: DepOf[..., GenericResult[float]], - previous: Annotated[GenericResult[float], Dep(transform=previous_window)], -) -> GenericResult[dict]: - growth = (current.value - previous.value) / previous.value - return GenericResult(value={"as_of": context.end_date, "growth": growth}) - -# Build the pipeline. The same loader is reused with two contexts: -# - current window: original context -# - previous window: transformed via Dep(transform=...) -revenue = load_revenue(region="us") -growth = revenue_growth(current=revenue, previous=revenue) - -ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = growth(ctx) + start_date: date, + end_date: date, + current: float, + previous: float, +) -> dict: + return { + "window_end": end_date, + "growth_pct": round((current - previous) / previous * 100, 2), + } + +current = load_revenue(region="us") +previous = current.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), +) +growth = revenue_growth(current=current, previous=previous) + +ctx = DateRangeContext( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) + +# Standard ccflow execution +direct = growth(ctx).value + +# Equivalent explicit deferred entry point +computed = growth.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), +) + +assert direct == computed ``` -`DepOf` is also useful when you want the same parameter to accept either an -upstream model or a precomputed value: +This pattern is the main story for transformed dependencies. `@Flow.model` +still produces an ordinary `CallableModel`; `.flow.compute(...)` is just a +clearer way to say "supply the runtime inputs here." + +**Why `context_args` Exists:** + +Without `context_args`, runtime inputs are inferred from whichever parameters +are still unbound on a particular model instance. That is flexible and +ergonomic. + +Use `context_args` when some parameters are semantically "the execution +context" and you want that split to stay stable and explicit: + +- the runtime context should be stable across instances +- the split between config and runtime inputs matters semantically +- the model is naturally "run over a context" such as date windows, + partitions, or scenarios +- you want the generated model to match a built-in context type like + `DateRangeContext` when possible + +**Deferred Execution Helpers:** ```python -from ccflow import DateRangeContext, DepOf, Flow, GenericResult +from ccflow import Flow @Flow.model -def load_signal(context: DateRangeContext, source: str) -> GenericResult[float]: - return GenericResult(value=0.87) +def add(x: int, y: int) -> int: + return x + y -@Flow.model -def publish_signal( - context: DateRangeContext, - signal: DepOf[..., GenericResult[float]], - threshold: float = 0.8, -) -> GenericResult[dict]: - return GenericResult(value={ - "as_of": context.end_date, - "signal": signal.value, - "go_live": signal.value >= threshold, - }) - -live = publish_signal(signal=load_signal(source="prod")) -override = publish_signal(signal=GenericResult(value=0.95), threshold=0.9) +model = add(x=10) +assert model.flow.compute(y=5) == 15 + +shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) +assert shifted.flow.compute(y=5) == 20 +``` + +If you already have a real context object, you can call the model directly +instead: + +```python +from ccflow import FlowContext + +ctx = FlowContext(y=5) +assert model(ctx).value == 15 +assert shifted(ctx).value == 20 ``` **Hydra/YAML Configuration:** @@ -166,27 +229,6 @@ result = loader(ctx) # "my_database:2024-01-01 to 2024-01-31" The `context_args` parameter specifies which function parameters should be extracted from the context. Those fields are validated through a runtime schema built from the parameter annotations. For well-known shapes such as `start_date` / `end_date`, the generated model uses a concrete built-in context type like `DateRangeContext`; otherwise it uses `FlowContext`, a universal frozen carrier for the validated fields. -**Deferred Execution Helpers:** - -Generated models also expose a `.flow` helper namespace: - -```python -from ccflow import Flow, GenericResult - -@Flow.model -def add(x: int, y: int) -> GenericResult[int]: - return GenericResult(value=x + y) - -model = add(x=10) - -# Validate and execute by passing context fields as kwargs -assert model.flow.compute(y=5) == 15 - -# Derive a new model by transforming context inputs -shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert shifted.flow.compute(y=5) == 20 -``` - If a `@Flow.model` function returns a plain value instead of a `ResultBase` subclass, the generated model automatically wraps that value in `GenericResult` at runtime so it still behaves like a normal `CallableModel`. diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index e93d452..f2616dc 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -1,220 +1,64 @@ #!/usr/bin/env python -"""Example demonstrating Flow.model decorator and class-based CallableModel. - -This example shows: -- Flow.model for simple functions with minimal boilerplate -- Context transforms with Dep annotations -- Class-based CallableModel for complex cases needing instance field access -""" +"""Example demonstrating the core Flow.model workflow.""" from datetime import date, timedelta -from typing import Annotated - -from ccflow import CallableModel, DateRangeContext, Dep, DepOf, Flow, GenericResult -from ccflow.callable import resolve - - -# ============================================================================= -# Example 1: Basic Flow.model - No more boilerplate classes! -# ============================================================================= - -@Flow.model -def load_records(context: DateRangeContext, source: str, limit: int = 100) -> GenericResult[list]: - """Load records from a data source for the given date range.""" - print(f" Loading from '{source}' for {context.start_date} to {context.end_date} (limit={limit})") - return GenericResult(value=[ - {"id": i, "date": str(context.start_date), "value": i * 10} - for i in range(min(limit, 5)) - ]) - - -# ============================================================================= -# Example 2: Dependencies with DepOf - Automatic dependency resolution -# ============================================================================= - -@Flow.model -def compute_totals( - _: DateRangeContext, # Context passed to dependency, not used directly here - records: DepOf[..., GenericResult[list]], -) -> GenericResult[dict]: - """Compute totals from loaded records.""" - total = sum(r["value"] for r in records.value) - count = len(records.value) - print(f" Computing totals: {count} records, total={total}") - return GenericResult(value={"total": total, "count": count}) - - -# ============================================================================= -# Example 3: Simple Transform with Flow.model -# When the transform is a fixed function, Flow.model works great -# ============================================================================= - -def lookback_7_days(ctx: DateRangeContext) -> DateRangeContext: - """Fixed transform that extends the date range back by 7 days.""" - return ctx.model_copy(update={"start_date": ctx.start_date - timedelta(days=7)}) - - -@Flow.model -def compute_weekly_average( - _: DateRangeContext, - records: Annotated[GenericResult[list], Dep(transform=lookback_7_days)], -) -> GenericResult[float]: - """Compute average using fixed 7-day lookback.""" - values = [r["value"] for r in records.value] - avg = sum(values) / len(values) if values else 0 - print(f" Computing weekly average: {avg:.2f} (from {len(values)} records)") - return GenericResult(value=avg) - - -# ============================================================================= -# Example 4: Class-based CallableModel with Configurable Transform -# When the transform needs access to instance fields (like window size), -# use a class-based approach with auto-resolution -# ============================================================================= - -class ComputeMovingAverage(CallableModel): - """Compute moving average with configurable lookback window. - - This demonstrates: - - Field uses DepOf annotation: accepts either result or CallableModel - - Instance field (window) accessible in __deps__ for custom transforms - - resolve() to access resolved dependency values during __call__ - """ - - records: DepOf[..., GenericResult[list]] - window: int = 7 # Configurable lookback window - - @Flow.call - def __call__(self, context: DateRangeContext) -> GenericResult[float]: - """Compute the moving average - use resolve() to get resolved value.""" - records = resolve(self.records) # Get the resolved GenericResult - values = [r["value"] for r in records.value] - avg = sum(values) / len(values) if values else 0 - print(f" Computing {self.window}-day moving average: {avg:.2f} (from {len(values)} records)") - return GenericResult(value=avg) - - @Flow.deps - def __deps__(self, context: DateRangeContext): - """Define dependencies with transform that uses self.window.""" - # This is where we can access instance fields! - lookback_ctx = context.model_copy( - update={"start_date": context.start_date - timedelta(days=self.window)} - ) - return [(self.records, [lookback_ctx])] - - -# ============================================================================= -# Example 5: Multi-stage pipeline - Composing models together -# ============================================================================= - -@Flow.model -def generate_report( - context: DateRangeContext, - totals: DepOf[..., GenericResult[dict]], - moving_avg: DepOf[..., GenericResult[float]], - report_name: str = "Daily Report", -) -> GenericResult[str]: - """Generate a report combining multiple data sources.""" - report = f""" -{report_name} -{'=' * len(report_name)} -Date Range: {context.start_date} to {context.end_date} -Total Value: {totals.value['total']} -Record Count: {totals.value['count']} -Moving Avg: {moving_avg.value:.2f} -""" - return GenericResult(value=report.strip()) - - -# ============================================================================= -# Example 6: Using context_args for cleaner signatures -# ============================================================================= + +from ccflow import DateRangeContext, Flow + @Flow.model(context_args=["start_date", "end_date"]) -def fetch_metadata(start_date: date, end_date: date, category: str) -> GenericResult[dict]: - """Fetch metadata - note how start_date/end_date are direct parameters.""" - print(f" Fetching metadata for '{category}' from {start_date} to {end_date}") - return GenericResult(value={ - "category": category, - "days": (end_date - start_date).days, - "generated_at": str(date.today()), - }) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + """Pretend to load revenue for a date window.""" + days = (end_date - start_date).days + 1 + baseline = 1000.0 if region == "us" else 800.0 + return baseline + days * 10.0 + +@Flow.model(context_args=["start_date", "end_date"]) +def summarize_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: + """Compare the current and previous windows.""" + growth_pct = round((current - previous) / previous * 100, 2) + return { + "start_date": start_date, + "end_date": end_date, + "current": current, + "previous": previous, + "growth_pct": growth_pct, + } -# ============================================================================= -# Main: Build and execute the pipeline -# ============================================================================= def main(): print("=" * 60) - print("Flow.model Example - Simplified CallableModel Creation") + print("Flow.model Example") print("=" * 60) + current_window = load_revenue(region="us") + previous_window = current_window.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30), + end_date=lambda ctx: ctx.end_date - timedelta(days=30), + ) + + growth = summarize_growth(current=current_window, previous=previous_window) + ctx = DateRangeContext( - start_date=date(2024, 1, 15), - end_date=date(2024, 1, 31) + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), ) - # --- Example 1: Basic model --- - print("\n[1] Basic Flow.model:") - loader = load_records(source="main_db", limit=5) - result = loader(ctx) - print(f" Result: {result.value}") - - # --- Example 2: Simple dependency chain --- - print("\n[2] Dependency chain (loader -> totals):") - loader = load_records(source="main_db") - totals = compute_totals(records=loader) - result = totals(ctx) - print(f" Result: {result.value}") - - # --- Example 3: Fixed transform with Flow.model --- - print("\n[3] Fixed transform (7-day lookback with Flow.model):") - loader = load_records(source="main_db") - weekly_avg = compute_weekly_average(records=loader) - result = weekly_avg(ctx) - print(f" Result: {result.value}") - - # --- Example 4: Configurable transform with class-based model --- - print("\n[4] Configurable transform (class-based with auto-resolution):") - loader = load_records(source="main_db") - - # 14-day window - moving_avg_14 = ComputeMovingAverage(records=loader, window=14) - result = moving_avg_14(ctx) - print(f" 14-day result: {result.value}") - - # 30-day window - same loader, different window - moving_avg_30 = ComputeMovingAverage(records=loader, window=30) - result = moving_avg_30(ctx) - print(f" 30-day result: {result.value}") - - # --- Example 5: Full pipeline --- - print("\n[5] Full pipeline (mixing Flow.model and class-based):") - loader = load_records(source="analytics_db") - totals = compute_totals(records=loader) - moving_avg = ComputeMovingAverage(records=loader, window=7) - report = generate_report( - totals=totals, - moving_avg=moving_avg, - report_name="Analytics Summary" + print("\n[1] Execute as a normal CallableModel:") + print(growth(ctx).value) + + print("\n[2] Execute via .flow.compute(...):") + print( + growth.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 1, 31), + ) ) - result = report(ctx) - print(result.value) - - # --- Example 6: context_args --- - print("\n[6] Using context_args (auto-unpacked context):") - metadata = fetch_metadata(category="sales") - result = metadata(ctx) - print(f" Result: {result.value}") - - # --- Bonus: Inspecting models --- - print("\n[Bonus] Inspecting models:") - print(f" load_records.context_type = {loader.context_type.__name__}") - print(f" ComputeMovingAverage uses __deps__ for custom transforms") - deps = moving_avg.__deps__(ctx) - for dep_model, dep_contexts in deps: - print(f" - Dependency context start: {dep_contexts[0].start_date} (lookback applied)") + + print("\n[3] Inspect bound and unbound inputs:") + print(" bound_inputs:", growth.flow.bound_inputs) + print(" unbound_inputs:", growth.flow.unbound_inputs) if __name__ == "__main__": From 0c274a1f2d564ba856fcb9981a9c99c75649262f Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 17:02:40 -0400 Subject: [PATCH 11/26] Clean up more, make repr nicer for BoundModel Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 101 ++++++++---- ccflow/flow_model.py | 89 ++++++++--- ccflow/tests/test_flow_context.py | 59 ++++++- docs/design/flow_model_design.md | 60 ++++++- docs/wiki/Key-Features.md | 253 ++++++++++++++++++++++-------- examples/flow_model_example.py | 100 ++++++++---- 6 files changed, 516 insertions(+), 146 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index fd849c5..185c229 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -16,7 +16,7 @@ import logging from functools import lru_cache, wraps from inspect import Signature, isclass, signature -from typing import Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, get_args, get_origin +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generic, List, Optional, Tuple, Type, TypeVar, Union, cast, get_args, get_origin from pydantic import BaseModel as PydanticBaseModel, ConfigDict, Field, InstanceOf, PrivateAttr, TypeAdapter, field_validator, model_validator from typing_extensions import override @@ -31,6 +31,9 @@ from .local_persistence import create_ccflow_model from .validators import str_to_log_level +if TYPE_CHECKING: + from .flow_model import FlowAPI + __all__ = ( "GraphDepType", "GraphDepList", @@ -62,6 +65,25 @@ def _cached_signature(fn): return signature(fn) +def _callable_qualname(fn: Callable[..., Any]) -> str: + return getattr(fn, "__qualname__", type(fn).__qualname__) + + +def _declared_type_matches(actual: Any, expected: Any) -> bool: + if isinstance(expected, TypeVar): + return True + if get_origin(expected) is Union: + expected_args = tuple(arg for arg in get_args(expected) if isinstance(arg, type)) + if not expected_args: + return False + if get_origin(actual) is Union: + actual_args = tuple(arg for arg in get_args(actual) if isinstance(arg, type)) + return set(actual_args) == set(expected_args) + return isinstance(actual, type) and any(issubclass(actual, arg) for arg in expected_args) + + return isinstance(actual, type) and isinstance(expected, type) and issubclass(actual, expected) + + class MetaData(BaseModel): """Class to represent metadata for all callable models""" @@ -329,15 +351,16 @@ def wrapper(model, context=Signature.empty, *, _options: Optional[FlowOptions] = return result wrap = wraps(fn)(wrapper) - wrap.get_evaluator = self.get_evaluator - wrap.get_options = self.get_options - wrap.get_evaluation_context = get_evaluation_context + wrap_any = cast(Any, wrap) + wrap_any.get_evaluator = self.get_evaluator + wrap_any.get_options = self.get_options + wrap_any.get_evaluation_context = get_evaluation_context # Preserve auto context attributes for introspection if hasattr(fn, "__auto_context__"): - wrap.__auto_context__ = fn.__auto_context__ + wrap_any.__auto_context__ = fn.__auto_context__ if hasattr(fn, "__result_type__"): - wrap.__result_type__ = fn.__result_type__ + wrap_any.__result_type__ = fn.__result_type__ return wrap @@ -480,7 +503,7 @@ def __call__(self, *, date: date, extra: int = 0) -> MyResult: # Note that the code below is executed only once if auto_context_enabled: # Return a decorator that first applies auto_context, then FlowOptions - def auto_context_decorator(fn): + def auto_context_decorator(fn: Callable[..., Any]) -> Callable[..., Any]: wrapped = _apply_auto_context(fn, parent=context_parent) # FlowOptions.__call__ already applies wraps, so we just return its result return FlowOptions(**kwargs)(wrapped) @@ -592,13 +615,15 @@ class ModelEvaluationContext( # TODO: Make the instance check compatible with the generic types instead of the base type @model_validator(mode="wrap") - def _context_validator(cls, values, handler, info): + @classmethod + def _context_validator(cls, values: Any, handler: Any, info: Any): """Override _context_validator from parent""" # Validate the context with the model, if possible - model = values.get("model") - if model and isinstance(model, CallableModel) and not isinstance(values.get("context"), model.context_type): - values["context"] = model.context_type.model_validate(values.get("context")) + if isinstance(values, dict): + model = values.get("model") + if model and isinstance(model, CallableModel) and not isinstance(values.get("context"), model.context_type): + values["context"] = model.context_type.model_validate(values.get("context")) # Apply standard pydantic validation context = handler(values) @@ -626,9 +651,9 @@ def __call__(self) -> ResultType: raise TypeError(f"Model result_type {result_type} is not a subclass of ResultBase") result = result_type.model_validate(result) - return result + return cast(ResultType, result) else: - return fn(self.context) + return cast(ResultType, fn(self.context)) class EvaluatorBase(_CallableModel, abc.ABC): @@ -716,7 +741,7 @@ def context_type(self) -> Type[ContextType]: if not isclass(type_to_check) or not issubclass(type_to_check, ContextBase): raise TypeError(f"Context type declared in signature of __call__ must be a subclass of ContextBase. Received {type_to_check}.") - return typ + return cast(Type[ContextType], typ) @property def result_type(self) -> Type[ResultType]: @@ -759,7 +784,7 @@ def result_type(self) -> Type[ResultType]: # Ensure subclass of ResultBase if not isclass(typ) or not issubclass(typ, ResultBase): raise TypeError(f"Return type declared in signature of __call__ must be a subclass of ResultBase (i.e. GenericResult). Received {typ}.") - return typ + return cast(Type[ResultType], typ) @Flow.deps def __deps__( @@ -775,6 +800,13 @@ def __deps__( """ return [] + @property + def flow(self) -> "FlowAPI": + """Access flow helpers for execution, context transforms, and introspection.""" + from .flow_model import FlowAPI + + return FlowAPI(self) + class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): """Abstract class that represents a wrapper around an underlying model, with the same context and return types. @@ -787,12 +819,12 @@ class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): @property def context_type(self) -> Type[ContextType]: """Return the context type of the underlying model.""" - return self.model.context_type + return cast(CallableModel, self.model).context_type @property def result_type(self) -> Type[ResultType]: """Return the result type of the underlying model.""" - return self.model.result_type + return cast(CallableModel, self.model).result_type class CallableModelGeneric(CallableModel, Generic[ContextType, ResultType]): @@ -864,32 +896,36 @@ def _determine_context_result(cls): if new_context_type is not None: # Set on class - cls._context_generic_type = new_context_type + setattr(cls, "_context_generic_type", new_context_type) if new_result_type is not None: # Set on class - cls._result_generic_type = new_result_type + setattr(cls, "_result_generic_type", new_result_type) @model_validator(mode="wrap") - def _validate_callable_model_generic_type(cls, m, handler, info): + @classmethod + def _validate_callable_model_generic_type(cls, m: Any, handler: Any, info: Any): from ccflow.base import resolve_str if isinstance(m, str): m = resolve_str(m) - if isinstance(m, dict): - m = handler(m) - elif isinstance(m, cls): - m = handler(m) + validated_cls = cast(Any, cls) + if isinstance(m, (dict, CallableModel)): + if isinstance(m, dict): + m = handler(m) + elif isinstance(m, validated_cls): + m = handler(m) # Raise ValueError (not TypeError) as per https://docs.pydantic.dev/latest/errors/errors/ if not isinstance(m, CallableModel): raise ValueError(f"{m} is not a CallableModel: {type(m)}") subtypes = cls.__pydantic_generic_metadata__["args"] - if subtypes: - TypeAdapter(Type[subtypes[0]]).validate_python(m.context_type) - TypeAdapter(Type[subtypes[1]]).validate_python(m.result_type) + if len(subtypes) >= 1 and not _declared_type_matches(m.context_type, subtypes[0]): + raise ValueError(f"{m} context_type {m.context_type} does not match {subtypes[0]}") + if len(subtypes) >= 2 and not _declared_type_matches(m.result_type, subtypes[1]): + raise ValueError(f"{m} result_type {m.result_type} does not match {subtypes[1]}") return m @@ -902,7 +938,7 @@ def _validate_callable_model_generic_type(cls, m, handler, info): # ***************************************************************************** -def _apply_auto_context(func: Callable, *, parent: Type[ContextBase] = None) -> Callable: +def _apply_auto_context(func: Callable[..., Any], *, parent: Optional[Type[ContextBase]] = None) -> Callable[..., Any]: """Internal function that creates an auto context class from function parameters. This function extracts the parameters from a function signature and creates @@ -941,7 +977,7 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: fields[name] = (param.annotation, default) # Create auto context class - auto_context_class = create_ccflow_model(f"{func.__qualname__}_AutoContext", __base__=base_class, **fields) + auto_context_class = create_ccflow_model(f"{_callable_qualname(func)}_AutoContext", __base__=base_class, **fields) @wraps(func) def wrapper(self, context): @@ -949,13 +985,14 @@ def wrapper(self, context): return func(self, **fn_kwargs) # Must set __signature__ so CallableModel validation sees 'context' parameter - wrapper.__signature__ = inspect.Signature( + wrapper_any = cast(Any, wrapper) + wrapper_any.__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=auto_context_class), ], return_annotation=sig.return_annotation, ) - wrapper.__auto_context__ = auto_context_class - wrapper.__result_type__ = sig.return_annotation + wrapper_any.__auto_context__ = auto_context_class + wrapper_any.__result_type__ = sig.return_annotation return wrapper diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 44d9cfa..e2496ea 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -12,7 +12,7 @@ import inspect import logging from functools import wraps -from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, cast, get_args, get_origin +from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin from pydantic import Field, TypeAdapter, model_validator from typing_extensions import TypedDict @@ -87,6 +87,17 @@ def _context_values(context: ContextBase) -> Dict[str, Any]: return dict(context) +def _transform_repr(transform: Any) -> str: + """Render an input transform without noisy object addresses.""" + + if callable(transform): + name = _callable_name(transform) + if name.startswith("<") and name.endswith(">"): + return name + return f"<{name}>" + return repr(transform) + + def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter: """Build a TypeAdapter for a runtime TypedDict schema.""" @@ -95,6 +106,22 @@ def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter return TypeAdapter(TypedDict(name, schema)) +def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: + """Extract a concrete ContextBase subclass from a context annotation.""" + + if isinstance(context_type, type) and issubclass(context_type, ContextBase): + return context_type + + if get_origin(context_type) in (Optional, Union): + for arg in get_args(context_type): + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, ContextBase): + return arg + + return None + + def _build_config_validators(all_param_types: Dict[str, Type]) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: """Precompute validators for constructor fields.""" @@ -141,9 +168,20 @@ class FlowAPI: Accessed via model.flow property. """ - def __init__(self, model: "_GeneratedFlowModelBase"): + def __init__(self, model: CallableModel): self._model = model + def _build_context(self, kwargs: Dict[str, Any]) -> ContextBase: + """Construct a runtime context for either generated or hand-written models.""" + get_validator = getattr(self._model, "_get_context_validator", None) + if get_validator is not None: + validator = get_validator() + validated = validator.validate_python(kwargs) + return FlowContext(**validated) + + validator = TypeAdapter(self._model.context_type) + return validator.validate_python(kwargs) + def compute(self, **kwargs) -> Any: """Execute the model with the provided context arguments. @@ -156,14 +194,7 @@ def compute(self, **kwargs) -> Any: Returns: The model's result, unwrapped from GenericResult if applicable. """ - # Get validator from model (lazily created if needed after unpickling) - validator = self._model._get_context_validator() - - # Validate and coerce kwargs via TypeAdapter - validated = validator.validate_python(kwargs) - - # Wrap in FlowContext (single class, always) - ctx = FlowContext(**validated) + ctx = self._build_context(kwargs) # Call the model result = self._model(ctx) @@ -181,23 +212,41 @@ def unbound_inputs(self) -> Dict[str, Type]: """ all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) bound_fields = getattr(self._model, "_bound_fields", set()) + model_cls = self._model.__class__ # If explicit context_args was provided, use _context_schema - explicit_args = getattr(self._model.__class__, "__flow_model_explicit_context_args__", None) + explicit_args = getattr(model_cls, "__flow_model_explicit_context_args__", None) if explicit_args is not None: - return self._model._context_schema.copy() + context_schema = getattr(model_cls, "_context_schema", None) + return context_schema.copy() if context_schema is not None else {} - # Otherwise, unbound = all params - bound - return {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + # Dynamic @Flow.model: unbound = all params - bound + if all_param_types: + return {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + + # Generic CallableModel: runtime inputs are the context schema. + context_cls = _concrete_context_type(self._model.context_type) + if context_cls is None or not hasattr(context_cls, "model_fields"): + return {} + return {name: info.annotation for name, info in context_cls.model_fields.items()} @property def bound_inputs(self) -> Dict[str, Any]: """Return the config values bound at construction time.""" bound_fields = getattr(self._model, "_bound_fields", set()) - result = {} + result: Dict[str, Any] = {} for name in bound_fields: if hasattr(self._model, name): result[name] = getattr(self._model, name) + if result: + return result + + # Generic CallableModel: configured model fields are the bound inputs. + model_fields = getattr(self._model.__class__, "model_fields", {}) + for name in model_fields: + if name == "meta": + continue + result[name] = getattr(self._model, name) return result def with_inputs(self, **transforms) -> "BoundModel": @@ -235,7 +284,7 @@ class BoundModel: of a previous transform). """ - def __init__(self, model: "_GeneratedFlowModelBase", input_transforms: Dict[str, Any]): + def __init__(self, model: CallableModel, input_transforms: Dict[str, Any]): self._model = model self._input_transforms = input_transforms @@ -253,6 +302,10 @@ def __call__(self, context: ContextBase) -> Any: """Call the model with transformed context.""" return self._model(self._transform_context(context)) + def __repr__(self) -> str: + transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) + return f"{self._model!r}.flow.with_inputs({transforms})" + @property def flow(self) -> "FlowAPI": """Access the flow API.""" @@ -267,9 +320,7 @@ def __init__(self, bound_model: BoundModel): super().__init__(bound_model._model) def compute(self, **kwargs) -> Any: - validator = self._model._get_context_validator() - validated = validator.validate_python(kwargs) - ctx = FlowContext(**validated) + ctx = self._build_context(kwargs) result = self._bound(ctx) # Call through BoundModel, not _model if isinstance(result, GenericResult): return result.value diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 61869f9..bd526b1 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -12,10 +12,22 @@ import cloudpickle import pytest -from ccflow import Flow, FlowAPI, FlowContext, GenericResult +from ccflow import CallableModel, ContextBase, Flow, FlowAPI, FlowContext, GenericResult from ccflow.context import DateRangeContext +class NumberContext(ContextBase): + x: int + + +class OffsetModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: NumberContext) -> GenericResult[int]: + return GenericResult(value=context.x + self.offset) + + class TestFlowContext: """Tests for the FlowContext universal carrier.""" @@ -152,6 +164,30 @@ def load_data(start_date: date, end_date: date, source: str = "db") -> GenericRe assert "start_date" not in bound assert "end_date" not in bound + def test_flow_compute_regular_callable_model(self): + """Regular CallableModels also expose .flow.compute().""" + + model = OffsetModel(offset=10) + result = model.flow.compute(x=5) + + assert result == 15 + + def test_flow_unbound_inputs_regular_callable_model(self): + """Regular CallableModels expose their context schema as unbound inputs.""" + + model = OffsetModel(offset=10) + unbound = model.flow.unbound_inputs + + assert unbound == {"x": int} + + def test_flow_bound_inputs_regular_callable_model(self): + """Regular CallableModels expose their configured fields as bound inputs.""" + + model = OffsetModel(offset=10) + bound = model.flow.bound_inputs + + assert bound["offset"] == 10 + class TestBoundModel: """Tests for BoundModel (created via .flow.with_inputs()).""" @@ -217,6 +253,27 @@ def compute(x: int) -> GenericResult[int]: bound = model.flow.with_inputs(x=42) assert isinstance(bound.flow, FlowAPI) + def test_bound_model_repr_looks_like_with_inputs_call(self): + """BoundModel repr should mirror the API users wrote.""" + + @Flow.model(context_args=["x"]) + def compute(x: int) -> GenericResult[int]: + return GenericResult(value=x * 2) + + model = compute() + bound = model.flow.with_inputs(x=lambda ctx: ctx.x + 1) + + assert repr(bound) == f"{model!r}.flow.with_inputs(x=)" + + def test_with_inputs_regular_callable_model(self): + """Regular CallableModels support .flow.with_inputs().""" + + model = OffsetModel(offset=1) + shifted = model.flow.with_inputs(x=lambda ctx: ctx.x * 2) + + result = shifted(NumberContext(x=5)) + assert result.value == 11 + class TestTypedDictValidation: """Tests for TypedDict-based context validation.""" diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index dca3fbe..76adbea 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -192,11 +192,45 @@ schema, creates a `FlowContext`, executes the model, and unwraps It is not the only execution path. Because the generated object is still a standard `CallableModel`, calling `model(context)` remains fully supported. +## FieldExtractor + +Accessing an unknown public attribute on a `@Flow.model` instance returns a +`FieldExtractor`. It is itself a `CallableModel` that runs the source model, +then extracts the named field from the result (via `getattr` or dict key +access). + +```python +from ccflow import ContextBase, Flow, GenericResult + + +class TrainingContext(ContextBase): + seed: int + + +@Flow.model +def prepare(context: TrainingContext) -> GenericResult[dict]: + s = context.seed + return GenericResult(value={"X_train": [s, s * 2], "y_train": [s * 10]}) + + +@Flow.model +def train(context: TrainingContext, X: list, y: list) -> GenericResult[int]: + return GenericResult(value=sum(X) + sum(y)) + + +prepared = prepare() +model = train(X=prepared.X_train, y=prepared.y_train) +``` + +Multiple extractors from the same source share the source model instance. If +caching is enabled the source is evaluated only once. + ## Lazy Inputs `Lazy[T]` marks a parameter as on-demand. Instead of eagerly resolving an upstream model, the generated model passes a zero-argument thunk. The thunk -caches its first result. +caches its first result. Lazy dependencies are excluded from the `__deps__` +graph, so they are not pre-evaluated by the evaluator infrastructure. ```python from ccflow import Flow, Lazy @@ -243,3 +277,27 @@ At call time it: That keeps transformed dependency wiring explicit without adding special annotation machinery to the core API. + +## Flow.call with `auto_context` + +Separately from `@Flow.model`, `Flow.call(auto_context=...)` provides a similar +convenience for class-based `CallableModel`s. Instead of defining a separate +`ContextBase` subclass, the decorator generates one from the function's +keyword-only parameters. + +```python +from ccflow import CallableModel, Flow, GenericResult + + +class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") +``` + +Passing a `ContextBase` subclass (e.g., `auto_context=DateContext`) makes the +generated context inherit from that class, so it remains compatible with +infrastructure that expects the parent type. + +The generated class is registered via `create_ccflow_model` for serialization +support. diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 5f1b502..4e34fc3 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -30,6 +30,10 @@ standard `CallableModel` class with proper `__call__` and `__deps__` methods, so it still uses the normal ccflow framework for evaluation, caching, serialization, and registry loading. +If a `@Flow.model` function returns a plain value instead of a `ResultBase` +subclass, the generated model automatically wraps it in `GenericResult` at +runtime so it still behaves like a normal `CallableModel`. + You can execute a generated model in two equivalent ways: - call it directly with a context object: `model(ctx)` @@ -38,7 +42,16 @@ You can execute a generated model in two equivalent ways: `.flow.compute(...)` is mainly an explicit, ergonomic way to mark the deferred execution point. -**Basic Example:** +#### Context Modes + +There are three ways to define how a `@Flow.model` function receives its +runtime context. + +**Mode 1 — Explicit context parameter:** + +The function takes a `context` parameter (or `_` if unused) annotated with a +`ContextBase` subclass. This is the most direct mode and behaves like a +traditional `CallableModel.__call__`. ```python from datetime import date @@ -46,21 +59,56 @@ from ccflow import Flow, GenericResult, DateRangeContext @Flow.model def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: - # Your data loading logic here return GenericResult(value=query_db(source, context.start_date, context.end_date)) -# Create model instance loader = load_data(source="my_database") -# Execute with context ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) ``` -**Default `@Flow.model` Style:** +**Mode 2 — Unpacked context with `context_args`:** -Use this when you want the simplest API and do not need to declare a formal -context shape up front. +Instead of receiving a context object, you list which parameters should come +from the context at runtime. The remaining parameters are model configuration. + +```python +from datetime import date +from ccflow import Flow, GenericResult, DateRangeContext + +@Flow.model(context_args=["start_date", "end_date"]) +def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: + return GenericResult(value=f"{source}:{start_date} to {end_date}") + +loader = load_data(source="my_database") + +# For well-known field sets the decorator matches a built-in context type +assert loader.context_type == DateRangeContext + +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = loader(ctx) +``` + +For well-known shapes such as `start_date` / `end_date` with `date` +annotations, the generated model uses a concrete built-in context type like +`DateRangeContext`. Otherwise it falls back to `FlowContext`, a universal +frozen carrier for the validated fields. + +Use `context_args` when some parameters are semantically "the execution +context" and you want that split to stay stable and explicit: + +- the runtime context should be stable across instances +- the split between config and runtime inputs matters semantically +- the model is naturally "run over a context" such as date windows, + partitions, or scenarios +- you want the generated model to match a built-in context type like + `DateRangeContext` when possible + +**Mode 3 — Default deferred style (no explicit context):** + +When there is no `context` parameter and no `context_args`, all parameters are +potential configuration or runtime inputs. Parameters provided at construction +are bound (configuration); everything else comes from the context at runtime. ```python from ccflow import Flow @@ -80,12 +128,7 @@ doubled_y = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) assert doubled_y.flow.compute(y=5) == 20 ``` -In this mode: - -- bound parameters are model configuration -- unbound parameters are runtime inputs for that model instance - -**Composing Dependencies with Normal Parameters:** +#### Composing Dependencies Any non-context parameter can be bound either to a literal value or to another `CallableModel`. If you pass an upstream model, `@Flow.model` evaluates it with @@ -136,30 +179,19 @@ computed = growth.flow.compute( assert direct == computed ``` -This pattern is the main story for transformed dependencies. `@Flow.model` -still produces an ordinary `CallableModel`; `.flow.compute(...)` is just a -clearer way to say "supply the runtime inputs here." - -**Why `context_args` Exists:** +#### Deferred Execution Helpers -Without `context_args`, runtime inputs are inferred from whichever parameters -are still unbound on a particular model instance. That is flexible and -ergonomic. +**`.flow.compute(**kwargs)`** validates the keyword arguments against the +generated context schema, wraps them in a `FlowContext`, calls the model, and +unwraps `GenericResult.value` if present. -Use `context_args` when some parameters are semantically "the execution -context" and you want that split to stay stable and explicit: - -- the runtime context should be stable across instances -- the split between config and runtime inputs matters semantically -- the model is naturally "run over a context" such as date windows, - partitions, or scenarios -- you want the generated model to match a built-in context type like - `DateRangeContext` when possible - -**Deferred Execution Helpers:** +**`.flow.with_inputs(**transforms)`** returns a `BoundModel` that applies +context transforms before delegating to the underlying model. Each transform +is either a static value or a `(ctx) -> value` callable. Transforms are local +to the wrapped model — upstream models never see them. ```python -from ccflow import Flow +from ccflow import Flow, FlowContext @Flow.model def add(x: int, y: int) -> int: @@ -170,22 +202,105 @@ assert model.flow.compute(y=5) == 15 shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) assert shifted.flow.compute(y=5) == 20 + +# You can also call with a context object directly +ctx = FlowContext(y=5) +assert model(ctx).value == 15 +assert shifted(ctx).value == 20 ``` -If you already have a real context object, you can call the model directly -instead: +#### Field Extraction + +Accessing an unknown attribute on a `@Flow.model` instance returns a +`FieldExtractor` — a `CallableModel` that runs the source model and extracts +the named field from its result. This makes it easy to wire individual output +fields into downstream models. ```python -from ccflow import FlowContext +from ccflow import ContextBase, Flow, GenericResult -ctx = FlowContext(y=5) -assert model(ctx).value == 15 -assert shifted(ctx).value == 20 +class TrainingContext(ContextBase): + seed: int + +@Flow.model +def prepare(context: TrainingContext) -> GenericResult[dict]: + s = context.seed + return GenericResult(value={"X_train": [s, s * 2], "y_train": [s * 10]}) + +@Flow.model +def train(context: TrainingContext, X: list, y: list) -> GenericResult[int]: + return GenericResult(value=sum(X) + sum(y)) + +prepared = prepare() +model = train(X=prepared.X_train, y=prepared.y_train) +result = model(TrainingContext(seed=5)) +# X_train = [5, 10], y_train = [50] -> 15 + 50 = 65 +assert result.value == 65 +``` + +Multiple extractors from the same source share the source model instance, so +with caching enabled the source is only evaluated once. + +#### Lazy Dependencies with `Lazy[T]` + +Mark a parameter with `Lazy[T]` to defer its evaluation. Instead of eagerly +resolving the upstream model, the generated model passes a zero-argument thunk +that evaluates on first call and caches the result. The thunk unwraps +`GenericResult` automatically, so `T` should be the inner value type. + +```python +from ccflow import ContextBase, Flow, GenericResult, Lazy + +class SimpleContext(ContextBase): + value: int + +@Flow.model +def fast_path(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + +@Flow.model +def slow_path(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 100) + +@Flow.model +def smart_selector( + context: SimpleContext, + fast: int, # Eagerly resolved and unwrapped + slow: Lazy[int], # Deferred — receives a thunk returning unwrapped int + threshold: int = 10, +) -> GenericResult[int]: + if fast > threshold: + return GenericResult(value=fast) + return GenericResult(value=slow()) # Evaluated only when called + +model = smart_selector( + fast=fast_path(), + slow=slow_path(), + threshold=10, +) ``` -**Hydra/YAML Configuration:** +`Lazy` dependencies are excluded from the model's `__deps__` graph, so they +are not pre-evaluated by the evaluator infrastructure. -`Flow.model` decorated functions work seamlessly with Hydra configuration and the `ModelRegistry`: +#### Decorator Options + +`@Flow.model(...)` accepts the same options as `Flow.call` to control execution +behavior: + +- `cacheable` — enable caching of results +- `volatile` — mark as volatile (always re-execute) +- `log_level` — logging verbosity +- `validate_result` — validate return type +- `verbose` — verbose logging output +- `evaluator` — custom evaluator + +When not explicitly set, these inherit from any active `FlowOptionsOverride`. + +#### Hydra / YAML Configuration + +`@Flow.model` decorated functions work seamlessly with Hydra configuration and +the `ModelRegistry`: ```yaml # config.yaml @@ -202,36 +317,50 @@ aggregated: transformed: transformed # Reference by registry name ``` -When loaded via `ModelRegistry.load_config()`, references by name ensure the same object instance is shared across all consumers. +```python +from ccflow import ModelRegistry -**Auto-Unpacked Context with `context_args`:** +registry = ModelRegistry.root() +registry.load_config_from_path("config.yaml") -Instead of taking an explicit `context` parameter, you can use `context_args` to automatically unpack context fields as function parameters. This is useful when you want cleaner function signatures: +# References by name ensure the same object instance is shared +model = registry["aggregated"] +``` -```python -from datetime import date -from ccflow import Flow, GenericResult, DateRangeContext +### Flow.call with `auto_context` -# Instead of: def load_data(context: DateRangeContext, source: str) -# Use context_args to unpack the context fields directly: -@Flow.model(context_args=["start_date", "end_date"]) -def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: - return GenericResult(value=f"{source}:{start_date} to {end_date}") +For class-based `CallableModel`s, `Flow.call(auto_context=...)` provides a +similar convenience. Instead of defining a separate `ContextBase` subclass, the +decorator generates one from the function's keyword-only parameters. -# The decorator matches common built-in context types when possible -loader = load_data(source="my_database") -assert loader.context_type == DateRangeContext +```python +from ccflow import CallableModel, Flow, GenericResult -# Execute with context as usual -ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = loader(ctx) # "my_database:2024-01-01 to 2024-01-31" +class MyModel(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, x: int, y: str = "default") -> GenericResult: + return GenericResult(value=f"{x}-{y}") + +model = MyModel() +result = model(x=42, y="hello") +assert result.value == "42-hello" ``` -The `context_args` parameter specifies which function parameters should be extracted from the context. Those fields are validated through a runtime schema built from the parameter annotations. For well-known shapes such as `start_date` / `end_date`, the generated model uses a concrete built-in context type like `DateRangeContext`; otherwise it uses `FlowContext`, a universal frozen carrier for the validated fields. +You can also pass a parent context class so the generated context inherits +from it: -If a `@Flow.model` function returns a plain value instead of a `ResultBase` -subclass, the generated model automatically wraps that value in `GenericResult` -at runtime so it still behaves like a normal `CallableModel`. +```python +from datetime import date +from ccflow import CallableModel, DateContext, Flow, GenericResult + +class MyModel(CallableModel): + @Flow.call(auto_context=DateContext) + def __call__(self, *, date: date, extra: int = 0) -> GenericResult: + return GenericResult(value=date.day + extra) +``` + +The generated context class is a proper `ContextBase` subclass, so it works +with all existing evaluator and registry infrastructure. ## Model Registry diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index f2616dc..27e31bb 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -1,5 +1,16 @@ #!/usr/bin/env python -"""Example demonstrating the core Flow.model workflow.""" +"""Canonical Flow.model example. + +This is the main `@Flow.model` story: + +1. define workflow steps as plain Python functions, +2. wire them together by passing upstream models as normal arguments, +3. use a small Python builder for reusable composition, +4. execute either as a normal CallableModel or via `.flow.compute(...)`. + +Run with: + python examples/flow_model_example.py +""" from datetime import date, timedelta @@ -8,57 +19,84 @@ @Flow.model(context_args=["start_date", "end_date"]) def load_revenue(start_date: date, end_date: date, region: str) -> float: - """Pretend to load revenue for a date window.""" + """Return synthetic revenue for one reporting window.""" days = (end_date - start_date).days + 1 - baseline = 1000.0 if region == "us" else 800.0 - return baseline + days * 10.0 + region_base = {"us": 1000.0, "eu": 850.0}.get(region, 900.0) + days_since_2024 = (end_date - date(2024, 1, 1)).days + trend = days_since_2024 * 2.5 + return round(region_base + days * 8.0 + trend, 2) @Flow.model(context_args=["start_date", "end_date"]) -def summarize_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: - """Compare the current and previous windows.""" +def revenue_change( + start_date: date, + end_date: date, + current: float, + previous: float, + label: str, + days_back: int, +) -> dict: + """Compare the current window against a shifted previous window.""" + previous_start = start_date - timedelta(days=days_back) + previous_end = end_date - timedelta(days=days_back) growth_pct = round((current - previous) / previous * 100, 2) return { - "start_date": start_date, - "end_date": end_date, + "comparison": label, + "current_window": f"{start_date} -> {end_date}", + "previous_window": f"{previous_start} -> {previous_end}", "current": current, "previous": previous, "growth_pct": growth_pct, } -def main(): - print("=" * 60) - print("Flow.model Example") - print("=" * 60) +def shifted_window(model, *, days_back: int): + """Reuse one upstream model with a shifted runtime window.""" + return model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=days_back), + end_date=lambda ctx: ctx.end_date - timedelta(days=days_back), + ) + - current_window = load_revenue(region="us") - previous_window = current_window.flow.with_inputs( - start_date=lambda ctx: ctx.start_date - timedelta(days=30), - end_date=lambda ctx: ctx.end_date - timedelta(days=30), +def build_week_over_week_pipeline(region: str): + """Build one reusable pipeline from plain Flow.model functions.""" + current = load_revenue(region=region) + previous = shifted_window(current, days_back=7) + return revenue_change( + current=current, + previous=previous, + label="week_over_week", + days_back=7, ) - growth = summarize_growth(current=current_window, previous=previous_window) +def main() -> None: + print("=" * 64) + print("Flow.model Example") + print("=" * 64) + + pipeline = build_week_over_week_pipeline(region="us") ctx = DateRangeContext( - start_date=date(2024, 1, 1), - end_date=date(2024, 1, 31), + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 31), ) - print("\n[1] Execute as a normal CallableModel:") - print(growth(ctx).value) - - print("\n[2] Execute via .flow.compute(...):") - print( - growth.flow.compute( - start_date=date(2024, 1, 1), - end_date=date(2024, 1, 31), - ) + direct = pipeline(ctx).value + computed = pipeline.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, ) - print("\n[3] Inspect bound and unbound inputs:") - print(" bound_inputs:", growth.flow.bound_inputs) - print(" unbound_inputs:", growth.flow.unbound_inputs) + print("\nPipeline wired from plain functions:") + print(" current input:", pipeline.current) + print(" previous input:", pipeline.previous) + + print("\nDirect call and .flow.compute(...) are equivalent:") + print(f" direct == computed: {direct == computed}") + + print("\nResult:") + for key, value in computed.items(): + print(f" {key}: {value}") if __name__ == "__main__": From 9c7bc7debb94728ca80c8b5cebcbf99645b32e89 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 17 Mar 2026 17:30:01 -0400 Subject: [PATCH 12/26] Small bug fixes for @Flow.model Signed-off-by: Nijat Khanbabayev --- ccflow/flow_model.py | 62 +++++++++++++++++++++----- ccflow/tests/test_flow_model.py | 78 +++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 12 deletions(-) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index e2496ea..9d426fa 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -98,6 +98,36 @@ def _transform_repr(transform: Any) -> str: return repr(transform) +def _is_model_dependency(value: Any) -> bool: + return isinstance(value, (CallableModel, BoundModel)) + + +def _resolve_registry_candidate(value: str) -> Any: + from .base import BaseModel as _BM + + try: + candidate = _BM.model_validate(value) + except Exception: + return None + return candidate if isinstance(candidate, _BM) else None + + +def _registry_candidate_allowed(expected_type: Type, candidate: Any) -> bool: + if _is_model_dependency(candidate): + return True + try: + TypeAdapter(expected_type).validate_python(candidate) + except Exception: + return False + return True + + +def _make_field_extractor(source: Any, name: str) -> "FieldExtractor": + if name.startswith("_"): + raise AttributeError(f"'{type(source).__name__}' has no attribute '{name}'") + return FieldExtractor(source=source, field_name=name) + + def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter: """Build a TypeAdapter for a runtime TypedDict schema.""" @@ -144,16 +174,18 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, return from .base import ModelRegistry as _MR - from .callable import CallableModel as _CM for field_name, validator in validators.items(): if field_name not in kwargs: continue value = kwargs[field_name] - if value is None or isinstance(value, (_CM, BoundModel)): + if value is None or _is_model_dependency(value): continue if isinstance(value, str) and value in _MR.root(): - continue + candidate = _resolve_registry_candidate(value) + expected_type = validatable_types[field_name] + if candidate is not None and _registry_candidate_allowed(expected_type, candidate): + continue try: validator.validate_python(value) except Exception: @@ -302,6 +334,9 @@ def __call__(self, context: ContextBase) -> Any: """Call the model with transformed context.""" return self._model(self._transform_context(context)) + def __getattr__(self, name): + return _make_field_extractor(self, name) + def __repr__(self) -> str: transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) return f"{self._model!r}.flow.with_inputs({transforms})" @@ -311,6 +346,10 @@ def flow(self) -> "FlowAPI": """Access the flow API.""" return _BoundFlowAPI(self) + @property + def context_type(self) -> Type[ContextBase]: + return self._model.context_type + class _BoundFlowAPI(FlowAPI): """FlowAPI that delegates to a BoundModel, honoring transforms.""" @@ -349,9 +388,7 @@ def __getattr__(self, name): raise AttributeError(name) return super_getattr(name) except AttributeError: - if name.startswith("_"): - raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") from None - return FieldExtractor(source=self, field_name=name) + return _make_field_extractor(self, name) class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): @@ -374,8 +411,6 @@ def _resolve_registry_refs(cls, values, info): if not isinstance(values, dict): return values - from .base import BaseModel as _BM - param_types = getattr(cls, "__flow_model_all_param_types__", {}) resolved = dict(values) for field_name, expected_type in param_types.items(): @@ -386,11 +421,10 @@ def _resolve_registry_refs(cls, values, info): continue if expected_type is str: continue - try: - candidate = _BM.model_validate(value) - except Exception: + candidate = _resolve_registry_candidate(value) + if candidate is None: continue - if isinstance(candidate, _BM): + if _registry_candidate_allowed(expected_type, candidate): resolved[field_name] = candidate return resolved @@ -936,6 +970,8 @@ class FieldExtractor(_FieldExtractorMixin, CallableModel): @property def context_type(self): + if isinstance(self.source, BoundModel): + return self.source.context_type if isinstance(self.source, _CallableModel): return self.source.context_type return ContextBase @@ -956,6 +992,8 @@ def __call__(self, context: ContextBase) -> GenericResult: @Flow.deps def __deps__(self, context: ContextBase) -> GraphDepList: + if isinstance(self.source, BoundModel): + return [(self.source._model, [self.source._transform_context(context)])] if isinstance(self.source, _CallableModel): return [(self.source, [context])] return [] diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 458569d..018052a 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -6,6 +6,7 @@ from ray.cloudpickle import dumps as rcpdumps, loads as rcploads from ccflow import ( + BaseModel, CallableModel, ContextBase, DateRangeContext, @@ -709,6 +710,51 @@ def typed_config(context: SimpleContext, n: int = 10, name: str = "x") -> Generi result = model(SimpleContext(value=1)) self.assertEqual(result.value, "test:42") + def test_config_validation_rejects_registry_alias_for_incompatible_type(self): + """Registry aliases should not silently bypass scalar type validation.""" + + class DummyConfig(BaseModel): + x: int = 1 + + registry = ModelRegistry.root() + registry.clear() + try: + registry.add("dummy_config", DummyConfig()) + + @Flow.model + def typed_config(context: SimpleContext, n: int = 10) -> GenericResult[int]: + return GenericResult(value=n) + + with self.assertRaises(TypeError) as cm: + typed_config(n="dummy_config") + + self.assertIn("n", str(cm.exception)) + finally: + registry.clear() + + def test_config_validation_accepts_registry_alias_for_callable_dependency(self): + """Registry aliases still work for CallableModel dependencies.""" + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 2) + + @Flow.model + def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: + return GenericResult(value=data + 1) + + registry.add("source_model", source()) + model = consumer(data="source_model") + + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 11) + finally: + registry.clear() + # ============================================================================= # BoundModel Tests @@ -1689,6 +1735,38 @@ def prepare(context: SimpleContext) -> GenericResult[dict]: self.assertIs(deps[0][0], model) self.assertEqual(deps[0][1], [ctx]) + def test_field_extraction_from_bound_model(self): + """Field extraction should still work after .flow.with_inputs().""" + + @Flow.model + def prepare(x: int) -> GenericResult[dict]: + return GenericResult(value={"doubled": x * 2}) + + bound = prepare().flow.with_inputs(x=lambda ctx: ctx.x + 1) + extractor = bound.doubled + + result = extractor.flow.compute(x=5) + self.assertEqual(result, 12) + + def test_field_extraction_deps_from_bound_model(self): + """Bound-model extractors should preserve transformed dependency contexts.""" + from ccflow import FlowContext + + @Flow.model + def prepare(x: int) -> GenericResult[dict]: + return GenericResult(value={"doubled": x * 2}) + + model = prepare() + bound = model.flow.with_inputs(x=lambda ctx: ctx.x + 1) + extractor = bound.doubled + + ctx = FlowContext(x=5) + deps = extractor.__deps__(ctx) + + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], model) + self.assertEqual(deps[0][1][0].x, 6) + if __name__ == "__main__": import unittest From 5569ed09d63f2036994ab895a39679413a7399cb Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Wed, 18 Mar 2026 14:20:35 -0400 Subject: [PATCH 13/26] Temp progress for cleaning up Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 16 +- ccflow/context.py | 29 +- ccflow/flow_model.py | 480 ++++++++++++++++---------- ccflow/tests/test_callable.py | 34 +- ccflow/tests/test_flow_context.py | 43 ++- ccflow/tests/test_flow_model.py | 538 ++++++++++++++++-------------- docs/design/flow_model_design.md | 57 +--- docs/wiki/Key-Features.md | 65 +--- examples/flow_model_example.py | 8 +- 9 files changed, 720 insertions(+), 550 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 185c229..54f4a9d 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -536,6 +536,7 @@ def model(*args, **kwargs): Args: context_args: List of parameter names that come from context (for unpacked mode) + context_type: Explicit ContextBase subclass to use with context_args mode cacheable: Enable caching of results (default: False) volatile: Mark as volatile (default: False) log_level: Logging verbosity (default: logging.DEBUG) @@ -555,7 +556,7 @@ def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.Data Mode 2 - Unpacked context_args: Context fields are unpacked into function parameters. - @Flow.model(context_args=["start_date", "end_date"]) + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: return GenericResult(value=query_db(source, start_date, end_date)) @@ -807,6 +808,12 @@ def flow(self) -> "FlowAPI": return FlowAPI(self) + def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: + """Wire this model into a downstream generated ``@Flow.model`` stage.""" + from .flow_model import pipe_model + + return pipe_model(self, stage, param=param, **bindings) + class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): """Abstract class that represents a wrapper around an underlying model, with the same context and return types. @@ -960,6 +967,9 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: sig = signature(func) base_class = parent or ContextBase + if sig.return_annotation is inspect.Signature.empty: + raise TypeError(f"Function {_callable_qualname(func)} must have a return type annotation when auto_context=True") + # Validate parent fields are in function signature if parent is not None: parent_fields = set(parent.model_fields.keys()) - set(ContextBase.model_fields.keys()) @@ -973,6 +983,10 @@ def __call__(self, *, x: int, y: str = "default") -> GenericResult: for name, param in sig.parameters.items(): if name == "self": continue + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError(f"Function {_callable_qualname(func)} does not support {param.kind.description} when auto_context=True") + if param.annotation is inspect.Parameter.empty: + raise TypeError(f"Parameter '{name}' must have a type annotation when auto_context=True") default = ... if param.default is inspect.Parameter.empty else param.default fields[name] = (param.annotation, default) diff --git a/ccflow/context.py b/ccflow/context.py index 0d00d2e..ae69e22 100644 --- a/ccflow/context.py +++ b/ccflow/context.py @@ -1,7 +1,8 @@ """This module defines re-usable contexts for the "Callable Model" framework defined in flow.callable.py.""" +from collections.abc import Mapping from datetime import date, datetime -from typing import Generic, Hashable, Optional, Sequence, Set, TypeVar +from typing import Any, Generic, Hashable, Optional, Sequence, Set, TypeVar from deprecated import deprecated from pydantic import ConfigDict, field_validator, model_validator @@ -106,6 +107,32 @@ class FlowContext(ContextBase): model_config = ConfigDict(extra="allow", frozen=True) + def __eq__(self, other: Any) -> bool: + if not isinstance(other, FlowContext): + return False + return self.model_dump(mode="python") == other.model_dump(mode="python") + + def __hash__(self) -> int: + return hash(_freeze_for_hash(self.model_dump(mode="python"))) + + +def _freeze_for_hash(value: Any) -> Hashable: + if isinstance(value, Mapping): + return tuple(sorted((key, _freeze_for_hash(item)) for key, item in value.items())) + if isinstance(value, (list, tuple)): + return tuple(_freeze_for_hash(item) for item in value) + if isinstance(value, (set, frozenset)): + return frozenset(_freeze_for_hash(item) for item in value) + if hasattr(value, "model_dump"): + return (type(value), _freeze_for_hash(value.model_dump(mode="python"))) + try: + hash(value) + except TypeError as exc: + if hasattr(value, "__dict__"): + return (type(value), _freeze_for_hash(vars(value))) + raise TypeError(f"FlowContext contains an unhashable value of type {type(value).__name__}") from exc + return value + C = TypeVar("C", bound=Hashable) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 9d426fa..da9d1bb 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -12,22 +12,32 @@ import inspect import logging from functools import wraps -from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin +from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, get_type_hints from pydantic import Field, TypeAdapter, model_validator from typing_extensions import TypedDict from .base import ContextBase, ResultBase -from .callable import CallableModel, Flow, GraphDepList, _CallableModel +from .callable import CallableModel, Flow, GraphDepList from .context import FlowContext from .local_persistence import register_ccflow_import_path from .result import GenericResult -__all__ = ("flow_model", "FlowAPI", "BoundModel", "Lazy", "FieldExtractor") +__all__ = ("FlowAPI", "BoundModel", "Lazy") _AnyCallable = Callable[..., Any] +class _DeferredInput: + """Sentinel for dynamic @Flow.model inputs left for runtime context.""" + + def __repr__(self) -> str: + return "" + + +_DEFERRED_INPUT = _DeferredInput() + + def _callable_name(func: _AnyCallable) -> str: return getattr(func, "__name__", type(func).__name__) @@ -102,6 +112,34 @@ def _is_model_dependency(value: Any) -> bool: return isinstance(value, (CallableModel, BoundModel)) +def _bound_field_names(model: Any) -> set[str]: + fields_set = getattr(model, "model_fields_set", None) + if fields_set is not None: + return set(fields_set) + return set(getattr(model, "_bound_fields", set())) + + +def _has_deferred_input(value: Any) -> bool: + return isinstance(value, _DeferredInput) + + +def _deferred_input_factory() -> _DeferredInput: + return _DEFERRED_INPUT + + +def _effective_bound_field_names(model: Any) -> set[str]: + fields = _bound_field_names(model) + defaults = getattr(model.__class__, "__flow_model_default_param_names__", set()) + return fields | set(defaults) + + +def _runtime_input_names(model: Any) -> set[str]: + all_param_names = set(getattr(model.__class__, "__flow_model_all_param_types__", {})) + if not all_param_names: + return set() + return all_param_names - _effective_bound_field_names(model) + + def _resolve_registry_candidate(value: str) -> Any: from .base import BaseModel as _BM @@ -122,18 +160,12 @@ def _registry_candidate_allowed(expected_type: Type, candidate: Any) -> bool: return True -def _make_field_extractor(source: Any, name: str) -> "FieldExtractor": - if name.startswith("_"): - raise AttributeError(f"'{type(source).__name__}' has no attribute '{name}'") - return FieldExtractor(source=source, field_name=name) - - -def _build_typed_dict_adapter(name: str, schema: Dict[str, Type]) -> TypeAdapter: +def _build_typed_dict_adapter(name: str, schema: Dict[str, Type], *, total: bool = True) -> TypeAdapter: """Build a TypeAdapter for a runtime TypedDict schema.""" if not schema: return TypeAdapter(dict) - return TypeAdapter(TypedDict(name, schema)) + return TypeAdapter(TypedDict(name, schema, total=total)) def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: @@ -193,6 +225,129 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") +def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: + if isinstance(stage, BoundModel): + model = stage._model + else: + model = stage + if isinstance(model, _GeneratedFlowModelBase): + return model + return None + + +def _generated_model_class(stage: Any) -> Optional[type["_GeneratedFlowModelBase"]]: + model = _generated_model_instance(stage) + if model is not None: + return type(model) + + generated_model = getattr(stage, "_generated_model", None) + if isinstance(generated_model, type) and issubclass(generated_model, _GeneratedFlowModelBase): + return generated_model + return None + + +def _describe_pipe_stage(stage: Any) -> str: + if isinstance(stage, BoundModel): + return repr(stage) + if isinstance(stage, _GeneratedFlowModelBase): + return repr(stage) + if callable(stage): + return _callable_name(stage) + return repr(stage) + + +def _generated_model_explicit_kwargs(model: "_GeneratedFlowModelBase") -> Dict[str, Any]: + return cast(Dict[str, Any], model.model_dump(mode="python", exclude_unset=True)) + + +def _infer_pipe_param( + stage_name: str, + param_names: List[str], + default_param_names: set[str], + occupied_names: set[str], +) -> str: + required_candidates = [name for name in param_names if name not in occupied_names and name not in default_param_names] + if len(required_candidates) == 1: + return required_candidates[0] + if len(required_candidates) > 1: + candidates = ", ".join(required_candidates) + raise TypeError( + f"pipe() could not infer a target parameter for {stage_name}; unbound candidates are: {candidates}. Pass param='...' explicitly." + ) + + fallback_candidates = [name for name in param_names if name not in occupied_names] + if len(fallback_candidates) == 1: + return fallback_candidates[0] + if len(fallback_candidates) > 1: + candidates = ", ".join(fallback_candidates) + raise TypeError( + f"pipe() could not infer a target parameter for {stage_name}; unbound candidates are: {candidates}. Pass param='...' explicitly." + ) + + raise TypeError(f"pipe() could not find an available target parameter for {stage_name}.") + + +def _resolve_pipe_param(source: Any, stage: Any, param: Optional[str], bindings: Dict[str, Any]) -> Tuple[str, type["_GeneratedFlowModelBase"]]: + del source # Source only matters when binding, not during target resolution. + + generated_model_cls = _generated_model_class(stage) + if generated_model_cls is None: + raise TypeError("pipe() only supports downstream stages created by @Flow.model or bound versions of those stages.") + + stage_name = _describe_pipe_stage(stage) + all_param_types = getattr(generated_model_cls, "__flow_model_all_param_types__", {}) + if not all_param_types: + raise TypeError(f"pipe() could not determine bindable parameters for {stage_name}.") + + param_names = list(all_param_types.keys()) + default_param_names = set(getattr(generated_model_cls, "__flow_model_default_param_names__", set())) + + generated_model = _generated_model_instance(stage) + occupied_names = set(bindings) + if generated_model is not None: + occupied_names |= _bound_field_names(generated_model) + if isinstance(stage, BoundModel): + occupied_names |= set(stage._input_transforms) + + if param is not None: + if param not in all_param_types: + valid = ", ".join(param_names) + raise TypeError(f"pipe() target parameter '{param}' is not valid for {stage_name}. Available parameters: {valid}.") + if param in occupied_names: + raise TypeError(f"pipe() target parameter '{param}' is already bound for {stage_name}.") + return param, generated_model_cls + + return _infer_pipe_param(stage_name, param_names, default_param_names, occupied_names), generated_model_cls + + +def pipe_model(source: Any, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: + """Wire ``source`` into a downstream generated ``@Flow.model`` stage.""" + + if not _is_model_dependency(source): + raise TypeError(f"pipe() source must be a CallableModel or BoundModel, got {type(source).__name__}.") + + target_param, generated_model_cls = _resolve_pipe_param(source, stage, param, bindings) + build_kwargs = dict(bindings) + build_kwargs[target_param] = source + + if isinstance(stage, BoundModel): + generated_model = _generated_model_instance(stage) + if generated_model is None: + raise TypeError("pipe() only supports downstream BoundModel stages created from @Flow.model.") + explicit_kwargs = _generated_model_explicit_kwargs(generated_model) + explicit_kwargs.update(build_kwargs) + rebound_model = generated_model_cls(**explicit_kwargs) + return BoundModel(model=rebound_model, input_transforms=dict(stage._input_transforms)) + + generated_model = _generated_model_instance(stage) + if generated_model is not None: + explicit_kwargs = _generated_model_explicit_kwargs(generated_model) + explicit_kwargs.update(build_kwargs) + return generated_model_cls(**explicit_kwargs) + + return stage(**build_kwargs) + + class FlowAPI: """API namespace for deferred computation operations. @@ -224,26 +379,18 @@ def compute(self, **kwargs) -> Any: **kwargs: Context arguments (e.g., start_date, end_date) Returns: - The model's result, unwrapped from GenericResult if applicable. + The model's result, using the same return contract as ``model(context)``. """ ctx = self._build_context(kwargs) - - # Call the model - result = self._model(ctx) - - # Unwrap GenericResult if present - if isinstance(result, GenericResult): - return result.value - return result + return self._model(ctx) @property def unbound_inputs(self) -> Dict[str, Type]: """Return the context schema (field name -> type). - In deferred mode, this is everything NOT provided at construction. + In deferred mode, this is everything that must still come from runtime context. """ all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) - bound_fields = getattr(self._model, "_bound_fields", set()) model_cls = self._model.__class__ # If explicit context_args was provided, use _context_schema @@ -252,9 +399,10 @@ def unbound_inputs(self) -> Dict[str, Type]: context_schema = getattr(model_cls, "_context_schema", None) return context_schema.copy() if context_schema is not None else {} - # Dynamic @Flow.model: unbound = all params - bound + # Dynamic @Flow.model: unbound = params with no explicit value and no declared default if all_param_types: - return {name: typ for name, typ in all_param_types.items() if name not in bound_fields} + runtime_inputs = _runtime_input_names(self._model) + return {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} # Generic CallableModel: runtime inputs are the context schema. context_cls = _concrete_context_type(self._model.context_type) @@ -264,13 +412,15 @@ def unbound_inputs(self) -> Dict[str, Type]: @property def bound_inputs(self) -> Dict[str, Any]: - """Return the config values bound at construction time.""" - bound_fields = getattr(self._model, "_bound_fields", set()) + """Return the effective config values for this model.""" result: Dict[str, Any] = {} - for name in bound_fields: - if hasattr(self._model, name): - result[name] = getattr(self._model, name) - if result: + flow_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) + if flow_param_types: + for name in flow_param_types: + value = getattr(self._model, name, _DEFERRED_INPUT) + if _has_deferred_input(value): + continue + result[name] = value return result # Generic CallableModel: configured model fields are the bound inputs. @@ -320,22 +470,26 @@ def __init__(self, model: CallableModel, input_transforms: Dict[str, Any]): self._model = model self._input_transforms = input_transforms - def _transform_context(self, context: ContextBase) -> FlowContext: - """Return a FlowContext with this model's input transforms applied.""" + def _transform_context(self, context: ContextBase) -> ContextBase: + """Return this model's preferred context type with input transforms applied.""" ctx_dict = _context_values(context) for name, transform in self._input_transforms.items(): if callable(transform): ctx_dict[name] = transform(context) else: ctx_dict[name] = transform + context_type = _concrete_context_type(self._model.context_type) + if context_type is not None and context_type is not FlowContext: + return context_type.model_validate(ctx_dict) return FlowContext(**ctx_dict) def __call__(self, context: ContextBase) -> Any: """Call the model with transformed context.""" return self._model(self._transform_context(context)) - def __getattr__(self, name): - return _make_field_extractor(self, name) + def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: + """Wire this bound model into a downstream generated ``@Flow.model`` stage.""" + return pipe_model(self, stage, param=param, **bindings) def __repr__(self) -> str: transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) @@ -360,10 +514,7 @@ def __init__(self, bound_model: BoundModel): def compute(self, **kwargs) -> Any: ctx = self._build_context(kwargs) - result = self._bound(ctx) # Call through BoundModel, not _model - if isinstance(result, GenericResult): - return result.value - return result + return self._bound(ctx) # Call through BoundModel, not _model def with_inputs(self, **transforms) -> "BoundModel": """Chain transforms: merge new transforms with existing ones. @@ -374,24 +525,7 @@ def with_inputs(self, **transforms) -> "BoundModel": return BoundModel(model=self._bound._model, input_transforms=merged) -class _FieldExtractorMixin: - """Turn unknown public attributes into FieldExtractors. - - Real model attributes are still resolved by the normal pydantic/base-model - attribute path via ``super().__getattr__``. - """ - - def __getattr__(self, name): - try: - super_getattr = getattr(super(), "__getattr__", None) - if super_getattr is None: - raise AttributeError(name) - return super_getattr(name) - except AttributeError: - return _make_field_extractor(self, name) - - -class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): +class _GeneratedFlowModelBase(CallableModel): """Shared behavior for models generated by ``@Flow.model``.""" __flow_model_context_type__: ClassVar[Type[ContextBase]] = FlowContext @@ -400,10 +534,10 @@ class _GeneratedFlowModelBase(_FieldExtractorMixin, CallableModel): __flow_model_use_context_args__: ClassVar[bool] = True __flow_model_explicit_context_args__: ClassVar[Optional[List[str]]] = None __flow_model_all_param_types__: ClassVar[Dict[str, Type]] = {} + __flow_model_default_param_names__: ClassVar[set[str]] = set() __flow_model_auto_wrap__: ClassVar[bool] = False _context_schema: ClassVar[Dict[str, Type]] = {} _context_td: ClassVar[Any | None] = None - _matched_context_type: ClassVar[Optional[Type[ContextBase]]] = None _cached_context_validator: ClassVar[TypeAdapter | None] = None @model_validator(mode="before") @@ -458,10 +592,14 @@ def _get_context_validator(self) -> TypeAdapter: if not hasattr(self, "_instance_context_validator"): all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) - bound_fields = getattr(self, "_bound_fields", set()) - unbound_schema = {name: typ for name, typ in all_param_types.items() if name not in bound_fields} - object.__setattr__(self, "_instance_context_validator", _build_typed_dict_adapter(f"{cls.__name__}Inputs", unbound_schema)) - return self._instance_context_validator + runtime_inputs = _runtime_input_names(self) + unbound_schema = {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} + object.__setattr__( + self, + "_instance_context_validator", + _build_typed_dict_adapter(f"{cls.__name__}Inputs", unbound_schema, total=False), + ) + return cast(TypeAdapter, getattr(self, "_instance_context_validator")) class Lazy: @@ -544,8 +682,8 @@ def model(self) -> "CallableModel": # noqa: F821 def _build_context_schema( - context_args: List[str], func: _AnyCallable, sig: inspect.Signature -) -> Tuple[Dict[str, Type], Any, Optional[Type[ContextBase]]]: + context_args: List[str], func: _AnyCallable, sig: inspect.Signature, resolved_hints: Dict[str, Any] +) -> Tuple[Dict[str, Type], Any]: """Build context schema from context_args parameter names. Instead of creating a dynamic ContextBase subclass, this builds: @@ -559,7 +697,7 @@ def _build_context_schema( sig: The function signature Returns: - Tuple of (schema_dict, TypedDict type, optional matched ContextBase type) + Tuple of (schema_dict, TypedDict type) """ # Build schema dict from parameter annotations schema = {} @@ -567,28 +705,54 @@ def _build_context_schema( if name not in sig.parameters: raise ValueError(f"context_arg '{name}' not found in function parameters") param = sig.parameters[name] - if param.annotation is inspect.Parameter.empty: + annotation = resolved_hints.get(name, param.annotation) + if annotation is inspect.Parameter.empty: raise ValueError(f"context_arg '{name}' must have a type annotation") - schema[name] = param.annotation + schema[name] = annotation - # Try to match common context types for compatibility - matched_context_type = None - from .context import DateRangeContext + # Create TypedDict for validation (not registered anywhere!) + context_td = TypedDict(f"{_callable_name(func)}Inputs", schema) - if set(context_args) == {"start_date", "end_date"}: - from datetime import date + return schema, context_td - if all( - sig.parameters[name].annotation in (date, "date") - or (isinstance(sig.parameters[name].annotation, type) and sig.parameters[name].annotation is date) - for name in context_args - ): - matched_context_type = DateRangeContext - # Create TypedDict for validation (not registered anywhere!) - context_td = TypedDict(f"{_callable_name(func)}Inputs", schema) +def _validate_context_type_override(context_type: Any, context_args: List[str], func_schema: Dict[str, Type]) -> Type[ContextBase]: + """Validate an explicit ``context_type`` override for ``context_args`` mode.""" - return schema, context_td, matched_context_type + if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): + raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}") + + context_fields = getattr(context_type, "model_fields", {}) + missing = sorted(name for name in context_args if name not in context_fields) + if missing: + raise TypeError(f"context_type {context_type.__name__} must define fields for context_args: {', '.join(missing)}") + + required_extra_fields = sorted( + name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in context_args and info.is_required() + ) + if required_extra_fields: + raise TypeError(f"context_type {context_type.__name__} has required fields not listed in context_args: {', '.join(required_extra_fields)}") + + # Warn when the function's annotation for a context_arg doesn't match the + # context_type's field annotation. A mismatch means the function declares + # one type but will silently receive whatever Pydantic coerces to. + for name in context_args: + func_ann = func_schema.get(name) + ctx_field = context_fields.get(name) + if func_ann is None or ctx_field is None: + continue + ctx_ann = ctx_field.annotation + if func_ann is ctx_ann: + continue + # Both are concrete types — check subclass relationship + if isinstance(func_ann, type) and isinstance(ctx_ann, type): + if not (issubclass(func_ann, ctx_ann) or issubclass(ctx_ann, func_ann)): + raise TypeError( + f"context_arg '{name}': function annotates {func_ann.__name__} " + f"but context_type {context_type.__name__} declares {ctx_ann.__name__}" + ) + + return context_type _UNSET = object() @@ -599,6 +763,7 @@ def flow_model( *, # Context handling context_args: Optional[List[str]] = None, + context_type: Optional[Type[ContextBase]] = None, # Flow.call options (passed to generated __call__) # Default to _UNSET so FlowOptionsOverride can control these globally. # Only explicitly user-provided values are passed to Flow.call. @@ -618,6 +783,7 @@ def flow_model( Args: func: The function to decorate context_args: List of parameter names that come from context (for unpacked mode) + context_type: Explicit ContextBase subclass to use with ``context_args`` mode. cacheable: Enable caching of results (default: unset, inherits from FlowOptionsOverride) volatile: Mark as volatile (always re-execute) (default: unset, inherits from FlowOptionsOverride) log_level: Logging verbosity (default: unset, inherits from FlowOptionsOverride) @@ -635,7 +801,7 @@ def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.Data 2. Unpacked context_args: Context fields are unpacked into function parameters. - @Flow.model(context_args=["start_date", "end_date"]) + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: ... @@ -644,15 +810,13 @@ def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[ """ def decorator(fn: _AnyCallable) -> _AnyCallable: - import typing as _typing - sig = inspect.signature(fn) params = sig.parameters # Resolve string annotations (PEP 563 / from __future__ import annotations) # into real type objects. include_extras=True preserves Annotated metadata. try: - _resolved_hints = _typing.get_type_hints(fn, include_extras=True) + _resolved_hints = get_type_hints(fn, include_extras=True) except Exception: _resolved_hints = {} @@ -670,15 +834,19 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: internal_return_type = return_type # Determine context mode + context_schema_early: Dict[str, Type] = {} + context_td_early = None if "context" in params or "_" in params: # Mode 1: Explicit context parameter (named 'context' or '_' for unused) + if context_type is not None: + raise TypeError("context_type=... is only supported when using context_args=[...]") context_param_name = "context" if "context" in params else "_" context_param = params[context_param_name] context_annotation = _resolved_hints.get(context_param_name, context_param.annotation) if context_annotation is inspect.Parameter.empty: raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' parameter must have a type annotation") - context_type = context_annotation - if not (isinstance(context_type, type) and issubclass(context_type, ContextBase)): + resolved_context_type = context_annotation + if not (isinstance(resolved_context_type, type) and issubclass(resolved_context_type, ContextBase)): raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' must be annotated with a ContextBase subclass") model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} use_context_args = False @@ -686,20 +854,22 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: elif context_args is not None: # Mode 2: Explicit context_args - specified params come from context context_param_name = "context" - # Build context schema early to determine matched_context_type - context_schema_early, _, matched_type = _build_context_schema(context_args, fn, sig) - # Use matched type if available (e.g., DateRangeContext), else FlowContext - context_type = matched_type if matched_type is not None else FlowContext + context_schema_early, context_td_early = _build_context_schema(context_args, fn, sig, _resolved_hints) + explicit_context_type = ( + _validate_context_type_override(context_type, context_args, context_schema_early) if context_type is not None else None + ) + resolved_context_type = explicit_context_type if explicit_context_type is not None else FlowContext # Exclude context_args from model fields model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} use_context_args = True explicit_context_args = context_args else: - # Mode 3: Dynamic deferred mode - ALL params are potential context or config - # What's provided at construction = config/deps - # What's NOT provided = comes from context at runtime + # Mode 3: Dynamic deferred mode - every param can be configured on the model, + # but only params without Python defaults remain runtime inputs when omitted. + if context_type is not None: + raise TypeError("context_type=... is only supported when using context_args=[...]") context_param_name = "context" - context_type = FlowContext + resolved_context_type = FlowContext model_field_params = {name: param for name, param in params.items() if name != "self"} use_context_args = True explicit_context_args = None # Dynamic - determined at construction @@ -707,9 +877,10 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: # Analyze parameters to find lazy fields and regular fields. model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) lazy_fields: set[str] = set() # Names of parameters marked with Lazy[T] + default_param_names: set[str] = set() - # In dynamic deferred mode (no explicit context_args), all fields are optional - # because values not provided at construction come from context at runtime + # In dynamic deferred mode (no explicit context_args), fields without Python defaults + # are internally represented by a deferred sentinel until runtime context supplies them. dynamic_deferred_mode = use_context_args and explicit_context_args is None for name, param in model_field_params.items(): @@ -724,10 +895,11 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: lazy_fields.add(name) if param.default is not inspect.Parameter.empty: + default_param_names.add(name) default = param.default elif dynamic_deferred_mode: - # In dynamic mode, params without defaults are optional (come from context) - default = None + # In dynamic mode, params without defaults remain deferred to runtime context. + default = Field(default_factory=_deferred_input_factory, exclude_if=_has_deferred_input) else: # In explicit mode, params without defaults are required default = ... @@ -786,17 +958,32 @@ def _resolve_field(name, value): value = getattr(self, name) fn_kwargs[name] = _resolve_field(name, value) else: - # Mode 3: Dynamic deferred mode - unbound from context, bound from self - bound_fields = getattr(self, "_bound_fields", set()) + # Mode 3: Dynamic deferred mode - explicit values or Python defaults from self, + # otherwise values come from runtime context. + explicit_fields = _bound_field_names(self) + missing_fields = [] for name in all_param_names: - if name in bound_fields: - # Bound at construction - get from self + value = getattr(self, name, _DEFERRED_INPUT) + if name in explicit_fields or name in default_param_names: + # Explicitly provided or implicitly bound via Python default. value = getattr(self, name) fn_kwargs[name] = _resolve_field(name, value) - else: - # Unbound - get from context - fn_kwargs[name] = getattr(context, name) + continue + + if _has_deferred_input(value): + value = getattr(context, name, _UNSET) + if value is _UNSET: + missing_fields.append(name) + continue + fn_kwargs[name] = value + + if missing_fields: + missing = ", ".join(sorted(missing_fields)) + raise TypeError( + f"Missing runtime input(s) for {_callable_name(fn)}: {missing}. " + "Provide them in the call context or bind them at construction time." + ) raw_result = fn(**fn_kwargs) if auto_wrap_result: @@ -807,7 +994,7 @@ def _resolve_field(name, value): cast(Any, __call__).__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=resolved_context_type), ], return_annotation=internal_return_type, ) @@ -849,7 +1036,7 @@ def __deps__(self, context) -> GraphDepList: cast(Any, __deps__).__signature__ = inspect.Signature( parameters=[ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=context_type), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=resolved_context_type), ], return_annotation=GraphDepList, ) @@ -884,34 +1071,32 @@ def __deps__(self, context) -> GraphDepList: GeneratedModel = cast(type[_GeneratedFlowModelBase], type(f"_{_callable_name(fn)}_Model", (_GeneratedFlowModelBase,), namespace)) # Set class-level attributes after class creation (to avoid pydantic processing) - GeneratedModel.__flow_model_context_type__ = context_type + GeneratedModel.__flow_model_context_type__ = resolved_context_type GeneratedModel.__flow_model_return_type__ = internal_return_type setattr(GeneratedModel, "__flow_model_func__", fn) GeneratedModel.__flow_model_use_context_args__ = use_context_args GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type + GeneratedModel.__flow_model_default_param_names__ = default_param_names GeneratedModel.__flow_model_auto_wrap__ = auto_wrap_result - # Build context_schema and matched_context_type + # Build context_schema context_schema: Dict[str, Type] = {} context_td = None - matched_context_type: Optional[Type[ContextBase]] = None if explicit_context_args is not None: # Explicit context_args provided - use early-computed schema - # (matched_context_type was already used to set context_type above) - context_schema, context_td, matched_context_type = _build_context_schema(explicit_context_args, fn, sig) + context_schema, context_td = context_schema_early, context_td_early elif not use_context_args: # Explicit context mode - schema comes from the context type's fields - if hasattr(context_type, "model_fields"): - context_schema = {name: info.annotation for name, info in context_type.model_fields.items()} + if hasattr(resolved_context_type, "model_fields"): + context_schema = {name: info.annotation for name, info in resolved_context_type.model_fields.items()} # For dynamic mode (is_dynamic_mode), _context_schema remains empty - # and schema is built dynamically from _bound_fields at runtime + # and schema is built dynamically from the instance's unresolved runtime inputs. # Store context schema for TypedDict-based validation (picklable!) GeneratedModel._context_schema = context_schema GeneratedModel._context_td = context_td - GeneratedModel._matched_context_type = matched_context_type # Validator is created lazily to survive pickling GeneratedModel._cached_context_validator = None @@ -927,12 +1112,7 @@ def __deps__(self, context) -> GraphDepList: @wraps(fn) def factory(**kwargs) -> _GeneratedFlowModelBase: _validate_config_kwargs(kwargs, _validatable_types, _config_validators) - - instance = GeneratedModel(**kwargs) - # Track which fields were explicitly provided at construction - # These are "bound" - everything else comes from context at runtime - object.__setattr__(instance, "_bound_fields", set(kwargs.keys())) - return instance + return GeneratedModel(**kwargs) # Preserve useful attributes on factory cast(Any, factory)._generated_model = GeneratedModel @@ -944,59 +1124,3 @@ def factory(**kwargs) -> _GeneratedFlowModelBase: if func is not None: return decorator(func) return decorator - - -# ============================================================================= -# FieldExtractor — structured output field access -# ============================================================================= - - -class FieldExtractor(_FieldExtractorMixin, CallableModel): - """Extracts a named field from a source model's result. - - Created automatically by accessing an unknown attribute on a @Flow.model - instance (e.g., ``prepared.X_train``). The extractor is itself a - CallableModel, so it can be wired as a dependency to downstream models. - - When evaluated, it runs the source model and returns - ``GenericResult(value=getattr(source_result, field_name))``. - - Multiple extractors from the same source share the source model instance. - If caching is enabled on the evaluator, the source is evaluated only once. - """ - - source: Any # The source CallableModel - field_name: str # The attribute to extract - - @property - def context_type(self): - if isinstance(self.source, BoundModel): - return self.source.context_type - if isinstance(self.source, _CallableModel): - return self.source.context_type - return ContextBase - - @property - def result_type(self): - return GenericResult - - @Flow.call - def __call__(self, context: ContextBase) -> GenericResult: - result = self.source(context) - if isinstance(result, GenericResult): - result = result.value - # Support both attribute access and dict key access - if isinstance(result, dict): - return GenericResult(value=result[self.field_name]) - return GenericResult(value=getattr(result, self.field_name)) - - @Flow.deps - def __deps__(self, context: ContextBase) -> GraphDepList: - if isinstance(self.source, BoundModel): - return [(self.source._model, [self.source._transform_context(context)])] - if isinstance(self.source, _CallableModel): - return [(self.source, [context])] - return [] - - -register_ccflow_import_path(FieldExtractor) diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 29f4524..6d8f53e 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -1024,4 +1024,36 @@ def bad_func(self, *, x: int) -> GenericResult: error_msg = str(cm.exception) self.assertIn("auto_context must be False, True, or a ContextBase subclass", error_msg) - self.assertIn("invalid", error_msg) + + def test_auto_context_rejects_var_args(self): + """auto_context should reject *args early with a clear error.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *args: int) -> GenericResult: + return GenericResult(value=len(args)) + + self.assertIn("variadic positional", str(cm.exception)) + + def test_auto_context_rejects_var_kwargs(self): + """auto_context should reject **kwargs early with a clear error.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, **kwargs: int) -> GenericResult: + return GenericResult(value=len(kwargs)) + + self.assertIn("variadic keyword", str(cm.exception)) + + def test_auto_context_requires_return_annotation(self): + """auto_context should reject missing return annotations immediately.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value: int): + return GenericResult(value=value) + + self.assertIn("must have a return type annotation", str(cm.exception)) diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index bd526b1..718f8de 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -72,6 +72,20 @@ def test_flow_context_model_dump(self): assert dumped["start_date"] == date(2024, 1, 1) assert dumped["value"] == 42 + def test_flow_context_value_semantics_include_extra_fields(self): + """Equality should reflect the actual extra payload.""" + assert FlowContext(x=1) == FlowContext(x=1) + assert FlowContext(x=1) != FlowContext(x=2) + assert FlowContext(x=1) != FlowContext(y=1) + + def test_flow_context_hash_uses_extra_fields(self): + """Distinct extra payloads should remain distinct in hashed collections.""" + first = FlowContext(values=[1, 2], label="a") + second = FlowContext(values=[1, 3], label="a") + third = FlowContext(values=[1, 2], label="b") + + assert len({first, second, third}) == 3 + def test_flow_context_pickle(self): """FlowContext pickles cleanly.""" ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) @@ -102,9 +116,9 @@ def load_data(start_date: date, end_date: date, source: str = "db") -> GenericRe model = load_data(source="api") result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - assert result["start"] == date(2024, 1, 1) - assert result["end"] == date(2024, 1, 31) - assert result["source"] == "api" + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) + assert result.value["source"] == "api" def test_flow_compute_type_coercion(self): """FlowAPI.compute() coerces types via TypeAdapter.""" @@ -117,8 +131,8 @@ def load_data(start_date: date, end_date: date) -> GenericResult[dict]: # Pass strings - should be coerced to dates result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") - assert result["start"] == date(2024, 1, 1) - assert result["end"] == date(2024, 1, 31) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) def test_flow_compute_validation_error(self): """FlowAPI.compute() raises on missing required args.""" @@ -170,7 +184,7 @@ def test_flow_compute_regular_callable_model(self): model = OffsetModel(offset=10) result = model.flow.compute(x=5) - assert result == 15 + assert result.value == 15 def test_flow_unbound_inputs_regular_callable_model(self): """Regular CallableModels expose their context schema as unbound inputs.""" @@ -305,15 +319,14 @@ def compute(x: int) -> GenericResult[int]: assert validator is not None assert model.__class__._cached_context_validator is validator - def test_matched_context_type(self): - """DateRangeContext pattern is matched for compatibility.""" + def test_explicit_context_type_override(self): + """context_type can opt into an existing ContextBase subclass.""" - @Flow.model(context_args=["start_date", "end_date"]) + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_data(start_date: date, end_date: date) -> GenericResult[dict]: return GenericResult(value={}) model = load_data() - # Should match DateRangeContext assert model.context_type == DateRangeContext @@ -340,7 +353,7 @@ def compute(x: int, y: int, multiplier: int = 2) -> GenericResult[int]: # Should work after unpickling result = unpickled.flow.compute(x=1, y=2) - assert result == 9 # (1 + 2) * 3 + assert result.value == 9 # (1 + 2) * 3 def test_model_cloudpickle_simple(self): """Simple model cloudpickle test.""" @@ -355,7 +368,7 @@ def double(value: int) -> GenericResult[int]: unpickled = cloudpickle.loads(pickled) result = unpickled.flow.compute(value=21) - assert result == 42 + assert result.value == 42 def test_validator_recreated_after_cloudpickle(self): """TypeAdapter validator is recreated after cloudpickling.""" @@ -375,7 +388,7 @@ def compute(x: int) -> GenericResult[int]: # Validator should still work (may be lazily recreated) result = unpickled.flow.compute(x=42) - assert result == 42 + assert result.value == 42 def test_flow_context_pickle_standard(self): """FlowContext works with standard pickle.""" @@ -431,8 +444,8 @@ def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[di model = load_data(source="api") result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - assert result["start"] == date(2024, 1, 1) - assert result["end"] == date(2024, 1, 31) + assert result.value["start"] == date(2024, 1, 1) + assert result.value["end"] == date(2024, 1, 31) class TestLazy: diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 018052a..a2e788e 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -11,6 +11,7 @@ ContextBase, DateRangeContext, Flow, + FlowContext, FlowOptionsOverride, GenericResult, Lazy, @@ -160,13 +161,13 @@ class TestFlowModelContextArgs(TestCase): def test_context_args_basic(self): """Test basic context_args usage.""" - @Flow.model(context_args=["start_date", "end_date"]) + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def date_range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: return GenericResult(value=f"{source}:{start_date} to {end_date}") loader = date_range_loader(source="db") - # Should use DateRangeContext + # Explicit context_type keeps compatibility with existing contexts. self.assertEqual(loader.context_type, DateRangeContext) ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) @@ -182,6 +183,9 @@ def unpacked_model(x: int, y: str, multiplier: int = 1) -> GenericResult[str]: model = unpacked_model(multiplier=2) + # Default context_args mode uses FlowContext unless overridden explicitly. + self.assertEqual(model.context_type, FlowContext) + # Create context with generated type ctx_type = model.context_type ctx = ctx_type(x=5, y="test") @@ -490,6 +494,40 @@ def serializable_model(context: SimpleContext, value: int = 42) -> GenericResult self.assertEqual(dumped["value"], 100) self.assertIn("type_", dumped) + def test_serialization_roundtrip_preserves_bound_inputs(self): + """Round-tripping should preserve which inputs were bound at construction.""" + + @Flow.model + def add(x: int, y: int) -> int: + return x + y + + model = add(x=10) + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) + + self.assertEqual(dumped["x"], 10) + self.assertNotIn("y", dumped) + self.assertEqual(restored.flow.bound_inputs, {"x": 10}) + self.assertEqual(restored.flow.unbound_inputs, {"y": int}) + self.assertEqual(restored.flow.compute(y=5).value, 15) + + def test_serialization_roundtrip_preserves_defaults_and_deferred_inputs(self): + """Default-valued params should serialize normally without binding runtime-only inputs.""" + + @Flow.model + def load(start_date: str, source: str = "warehouse") -> str: + return f"{source}:{start_date}" + + model = load() + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) + + self.assertEqual(dumped["source"], "warehouse") + self.assertNotIn("start_date", dumped) + self.assertEqual(restored.flow.bound_inputs, {"source": "warehouse"}) + self.assertEqual(restored.flow.unbound_inputs, {"start_date": str}) + self.assertEqual(restored.flow.compute(start_date="2024-01-01").value, "warehouse:2024-01-01") + def test_pickle_roundtrip(self): """Test cloudpickle serialization of generated models.""" @@ -568,7 +606,7 @@ def test_auto_wrap_unwrap_as_dependency(self): Auto-wrapped models have result_type=GenericResult (unparameterized). When used as an auto-detected dep, the framework resolves - the GenericResult and unwraps .value for the downstream function. + the GenericResult to its inner value for the downstream function. """ @Flow.model @@ -622,24 +660,36 @@ def dynamic_model(value: int, multiplier: int) -> GenericResult[int]: result = model(ctx) self.assertEqual(result.value, 30) # 10 * 3 + def test_dynamic_deferred_mode_missing_runtime_inputs_is_clear(self): + """Missing deferred inputs should fail at the framework boundary.""" + + @Flow.model + def dynamic_model(value: int, multiplier: int) -> int: + return value * multiplier + + model = dynamic_model() + + with self.assertRaises(TypeError) as cm: + model.flow.compute() + + self.assertIn("Missing runtime input(s) for dynamic_model: multiplier, value", str(cm.exception)) + def test_all_defaults_is_valid(self): - """Test that all-defaults function is valid (everything can be pre-bound).""" + """All-default functions should treat those defaults as bound config.""" from ccflow import FlowContext @Flow.model def all_defaults(value: int = 1, other: str = "x") -> GenericResult[str]: return GenericResult(value=f"{value}-{other}") - # No args provided -> everything comes from defaults or context model = all_defaults() - # All params are unbound (not provided at construction) - self.assertEqual(model.flow.unbound_inputs, {"value": int, "other": str}) + self.assertEqual(model.flow.bound_inputs, {"value": 1, "other": "x"}) + self.assertEqual(model.flow.unbound_inputs, {}) - # Call with context - context values override defaults ctx = FlowContext(value=5, other="y") result = model(ctx) - self.assertEqual(result.value, "5-y") + self.assertEqual(result.value, "1-x") def test_invalid_context_arg(self): """Test error when context_args refers to non-existent parameter.""" @@ -661,6 +711,30 @@ def untyped_context_arg(x) -> GenericResult[int]: self.assertIn("type annotation", str(cm.exception)) + def test_context_type_requires_context_args_mode(self): + """context_type is only valid alongside context_args.""" + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_type=DateRangeContext) + def dynamic_model(value: int) -> GenericResult[int]: + return GenericResult(value=value) + + self.assertIn("context_args", str(cm.exception)) + + def test_context_type_must_cover_context_args(self): + """context_type must expose all named context_args fields.""" + + class StartOnlyContext(ContextBase): + start_date: date + + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_args=["start_date", "end_date"], context_type=StartOnlyContext) + def load_data(start_date: date, end_date: date) -> GenericResult[dict]: + return GenericResult(value={}) + + self.assertIn("end_date", str(cm.exception)) + # ============================================================================= # Validation Tests @@ -755,6 +829,32 @@ def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: finally: registry.clear() + def test_context_type_annotation_mismatch_raises(self): + """context_type validation should reject incompatible field annotations.""" + + class StringIdContext(ContextBase): + item_id: str + + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_args=["item_id"], context_type=StringIdContext) + def load(item_id: int) -> int: + return item_id + + self.assertIn("item_id", str(cm.exception)) + self.assertIn("int", str(cm.exception)) + self.assertIn("str", str(cm.exception)) + + def test_context_type_compatible_annotations_accepted(self): + """context_type validation should accept matching or subclass annotations.""" + + # Exact match should work + @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) + def load_exact(start_date: date, end_date: date) -> str: + return f"{start_date}" + + self.assertIsNotNone(load_exact) + # ============================================================================= # BoundModel Tests @@ -780,7 +880,7 @@ def my_model(x: int, y: int) -> GenericResult[int]: result = bound.flow.compute(y=5) # y transform: 5 * 2 = 10, x is bound to 10 # model: 10 + 10 = 20 - self.assertEqual(result, 20) + self.assertEqual(result.value, 20) def test_bound_model_flow_compute_static_transform(self): """Test BoundModel.flow.compute() with static value transform.""" @@ -795,7 +895,19 @@ def my_model(x: int, y: int) -> GenericResult[int]: result = bound.flow.compute(y=999) # y should be overridden by transform # y is statically bound to 3, x=7 # 7 * 3 = 21 - self.assertEqual(result, 21) + self.assertEqual(result.value, 21) + + def test_bound_model_cloudpickle_with_lambda_transform(self): + """BoundModel with lambda transforms should survive cloudpickle round-trip.""" + + @Flow.model + def my_model(x: int, y: int) -> int: + return x + y + + bound = my_model(x=10).flow.with_inputs(y=lambda ctx: ctx.y * 2) + restored = rcploads(rcpdumps(bound, protocol=5)) + + self.assertEqual(restored.flow.compute(y=6).value, 22) def test_bound_model_as_dependency(self): """Test that BoundModel can be passed as a dependency to another model.""" @@ -817,7 +929,21 @@ def consumer(data: GenericResult[int]) -> GenericResult[int]: # x transform: 5 * 2 = 10 # source: 10 * 10 = 100 # consumer: 100 + 1 = 101 - self.assertEqual(result, 101) + self.assertEqual(result.value, 101) + + def test_flow_compute_with_upstream_callable_model_dependency(self): + """flow.compute() should resolve upstream generated-model dependencies.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model + def consumer(data: GenericResult[int], offset: int = 1) -> int: + return data + offset + + model = consumer(data=source(), offset=3) + self.assertEqual(model.flow.compute(x=5).value, 53) def test_bound_model_chained_with_inputs(self): """Test that chaining with_inputs merges transforms correctly.""" @@ -836,7 +962,7 @@ def my_model(x: int, y: int, z: int) -> int: # y transform: 10 * 3 = 30 # z from context: 1 # 10 + 30 + 1 = 41 - self.assertEqual(result, 41) + self.assertEqual(result.value, 41) def test_bound_model_chained_with_inputs_override(self): """Test that chaining with_inputs allows overriding transforms.""" @@ -851,7 +977,7 @@ def my_model(x: int) -> int: # Second transform should override the first for 'x' result = bound2.flow.compute(x=5) - self.assertEqual(result, 50) # 5 * 10, not 5 * 2 + self.assertEqual(result.value, 50) # 5 * 10, not 5 * 2 def test_bound_model_with_default_args(self): """with_inputs works when the model has parameters with default values.""" @@ -867,24 +993,24 @@ def load(start_date: str, end_date: str, source: str = "warehouse") -> str: lookback = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) result = lookback.flow.compute(start_date="2024-01-01", end_date="2024-06-30") - self.assertEqual(result, "prod_db:shifted_2024-01-01-2024-06-30") + self.assertEqual(result.value, "prod_db:shifted_2024-01-01-2024-06-30") - def test_bound_model_with_default_arg_unbound(self): - """with_inputs works when defaulted parameter is left unbound (comes from context).""" + def test_bound_model_with_default_arg_uses_default(self): + """with_inputs should preserve omitted Python defaults as bound config.""" @Flow.model def load(start_date: str, source: str = "warehouse") -> str: return f"{source}:{start_date}" - # Don't bind 'source' — it keeps its default in the model, - # but in dynamic deferred mode, unbound params come from context model = load() - # Transform start_date; source comes from context (overriding the default) bound = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) - result = bound.flow.compute(start_date="2024-01-01", source="s3_bucket") - self.assertEqual(result, "s3_bucket:shifted_2024-01-01") + self.assertEqual(model.flow.bound_inputs, {"source": "warehouse"}) + self.assertEqual(model.flow.unbound_inputs, {"start_date": str}) + + result = bound.flow.compute(start_date="2024-01-01") + self.assertEqual(result.value, "warehouse:shifted_2024-01-01") def test_bound_model_default_arg_as_dependency(self): """BoundModel with default args works correctly as a dependency.""" @@ -905,7 +1031,7 @@ def consumer(data: int) -> int: # x transform: 3 * 10 = 30 # source: 30 * 5 (multiplier) = 150 # consumer: 150 + 1 = 151 - self.assertEqual(result, 151) + self.assertEqual(result.value, 151) def test_bound_model_as_lazy_dependency(self): """Test that BoundModel works as a Lazy dependency.""" @@ -927,7 +1053,7 @@ def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: model = consumer(data=5, slow=bound_src) result = model.flow.compute(x=7) # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 - self.assertEqual(result, 51) + self.assertEqual(result.value, 51) def test_bound_and_unbound_models_share_memory_cache(self): """Shifted and unshifted models should share one evaluator cache. @@ -959,6 +1085,135 @@ def source(context: SimpleContext) -> GenericResult[int]: self.assertEqual(call_counts["source"], 2) self.assertEqual(len(evaluator.cache), 2) + def test_transform_error_propagates(self): + """A buggy transform should raise, not silently fall back to FlowContext.""" + + @Flow.model + def load(context: DateRangeContext, source: str = "db") -> str: + return f"{source}:{context.start_date}" + + model = load() + # Transform has a typo — ctx.sart_date instead of ctx.start_date + bound = model.flow.with_inputs(start_date=lambda ctx: ctx.sart_date) + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + with self.assertRaises(AttributeError): + bound(ctx) + + def test_transform_validation_error_propagates(self): + """If transforms produce invalid context data, the error should surface.""" + from pydantic import ValidationError + + @Flow.model + def load(context: DateRangeContext, source: str = "db") -> str: + return f"{source}:{context.start_date}" + + model = load() + # Transform returns a string where a date is expected + bound = model.flow.with_inputs(start_date="not-a-date") + + ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) + # Pydantic validation should raise, not silently fall back to FlowContext + with self.assertRaises(ValidationError): + bound(ctx) + + +class TestFlowModelPipe(TestCase): + """Tests for the ``.pipe(..., param=...)`` convenience API.""" + + def test_pipe_infers_single_required_parameter(self): + """pipe() should infer the only required downstream parameter.""" + + @Flow.model + def source(x: int) -> GenericResult[int]: + return GenericResult(value=x * 10) + + @Flow.model + def consumer(data: int, offset: int = 1) -> int: + return data + offset + + pipeline = source().pipe(consumer, offset=3) + self.assertEqual(pipeline.flow.compute(x=5).value, 53) + + def test_pipe_infers_single_defaulted_parameter(self): + """pipe() should fall back to a single defaulted downstream parameter.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def consumer(data: int = 0) -> int: + return data + 1 + + pipeline = source().pipe(consumer) + self.assertEqual(pipeline.flow.compute(x=5).value, 51) + + def test_pipe_param_disambiguates_multiple_parameters(self): + """param= should identify the downstream argument to bind.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def combine(left: int, right: int) -> int: + return left + right + + pipeline = source().pipe(combine, param="right", left=7) + self.assertEqual(pipeline.flow.compute(x=5).value, 57) + + def test_pipe_rejects_ambiguous_downstream_stage(self): + """pipe() should require param= when multiple targets are available.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def combine(left: int, right: int) -> int: + return left + right + + with self.assertRaisesRegex( + TypeError, + r"pipe\(\) could not infer a target parameter for combine; unbound candidates are: left, right", + ): + source().pipe(combine) + + def test_manual_callable_model_can_pipe_into_generated_stage(self): + """Hand-written CallableModels should be usable as pipe sources.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + @Flow.model + def consumer(data: int, multiplier: int) -> int: + return data * multiplier + + pipeline = ManualModel(offset=5).pipe(consumer, multiplier=2) + self.assertEqual(pipeline.flow.compute(value=10).value, 30) + + def test_bound_model_pipe_preserves_downstream_transforms(self): + """pipe() should keep downstream with_inputs transforms intact.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def consumer(data: int, scale: int) -> int: + return data + scale + + shifted_source = source().flow.with_inputs(x=lambda ctx: ctx.scale + 1) + scaled_consumer = consumer().flow.with_inputs(scale=lambda ctx: ctx.scale * 3) + + pipeline = shifted_source.pipe(scaled_consumer) + self.assertEqual(pipeline.flow.compute(scale=2).value, 76) + # ============================================================================= # PEP 563 (from __future__ import annotations) Compatibility Tests @@ -1007,7 +1262,7 @@ def plain_model(value: int) -> int: model = plain_model() result = model.flow.compute(value=5) - self.assertEqual(result, 10) + self.assertEqual(result.value, 10) self.assertEqual(model.result_type, GenericResult) @@ -1159,7 +1414,7 @@ def hydra_consumer_model( # --- context_args fixtures for Hydra testing --- -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[dict]: """Loader using context_args with DateRangeContext.""" return GenericResult( @@ -1171,7 +1426,7 @@ def context_args_loader(start_date: date, end_date: date, source: str) -> Generi ) -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def context_args_processor( start_date: date, end_date: date, @@ -1539,235 +1794,6 @@ def consumer( self.assertEqual(result.value, 51) # 50 + 1 -# ============================================================================= -# FieldExtractor Tests (Structured Output Field Access) -# ============================================================================= - - -class TestFieldExtractor(TestCase): - """Tests for structured output field access (prepared.X_train pattern).""" - - def test_field_extraction_basic(self): - """Accessing unknown attr on @Flow.model instance returns FieldExtractor.""" - from ccflow.flow_model import FieldExtractor - - @Flow.model - def prepare(context: SimpleContext, factor: int = 2) -> GenericResult[dict]: - return GenericResult(value={"X_train": context.value * factor, "X_test": context.value}) - - model = prepare(factor=3) - extractor = model.X_train - - self.assertIsInstance(extractor, FieldExtractor) - self.assertIs(extractor.source, model) - self.assertEqual(extractor.field_name, "X_train") - - def test_field_extraction_evaluates_correctly(self): - """FieldExtractor runs source and extracts the named field.""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"X_train": [1, 2, 3], "y_train": [4, 5, 6]}) - - model = prepare() - x_train = model.X_train - - result = x_train(SimpleContext(value=0)) - self.assertEqual(result.value, [1, 2, 3]) - - def test_field_extraction_as_dependency(self): - """FieldExtractor wired as a dep to a downstream model. - - Note: FieldExtractors are CallableModels, so they're auto-detected as deps - and auto-unwrapped (GenericResult.value). The downstream function receives - the raw extracted value, not a GenericResult wrapper. - """ - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - v = context.value - return GenericResult(value={"X_train": [v, v * 2], "y_train": [v * 10]}) - - @Flow.model - def train(context: SimpleContext, X: list, y: list) -> GenericResult[int]: - # X and y are auto-unwrapped to the raw list values - return GenericResult(value=sum(X) + sum(y)) - - prepared = prepare() - model = train(X=prepared.X_train, y=prepared.y_train) - - result = model(SimpleContext(value=5)) - # X_train = [5, 10], y_train = [50] - # sum(X) + sum(y) = 15 + 50 = 65 - self.assertEqual(result.value, 65) - - def test_field_extraction_multiple_from_same_source(self): - """Multiple extractors from same source share the source instance.""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"a": 1, "b": 2, "c": 3}) - - model = prepare() - ext_a = model.a - ext_b = model.b - ext_c = model.c - - # All should reference the same source - self.assertIs(ext_a.source, model) - self.assertIs(ext_b.source, model) - self.assertIs(ext_c.source, model) - - # All should evaluate correctly - ctx = SimpleContext(value=0) - self.assertEqual(ext_a(ctx).value, 1) - self.assertEqual(ext_b(ctx).value, 2) - self.assertEqual(ext_c(ctx).value, 3) - - def test_field_extraction_nested(self): - """Chained extraction (result.a.b) creates nested FieldExtractors.""" - from ccflow.flow_model import FieldExtractor - - class Nested: - def __init__(self): - self.inner_val = 42 - - @Flow.model - def produce(context: SimpleContext) -> GenericResult: - return GenericResult(value={"nested": Nested()}) - - model = produce() - nested_extractor = model.nested - inner_extractor = nested_extractor.inner_val - - self.assertIsInstance(nested_extractor, FieldExtractor) - self.assertIsInstance(inner_extractor, FieldExtractor) - - result = inner_extractor(SimpleContext(value=0)) - self.assertEqual(result.value, 42) - - def test_field_extraction_context_type_inherited(self): - """FieldExtractor inherits context_type from source.""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"x": 1}) - - model = prepare() - extractor = model.x - - self.assertEqual(extractor.context_type, SimpleContext) - - def test_field_extraction_nonexistent_field_runtime_error(self): - """Non-existent field raises error at evaluation time, not construction. - - For dict results, raises KeyError. For object results, raises AttributeError. - """ - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"x": 1}) - - model = prepare() - extractor = model.nonexistent # No error at construction - - # Error at evaluation time (KeyError for dicts, AttributeError for objects) - with self.assertRaises((KeyError, AttributeError)): - extractor(SimpleContext(value=0)) - - def test_field_extraction_pydantic_fields_not_intercepted(self): - """Accessing real pydantic fields returns the field value, NOT an extractor.""" - from ccflow.flow_model import FieldExtractor - - @Flow.model - def model_with_fields(context: SimpleContext, multiplier: int = 5) -> GenericResult[int]: - return GenericResult(value=context.value * multiplier) - - model = model_with_fields(multiplier=10) - - # 'multiplier' is a real pydantic field — should return the value, not a FieldExtractor - self.assertEqual(model.multiplier, 10) - self.assertNotIsInstance(model.multiplier, FieldExtractor) - - # 'meta' is inherited from CallableModel — should also not be intercepted - self.assertNotIsInstance(model.meta, FieldExtractor) - - def test_field_extraction_with_context_args(self): - """FieldExtractor works with context_args mode models.""" - from ccflow import FlowContext - - @Flow.model(context_args=["x"]) - def prepare(x: int) -> GenericResult[dict]: - return GenericResult(value={"doubled": x * 2, "tripled": x * 3}) - - model = prepare() - doubled = model.doubled - - result = doubled(FlowContext(x=5)) - self.assertEqual(result.value, 10) - - def test_field_extraction_has_flow_property(self): - """FieldExtractor has .flow property (inherits from CallableModel).""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"x": 1}) - - model = prepare() - extractor = model.x - - self.assertTrue(hasattr(extractor, "flow")) - - def test_field_extraction_deps(self): - """FieldExtractor.__deps__ returns the source as a dependency.""" - - @Flow.model - def prepare(context: SimpleContext) -> GenericResult[dict]: - return GenericResult(value={"x": 1}) - - model = prepare() - extractor = model.x - - ctx = SimpleContext(value=0) - deps = extractor.__deps__(ctx) - - self.assertEqual(len(deps), 1) - self.assertIs(deps[0][0], model) - self.assertEqual(deps[0][1], [ctx]) - - def test_field_extraction_from_bound_model(self): - """Field extraction should still work after .flow.with_inputs().""" - - @Flow.model - def prepare(x: int) -> GenericResult[dict]: - return GenericResult(value={"doubled": x * 2}) - - bound = prepare().flow.with_inputs(x=lambda ctx: ctx.x + 1) - extractor = bound.doubled - - result = extractor.flow.compute(x=5) - self.assertEqual(result, 12) - - def test_field_extraction_deps_from_bound_model(self): - """Bound-model extractors should preserve transformed dependency contexts.""" - from ccflow import FlowContext - - @Flow.model - def prepare(x: int) -> GenericResult[dict]: - return GenericResult(value={"doubled": x * 2}) - - model = prepare() - bound = model.flow.with_inputs(x=lambda ctx: ctx.x + 1) - extractor = bound.doubled - - ctx = FlowContext(x=5) - deps = extractor.__deps__(ctx) - - self.assertEqual(len(deps), 1) - self.assertIs(deps[0][0], model) - self.assertEqual(deps[0][1][0].x, 6) - - if __name__ == "__main__": import unittest diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index 76adbea..7b6ac9f 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -36,13 +36,13 @@ def add(x: int, y: int) -> int: model = add(x=10) # Explicit deferred entry point -assert model.flow.compute(y=5) == 15 +assert model.flow.compute(y=5).value == 15 # Standard CallableModel call path assert model(FlowContext(y=5)).value == 15 shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert shifted.flow.compute(y=5) == 20 +assert shifted.flow.compute(y=5).value == 20 ``` In this mode: @@ -73,7 +73,7 @@ from datetime import date from ccflow import Flow -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_revenue(start_date: date, end_date: date, region: str) -> float: return 125.0 ``` @@ -85,9 +85,8 @@ Use `context_args` when certain parameters are semantically the execution context and you want that split to be explicit and stable across model instances. -When the requested shape matches a built-in context like -`DateRangeContext(start_date, end_date)`, the generated model uses that type. -Otherwise it falls back to `FlowContext`. +By default, `context_args` models use `FlowContext`. If you want compatibility +with an existing context class, pass `context_type=...` explicitly. ### Upstream Models as Normal Arguments @@ -130,13 +129,13 @@ from datetime import date, timedelta from ccflow import DateRangeContext, Flow -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_revenue(start_date: date, end_date: date, region: str) -> float: days = (end_date - start_date).days + 1 return 1000.0 + days * 10.0 -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def revenue_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: return { "window_end": end_date, @@ -156,7 +155,7 @@ ctx = DateRangeContext( end_date=date(2024, 1, 31), ) -direct = model(ctx).value +direct = model(ctx) computed = model.flow.compute( start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), @@ -182,49 +181,17 @@ def add(x: int, y: int) -> int: model = add(x=10) -assert model.flow.compute(y=5) == 15 +assert model.flow.compute(y=5).value == 15 ``` It validates the supplied keyword arguments against the generated context -schema, creates a `FlowContext`, executes the model, and unwraps -`GenericResult.value` if needed. +schema, creates a `FlowContext`, and executes the model. + +It returns the same result object you would get from calling `model(context)`. It is not the only execution path. Because the generated object is still a standard `CallableModel`, calling `model(context)` remains fully supported. -## FieldExtractor - -Accessing an unknown public attribute on a `@Flow.model` instance returns a -`FieldExtractor`. It is itself a `CallableModel` that runs the source model, -then extracts the named field from the result (via `getattr` or dict key -access). - -```python -from ccflow import ContextBase, Flow, GenericResult - - -class TrainingContext(ContextBase): - seed: int - - -@Flow.model -def prepare(context: TrainingContext) -> GenericResult[dict]: - s = context.seed - return GenericResult(value={"X_train": [s, s * 2], "y_train": [s * 10]}) - - -@Flow.model -def train(context: TrainingContext, X: list, y: list) -> GenericResult[int]: - return GenericResult(value=sum(X) + sum(y)) - - -prepared = prepare() -model = train(X=prepared.X_train, y=prepared.y_train) -``` - -Multiple extractors from the same source share the source model instance. If -caching is enabled the source is evaluated only once. - ## Lazy Inputs `Lazy[T]` marks a parameter as on-demand. Instead of eagerly resolving an diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 4e34fc3..e0aac45 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -76,23 +76,22 @@ from the context at runtime. The remaining parameters are model configuration. from datetime import date from ccflow import Flow, GenericResult, DateRangeContext -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: return GenericResult(value=f"{source}:{start_date} to {end_date}") loader = load_data(source="my_database") -# For well-known field sets the decorator matches a built-in context type +# Opt in explicitly when you want compatibility with an existing context type assert loader.context_type == DateRangeContext ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) result = loader(ctx) ``` -For well-known shapes such as `start_date` / `end_date` with `date` -annotations, the generated model uses a concrete built-in context type like -`DateRangeContext`. Otherwise it falls back to `FlowContext`, a universal -frozen carrier for the validated fields. +By default, `context_args` models use `FlowContext`, a universal frozen carrier +for the validated fields. If you want the generated model to advertise and +accept an existing `ContextBase` subclass, pass `context_type=...` explicitly. Use `context_args` when some parameters are semantically "the execution context" and you want that split to stay stable and explicit: @@ -101,8 +100,8 @@ context" and you want that split to stay stable and explicit: - the split between config and runtime inputs matters semantically - the model is naturally "run over a context" such as date windows, partitions, or scenarios -- you want the generated model to match a built-in context type like - `DateRangeContext` when possible +- you want the generated model to accept a specific existing context type + such as `DateRangeContext` **Mode 3 — Default deferred style (no explicit context):** @@ -121,11 +120,11 @@ model = add(x=10) # `x` is bound when the model is created. # `y` is supplied later at execution time. -assert model.flow.compute(y=5) == 15 +assert model.flow.compute(y=5).value == 15 # `.flow.with_inputs(...)` rewrites runtime inputs for this call path. doubled_y = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert doubled_y.flow.compute(y=5) == 20 +assert doubled_y.flow.compute(y=5).value == 20 ``` #### Composing Dependencies @@ -138,12 +137,12 @@ the current context and passes the resolved value into your function. from datetime import date, timedelta from ccflow import DateRangeContext, Flow -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_revenue(start_date: date, end_date: date, region: str) -> float: days = (end_date - start_date).days + 1 return 1000.0 + days * 10.0 -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def revenue_growth( start_date: date, end_date: date, @@ -168,7 +167,7 @@ ctx = DateRangeContext( ) # Standard ccflow execution -direct = growth(ctx).value +direct = growth(ctx) # Equivalent explicit deferred entry point computed = growth.flow.compute( @@ -182,8 +181,8 @@ assert direct == computed #### Deferred Execution Helpers **`.flow.compute(**kwargs)`** validates the keyword arguments against the -generated context schema, wraps them in a `FlowContext`, calls the model, and -unwraps `GenericResult.value` if present. +generated context schema, wraps them in a `FlowContext`, and calls the model. +It returns the same result object you would get from `model(context)`. **`.flow.with_inputs(**transforms)`** returns a `BoundModel` that applies context transforms before delegating to the underlying model. Each transform @@ -198,10 +197,10 @@ def add(x: int, y: int) -> int: return x + y model = add(x=10) -assert model.flow.compute(y=5) == 15 +assert model.flow.compute(y=5).value == 15 shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert shifted.flow.compute(y=5) == 20 +assert shifted.flow.compute(y=5).value == 20 # You can also call with a context object directly ctx = FlowContext(y=5) @@ -209,38 +208,6 @@ assert model(ctx).value == 15 assert shifted(ctx).value == 20 ``` -#### Field Extraction - -Accessing an unknown attribute on a `@Flow.model` instance returns a -`FieldExtractor` — a `CallableModel` that runs the source model and extracts -the named field from its result. This makes it easy to wire individual output -fields into downstream models. - -```python -from ccflow import ContextBase, Flow, GenericResult - -class TrainingContext(ContextBase): - seed: int - -@Flow.model -def prepare(context: TrainingContext) -> GenericResult[dict]: - s = context.seed - return GenericResult(value={"X_train": [s, s * 2], "y_train": [s * 10]}) - -@Flow.model -def train(context: TrainingContext, X: list, y: list) -> GenericResult[int]: - return GenericResult(value=sum(X) + sum(y)) - -prepared = prepare() -model = train(X=prepared.X_train, y=prepared.y_train) -result = model(TrainingContext(seed=5)) -# X_train = [5, 10], y_train = [50] -> 15 + 50 = 65 -assert result.value == 65 -``` - -Multiple extractors from the same source share the source model instance, so -with caching enabled the source is only evaluated once. - #### Lazy Dependencies with `Lazy[T]` Mark a parameter with `Lazy[T]` to defer its evaluation. Instead of eagerly diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index 27e31bb..27d5d0e 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -17,7 +17,7 @@ from ccflow import DateRangeContext, Flow -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def load_revenue(start_date: date, end_date: date, region: str) -> float: """Return synthetic revenue for one reporting window.""" days = (end_date - start_date).days + 1 @@ -27,7 +27,7 @@ def load_revenue(start_date: date, end_date: date, region: str) -> float: return round(region_base + days * 8.0 + trend, 2) -@Flow.model(context_args=["start_date", "end_date"]) +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) def revenue_change( start_date: date, end_date: date, @@ -81,7 +81,7 @@ def main() -> None: end_date=date(2024, 3, 31), ) - direct = pipeline(ctx).value + direct = pipeline(ctx) computed = pipeline.flow.compute( start_date=ctx.start_date, end_date=ctx.end_date, @@ -95,7 +95,7 @@ def main() -> None: print(f" direct == computed: {direct == computed}") print("\nResult:") - for key, value in computed.items(): + for key, value in computed.value.items(): print(f" {key}: {value}") From 20612e27ed5ca40a12907cf15b1d7c77567ccd2d Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 19 Mar 2026 14:19:14 -0400 Subject: [PATCH 14/26] Further clean-up Signed-off-by: Nijat Khanbabayev --- ccflow/flow_model.py | 259 ++++++++++++++++++--- ccflow/tests/test_flow_model.py | 389 +++++++++++++++++++++++++++++++- 2 files changed, 611 insertions(+), 37 deletions(-) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index da9d1bb..ab0b3f4 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -11,14 +11,15 @@ import inspect import logging +import threading from functools import wraps from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, get_type_hints -from pydantic import Field, TypeAdapter, model_validator -from typing_extensions import TypedDict +from pydantic import Field, PrivateAttr, TypeAdapter, model_serializer, model_validator +from typing_extensions import NotRequired, TypedDict from .base import ContextBase, ResultBase -from .callable import CallableModel, Flow, GraphDepList +from .callable import CallableModel, Flow, GraphDepList, WrapperModel from .context import FlowContext from .local_persistence import register_ccflow_import_path from .result import GenericResult @@ -109,7 +110,7 @@ def _transform_repr(transform: Any) -> str: def _is_model_dependency(value: Any) -> bool: - return isinstance(value, (CallableModel, BoundModel)) + return isinstance(value, CallableModel) def _bound_field_names(model: Any) -> set[str]: @@ -160,6 +161,22 @@ def _registry_candidate_allowed(expected_type: Type, candidate: Any) -> bool: return True +def _type_accepts_str(annotation) -> bool: + """Return True when ``str`` is a valid type for *annotation*. + + Handles ``str``, ``Union[str, ...]``, ``Optional[str]``, and + ``Annotated[str, ...]``. + """ + if annotation is str: + return True + origin = get_origin(annotation) + if origin is Annotated: + return _type_accepts_str(get_args(annotation)[0]) + if origin is Union: + return any(_type_accepts_str(arg) for arg in get_args(annotation) if arg is not type(None)) + return False + + def _build_typed_dict_adapter(name: str, schema: Dict[str, Type], *, total: bool = True) -> TypeAdapter: """Build a TypeAdapter for a runtime TypedDict schema.""" @@ -199,6 +216,17 @@ def _build_config_validators(all_param_types: Dict[str, Type]) -> Tuple[Dict[str return validatable_types, validators +def _coerce_context_value(name: str, value: Any, validators: Dict[str, TypeAdapter], validatable_types: Dict[str, Type]) -> Any: + """Validate/coerce a single context-sourced value. Returns coerced value or raises TypeError.""" + if name not in validators: + return value + try: + return validators[name].validate_python(value) + except Exception: + expected = validatable_types.get(name, "unknown") + raise TypeError(f"Context field '{name}': expected {expected}, got {type(value).__name__} ({value!r})") + + def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, Type], validators: Dict[str, TypeAdapter]) -> None: """Validate plain config inputs while still allowing dependency objects.""" @@ -214,8 +242,10 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, if value is None or _is_model_dependency(value): continue if isinstance(value, str) and value in _MR.root(): - candidate = _resolve_registry_candidate(value) expected_type = validatable_types[field_name] + if _type_accepts_str(expected_type): + continue + candidate = _resolve_registry_candidate(value) if candidate is not None and _registry_candidate_allowed(expected_type, candidate): continue try: @@ -227,7 +257,7 @@ def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: if isinstance(stage, BoundModel): - model = stage._model + model = stage.model else: model = stage if isinstance(model, _GeneratedFlowModelBase): @@ -324,7 +354,7 @@ def pipe_model(source: Any, stage: Any, /, *, param: Optional[str] = None, **bin """Wire ``source`` into a downstream generated ``@Flow.model`` stage.""" if not _is_model_dependency(source): - raise TypeError(f"pipe() source must be a CallableModel or BoundModel, got {type(source).__name__}.") + raise TypeError(f"pipe() source must be a CallableModel, got {type(source).__name__}.") target_param, generated_model_cls = _resolve_pipe_param(source, stage, param, bindings) build_kwargs = dict(bindings) @@ -364,6 +394,8 @@ def _build_context(self, kwargs: Dict[str, Any]) -> ContextBase: if get_validator is not None: validator = get_validator() validated = validator.validate_python(kwargs) + if isinstance(validated, ContextBase): + return validated return FlowContext(**validated) validator = TypeAdapter(self._model.context_type) @@ -393,22 +425,27 @@ def unbound_inputs(self) -> Dict[str, Type]: all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) model_cls = self._model.__class__ - # If explicit context_args was provided, use _context_schema + # If explicit context_args was provided, use _context_schema minus + # fields that have function defaults (they aren't truly required). explicit_args = getattr(model_cls, "__flow_model_explicit_context_args__", None) if explicit_args is not None: context_schema = getattr(model_cls, "_context_schema", None) - return context_schema.copy() if context_schema is not None else {} + if context_schema is None: + return {} + ctx_arg_defaults = getattr(model_cls, "__flow_model_context_arg_defaults__", {}) + return {name: typ for name, typ in context_schema.items() if name not in ctx_arg_defaults} # Dynamic @Flow.model: unbound = params with no explicit value and no declared default if all_param_types: runtime_inputs = _runtime_input_names(self._model) return {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} - # Generic CallableModel: runtime inputs are the context schema. + # Generic CallableModel / Mode 1: runtime inputs are the required + # context fields (fields with defaults are not required). context_cls = _concrete_context_type(self._model.context_type) if context_cls is None or not hasattr(context_cls, "model_fields"): return {} - return {name: info.annotation for name, info in context_cls.model_fields.items()} + return {name: info.annotation for name, info in context_cls.model_fields.items() if info.is_required()} @property def bound_inputs(self) -> Dict[str, Any]: @@ -445,7 +482,25 @@ def with_inputs(self, **transforms) -> "BoundModel": return BoundModel(model=self._model, input_transforms=transforms) -class BoundModel: +_bound_model_restore = threading.local() + + +def _fingerprint_transforms(transforms: Dict[str, Any]) -> Dict[str, str]: + """Create a stable, hashable fingerprint of input transforms for cache key differentiation. + + Callable transforms are identified by their id() (unique per object), which is + stable within a process lifetime. Static values are repr'd directly. + """ + result = {} + for name, transform in sorted(transforms.items()): + if callable(transform): + result[name] = f"callable:{id(transform)}" + else: + result[name] = repr(transform) + return result + + +class BoundModel(WrapperModel): """A model with context transforms applied. Created by model.flow.with_inputs(). Applies transforms to context @@ -466,9 +521,44 @@ class BoundModel: of a previous transform). """ - def __init__(self, model: CallableModel, input_transforms: Dict[str, Any]): - self._model = model - self._input_transforms = input_transforms + _input_transforms: Dict[str, Any] = PrivateAttr(default_factory=dict) + + @model_validator(mode="wrap") + @classmethod + def _restore_serialized_transforms(cls, values, handler): + """Strip serialization-injected keys, restore static transforms, guarantee cleanup. + + Uses thread-local storage to pass static transforms to __init__ because + pydantic rejects unknown keys in the input dict. The wrap validator's + try/finally ensures the thread-local is always cleaned up, even if + validation fails before __init__ runs. + """ + if isinstance(values, dict): + values = dict(values) # Don't mutate the caller's dict + values.pop("_input_transforms_token", None) + static = values.pop("_static_transforms", None) + else: + static = None + + if static is not None: + _bound_model_restore.pending = static + try: + return handler(values) + except Exception: + _bound_model_restore.pending = None + raise + + def __init__(self, *, model: CallableModel, input_transforms: Optional[Dict[str, Any]] = None, **kwargs): + super().__init__(model=model, **kwargs) + restore = getattr(_bound_model_restore, "pending", None) + if restore is not None: + _bound_model_restore.pending = None + if input_transforms is not None: + self._input_transforms = input_transforms + elif restore is not None: + self._input_transforms = restore + else: + self._input_transforms = {} def _transform_context(self, context: ContextBase) -> ContextBase: """Return this model's preferred context type with input transforms applied.""" @@ -478,14 +568,35 @@ def _transform_context(self, context: ContextBase) -> ContextBase: ctx_dict[name] = transform(context) else: ctx_dict[name] = transform - context_type = _concrete_context_type(self._model.context_type) + context_type = _concrete_context_type(self.model.context_type) if context_type is not None and context_type is not FlowContext: return context_type.model_validate(ctx_dict) return FlowContext(**ctx_dict) - def __call__(self, context: ContextBase) -> Any: + @Flow.call + def __call__(self, context: ContextBase) -> ResultBase: """Call the model with transformed context.""" - return self._model(self._transform_context(context)) + return self.model(self._transform_context(context)) + + @Flow.deps + def __deps__(self, context: ContextBase) -> GraphDepList: + """Declare the wrapped model as an upstream dependency with transformed context.""" + return [(self.model, [self._transform_context(context)])] + + @model_serializer(mode="wrap") + def _serialize_with_transforms(self, handler): + """Include transforms in serialization for cache keys and faithful roundtrips. + + Static (non-callable) transforms are serialized in _static_transforms for + faithful restoration. A fingerprint token covers all transforms (including + callables) for cache key differentiation. + """ + data = handler(self) + static = {k: v for k, v in self._input_transforms.items() if not callable(v)} + if static: + data["_static_transforms"] = static + data["_input_transforms_token"] = _fingerprint_transforms(self._input_transforms) + return data def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: """Wire this bound model into a downstream generated ``@Flow.model`` stage.""" @@ -493,28 +604,24 @@ def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) - def __repr__(self) -> str: transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) - return f"{self._model!r}.flow.with_inputs({transforms})" + return f"{self.model!r}.flow.with_inputs({transforms})" @property def flow(self) -> "FlowAPI": """Access the flow API.""" return _BoundFlowAPI(self) - @property - def context_type(self) -> Type[ContextBase]: - return self._model.context_type - class _BoundFlowAPI(FlowAPI): """FlowAPI that delegates to a BoundModel, honoring transforms.""" def __init__(self, bound_model: BoundModel): self._bound = bound_model - super().__init__(bound_model._model) + super().__init__(bound_model.model) def compute(self, **kwargs) -> Any: ctx = self._build_context(kwargs) - return self._bound(ctx) # Call through BoundModel, not _model + return self._bound(ctx) # Call through BoundModel, not inner model def with_inputs(self, **transforms) -> "BoundModel": """Chain transforms: merge new transforms with existing ones. @@ -522,7 +629,7 @@ def with_inputs(self, **transforms) -> "BoundModel": New transforms override existing ones for the same key. """ merged = {**self._bound._input_transforms, **transforms} - return BoundModel(model=self._bound._model, input_transforms=merged) + return BoundModel(model=self._bound.model, input_transforms=merged) class _GeneratedFlowModelBase(CallableModel): @@ -535,7 +642,10 @@ class _GeneratedFlowModelBase(CallableModel): __flow_model_explicit_context_args__: ClassVar[Optional[List[str]]] = None __flow_model_all_param_types__: ClassVar[Dict[str, Type]] = {} __flow_model_default_param_names__: ClassVar[set[str]] = set() + __flow_model_context_arg_defaults__: ClassVar[Dict[str, Any]] = {} __flow_model_auto_wrap__: ClassVar[bool] = False + __flow_model_validatable_types__: ClassVar[Dict[str, Type]] = {} + __flow_model_config_validators__: ClassVar[Dict[str, TypeAdapter]] = {} _context_schema: ClassVar[Dict[str, Type]] = {} _context_td: ClassVar[Any | None] = None _cached_context_validator: ClassVar[TypeAdapter | None] = None @@ -553,7 +663,7 @@ def _resolve_registry_refs(cls, values, info): value = resolved[field_name] if not isinstance(value, str): continue - if expected_type is str: + if _type_accepts_str(expected_type): continue candidate = _resolve_registry_candidate(value) if candidate is None: @@ -562,6 +672,30 @@ def _resolve_registry_refs(cls, values, info): resolved[field_name] = candidate return resolved + @model_validator(mode="after") + def _validate_field_types(self): + """Validate field values against their declared types. + + This catches type mismatches in the model_validate/deserialization path, + where fields are typed as Any and pydantic won't reject wrong types. + """ + cls = self.__class__ + config_validators = getattr(cls, "__flow_model_config_validators__", {}) + validatable_types = getattr(cls, "__flow_model_validatable_types__", {}) + if not config_validators: + return self + + for field_name, validator in config_validators.items(): + value = getattr(self, field_name, _DEFERRED_INPUT) + if _has_deferred_input(value) or value is None or _is_model_dependency(value): + continue + try: + validator.validate_python(value) + except Exception: + expected_type = validatable_types[field_name] + raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") + return self + @property def context_type(self) -> Type[ContextBase]: return self.__class__.__flow_model_context_type__ @@ -582,7 +716,13 @@ def _get_context_validator(self) -> TypeAdapter: if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): if cls._cached_context_validator is None: - if cls._context_td is not None: + use_ctx_args = getattr(cls, "__flow_model_use_context_args__", True) + ctx_type = cls.__flow_model_context_type__ + if not use_ctx_args and isinstance(ctx_type, type) and issubclass(ctx_type, ContextBase) and ctx_type is not FlowContext: + # Mode 1 with concrete context type — use TypeAdapter(context_type) + # directly so defaults on the context type are respected. + cls._cached_context_validator = TypeAdapter(ctx_type) + elif cls._context_td is not None: cls._cached_context_validator = TypeAdapter(cls._context_td) elif cls._context_schema: cls._cached_context_validator = _build_typed_dict_adapter(f"{cls.__name__}Inputs", cls._context_schema) @@ -701,6 +841,7 @@ def _build_context_schema( """ # Build schema dict from parameter annotations schema = {} + td_schema = {} for name in context_args: if name not in sig.parameters: raise ValueError(f"context_arg '{name}' not found in function parameters") @@ -709,14 +850,25 @@ def _build_context_schema( if annotation is inspect.Parameter.empty: raise ValueError(f"context_arg '{name}' must have a type annotation") schema[name] = annotation + # Use NotRequired in the TypedDict for params that have a default in the + # function signature, so compute() doesn't require them. + if param.default is not inspect.Parameter.empty: + td_schema[name] = NotRequired[annotation] + else: + td_schema[name] = annotation # Create TypedDict for validation (not registered anywhere!) - context_td = TypedDict(f"{_callable_name(func)}Inputs", schema) + context_td = TypedDict(f"{_callable_name(func)}Inputs", td_schema) return schema, context_td -def _validate_context_type_override(context_type: Any, context_args: List[str], func_schema: Dict[str, Type]) -> Type[ContextBase]: +def _validate_context_type_override( + context_type: Any, + context_args: List[str], + func_schema: Dict[str, Type], + func_defaults: set[str] = frozenset(), +) -> Type[ContextBase]: """Validate an explicit ``context_type`` override for ``context_args`` mode.""" if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): @@ -752,6 +904,14 @@ def _validate_context_type_override(context_type: Any, context_args: List[str], f"but context_type {context_type.__name__} declares {ctx_ann.__name__}" ) + # Reject if the function has a default for a context_arg but the + # context_type declares that field as required — this is contradictory. + for name in context_args: + if name in func_defaults: + ctx_field = context_fields.get(name) + if ctx_field is not None and ctx_field.is_required(): + raise TypeError(f"context_arg '{name}': function has a default but context_type {context_type.__name__} requires this field") + return context_type @@ -855,8 +1015,11 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: # Mode 2: Explicit context_args - specified params come from context context_param_name = "context" context_schema_early, context_td_early = _build_context_schema(context_args, fn, sig, _resolved_hints) + _func_defaults_set = {name for name in context_args if sig.parameters[name].default is not inspect.Parameter.empty} explicit_context_type = ( - _validate_context_type_override(context_type, context_args, context_schema_early) if context_type is not None else None + _validate_context_type_override(context_type, context_args, context_schema_early, _func_defaults_set) + if context_type is not None + else None ) resolved_context_type = explicit_context_type if explicit_context_type is not None else FlowContext # Exclude context_args from model fields @@ -914,6 +1077,17 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: ctx_args_for_closure = context_args if context_args is not None else [] is_dynamic_mode = use_context_args and explicit_context_args is None + # Compute context_arg defaults and validators for Mode 2 (context_args) + context_arg_defaults: Dict[str, Any] = {} + _ctx_validatable_types: Dict[str, Type] = {} + _ctx_validators: Dict[str, TypeAdapter] = {} + if context_args is not None: + for name in context_args: + p = sig.parameters[name] + if p.default is not inspect.Parameter.empty: + context_arg_defaults[name] = p.default + _ctx_validatable_types, _ctx_validators = _build_config_validators(context_schema_early) + # Create the __call__ method def make_call_impl(): def __call__(self, context): @@ -929,7 +1103,7 @@ def resolve_callable_model(value): def _resolve_field(name, value): """Resolve a single field value, handling lazy wrapping.""" - is_dep = isinstance(value, (CallableModel, BoundModel)) + is_dep = isinstance(value, CallableModel) if name in lazy_fields: # Lazy field: wrap in a thunk regardless of type if is_dep: @@ -952,7 +1126,14 @@ def _resolve_field(name, value): elif not is_dynamic_mode: # Mode 2: Explicit context_args - get those from context, rest from self for name in ctx_args_for_closure: - fn_kwargs[name] = getattr(context, name) + value = getattr(context, name, _UNSET) + if value is _UNSET: + if name in context_arg_defaults: + fn_kwargs[name] = context_arg_defaults[name] + else: + raise TypeError(f"Missing context field '{name}'") + else: + fn_kwargs[name] = _coerce_context_value(name, value, _ctx_validators, _ctx_validatable_types) # Add model fields for name in all_param_names: value = getattr(self, name) @@ -976,7 +1157,10 @@ def _resolve_field(name, value): if value is _UNSET: missing_fields.append(name) continue - fn_kwargs[name] = value + # Validate/coerce context-sourced value, skip CallableModel deps + if not _is_model_dependency(value): + value = _coerce_context_value(name, value, _config_validators, _validatable_types) + fn_kwargs[name] = _resolve_field(name, value) if missing_fields: missing = ", ".join(sorted(missing_fields)) @@ -1021,13 +1205,13 @@ def _resolve_field(name, value): def make_deps_impl(): def __deps__(self, context) -> GraphDepList: deps = [] - # Check ALL fields for CallableModels/BoundModels (auto-detection) + # Check ALL fields for CallableModel dependencies (auto-detection) for name in model_fields: if name in lazy_fields: continue # Lazy deps are NOT pre-evaluated value = getattr(self, name) if isinstance(value, BoundModel): - deps.append((value._model, [value._transform_context(context)])) + deps.append((value.model, [value._transform_context(context)])) elif isinstance(value, CallableModel): deps.append((value, [context])) return deps @@ -1078,7 +1262,10 @@ def __deps__(self, context) -> GraphDepList: GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type GeneratedModel.__flow_model_default_param_names__ = default_param_names + GeneratedModel.__flow_model_context_arg_defaults__ = context_arg_defaults GeneratedModel.__flow_model_auto_wrap__ = auto_wrap_result + GeneratedModel.__flow_model_validatable_types__ = _validatable_types + GeneratedModel.__flow_model_config_validators__ = _config_validators # Build context_schema context_schema: Dict[str, Type] = {} diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index a2e788e..ad30824 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -845,6 +845,49 @@ def load(item_id: int) -> int: self.assertIn("int", str(cm.exception)) self.assertIn("str", str(cm.exception)) + def test_model_validate_rejects_bad_scalar_type(self): + """model_validate should reject wrong scalar types, not silently accept them.""" + + @Flow.model + def source(context: SimpleContext, x: int) -> GenericResult[int]: + return GenericResult(value=x) + + cls = type(source(x=1)) + with self.assertRaises(TypeError) as cm: + cls.model_validate({"x": "abc"}) + + self.assertIn("x", str(cm.exception)) + + def test_model_validate_accepts_correct_type(self): + """model_validate should accept correct types.""" + + @Flow.model + def source(context: SimpleContext, x: int) -> GenericResult[int]: + return GenericResult(value=x) + + cls = type(source(x=1)) + restored = cls.model_validate({"x": 42}) + self.assertEqual(restored(SimpleContext(value=0)).value, 42) + + def test_model_validate_rejects_bad_registry_alias(self): + """Typoed registry aliases should not silently pass through model_validate.""" + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def consumer(context: SimpleContext, n: int = 10) -> GenericResult[int]: + return GenericResult(value=n) + + cls = type(consumer(n=1)) + # "not_in_registry" is not a valid int and not a valid registry key + with self.assertRaises(TypeError) as cm: + cls.model_validate({"n": "not_in_registry"}) + self.assertIn("n", str(cm.exception)) + finally: + registry.clear() + def test_context_type_compatible_annotations_accepted(self): """context_type validation should accept matching or subclass annotations.""" @@ -864,6 +907,16 @@ def load_exact(start_date: date, end_date: date) -> str: class TestBoundModel(TestCase): """Tests for BoundModel and BoundModel.flow.""" + def test_bound_model_is_callable_model(self): + """BoundModel should be a proper CallableModel subclass.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + bound = source().flow.with_inputs(x=lambda ctx: ctx.x * 2) + self.assertIsInstance(bound, CallableModel) + def test_bound_model_flow_compute(self): """Test that bound.flow.compute() honors transforms.""" @@ -897,6 +950,66 @@ def my_model(x: int, y: int) -> GenericResult[int]: # 7 * 3 = 21 self.assertEqual(result.value, 21) + def test_bound_model_dump_validate_roundtrip_static(self): + """Static transforms survive model_dump → model_validate roundtrip.""" + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + dump = bound.model_dump(mode="python") + restored = type(bound).model_validate(dump) + + ctx = SimpleContext(value=1) + self.assertEqual(bound(ctx).value, 420) + self.assertEqual(restored(ctx).value, 420) + + def test_bound_model_validate_same_payload_twice(self): + """Validating the same serialized BoundModel payload twice should work both times.""" + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + dump = bound.model_dump(mode="python") + + r1 = BoundModel.model_validate(dump) + r2 = BoundModel.model_validate(dump) + + ctx = SimpleContext(value=1) + self.assertEqual(r1(ctx).value, 420) + self.assertEqual(r2(ctx).value, 420) + + def test_bound_model_failed_validate_does_not_poison_next_construction(self): + """A failed model_validate must not leak static transforms to subsequent constructions.""" + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + base = source() + + # Attempt a model_validate that will fail (invalid model field) + try: + BoundModel.model_validate( + { + "model": "not-a-real-model", + "_static_transforms": {"value": 42}, + "_input_transforms_token": {"value": "42"}, + } + ) + except Exception: + pass # Expected to fail + + # Now construct a fresh BoundModel normally — must NOT inherit stale transforms + clean = BoundModel(model=base, input_transforms={}) + ctx = SimpleContext(value=1) + self.assertEqual(clean(ctx).value, 10) # 1 * 10, no transform applied + def test_bound_model_cloudpickle_with_lambda_transform(self): """BoundModel with lambda transforms should survive cloudpickle round-trip.""" @@ -1055,6 +1168,33 @@ def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 self.assertEqual(result.value, 51) + def test_differently_transformed_bound_models_have_distinct_cache_keys(self): + """Two BoundModels with different transforms must not collide under caching.""" + + call_counts = {"source": 0} + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + call_counts["source"] += 1 + return GenericResult(value=context.value * 10) + + base = source() + b1 = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) + b2 = base.flow.with_inputs(value=lambda ctx: ctx.value + 2) + evaluator = MemoryCacheEvaluator() + ctx = SimpleContext(value=5) + + with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + r1 = b1(ctx) + r2 = b2(ctx) + + # b1 transforms value to 6, source: 6*10=60 + # b2 transforms value to 7, source: 7*10=70 + self.assertEqual(r1.value, 60) + self.assertEqual(r2.value, 70) + # Source called twice (once per distinct transformed context) + self.assertEqual(call_counts["source"], 2) + def test_bound_and_unbound_models_share_memory_cache(self): """Shifted and unshifted models should share one evaluator cache. @@ -1083,7 +1223,9 @@ def source(context: SimpleContext) -> GenericResult[int]: # One execution for the unshifted context and one for the shifted context. self.assertEqual(call_counts["source"], 2) - self.assertEqual(len(evaluator.cache), 2) + # Cache has 3 entries: base(ctx), BoundModel(ctx), and base(shifted_ctx). + # BoundModel is a proper CallableModel now, so it gets its own cache entry. + self.assertEqual(len(evaluator.cache), 3) def test_transform_error_propagates(self): """A buggy transform should raise, not silently fall back to FlowContext.""" @@ -1794,6 +1936,251 @@ def consumer( self.assertEqual(result.value, 51) # 50 + 1 +# ============================================================================= +# Bug Fix Regression Tests +# ============================================================================= + + +class TestFlowModelBugFixes(TestCase): + """Regression tests for four bugs identified during code review.""" + + # ----- Issue 1: .flow.compute() drops context defaults ----- + + def test_compute_respects_explicit_context_defaults(self): + """Mode 1: compute(x=1) should use ExtendedContext's default y='default'.""" + + @Flow.model + def model_fn(context: ExtendedContext, factor: int = 1) -> str: + return f"{context.x}-{context.y}-{factor}" + + model = model_fn() + result = model.flow.compute(x=1) + self.assertEqual(result.value, "1-default-1") + + def test_compute_respects_context_args_defaults(self): + """Mode 2: compute(x=1) should use function default y=42.""" + + @Flow.model(context_args=["x", "y"]) + def model_fn(x: int, y: int = 42) -> int: + return x + y + + model = model_fn() + result = model.flow.compute(x=1) + self.assertEqual(result.value, 43) + + def test_unbound_inputs_excludes_context_args_with_defaults(self): + """Mode 2: unbound_inputs should not include context_args that have function defaults.""" + + @Flow.model(context_args=["x", "y"]) + def model_fn(x: int, y: int = 42) -> int: + return x + y + + model = model_fn() + self.assertEqual(model.flow.unbound_inputs, {"x": int}) + + def test_unbound_inputs_excludes_context_type_defaults(self): + """Mode 1: unbound_inputs should not include context fields that have defaults.""" + + @Flow.model + def model_fn(context: ExtendedContext) -> str: + return f"{context.x}-{context.y}" + + model = model_fn() + # ExtendedContext has x: int (required) and y: str = "default" + self.assertEqual(model.flow.unbound_inputs, {"x": int}) + + def test_context_type_rejects_required_field_with_function_default(self): + """Decoration should fail when function has default but context_type requires the field.""" + + class StrictContext(ContextBase): + x: int # required + + with self.assertRaises(TypeError) as cm: + + @Flow.model(context_args=["x"], context_type=StrictContext) + def model_fn(x: int = 5) -> int: + return x + + self.assertIn("x", str(cm.exception)) + self.assertIn("requires", str(cm.exception)) + + def test_context_type_accepts_optional_field_with_function_default(self): + """Both context_type and function have defaults — should work.""" + + class OptionalContext(ContextBase): + x: int = 10 + + @Flow.model(context_args=["x"], context_type=OptionalContext) + def model_fn(x: int = 5) -> int: + return x + + model = model_fn() + result = model(OptionalContext()) + self.assertEqual(result.value, 10) # context default wins + + # ----- Issue 2: Lazy[...] broken in dynamic deferred mode ----- + + def test_lazy_from_runtime_context_in_dynamic_mode(self): + """Lazy[int] provided via FlowContext should be wrapped in a thunk.""" + + @Flow.model + def model_fn(x: int, y: Lazy[int]) -> int: + return x + y() + + model = model_fn(x=10) + result = model(FlowContext(y=32)) + self.assertEqual(result.value, 42) + + def test_callable_model_from_runtime_context_in_dynamic_mode(self): + """CallableModel provided in FlowContext should be resolved.""" + + @Flow.model + def source(value: int) -> int: + return value * 10 + + @Flow.model + def consumer(x: int, data: int) -> int: + return x + data + + model = consumer(x=1) + src = source() + result = model(FlowContext(data=src, value=5)) + # source resolves with value=5 → 50, consumer: 1 + 50 = 51 + self.assertEqual(result.value, 51) + + # ----- Issue 3: FlowContext-backed models skip schema validation ----- + + def test_direct_call_validates_flowcontext_dynamic_mode(self): + """Dynamic mode: FlowContext(y='hello') for int param should raise TypeError.""" + + @Flow.model + def model_fn(x: int, y: int) -> int: + return x + y + + model = model_fn() + with self.assertRaises(TypeError) as cm: + model(FlowContext(x=1, y="hello")) + + self.assertIn("y", str(cm.exception)) + + def test_direct_call_validates_flowcontext_context_args_mode(self): + """context_args mode: FlowContext(x='hello') for int param should raise TypeError.""" + + @Flow.model(context_args=["x"]) + def model_fn(x: int) -> int: + return x + + model = model_fn() + with self.assertRaises(TypeError) as cm: + model(FlowContext(x="hello")) + + self.assertIn("x", str(cm.exception)) + + def test_with_inputs_validates_transformed_fields_dynamic(self): + """Dynamic mode: with_inputs(y='hello') for int param should raise TypeError.""" + + @Flow.model + def model_fn(x: int, y: int) -> int: + return x + y + + model = model_fn(x=1) + bound = model.flow.with_inputs(y="hello") + + with self.assertRaises(TypeError) as cm: + bound(FlowContext()) + + self.assertIn("y", str(cm.exception)) + + def test_with_inputs_validates_transformed_fields_context_args(self): + """context_args mode: with_inputs(x='hello') for int param should raise TypeError.""" + + @Flow.model(context_args=["x"]) + def model_fn(x: int) -> int: + return x + + model = model_fn() + bound = model.flow.with_inputs(x="hello") + + with self.assertRaises(TypeError) as cm: + bound(FlowContext()) + + self.assertIn("x", str(cm.exception)) + + # ----- Issue 4: Registry-name resolution too aggressive for union strings ----- + + def test_registry_resolution_skips_union_str_annotation(self): + """Union[str, int] field with a registry key string should keep the string.""" + from typing import Union + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def dummy(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=1) + + registry.add("my_key", dummy()) + + @Flow.model + def consumer(context: SimpleContext, tag: Union[str, int] = "none") -> str: + return f"tag={tag}" + + model = consumer(tag="my_key") + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, "tag=my_key") + finally: + registry.clear() + + def test_registry_resolution_skips_optional_str_annotation(self): + """Optional[str] field with a registry key string should keep the string.""" + from typing import Optional + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def dummy(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=1) + + registry.add("my_key", dummy()) + + @Flow.model + def consumer(context: SimpleContext, label: Optional[str] = None) -> str: + return f"label={label}" + + model = consumer(label="my_key") + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, "label=my_key") + finally: + registry.clear() + + def test_registry_resolution_skips_union_annotated_str(self): + """Union[Annotated[str, ...], int] field with a registry key should keep the string.""" + from typing import Annotated, Union + + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def dummy(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=1) + + registry.add("my_key", dummy()) + + @Flow.model + def consumer(context: SimpleContext, tag: Union[Annotated[str, "label"], int] = "none") -> str: + return f"tag={tag}" + + model = consumer(tag="my_key") + result = model(SimpleContext(value=0)) + self.assertEqual(result.value, "tag=my_key") + finally: + registry.clear() + + if __name__ == "__main__": import unittest From 587b26f68eb9af7ad5e042400957b168b6646fee Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 19 Mar 2026 14:57:25 -0400 Subject: [PATCH 15/26] Update docs and small flow_model changes Signed-off-by: Nijat Khanbabayev --- ccflow/flow_model.py | 27 ++++++++++++++++++++++++--- ccflow/tests/test_flow_context.py | 22 ++++++++++++++++++++++ docs/wiki/Key-Features.md | 2 +- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index ab0b3f4..da05d8e 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -222,9 +222,9 @@ def _coerce_context_value(name: str, value: Any, validators: Dict[str, TypeAdapt return value try: return validators[name].validate_python(value) - except Exception: + except Exception as exc: expected = validatable_types.get(name, "unknown") - raise TypeError(f"Context field '{name}': expected {expected}, got {type(value).__name__} ({value!r})") + raise TypeError(f"Context field '{name}': expected {expected}, got {type(value).__name__} ({value!r})") from exc def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, Type], validators: Dict[str, TypeAdapter]) -> None: @@ -767,6 +767,13 @@ def smart_training( ``with_inputs()`` for deferred execution:: lookback = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + + **Which to use:** + + - Use ``Lazy[T]`` in a ``@Flow.model`` signature when you want conditional/ + on-demand evaluation of an expensive upstream dependency. + - Use ``Lazy(model)(...)`` when you need to rewire context fields before + passing them to an existing model (e.g., shifting a date window). """ def __class_getitem__(cls, item): @@ -993,7 +1000,21 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: else: internal_return_type = return_type - # Determine context mode + # ── Context mode selection ── + # The decorator supports three mutually exclusive context modes: + # + # Mode 1 (explicit context): Function has a 'context' (or '_') parameter + # annotated with a ContextBase subclass. Behaves like a traditional + # CallableModel.__call__. Other params become model fields. + # + # Mode 2 (context_args): Decorator specifies context_args=[...] listing + # which params come from the context at runtime. Remaining params become + # model fields. Uses FlowContext unless context_type= overrides it. + # + # Mode 3 (dynamic deferred): No 'context' param and no context_args. + # Every param is a potential model field. Params bound at construction + # are config; unbound params become runtime inputs from FlowContext. + # context_schema_early: Dict[str, Type] = {} context_td_early = None if "context" in params or "_" in params: diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 718f8de..c9a9811 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -86,6 +86,28 @@ def test_flow_context_hash_uses_extra_fields(self): assert len({first, second, third}) == 3 + def test_flow_context_hash_raises_for_unhashable_values(self): + """FlowContext with truly unhashable values (no __dict__) should raise TypeError.""" + + class Unhashable: + __hash__ = None # type: ignore[assignment] + + def __init__(self): + pass + + # Deliberately no __dict__ suppression — but __hash__ is None, + # so the fallback path in _freeze_for_hash should use __dict__. + # To trigger the actual TypeError path, we need an object with + # no __dict__ and no __hash__. + + class UnhashableSlots: + __slots__ = () + __hash__ = None # type: ignore[assignment] + + ctx = FlowContext(val=UnhashableSlots()) + with pytest.raises(TypeError, match="unhashable value"): + hash(ctx) + def test_flow_context_pickle(self): """FlowContext pickles cleanly.""" ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index e0aac45..5fb27de 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -103,7 +103,7 @@ context" and you want that split to stay stable and explicit: - you want the generated model to accept a specific existing context type such as `DateRangeContext` -**Mode 3 — Default deferred style (no explicit context):** +**Mode 3 — Dynamic deferred style (no explicit context):** When there is no `context` parameter and no `context_args`, all parameters are potential configuration or runtime inputs. Parameters provided at construction From 517f45ec39cc9eae76a1d79ef2fc702bdd76eb5f Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 19 Mar 2026 15:03:58 -0400 Subject: [PATCH 16/26] Add more examples Signed-off-by: Nijat Khanbabayev --- .../config/flow_model_hydra_builder_demo.yaml | 24 ++ examples/evaluator_demo.py | 186 ++++++++++ examples/flow_model_hydra_builder_demo.py | 160 ++++++++ examples/ml_pipeline_demo.py | 351 ++++++++++++++++++ 4 files changed, 721 insertions(+) create mode 100644 examples/config/flow_model_hydra_builder_demo.yaml create mode 100644 examples/evaluator_demo.py create mode 100644 examples/flow_model_hydra_builder_demo.py create mode 100644 examples/ml_pipeline_demo.py diff --git a/examples/config/flow_model_hydra_builder_demo.yaml b/examples/config/flow_model_hydra_builder_demo.yaml new file mode 100644 index 0000000..5579a5c --- /dev/null +++ b/examples/config/flow_model_hydra_builder_demo.yaml @@ -0,0 +1,24 @@ +# Hydra config for examples/flow_model_hydra_builder_demo.py +# +# Pattern: +# - configure static pipeline specs in YAML +# - use model_alias to pass already-registered models into a plain Python builder +# - keep runtime context as runtime inputs, supplied later at execution time + +current_revenue: + _target_: examples.flow_model_hydra_builder_demo.load_revenue + region: us + +week_over_week: + _target_: examples.flow_model_hydra_builder_demo.build_comparison + current: + _target_: ccflow.compose.model_alias + model_name: current_revenue + comparison: week_over_week + +month_over_month: + _target_: examples.flow_model_hydra_builder_demo.build_comparison + current: + _target_: ccflow.compose.model_alias + model_name: current_revenue + comparison: month_over_month diff --git a/examples/evaluator_demo.py b/examples/evaluator_demo.py new file mode 100644 index 0000000..a85b087 --- /dev/null +++ b/examples/evaluator_demo.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +""" +Evaluator Demo: Caching & Execution Strategies +=============================================== + +Shows how to change execution behavior (caching, graph evaluation, logging) +WITHOUT changing user code. The same @Flow.model functions work with any +evaluator stack — you just configure it at the top level. + +Key insight: "default lazy" is an evaluator concern, not a wiring concern. +Users write plain functions and wire them by passing outputs as inputs. +The evaluator layer controls how they execute. + +Demonstrates: + 1. Default execution (eager, no caching) — diamond dep calls load twice + 2. MemoryCacheEvaluator — deduplicates shared deps in a diamond + 3. GraphEvaluator + Cache — topological evaluation + deduplication + 4. LoggingEvaluator — adds tracing around every model call + 5. Per-model opt-out — @Flow.model(cacheable=False) overrides global + +Run with: python examples/evaluator_demo.py +""" + +from __future__ import annotations + +import logging +import sys + +# Suppress default debug logging from ccflow evaluators for clean demo output +logging.disable(logging.DEBUG) + +from ccflow import Flow, FlowOptionsOverride # noqa: E402 +from ccflow.evaluators.common import ( # noqa: E402 + GraphEvaluator, + LoggingEvaluator, + MemoryCacheEvaluator, + MultiEvaluator, +) + +# ============================================================================= +# Plain @Flow.model functions — no evaluator concerns in the code +# ============================================================================= + +call_counts: dict[str, int] = {} + + +def _track(name: str) -> None: + call_counts[name] = call_counts.get(name, 0) + 1 + + +@Flow.model +def load_data(x: int, source: str = "warehouse") -> list: + """Load raw data. Expensive — we want to avoid calling this twice.""" + _track("load_data") + return [x, x * 2, x * 3] + + +@Flow.model +def compute_sum(data: list) -> int: + """Branch A: sum the data.""" + _track("compute_sum") + return sum(data) + + +@Flow.model +def compute_max(data: list) -> int: + """Branch B: max of the data.""" + _track("compute_max") + return max(data) + + +@Flow.model +def combine(sum_result: int, max_result: int) -> dict: + """Combine results from both branches.""" + _track("combine") + return {"sum": sum_result, "max": max_result, "total": sum_result + max_result} + + +@Flow.model(cacheable=False) +def volatile_timestamp(seed: int) -> str: + """Explicitly non-cacheable — always re-executes even with global caching.""" + _track("volatile_timestamp") + from datetime import datetime + + return datetime.now().isoformat() + + +# ============================================================================= +# Wire the pipeline — diamond dependency on load_data +# +# load_data ──┬── compute_sum ──┐ +# └── compute_max ──┴── combine +# ============================================================================= + +shared = load_data(source="prod") +branch_a = compute_sum(data=shared) +branch_b = compute_max(data=shared) +pipeline = combine(sum_result=branch_a, max_result=branch_b) + + +def run() -> dict: + call_counts.clear() + result = pipeline.flow.compute(x=5) + loads = call_counts.get("load_data", 0) + print(f" Result: {result.value}") + print(f" load_data called: {loads}x | total model calls: {sum(call_counts.values())}") + return result.value + + +# ============================================================================= +# Demo 1: Default — no evaluator +# ============================================================================= + +print("=" * 70) +print("1. Default (eager, no caching)") +print(" load_data is called TWICE — once per branch") +print("=" * 70) +run() + +# ============================================================================= +# Demo 2: MemoryCacheEvaluator — deduplicates shared deps +# ============================================================================= + +print() +print("=" * 70) +print("2. MemoryCacheEvaluator (global override)") +print(" load_data is called ONCE — second branch hits cache") +print("=" * 70) +with FlowOptionsOverride(options={"evaluator": MemoryCacheEvaluator(), "cacheable": True}): + run() + +# ============================================================================= +# Demo 3: Cache + GraphEvaluator — topological order + deduplication +# ============================================================================= + +print() +print("=" * 70) +print("3. GraphEvaluator + MemoryCacheEvaluator") +print(" Evaluates in dependency order: load_data → branches → combine") +print("=" * 70) +evaluator = MultiEvaluator(evaluators=[MemoryCacheEvaluator(), GraphEvaluator()]) +with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + run() + +# ============================================================================= +# Demo 4: Logging — trace every model call +# ============================================================================= + +print() +print("=" * 70) +print("4. LoggingEvaluator + MemoryCacheEvaluator") +print(" Adds timing/tracing around every evaluation") +print("=" * 70) + +# Re-enable logging for this demo (use stdout so log lines interleave with print correctly) +logging.disable(logging.NOTSET) +logging.basicConfig(level=logging.INFO, format=" LOG: %(message)s", stream=sys.stdout) + +evaluator = MultiEvaluator(evaluators=[LoggingEvaluator(log_level=logging.INFO), MemoryCacheEvaluator()]) +with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): + run() + +# Suppress again for clean output +logging.disable(logging.DEBUG) +logging.getLogger().handlers.clear() + +# ============================================================================= +# Demo 5: Per-model opt-out — cacheable=False overrides global +# ============================================================================= + +print() +print("=" * 70) +print("5. Per-model opt-out: @Flow.model(cacheable=False)") +print(" volatile_timestamp always re-executes despite global cacheable=True") +print("=" * 70) + +ts = volatile_timestamp(seed=0) + +with FlowOptionsOverride(options={"evaluator": MemoryCacheEvaluator(), "cacheable": True}): + call_counts.clear() + r1 = ts.flow.compute(seed=0) + r2 = ts.flow.compute(seed=0) + print(f" Call 1: {r1.value}") + print(f" Call 2: {r2.value}") + print(f" volatile_timestamp called: {call_counts.get('volatile_timestamp', 0)}x") + print(f" (Same result? {r1.value == r2.value} — called twice, timestamps may differ)") diff --git a/examples/flow_model_hydra_builder_demo.py b/examples/flow_model_hydra_builder_demo.py new file mode 100644 index 0000000..00c9571 --- /dev/null +++ b/examples/flow_model_hydra_builder_demo.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python +"""Hydra + Flow.model builder demo. + +This example shows a clean way to mix: + +1. ergonomic `@Flow.model` pipeline wiring in Python, and +2. Hydra / ModelRegistry configuration for static pipeline specs. + +The pattern is: + +- keep runtime context (`start_date`, `end_date`) as runtime inputs, +- use a plain Python builder function for graph construction, +- let Hydra instantiate that builder and register the returned model. + +Run with: + python examples/flow_model_hydra_builder_demo.py +""" + +from calendar import monthrange +from datetime import date, timedelta +from pathlib import Path +from typing import Literal, Protocol, cast + +from ccflow import BoundModel, CallableModel, DateRangeContext, Flow, FlowAPI, GenericResult, ModelRegistry +from typing_extensions import TypedDict + +CONFIG_PATH = Path(__file__).with_name("config") / "flow_model_hydra_builder_demo.yaml" +ComparisonName = Literal["week_over_week", "month_over_month"] + + +class RevenueChangeResult(TypedDict): + comparison: ComparisonName + current_window: str + previous_window: str + current: float + previous: float + delta: float + growth_pct: float + + +class RevenueChangeModel(Protocol): + flow: FlowAPI + + def __call__(self, context: DateRangeContext) -> GenericResult[RevenueChangeResult]: ... + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def load_revenue(start_date: date, end_date: date, region: str) -> float: + """Return synthetic revenue for a date window.""" + days = (end_date - start_date).days + 1 + region_base = {"us": 1000.0, "eu": 850.0, "apac": 920.0}.get(region, 900.0) + days_since_2024 = (end_date - date(2024, 1, 1)).days + trend = days_since_2024 * 2.5 + return round(region_base + days * 8.0 + trend, 2) + + +@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +def revenue_change( + start_date: date, + end_date: date, + current: float, + previous: float, + comparison: ComparisonName, +) -> RevenueChangeResult: + """Compare the current window against a shifted previous window.""" + growth = (current - previous) / previous + previous_start, previous_end = comparison_window(start_date, end_date, comparison) + return { + "comparison": comparison, + "current_window": f"{start_date} -> {end_date}", + "previous_window": f"{previous_start} -> {previous_end}", + "current": current, + "previous": previous, + "delta": round(current - previous, 2), + "growth_pct": round(growth * 100, 2), + } + + +def comparison_window(start_date: date, end_date: date, comparison: ComparisonName) -> tuple[date, date]: + """Return the previous window for a named comparison policy.""" + if comparison == "week_over_week": + return start_date - timedelta(days=7), end_date - timedelta(days=7) + + if start_date.day != 1: + raise ValueError("month_over_month requires start_date to be the first day of a month") + if start_date.year != end_date.year or start_date.month != end_date.month: + raise ValueError("month_over_month requires the current window to stay within one calendar month") + expected_end = date(end_date.year, end_date.month, monthrange(end_date.year, end_date.month)[1]) + if end_date != expected_end: + raise ValueError("month_over_month requires end_date to be the last day of that month") + + previous_year = start_date.year if start_date.month > 1 else start_date.year - 1 + previous_month = start_date.month - 1 if start_date.month > 1 else 12 + previous_start = date(previous_year, previous_month, 1) + previous_end = date(previous_year, previous_month, monthrange(previous_year, previous_month)[1]) + return previous_start, previous_end + + +def comparison_input(model: CallableModel, comparison: ComparisonName) -> BoundModel: + """Apply a named comparison policy to one dependency.""" + return model.flow.with_inputs( + start_date=lambda ctx: comparison_window(ctx.start_date, ctx.end_date, comparison)[0], + end_date=lambda ctx: comparison_window(ctx.start_date, ctx.end_date, comparison)[1], + ) + + +def build_comparison(current: CallableModel, *, comparison: ComparisonName) -> RevenueChangeModel: + """Hydra-friendly builder that returns a configured comparison model.""" + previous = comparison_input(current, comparison) + return revenue_change( + current=current, + previous=previous, + comparison=comparison, + ) + + +def main() -> None: + registry = ModelRegistry.root() + registry.clear() + try: + registry.load_config_from_path(str(CONFIG_PATH), overwrite=True) + + week_over_week = cast(RevenueChangeModel, registry["week_over_week"]) + month_over_month = cast(RevenueChangeModel, registry["month_over_month"]) + + ctx = DateRangeContext( + start_date=date(2024, 3, 1), + end_date=date(2024, 3, 31), + ) + + print("=" * 68) + print("Hydra + Flow.model Builder Demo") + print("=" * 68) + print("\nLoaded from config:") + print(" current_revenue:", registry["current_revenue"]) + print(" week_over_week:", week_over_week) + print(" month_over_month:", month_over_month) + + week_over_week_result = cast( + RevenueChangeResult, + week_over_week.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ).value, + ) + month_over_month_result = month_over_month(ctx).value + + print("\nWeek-over-week:") + for key, value in week_over_week_result.items(): + print(f" {key}: {value}") + + print("\nMonth-over-month:") + for key, value in month_over_month_result.items(): + print(f" {key}: {value}") + finally: + registry.clear() + + +if __name__ == "__main__": + main() diff --git a/examples/ml_pipeline_demo.py b/examples/ml_pipeline_demo.py new file mode 100644 index 0000000..0c8920f --- /dev/null +++ b/examples/ml_pipeline_demo.py @@ -0,0 +1,351 @@ +#!/usr/bin/env python +""" +ML Pipeline Demo: Smart Model Selection +======================================== + +This is the example from the original design conversation — a realistic ML +pipeline that demonstrates how Flow.model lets you write plain functions, +wire them by passing outputs as inputs, and execute with .flow.compute(). + +Features demonstrated: + 1. @Flow.model with auto-wrap (plain return types, no GenericResult needed) + 2. Lazy[T] for conditional evaluation (skip slow model if fast is good enough) + 3. .flow.compute() for execution with automatic context propagation + 4. .flow.with_inputs() for context transforms (lookback windows) + 5. Factored wiring — build_pipeline() shows how to reuse the same graph + structure with different data sources + +The pipeline: + + load_dataset ──> prepare_features ──> train_linear ──> evaluate ──> fast_metrics ──┐ + └──> train_forest ──> evaluate ──> slow_metrics ──┴──> smart_training + +Run with: python examples/ml_pipeline_demo.py +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import date, timedelta +from math import sin + +from ccflow import Flow, Lazy + + +# ============================================================================= +# Domain types (stand-ins for real ML objects) +# ============================================================================= + + +@dataclass +class PreparedData: + """Container for train/test split data.""" + + X_train: list # list of feature vectors + X_test: list + y_train: list # list of target values + y_test: list + + +@dataclass +class TrainedModel: + """A fitted model (placeholder).""" + + name: str + coefficients: list + intercept: float + augment: bool # Whether to add sin feature during prediction + + +@dataclass +class Metrics: + """Evaluation metrics.""" + + r2: float + mse: float + model_name: str + + +# ============================================================================= +# Data Loading +# ============================================================================= + + +@Flow.model +def load_dataset(start_date: date, end_date: date, source: str = "warehouse") -> list: + """Load raw dataset for a date range. + + Returns a list of dicts (standing in for a DataFrame). + Auto-wrapped: returns plain list, framework wraps in GenericResult. + """ + n_days = (end_date - start_date).days + 1 + print(f" [load_dataset] Loading {n_days} days from '{source}' ({start_date} to {end_date})") + # True relationship: target = 2.0 * x + 10.0 + 15.0 * sin(x * 0.2) + # Linear model captures the trend (R^2 ~0.93), forest also captures the sin wave (~0.99) + return [ + { + "date": str(start_date + timedelta(days=i)), + "x": float(i), + "target": 2.0 * i + 10.0 + 15.0 * sin(i * 0.2), + } + for i in range(n_days) + ] + + +# ============================================================================= +# Feature Engineering +# ============================================================================= + + +@Flow.model +def prepare_features(raw_data: list) -> PreparedData: + """Split data into train/test. + + Returns a PreparedData dataclass — the framework auto-wraps it in GenericResult. + Downstream models can request individual fields via prepared["X_train"] etc. + """ + n = len(raw_data) + split = int(n * 0.8) + print(f" [prepare_features] {n} rows, split at {split}") + + X = [[r["x"]] for r in raw_data] + y = [r["target"] for r in raw_data] + + return PreparedData( + X_train=X[:split], + X_test=X[split:], + y_train=y[:split], + y_test=y[split:], + ) + + +# ============================================================================= +# Model Training +# ============================================================================= + + +def _ols_fit(X, y): + """Simple OLS: compute coefficients and intercept.""" + n = len(X) + n_feat = len(X[0]) + y_mean = sum(y) / n + x_means = [sum(row[j] for row in X) / n for j in range(n_feat)] + + coefficients = [] + for j in range(n_feat): + cov = sum((X[i][j] - x_means[j]) * (y[i] - y_mean) for i in range(n)) / n + var = sum((X[i][j] - x_means[j]) ** 2 for i in range(n)) / n + coefficients.append(cov / var if var > 1e-10 else 0.0) + + intercept = y_mean - sum(c * m for c, m in zip(coefficients, x_means)) + return coefficients, intercept + + +def _augment(X): + """Add sin(x*0.2) feature to capture non-linearity.""" + return [row + [sin(row[0] * 0.2)] for row in X] + + +@Flow.model +def train_linear(prepared: PreparedData) -> TrainedModel: + """Train a fast linear model (linear features only).""" + print(f" [train_linear] Fitting on {len(prepared.X_train)} samples") + coefficients, intercept = _ols_fit(prepared.X_train, prepared.y_train) + return TrainedModel(name="LinearRegression", coefficients=coefficients, intercept=intercept, augment=False) + + +@Flow.model +def train_forest(prepared: PreparedData, n_estimators: int = 100) -> TrainedModel: + """Train a model that also captures non-linear patterns (simulated).""" + print(f" [train_forest] Fitting {n_estimators} trees on {len(prepared.X_train)} samples") + # Augment with sin feature to capture non-linearity + X_aug = _augment(prepared.X_train) + coefficients, intercept = _ols_fit(X_aug, prepared.y_train) + return TrainedModel( + name=f"RandomForest(n={n_estimators})", + coefficients=coefficients, + intercept=intercept, + augment=True, + ) + + +# ============================================================================= +# Model Evaluation +# ============================================================================= + + +@Flow.model +def evaluate_model(model: TrainedModel, prepared: PreparedData) -> Metrics: + """Evaluate a trained model on test data.""" + X_test = prepared.X_test + y_test = prepared.y_test + X_eval = _augment(X_test) if model.augment else X_test + + y_pred = [ + model.intercept + sum(c * x for c, x in zip(model.coefficients, row)) + for row in X_eval + ] + + y_mean = sum(y_test) / len(y_test) if y_test else 0 + ss_tot = sum((y - y_mean) ** 2 for y in y_test) or 1 + ss_res = sum((yt - yp) ** 2 for yt, yp in zip(y_test, y_pred)) + r2 = 1.0 - ss_res / ss_tot + mse = ss_res / len(y_test) if y_test else 0 + + print(f" [evaluate_model] {model.name}: R^2={r2:.4f}, MSE={mse:.2f}") + return Metrics(r2=r2, mse=mse, model_name=model.name) + + +# ============================================================================= +# Smart Pipeline with Conditional Execution +# ============================================================================= + + +@Flow.model +def smart_training( + # data: PreparedData, + fast_metrics: Metrics, + slow_metrics: Lazy[Metrics], # Only evaluated if fast isn't good enough + threshold: float = 0.9, +) -> Metrics: + """Use fast model if good enough, else fall back to slow. + + The slow_metrics parameter is Lazy — it receives a zero-arg thunk. + If the fast model exceeds the threshold, the slow model is never + trained or evaluated at all. + """ + print(f" [smart_training] Fast R^2={fast_metrics.r2:.4f}, threshold={threshold}") + if fast_metrics.r2 >= threshold: + print(" [smart_training] Fast model is good enough! Skipping slow model.") + return fast_metrics + else: + print(" [smart_training] Fast model below threshold, evaluating slow model...") + return slow_metrics() + + +# ============================================================================= +# Pipeline Wiring Helper +# ============================================================================= + + +def build_pipeline(raw, *, n_estimators=200, threshold=0.95): + """Wire a complete train/evaluate/select pipeline from a data source. + + This function shows the flexibility of the approach: the same wiring + logic can be applied to different data sources (raw, lookback_raw, etc.) + without duplicating code. Everything here is just wiring — no computation + happens until .flow.compute() is called. + + Args: + raw: A CallableModel or BoundModel that produces raw data (list of dicts) + n_estimators: Number of trees for the forest model + threshold: R^2 threshold for the fast/slow model selection + + Returns: + A smart_training model instance ready for .flow.compute() + """ + # Feature engineering — returns a PreparedData with X_train, X_test, etc. + prepared = prepare_features(raw_data=raw) + + # Train both models — each receives the whole PreparedData and extracts + # the fields it needs internally. + linear = train_linear(prepared=prepared) + forest = train_forest(prepared=prepared, n_estimators=n_estimators) + + # Evaluate both + linear_metrics = evaluate_model(model=linear, prepared=prepared) + forest_metrics = evaluate_model(model=forest, prepared=prepared) + + # Smart selection with Lazy — forest is only evaluated if linear isn't good enough + return smart_training( + fast_metrics=linear_metrics, + slow_metrics=forest_metrics, + threshold=threshold, + ) + + +# ============================================================================= +# Main: Wire and execute the pipeline +# ============================================================================= + + +def main(): + print("=" * 70) + print("ML Pipeline Demo: Smart Model Selection with Flow.model") + print("=" * 70) + + # ------------------------------------------------------------------ + # Step 1: Wire the pipeline (no computation happens here) + # ------------------------------------------------------------------ + print("\n--- Wiring the pipeline (lazy, no computation yet) ---\n") + + raw = load_dataset(source="prod_warehouse") + + # build_pipeline factors out the repeated wiring logic. + # Linear R^2 ≈ 0.93. Threshold is 0.95 → falls through to forest. + pipeline = build_pipeline(raw, n_estimators=200, threshold=0.95) + + print("Pipeline wired. No functions have been called yet.") + + # ------------------------------------------------------------------ + # Step 2: Execute — linear not good enough, falls back to forest + # ------------------------------------------------------------------ + print("\n--- Executing pipeline (Jan-Jun 2024) ---\n") + result = pipeline.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + + print(f"\n Best model: {result.value.model_name}") + print(f" R^2: {result.value.r2:.4f}") + print(f" MSE: {result.value.mse:.2f}") + + # ------------------------------------------------------------------ + # Step 3: Context transforms (lookback) — reuse build_pipeline + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("With Lookback: Same pipeline structure, extra history for loading") + print("=" * 70) + + # flow.with_inputs() creates a BoundModel that transforms the context + # before calling the underlying model. start_date is shifted 30 days earlier. + lookback_raw = raw.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=30) + ) + + # Same wiring logic, different data source — no duplication. + lookback_pipeline = build_pipeline(lookback_raw, n_estimators=200, threshold=0.95) + + print("\n--- Executing lookback pipeline ---\n") + result2 = lookback_pipeline.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + # Notice: load_dataset gets start_date=2023-12-02 (30 days earlier) + + print(f"\n Best model: {result2.value.model_name}") + print(f" R^2: {result2.value.r2:.4f}") + print(f" MSE: {result2.value.mse:.2f}") + + # ------------------------------------------------------------------ + # Step 4: Lower threshold — linear is good enough, skip forest + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("Lazy Evaluation: Lower threshold so fast model is good enough") + print("=" * 70) + + # With threshold=0.80, the linear model's R^2 (~0.93) passes. + # The forest is NEVER trained or evaluated — Lazy skips it entirely. + fast_pipeline = build_pipeline(raw, n_estimators=200, threshold=0.80) + + print("\n--- Executing (slow model should NOT be trained) ---\n") + result3 = fast_pipeline.flow.compute( + start_date=date(2024, 1, 1), + end_date=date(2024, 6, 30), + ) + print(f"\n Selected: {result3.value.model_name} (R^2={result3.value.r2:.4f})") + print(" (Notice: train_forest and its evaluate_model were never called)") + + +if __name__ == "__main__": + main() From 324fc4ec93a3058a8289a21b612b85c590eacc4c Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Fri, 20 Mar 2026 13:53:06 -0400 Subject: [PATCH 17/26] More test coverage Signed-off-by: Nijat Khanbabayev --- ccflow/tests/test_callable.py | 97 ++++ ccflow/tests/test_flow_context.py | 40 ++ ccflow/tests/test_flow_model.py | 799 ++++++++++++++++++++++++++++++ 3 files changed, 936 insertions(+) diff --git a/ccflow/tests/test_callable.py b/ccflow/tests/test_callable.py index 6d8f53e..9b51592 100644 --- a/ccflow/tests/test_callable.py +++ b/ccflow/tests/test_callable.py @@ -1057,3 +1057,100 @@ def __call__(self, *, value: int): return GenericResult(value=value) self.assertIn("must have a return type annotation", str(cm.exception)) + + def test_auto_context_rejects_missing_annotation(self): + """auto_context should reject params without type annotations.""" + with self.assertRaises(TypeError) as cm: + + class AutoContextCallable(CallableModel): + @Flow.call(auto_context=True) + def __call__(self, *, value) -> GenericResult: + return GenericResult(value=value) + + self.assertIn("must have a type annotation", str(cm.exception)) + + +class TestDeclaredTypeMatches(TestCase): + """Tests for _declared_type_matches helper in callable.py.""" + + def test_typevar_always_matches(self): + from ccflow.callable import _declared_type_matches + + T = TypeVar("T") + self.assertTrue(_declared_type_matches(int, T)) + + def test_union_expected_no_type_args(self): + """Union with no concrete type args should return False.""" + from ccflow.callable import _declared_type_matches + + # Union[None] after filtering out NoneType has no concrete args + self.assertFalse(_declared_type_matches(int, Union[None])) + + def test_union_expected_with_actual_type(self): + """Concrete type matching Union expected.""" + from ccflow.callable import _declared_type_matches + + self.assertTrue(_declared_type_matches(int, Union[int, str])) + self.assertFalse(_declared_type_matches(float, Union[int, str])) + + def test_union_both_sides(self): + """Both actual and expected are Unions.""" + from ccflow.callable import _declared_type_matches + + self.assertTrue(_declared_type_matches(Union[int, str], Union[int, str])) + self.assertTrue(_declared_type_matches(Union[str, int], Union[int, str])) # order independent + self.assertFalse(_declared_type_matches(Union[int, float], Union[int, str])) + + def test_non_type_actual(self): + """Non-type actual should return False.""" + from ccflow.callable import _declared_type_matches + + self.assertFalse(_declared_type_matches("not_a_type", int)) + + def test_non_type_expected(self): + """Non-type expected should return False.""" + from ccflow.callable import _declared_type_matches + + self.assertFalse(_declared_type_matches(int, "not_a_type")) + + +class TestCallableModelGenericValidation(TestCase): + """Tests for CallableModelGeneric type validation paths.""" + + def test_context_type_mismatch_raises(self): + """Generic type validation should reject context type mismatch.""" + + class ContextA(ContextBase): + a: int + + class ContextB(ContextBase): + b: int + + class ModelA(CallableModel): + @Flow.call + def __call__(self, context: ContextA) -> GenericResult[int]: + return GenericResult(value=context.a) + + with self.assertRaises(ValidationError): + # Expect ContextB but model has ContextA + CallableModelGenericType[ContextB, GenericResult[int]].model_validate(ModelA()) + + def test_result_type_mismatch_raises(self): + """Generic type validation should reject result type mismatch.""" + + class MyContext(ContextBase): + x: int + + class ResultA(ResultBase): + a: int + + class ResultB(ResultBase): + b: int + + class ModelA(CallableModel): + @Flow.call + def __call__(self, context: MyContext) -> ResultA: + return ResultA(a=context.x) + + with self.assertRaises(ValidationError): + CallableModelGenericType[MyContext, ResultB].model_validate(ModelA()) diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index c9a9811..970cc08 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -108,6 +108,46 @@ class UnhashableSlots: with pytest.raises(TypeError, match="unhashable value"): hash(ctx) + def test_flow_context_eq_non_flow_context(self): + """FlowContext.__eq__ returns False for non-FlowContext objects.""" + ctx = FlowContext(x=1) + assert ctx != 42 + assert ctx != "hello" + assert ctx != None # noqa: E711 + assert ctx != NumberContext(x=1) + + def test_flow_context_hash_with_set_value(self): + """FlowContext with set values should hash correctly via frozenset.""" + ctx = FlowContext(tags=frozenset({"a", "b"})) + # Should not raise + h = hash(ctx) + assert isinstance(h, int) + + def test_flow_context_hash_with_model_dump_object(self): + """_freeze_for_hash should handle objects with model_dump attribute.""" + from ccflow.context import _freeze_for_hash + + # Directly test _freeze_for_hash with an object that has model_dump + # (FlowContext.__hash__ goes through model_dump first which serializes + # nested models, so we test the helper directly) + inner = NumberContext(x=42) + result = _freeze_for_hash(inner) + assert isinstance(result, tuple) + assert result[0] is NumberContext + + def test_flow_context_hash_unhashable_with_dict_fallback(self): + """Objects with __dict__ but no __hash__ should use __dict__ fallback.""" + + class UnhashableWithDict: + __hash__ = None # type: ignore[assignment] + + def __init__(self, val): + self.val = val + + ctx = FlowContext(obj=UnhashableWithDict(42)) + h = hash(ctx) + assert isinstance(h, int) + def test_flow_context_pickle(self): """FlowContext pickles cleanly.""" ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index ad30824..57a54df 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -2181,6 +2181,805 @@ def consumer(context: SimpleContext, tag: Union[Annotated[str, "label"], int] = registry.clear() +# ============================================================================= +# Coverage Gap Tests +# ============================================================================= + + +class TestExtractLazyLoopBody(TestCase): + """Group 1: _extract_lazy loop body with non-LazyMarker metadata.""" + + def test_annotated_with_extra_metadata_before_lazy_marker(self): + """Annotated type where _LazyMarker is NOT the first metadata element.""" + from typing import Annotated + + from ccflow.flow_model import _extract_lazy, _LazyMarker + + # _LazyMarker is the second metadata element — loop must iterate past "other" + ann = Annotated[int, "other_metadata", _LazyMarker()] + base_type, is_lazy = _extract_lazy(ann) + self.assertTrue(is_lazy) + self.assertIs(base_type, int) + + def test_annotated_without_lazy_marker(self): + """Annotated type with no _LazyMarker returns is_lazy=False.""" + from typing import Annotated + + from ccflow.flow_model import _extract_lazy + + ann = Annotated[int, "just_metadata"] + base_type, is_lazy = _extract_lazy(ann) + self.assertFalse(is_lazy) + + def test_lazy_type_annotation_with_extra_annotated(self): + """End-to-end: Lazy wrapping of an Annotated type.""" + + @Flow.model + def model_with_lazy( + x: int, + dep: Lazy[int], + ) -> int: + return x + dep() + + @Flow.model + def upstream(x: int) -> int: + return x * 10 + + model = model_with_lazy(x=1, dep=upstream()) + result = model.flow.compute(x=1) + self.assertEqual(result.value, 11) + + def test_lazy_dep_returning_custom_result(self): + """Lazy dep returning custom ResultBase (not GenericResult) should return raw result.""" + + @Flow.model + def upstream(context: SimpleContext) -> MyResult: + return MyResult(data=f"v={context.value}") + + @Flow.model + def consumer(context: SimpleContext, dep: Lazy[MyResult]) -> GenericResult[str]: + result = dep() + return GenericResult(value=result.data) + + model = consumer(dep=upstream()) + result = model(SimpleContext(value=42)) + self.assertEqual(result.value, "v=42") + + +class TestTransformReprNamedCallable(TestCase): + """Group 2: _transform_repr with a named callable.""" + + def test_named_function_transform_in_repr(self): + """Named functions should appear in BoundModel repr wrapped in angle brackets.""" + from ccflow.flow_model import _transform_repr + + def my_custom_transform(ctx): + return ctx.value + 1 + + result = _transform_repr(my_custom_transform) + self.assertIn("my_custom_transform", result) + self.assertTrue(result.startswith("<")) + self.assertTrue(result.endswith(">")) + + def test_static_value_repr(self): + """Static (non-callable) values should use repr().""" + from ccflow.flow_model import _transform_repr + + self.assertEqual(_transform_repr(42), "42") + self.assertEqual(_transform_repr("hello"), "'hello'") + + +class TestBoundFieldNamesFallback(TestCase): + """Group 3: _bound_field_names fallback for objects without model_fields_set.""" + + def test_fallback_to_bound_fields_attr(self): + from ccflow.flow_model import _bound_field_names + + class FakeModel: + _bound_fields = {"x", "y"} + + result = _bound_field_names(FakeModel()) + self.assertEqual(result, {"x", "y"}) + + def test_fallback_no_attrs(self): + from ccflow.flow_model import _bound_field_names + + class Empty: + pass + + result = _bound_field_names(Empty()) + self.assertEqual(result, set()) + + +class TestRuntimeInputNamesEmpty(TestCase): + """Group 4: _runtime_input_names when all_param_names is empty.""" + + def test_non_flow_model_returns_empty(self): + from ccflow.flow_model import _runtime_input_names + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + model = ManualModel(offset=5) + self.assertEqual(_runtime_input_names(model), set()) + + +class TestRegistryCandidateAllowed(TestCase): + """Group 5: _registry_candidate_allowed TypeAdapter success path.""" + + def test_non_callable_model_passes_type_check(self): + """Registry value that is not a CallableModel but passes TypeAdapter validation.""" + from ccflow.flow_model import _registry_candidate_allowed + + # int value passes TypeAdapter(int).validate_python + self.assertTrue(_registry_candidate_allowed(int, 42)) + + def test_non_callable_model_fails_type_check(self): + from ccflow.flow_model import _registry_candidate_allowed + + self.assertFalse(_registry_candidate_allowed(int, "not_an_int")) + + +class TestConcreteContextTypeOptional(TestCase): + """Group 6: _concrete_context_type with Optional/Union types.""" + + def test_optional_context_type(self): + """Optional[T] has NoneType that should be skipped to find T.""" + from typing import Optional + + from ccflow.flow_model import _concrete_context_type + + # Optional[SimpleContext] = Union[SimpleContext, None] + # The NoneType arg must be skipped (line 196-197) + result = _concrete_context_type(Optional[SimpleContext]) + self.assertIs(result, SimpleContext) + + def test_union_with_none_first(self): + """Union[None, T] should skip NoneType and find T.""" + from typing import Union + + from ccflow.flow_model import _concrete_context_type + + # NoneType comes first, must be skipped + result = _concrete_context_type(Union[None, SimpleContext]) + self.assertIs(result, SimpleContext) + + def test_union_context_type(self): + from typing import Union + + from ccflow.flow_model import _concrete_context_type + + result = _concrete_context_type(Union[SimpleContext, None]) + self.assertIs(result, SimpleContext) + + def test_union_no_context_base(self): + from typing import Union + + from ccflow.flow_model import _concrete_context_type + + result = _concrete_context_type(Union[int, str]) + self.assertIsNone(result) + + def test_returns_none_for_non_type(self): + from ccflow.flow_model import _concrete_context_type + + result = _concrete_context_type("not_a_type") + self.assertIsNone(result) + + +class TestBuildConfigValidatorsException(TestCase): + """Group 7: _build_config_validators when TypeAdapter fails.""" + + def test_unadaptable_type_skipped(self): + """Types that TypeAdapter can't handle should be silently skipped.""" + from ccflow.flow_model import _build_config_validators + + # type(...) (EllipsisType) makes TypeAdapter fail + validatable, validators = _build_config_validators({"x": int, "y": type(...)}) + self.assertIn("x", validatable) + self.assertNotIn("y", validatable) + self.assertIn("x", validators) + self.assertNotIn("y", validators) + + +class TestCoerceContextValueNoValidator(TestCase): + """Group 8: _coerce_context_value early return for fields without validators.""" + + def test_field_without_validator_passes_through(self): + from ccflow.flow_model import _coerce_context_value + + # When name is not in validators, value should pass through unchanged + result = _coerce_context_value("unknown_field", 42, {}, {}) + self.assertEqual(result, 42) + + +class TestGeneratedModelClassFactoryPath(TestCase): + """Group 9: _generated_model_class when stage has no generated model.""" + + def test_returns_none_for_plain_callable(self): + from ccflow.flow_model import _generated_model_class + + def plain_func(): + pass + + self.assertIsNone(_generated_model_class(plain_func)) + + +class TestDescribePipeStagePaths(TestCase): + """Group 10: _describe_pipe_stage for different stage types.""" + + def test_generated_model_instance(self): + from ccflow.flow_model import _describe_pipe_stage + + @Flow.model + def my_stage(x: int) -> int: + return x + + desc = _describe_pipe_stage(my_stage()) + self.assertIn("my_stage", desc) + + def test_callable_stage(self): + from ccflow.flow_model import _describe_pipe_stage + + @Flow.model + def factory_stage(x: int) -> int: + return x + + desc = _describe_pipe_stage(factory_stage) + self.assertIn("factory_stage", desc) + + def test_non_callable_stage(self): + from ccflow.flow_model import _describe_pipe_stage + + desc = _describe_pipe_stage(42) + self.assertEqual(desc, "42") + + +class TestInferPipeParamAmbiguousDefaults(TestCase): + """Cover _infer_pipe_param fallback path with multiple defaulted candidates.""" + + def test_ambiguous_defaulted_candidates(self): + """When all candidates have defaults but multiple are unoccupied.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(a: int = 1, b: int = 2) -> int: + return a + b + + # Both a and b have defaults, both are unoccupied -> ambiguous + with self.assertRaisesRegex(TypeError, "could not infer a target parameter"): + source().pipe(consumer) + + +class TestPipeErrorPaths(TestCase): + """Group 11: pipe() error paths not covered by existing tests.""" + + def test_pipe_non_callable_model_source(self): + """pipe() should reject non-CallableModel source.""" + from ccflow.flow_model import pipe_model + + @Flow.model + def consumer(data: int) -> int: + return data + + with self.assertRaisesRegex(TypeError, "pipe\\(\\) source must be a CallableModel"): + pipe_model("not_a_model", consumer) + + def test_pipe_non_flow_model_target(self): + """pipe() should reject non-@Flow.model target.""" + from ccflow.flow_model import pipe_model + + @Flow.model + def source(x: int) -> int: + return x + + class ManualTarget(CallableModel): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=0) + + with self.assertRaisesRegex(TypeError, "pipe\\(\\) only supports downstream stages"): + pipe_model(source(), ManualTarget()) + + def test_pipe_invalid_param_name(self): + """pipe() should reject invalid target parameter names.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(data: int) -> int: + return data + + with self.assertRaisesRegex(TypeError, "is not valid for"): + source().pipe(consumer, param="nonexistent") + + def test_pipe_already_bound_param(self): + """pipe() should reject already-bound parameters.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(data: int) -> int: + return data + + model = consumer(data=5) + with self.assertRaisesRegex(TypeError, "is already bound"): + source().pipe(model, param="data") + + def test_pipe_no_available_target_parameter(self): + """pipe() should error when all downstream params are occupied.""" + + @Flow.model + def source(x: int) -> int: + return x + + @Flow.model + def consumer(data: int) -> int: + return data + + model = consumer(data=5) + with self.assertRaisesRegex(TypeError, "could not find an available target parameter"): + source().pipe(model) + + def test_pipe_into_generated_instance_rebuilds(self): + """pipe() into an existing generated model instance should rebuild.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + @Flow.model + def consumer(data: int, extra: int = 1) -> int: + return data + extra + + instance = consumer(extra=5) + pipeline = source().pipe(instance) + result = pipeline.flow.compute(x=3) + self.assertEqual(result.value, 35) # 3*10 + 5 + + def test_pipe_bound_model_wrapping_non_generated_rejects(self): + """pipe() into BoundModel wrapping a non-generated model should fail.""" + from ccflow.flow_model import BoundModel, pipe_model + + @Flow.model + def source(x: int) -> int: + return x + + class ManualModel(CallableModel): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + bound = BoundModel(model=ManualModel(), input_transforms={"value": 42}) + with self.assertRaisesRegex(TypeError, "pipe\\(\\) only supports downstream"): + pipe_model(source(), bound) + + +class TestFlowAPIBuildContextFallback(TestCase): + """Group 12: FlowAPI._build_context when _context_schema is None/unset.""" + + def test_unbound_inputs_on_manual_callable_model(self): + """Manual CallableModel with context should show required fields.""" + + class ManualModel(CallableModel): + offset: int + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value + self.offset) + + model = ManualModel(offset=5) + unbound = model.flow.unbound_inputs + self.assertIn("value", unbound) + + +class TestBoundModelRestoreNonDict(TestCase): + """Group 13: BoundModel._restore_serialized_transforms non-dict path.""" + + def test_restore_from_model_instance(self): + """model_validate from an existing BoundModel instance (non-dict).""" + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + # Pass existing instance through model_validate (non-dict path) + restored = BoundModel.model_validate(bound) + ctx = SimpleContext(value=1) + self.assertEqual(restored(ctx).value, 420) + + +class TestBoundModelInitEmptyTransforms(TestCase): + """Group 14: BoundModel.__init__ with no transforms.""" + + def test_init_without_transforms(self): + from ccflow.flow_model import BoundModel + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + bound = BoundModel(model=source()) + self.assertEqual(bound._input_transforms, {}) + result = bound(SimpleContext(value=5)) + self.assertEqual(result.value, 5) + + +class TestBoundModelDeps(TestCase): + """Group 15: BoundModel.__deps__.""" + + def test_deps_returns_wrapped_model(self): + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + bound = source().flow.with_inputs(value=42) + deps = bound.__deps__(SimpleContext(value=1)) + self.assertEqual(len(deps), 1) + self.assertIs(deps[0][0], bound.model) + + +class TestValidateFieldTypesAfterValidator(TestCase): + """Group 16: _validate_field_types in the model_validate path.""" + + def test_model_validate_rejects_wrong_type(self): + """model_validate should reject wrong scalar types.""" + + @Flow.model + def source(x: int) -> int: + return x * 10 + + cls = type(source(x=5)) + with self.assertRaisesRegex(TypeError, "Field 'x'"): + cls.model_validate({"x": "not_an_int"}) + + +class TestGetContextValidatorPaths(TestCase): + """Group 17: _get_context_validator fallback paths.""" + + def test_mode2_context_validator_from_schema(self): + """Mode 2 model should build validator from _context_schema.""" + + @Flow.model(context_args=["start_date"]) + def loader(start_date: str, source: str = "db") -> str: + return f"{source}:{start_date}" + + model = loader() + # Trigger validator creation by calling flow.compute + result = model.flow.compute(start_date="2024-01-01") + self.assertEqual(result.value, "db:2024-01-01") + + def test_mode1_context_validator_uses_context_type_directly(self): + """Mode 1 should use TypeAdapter(context_type) directly.""" + + @Flow.model + def model_fn(context: SimpleContext, offset: int = 0) -> GenericResult[int]: + return GenericResult(value=context.value + offset) + + model = model_fn() + # compute with SimpleContext fields + result = model.flow.compute(value=5) + self.assertEqual(result.value, 5) + + +class TestValidateContextTypeOverrideErrors(TestCase): + """Group 18: _validate_context_type_override error paths.""" + + def test_non_context_base_raises(self): + with self.assertRaisesRegex(TypeError, "context_type must be a ContextBase subclass"): + + @Flow.model(context_args=["x"], context_type=int) + def bad_model(x: int) -> int: + return x + + def test_context_type_missing_context_args_fields(self): + """context_type missing required context_args fields.""" + + class TinyContext(ContextBase): + a: int + + with self.assertRaisesRegex(TypeError, "must define fields for context_args"): + + @Flow.model(context_args=["a", "b"], context_type=TinyContext) + def bad_model(a: int, b: int) -> int: + return a + b + + def test_context_type_extra_required_fields(self): + """context_type has required fields not listed in context_args.""" + + class BigContext(ContextBase): + a: int + b: int + extra: str + + with self.assertRaisesRegex(TypeError, "has required fields not listed in context_args"): + + @Flow.model(context_args=["a"], context_type=BigContext) + def bad_model(a: int) -> int: + return a + + def test_annotation_type_mismatch(self): + """Function and context_type disagree on annotation type.""" + + class TypedContext(ContextBase): + x: str + + with self.assertRaisesRegex(TypeError, "context_arg 'x'"): + + @Flow.model(context_args=["x"], context_type=TypedContext) + def bad_model(x: int) -> int: + return x + + def test_annotation_skip_when_func_ann_is_none(self): + """Annotation check should skip when function annotation is absent from schema.""" + from ccflow.flow_model import _validate_context_type_override + + class CompatContext(ContextBase): + a: int + + # context_args has 'a', schema has 'a': int. Compatible, no error. + result = _validate_context_type_override(CompatContext, ["a"], {"a": int}) + self.assertIs(result, CompatContext) + + def test_subclass_annotations_allowed(self): + """context_type with subclass-compatible annotations should pass.""" + from ccflow.flow_model import _validate_context_type_override + + class ContextWithBase(ContextBase): + ctx: ContextBase + + # Function declares SimpleContext which is a subclass of ContextBase — should pass + result = _validate_context_type_override(ContextWithBase, ["ctx"], {"ctx": SimpleContext}) + self.assertIs(result, ContextWithBase) + + def test_default_vs_required_field_conflict(self): + """Function has default for context_arg but context_type requires it.""" + + class StrictContext(ContextBase): + x: int + + with self.assertRaisesRegex(TypeError, "function has a default but context_type"): + + @Flow.model(context_args=["x"], context_type=StrictContext) + def bad_model(x: int = 5) -> int: + return x + + +class TestDecoratorErrorPaths(TestCase): + """Group 19: Decorator error paths.""" + + def test_context_type_with_explicit_context_param(self): + """context_type= with explicit context param should raise.""" + with self.assertRaisesRegex(TypeError, "context_type.*only supported"): + + @Flow.model(context_type=SimpleContext) + def bad_model(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=0) + + def test_context_type_without_context_args(self): + """context_type= without context_args should raise in dynamic mode.""" + with self.assertRaisesRegex(TypeError, "context_type.*only supported"): + + @Flow.model(context_type=SimpleContext) + def bad_model(x: int) -> int: + return x + + def test_missing_context_annotation(self): + """Missing type annotation on context param should raise.""" + with self.assertRaisesRegex(TypeError, "must have a type annotation"): + + @Flow.model + def bad_model(context) -> int: + return 0 + + def test_missing_param_annotation(self): + """Missing type annotation on a model field param should raise.""" + with self.assertRaisesRegex(TypeError, "must have a type annotation"): + + @Flow.model + def bad_model(context: SimpleContext, untyped_param) -> int: + return 0 + + def test_context_param_not_context_base(self): + """context param annotated with non-ContextBase type should raise.""" + with self.assertRaisesRegex(TypeError, "must be annotated with a ContextBase subclass"): + + @Flow.model + def bad_model(context: int) -> int: + return 0 + + def test_pep563_fallback_on_failed_get_type_hints(self): + """When get_type_hints fails, falls back to raw annotations.""" + + # This is hard to trigger directly, but we can test that string annotations work + @Flow.model + def model_with_string_return(x: int) -> "int": + return x * 2 + + result = model_with_string_return().flow.compute(x=5) + self.assertEqual(result.value, 10) + + +class TestMode1CallPath(TestCase): + """Group 20: Mode 1 explicit context pass-through in __call__.""" + + def test_mode1_resolve_callable_model_returns_non_generic_result(self): + """Mode 1 should handle deps that return raw ResultBase (not GenericResult).""" + + @Flow.model + def upstream(context: SimpleContext) -> MyResult: + return MyResult(data=f"value={context.value}") + + @Flow.model + def downstream(context: SimpleContext, dep: CallableModel) -> GenericResult[str]: + # dep is resolved to MyResult since it's not GenericResult + return GenericResult(value=f"got:{dep}") + + model = downstream(dep=upstream()) + result = model(SimpleContext(value=42)) + self.assertIn("value=42", result.value) + + +class TestDynamicModeContextLookup(TestCase): + """Group 21: Dynamic mode context lookup for deferred values.""" + + def test_deferred_value_from_context(self): + """Dynamic mode should pull deferred values from context.""" + + @Flow.model + def add(x: int, y: int) -> int: + return x + y + + model = add(x=10) + # y is deferred — pulled from context + result = model.flow.compute(y=5) + self.assertEqual(result.value, 15) + + def test_missing_deferred_value_raises(self): + """Dynamic mode should raise for missing deferred values.""" + + @Flow.model + def add(x: int, y: int) -> int: + return x + y + + model = add(x=10) + with self.assertRaisesRegex(TypeError, "Missing runtime input"): + model.flow.compute() # y not provided + + def test_context_sourced_value_coercion(self): + """Dynamic mode should coerce context-sourced values through validators.""" + + @Flow.model + def typed_model(x: int, y: int) -> int: + return x + y + + model = typed_model(x=10) + # y provided as a value that can be coerced to int + result = model.flow.compute(y=5) + self.assertEqual(result.value, 15) + + def test_deferred_value_from_context_object(self): + """Dynamic mode should look up deferred values from context attributes.""" + + @Flow.model + def multiply(x: int, y: int) -> int: + return x * y + + model = multiply(x=3) + # Call directly with a FlowContext — y must come from context + result = model(FlowContext(y=7)) + self.assertEqual(result.value, 21) + + +class TestGetContextValidatorFallbacks(TestCase): + """Group 17 additional: _get_context_validator edge cases.""" + + def test_mode2_with_context_type_override(self): + """Mode 2 with explicit context_type should use that type's validator.""" + + @Flow.model(context_args=["value"], context_type=SimpleContext) + def typed_model(value: int) -> int: + return value * 2 + + model = typed_model() + result = model(SimpleContext(value=5)) + self.assertEqual(result.value, 10) + + def test_dynamic_mode_instance_validator(self): + """Dynamic mode should create instance-specific validator.""" + + @Flow.model + def add(x: int, y: int, z: int = 0) -> int: + return x + y + z + + m1 = add(x=1) + m2 = add(x=1, y=2) + # Different bound fields => different runtime inputs + self.assertIn("y", m1.flow.unbound_inputs) + self.assertNotIn("y", m2.flow.unbound_inputs) + + +class TestRegistryResolutionInValidateFieldTypes(TestCase): + """Group 16: _resolve_registry_refs and _validate_field_types paths.""" + + def test_registry_string_not_resolving_passes_through(self): + """String value that doesn't resolve from registry should fail type validation.""" + + @Flow.model + def model_fn(x: int) -> int: + return x + + cls = type(model_fn(x=1)) + with self.assertRaisesRegex(TypeError, "Field 'x'"): + cls.model_validate({"x": "nonexistent_registry_key"}) + + def test_registry_ref_resolves_to_callable_model(self): + """String value resolving to a CallableModel should be substituted.""" + registry = ModelRegistry.root() + registry.clear() + try: + + @Flow.model + def upstream(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value * 10) + + @Flow.model + def downstream(context: SimpleContext, dep: CallableModel) -> GenericResult[int]: + return GenericResult(value=0) + + registry.add("my_upstream", upstream()) + cls = type(downstream(dep=upstream())) + restored = cls.model_validate({"dep": "my_upstream"}) + self.assertIsNotNone(restored) + finally: + registry.clear() + + +class TestMode2MissingContextField(TestCase): + """Line 1155: Mode 2 missing context field error.""" + + def test_mode2_missing_required_context_field(self): + """Mode 2 model called with context missing a required field should raise.""" + + @Flow.model(context_args=["start_date", "end_date"]) + def loader(start_date: str, end_date: str, source: str = "db") -> str: + return f"{source}:{start_date}-{end_date}" + + model = loader() + # Call with a FlowContext missing end_date + with self.assertRaisesRegex(TypeError, "Missing context field"): + model(FlowContext(start_date="2024-01-01")) + + +class TestDynamicModeContextObjectLookup(TestCase): + """Line 1155/1176: Dynamic mode pulling deferred values from context object.""" + + def test_deferred_value_coercion_through_context(self): + """Dynamic mode should coerce values from FlowContext through validators.""" + + @Flow.model + def typed_add(x: int, y: int) -> int: + return x + y + + model = typed_add(x=10) + # Calling with a FlowContext — y pulled from context and coerced + result = model(FlowContext(y=5)) + self.assertEqual(result.value, 15) + + if __name__ == "__main__": import unittest From 681a6bdcdcf5864efceca1a899cf696918a75b29 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Tue, 7 Apr 2026 19:09:53 -0400 Subject: [PATCH 18/26] Update to simplify code and logic Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 35 +- ccflow/exttypes/frequency.py | 46 +- ccflow/flow_model.py | 1755 +++++------- ccflow/tests/config/conf_flow.yaml | 21 +- ccflow/tests/test_flow_context.py | 642 +---- ccflow/tests/test_flow_model.py | 3134 ++------------------- ccflow/tests/test_flow_model_hydra.py | 480 +--- ccflow/utils/chunker.py | 9 +- ccflow/validators.py | 5 +- docs/design/flow_model_design.md | 385 ++- docs/wiki/Key-Features.md | 452 +-- examples/evaluator_demo.py | 186 -- examples/flow_model_example.py | 12 +- examples/flow_model_hydra_builder_demo.py | 50 +- examples/ml_pipeline_demo.py | 351 --- 15 files changed, 1726 insertions(+), 5837 deletions(-) delete mode 100644 examples/evaluator_demo.py delete mode 100644 examples/ml_pipeline_demo.py diff --git a/ccflow/callable.py b/ccflow/callable.py index 54f4a9d..b77c205 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -535,8 +535,8 @@ def model(*args, **kwargs): features (caching, evaluation, registry, serialization) work unchanged. Args: - context_args: List of parameter names that come from context (for unpacked mode) - context_type: Explicit ContextBase subclass to use with context_args mode + context_type: Optional ContextBase subclass used only to validate/coerce + `FromContext[...]` inputs against an existing nominal context shape cacheable: Enable caching of results (default: False) volatile: Mark as volatile (default: False) log_level: Logging verbosity (default: logging.DEBUG) @@ -544,28 +544,37 @@ def model(*args, **kwargs): verbose: Verbose logging output (default: True) evaluator: Custom evaluator (default: None) - Two Context Modes: + Primary authoring model: + Mark runtime/contextual inputs explicitly with `FromContext[...]`. + Ordinary unmarked parameters are regular bound inputs and are never + read implicitly from the runtime context. - Mode 1 - Explicit context parameter: - Function has a 'context' parameter annotated with a ContextBase subclass. + @Flow.model + def load_prices( + source: str, + start_date: FromContext[date], + end_date: FromContext[date], + ) -> GenericResult[pl.DataFrame]: + return GenericResult(value=query_db(source, start_date, end_date)) + + Advanced interop path: + Functions may still declare an explicit context parameter annotated + with a ContextBase subclass. @Flow.model def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: return GenericResult(value=query_db(source, context.start_date, context.end_date)) - Mode 2 - Unpacked context_args: - Context fields are unpacked into function parameters. - - @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) - def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: - return GenericResult(value=query_db(source, start_date, end_date)) - Dependencies: - Any non-context parameter can be bound either to a literal value or + Any ordinary parameter can be bound either to a literal value or to another CallableModel. When a CallableModel is supplied, the generated model treats it as an upstream dependency and resolves it with the current context before calling the underlying function. + `FromContext[...]` parameters are different: they may be satisfied by + runtime context, construction-time contextual defaults, or function + defaults, but not by CallableModel values. + Usage: # Create model instances loader = load_prices(source="prod_db") diff --git a/ccflow/exttypes/frequency.py b/ccflow/exttypes/frequency.py index 33c16b5..afb772c 100644 --- a/ccflow/exttypes/frequency.py +++ b/ccflow/exttypes/frequency.py @@ -1,3 +1,4 @@ +import re import warnings from datetime import timedelta from functools import cached_property @@ -32,6 +33,13 @@ def _validate(cls, value) -> "Frequency": if isinstance(value, cls): return cls._validate(str(value)) + if isinstance(value, timedelta): + if value.total_seconds() % 86400 == 0: + return cls(f"{int(value.total_seconds() // 86400)}D") + + if isinstance(value, str): + value = _normalize_frequency_alias(value) + if isinstance(value, (timedelta, str)): try: with warnings.catch_warnings(): @@ -43,7 +51,7 @@ def _validate(cls, value) -> "Frequency": raise ValueError(f"ensure this value can be converted to a pandas offset: {e}") if isinstance(value, pd.offsets.DateOffset): - return cls(f"{value.n}{value.base.freqstr}") + return cls(_canonicalize_offset_string(value)) raise ValueError(f"ensure this value can be converted to a pandas offset: {value}") @@ -54,3 +62,39 @@ def validate(cls, value) -> "Frequency": _TYPE_ADAPTER = TypeAdapter(Frequency) + + +_LEGACY_FREQ_PATTERN = re.compile( + r"^(?P[+-]?\d+)?(?PT|M|A|Y)(?:-(?P[A-Za-z]{3}))?$", + re.IGNORECASE, +) + + +def _normalize_frequency_alias(value: str) -> str: + normalized = value.strip() + if not normalized: + return normalized + + match = _LEGACY_FREQ_PATTERN.fullmatch(normalized) + if not match: + day_match = re.fullmatch(r"(?P[+-]?\d+)?d", normalized, re.IGNORECASE) + if day_match: + return f"{day_match.group('count') or 1}D" + return normalized + + count = match.group("count") or "1" + unit = match.group("unit").upper() + suffix = (match.group("suffix") or "DEC").upper() + replacements = { + "T": f"{count}min", + "M": f"{count}ME", + "A": f"{count}YE-{suffix}", + "Y": f"{count}YE-{suffix}", + } + return replacements[unit] + + +def _canonicalize_offset_string(offset: pd.offsets.DateOffset) -> str: + if isinstance(offset, pd.offsets.Day): + return f"{offset.n}D" + return f"{offset.n}{offset.base.freqstr}" diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index da05d8e..fbd5ccb 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -1,106 +1,153 @@ -"""Flow.model decorator implementation. - -This module provides the Flow.model decorator that generates CallableModel classes -from plain Python functions, reducing boilerplate while maintaining full compatibility -with existing ccflow infrastructure. - -Key design: Uses TypedDict + TypeAdapter for context schema validation instead of -generating dynamic ContextBase subclasses. This avoids class registration overhead -and enables clean pickling for distributed computing (e.g., Ray). -""" +"""Flow.model decorator implementation built around ``FromContext``.""" +import hashlib import inspect import logging -import threading -from functools import wraps +import marshal +from dataclasses import dataclass +from functools import lru_cache, wraps +from types import UnionType from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, get_type_hints -from pydantic import Field, PrivateAttr, TypeAdapter, model_serializer, model_validator -from typing_extensions import NotRequired, TypedDict +from pydantic import Field, PrivateAttr, TypeAdapter, ValidationError, model_serializer, model_validator +from pydantic.errors import PydanticUndefinedAnnotation -from .base import ContextBase, ResultBase +from .base import BaseModel, ContextBase, ResultBase from .callable import CallableModel, Flow, GraphDepList, WrapperModel from .context import FlowContext from .local_persistence import register_ccflow_import_path from .result import GenericResult -__all__ = ("FlowAPI", "BoundModel", "Lazy") +__all__ = ("FlowAPI", "BoundModel", "FromContext", "Lazy") _AnyCallable = Callable[..., Any] +log = logging.getLogger(__name__) -class _DeferredInput: - """Sentinel for dynamic @Flow.model inputs left for runtime context.""" - +class _UnsetFlowInput: def __repr__(self) -> str: - return "" + return "" -_DEFERRED_INPUT = _DeferredInput() +_UNSET_FLOW_INPUT = _UnsetFlowInput() +_UNSET = object() +_REMOVED_CONTEXT_ARGS = object() +_UNION_ORIGINS = (Union, UnionType) -def _callable_name(func: _AnyCallable) -> str: - return getattr(func, "__name__", type(func).__name__) +def _unset_flow_input_factory() -> _UnsetFlowInput: + return _UNSET_FLOW_INPUT -def _callable_module(func: _AnyCallable) -> str: - return getattr(func, "__module__", __name__) +def _is_unset_flow_input(value: Any) -> bool: + return value is _UNSET_FLOW_INPUT class _LazyMarker: - """Sentinel that marks a parameter as lazily evaluated via Lazy[T].""" + pass + +class _FromContextMarker: pass -def _extract_lazy(annotation) -> Tuple[Any, bool]: - """Check if annotation is Lazy[T]. Returns (base_type, is_lazy). +class FromContext: + """Marker used in ``@Flow.model`` signatures for runtime/contextual inputs.""" - Handles nested Annotated types, so we need to check the outermost - Annotated layer for _LazyMarker. - """ - if get_origin(annotation) is Annotated: - args = get_args(annotation) - for metadata in args[1:]: - if isinstance(metadata, _LazyMarker): - return args[0], True - return annotation, False + def __class_getitem__(cls, item): + return Annotated[item, _FromContextMarker()] -def _make_lazy_thunk(model, context): - """Create a zero-arg callable that evaluates model(context) on demand. +class Lazy: + """Lazy dependency marker used only as ``Lazy[T]`` in type annotations.""" - The thunk caches its result so repeated calls don't re-evaluate. - """ - _cache = {} + def __new__(cls, *args, **kwargs): + raise TypeError("Lazy(model)(...) has been removed. Use model.flow.with_inputs(...) for contextual rewrites.") - def thunk(): - if "result" not in _cache: - result = model(context) - if isinstance(result, GenericResult): - result = result.value - _cache["result"] = result - return _cache["result"] + def __class_getitem__(cls, item): + return Annotated[item, _LazyMarker()] - return thunk +@dataclass(frozen=True) +class _ParsedAnnotation: + base: Any + is_lazy: bool + is_from_context: bool -log = logging.getLogger(__name__) +@dataclass(frozen=True) +class _FlowModelParam: + name: str + annotation: Any + kind: str + is_lazy: bool + has_function_default: bool + function_default: Any = _UNSET + context_validation_annotation: Any = _UNSET -def _context_values(context: ContextBase) -> Dict[str, Any]: - """Return a plain mapping of all context values. + @property + def is_contextual(self) -> bool: + return self.kind == "contextual" + + @property + def validation_annotation(self) -> Any: + if self.context_validation_annotation is not _UNSET: + return self.context_validation_annotation + return self.annotation + + +@dataclass(frozen=True) +class _FlowModelConfig: + func: _AnyCallable + context_type: Type[ContextBase] + result_type: Type[ResultBase] + auto_wrap_result: bool + explicit_context_param: Optional[str] + parameters: Tuple[_FlowModelParam, ...] + context_input_types: Dict[str, Any] + context_required_names: Tuple[str, ...] + declared_context_type: Optional[Type[ContextBase]] = None + + @property + def regular_params(self) -> Tuple[_FlowModelParam, ...]: + return tuple(param for param in self.parameters if not param.is_contextual) + + @property + def contextual_params(self) -> Tuple[_FlowModelParam, ...]: + return tuple(param for param in self.parameters if param.is_contextual) + + @property + def regular_param_names(self) -> Tuple[str, ...]: + return tuple(param.name for param in self.regular_params) + + @property + def contextual_param_names(self) -> Tuple[str, ...]: + return tuple(param.name for param in self.contextual_params) + + @property + def uses_explicit_context(self) -> bool: + return self.explicit_context_param is not None + + def param(self, name: str) -> _FlowModelParam: + for param in self.parameters: + if param.name == name: + return param + raise KeyError(name) + + +def _callable_name(func: _AnyCallable) -> str: + return getattr(func, "__name__", type(func).__name__) - `dict(context)` uses pydantic's public iteration behavior, which includes - both declared fields and any allowed extra fields. - """ +def _callable_module(func: _AnyCallable) -> str: + return getattr(func, "__module__", __name__) + + +def _context_values(context: ContextBase) -> Dict[str, Any]: return dict(context) def _transform_repr(transform: Any) -> str: - """Render an input transform without noisy object addresses.""" - if callable(transform): name = _callable_name(transform) if name.startswith("<") and name.endswith(">"): @@ -117,149 +164,163 @@ def _bound_field_names(model: Any) -> set[str]: fields_set = getattr(model, "model_fields_set", None) if fields_set is not None: return set(fields_set) - return set(getattr(model, "_bound_fields", set())) - - -def _has_deferred_input(value: Any) -> bool: - return isinstance(value, _DeferredInput) - - -def _deferred_input_factory() -> _DeferredInput: - return _DEFERRED_INPUT + return set() -def _effective_bound_field_names(model: Any) -> set[str]: - fields = _bound_field_names(model) - defaults = getattr(model.__class__, "__flow_model_default_param_names__", set()) - return fields | set(defaults) - +def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: + if isinstance(context_type, type) and issubclass(context_type, ContextBase): + return context_type -def _runtime_input_names(model: Any) -> set[str]: - all_param_names = set(getattr(model.__class__, "__flow_model_all_param_types__", {})) - if not all_param_names: - return set() - return all_param_names - _effective_bound_field_names(model) + if get_origin(context_type) in _UNION_ORIGINS: + for arg in get_args(context_type): + if arg is type(None): + continue + if isinstance(arg, type) and issubclass(arg, ContextBase): + return arg + return None -def _resolve_registry_candidate(value: str) -> Any: - from .base import BaseModel as _BM - try: - candidate = _BM.model_validate(value) - except Exception: - return None - return candidate if isinstance(candidate, _BM) else None +@lru_cache(maxsize=None) +def _type_adapter(annotation: Any) -> TypeAdapter: + return TypeAdapter(annotation) -def _registry_candidate_allowed(expected_type: Type, candidate: Any) -> bool: - if _is_model_dependency(candidate): - return True +def _can_validate_type(annotation: Any) -> bool: try: - TypeAdapter(expected_type).validate_python(candidate) - except Exception: + _type_adapter(annotation) + except (PydanticUndefinedAnnotation, TypeError, ValueError): return False return True -def _type_accepts_str(annotation) -> bool: - """Return True when ``str`` is a valid type for *annotation*. +def _expected_type_repr(annotation: Any) -> str: + try: + return annotation.__name__ + except AttributeError: + return repr(annotation) - Handles ``str``, ``Union[str, ...]``, ``Optional[str]``, and - ``Annotated[str, ...]``. - """ - if annotation is str: - return True - origin = get_origin(annotation) - if origin is Annotated: - return _type_accepts_str(get_args(annotation)[0]) - if origin is Union: - return any(_type_accepts_str(arg) for arg in get_args(annotation) if arg is not type(None)) - return False +def _coerce_value(name: str, value: Any, annotation: Any, source: str) -> Any: + if not _can_validate_type(annotation): + return value + try: + return _type_adapter(annotation).validate_python(value) + except Exception as exc: + expected = _expected_type_repr(annotation) + raise TypeError(f"{source} '{name}': expected {expected}, got {type(value).__name__} ({value!r})") from exc -def _build_typed_dict_adapter(name: str, schema: Dict[str, Type], *, total: bool = True) -> TypeAdapter: - """Build a TypeAdapter for a runtime TypedDict schema.""" - if not schema: - return TypeAdapter(dict) - return TypeAdapter(TypedDict(name, schema, total=total)) +def _unwrap_model_result(value: Any) -> Any: + if isinstance(value, GenericResult): + return value.value + return value -def _concrete_context_type(context_type: Any) -> Optional[Type[ContextBase]]: - """Extract a concrete ContextBase subclass from a context annotation.""" +def _make_lazy_thunk(model: CallableModel, context: ContextBase) -> Callable[[], Any]: + cache: Dict[str, Any] = {} - if isinstance(context_type, type) and issubclass(context_type, ContextBase): - return context_type + def thunk(): + if "result" not in cache: + cache["result"] = _unwrap_model_result(model(context)) + return cache["result"] - if get_origin(context_type) in (Optional, Union): - for arg in get_args(context_type): - if arg is type(None): - continue - if isinstance(arg, type) and issubclass(arg, ContextBase): - return arg + return thunk - return None +def _parse_annotation(annotation: Any) -> _ParsedAnnotation: + is_lazy = False + is_from_context = False -def _build_config_validators(all_param_types: Dict[str, Type]) -> Tuple[Dict[str, Type], Dict[str, TypeAdapter]]: - """Precompute validators for constructor fields.""" + while get_origin(annotation) is Annotated: + args = get_args(annotation) + annotation = args[0] + for metadata in args[1:]: + if isinstance(metadata, _LazyMarker): + is_lazy = True + elif isinstance(metadata, _FromContextMarker): + is_from_context = True - validatable_types: Dict[str, Type] = {} - for name, typ in all_param_types.items(): - try: - TypeAdapter(typ) - validatable_types[name] = typ - except Exception: - pass + return _ParsedAnnotation(base=annotation, is_lazy=is_lazy, is_from_context=is_from_context) - validators = {name: TypeAdapter(typ) for name, typ in validatable_types.items()} - return validatable_types, validators +def _type_accepts_str(annotation: Any) -> bool: + if annotation is str: + return True + origin = get_origin(annotation) + if origin is Annotated: + return _type_accepts_str(get_args(annotation)[0]) + if origin in _UNION_ORIGINS: + return any(_type_accepts_str(arg) for arg in get_args(annotation) if arg is not type(None)) + return False -def _coerce_context_value(name: str, value: Any, validators: Dict[str, TypeAdapter], validatable_types: Dict[str, Type]) -> Any: - """Validate/coerce a single context-sourced value. Returns coerced value or raises TypeError.""" - if name not in validators: - return value - try: - return validators[name].validate_python(value) - except Exception as exc: - expected = validatable_types.get(name, "unknown") - raise TypeError(f"Context field '{name}': expected {expected}, got {type(value).__name__} ({value!r})") from exc +def _resolve_registry_candidate(value: str) -> Any: + try: + candidate = BaseModel.model_validate(value) + except ValidationError: + return None + return candidate if isinstance(candidate, BaseModel) else None -def _validate_config_kwargs(kwargs: Dict[str, Any], validatable_types: Dict[str, Type], validators: Dict[str, TypeAdapter]) -> None: - """Validate plain config inputs while still allowing dependency objects.""" - if not validators: - return +def _registry_candidate_allowed(expected_type: Any, candidate: Any) -> bool: + if _is_model_dependency(candidate): + return True + if not _can_validate_type(expected_type): + return True + try: + _type_adapter(expected_type).validate_python(candidate) + except ValidationError: + return False + return True - from .base import ModelRegistry as _MR - for field_name, validator in validators.items(): - if field_name not in kwargs: - continue - value = kwargs[field_name] - if value is None or _is_model_dependency(value): - continue - if isinstance(value, str) and value in _MR.root(): - expected_type = validatable_types[field_name] - if _type_accepts_str(expected_type): - continue - candidate = _resolve_registry_candidate(value) - if candidate is not None and _registry_candidate_allowed(expected_type, candidate): - continue +def _callable_closure_repr(transform: Any) -> str: + closure = getattr(transform, "__closure__", None) + if not closure: + return "" + pieces = [] + for cell in closure: try: - validator.validate_python(value) + pieces.append(repr(cell.cell_contents)) except Exception: - expected_type = validatable_types[field_name] - raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") + pieces.append("") + return "|".join(pieces) + + +def _callable_fingerprint(transform: Any) -> str: + module = getattr(transform, "__module__", type(transform).__module__) + qualname = getattr(transform, "__qualname__", type(transform).__qualname__) + code = getattr(transform, "__code__", None) + if code is None: + return f"callable:{module}:{qualname}:{repr(transform)}" + + payload = "|".join( + [ + module, + qualname, + code.co_filename, + str(code.co_firstlineno), + hashlib.sha256(marshal.dumps(code)).hexdigest(), + repr(getattr(transform, "__defaults__", None)), + _callable_closure_repr(transform), + ] + ) + return f"callable:{payload}" + + +def _fingerprint_transforms(transforms: Dict[str, Any]) -> Tuple[Tuple[str, str], ...]: + items = [] + for name, transform in sorted(transforms.items()): + if callable(transform): + items.append((name, _callable_fingerprint(transform))) + else: + items.append((name, repr(transform))) + return tuple(items) def _generated_model_instance(stage: Any) -> Optional["_GeneratedFlowModelBase"]: - if isinstance(stage, BoundModel): - model = stage.model - else: - model = stage + model = stage.model if isinstance(stage, BoundModel) else stage if isinstance(model, _GeneratedFlowModelBase): return model return None @@ -287,67 +348,58 @@ def _describe_pipe_stage(stage: Any) -> str: def _generated_model_explicit_kwargs(model: "_GeneratedFlowModelBase") -> Dict[str, Any]: - return cast(Dict[str, Any], model.model_dump(mode="python", exclude_unset=True)) - - -def _infer_pipe_param( - stage_name: str, - param_names: List[str], - default_param_names: set[str], - occupied_names: set[str], -) -> str: - required_candidates = [name for name in param_names if name not in occupied_names and name not in default_param_names] - if len(required_candidates) == 1: - return required_candidates[0] - if len(required_candidates) > 1: - candidates = ", ".join(required_candidates) - raise TypeError( - f"pipe() could not infer a target parameter for {stage_name}; unbound candidates are: {candidates}. Pass param='...' explicitly." - ) - - fallback_candidates = [name for name in param_names if name not in occupied_names] - if len(fallback_candidates) == 1: - return fallback_candidates[0] - if len(fallback_candidates) > 1: - candidates = ", ".join(fallback_candidates) - raise TypeError( - f"pipe() could not infer a target parameter for {stage_name}; unbound candidates are: {candidates}. Pass param='...' explicitly." - ) - - raise TypeError(f"pipe() could not find an available target parameter for {stage_name}.") + values = cast(Dict[str, Any], model.model_dump(mode="python", exclude_unset=True)) + values.pop("type_", None) + values.pop("_target_", None) + values.pop("meta", None) + return values def _resolve_pipe_param(source: Any, stage: Any, param: Optional[str], bindings: Dict[str, Any]) -> Tuple[str, type["_GeneratedFlowModelBase"]]: - del source # Source only matters when binding, not during target resolution. + del source generated_model_cls = _generated_model_class(stage) if generated_model_cls is None: raise TypeError("pipe() only supports downstream stages created by @Flow.model or bound versions of those stages.") + config = generated_model_cls.__flow_model_config__ stage_name = _describe_pipe_stage(stage) - all_param_types = getattr(generated_model_cls, "__flow_model_all_param_types__", {}) - if not all_param_types: - raise TypeError(f"pipe() could not determine bindable parameters for {stage_name}.") - - param_names = list(all_param_types.keys()) - default_param_names = set(getattr(generated_model_cls, "__flow_model_default_param_names__", set())) + regular_names = list(config.regular_param_names) generated_model = _generated_model_instance(stage) occupied_names = set(bindings) if generated_model is not None: - occupied_names |= _bound_field_names(generated_model) - if isinstance(stage, BoundModel): - occupied_names |= set(stage._input_transforms) + occupied_names |= {name for name in _bound_field_names(generated_model) if name in regular_names} if param is not None: - if param not in all_param_types: - valid = ", ".join(param_names) - raise TypeError(f"pipe() target parameter '{param}' is not valid for {stage_name}. Available parameters: {valid}.") + if param in config.contextual_param_names: + raise TypeError(f"pipe() target parameter '{param}' on {stage_name} is contextual. Use .flow.with_inputs(...) instead.") + if param not in regular_names: + valid = ", ".join(regular_names) or "" + raise TypeError(f"pipe() target parameter '{param}' is not valid for {stage_name}. Available regular parameters: {valid}.") if param in occupied_names: raise TypeError(f"pipe() target parameter '{param}' is already bound for {stage_name}.") return param, generated_model_cls - return _infer_pipe_param(stage_name, param_names, default_param_names, occupied_names), generated_model_cls + required_candidates = [p.name for p in config.regular_params if not p.has_function_default and p.name not in occupied_names] + if len(required_candidates) == 1: + return required_candidates[0], generated_model_cls + if len(required_candidates) > 1: + candidates = ", ".join(required_candidates) + raise TypeError( + f"pipe() could not infer a target parameter for {stage_name}; unbound regular candidates are: {candidates}. Pass param='...'." + ) + + fallback_candidates = [name for name in regular_names if name not in occupied_names] + if len(fallback_candidates) == 1: + return fallback_candidates[0], generated_model_cls + if len(fallback_candidates) > 1: + candidates = ", ".join(fallback_candidates) + raise TypeError( + f"pipe() could not infer a target parameter for {stage_name}; unbound regular candidates are: {candidates}. Pass param='...'." + ) + + raise TypeError(f"pipe() could not find an available regular target parameter for {stage_name}.") def pipe_model(source: Any, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: @@ -378,89 +430,221 @@ def pipe_model(source: Any, stage: Any, /, *, param: Optional[str] = None, **bin return stage(**build_kwargs) -class FlowAPI: - """API namespace for deferred computation operations. +def _context_input_types_for_model(model: CallableModel) -> Optional[Dict[str, Any]]: + generated = _generated_model_instance(model) + if generated is not None: + return dict(type(generated).__flow_model_config__.context_input_types) - Provides methods for executing models and transforming contexts. - Accessed via model.flow property. - """ + context_cls = _concrete_context_type(model.context_type) + if context_cls is None or context_cls is FlowContext or not hasattr(context_cls, "model_fields"): + return None + return {name: info.annotation for name, info in context_cls.model_fields.items()} - def __init__(self, model: CallableModel): - self._model = model - def _build_context(self, kwargs: Dict[str, Any]) -> ContextBase: - """Construct a runtime context for either generated or hand-written models.""" - get_validator = getattr(self._model, "_get_context_validator", None) - if get_validator is not None: - validator = get_validator() - validated = validator.validate_python(kwargs) - if isinstance(validated, ContextBase): - return validated - return FlowContext(**validated) +def _context_required_names_for_model(model: CallableModel) -> Tuple[str, ...]: + generated = _generated_model_instance(model) + if generated is not None: + return type(generated).__flow_model_config__.context_required_names + + context_cls = _concrete_context_type(model.context_type) + if context_cls is None or not hasattr(context_cls, "model_fields"): + return () + return tuple(name for name, info in context_cls.model_fields.items() if info.is_required()) + - validator = TypeAdapter(self._model.context_type) - return validator.validate_python(kwargs) +def _missing_regular_param_names(model: "_GeneratedFlowModelBase", config: _FlowModelConfig) -> List[str]: + missing = [] + for param in config.regular_params: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + missing.append(param.name) + return missing - def compute(self, **kwargs) -> Any: - """Execute the model with the provided context arguments. - Validates kwargs against the model's context schema using TypeAdapter, - then wraps in FlowContext and calls the model. +def _resolve_regular_param_value(model: "_GeneratedFlowModelBase", param: _FlowModelParam, context: ContextBase) -> Any: + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + raise TypeError( + f"Regular parameter '{param.name}' for {_callable_name(type(model).__flow_model_config__.func)} is still unbound. " + "Bind it at construction time or via pipe()." + ) + if param.is_lazy: + if _is_model_dependency(value): + return _make_lazy_thunk(value, context) + return lambda v=value: v + if _is_model_dependency(value): + return _unwrap_model_result(value(context)) + return value + + +def _resolve_contextual_param_value( + model: "_GeneratedFlowModelBase", + param: _FlowModelParam, + context_values: Dict[str, Any], +) -> Tuple[Any, bool]: + if param.name in context_values: + return context_values[param.name], True + + value = getattr(model, param.name, _UNSET_FLOW_INPUT) + if not _is_unset_flow_input(value): + return value, True + + if param.has_function_default: + return param.function_default, True + + return _UNSET, False + + +def _resolved_contextual_inputs(model: "_GeneratedFlowModelBase", config: _FlowModelConfig, context: ContextBase) -> Dict[str, Any]: + context_values = _context_values(context) + resolved: Dict[str, Any] = {} + missing_contextual = [] + + for param in config.contextual_params: + value, found = _resolve_contextual_param_value(model, param, context_values) + if not found: + missing_contextual.append(param.name) + continue + resolved[param.name] = value + + if missing_contextual: + missing = ", ".join(sorted(missing_contextual)) + raise TypeError( + f"Missing contextual input(s) for {_callable_name(config.func)}: {missing}. " + "Supply them via the runtime context, compute(), with_inputs(), or construction-time contextual defaults." + ) + + if config.declared_context_type is not None: + validated = config.declared_context_type.model_validate(resolved) + return {param.name: getattr(validated, param.name) for param in config.contextual_params} + + return { + param.name: _coerce_value(param.name, resolved[param.name], param.validation_annotation, "Context field") + for param in config.contextual_params + } + + +def _validate_with_inputs_transforms(model: CallableModel, transforms: Dict[str, Any]) -> Dict[str, Any]: + context_input_types = _context_input_types_for_model(model) + validated = dict(transforms) + + if context_input_types is not None: + invalid = sorted(set(transforms) - set(context_input_types)) + if invalid: + names = ", ".join(invalid) + raise TypeError(f"with_inputs() only accepts contextual fields. Invalid field(s): {names}.") + + for name, transform in list(validated.items()): + if callable(transform): + continue + validated[name] = _coerce_value(name, transform, context_input_types[name], "with_inputs()") + + return validated + + +def _build_generated_compute_context(model: "_GeneratedFlowModelBase", context: Any, kwargs: Dict[str, Any]) -> ContextBase: + config = type(model).__flow_model_config__ + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") + + if config.uses_explicit_context: + if context is _UNSET: + return config.context_type.model_validate(kwargs) + return context if isinstance(context, ContextBase) else config.context_type.model_validate(context) + + if context is not _UNSET: + if isinstance(context, FlowContext): + return context + if isinstance(context, ContextBase): + return FlowContext(**_context_values(context)) + return FlowContext.model_validate(context) - Args: - **kwargs: Context arguments (e.g., start_date, end_date) + invalid = sorted(set(kwargs) - set(config.context_input_types)) + if invalid: + names = ", ".join(invalid) + raise TypeError(f"compute() only accepts contextual inputs. Bind regular parameter(s) separately: {names}.") - Returns: - The model's result, using the same return contract as ``model(context)``. - """ - ctx = self._build_context(kwargs) - return self._model(ctx) + coerced = {} + for param in config.contextual_params: + if param.name not in kwargs: + continue + coerced[param.name] = _coerce_value(param.name, kwargs[param.name], param.validation_annotation, "compute() input") + return FlowContext(**coerced) + + +class FlowAPI: + """API namespace for contextual execution and rewrites.""" + + def __init__(self, model: CallableModel): + self._model = model + + @property + def _compute_target(self) -> CallableModel: + return self._model + + def compute(self, context: Any = _UNSET, /, **kwargs) -> Any: + target = self._compute_target + generated = _generated_model_instance(target) + if generated is not None: + built_context = _build_generated_compute_context(generated, context, kwargs) + return target(built_context) + + if context is not _UNSET and kwargs: + raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") + if context is _UNSET: + built_context = target.context_type.model_validate(kwargs) + else: + built_context = context if isinstance(context, ContextBase) else target.context_type.model_validate(context) + return target(built_context) @property - def unbound_inputs(self) -> Dict[str, Type]: - """Return the context schema (field name -> type). - - In deferred mode, this is everything that must still come from runtime context. - """ - all_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) - model_cls = self._model.__class__ - - # If explicit context_args was provided, use _context_schema minus - # fields that have function defaults (they aren't truly required). - explicit_args = getattr(model_cls, "__flow_model_explicit_context_args__", None) - if explicit_args is not None: - context_schema = getattr(model_cls, "_context_schema", None) - if context_schema is None: - return {} - ctx_arg_defaults = getattr(model_cls, "__flow_model_context_arg_defaults__", {}) - return {name: typ for name, typ in context_schema.items() if name not in ctx_arg_defaults} - - # Dynamic @Flow.model: unbound = params with no explicit value and no declared default - if all_param_types: - runtime_inputs = _runtime_input_names(self._model) - return {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} - - # Generic CallableModel / Mode 1: runtime inputs are the required - # context fields (fields with defaults are not required). - context_cls = _concrete_context_type(self._model.context_type) - if context_cls is None or not hasattr(context_cls, "model_fields"): - return {} - return {name: info.annotation for name, info in context_cls.model_fields.items() if info.is_required()} + def context_inputs(self) -> Dict[str, Any]: + context_input_types = _context_input_types_for_model(self._model) + return dict(context_input_types or {}) + + @property + def unbound_inputs(self) -> Dict[str, Any]: + generated = _generated_model_instance(self._model) + if generated is not None: + config = type(generated).__flow_model_config__ + if config.uses_explicit_context: + return {name: config.context_input_types[name] for name in config.context_required_names} + result = {} + for param in config.contextual_params: + if not _is_unset_flow_input(getattr(generated, param.name, _UNSET_FLOW_INPUT)): + continue + if param.has_function_default: + continue + result[param.name] = config.context_input_types[param.name] + return result + + required_names = _context_required_names_for_model(self._model) + context_input_types = _context_input_types_for_model(self._model) or {} + return {name: context_input_types[name] for name in required_names} @property def bound_inputs(self) -> Dict[str, Any]: - """Return the effective config values for this model.""" - result: Dict[str, Any] = {} - flow_param_types = getattr(self._model.__class__, "__flow_model_all_param_types__", {}) - if flow_param_types: - for name in flow_param_types: - value = getattr(self._model, name, _DEFERRED_INPUT) - if _has_deferred_input(value): + generated = _generated_model_instance(self._model) + if generated is not None: + config = type(generated).__flow_model_config__ + result: Dict[str, Any] = {} + explicit_fields = _bound_field_names(generated) + for param in config.regular_params: + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + result[param.name] = value + for param in config.contextual_params: + if param.name not in explicit_fields: continue - result[name] = value + value = getattr(generated, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + result[param.name] = value return result - # Generic CallableModel: configured model fields are the bound inputs. + result: Dict[str, Any] = {} model_fields = getattr(self._model.__class__, "model_fields", {}) for name in model_fields: if name == "meta": @@ -469,105 +653,53 @@ def bound_inputs(self) -> Dict[str, Any]: return result def with_inputs(self, **transforms) -> "BoundModel": - """Create a version of this model with transformed context inputs. - - Args: - **transforms: Mapping of field name to either: - - A callable (ctx) -> value for dynamic transforms - - A static value to bind - - Returns: - A BoundModel that applies the transforms before calling. - """ - return BoundModel(model=self._model, input_transforms=transforms) - - -_bound_model_restore = threading.local() - - -def _fingerprint_transforms(transforms: Dict[str, Any]) -> Dict[str, str]: - """Create a stable, hashable fingerprint of input transforms for cache key differentiation. - - Callable transforms are identified by their id() (unique per object), which is - stable within a process lifetime. Static values are repr'd directly. - """ - result = {} - for name, transform in sorted(transforms.items()): - if callable(transform): - result[name] = f"callable:{id(transform)}" - else: - result[name] = repr(transform) - return result + validated = _validate_with_inputs_transforms(self._model, transforms) + return BoundModel(model=self._model, input_transforms=validated) class BoundModel(WrapperModel): - """A model with context transforms applied. - - Created by model.flow.with_inputs(). Applies transforms to context - before delegating to the underlying model. - - Context propagation across dependencies: - Each BoundModel transforms the context locally — only for the model it - wraps. When used as a dependency inside another model, the FlowContext - flows through the chain unchanged until it reaches this BoundModel, - which intercepts it, applies its transforms, and passes the modified - context to the wrapped model. Upstream models never see the transform. - - Chaining with_inputs: - Calling ``bound.flow.with_inputs(...)`` merges the new transforms with - the existing ones (new overrides old for the same key). All transforms - are applied to the incoming context in one pass — they don't compose - sequentially (each transform sees the original context, not the output - of a previous transform). - """ + """A model with contextual input transforms applied locally.""" _input_transforms: Dict[str, Any] = PrivateAttr(default_factory=dict) + serialized_transforms: Dict[str, Any] = Field(default_factory=dict, alias="_static_transforms", repr=False, exclude=True) - @model_validator(mode="wrap") + @model_validator(mode="before") @classmethod - def _restore_serialized_transforms(cls, values, handler): - """Strip serialization-injected keys, restore static transforms, guarantee cleanup. - - Uses thread-local storage to pass static transforms to __init__ because - pydantic rejects unknown keys in the input dict. The wrap validator's - try/finally ensures the thread-local is always cleaned up, even if - validation fails before __init__ runs. - """ + def _strip_runtime_serializer_fields(cls, values): if isinstance(values, dict): - values = dict(values) # Don't mutate the caller's dict - values.pop("_input_transforms_token", None) - static = values.pop("_static_transforms", None) - else: - static = None - - if static is not None: - _bound_model_restore.pending = static - try: - return handler(values) - except Exception: - _bound_model_restore.pending = None - raise + cleaned = dict(values) + cleaned.pop("_input_transforms_fingerprint", None) + return cleaned + return values def __init__(self, *, model: CallableModel, input_transforms: Optional[Dict[str, Any]] = None, **kwargs): + if input_transforms is not None: + static_transforms = {name: value for name, value in input_transforms.items() if not callable(value)} + kwargs["_static_transforms"] = static_transforms super().__init__(model=model, **kwargs) - restore = getattr(_bound_model_restore, "pending", None) - if restore is not None: - _bound_model_restore.pending = None if input_transforms is not None: - self._input_transforms = input_transforms - elif restore is not None: - self._input_transforms = restore + self._input_transforms = dict(input_transforms) else: - self._input_transforms = {} + self._input_transforms = dict(self.serialized_transforms) + + def model_post_init(self, __context): + if not self._input_transforms: + self._input_transforms = dict(self.serialized_transforms) def _transform_context(self, context: ContextBase) -> ContextBase: - """Return this model's preferred context type with input transforms applied.""" ctx_dict = _context_values(context) + context_input_types = _context_input_types_for_model(self.model) + for name, transform in self._input_transforms.items(): - if callable(transform): - ctx_dict[name] = transform(context) - else: - ctx_dict[name] = transform + value = transform(context) if callable(transform) else transform + if context_input_types is not None and name in context_input_types: + value = _coerce_value(name, value, context_input_types[name], "with_inputs()") + ctx_dict[name] = value + + generated = _generated_model_instance(self.model) + if generated is not None and not type(generated).__flow_model_config__.uses_explicit_context: + return FlowContext(**ctx_dict) + context_type = _concrete_context_type(self.model.context_type) if context_type is not None and context_type is not FlowContext: return context_type.model_validate(ctx_dict) @@ -575,31 +707,21 @@ def _transform_context(self, context: ContextBase) -> ContextBase: @Flow.call def __call__(self, context: ContextBase) -> ResultBase: - """Call the model with transformed context.""" return self.model(self._transform_context(context)) @Flow.deps def __deps__(self, context: ContextBase) -> GraphDepList: - """Declare the wrapped model as an upstream dependency with transformed context.""" return [(self.model, [self._transform_context(context)])] @model_serializer(mode="wrap") def _serialize_with_transforms(self, handler): - """Include transforms in serialization for cache keys and faithful roundtrips. - - Static (non-callable) transforms are serialized in _static_transforms for - faithful restoration. A fingerprint token covers all transforms (including - callables) for cache key differentiation. - """ data = handler(self) - static = {k: v for k, v in self._input_transforms.items() if not callable(v)} - if static: - data["_static_transforms"] = static - data["_input_transforms_token"] = _fingerprint_transforms(self._input_transforms) + if self.serialized_transforms: + data["_static_transforms"] = dict(self.serialized_transforms) + data["_input_transforms_fingerprint"] = _fingerprint_transforms(self._input_transforms) return data def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: - """Wire this bound model into a downstream generated ``@Flow.model`` stage.""" return pipe_model(self, stage, param=param, **bindings) def __repr__(self) -> str: @@ -608,332 +730,338 @@ def __repr__(self) -> str: @property def flow(self) -> "FlowAPI": - """Access the flow API.""" return _BoundFlowAPI(self) class _BoundFlowAPI(FlowAPI): - """FlowAPI that delegates to a BoundModel, honoring transforms.""" - def __init__(self, bound_model: BoundModel): self._bound = bound_model super().__init__(bound_model.model) - def compute(self, **kwargs) -> Any: - ctx = self._build_context(kwargs) - return self._bound(ctx) # Call through BoundModel, not inner model - - def with_inputs(self, **transforms) -> "BoundModel": - """Chain transforms: merge new transforms with existing ones. + @property + def _compute_target(self) -> CallableModel: + return self._bound - New transforms override existing ones for the same key. - """ - merged = {**self._bound._input_transforms, **transforms} + def with_inputs(self, **transforms) -> BoundModel: + validated = _validate_with_inputs_transforms(self._bound.model, transforms) + merged = {**self._bound._input_transforms, **validated} return BoundModel(model=self._bound.model, input_transforms=merged) class _GeneratedFlowModelBase(CallableModel): - """Shared behavior for models generated by ``@Flow.model``.""" - - __flow_model_context_type__: ClassVar[Type[ContextBase]] = FlowContext - __flow_model_return_type__: ClassVar[Type[ResultBase]] = GenericResult - __flow_model_func__: ClassVar[_AnyCallable | None] = None - __flow_model_use_context_args__: ClassVar[bool] = True - __flow_model_explicit_context_args__: ClassVar[Optional[List[str]]] = None - __flow_model_all_param_types__: ClassVar[Dict[str, Type]] = {} - __flow_model_default_param_names__: ClassVar[set[str]] = set() - __flow_model_context_arg_defaults__: ClassVar[Dict[str, Any]] = {} - __flow_model_auto_wrap__: ClassVar[bool] = False - __flow_model_validatable_types__: ClassVar[Dict[str, Type]] = {} - __flow_model_config_validators__: ClassVar[Dict[str, TypeAdapter]] = {} - _context_schema: ClassVar[Dict[str, Type]] = {} - _context_td: ClassVar[Any | None] = None - _cached_context_validator: ClassVar[TypeAdapter | None] = None + __flow_model_config__: ClassVar[_FlowModelConfig] @model_validator(mode="before") - def _resolve_registry_refs(cls, values, info): + @classmethod + def _resolve_registry_refs(cls, values): if not isinstance(values, dict): return values - param_types = getattr(cls, "__flow_model_all_param_types__", {}) + config = getattr(cls, "__flow_model_config__", None) + if config is None: + return values + resolved = dict(values) - for field_name, expected_type in param_types.items(): - if field_name not in resolved: + for param in config.regular_params: + if param.name not in resolved: continue - value = resolved[field_name] + value = resolved[param.name] if not isinstance(value, str): continue - if _type_accepts_str(expected_type): + if _type_accepts_str(param.annotation): continue candidate = _resolve_registry_candidate(value) if candidate is None: continue - if _registry_candidate_allowed(expected_type, candidate): - resolved[field_name] = candidate + if _registry_candidate_allowed(param.annotation, candidate): + resolved[param.name] = candidate return resolved @model_validator(mode="after") - def _validate_field_types(self): - """Validate field values against their declared types. - - This catches type mismatches in the model_validate/deserialization path, - where fields are typed as Any and pydantic won't reject wrong types. - """ - cls = self.__class__ - config_validators = getattr(cls, "__flow_model_config_validators__", {}) - validatable_types = getattr(cls, "__flow_model_validatable_types__", {}) - if not config_validators: - return self - - for field_name, validator in config_validators.items(): - value = getattr(self, field_name, _DEFERRED_INPUT) - if _has_deferred_input(value) or value is None or _is_model_dependency(value): + def _validate_flow_model_fields(self): + config = self.__class__.__flow_model_config__ + + for param in config.parameters: + value = getattr(self, param.name, _UNSET_FLOW_INPUT) + if _is_unset_flow_input(value): + continue + + if param.is_contextual: + if _is_model_dependency(value): + raise TypeError( + f"Parameter '{param.name}' is marked FromContext[...] and cannot be bound to a CallableModel. " + "Bind a literal contextual default or supply it via compute()/with_inputs()." + ) + object.__setattr__( + self, + param.name, + _coerce_value(param.name, value, param.validation_annotation, "Contextual default"), + ) + continue + + if _is_model_dependency(value): continue - try: - validator.validate_python(value) - except Exception: - expected_type = validatable_types[field_name] - raise TypeError(f"Field '{field_name}': expected {expected_type}, got {type(value).__name__} ({value!r})") + + object.__setattr__(self, param.name, _coerce_value(param.name, value, param.annotation, "Field")) + return self @property def context_type(self) -> Type[ContextBase]: - return self.__class__.__flow_model_context_type__ + return self.__class__.__flow_model_config__.context_type @property def result_type(self) -> Type[ResultBase]: - return self.__class__.__flow_model_return_type__ + return self.__class__.__flow_model_config__.result_type @property def flow(self) -> FlowAPI: return FlowAPI(self) - def _get_context_validator(self) -> TypeAdapter: - """Get or create the context validator for this generated model.""" - - cls = self.__class__ - explicit_args = getattr(cls, "__flow_model_explicit_context_args__", None) - - if explicit_args is not None or not getattr(cls, "__flow_model_use_context_args__", True): - if cls._cached_context_validator is None: - use_ctx_args = getattr(cls, "__flow_model_use_context_args__", True) - ctx_type = cls.__flow_model_context_type__ - if not use_ctx_args and isinstance(ctx_type, type) and issubclass(ctx_type, ContextBase) and ctx_type is not FlowContext: - # Mode 1 with concrete context type — use TypeAdapter(context_type) - # directly so defaults on the context type are respected. - cls._cached_context_validator = TypeAdapter(ctx_type) - elif cls._context_td is not None: - cls._cached_context_validator = TypeAdapter(cls._context_td) - elif cls._context_schema: - cls._cached_context_validator = _build_typed_dict_adapter(f"{cls.__name__}Inputs", cls._context_schema) - else: - cls._cached_context_validator = TypeAdapter(cls.__flow_model_context_type__) - return cls._cached_context_validator - - if not hasattr(self, "_instance_context_validator"): - all_param_types = getattr(cls, "__flow_model_all_param_types__", {}) - runtime_inputs = _runtime_input_names(self) - unbound_schema = {name: typ for name, typ in all_param_types.items() if name in runtime_inputs} - object.__setattr__( - self, - "_instance_context_validator", - _build_typed_dict_adapter(f"{cls.__name__}Inputs", unbound_schema, total=False), - ) - return cast(TypeAdapter, getattr(self, "_instance_context_validator")) - - -class Lazy: - """Deferred model execution with runtime context overrides. - - Has two distinct uses: - 1. **Type annotation** — ``Lazy[T]`` marks a parameter as lazily evaluated. - The framework will NOT pre-evaluate the dependency; instead the function - receives a zero-arg thunk that triggers evaluation on demand:: - - @Flow.model - def smart_training( - data: PreparedData, - fast_metrics: Metrics, - slow_metrics: Lazy[Metrics], # NOT eagerly evaluated - threshold: float = 0.9, - ) -> Metrics: - if fast_metrics.r2 > threshold: - return fast_metrics - return slow_metrics() # Evaluated on demand - - 2. **Runtime helper** — ``Lazy(model)(overrides)`` creates a callable that - applies context overrides before calling the model. Used with - ``with_inputs()`` for deferred execution:: +def _make_call_impl(config: _FlowModelConfig) -> _AnyCallable: + def __call__(self, context): + missing_regular = _missing_regular_param_names(self, config) + if missing_regular: + missing = ", ".join(sorted(missing_regular)) + raise TypeError( + f"Missing regular parameter(s) for {_callable_name(config.func)}: {missing}. " + "Bind them at construction time or via pipe(); compute() only supplies contextual inputs." + ) - lookback = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) + fn_kwargs: Dict[str, Any] = {} + for param in config.regular_params: + fn_kwargs[param.name] = _resolve_regular_param_value(self, param, context) - **Which to use:** + if config.uses_explicit_context: + fn_kwargs[cast(str, config.explicit_context_param)] = context + else: + fn_kwargs.update(_resolved_contextual_inputs(self, config, context)) + + raw_result = config.func(**fn_kwargs) + if config.auto_wrap_result: + return GenericResult(value=raw_result) + return raw_result + + cast(Any, __call__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=config.context_type), + ], + return_annotation=config.result_type, + ) + return __call__ - - Use ``Lazy[T]`` in a ``@Flow.model`` signature when you want conditional/ - on-demand evaluation of an expensive upstream dependency. - - Use ``Lazy(model)(...)`` when you need to rewire context fields before - passing them to an existing model (e.g., shifting a date window). - """ - def __class_getitem__(cls, item): - """Support Lazy[T] syntax as a type annotation marker. +def _make_deps_impl(config: _FlowModelConfig) -> _AnyCallable: + def __deps__(self, context): + missing_regular = _missing_regular_param_names(self, config) + if missing_regular: + missing = ", ".join(sorted(missing_regular)) + raise TypeError(f"Missing regular parameter(s) for {_callable_name(config.func)}: {missing}. Bind them before dependency evaluation.") - Returns Annotated[T, _LazyMarker()] so the framework can detect - lazy parameters during signature analysis. - """ - return Annotated[item, _LazyMarker()] + deps = [] + for param in config.regular_params: + if param.is_lazy: + continue + value = getattr(self, param.name, _UNSET_FLOW_INPUT) + if isinstance(value, BoundModel): + deps.append((value.model, [value._transform_context(context)])) + elif isinstance(value, CallableModel): + deps.append((value, [context])) + return deps + + cast(Any, __deps__).__signature__ = inspect.Signature( + parameters=[ + inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=config.context_type), + ], + return_annotation=GraphDepList, + ) + return __deps__ - def __init__(self, model: "CallableModel"): # noqa: F821 - """Wrap a model for deferred execution. - Args: - model: The CallableModel to wrap - """ - self._model = model +def _context_type_annotations_compatible(func_annotation: Any, context_annotation: Any) -> bool: + if func_annotation is context_annotation: + return True + if isinstance(func_annotation, type) and isinstance(context_annotation, type): + return issubclass(func_annotation, context_annotation) or issubclass(context_annotation, func_annotation) + return True - def __call__(self, **overrides) -> Callable[[ContextBase], Any]: - """Create a callable that applies overrides to context before execution. - Args: - **overrides: Context field overrides. Values can be: - - Static values (applied directly) - - Callables (ctx) -> value (called with context at runtime) +def _validate_declared_context_type(context_type: Any, contextual_params: Tuple[_FlowModelParam, ...]) -> Type[ContextBase]: + if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): + raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}") - Returns: - A callable (context) -> result that applies overrides and calls the model - """ - model = self._model + context_fields = getattr(context_type, "model_fields", {}) + contextual_names = {param.name for param in contextual_params} - def execute_with_overrides(context: ContextBase) -> Any: - # Build context dict from incoming context - ctx_dict = _context_values(context) + missing = sorted(name for name in contextual_names if name not in context_fields) + if missing: + raise TypeError(f"context_type {context_type.__name__} must define fields for all FromContext parameters: {', '.join(missing)}") - # Apply overrides - for name, value in overrides.items(): - if callable(value): - ctx_dict[name] = value(context) - else: - ctx_dict[name] = value + required_extra_fields = sorted( + name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in contextual_names and info.is_required() + ) + if required_extra_fields: + raise TypeError( + f"context_type {context_type.__name__} has required fields that are not declared as FromContext parameters: " + f"{', '.join(required_extra_fields)}" + ) - # Call model with modified context - new_ctx = FlowContext(**ctx_dict) - return model(new_ctx) + for param in contextual_params: + ctx_field = context_fields[param.name] + if not _context_type_annotations_compatible(param.annotation, ctx_field.annotation): + raise TypeError( + f"FromContext parameter '{param.name}' annotates {param.annotation!r}, but " + f"context_type {context_type.__name__} declares {ctx_field.annotation!r}." + ) - return execute_with_overrides + return context_type - @property - def model(self) -> "CallableModel": # noqa: F821 - """Access the wrapped model.""" - return self._model +def _analyze_flow_model( + fn: _AnyCallable, + sig: inspect.Signature, + resolved_hints: Dict[str, Any], + *, + context_type: Optional[Type[ContextBase]], +) -> _FlowModelConfig: + params = sig.parameters + + explicit_context_param = None + if "context" in params: + explicit_context_param = "context" + elif "_" in params: + explicit_context_param = "_" + + analyzed_params: List[_FlowModelParam] = [] + explicit_context_type = None + + if explicit_context_param is not None: + context_annotation = resolved_hints.get(explicit_context_param, params[explicit_context_param].annotation) + explicit_context_type = _concrete_context_type(context_annotation) + if explicit_context_type is None: + raise TypeError(f"Function {_callable_name(fn)}: '{explicit_context_param}' must be annotated with a ContextBase subclass.") + if context_type is not None: + raise TypeError("context_type=... is inferred from the explicit context parameter; remove the keyword argument.") + + for name, param in params.items(): + if name == "self" or name == explicit_context_param: + continue -def _build_context_schema( - context_args: List[str], func: _AnyCallable, sig: inspect.Signature, resolved_hints: Dict[str, Any] -) -> Tuple[Dict[str, Type], Any]: - """Build context schema from context_args parameter names. - - Instead of creating a dynamic ContextBase subclass, this builds: - - A schema dict mapping field names to types - - A TypedDict for Pydantic TypeAdapter validation - - Optionally, a matched existing ContextBase type for compatibility - - Args: - context_args: List of parameter names that come from context - func: The decorated function - sig: The function signature - - Returns: - Tuple of (schema_dict, TypedDict type) - """ - # Build schema dict from parameter annotations - schema = {} - td_schema = {} - for name in context_args: - if name not in sig.parameters: - raise ValueError(f"context_arg '{name}' not found in function parameters") - param = sig.parameters[name] annotation = resolved_hints.get(name, param.annotation) if annotation is inspect.Parameter.empty: - raise ValueError(f"context_arg '{name}' must have a type annotation") - schema[name] = annotation - # Use NotRequired in the TypedDict for params that have a default in the - # function signature, so compute() doesn't require them. - if param.default is not inspect.Parameter.empty: - td_schema[name] = NotRequired[annotation] - else: - td_schema[name] = annotation - - # Create TypedDict for validation (not registered anywhere!) - context_td = TypedDict(f"{_callable_name(func)}Inputs", td_schema) - - return schema, context_td - + raise TypeError(f"Parameter '{name}' must have a type annotation") + + parsed = _parse_annotation(annotation) + if parsed.is_lazy and parsed.is_from_context: + raise TypeError(f"Parameter '{name}' cannot combine Lazy[...] and FromContext[...].") + + has_function_default = param.default is not inspect.Parameter.empty + function_default = param.default if has_function_default else _UNSET + if parsed.is_from_context and has_function_default and _is_model_dependency(function_default): + raise TypeError(f"Parameter '{name}' is marked FromContext[...] and cannot default to a CallableModel.") + + analyzed_params.append( + _FlowModelParam( + name=name, + annotation=parsed.base, + kind="contextual" if parsed.is_from_context else "regular", + is_lazy=parsed.is_lazy, + has_function_default=has_function_default, + function_default=function_default, + ) + ) -def _validate_context_type_override( - context_type: Any, - context_args: List[str], - func_schema: Dict[str, Type], - func_defaults: set[str] = frozenset(), -) -> Type[ContextBase]: - """Validate an explicit ``context_type`` override for ``context_args`` mode.""" + contextual_params = tuple(param for param in analyzed_params if param.is_contextual) + if explicit_context_param is not None and contextual_params: + raise TypeError("Functions using an explicit context parameter cannot also declare FromContext[...] parameters.") - if not isinstance(context_type, type) or not issubclass(context_type, ContextBase): - raise TypeError(f"context_type must be a ContextBase subclass, got {context_type!r}") + declared_context_type = None + if explicit_context_type is not None: + call_context_type = explicit_context_type + context_input_types = {name: info.annotation for name, info in explicit_context_type.model_fields.items()} + context_required_names = tuple(name for name, info in explicit_context_type.model_fields.items() if info.is_required()) + else: + if context_type is not None and not contextual_params: + raise TypeError("context_type=... requires FromContext[...] parameters or an explicit context parameter.") + if context_type is not None: + declared_context_type = _validate_declared_context_type(context_type, contextual_params) + + call_context_type = FlowContext + context_input_types = {param.name: param.annotation for param in contextual_params} + context_required_names = tuple(param.name for param in contextual_params if not param.has_function_default) + + if declared_context_type is not None: + updated_params = [] + context_fields = declared_context_type.model_fields + for param in analyzed_params: + if not param.is_contextual: + updated_params.append(param) + continue + updated_params.append( + _FlowModelParam( + name=param.name, + annotation=param.annotation, + kind=param.kind, + is_lazy=param.is_lazy, + has_function_default=param.has_function_default, + function_default=param.function_default, + context_validation_annotation=context_fields[param.name].annotation, + ) + ) + analyzed_params = updated_params + + return_type = resolved_hints.get("return", sig.return_annotation) + if return_type is inspect.Signature.empty: + raise TypeError(f"Function {_callable_name(fn)} must have a return type annotation") + + return_origin = get_origin(return_type) or return_type + auto_wrap_result = not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)) + result_type = GenericResult if auto_wrap_result else return_type + + return _FlowModelConfig( + func=fn, + context_type=call_context_type, + result_type=result_type, + auto_wrap_result=auto_wrap_result, + explicit_context_param=explicit_context_param, + parameters=tuple(analyzed_params), + context_input_types=context_input_types, + context_required_names=context_required_names, + declared_context_type=declared_context_type, + ) - context_fields = getattr(context_type, "model_fields", {}) - missing = sorted(name for name in context_args if name not in context_fields) - if missing: - raise TypeError(f"context_type {context_type.__name__} must define fields for context_args: {', '.join(missing)}") - required_extra_fields = sorted( - name for name, info in context_fields.items() if name not in ContextBase.model_fields and name not in context_args and info.is_required() - ) - if required_extra_fields: - raise TypeError(f"context_type {context_type.__name__} has required fields not listed in context_args: {', '.join(required_extra_fields)}") - - # Warn when the function's annotation for a context_arg doesn't match the - # context_type's field annotation. A mismatch means the function declares - # one type but will silently receive whatever Pydantic coerces to. - for name in context_args: - func_ann = func_schema.get(name) - ctx_field = context_fields.get(name) - if func_ann is None or ctx_field is None: +def _validate_factory_kwargs(config: _FlowModelConfig, kwargs: Dict[str, Any]) -> None: + for param in config.parameters: + if param.name not in kwargs: continue - ctx_ann = ctx_field.annotation - if func_ann is ctx_ann: - continue - # Both are concrete types — check subclass relationship - if isinstance(func_ann, type) and isinstance(ctx_ann, type): - if not (issubclass(func_ann, ctx_ann) or issubclass(ctx_ann, func_ann)): + value = kwargs[param.name] + if param.is_contextual: + if _is_model_dependency(value): raise TypeError( - f"context_arg '{name}': function annotates {func_ann.__name__} " - f"but context_type {context_type.__name__} declares {ctx_ann.__name__}" + f"Parameter '{param.name}' is marked FromContext[...] and cannot be bound to a CallableModel. " + "Use a literal contextual default or supply it at runtime." ) + _coerce_value(param.name, value, param.validation_annotation, "Field") + continue - # Reject if the function has a default for a context_arg but the - # context_type declares that field as required — this is contradictory. - for name in context_args: - if name in func_defaults: - ctx_field = context_fields.get(name) - if ctx_field is not None and ctx_field.is_required(): - raise TypeError(f"context_arg '{name}': function has a default but context_type {context_type.__name__} requires this field") - - return context_type - - -_UNSET = object() + if value is None or _is_model_dependency(value): + continue + if isinstance(value, str) and not _type_accepts_str(param.annotation): + candidate = _resolve_registry_candidate(value) + if candidate is not None and _registry_candidate_allowed(param.annotation, candidate): + continue + _coerce_value(param.name, value, param.annotation, "Field") def flow_model( func: Optional[_AnyCallable] = None, *, - # Context handling - context_args: Optional[List[str]] = None, + context_args: Any = _REMOVED_CONTEXT_ARGS, context_type: Optional[Type[ContextBase]] = None, - # Flow.call options (passed to generated __call__) - # Default to _UNSET so FlowOptionsOverride can control these globally. - # Only explicitly user-provided values are passed to Flow.call. cacheable: Any = _UNSET, volatile: Any = _UNSET, log_level: Any = _UNSET, @@ -941,394 +1069,67 @@ def flow_model( verbose: Any = _UNSET, evaluator: Any = _UNSET, ) -> _AnyCallable: - """Decorator that generates a CallableModel class from a plain Python function. - - This is syntactic sugar over CallableModel. The decorator generates a real - CallableModel class with proper __call__ and __deps__ methods, so all existing - features (caching, evaluation, registry, serialization) work unchanged. - - Args: - func: The function to decorate - context_args: List of parameter names that come from context (for unpacked mode) - context_type: Explicit ContextBase subclass to use with ``context_args`` mode. - cacheable: Enable caching of results (default: unset, inherits from FlowOptionsOverride) - volatile: Mark as volatile (always re-execute) (default: unset, inherits from FlowOptionsOverride) - log_level: Logging verbosity (default: unset, inherits from FlowOptionsOverride) - validate_result: Validate return type (default: unset, inherits from FlowOptionsOverride) - verbose: Verbose logging output (default: unset, inherits from FlowOptionsOverride) - evaluator: Custom evaluator (default: unset, inherits from FlowOptionsOverride) - - Two Context Modes: - 1. Explicit context parameter: Function has a 'context' parameter annotated - with a ContextBase subclass. - - @Flow.model - def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: - ... + """Decorator that generates a CallableModel class from a plain Python function.""" - 2. Unpacked context_args: Context fields are unpacked into function parameters. - - @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) - def load_prices(start_date: date, end_date: date, source: str) -> GenericResult[pl.DataFrame]: - ... - - Returns: - A factory function that creates CallableModel instances - """ + if context_args is not _REMOVED_CONTEXT_ARGS: + raise TypeError("context_args=... has been removed. Mark runtime/contextual parameters with FromContext[...] instead.") def decorator(fn: _AnyCallable) -> _AnyCallable: sig = inspect.signature(fn) - params = sig.parameters - # Resolve string annotations (PEP 563 / from __future__ import annotations) - # into real type objects. include_extras=True preserves Annotated metadata. try: - _resolved_hints = get_type_hints(fn, include_extras=True) - except Exception: - _resolved_hints = {} - - # Validate return type - return_type = _resolved_hints.get("return", sig.return_annotation) - if return_type is inspect.Signature.empty: - raise TypeError(f"Function {_callable_name(fn)} must have a return type annotation") - # Check if return type is a ResultBase subclass; if not, auto-wrap in GenericResult - return_origin = get_origin(return_type) or return_type - auto_wrap_result = False - if not (isinstance(return_origin, type) and issubclass(return_origin, ResultBase)): - auto_wrap_result = True - internal_return_type = GenericResult # unparameterized for safety - else: - internal_return_type = return_type - - # ── Context mode selection ── - # The decorator supports three mutually exclusive context modes: - # - # Mode 1 (explicit context): Function has a 'context' (or '_') parameter - # annotated with a ContextBase subclass. Behaves like a traditional - # CallableModel.__call__. Other params become model fields. - # - # Mode 2 (context_args): Decorator specifies context_args=[...] listing - # which params come from the context at runtime. Remaining params become - # model fields. Uses FlowContext unless context_type= overrides it. - # - # Mode 3 (dynamic deferred): No 'context' param and no context_args. - # Every param is a potential model field. Params bound at construction - # are config; unbound params become runtime inputs from FlowContext. - # - context_schema_early: Dict[str, Type] = {} - context_td_early = None - if "context" in params or "_" in params: - # Mode 1: Explicit context parameter (named 'context' or '_' for unused) - if context_type is not None: - raise TypeError("context_type=... is only supported when using context_args=[...]") - context_param_name = "context" if "context" in params else "_" - context_param = params[context_param_name] - context_annotation = _resolved_hints.get(context_param_name, context_param.annotation) - if context_annotation is inspect.Parameter.empty: - raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' parameter must have a type annotation") - resolved_context_type = context_annotation - if not (isinstance(resolved_context_type, type) and issubclass(resolved_context_type, ContextBase)): - raise TypeError(f"Function {_callable_name(fn)}: '{context_param_name}' must be annotated with a ContextBase subclass") - model_field_params = {name: param for name, param in params.items() if name not in (context_param_name, "self")} - use_context_args = False - explicit_context_args = None - elif context_args is not None: - # Mode 2: Explicit context_args - specified params come from context - context_param_name = "context" - context_schema_early, context_td_early = _build_context_schema(context_args, fn, sig, _resolved_hints) - _func_defaults_set = {name for name in context_args if sig.parameters[name].default is not inspect.Parameter.empty} - explicit_context_type = ( - _validate_context_type_override(context_type, context_args, context_schema_early, _func_defaults_set) - if context_type is not None - else None - ) - resolved_context_type = explicit_context_type if explicit_context_type is not None else FlowContext - # Exclude context_args from model fields - model_field_params = {name: param for name, param in params.items() if name not in context_args and name != "self"} - use_context_args = True - explicit_context_args = context_args - else: - # Mode 3: Dynamic deferred mode - every param can be configured on the model, - # but only params without Python defaults remain runtime inputs when omitted. - if context_type is not None: - raise TypeError("context_type=... is only supported when using context_args=[...]") - context_param_name = "context" - resolved_context_type = FlowContext - model_field_params = {name: param for name, param in params.items() if name != "self"} - use_context_args = True - explicit_context_args = None # Dynamic - determined at construction - - # Analyze parameters to find lazy fields and regular fields. - model_fields: Dict[str, Tuple[Type, Any]] = {} # name -> (type, default) - lazy_fields: set[str] = set() # Names of parameters marked with Lazy[T] - default_param_names: set[str] = set() - - # In dynamic deferred mode (no explicit context_args), fields without Python defaults - # are internally represented by a deferred sentinel until runtime context supplies them. - dynamic_deferred_mode = use_context_args and explicit_context_args is None - - for name, param in model_field_params.items(): - # Use resolved hint (handles PEP 563 string annotations) - annotation = _resolved_hints.get(name, param.annotation) - if annotation is inspect.Parameter.empty: - raise TypeError(f"Parameter '{name}' must have a type annotation") - - # Check for Lazy[T] annotation first - unwrapped_annotation, is_lazy = _extract_lazy(annotation) - if is_lazy: - lazy_fields.add(name) - - if param.default is not inspect.Parameter.empty: - default_param_names.add(name) - default = param.default - elif dynamic_deferred_mode: - # In dynamic mode, params without defaults remain deferred to runtime context. - default = Field(default_factory=_deferred_input_factory, exclude_if=_has_deferred_input) - else: - # In explicit mode, params without defaults are required - default = ... - - model_fields[name] = (Any, default) - - # Capture variables for closures - ctx_param_name = context_param_name if not use_context_args else "context" - all_param_names = list(model_fields.keys()) # All non-context params (model fields) - all_param_types = {name: _resolved_hints.get(name, param.annotation) for name, param in model_field_params.items()} - # For explicit context_args mode, we also need the list of context arg names - ctx_args_for_closure = context_args if context_args is not None else [] - is_dynamic_mode = use_context_args and explicit_context_args is None - - # Compute context_arg defaults and validators for Mode 2 (context_args) - context_arg_defaults: Dict[str, Any] = {} - _ctx_validatable_types: Dict[str, Type] = {} - _ctx_validators: Dict[str, TypeAdapter] = {} - if context_args is not None: - for name in context_args: - p = sig.parameters[name] - if p.default is not inspect.Parameter.empty: - context_arg_defaults[name] = p.default - _ctx_validatable_types, _ctx_validators = _build_config_validators(context_schema_early) - - # Create the __call__ method - def make_call_impl(): - def __call__(self, context): - def resolve_callable_model(value): - """Resolve a CallableModel field.""" - resolved = value(context) - if isinstance(resolved, GenericResult): - return resolved.value - return resolved - - # Build kwargs for the original function - fn_kwargs = {} - - def _resolve_field(name, value): - """Resolve a single field value, handling lazy wrapping.""" - is_dep = isinstance(value, CallableModel) - if name in lazy_fields: - # Lazy field: wrap in a thunk regardless of type - if is_dep: - return _make_lazy_thunk(value, context) - else: - # Non-dep value: wrap in trivial thunk - return lambda v=value: v - elif is_dep: - return resolve_callable_model(value) - else: - return value - - if not use_context_args: - # Mode 1: Explicit context param - pass context directly - fn_kwargs[ctx_param_name] = context - # Add model fields - for name in all_param_names: - value = getattr(self, name) - fn_kwargs[name] = _resolve_field(name, value) - elif not is_dynamic_mode: - # Mode 2: Explicit context_args - get those from context, rest from self - for name in ctx_args_for_closure: - value = getattr(context, name, _UNSET) - if value is _UNSET: - if name in context_arg_defaults: - fn_kwargs[name] = context_arg_defaults[name] - else: - raise TypeError(f"Missing context field '{name}'") - else: - fn_kwargs[name] = _coerce_context_value(name, value, _ctx_validators, _ctx_validatable_types) - # Add model fields - for name in all_param_names: - value = getattr(self, name) - fn_kwargs[name] = _resolve_field(name, value) - else: - # Mode 3: Dynamic deferred mode - explicit values or Python defaults from self, - # otherwise values come from runtime context. - explicit_fields = _bound_field_names(self) - missing_fields = [] - - for name in all_param_names: - value = getattr(self, name, _DEFERRED_INPUT) - if name in explicit_fields or name in default_param_names: - # Explicitly provided or implicitly bound via Python default. - value = getattr(self, name) - fn_kwargs[name] = _resolve_field(name, value) - continue - - if _has_deferred_input(value): - value = getattr(context, name, _UNSET) - if value is _UNSET: - missing_fields.append(name) - continue - # Validate/coerce context-sourced value, skip CallableModel deps - if not _is_model_dependency(value): - value = _coerce_context_value(name, value, _config_validators, _validatable_types) - fn_kwargs[name] = _resolve_field(name, value) - - if missing_fields: - missing = ", ".join(sorted(missing_fields)) - raise TypeError( - f"Missing runtime input(s) for {_callable_name(fn)}: {missing}. " - "Provide them in the call context or bind them at construction time." - ) - - raw_result = fn(**fn_kwargs) - if auto_wrap_result: - return GenericResult(value=raw_result) - return raw_result - - # Set proper signature for CallableModel validation - cast(Any, __call__).__signature__ = inspect.Signature( - parameters=[ - inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=resolved_context_type), - ], - return_annotation=internal_return_type, - ) - return __call__ - - call_impl = make_call_impl() - - # Apply Flow.call decorator — only include options the user explicitly set - flow_options = {} - for opt_name, opt_val in [ - ("cacheable", cacheable), - ("volatile", volatile), - ("log_level", log_level), - ("validate_result", validate_result), - ("verbose", verbose), - ("evaluator", evaluator), - ]: - if opt_val is not _UNSET: - flow_options[opt_name] = opt_val - - decorated_call = Flow.call(**flow_options)(call_impl) - - # Create the __deps__ method - def make_deps_impl(): - def __deps__(self, context) -> GraphDepList: - deps = [] - # Check ALL fields for CallableModel dependencies (auto-detection) - for name in model_fields: - if name in lazy_fields: - continue # Lazy deps are NOT pre-evaluated - value = getattr(self, name) - if isinstance(value, BoundModel): - deps.append((value.model, [value._transform_context(context)])) - elif isinstance(value, CallableModel): - deps.append((value, [context])) - return deps - - # Set proper signature - cast(Any, __deps__).__signature__ = inspect.Signature( - parameters=[ - inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD), - inspect.Parameter("context", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=resolved_context_type), - ], - return_annotation=GraphDepList, - ) - return __deps__ - - deps_impl = make_deps_impl() - decorated_deps = Flow.deps(deps_impl) + resolved_hints = get_type_hints(fn, include_extras=True) + except (AttributeError, NameError, TypeError): + resolved_hints = {} - # Build pydantic field annotations for the class - annotations = {} + config = _analyze_flow_model(fn, sig, resolved_hints, context_type=context_type) - namespace = { + annotations: Dict[str, Any] = {} + namespace: Dict[str, Any] = { "__module__": _callable_module(fn), "__qualname__": f"_{_callable_name(fn)}_Model", - "__call__": decorated_call, - "__deps__": decorated_deps, + "__call__": Flow.call( + **{ + name: value + for name, value in [ + ("cacheable", cacheable), + ("volatile", volatile), + ("log_level", log_level), + ("validate_result", validate_result), + ("verbose", verbose), + ("evaluator", evaluator), + ] + if value is not _UNSET + } + )(_make_call_impl(config)), + "__deps__": Flow.deps(_make_deps_impl(config)), } - for name, (typ, default) in model_fields.items(): - annotations[name] = typ - if default is not ...: - namespace[name] = default + for param in config.parameters: + annotations[param.name] = Any + if param.is_contextual: + namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) + elif param.has_function_default: + namespace[param.name] = param.function_default else: - # For required fields, use Field(...) - namespace[name] = Field(...) + namespace[param.name] = Field(default_factory=_unset_flow_input_factory, exclude_if=_is_unset_flow_input) namespace["__annotations__"] = annotations - _validatable_types, _config_validators = _build_config_validators(all_param_types) - - # Create the class using type() GeneratedModel = cast(type[_GeneratedFlowModelBase], type(f"_{_callable_name(fn)}_Model", (_GeneratedFlowModelBase,), namespace)) - - # Set class-level attributes after class creation (to avoid pydantic processing) - GeneratedModel.__flow_model_context_type__ = resolved_context_type - GeneratedModel.__flow_model_return_type__ = internal_return_type - setattr(GeneratedModel, "__flow_model_func__", fn) - GeneratedModel.__flow_model_use_context_args__ = use_context_args - GeneratedModel.__flow_model_explicit_context_args__ = explicit_context_args - GeneratedModel.__flow_model_all_param_types__ = all_param_types # All param name -> type - GeneratedModel.__flow_model_default_param_names__ = default_param_names - GeneratedModel.__flow_model_context_arg_defaults__ = context_arg_defaults - GeneratedModel.__flow_model_auto_wrap__ = auto_wrap_result - GeneratedModel.__flow_model_validatable_types__ = _validatable_types - GeneratedModel.__flow_model_config_validators__ = _config_validators - - # Build context_schema - context_schema: Dict[str, Type] = {} - context_td = None - - if explicit_context_args is not None: - # Explicit context_args provided - use early-computed schema - context_schema, context_td = context_schema_early, context_td_early - elif not use_context_args: - # Explicit context mode - schema comes from the context type's fields - if hasattr(resolved_context_type, "model_fields"): - context_schema = {name: info.annotation for name, info in resolved_context_type.model_fields.items()} - # For dynamic mode (is_dynamic_mode), _context_schema remains empty - # and schema is built dynamically from the instance's unresolved runtime inputs. - - # Store context schema for TypedDict-based validation (picklable!) - GeneratedModel._context_schema = context_schema - GeneratedModel._context_td = context_td - # Validator is created lazily to survive pickling - GeneratedModel._cached_context_validator = None - - # Register the MODEL class for serialization (needed for model_dump/_target_). - # Note: We do NOT register dynamic context classes anymore - context handling - # uses FlowContext + TypedDict instead, which don't need registration. + GeneratedModel.__flow_model_config__ = config register_ccflow_import_path(GeneratedModel) - - # Rebuild the model to process annotations properly GeneratedModel.model_rebuild() - # Create factory function that returns model instances @wraps(fn) def factory(**kwargs) -> _GeneratedFlowModelBase: - _validate_config_kwargs(kwargs, _validatable_types, _config_validators) + _validate_factory_kwargs(config, kwargs) return GeneratedModel(**kwargs) - # Preserve useful attributes on factory cast(Any, factory)._generated_model = GeneratedModel factory.__doc__ = fn.__doc__ - return factory - # Handle both @Flow.model and @Flow.model(...) syntax if func is not None: return decorator(func) return decorator diff --git a/ccflow/tests/config/conf_flow.yaml b/ccflow/tests/config/conf_flow.yaml index 41acfaf..c58d7b7 100644 --- a/ccflow/tests/config/conf_flow.yaml +++ b/ccflow/tests/config/conf_flow.yaml @@ -1,7 +1,5 @@ -# Flow.model configurations for Hydra integration tests -# This file is separate from conf.yaml to avoid affecting existing tests +# Flow.model configurations for Hydra integration tests. -# Basic Flow.model flow_loader: _target_: ccflow.tests.test_flow_model.basic_loader source: test_source @@ -12,7 +10,6 @@ flow_processor: prefix: "value=" suffix: "!" -# Pipeline with dependencies (uses registry name references for same instance) flow_source: _target_: ccflow.tests.test_flow_model.data_source base_value: 100 @@ -22,7 +19,6 @@ flow_transformer: source: flow_source factor: 3 -# Three-stage pipeline flow_stage1: _target_: ccflow.tests.test_flow_model.pipeline_stage1 initial: 10 @@ -37,7 +33,6 @@ flow_stage3: stage2_output: flow_stage2 offset: 50 -# Diamond dependency pattern diamond_source: _target_: ccflow.tests.test_flow_model.data_source base_value: 10 @@ -58,7 +53,6 @@ diamond_aggregator: input_b: diamond_branch_b operation: add -# DateRangeContext with transform flow_date_loader: _target_: ccflow.tests.test_flow_model.date_range_loader_previous_day source: market_data @@ -69,12 +63,11 @@ flow_date_processor: raw_data: flow_date_loader normalize: true -# context_args models (auto-unpacked context parameters) -ctx_args_loader: - _target_: ccflow.tests.test_flow_model.context_args_loader +contextual_loader_model: + _target_: ccflow.tests.test_flow_model.contextual_loader source: data_source -ctx_args_processor: - _target_: ccflow.tests.test_flow_model.context_args_processor - data: ctx_args_loader - prefix: "output" +contextual_processor_model: + _target_: ccflow.tests.test_flow_model.contextual_processor + data: contextual_loader_model + prefix: output diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 970cc08..caf3991 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -1,19 +1,13 @@ -"""Tests for FlowContext, FlowAPI, and TypedDict-based context validation. - -These tests verify the new deferred computation API that uses: -- FlowContext: Universal context carrier with extra="allow" -- TypedDict + TypeAdapter: Schema validation without dynamic class registration -- FlowAPI: The .flow namespace for compute/with_inputs/etc. -""" +"""Tests for FlowContext, FlowAPI, and BoundModel under the FromContext design.""" import pickle +from concurrent.futures import ThreadPoolExecutor from datetime import date, timedelta import cloudpickle import pytest -from ccflow import CallableModel, ContextBase, Flow, FlowAPI, FlowContext, GenericResult -from ccflow.context import DateRangeContext +from ccflow import BoundModel, CallableModel, ContextBase, Flow, FlowContext, FromContext, GenericResult class NumberContext(ContextBase): @@ -28,569 +22,161 @@ def __call__(self, context: NumberContext) -> GenericResult[int]: return GenericResult(value=context.x + self.offset) -class TestFlowContext: - """Tests for the FlowContext universal carrier.""" - - def test_flow_context_basic(self): - """FlowContext accepts arbitrary fields.""" - ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - assert ctx.start_date == date(2024, 1, 1) - assert ctx.end_date == date(2024, 1, 31) - - def test_flow_context_extra_fields(self): - """FlowContext exposes arbitrary fields through normal model APIs.""" - ctx = FlowContext(x=1, y="hello", z=[1, 2, 3]) - assert ctx.x == 1 - assert ctx.y == "hello" - assert ctx.z == [1, 2, 3] - assert dict(ctx) == {"x": 1, "y": "hello", "z": [1, 2, 3]} - - def test_flow_context_frozen(self): - """FlowContext is immutable (frozen).""" - ctx = FlowContext(value=42) - with pytest.raises(Exception): # ValidationError for frozen model - ctx.value = 100 - - def test_flow_context_repr(self): - """FlowContext has a useful repr.""" - ctx = FlowContext(a=1, b=2) - repr_str = repr(ctx) - assert "FlowContext" in repr_str - assert "a=1" in repr_str - assert "b=2" in repr_str - - def test_flow_context_attribute_error(self): - """FlowContext raises AttributeError for missing fields.""" - ctx = FlowContext(x=1) - with pytest.raises(AttributeError, match="no attribute 'missing'"): - _ = ctx.missing - - def test_flow_context_model_dump(self): - """FlowContext can be dumped (includes extra fields).""" - ctx = FlowContext(start_date=date(2024, 1, 1), value=42) - dumped = ctx.model_dump() - assert dumped["start_date"] == date(2024, 1, 1) - assert dumped["value"] == 42 - - def test_flow_context_value_semantics_include_extra_fields(self): - """Equality should reflect the actual extra payload.""" - assert FlowContext(x=1) == FlowContext(x=1) - assert FlowContext(x=1) != FlowContext(x=2) - assert FlowContext(x=1) != FlowContext(y=1) - - def test_flow_context_hash_uses_extra_fields(self): - """Distinct extra payloads should remain distinct in hashed collections.""" - first = FlowContext(values=[1, 2], label="a") - second = FlowContext(values=[1, 3], label="a") - third = FlowContext(values=[1, 2], label="b") - - assert len({first, second, third}) == 3 - - def test_flow_context_hash_raises_for_unhashable_values(self): - """FlowContext with truly unhashable values (no __dict__) should raise TypeError.""" - - class Unhashable: - __hash__ = None # type: ignore[assignment] - - def __init__(self): - pass - - # Deliberately no __dict__ suppression — but __hash__ is None, - # so the fallback path in _freeze_for_hash should use __dict__. - # To trigger the actual TypeError path, we need an object with - # no __dict__ and no __hash__. - - class UnhashableSlots: - __slots__ = () - __hash__ = None # type: ignore[assignment] - - ctx = FlowContext(val=UnhashableSlots()) - with pytest.raises(TypeError, match="unhashable value"): - hash(ctx) - - def test_flow_context_eq_non_flow_context(self): - """FlowContext.__eq__ returns False for non-FlowContext objects.""" - ctx = FlowContext(x=1) - assert ctx != 42 - assert ctx != "hello" - assert ctx != None # noqa: E711 - assert ctx != NumberContext(x=1) - - def test_flow_context_hash_with_set_value(self): - """FlowContext with set values should hash correctly via frozenset.""" - ctx = FlowContext(tags=frozenset({"a", "b"})) - # Should not raise - h = hash(ctx) - assert isinstance(h, int) - - def test_flow_context_hash_with_model_dump_object(self): - """_freeze_for_hash should handle objects with model_dump attribute.""" - from ccflow.context import _freeze_for_hash - - # Directly test _freeze_for_hash with an object that has model_dump - # (FlowContext.__hash__ goes through model_dump first which serializes - # nested models, so we test the helper directly) - inner = NumberContext(x=42) - result = _freeze_for_hash(inner) - assert isinstance(result, tuple) - assert result[0] is NumberContext - - def test_flow_context_hash_unhashable_with_dict_fallback(self): - """Objects with __dict__ but no __hash__ should use __dict__ fallback.""" - - class UnhashableWithDict: - __hash__ = None # type: ignore[assignment] - - def __init__(self, val): - self.val = val - - ctx = FlowContext(obj=UnhashableWithDict(42)) - h = hash(ctx) - assert isinstance(h, int) - - def test_flow_context_pickle(self): - """FlowContext pickles cleanly.""" - ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - pickled = pickle.dumps(ctx) - unpickled = pickle.loads(pickled) - assert unpickled.start_date == date(2024, 1, 1) - assert unpickled.end_date == date(2024, 1, 31) - - def test_flow_context_cloudpickle(self): - """FlowContext works with cloudpickle (for Ray).""" - ctx = FlowContext(data=[1, 2, 3], name="test") - pickled = cloudpickle.dumps(ctx) - unpickled = cloudpickle.loads(pickled) - assert unpickled.data == [1, 2, 3] - assert unpickled.name == "test" - - -class TestFlowAPI: - """Tests for the FlowAPI (.flow namespace).""" - - def test_flow_compute_basic(self): - """FlowAPI.compute() validates and executes.""" - - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: - return GenericResult(value={"start": start_date, "end": end_date, "source": source}) - - model = load_data(source="api") - result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - - assert result.value["start"] == date(2024, 1, 1) - assert result.value["end"] == date(2024, 1, 31) - assert result.value["source"] == "api" - - def test_flow_compute_type_coercion(self): - """FlowAPI.compute() coerces types via TypeAdapter.""" - - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date) -> GenericResult[dict]: - return GenericResult(value={"start": start_date, "end": end_date}) - - model = load_data() - # Pass strings - should be coerced to dates - result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") - - assert result.value["start"] == date(2024, 1, 1) - assert result.value["end"] == date(2024, 1, 31) - - def test_flow_compute_validation_error(self): - """FlowAPI.compute() raises on missing required args.""" - - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date) -> GenericResult[dict]: - return GenericResult(value={}) - - model = load_data() - with pytest.raises(Exception): # ValidationError - model.flow.compute(start_date=date(2024, 1, 1)) # Missing end_date - - def test_flow_unbound_inputs(self): - """FlowAPI.unbound_inputs returns the context schema.""" - - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: - return GenericResult(value={}) - - model = load_data(source="api") - unbound = model.flow.unbound_inputs - - assert "start_date" in unbound - assert "end_date" in unbound - assert unbound["start_date"] == date - assert unbound["end_date"] == date - # source is not unbound (it has a default/is bound) - assert "source" not in unbound - - def test_flow_bound_inputs(self): - """FlowAPI.bound_inputs returns config values.""" - - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date, source: str = "db") -> GenericResult[dict]: - return GenericResult(value={}) - - model = load_data(source="api") - bound = model.flow.bound_inputs - - assert "source" in bound - assert bound["source"] == "api" - # Context args are not in bound_inputs - assert "start_date" not in bound - assert "end_date" not in bound - - def test_flow_compute_regular_callable_model(self): - """Regular CallableModels also expose .flow.compute().""" - - model = OffsetModel(offset=10) - result = model.flow.compute(x=5) - - assert result.value == 15 - - def test_flow_unbound_inputs_regular_callable_model(self): - """Regular CallableModels expose their context schema as unbound inputs.""" - - model = OffsetModel(offset=10) - unbound = model.flow.unbound_inputs - - assert unbound == {"x": int} - - def test_flow_bound_inputs_regular_callable_model(self): - """Regular CallableModels expose their configured fields as bound inputs.""" - - model = OffsetModel(offset=10) - bound = model.flow.bound_inputs - - assert bound["offset"] == 10 - - -class TestBoundModel: - """Tests for BoundModel (created via .flow.with_inputs()).""" - - def test_with_inputs_static_value(self): - """with_inputs can bind static values.""" - - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date) -> GenericResult[dict]: - return GenericResult(value={"start": start_date, "end": end_date}) - - model = load_data() - bound = model.flow.with_inputs(start_date=date(2024, 1, 1)) - - # Call with just end_date (start_date is bound) - ctx = FlowContext(end_date=date(2024, 1, 31)) - result = bound(ctx) - assert result.value["start"] == date(2024, 1, 1) - assert result.value["end"] == date(2024, 1, 31) - - def test_with_inputs_transform_function(self): - """with_inputs can use transform functions.""" - - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date) -> GenericResult[dict]: - return GenericResult(value={"start": start_date, "end": end_date}) - - model = load_data() - # Lookback: start_date is 7 days before the context's start_date - bound = model.flow.with_inputs(start_date=lambda ctx: ctx.start_date - timedelta(days=7)) - - ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 31)) - result = bound(ctx) - assert result.value["start"] == date(2024, 1, 1) # 7 days before - assert result.value["end"] == date(2024, 1, 31) - - def test_with_inputs_multiple_transforms(self): - """with_inputs can apply multiple transforms.""" - - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date) -> GenericResult[dict]: - return GenericResult(value={"start": start_date, "end": end_date}) - - model = load_data() - bound = model.flow.with_inputs( - start_date=lambda ctx: ctx.start_date - timedelta(days=7), - end_date=lambda ctx: ctx.end_date + timedelta(days=1), - ) - - ctx = FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 30)) - result = bound(ctx) - assert result.value["start"] == date(2024, 1, 1) - assert result.value["end"] == date(2024, 1, 31) - - def test_bound_model_has_flow_property(self): - """BoundModel has a .flow property.""" - - @Flow.model(context_args=["x"]) - def compute(x: int) -> GenericResult[int]: - return GenericResult(value=x * 2) - - model = compute() - bound = model.flow.with_inputs(x=42) - assert isinstance(bound.flow, FlowAPI) - - def test_bound_model_repr_looks_like_with_inputs_call(self): - """BoundModel repr should mirror the API users wrote.""" - - @Flow.model(context_args=["x"]) - def compute(x: int) -> GenericResult[int]: - return GenericResult(value=x * 2) - - model = compute() - bound = model.flow.with_inputs(x=lambda ctx: ctx.x + 1) - - assert repr(bound) == f"{model!r}.flow.with_inputs(x=)" - - def test_with_inputs_regular_callable_model(self): - """Regular CallableModels support .flow.with_inputs().""" - - model = OffsetModel(offset=1) - shifted = model.flow.with_inputs(x=lambda ctx: ctx.x * 2) - - result = shifted(NumberContext(x=5)) - assert result.value == 11 - - -class TestTypedDictValidation: - """Tests for TypedDict-based context validation.""" - - def test_schema_stored_on_model(self): - """Model stores _context_schema for validation.""" - - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date) -> GenericResult[dict]: - return GenericResult(value={}) - - model = load_data() - assert hasattr(model, "_context_schema") - assert model._context_schema == {"start_date": date, "end_date": date} - - def test_validator_created_lazily(self): - """TypeAdapter validator is created lazily.""" - - @Flow.model(context_args=["x"]) - def compute(x: int) -> GenericResult[int]: - return GenericResult(value=x) - - model = compute() - # Initially None - assert model.__class__._cached_context_validator is None - - # After getting validator, it's cached - validator = model._get_context_validator() - assert validator is not None - assert model.__class__._cached_context_validator is validator - - def test_explicit_context_type_override(self): - """context_type can opt into an existing ContextBase subclass.""" - - @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) - def load_data(start_date: date, end_date: date) -> GenericResult[dict]: - return GenericResult(value={}) - - model = load_data() - assert model.context_type == DateRangeContext - - -class TestPicklingSupport: - """Tests for pickling support (important for Ray). - - Note: Regular pickle cannot pickle locally-defined classes (functions decorated - inside test methods). cloudpickle CAN handle this, which is why Ray uses it. - All tests here use cloudpickle to match Ray's behavior. - """ - - def test_model_cloudpickle_roundtrip(self): - """Model works with cloudpickle (for Ray).""" - - @Flow.model(context_args=["x", "y"]) - def compute(x: int, y: int, multiplier: int = 2) -> GenericResult[int]: - return GenericResult(value=(x + y) * multiplier) - - model = compute(multiplier=3) - - # cloudpickle roundtrip (what Ray uses) - pickled = cloudpickle.dumps(model) - unpickled = cloudpickle.loads(pickled) - - # Should work after unpickling - result = unpickled.flow.compute(x=1, y=2) - assert result.value == 9 # (1 + 2) * 3 - - def test_model_cloudpickle_simple(self): - """Simple model cloudpickle test.""" - - @Flow.model(context_args=["value"]) - def double(value: int) -> GenericResult[int]: - return GenericResult(value=value * 2) - - model = double() - - pickled = cloudpickle.dumps(model) - unpickled = cloudpickle.loads(pickled) +def test_flow_context_basic_properties(): + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), label="x") + assert ctx.start_date == date(2024, 1, 1) + assert ctx.end_date == date(2024, 1, 31) + assert ctx.label == "x" + assert dict(ctx) == {"start_date": date(2024, 1, 1), "end_date": date(2024, 1, 31), "label": "x"} - result = unpickled.flow.compute(value=21) - assert result.value == 42 - def test_validator_recreated_after_cloudpickle(self): - """TypeAdapter validator is recreated after cloudpickling.""" +def test_flow_context_value_semantics_and_hash(): + first = FlowContext(x=1, values=[1, 2]) + second = FlowContext(x=1, values=[1, 2]) + third = FlowContext(x=2, values=[1, 2]) - @Flow.model(context_args=["x"]) - def compute(x: int) -> GenericResult[int]: - return GenericResult(value=x) + assert first == second + assert first != third + assert len({first, second, third}) == 2 - model = compute() - # Warm up the validator cache - _ = model._get_context_validator() - assert model.__class__._cached_context_validator is not None - # cloudpickle and unpickle - pickled = cloudpickle.dumps(model) - unpickled = cloudpickle.loads(pickled) +def test_flow_context_pickle_and_cloudpickle_roundtrip(): + ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31), tags=frozenset({"a", "b"})) + assert pickle.loads(pickle.dumps(ctx)) == ctx + assert cloudpickle.loads(cloudpickle.dumps(ctx)) == ctx - # Validator should still work (may be lazily recreated) - result = unpickled.flow.compute(x=42) - assert result.value == 42 - def test_flow_context_pickle_standard(self): - """FlowContext works with standard pickle.""" - ctx = FlowContext(x=1, y=2, z="test") +def test_flow_api_introspection_for_from_context_model(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c - pickled = pickle.dumps(ctx) - unpickled = pickle.loads(pickled) + model = add(a=10) + assert model.flow.context_inputs == {"b": int, "c": int} + assert model.flow.unbound_inputs == {"b": int} + assert model.flow.bound_inputs == {"a": 10} + assert model.flow.compute(b=2).value == 17 - assert unpickled.x == 1 - assert unpickled.y == 2 - assert unpickled.z == "test" +def test_flow_api_compute_accepts_single_context_or_kwargs_but_not_both(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b -class TestIntegrationWithExistingContextTypes: - """Tests for integration with existing ContextBase subclasses.""" + model = add(a=10) + assert model.flow.compute(b=5).value == 15 + assert model.flow.compute(FlowContext(b=6)).value == 16 - def test_explicit_context_still_works(self): - """Explicit context parameter mode still works.""" + with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + model.flow.compute(FlowContext(b=5), b=6) - @Flow.model - def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: - return GenericResult(value={"start": context.start_date, "end": context.end_date, "source": source}) - model = load_data(source="api") - ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - result = model(ctx) +def test_bound_model_with_inputs_static_and_callable(): + @Flow.model + def load_window(start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult(value={"start": start_date, "end": end_date}) - assert result.value["start"] == date(2024, 1, 1) - assert result.value["source"] == "api" + model = load_window() + shifted = model.flow.with_inputs( + start_date=lambda ctx: ctx.start_date - timedelta(days=7), + end_date=date(2024, 1, 31), + ) - def test_flow_context_coerces_to_date_range(self): - """FlowContext can be used with models expecting DateRangeContext.""" + result = shifted(FlowContext(start_date=date(2024, 1, 8), end_date=date(2024, 1, 30))) + assert result.value == {"start": date(2024, 1, 1), "end": date(2024, 1, 31)} - @Flow.model - def load_data(context: DateRangeContext) -> GenericResult[dict]: - return GenericResult(value={"start": context.start_date, "end": context.end_date}) - model = load_data() - # Use FlowContext - should coerce to DateRangeContext - ctx = FlowContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - result = model(ctx) +def test_bound_model_with_inputs_is_branch_local_and_chained(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value - assert result.value["start"] == date(2024, 1, 1) - assert result.value["end"] == date(2024, 1, 31) + @Flow.model + def combine(left: int, right: int, value: FromContext[int]) -> int: + return left + right + value - def test_flow_api_with_explicit_context(self): - """FlowAPI.compute works with explicit context mode.""" + base = source() + left = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) + right = base.flow.with_inputs(value=lambda ctx: ctx.value + 2).flow.with_inputs(value=lambda ctx: ctx.value + 10) + model = combine(left=left, right=right) - @Flow.model - def load_data(context: DateRangeContext, source: str = "db") -> GenericResult[dict]: - return GenericResult(value={"start": context.start_date, "end": context.end_date}) + assert model.flow.compute(value=5).value == (6 + 15 + 5) - model = load_data(source="api") - result = model.flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - assert result.value["start"] == date(2024, 1, 1) - assert result.value["end"] == date(2024, 1, 31) +def test_bound_model_rejects_regular_field_rewrites(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + with pytest.raises(TypeError, match="only accepts contextual fields"): + add(a=1).flow.with_inputs(a=3) -class TestLazy: - """Tests for Lazy (deferred execution with context overrides).""" - def test_lazy_basic(self): - """Lazy wraps a model for deferred execution.""" - from ccflow import Lazy +def test_bound_model_repr_matches_user_facing_api(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b - @Flow.model(context_args=["value"]) - def compute(value: int, multiplier: int = 2) -> GenericResult[int]: - return GenericResult(value=value * multiplier) + model = add(a=1) + bound = model.flow.with_inputs(b=lambda ctx: ctx.b + 1) + assert repr(bound) == f"{model!r}.flow.with_inputs(b=)" - model = compute(multiplier=3) - lazy = Lazy(model) - assert lazy.model is model +def test_bound_model_serialization_roundtrip_preserves_static_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b - def test_lazy_call_with_static_override(self): - """Lazy.__call__ with static override values.""" - from ccflow import Lazy + bound = add(a=10).flow.with_inputs(b=5) + dumped = bound.model_dump(mode="python") + restored = type(bound).model_validate(dumped) - @Flow.model(context_args=["x", "y"]) - def add(x: int, y: int) -> GenericResult[int]: - return GenericResult(value=x + y) + assert restored.flow.compute().value == 15 + assert restored.model.flow.bound_inputs == {"a": 10} - model = add() - lazy_fn = Lazy(model)(y=100) # Override y to 100 - ctx = FlowContext(x=5, y=10) # Original y=10 - result = lazy_fn(ctx) - assert result.value == 105 # x=5 + y=100 (overridden) +def test_regular_callable_models_still_support_with_inputs(): + model = OffsetModel(offset=10) + shifted = model.flow.with_inputs(x=lambda ctx: ctx.x * 2) + assert shifted(NumberContext(x=5)).value == 20 - def test_lazy_call_with_callable_override(self): - """Lazy.__call__ with callable override (computed at runtime).""" - from ccflow import Lazy - @Flow.model(context_args=["value"]) - def double(value: int) -> GenericResult[int]: - return GenericResult(value=value * 2) +def test_flow_api_for_regular_callable_model(): + model = OffsetModel(offset=10) + assert model.flow.compute(x=5).value == 15 + assert model.flow.context_inputs == {"x": int} + assert model.flow.unbound_inputs == {"x": int} + assert model.flow.bound_inputs == {"offset": 10} - model = double() - # Override value to be original value + 10 - lazy_fn = Lazy(model)(value=lambda ctx: ctx.value + 10) - ctx = FlowContext(value=5) - result = lazy_fn(ctx) - assert result.value == 30 # (5 + 10) * 2 = 30 +def test_generated_flow_model_compute_is_thread_safe(): + @Flow.model + def add(a: int, b: FromContext[int], c: FromContext[int]) -> int: + return a + b + c - def test_lazy_with_date_transforms(self): - """Lazy works with date transforms.""" - from ccflow import Lazy + model = add(a=10) - @Flow.model(context_args=["start_date", "end_date"]) - def load_data(start_date: date, end_date: date) -> GenericResult[dict]: - return GenericResult(value={"start": start_date, "end": end_date}) + def worker(n: int) -> int: + return model.flow.compute(b=n, c=n + 1).value - model = load_data() + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(worker, range(20))) - # Use Lazy to create a transform that shifts dates - lazy_fn = Lazy(model)(start_date=lambda ctx: ctx.start_date - timedelta(days=7), end_date=lambda ctx: ctx.end_date) + assert results == [10 + n + n + 1 for n in range(20)] - ctx = FlowContext(start_date=date(2024, 1, 15), end_date=date(2024, 1, 31)) - result = lazy_fn(ctx) - assert result.value["start"] == date(2024, 1, 8) # 7 days before - assert result.value["end"] == date(2024, 1, 31) +def test_bound_model_restore_is_thread_safe(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b - def test_lazy_multiple_overrides(self): - """Lazy supports multiple overrides at once.""" - from ccflow import Lazy + dumped = add(a=10).flow.with_inputs(b=5).model_dump(mode="python") - @Flow.model(context_args=["a", "b", "c"]) - def compute(a: int, b: int, c: int) -> GenericResult[int]: - return GenericResult(value=a + b + c) + def worker(_: int) -> int: + restored = BoundModel.model_validate(dumped) + return restored.flow.compute().value - model = compute() - lazy_fn = Lazy(model)( - a=10, # Static - b=lambda ctx: ctx.b * 2, # Transform - # c not overridden, uses context value - ) + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(worker, range(20))) - ctx = FlowContext(a=1, b=5, c=100) - result = lazy_fn(ctx) - assert result.value == 10 + 10 + 100 # a=10, b=5*2=10, c=100 + assert results == [15] * 20 diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 57a54df..ed18cd5 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -1,2986 +1,434 @@ -"""Tests for Flow.model decorator.""" +"""Focused tests for the FromContext-based Flow.model API.""" +import graphlib from datetime import date, timedelta -from unittest import TestCase +import pytest +from pydantic import model_validator from ray.cloudpickle import dumps as rcpdumps, loads as rcploads +import ccflow.flow_model as flow_model_module from ccflow import ( - BaseModel, CallableModel, ContextBase, DateRangeContext, Flow, FlowContext, FlowOptionsOverride, + FromContext, GenericResult, Lazy, ModelRegistry, - ResultBase, ) -from ccflow.evaluators.common import MemoryCacheEvaluator +from ccflow.evaluators import GraphEvaluator class SimpleContext(ContextBase): - """Simple context for testing.""" - - value: int - - -class ExtendedContext(ContextBase): - """Extended context with multiple fields.""" - - x: int - y: str = "default" - - -class MyResult(ResultBase): - """Custom result type for testing.""" - - data: str - - -# ============================================================================= -# Basic Flow.model Tests -# ============================================================================= - - -class TestFlowModelBasic(TestCase): - """Basic Flow.model functionality tests.""" - - def test_simple_model_explicit_context(self): - """Test Flow.model with explicit context parameter.""" - - @Flow.model - def simple_loader(context: SimpleContext, multiplier: int) -> GenericResult[int]: - return GenericResult(value=context.value * multiplier) - - # Create model instance - loader = simple_loader(multiplier=3) - - # Should be a CallableModel - self.assertIsInstance(loader, CallableModel) - - # Execute - ctx = SimpleContext(value=10) - result = loader(ctx) - - self.assertIsInstance(result, GenericResult) - self.assertEqual(result.value, 30) - - def test_model_with_default_params(self): - """Test Flow.model with default parameter values.""" - - @Flow.model - def loader_with_defaults(context: SimpleContext, multiplier: int = 2, prefix: str = "result") -> GenericResult[str]: - return GenericResult(value=f"{prefix}:{context.value * multiplier}") - - # Create with defaults - loader = loader_with_defaults() - result = loader(SimpleContext(value=5)) - self.assertEqual(result.value, "result:10") - - # Create with custom values - loader2 = loader_with_defaults(multiplier=3, prefix="custom") - result2 = loader2(SimpleContext(value=5)) - self.assertEqual(result2.value, "custom:15") - - def test_model_context_type_property(self): - """Test that generated model has correct context_type.""" - - @Flow.model - def typed_model(context: ExtendedContext, factor: int) -> GenericResult[int]: - return GenericResult(value=context.x * factor) - - model = typed_model(factor=2) - self.assertEqual(model.context_type, ExtendedContext) - - def test_model_result_type_property(self): - """Test that generated model has correct result_type.""" - - @Flow.model - def custom_result_model(context: SimpleContext) -> MyResult: - return MyResult(data=f"value={context.value}") - - model = custom_result_model() - self.assertEqual(model.result_type, MyResult) - - def test_model_with_no_extra_params(self): - """Test Flow.model with only context parameter.""" - - @Flow.model - def identity_model(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - model = identity_model() - result = model(SimpleContext(value=42)) - self.assertEqual(result.value, 42) - - def test_model_with_flow_options(self): - """Test Flow.model with Flow.call options.""" - - @Flow.model(cacheable=True, validate_result=True) - def cached_model(context: SimpleContext, value: int) -> GenericResult[int]: - return GenericResult(value=value + context.value) - - model = cached_model(value=10) - result = model(SimpleContext(value=5)) - self.assertEqual(result.value, 15) - - def test_model_with_underscore_context(self): - """Test Flow.model with '_' as context parameter (unused context convention).""" - - @Flow.model - def loader(context: SimpleContext, base: int) -> GenericResult[int]: - return GenericResult(value=context.value + base) - - @Flow.model - def consumer(_: SimpleContext, data: int) -> GenericResult[int]: - # Context not used directly, just passed to dependency - return GenericResult(value=data * 2) - - load = loader(base=100) - consume = consumer(data=load) - - result = consume(SimpleContext(value=10)) - # loader: 10 + 100 = 110, consumer: 110 * 2 = 220 - self.assertEqual(result.value, 220) - - # Verify context_type is still correct - self.assertEqual(consume.context_type, SimpleContext) - - -# ============================================================================= -# context_args Mode Tests -# ============================================================================= - - -class TestFlowModelContextArgs(TestCase): - """Tests for Flow.model with context_args (unpacked context).""" - - def test_context_args_basic(self): - """Test basic context_args usage.""" - - @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) - def date_range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: - return GenericResult(value=f"{source}:{start_date} to {end_date}") - - loader = date_range_loader(source="db") - - # Explicit context_type keeps compatibility with existing contexts. - self.assertEqual(loader.context_type, DateRangeContext) - - ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - result = loader(ctx) - self.assertEqual(result.value, "db:2024-01-01 to 2024-01-31") - - def test_context_args_custom_context(self): - """Test context_args with custom context type.""" - - @Flow.model(context_args=["x", "y"]) - def unpacked_model(x: int, y: str, multiplier: int = 1) -> GenericResult[str]: - return GenericResult(value=f"{y}:{x * multiplier}") - - model = unpacked_model(multiplier=2) - - # Default context_args mode uses FlowContext unless overridden explicitly. - self.assertEqual(model.context_type, FlowContext) - - # Create context with generated type - ctx_type = model.context_type - ctx = ctx_type(x=5, y="test") - - result = model(ctx) - self.assertEqual(result.value, "test:10") - - def test_context_args_with_defaults(self): - """Test context_args where context fields have defaults.""" - - @Flow.model(context_args=["value"]) - def model_with_ctx_default(value: int = 42, extra: str = "foo") -> GenericResult[str]: - return GenericResult(value=f"{extra}:{value}") - - model = model_with_ctx_default() - - # Create context - the generated context should allow default - ctx_type = model.context_type - ctx = ctx_type(value=100) - - result = model(ctx) - self.assertEqual(result.value, "foo:100") - - -# ============================================================================= -# Dependency Tests -# ============================================================================= - - -class TestFlowModelDependencies(TestCase): - """Tests for Flow.model with upstream CallableModel inputs.""" - - def test_simple_dependency(self): - """Test passing an upstream model as a normal parameter.""" - - @Flow.model - def loader(context: SimpleContext, value: int) -> GenericResult[int]: - return GenericResult(value=value + context.value) - - @Flow.model - def consumer( - context: SimpleContext, - data: int, - multiplier: int = 1, - ) -> GenericResult[int]: - return GenericResult(value=data * multiplier) - - # Create pipeline - load = loader(value=10) - consume = consumer(data=load, multiplier=2) - - ctx = SimpleContext(value=5) - result = consume(ctx) - - # loader returns 10 + 5 = 15, consumer multiplies by 2 = 30 - self.assertEqual(result.value, 30) - - def test_dependency_with_direct_value(self): - """Test that dependency-shaped parameters can also take direct values.""" - - @Flow.model - def consumer( - context: SimpleContext, - data: int, - ) -> GenericResult[int]: - return GenericResult(value=data + context.value) - - consume = consumer(data=100) - - result = consume(SimpleContext(value=5)) - self.assertEqual(result.value, 105) - - def test_deps_method_generation(self): - """Test that __deps__ method is correctly generated.""" - - @Flow.model - def loader(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def consumer( - context: SimpleContext, - data: int, - ) -> GenericResult[int]: - return GenericResult(value=data) - - load = loader() - consume = consumer(data=load) - - ctx = SimpleContext(value=10) - deps = consume.__deps__(ctx) - - # Should have one dependency - self.assertEqual(len(deps), 1) - self.assertEqual(deps[0][0], load) - self.assertEqual(deps[0][1], [ctx]) - - def test_no_deps_when_direct_value(self): - """Test that __deps__ returns empty when direct values used.""" - - @Flow.model - def consumer( - context: SimpleContext, - data: int, - ) -> GenericResult[int]: - return GenericResult(value=data) - - consume = consumer(data=100) - - deps = consume.__deps__(SimpleContext(value=10)) - self.assertEqual(len(deps), 0) - - -# ============================================================================= -# with_inputs Tests -# ============================================================================= - - -class TestFlowModelWithInputs(TestCase): - """Tests for Flow.model with .flow.with_inputs().""" - - def test_transformed_dependency_with_inputs(self): - """Test dependency context transformation via .flow.with_inputs().""" - - @Flow.model - def loader(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def consumer(context: SimpleContext, data: int) -> GenericResult[int]: - return GenericResult(value=data * 2) - - load = loader().flow.with_inputs(value=lambda ctx: ctx.value + 10) - consume = consumer(data=load) - - result = consume(SimpleContext(value=5)) - self.assertEqual(result.value, 30) - - def test_with_inputs_changes_dependency_context_in_deps(self): - """Test that BoundModel contributes transformed dependency contexts.""" - - @Flow.model - def loader(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def consumer(context: SimpleContext, data: int) -> GenericResult[int]: - return GenericResult(value=data) - - load = loader().flow.with_inputs(value=lambda ctx: ctx.value * 3) - consume = consumer(data=load) - - deps = consume.__deps__(SimpleContext(value=7)) - self.assertEqual(len(deps), 1) - transformed_ctx = deps[0][1][0] - self.assertEqual(transformed_ctx.value, 21) - - def test_date_range_transform_with_inputs(self): - """Test date-range lookback wiring via .flow.with_inputs().""" - - @Flow.model(context_args=["start_date", "end_date"]) - def range_loader(start_date: date, end_date: date, source: str) -> GenericResult[str]: - return GenericResult(value=f"{source}:{start_date}") - - @Flow.model(context_args=["start_date", "end_date"]) - def range_processor( - start_date: date, - end_date: date, - data: str, - ) -> GenericResult[str]: - return GenericResult(value=f"processed:{data}") - - loader = range_loader(source="db").flow.with_inputs(start_date=lambda ctx: ctx.start_date - timedelta(days=1)) - processor = range_processor(data=loader) - - ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) - result = processor(ctx) - self.assertEqual(result.value, "processed:db:2024-01-09") - - -# ============================================================================= -# Pipeline Tests -# ============================================================================= - - -class TestFlowModelPipeline(TestCase): - """Tests for multi-stage pipelines with Flow.model.""" - - def test_three_stage_pipeline(self): - """Test a three-stage computation pipeline.""" - - @Flow.model - def stage1(context: SimpleContext, base: int) -> GenericResult[int]: - return GenericResult(value=context.value + base) - - @Flow.model - def stage2( - context: SimpleContext, - input_data: int, - multiplier: int, - ) -> GenericResult[int]: - return GenericResult(value=input_data * multiplier) - - @Flow.model - def stage3( - context: SimpleContext, - input_data: int, - offset: int = 0, - ) -> GenericResult[int]: - return GenericResult(value=input_data + offset) - - # Build pipeline - s1 = stage1(base=100) - s2 = stage2(input_data=s1, multiplier=2) - s3 = stage3(input_data=s2, offset=50) - - ctx = SimpleContext(value=10) - result = s3(ctx) - - # s1: 10 + 100 = 110 - # s2: 110 * 2 = 220 - # s3: 220 + 50 = 270 - self.assertEqual(result.value, 270) - - def test_diamond_dependency_pattern(self): - """Test diamond-shaped dependency pattern.""" - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def branch_a( - context: SimpleContext, - data: int, - ) -> GenericResult[int]: - return GenericResult(value=data * 2) - - @Flow.model - def branch_b( - context: SimpleContext, - data: int, - ) -> GenericResult[int]: - return GenericResult(value=data + 100) - - @Flow.model - def merger( - context: SimpleContext, - a: int, - b: int, - ) -> GenericResult[int]: - return GenericResult(value=a + b) - - src = source() - a = branch_a(data=src) - b = branch_b(data=src) - merge = merger(a=a, b=b) - - ctx = SimpleContext(value=10) - result = merge(ctx) - - # source: 10 - # branch_a: 10 * 2 = 20 - # branch_b: 10 + 100 = 110 - # merger: 20 + 110 = 130 - self.assertEqual(result.value, 130) - - -# ============================================================================= -# Integration Tests -# ============================================================================= - - -class TestFlowModelIntegration(TestCase): - """Integration tests for Flow.model with ccflow infrastructure.""" - - def test_registry_integration(self): - """Test that Flow.model models work with ModelRegistry.""" - - @Flow.model - def registrable_model(context: SimpleContext, value: int) -> GenericResult[int]: - return GenericResult(value=context.value + value) - - model = registrable_model(value=100) - - registry = ModelRegistry.root().clear() - registry.add("test_model", model) - - retrieved = registry["test_model"] - self.assertEqual(retrieved, model) - - result = retrieved(SimpleContext(value=10)) - self.assertEqual(result.value, 110) - - def test_serialization_dump(self): - """Test that generated models can be serialized.""" - - @Flow.model - def serializable_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: - return GenericResult(value=value) - - model = serializable_model(value=100) - dumped = model.model_dump(mode="python") - - self.assertIn("value", dumped) - self.assertEqual(dumped["value"], 100) - self.assertIn("type_", dumped) - - def test_serialization_roundtrip_preserves_bound_inputs(self): - """Round-tripping should preserve which inputs were bound at construction.""" - - @Flow.model - def add(x: int, y: int) -> int: - return x + y - - model = add(x=10) - dumped = model.model_dump(mode="python") - restored = type(model).model_validate(dumped) - - self.assertEqual(dumped["x"], 10) - self.assertNotIn("y", dumped) - self.assertEqual(restored.flow.bound_inputs, {"x": 10}) - self.assertEqual(restored.flow.unbound_inputs, {"y": int}) - self.assertEqual(restored.flow.compute(y=5).value, 15) - - def test_serialization_roundtrip_preserves_defaults_and_deferred_inputs(self): - """Default-valued params should serialize normally without binding runtime-only inputs.""" - - @Flow.model - def load(start_date: str, source: str = "warehouse") -> str: - return f"{source}:{start_date}" - - model = load() - dumped = model.model_dump(mode="python") - restored = type(model).model_validate(dumped) - - self.assertEqual(dumped["source"], "warehouse") - self.assertNotIn("start_date", dumped) - self.assertEqual(restored.flow.bound_inputs, {"source": "warehouse"}) - self.assertEqual(restored.flow.unbound_inputs, {"start_date": str}) - self.assertEqual(restored.flow.compute(start_date="2024-01-01").value, "warehouse:2024-01-01") - - def test_pickle_roundtrip(self): - """Test cloudpickle serialization of generated models.""" - - @Flow.model - def pickleable_model(context: SimpleContext, factor: int) -> GenericResult[int]: - return GenericResult(value=context.value * factor) - - model = pickleable_model(factor=3) - - # Cloudpickle roundtrip (standard pickle won't work for local classes) - pickled = rcpdumps(model, protocol=5) - restored = rcploads(pickled) - - result = restored(SimpleContext(value=10)) - self.assertEqual(result.value, 30) - - def test_mix_with_manual_callable_model(self): - """Test mixing Flow.model with manually defined CallableModel.""" - - class ManualModel(CallableModel): - offset: int - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value + self.offset) - - @Flow.model - def generated_consumer( - context: SimpleContext, - data: int, - multiplier: int, - ) -> GenericResult[int]: - return GenericResult(value=data * multiplier) - - manual = ManualModel(offset=50) - generated = generated_consumer(data=manual, multiplier=2) - - result = generated(SimpleContext(value=10)) - # manual: 10 + 50 = 60 - # generated: 60 * 2 = 120 - self.assertEqual(result.value, 120) - - -# ============================================================================= -# Error Case Tests -# ============================================================================= - - -class TestFlowModelErrors(TestCase): - """Error case tests for Flow.model.""" - - def test_missing_return_type(self): - """Test error when return type annotation is missing.""" - with self.assertRaises(TypeError) as cm: - - @Flow.model - def no_return(context: SimpleContext): - return GenericResult(value=1) - - self.assertIn("return type annotation", str(cm.exception)) - - def test_auto_wrap_plain_return_type(self): - """Test that non-ResultBase return types are auto-wrapped in GenericResult.""" - - @Flow.model - def plain_return(context: SimpleContext) -> int: - return context.value * 2 - - model = plain_return() - result = model(SimpleContext(value=5)) - self.assertIsInstance(result, GenericResult) - self.assertEqual(result.value, 10) - - def test_auto_wrap_unwrap_as_dependency(self): - """Test that auto-wrapped model used as dep delivers unwrapped value downstream. - - Auto-wrapped models have result_type=GenericResult (unparameterized). - When used as an auto-detected dep, the framework resolves - the GenericResult to its inner value for the downstream function. - """ - - @Flow.model - def plain_source(context: SimpleContext) -> int: - return context.value * 3 - - @Flow.model - def consumer( - context: SimpleContext, - data: GenericResult[int], # Auto-detected dep - ) -> GenericResult[int]: - # data is auto-unwrapped to the int value by the framework - return GenericResult(value=data + 1) - - src = plain_source() - model = consumer(data=src) - result = model(SimpleContext(value=10)) - # plain_source: 10 * 3 = 30, auto-wrapped to GenericResult(value=30) - # resolve_callable_model unwraps GenericResult -> 30 - # consumer: 30 + 1 = 31 - self.assertEqual(result.value, 31) - - def test_auto_wrap_result_type_property(self): - """Test that auto-wrapped model has GenericResult as result_type.""" - - @Flow.model - def plain_return(context: SimpleContext) -> int: - return context.value - - model = plain_return() - self.assertEqual(model.result_type, GenericResult) - - def test_dynamic_deferred_mode(self): - """Test dynamic deferred mode where what you provide at construction = bound.""" - from ccflow import FlowContext - - @Flow.model - def dynamic_model(value: int, multiplier: int) -> GenericResult[int]: - return GenericResult(value=value * multiplier) - - # Provide 'multiplier' at construction -> it's bound - # Don't provide 'value' -> comes from context - model = dynamic_model(multiplier=3) - - # Check bound vs unbound - self.assertEqual(model.flow.bound_inputs, {"multiplier": 3}) - self.assertEqual(model.flow.unbound_inputs, {"value": int}) - - # Call with context providing 'value' - ctx = FlowContext(value=10) - result = model(ctx) - self.assertEqual(result.value, 30) # 10 * 3 - - def test_dynamic_deferred_mode_missing_runtime_inputs_is_clear(self): - """Missing deferred inputs should fail at the framework boundary.""" - - @Flow.model - def dynamic_model(value: int, multiplier: int) -> int: - return value * multiplier - - model = dynamic_model() - - with self.assertRaises(TypeError) as cm: - model.flow.compute() - - self.assertIn("Missing runtime input(s) for dynamic_model: multiplier, value", str(cm.exception)) - - def test_all_defaults_is_valid(self): - """All-default functions should treat those defaults as bound config.""" - from ccflow import FlowContext - - @Flow.model - def all_defaults(value: int = 1, other: str = "x") -> GenericResult[str]: - return GenericResult(value=f"{value}-{other}") - - model = all_defaults() - - self.assertEqual(model.flow.bound_inputs, {"value": 1, "other": "x"}) - self.assertEqual(model.flow.unbound_inputs, {}) - - ctx = FlowContext(value=5, other="y") - result = model(ctx) - self.assertEqual(result.value, "1-x") - - def test_invalid_context_arg(self): - """Test error when context_args refers to non-existent parameter.""" - with self.assertRaises(ValueError) as cm: - - @Flow.model(context_args=["nonexistent"]) - def bad_context_args(x: int) -> GenericResult[int]: - return GenericResult(value=x) - - self.assertIn("nonexistent", str(cm.exception)) - - def test_context_arg_without_annotation(self): - """Test error when context_arg parameter lacks type annotation.""" - with self.assertRaises(ValueError) as cm: - - @Flow.model(context_args=["x"]) - def untyped_context_arg(x) -> GenericResult[int]: - return GenericResult(value=x) - - self.assertIn("type annotation", str(cm.exception)) - - def test_context_type_requires_context_args_mode(self): - """context_type is only valid alongside context_args.""" - with self.assertRaises(TypeError) as cm: - - @Flow.model(context_type=DateRangeContext) - def dynamic_model(value: int) -> GenericResult[int]: - return GenericResult(value=value) - - self.assertIn("context_args", str(cm.exception)) - - def test_context_type_must_cover_context_args(self): - """context_type must expose all named context_args fields.""" - - class StartOnlyContext(ContextBase): - start_date: date - - with self.assertRaises(TypeError) as cm: - - @Flow.model(context_args=["start_date", "end_date"], context_type=StartOnlyContext) - def load_data(start_date: date, end_date: date) -> GenericResult[dict]: - return GenericResult(value={}) - - self.assertIn("end_date", str(cm.exception)) - - -# ============================================================================= -# Validation Tests -# ============================================================================= - - -class TestFlowModelValidation(TestCase): - """Tests for Flow.model validation behavior.""" - - def test_config_validation_rejects_bad_type(self): - """Test that config validator rejects wrong types at construction.""" - - @Flow.model - def typed_config(context: SimpleContext, n_estimators: int = 10) -> GenericResult[int]: - return GenericResult(value=n_estimators) - - with self.assertRaises(TypeError) as cm: - typed_config(n_estimators="banana") - - self.assertIn("n_estimators", str(cm.exception)) - - def test_config_validation_accepts_callable_model(self): - """Test that config validator allows CallableModel values for any field.""" - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: - return GenericResult(value=data) - - # Passing a CallableModel for an int field should not raise - src = source() - model = consumer(data=src) - self.assertIsNotNone(model) - - def test_config_validation_accepts_correct_types(self): - """Test that config validator accepts correct types.""" - - @Flow.model - def typed_config(context: SimpleContext, n: int = 10, name: str = "x") -> GenericResult[str]: - return GenericResult(value=f"{name}:{n}") - - # Should not raise - model = typed_config(n=42, name="test") - result = model(SimpleContext(value=1)) - self.assertEqual(result.value, "test:42") - - def test_config_validation_rejects_registry_alias_for_incompatible_type(self): - """Registry aliases should not silently bypass scalar type validation.""" - - class DummyConfig(BaseModel): - x: int = 1 - - registry = ModelRegistry.root() - registry.clear() - try: - registry.add("dummy_config", DummyConfig()) - - @Flow.model - def typed_config(context: SimpleContext, n: int = 10) -> GenericResult[int]: - return GenericResult(value=n) - - with self.assertRaises(TypeError) as cm: - typed_config(n="dummy_config") - - self.assertIn("n", str(cm.exception)) - finally: - registry.clear() - - def test_config_validation_accepts_registry_alias_for_callable_dependency(self): - """Registry aliases still work for CallableModel dependencies.""" - - registry = ModelRegistry.root() - registry.clear() - try: - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 2) - - @Flow.model - def consumer(context: SimpleContext, data: int = 0) -> GenericResult[int]: - return GenericResult(value=data + 1) - - registry.add("source_model", source()) - model = consumer(data="source_model") - - result = model(SimpleContext(value=5)) - self.assertEqual(result.value, 11) - finally: - registry.clear() - - def test_context_type_annotation_mismatch_raises(self): - """context_type validation should reject incompatible field annotations.""" - - class StringIdContext(ContextBase): - item_id: str - - with self.assertRaises(TypeError) as cm: - - @Flow.model(context_args=["item_id"], context_type=StringIdContext) - def load(item_id: int) -> int: - return item_id - - self.assertIn("item_id", str(cm.exception)) - self.assertIn("int", str(cm.exception)) - self.assertIn("str", str(cm.exception)) - - def test_model_validate_rejects_bad_scalar_type(self): - """model_validate should reject wrong scalar types, not silently accept them.""" - - @Flow.model - def source(context: SimpleContext, x: int) -> GenericResult[int]: - return GenericResult(value=x) - - cls = type(source(x=1)) - with self.assertRaises(TypeError) as cm: - cls.model_validate({"x": "abc"}) - - self.assertIn("x", str(cm.exception)) - - def test_model_validate_accepts_correct_type(self): - """model_validate should accept correct types.""" - - @Flow.model - def source(context: SimpleContext, x: int) -> GenericResult[int]: - return GenericResult(value=x) - - cls = type(source(x=1)) - restored = cls.model_validate({"x": 42}) - self.assertEqual(restored(SimpleContext(value=0)).value, 42) - - def test_model_validate_rejects_bad_registry_alias(self): - """Typoed registry aliases should not silently pass through model_validate.""" - - registry = ModelRegistry.root() - registry.clear() - try: - - @Flow.model - def consumer(context: SimpleContext, n: int = 10) -> GenericResult[int]: - return GenericResult(value=n) - - cls = type(consumer(n=1)) - # "not_in_registry" is not a valid int and not a valid registry key - with self.assertRaises(TypeError) as cm: - cls.model_validate({"n": "not_in_registry"}) - self.assertIn("n", str(cm.exception)) - finally: - registry.clear() - - def test_context_type_compatible_annotations_accepted(self): - """context_type validation should accept matching or subclass annotations.""" - - # Exact match should work - @Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) - def load_exact(start_date: date, end_date: date) -> str: - return f"{start_date}" - - self.assertIsNotNone(load_exact) - - -# ============================================================================= -# BoundModel Tests -# ============================================================================= - - -class TestBoundModel(TestCase): - """Tests for BoundModel and BoundModel.flow.""" - - def test_bound_model_is_callable_model(self): - """BoundModel should be a proper CallableModel subclass.""" - - @Flow.model - def source(x: int) -> int: - return x * 10 - - bound = source().flow.with_inputs(x=lambda ctx: ctx.x * 2) - self.assertIsInstance(bound, CallableModel) - - def test_bound_model_flow_compute(self): - """Test that bound.flow.compute() honors transforms.""" - - @Flow.model - def my_model(x: int, y: int) -> GenericResult[int]: - return GenericResult(value=x + y) - - model = my_model(x=10) - - # Create bound model with y transform - bound = model.flow.with_inputs(y=lambda ctx: getattr(ctx, "y", 0) * 2) - - # flow.compute() should go through BoundModel, applying transform - result = bound.flow.compute(y=5) - # y transform: 5 * 2 = 10, x is bound to 10 - # model: 10 + 10 = 20 - self.assertEqual(result.value, 20) - - def test_bound_model_flow_compute_static_transform(self): - """Test BoundModel.flow.compute() with static value transform.""" - - @Flow.model - def my_model(x: int, y: int) -> GenericResult[int]: - return GenericResult(value=x * y) - - model = my_model(x=7) - bound = model.flow.with_inputs(y=3) - - result = bound.flow.compute(y=999) # y should be overridden by transform - # y is statically bound to 3, x=7 - # 7 * 3 = 21 - self.assertEqual(result.value, 21) - - def test_bound_model_dump_validate_roundtrip_static(self): - """Static transforms survive model_dump → model_validate roundtrip.""" - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - bound = source().flow.with_inputs(value=42) - dump = bound.model_dump(mode="python") - restored = type(bound).model_validate(dump) - - ctx = SimpleContext(value=1) - self.assertEqual(bound(ctx).value, 420) - self.assertEqual(restored(ctx).value, 420) - - def test_bound_model_validate_same_payload_twice(self): - """Validating the same serialized BoundModel payload twice should work both times.""" - from ccflow.flow_model import BoundModel - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - bound = source().flow.with_inputs(value=42) - dump = bound.model_dump(mode="python") - - r1 = BoundModel.model_validate(dump) - r2 = BoundModel.model_validate(dump) - - ctx = SimpleContext(value=1) - self.assertEqual(r1(ctx).value, 420) - self.assertEqual(r2(ctx).value, 420) - - def test_bound_model_failed_validate_does_not_poison_next_construction(self): - """A failed model_validate must not leak static transforms to subsequent constructions.""" - from ccflow.flow_model import BoundModel - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - base = source() - - # Attempt a model_validate that will fail (invalid model field) - try: - BoundModel.model_validate( - { - "model": "not-a-real-model", - "_static_transforms": {"value": 42}, - "_input_transforms_token": {"value": "42"}, - } - ) - except Exception: - pass # Expected to fail - - # Now construct a fresh BoundModel normally — must NOT inherit stale transforms - clean = BoundModel(model=base, input_transforms={}) - ctx = SimpleContext(value=1) - self.assertEqual(clean(ctx).value, 10) # 1 * 10, no transform applied - - def test_bound_model_cloudpickle_with_lambda_transform(self): - """BoundModel with lambda transforms should survive cloudpickle round-trip.""" - - @Flow.model - def my_model(x: int, y: int) -> int: - return x + y - - bound = my_model(x=10).flow.with_inputs(y=lambda ctx: ctx.y * 2) - restored = rcploads(rcpdumps(bound, protocol=5)) - - self.assertEqual(restored.flow.compute(y=6).value, 22) - - def test_bound_model_as_dependency(self): - """Test that BoundModel can be passed as a dependency to another model.""" - - @Flow.model - def source(x: int) -> GenericResult[int]: - return GenericResult(value=x * 10) - - @Flow.model - def consumer(data: GenericResult[int]) -> GenericResult[int]: - return GenericResult(value=data + 1) - - src = source() - bound_src = src.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) - - # Pass BoundModel as a dependency - model = consumer(data=bound_src) - result = model.flow.compute(x=5) - # x transform: 5 * 2 = 10 - # source: 10 * 10 = 100 - # consumer: 100 + 1 = 101 - self.assertEqual(result.value, 101) - - def test_flow_compute_with_upstream_callable_model_dependency(self): - """flow.compute() should resolve upstream generated-model dependencies.""" - - @Flow.model - def source(x: int) -> GenericResult[int]: - return GenericResult(value=x * 10) - - @Flow.model - def consumer(data: GenericResult[int], offset: int = 1) -> int: - return data + offset - - model = consumer(data=source(), offset=3) - self.assertEqual(model.flow.compute(x=5).value, 53) - - def test_bound_model_chained_with_inputs(self): - """Test that chaining with_inputs merges transforms correctly.""" - - @Flow.model - def my_model(x: int, y: int, z: int) -> int: - return x + y + z - - model = my_model() - bound1 = model.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) - bound2 = bound1.flow.with_inputs(y=lambda ctx: getattr(ctx, "y", 0) * 3) - - # Both transforms should be active - result = bound2.flow.compute(x=5, y=10, z=1) - # x transform: 5 * 2 = 10 - # y transform: 10 * 3 = 30 - # z from context: 1 - # 10 + 30 + 1 = 41 - self.assertEqual(result.value, 41) - - def test_bound_model_chained_with_inputs_override(self): - """Test that chaining with_inputs allows overriding transforms.""" - - @Flow.model - def my_model(x: int) -> int: - return x - - model = my_model() - bound1 = model.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 2) - bound2 = bound1.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) * 10) - - # Second transform should override the first for 'x' - result = bound2.flow.compute(x=5) - self.assertEqual(result.value, 50) # 5 * 10, not 5 * 2 - - def test_bound_model_with_default_args(self): - """with_inputs works when the model has parameters with default values.""" - - @Flow.model - def load(start_date: str, end_date: str, source: str = "warehouse") -> str: - return f"{source}:{start_date}-{end_date}" - - # Bind source at construction, leave dates for context - model = load(source="prod_db") - - # with_inputs transforms a context param; default-valued 'source' stays bound - lookback = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) - - result = lookback.flow.compute(start_date="2024-01-01", end_date="2024-06-30") - self.assertEqual(result.value, "prod_db:shifted_2024-01-01-2024-06-30") - - def test_bound_model_with_default_arg_uses_default(self): - """with_inputs should preserve omitted Python defaults as bound config.""" - - @Flow.model - def load(start_date: str, source: str = "warehouse") -> str: - return f"{source}:{start_date}" - - model = load() - - bound = model.flow.with_inputs(start_date=lambda ctx: "shifted_" + ctx.start_date) - - self.assertEqual(model.flow.bound_inputs, {"source": "warehouse"}) - self.assertEqual(model.flow.unbound_inputs, {"start_date": str}) - - result = bound.flow.compute(start_date="2024-01-01") - self.assertEqual(result.value, "warehouse:shifted_2024-01-01") - - def test_bound_model_default_arg_as_dependency(self): - """BoundModel with default args works correctly as a dependency.""" - - @Flow.model - def source(x: int, multiplier: int = 2) -> int: - return x * multiplier - - @Flow.model - def consumer(data: int) -> int: - return data + 1 - - src = source(multiplier=5) - bound_src = src.flow.with_inputs(x=lambda ctx: ctx.x * 10) - model = consumer(data=bound_src) - - result = model.flow.compute(x=3) - # x transform: 3 * 10 = 30 - # source: 30 * 5 (multiplier) = 150 - # consumer: 150 + 1 = 151 - self.assertEqual(result.value, 151) - - def test_bound_model_as_lazy_dependency(self): - """Test that BoundModel works as a Lazy dependency.""" - - @Flow.model - def source(x: int) -> GenericResult[int]: - return GenericResult(value=x * 3) - - @Flow.model - def consumer(data: int, slow: Lazy[GenericResult[int]]) -> GenericResult[int]: - if data > 100: - return GenericResult(value=data) - return GenericResult(value=slow()) - - src = source() - bound_src = src.flow.with_inputs(x=lambda ctx: getattr(ctx, "x", 0) + 10) - - # Use BoundModel as lazy dependency - model = consumer(data=5, slow=bound_src) - result = model.flow.compute(x=7) - # data=5 < 100, so slow path: x transform: 7+10=17, source: 17*3=51 - self.assertEqual(result.value, 51) - - def test_differently_transformed_bound_models_have_distinct_cache_keys(self): - """Two BoundModels with different transforms must not collide under caching.""" - - call_counts = {"source": 0} - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - call_counts["source"] += 1 - return GenericResult(value=context.value * 10) - - base = source() - b1 = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) - b2 = base.flow.with_inputs(value=lambda ctx: ctx.value + 2) - evaluator = MemoryCacheEvaluator() - ctx = SimpleContext(value=5) - - with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): - r1 = b1(ctx) - r2 = b2(ctx) - - # b1 transforms value to 6, source: 6*10=60 - # b2 transforms value to 7, source: 7*10=70 - self.assertEqual(r1.value, 60) - self.assertEqual(r2.value, 70) - # Source called twice (once per distinct transformed context) - self.assertEqual(call_counts["source"], 2) - - def test_bound_and_unbound_models_share_memory_cache(self): - """Shifted and unshifted models should share one evaluator cache. - - They should not share the same cache key when the effective contexts - differ, but repeated evaluations of either model should still hit the - same underlying MemoryCacheEvaluator instance rather than re-executing. - """ - - call_counts = {"source": 0} - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - call_counts["source"] += 1 - return GenericResult(value=context.value * 10) - - base = source() - shifted = base.flow.with_inputs(value=lambda ctx: ctx.value + 1) - evaluator = MemoryCacheEvaluator() - ctx = SimpleContext(value=5) - - with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): - self.assertEqual(base(ctx).value, 50) - self.assertEqual(shifted(ctx).value, 60) - self.assertEqual(base(ctx).value, 50) - self.assertEqual(shifted(ctx).value, 60) - - # One execution for the unshifted context and one for the shifted context. - self.assertEqual(call_counts["source"], 2) - # Cache has 3 entries: base(ctx), BoundModel(ctx), and base(shifted_ctx). - # BoundModel is a proper CallableModel now, so it gets its own cache entry. - self.assertEqual(len(evaluator.cache), 3) - - def test_transform_error_propagates(self): - """A buggy transform should raise, not silently fall back to FlowContext.""" - - @Flow.model - def load(context: DateRangeContext, source: str = "db") -> str: - return f"{source}:{context.start_date}" - - model = load() - # Transform has a typo — ctx.sart_date instead of ctx.start_date - bound = model.flow.with_inputs(start_date=lambda ctx: ctx.sart_date) - - ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - with self.assertRaises(AttributeError): - bound(ctx) - - def test_transform_validation_error_propagates(self): - """If transforms produce invalid context data, the error should surface.""" - from pydantic import ValidationError - - @Flow.model - def load(context: DateRangeContext, source: str = "db") -> str: - return f"{source}:{context.start_date}" - - model = load() - # Transform returns a string where a date is expected - bound = model.flow.with_inputs(start_date="not-a-date") - - ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - # Pydantic validation should raise, not silently fall back to FlowContext - with self.assertRaises(ValidationError): - bound(ctx) - - -class TestFlowModelPipe(TestCase): - """Tests for the ``.pipe(..., param=...)`` convenience API.""" - - def test_pipe_infers_single_required_parameter(self): - """pipe() should infer the only required downstream parameter.""" - - @Flow.model - def source(x: int) -> GenericResult[int]: - return GenericResult(value=x * 10) - - @Flow.model - def consumer(data: int, offset: int = 1) -> int: - return data + offset - - pipeline = source().pipe(consumer, offset=3) - self.assertEqual(pipeline.flow.compute(x=5).value, 53) - - def test_pipe_infers_single_defaulted_parameter(self): - """pipe() should fall back to a single defaulted downstream parameter.""" - - @Flow.model - def source(x: int) -> int: - return x * 10 - - @Flow.model - def consumer(data: int = 0) -> int: - return data + 1 - - pipeline = source().pipe(consumer) - self.assertEqual(pipeline.flow.compute(x=5).value, 51) - - def test_pipe_param_disambiguates_multiple_parameters(self): - """param= should identify the downstream argument to bind.""" - - @Flow.model - def source(x: int) -> int: - return x * 10 - - @Flow.model - def combine(left: int, right: int) -> int: - return left + right - - pipeline = source().pipe(combine, param="right", left=7) - self.assertEqual(pipeline.flow.compute(x=5).value, 57) - - def test_pipe_rejects_ambiguous_downstream_stage(self): - """pipe() should require param= when multiple targets are available.""" - - @Flow.model - def source(x: int) -> int: - return x - - @Flow.model - def combine(left: int, right: int) -> int: - return left + right - - with self.assertRaisesRegex( - TypeError, - r"pipe\(\) could not infer a target parameter for combine; unbound candidates are: left, right", - ): - source().pipe(combine) - - def test_manual_callable_model_can_pipe_into_generated_stage(self): - """Hand-written CallableModels should be usable as pipe sources.""" - - class ManualModel(CallableModel): - offset: int - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value + self.offset) - - @Flow.model - def consumer(data: int, multiplier: int) -> int: - return data * multiplier - - pipeline = ManualModel(offset=5).pipe(consumer, multiplier=2) - self.assertEqual(pipeline.flow.compute(value=10).value, 30) - - def test_bound_model_pipe_preserves_downstream_transforms(self): - """pipe() should keep downstream with_inputs transforms intact.""" - - @Flow.model - def source(x: int) -> int: - return x * 10 - - @Flow.model - def consumer(data: int, scale: int) -> int: - return data + scale - - shifted_source = source().flow.with_inputs(x=lambda ctx: ctx.scale + 1) - scaled_consumer = consumer().flow.with_inputs(scale=lambda ctx: ctx.scale * 3) - - pipeline = shifted_source.pipe(scaled_consumer) - self.assertEqual(pipeline.flow.compute(scale=2).value, 76) - - -# ============================================================================= -# PEP 563 (from __future__ import annotations) Compatibility Tests -# ============================================================================= - -# These functions are defined at module level to simulate realistic usage. -# Note: We can't use `from __future__ import annotations` at module level -# since it would affect ALL annotations in this file. Instead, we test -# that the annotation resolution code handles string annotations. - - -class TestPEP563Annotations(TestCase): - """Test that Flow.model handles string annotations (PEP 563).""" - - def test_string_annotation_lazy_resolved(self): - """Test that Lazy annotations work even when passed through get_type_hints. - - This verifies the fix for from __future__ import annotations by - confirming the annotation resolution pipeline processes Lazy correctly. - """ - # Verify _extract_lazy handles real type objects (resolved by get_type_hints) - from ccflow.flow_model import _extract_lazy - - lazy_int = Lazy[int] - unwrapped, is_lazy = _extract_lazy(lazy_int) - self.assertTrue(is_lazy) - self.assertEqual(unwrapped, int) - - def test_string_annotation_return_type_resolved(self): - """Test that string return type annotations are resolved correctly.""" - - @Flow.model - def model_func(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=42) - - # If annotation resolution works, this should create successfully - model = model_func() - self.assertEqual(model.result_type, GenericResult[int]) - - def test_auto_wrap_with_resolved_annotations(self): - """Test that auto-wrap works with properly resolved type annotations.""" - - @Flow.model - def plain_model(value: int) -> int: - return value * 2 - - model = plain_model() - result = model.flow.compute(value=5) - self.assertEqual(result.value, 10) - self.assertEqual(model.result_type, GenericResult) - - -# ============================================================================= -# Hydra Integration Tests -# ============================================================================= - - -# Define Flow.model functions at module level for Hydra to find them -@Flow.model -def hydra_basic_model(context: SimpleContext, value: int, name: str = "default") -> GenericResult[str]: - """Module-level model for Hydra testing.""" - return GenericResult(value=f"{name}:{context.value + value}") - - -# --- Additional module-level fixtures for Hydra YAML tests --- - - -@Flow.model -def basic_loader(context: SimpleContext, source: str, multiplier: int = 1) -> GenericResult[int]: - """Basic loader that multiplies context value by multiplier.""" - return GenericResult(value=context.value * multiplier) - - -@Flow.model -def string_processor(context: SimpleContext, prefix: str, suffix: str = "") -> GenericResult[str]: - """Process context value into a string with prefix and suffix.""" - return GenericResult(value=f"{prefix}{context.value}{suffix}") - - -@Flow.model -def data_source(context: SimpleContext, base_value: int) -> GenericResult[int]: - """Source that provides base data.""" - return GenericResult(value=context.value + base_value) - - -@Flow.model -def data_transformer( - context: SimpleContext, - source: int, - factor: int = 2, -) -> GenericResult[int]: - """Transform data by multiplying with factor.""" - return GenericResult(value=source * factor) - - -@Flow.model -def data_aggregator( - context: SimpleContext, - input_a: int, - input_b: int, - operation: str = "add", -) -> GenericResult[int]: - """Aggregate two inputs.""" - if operation == "add": - return GenericResult(value=input_a + input_b) - elif operation == "multiply": - return GenericResult(value=input_a * input_b) - else: - return GenericResult(value=input_a - input_b) - - -@Flow.model -def pipeline_stage1(context: SimpleContext, initial: int) -> GenericResult[int]: - """First stage of pipeline.""" - return GenericResult(value=context.value + initial) - - -@Flow.model -def pipeline_stage2( - context: SimpleContext, - stage1_output: int, - multiplier: int = 2, -) -> GenericResult[int]: - """Second stage of pipeline.""" - return GenericResult(value=stage1_output * multiplier) - - -@Flow.model -def pipeline_stage3( - context: SimpleContext, - stage2_output: int, - offset: int = 0, -) -> GenericResult[int]: - """Third stage of pipeline.""" - return GenericResult(value=stage2_output + offset) - - -@Flow.model -def date_range_loader( - context: DateRangeContext, - source: str, - include_weekends: bool = True, -) -> GenericResult[dict]: - """Load data for a date range.""" - return GenericResult( - value={ - "source": source, - "start_date": str(context.start_date), - "end_date": str(context.end_date), - } - ) - - -@Flow.model -def date_range_loader_previous_day( - context: DateRangeContext, - source: str, - include_weekends: bool = True, -) -> dict: - """Hydra helper that applies a one-day lookback before delegating.""" - shifted = context.model_copy(update={"start_date": context.start_date - timedelta(days=1)}) - return date_range_loader(source=source, include_weekends=include_weekends)(shifted).value - - -@Flow.model -def date_range_processor( - context: DateRangeContext, - raw_data: dict, - normalize: bool = False, -) -> GenericResult[str]: - """Process date range data.""" - prefix = "normalized:" if normalize else "raw:" - return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") - - -@Flow.model -def hydra_default_model(context: SimpleContext, value: int = 42) -> GenericResult[int]: - """Module-level model with defaults for Hydra testing.""" - return GenericResult(value=context.value + value) - - -@Flow.model -def hydra_source_model(context: SimpleContext, base: int) -> GenericResult[int]: - """Source model for dependency testing.""" - return GenericResult(value=context.value * base) - - -@Flow.model -def hydra_consumer_model( - context: SimpleContext, - source: int, - factor: int = 1, -) -> GenericResult[int]: - """Consumer model for dependency testing.""" - return GenericResult(value=source * factor) - - -# --- context_args fixtures for Hydra testing --- - - -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def context_args_loader(start_date: date, end_date: date, source: str) -> GenericResult[dict]: - """Loader using context_args with DateRangeContext.""" - return GenericResult( - value={ - "source": source, - "start_date": str(start_date), - "end_date": str(end_date), - } - ) - - -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def context_args_processor( - start_date: date, - end_date: date, - data: dict, - prefix: str = "processed", -) -> GenericResult[str]: - """Processor using context_args with dependency.""" - return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") - - -class TestFlowModelHydra(TestCase): - """Tests for Flow.model with Hydra configuration.""" - - def test_hydra_instantiate_basic(self): - """Test that Flow.model factory can be instantiated via Hydra.""" - from hydra.utils import instantiate - from omegaconf import OmegaConf - - # Create config that references the factory function by module path - cfg = OmegaConf.create( - { - "_target_": "ccflow.tests.test_flow_model.hydra_basic_model", - "value": 100, - "name": "test", - } - ) - - # Instantiate via Hydra - model = instantiate(cfg) - - self.assertIsInstance(model, CallableModel) - result = model(SimpleContext(value=10)) - self.assertEqual(result.value, "test:110") - - def test_hydra_instantiate_with_defaults(self): - """Test Hydra instantiation using default parameter values.""" - from hydra.utils import instantiate - from omegaconf import OmegaConf - - cfg = OmegaConf.create( - { - "_target_": "ccflow.tests.test_flow_model.hydra_default_model", - # Not specifying value, should use default - } - ) - - model = instantiate(cfg) - result = model(SimpleContext(value=8)) - self.assertEqual(result.value, 50) - - def test_hydra_instantiate_with_dependency(self): - """Test Hydra instantiation with dependencies.""" - from hydra.utils import instantiate - from omegaconf import OmegaConf - - # Create nested config - cfg = OmegaConf.create( - { - "_target_": "ccflow.tests.test_flow_model.hydra_consumer_model", - "source": { - "_target_": "ccflow.tests.test_flow_model.hydra_source_model", - "base": 10, - }, - "factor": 2, - } - ) - - model = instantiate(cfg) - - result = model(SimpleContext(value=5)) - # source: 5 * 10 = 50, consumer: 50 * 2 = 100 - self.assertEqual(result.value, 100) - - -# ============================================================================= -# Lazy[T] Type Annotation Tests -# ============================================================================= - - -class TestLazyTypeAnnotation(TestCase): - """Tests for Lazy[T] type annotation (deferred/conditional evaluation).""" - - def test_lazy_type_annotation_basic(self): - """Lazy[T] param receives a thunk (zero-arg callable). - - The thunk unwraps GenericResult.value, so calling thunk() returns - the inner value (e.g., int), not the GenericResult wrapper. - """ - from ccflow import Lazy - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - @Flow.model - def consumer( - context: SimpleContext, - data: Lazy[GenericResult[int]], - ) -> GenericResult[int]: - # data() returns the unwrapped value (int) - resolved = data() - return GenericResult(value=resolved + 1) - - src = source() - model = consumer(data=src) - result = model(SimpleContext(value=5)) - - # source: 5 * 10 = 50, consumer: 50 + 1 = 51 - self.assertEqual(result.value, 51) - - def test_lazy_conditional_evaluation(self): - """Mirror the smart_training example: lazy dep only evaluated if needed. - - Note: Non-lazy CallableModel deps are auto-resolved and their .value is - unwrapped by the framework (auto-detected dep resolution). So 'fast' - receives the unwrapped int, while 'slow' receives a thunk that returns - the unwrapped value (GenericResult.value) when called. - """ - from ccflow import Lazy - - call_counts = {"fast": 0, "slow": 0} - - @Flow.model - def fast_path(context: SimpleContext) -> GenericResult[int]: - call_counts["fast"] += 1 - return GenericResult(value=context.value) - - @Flow.model - def slow_path(context: SimpleContext) -> GenericResult[int]: - call_counts["slow"] += 1 - return GenericResult(value=context.value * 100) - - @Flow.model - def smart_selector( - context: SimpleContext, - fast: GenericResult[int], # Auto-resolved: receives unwrapped int - slow: Lazy[GenericResult[int]], # Lazy: receives thunk returning unwrapped value - threshold: int = 10, - ) -> GenericResult[int]: - # fast is auto-unwrapped to the int value by the framework - if fast > threshold: - return GenericResult(value=fast) - else: - return GenericResult(value=slow()) - - fast = fast_path() - slow = slow_path() - - # Case 1: fast path sufficient (value > threshold) - model = smart_selector(fast=fast, slow=slow, threshold=10) - result = model(SimpleContext(value=20)) - self.assertEqual(result.value, 20) - self.assertEqual(call_counts["fast"], 1) - self.assertEqual(call_counts["slow"], 0) # Never called! - - # Case 2: fast path insufficient (value <= threshold), slow triggered - call_counts["fast"] = 0 - model2 = smart_selector(fast=fast, slow=slow, threshold=100) - result2 = model2(SimpleContext(value=5)) - self.assertEqual(result2.value, 500) # 5 * 100 - self.assertEqual(call_counts["fast"], 1) - self.assertEqual(call_counts["slow"], 1) - - def test_lazy_thunk_caches_result(self): - """Repeated calls to a thunk return the same value without re-evaluation.""" - from ccflow import Lazy - - call_counts = {"source": 0} - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - call_counts["source"] += 1 - return GenericResult(value=context.value * 10) - - @Flow.model - def consumer( - context: SimpleContext, - data: Lazy[GenericResult[int]], - ) -> GenericResult[int]: - # Call thunk multiple times — returns the unwrapped int - val1 = data() - val2 = data() - val3 = data() - self.assertEqual(val1, val2) - self.assertEqual(val2, val3) - return GenericResult(value=val1) - - src = source() - model = consumer(data=src) - result = model(SimpleContext(value=5)) - self.assertEqual(result.value, 50) - self.assertEqual(call_counts["source"], 1) # Called only once despite 3 thunk() calls - - def test_lazy_with_direct_value(self): - """Pre-computed (non-CallableModel) value wrapped in trivial thunk.""" - from ccflow import Lazy - - @Flow.model - def consumer( - context: SimpleContext, - data: Lazy[int], - ) -> GenericResult[int]: - # data is a thunk even though the underlying value is a plain int - return GenericResult(value=data() * 2) - - model = consumer(data=42) - result = model(SimpleContext(value=0)) - self.assertEqual(result.value, 84) - - def test_lazy_dep_excluded_from_deps(self): - """__deps__ does NOT include lazy dependencies.""" - from ccflow import Lazy - - @Flow.model - def eager_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - @Flow.model - def lazy_source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - @Flow.model - def consumer( - context: SimpleContext, - eager: GenericResult[int], # Auto-resolved, unwrapped to int - lazy_dep: Lazy[GenericResult[int]], # Thunk, returns unwrapped value - ) -> GenericResult[int]: - return GenericResult(value=eager + lazy_dep()) - - eager = eager_source() - lazy = lazy_source() - model = consumer(eager=eager, lazy_dep=lazy) - - ctx = SimpleContext(value=5) - deps = model.__deps__(ctx) - - # Only eager dep should be in __deps__ - self.assertEqual(len(deps), 1) - self.assertIs(deps[0][0], eager) - - def test_lazy_eager_dep_still_pre_evaluated(self): - """Non-lazy deps are still eagerly resolved via __deps__.""" - from ccflow import Lazy - - call_counts = {"eager": 0, "lazy": 0} - - @Flow.model - def eager_source(context: SimpleContext) -> GenericResult[int]: - call_counts["eager"] += 1 - return GenericResult(value=context.value) - - @Flow.model - def lazy_source(context: SimpleContext) -> GenericResult[int]: - call_counts["lazy"] += 1 - return GenericResult(value=context.value * 10) - - @Flow.model - def consumer( - context: SimpleContext, - eager: GenericResult[int], # Auto-resolved, unwrapped to int - lazy_dep: Lazy[GenericResult[int]], # Thunk, returns unwrapped value - ) -> GenericResult[int]: - # eager is auto-unwrapped to int, lazy_dep() returns unwrapped value - return GenericResult(value=eager + lazy_dep()) - - model = consumer(eager=eager_source(), lazy_dep=lazy_source()) - result = model(SimpleContext(value=5)) - - self.assertEqual(result.value, 55) # 5 + 50 - self.assertEqual(call_counts["eager"], 1) - self.assertEqual(call_counts["lazy"], 1) - - def test_lazy_in_dynamic_deferred_mode(self): - """Lazy[T] works in dynamic deferred mode (no context_args).""" - from ccflow import FlowContext, Lazy - - call_counts = {"source": 0} - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - call_counts["source"] += 1 - return GenericResult(value=context.value * 10) - - @Flow.model - def consumer( - value: int, - data: Lazy[GenericResult[int]], - ) -> GenericResult[int]: - if value > 10: - return GenericResult(value=value) - return GenericResult(value=data()) # data() returns unwrapped int - - # value comes from context, data is bound at construction - model = consumer(data=source()) - result = model(FlowContext(value=20)) # value > 10, lazy not called - self.assertEqual(result.value, 20) - self.assertEqual(call_counts["source"], 0) - - def test_lazy_in_context_args_mode(self): - """Lazy[T] works with explicit context_args.""" - from ccflow import FlowContext, Lazy - - @Flow.model(context_args=["x"]) - def source(x: int) -> GenericResult[int]: - return GenericResult(value=x * 10) - - @Flow.model(context_args=["x"]) - def consumer( - x: int, - data: Lazy[GenericResult[int]], - ) -> GenericResult[int]: - return GenericResult(value=x + data()) # data() returns unwrapped int - - model = consumer(data=source()) - result = model(FlowContext(x=5)) - self.assertEqual(result.value, 55) # 5 + 50 - - def test_lazy_never_evaluated_if_not_called(self): - """If thunk is never called, the dependency is never evaluated.""" - from ccflow import Lazy - - call_counts = {"source": 0} - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - call_counts["source"] += 1 - return GenericResult(value=context.value) - - @Flow.model - def consumer( - context: SimpleContext, - data: Lazy[GenericResult[int]], - ) -> GenericResult[int]: - # Never call data() - return GenericResult(value=42) - - model = consumer(data=source()) - result = model(SimpleContext(value=5)) - self.assertEqual(result.value, 42) - self.assertEqual(call_counts["source"], 0) - - def test_lazy_with_upstream_model(self): - """Lazy[T] works when bound to an upstream model.""" - from ccflow import Lazy - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - @Flow.model - def consumer( - context: SimpleContext, - data: Lazy[GenericResult[int]], - ) -> GenericResult[int]: - return GenericResult(value=data() + 1) # data() returns unwrapped int - - src = source() - model = consumer(data=src) - - # Lazy dep should NOT be in __deps__ - deps = model.__deps__(SimpleContext(value=5)) - self.assertEqual(len(deps), 0) - - result = model(SimpleContext(value=5)) - self.assertEqual(result.value, 51) # 50 + 1 - - -# ============================================================================= -# Bug Fix Regression Tests -# ============================================================================= - - -class TestFlowModelBugFixes(TestCase): - """Regression tests for four bugs identified during code review.""" - - # ----- Issue 1: .flow.compute() drops context defaults ----- - - def test_compute_respects_explicit_context_defaults(self): - """Mode 1: compute(x=1) should use ExtendedContext's default y='default'.""" - - @Flow.model - def model_fn(context: ExtendedContext, factor: int = 1) -> str: - return f"{context.x}-{context.y}-{factor}" - - model = model_fn() - result = model.flow.compute(x=1) - self.assertEqual(result.value, "1-default-1") - - def test_compute_respects_context_args_defaults(self): - """Mode 2: compute(x=1) should use function default y=42.""" - - @Flow.model(context_args=["x", "y"]) - def model_fn(x: int, y: int = 42) -> int: - return x + y - - model = model_fn() - result = model.flow.compute(x=1) - self.assertEqual(result.value, 43) - - def test_unbound_inputs_excludes_context_args_with_defaults(self): - """Mode 2: unbound_inputs should not include context_args that have function defaults.""" - - @Flow.model(context_args=["x", "y"]) - def model_fn(x: int, y: int = 42) -> int: - return x + y - - model = model_fn() - self.assertEqual(model.flow.unbound_inputs, {"x": int}) - - def test_unbound_inputs_excludes_context_type_defaults(self): - """Mode 1: unbound_inputs should not include context fields that have defaults.""" - - @Flow.model - def model_fn(context: ExtendedContext) -> str: - return f"{context.x}-{context.y}" - - model = model_fn() - # ExtendedContext has x: int (required) and y: str = "default" - self.assertEqual(model.flow.unbound_inputs, {"x": int}) - - def test_context_type_rejects_required_field_with_function_default(self): - """Decoration should fail when function has default but context_type requires the field.""" - - class StrictContext(ContextBase): - x: int # required - - with self.assertRaises(TypeError) as cm: - - @Flow.model(context_args=["x"], context_type=StrictContext) - def model_fn(x: int = 5) -> int: - return x - - self.assertIn("x", str(cm.exception)) - self.assertIn("requires", str(cm.exception)) - - def test_context_type_accepts_optional_field_with_function_default(self): - """Both context_type and function have defaults — should work.""" - - class OptionalContext(ContextBase): - x: int = 10 - - @Flow.model(context_args=["x"], context_type=OptionalContext) - def model_fn(x: int = 5) -> int: - return x - - model = model_fn() - result = model(OptionalContext()) - self.assertEqual(result.value, 10) # context default wins - - # ----- Issue 2: Lazy[...] broken in dynamic deferred mode ----- - - def test_lazy_from_runtime_context_in_dynamic_mode(self): - """Lazy[int] provided via FlowContext should be wrapped in a thunk.""" - - @Flow.model - def model_fn(x: int, y: Lazy[int]) -> int: - return x + y() - - model = model_fn(x=10) - result = model(FlowContext(y=32)) - self.assertEqual(result.value, 42) - - def test_callable_model_from_runtime_context_in_dynamic_mode(self): - """CallableModel provided in FlowContext should be resolved.""" - - @Flow.model - def source(value: int) -> int: - return value * 10 - - @Flow.model - def consumer(x: int, data: int) -> int: - return x + data - - model = consumer(x=1) - src = source() - result = model(FlowContext(data=src, value=5)) - # source resolves with value=5 → 50, consumer: 1 + 50 = 51 - self.assertEqual(result.value, 51) - - # ----- Issue 3: FlowContext-backed models skip schema validation ----- - - def test_direct_call_validates_flowcontext_dynamic_mode(self): - """Dynamic mode: FlowContext(y='hello') for int param should raise TypeError.""" - - @Flow.model - def model_fn(x: int, y: int) -> int: - return x + y - - model = model_fn() - with self.assertRaises(TypeError) as cm: - model(FlowContext(x=1, y="hello")) - - self.assertIn("y", str(cm.exception)) - - def test_direct_call_validates_flowcontext_context_args_mode(self): - """context_args mode: FlowContext(x='hello') for int param should raise TypeError.""" - - @Flow.model(context_args=["x"]) - def model_fn(x: int) -> int: - return x - - model = model_fn() - with self.assertRaises(TypeError) as cm: - model(FlowContext(x="hello")) - - self.assertIn("x", str(cm.exception)) - - def test_with_inputs_validates_transformed_fields_dynamic(self): - """Dynamic mode: with_inputs(y='hello') for int param should raise TypeError.""" - - @Flow.model - def model_fn(x: int, y: int) -> int: - return x + y - - model = model_fn(x=1) - bound = model.flow.with_inputs(y="hello") - - with self.assertRaises(TypeError) as cm: - bound(FlowContext()) - - self.assertIn("y", str(cm.exception)) - - def test_with_inputs_validates_transformed_fields_context_args(self): - """context_args mode: with_inputs(x='hello') for int param should raise TypeError.""" - - @Flow.model(context_args=["x"]) - def model_fn(x: int) -> int: - return x - - model = model_fn() - bound = model.flow.with_inputs(x="hello") - - with self.assertRaises(TypeError) as cm: - bound(FlowContext()) - - self.assertIn("x", str(cm.exception)) - - # ----- Issue 4: Registry-name resolution too aggressive for union strings ----- - - def test_registry_resolution_skips_union_str_annotation(self): - """Union[str, int] field with a registry key string should keep the string.""" - from typing import Union - - registry = ModelRegistry.root() - registry.clear() - try: - - @Flow.model - def dummy(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=1) - - registry.add("my_key", dummy()) - - @Flow.model - def consumer(context: SimpleContext, tag: Union[str, int] = "none") -> str: - return f"tag={tag}" - - model = consumer(tag="my_key") - result = model(SimpleContext(value=0)) - self.assertEqual(result.value, "tag=my_key") - finally: - registry.clear() - - def test_registry_resolution_skips_optional_str_annotation(self): - """Optional[str] field with a registry key string should keep the string.""" - from typing import Optional - - registry = ModelRegistry.root() - registry.clear() - try: - - @Flow.model - def dummy(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=1) - - registry.add("my_key", dummy()) - - @Flow.model - def consumer(context: SimpleContext, label: Optional[str] = None) -> str: - return f"label={label}" - - model = consumer(label="my_key") - result = model(SimpleContext(value=0)) - self.assertEqual(result.value, "label=my_key") - finally: - registry.clear() - - def test_registry_resolution_skips_union_annotated_str(self): - """Union[Annotated[str, ...], int] field with a registry key should keep the string.""" - from typing import Annotated, Union - - registry = ModelRegistry.root() - registry.clear() - try: - - @Flow.model - def dummy(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=1) - - registry.add("my_key", dummy()) - - @Flow.model - def consumer(context: SimpleContext, tag: Union[Annotated[str, "label"], int] = "none") -> str: - return f"tag={tag}" - - model = consumer(tag="my_key") - result = model(SimpleContext(value=0)) - self.assertEqual(result.value, "tag=my_key") - finally: - registry.clear() - - -# ============================================================================= -# Coverage Gap Tests -# ============================================================================= - - -class TestExtractLazyLoopBody(TestCase): - """Group 1: _extract_lazy loop body with non-LazyMarker metadata.""" - - def test_annotated_with_extra_metadata_before_lazy_marker(self): - """Annotated type where _LazyMarker is NOT the first metadata element.""" - from typing import Annotated - - from ccflow.flow_model import _extract_lazy, _LazyMarker - - # _LazyMarker is the second metadata element — loop must iterate past "other" - ann = Annotated[int, "other_metadata", _LazyMarker()] - base_type, is_lazy = _extract_lazy(ann) - self.assertTrue(is_lazy) - self.assertIs(base_type, int) - - def test_annotated_without_lazy_marker(self): - """Annotated type with no _LazyMarker returns is_lazy=False.""" - from typing import Annotated - - from ccflow.flow_model import _extract_lazy - - ann = Annotated[int, "just_metadata"] - base_type, is_lazy = _extract_lazy(ann) - self.assertFalse(is_lazy) - - def test_lazy_type_annotation_with_extra_annotated(self): - """End-to-end: Lazy wrapping of an Annotated type.""" - - @Flow.model - def model_with_lazy( - x: int, - dep: Lazy[int], - ) -> int: - return x + dep() - - @Flow.model - def upstream(x: int) -> int: - return x * 10 - - model = model_with_lazy(x=1, dep=upstream()) - result = model.flow.compute(x=1) - self.assertEqual(result.value, 11) - - def test_lazy_dep_returning_custom_result(self): - """Lazy dep returning custom ResultBase (not GenericResult) should return raw result.""" - - @Flow.model - def upstream(context: SimpleContext) -> MyResult: - return MyResult(data=f"v={context.value}") - - @Flow.model - def consumer(context: SimpleContext, dep: Lazy[MyResult]) -> GenericResult[str]: - result = dep() - return GenericResult(value=result.data) - - model = consumer(dep=upstream()) - result = model(SimpleContext(value=42)) - self.assertEqual(result.value, "v=42") - - -class TestTransformReprNamedCallable(TestCase): - """Group 2: _transform_repr with a named callable.""" - - def test_named_function_transform_in_repr(self): - """Named functions should appear in BoundModel repr wrapped in angle brackets.""" - from ccflow.flow_model import _transform_repr - - def my_custom_transform(ctx): - return ctx.value + 1 - - result = _transform_repr(my_custom_transform) - self.assertIn("my_custom_transform", result) - self.assertTrue(result.startswith("<")) - self.assertTrue(result.endswith(">")) - - def test_static_value_repr(self): - """Static (non-callable) values should use repr().""" - from ccflow.flow_model import _transform_repr - - self.assertEqual(_transform_repr(42), "42") - self.assertEqual(_transform_repr("hello"), "'hello'") - - -class TestBoundFieldNamesFallback(TestCase): - """Group 3: _bound_field_names fallback for objects without model_fields_set.""" - - def test_fallback_to_bound_fields_attr(self): - from ccflow.flow_model import _bound_field_names - - class FakeModel: - _bound_fields = {"x", "y"} - - result = _bound_field_names(FakeModel()) - self.assertEqual(result, {"x", "y"}) - - def test_fallback_no_attrs(self): - from ccflow.flow_model import _bound_field_names - - class Empty: - pass - - result = _bound_field_names(Empty()) - self.assertEqual(result, set()) - - -class TestRuntimeInputNamesEmpty(TestCase): - """Group 4: _runtime_input_names when all_param_names is empty.""" - - def test_non_flow_model_returns_empty(self): - from ccflow.flow_model import _runtime_input_names - - class ManualModel(CallableModel): - offset: int - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value + self.offset) - - model = ManualModel(offset=5) - self.assertEqual(_runtime_input_names(model), set()) - - -class TestRegistryCandidateAllowed(TestCase): - """Group 5: _registry_candidate_allowed TypeAdapter success path.""" + value: int - def test_non_callable_model_passes_type_check(self): - """Registry value that is not a CallableModel but passes TypeAdapter validation.""" - from ccflow.flow_model import _registry_candidate_allowed - # int value passes TypeAdapter(int).validate_python - self.assertTrue(_registry_candidate_allowed(int, 42)) +class ParentRangeContext(ContextBase): + start_date: date + end_date: date - def test_non_callable_model_fails_type_check(self): - from ccflow.flow_model import _registry_candidate_allowed - self.assertFalse(_registry_candidate_allowed(int, "not_an_int")) +class RichRangeContext(ParentRangeContext): + label: str = "child" -class TestConcreteContextTypeOptional(TestCase): - """Group 6: _concrete_context_type with Optional/Union types.""" +class OrderedContext(ContextBase): + a: int + b: int - def test_optional_context_type(self): - """Optional[T] has NoneType that should be skipped to find T.""" - from typing import Optional + @model_validator(mode="after") + def _validate_order(self): + if self.a > self.b: + raise ValueError("a must be <= b") + return self - from ccflow.flow_model import _concrete_context_type - # Optional[SimpleContext] = Union[SimpleContext, None] - # The NoneType arg must be skipped (line 196-197) - result = _concrete_context_type(Optional[SimpleContext]) - self.assertIs(result, SimpleContext) +@Flow.model +def basic_loader(context: SimpleContext, source: str, multiplier: int) -> GenericResult[int]: + return GenericResult(value=context.value * multiplier) - def test_union_with_none_first(self): - """Union[None, T] should skip NoneType and find T.""" - from typing import Union - from ccflow.flow_model import _concrete_context_type +@Flow.model +def string_processor(context: SimpleContext, prefix: str = "value=", suffix: str = "!") -> GenericResult[str]: + return GenericResult(value=f"{prefix}{context.value}{suffix}") - # NoneType comes first, must be skipped - result = _concrete_context_type(Union[None, SimpleContext]) - self.assertIs(result, SimpleContext) - def test_union_context_type(self): - from typing import Union +@Flow.model +def data_source(context: SimpleContext, base_value: int) -> GenericResult[int]: + return GenericResult(value=context.value + base_value) - from ccflow.flow_model import _concrete_context_type - result = _concrete_context_type(Union[SimpleContext, None]) - self.assertIs(result, SimpleContext) +@Flow.model +def data_transformer(context: SimpleContext, source: int, factor: int) -> GenericResult[int]: + return GenericResult(value=source * factor) - def test_union_no_context_base(self): - from typing import Union - from ccflow.flow_model import _concrete_context_type +@Flow.model +def data_aggregator(context: SimpleContext, input_a: int, input_b: int, operation: str = "add") -> GenericResult[int]: + if operation == "add": + return GenericResult(value=input_a + input_b) + raise ValueError(f"unsupported operation: {operation}") - result = _concrete_context_type(Union[int, str]) - self.assertIsNone(result) - def test_returns_none_for_non_type(self): - from ccflow.flow_model import _concrete_context_type +@Flow.model +def pipeline_stage1(context: SimpleContext, initial: int) -> GenericResult[int]: + return GenericResult(value=context.value + initial) - result = _concrete_context_type("not_a_type") - self.assertIsNone(result) +@Flow.model +def pipeline_stage2(context: SimpleContext, stage1_output: int, multiplier: int) -> GenericResult[int]: + return GenericResult(value=stage1_output * multiplier) -class TestBuildConfigValidatorsException(TestCase): - """Group 7: _build_config_validators when TypeAdapter fails.""" - def test_unadaptable_type_skipped(self): - """Types that TypeAdapter can't handle should be silently skipped.""" - from ccflow.flow_model import _build_config_validators +@Flow.model +def pipeline_stage3(context: SimpleContext, stage2_output: int, offset: int) -> GenericResult[int]: + return GenericResult(value=stage2_output + offset) - # type(...) (EllipsisType) makes TypeAdapter fail - validatable, validators = _build_config_validators({"x": int, "y": type(...)}) - self.assertIn("x", validatable) - self.assertNotIn("y", validatable) - self.assertIn("x", validators) - self.assertNotIn("y", validators) +@Flow.model +def date_range_loader_previous_day( + source: str, + start_date: FromContext[date], + end_date: FromContext[date], + include_weekends: bool = False, +) -> GenericResult[dict]: + del include_weekends + return GenericResult( + value={ + "source": source, + "start_date": str(start_date - timedelta(days=1)), + "end_date": str(end_date), + } + ) -class TestCoerceContextValueNoValidator(TestCase): - """Group 8: _coerce_context_value early return for fields without validators.""" - def test_field_without_validator_passes_through(self): - from ccflow.flow_model import _coerce_context_value +@Flow.model +def date_range_processor(context: DateRangeContext, raw_data: dict, normalize: bool = False) -> GenericResult[str]: + prefix = "normalized:" if normalize else "raw:" + return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") - # When name is not in validators, value should pass through unchanged - result = _coerce_context_value("unknown_field", 42, {}, {}) - self.assertEqual(result, 42) +@Flow.model +def contextual_loader(source: str, start_date: FromContext[date], end_date: FromContext[date]) -> GenericResult[dict]: + return GenericResult( + value={ + "source": source, + "start_date": str(start_date), + "end_date": str(end_date), + } + ) -class TestGeneratedModelClassFactoryPath(TestCase): - """Group 9: _generated_model_class when stage has no generated model.""" - def test_returns_none_for_plain_callable(self): - from ccflow.flow_model import _generated_model_class +@Flow.model +def contextual_processor( + prefix: str, + data: dict, + start_date: FromContext[date], + end_date: FromContext[date], +) -> GenericResult[str]: + del start_date, end_date + return GenericResult(value=f"{prefix}:{data['source']}:{data['start_date']} to {data['end_date']}") - def plain_func(): - pass - self.assertIsNone(_generated_model_class(plain_func)) +def test_from_context_anchor_behavior(): + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b + assert foo(a=11).flow.compute(b=12).value == 23 + assert foo(a=11, b=12).flow.compute().value == 23 -class TestDescribePipeStagePaths(TestCase): - """Group 10: _describe_pipe_stage for different stage types.""" + with pytest.raises(TypeError, match="compute\\(\\) only accepts contextual inputs"): + foo().flow.compute(a=11, b=12) - def test_generated_model_instance(self): - from ccflow.flow_model import _describe_pipe_stage - @Flow.model - def my_stage(x: int) -> int: - return x +def test_regular_param_accepts_upstream_model(): + @Flow.model + def source(value: FromContext[int], offset: int) -> int: + return value + offset - desc = _describe_pipe_stage(my_stage()) - self.assertIn("my_stage", desc) + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b - def test_callable_stage(self): - from ccflow.flow_model import _describe_pipe_stage + model = foo(a=source(offset=5)) + assert model.flow.compute(FlowContext(value=7, b=12)).value == 24 - @Flow.model - def factory_stage(x: int) -> int: - return x - desc = _describe_pipe_stage(factory_stage) - self.assertIn("factory_stage", desc) +def test_contextual_param_rejects_callable_model(): + @Flow.model + def source(context: SimpleContext, offset: int) -> GenericResult[int]: + return GenericResult(value=context.value + offset) - def test_non_callable_stage(self): - from ccflow.flow_model import _describe_pipe_stage + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b - desc = _describe_pipe_stage(42) - self.assertEqual(desc, "42") + with pytest.raises(TypeError, match="cannot be bound to a CallableModel"): + foo(a=1, b=source(offset=2)) -class TestInferPipeParamAmbiguousDefaults(TestCase): - """Cover _infer_pipe_param fallback path with multiple defaulted candidates.""" +def test_contextual_construction_defaults_and_bound_inputs(): + @Flow.model + def foo(a: int, b: FromContext[int]) -> int: + return a + b - def test_ambiguous_defaulted_candidates(self): - """When all candidates have defaults but multiple are unoccupied.""" + model = foo(a=11, b=12) + assert model.flow.bound_inputs == {"a": 11, "b": 12} + assert model.flow.context_inputs == {"b": int} + assert model.flow.unbound_inputs == {} + assert model.flow.compute().value == 23 - @Flow.model - def source(x: int) -> int: - return x - @Flow.model - def consumer(a: int = 1, b: int = 2) -> int: - return a + b +def test_contextual_function_defaults_remain_contextual(): + @Flow.model + def foo(a: int, b: FromContext[int] = 5) -> int: + return a + b - # Both a and b have defaults, both are unoccupied -> ambiguous - with self.assertRaisesRegex(TypeError, "could not infer a target parameter"): - source().pipe(consumer) + model = foo(a=2) + assert model.flow.bound_inputs == {"a": 2} + assert model.flow.context_inputs == {"b": int} + assert model.flow.unbound_inputs == {} + assert model.flow.compute().value == 7 + assert model.flow.compute(b=10).value == 12 -class TestPipeErrorPaths(TestCase): - """Group 11: pipe() error paths not covered by existing tests.""" +def test_context_type_accepts_richer_subclass_for_from_context(): + @Flow.model(context_type=ParentRangeContext) + def span_days(multiplier: int, start_date: FromContext[date], end_date: FromContext[date]) -> int: + return multiplier * ((end_date - start_date).days + 1) - def test_pipe_non_callable_model_source(self): - """pipe() should reject non-CallableModel source.""" - from ccflow.flow_model import pipe_model + model = span_days(multiplier=2) + assert model.flow.compute(start_date="2024-01-01", end_date="2024-01-03").value == 6 + assert model(RichRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 4), label="x")).value == 8 - @Flow.model - def consumer(data: int) -> int: - return data - with self.assertRaisesRegex(TypeError, "pipe\\(\\) source must be a CallableModel"): - pipe_model("not_a_model", consumer) +def test_context_type_validation_applies_to_resolved_contextual_values(): + @Flow.model(context_type=OrderedContext) + def add(a: FromContext[int], b: FromContext[int]) -> int: + return a + b - def test_pipe_non_flow_model_target(self): - """pipe() should reject non-@Flow.model target.""" - from ccflow.flow_model import pipe_model + with pytest.raises(ValueError, match="a must be <= b"): + add().flow.compute(a=2, b=1) - @Flow.model - def source(x: int) -> int: - return x + with pytest.raises(ValueError, match="a must be <= b"): + add(a=2, b=1).flow.compute() - class ManualTarget(CallableModel): - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=0) - with self.assertRaisesRegex(TypeError, "pipe\\(\\) only supports downstream stages"): - pipe_model(source(), ManualTarget()) +def test_explicit_context_interop_accepts_pep604_optional_annotation(): + @Flow.model + def loader(context: DateRangeContext | None, source: str = "db") -> GenericResult[str]: + assert context is not None + return GenericResult(value=f"{source}:{context.start_date}:{context.end_date}") - def test_pipe_invalid_param_name(self): - """pipe() should reject invalid target parameter names.""" + model = loader(source="api") + assert model.flow.compute(start_date="2024-01-01", end_date="2024-01-02").value == "api:2024-01-01:2024-01-02" - @Flow.model - def source(x: int) -> int: - return x - @Flow.model - def consumer(data: int) -> int: - return data +def test_explicit_context_interop_still_works(): + @Flow.model + def loader(context: DateRangeContext, source: str = "db") -> GenericResult[str]: + return GenericResult(value=f"{source}:{context.start_date}:{context.end_date}") - with self.assertRaisesRegex(TypeError, "is not valid for"): - source().pipe(consumer, param="nonexistent") + model = loader(source="api") + assert model(DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2))).value == "api:2024-01-01:2024-01-02" + assert model.flow.compute(start_date="2024-01-01", end_date="2024-01-02").value == "api:2024-01-01:2024-01-02" - def test_pipe_already_bound_param(self): - """pipe() should reject already-bound parameters.""" - @Flow.model - def source(x: int) -> int: - return x +def test_explicit_context_and_from_context_cannot_mix(): + with pytest.raises(TypeError, match="cannot also declare FromContext"): @Flow.model - def consumer(data: int) -> int: - return data + def bad(context: SimpleContext, y: FromContext[int]) -> int: + return context.value + y - model = consumer(data=5) - with self.assertRaisesRegex(TypeError, "is already bound"): - source().pipe(model, param="data") - def test_pipe_no_available_target_parameter(self): - """pipe() should error when all downstream params are occupied.""" +def test_context_args_keyword_is_removed(): + with pytest.raises(TypeError, match="context_args=... has been removed"): - @Flow.model - def source(x: int) -> int: + @Flow.model(context_args=["x"]) + def bad(x: int) -> int: return x - @Flow.model - def consumer(data: int) -> int: - return data - - model = consumer(data=5) - with self.assertRaisesRegex(TypeError, "could not find an available target parameter"): - source().pipe(model) - - def test_pipe_into_generated_instance_rebuilds(self): - """pipe() into an existing generated model instance should rebuild.""" - - @Flow.model - def source(x: int) -> int: - return x * 10 - - @Flow.model - def consumer(data: int, extra: int = 1) -> int: - return data + extra - - instance = consumer(extra=5) - pipeline = source().pipe(instance) - result = pipeline.flow.compute(x=3) - self.assertEqual(result.value, 35) # 3*10 + 5 - def test_pipe_bound_model_wrapping_non_generated_rejects(self): - """pipe() into BoundModel wrapping a non-generated model should fail.""" - from ccflow.flow_model import BoundModel, pipe_model +def test_context_type_requires_from_context_or_explicit_context(): + with pytest.raises(TypeError, match="context_type=... requires FromContext"): - @Flow.model - def source(x: int) -> int: + @Flow.model(context_type=DateRangeContext) + def bad(x: int) -> int: return x - class ManualModel(CallableModel): - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - bound = BoundModel(model=ManualModel(), input_transforms={"value": 42}) - with self.assertRaisesRegex(TypeError, "pipe\\(\\) only supports downstream"): - pipe_model(source(), bound) - - -class TestFlowAPIBuildContextFallback(TestCase): - """Group 12: FlowAPI._build_context when _context_schema is None/unset.""" - - def test_unbound_inputs_on_manual_callable_model(self): - """Manual CallableModel with context should show required fields.""" - - class ManualModel(CallableModel): - offset: int - - @Flow.call - def __call__(self, context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value + self.offset) - - model = ManualModel(offset=5) - unbound = model.flow.unbound_inputs - self.assertIn("value", unbound) - - -class TestBoundModelRestoreNonDict(TestCase): - """Group 13: BoundModel._restore_serialized_transforms non-dict path.""" - - def test_restore_from_model_instance(self): - """model_validate from an existing BoundModel instance (non-dict).""" - from ccflow.flow_model import BoundModel - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - bound = source().flow.with_inputs(value=42) - # Pass existing instance through model_validate (non-dict path) - restored = BoundModel.model_validate(bound) - ctx = SimpleContext(value=1) - self.assertEqual(restored(ctx).value, 420) - - -class TestBoundModelInitEmptyTransforms(TestCase): - """Group 14: BoundModel.__init__ with no transforms.""" - - def test_init_without_transforms(self): - from ccflow.flow_model import BoundModel - - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) - - bound = BoundModel(model=source()) - self.assertEqual(bound._input_transforms, {}) - result = bound(SimpleContext(value=5)) - self.assertEqual(result.value, 5) - - -class TestBoundModelDeps(TestCase): - """Group 15: BoundModel.__deps__.""" - - def test_deps_returns_wrapped_model(self): - @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) - - bound = source().flow.with_inputs(value=42) - deps = bound.__deps__(SimpleContext(value=1)) - self.assertEqual(len(deps), 1) - self.assertIs(deps[0][0], bound.model) - - -class TestValidateFieldTypesAfterValidator(TestCase): - """Group 16: _validate_field_types in the model_validate path.""" - - def test_model_validate_rejects_wrong_type(self): - """model_validate should reject wrong scalar types.""" - - @Flow.model - def source(x: int) -> int: - return x * 10 - - cls = type(source(x=5)) - with self.assertRaisesRegex(TypeError, "Field 'x'"): - cls.model_validate({"x": "not_an_int"}) - - -class TestGetContextValidatorPaths(TestCase): - """Group 17: _get_context_validator fallback paths.""" - - def test_mode2_context_validator_from_schema(self): - """Mode 2 model should build validator from _context_schema.""" - - @Flow.model(context_args=["start_date"]) - def loader(start_date: str, source: str = "db") -> str: - return f"{source}:{start_date}" - - model = loader() - # Trigger validator creation by calling flow.compute - result = model.flow.compute(start_date="2024-01-01") - self.assertEqual(result.value, "db:2024-01-01") - - def test_mode1_context_validator_uses_context_type_directly(self): - """Mode 1 should use TypeAdapter(context_type) directly.""" - - @Flow.model - def model_fn(context: SimpleContext, offset: int = 0) -> GenericResult[int]: - return GenericResult(value=context.value + offset) - - model = model_fn() - # compute with SimpleContext fields - result = model.flow.compute(value=5) - self.assertEqual(result.value, 5) - - -class TestValidateContextTypeOverrideErrors(TestCase): - """Group 18: _validate_context_type_override error paths.""" - - def test_non_context_base_raises(self): - with self.assertRaisesRegex(TypeError, "context_type must be a ContextBase subclass"): - - @Flow.model(context_args=["x"], context_type=int) - def bad_model(x: int) -> int: - return x - - def test_context_type_missing_context_args_fields(self): - """context_type missing required context_args fields.""" - - class TinyContext(ContextBase): - a: int - - with self.assertRaisesRegex(TypeError, "must define fields for context_args"): - - @Flow.model(context_args=["a", "b"], context_type=TinyContext) - def bad_model(a: int, b: int) -> int: - return a + b - - def test_context_type_extra_required_fields(self): - """context_type has required fields not listed in context_args.""" - - class BigContext(ContextBase): - a: int - b: int - extra: str - - with self.assertRaisesRegex(TypeError, "has required fields not listed in context_args"): - - @Flow.model(context_args=["a"], context_type=BigContext) - def bad_model(a: int) -> int: - return a - - def test_annotation_type_mismatch(self): - """Function and context_type disagree on annotation type.""" - - class TypedContext(ContextBase): - x: str - - with self.assertRaisesRegex(TypeError, "context_arg 'x'"): - - @Flow.model(context_args=["x"], context_type=TypedContext) - def bad_model(x: int) -> int: - return x - - def test_annotation_skip_when_func_ann_is_none(self): - """Annotation check should skip when function annotation is absent from schema.""" - from ccflow.flow_model import _validate_context_type_override - - class CompatContext(ContextBase): - a: int - - # context_args has 'a', schema has 'a': int. Compatible, no error. - result = _validate_context_type_override(CompatContext, ["a"], {"a": int}) - self.assertIs(result, CompatContext) - - def test_subclass_annotations_allowed(self): - """context_type with subclass-compatible annotations should pass.""" - from ccflow.flow_model import _validate_context_type_override - - class ContextWithBase(ContextBase): - ctx: ContextBase - # Function declares SimpleContext which is a subclass of ContextBase — should pass - result = _validate_context_type_override(ContextWithBase, ["ctx"], {"ctx": SimpleContext}) - self.assertIs(result, ContextWithBase) +def test_pipe_only_targets_regular_parameters(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value - def test_default_vs_required_field_conflict(self): - """Function has default for context_arg but context_type requires it.""" + @Flow.model + def consumer(a: int, b: FromContext[int]) -> int: + return a + b - class StrictContext(ContextBase): - x: int + piped = source().pipe(consumer()) + assert piped.flow.compute(FlowContext(value=10, b=5)).value == 15 - with self.assertRaisesRegex(TypeError, "function has a default but context_type"): + with pytest.raises(TypeError, match="is contextual"): + source().pipe(consumer(), param="b") - @Flow.model(context_args=["x"], context_type=StrictContext) - def bad_model(x: int = 5) -> int: - return x +def test_lazy_dependency_remains_lazy(): + calls = {"source": 0} -class TestDecoratorErrorPaths(TestCase): - """Group 19: Decorator error paths.""" + @Flow.model + def source(value: FromContext[int]) -> int: + calls["source"] += 1 + return value * 10 - def test_context_type_with_explicit_context_param(self): - """context_type= with explicit context param should raise.""" - with self.assertRaisesRegex(TypeError, "context_type.*only supported"): + @Flow.model + def choose(value: int, lazy_value: Lazy[int], threshold: FromContext[int]) -> int: + if value > threshold: + return value + return lazy_value() - @Flow.model(context_type=SimpleContext) - def bad_model(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=0) + eager = choose(value=50, lazy_value=source()) + assert eager.flow.compute(FlowContext(value=3, threshold=10)).value == 50 + assert calls["source"] == 0 - def test_context_type_without_context_args(self): - """context_type= without context_args should raise in dynamic mode.""" - with self.assertRaisesRegex(TypeError, "context_type.*only supported"): + deferred = choose(value=5, lazy_value=source()) + assert deferred.flow.compute(FlowContext(value=3, threshold=10)).value == 30 + assert calls["source"] == 1 - @Flow.model(context_type=SimpleContext) - def bad_model(x: int) -> int: - return x - def test_missing_context_annotation(self): - """Missing type annotation on context param should raise.""" - with self.assertRaisesRegex(TypeError, "must have a type annotation"): +def test_lazy_runtime_helper_is_removed(): + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) - @Flow.model - def bad_model(context) -> int: - return 0 + with pytest.raises(TypeError, match="Lazy\\(model\\)\\(\\.\\.\\.\\) has been removed"): + Lazy(source()) - def test_missing_param_annotation(self): - """Missing type annotation on a model field param should raise.""" - with self.assertRaisesRegex(TypeError, "must have a type annotation"): - @Flow.model - def bad_model(context: SimpleContext, untyped_param) -> int: - return 0 - - def test_context_param_not_context_base(self): - """context param annotated with non-ContextBase type should raise.""" - with self.assertRaisesRegex(TypeError, "must be annotated with a ContextBase subclass"): - - @Flow.model - def bad_model(context: int) -> int: - return 0 - - def test_pep563_fallback_on_failed_get_type_hints(self): - """When get_type_hints fails, falls back to raw annotations.""" - - # This is hard to trigger directly, but we can test that string annotations work - @Flow.model - def model_with_string_return(x: int) -> "int": - return x * 2 - - result = model_with_string_return().flow.compute(x=5) - self.assertEqual(result.value, 10) - - -class TestMode1CallPath(TestCase): - """Group 20: Mode 1 explicit context pass-through in __call__.""" - - def test_mode1_resolve_callable_model_returns_non_generic_result(self): - """Mode 1 should handle deps that return raw ResultBase (not GenericResult).""" - - @Flow.model - def upstream(context: SimpleContext) -> MyResult: - return MyResult(data=f"value={context.value}") - - @Flow.model - def downstream(context: SimpleContext, dep: CallableModel) -> GenericResult[str]: - # dep is resolved to MyResult since it's not GenericResult - return GenericResult(value=f"got:{dep}") - - model = downstream(dep=upstream()) - result = model(SimpleContext(value=42)) - self.assertIn("value=42", result.value) - - -class TestDynamicModeContextLookup(TestCase): - """Group 21: Dynamic mode context lookup for deferred values.""" - - def test_deferred_value_from_context(self): - """Dynamic mode should pull deferred values from context.""" - - @Flow.model - def add(x: int, y: int) -> int: - return x + y - - model = add(x=10) - # y is deferred — pulled from context - result = model.flow.compute(y=5) - self.assertEqual(result.value, 15) - - def test_missing_deferred_value_raises(self): - """Dynamic mode should raise for missing deferred values.""" - - @Flow.model - def add(x: int, y: int) -> int: - return x + y - - model = add(x=10) - with self.assertRaisesRegex(TypeError, "Missing runtime input"): - model.flow.compute() # y not provided - - def test_context_sourced_value_coercion(self): - """Dynamic mode should coerce context-sourced values through validators.""" +def test_lazy_and_from_context_combination_is_rejected(): + with pytest.raises(TypeError, match="cannot combine Lazy"): @Flow.model - def typed_model(x: int, y: int) -> int: - return x + y - - model = typed_model(x=10) - # y provided as a value that can be coerced to int - result = model.flow.compute(y=5) - self.assertEqual(result.value, 15) + def bad(x: Lazy[FromContext[int]]) -> int: + return x() - def test_deferred_value_from_context_object(self): - """Dynamic mode should look up deferred values from context attributes.""" - - @Flow.model - def multiply(x: int, y: int) -> int: - return x * y - model = multiply(x=3) - # Call directly with a FlowContext — y must come from context - result = model(FlowContext(y=7)) - self.assertEqual(result.value, 21) +def test_auto_wrap_and_serialization_roundtrip(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + model = add(a=10) + dumped = model.model_dump(mode="python") + restored = type(model).model_validate(dumped) -class TestGetContextValidatorFallbacks(TestCase): - """Group 17 additional: _get_context_validator edge cases.""" + assert restored.flow.bound_inputs == {"a": 10} + assert restored.flow.unbound_inputs == {"b": int} + assert restored.flow.compute(b=5).value == 15 - def test_mode2_with_context_type_override(self): - """Mode 2 with explicit context_type should use that type's validator.""" - @Flow.model(context_args=["value"], context_type=SimpleContext) - def typed_model(value: int) -> int: - return value * 2 +def test_generated_models_cloudpickle_roundtrip(): + @Flow.model + def multiply(a: int, b: FromContext[int]) -> int: + return a * b - model = typed_model() - result = model(SimpleContext(value=5)) - self.assertEqual(result.value, 10) + model = multiply(a=6) + restored = rcploads(rcpdumps(model, protocol=5)) + assert restored.flow.compute(b=7).value == 42 - def test_dynamic_mode_instance_validator(self): - """Dynamic mode should create instance-specific validator.""" - @Flow.model - def add(x: int, y: int, z: int = 0) -> int: - return x + y + z +def test_graph_integration_fanout_fanin(): + @Flow.model + def source(base: int, value: FromContext[int]) -> int: + return value + base - m1 = add(x=1) - m2 = add(x=1, y=2) - # Different bound fields => different runtime inputs - self.assertIn("y", m1.flow.unbound_inputs) - self.assertNotIn("y", m2.flow.unbound_inputs) + @Flow.model + def scale(data: int, factor: int) -> int: + return data * factor + @Flow.model + def merge(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus -class TestRegistryResolutionInValidateFieldTypes(TestCase): - """Group 16: _resolve_registry_refs and _validate_field_types paths.""" + src = source(base=10) + left = scale(data=src, factor=2) + right = scale(data=src, factor=5) + model = merge(left=left, right=right) - def test_registry_string_not_resolving_passes_through(self): - """String value that doesn't resolve from registry should fail type validation.""" + assert model.flow.compute(FlowContext(value=3, bonus=7)).value == ((3 + 10) * 2) + ((3 + 10) * 5) + 7 - @Flow.model - def model_fn(x: int) -> int: - return x - cls = type(model_fn(x=1)) - with self.assertRaisesRegex(TypeError, "Field 'x'"): - cls.model_validate({"x": "nonexistent_registry_key"}) +def test_graph_integration_cycle_raises_cleanly(): + @Flow.model + def increment(x: int, n: FromContext[int]) -> int: + return x + n - def test_registry_ref_resolves_to_callable_model(self): - """String value resolving to a CallableModel should be substituted.""" - registry = ModelRegistry.root() - registry.clear() - try: + root = increment() + branch = increment(x=root) + object.__setattr__(root, "x", branch) - @Flow.model - def upstream(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 10) + with FlowOptionsOverride(options={"evaluator": GraphEvaluator()}): + with pytest.raises(graphlib.CycleError): + root.flow.compute(n=1) - @Flow.model - def downstream(context: SimpleContext, dep: CallableModel) -> GenericResult[int]: - return GenericResult(value=0) - registry.add("my_upstream", upstream()) - cls = type(downstream(dep=upstream())) - restored = cls.model_validate({"dep": "my_upstream"}) - self.assertIsNotNone(restored) - finally: - registry.clear() +def test_large_contextual_contract_stress(): + @Flow.model + def total( + base: int, + x1: FromContext[int], + x2: FromContext[int], + x3: FromContext[int], + x4: FromContext[int], + x5: FromContext[int], + x6: FromContext[int], + ) -> int: + return base + x1 + x2 + x3 + x4 + x5 + x6 + model = total(base=10) + assert model.flow.context_inputs == {"x1": int, "x2": int, "x3": int, "x4": int, "x5": int, "x6": int} + assert model.flow.compute(x1=1, x2=2, x3=3, x4=4, x5=5, x6=6).value == 31 -class TestMode2MissingContextField(TestCase): - """Line 1155: Mode 2 missing context field error.""" - def test_mode2_missing_required_context_field(self): - """Mode 2 model called with context missing a required field should raise.""" +def test_registry_integration_for_generated_models(): + registry = ModelRegistry.root().clear() + model = basic_loader(source="warehouse", multiplier=3) + registry.add("loader", model) - @Flow.model(context_args=["start_date", "end_date"]) - def loader(start_date: str, end_date: str, source: str = "db") -> str: - return f"{source}:{start_date}-{end_date}" + retrieved = registry["loader"] + assert isinstance(retrieved, CallableModel) + assert retrieved(SimpleContext(value=4)).value == 12 - model = loader() - # Call with a FlowContext missing end_date - with self.assertRaisesRegex(TypeError, "Missing context field"): - model(FlowContext(start_date="2024-01-01")) +def test_unexpected_type_adapter_errors_are_not_silently_swallowed(): + class BrokenSchema: + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + raise RuntimeError("boom") -class TestDynamicModeContextObjectLookup(TestCase): - """Line 1155/1176: Dynamic mode pulling deferred values from context object.""" + @Flow.model + def bad(x: BrokenSchema, y: FromContext[int]) -> int: + del x, y + return 0 - def test_deferred_value_coercion_through_context(self): - """Dynamic mode should coerce values from FlowContext through validators.""" + with pytest.raises(RuntimeError, match="boom"): + bad(x=object()) - @Flow.model - def typed_add(x: int, y: int) -> int: - return x + y - model = typed_add(x=10) - # Calling with a FlowContext — y pulled from context and coerced - result = model(FlowContext(y=5)) - self.assertEqual(result.value, 15) +def test_unexpected_type_hint_resolution_errors_propagate(monkeypatch): + def broken_get_type_hints(*args, **kwargs): + raise RuntimeError("boom") + monkeypatch.setattr(flow_model_module, "get_type_hints", broken_get_type_hints) -if __name__ == "__main__": - import unittest + def add(x: int) -> int: + return x - unittest.main() + with pytest.raises(RuntimeError, match="boom"): + Flow.model(add) diff --git a/ccflow/tests/test_flow_model_hydra.py b/ccflow/tests/test_flow_model_hydra.py index 28f2883..981c511 100644 --- a/ccflow/tests/test_flow_model_hydra.py +++ b/ccflow/tests/test_flow_model_hydra.py @@ -1,441 +1,135 @@ -"""Hydra integration tests for Flow.model. - -These tests verify that Flow.model decorated functions work correctly when -loaded from YAML configuration files using ModelRegistry.load_config_from_path(). - -Key feature: Registry name references (e.g., `source: flow_source`) ensure the same -object instance is shared across all consumers. -""" +"""Hydra integration tests for the FromContext-based Flow.model API.""" from datetime import date from pathlib import Path -from unittest import TestCase from omegaconf import OmegaConf -from ccflow import CallableModel, DateRangeContext, GenericResult, ModelRegistry +from ccflow import CallableModel, DateRangeContext, FlowContext, ModelRegistry from .test_flow_model import SimpleContext CONFIG_PATH = str(Path(__file__).parent / "config" / "conf_flow.yaml") -class TestFlowModelHydraYAML(TestCase): - """Tests loading Flow.model from YAML config files using ModelRegistry.""" - - def setUp(self) -> None: - ModelRegistry.root().clear() - - def tearDown(self) -> None: - ModelRegistry.root().clear() - - def test_basic_loader_from_yaml(self): - """Test basic model instantiation from YAML.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - - loader = r["flow_loader"] +def setup_function(): + ModelRegistry.root().clear() - self.assertIsInstance(loader, CallableModel) - ctx = SimpleContext(value=10) - result = loader(ctx) - self.assertEqual(result.value, 50) # 10 * 5 +def teardown_function(): + ModelRegistry.root().clear() - def test_string_processor_from_yaml(self): - """Test string processor model from YAML.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - processor = r["flow_processor"] +def test_basic_loader_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) - ctx = SimpleContext(value=42) - result = processor(ctx) - self.assertEqual(result.value, "value=42!") + loader = registry["flow_loader"] + assert isinstance(loader, CallableModel) + assert loader(SimpleContext(value=10)).value == 50 - def test_two_stage_pipeline_from_yaml(self): - """Test two-stage pipeline from YAML config.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - transformer = r["flow_transformer"] +def test_basic_processor_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) - self.assertIsInstance(transformer, CallableModel) + processor = registry["flow_processor"] + assert processor(SimpleContext(value=42)).value == "value=42!" - ctx = SimpleContext(value=5) - result = transformer(ctx) - # flow_source: 5 + 100 = 105 - # flow_transformer: 105 * 3 = 315 - self.assertEqual(result.value, 315) - def test_three_stage_pipeline_from_yaml(self): - """Test three-stage pipeline from YAML config.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) +def test_two_stage_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) - stage3 = r["flow_stage3"] + transformer = registry["flow_transformer"] + assert transformer(SimpleContext(value=5)).value == 315 - ctx = SimpleContext(value=10) - result = stage3(ctx) - # stage1: 10 + 10 = 20 - # stage2: 20 * 2 = 40 - # stage3: 40 + 50 = 90 - self.assertEqual(result.value, 90) - def test_diamond_dependency_from_yaml(self): - """Test diamond dependency pattern from YAML config.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) +def test_three_stage_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) - aggregator = r["diamond_aggregator"] + stage3 = registry["flow_stage3"] + assert stage3(SimpleContext(value=10)).value == 90 - ctx = SimpleContext(value=10) - result = aggregator(ctx) - # source: 10 + 10 = 20 - # branch_a: 20 * 2 = 40 - # branch_b: 20 * 5 = 100 - # aggregator: 40 + 100 = 140 - self.assertEqual(result.value, 140) - def test_date_range_pipeline_from_yaml(self): - """Test DateRangeContext pipeline with transforms from YAML.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) +def test_diamond_dependency_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) - processor = r["flow_date_processor"] + aggregator = registry["diamond_aggregator"] + assert aggregator(SimpleContext(value=10)).value == 140 - ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) - result = processor(ctx) - - # The transform extends start_date back by one day - self.assertIn("2024-01-09", result.value) - self.assertIn("normalized:", result.value) - - def test_context_args_from_yaml(self): - """Test context_args model from YAML config.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - - loader = r["ctx_args_loader"] - - self.assertIsInstance(loader, CallableModel) - # context_args models use DateRangeContext - self.assertEqual(loader.context_type, DateRangeContext) - - ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) - result = loader(ctx) - self.assertEqual( - result.value, - { - "source": "data_source", - "start_date": "2024-01-01", - "end_date": "2024-01-31", - }, - ) - def test_context_args_pipeline_from_yaml(self): - """Test context_args pipeline with dependencies from YAML.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) +def test_date_range_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) - processor = r["ctx_args_processor"] + processor = registry["flow_date_processor"] + ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) + result = processor(ctx) - ctx = DateRangeContext(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31)) - result = processor(ctx) - # loader: "data_source:2024-03-01 to 2024-03-31" - # processor: "output:data_source:2024-03-01 to 2024-03-31" - self.assertEqual(result.value, "output:data_source:2024-03-01 to 2024-03-31") + assert "normalized:" in result.value + assert "2024-01-09" in result.value - def test_context_args_shares_instance(self): - """Test that context_args pipeline shares dependency instance.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - loader = r["ctx_args_loader"] - processor = r["ctx_args_processor"] +def test_from_context_pipeline_from_yaml(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) - self.assertIs(processor.data, loader) + loader = registry["contextual_loader_model"] + processor = registry["contextual_processor_model"] + assert loader.flow.context_inputs == {"start_date": date, "end_date": date} + result = processor.flow.compute(start_date=date(2024, 3, 1), end_date=date(2024, 3, 31)) + assert result.value == "output:data_source:2024-03-01 to 2024-03-31" + assert processor.data is loader -class TestFlowModelHydraInstanceSharing(TestCase): - """Tests that registry name references share the same object instance.""" - - def setUp(self) -> None: - ModelRegistry.root().clear() - - def tearDown(self) -> None: - ModelRegistry.root().clear() - def test_pipeline_shares_instance(self): - """Test that pipeline stages share the same dependency instance.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - - transformer = r["flow_transformer"] - source = r["flow_source"] +def test_registry_name_references_share_instances(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) - self.assertIs(transformer.source, source) - - def test_three_stage_pipeline_shares_instances(self): - """Test that three-stage pipeline shares instances correctly.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - - stage1 = r["flow_stage1"] - stage2 = r["flow_stage2"] - stage3 = r["flow_stage3"] - - self.assertIs(stage2.stage1_output, stage1) - self.assertIs(stage3.stage2_output, stage2) - - def test_diamond_pattern_shares_source_instance(self): - """Test that diamond pattern branches share the same source instance.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - - source = r["diamond_source"] - branch_a = r["diamond_branch_a"] - branch_b = r["diamond_branch_b"] - aggregator = r["diamond_aggregator"] - - # Both branches should share the SAME source instance - self.assertIs(branch_a.source, source) - self.assertIs(branch_b.source, source) - self.assertIs(branch_a.source, branch_b.source) - - self.assertIs(aggregator.input_a, branch_a) - self.assertIs(aggregator.input_b, branch_b) - - def test_date_range_shares_instance(self): - """Test that date range pipeline shares dependency instance.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - - loader = r["flow_date_loader"] - processor = r["flow_date_processor"] - - self.assertIs(processor.raw_data, loader) - - -class TestFlowModelHydraOmegaConf(TestCase): - """Tests using OmegaConf.create for dynamic config creation.""" - - def setUp(self) -> None: - ModelRegistry.root().clear() - - def tearDown(self) -> None: - ModelRegistry.root().clear() - - def test_instantiate_with_omegaconf(self): - """Test instantiation using OmegaConf.create via ModelRegistry.""" - cfg = OmegaConf.create( - { - "loader": { - "_target_": "ccflow.tests.test_flow_model.basic_loader", - "source": "dynamic_source", - "multiplier": 7, - }, - } - ) - - r = ModelRegistry.root() - r.load_config(cfg) - loader = r["loader"] - - ctx = SimpleContext(value=3) - result = loader(ctx) - self.assertEqual(result.value, 21) # 3 * 7 - - def test_nested_deps_with_omegaconf(self): - """Test nested dependencies using OmegaConf with registry names.""" - cfg = OmegaConf.create( - { - "source": { - "_target_": "ccflow.tests.test_flow_model.data_source", - "base_value": 50, - }, - "transformer": { - "_target_": "ccflow.tests.test_flow_model.data_transformer", - "source": "source", - "factor": 4, - }, - } - ) - - r = ModelRegistry.root() - r.load_config(cfg) - transformer = r["transformer"] - - ctx = SimpleContext(value=10) - result = transformer(ctx) - # source: 10 + 50 = 60 - # transformer: 60 * 4 = 240 - self.assertEqual(result.value, 240) - - self.assertIs(transformer.source, r["source"]) - - def test_diamond_with_omegaconf(self): - """Test diamond pattern with OmegaConf using registry names.""" - cfg = OmegaConf.create( - { - "source": { - "_target_": "ccflow.tests.test_flow_model.data_source", - "base_value": 10, - }, - "branch_a": { - "_target_": "ccflow.tests.test_flow_model.data_transformer", - "source": "source", - "factor": 2, - }, - "branch_b": { - "_target_": "ccflow.tests.test_flow_model.data_transformer", - "source": "source", - "factor": 3, - }, - "aggregator": { - "_target_": "ccflow.tests.test_flow_model.data_aggregator", - "input_a": "branch_a", - "input_b": "branch_b", - "operation": "multiply", - }, - } - ) - - r = ModelRegistry.root() - r.load_config(cfg) - aggregator = r["aggregator"] - - ctx = SimpleContext(value=5) - result = aggregator(ctx) - # source: 5 + 10 = 15 - # branch_a: 15 * 2 = 30 - # branch_b: 15 * 3 = 45 - # aggregator: 30 * 45 = 1350 - self.assertEqual(result.value, 1350) - - # Verify SAME source instance is shared - self.assertIs(r["branch_a"].source, r["source"]) - self.assertIs(r["branch_b"].source, r["source"]) - - -class TestFlowModelHydraDefaults(TestCase): - """Tests that default parameter values work with Hydra.""" - - def setUp(self) -> None: - ModelRegistry.root().clear() - - def tearDown(self) -> None: - ModelRegistry.root().clear() - - def test_defaults_used_when_not_specified(self): - """Test that default values are used when not in config.""" - cfg = OmegaConf.create( - { - "loader": { - "_target_": "ccflow.tests.test_flow_model.basic_loader", - "source": "test", - }, - } - ) - - r = ModelRegistry.root() - r.load_config(cfg) - loader = r["loader"] - - ctx = SimpleContext(value=10) - result = loader(ctx) - self.assertEqual(result.value, 10) # 10 * 1 (default) - - def test_defaults_can_be_overridden(self): - """Test that defaults can be overridden in config.""" - cfg = OmegaConf.create( - { - "loader": { - "_target_": "ccflow.tests.test_flow_model.basic_loader", - "source": "test", - "multiplier": 100, - }, - } - ) - - r = ModelRegistry.root() - r.load_config(cfg) - loader = r["loader"] - - ctx = SimpleContext(value=10) - result = loader(ctx) - self.assertEqual(result.value, 1000) # 10 * 100 - - -class TestFlowModelHydraModelProperties(TestCase): - """Tests that model properties are correct after Hydra instantiation.""" - - def setUp(self) -> None: - ModelRegistry.root().clear() - - def tearDown(self) -> None: - ModelRegistry.root().clear() - - def test_context_type_property(self): - """Test that context_type is correct.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - - loader = r["flow_loader"] - self.assertEqual(loader.context_type, SimpleContext) - - def test_result_type_property(self): - """Test that result_type is correct.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) + transformer = registry["flow_transformer"] + source = registry["flow_source"] + assert transformer.source is source - loader = r["flow_loader"] - self.assertEqual(loader.result_type, GenericResult[int]) + stage2 = registry["flow_stage2"] + stage3 = registry["flow_stage3"] + assert stage2.stage1_output is registry["flow_stage1"] + assert stage3.stage2_output is stage2 - def test_deps_method_works(self): - """Test that __deps__ method works after Hydra instantiation.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - transformer = r["flow_transformer"] - - ctx = SimpleContext(value=5) - deps = transformer.__deps__(ctx) - - self.assertEqual(len(deps), 1) - self.assertIsInstance(deps[0][0], CallableModel) - self.assertEqual(deps[0][1], [ctx]) - self.assertIs(deps[0][0], r["flow_source"]) - - -class TestFlowModelHydraDateRangeTransforms(TestCase): - """Tests transforms with DateRangeContext from Hydra config.""" - - def setUp(self) -> None: - ModelRegistry.root().clear() - - def tearDown(self) -> None: - ModelRegistry.root().clear() - - def test_transform_applied_from_yaml(self): - """Test that transform is applied when loaded from YAML.""" - r = ModelRegistry.root() - r.load_config_from_path(CONFIG_PATH) - - processor = r["flow_date_processor"] - - ctx = DateRangeContext(start_date=date(2024, 1, 10), end_date=date(2024, 1, 31)) - deps = processor.__deps__(ctx) +def test_instantiate_with_omegaconf(): + cfg = OmegaConf.create( + { + "loader": { + "_target_": "ccflow.tests.test_flow_model.basic_loader", + "source": "dynamic_source", + "multiplier": 7, + }, + "contextual": { + "_target_": "ccflow.tests.test_flow_model.contextual_loader", + "source": "warehouse", + }, + } + ) - self.assertEqual(len(deps), 1) - dep_model, dep_contexts = deps[0] + registry = ModelRegistry.root() + registry.load_config(cfg) - self.assertIs(dep_model, r["flow_date_loader"]) - self.assertEqual(dep_contexts[0], ctx) - self.assertEqual(dep_model(ctx).value["start_date"], "2024-01-09") + assert registry["loader"](SimpleContext(value=3)).value == 21 + assert registry["contextual"].flow.compute(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)).value == { + "source": "warehouse", + "start_date": "2024-01-01", + "end_date": "2024-01-02", + } -if __name__ == "__main__": - import unittest +def test_flow_context_execution_with_yaml_models(): + registry = ModelRegistry.root() + registry.load_config_from_path(CONFIG_PATH) - unittest.main() + processor = registry["contextual_processor_model"] + result = processor.flow.compute(FlowContext(start_date=date(2024, 4, 1), end_date=date(2024, 4, 30))) + assert result.value == "output:data_source:2024-04-01 to 2024-04-30" diff --git a/ccflow/utils/chunker.py b/ccflow/utils/chunker.py index 605bfbd..fb32c70 100644 --- a/ccflow/utils/chunker.py +++ b/ccflow/utils/chunker.py @@ -12,6 +12,8 @@ import pandas as pd +from ccflow.exttypes.frequency import _normalize_frequency_alias + _MIN_END_DATE = date(1969, 12, 31) __all__ = ("dates_to_chunks",) @@ -31,19 +33,20 @@ def dates_to_chunks(start: date, end: date, chunk_size: str = "ME", trim: bool = Returns: List of tuples of (start date, end date) for each of the chunks """ + normalized_chunk_size = _normalize_frequency_alias(chunk_size) with warnings.catch_warnings(): # Because pandas 2.2 deprecated many frequency strings (i.e. "Y", "M", "T" still in common use) # We should consider switching away from pandas on this and supporting ISO warnings.simplefilter("ignore", category=FutureWarning) - offset = pd.tseries.frequencies.to_offset(chunk_size) + offset = pd.tseries.frequencies.to_offset(normalized_chunk_size) if offset.n == 1: - end_dates = pd.date_range(start - offset, end + offset, freq=chunk_size) + end_dates = pd.date_range(start - offset, end + offset, freq=normalized_chunk_size) else: # Need to anchor the timeline at some absolute date, because otherwise chunks might depend on the start date # and end up overlappig each other, i.e. with 2M, would end up with # i.e. (Jan-Feb) or (Feb,Mar) depending on whether start date was in Jan or Feb, # instead of always returning (Jan,Feb) for any start date in either of those two months. - end_dates = pd.date_range(_MIN_END_DATE, end + offset, freq=chunk_size) + end_dates = pd.date_range(_MIN_END_DATE, end + offset, freq=normalized_chunk_size) start_dates = end_dates + pd.DateOffset(1) chunks = [(s, e) for s, e in zip(start_dates[:-1].date, end_dates[1:].date) if e >= start and s <= end] if trim: diff --git a/ccflow/validators.py b/ccflow/validators.py index 720187c..fee7698 100644 --- a/ccflow/validators.py +++ b/ccflow/validators.py @@ -9,6 +9,7 @@ from pydantic import TypeAdapter, ValidationError from .exttypes import PyObjectPath +from .exttypes.frequency import _normalize_frequency_alias _DatetimeAdapter = TypeAdapter(datetime) @@ -25,7 +26,7 @@ def normalize_date(v: Any) -> Any: """Validator that will convert string offsets to date based on today, and convert datetime to date.""" if isinstance(v, str): # Check case where it's an offset try: - timestamp = pd.tseries.frequencies.to_offset(v) + date.today() + timestamp = pd.tseries.frequencies.to_offset(_normalize_frequency_alias(v)) + date.today() return timestamp.date() except ValueError: pass @@ -44,7 +45,7 @@ def normalize_datetime(v: Any) -> Any: """Validator that will convert string offsets to datetime based on today, and convert datetime to date.""" if isinstance(v, str): # Check case where it's an offset try: - return (pd.tseries.frequencies.to_offset(v) + date.today()).to_pydatetime() + return (pd.tseries.frequencies.to_offset(_normalize_frequency_alias(v)) + date.today()).to_pydatetime() except ValueError: pass if isinstance(v, dict): diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index 7b6ac9f..6ea254c 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -4,143 +4,167 @@ `@Flow.model` turns a plain Python function into a real `CallableModel`. -The core goals are: +The design is intentionally narrow: -- keep the authoring model close to an ordinary function, -- preserve the existing evaluator / registry / serialization machinery, -- make deferred execution explicit with `.flow.compute(...)` and `.flow.with_inputs(...)`, -- allow callers to pass either literal values or upstream models for ordinary parameters. +- ordinary unmarked parameters are regular bound inputs, +- `FromContext[T]` marks the only runtime/contextual inputs, +- `.flow.compute(...)` supplies contextual inputs, +- `.flow.with_inputs(...)` rewires contextual inputs on one dependency edge, +- upstream `CallableModel`s can still be passed as ordinary arguments. -`@Flow.model` is syntactic sugar over the existing ccflow framework. The -generated object is still a standard `CallableModel`, so you can execute it the -same way as any other model by calling it with a context object. The -`.flow.compute(...)` helper is an explicit, ergonomic way to mark the deferred -execution boundary when supplying runtime inputs as keyword arguments. +The goal is that a reader can look at one function signature and immediately +answer: -## Core Patterns +1. which values come from runtime context, +2. which values must be bound as regular configuration or dependencies, +3. how to rewrite contextual inputs for one branch of the graph. -### Default Deferred Style - -This is the most ergonomic mode. Bind some parameters up front, then provide -the remaining runtime inputs later. +## Primary Story ```python -from ccflow import Flow, FlowContext +from ccflow import Flow, FromContext @Flow.model -def add(x: int, y: int) -> int: - return x + y +def foo(a: int, b: FromContext[int]) -> int: + return a + b + + +# Build an instance with a=11 bound, then supply b=12 at runtime: +result = foo(a=11).flow.compute(b=12) +assert result.value == 23 # .value unwraps the GenericResult wrapper + +# Or pre-fill both — b=12 becomes a contextual default: +result = foo(a=11, b=12).flow.compute() +assert result.value == 23 +``` +> **Note:** When the function returns a plain value (like `int` above) instead +> of a `ResultBase` subclass, `@Flow.model` automatically wraps it in +> `GenericResult`. Access the inner value with `.value`. -model = add(x=10) +This is the core contract: -# Explicit deferred entry point -assert model.flow.compute(y=5).value == 15 +- `a` is a regular parameter — it must be bound at construction time, +- `b` is contextual because it is marked with `FromContext[int]` — it can come + from runtime context, a construction-time default, or a function default, +- `.flow.compute(...)` only accepts contextual inputs. -# Standard CallableModel call path -assert model(FlowContext(y=5)).value == 15 +This means the following is **invalid**: -shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert shifted.flow.compute(y=5).value == 20 +```python +foo().flow.compute(a=11, b=12) +# TypeError: compute() only accepts contextual inputs. +# Bind regular parameter(s) separately: a ``` -In this mode: +`a` is not contextual, so it must be bound at construction time (`foo(a=11)`) +or wired with `.pipe(...)`. + +## Regular Parameters vs Contextual Parameters + +### Regular Parameters -- bound parameters are model configuration, -- unbound parameters become runtime inputs for that model instance. +Regular parameters are the unmarked ones. -### Explicit Context Parameter +They can be satisfied by: + +- a literal value, +- a default value from the function signature, +- an upstream `CallableModel`. + +When an upstream model is supplied, `@Flow.model` evaluates it with the current +context and passes the resolved value into the function. This is how you wire +stages together — just pass one model as an argument to another: ```python -from ccflow import DateRangeContext, Flow +from ccflow import Flow, FlowContext, FromContext @Flow.model -def load_revenue(context: DateRangeContext, region: str) -> float: - return 125.0 -``` +def load_value(value: FromContext[int], offset: int) -> int: + return value + offset -This is the most direct mode. The function receives a normal context object and -returns either a `ResultBase` subclass or a plain value. Plain values are -wrapped into `GenericResult` automatically by the generated model. -### `context_args` - -```python -from datetime import date +@Flow.model +def add(a: int, b: FromContext[int]) -> int: + return a + b -from ccflow import Flow +# Wire load_value into add's 'a' parameter: +model = add(a=load_value(offset=5)) -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def load_revenue(start_date: date, end_date: date, region: str) -> float: - return 125.0 +# At runtime, load_value runs first (value=7 + offset=5 = 12), +# then add runs (a=12 + b=12 = 24): +assert model.flow.compute(value=7, b=12).value == 24 ``` -This keeps the function signature focused on the inputs it actually uses while -still producing a `CallableModel` that accepts a context at runtime. +### Contextual Parameters -Use `context_args` when certain parameters are semantically the execution -context and you want that split to be explicit and stable across model -instances. +Contextual parameters are the ones marked with `FromContext[...]`. -By default, `context_args` models use `FlowContext`. If you want compatibility -with an existing context class, pass `context_type=...` explicitly. +They can be satisfied by: -### Upstream Models as Normal Arguments +- runtime context, +- construction-time contextual defaults, +- function defaults. -Any non-context parameter can be given either: +They cannot be satisfied by `CallableModel` values. -- a literal value, or -- another `CallableModel` / `BoundModel`. +Contextual precedence is: -If a model is passed, it is evaluated with the current context and its result is -unwrapped before the function is called. +1. branch-local `.flow.with_inputs(...)` rewrites, +2. incoming runtime context, +3. construction-time contextual defaults, +4. function defaults. -```python -from ccflow import DateRangeContext, Flow +## `.flow.compute(...)` +`.flow.compute(...)` is the ergonomic entry point for contextual execution. -@Flow.model -def load_revenue(context: DateRangeContext, region: str) -> float: - return 125.0 +For generated `@Flow.model` stages it accepts either: + +- contextual keyword arguments, or +- one context object. + +It does not accept both at the same time. + +```python +from ccflow import Flow, FlowContext, FromContext @Flow.model -def double_revenue(_: DateRangeContext, revenue: float) -> float: - return revenue * 2 +def add(a: int, b: FromContext[int]) -> int: + return a + b -revenue = load_revenue(region="us") -model = double_revenue(revenue=revenue) -result = model.flow.compute(start_date="2024-01-01", end_date="2024-01-31") +model = add(a=10) +assert model.flow.compute(b=5).value == 15 +assert model.flow.compute(FlowContext(b=6)).value == 16 ``` -This is the main composition story for the core API. +`compute()` returns the same result object you would get from `model(context)`. -### `.flow.with_inputs(...)` +## `.flow.with_inputs(...)` -`with_inputs` is how a caller rewires context locally for one upstream model. +`.flow.with_inputs(...)` rewrites contextual inputs locally for one wrapped +dependency. ```python from datetime import date, timedelta -from ccflow import DateRangeContext, Flow +from ccflow import DateRangeContext, Flow, FromContext -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def load_revenue(start_date: date, end_date: date, region: str) -> float: +@Flow.model +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: days = (end_date - start_date).days + 1 return 1000.0 + days * 10.0 -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def revenue_growth(start_date: date, end_date: date, current: float, previous: float) -> dict: - return { - "window_end": end_date, - "growth_pct": round((current - previous) / previous * 100, 2), - } +@Flow.model +def revenue_growth(current: float, previous: float, start_date: FromContext[date], end_date: FromContext[date]) -> dict: + return {"window_end": end_date, "growth_pct": round((current - previous) / previous * 100, 2)} current = load_revenue(region="us") @@ -149,122 +173,177 @@ previous = current.flow.with_inputs( end_date=lambda ctx: ctx.end_date - timedelta(days=30), ) -model = revenue_growth(current=current, previous=previous) -ctx = DateRangeContext( - start_date=date(2024, 1, 1), - end_date=date(2024, 1, 31), -) +growth = revenue_growth(current=current, previous=previous) +result = growth(DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31))) +``` -direct = model(ctx) -computed = model.flow.compute( - start_date=date(2024, 1, 1), - end_date=date(2024, 1, 31), -) +In this example, `current` and `previous` share the same `load_revenue` model +but see different date windows at runtime. The `with_inputs()` call on +`previous` shifts the dates back 30 days without affecting `current`. + +Key rules: + +- `with_inputs()` only targets contextual fields, +- transforms are branch-local — they only affect the wrapped dependency, not + the entire pipeline, +- chained `with_inputs()` calls merge, with the newest transform winning for a + repeated field. + +## `.pipe(...)` -assert direct == computed +`.pipe(...)` is a convenience API for wiring one upstream model into a +downstream regular parameter. It is equivalent to passing the model directly: + +```python +source = load_value(offset=5) + +# These two are equivalent: +model_a = add(a=source) +model_b = source.pipe(add(), param="a") ``` -The transform is local to the bound upstream model. The parent model continues -to receive the original context. +`pipe()` is most useful when the downstream stage is already partially +configured and you want to wire in one more dependency, or when you are +building pipelines programmatically and the parameter name is determined at +runtime. It only targets regular parameters — use `.flow.with_inputs(...)` to +rewrite contextual inputs. -### `.flow.compute(...)` +## Explicit Context Interop -`compute` is the ergonomic entry point for deferred execution: +`@Flow.model` still supports an explicit context parameter for cases where the +function needs the whole context object: ```python -from ccflow import Flow +from ccflow import DateRangeContext, Flow @Flow.model -def add(x: int, y: int) -> int: - return x + y +def load_revenue(context: DateRangeContext, region: str) -> float: + days = (context.end_date - context.start_date).days + 1 + return days * 50.0 +``` + +This path is useful when interoperating with existing code that already uses +typed `ContextBase` subclasses, or when the function genuinely needs access to +the full context rather than individual fields. + +You can also keep the `FromContext[...]` style while asking ccflow to validate +those contextual fields against an existing nominal context shape: + +```python +from ccflow import DateRangeContext, Flow, FromContext -model = add(x=10) -assert model.flow.compute(y=5).value == 15 +@Flow.model(context_type=DateRangeContext) +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: + return 125.0 ``` -It validates the supplied keyword arguments against the generated context -schema, creates a `FlowContext`, and executes the model. +That preserves the primary `FromContext[...]` authoring model while letting +callers pass richer context objects whose relevant fields satisfy the declared +`context_type`. + +Do not mix both systems in one function signature. A function with an explicit +`context: ContextBase` parameter cannot also declare `FromContext[...]` +parameters. -It returns the same result object you would get from calling `model(context)`. +## Introspection APIs -It is not the only execution path. Because the generated object is still a -standard `CallableModel`, calling `model(context)` remains fully supported. +Generated models expose three useful introspection helpers: -## Lazy Inputs +- `model.flow.context_inputs`: the full contextual contract, +- `model.flow.unbound_inputs`: the contextual fields still required at runtime, +- `model.flow.bound_inputs`: regular bound inputs plus any construction-time + contextual defaults. -`Lazy[T]` marks a parameter as on-demand. Instead of eagerly resolving an -upstream model, the generated model passes a zero-argument thunk. The thunk -caches its first result. Lazy dependencies are excluded from the `__deps__` -graph, so they are not pre-evaluated by the evaluator infrastructure. +Example: ```python -from ccflow import Flow, Lazy +from ccflow import Flow, FromContext @Flow.model -def source(value: int) -> int: +def add(a: int, b: FromContext[int], c: FromContext[int] = 5) -> int: + return a + b + c + + +model = add(a=10) +assert model.flow.context_inputs == {"b": int, "c": int} +assert model.flow.unbound_inputs == {"b": int} +assert model.flow.bound_inputs == {"a": 10} +``` + +## Lazy Dependencies + +`Lazy[T]` defers evaluation of an upstream dependency until the function body +explicitly calls it. This is useful when a dependency is expensive and only +needed conditionally: + +```python +from ccflow import Flow, FlowContext, FromContext, Lazy + + +@Flow.model +def load_value(value: FromContext[int]) -> int: return value * 10 @Flow.model -def maybe_use_source(value: int, data: Lazy[int]) -> int: - if value > 10: - return value - return data() -``` +def maybe_use(current: int, fallback: Lazy[int], threshold: FromContext[int]) -> int: + if current > threshold: + return current # fallback is never evaluated + return fallback() # evaluate only when needed + -## FlowContext +model = maybe_use(current=50, fallback=load_value()) -`FlowContext` is the universal frozen carrier for generated contexts that do -not map to a dedicated built-in context type. +# current (50) > threshold (10), so load_value never runs: +assert model.flow.compute(value=3, threshold=10).value == 50 -The implementation stays intentionally small: +# current (5) <= threshold (10), so load_value runs (3 * 10 = 30): +model2 = maybe_use(current=5, fallback=load_value()) +assert model2.flow.compute(value=3, threshold=10).value == 30 +``` -- context validation is driven by `TypedDict` + `TypeAdapter`, -- runtime execution uses one reusable `FlowContext` type, -- public pydantic iteration (`dict(context)`) is used instead of pydantic - internals. +Without `Lazy[T]`, the upstream model would always run. With it, the function +controls exactly when (and whether) the dependency executes. -## BoundModel +## When To Use `@Flow.model` -`.flow.with_inputs(...)` returns a `BoundModel`, which is just a thin wrapper -around: +Use `@Flow.model` when: -- the original model, and -- a mapping of input transforms. +- the stage logic is naturally a plain function, +- you want ordinary arguments to look like ordinary Python function parameters, +- the contextual contract is small and explicit, +- the main goal is easy graph authoring on top of existing ccflow machinery. -At call time it: +Use a hand-written class-based `CallableModel` when: -1. converts the incoming context into a plain dictionary, -1. applies the configured transforms, -1. rebuilds a `FlowContext`, -1. delegates to the wrapped model. +- the model needs custom methods or substantial internal state, +- the full context object is the natural primary interface, +- the stage is no longer best expressed as one function and a small amount of + wiring. -That keeps transformed dependency wiring explicit without adding special -annotation machinery to the core API. +## Troubleshooting -## Flow.call with `auto_context` +**`compute()` says a field is not contextual** -Separately from `@Flow.model`, `Flow.call(auto_context=...)` provides a similar -convenience for class-based `CallableModel`s. Instead of defining a separate -`ContextBase` subclass, the decorator generates one from the function's -keyword-only parameters. +That field is a regular parameter. Bind it at construction time or wire it with +`.pipe(...)`. Only `FromContext[...]` fields belong in `compute()`. -```python -from ccflow import CallableModel, Flow, GenericResult +**`with_inputs()` rejects a field** +`with_inputs()` only rewrites contextual inputs. If you are trying to attach one +stage to another, use regular argument binding or `.pipe(...)`. -class MyModel(CallableModel): - @Flow.call(auto_context=True) - def __call__(self, *, x: int, y: str = "default") -> GenericResult: - return GenericResult(value=f"{x}-{y}") -``` +**A contextual parameter still shows up in `context_inputs` after I bound it** + +That is expected. `context_inputs` reports the full contextual contract. +`unbound_inputs` reports only the contextual values still needed at runtime. -Passing a `ContextBase` subclass (e.g., `auto_context=DateContext`) makes the -generated context inherit from that class, so it remains compatible with -infrastructure that expects the parent type. +**A shared dependency runs more than once** -The generated class is registered via `create_ccflow_model` for serialization -support. +`@Flow.model` authors the graph cleanly, but execution still follows the normal +ccflow evaluator path. If you need deduplication or graph scheduling, use the +appropriate evaluators and cache settings just as you would for class-based +`CallableModel`s. diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 5fb27de..c1346ee 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -6,428 +6,216 @@ - [Evaluators](#evaluators) - [Results](#results) -`ccflow` (Composable Configuration Flow) is a collection of tools for workflow configuration, orchestration, and dependency injection. -It is intended to be flexible enough to handle diverse use cases, including data retrieval, validation, transformation, and loading (i.e. ETL workflows), model training, microservice configuration, and automated report generation. +`ccflow` (Composable Configuration Flow) is a collection of tools for workflow +configuration, orchestration, and dependency injection. It is intended to stay +flexible across ETL workflows, model training, report generation, and service +configuration. ## Base Model -Central to `ccflow` is the `BaseModel` class. -`BaseModel` is the base class for models in the `ccflow` framework. -A model is basically a data class (class with attributes). -The naming was inspired by the open source library [Pydantic](https://docs.pydantic.dev/latest/) (`BaseModel` actually inherits from the Pydantic base model class). +`BaseModel` is the core pydantic-based model class used throughout `ccflow`. +Models are regular data objects with validation, serialization, and registry +support. ## Callable Model -`CallableModel` is the base class for a special type of `BaseModel` which can be called. -`CallableModel`'s are called with a context (something that derives from `ContextBase`) and returns a result (something that derives from `ResultBase`). -As an example, you may have a `SQLReader` callable model that when called with a `DateRangeContext` returns a `ArrowResult` (wrapper around a Arrow table) with data in the date range defined by the context by querying some SQL database. +`CallableModel` is a `BaseModel` that can be executed with a context +(`ContextBase`) and produces a result (`ResultBase`). ### Flow.model Decorator -The `@Flow.model` decorator provides a simpler way to define `CallableModel`s -using plain Python functions instead of classes. It automatically generates a -standard `CallableModel` class with proper `__call__` and `__deps__` methods, -so it still uses the normal ccflow framework for evaluation, caching, -serialization, and registry loading. +`@Flow.model` is the plain-function front door to `CallableModel`. -If a `@Flow.model` function returns a plain value instead of a `ResultBase` -subclass, the generated model automatically wraps it in `GenericResult` at -runtime so it still behaves like a normal `CallableModel`. +It generates a real `CallableModel` class with proper `__call__` and `__deps__` +methods, so it still plugs into the normal evaluator, registry, cache, Hydra, +and serialization machinery. -You can execute a generated model in two equivalent ways: +If the function returns a plain value instead of a `ResultBase`, the generated +model wraps it in `GenericResult`. -- call it directly with a context object: `model(ctx)` -- use `.flow.compute(...)` to supply runtime inputs as keyword arguments +#### Primary Authoring Model -`.flow.compute(...)` is mainly an explicit, ergonomic way to mark the deferred -execution point. - -#### Context Modes - -There are three ways to define how a `@Flow.model` function receives its -runtime context. - -**Mode 1 — Explicit context parameter:** - -The function takes a `context` parameter (or `_` if unused) annotated with a -`ContextBase` subclass. This is the most direct mode and behaves like a -traditional `CallableModel.__call__`. +`FromContext[T]` is the only marker for runtime/contextual inputs. ```python -from datetime import date -from ccflow import Flow, GenericResult, DateRangeContext +from ccflow import Flow, FromContext + @Flow.model -def load_data(context: DateRangeContext, source: str) -> GenericResult[dict]: - return GenericResult(value=query_db(source, context.start_date, context.end_date)) +def add(a: int, b: FromContext[int]) -> int: + return a + b -loader = load_data(source="my_database") -ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = loader(ctx) +model = add(a=10) +assert model.flow.compute(b=5).value == 15 ``` -**Mode 2 — Unpacked context with `context_args`:** +That means: -Instead of receiving a context object, you list which parameters should come -from the context at runtime. The remaining parameters are model configuration. +- `a` is a regular parameter, +- `b` is contextual, +- `.flow.compute(...)` only accepts contextual inputs. -```python -from datetime import date -from ccflow import Flow, GenericResult, DateRangeContext +Regular parameters can be satisfied by: -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def load_data(start_date: date, end_date: date, source: str) -> GenericResult[str]: - return GenericResult(value=f"{source}:{start_date} to {end_date}") +- literal values, +- function defaults, +- upstream `CallableModel`s. -loader = load_data(source="my_database") +Contextual parameters can be satisfied by: -# Opt in explicitly when you want compatibility with an existing context type -assert loader.context_type == DateRangeContext +- runtime context, +- construction-time contextual defaults, +- function defaults. -ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) -result = loader(ctx) -``` - -By default, `context_args` models use `FlowContext`, a universal frozen carrier -for the validated fields. If you want the generated model to advertise and -accept an existing `ContextBase` subclass, pass `context_type=...` explicitly. - -Use `context_args` when some parameters are semantically "the execution -context" and you want that split to stay stable and explicit: +Contextual parameters cannot be bound to `CallableModel` values. -- the runtime context should be stable across instances -- the split between config and runtime inputs matters semantically -- the model is naturally "run over a context" such as date windows, - partitions, or scenarios -- you want the generated model to accept a specific existing context type - such as `DateRangeContext` +#### Explicit Context Interop -**Mode 3 — Dynamic deferred style (no explicit context):** - -When there is no `context` parameter and no `context_args`, all parameters are -potential configuration or runtime inputs. Parameters provided at construction -are bound (configuration); everything else comes from the context at runtime. +An explicit context parameter is still supported when that is the natural API: ```python -from ccflow import Flow +from ccflow import DateRangeContext, Flow + @Flow.model -def add(x: int, y: int) -> int: - return x + y +def load_data(context: DateRangeContext, source: str) -> float: + return 125.0 +``` -model = add(x=10) +You can also keep the `FromContext[...]` style while validating those fields +against an existing context type: + +```python +from datetime import date +from ccflow import DateRangeContext, Flow, FromContext -# `x` is bound when the model is created. -# `y` is supplied later at execution time. -assert model.flow.compute(y=5).value == 15 -# `.flow.with_inputs(...)` rewrites runtime inputs for this call path. -doubled_y = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert doubled_y.flow.compute(y=5).value == 20 +@Flow.model(context_type=DateRangeContext) +def load_data(source: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: + return 125.0 ``` +Do not mix both styles in one signature. + #### Composing Dependencies -Any non-context parameter can be bound either to a literal value or to another -`CallableModel`. If you pass an upstream model, `@Flow.model` evaluates it with -the current context and passes the resolved value into your function. +Passing an upstream model as an ordinary argument is the main composition story. ```python from datetime import date, timedelta -from ccflow import DateRangeContext, Flow +from ccflow import DateRangeContext, Flow, FromContext -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def load_revenue(start_date: date, end_date: date, region: str) -> float: + +@Flow.model +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: days = (end_date - start_date).days + 1 return 1000.0 + days * 10.0 -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def revenue_growth( - start_date: date, - end_date: date, - current: float, - previous: float, -) -> dict: - return { - "window_end": end_date, - "growth_pct": round((current - previous) / previous * 100, 2), - } + +@Flow.model +def revenue_growth(current: float, previous: float, start_date: FromContext[date], end_date: FromContext[date]) -> dict: + return {"window_end": end_date, "growth_pct": round((current - previous) / previous * 100, 2)} + current = load_revenue(region="us") + +# Reuse the same model with a shifted date window for "previous": previous = current.flow.with_inputs( start_date=lambda ctx: ctx.start_date - timedelta(days=30), end_date=lambda ctx: ctx.end_date - timedelta(days=30), ) -growth = revenue_growth(current=current, previous=previous) -ctx = DateRangeContext( - start_date=date(2024, 1, 1), - end_date=date(2024, 1, 31), -) - -# Standard ccflow execution -direct = growth(ctx) - -# Equivalent explicit deferred entry point -computed = growth.flow.compute( - start_date=date(2024, 1, 1), - end_date=date(2024, 1, 31), -) +growth = revenue_growth(current=current, previous=previous) -assert direct == computed +# Execute — current sees Jan 2024, previous sees Dec 2023: +ctx = DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 31)) +result = growth(ctx) ``` #### Deferred Execution Helpers -**`.flow.compute(**kwargs)`** validates the keyword arguments against the -generated context schema, wraps them in a `FlowContext`, and calls the model. -It returns the same result object you would get from `model(context)`. - -**`.flow.with_inputs(**transforms)`** returns a `BoundModel` that applies -context transforms before delegating to the underlying model. Each transform -is either a static value or a `(ctx) -> value` callable. Transforms are local -to the wrapped model — upstream models never see them. - -```python -from ccflow import Flow, FlowContext +`model.flow.compute(...)` accepts either contextual keyword arguments or one +context object and returns the same result as `model(context)`. -@Flow.model -def add(x: int, y: int) -> int: - return x + y +`model.flow.with_inputs(...)` rewrites contextual inputs on one dependency edge. +It only accepts contextual fields and remains branch-local. -model = add(x=10) -assert model.flow.compute(y=5).value == 15 +`model.pipe(...)` is a secondary helper for wiring one upstream model into a +downstream regular parameter. It never targets `FromContext[...]` fields. -shifted = model.flow.with_inputs(y=lambda ctx: ctx.y * 2) -assert shifted.flow.compute(y=5).value == 20 +Generated models also expose introspection helpers: -# You can also call with a context object directly -ctx = FlowContext(y=5) -assert model(ctx).value == 15 -assert shifted(ctx).value == 20 +```python +model = add(a=10) +model.flow.context_inputs # {"b": int} — the full contextual contract +model.flow.unbound_inputs # {"b": int} — contextual fields still needed at runtime +model.flow.bound_inputs # {"a": 10} — all construction-time values ``` -#### Lazy Dependencies with `Lazy[T]` +#### Lazy Dependencies -Mark a parameter with `Lazy[T]` to defer its evaluation. Instead of eagerly -resolving the upstream model, the generated model passes a zero-argument thunk -that evaluates on first call and caches the result. The thunk unwraps -`GenericResult` automatically, so `T` should be the inner value type. +`Lazy[T]` is the lazy type-level marker for dependency parameters. ```python -from ccflow import ContextBase, Flow, GenericResult, Lazy +from ccflow import Flow, FromContext, Lazy -class SimpleContext(ContextBase): - value: int @Flow.model -def fast_path(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) +def load_value(value: FromContext[int]) -> int: + return value * 10 -@Flow.model -def slow_path(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value * 100) @Flow.model -def smart_selector( - context: SimpleContext, - fast: int, # Eagerly resolved and unwrapped - slow: Lazy[int], # Deferred — receives a thunk returning unwrapped int - threshold: int = 10, -) -> GenericResult[int]: - if fast > threshold: - return GenericResult(value=fast) - return GenericResult(value=slow()) # Evaluated only when called - -model = smart_selector( - fast=fast_path(), - slow=slow_path(), - threshold=10, -) -``` - -`Lazy` dependencies are excluded from the model's `__deps__` graph, so they -are not pre-evaluated by the evaluator infrastructure. - -#### Decorator Options - -`@Flow.model(...)` accepts the same options as `Flow.call` to control execution -behavior: - -- `cacheable` — enable caching of results -- `volatile` — mark as volatile (always re-execute) -- `log_level` — logging verbosity -- `validate_result` — validate return type -- `verbose` — verbose logging output -- `evaluator` — custom evaluator - -When not explicitly set, these inherit from any active `FlowOptionsOverride`. - -#### Hydra / YAML Configuration - -`@Flow.model` decorated functions work seamlessly with Hydra configuration and -the `ModelRegistry`: - -```yaml -# config.yaml -data: - _target_: mymodule.load_data - source: my_database - -transformed: - _target_: mymodule.transform_data - raw_data: data # Reference by registry name (same instance is shared) - -aggregated: - _target_: mymodule.aggregate_data - transformed: transformed # Reference by registry name -``` - -```python -from ccflow import ModelRegistry - -registry = ModelRegistry.root() -registry.load_config_from_path("config.yaml") - -# References by name ensure the same object instance is shared -model = registry["aggregated"] +def choose(current: int, deferred: Lazy[int], threshold: FromContext[int]) -> int: + if current > threshold: + return current + return deferred() ``` -### Flow.call with `auto_context` - -For class-based `CallableModel`s, `Flow.call(auto_context=...)` provides a -similar convenience. Instead of defining a separate `ContextBase` subclass, the -decorator generates one from the function's keyword-only parameters. - -```python -from ccflow import CallableModel, Flow, GenericResult - -class MyModel(CallableModel): - @Flow.call(auto_context=True) - def __call__(self, *, x: int, y: str = "default") -> GenericResult: - return GenericResult(value=f"{x}-{y}") +Use `Lazy[T]` when a dependency is expensive and the function should decide +whether to execute it. -model = MyModel() -result = model(x=42, y="hello") -assert result.value == "42-hello" -``` +## Model Registry -You can also pass a parent context class so the generated context inherits -from it: +The model registry lets you register models by name and resolve them later, +including from config-driven workflows. -```python -from datetime import date -from ccflow import CallableModel, DateContext, Flow, GenericResult +- root registry access: `ModelRegistry.root()` +- add and remove models by name +- reuse shared instances through registry references -class MyModel(CallableModel): - @Flow.call(auto_context=DateContext) - def __call__(self, *, date: date, extra: int = 0) -> GenericResult: - return GenericResult(value=date.day + extra) -``` +## Models -The generated context class is a proper `ContextBase` subclass, so it works -with all existing evaluator and registry infrastructure. +The `ccflow.models` package contains concrete model implementations that build +on the framework primitives. -## Model Registry +Use these when you want reusable, prebuilt model classes instead of authoring +your own `CallableModel` or `@Flow.model` stage. -A `ModelRegistry` is a named collection of models. -A `ModelRegistry` can be loaded from YAML configuration, which means you can define a collection of models using YAML. -This is really powerful because this gives you a easy way to define a collection of Python objects via configuration. +## Publishers -## Models +Publishers handle result publication and side-effectful output sinks. -Although you are free to define your own models (`BaseModel` implementations) to use in your flow graph, -`ccflow` comes with some models that you can use off the shelf to solve common problems. `ccflow` comes with a range of models for reading data. +They are useful when a workflow result needs to be written to an external +system rather than only returned to the caller. -The following table summarizes the available models. +## Evaluators -> [!NOTE] -> -> Some models are still in the process of being open sourced. +Evaluators control how `CallableModel`s execute. -## Publishers +Key point for `@Flow.model`: it does not create a new execution engine. It +authors models that still run through the existing evaluator stack. -`ccflow` also comes with a range of models for writing data. -These are referred to as publishers. -You can "chain" publishers and callable models using `PublisherModel` to call a `CallableModel` and publish the results in one step. -In fact, `ccflow` comes with several implementations of `PublisherModel` for common publishing use cases. - -The following table summarizes the "publisher" models. - -> [!NOTE] -> -> Some models are still in the process of being open sourced. - -| Name | Path | Description | -| :--------------------------- | :------------------ | :---------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `DictTemplateFilePublisher` | `ccflow.publishers` | Publish data to a file after populating a Jinja template. | -| `GenericFilePublisher` | `ccflow.publishers` | Publish data using a generic "dump" Callable. Uses `smart_open` under the hood so that local and cloud paths are supported. | -| `JSONPublisher` | `ccflow.publishers` | Publish data to file in JSON format. | -| `PandasFilePublisher` | `ccflow.publishers` | Publish a pandas data frame to a file using an appropriate method on pd.DataFrame. For large-scale exporting (using parquet), see `PandasParquetPublisher`. | -| `NarwhalsFilePublisher` | `ccflow.publishers` | Publish a narwhals data frame to a file using an appropriate method on nw.DataFrame. | -| `PicklePublisher` | `ccflow.publishers` | Publish data to a pickle file. | -| `PydanticJSONPublisher` | `ccflow.publishers` | Publish a pydantic model to a json file. See [Pydantic modeljson](https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump) | -| `YAMLPublisher` | `ccflow.publishers` | Publish data to file in YAML format. | -| `CompositePublisher` | `ccflow.publishers` | Highly configurable, publisher that decomposes a pydantic BaseModel or a dictionary into pieces and publishes each piece separately. | -| `PrintPublisher` | `ccflow.publishers` | Print data using python standard print. | -| `LogPublisher` | `ccflow.publishers` | Print data using python standard logging. | -| `PrintJSONPublisher` | `ccflow.publishers` | Print data in JSON format. | -| `PrintYAMLPublisher` | `ccflow.publishers` | Print data in YAML format. | -| `PrintPydanticJSONPublisher` | `ccflow.publishers` | Print pydantic model as json. See https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump_json | -| `ArrowDatasetPublisher` | *Coming Soon!* | | -| `PandasDeltaPublisher` | *Coming Soon!* | | -| `EmailPublisher` | *Coming Soon!* | | -| `MatplotlibFilePublisher` | *Coming Soon!* | | -| `MLFlowArtifactPublisher` | *Coming Soon!* | | -| `MLFlowPublisher` | *Coming Soon!* | | -| `PandasParquetPublisher` | *Coming Soon!* | | -| `PlotlyFilePublisher` | *Coming Soon!* | | -| `XArrayPublisher` | *Coming Soon!* | | +Depending on your evaluator setup, you can add logging, caching, graph-aware +execution, or custom execution policies. -## Evaluators +## Results -`ccflow` comes with "evaluators" that allows you to evaluate (i.e. run) `CallableModel` s in different ways. - -The following table summarizes the "evaluator" models. - -> [!NOTE] -> -> Some models are still in the process of being open sourced. - -| Name | Path | Description | -| :---------------------------------- | :------------------ | :----------------------------------------------------------------------------------------------------------------------------- | -| `LazyEvaluator` | `ccflow.evaluators` | Evaluator that only actually runs the callable once an attribute of the result is queried (by hooking into `__getattribute__`) | -| `LoggingEvaluator` | `ccflow.evaluators` | Evaluator that logs information about evaluating the callable. | -| `MemoryCacheEvaluator` | `ccflow.evaluators` | Evaluator that caches results in memory. | -| `MultiEvaluator` | `ccflow.evaluators` | An evaluator that combines multiple evaluators. | -| `GraphEvaluator` | `ccflow.evaluators` | Evaluator that evaluates the dependency graph of callable models in topologically sorted order. | -| `ChunkedDateRangeEvaluator` | *Coming Soon!* | | -| `ChunkedDateRangeResultsAggregator` | *Coming Soon!* | | -| `RayChunkedDateRangeEvaluator` | *Coming Soon!* | | -| `DependencyTrackingEvaluator` | *Coming Soon!* | | -| `DiskCacheEvaluator` | *Coming Soon!* | | -| `ParquetCacheEvaluator` | *Coming Soon!* | | -| `RayCacheEvaluator` | *Coming Soon!* | | -| `RayGraphEvaluator` | *Coming Soon!* | | -| `RayDelayedDistributedEvaluator` | *Coming Soon!* | | -| `ParquetCacheEvaluator` | *Coming Soon!* | | -| `RetryEvaluator` | *Coming Soon!* | | +`ResultBase` is the common base class for workflow results. -## Results +`GenericResult[T]` is the default wrapper used when: -A Result is an object that holds the results from a callable model. It provides the equivalent of a strongly typed dictionary where the keys and schema are known upfront. - -The following table summarizes the "result" models. - -| Name | Path | Description | -| :------------------------ | :----------------------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `GenericResult` | `ccflow.result` | A generic result (holds anything). | -| `DictResult` | `ccflow.result` | A generic dict (key/value) result. | -| `ArrowResult` | `ccflow.result.pyarrow` | Holds an arrow table. | -| `ArrowDateRangeResult` | `ccflow.result.pyarrow` | Extension of `ArrowResult` for representing a table over a date range that can be divided by date, such that generation of any sub-range of dates gives the same results as the original table filtered for those dates. | -| `NarwhalsResult` | `ccflow.result.narwhals` | Holds a narwhals `DataFrame` or `LazyFrame`. | -| `NarwhalsDataFrameResult` | `ccflow.result.narwhals` | Holds a narwhals eager `DataFrame`. | -| `NumpyResult` | `ccflow.result.numpy` | Holds a numpy array. | -| `PandasResult` | `ccflow.result.pandas` | Holds a pandas dataframe. | -| `XArrayResult` | `ccflow.result.xarray` | Holds an xarray. | +- a model naturally wants one value payload, or +- a `@Flow.model` function returns a plain Python value instead of a concrete + `ResultBase` subclass. diff --git a/examples/evaluator_demo.py b/examples/evaluator_demo.py deleted file mode 100644 index a85b087..0000000 --- a/examples/evaluator_demo.py +++ /dev/null @@ -1,186 +0,0 @@ -#!/usr/bin/env python -""" -Evaluator Demo: Caching & Execution Strategies -=============================================== - -Shows how to change execution behavior (caching, graph evaluation, logging) -WITHOUT changing user code. The same @Flow.model functions work with any -evaluator stack — you just configure it at the top level. - -Key insight: "default lazy" is an evaluator concern, not a wiring concern. -Users write plain functions and wire them by passing outputs as inputs. -The evaluator layer controls how they execute. - -Demonstrates: - 1. Default execution (eager, no caching) — diamond dep calls load twice - 2. MemoryCacheEvaluator — deduplicates shared deps in a diamond - 3. GraphEvaluator + Cache — topological evaluation + deduplication - 4. LoggingEvaluator — adds tracing around every model call - 5. Per-model opt-out — @Flow.model(cacheable=False) overrides global - -Run with: python examples/evaluator_demo.py -""" - -from __future__ import annotations - -import logging -import sys - -# Suppress default debug logging from ccflow evaluators for clean demo output -logging.disable(logging.DEBUG) - -from ccflow import Flow, FlowOptionsOverride # noqa: E402 -from ccflow.evaluators.common import ( # noqa: E402 - GraphEvaluator, - LoggingEvaluator, - MemoryCacheEvaluator, - MultiEvaluator, -) - -# ============================================================================= -# Plain @Flow.model functions — no evaluator concerns in the code -# ============================================================================= - -call_counts: dict[str, int] = {} - - -def _track(name: str) -> None: - call_counts[name] = call_counts.get(name, 0) + 1 - - -@Flow.model -def load_data(x: int, source: str = "warehouse") -> list: - """Load raw data. Expensive — we want to avoid calling this twice.""" - _track("load_data") - return [x, x * 2, x * 3] - - -@Flow.model -def compute_sum(data: list) -> int: - """Branch A: sum the data.""" - _track("compute_sum") - return sum(data) - - -@Flow.model -def compute_max(data: list) -> int: - """Branch B: max of the data.""" - _track("compute_max") - return max(data) - - -@Flow.model -def combine(sum_result: int, max_result: int) -> dict: - """Combine results from both branches.""" - _track("combine") - return {"sum": sum_result, "max": max_result, "total": sum_result + max_result} - - -@Flow.model(cacheable=False) -def volatile_timestamp(seed: int) -> str: - """Explicitly non-cacheable — always re-executes even with global caching.""" - _track("volatile_timestamp") - from datetime import datetime - - return datetime.now().isoformat() - - -# ============================================================================= -# Wire the pipeline — diamond dependency on load_data -# -# load_data ──┬── compute_sum ──┐ -# └── compute_max ──┴── combine -# ============================================================================= - -shared = load_data(source="prod") -branch_a = compute_sum(data=shared) -branch_b = compute_max(data=shared) -pipeline = combine(sum_result=branch_a, max_result=branch_b) - - -def run() -> dict: - call_counts.clear() - result = pipeline.flow.compute(x=5) - loads = call_counts.get("load_data", 0) - print(f" Result: {result.value}") - print(f" load_data called: {loads}x | total model calls: {sum(call_counts.values())}") - return result.value - - -# ============================================================================= -# Demo 1: Default — no evaluator -# ============================================================================= - -print("=" * 70) -print("1. Default (eager, no caching)") -print(" load_data is called TWICE — once per branch") -print("=" * 70) -run() - -# ============================================================================= -# Demo 2: MemoryCacheEvaluator — deduplicates shared deps -# ============================================================================= - -print() -print("=" * 70) -print("2. MemoryCacheEvaluator (global override)") -print(" load_data is called ONCE — second branch hits cache") -print("=" * 70) -with FlowOptionsOverride(options={"evaluator": MemoryCacheEvaluator(), "cacheable": True}): - run() - -# ============================================================================= -# Demo 3: Cache + GraphEvaluator — topological order + deduplication -# ============================================================================= - -print() -print("=" * 70) -print("3. GraphEvaluator + MemoryCacheEvaluator") -print(" Evaluates in dependency order: load_data → branches → combine") -print("=" * 70) -evaluator = MultiEvaluator(evaluators=[MemoryCacheEvaluator(), GraphEvaluator()]) -with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): - run() - -# ============================================================================= -# Demo 4: Logging — trace every model call -# ============================================================================= - -print() -print("=" * 70) -print("4. LoggingEvaluator + MemoryCacheEvaluator") -print(" Adds timing/tracing around every evaluation") -print("=" * 70) - -# Re-enable logging for this demo (use stdout so log lines interleave with print correctly) -logging.disable(logging.NOTSET) -logging.basicConfig(level=logging.INFO, format=" LOG: %(message)s", stream=sys.stdout) - -evaluator = MultiEvaluator(evaluators=[LoggingEvaluator(log_level=logging.INFO), MemoryCacheEvaluator()]) -with FlowOptionsOverride(options={"evaluator": evaluator, "cacheable": True}): - run() - -# Suppress again for clean output -logging.disable(logging.DEBUG) -logging.getLogger().handlers.clear() - -# ============================================================================= -# Demo 5: Per-model opt-out — cacheable=False overrides global -# ============================================================================= - -print() -print("=" * 70) -print("5. Per-model opt-out: @Flow.model(cacheable=False)") -print(" volatile_timestamp always re-executes despite global cacheable=True") -print("=" * 70) - -ts = volatile_timestamp(seed=0) - -with FlowOptionsOverride(options={"evaluator": MemoryCacheEvaluator(), "cacheable": True}): - call_counts.clear() - r1 = ts.flow.compute(seed=0) - r2 = ts.flow.compute(seed=0) - print(f" Call 1: {r1.value}") - print(f" Call 2: {r2.value}") - print(f" volatile_timestamp called: {call_counts.get('volatile_timestamp', 0)}x") - print(f" (Same result? {r1.value == r2.value} — called twice, timestamps may differ)") diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index 27d5d0e..c22e852 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -14,11 +14,11 @@ from datetime import date, timedelta -from ccflow import DateRangeContext, Flow +from ccflow import DateRangeContext, Flow, FromContext -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def load_revenue(start_date: date, end_date: date, region: str) -> float: +@Flow.model(context_type=DateRangeContext) +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: """Return synthetic revenue for one reporting window.""" days = (end_date - start_date).days + 1 region_base = {"us": 1000.0, "eu": 850.0}.get(region, 900.0) @@ -27,14 +27,14 @@ def load_revenue(start_date: date, end_date: date, region: str) -> float: return round(region_base + days * 8.0 + trend, 2) -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +@Flow.model(context_type=DateRangeContext) def revenue_change( - start_date: date, - end_date: date, current: float, previous: float, label: str, days_back: int, + start_date: FromContext[date], + end_date: FromContext[date], ) -> dict: """Compare the current window against a shifted previous window.""" previous_start = start_date - timedelta(days=days_back) diff --git a/examples/flow_model_hydra_builder_demo.py b/examples/flow_model_hydra_builder_demo.py index 00c9571..0e2aaa1 100644 --- a/examples/flow_model_hydra_builder_demo.py +++ b/examples/flow_model_hydra_builder_demo.py @@ -19,33 +19,16 @@ from calendar import monthrange from datetime import date, timedelta from pathlib import Path -from typing import Literal, Protocol, cast +from typing import Literal -from ccflow import BoundModel, CallableModel, DateRangeContext, Flow, FlowAPI, GenericResult, ModelRegistry -from typing_extensions import TypedDict +from ccflow import BoundModel, CallableModel, DateRangeContext, Flow, FromContext, ModelRegistry CONFIG_PATH = Path(__file__).with_name("config") / "flow_model_hydra_builder_demo.yaml" ComparisonName = Literal["week_over_week", "month_over_month"] -class RevenueChangeResult(TypedDict): - comparison: ComparisonName - current_window: str - previous_window: str - current: float - previous: float - delta: float - growth_pct: float - - -class RevenueChangeModel(Protocol): - flow: FlowAPI - - def __call__(self, context: DateRangeContext) -> GenericResult[RevenueChangeResult]: ... - - -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) -def load_revenue(start_date: date, end_date: date, region: str) -> float: +@Flow.model(context_type=DateRangeContext) +def load_revenue(region: str, start_date: FromContext[date], end_date: FromContext[date]) -> float: """Return synthetic revenue for a date window.""" days = (end_date - start_date).days + 1 region_base = {"us": 1000.0, "eu": 850.0, "apac": 920.0}.get(region, 900.0) @@ -54,14 +37,14 @@ def load_revenue(start_date: date, end_date: date, region: str) -> float: return round(region_base + days * 8.0 + trend, 2) -@Flow.model(context_args=["start_date", "end_date"], context_type=DateRangeContext) +@Flow.model(context_type=DateRangeContext) def revenue_change( - start_date: date, - end_date: date, current: float, previous: float, comparison: ComparisonName, -) -> RevenueChangeResult: + start_date: FromContext[date], + end_date: FromContext[date], +) -> dict: """Compare the current window against a shifted previous window.""" growth = (current - previous) / previous previous_start, previous_end = comparison_window(start_date, end_date, comparison) @@ -104,7 +87,7 @@ def comparison_input(model: CallableModel, comparison: ComparisonName) -> BoundM ) -def build_comparison(current: CallableModel, *, comparison: ComparisonName) -> RevenueChangeModel: +def build_comparison(current: CallableModel, *, comparison: ComparisonName): """Hydra-friendly builder that returns a configured comparison model.""" previous = comparison_input(current, comparison) return revenue_change( @@ -120,8 +103,8 @@ def main() -> None: try: registry.load_config_from_path(str(CONFIG_PATH), overwrite=True) - week_over_week = cast(RevenueChangeModel, registry["week_over_week"]) - month_over_month = cast(RevenueChangeModel, registry["month_over_month"]) + week_over_week = registry["week_over_week"] + month_over_month = registry["month_over_month"] ctx = DateRangeContext( start_date=date(2024, 3, 1), @@ -136,13 +119,10 @@ def main() -> None: print(" week_over_week:", week_over_week) print(" month_over_month:", month_over_month) - week_over_week_result = cast( - RevenueChangeResult, - week_over_week.flow.compute( - start_date=ctx.start_date, - end_date=ctx.end_date, - ).value, - ) + week_over_week_result = week_over_week.flow.compute( + start_date=ctx.start_date, + end_date=ctx.end_date, + ).value month_over_month_result = month_over_month(ctx).value print("\nWeek-over-week:") diff --git a/examples/ml_pipeline_demo.py b/examples/ml_pipeline_demo.py deleted file mode 100644 index 0c8920f..0000000 --- a/examples/ml_pipeline_demo.py +++ /dev/null @@ -1,351 +0,0 @@ -#!/usr/bin/env python -""" -ML Pipeline Demo: Smart Model Selection -======================================== - -This is the example from the original design conversation — a realistic ML -pipeline that demonstrates how Flow.model lets you write plain functions, -wire them by passing outputs as inputs, and execute with .flow.compute(). - -Features demonstrated: - 1. @Flow.model with auto-wrap (plain return types, no GenericResult needed) - 2. Lazy[T] for conditional evaluation (skip slow model if fast is good enough) - 3. .flow.compute() for execution with automatic context propagation - 4. .flow.with_inputs() for context transforms (lookback windows) - 5. Factored wiring — build_pipeline() shows how to reuse the same graph - structure with different data sources - -The pipeline: - - load_dataset ──> prepare_features ──> train_linear ──> evaluate ──> fast_metrics ──┐ - └──> train_forest ──> evaluate ──> slow_metrics ──┴──> smart_training - -Run with: python examples/ml_pipeline_demo.py -""" - -from __future__ import annotations - -from dataclasses import dataclass -from datetime import date, timedelta -from math import sin - -from ccflow import Flow, Lazy - - -# ============================================================================= -# Domain types (stand-ins for real ML objects) -# ============================================================================= - - -@dataclass -class PreparedData: - """Container for train/test split data.""" - - X_train: list # list of feature vectors - X_test: list - y_train: list # list of target values - y_test: list - - -@dataclass -class TrainedModel: - """A fitted model (placeholder).""" - - name: str - coefficients: list - intercept: float - augment: bool # Whether to add sin feature during prediction - - -@dataclass -class Metrics: - """Evaluation metrics.""" - - r2: float - mse: float - model_name: str - - -# ============================================================================= -# Data Loading -# ============================================================================= - - -@Flow.model -def load_dataset(start_date: date, end_date: date, source: str = "warehouse") -> list: - """Load raw dataset for a date range. - - Returns a list of dicts (standing in for a DataFrame). - Auto-wrapped: returns plain list, framework wraps in GenericResult. - """ - n_days = (end_date - start_date).days + 1 - print(f" [load_dataset] Loading {n_days} days from '{source}' ({start_date} to {end_date})") - # True relationship: target = 2.0 * x + 10.0 + 15.0 * sin(x * 0.2) - # Linear model captures the trend (R^2 ~0.93), forest also captures the sin wave (~0.99) - return [ - { - "date": str(start_date + timedelta(days=i)), - "x": float(i), - "target": 2.0 * i + 10.0 + 15.0 * sin(i * 0.2), - } - for i in range(n_days) - ] - - -# ============================================================================= -# Feature Engineering -# ============================================================================= - - -@Flow.model -def prepare_features(raw_data: list) -> PreparedData: - """Split data into train/test. - - Returns a PreparedData dataclass — the framework auto-wraps it in GenericResult. - Downstream models can request individual fields via prepared["X_train"] etc. - """ - n = len(raw_data) - split = int(n * 0.8) - print(f" [prepare_features] {n} rows, split at {split}") - - X = [[r["x"]] for r in raw_data] - y = [r["target"] for r in raw_data] - - return PreparedData( - X_train=X[:split], - X_test=X[split:], - y_train=y[:split], - y_test=y[split:], - ) - - -# ============================================================================= -# Model Training -# ============================================================================= - - -def _ols_fit(X, y): - """Simple OLS: compute coefficients and intercept.""" - n = len(X) - n_feat = len(X[0]) - y_mean = sum(y) / n - x_means = [sum(row[j] for row in X) / n for j in range(n_feat)] - - coefficients = [] - for j in range(n_feat): - cov = sum((X[i][j] - x_means[j]) * (y[i] - y_mean) for i in range(n)) / n - var = sum((X[i][j] - x_means[j]) ** 2 for i in range(n)) / n - coefficients.append(cov / var if var > 1e-10 else 0.0) - - intercept = y_mean - sum(c * m for c, m in zip(coefficients, x_means)) - return coefficients, intercept - - -def _augment(X): - """Add sin(x*0.2) feature to capture non-linearity.""" - return [row + [sin(row[0] * 0.2)] for row in X] - - -@Flow.model -def train_linear(prepared: PreparedData) -> TrainedModel: - """Train a fast linear model (linear features only).""" - print(f" [train_linear] Fitting on {len(prepared.X_train)} samples") - coefficients, intercept = _ols_fit(prepared.X_train, prepared.y_train) - return TrainedModel(name="LinearRegression", coefficients=coefficients, intercept=intercept, augment=False) - - -@Flow.model -def train_forest(prepared: PreparedData, n_estimators: int = 100) -> TrainedModel: - """Train a model that also captures non-linear patterns (simulated).""" - print(f" [train_forest] Fitting {n_estimators} trees on {len(prepared.X_train)} samples") - # Augment with sin feature to capture non-linearity - X_aug = _augment(prepared.X_train) - coefficients, intercept = _ols_fit(X_aug, prepared.y_train) - return TrainedModel( - name=f"RandomForest(n={n_estimators})", - coefficients=coefficients, - intercept=intercept, - augment=True, - ) - - -# ============================================================================= -# Model Evaluation -# ============================================================================= - - -@Flow.model -def evaluate_model(model: TrainedModel, prepared: PreparedData) -> Metrics: - """Evaluate a trained model on test data.""" - X_test = prepared.X_test - y_test = prepared.y_test - X_eval = _augment(X_test) if model.augment else X_test - - y_pred = [ - model.intercept + sum(c * x for c, x in zip(model.coefficients, row)) - for row in X_eval - ] - - y_mean = sum(y_test) / len(y_test) if y_test else 0 - ss_tot = sum((y - y_mean) ** 2 for y in y_test) or 1 - ss_res = sum((yt - yp) ** 2 for yt, yp in zip(y_test, y_pred)) - r2 = 1.0 - ss_res / ss_tot - mse = ss_res / len(y_test) if y_test else 0 - - print(f" [evaluate_model] {model.name}: R^2={r2:.4f}, MSE={mse:.2f}") - return Metrics(r2=r2, mse=mse, model_name=model.name) - - -# ============================================================================= -# Smart Pipeline with Conditional Execution -# ============================================================================= - - -@Flow.model -def smart_training( - # data: PreparedData, - fast_metrics: Metrics, - slow_metrics: Lazy[Metrics], # Only evaluated if fast isn't good enough - threshold: float = 0.9, -) -> Metrics: - """Use fast model if good enough, else fall back to slow. - - The slow_metrics parameter is Lazy — it receives a zero-arg thunk. - If the fast model exceeds the threshold, the slow model is never - trained or evaluated at all. - """ - print(f" [smart_training] Fast R^2={fast_metrics.r2:.4f}, threshold={threshold}") - if fast_metrics.r2 >= threshold: - print(" [smart_training] Fast model is good enough! Skipping slow model.") - return fast_metrics - else: - print(" [smart_training] Fast model below threshold, evaluating slow model...") - return slow_metrics() - - -# ============================================================================= -# Pipeline Wiring Helper -# ============================================================================= - - -def build_pipeline(raw, *, n_estimators=200, threshold=0.95): - """Wire a complete train/evaluate/select pipeline from a data source. - - This function shows the flexibility of the approach: the same wiring - logic can be applied to different data sources (raw, lookback_raw, etc.) - without duplicating code. Everything here is just wiring — no computation - happens until .flow.compute() is called. - - Args: - raw: A CallableModel or BoundModel that produces raw data (list of dicts) - n_estimators: Number of trees for the forest model - threshold: R^2 threshold for the fast/slow model selection - - Returns: - A smart_training model instance ready for .flow.compute() - """ - # Feature engineering — returns a PreparedData with X_train, X_test, etc. - prepared = prepare_features(raw_data=raw) - - # Train both models — each receives the whole PreparedData and extracts - # the fields it needs internally. - linear = train_linear(prepared=prepared) - forest = train_forest(prepared=prepared, n_estimators=n_estimators) - - # Evaluate both - linear_metrics = evaluate_model(model=linear, prepared=prepared) - forest_metrics = evaluate_model(model=forest, prepared=prepared) - - # Smart selection with Lazy — forest is only evaluated if linear isn't good enough - return smart_training( - fast_metrics=linear_metrics, - slow_metrics=forest_metrics, - threshold=threshold, - ) - - -# ============================================================================= -# Main: Wire and execute the pipeline -# ============================================================================= - - -def main(): - print("=" * 70) - print("ML Pipeline Demo: Smart Model Selection with Flow.model") - print("=" * 70) - - # ------------------------------------------------------------------ - # Step 1: Wire the pipeline (no computation happens here) - # ------------------------------------------------------------------ - print("\n--- Wiring the pipeline (lazy, no computation yet) ---\n") - - raw = load_dataset(source="prod_warehouse") - - # build_pipeline factors out the repeated wiring logic. - # Linear R^2 ≈ 0.93. Threshold is 0.95 → falls through to forest. - pipeline = build_pipeline(raw, n_estimators=200, threshold=0.95) - - print("Pipeline wired. No functions have been called yet.") - - # ------------------------------------------------------------------ - # Step 2: Execute — linear not good enough, falls back to forest - # ------------------------------------------------------------------ - print("\n--- Executing pipeline (Jan-Jun 2024) ---\n") - result = pipeline.flow.compute( - start_date=date(2024, 1, 1), - end_date=date(2024, 6, 30), - ) - - print(f"\n Best model: {result.value.model_name}") - print(f" R^2: {result.value.r2:.4f}") - print(f" MSE: {result.value.mse:.2f}") - - # ------------------------------------------------------------------ - # Step 3: Context transforms (lookback) — reuse build_pipeline - # ------------------------------------------------------------------ - print("\n" + "=" * 70) - print("With Lookback: Same pipeline structure, extra history for loading") - print("=" * 70) - - # flow.with_inputs() creates a BoundModel that transforms the context - # before calling the underlying model. start_date is shifted 30 days earlier. - lookback_raw = raw.flow.with_inputs( - start_date=lambda ctx: ctx.start_date - timedelta(days=30) - ) - - # Same wiring logic, different data source — no duplication. - lookback_pipeline = build_pipeline(lookback_raw, n_estimators=200, threshold=0.95) - - print("\n--- Executing lookback pipeline ---\n") - result2 = lookback_pipeline.flow.compute( - start_date=date(2024, 1, 1), - end_date=date(2024, 6, 30), - ) - # Notice: load_dataset gets start_date=2023-12-02 (30 days earlier) - - print(f"\n Best model: {result2.value.model_name}") - print(f" R^2: {result2.value.r2:.4f}") - print(f" MSE: {result2.value.mse:.2f}") - - # ------------------------------------------------------------------ - # Step 4: Lower threshold — linear is good enough, skip forest - # ------------------------------------------------------------------ - print("\n" + "=" * 70) - print("Lazy Evaluation: Lower threshold so fast model is good enough") - print("=" * 70) - - # With threshold=0.80, the linear model's R^2 (~0.93) passes. - # The forest is NEVER trained or evaluated — Lazy skips it entirely. - fast_pipeline = build_pipeline(raw, n_estimators=200, threshold=0.80) - - print("\n--- Executing (slow model should NOT be trained) ---\n") - result3 = fast_pipeline.flow.compute( - start_date=date(2024, 1, 1), - end_date=date(2024, 6, 30), - ) - print(f"\n Selected: {result3.value.model_name} (R^2={result3.value.r2:.4f})") - print(" (Notice: train_forest and its evaluate_model were never called)") - - -if __name__ == "__main__": - main() From d5ca79e65aab082a6b373134e1d38f05deae0c8d Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Wed, 8 Apr 2026 13:16:19 -0400 Subject: [PATCH 19/26] Clean-up Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 6 -- ccflow/flow_model.py | 100 +------------------------------ ccflow/tests/test_flow_model.py | 16 ----- docs/design/flow_model_design.md | 56 ++++++++--------- docs/wiki/Key-Features.md | 12 ++-- examples/flow_model_example.py | 18 +++--- 6 files changed, 43 insertions(+), 165 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index b77c205..a0c10fa 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -817,12 +817,6 @@ def flow(self) -> "FlowAPI": return FlowAPI(self) - def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: - """Wire this model into a downstream generated ``@Flow.model`` stage.""" - from .flow_model import pipe_model - - return pipe_model(self, stage, param=param, **bindings) - class WrapperModel(CallableModel, Generic[CallableModelType], abc.ABC): """Abstract class that represents a wrapper around an underlying model, with the same context and return types. diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index fbd5ccb..108c08f 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -337,99 +337,6 @@ def _generated_model_class(stage: Any) -> Optional[type["_GeneratedFlowModelBase return None -def _describe_pipe_stage(stage: Any) -> str: - if isinstance(stage, BoundModel): - return repr(stage) - if isinstance(stage, _GeneratedFlowModelBase): - return repr(stage) - if callable(stage): - return _callable_name(stage) - return repr(stage) - - -def _generated_model_explicit_kwargs(model: "_GeneratedFlowModelBase") -> Dict[str, Any]: - values = cast(Dict[str, Any], model.model_dump(mode="python", exclude_unset=True)) - values.pop("type_", None) - values.pop("_target_", None) - values.pop("meta", None) - return values - - -def _resolve_pipe_param(source: Any, stage: Any, param: Optional[str], bindings: Dict[str, Any]) -> Tuple[str, type["_GeneratedFlowModelBase"]]: - del source - - generated_model_cls = _generated_model_class(stage) - if generated_model_cls is None: - raise TypeError("pipe() only supports downstream stages created by @Flow.model or bound versions of those stages.") - - config = generated_model_cls.__flow_model_config__ - stage_name = _describe_pipe_stage(stage) - regular_names = list(config.regular_param_names) - - generated_model = _generated_model_instance(stage) - occupied_names = set(bindings) - if generated_model is not None: - occupied_names |= {name for name in _bound_field_names(generated_model) if name in regular_names} - - if param is not None: - if param in config.contextual_param_names: - raise TypeError(f"pipe() target parameter '{param}' on {stage_name} is contextual. Use .flow.with_inputs(...) instead.") - if param not in regular_names: - valid = ", ".join(regular_names) or "" - raise TypeError(f"pipe() target parameter '{param}' is not valid for {stage_name}. Available regular parameters: {valid}.") - if param in occupied_names: - raise TypeError(f"pipe() target parameter '{param}' is already bound for {stage_name}.") - return param, generated_model_cls - - required_candidates = [p.name for p in config.regular_params if not p.has_function_default and p.name not in occupied_names] - if len(required_candidates) == 1: - return required_candidates[0], generated_model_cls - if len(required_candidates) > 1: - candidates = ", ".join(required_candidates) - raise TypeError( - f"pipe() could not infer a target parameter for {stage_name}; unbound regular candidates are: {candidates}. Pass param='...'." - ) - - fallback_candidates = [name for name in regular_names if name not in occupied_names] - if len(fallback_candidates) == 1: - return fallback_candidates[0], generated_model_cls - if len(fallback_candidates) > 1: - candidates = ", ".join(fallback_candidates) - raise TypeError( - f"pipe() could not infer a target parameter for {stage_name}; unbound regular candidates are: {candidates}. Pass param='...'." - ) - - raise TypeError(f"pipe() could not find an available regular target parameter for {stage_name}.") - - -def pipe_model(source: Any, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: - """Wire ``source`` into a downstream generated ``@Flow.model`` stage.""" - - if not _is_model_dependency(source): - raise TypeError(f"pipe() source must be a CallableModel, got {type(source).__name__}.") - - target_param, generated_model_cls = _resolve_pipe_param(source, stage, param, bindings) - build_kwargs = dict(bindings) - build_kwargs[target_param] = source - - if isinstance(stage, BoundModel): - generated_model = _generated_model_instance(stage) - if generated_model is None: - raise TypeError("pipe() only supports downstream BoundModel stages created from @Flow.model.") - explicit_kwargs = _generated_model_explicit_kwargs(generated_model) - explicit_kwargs.update(build_kwargs) - rebound_model = generated_model_cls(**explicit_kwargs) - return BoundModel(model=rebound_model, input_transforms=dict(stage._input_transforms)) - - generated_model = _generated_model_instance(stage) - if generated_model is not None: - explicit_kwargs = _generated_model_explicit_kwargs(generated_model) - explicit_kwargs.update(build_kwargs) - return generated_model_cls(**explicit_kwargs) - - return stage(**build_kwargs) - - def _context_input_types_for_model(model: CallableModel) -> Optional[Dict[str, Any]]: generated = _generated_model_instance(model) if generated is not None: @@ -466,7 +373,7 @@ def _resolve_regular_param_value(model: "_GeneratedFlowModelBase", param: _FlowM if _is_unset_flow_input(value): raise TypeError( f"Regular parameter '{param.name}' for {_callable_name(type(model).__flow_model_config__.func)} is still unbound. " - "Bind it at construction time or via pipe()." + "Bind it at construction time." ) if param.is_lazy: if _is_model_dependency(value): @@ -721,9 +628,6 @@ def _serialize_with_transforms(self, handler): data["_input_transforms_fingerprint"] = _fingerprint_transforms(self._input_transforms) return data - def pipe(self, stage: Any, /, *, param: Optional[str] = None, **bindings: Any) -> Any: - return pipe_model(self, stage, param=param, **bindings) - def __repr__(self) -> str: transforms = ", ".join(f"{name}={_transform_repr(transform)}" for name, transform in self._input_transforms.items()) return f"{self.model!r}.flow.with_inputs({transforms})" @@ -826,7 +730,7 @@ def __call__(self, context): missing = ", ".join(sorted(missing_regular)) raise TypeError( f"Missing regular parameter(s) for {_callable_name(config.func)}: {missing}. " - "Bind them at construction time or via pipe(); compute() only supplies contextual inputs." + "Bind them at construction time; compute() only supplies contextual inputs." ) fn_kwargs: Dict[str, Any] = {} diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index ed18cd5..9d3c2ed 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -263,22 +263,6 @@ def bad(x: int) -> int: return x -def test_pipe_only_targets_regular_parameters(): - @Flow.model - def source(value: FromContext[int]) -> int: - return value - - @Flow.model - def consumer(a: int, b: FromContext[int]) -> int: - return a + b - - piped = source().pipe(consumer()) - assert piped.flow.compute(FlowContext(value=10, b=5)).value == 15 - - with pytest.raises(TypeError, match="is contextual"): - source().pipe(consumer(), param="b") - - def test_lazy_dependency_remains_lazy(): calls = {"source": 0} diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index 6ea254c..b6055ae 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -13,13 +13,13 @@ The design is intentionally narrow: - upstream `CallableModel`s can still be passed as ordinary arguments. The goal is that a reader can look at one function signature and immediately -answer: +see: 1. which values come from runtime context, 2. which values must be bound as regular configuration or dependencies, 3. how to rewrite contextual inputs for one branch of the graph. -## Primary Story +## Core Example ```python from ccflow import Flow, FromContext @@ -31,11 +31,13 @@ def foo(a: int, b: FromContext[int]) -> int: # Build an instance with a=11 bound, then supply b=12 at runtime: -result = foo(a=11).flow.compute(b=12) +configured = foo(a=11) +result = configured.flow.compute(b=12) assert result.value == 23 # .value unwraps the GenericResult wrapper -# Or pre-fill both — b=12 becomes a contextual default: -result = foo(a=11, b=12).flow.compute() +# Or create a different instance that stores b=12 as its contextual default: +prefilled = foo(a=11, b=12) +result = prefilled.flow.compute() assert result.value == 23 ``` @@ -47,9 +49,15 @@ This is the core contract: - `a` is a regular parameter — it must be bound at construction time, - `b` is contextual because it is marked with `FromContext[int]` — it can come - from runtime context, a construction-time default, or a function default, + from runtime context, a contextual default stored on the model instance, or a + function default, - `.flow.compute(...)` only accepts contextual inputs. +Nothing is being mutated at execution time in the second example. +`prefilled = foo(a=11, b=12)` constructs a different model instance whose +contextual default for `b` is already `12`. Because `b` is still contextual, +incoming runtime context can still override that default. + This means the following is **invalid**: ```python @@ -58,8 +66,7 @@ foo().flow.compute(a=11, b=12) # Bind regular parameter(s) separately: a ``` -`a` is not contextual, so it must be bound at construction time (`foo(a=11)`) -or wired with `.pipe(...)`. +`a` is not contextual, so it must be bound at construction time (`foo(a=11)`). ## Regular Parameters vs Contextual Parameters @@ -106,16 +113,19 @@ Contextual parameters are the ones marked with `FromContext[...]`. They can be satisfied by: - runtime context, -- construction-time contextual defaults, +- contextual defaults stored on the model instance, - function defaults. They cannot be satisfied by `CallableModel` values. +A construction-time value for a contextual parameter is still a default, not a +conversion into a regular bound parameter. + Contextual precedence is: 1. branch-local `.flow.with_inputs(...)` rewrites, 2. incoming runtime context, -3. construction-time contextual defaults, +3. contextual defaults stored on the model instance, 4. function defaults. ## `.flow.compute(...)` @@ -189,25 +199,6 @@ Key rules: - chained `with_inputs()` calls merge, with the newest transform winning for a repeated field. -## `.pipe(...)` - -`.pipe(...)` is a convenience API for wiring one upstream model into a -downstream regular parameter. It is equivalent to passing the model directly: - -```python -source = load_value(offset=5) - -# These two are equivalent: -model_a = add(a=source) -model_b = source.pipe(add(), param="a") -``` - -`pipe()` is most useful when the downstream stage is already partially -configured and you want to wire in one more dependency, or when you are -building pipelines programmatically and the parameter name is determined at -runtime. It only targets regular parameters — use `.flow.with_inputs(...)` to -rewrite contextual inputs. - ## Explicit Context Interop `@Flow.model` still supports an explicit context parameter for cases where the @@ -328,13 +319,14 @@ Use a hand-written class-based `CallableModel` when: **`compute()` says a field is not contextual** -That field is a regular parameter. Bind it at construction time or wire it with -`.pipe(...)`. Only `FromContext[...]` fields belong in `compute()`. +That field is a regular parameter. Bind it at construction time. Only +`FromContext[...]` fields belong in `compute()`. **`with_inputs()` rejects a field** `with_inputs()` only rewrites contextual inputs. If you are trying to attach one -stage to another, use regular argument binding or `.pipe(...)`. +stage to another, pass the upstream model as a regular argument at construction +time. **A contextual parameter still shows up in `context_inputs` after I bound it** diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index c1346ee..1f8adcb 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -48,6 +48,9 @@ def add(a: int, b: FromContext[int]) -> int: model = add(a=10) assert model.flow.compute(b=5).value == 15 + +prefilled = add(a=10, b=7) +assert prefilled.flow.compute().value == 17 ``` That means: @@ -56,6 +59,10 @@ That means: - `b` is contextual, - `.flow.compute(...)` only accepts contextual inputs. +`prefilled = add(a=10, b=7)` creates a different model instance with a stored +contextual default for `b`. `compute()` does not mutate the model; it resolves +the remaining contextual inputs for that execution. + Regular parameters can be satisfied by: - literal values, @@ -65,7 +72,7 @@ Regular parameters can be satisfied by: Contextual parameters can be satisfied by: - runtime context, -- construction-time contextual defaults, +- contextual defaults stored on the model instance, - function defaults. Contextual parameters cannot be bound to `CallableModel` values. @@ -141,9 +148,6 @@ context object and returns the same result as `model(context)`. `model.flow.with_inputs(...)` rewrites contextual inputs on one dependency edge. It only accepts contextual fields and remains branch-local. -`model.pipe(...)` is a secondary helper for wiring one upstream model into a -downstream regular parameter. It never targets `FromContext[...]` fields. - Generated models also expose introspection helpers: ```python diff --git a/examples/flow_model_example.py b/examples/flow_model_example.py index c22e852..d4dd633 100644 --- a/examples/flow_model_example.py +++ b/examples/flow_model_example.py @@ -1,12 +1,12 @@ #!/usr/bin/env python -"""Canonical Flow.model example. +"""Main `@Flow.model` example. -This is the main `@Flow.model` story: +Shows how to: -1. define workflow steps as plain Python functions, -2. wire them together by passing upstream models as normal arguments, -3. use a small Python builder for reusable composition, -4. execute either as a normal CallableModel or via `.flow.compute(...)`. +1. define stages as plain Python functions, +2. compose stages by passing upstream models as ordinary arguments, +3. rewrite contextual inputs on one dependency edge with `.flow.with_inputs(...)`, +4. execute either as `model(context)` or `model.flow.compute(...)`. Run with: python examples/flow_model_example.py @@ -59,7 +59,7 @@ def shifted_window(model, *, days_back: int): def build_week_over_week_pipeline(region: str): - """Build one reusable pipeline from plain Flow.model functions.""" + """Build one reusable comparison pipeline.""" current = load_revenue(region=region) previous = shifted_window(current, days_back=7) return revenue_change( @@ -87,11 +87,11 @@ def main() -> None: end_date=ctx.end_date, ) - print("\nPipeline wired from plain functions:") + print("\nPipeline:") print(" current input:", pipeline.current) print(" previous input:", pipeline.previous) - print("\nDirect call and .flow.compute(...) are equivalent:") + print("\nExecution:") print(f" direct == computed: {direct == computed}") print("\nResult:") From 387d4bce6811d04dadcaa26390e38d777df1870b Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 9 Apr 2026 04:07:27 -0400 Subject: [PATCH 20/26] Add auto-unwrap parameter for @Flow.model decorator Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 5 ++ ccflow/flow_model.py | 38 +++++++++++++-- ccflow/tests/test_flow_context.py | 30 ++++++++++++ ccflow/tests/test_flow_model.py | 80 +++++++++++++++++++++++++++++++ 4 files changed, 149 insertions(+), 4 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index a0c10fa..9012bbb 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -537,6 +537,11 @@ def model(*args, **kwargs): Args: context_type: Optional ContextBase subclass used only to validate/coerce `FromContext[...]` inputs against an existing nominal context shape + auto_unwrap: When True, `.flow.compute(...)` unwraps auto-wrapped + `GenericResult(value=...)` outputs back to the annotated return type. + Explicit `ResultBase` returns are left unchanged. Default: False. + model_base: Optional custom `CallableModel` subclass to use as an + additional base for the generated model class. cacheable: Enable caching of results (default: False) volatile: Mark as volatile (default: False) log_level: Logging verbosity (default: logging.DEBUG) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 108c08f..d857445 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -102,6 +102,7 @@ class _FlowModelConfig: context_type: Type[ContextBase] result_type: Type[ResultBase] auto_wrap_result: bool + auto_unwrap: bool explicit_context_param: Optional[str] parameters: Tuple[_FlowModelParam, ...] context_input_types: Dict[str, Any] @@ -228,6 +229,17 @@ def thunk(): return thunk +def _maybe_auto_unwrap_external_result(target: CallableModel, result: Any) -> Any: + generated = _generated_model_instance(target) + if generated is None: + return result + + config = type(generated).__flow_model_config__ + if config.auto_wrap_result and config.auto_unwrap: + return _unwrap_model_result(result) + return result + + def _parse_annotation(annotation: Any) -> _ParsedAnnotation: is_lazy = False is_from_context = False @@ -495,7 +507,7 @@ def compute(self, context: Any = _UNSET, /, **kwargs) -> Any: generated = _generated_model_instance(target) if generated is not None: built_context = _build_generated_compute_context(generated, context, kwargs) - return target(built_context) + return _maybe_auto_unwrap_external_result(target, target(built_context)) if context is not _UNSET and kwargs: raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") @@ -503,7 +515,7 @@ def compute(self, context: Any = _UNSET, /, **kwargs) -> Any: built_context = target.context_type.model_validate(kwargs) else: built_context = context if isinstance(context, ContextBase) else target.context_type.model_validate(context) - return target(built_context) + return _maybe_auto_unwrap_external_result(target, target(built_context)) @property def context_inputs(self) -> Dict[str, Any]: @@ -830,6 +842,7 @@ def _analyze_flow_model( resolved_hints: Dict[str, Any], *, context_type: Optional[Type[ContextBase]], + auto_unwrap: bool, ) -> _FlowModelConfig: params = sig.parameters @@ -930,6 +943,7 @@ def _analyze_flow_model( context_type=call_context_type, result_type=result_type, auto_wrap_result=auto_wrap_result, + auto_unwrap=auto_unwrap, explicit_context_param=explicit_context_param, parameters=tuple(analyzed_params), context_input_types=context_input_types, @@ -961,11 +975,24 @@ def _validate_factory_kwargs(config: _FlowModelConfig, kwargs: Dict[str, Any]) - _coerce_value(param.name, value, param.annotation, "Field") +def _resolve_generated_model_bases(model_base: Type[CallableModel]) -> Tuple[type, ...]: + if not isinstance(model_base, type) or not issubclass(model_base, CallableModel): + raise TypeError(f"model_base must be a CallableModel subclass, got {model_base!r}") + + if issubclass(model_base, _GeneratedFlowModelBase): + return (model_base,) + if model_base is CallableModel: + return (_GeneratedFlowModelBase,) + return (_GeneratedFlowModelBase, model_base) + + def flow_model( func: Optional[_AnyCallable] = None, *, context_args: Any = _REMOVED_CONTEXT_ARGS, context_type: Optional[Type[ContextBase]] = None, + auto_unwrap: bool = False, + model_base: Type[CallableModel] = CallableModel, cacheable: Any = _UNSET, volatile: Any = _UNSET, log_level: Any = _UNSET, @@ -986,7 +1013,7 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: except (AttributeError, NameError, TypeError): resolved_hints = {} - config = _analyze_flow_model(fn, sig, resolved_hints, context_type=context_type) + config = _analyze_flow_model(fn, sig, resolved_hints, context_type=context_type, auto_unwrap=auto_unwrap) annotations: Dict[str, Any] = {} namespace: Dict[str, Any] = { @@ -1020,7 +1047,10 @@ def decorator(fn: _AnyCallable) -> _AnyCallable: namespace["__annotations__"] = annotations - GeneratedModel = cast(type[_GeneratedFlowModelBase], type(f"_{_callable_name(fn)}_Model", (_GeneratedFlowModelBase,), namespace)) + GeneratedModel = cast( + type[_GeneratedFlowModelBase], + type(f"_{_callable_name(fn)}_Model", _resolve_generated_model_bases(model_base), namespace), + ) GeneratedModel.__flow_model_config__ = config register_ccflow_import_path(GeneratedModel) GeneratedModel.model_rebuild() diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index caf3991..50a6754 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -135,6 +135,36 @@ def add(a: int, b: FromContext[int]) -> int: assert restored.model.flow.bound_inputs == {"a": 10} +def test_bound_model_cloudpickle_roundtrip_preserves_callable_transforms(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + bound = add(a=10).flow.with_inputs(b=lambda ctx: ctx.b + 1) + restored = cloudpickle.loads(cloudpickle.dumps(bound)) + + assert restored.flow.compute(b=4).value == 15 + + +def test_transformed_dag_cloudpickle_roundtrip_preserves_callable_transforms(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, value: FromContext[int]) -> int: + return left + right + value + + base = source() + model = combine( + left=base.flow.with_inputs(value=lambda ctx: ctx.value + 1), + right=base.flow.with_inputs(value=lambda ctx: ctx.value + 10), + ) + restored = cloudpickle.loads(cloudpickle.dumps(model)) + + assert restored.flow.compute(value=5).value == (6 + 15 + 5) + + def test_regular_callable_models_still_support_with_inputs(): model = OffsetModel(offset=10) shifted = model.flow.with_inputs(x=lambda ctx: ctx.x * 2) diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 9d3c2ed..3c1a691 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -239,6 +239,86 @@ def loader(context: DateRangeContext, source: str = "db") -> GenericResult[str]: assert model.flow.compute(start_date="2024-01-01", end_date="2024-01-02").value == "api:2024-01-01:2024-01-02" +def test_auto_unwrap_defaults_to_false_for_auto_wrapped_results(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + result = add(a=10).flow.compute(b=5) + assert isinstance(result, GenericResult) + assert result.value == 15 + + +def test_compute_does_not_unwrap_explicit_generic_result_returns(): + @Flow.model + def load(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * 2) + + result = load().flow.compute(value=3) + assert isinstance(result, GenericResult) + assert result.value == 6 + + +def test_auto_unwrap_can_be_enabled_for_auto_wrapped_results(): + @Flow.model(auto_unwrap=True) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + assert add(a=10).flow.compute(b=5) == 15 + + +def test_auto_unwrap_only_affects_external_compute_results(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value * 2 + + @Flow.model(auto_unwrap=True) + def add(left: int, bonus: FromContext[int]) -> int: + return left + bonus + + model = add(left=source()) + assert model.flow.compute(FlowContext(value=4, bonus=3)) == 11 + + +def test_model_base_allows_custom_callable_model_subclass(): + class CustomFlowBase(CallableModel): + multiplier: int = 1 + + @model_validator(mode="after") + def _validate_multiplier(self): + if self.multiplier <= 0: + raise ValueError("multiplier must be positive") + return self + + def scaled(self, value: int) -> int: + return value * self.multiplier + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.model(model_base=CustomFlowBase) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + model = add(a=10, multiplier=3) + assert isinstance(model, CustomFlowBase) + assert model.multiplier == 3 + assert model.scaled(4) == 12 + assert model.flow.compute(b=5).value == 15 + + with pytest.raises(ValueError, match="multiplier must be positive"): + add(a=10, multiplier=0) + + +def test_model_base_must_be_callable_model_subclass(): + with pytest.raises(TypeError, match="model_base must be a CallableModel subclass"): + + @Flow.model(model_base=int) + def add(a: int, b: FromContext[int]) -> int: + return a + b + + def test_explicit_context_and_from_context_cannot_mix(): with pytest.raises(TypeError, match="cannot also declare FromContext"): From a038eb61ccca91d366fd180afa1d15513bec2893 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 9 Apr 2026 04:21:47 -0400 Subject: [PATCH 21/26] More tests Signed-off-by: Nijat Khanbabayev --- ccflow/tests/test_flow_model.py | 153 ++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 3c1a691..7dc555c 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -2,6 +2,7 @@ import graphlib from datetime import date, timedelta +from typing import Annotated import pytest from pydantic import model_validator @@ -496,3 +497,155 @@ def add(x: int) -> int: with pytest.raises(RuntimeError, match="boom"): Flow.model(add) + + +def test_internal_generated_model_helpers_and_config_properties(): + @Flow.model + def add(a: int, b: FromContext[int]) -> int: + return a + b + + generated_cls = flow_model_module._generated_model_class(add) + assert generated_cls is not None + + config = generated_cls.__flow_model_config__ + assert config.regular_param_names == ("a",) + assert config.contextual_param_names == ("b",) + assert config.param("a").name == "a" + assert config.param("b").name == "b" + + with pytest.raises(KeyError): + config.param("missing") + + model = add(a=10) + assert flow_model_module._generated_model_class(model) is generated_cls + + class DerivedGeneratedBase(flow_model_module._GeneratedFlowModelBase): + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + del context + return [] + + assert flow_model_module._resolve_generated_model_bases(DerivedGeneratedBase) == (DerivedGeneratedBase,) + + +def test_internal_type_helpers_and_plain_callable_flow_api_paths(): + assert flow_model_module._concrete_context_type(SimpleContext | None) is SimpleContext + assert flow_model_module._concrete_context_type(int) is None + assert flow_model_module._type_accepts_str(Annotated[str, flow_model_module._FromContextMarker()]) is True + assert flow_model_module._type_accepts_str(int | None) is False + assert flow_model_module._transform_repr(lambda value: value) == "" + assert flow_model_module._bound_field_names(object()) == set() + + class PlainModel(CallableModel): + @property + def context_type(self): + return SimpleContext + + @property + def result_type(self): + return GenericResult[int] + + @Flow.call + def __call__(self, context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + @Flow.deps + def __deps__(self, context: SimpleContext): + del context + return [] + + model = PlainModel() + + assert model.flow.context_inputs == {"value": int} + assert model.flow.unbound_inputs == {"value": int} + assert model.flow.bound_inputs == {} + assert model.flow.compute({"value": 3}).value == 3 + + with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + model.flow.compute(SimpleContext(value=1), value=2) + + +def test_explicit_context_paths_and_underbar_context_parameter(): + model = basic_loader(source="warehouse", multiplier=3) + + assert model.flow.context_inputs == {"value": int} + assert model.flow.unbound_inputs == {"value": int} + assert model.flow.compute({"value": 4}).value == 12 + assert model.flow.compute(SimpleContext(value=5)).value == 15 + + with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): + model.flow.compute(SimpleContext(value=1), value=2) + + @Flow.model + def underbar(_: SimpleContext, a: int) -> int: + return _.value + a + + assert underbar(a=10).flow.compute({"value": 2}).value == 12 + + +def test_additional_validation_and_hint_fallback_paths(monkeypatch): + class MissingFieldContext(ContextBase): + start_date: date + + with pytest.raises(TypeError, match="must define fields for all FromContext parameters"): + + @Flow.model(context_type=MissingFieldContext) + def bad_missing(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return 0 + + class ExtraRequiredContext(ContextBase): + start_date: date + end_date: date + label: str + + with pytest.raises(TypeError, match="has required fields that are not declared as FromContext parameters"): + + @Flow.model(context_type=ExtraRequiredContext) + def bad_extra(start_date: FromContext[date], end_date: FromContext[date]) -> int: + return 0 + + class BadAnnotationContext(ContextBase): + value: str + + with pytest.raises(TypeError, match="annotates"): + + @Flow.model(context_type=BadAnnotationContext) + def bad_annotation(value: FromContext[int]) -> int: + return value + + with pytest.raises(TypeError, match="context_type must be a ContextBase subclass"): + + @Flow.model(context_type=int) + def bad_context_type(value: FromContext[int]) -> int: + return value + + @Flow.model + def source(context: SimpleContext) -> GenericResult[int]: + return GenericResult(value=context.value) + + with pytest.raises(TypeError, match="cannot default to a CallableModel"): + + @Flow.model + def bad_default(value: FromContext[int] = source()) -> int: + return value + + with pytest.raises(TypeError, match="return type annotation"): + + @Flow.model + def missing_return(value: int): + return value + + def missing_hints(*args, **kwargs): + raise AttributeError("missing hints") + + monkeypatch.setattr(flow_model_module, "get_type_hints", missing_hints) + + @Flow.model + def add(x: int, y: FromContext[int]) -> int: + return x + y + + assert add(x=1).flow.compute(y=2).value == 3 From 9319fc28418c39cbf9a1e3576517cc67ae4227f1 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 9 Apr 2026 04:54:49 -0400 Subject: [PATCH 22/26] Remove pandas compat changes Signed-off-by: Nijat Khanbabayev --- ccflow/exttypes/frequency.py | 46 +----------------------------------- ccflow/utils/chunker.py | 9 +++---- ccflow/validators.py | 5 ++-- 3 files changed, 6 insertions(+), 54 deletions(-) diff --git a/ccflow/exttypes/frequency.py b/ccflow/exttypes/frequency.py index afb772c..33c16b5 100644 --- a/ccflow/exttypes/frequency.py +++ b/ccflow/exttypes/frequency.py @@ -1,4 +1,3 @@ -import re import warnings from datetime import timedelta from functools import cached_property @@ -33,13 +32,6 @@ def _validate(cls, value) -> "Frequency": if isinstance(value, cls): return cls._validate(str(value)) - if isinstance(value, timedelta): - if value.total_seconds() % 86400 == 0: - return cls(f"{int(value.total_seconds() // 86400)}D") - - if isinstance(value, str): - value = _normalize_frequency_alias(value) - if isinstance(value, (timedelta, str)): try: with warnings.catch_warnings(): @@ -51,7 +43,7 @@ def _validate(cls, value) -> "Frequency": raise ValueError(f"ensure this value can be converted to a pandas offset: {e}") if isinstance(value, pd.offsets.DateOffset): - return cls(_canonicalize_offset_string(value)) + return cls(f"{value.n}{value.base.freqstr}") raise ValueError(f"ensure this value can be converted to a pandas offset: {value}") @@ -62,39 +54,3 @@ def validate(cls, value) -> "Frequency": _TYPE_ADAPTER = TypeAdapter(Frequency) - - -_LEGACY_FREQ_PATTERN = re.compile( - r"^(?P[+-]?\d+)?(?PT|M|A|Y)(?:-(?P[A-Za-z]{3}))?$", - re.IGNORECASE, -) - - -def _normalize_frequency_alias(value: str) -> str: - normalized = value.strip() - if not normalized: - return normalized - - match = _LEGACY_FREQ_PATTERN.fullmatch(normalized) - if not match: - day_match = re.fullmatch(r"(?P[+-]?\d+)?d", normalized, re.IGNORECASE) - if day_match: - return f"{day_match.group('count') or 1}D" - return normalized - - count = match.group("count") or "1" - unit = match.group("unit").upper() - suffix = (match.group("suffix") or "DEC").upper() - replacements = { - "T": f"{count}min", - "M": f"{count}ME", - "A": f"{count}YE-{suffix}", - "Y": f"{count}YE-{suffix}", - } - return replacements[unit] - - -def _canonicalize_offset_string(offset: pd.offsets.DateOffset) -> str: - if isinstance(offset, pd.offsets.Day): - return f"{offset.n}D" - return f"{offset.n}{offset.base.freqstr}" diff --git a/ccflow/utils/chunker.py b/ccflow/utils/chunker.py index fb32c70..605bfbd 100644 --- a/ccflow/utils/chunker.py +++ b/ccflow/utils/chunker.py @@ -12,8 +12,6 @@ import pandas as pd -from ccflow.exttypes.frequency import _normalize_frequency_alias - _MIN_END_DATE = date(1969, 12, 31) __all__ = ("dates_to_chunks",) @@ -33,20 +31,19 @@ def dates_to_chunks(start: date, end: date, chunk_size: str = "ME", trim: bool = Returns: List of tuples of (start date, end date) for each of the chunks """ - normalized_chunk_size = _normalize_frequency_alias(chunk_size) with warnings.catch_warnings(): # Because pandas 2.2 deprecated many frequency strings (i.e. "Y", "M", "T" still in common use) # We should consider switching away from pandas on this and supporting ISO warnings.simplefilter("ignore", category=FutureWarning) - offset = pd.tseries.frequencies.to_offset(normalized_chunk_size) + offset = pd.tseries.frequencies.to_offset(chunk_size) if offset.n == 1: - end_dates = pd.date_range(start - offset, end + offset, freq=normalized_chunk_size) + end_dates = pd.date_range(start - offset, end + offset, freq=chunk_size) else: # Need to anchor the timeline at some absolute date, because otherwise chunks might depend on the start date # and end up overlappig each other, i.e. with 2M, would end up with # i.e. (Jan-Feb) or (Feb,Mar) depending on whether start date was in Jan or Feb, # instead of always returning (Jan,Feb) for any start date in either of those two months. - end_dates = pd.date_range(_MIN_END_DATE, end + offset, freq=normalized_chunk_size) + end_dates = pd.date_range(_MIN_END_DATE, end + offset, freq=chunk_size) start_dates = end_dates + pd.DateOffset(1) chunks = [(s, e) for s, e in zip(start_dates[:-1].date, end_dates[1:].date) if e >= start and s <= end] if trim: diff --git a/ccflow/validators.py b/ccflow/validators.py index fee7698..720187c 100644 --- a/ccflow/validators.py +++ b/ccflow/validators.py @@ -9,7 +9,6 @@ from pydantic import TypeAdapter, ValidationError from .exttypes import PyObjectPath -from .exttypes.frequency import _normalize_frequency_alias _DatetimeAdapter = TypeAdapter(datetime) @@ -26,7 +25,7 @@ def normalize_date(v: Any) -> Any: """Validator that will convert string offsets to date based on today, and convert datetime to date.""" if isinstance(v, str): # Check case where it's an offset try: - timestamp = pd.tseries.frequencies.to_offset(_normalize_frequency_alias(v)) + date.today() + timestamp = pd.tseries.frequencies.to_offset(v) + date.today() return timestamp.date() except ValueError: pass @@ -45,7 +44,7 @@ def normalize_datetime(v: Any) -> Any: """Validator that will convert string offsets to datetime based on today, and convert datetime to date.""" if isinstance(v, str): # Check case where it's an offset try: - return (pd.tseries.frequencies.to_offset(_normalize_frequency_alias(v)) + date.today()).to_pydatetime() + return (pd.tseries.frequencies.to_offset(v) + date.today()).to_pydatetime() except ValueError: pass if isinstance(v, dict): From 2b3e1c93e346e6d6299d6e551de0173558a4515c Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 9 Apr 2026 16:55:13 -0400 Subject: [PATCH 23/26] Add checks to not allow positional args, or **kwargs in @Flow.model decorated functions Signed-off-by: Nijat Khanbabayev --- ccflow/flow_model.py | 12 ++++++++++++ ccflow/tests/test_flow_model.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index d857445..b102dea 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -856,6 +856,13 @@ def _analyze_flow_model( explicit_context_type = None if explicit_context_param is not None: + explicit_context = params[explicit_context_param] + if explicit_context.kind is inspect.Parameter.POSITIONAL_ONLY: + raise TypeError(f"Function {_callable_name(fn)} does not support positional-only parameter '{explicit_context_param}'.") + if explicit_context.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError( + f"Function {_callable_name(fn)} does not support {explicit_context.kind.description} parameter '{explicit_context_param}'." + ) context_annotation = resolved_hints.get(explicit_context_param, params[explicit_context_param].annotation) explicit_context_type = _concrete_context_type(context_annotation) if explicit_context_type is None: @@ -867,6 +874,11 @@ def _analyze_flow_model( if name == "self" or name == explicit_context_param: continue + if param.kind is inspect.Parameter.POSITIONAL_ONLY: + raise TypeError(f"Function {_callable_name(fn)} does not support positional-only parameter '{name}'.") + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + raise TypeError(f"Function {_callable_name(fn)} does not support {param.kind.description} parameter '{name}'.") + annotation = resolved_hints.get(name, param.annotation) if annotation is inspect.Parameter.empty: raise TypeError(f"Parameter '{name}' must have a type annotation") diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 7dc555c..4b92c4a 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -639,6 +639,36 @@ def bad_default(value: FromContext[int] = source()) -> int: def missing_return(value: int): return value + with pytest.raises(TypeError, match="does not support positional-only parameter 'value'"): + + @Flow.model + def bad_positional_only(value: int, /, bonus: FromContext[int]) -> int: + return value + bonus + + with pytest.raises(TypeError, match="does not support variadic positional parameter 'values'"): + + @Flow.model + def bad_varargs(*values: int) -> int: + return sum(values) + + with pytest.raises(TypeError, match="does not support variadic keyword parameter 'values'"): + + @Flow.model + def bad_varkw(**values: int) -> int: + return sum(values.values()) + + @Flow.model + def keyword_only(value: int, *, bonus: FromContext[int]) -> int: + return value + bonus + + assert keyword_only(value=2).flow.compute(bonus=3).value == 5 + + @Flow.model + def keyword_only_context(*, context: SimpleContext, offset: int) -> int: + return context.value + offset + + assert keyword_only_context(offset=4).flow.compute({"value": 3}).value == 7 + def missing_hints(*args, **kwargs): raise AttributeError("missing hints") From c872d6ba2e7ff9eb37cc944afd3493229ecd1793 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 9 Apr 2026 17:17:58 -0400 Subject: [PATCH 24/26] Handle normal bound arguments and context conflict seamlessly Signed-off-by: Nijat Khanbabayev --- ccflow/flow_model.py | 19 ++++--- ccflow/tests/test_flow_context.py | 19 +++++++ ccflow/tests/test_flow_model.py | 16 +++++- docs/design/flow_model_design.md | 89 ++++++++++++++++++++++++++++--- 4 files changed, 128 insertions(+), 15 deletions(-) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index b102dea..3dc385e 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -479,17 +479,22 @@ def _build_generated_compute_context(model: "_GeneratedFlowModelBase", context: return FlowContext(**_context_values(context)) return FlowContext.model_validate(context) - invalid = sorted(set(kwargs) - set(config.context_input_types)) - if invalid: - names = ", ".join(invalid) - raise TypeError(f"compute() only accepts contextual inputs. Bind regular parameter(s) separately: {names}.") + unresolved_regular = sorted( + name for name in config.regular_param_names if name in kwargs and _is_unset_flow_input(getattr(model, name, _UNSET_FLOW_INPUT)) + ) + if unresolved_regular: + names = ", ".join(unresolved_regular) + raise TypeError( + f"compute() cannot satisfy unbound regular parameter(s): {names}. " + "Bind them at construction time; compute() only supplies runtime context." + ) - coerced = {} + ambient = dict(kwargs) for param in config.contextual_params: if param.name not in kwargs: continue - coerced[param.name] = _coerce_value(param.name, kwargs[param.name], param.validation_annotation, "compute() input") - return FlowContext(**coerced) + ambient[param.name] = _coerce_value(param.name, kwargs[param.name], param.validation_annotation, "compute() input") + return FlowContext(**ambient) class FlowAPI: diff --git a/ccflow/tests/test_flow_context.py b/ccflow/tests/test_flow_context.py index 50a6754..1480d3b 100644 --- a/ccflow/tests/test_flow_context.py +++ b/ccflow/tests/test_flow_context.py @@ -103,6 +103,25 @@ def combine(left: int, right: int, value: FromContext[int]) -> int: assert model.flow.compute(value=5).value == (6 + 15 + 5) +def test_compute_kwargs_can_supply_ambient_context_for_upstream_transforms(): + @Flow.model + def source(value: FromContext[int]) -> int: + return value + + @Flow.model + def combine(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + base = source() + model = combine( + left=base.flow.with_inputs(value=lambda ctx: ctx.value + 1), + right=base.flow.with_inputs(value=lambda ctx: ctx.value + 10), + ) + + assert model.flow.context_inputs == {"bonus": int} + assert model.flow.compute(value=5, bonus=100).value == (6 + 15 + 100) + + def test_bound_model_rejects_regular_field_rewrites(): @Flow.model def add(a: int, b: FromContext[int]) -> int: diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 4b92c4a..ade6408 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -143,7 +143,7 @@ def foo(a: int, b: FromContext[int]) -> int: assert foo(a=11).flow.compute(b=12).value == 23 assert foo(a=11, b=12).flow.compute().value == 23 - with pytest.raises(TypeError, match="compute\\(\\) only accepts contextual inputs"): + with pytest.raises(TypeError, match="compute\\(\\) cannot satisfy unbound regular parameter\\(s\\): a"): foo().flow.compute(a=11, b=12) @@ -157,9 +157,23 @@ def foo(a: int, b: FromContext[int]) -> int: return a + b model = foo(a=source(offset=5)) + assert model.flow.compute(value=7, b=12).value == 24 assert model.flow.compute(FlowContext(value=7, b=12)).value == 24 +def test_bound_regular_param_name_can_collide_with_ambient_context(): + @Flow.model + def source(a: FromContext[int]) -> int: + return a + + @Flow.model + def combine(a: int, left: int, bonus: FromContext[int]) -> int: + return a + left + bonus + + model = combine(a=100, left=source()) + assert model.flow.compute(a=7, bonus=5).value == 112 + + def test_contextual_param_rejects_callable_model(): @Flow.model def source(context: SimpleContext, offset: int) -> GenericResult[int]: diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index b6055ae..7e20eab 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -8,7 +8,7 @@ The design is intentionally narrow: - ordinary unmarked parameters are regular bound inputs, - `FromContext[T]` marks the only runtime/contextual inputs, -- `.flow.compute(...)` supplies contextual inputs, +- `.flow.compute(...)` is the execution entry point for the full DAG, - `.flow.with_inputs(...)` rewires contextual inputs on one dependency edge, - upstream `CallableModel`s can still be passed as ordinary arguments. @@ -51,7 +51,8 @@ This is the core contract: - `b` is contextual because it is marked with `FromContext[int]` — it can come from runtime context, a contextual default stored on the model instance, or a function default, -- `.flow.compute(...)` only accepts contextual inputs. +- `.flow.compute(...)` may carry extra ambient context for upstream graph + branches, but it never binds regular parameters. Nothing is being mutated at execution time in the second example. `prefilled = foo(a=11, b=12)` constructs a different model instance whose @@ -62,11 +63,14 @@ This means the following is **invalid**: ```python foo().flow.compute(a=11, b=12) -# TypeError: compute() only accepts contextual inputs. -# Bind regular parameter(s) separately: a +# TypeError: compute() cannot bind regular parameter(s): a. +# Bind them at construction time. ``` `a` is not contextual, so it must be bound at construction time (`foo(a=11)`). +By contrast, extra ambient fields that are only needed by upstream +`with_inputs(...)` rewrites are allowed on the kwargs entrypoint for +implicit-`FlowContext` graphs. ## Regular Parameters vs Contextual Parameters @@ -130,11 +134,12 @@ Contextual precedence is: ## `.flow.compute(...)` -`.flow.compute(...)` is the ergonomic entry point for contextual execution. +`.flow.compute(...)` is the ergonomic execution entry point for contextual +execution of the whole DAG. For generated `@Flow.model` stages it accepts either: -- contextual keyword arguments, or +- keyword arguments that become the ambient runtime context bag, or - one context object. It does not accept both at the same time. @@ -153,7 +158,77 @@ assert model.flow.compute(b=5).value == 15 assert model.flow.compute(FlowContext(b=6)).value == 16 ``` -`compute()` returns the same result object you would get from `model(context)`. +For implicit-`FlowContext` models, the kwargs form is intentionally a DAG +entrypoint: it can include extra fields needed only by upstream transformed +dependencies. Regular parameters are still never read from runtime context. If +the root model has an unbound regular parameter whose name appears in +`compute(**kwargs)`, `compute()` raises early instead of silently treating that +value as configuration. + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def source(value: FromContext[int]) -> int: + return value + + +@Flow.model +def add(left: int, right: int, bonus: FromContext[int]) -> int: + return left + right + bonus + + +base = source() +model = add( + left=base.flow.with_inputs(value=lambda ctx: ctx.value + 1), + right=base.flow.with_inputs(value=lambda ctx: ctx.value + 10), +) + +assert model.flow.context_inputs == {"bonus": int} +assert model.flow.compute(value=5, bonus=100).value == 121 +``` + +If a regular parameter is already bound on the root model, a same-named key in +`compute(**kwargs)` is treated as ambient context for the graph rather than a +rebind of the root parameter: + +```python +from ccflow import Flow, FromContext + + +@Flow.model +def source(a: FromContext[int]) -> int: + return a + + +@Flow.model +def combine(a: int, left: int, bonus: FromContext[int]) -> int: + return a + left + bonus + + +model = combine(a=100, left=source()) + +# Root 'a' stays bound to 100. The runtime 'a=7' is still available to +# upstream graph nodes that read it from context. +assert model.flow.compute(a=7, bonus=5).value == 112 +``` + +`compute()` returns the same result object you would get from `model(context)`, +unless `auto_unwrap=True` is enabled for an auto-wrapped plain return type: + +```python +from ccflow import Flow, FromContext + + +@Flow.model(auto_unwrap=True) +def add(a: int, b: FromContext[int]) -> int: + return a + b + + +result = add(a=10).flow.compute(b=5) +assert result == 15 +``` ## `.flow.with_inputs(...)` From a3deba32780e6fd6aef4e92ff2e714ac77281765 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Thu, 9 Apr 2026 18:40:51 -0400 Subject: [PATCH 25/26] Simplify code Signed-off-by: Nijat Khanbabayev --- ccflow/callable.py | 7 --- ccflow/flow_model.py | 68 +++++-------------------- ccflow/tests/test_flow_model.py | 85 ++++++++++++++------------------ docs/design/flow_model_design.md | 29 +++-------- docs/wiki/Key-Features.md | 21 +++----- 5 files changed, 61 insertions(+), 149 deletions(-) diff --git a/ccflow/callable.py b/ccflow/callable.py index 9012bbb..9ea7a2f 100644 --- a/ccflow/callable.py +++ b/ccflow/callable.py @@ -562,13 +562,6 @@ def load_prices( ) -> GenericResult[pl.DataFrame]: return GenericResult(value=query_db(source, start_date, end_date)) - Advanced interop path: - Functions may still declare an explicit context parameter annotated - with a ContextBase subclass. - - @Flow.model - def load_prices(context: DateRangeContext, source: str) -> GenericResult[pl.DataFrame]: - return GenericResult(value=query_db(source, context.start_date, context.end_date)) Dependencies: Any ordinary parameter can be bound either to a literal value or diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index 3dc385e..bb32506 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -103,7 +103,6 @@ class _FlowModelConfig: result_type: Type[ResultBase] auto_wrap_result: bool auto_unwrap: bool - explicit_context_param: Optional[str] parameters: Tuple[_FlowModelParam, ...] context_input_types: Dict[str, Any] context_required_names: Tuple[str, ...] @@ -125,10 +124,6 @@ def regular_param_names(self) -> Tuple[str, ...]: def contextual_param_names(self) -> Tuple[str, ...]: return tuple(param.name for param in self.contextual_params) - @property - def uses_explicit_context(self) -> bool: - return self.explicit_context_param is not None - def param(self, name: str) -> _FlowModelParam: for param in self.parameters: if param.name == name: @@ -467,11 +462,6 @@ def _build_generated_compute_context(model: "_GeneratedFlowModelBase", context: if context is not _UNSET and kwargs: raise TypeError("compute() accepts either one context object or contextual keyword arguments, but not both.") - if config.uses_explicit_context: - if context is _UNSET: - return config.context_type.model_validate(kwargs) - return context if isinstance(context, ContextBase) else config.context_type.model_validate(context) - if context is not _UNSET: if isinstance(context, FlowContext): return context @@ -532,8 +522,6 @@ def unbound_inputs(self) -> Dict[str, Any]: generated = _generated_model_instance(self._model) if generated is not None: config = type(generated).__flow_model_config__ - if config.uses_explicit_context: - return {name: config.context_input_types[name] for name in config.context_required_names} result = {} for param in config.contextual_params: if not _is_unset_flow_input(getattr(generated, param.name, _UNSET_FLOW_INPUT)): @@ -621,7 +609,7 @@ def _transform_context(self, context: ContextBase) -> ContextBase: ctx_dict[name] = value generated = _generated_model_instance(self.model) - if generated is not None and not type(generated).__flow_model_config__.uses_explicit_context: + if generated is not None: return FlowContext(**ctx_dict) context_type = _concrete_context_type(self.model.context_type) @@ -754,10 +742,7 @@ def __call__(self, context): for param in config.regular_params: fn_kwargs[param.name] = _resolve_regular_param_value(self, param, context) - if config.uses_explicit_context: - fn_kwargs[cast(str, config.explicit_context_param)] = context - else: - fn_kwargs.update(_resolved_contextual_inputs(self, config, context)) + fn_kwargs.update(_resolved_contextual_inputs(self, config, context)) raw_result = config.func(**fn_kwargs) if config.auto_wrap_result: @@ -851,32 +836,10 @@ def _analyze_flow_model( ) -> _FlowModelConfig: params = sig.parameters - explicit_context_param = None - if "context" in params: - explicit_context_param = "context" - elif "_" in params: - explicit_context_param = "_" - analyzed_params: List[_FlowModelParam] = [] - explicit_context_type = None - - if explicit_context_param is not None: - explicit_context = params[explicit_context_param] - if explicit_context.kind is inspect.Parameter.POSITIONAL_ONLY: - raise TypeError(f"Function {_callable_name(fn)} does not support positional-only parameter '{explicit_context_param}'.") - if explicit_context.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): - raise TypeError( - f"Function {_callable_name(fn)} does not support {explicit_context.kind.description} parameter '{explicit_context_param}'." - ) - context_annotation = resolved_hints.get(explicit_context_param, params[explicit_context_param].annotation) - explicit_context_type = _concrete_context_type(context_annotation) - if explicit_context_type is None: - raise TypeError(f"Function {_callable_name(fn)}: '{explicit_context_param}' must be annotated with a ContextBase subclass.") - if context_type is not None: - raise TypeError("context_type=... is inferred from the explicit context parameter; remove the keyword argument.") for name, param in params.items(): - if name == "self" or name == explicit_context_param: + if name == "self": continue if param.kind is inspect.Parameter.POSITIONAL_ONLY: @@ -909,23 +872,15 @@ def _analyze_flow_model( ) contextual_params = tuple(param for param in analyzed_params if param.is_contextual) - if explicit_context_param is not None and contextual_params: - raise TypeError("Functions using an explicit context parameter cannot also declare FromContext[...] parameters.") - declared_context_type = None - if explicit_context_type is not None: - call_context_type = explicit_context_type - context_input_types = {name: info.annotation for name, info in explicit_context_type.model_fields.items()} - context_required_names = tuple(name for name, info in explicit_context_type.model_fields.items() if info.is_required()) - else: - if context_type is not None and not contextual_params: - raise TypeError("context_type=... requires FromContext[...] parameters or an explicit context parameter.") - if context_type is not None: - declared_context_type = _validate_declared_context_type(context_type, contextual_params) - - call_context_type = FlowContext - context_input_types = {param.name: param.annotation for param in contextual_params} - context_required_names = tuple(param.name for param in contextual_params if not param.has_function_default) + if context_type is not None and not contextual_params: + raise TypeError("context_type=... requires FromContext[...] parameters.") + if context_type is not None: + declared_context_type = _validate_declared_context_type(context_type, contextual_params) + + call_context_type = FlowContext + context_input_types = {param.name: param.annotation for param in contextual_params} + context_required_names = tuple(param.name for param in contextual_params if not param.has_function_default) if declared_context_type is not None: updated_params = [] @@ -961,7 +916,6 @@ def _analyze_flow_model( result_type=result_type, auto_wrap_result=auto_wrap_result, auto_unwrap=auto_unwrap, - explicit_context_param=explicit_context_param, parameters=tuple(analyzed_params), context_input_types=context_input_types, context_required_names=context_required_names, diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index ade6408..4be6043 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -49,44 +49,44 @@ def _validate_order(self): @Flow.model -def basic_loader(context: SimpleContext, source: str, multiplier: int) -> GenericResult[int]: - return GenericResult(value=context.value * multiplier) +def basic_loader(source: str, multiplier: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value * multiplier) @Flow.model -def string_processor(context: SimpleContext, prefix: str = "value=", suffix: str = "!") -> GenericResult[str]: - return GenericResult(value=f"{prefix}{context.value}{suffix}") +def string_processor(value: FromContext[int], prefix: str = "value=", suffix: str = "!") -> GenericResult[str]: + return GenericResult(value=f"{prefix}{value}{suffix}") @Flow.model -def data_source(context: SimpleContext, base_value: int) -> GenericResult[int]: - return GenericResult(value=context.value + base_value) +def data_source(base_value: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + base_value) @Flow.model -def data_transformer(context: SimpleContext, source: int, factor: int) -> GenericResult[int]: +def data_transformer(source: int, factor: int) -> GenericResult[int]: return GenericResult(value=source * factor) @Flow.model -def data_aggregator(context: SimpleContext, input_a: int, input_b: int, operation: str = "add") -> GenericResult[int]: +def data_aggregator(input_a: int, input_b: int, operation: str = "add") -> GenericResult[int]: if operation == "add": return GenericResult(value=input_a + input_b) raise ValueError(f"unsupported operation: {operation}") @Flow.model -def pipeline_stage1(context: SimpleContext, initial: int) -> GenericResult[int]: - return GenericResult(value=context.value + initial) +def pipeline_stage1(initial: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + initial) @Flow.model -def pipeline_stage2(context: SimpleContext, stage1_output: int, multiplier: int) -> GenericResult[int]: +def pipeline_stage2(stage1_output: int, multiplier: int) -> GenericResult[int]: return GenericResult(value=stage1_output * multiplier) @Flow.model -def pipeline_stage3(context: SimpleContext, stage2_output: int, offset: int) -> GenericResult[int]: +def pipeline_stage3(stage2_output: int, offset: int) -> GenericResult[int]: return GenericResult(value=stage2_output + offset) @@ -108,7 +108,7 @@ def date_range_loader_previous_day( @Flow.model -def date_range_processor(context: DateRangeContext, raw_data: dict, normalize: bool = False) -> GenericResult[str]: +def date_range_processor(raw_data: dict, normalize: bool = False) -> GenericResult[str]: prefix = "normalized:" if normalize else "raw:" return GenericResult(value=f"{prefix}{raw_data['source']}:{raw_data['start_date']} to {raw_data['end_date']}") @@ -176,8 +176,8 @@ def combine(a: int, left: int, bonus: FromContext[int]) -> int: def test_contextual_param_rejects_callable_model(): @Flow.model - def source(context: SimpleContext, offset: int) -> GenericResult[int]: - return GenericResult(value=context.value + offset) + def source(offset: int, value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value + offset) @Flow.model def foo(a: int, b: FromContext[int]) -> int: @@ -234,24 +234,18 @@ def add(a: FromContext[int], b: FromContext[int]) -> int: add(a=2, b=1).flow.compute() -def test_explicit_context_interop_accepts_pep604_optional_annotation(): - @Flow.model - def loader(context: DateRangeContext | None, source: str = "db") -> GenericResult[str]: - assert context is not None - return GenericResult(value=f"{source}:{context.start_date}:{context.end_date}") - - model = loader(source="api") - assert model.flow.compute(start_date="2024-01-01", end_date="2024-01-02").value == "api:2024-01-01:2024-01-02" - - -def test_explicit_context_interop_still_works(): +def test_context_named_parameters_are_just_regular_parameters(): @Flow.model def loader(context: DateRangeContext, source: str = "db") -> GenericResult[str]: return GenericResult(value=f"{source}:{context.start_date}:{context.end_date}") - model = loader(source="api") - assert model(DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2))).value == "api:2024-01-01:2024-01-02" - assert model.flow.compute(start_date="2024-01-01", end_date="2024-01-02").value == "api:2024-01-01:2024-01-02" + model = loader(context=DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)), source="api") + assert model.flow.bound_inputs["context"] == DateRangeContext(start_date=date(2024, 1, 1), end_date=date(2024, 1, 2)) + assert model.flow.context_inputs == {} + assert model.flow.compute().value == "api:2024-01-01:2024-01-02" + + with pytest.raises(TypeError, match="Missing regular parameter\\(s\\) for loader: context"): + loader(source="api").flow.compute(start_date="2024-01-01", end_date="2024-01-02") def test_auto_unwrap_defaults_to_false_for_auto_wrapped_results(): @@ -334,12 +328,15 @@ def add(a: int, b: FromContext[int]) -> int: return a + b -def test_explicit_context_and_from_context_cannot_mix(): - with pytest.raises(TypeError, match="cannot also declare FromContext"): +def test_context_named_regular_parameter_can_coexist_with_from_context(): + @Flow.model + def mixed(context: SimpleContext, y: FromContext[int]) -> int: + return context.value + y - @Flow.model - def bad(context: SimpleContext, y: FromContext[int]) -> int: - return context.value + y + model = mixed(context=SimpleContext(value=10)) + assert model.flow.bound_inputs == {"context": SimpleContext(value=10)} + assert model.flow.context_inputs == {"y": int} + assert model.flow.compute(y=5).value == 15 def test_context_args_keyword_is_removed(): @@ -350,7 +347,7 @@ def bad(x: int) -> int: return x -def test_context_type_requires_from_context_or_explicit_context(): +def test_context_type_requires_from_context(): with pytest.raises(TypeError, match="context_type=... requires FromContext"): @Flow.model(context_type=DateRangeContext) @@ -383,8 +380,8 @@ def choose(value: int, lazy_value: Lazy[int], threshold: FromContext[int]) -> in def test_lazy_runtime_helper_is_removed(): @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) + def source(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value) with pytest.raises(TypeError, match="Lazy\\(model\\)\\(\\.\\.\\.\\) has been removed"): Lazy(source()) @@ -583,7 +580,7 @@ def __deps__(self, context: SimpleContext): model.flow.compute(SimpleContext(value=1), value=2) -def test_explicit_context_paths_and_underbar_context_parameter(): +def test_compute_accepts_context_object_for_from_context_models(): model = basic_loader(source="warehouse", multiplier=3) assert model.flow.context_inputs == {"value": int} @@ -594,12 +591,6 @@ def test_explicit_context_paths_and_underbar_context_parameter(): with pytest.raises(TypeError, match="either one context object or contextual keyword arguments"): model.flow.compute(SimpleContext(value=1), value=2) - @Flow.model - def underbar(_: SimpleContext, a: int) -> int: - return _.value + a - - assert underbar(a=10).flow.compute({"value": 2}).value == 12 - def test_additional_validation_and_hint_fallback_paths(monkeypatch): class MissingFieldContext(ContextBase): @@ -638,8 +629,8 @@ def bad_context_type(value: FromContext[int]) -> int: return value @Flow.model - def source(context: SimpleContext) -> GenericResult[int]: - return GenericResult(value=context.value) + def source(value: FromContext[int]) -> GenericResult[int]: + return GenericResult(value=value) with pytest.raises(TypeError, match="cannot default to a CallableModel"): @@ -681,7 +672,7 @@ def keyword_only(value: int, *, bonus: FromContext[int]) -> int: def keyword_only_context(*, context: SimpleContext, offset: int) -> int: return context.value + offset - assert keyword_only_context(offset=4).flow.compute({"value": 3}).value == 7 + assert keyword_only_context(context=SimpleContext(value=3), offset=4).flow.compute().value == 7 def missing_hints(*args, **kwargs): raise AttributeError("missing hints") diff --git a/docs/design/flow_model_design.md b/docs/design/flow_model_design.md index 7e20eab..7aa0a6c 100644 --- a/docs/design/flow_model_design.md +++ b/docs/design/flow_model_design.md @@ -274,27 +274,10 @@ Key rules: - chained `with_inputs()` calls merge, with the newest transform winning for a repeated field. -## Explicit Context Interop +## `context_type=...` -`@Flow.model` still supports an explicit context parameter for cases where the -function needs the whole context object: - -```python -from ccflow import DateRangeContext, Flow - - -@Flow.model -def load_revenue(context: DateRangeContext, region: str) -> float: - days = (context.end_date - context.start_date).days + 1 - return days * 50.0 -``` - -This path is useful when interoperating with existing code that already uses -typed `ContextBase` subclasses, or when the function genuinely needs access to -the full context rather than individual fields. - -You can also keep the `FromContext[...]` style while asking ccflow to validate -those contextual fields against an existing nominal context shape: +When you want the `FromContext[...]` fields to match an existing nominal +context shape, use `context_type=...`: ```python from ccflow import DateRangeContext, Flow, FromContext @@ -309,9 +292,9 @@ That preserves the primary `FromContext[...]` authoring model while letting callers pass richer context objects whose relevant fields satisfy the declared `context_type`. -Do not mix both systems in one function signature. A function with an explicit -`context: ContextBase` parameter cannot also declare `FromContext[...]` -parameters. +If the function genuinely needs the runtime context object itself inside the +function body on each call, use a normal `CallableModel` subclass instead of +`@Flow.model`. ## Introspection APIs diff --git a/docs/wiki/Key-Features.md b/docs/wiki/Key-Features.md index 1f8adcb..0da832c 100644 --- a/docs/wiki/Key-Features.md +++ b/docs/wiki/Key-Features.md @@ -77,21 +77,10 @@ Contextual parameters can be satisfied by: Contextual parameters cannot be bound to `CallableModel` values. -#### Explicit Context Interop +#### Nominal Context Validation -An explicit context parameter is still supported when that is the natural API: - -```python -from ccflow import DateRangeContext, Flow - - -@Flow.model -def load_data(context: DateRangeContext, source: str) -> float: - return 125.0 -``` - -You can also keep the `FromContext[...]` style while validating those fields -against an existing context type: +You can keep the `FromContext[...]` style while validating those fields against +an existing context type: ```python from datetime import date @@ -103,7 +92,9 @@ def load_data(source: str, start_date: FromContext[date], end_date: FromContext[ return 125.0 ``` -Do not mix both styles in one signature. +If the function genuinely needs the runtime context object itself inside the +function body on each call, write a normal `CallableModel` subclass instead of +using `@Flow.model`. #### Composing Dependencies From c479bb50d81dc5d0ff8c774ec2f4515506594383 Mon Sep 17 00:00:00 2001 From: Nijat Khanbabayev Date: Mon, 13 Apr 2026 12:08:17 -0400 Subject: [PATCH 26/26] Handle pickling of unset singleton Signed-off-by: Nijat Khanbabayev --- ccflow/flow_model.py | 29 +++++++++++++++++++++++++---- ccflow/tests/test_flow_model.py | 13 +++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/ccflow/flow_model.py b/ccflow/flow_model.py index bb32506..48c3da6 100644 --- a/ccflow/flow_model.py +++ b/ccflow/flow_model.py @@ -10,7 +10,7 @@ from typing import Annotated, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, get_type_hints from pydantic import Field, PrivateAttr, TypeAdapter, ValidationError, model_serializer, model_validator -from pydantic.errors import PydanticUndefinedAnnotation +from pydantic.errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation from .base import BaseModel, ContextBase, ResultBase from .callable import CallableModel, Flow, GraphDepList, WrapperModel @@ -29,12 +29,33 @@ def __repr__(self) -> str: return "" +class _InternalSentinel: + def __init__(self, name: str): + self._name = name + + def __repr__(self) -> str: + return self._name + + def __reduce__(self): + return (_get_internal_sentinel, (self._name,)) + + _UNSET_FLOW_INPUT = _UnsetFlowInput() -_UNSET = object() -_REMOVED_CONTEXT_ARGS = object() _UNION_ORIGINS = (Union, UnionType) +def _get_internal_sentinel(name: str) -> _InternalSentinel: + return _INTERNAL_SENTINELS[name] + + +_INTERNAL_SENTINELS = { + "_UNSET": _InternalSentinel("_UNSET"), + "_REMOVED_CONTEXT_ARGS": _InternalSentinel("_REMOVED_CONTEXT_ARGS"), +} +_UNSET = _INTERNAL_SENTINELS["_UNSET"] +_REMOVED_CONTEXT_ARGS = _INTERNAL_SENTINELS["_REMOVED_CONTEXT_ARGS"] + + def _unset_flow_input_factory() -> _UnsetFlowInput: return _UNSET_FLOW_INPUT @@ -185,7 +206,7 @@ def _type_adapter(annotation: Any) -> TypeAdapter: def _can_validate_type(annotation: Any) -> bool: try: _type_adapter(annotation) - except (PydanticUndefinedAnnotation, TypeError, ValueError): + except (PydanticSchemaGenerationError, PydanticUndefinedAnnotation, TypeError, ValueError): return False return True diff --git a/ccflow/tests/test_flow_model.py b/ccflow/tests/test_flow_model.py index 4be6043..5dd4abd 100644 --- a/ccflow/tests/test_flow_model.py +++ b/ccflow/tests/test_flow_model.py @@ -419,6 +419,19 @@ def multiply(a: int, b: FromContext[int]) -> int: assert restored.flow.compute(b=7).value == 42 +def test_generated_models_cloudpickle_preserves_unset_validation_sentinel(): + @Flow.model + def multiply(a: int, b: FromContext[int]) -> int: + return a * b + + model = multiply(a=6) + restored = rcploads(rcpdumps(model, protocol=5)) + param = type(restored).__flow_model_config__.contextual_params[0] + + assert param.context_validation_annotation is flow_model_module._UNSET + assert param.validation_annotation is int + + def test_graph_integration_fanout_fanin(): @Flow.model def source(base: int, value: FromContext[int]) -> int: