diff --git a/py/packages/genkit/src/genkit/__init__.py b/py/packages/genkit/src/genkit/__init__.py index 896ea8e2b0..ebbaee9a1f 100644 --- a/py/packages/genkit/src/genkit/__init__.py +++ b/py/packages/genkit/src/genkit/__init__.py @@ -41,6 +41,7 @@ Media, MediaPart, Metadata, + MiddlewareRef, Part, ReasoningPart, Role, @@ -131,6 +132,8 @@ 'DocumentPart', # Plugin interface 'Plugin', + # Middleware references (wire form for use= parameter) + 'MiddlewareRef', # AI runtime 'ActionKind', 'ActionRunContext', diff --git a/py/packages/genkit/src/genkit/_ai/_aio.py b/py/packages/genkit/src/genkit/_ai/_aio.py index b2b4a1c148..8922feca6b 100644 --- a/py/packages/genkit/src/genkit/_ai/_aio.py +++ b/py/packages/genkit/src/genkit/_ai/_aio.py @@ -44,12 +44,16 @@ ) from genkit._ai._formats import built_in_formats from genkit._ai._formats._types import FormatDef -from genkit._ai._generate import define_generate_action, generate_action, registry_with_inline_tools +from genkit._ai._generate import ( + define_generate_action, + generate_action, + registry_with_inline_middleware, + registry_with_inline_tools, +) from genkit._ai._model import ( Message, ModelConfig, ModelFn, - ModelMiddleware, ModelResponse, ModelResponseChunk, define_model, @@ -90,6 +94,7 @@ from genkit._core._environment import is_dev_environment from genkit._core._error import GenkitError from genkit._core._logger import get_logger +from genkit._core._middleware import BaseMiddleware, MiddlewareDesc, new_middleware from genkit._core._model import Document from genkit._core._plugin import Plugin from genkit._core._reflection import ReflectionServer, ServerSpec, create_reflection_asgi_app @@ -102,6 +107,7 @@ EmbedRequest, EvalRequest, EvalResponse, + MiddlewareRef, ModelInfo, Operation, Part, @@ -157,6 +163,7 @@ def __init__( self._initialize_registry(model, plugins) # Ensure the default generate action is registered for async usage. define_generate_action(self.registry) + self._register_plugin_middleware(plugins) # In dev mode, start the reflection server immediately in a background # daemon thread so it's available regardless of which web framework (or # none) the user chooses. @@ -425,7 +432,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, input_schema: type[InputT], output_schema: type[OutputT], @@ -453,7 +460,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, input_schema: type[InputT], output_schema: dict[str, object] | str | None = None, @@ -481,7 +488,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, input_schema: dict[str, object] | str | None = None, output_schema: type[OutputT], @@ -509,7 +516,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, input_schema: dict[str, object] | str | None = None, output_schema: dict[str, object] | str | None = None, @@ -535,7 +542,7 @@ def define_prompt( metadata: dict[str, object] | None = None, tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, input_schema: type | dict[str, object] | str | None = None, output_schema: type | dict[str, object] | str | None = None, @@ -725,6 +732,45 @@ def _initialize_registry(self, model: str | None, plugins: list[Plugin] | None) else: raise ValueError(f'Invalid {plugin=} provided to Genkit: must be of type `genkit.ai.Plugin`') + def _register_plugin_middleware(self, plugins: list[Plugin] | None) -> None: + """Register middleware descriptors returned by ``Plugin.list_middleware``.""" + if not plugins: + return + for plugin in plugins: + for desc in plugin.list_middleware(): + self.registry.register_value('middleware', desc.name, desc) + + def new_middleware(self, middleware_cls: type[BaseMiddleware]) -> MiddlewareDesc: + """Build a ``MiddlewareDesc`` from a class.""" + return new_middleware(middleware_cls) + + def define_middleware(self, middleware_cls: type[BaseMiddleware]) -> MiddlewareDesc: + """Register a middleware class on this app and return the descriptor. + + Registering a class: + + * Makes it visible to the **Dev UI** through the reflection API. + * Allows it to be referenced by name via :class:`MiddlewareRef`. + + Equivalent to building the descriptor with ``new_middleware(cls)`` + and wiring it through ``middleware_plugin([...])`` at construction + time, but usable after ``Genkit`` has already been built. + + The factory instantiates ``middleware_cls(**config)`` each time a + request resolves the name via :class:`MiddlewareRef`, so the same + pydantic fields drive both: + + * the inline path: ``use=[cls(...)]`` + * the registered path: ``use=[MiddlewareRef(name=cls.name)]`` + + Returns: + The registered :class:`MiddlewareDesc`. Also available via + ``registry.lookup_value('middleware', cls.name)``. + """ + desc = new_middleware(middleware_cls) + self.registry.register_value('middleware', desc.name, desc) + return desc + def run_main(self, coro: Coroutine[Any, Any, T]) -> T | None: """Run the user's main coroutine, blocking in dev mode for the reflection server.""" if not is_dev_environment(): @@ -799,7 +845,7 @@ async def generate( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, ) -> ModelResponse[OutputT]: ... @@ -826,7 +872,7 @@ async def generate( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, ) -> ModelResponse[Any]: ... @@ -851,7 +897,7 @@ async def generate( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, ) -> ModelResponse[Any]: """Generate text or structured data using a language model. @@ -860,6 +906,9 @@ async def generate( is covariant: ``list[Tool]`` or ``list[str]`` are both assignable to ``Sequence[str | Tool]``, but not to ``list[str | Tool]``. """ + registry = await registry_with_inline_tools(self.registry, tools) + child_registry = registry if registry.is_child else registry.new_child() + refs = registry_with_inline_middleware(child_registry, use) or None prompt_config = PromptConfig( model=model, prompt=prompt, @@ -879,13 +928,12 @@ async def generate( output_schema=output_schema, output_constrained=output_constrained, docs=docs, + use=refs, ) - registry = await registry_with_inline_tools(self.registry, prompt_config.tools) gen_options = await to_generate_action_options(registry, prompt_config) return await generate_action( - registry, + child_registry, gen_options, - middleware=use, context=context if context else ActionRunContext._current_context(), # pyright: ignore[reportPrivateUsage] ) @@ -912,7 +960,7 @@ def generate_stream( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, timeout: float | None = None, ) -> ModelStreamResponse[OutputT]: ... @@ -940,7 +988,7 @@ def generate_stream( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, timeout: float | None = None, ) -> ModelStreamResponse[Any]: ... @@ -966,7 +1014,7 @@ def generate_stream( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, timeout: float | None = None, ) -> ModelStreamResponse[Any]: @@ -974,6 +1022,9 @@ def generate_stream( channel: Channel[ModelResponseChunk, ModelResponse[Any]] = Channel(timeout=timeout) async def _run_generate() -> ModelResponse[Any]: + registry = await registry_with_inline_tools(self.registry, tools) + child_registry = registry if registry.is_child else registry.new_child() + refs = registry_with_inline_middleware(child_registry, use) or None prompt_config = PromptConfig( model=model, prompt=prompt, @@ -993,14 +1044,13 @@ async def _run_generate() -> ModelResponse[Any]: output_schema=output_schema, output_constrained=output_constrained, docs=docs, + use=refs, ) - registry = await registry_with_inline_tools(self.registry, prompt_config.tools) gen_options = await to_generate_action_options(registry, prompt_config) return await generate_action( - registry, + child_registry, gen_options, on_chunk=lambda c: channel.send(c), - middleware=use, context=context if context else ActionRunContext._current_context(), # pyright: ignore[reportPrivateUsage] ) @@ -1184,7 +1234,7 @@ async def generate_operation( output_content_type: str | None = None, output_instructions: str | None = None, output_constrained: bool | None = None, - use: list[ModelMiddleware] | None = None, + use: list[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, ) -> Operation: """Generate content using a long-running model, returning an Operation to poll.""" diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index d21bca090b..caefccf093 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -19,22 +19,22 @@ import asyncio import contextlib import copy -import inspect import re -from collections.abc import Callable, Sequence +import secrets +from collections.abc import Awaitable, Callable, Sequence from typing import Any, cast from pydantic import BaseModel +from typing_extensions import Never from genkit._ai._formats._types import FormatDef, Formatter from genkit._ai._messages import inject_instructions -from genkit._ai._middleware import augment_with_context from genkit._ai._model import ( Message, - ModelMiddleware, ModelRequest, ModelResponse, ModelResponseChunk, + text_from_content, ) from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources from genkit._ai._tools import Interrupt, Tool, run_tool_after_restart @@ -46,13 +46,26 @@ ) from genkit._core._error import GenkitError from genkit._core._logger import get_logger -from genkit._core._model import GenerateActionOptions +from genkit._core._middleware import ( + BaseMiddleware, + GenerateHookParams, + MiddlewareDesc, + ModelHookParams, + MultipartToolResponse, + ToolHookParams, +) +from genkit._core._model import ( + Document, + GenerateActionOptions, +) from genkit._core._registry import Registry from genkit._core._tracing import SpanMetadata, run_in_new_span from genkit._core._typing import ( FinishReason, + MiddlewareRef, Part, Role, + TextPart, ToolDefinition, ToolRequest, ToolRequestPart, @@ -65,6 +78,142 @@ logger = get_logger(__name__) +def registry_with_inline_middleware( + registry: Registry, + use: Sequence[BaseMiddleware | MiddlewareRef] | None, +) -> list[MiddlewareRef]: + """Normalize a ``use=[...]`` list into registry-backed ``MiddlewareRef``s. + + Inline ``BaseMiddleware`` instances are registered into the (child) registry + under their class name — or an auto-generated ``__inline_{i}__`` name when + the class has no registered name — so that everything in ``use=`` can be + resolved uniformly via the registry. + + The returned list of refs has the same ordering as the input and can be + stored on ``GenerateActionOptions.use`` for consistent tracing / Dev UI + representation. + + Args: + registry: Per-call child registry. Inline instances are registered + here so they are automatically scoped to this generate() call. + use: Mixed list of inline instances and/or ``MiddlewareRef`` entries. + + Returns: + A list of ``MiddlewareRef`` covering every entry in ``use``. + """ + if not use: + return [] + refs: list[MiddlewareRef] = [] + # Track how many times each name appears so duplicates get unique suffixes. + name_counts: dict[str, int] = {} + for i, entry in enumerate(use): + if isinstance(entry, BaseMiddleware): + cls_name = entry.__class__.name # type: ignore[attr-defined] + base_name = str(cls_name) if cls_name else f'dynamic-middleware-{i}-{secrets.token_hex(5)}' + count = name_counts.get(base_name, 0) + name_counts[base_name] = count + 1 + reg_name = base_name if count == 0 else f'{base_name}__{count}' + + def _make_factory( + _i: BaseMiddleware = entry, # capture for the closure; mypy needs a non-lambda factory + ) -> Callable[[dict[str, Any] | None], BaseMiddleware]: + def _factory(_cfg: dict[str, Any] | None) -> BaseMiddleware: + return _i + + return _factory + + desc = MiddlewareDesc( + name=reg_name, + factory=_make_factory(), + ) + registry.register_value('middleware', reg_name, desc) + refs.append(MiddlewareRef(name=reg_name)) + else: + refs.append(entry) + return refs + + +def resolve_middleware_from_use( + registry: Registry, + use: Sequence[MiddlewareRef] | None, +) -> list[BaseMiddleware]: + """Resolve a list of ``MiddlewareRef``s to concrete ``BaseMiddleware`` instances. + + All entries must already be in the registry (inline instances were registered + there by :func:`registry_with_inline_middleware`). Order is preserved. + """ + if not use: + return [] + out: list[BaseMiddleware] = [] + for entry in use: + defn = registry.lookup_value('middleware', entry.name) + if isinstance(defn, MiddlewareDesc): + cfg = entry.config if isinstance(entry.config, dict) else None + out.append(defn(cfg)) + continue + raise GenkitError( + status='NOT_FOUND', + message=( + f'No middleware named "{entry.name}" is registered on this app. ' + 'Register descriptors with middleware_plugin([...]), Plugin.list_middleware(), ' + 'or ai.define_middleware(MyMiddleware); or pass the middleware instance directly ' + 'in use=[MyMiddleware(...)].' + ), + source='genkit.generate', + ) + return out + + +def _bind_call_state( + middleware: list[BaseMiddleware], + *, + registry: Registry, + enqueue_parts: Callable[[list[Part]], None], +) -> list[BaseMiddleware]: + """Return per-call copies of each middleware with framework attrs bound. + + The same ``BaseMiddleware`` instance can show up in concurrent ``generate()`` + calls (a user reusing ``Retry(max_retries=3)``, or a plugin descriptor whose + factory returns the same instance every time). Each call needs its own + ``self.registry`` and ``self.enqueue_parts`` — so we shallow-copy here and + set the attrs on the copy, leaving the caller's instance untouched. + """ + bound: list[BaseMiddleware] = [] + for mw in middleware: + copy = mw.model_copy() + copy.registry = registry + copy.enqueue_parts = enqueue_parts + bound.append(copy) + return bound + + +async def _chain_tool_middleware( + middleware: list[BaseMiddleware], + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]], +) -> MultipartToolResponse: + """Run the tool middleware chain and return the tool response. + + Interrupts propagate as ``Interrupt`` exceptions (or ``GenkitError`` wrapping + one) for the caller to catch and convert to wire-shape ``ToolRequestPart``. + """ + runner: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]] = next_fn + for mw in reversed(middleware): + _mw = mw + _inner = runner + + async def run_next( + p: ToolHookParams, + *, + _m: BaseMiddleware = _mw, + _i: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]] = _inner, + ) -> MultipartToolResponse: + return await _m.wrap_tool(p, _i) + + runner = run_next + return await runner(params) + + async def expand_wildcard_tools(registry: Registry, tool_names: list[str]) -> list[str]: """Expand DAP wildcard tool names into individual registry keys. @@ -152,6 +301,78 @@ async def registry_with_inline_tools(registry: Registry, tools: Sequence[str | T return child if child is not None else registry +_CONTEXT_PREFACE = '\n\nUse the following information to complete your task:\n\n' + + +def _last_user_message(messages: list[Message]) -> Message | None: + """Find the last user message in a list.""" + for i in range(len(messages) - 1, -1, -1): + if messages[i].role == 'user': + return messages[i] + return None + + +def _context_item_template(d: Document, index: int) -> str: + """Render a document as a citation line for context injection.""" + out = '- ' + ref = (d.metadata and (d.metadata.get('ref') or d.metadata.get('id'))) or index + out += f'[{ref}]: ' + out += text_from_content(d.content) + '\n' + return out + + +def _augment_with_context( + request: ModelRequest, + *, + preface: str | None = _CONTEXT_PREFACE, + item_template: Callable[[Document, int], str] | None = None, + citation_key: str | None = None, +) -> ModelRequest: + """Return a deepcopy of ``request`` with ``request.docs`` injected as a context part on the last user message. + + No-op (returns ``request`` unchanged) when there are no docs, no user message, or the last user message + already has a non-pending ``purpose: 'context'`` part. + """ + if not request.docs: + return request + + user_message = _last_user_message(request.messages) + if user_message is None: + return request + + context_part_index = -1 + for i, part in enumerate(user_message.content): + part_metadata = part.root.metadata if hasattr(part.root, 'metadata') else None + if isinstance(part_metadata, dict) and part_metadata.get('purpose') == 'context': + context_part_index = i + break + + if context_part_index >= 0: + existing_meta = user_message.content[context_part_index].root.metadata + if not (isinstance(existing_meta, dict) and existing_meta.get('pending')): + return request + + template = item_template or _context_item_template + out = preface or '' + for i, doc_data in enumerate(request.docs): + doc = Document(content=doc_data.content, metadata=doc_data.metadata) + if citation_key and doc.metadata: + doc.metadata['ref'] = doc.metadata.get(citation_key, i) + out += template(doc, i) + out += '\n' + + text_part = Part(root=TextPart(text=out, metadata={'purpose': 'context'})) + + new_req = copy.deepcopy(request) + new_user = _last_user_message(new_req.messages) + assert new_user is not None # mirrors the guard above; deepcopy preserves structure + if context_part_index >= 0: + new_user.content[context_part_index] = text_part + else: + new_user.content.append(text_part) + return new_req + + # Matches data URIs: everything up to the first comma is the media-type + # parameters (e.g. "data:audio/L16;codec=pcm;rate=24000;base64,"). _DATA_URI_RE = re.compile(r'data:[^,]{0,200},(?=.{100})', re.ASCII) @@ -204,18 +425,60 @@ async def generate_action( on_chunk: Callable[[ModelResponseChunk], None] | None = None, message_index: int = 0, current_turn: int = 0, - middleware: list[ModelMiddleware] | None = None, context: dict[str, Any] | None = None, ) -> ModelResponse: - """Run generation with a util ``generate`` span. + """Execute a generation request with tool calling and middleware support, wrapped in a util ``generate`` span. - The registered ``/util/generate`` action calls `_generate_action` directly + The registered ``/util/generate`` action calls :func:`_generate_action` directly, so reflection runs do not stack another util span on the action span. """ span_name = 'generate' with run_in_new_span(SpanMetadata(name=span_name, type='util', input=raw_request)) as span: + call_registry = registry if registry.is_child else registry.new_child() + refs = registry_with_inline_middleware(call_registry, raw_request.use) + if refs: + raw_request = raw_request.model_copy(update={'use': refs}) + middleware = resolve_middleware_from_use(call_registry, refs) + _queue: list[Message] = [] + + def _enqueue_parts(parts: list[Part]) -> None: + if _queue and _queue[-1].role == Role.USER: + _queue[-1] = Message(role=Role.USER, content=list(_queue[-1].content) + list(parts)) + else: + _queue.append(Message(role=Role.USER, content=list(parts))) + + # Bind per-call framework attrs onto each middleware before any hook or + # tools() runs. After this, hooks can read ``self.registry`` and + # ``self.enqueue_parts`` directly without the engine threading them + # through every params object. + if middleware: + middleware = _bind_call_state( + middleware, + registry=call_registry, + enqueue_parts=_enqueue_parts, + ) + mw_tools: list[Action[Any, Any, Never]] = [] + for mw in middleware: + contributed = mw.tools() + mw_tools.extend(contributed) + + if mw_tools: + mw_tool_names: list[str] = [] + for t in mw_tools: + call_registry.register_action_from_instance(t) + mw_tool_names.append(t.name) + existing = list(raw_request.tools) if raw_request.tools else [] + raw_request = raw_request.model_copy(update={'tools': existing + mw_tool_names}) result = await _generate_action( - registry, raw_request, on_chunk, message_index, current_turn, middleware, context + registry=call_registry, + raw_request=raw_request, + on_chunk=on_chunk, + message_index=message_index, + current_turn=current_turn, + middleware=middleware, + context=context, + _enqueue_parts=_enqueue_parts, + _queue=_queue, ) with contextlib.suppress(Exception): span.set_attribute('genkit:output', result.model_dump_json(by_alias=True, exclude_none=True)) @@ -228,8 +491,10 @@ async def _generate_action( on_chunk: Callable[[ModelResponseChunk], None] | None = None, message_index: int = 0, current_turn: int = 0, - middleware: list[ModelMiddleware] | None = None, + middleware: list[BaseMiddleware] | None = None, context: dict[str, Any] | None = None, + _enqueue_parts: Callable[[list[Part]], None] | None = None, + _queue: list[Message] | None = None, ) -> ModelResponse: """Execute a generation request with tool calling and middleware support.""" tools_in = raw_request.tools @@ -250,7 +515,11 @@ async def _generate_action( revised_request, interrupted_response, resumed_tool_message, - ) = await _resolve_resume_options(registry, raw_request) + ) = await _resolve_resume_options( + registry, + raw_request, + middleware=middleware, + ) # NOTE: in the future we should make it possible to interrupt a restart, but # at the moment it's too complicated because it's not clear how to return a @@ -309,68 +578,68 @@ def wrapper(chunk: ModelResponseChunk) -> None: if not middleware: middleware = [] - supports_context = False - if model.metadata: - model_info = model.metadata.get('model') - if model_info and isinstance(model_info, dict): - model_info_dict = cast(dict[str, object], model_info) - supports_info = model_info_dict.get('supports') - if supports_info and isinstance(supports_info, dict): - supports_dict = cast(dict[str, object], supports_info) - supports_context = bool(supports_dict.get('context')) - # if it doesn't support contextm inject context middleware - if raw_request.docs and not supports_context: - middleware.append(augment_with_context()) - - async def dispatch( - index: int, + # Inject ``request.docs`` as a context part on the last user message. + if request.docs: + request = _augment_with_context(request) + + normalized_mw: list[BaseMiddleware] = list(middleware) + + async def dispatch_generate( + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + """Chain wrap_generate middleware and call next_fn.""" + runner: Callable[[GenerateHookParams], Awaitable[ModelResponse]] = next_fn + for mw in reversed(normalized_mw): + _mw = mw + _inner = runner + + async def run_next( + p: GenerateHookParams, + *, + _m: BaseMiddleware = _mw, + _i: Callable[[GenerateHookParams], Awaitable[ModelResponse]] = _inner, + ) -> ModelResponse: + return await _m.wrap_generate(p, _i) + + runner = run_next + return await runner(params) + + async def dispatch_model( req: ModelRequest, - ctx: ActionRunContext, chunk_callback: Callable[[ModelResponseChunk], None] | None, ) -> ModelResponse: - """Dispatch request through middleware chain to the model.""" - if not middleware or index == len(middleware): - # End of the chain, call the original model action + async def run_model(params: ModelHookParams) -> ModelResponse: return ( await model.run( - input=req, - context=ctx.context, - on_chunk=cast(Callable[[object], None], chunk_callback) if chunk_callback else None, + input=params.request, + context=params.context, + on_chunk=params.on_chunk, ) ).response - current_middleware = middleware[index] - n_params = len(inspect.signature(current_middleware).parameters) + runner: Callable[[ModelHookParams], Awaitable[ModelResponse]] = run_model + for mw in reversed(normalized_mw): + _mw = mw + _inner = runner - if n_params == 4: - # Streaming middleware: (req, ctx, on_chunk, next) -> response - async def next_fn_streaming( - modified_req: ModelRequest | None = None, - modified_ctx: ActionRunContext | None = None, - modified_on_chunk: Callable[[ModelResponseChunk], None] | None = None, + async def run_next( + params: ModelHookParams, + *, + _mw: BaseMiddleware = _mw, + _inner: Callable[[ModelHookParams], Awaitable[ModelResponse]] = _inner, ) -> ModelResponse: - return await dispatch( - index + 1, - modified_req if modified_req else req, - modified_ctx if modified_ctx else ctx, - modified_on_chunk if modified_on_chunk is not None else chunk_callback, - ) + return await _mw.wrap_model(params, _inner) - return await current_middleware(req, ctx, chunk_callback, next_fn_streaming) - else: - # Simple middleware: (req, ctx, next) -> response - async def next_fn_simple( - modified_req: ModelRequest | None = None, - modified_ctx: ActionRunContext | None = None, - ) -> ModelResponse: - return await dispatch( - index + 1, - modified_req if modified_req else req, - modified_ctx if modified_ctx else ctx, - chunk_callback, - ) + runner = cast(Callable[[ModelHookParams], Awaitable[ModelResponse]], run_next) - return await current_middleware(req, ctx, next_fn_simple) + return await runner( + ModelHookParams( + request=req, + on_chunk=chunk_callback, + context=context or {}, + ) + ) # if resolving the 'resume' option above generated a tool message, stream it. if resumed_tool_message and on_chunk: @@ -381,117 +650,169 @@ async def next_fn_simple( ) ) - model_response = await dispatch( - 0, - request, - ActionRunContext(context=context), - wrap_chunks() if on_chunk else None, - ) + async def run_one_iteration(_params: GenerateHookParams) -> ModelResponse: + """Execute one turn of the generate loop (model call + optional tool resolution).""" + nonlocal request, message_index, chunk_role + # Sync from params so wrap_generate middleware can reshape the request + # by returning a model_copy(update={'request': ...}) to next_fn. + # Without this, a middleware-modified params.request would be silently ignored. + request = _params.request + # Drain anything middleware queued during the previous turn's tool + # calls and inject it as additional USER messages before the model + # runs. This is how a tool-side middleware (e.g. Filesystem read_file) + # can make extra context — file contents, error notes, etc. — visible + # to the model on the very next turn without forging a tool response. + if _queue: + queued = list(_queue) + _queue.clear() + if on_chunk: + # Emit each queued message at the current index and advance once + # per message. We bypass `make_chunk` here because its role + # tracker treats every USER chunk as a new message and would + # double-count the role flip from MODEL to USER. + for msg in queued: + msg_role = cast(Role, msg.role) + chunk = ModelResponseChunk( + role=msg_role, + content=msg.content, + index=message_index, + previous_chunks=list(prev_chunks), + ) + prev_chunks.append(chunk) + on_chunk(chunk) + message_index += 1 + chunk_role = msg_role + request = request.model_copy(update={'messages': list(request.messages) + queued}) + + model_response = await dispatch_model( + request, + wrap_chunks() if on_chunk else None, + ) - def message_parser(msg: Message) -> Any: # noqa: ANN401 - if formatter is None: - return None - return formatter.parse_message(msg) - - # Extract schema_type for runtime Pydantic validation - schema_type = raw_request.output.schema_type if raw_request.output else None - - # Plugin returns ModelResponse directly. Framework sets request and - # any output format context (message_parser, schema_type) as private attrs. - response = model_response - response.request = request - if formatter: - response._message_parser = message_parser - if schema_type: - response._schema_type = schema_type - - logger.debug('generate response', response=_redact_data_uris(response.model_dump())) - - response.assert_valid() - generated_msg = response.message - - if generated_msg is None: - # No message in response, return as-is - return response - - # Stamp output format metadata on message so the Dev UI can render formatted JSON vs plain text. - out = raw_request.output - if out and (out.content_type or out.format): - generate_output: dict[str, str] = {} - if out.content_type: - generate_output['contentType'] = out.content_type - if out.format: - generate_output['format'] = out.format - existing_meta = dict(generated_msg.metadata) if isinstance(generated_msg.metadata, dict) else {} - generate_meta = existing_meta.get('generate') - if not isinstance(generate_meta, dict): - generate_meta = {} - generate_meta['output'] = generate_output - existing_meta['generate'] = generate_meta - generated_msg.metadata = existing_meta - - tool_requests = [x for x in generated_msg.content if x.root.tool_request] - - if raw_request.return_tool_requests or len(tool_requests) == 0: - if len(tool_requests) == 0: - response.assert_valid_schema() - return response - - max_iters = raw_request.max_turns if raw_request.max_turns else DEFAULT_MAX_TURNS - - if current_turn + 1 > max_iters: - raise GenerationResponseError( - response=response, - message=f'Exceeded maximum tool call iterations ({max_iters})', - status='ABORTED', - details={'request': request}, + def message_parser(msg: Message) -> Any: # noqa: ANN401 + if formatter is None: + return None + return formatter.parse_message(msg) + + # Extract schema_type for runtime Pydantic validation + schema_type = raw_request.output.schema_type if raw_request.output else None + + # Plugin returns ModelResponse directly. Framework sets request and + # any output format context (message_parser, schema_type) as private attrs. + response = model_response + response.request = request + if formatter: + response._message_parser = message_parser + if schema_type: + response._schema_type = schema_type + + logger.debug( + 'generate response', + response=_redact_data_uris(response.model_dump()), ) - ( - revised_model_msg, - tool_msg, - transfer_preamble, - ) = await resolve_tool_requests(registry, raw_request, generated_msg) - - # if an interrupt message is returned, stop the tool loop and return a - # response. - if revised_model_msg: - interrupted_resp = response.model_copy(deep=False) - interrupted_resp.finish_reason = FinishReason.INTERRUPTED - interrupted_resp.finish_message = 'One or more tool calls resulted in interrupts.' - interrupted_resp.message = Message(revised_model_msg) - return interrupted_resp - - # If the loop will continue, stream out the tool response message... - if on_chunk and tool_msg: - on_chunk( - make_chunk( - Role.TOOL, - ModelResponseChunk( - role=tool_msg.role, - content=tool_msg.content, - ), + response.assert_valid() + generated_msg = response.message + + if generated_msg is None: + # No message in response, return as-is + return response + + # Stamp output format metadata on message so the Dev UI can render formatted JSON vs plain text. + out = raw_request.output + if out and (out.content_type or out.format): + generate_output: dict[str, str] = {} + if out.content_type: + generate_output['contentType'] = out.content_type + if out.format: + generate_output['format'] = out.format + existing_meta = dict(generated_msg.metadata) if isinstance(generated_msg.metadata, dict) else {} + generate_meta = existing_meta.get('generate') + if not isinstance(generate_meta, dict): + generate_meta = {} + generate_meta['output'] = generate_output + existing_meta['generate'] = generate_meta + generated_msg.metadata = existing_meta + + tool_requests = [x for x in generated_msg.content if x.root.tool_request] + + if raw_request.return_tool_requests or len(tool_requests) == 0: + if len(tool_requests) == 0: + response.assert_valid_schema() + return response + + max_iters = raw_request.max_turns if raw_request.max_turns else DEFAULT_MAX_TURNS + + if current_turn + 1 > max_iters: + raise GenerationResponseError( + response=response, + message=f'Exceeded maximum tool call iterations ({max_iters})', + status='ABORTED', + details={'request': request}, ) + + ( + revised_model_msg, + tool_msg, + transfer_preamble, + ) = await resolve_tool_requests( + registry, + raw_request, + generated_msg, + middleware=normalized_mw, ) - next_request = copy.copy(raw_request) - next_messages = copy.copy(raw_request.messages) - next_messages.append(generated_msg) - if tool_msg: - next_messages.append(tool_msg) - next_request.messages = next_messages - if transfer_preamble: - next_request = apply_transfer_preamble(next_request, transfer_preamble) - - # then recursively call for another loop - return await _generate_action( - registry, - raw_request=next_request, - # middleware: middleware, - current_turn=current_turn + 1, - message_index=message_index + 1, + # if an interrupt message is returned, stop the tool loop and return a + # response. + if revised_model_msg: + interrupted_resp = response.model_copy(deep=False) + interrupted_resp.finish_reason = FinishReason.INTERRUPTED + interrupted_resp.finish_message = 'One or more tool calls resulted in interrupts.' + interrupted_resp.message = Message(revised_model_msg) + return interrupted_resp + + # If the loop will continue, stream out the tool response message... + if on_chunk and tool_msg: + on_chunk( + make_chunk( + Role.TOOL, + ModelResponseChunk( + role=tool_msg.role, + content=tool_msg.content, + ), + ) + ) + + next_request = copy.copy(raw_request) + next_messages = copy.copy(raw_request.messages) + next_messages.append(generated_msg) + if tool_msg: + next_messages.append(tool_msg) + next_request.messages = next_messages + if transfer_preamble: + next_request = apply_transfer_preamble(next_request, transfer_preamble) + + # then recursively call for another loop. + return await _generate_action( + registry=registry, + raw_request=next_request, + middleware=middleware, + current_turn=current_turn + 1, + message_index=message_index + 1, + on_chunk=on_chunk, + context=context, + _enqueue_parts=_enqueue_parts, + _queue=_queue, + ) + + generate_params = GenerateHookParams( + options=raw_request, + request=request, + iteration=current_turn, + message_index=message_index, on_chunk=on_chunk, ) + return await dispatch_generate(generate_params, run_one_iteration) def apply_format( @@ -760,7 +1081,11 @@ def to_tool_definition(tool: Action) -> ToolDefinition: async def resolve_tool_requests( - registry: Registry, request: GenerateActionOptions, message: Message + registry: Registry, + request: GenerateActionOptions, + message: Message, + *, + middleware: list[BaseMiddleware] | None = None, ) -> tuple[Message | None, Message | None, GenerateActionOptions | None]: """Execute tool requests in a message, returning responses or interrupt info.""" # TODO(#4342): prompt transfer @@ -775,6 +1100,7 @@ async def resolve_tool_requests( tool_dict[short] = tool_action revised_model_message = message.model_copy(deep=True) + mw_list = middleware or [] work: list[tuple[int, Action, ToolRequestPart]] = [] for i, tool_request_part in enumerate(message.content): @@ -792,12 +1118,45 @@ async def resolve_tool_requests( if not work: return (None, Message(role=Role.TOOL, content=[]), None) - outs = await asyncio.gather(*[_resolve_tool_request(tool, trp) for _, tool, trp in work]) + async def _resolve_one_tool( + tool: Action, trp: ToolRequestPart + ) -> tuple[MultipartToolResponse | None, ToolRequestPart | None]: + params = ToolHookParams(tool_request_part=trp, tool=tool) + + async def base(p: ToolHookParams) -> MultipartToolResponse: + return await _resolve_tool_request(p.tool, p.tool_request_part) + + try: + if mw_list: + multipart = await _chain_tool_middleware(mw_list, params, base) + else: + multipart = await base(params) + return (multipart, None) + except Exception as e: + # Interrupts (raised by the tool body or by middleware) become a + # wire-shape interrupt ``ToolRequestPart``. Any tracing span is the + # middleware's responsibility (e.g. ToolApproval wraps its raise in + # ``run_in_new_span`` explicitly). Non-Interrupt exceptions are real + # failures and propagate to ``asyncio.gather``. + intr = _interrupt_from_tool_exc(e) + if intr is None: + raise + return (None, _interrupt_request_part(trp, intr)) + + outs = await asyncio.gather(*[_resolve_one_tool(tool, trp) for _, tool, trp in work]) has_interrupts = False response_parts: list[Part] = [] - for (idx, _tool, tool_req_root), (tool_response_part, interrupt_part) in zip(work, outs, strict=True): - if tool_response_part: + for (idx, _tool, tool_req_root), (multipart_resp, interrupt_part) in zip(work, outs, strict=True): + if multipart_resp is not None: + tool_response_part = ToolResponsePart( + tool_response=ToolResponse( + name=tool_req_root.tool_request.name, + ref=tool_req_root.tool_request.ref, + output=multipart_resp.output, + content=[p.model_dump() for p in multipart_resp.content] if multipart_resp.content else None, + ) + ) revised_model_message.content[idx] = _to_pending_response(tool_req_root, tool_response_part) response_parts.append(Part(root=tool_response_part)) @@ -833,40 +1192,29 @@ def _interrupt_from_tool_exc(exc: BaseException) -> Interrupt | None: return None -async def _resolve_tool_request( - tool: Action, tool_request_part: ToolRequestPart -) -> tuple[ToolResponsePart | None, ToolRequestPart | None]: - """Execute a tool. +async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart) -> MultipartToolResponse: + """Execute a tool and return its response. - Returns ``(ToolResponsePart, None)`` on success or ``(None, ToolRequestPart)`` when interrupted. + Interrupts from the tool body propagate to the caller (the engine + converts them to a wire ``ToolRequestPart`` at the top of + ``_resolve_one_tool``). This keeps the contract symmetric with + ``BaseMiddleware.wrap_tool``: responses are return values, interrupts + are exceptions. """ - try: - tool_response = (await tool.run(tool_request_part.tool_request.input)).response - return ( - ToolResponsePart( - tool_response=ToolResponse( - name=tool_request_part.tool_request.name, - ref=tool_request_part.tool_request.ref, - output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, - ) - ), - None, - ) - except Exception as e: - intr = _interrupt_from_tool_exc(e) - if intr is not None: - payload: dict[str, Any] | bool = intr.metadata if intr.metadata else True - tool_meta = tool_request_part.metadata or {} - if not isinstance(tool_meta, dict): - tool_meta = dict(tool_meta) - return ( - None, - ToolRequestPart( - tool_request=tool_request_part.tool_request, - metadata={**tool_meta, 'interrupt': payload}, - ), - ) - raise + tool_response = (await tool.run(tool_request_part.tool_request.input)).response + return MultipartToolResponse( + output=tool_response.model_dump() if isinstance(tool_response, BaseModel) else tool_response, + ) + + +def _interrupt_request_part(trp: ToolRequestPart, intr: Interrupt) -> ToolRequestPart: + """Convert an Interrupt exception into the wire-shape interrupt ToolRequestPart.""" + payload: dict[str, Any] | bool = intr.metadata if intr.metadata else True + tool_meta = trp.metadata or {} + return ToolRequestPart( + tool_request=trp.tool_request, + metadata={**tool_meta, 'interrupt': payload}, + ) async def resolve_tool(registry: Registry, tool_ref: str | Tool) -> Action: @@ -892,7 +1240,10 @@ async def resolve_tool(registry: Registry, tool_ref: str | Tool) -> Action: async def _resolve_resume_options( - _registry: Registry, raw_request: GenerateActionOptions + _registry: Registry, + raw_request: GenerateActionOptions, + *, + middleware: list[BaseMiddleware] | None = None, ) -> tuple[GenerateActionOptions, ModelResponse | None, Message | None]: """Handle resume options by resolving pending tool calls from a previous turn.""" if not raw_request.resume: @@ -912,18 +1263,23 @@ async def _resolve_resume_options( i = 0 tool_responses = [] - # Create a new list for content to avoid mutation during iteration + # Build updated_content in a new list — do NOT mutate last_message.content + # directly; the caller's raw_request object must remain unchanged. updated_content = list(last_message.content) for part in last_message.content: if not isinstance(part.root, ToolRequestPart): i += 1 continue - resumed_request, resumed_response = await _resolve_resumed_tool_request(_registry, raw_request, part) + resumed_request, resumed_response = await _resolve_resumed_tool_request( + _registry, + raw_request, + part, + middleware=middleware, + ) tool_responses.append(Part(root=resumed_response)) updated_content[i] = Part(root=resumed_request) i += 1 - last_message.content = updated_content if len(tool_responses) != len(tool_requests): raise GenkitError( @@ -939,13 +1295,24 @@ async def _resolve_resume_options( revised_request = raw_request.model_copy(deep=True) revised_request.resume = None + # Replace the last message in the deep copy with the resolved version + # (pending TRPs swapped for resolved ones) without touching raw_request. + revised_request.messages[-1] = Message( + role=last_message.role, + content=updated_content, + metadata=last_message.metadata, + ) revised_request.messages.append(tool_message) return (revised_request, None, tool_message) async def _resolve_resumed_tool_request( - registry: Registry, raw_request: GenerateActionOptions, tool_request_part: Part + registry: Registry, + raw_request: GenerateActionOptions, + tool_request_part: Part, + *, + middleware: list[BaseMiddleware] | None = None, ) -> tuple[ToolRequestPart, ToolResponsePart]: """Resolve a single tool request from pending output, resume.respond, or resume.restart.""" # Type narrowing: ensure we're working with a ToolRequestPart @@ -1008,7 +1375,7 @@ async def _resolve_resumed_tool_request( ) if restart_trp: tool = await resolve_tool(registry, tool_req_root.tool_request.name) - executed = await run_tool_after_restart(tool, restart_trp) + executed = await _run_restart_through_middleware(tool, restart_trp, middleware=middleware) metadata = dict(tool_req_root.metadata) if tool_req_root.metadata else {} interrupt = metadata.get('interrupt') if interrupt: @@ -1033,6 +1400,58 @@ async def _resolve_resumed_tool_request( ) +async def _run_restart_through_middleware( + tool: Action, + restart_trp: ToolRequestPart, + *, + middleware: list[BaseMiddleware] | None, +) -> ToolResponsePart: + """Run a restarted tool through the wrap_tool middleware chain. + + Restart paths reuse the same dispatch as fresh tool calls so middleware + (ToolApproval, Filesystem error queueing, etc.) sees every tool execution + regardless of whether it was triggered by the model or by a resumed + interrupt. Without this, a restart would silently bypass approval checks. + """ + mw_list = middleware or [] + if not mw_list: + return await run_tool_after_restart(tool, restart_trp) + + params = ToolHookParams( + tool_request_part=restart_trp, + tool=tool, + ) + + async def next_fn(p: ToolHookParams) -> MultipartToolResponse: + executed = await run_tool_after_restart(p.tool, restart_trp) + return MultipartToolResponse( + output=executed.tool_response.output, + content=[Part.model_validate(c) for c in (executed.tool_response.content or [])], + ) + + try: + multipart = await _chain_tool_middleware(mw_list, params, next_fn) + except Exception as e: + if _interrupt_from_tool_exc(e) is not None: + # Re-interrupting during restart is a hard error — same as the legacy + # run_tool_after_restart path, which raises FAILED_PRECONDITION when + # the inner tool throws an Interrupt during restart. + raise GenkitError( + status='FAILED_PRECONDITION', + message='Tool interrupted again during a restart execution; not supported yet.', + ) from e + raise + + return ToolResponsePart( + tool_response=ToolResponse( + name=restart_trp.tool_request.name, + ref=restart_trp.tool_request.ref, + output=multipart.output, + content=[p.model_dump() for p in multipart.content] if multipart.content else None, + ) + ) + + def _find_corresponding_restart( restarts: list[ToolRequestPart] | None, request: ToolRequestPart, diff --git a/py/packages/genkit/src/genkit/_ai/_middleware.py b/py/packages/genkit/src/genkit/_ai/_middleware.py deleted file mode 100644 index 54931710ea..0000000000 --- a/py/packages/genkit/src/genkit/_ai/_middleware.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Middleware for the Genkit framework.""" - -from collections.abc import Awaitable, Callable - -from genkit._ai._model import ( - Message, - ModelMiddleware, - ModelRequest, - ModelResponse, - text_from_content, -) -from genkit._core._action import ActionRunContext -from genkit._core._model import Document -from genkit._core._typing import ( - Part, - TextPart, -) - -CONTEXT_PREFACE = '\n\nUse the following information to complete your task:\n\n' - - -def context_item_template(d: Document, index: int) -> str: - """Render a document as a citation line for context injection.""" - out = '- ' - ref = (d.metadata and (d.metadata.get('ref') or d.metadata.get('id'))) or index - out += f'[{ref}]: ' - out += text_from_content(d.content) + '\n' - return out - - -def augment_with_context() -> ModelMiddleware: - """Middleware that injects document context into the last user message.""" - - async def middleware( - req: ModelRequest, - ctx: ActionRunContext, - next_middleware: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - if not req.docs: - return await next_middleware(req, ctx) - - user_message = last_user_message(req.messages) - if not user_message: - return await next_middleware(req, ctx) - - context_part_index = -1 - for i, part in enumerate(user_message.content): - part_metadata = part.root.metadata - if isinstance(part_metadata, dict) and part_metadata.get('purpose') == 'context': - context_part_index = i - break - - context_part = user_message.content[context_part_index] if context_part_index >= 0 else None - - if context_part: - metadata = context_part.root.metadata - if not (isinstance(metadata, dict) and metadata.get('pending')): - return await next_middleware(req, ctx) - - out = CONTEXT_PREFACE - for i, doc_data in enumerate(req.docs): - doc = Document(content=doc_data.content, metadata=doc_data.metadata) - out += context_item_template(doc, i) - out += '\n' - - text_part = Part(root=TextPart(text=out, metadata={'purpose': 'context'})) - if context_part_index >= 0: - user_message.content[context_part_index] = text_part - else: - if not user_message.content: - user_message.content = [] - user_message.content.append(text_part) - - return await next_middleware(req, ctx) - - return middleware - - -def last_user_message(messages: list[Message]) -> Message | None: - """Find the last user message in a list.""" - for i in range(len(messages) - 1, -1, -1): - if messages[i].role == 'user': - return messages[i] - return None diff --git a/py/packages/genkit/src/genkit/_ai/_model.py b/py/packages/genkit/src/genkit/_ai/_model.py index 26baf4e513..4d032b22fe 100644 --- a/py/packages/genkit/src/genkit/_ai/_model.py +++ b/py/packages/genkit/src/genkit/_ai/_model.py @@ -32,7 +32,6 @@ from genkit._core._model import ( Message, ModelConfig, - ModelMiddleware, ModelRef, ModelRequest, ModelResponse, diff --git a/py/packages/genkit/src/genkit/_ai/_prompt.py b/py/packages/genkit/src/genkit/_ai/_prompt.py index b9dfcd8c5c..71d8c862ef 100644 --- a/py/packages/genkit/src/genkit/_ai/_prompt.py +++ b/py/packages/genkit/src/genkit/_ai/_prompt.py @@ -36,13 +36,13 @@ from genkit._ai._generate import ( generate_action, + registry_with_inline_middleware, registry_with_inline_tools, resolve_tool, to_tool_definition, tools_to_action_names, ) from genkit._ai._model import ( - ModelMiddleware, ModelRequest, ModelResponse, ModelResponseChunk, @@ -52,11 +52,13 @@ from genkit._core._channel import Channel from genkit._core._error import GenkitError from genkit._core._logger import get_logger +from genkit._core._middleware import BaseMiddleware from genkit._core._model import Document, GenerateActionOptions, Message, ModelConfig from genkit._core._registry import Registry from genkit._core._schema import to_json_schema from genkit._core._typing import ( GenerateActionOutputConfig, + MiddlewareRef, OutputConfig, Part, Resume, @@ -134,9 +136,8 @@ class PromptGenerateOptions(TypedDict, total=False): return_tool_requests: bool | None max_turns: int | None on_chunk: ModelStreamingCallback | None - use: list[ModelMiddleware] | None + use: Sequence[BaseMiddleware | MiddlewareRef] | None context: dict[str, Any] | None - step_name: str | None metadata: dict[str, Any] | None @@ -213,7 +214,7 @@ class PromptConfig(BaseModel): metadata: dict[str, Any] | None = None tools: Sequence[str | Tool] | None = None tool_choice: ToolChoice | None = None - use: list[ModelMiddleware] | None = None + use: Sequence[BaseMiddleware | MiddlewareRef] | None = None docs: list[Document] | None = None resume_respond: ToolResponsePart | list[ToolResponsePart] | None = None resume_restart: ToolRequestPart | list[ToolRequestPart] | None = None @@ -245,7 +246,7 @@ def __init__( metadata: dict[str, Any] | None = None, tools: Sequence[str | Tool] | None = None, tool_choice: ToolChoice | None = None, - use: list[ModelMiddleware] | None = None, + use: Sequence[BaseMiddleware | MiddlewareRef] | None = None, docs: list[Document] | None = None, resources: list[str] | None = None, name: str | None = None, @@ -343,16 +344,22 @@ async def _call_impl( """Execute the prompt with resolved opts. Used by __call__ and stream.""" await self._ensure_resolved() on_chunk = opts.get('on_chunk') - middleware = opts.get('use') or self._use context = opts.get('context') + prompt_config = self._prompt_config_for_call(opts) - registry = await registry_with_inline_tools(self._registry, prompt_config.tools) - gen_options = await executable_prompt_call_to_generate_options(self, registry, prompt_config, input, opts) + # `exec_registry` carries inline `use=[Logger()]` middleware under + # synthetic ref names so the generate action can resolve them; pass it + # (not `render_registry`) into `generate_action` for that reason. + render_registry, exec_registry, prompt_config = await prepare_prompt_call_registry( + self._registry, prompt_config + ) + gen_options = await executable_prompt_call_to_generate_options( + self, render_registry, prompt_config, input, opts + ) result = await generate_action( - registry, + exec_registry, gen_options, on_chunk=on_chunk, - middleware=middleware, context=context if context else ActionRunContext._current_context(), # pyright: ignore[reportPrivateUsage] ) return cast(ModelResponse[OutputT], result) @@ -401,6 +408,7 @@ def _or(opt_val: Any, default: Any) -> Any: # noqa: ANN401 metadata=merged_metadata, docs=self._docs, resources=opts.get('resources') or self._resources, + use=opts.get('use') or self._use, resume_respond=opts.get('resume_respond'), resume_restart=opts.get('resume_restart'), resume_metadata=opts.get('resume_metadata'), @@ -437,8 +445,10 @@ async def render( call_opts: PromptGenerateOptions = opts # ty: ignore[invalid-assignment] # ty treats **opts as a plain dict here; callers are still validated against PromptGenerateOptions. await self._ensure_resolved() prompt_config = self._prompt_config_for_call(call_opts) - registry = await registry_with_inline_tools(self._registry, prompt_config.tools) - return await executable_prompt_call_to_generate_options(self, registry, prompt_config, input, call_opts) + render_registry, _exec_registry, prompt_config = await prepare_prompt_call_registry( + self._registry, prompt_config + ) + return await executable_prompt_call_to_generate_options(self, render_registry, prompt_config, input, call_opts) def register_prompt_actions( @@ -465,19 +475,23 @@ async def prompt_action_fn(input: Any = None) -> ModelRequest: # noqa: ANN401 await executable_prompt._ensure_resolved() call_opts: PromptGenerateOptions = {} prompt_config = executable_prompt._prompt_config_for_call(call_opts) - registry = await registry_with_inline_tools(executable_prompt._registry, prompt_config.tools) + render_registry, _exec_registry, prompt_config = await prepare_prompt_call_registry( + executable_prompt._registry, prompt_config + ) gen_options = await executable_prompt_call_to_generate_options( - executable_prompt, registry, prompt_config, input, call_opts + executable_prompt, render_registry, prompt_config, input, call_opts ) - return await to_generate_request(registry, gen_options) + return await to_generate_request(render_registry, gen_options) async def executable_prompt_action_fn(input: Any = None) -> GenerateActionOptions: # noqa: ANN401 await executable_prompt._ensure_resolved() call_opts: PromptGenerateOptions = {} prompt_config = executable_prompt._prompt_config_for_call(call_opts) - registry = await registry_with_inline_tools(executable_prompt._registry, prompt_config.tools) + render_registry, _exec_registry, prompt_config = await prepare_prompt_call_registry( + executable_prompt._registry, prompt_config + ) return await executable_prompt_call_to_generate_options( - executable_prompt, registry, prompt_config, input, call_opts + executable_prompt, render_registry, prompt_config, input, call_opts ) action_name = registry_definition_key(name, variant) @@ -539,6 +553,39 @@ def _resolve_output_schema( output.json_schema = to_json_schema(output_schema) +async def prepare_prompt_call_registry( + base_registry: Registry, + prompt_config: PromptConfig, +) -> tuple[Registry, Registry, PromptConfig]: + """Build per-call registries for a prompt and finalize ``use`` to refs. + + Returns ``(render_registry, exec_registry, prompt_config)``: + + * ``render_registry`` is what template/schema/tool resolution should run + against (it has any inline ``Tool`` instances registered on a child of + ``base_registry`` if there were any, otherwise it's ``base_registry`` + itself). + * ``exec_registry`` is the child registry where inline ``BaseMiddleware`` + instances from ``prompt_config.use`` were registered, and is what + :func:`generate_action` should be called with so those middleware + resolve by name during execution. + * ``prompt_config`` is returned with ``use`` rewritten to a list of + ``MiddlewareRef`` (the wire shape that flows into + :class:`GenerateActionOptions`). Inline instances now live on + ``exec_registry`` under stable synthetic names so the refs resolve at + call time. + + Mirrors the pattern in :meth:`Genkit.generate` so prompt execution sees + the same middleware semantics as direct ``ai.generate`` calls. + """ + render_registry = await registry_with_inline_tools(base_registry, prompt_config.tools) + exec_registry = render_registry if render_registry.is_child else render_registry.new_child() + refs = registry_with_inline_middleware(exec_registry, prompt_config.use) or None + if refs is not None or prompt_config.use is not None: + prompt_config = prompt_config.model_copy(update={'use': refs}) + return render_registry, exec_registry, prompt_config + + async def to_generate_action_options( registry: Registry, options: PromptConfig, @@ -597,6 +644,7 @@ async def to_generate_action_options( max_turns=options.max_turns, docs=merged_docs, # type: ignore[arg-type] resume=resume, + use=options.use, # type: ignore[arg-type] ) @@ -1142,18 +1190,24 @@ async def prompt_action_fn(input: Any = None) -> ModelRequest: # noqa: ANN401 prompt = await create_prompt_from_file() call_opts: PromptGenerateOptions = {} prompt_config = prompt._prompt_config_for_call(call_opts) - registry = await registry_with_inline_tools(prompt._registry, prompt_config.tools) + render_registry, _exec_registry, prompt_config = await prepare_prompt_call_registry( + prompt._registry, prompt_config + ) gen_options = await executable_prompt_call_to_generate_options( - prompt, registry, prompt_config, input, call_opts + prompt, render_registry, prompt_config, input, call_opts ) - return await to_generate_request(registry, gen_options) + return await to_generate_request(render_registry, gen_options) async def executable_prompt_action_fn(input: Any = None) -> GenerateActionOptions: # noqa: ANN401 prompt = await create_prompt_from_file() call_opts: PromptGenerateOptions = {} prompt_config = prompt._prompt_config_for_call(call_opts) - registry = await registry_with_inline_tools(prompt._registry, prompt_config.tools) - return await executable_prompt_call_to_generate_options(prompt, registry, prompt_config, input, call_opts) + render_registry, _exec_registry, prompt_config = await prepare_prompt_call_registry( + prompt._registry, prompt_config + ) + return await executable_prompt_call_to_generate_options( + prompt, render_registry, prompt_config, input, call_opts + ) action_name = registry_definition_key(name, variant, ns) prompt_action = registry.register_action( diff --git a/py/packages/genkit/src/genkit/_core/_middleware.py b/py/packages/genkit/src/genkit/_core/_middleware.py new file mode 100644 index 0000000000..03e55645a8 --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_middleware.py @@ -0,0 +1,474 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Core middleware abstractions for the Genkit generate pipeline. + +Defines :class:`BaseMiddleware` (the class authors subclass to add config fields +and hook overrides), :class:`MiddlewareDesc` (the registry descriptor used for +Dev UI name-based dispatch), plus the :func:`middleware` decorator and +:func:`new_middleware` factory for registration. + +Also contains the hook parameter types (:class:`GenerateHookParams`, +:class:`ModelHookParams`, :class:`ToolHookParams`, :class:`MultipartToolResponse`) +that are passed into each hook by the engine. These live here rather than in +``_model.py`` because middleware is a concept built on top of the model layer. +""" + +from __future__ import annotations + +import re +from collections.abc import Awaitable, Callable +from typing import Any, ClassVar, TypeVar + +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, create_model + +from genkit._core._action import Action +from genkit._core._model import ( + GenerateActionOptions, + ModelRequest, + ModelResponse, + ModelResponseChunk, +) +from genkit._core._protocols import RegistryLike +from genkit._core._typing import MiddlewareDescData, Part, ToolRequestPart + +_M = TypeVar('_M', bound='type[BaseMiddleware]') + +# Disallowed in middleware definition names and in ``middleware_plugin(..., namespace=...)``. +# Model/action keys use ``provider/name``; middleware stays one path-free token for the registry. +_FORBIDDEN_IN_MIDDLEWARE_KEY_SEGMENT = re.compile(r'[\x00-\x1f/\\:]|\s') + + +def _validate_middleware_key_segment(name: str, *, label: str) -> None: + """Raise if ``name`` is not usable as a middleware registry key or namespace. + + Middleware definitions are stored under + ``register_value(kind='middleware', name=...)``. The optional + ``middleware_plugin(..., namespace='acme')`` builds keys of the form + ``acme_logging``. The string must therefore be one segment: + + * no ``/`` (that shape is reserved for models and other actions); + * no whitespace, ``:``, backslashes, or control characters that + would break registry keys or the Dev UI. + + Args: + name: Proposed name or namespace segment. + label: Field name for error messages (e.g. ``MiddlewareDesc name``). + """ + if not name or not name.strip(): + raise ValueError(f'{label} must be a non-empty string (not whitespace-only).') + if name != name.strip(): + raise ValueError(f'{label} must not have leading or trailing whitespace.') + if _FORBIDDEN_IN_MIDDLEWARE_KEY_SEGMENT.search(name): + raise ValueError( + f'{label} must be one path-free token: no whitespace, "/", ":", ' + r'backslashes, or control characters (for example "myorg_logging_mw").' + ) + + +class MultipartToolResponse(BaseModel): + """A tool result with optional rich content attachments. + + Return from ``wrap_tool`` to send structured output alongside extra + parts — images, file contents, error details — that the model can + reason about. + + The engine serializes both fields into a single ``ToolResponsePart`` on + the wire: ``output`` becomes ``ToolResponse.output`` and ``content`` + becomes ``ToolResponse.content``. Packing them together preserves the + LLM's one-response-per-call contract while still letting middleware + attach rich context. + + Fields: + output: Structured result returned to the model. May be ``None`` + when the tool only produces rich content parts. + content: Extra ``Part`` objects (images, files, metadata) bundled + alongside ``output`` in the same ``ToolResponsePart``. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + + output: Any = None + content: list[Part] = Field(default_factory=list) + + +class GenerateHookParams(BaseModel): + """Params passed to the ``wrap_generate`` hook. + + Covers one full iteration of the tool loop: a model call plus optional tool + resolution. ``message_index`` and ``on_chunk`` support streaming. + """ + + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + + options: GenerateActionOptions + request: ModelRequest + iteration: int + message_index: int = 0 + on_chunk: Callable[[ModelResponseChunk], None] | None = None + + +class ModelHookParams(BaseModel): + """Params passed to the ``wrap_model`` hook (each raw model API call).""" + + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + + request: ModelRequest + on_chunk: Callable[[ModelResponseChunk], None] | None = None + context: dict[str, object] = Field(default_factory=dict) + + +class ToolHookParams(BaseModel): + """Params passed to the ``wrap_tool`` hook (each individual tool execution).""" + + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + + tool_request_part: ToolRequestPart + tool: Action + + +class BaseMiddleware(BaseModel): + """Pydantic-backed middleware: config fields + hook overrides in one class. + + The config struct *is* the middleware — there is no separate + "factory args" type. To author one: + + * Subclass and add pydantic fields for config. + * Override the ``wrap_generate`` / ``wrap_model`` / ``wrap_tool`` hooks. + * Either pass instances inline in ``use=[...]``, or register the class + with ``Genkit.define_middleware`` (or ``middleware_plugin`` / + ``Plugin.list_middleware``) and reference it by name with + :class:`MiddlewareRef`. + + Inside any hook, two framework-injected attributes are guaranteed to be set: + + * ``self.registry`` — the per-call child registry. Use it to resolve actions + and to inspect what else is in scope for this call. Anything you register + through it is automatically scoped to the call and torn down at the end. + * ``self.enqueue_parts(parts)`` — queue an extra user message to be injected + into the conversation at the start of the next generate iteration. Use it + from a tool closure or from ``wrap_tool`` to surface error details, file + contents, or other rich context to the model without forging a tool + response. + + Outside a ``generate()`` call these attributes are ``None`` — they only + become valid once the engine binds the instance to a specific call. + + Example: + @middleware(name='logger') + class Logger(BaseMiddleware): + prefix: str = '[trace]' + + async def wrap_model(self, params, next_fn): + t = time.monotonic() + resp = await next_fn(params) + log(f'{self.prefix} {time.monotonic() - t:.3f}s') + return resp + + # Inline (fast path, no registration): + await ai.generate(prompt='...', use=[Logger(prefix='[span]')]) + + # Registered (visible in the Dev UI, dispatched by name): + ai.define_middleware(Logger) + await ai.generate( + prompt='...', + use=[MiddlewareRef(name='logger', config={'prefix': '[span]'})], + ) + + Concurrency: + Each ``generate()`` call works on its own shallow copy of the + middleware instance with a freshly bound ``self.registry`` and + ``self.enqueue_parts``, so those framework attributes are safe + even when the same instance is reused across concurrent calls. + Author-added state on ``self`` is *not* deep-copied — keep + per-call state in method locals, or override ``model_copy`` if + you need stronger isolation (same convention as Django / + Starlette middleware). + """ + + # ``arbitrary_types_allowed`` lets subclasses keep non-pydantic fields like + # ``Callable`` or opaque resources without opting in per-subclass. + model_config = ConfigDict(arbitrary_types_allowed=True) + + # Class-level metadata used by ``new_middleware(MyClass)`` and the Dev UI. + # These are ClassVars, not fields, so they do not appear in ``model_dump()`` or + # ``config`` dicts passed to factories. + name: ClassVar[str] = '' + description: ClassVar[str | None] = None + middleware_config_schema: ClassVar[dict[str, Any] | None] = None + middleware_metadata: ClassVar[dict[str, object] | None] = None + + # Framework-injected at the start of each generate() call (see the class + # docstring). They are public fields, not PrivateAttrs, so a middleware + # author writing ``self.`` in their IDE sees them in autocomplete and knows + # they exist. Annotated as required so hooks can write + # ``self.registry.lookup_action(...)`` without a None-narrow; the runtime + # default of ``None`` lets bare constructors like ``Retry(max_retries=3)`` + # work, with the engine rebinding before any hook fires. + registry: RegistryLike = Field(default=None, exclude=True, repr=False) # type: ignore[assignment] + enqueue_parts: Callable[[list[Part]], None] = Field(default=None, exclude=True, repr=False) # type: ignore[assignment] + + def tools(self) -> list[Action]: + """Return additional tools to expose to the model for this generate call. + + Called once per ``generate()`` call after the engine has bound + ``self.registry`` and ``self.enqueue_parts``. Tool closures may + capture ``self.enqueue_parts`` to queue extra user messages + alongside the normal ``ToolResponsePart`` (e.g. filesystem + error details for the next turn). + + Tools are registered on a call-scoped child registry, so they + do not pollute the root registry and are invisible to other + concurrent ``generate()`` calls. + + Override to contribute tools dynamically. The default returns + ``[]``. + """ + return [] + + async def wrap_generate( + self, + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + """Wrap each iteration of the tool loop (model call + optional tool resolution).""" + return await next_fn(params) + + async def wrap_model( + self, + params: ModelHookParams, + next_fn: Callable[[ModelHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + """Wrap each model API call.""" + return await next_fn(params) + + async def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]], + ) -> MultipartToolResponse: + """Wrap each tool execution. + + Return a ``MultipartToolResponse`` to forward (or substitute) the + tool's result. Raise ``Interrupt(metadata)`` to halt this tool call + and surface an interrupt to the caller — the engine attaches + ``metadata`` to the pending ``ToolRequestPart`` exactly like an + interrupt raised by the tool body itself. Mirroring the tool-side + convention means authors learn one rule: **responses are return + values, interrupts are exceptions, everywhere**. + + Example (tool approval gate):: + + class Approval(BaseMiddleware): + async def wrap_tool(self, params, next_fn): + if params.tool.name == 'transfer_money' and not approved(): + raise Interrupt({'reason': 'requires_approval'}) + return await next_fn(params) + """ + return await next_fn(params) + + +class MiddlewareDesc(MiddlewareDescData): + """Registered middleware descriptor: wire shape + per-process factory closure. + + Inherits the wire fields (``name``, ``description``, ``config_schema``, + ``metadata``) from the auto-generated + :class:`genkit._core._typing.MiddlewareDescData` schema, and adds a + ``PrivateAttr`` factory used to mint a fresh :class:`BaseMiddleware` per + ``generate()`` call. ``PrivateAttr`` is excluded from serialization, so + ``model_dump(by_alias=True, exclude_none=True)`` produces the wire shape + directly. + + Stored under ``register_value('middleware', name, desc)`` and resolved + when ``generate()`` runs with a ``use=`` entry that references the + descriptor by name. This follows the same hand-authored runtime-subclass + convention as ``Message`` / ``MessageData`` and ``GenerateActionOptions`` / + ``GenerateActionOptionsData``: the runtime class adds non-serializable + behavior (here: the factory) on top of the pure wire schema. + """ + + # ``arbitrary_types_allowed`` lets the ``PrivateAttr`` carry an opaque ``Callable``; + # parent's ``alias_generator`` and ``extra='forbid'`` settings are inherited. + model_config = ConfigDict(arbitrary_types_allowed=True) + + # Factory takes ``config`` and mints a fresh BaseMiddleware instance per + # generate() call. The engine binds per-call attrs (``self.registry``, + # ``self.enqueue_parts``) onto the result before any hook fires. + _factory: Callable[[dict[str, Any] | None], BaseMiddleware] = PrivateAttr() + + def __init__( + self, + *, + factory: Callable[[dict[str, Any] | None], BaseMiddleware], + name: str, + description: str | None = None, + config_schema: object | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + _validate_middleware_key_segment(name, label='MiddlewareDesc name') + super().__init__( + name=name, + description=description, + config_schema=config_schema, + metadata=metadata, + ) + self._factory = factory + + def __call__(self, config: dict[str, Any] | None = None) -> BaseMiddleware: + """Return a fresh BaseMiddleware instance for this generate() call.""" + return self._factory(config) + + def with_name(self, name: str) -> MiddlewareDesc: + """Return a copy with the same factory and metadata but a different registry name.""" + return MiddlewareDesc( + factory=self._factory, + name=name, + description=self.description, + config_schema=self.config_schema, + metadata=self.metadata, + ) + + +def middleware( + name: str, + *, + description: str | None = None, + config_schema: dict[str, Any] | None = None, + metadata: dict[str, object] | None = None, +) -> Callable[[_M], _M]: + """Class decorator that sets registry metadata on a ``BaseMiddleware`` subclass. + + Required when registering middleware via ``new_middleware``, + ``define_middleware``, or ``middleware_plugin``. Optional for inline-only + use (``use=[MyClass()]``). + + Example: + @middleware(name='latency_logger', description='Logs model call latency') + class LatencyLogger(BaseMiddleware): + prefix: str = '[trace]' + + async def wrap_model(self, params, next_fn): ... + + Args: + name: Registry key. Must be a single path-free token (no ``/``, + whitespace, ``:``, backslashes, or control characters). + description: Human-readable description shown in the Dev UI. + config_schema: JSON Schema for the config. Inferred from pydantic + fields when omitted. + metadata: Arbitrary metadata passed through to the Dev UI wire format. + """ + _validate_middleware_key_segment(name, label='middleware name') + + def decorator(cls: _M) -> _M: + cls.name = name # type: ignore[attr-defined] + cls.description = description # type: ignore[attr-defined] + cls.middleware_config_schema = config_schema # type: ignore[attr-defined] + cls.middleware_metadata = metadata # type: ignore[attr-defined] + return cls + + return decorator + + +def _derive_config_schema(cls: type[BaseMiddleware]) -> dict[str, Any]: + """Build a JSON Schema describing a middleware's user-facing config fields. + + The Dev UI renders a config form for each registered middleware from this + schema. Without it the form has nothing to draw and falls back to a free-text + JSON box, so every middleware should expose one even when it has no knobs. + + Pydantic's full ``cls.model_json_schema()`` would also include the + framework-injected ``registry`` and ``enqueue_parts`` attributes — those + aren't config the user sets, and their types (a registry protocol, a + callable) aren't always representable in JSON Schema. Build the schema from + a stripped pydantic model containing only the subclass-added fields so the + Dev UI sees just the knobs the author meant to expose. + """ + base_fields = set(BaseMiddleware.model_fields) + new_fields: dict[str, Any] = { + field_name: (info.annotation, info) + for field_name, info in cls.model_fields.items() + if field_name not in base_fields + } + if not new_fields: + # Empty object schema still tells the Dev UI "this middleware exists + # and has no knobs", which renders as a no-input form rather than a + # raw JSON editor. + return { + 'type': 'object', + 'properties': {}, + 'additionalProperties': True, + } + try: + stripped = create_model( # type: ignore[call-overload] + f'{cls.__name__}Config', + __config__=ConfigDict(arbitrary_types_allowed=True), + **new_fields, + ) + return stripped.model_json_schema() + except Exception: + # If a config field carries a type pydantic can't translate, prefer a + # permissive empty schema over crashing the whole registration — + # the middleware itself still works, just without form generation. + return { + 'type': 'object', + 'properties': {}, + 'additionalProperties': True, + } + + +def new_middleware(middleware_cls: type[BaseMiddleware]) -> MiddlewareDesc: + """Create a ``MiddlewareDesc`` from a ``BaseMiddleware`` subclass. + + Set ``name``, and optionally ``description``, ``middleware_config_schema``, and + ``middleware_metadata`` on the class. The resulting factory instantiates the class + with ``**(config or {})`` when a request resolves the descriptor, so the same + pydantic fields on the class drive both the inline (``use=[Cls(...)]``) and the + registered (``MiddlewareRef(name=..., config=...)``) paths. + + When ``middleware_config_schema`` is not set explicitly, a JSON Schema is + derived from the subclass's pydantic fields so the Dev UI can render a + config form without the author having to hand-write one. + + Does not register; pass the result to ``middleware_plugin([...])`` or return from + a custom ``Plugin.list_middleware``. + + Args: + middleware_cls: A ``BaseMiddleware`` subclass with a non-empty ``name``. + + Returns: + A descriptor suitable for ``registry.register_value`` or ``middleware_plugin``. + """ + reg_name = middleware_cls.name + if not reg_name: + raise ValueError(f'{middleware_cls.__qualname__}.name must be set for new_middleware(MyClass).') + _validate_middleware_key_segment(str(reg_name), label=f'{middleware_cls.__qualname__}.name') + + def _factory(config: dict[str, Any] | None) -> BaseMiddleware: + # Instantiate with the incoming config so registered use is equivalent to + # ``use=[middleware_cls(**config)]``; empty/None config uses class defaults. + return middleware_cls(**(config or {})) + + config_schema = middleware_cls.middleware_config_schema + if config_schema is None: + config_schema = _derive_config_schema(middleware_cls) + + return MiddlewareDesc( + name=reg_name, + factory=_factory, + description=middleware_cls.description, + config_schema=config_schema, + metadata=middleware_cls.middleware_metadata, + ) diff --git a/py/packages/genkit/src/genkit/_core/_model.py b/py/packages/genkit/src/genkit/_core/_model.py index 3418f08c93..394f268b86 100644 --- a/py/packages/genkit/src/genkit/_core/_model.py +++ b/py/packages/genkit/src/genkit/_core/_model.py @@ -22,7 +22,7 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Callable, Sequence from copy import deepcopy from functools import cached_property from typing import Any, ClassVar, Generic, cast @@ -557,10 +557,3 @@ def count_parts(parts: list[Part]) -> tuple[int, int, int, int]: 'Role': Role, } ) - -# Type aliases for model middleware (Any is intentional - middleware is type-agnostic) -# Middleware can have two signatures: -# Simple (3 params): (req, ctx, next) -> response -# Streaming (4 params): (req, ctx, on_chunk, next) -> response -# The framework detects which signature is used based on parameter count. -ModelMiddleware = Callable[..., Awaitable[ModelResponse[Any]]] diff --git a/py/packages/genkit/src/genkit/_core/_plugin.py b/py/packages/genkit/src/genkit/_core/_plugin.py index 28afbce030..d104b868dc 100644 --- a/py/packages/genkit/src/genkit/_core/_plugin.py +++ b/py/packages/genkit/src/genkit/_core/_plugin.py @@ -14,11 +14,15 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Abstract base class for Genkit plugins.""" +"""Abstract base class for Genkit plugins and middleware_plugin helper.""" + +from __future__ import annotations import abc +from collections.abc import Sequence from genkit._core._action import Action, ActionKind +from genkit._core._middleware import MiddlewareDesc, _validate_middleware_key_segment from genkit._core._typing import ActionMetadata @@ -47,6 +51,20 @@ async def list_actions(self) -> list[ActionMetadata]: """ ... + def list_middleware(self) -> list[MiddlewareDesc]: + """Return middleware descriptors for this plugin to register on the app. + + This runs while :class:`Genkit` is being constructed, after + built-in middleware is registered. Use unique flat names without + slash characters so they do not collide with built-ins or other + plugins. + + Returns: + Descriptors to list in the Dev UI and to resolve by name from + ``generate(use=...)``. + """ + return [] + async def model(self, name: str) -> Action | None: """Resolve a model action by name (local or namespaced).""" target = name if '/' in name else f'{self.name}/{name}' @@ -56,3 +74,96 @@ async def embedder(self, name: str) -> Action | None: """Resolve an embedder action by name (local or namespaced).""" target = name if '/' in name else f'{self.name}/{name}' return await self.resolve(ActionKind.EMBEDDER, target) + + +class _MiddlewareDescsPlugin(Plugin): + """Plugin implementation that contributes only middleware descriptors.""" + + def __init__(self, plugin_name: str, descs: list[MiddlewareDesc]) -> None: + self.name = plugin_name + self._descs = descs + + async def init(self) -> list[Action]: + return [] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + return None + + async def list_actions(self) -> list[ActionMetadata]: + return [] + + def list_middleware(self) -> list[MiddlewareDesc]: + return list(self._descs) + + +def _middleware_registry_name(namespace: str | None, desc_name: str) -> str: + """Registry key for a descriptor under an optional namespace prefix.""" + if not namespace: + return desc_name + return f'{namespace}/{desc_name}' + + +def middleware_plugin( + descs: Sequence[MiddlewareDesc], + *, + namespace: str | None = None, +) -> Plugin: + """Wrap a list of middleware descriptors as a single plugin (for ``plugins=[...]``). + + Pass all descriptors for this plugin in one list so one plugin can + register several middlewares together. + + Example: + Genkit(plugins=[ + middleware_plugin( + [ + new_middleware(PrefixPromptMiddleware), + new_middleware(OtherMiddleware), + ], + namespace='myapp', + ), + ]) + + Build each item with ``new_middleware`` from ``genkit.middleware`` or + the same API on your :class:`Genkit` instance — neither registers by + itself. Registration happens when this plugin is passed in + ``plugins=[...]``. + + Args: + descs: Non-empty sequence of middleware descriptors. + namespace: Optional plugin namespace. + + * If set, it becomes the plugin name and each descriptor is + registered as ``{namespace}_{desc.name}`` (e.g. ``acme`` + + ``logging`` → ``acme_logging``). + * If omitted, the plugin name is ``extension-middleware`` and + registry keys stay the descriptors' own names. + + Same flat-segment rules as middleware descriptor names: no + ``/``, whitespace, ``:``, backslashes, or control characters. + + Returns: + A plugin whose ``list_middleware`` returns the descriptors (renamed + when ``namespace`` is set). + """ + built = list(descs) + if not built: + raise ValueError( + 'middleware_plugin() needs a non-empty list of MiddlewareDesc instances. ' + + 'Build each with new_middleware(...) from genkit.middleware or ai.new_middleware(...).' + ) + if not namespace: + ns = None + else: + ns = namespace.strip() or None + if ns is not None: + _validate_middleware_key_segment(ns, label='middleware_plugin namespace') + + if ns is None: + registered = built + else: + registered = [d.with_name(_middleware_registry_name(ns, d.name)) for d in built] + + plugin_name = ns if ns is not None else 'extension-middleware' + + return _MiddlewareDescsPlugin(plugin_name, registered) diff --git a/py/packages/genkit/src/genkit/_core/_protocols.py b/py/packages/genkit/src/genkit/_core/_protocols.py new file mode 100644 index 0000000000..bc9eb92103 --- /dev/null +++ b/py/packages/genkit/src/genkit/_core/_protocols.py @@ -0,0 +1,71 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Leaf module of structural interfaces (Protocols) for core Genkit types. + +Keeping interfaces here instead of in their implementation modules breaks +circular-import cycles. The pattern is: + + Module A needs B's type in annotations, but B depends on A → cycle. + Solution: extract B's interface here; A imports the interface, not B. + +Currently defined: + +- ``RegistryLike`` — structural Protocol covering the registry methods that + middleware and the generate engine actually call. Use instead + of the concrete ``Registry`` whenever a cycle would result. + The real ``Registry`` satisfies it structurally; no + ``register`` call or inheritance is needed. +""" + +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +from genkit._core._action import Action, ActionKind + + +@runtime_checkable +class RegistryLike(Protocol): + """Structural interface for the subset of Registry used by middleware and the generate engine. + + Middleware plugins (e.g. ``Fallback``) and ``generate_action`` depend only + on this interface, not the full concrete ``Registry``. This avoids pulling + in ``_registry.py`` from modules that ``_registry.py`` itself depends on. + + The concrete ``Registry`` satisfies this protocol structurally — no + subclassing or registration is required. + """ + + def new_child(self) -> RegistryLike: + """Return a scoped child registry that delegates misses to this one.""" + ... + + def lookup_value(self, kind: str, name: str) -> Any: # noqa: ANN401 + """Look up a registered value by kind and name.""" + ... + + def register_value(self, kind: str, name: str, value: object) -> None: + """Register an arbitrary value under kind/name.""" + ... + + def register_action_from_instance(self, action: Action) -> None: + """Register a pre-built Action instance.""" + ... + + async def resolve_action(self, kind: ActionKind, name: str) -> Action | None: + """Resolve an action by kind and name, initialising plugins as needed.""" + ... diff --git a/py/packages/genkit/src/genkit/_core/_reflection.py b/py/packages/genkit/src/genkit/_core/_reflection.py index eaf4d6edcf..94338b1036 100644 --- a/py/packages/genkit/src/genkit/_core/_reflection.py +++ b/py/packages/genkit/src/genkit/_core/_reflection.py @@ -41,6 +41,7 @@ from genkit._core._constants import GENKIT_VERSION from genkit._core._error import get_reflection_json from genkit._core._logger import get_logger +from genkit._core._middleware import MiddlewareDesc from genkit._core._registry import Registry logger = get_logger(__name__) @@ -176,10 +177,33 @@ def omit_none(payload: dict[str, Any]) -> dict[str, Any]: return JSONResponse(response, headers={'x-genkit-version': version}) - async def values(req: Request) -> JSONResponse: - if req.query_params.get('type') != 'defaultModel': - return JSONResponse({'error': 'Only type=defaultModel supported'}, status_code=400) - return JSONResponse(registry.list_values('defaultModel')) + async def values(req: Request) -> Response: + raw = req.query_params.get('type') + if not raw or not raw.strip(): + return JSONResponse( + {'error': 'Query parameter "type" is required.'}, + status_code=400, + headers={'x-genkit-version': version}, + ) + type_param = raw.strip() + try: + raw_values = registry.list_values(type_param) + if type_param == 'middleware': + serialized: dict[str, Any] = {} + for key, val in raw_values.items(): + if isinstance(val, MiddlewareDesc): + serialized[key] = val.model_dump(by_alias=True, exclude_none=True, mode='json') + else: + serialized[key] = val + raw_values = serialized + return JSONResponse(raw_values, headers={'x-genkit-version': version}) + except Exception: + logger.exception('Reflection /api/values failed') + return JSONResponse( + {'error': 'Failed to list values', 'detail': 'See Python process logs for the traceback.'}, + status_code=500, + headers={'x-genkit-version': version}, + ) async def envs(_: Request) -> JSONResponse: return JSONResponse(['dev']) diff --git a/py/packages/genkit/src/genkit/_core/_reflection_v2.py b/py/packages/genkit/src/genkit/_core/_reflection_v2.py index f6310c341f..17d722b7c5 100644 --- a/py/packages/genkit/src/genkit/_core/_reflection_v2.py +++ b/py/packages/genkit/src/genkit/_core/_reflection_v2.py @@ -380,6 +380,11 @@ async def _handle_list_values(self, req_id: str | int | None, params: dict[str, to_json_fn = getattr(value, 'to_json', None) if value is not None else None if callable(to_json_fn): mapped[name] = to_json_fn() + elif isinstance(value, BaseModel): + # Without this, ``json.dumps(default=str)`` in ``_send_message`` + # would fall back to ``str(value)`` and the dev-ui would receive + # pydantic's ``__repr__`` text instead of an object. + mapped[name] = value.model_dump(by_alias=True, exclude_none=True, mode='json') else: mapped[name] = value await self._send_response(sid, {'values': mapped}) diff --git a/py/packages/genkit/src/genkit/_core/_typing.py b/py/packages/genkit/src/genkit/_core/_typing.py index 9fc11dbed2..cf01105306 100644 --- a/py/packages/genkit/src/genkit/_core/_typing.py +++ b/py/packages/genkit/src/genkit/_core/_typing.py @@ -175,8 +175,8 @@ class GenkitError(GenkitModel): data: Data | None = None -class MiddlewareDesc(GenkitModel): - """Model for middlewaredesc data.""" +class MiddlewareDescData(GenkitModel): + """Model for middlewaredescdata data.""" model_config: ClassVar[ConfigDict] = ConfigDict(alias_generator=to_camel, extra='forbid', populate_by_name=True) name: str = Field(...) diff --git a/py/packages/genkit/src/genkit/middleware/__init__.py b/py/packages/genkit/src/genkit/middleware/__init__.py new file mode 100644 index 0000000000..8e2618dcdd --- /dev/null +++ b/py/packages/genkit/src/genkit/middleware/__init__.py @@ -0,0 +1,72 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Middleware for Genkit model calls. + +This module provides types and helpers to define custom middleware. + +Chain ordering: middleware is applied first-in, outermost. + +Define a middleware class with ``@middleware`` and pass instances inline +via ``use=``: + + from genkit import Genkit + from genkit.middleware import BaseMiddleware, middleware + + @middleware(name='logging') + class LoggingMiddleware(BaseMiddleware): + async def wrap_generate(self, params, next_fn): + print('before') + result = await next_fn(params) + print('after') + return result + + ai = Genkit() + + response = await ai.generate( + model='your-model-here', + prompt='Hello', + use=[LoggingMiddleware()], + ) + +To make middleware available to the **Dev UI** and referenceable by name, +register it on the app via ``ai.define_middleware`` or declare it in a +plugin: + + ai.define_middleware(LoggingMiddleware) +""" + +from genkit._core._middleware import ( + BaseMiddleware, + GenerateHookParams, + MiddlewareDesc, + ModelHookParams, + MultipartToolResponse, + ToolHookParams, + middleware, +) +from genkit._core._plugin import middleware_plugin + +__all__ = [ + 'BaseMiddleware', + 'GenerateHookParams', + 'MiddlewareDesc', + 'ModelHookParams', + 'MultipartToolResponse', + 'ToolHookParams', + 'middleware', + 'middleware_plugin', +] diff --git a/py/packages/genkit/src/genkit/plugin_api/__init__.py b/py/packages/genkit/src/genkit/plugin_api/__init__.py index f5103f0438..94bf029de3 100644 --- a/py/packages/genkit/src/genkit/plugin_api/__init__.py +++ b/py/packages/genkit/src/genkit/plugin_api/__init__.py @@ -24,7 +24,8 @@ from genkit._core._error import GenkitError, StatusCodes, StatusName, get_callable_json from genkit._core._http_client import get_cached_client from genkit._core._loop_cache import _loop_local_client as loop_local_client -from genkit._core._plugin import Plugin +from genkit._core._middleware import new_middleware +from genkit._core._plugin import Plugin, middleware_plugin from genkit._core._schema import to_json_schema from genkit._core._trace._adjusting_exporter import AdjustingTraceExporter, RedactedSpan from genkit._core._trace._path import to_display_path @@ -55,6 +56,8 @@ __all__ = [ # Base class and framework primitives 'Plugin', + 'new_middleware', + 'middleware_plugin', 'Action', 'ActionMetadata', 'ActionKind', diff --git a/py/packages/genkit/tests/genkit/ai/ai_plugin_test.py b/py/packages/genkit/tests/genkit/ai/ai_plugin_test.py index f6a18fa246..ff5c69cfce 100644 --- a/py/packages/genkit/tests/genkit/ai/ai_plugin_test.py +++ b/py/packages/genkit/tests/genkit/ai/ai_plugin_test.py @@ -27,6 +27,8 @@ from genkit._core._model import ModelRequest from genkit._core._registry import ActionKind from genkit._core._typing import ActionMetadata, FinishReason +from genkit.middleware import BaseMiddleware, MiddlewareDesc, middleware +from genkit.plugin_api import new_middleware class AsyncResolveOnlyPlugin(Plugin): @@ -107,6 +109,40 @@ async def list_actions(self) -> list[ActionMetadata]: ] +@middleware(name='ai_plugin_test_mw') +class _RegistryMw(BaseMiddleware): + pass + + +class MiddlewareListingPlugin(Plugin): + """Plugin that contributes middleware via list_middleware.""" + + name = 'mw-list-plugin' + + async def init(self) -> list[Action]: + return [] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + return None + + async def list_actions(self) -> list[ActionMetadata]: + return [] + + def list_middleware(self) -> list[MiddlewareDesc]: + return [new_middleware(_RegistryMw)] + + +@pytest.mark.asyncio +async def test_plugin_list_middleware_registers_on_registry() -> None: + """Descriptors from Plugin.list_middleware appear under list_values('middleware').""" + ai = Genkit(plugins=[MiddlewareListingPlugin()]) + names = ai.registry.list_values('middleware') + assert 'ai_plugin_test_mw' in names + desc = ai.registry.lookup_value('middleware', 'ai_plugin_test_mw') + assert desc is not None + assert isinstance(desc, MiddlewareDesc) + + @pytest.mark.asyncio async def test_async_resolve_is_awaited_via_generate() -> None: """Test that async resolve is awaited when calling generate.""" diff --git a/py/packages/genkit/tests/genkit/ai/generate_test.py b/py/packages/genkit/tests/genkit/ai/generate_test.py index 97e999c6e9..61c38d7631 100644 --- a/py/packages/genkit/tests/genkit/ai/generate_test.py +++ b/py/packages/genkit/tests/genkit/ai/generate_test.py @@ -12,10 +12,10 @@ import pytest import yaml -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel, Field, TypeAdapter -from genkit import ActionKind, Document, Genkit, Message, ModelResponse, ModelResponseChunk -from genkit._ai._generate import generate_action +from genkit import ActionKind, Document, Genkit, Message, MiddlewareRef, ModelResponse, ModelResponseChunk +from genkit._ai._generate import _augment_with_context, generate_action from genkit._ai._model import text_from_content, text_from_message from genkit._ai._testing import ( ProgrammableModel, @@ -23,18 +23,29 @@ define_programmable_model, ) from genkit._ai._tools import Interrupt, define_tool -from genkit._core._action import ActionRunContext from genkit._core._model import GenerateActionOptions, ModelRequest from genkit._core._registry import Registry from genkit._core._typing import ( DocumentPart, FinishReason, Part, + Resume, Role, TextPart, ToolRequest, ToolRequestPart, ) +from genkit.middleware import ( + BaseMiddleware, + GenerateHookParams, + MiddlewareDesc, + ModelHookParams, + MultipartToolResponse, + ToolHookParams, + middleware, + middleware_plugin, +) +from genkit.plugin_api import new_middleware def _to_dict(obj: object) -> object: @@ -147,35 +158,190 @@ async def test_simulates_doc_grounding( ) -@pytest.mark.asyncio -async def test_generate_applies_middleware( - setup_test: tuple[Genkit, ProgrammableModel], -) -> None: - """When middleware is provided, apply it.""" - ai, *_ = setup_test - define_echo_model(ai) +# --------------------------------------------------------------------------- # +# Unit tests for the private _augment_with_context helper # +# --------------------------------------------------------------------------- # - async def pre_middle( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message(role=Role.USER, content=[Part(TextPart(text=f'PRE {txt}'))]), + +def test_augment_with_context_ignores_no_docs() -> None: + """No docs -> request returned unchanged (same object identity).""" + req = ModelRequest( + messages=[ + Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))]), + ], + ) + + transformed_req = _augment_with_context(req) + + assert transformed_req is req + + +def test_augment_with_context_adds_docs_as_context() -> None: + """Docs are injected as a context-purpose part appended to the last user message.""" + req = ModelRequest( + messages=[ + Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))]), + ], + docs=[ + Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), + Document(content=[DocumentPart(root=TextPart(text='doc content 2'))]), + ], + ) + + transformed_req = _augment_with_context(req) + + assert transformed_req == ModelRequest( + messages=[ + Message( + role=Role.USER, + content=[ + Part(root=TextPart(text='hi')), + Part( + root=TextPart( + text='\n\nUse the following information to complete ' + + 'your task:\n\n' + + '- [0]: doc content 1\n' + + '- [1]: doc content 2\n\n', + metadata={'purpose': 'context'}, + ) + ), + ], + ) + ], + docs=[ + Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), + Document(content=[DocumentPart(root=TextPart(text='doc content 2'))]), + ], + ) + + +def test_augment_with_context_does_not_mutate_input() -> None: + """Input request and its messages are not mutated; helper returns a deepcopy.""" + original_user_msg = Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))]) + req = ModelRequest( + messages=[original_user_msg], + docs=[Document(content=[DocumentPart(root=TextPart(text='doc content 1'))])], + ) + original_content_len = len(original_user_msg.content) + + transformed_req = _augment_with_context(req) + + assert transformed_req is not req + assert transformed_req.messages[0] is not original_user_msg + assert len(original_user_msg.content) == original_content_len + assert len(transformed_req.messages[0].content) == original_content_len + 1 + + +def test_augment_with_context_skips_when_context_already_rendered() -> None: + """Already-rendered context (purpose=context, no pending flag) is left untouched. + + If a message already contains a context part that was previously rendered + (non-pending), _augment_with_context should return the original request + unchanged rather than injecting the docs again. + """ + req = ModelRequest( + messages=[ + Message( + role=Role.USER, + content=[ + Part( + root=TextPart( + text='this is already context', + metadata={'purpose': 'context'}, + ) + ), + Part(root=TextPart(text='hi')), ], ), - ctx, + ], + docs=[ + Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), + ], + ) + + transformed_req = _augment_with_context(req) + + assert transformed_req is req + + +def test_augment_with_context_with_purpose_part() -> None: + """A pending context placeholder is replaced in-place with the rendered docs. + + Prompts can include a Part with metadata={'purpose': 'context', 'pending': True} + as a placeholder. _augment_with_context locates it and swaps it out for the + actual rendered document context, preserving the surrounding parts. + """ + req = ModelRequest( + messages=[ + Message( + role=Role.USER, + content=[ + Part( + root=TextPart( + text='insert context here', + metadata={'purpose': 'context', 'pending': True}, + ) + ), + Part(root=TextPart(text='hi')), + ], + ), + ], + docs=[ + Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), + ], + ) + + transformed_req = _augment_with_context(req) + + assert transformed_req == ModelRequest( + messages=[ + Message( + role=Role.USER, + content=[ + Part( + root=TextPart( + text='\n\nUse the following information to complete ' + + 'your task:\n\n' + + '- [0]: doc content 1\n\n', + metadata={'purpose': 'context'}, + ) + ), + Part(root=TextPart(text='hi')), + ], + ) + ], + docs=[ + Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), + ], + ) + + +# --------------------------------------------------------------------------- # +# Middleware class definitions shared by tests below # +# --------------------------------------------------------------------------- # + + +@middleware(name='pre_mw') +class PreMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + txt = ''.join(text_from_message(m) for m in params.request.messages) + return await next_fn( + ModelHookParams( + request=ModelRequest( + messages=[ + Message(role=Role.USER, content=[Part(TextPart(text=f'PRE {txt}'))]), + ], + ), + on_chunk=params.on_chunk, + context=params.context, + ) ) - async def post_middle( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - resp: ModelResponse = await next(req, ctx) + +@middleware(name='post_mw') +class PostMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + resp: ModelResponse = await next_fn(params) assert resp.message is not None txt = text_from_message(resp.message) return ModelResponse( @@ -183,6 +349,199 @@ async def post_middle( message=Message(role=Role.USER, content=[Part(TextPart(text=f'{txt} POST'))]), ) + +@pytest.mark.asyncio +async def test_generate_accepts_inline_base_middleware_instance() -> None: + """Inline ``BaseMiddleware`` instances in ``use=`` run without registration.""" + ai = Genkit() + define_echo_model(ai) + + response = await ai.generate( + model='echoModel', + prompt='hi', + use=[PreMiddleware(), PostMiddleware()], + ) + + assert response.text == '[ECHO] user: "PRE hi" POST' + + +@pytest.mark.asyncio +async def test_generate_interleaves_inline_instances_and_middleware_refs() -> None: + """Inline instances and ``MiddlewareRef`` entries preserve ``use=`` ordering together.""" + ai = Genkit(plugins=[middleware_plugin([new_middleware(PostMiddleware)])]) + define_echo_model(ai) + + response = await ai.generate( + model='echoModel', + prompt='hi', + use=[PreMiddleware(), MiddlewareRef(name='post_mw')], + ) + + assert response.text == '[ECHO] user: "PRE hi" POST' + + +@middleware(name='configured_prefix_mw') +class ConfiguredPrefixMiddleware(BaseMiddleware): + """Inline middleware driven purely by a pydantic config field.""" + + prefix: str = 'DEFAULT' + + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + txt = ''.join(text_from_message(m) for m in params.request.messages) + return await next_fn( + ModelHookParams( + request=ModelRequest( + messages=[ + Message(role=Role.USER, content=[Part(TextPart(text=f'{self.prefix} {txt}'))]), + ], + ), + on_chunk=params.on_chunk, + context=params.context, + ) + ) + + +@pytest.mark.asyncio +async def test_generate_inline_instance_uses_pydantic_fields() -> None: + """Config fields passed at construction time drive inline behavior.""" + ai = Genkit() + define_echo_model(ai) + + response = await ai.generate( + model='echoModel', + prompt='hi', + use=[ConfiguredPrefixMiddleware(prefix='[TRACE]')], + ) + + assert response.text == '[ECHO] user: "[TRACE] hi"' + + +@pytest.mark.asyncio +async def test_generate_middleware_ref_config_instantiates_class() -> None: + """``MiddlewareRef(config=...)`` feeds ``**config`` into the class constructor.""" + ai = Genkit(plugins=[middleware_plugin([new_middleware(ConfiguredPrefixMiddleware)])]) + define_echo_model(ai) + + response = await ai.generate( + model='echoModel', + prompt='hi', + use=[MiddlewareRef(name='configured_prefix_mw', config={'prefix': '[SPAN]'})], + ) + + assert response.text == '[ECHO] user: "[SPAN] hi"' + + +@pytest.mark.asyncio +async def test_define_middleware_registers_on_the_fly() -> None: + """``ai.define_middleware(cls)`` makes the definition resolvable by name.""" + ai = Genkit() + define_echo_model(ai) + ai.define_middleware(ConfiguredPrefixMiddleware) + + response = await ai.generate( + model='echoModel', + prompt='hi', + use=[MiddlewareRef(name='configured_prefix_mw', config={'prefix': '[LIVE]'})], + ) + + assert response.text == '[ECHO] user: "[LIVE] hi"' + + +@pytest.mark.asyncio +async def test_prompt_call_runs_middleware_declared_on_prompt() -> None: + """``ai.define_prompt(use=[...])`` actually runs those middleware on call.""" + ai = Genkit() + define_echo_model(ai) + + my_prompt = ai.define_prompt( + model='echoModel', + prompt='hi', + use=[PreMiddleware(), PostMiddleware()], + ) + + response = await my_prompt() + + assert response.text == '[ECHO] user: "PRE hi" POST' + + +@pytest.mark.asyncio +async def test_prompt_call_runs_per_call_middleware() -> None: + """``my_prompt(use=[...])`` per-call middleware run too.""" + ai = Genkit() + define_echo_model(ai) + + my_prompt = ai.define_prompt(model='echoModel', prompt='hi') + + response = await my_prompt(use=[PreMiddleware(), PostMiddleware()]) + + assert response.text == '[ECHO] user: "PRE hi" POST' + + +@pytest.mark.asyncio +async def test_prompt_call_use_interleaves_inline_and_refs() -> None: + """Prompts mix inline ``BaseMiddleware`` and ``MiddlewareRef`` like ``generate``.""" + ai = Genkit(plugins=[middleware_plugin([new_middleware(PostMiddleware)])]) + define_echo_model(ai) + + my_prompt = ai.define_prompt( + model='echoModel', + prompt='hi', + use=[PreMiddleware(), MiddlewareRef(name='post_mw')], + ) + + response = await my_prompt() + + assert response.text == '[ECHO] user: "PRE hi" POST' + + +@pytest.mark.asyncio +async def test_prompt_per_call_use_overrides_prompt_use() -> None: + """Per-call ``use=`` replaces the prompt's declared ``use``, matching ``opts.tools`` semantics.""" + ai = Genkit() + define_echo_model(ai) + + my_prompt = ai.define_prompt( + model='echoModel', + prompt='hi', + use=[PreMiddleware()], + ) + + response = await my_prompt(use=[PostMiddleware()]) + + assert response.text == '[ECHO] user: "hi" POST' + + +@pytest.mark.asyncio +async def test_prompt_stream_runs_middleware() -> None: + """``.stream()`` shares the middleware path with ``__call__``.""" + ai = Genkit() + define_echo_model(ai) + + my_prompt = ai.define_prompt( + model='echoModel', + prompt='hi', + use=[PreMiddleware(), PostMiddleware()], + ) + + streamed = my_prompt.stream() + response = await streamed.response + + assert response.text == '[ECHO] user: "PRE hi" POST' + + +@pytest.mark.asyncio +async def test_generate_applies_middleware() -> None: + """When middleware is provided, apply it via MiddlewareRef resolution.""" + ai = Genkit( + plugins=[ + middleware_plugin([ + new_middleware(PreMiddleware), + new_middleware(PostMiddleware), + ]) + ], + ) + define_echo_model(ai) + response = await generate_action( ai.registry, GenerateActionOptions( @@ -193,34 +552,19 @@ async def post_middle( content=[Part(TextPart(text='hi'))], ), ], + use=[MiddlewareRef(name='pre_mw'), MiddlewareRef(name='post_mw')], ), - middleware=[pre_middle, post_middle], ) assert response.text == '[ECHO] user: "PRE hi" POST' @pytest.mark.asyncio -async def test_generate_middleware_next_fn_args_optional( - setup_test: tuple[Genkit, ProgrammableModel], -) -> None: - """Can call next function without args (convenience).""" - ai, *_ = setup_test +async def test_generate_middleware_next_fn_args_optional() -> None: + """Can call next function without modifying params (pass params through).""" + ai = Genkit(plugins=[middleware_plugin([new_middleware(PostMiddleware)])]) define_echo_model(ai) - async def post_middle( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - resp: ModelResponse = await next(req, ctx) - assert resp.message is not None - txt = text_from_message(resp.message) - return ModelResponse( - finish_reason=resp.finish_reason, - message=Message(role=Role.USER, content=[Part(TextPart(text=f'{txt} POST'))]), - ) - response = await generate_action( ai.registry, GenerateActionOptions( @@ -231,46 +575,58 @@ async def post_middle( content=[Part(TextPart(text='hi'))], ), ], + use=[MiddlewareRef(name='post_mw')], ), - middleware=[post_middle], ) assert response.text == '[ECHO] user: "hi" POST' -@pytest.mark.asyncio -async def test_generate_middleware_can_modify_context( - setup_test: tuple[Genkit, ProgrammableModel], -) -> None: - """Test that middleware can modify context.""" - ai, *_ = setup_test - define_echo_model(ai) +@middleware(name='add_ctx') +class AddContextMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + return await next_fn( + ModelHookParams( + request=params.request, + on_chunk=params.on_chunk, + context={**params.context, 'banana': True}, + ) + ) - async def add_context( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - return await next(req, ActionRunContext(context={**ctx.context, 'banana': True})) - async def inject_context( - req: ModelRequest, - ctx: ActionRunContext, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[Part(TextPart(text=f'{txt} {ctx.context}'))], - ), - ], - ), - ctx, +@middleware(name='inject_ctx') +class InjectContextMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + txt = ''.join(text_from_message(m) for m in params.request.messages) + return await next_fn( + ModelHookParams( + request=ModelRequest( + messages=[ + Message( + role=Role.USER, + content=[Part(TextPart(text=f'{txt} {params.context}'))], + ), + ], + ), + on_chunk=params.on_chunk, + context=params.context, + ) ) + +@pytest.mark.asyncio +async def test_generate_middleware_can_modify_context() -> None: + """Test that middleware can modify context via ModelHookParams.context.""" + ai = Genkit( + plugins=[ + middleware_plugin([ + new_middleware(AddContextMiddleware), + new_middleware(InjectContextMiddleware), + ]) + ], + ) + define_echo_model(ai) + response = await generate_action( ai.registry, GenerateActionOptions( @@ -281,8 +637,8 @@ async def inject_context( content=[Part(TextPart(text='hi'))], ), ], + use=[MiddlewareRef(name='add_ctx'), MiddlewareRef(name='inject_ctx')], ), - middleware=[add_context, inject_context], context={'foo': 'bar'}, ) @@ -290,11 +646,46 @@ async def inject_context( @pytest.mark.asyncio -async def test_generate_middleware_can_modify_stream( - setup_test: tuple[Genkit, ProgrammableModel], -) -> None: - """Test that middleware can modify streams.""" - ai, pm = setup_test +async def test_generate_middleware_can_modify_stream() -> None: + """Test that middleware can intercept and modify streaming chunks.""" + + @middleware(name='mod_stream_mw') + class ModifyStreamMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + if params.on_chunk: + params.on_chunk( + ModelResponseChunk( + role=Role.MODEL, + content=[Part(TextPart(text='something extra before'))], + ) + ) + + def chunk_handler(chunk: ModelResponseChunk) -> None: + if params.on_chunk: + params.on_chunk( + ModelResponseChunk( + role=Role.MODEL, + content=[Part(TextPart(text=f'intercepted: {text_from_content(chunk.content)}'))], + ) + ) + + new_params = ModelHookParams( + request=params.request, + on_chunk=chunk_handler, + context=params.context, + ) + resp = await next_fn(new_params) + if params.on_chunk: + params.on_chunk( + ModelResponseChunk( + role=Role.MODEL, + content=[Part(TextPart(text='something extra after'))], + ) + ) + return resp + + ai = Genkit(plugins=[middleware_plugin([new_middleware(ModifyStreamMiddleware)])]) + pm, _ = define_programmable_model(ai) pm.responses.append( ModelResponse( @@ -310,40 +701,6 @@ async def test_generate_middleware_can_modify_stream( ] ] - async def modify_stream( - req: ModelRequest, - ctx: ActionRunContext, - on_chunk: Callable[[ModelResponseChunk], None] | None, - next: Callable[..., Awaitable[ModelResponse]], - ) -> ModelResponse: - # 4-param streaming middleware signature - if on_chunk: - on_chunk( - ModelResponseChunk( - role=Role.MODEL, - content=[Part(TextPart(text='something extra before'))], - ) - ) - - def chunk_handler(chunk: ModelResponseChunk) -> None: - if on_chunk: - on_chunk( - ModelResponseChunk( - role=Role.MODEL, - content=[Part(TextPart(text=f'intercepted: {text_from_content(chunk.content)}'))], - ) - ) - - resp = await next(req, ctx, chunk_handler) - if on_chunk: - on_chunk( - ModelResponseChunk( - role=Role.MODEL, - content=[Part(TextPart(text='something extra after'))], - ) - ) - return resp - got_chunks = [] def collect_chunks(c: ModelResponseChunk) -> None: @@ -359,8 +716,8 @@ def collect_chunks(c: ModelResponseChunk) -> None: content=[Part(TextPart(text='hi'))], ), ], + use=[MiddlewareRef(name='mod_stream_mw')], ), - middleware=[modify_stream], on_chunk=collect_chunks, ) @@ -374,6 +731,502 @@ def collect_chunks(c: ModelResponseChunk) -> None: ] +class TrackGenerateMiddleware(BaseMiddleware): + """Middleware that records wrap_generate calls per turn.""" + + iterations: list[int] = Field(default_factory=list) + + async def wrap_generate( + self, + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + self.iterations.append(params.iteration) + return await next_fn(params) + + +@pytest.mark.asyncio +async def test_wrap_generate_called_per_turn() -> None: + """wrap_generate is invoked for each turn of the generate loop. + + This is the two-turn regression test: verifies middleware runs on *every* + recursive _generate_action call (turn 0 + turn 1 after tool response). + """ + track_mw = TrackGenerateMiddleware() + track_mw2 = TrackGenerateMiddleware() + ai = Genkit( + plugins=[ + middleware_plugin([ + MiddlewareDesc( + name='track_gen', + description='track generate', + factory=lambda _opts: track_mw, + ), + MiddlewareDesc( + name='track_gen2', + description='track generate 2', + factory=lambda _opts: track_mw2, + ), + ]) + ], + ) + pm, _ = define_programmable_model(ai) + + @ai.tool(name='testTool') + async def _test_tool() -> object: + return 'tool called' + + # No tools: single turn → wrap_generate called once with iteration=0 + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='done'))]), + ) + ) + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='hi'))])], + use=[MiddlewareRef(name='track_gen')], + ), + ) + assert response.text == 'done' + assert track_mw.iterations == [0] + + # With tools: two turns (model→tool→model) → wrap_generate called for each + pm.responses.append( + ModelResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=ToolRequestPart(tool_request=ToolRequest(name='testTool', input={}, ref='r1')))], + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='final'))]), + ) + ) + response2 = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='hi'))])], + tools=['testTool'], + use=[MiddlewareRef(name='track_gen2')], + ), + ) + assert response2.text == 'final' + assert track_mw2.iterations == [0, 1] + + +class TrackToolMiddleware(BaseMiddleware): + """Middleware that records wrap_tool calls.""" + + tool_names: list[str] = Field(default_factory=list) + + async def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]], + ) -> MultipartToolResponse: + self.tool_names.append(params.tool_request_part.tool_request.name) + return await next_fn(params) + + +@pytest.mark.asyncio +async def test_wrap_tool_called_on_tool_execution() -> None: + """wrap_tool is invoked for each tool execution.""" + track_mw = TrackToolMiddleware() + ai = Genkit( + plugins=[ + middleware_plugin([ + MiddlewareDesc( + name='track_tool', + description='track tool', + factory=lambda _opts: track_mw, + ), + ]) + ], + ) + pm, _ = define_programmable_model(ai) + + @ai.tool(name='myTool') + async def my_tool() -> object: + return 'result' + + pm.responses.append( + ModelResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=ToolRequestPart(tool_request=ToolRequest(name='myTool', input={}, ref='r1')))], + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='done'))]), + ) + ) + + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='hi'))])], + tools=['myTool'], + use=[MiddlewareRef(name='track_tool')], + ), + ) + assert response.text == 'done' + assert track_mw.tool_names == ['myTool'] + + +@pytest.mark.asyncio +async def test_middleware_wrap_tool_interrupt_handled_as_interrupt_not_crash() -> None: + """Interrupt raised by wrap_tool middleware is converted to an interrupt part. + + This is a regression test: before the fix, a middleware-raised Interrupt + bypassed _resolve_tool_request's except block and propagated uncaught through + asyncio.gather, crashing generation instead of surfacing as a tool interrupt. + """ + from genkit._ai._tools import Interrupt + + @middleware(name='interrupt_all') + class InterruptingMiddleware(BaseMiddleware): + async def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]], + ) -> MultipartToolResponse: + raise Interrupt({'blocked': True}) + + ai = Genkit( + plugins=[ + middleware_plugin([ + MiddlewareDesc( + name='interrupt_all', + description='interrupt all tools', + factory=lambda _opts: InterruptingMiddleware(), + ), + ]) + ], + ) + pm, _ = define_programmable_model(ai) + + @ai.tool(name='blockedTool') + async def blocked_tool() -> str: + return 'should not run' + + pm.responses.append( + ModelResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=ToolRequestPart(tool_request=ToolRequest(name='blockedTool', input={}, ref='r1')))], + ), + ) + ) + + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='do it'))])], + tools=['blockedTool'], + use=[MiddlewareRef(name='interrupt_all')], + ), + ) + assert response.finish_reason == FinishReason.INTERRUPTED + assert response.message is not None + interrupt_parts = [ + p + for p in response.message.content + if isinstance(p.root, ToolRequestPart) and p.root.metadata and 'interrupt' in p.root.metadata + ] + assert len(interrupt_parts) == 1 + assert interrupt_parts[0].root.metadata is not None + assert interrupt_parts[0].root.metadata['interrupt'] == {'blocked': True} + + +@pytest.mark.asyncio +async def test_middleware_contributed_tools_available_to_model() -> None: + """Middleware.tools() contributes actions scoped to the generate call (child registry). + + The contributed tool is resolvable by the model during the call but must not + appear in the root registry afterward — mirroring Go's Hooks.Tools + NewChild. + """ + + @middleware(name='tool_provider_mw') + class ToolProviderMiddleware(BaseMiddleware): + """Middleware that contributes a tool dynamically per generate() call.""" + + def tools(self) -> list: + # Build a tool action on a throw-away registry; the generate engine + # will adopt it into a call-scoped child registry. + scratch = Registry() + + async def provided_tool() -> str: + """A tool injected by middleware.""" + return 'from_middleware_tool' + + t = define_tool(scratch, provided_tool, name='middleware_tool') + return [t.action()] + + ai = Genkit(plugins=[middleware_plugin([new_middleware(ToolProviderMiddleware)])]) + pm, _ = define_programmable_model(ai) + + # Turn 1: model calls the middleware-contributed tool + pm.responses.append( + ModelResponse( + message=Message( + role=Role.MODEL, + content=[ + Part(root=ToolRequestPart(tool_request=ToolRequest(name='middleware_tool', input={}, ref='r1'))) + ], + ), + ) + ) + # Turn 2: model returns final answer after tool result + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='done'))]), + ) + ) + + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='hi'))])], + use=[MiddlewareRef(name='tool_provider_mw')], + ), + ) + assert response.text == 'done' + + # The contributed tool must NOT be visible in the root registry after the call. + assert await ai.registry.resolve_action(ActionKind.TOOL, 'middleware_tool') is None + + +@pytest.mark.asyncio +async def test_middleware_in_one_call_share_an_isolated_registry() -> None: + """Middleware in the same generate() call share an isolated registry. + + This verifies: + + - **Cooperation:** Middleware A contributes a tool via ``tools()`` and + middleware B resolves it through ``self.registry`` in the same call + (proves both middleware see the same per-call child registry, so they + can pass tools and other actions to one another). + - **Isolation:** Anything middleware writes via ``self.registry`` does NOT + survive the call (proves writes are auto-cleaned and cannot leak into the + root registry or across concurrent generate() calls). + """ + seen_by_b: list[str] = [] + + @middleware(name='provider_mw') + class ProviderMW(BaseMiddleware): + def tools(self) -> list: + scratch = Registry() + + async def shared_tool() -> str: + """Shared by all middleware in the call.""" + return 'shared_ok' + + return [define_tool(scratch, shared_tool, name='shared_tool').action()] + + @middleware(name='looker_mw') + class LookerMW(BaseMiddleware): + async def wrap_generate( + self, + params: GenerateHookParams, + next_fn: Callable[[GenerateHookParams], Awaitable[ModelResponse]], + ) -> ModelResponse: + # Resolve the tool ProviderMW just contributed — only works if + # both middleware share the same per-call registry scope. + tool = await self.registry.resolve_action(ActionKind.TOOL, 'shared_tool') + if tool is not None: + seen_by_b.append(tool.name) + # Also exercise the write path: anything we register through + # self.registry must not survive the call. + scratch = Registry() + + async def leaky_tool() -> str: + """Should not survive the call.""" + return 'nope' + + leak = define_tool(scratch, leaky_tool, name='leaky_tool').action() + self.registry.register_action_from_instance(leak) + return await next_fn(params) + + ai = Genkit( + plugins=[ + middleware_plugin([ + new_middleware(ProviderMW), + new_middleware(LookerMW), + ]) + ], + ) + pm, _ = define_programmable_model(ai) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='ok'))]), + ) + ) + + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='hi'))])], + use=[ + MiddlewareRef(name='provider_mw'), + MiddlewareRef(name='looker_mw'), + ], + ), + ) + assert response.text == 'ok' + assert seen_by_b == ['shared_tool'], f'looker middleware should have resolved shared_tool, saw: {seen_by_b}' + # Neither tool may leak into the root registry after the call ends. + assert await ai.registry.resolve_action(ActionKind.TOOL, 'shared_tool') is None + assert await ai.registry.resolve_action(ActionKind.TOOL, 'leaky_tool') is None + + +@pytest.mark.asyncio +async def test_queue_drain_streams_each_message_at_one_index() -> None: + """Queued tool middleware messages stream as exactly one chunk per message. + + Regression: the old queue-drain path called ``make_chunk(USER, ...)`` for + each queued message AND then did ``message_index += 1``. ``make_chunk`` + *also* advanced the index (role flip from MODEL to USER), so each queued + message bumped the counter twice — leaving a hole in the stream sequence. + The fix emits queued chunks directly and increments once per message. + """ + + @middleware(name='enqueuing_mw') + class EnqueuingMW(BaseMiddleware): + """After each tool call, enqueue an extra USER part for the next turn.""" + + async def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]], + ) -> MultipartToolResponse: + result = await next_fn(params) + self.enqueue_parts([Part(TextPart(text='extra-context'))]) + return result + + ai = Genkit(plugins=[middleware_plugin([new_middleware(EnqueuingMW)])]) + pm, _ = define_programmable_model(ai) + + @ai.tool(name='trigger') + async def trigger() -> str: + return 'triggered' + + pm.responses.append( + ModelResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=ToolRequestPart(tool_request=ToolRequest(name='trigger', input={}, ref='r1')))], + ), + ) + ) + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='final'))]), + ) + ) + + streamed: list[ModelResponseChunk] = [] + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[Message(role=Role.USER, content=[Part(TextPart(text='go'))])], + tools=['trigger'], + use=[MiddlewareRef(name='enqueuing_mw')], + ), + on_chunk=streamed.append, + ) + assert response.text == 'final' + + user_chunks = [c for c in streamed if c.role == Role.USER] + assert len(user_chunks) == 1, ( + f'expected exactly one streamed user chunk for the queued message, saw ' + f'{[(c.role, c.index) for c in user_chunks]}' + ) + indices = [c.index or 0 for c in streamed] + assert indices == sorted(indices), f'indices not monotonic: {indices}' + + +@pytest.mark.asyncio +async def test_restart_path_routes_through_wrap_tool_middleware() -> None: + """Restarting a tool via ``resume_restart`` must invoke ``wrap_tool`` middleware. + + Regression: ``_resolve_resumed_tool_request`` used to call + ``run_tool_after_restart`` directly, skipping the middleware chain. That + silently bypassed ToolApproval / Filesystem / etc. on every restart. + """ + invocations: list[str] = [] + + @middleware(name='recording_mw') + class RecordingMW(BaseMiddleware): + async def wrap_tool( + self, + params: ToolHookParams, + next_fn: Callable[[ToolHookParams], Awaitable[MultipartToolResponse]], + ) -> MultipartToolResponse: + invocations.append(params.tool.name) + return await next_fn(params) + + ai = Genkit(plugins=[middleware_plugin([new_middleware(RecordingMW)])]) + pm, _ = define_programmable_model(ai) + + @ai.tool(name='approveMe') + async def approve_me() -> str: + return 'approved' + + pm.responses.append( + ModelResponse( + finish_reason=FinishReason.STOP, + message=Message(role=Role.MODEL, content=[Part(TextPart(text='final'))]), + ) + ) + + interrupt_part = ToolRequestPart( + tool_request=ToolRequest(name='approveMe', input={}, ref='r1'), + metadata={'interrupt': True}, + ) + + response = await generate_action( + ai.registry, + GenerateActionOptions( + model='programmableModel', + messages=[ + Message(role=Role.USER, content=[Part(TextPart(text='do it'))]), + Message(role=Role.MODEL, content=[Part(root=interrupt_part)]), + ], + tools=['approveMe'], + use=[MiddlewareRef(name='recording_mw')], + resume=Resume( + restart=[ + ToolRequestPart( + tool_request=ToolRequest(name='approveMe', input={}, ref='r1'), + metadata={'resumed': {'toolApproved': True}}, + ) + ], + ), + ), + ) + assert response.text == 'final' + assert invocations == ['approveMe'], f'expected wrap_tool to fire once on restart, saw: {invocations}' + + @pytest.mark.asyncio async def test_parallel_tool_requests_all_complete() -> None: """Multiple tool requests in one model turn are resolved together (asyncio.gather); all succeed.""" diff --git a/py/packages/genkit/tests/genkit/ai/middleware_test.py b/py/packages/genkit/tests/genkit/ai/middleware_test.py deleted file mode 100644 index fae6fe1c09..0000000000 --- a/py/packages/genkit/tests/genkit/ai/middleware_test.py +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - - -"""Tests for the action module.""" - -import asyncio - -import pytest - -from genkit import Document, Message, ModelResponse -from genkit._ai._middleware import augment_with_context -from genkit._core._action import ActionRunContext -from genkit._core._model import ModelRequest -from genkit._core._typing import ( - DocumentPart, - Part, - Role, - TextPart, -) - - -async def run_augmenter(req: ModelRequest) -> ModelRequest: - """Helper to run the augment_with_context middleware.""" - augmenter = augment_with_context() - req_future = asyncio.Future() - - async def next(req: ModelRequest, _: ActionRunContext) -> ModelResponse: - req_future.set_result(req) - return ModelResponse(message=Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))])) - - await augmenter(req, ActionRunContext(), next) - - return req_future.result() - - -@pytest.mark.asyncio -async def test_augment_with_context_ignores_no_docs() -> None: - """Test simple prompt rendering.""" - req = ModelRequest( - messages=[ - Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))]), - ], - ) - - transformed_req = await run_augmenter(req) - - assert transformed_req == req - - -@pytest.mark.asyncio -async def test_augment_with_context_adds_docs_as_context() -> None: - """Test simple prompt rendering.""" - req = ModelRequest( - messages=[ - Message(role=Role.USER, content=[Part(root=TextPart(text='hi'))]), - ], - docs=[ - Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), - Document(content=[DocumentPart(root=TextPart(text='doc content 2'))]), - ], - ) - - transformed_req = await run_augmenter(req) - - assert transformed_req == ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[ - Part(root=TextPart(text='hi')), - Part( - root=TextPart( - text='\n\nUse the following information to complete ' - + 'your task:\n\n' - + '- [0]: doc content 1\n' - + '- [1]: doc content 2\n\n', - metadata={'purpose': 'context'}, - ) - ), - ], - ) - ], - docs=[ - Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), - Document(content=[DocumentPart(root=TextPart(text='doc content 2'))]), - ], - ) - - -@pytest.mark.asyncio -async def test_augment_with_context_should_not_modify_non_pending_part() -> None: - """Test simple prompt rendering.""" - req = ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[ - Part( - root=TextPart( - text='this is already context', - metadata={'purpose': 'context'}, - ) - ), - Part(root=TextPart(text='hi')), - ], - ), - ], - docs=[ - Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), - ], - ) - - transformed_req = await run_augmenter(req) - - assert transformed_req == req - - -@pytest.mark.asyncio -async def test_augment_with_context_with_purpose_part() -> None: - """Test simple prompt rendering.""" - req = ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[ - Part( - root=TextPart( - text='insert context here', - metadata={'purpose': 'context', 'pending': True}, - ) - ), - Part(root=TextPart(text='hi')), - ], - ), - ], - docs=[ - Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), - ], - ) - - transformed_req = await run_augmenter(req) - - assert transformed_req == ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[ - Part( - root=TextPart( - text='\n\nUse the following information to complete ' - + 'your task:\n\n' - + '- [0]: doc content 1\n\n', - metadata={'purpose': 'context'}, - ) - ), - Part(root=TextPart(text='hi')), - ], - ) - ], - docs=[ - Document(content=[DocumentPart(root=TextPart(text='doc content 1'))]), - ], - ) diff --git a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py index 6131c7d3ab..a212d51f4d 100644 --- a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py +++ b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py @@ -44,8 +44,10 @@ import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient +from pydantic import Field from genkit._core._action import ActionKind +from genkit._core._middleware import BaseMiddleware, middleware, new_middleware from genkit._core._reflection import create_reflection_asgi_app from genkit._core._registry import Registry from genkit._core._typing import ActionMetadata @@ -309,3 +311,112 @@ async def mock_streaming( final_result = json.loads(lines[-1]) assert final_result['result'] == {'final': 'result'} + + +# Real-registry tests for the /api/values?type=middleware endpoint. The other +# endpoint tests above use a MagicMock registry, but here we want to exercise +# the actual MiddlewareDesc serialization path the Dev UI consumes — mocking +# would defeat the point. + + +async def _registry_asgi_client(registry: Registry) -> AsyncClient: + """Build an ASGI client wired to a real Registry instance.""" + app = create_reflection_asgi_app(registry) + transport = ASGITransport(app=app) + return AsyncClient(transport=transport, base_url='http://test') + + +@pytest.mark.asyncio +async def test_values_middleware_includes_derived_config_schema() -> None: + """The Dev UI's /api/values?type=middleware response carries each middleware's configSchema. + + Without this the Dev UI has nothing to render a config form from and falls + back to a free-text JSON box. The schema is derived from the middleware + class's pydantic fields by ``new_middleware``. + """ + + @middleware(name='fallback', description='Falls back to alternative models on failure') + class _Fallback(BaseMiddleware): + models: list[str] = Field(default_factory=list) + statuses: list[str] = Field(default_factory=list) + isolate_config: bool = False + + registry = Registry() + registry.register_value('middleware', 'fallback', new_middleware(_Fallback)) + + client = await _registry_asgi_client(registry) + try: + response = await client.get('/api/values?type=middleware') + assert response.status_code == 200 + body = response.json() + entry = body['fallback'] + assert entry['name'] == 'fallback' + assert entry['description'] == 'Falls back to alternative models on failure' + config_schema = entry['configSchema'] + assert config_schema['type'] == 'object' + # Author-defined fields show up; framework-injected ones (registry, + # enqueue_parts) must not leak into the form. + assert set(config_schema['properties'].keys()) == {'models', 'statuses', 'isolate_config'} + finally: + await client.aclose() + + +@pytest.mark.asyncio +async def test_values_middleware_empty_config_schema_for_no_op() -> None: + """A middleware with no config knobs still gets an (empty) object schema. + + The Dev UI renders an empty config form, signalling "registered, nothing to + configure" rather than dropping the user into a raw JSON editor. + """ + + @middleware(name='no_op') + class _NoOp(BaseMiddleware): + pass + + registry = Registry() + registry.register_value('middleware', 'no_op', new_middleware(_NoOp)) + + client = await _registry_asgi_client(registry) + try: + response = await client.get('/api/values?type=middleware') + assert response.status_code == 200 + entry = response.json()['no_op'] + assert entry['configSchema'] == { + 'type': 'object', + 'properties': {}, + 'additionalProperties': True, + } + finally: + await client.aclose() + + +@pytest.mark.asyncio +async def test_values_middleware_explicit_config_schema_wins() -> None: + """Explicit ``@middleware(config_schema=...)`` overrides the derived schema. + + Authors who hand-wrote a schema (often to add titles, descriptions, or + enum constraints the Dev UI uses for nicer form widgets) should keep it. + """ + explicit = { + 'type': 'object', + 'properties': {'mode': {'type': 'string', 'enum': ['fast', 'careful']}}, + 'required': ['mode'], + } + + @middleware(name='explicit_schema', config_schema=explicit) + class _Explicit(BaseMiddleware): + # Field exists on the class but the explicit schema wins; the Dev UI + # only sees what the author chose to expose. + ignored_field: int = 0 + + registry = Registry() + registry.register_value('middleware', 'explicit_schema', new_middleware(_Explicit)) + + client = await _registry_asgi_client(registry) + try: + response = await client.get('/api/values?type=middleware') + assert response.status_code == 200 + entry = response.json()['explicit_schema'] + assert entry['configSchema'] == explicit + finally: + await client.aclose() diff --git a/py/packages/genkit/tests/genkit/core/reflection_v2_test.py b/py/packages/genkit/tests/genkit/core/reflection_v2_test.py index 68e2cf4184..e11cb3843e 100644 --- a/py/packages/genkit/tests/genkit/core/reflection_v2_test.py +++ b/py/packages/genkit/tests/genkit/core/reflection_v2_test.py @@ -41,9 +41,11 @@ import pytest import pytest_asyncio +from pydantic import Field from websockets.asyncio.server import serve from genkit._core._action import Action, ActionKind, ActionRunContext +from genkit._core._middleware import BaseMiddleware, MiddlewareDesc, middleware, new_middleware from genkit._core._reflection_v2 import ( JSON_RPC_INVALID_PARAMS, JSON_RPC_METHOD_NOT_FOUND, @@ -248,6 +250,169 @@ async def test_reflection_server_v2_list_values(fake_manager: FakeReflectionMana await _stop_client(client, task) +@pytest.mark.asyncio +async def test_reflection_server_v2_list_values_serializes_middleware_as_object( + fake_manager: FakeReflectionManager, +) -> None: + """Registered middleware comes back as a JSON object, not pydantic's repr. + + Without explicit serialization the response would fall through to + ``json.dumps(default=str)`` and the dev-ui would receive the string + ``"name='concise_reply_mw' description=None ..."`` instead of the + ``MiddlewareDesc`` wire shape. + """ + + class _NoOpMiddleware(BaseMiddleware): + pass + + registry = Registry() + registry.register_value( + 'middleware', + 'concise_reply_mw', + MiddlewareDesc(factory=lambda _cfg: _NoOpMiddleware(), name='concise_reply_mw'), + ) + + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'listValues', + 'params': {'type': 'middleware'}, + 'id': '2b', + }) + resp = await fake_manager.read_rpc() + assert resp.get('id') == '2b' + values = resp['result']['values'] + assert values == {'concise_reply_mw': {'name': 'concise_reply_mw'}} + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_list_values_includes_derived_config_schema( + fake_manager: FakeReflectionManager, +) -> None: + """Middleware registered via ``new_middleware`` exposes a derived configSchema. + + The Dev UI uses this schema to render a config form for each registered + middleware. Without it the form has nothing to draw and the user is dumped + into a free-text JSON editor. + """ + + @middleware(name='fallback', description='Falls back to alternative models on failure') + class _Fallback(BaseMiddleware): + models: list[str] = Field(default_factory=list) + statuses: list[str] = Field(default_factory=list) + isolate_config: bool = False + + registry = Registry() + registry.register_value('middleware', 'fallback', new_middleware(_Fallback)) + + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'listValues', + 'params': {'type': 'middleware'}, + 'id': '2c', + }) + resp = await fake_manager.read_rpc() + assert resp.get('id') == '2c' + entry = resp['result']['values']['fallback'] + assert entry['name'] == 'fallback' + assert entry['description'] == 'Falls back to alternative models on failure' + config_schema = entry['configSchema'] + assert config_schema['type'] == 'object' + # Author-defined fields show up; framework-injected ones (registry, + # enqueue_parts) must not leak into the form. + props = config_schema['properties'] + assert set(props.keys()) == {'models', 'statuses', 'isolate_config'} + assert props['models']['type'] == 'array' + assert props['statuses']['type'] == 'array' + assert props['isolate_config']['type'] == 'boolean' + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_list_values_empty_config_schema_for_no_op( + fake_manager: FakeReflectionManager, +) -> None: + """A middleware with no config knobs still gets an (empty) object schema. + + The Dev UI renders an empty config form, signalling "registered, nothing to + configure" rather than dropping the user into a raw JSON editor. + """ + + @middleware(name='no_op') + class _NoOp(BaseMiddleware): + pass + + registry = Registry() + registry.register_value('middleware', 'no_op', new_middleware(_NoOp)) + + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'listValues', + 'params': {'type': 'middleware'}, + 'id': '2d', + }) + resp = await fake_manager.read_rpc() + entry = resp['result']['values']['no_op'] + assert entry['configSchema'] == { + 'type': 'object', + 'properties': {}, + 'additionalProperties': True, + } + finally: + await _stop_client(client, task) + + +@pytest.mark.asyncio +async def test_reflection_server_v2_list_values_explicit_config_schema_wins( + fake_manager: FakeReflectionManager, +) -> None: + """Explicit ``middleware_config_schema`` on the class overrides the derived one. + + Authors who hand-wrote a schema (often to add titles, descriptions, or + enum constraints the dev UI uses for nicer form widgets) should keep it. + """ + explicit = { + 'type': 'object', + 'properties': {'mode': {'type': 'string', 'enum': ['fast', 'careful']}}, + 'required': ['mode'], + } + + @middleware(name='explicit_schema', config_schema=explicit) + class _Explicit(BaseMiddleware): + # Field exists on the class but the explicit schema wins; the Dev UI + # only sees what the author chose to expose. + ignored_field: int = 0 + + registry = Registry() + registry.register_value('middleware', 'explicit_schema', new_middleware(_Explicit)) + + client, task = await _run_client_lifecycle(registry, fake_manager) + try: + await ack_register(fake_manager) + await fake_manager.write_rpc({ + 'jsonrpc': '2.0', + 'method': 'listValues', + 'params': {'type': 'middleware'}, + 'id': '2e', + }) + resp = await fake_manager.read_rpc() + entry = resp['result']['values']['explicit_schema'] + assert entry['configSchema'] == explicit + finally: + await _stop_client(client, task) + + @pytest.mark.asyncio async def test_reflection_server_v2_list_values_rejects_unsupported_type( fake_manager: FakeReflectionManager, diff --git a/py/packages/genkit/tests/genkit/core/registry_test.py b/py/packages/genkit/tests/genkit/core/registry_test.py index d77523fdc1..628a237d82 100644 --- a/py/packages/genkit/tests/genkit/core/registry_test.py +++ b/py/packages/genkit/tests/genkit/core/registry_test.py @@ -430,3 +430,11 @@ async def dap_fn() -> DapValue: assert catalog[qualified].key == qualified assert provider_key in catalog + + +def test_registry_satisfies_registry_like() -> None: + """Registry must structurally satisfy RegistryLike so middleware can use it as such.""" + from genkit._core._protocols import RegistryLike + from genkit._core._registry import Registry + + assert isinstance(Registry(None), RegistryLike) diff --git a/py/packages/genkit/tests/genkit/veneer/veneer_test.py b/py/packages/genkit/tests/genkit/veneer/veneer_test.py index c6918d9268..1287898cc3 100644 --- a/py/packages/genkit/tests/genkit/veneer/veneer_test.py +++ b/py/packages/genkit/tests/genkit/veneer/veneer_test.py @@ -6,7 +6,7 @@ """Tests for the action module.""" import json -from collections.abc import Awaitable, Callable +from collections.abc import Callable from typing import Any import pytest @@ -17,6 +17,7 @@ Genkit, Interrupt, Message, + MiddlewareRef, ModelResponse, ModelResponseChunk, respond_to_interrupt, @@ -52,6 +53,8 @@ ToolResponse, ToolResponsePart, ) +from genkit.middleware import BaseMiddleware, ModelHookParams, middleware +from genkit.plugin_api import middleware_plugin, new_middleware # type SetupFixture = tuple[Genkit, EchoModel, ProgrammableModel] SetupFixture = tuple[Genkit, EchoModel, ProgrammableModel] @@ -1000,30 +1003,27 @@ class TestSchema(BaseModel): assert (await stream_result.response).request == want -@pytest.mark.asyncio -async def test_generate_with_middleware( - setup_test: SetupFixture, -) -> None: - """When middleware is provided, applies it.""" - ai, *_ = setup_test - - async def pre_middle( - req: ModelRequest, ctx: ActionRunContext, next: Callable[..., Awaitable[ModelResponse]] - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message(role=Role.USER, content=[Part(root=TextPart(text=f'PRE {txt}'))]), - ], - ), - ctx, +@middleware(name='pre_mw') +class PreMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + txt = ''.join(text_from_message(m) for m in params.request.messages) + return await next_fn( + ModelHookParams( + request=ModelRequest( + messages=[ + Message(role=Role.USER, content=[Part(root=TextPart(text=f'PRE {txt}'))]), + ], + ), + on_chunk=params.on_chunk, + context=params.context, + ) ) - async def post_middle( - req: ModelRequest, ctx: ActionRunContext, next: Callable[..., Awaitable[ModelResponse]] - ) -> ModelResponse: - resp: ModelResponse = await next(req, ctx) + +@middleware(name='post_mw') +class PostMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + resp: ModelResponse = await next_fn(params) assert resp.message is not None txt = text_from_message(resp.message) return ModelResponse( @@ -1031,42 +1031,77 @@ async def post_middle( message=Message(role=Role.USER, content=[Part(root=TextPart(text=f'{txt} POST'))]), ) + +@pytest.mark.asyncio +async def test_generate_with_middleware() -> None: + """When middleware is provided, applies it.""" + ai = Genkit( + model='echoModel', + plugins=[ + middleware_plugin([ + new_middleware(PreMiddleware), + new_middleware(PostMiddleware), + ]) + ], + ) + define_programmable_model(ai) + define_echo_model(ai) + want = '[ECHO] user: "PRE hi" POST' - response = await ai.generate(model='echoModel', prompt='hi', use=[pre_middle, post_middle]) + response = await ai.generate( + model='echoModel', + prompt='hi', + use=[MiddlewareRef(name='pre_mw'), MiddlewareRef(name='post_mw')], + ) assert response.text == want - stream_result = ai.generate_stream(model='echoModel', prompt='hi', use=[pre_middle, post_middle]) + stream_result = ai.generate_stream( + model='echoModel', + prompt='hi', + use=[MiddlewareRef(name='pre_mw'), MiddlewareRef(name='post_mw')], + ) assert (await stream_result.response).text == want +@middleware(name='inject_ctx') +class InjectContextMiddleware(BaseMiddleware): + async def wrap_model(self, params: ModelHookParams, next_fn: Callable) -> ModelResponse: + txt = ''.join(text_from_message(m) for m in params.request.messages) + return await next_fn( + ModelHookParams( + request=ModelRequest( + messages=[ + Message( + role=Role.USER, + content=[Part(root=TextPart(text=f'{txt} {params.context}'))], + ), + ], + ), + on_chunk=params.on_chunk, + context=params.context, + ) + ) + + @pytest.mark.asyncio -async def test_generate_passes_through_current_action_context( - setup_test: SetupFixture, -) -> None: +async def test_generate_passes_through_current_action_context() -> None: """Test that generate uses current action context by default.""" - ai, *_ = setup_test - - async def inject_context( - req: ModelRequest, ctx: ActionRunContext, next: Callable[..., Awaitable[ModelResponse]] - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[Part(root=TextPart(text=f'{txt} {ctx.context}'))], - ), - ], - ), - ctx, - ) + ai = Genkit( + model='echoModel', + plugins=[middleware_plugin([new_middleware(InjectContextMiddleware)])], + ) + define_programmable_model(ai) + define_echo_model(ai) async def action_fn() -> ModelResponse: - return await ai.generate(model='echoModel', prompt='hi', use=[inject_context]) + return await ai.generate( + model='echoModel', + prompt='hi', + use=[MiddlewareRef(name='inject_ctx')], + ) action = ai.registry.register_action(name='test_action', kind=ActionKind.CUSTOM, fn=action_fn) action_response = await action.run(context={'foo': 'bar'}) @@ -1075,33 +1110,20 @@ async def action_fn() -> ModelResponse: @pytest.mark.asyncio -async def test_generate_uses_explicitly_passed_in_context( - setup_test: SetupFixture, -) -> None: +async def test_generate_uses_explicitly_passed_in_context() -> None: """Generate uses specific context instead of current action context.""" - ai, *_ = setup_test - - async def inject_context( - req: ModelRequest, ctx: ActionRunContext, next: Callable[..., Awaitable[ModelResponse]] - ) -> ModelResponse: - txt = ''.join(text_from_message(m) for m in req.messages) - return await next( - ModelRequest( - messages=[ - Message( - role=Role.USER, - content=[Part(root=TextPart(text=f'{txt} {ctx.context}'))], - ), - ], - ), - ctx, - ) + ai = Genkit( + model='echoModel', + plugins=[middleware_plugin([new_middleware(InjectContextMiddleware)])], + ) + define_programmable_model(ai) + define_echo_model(ai) async def action_fn() -> ModelResponse: return await ai.generate( model='echoModel', prompt='hi', - use=[inject_context], + use=[MiddlewareRef(name='inject_ctx')], context={'bar': 'baz'}, ) diff --git a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py index 464216c4dd..ef6e50cc83 100644 --- a/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py +++ b/py/plugins/google-genai/src/genkit/plugins/google_genai/models/utils.py @@ -47,6 +47,7 @@ """ import base64 +import logging from typing import cast from urllib.parse import urlparse @@ -68,6 +69,8 @@ ) from genkit.plugin_api import get_cached_client +logger = logging.getLogger(__name__) + class PartConverter: """Converts content parts between Genkit's internal representation and Gemini's API format. @@ -137,61 +140,58 @@ async def to_gemini(cls, part: Part | DocumentPart) -> genai.types.Part | list[g thought_signature=cls._extract_thought_signature(part.root.metadata), ) if isinstance(part.root, ToolResponsePart): - tool_output = part.root.tool_response.output - parts_to_return = [] - - # Check for multimodal content structure {content: [{media: ...}]} - if isinstance(tool_output, dict) and 'content' in tool_output: + tool_response = part.root.tool_response + tool_output = tool_response.output + + # FunctionResponse.response must be a dict, not a raw value. + output = tool_output if isinstance(tool_output, dict) else {'result': tool_output} + + # --- Primary path: ToolResponse.content (set by MultipartToolResponse.content) --- + # Mirrors JS: functionResponse.parts = content.map(toGeminiPart) + # Mirrors Go: genai.NewPartFromFunctionResponseWithParts(name, output, parts) + extra_parts: list[genai.types.Part] = [] + if tool_response.content: + for item in tool_response.content: + try: + genkit_part = Part.model_validate(item) + converted = await cls.to_gemini(genkit_part) + if isinstance(converted, list): + extra_parts.extend(converted) + else: + extra_parts.append(converted) + except Exception as exc: + logger.debug('Skipping unrecognised tool-response content part: %s', exc) + + # --- Legacy fallback: tools that embed media inside output['content'] --- + # Kept for backward compat; new middleware should use MultipartToolResponse.content. + if not extra_parts and isinstance(tool_output, dict) and 'content' in tool_output: content_list = tool_output['content'] if isinstance(content_list, list): - # Create a copy to avoid mutating original if that matters, - # but here we just want to separate content from other fields. - clean_output = tool_output.copy() - clean_output.pop('content') - - # Heuristic: if media found, extract it to separate parts. - has_media = False + clean_output = {k: v for k, v in tool_output.items() if k != 'content'} for item in content_list: if isinstance(item, dict) and 'media' in item: - has_media = True media_info = item['media'] - url = media_info.get('url') + url = media_info.get('url') or '' content_type = media_info.get('contentType') or media_info.get('content_type') - - if url and url.startswith(cls.DATA): + if url.startswith(cls.DATA): _, data_str = url.split(',', 1) data = base64.b64decode(data_str) - parts_to_return.append( + extra_parts.append( genai.types.Part(inline_data=genai.types.Blob(mime_type=content_type, data=data)) ) + if extra_parts: + output = clean_output - if has_media: - # Append the function response part FIRST (contextually correct) - parts_to_return.insert( - 0, - genai.types.Part( - function_response=genai.types.FunctionResponse( - id=part.root.tool_response.ref, - name=part.root.tool_response.name.replace('/', '__'), - response=clean_output, - ) - ), - ) - return parts_to_return - - # Default behavior for standard tool responses - # FunctionResponse.response must be a dict, not a raw value - output = tool_output - if not isinstance(output, dict): - output = {'result': output} - - return genai.types.Part( + fn_part = genai.types.Part( function_response=genai.types.FunctionResponse( - id=part.root.tool_response.ref, - name=part.root.tool_response.name.replace('/', '__'), + id=tool_response.ref, + name=tool_response.name.replace('/', '__'), response=output, ) ) + if extra_parts: + return [fn_part, *extra_parts] + return fn_part if isinstance(part.root, MediaPart): url = part.root.media.url if url.startswith(cls.DATA): diff --git a/py/pyproject.toml b/py/pyproject.toml index b30a18f561..abd97f3a31 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -154,7 +154,6 @@ flask-hello = { workspace = true } gemini-code-execution = { workspace = true } gemini-context-caching = { workspace = true } google-genai-media = { workspace = true } -middleware = { workspace = true } output-formats = { workspace = true } prompts = { workspace = true } tool-interrupts = { workspace = true } diff --git a/py/samples/middleware/README.md b/py/samples/middleware/README.md deleted file mode 100644 index 083fd9a815..0000000000 --- a/py/samples/middleware/README.md +++ /dev/null @@ -1,17 +0,0 @@ -# Middleware - -Intercept or modify model requests with `use=` on `ai.generate()`. - -```bash -export GEMINI_API_KEY=your-api-key -uv sync -uv run src/main.py -``` - -To inspect the flows in Dev UI instead: - -```bash -genkit start -- uv run src/main.py -``` - -Try `logging_demo` and `request_modifier_demo`. diff --git a/py/samples/middleware/pyproject.toml b/py/samples/middleware/pyproject.toml deleted file mode 100644 index 3ae1c686a4..0000000000 --- a/py/samples/middleware/pyproject.toml +++ /dev/null @@ -1,18 +0,0 @@ -[project] -name = "middleware" -version = "0.2.0" -requires-python = ">=3.10" -dependencies = [ - "genkit", - "genkit-plugin-google-genai", - "pydantic>=2.0.0", - "structlog>=24.0.0", - "uvloop>=0.21.0", -] - -[build-system] -build-backend = "hatchling.build" -requires = ["hatchling"] - -[tool.hatch.build.targets.wheel] -packages = ["src"] diff --git a/py/samples/middleware/src/main.py b/py/samples/middleware/src/main.py deleted file mode 100644 index fe4496bbe3..0000000000 --- a/py/samples/middleware/src/main.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# SPDX-License-Identifier: Apache-2.0 - -"""Middleware - inspect or modify requests before they reach the model.""" - -from collections.abc import Awaitable, Callable - -import structlog -from pydantic import BaseModel, Field - -from genkit import Genkit, Message, ModelRequest, ModelResponse, Part, Role, TextPart -from genkit._core._action import ActionRunContext -from genkit.plugins.google_genai import GoogleAI - -logger = structlog.get_logger(__name__) -ai = Genkit(plugins=[GoogleAI()], model='googleai/gemini-2.5-flash') - - -class PromptInput(BaseModel): - """Input shared by middleware flows.""" - - prompt: str = Field(default='Explain recursion simply.', description='Prompt to send to the model') - - -async def logging_middleware( - req: ModelRequest, - ctx: ActionRunContext, - next_handler: Callable[[ModelRequest, ActionRunContext], Awaitable[ModelResponse]], -) -> ModelResponse: - """Log request/response details without changing behavior.""" - - await logger.ainfo('middleware saw request', message_count=len(req.messages)) - response = await next_handler(req, ctx) - await logger.ainfo('middleware saw response', finish_reason=response.finish_reason) - return response - - -async def concise_reply_middleware( - req: ModelRequest, - ctx: ActionRunContext, - next_handler: Callable[[ModelRequest, ActionRunContext], Awaitable[ModelResponse]], -) -> ModelResponse: - """Add a short system instruction before the model call.""" - - system_message = Message( - role=Role.SYSTEM, - content=[Part(root=TextPart(text='Answer in one short paragraph.'))], - ) - return await next_handler(req.model_copy(update={'messages': [system_message, *req.messages]}), ctx) - - -@ai.flow() -async def logging_demo(input: PromptInput) -> str: - """Run a prompt through a read-only middleware.""" - - response = await ai.generate(prompt=input.prompt, use=[logging_middleware]) - return response.text - - -@ai.flow() -async def request_modifier_demo(input: PromptInput) -> str: - """Run a prompt through a request-modifying middleware.""" - - response = await ai.generate(prompt=input.prompt, use=[concise_reply_middleware]) - return response.text - - -async def main() -> None: - """Run both middleware demos once.""" - try: - print(await logging_demo(PromptInput())) # noqa: T201 - print(await request_modifier_demo(PromptInput(prompt='Write a haiku about recursion.'))) # noqa: T201 - except Exception as error: - print(f'Set GEMINI_API_KEY to a valid value before running this sample directly.\n{error}') # noqa: T201 - - -if __name__ == '__main__': - ai.run_main(main()) diff --git a/py/tools/schema_to_typing/schema_to_typing.py b/py/tools/schema_to_typing/schema_to_typing.py index 4e38b6e0df..337e95f68d 100644 --- a/py/tools/schema_to_typing/schema_to_typing.py +++ b/py/tools/schema_to_typing/schema_to_typing.py @@ -41,6 +41,7 @@ TRANSFORMATIONS = { 'Message': {'output_name': 'MessageData'}, 'GenerateActionOptions': {'suffix': 'Data', 'omit': ['messages']}, + 'MiddlewareDesc': {'output_name': 'MiddlewareDescData'}, } diff --git a/py/uv.lock b/py/uv.lock index d90bde07f4..f95e41a86f 100644 --- a/py/uv.lock +++ b/py/uv.lock @@ -29,7 +29,6 @@ members = [ "genkit-plugin-vertex-ai", "genkit-workspace", "google-genai-media", - "middleware", "output-formats", "prompts", "tool-interrupts", @@ -3462,27 +3461,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] -[[package]] -name = "middleware" -version = "0.2.0" -source = { editable = "samples/middleware" } -dependencies = [ - { name = "genkit" }, - { name = "genkit-plugin-google-genai" }, - { name = "pydantic" }, - { name = "structlog" }, - { name = "uvloop" }, -] - -[package.metadata] -requires-dist = [ - { name = "genkit", editable = "packages/genkit" }, - { name = "genkit-plugin-google-genai", editable = "plugins/google-genai" }, - { name = "pydantic", specifier = ">=2.0.0" }, - { name = "structlog", specifier = ">=24.0.0" }, - { name = "uvloop", specifier = ">=0.21.0" }, -] - [[package]] name = "mistune" version = "3.2.1"