diff --git a/py/packages/genkit/src/genkit/_ai/_generate.py b/py/packages/genkit/src/genkit/_ai/_generate.py index 98a84752ae..8704ef361b 100644 --- a/py/packages/genkit/src/genkit/_ai/_generate.py +++ b/py/packages/genkit/src/genkit/_ai/_generate.py @@ -38,7 +38,12 @@ ) from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources from genkit._ai._tools import Tool, ToolInterruptError -from genkit._core._action import Action, ActionKind, ActionRunContext +from genkit._core._action import ( + GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, + Action, + ActionKind, + ActionRunContext, +) from genkit._core._error import GenkitError from genkit._core._logger import get_logger from genkit._core._model import GenerateActionOptions @@ -61,6 +66,51 @@ logger = get_logger(__name__) +async def expand_wildcard_tools(registry: Registry, tool_names: list[str]) -> list[str]: + """Expand DAP wildcard tool names into individual registry keys. + + A wildcard has the form ``:tool/*`` (or ``:tool/*``). + Each match becomes a full DAP key + ``/dynamic-action-provider/:/`` so later resolution + stays bound to that provider (no ambiguous bare-name lookup across DAPs). + + Non-wildcard names are passed through unchanged. + """ + expanded: list[str] = [] + for name in tool_names: + if not name.endswith('*') or ':' not in name: + expanded.append(name) + continue + + colon = name.index(':') + provider_name = name[:colon] + rest = name[colon + 1 :] # e.g. "tool/*" or "tool/prefix*" + + provider_action = await registry.resolve_action(ActionKind.DYNAMIC_ACTION_PROVIDER, provider_name) + if provider_action is None: + expanded.append(name) + continue + + dap = getattr(provider_action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None) + if dap is None: + expanded.append(name) + continue + + if '/' not in rest: + expanded.append(name) + continue + + action_type, action_pattern = rest.split('/', 1) + metas = await dap.list_action_metadata(action_type, action_pattern) + for meta in metas: + tool_name = meta.get('name') + if tool_name: + tn = str(tool_name) + expanded.append(f'/dynamic-action-provider/{provider_name}:{action_type}/{tn}') + + return expanded + + def tools_to_action_names( tools: Sequence[str | Tool] | None, ) -> list[str] | None: @@ -158,20 +208,27 @@ async def _generate_action( context: dict[str, Any] | None = None, ) -> ModelResponse: """Execute a generation request with tool calling and middleware support.""" - model, tools, format_def = await resolve_parameters(registry, raw_request) + effective_registry = registry if registry.is_child else registry.new_child() + + tools_in = raw_request.tools + if tools_in: + raw_request = raw_request.model_copy() + raw_request.tools = await expand_wildcard_tools(effective_registry, tools_in) + + model, tools, format_def = await resolve_parameters(effective_registry, raw_request) raw_request, formatter = apply_format(raw_request, format_def) if raw_request.resources: - raw_request = await apply_resources(registry, raw_request) + raw_request = await apply_resources(effective_registry, raw_request) - assert_valid_tool_names(raw_request) + assert_valid_tool_names(tools) ( revised_request, interrupted_response, resumed_tool_message, - ) = await _resolve_resume_options(registry, raw_request) + ) = await _resolve_resume_options(effective_registry, raw_request) # NOTE: in the future we should make it possible to interrupt a restart, but # at the moment it's too complicated because it's not clear how to return a @@ -374,7 +431,7 @@ def message_parser(msg: Message) -> Any: # noqa: ANN401 revised_model_msg, tool_msg, transfer_preamble, - ) = await resolve_tool_requests(registry, raw_request, generated_msg) + ) = await resolve_tool_requests(effective_registry, raw_request, generated_msg) # if an interrupt message is returned, stop the tool loop and return a # response. @@ -408,7 +465,7 @@ def message_parser(msg: Message) -> Any: # noqa: ANN401 # then recursively call for another loop return await _generate_action( - registry, + effective_registry, raw_request=next_request, # middleware: middleware, current_turn=current_turn + 1, @@ -584,10 +641,30 @@ async def apply_resources(registry: Registry, raw_request: GenerateActionOptions return new_request -def assert_valid_tool_names(_raw_request: GenerateActionOptions) -> None: - """Validate tool names in the request. (TODO: not yet implemented).""" - # TODO(#4338): implement me - pass +def _tool_short_name_for_model(name: str) -> str: + """Return the last path segment of a tool name.""" + if '/' not in name: + return name + return name[name.rfind('/') + 1 :] + + +def assert_valid_tool_names(tools: list[Action[Any, Any, Any]]) -> None: + """Reject overlapping model-facing tool names before the model is called. + + Two resolved tools that share the same short name (segment after the last ``/``) + cannot both appear in one generate request. + """ + if not tools: + return + seen: dict[str, str] = {} + for tool in tools: + short = _tool_short_name_for_model(tool.name) + if short in seen: + raise GenkitError( + status='INVALID_ARGUMENT', + message=(f"Cannot provide two tools with the same name: '{tool.name}' and '{seen[short]}'"), + ) + seen[short] = tool.name async def resolve_parameters( @@ -605,9 +682,10 @@ async def resolve_parameters( tools: list[Action[Any, Any, Any]] = [] if request.tools: for tool_name in request.tools: - tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name) - if tool_action is None: - raise Exception(f'Unable to resolve tool {tool_name}') + try: + tool_action = await resolve_tool(registry, tool_name) + except GenkitError as e: + raise Exception(f'Unable to resolve tool {tool_name}') from e tools.append(tool_action) format_def: FormatDef | None = None @@ -665,7 +743,12 @@ async def resolve_tool_requests( tool_dict: dict[str, Action] = {} if request.tools: for tool_name in request.tools: - tool_dict[tool_name] = await resolve_tool(registry, tool_name) + tool_action = await resolve_tool(registry, tool_name) + tool_dict[tool_name] = tool_action + # Model tool calls use ToolDefinition.name (short); wildcard expansion uses full DAP keys. + short = tool_action.name + if short not in tool_dict: + tool_dict[short] = tool_action revised_model_message = message.model_copy(deep=True) @@ -762,11 +845,19 @@ async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart async def resolve_tool(registry: Registry, tool_ref: str | Tool) -> Action: """Resolve a tool from a registry name or a Tool instance. + Accepts full action keys (``/dynamic-action-provider/...``), DAP-qualified + names (``provider:tool/name``), or plain registered tool names. + Used when building ModelRequest (for example from to_generate_request). """ if isinstance(tool_ref, Tool): return tool_ref.action + if tool_ref.startswith('/'): + tool = await registry.resolve_action_by_key(tool_ref) + if tool is not None: + return tool + tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_ref) if tool is None: raise GenkitError(status='NOT_FOUND', message=f'Unable to resolve tool {tool_ref}') diff --git a/py/packages/genkit/src/genkit/_core/_action.py b/py/packages/genkit/src/genkit/_core/_action.py index 23057c2f01..b352d09a55 100644 --- a/py/packages/genkit/src/genkit/_core/_action.py +++ b/py/packages/genkit/src/genkit/_core/_action.py @@ -19,11 +19,12 @@ import asyncio import inspect import json +import re import time from collections.abc import AsyncIterator, Awaitable, Callable, Generator, Mapping from contextlib import contextmanager from contextvars import ContextVar -from typing import Any, ClassVar, Generic, cast, get_type_hints +from typing import Any, ClassVar, Generic, NamedTuple, cast, get_type_hints from opentelemetry import trace as trace_api from opentelemetry.trace import Span @@ -228,6 +229,43 @@ def extract_action_args_and_types( # ============================================================================= +GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR = '_genkit_dynamic_action_provider' +# Nested actions in a DAP cache store their qualified reflection key via this attribute name: +# ``/dynamic-action-provider/:/``. +GENKIT_DAP_QUALIFIED_KEY_ATTR = '_genkit_dap_qualified_key' + + +class DapQualifiedName(NamedTuple): + """Segments of a DAP-qualified name ``provider:innerKind/innerName``.""" + + provider: str + inner_kind: str + inner_name: str + + +def parse_dap_qualified_name(name: str) -> DapQualifiedName | None: + """Parse DAP-qualified segment ``provider:innerKind/innerName``. + + Used when the action key kind is ``dynamic-action-provider`` and the name + references a nested action exposed by a provider (e.g. MCP tools). + + Pattern: ``[provider]:[inner_kind]/[inner_name]`` — no slashes in the + provider segment (``plugin/foo`` is not a valid provider host). + + Returns: + A :class:`DapQualifiedName` if the string matches; otherwise ``None`` so + callers can treat the name as a plain dynamic-action-provider id. + """ + # Pattern: [provider]:[inner_kind]/[inner_name]; no '/' or ':' in provider. + match = re.match(r'^([^/:]+):([^/:]+)/(.+)$', name) + if not match: + return None + provider, inner_kind, inner_name = match.groups() + if not provider or not inner_kind or not inner_name: + return None + return DapQualifiedName(provider, inner_kind, inner_name) + + def parse_action_key(key: str) -> tuple[ActionKind, str]: """Parse '//' key into (ActionKind, name).""" tokens = key.split('/') @@ -246,11 +284,25 @@ def parse_action_key(key: str) -> tuple[ActionKind, str]: return kind, name -def create_action_key(kind: ActionKind, name: str) -> str: +def create_action_key(kind: ActionKind | str, name: str) -> str: """Create '//' key.""" return f'/{kind}/{name}' +def parse_dap_provider_host(name: str) -> str | None: + """Return the segment before the first ``:`` when the name has multiple ``:``-split parts. + + If there is no ``:``, or the first segment is empty, returns ``None``. + """ + parts = name.split(':') + if len(parts) < 2: + return None + host = parts[0] + if not host: + return None + return host + + # ============================================================================= # Action core # ============================================================================= diff --git a/py/packages/genkit/src/genkit/_core/_dap.py b/py/packages/genkit/src/genkit/_core/_dap.py index 5d509786e5..d96c9b2436 100644 --- a/py/packages/genkit/src/genkit/_core/_dap.py +++ b/py/packages/genkit/src/genkit/_core/_dap.py @@ -21,13 +21,43 @@ from collections.abc import Awaitable, Callable, Mapping from typing import Any -from genkit._core._action import Action, ActionKind +from genkit._core._action import ( + GENKIT_DAP_QUALIFIED_KEY_ATTR, + GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, + Action, + ActionKind, + create_action_key, +) +from genkit._core._error import GenkitError from genkit._core._registry import Registry ActionMetadataLike = Mapping[str, object] DapValue = dict[str, list[Action[Any, Any]]] DapFn = Callable[[], Awaitable[DapValue]] -DapMetadata = dict[str, list[ActionMetadataLike]] + + +def _qualified_dap_key(dap_id: str, action_type: str, child_name: str) -> str: + return create_action_key(ActionKind.DYNAMIC_ACTION_PROVIDER, f'{dap_id}:{action_type}/{child_name}') + + +def _transform_dap_value(value: DapValue, dap_id: str) -> list[dict[str, Any]]: + """Flatten child actions into reflection-style metadata rows for the DAP action result.""" + rows: list[dict[str, Any]] = [] + for action_type, actions in value.items(): + for child in actions: + rows.append( + { + 'key': _qualified_dap_key(dap_id, action_type, child.name), + 'name': child.name, + 'type': action_type, + 'description': child.description, + 'inputSchema': child.input_schema, + 'outputSchema': child.output_schema, + 'metadata': dict(child.metadata) if child.metadata else None, + } + ) + return rows + # Default cache TTL in milliseconds _DEFAULT_CACHE_TTL_MS = 3000 @@ -51,6 +81,19 @@ def __init__( _DEFAULT_CACHE_TTL_MS if cache_ttl_millis is None or cache_ttl_millis == 0 else cache_ttl_millis ) + def set_value(self, value: DapValue) -> None: + """Update the cached value and reset expiry; assign each child its DAP qualified reflection key.""" + dap_id = self.action.name + for action_type, actions in value.items(): + for child in actions: + setattr( + child, + GENKIT_DAP_QUALIFIED_KEY_ATTR, + _qualified_dap_key(dap_id, action_type, child.name), + ) + self._value = value + self._expires_at = time.time() * 1000 + self._ttl_millis + def invalidate_cache(self) -> None: self._value = None self._expires_at = None @@ -77,12 +120,22 @@ async def _get_or_fetch(self, skip_trace: bool = False) -> DapValue: async def _do_fetch(self, skip_trace: bool) -> DapValue: try: - self._value = await self._dap_fn() - self._expires_at = time.time() * 1000 + self._ttl_millis - if not skip_trace: - metadata = {k: [a.metadata or {} for a in v] for k, v in self._value.items()} - await self.action.run(metadata) - return self._value + if skip_trace: + # Bypass the action (and its trace span) for reflection/devtools + # listing, which would otherwise flood traces. + value = await self._dap_fn() + self.set_value(value) + return value + else: + # Run through the action so the fetch is wrapped in a trace span. + # The action body calls _dap_fn() and set_value(). + try: + await self.action.run() + except GenkitError as e: + raise e.__cause__ or e + if self._value is None: + raise ValueError('DAP value undefined after action run') + return self._value except Exception: self.invalidate_cache() raise @@ -140,8 +193,14 @@ def define_dynamic_action_provider( ) -> DynamicActionProvider: """Define and register a Dynamic Action Provider for lazy action resolution.""" - async def dap_action(input: DapMetadata) -> DapMetadata: - return input + # Filled in immediately after construction; the closure reads it when run() fires. + _dap: DynamicActionProvider | None = None + + async def dap_action(input: None) -> list[dict[str, Any]]: + value = await fn() + assert _dap is not None + _dap.set_value(value) + return _transform_dap_value(value, _dap.action.name) action = registry.register_action( name=name, @@ -151,4 +210,7 @@ async def dap_action(input: DapMetadata) -> DapMetadata: metadata={**(metadata or {}), 'type': 'dynamic-action-provider'}, ) - return DynamicActionProvider(action, fn, cache_ttl_millis) + dap = DynamicActionProvider(action, fn, cache_ttl_millis) + _dap = dap + setattr(action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, dap) + return dap diff --git a/py/packages/genkit/src/genkit/_core/_reflection.py b/py/packages/genkit/src/genkit/_core/_reflection.py index 130f3ad8f8..07cb07580f 100644 --- a/py/packages/genkit/src/genkit/_core/_reflection.py +++ b/py/packages/genkit/src/genkit/_core/_reflection.py @@ -36,7 +36,7 @@ from starlette.responses import JSONResponse, Response, StreamingResponse from starlette.routing import Route -from genkit._core._action import Action, ActionKind +from genkit._core._action import Action from genkit._core._constants import GENKIT_VERSION from genkit._core._error import get_reflection_json from genkit._core._logger import get_logger @@ -138,53 +138,6 @@ async def gen() -> AsyncGenerator[str, None]: return StreamingResponse(gen(), media_type='text/plain' if self.stream else 'application/json', headers=headers) -async def _get_actions_payload(registry: Registry) -> dict[str, dict[str, Any]]: - actions: dict[str, dict[str, Any]] = {} - - for kind in ActionKind.__members__.values(): - for name, action in (await registry.resolve_actions_by_kind(kind)).items(): - key = f'/{kind}/{name}' - actions[key] = { - 'key': key, - 'name': action.name, - 'type': action.kind, - 'description': action.description, - 'inputSchema': action.input_schema, - 'outputSchema': action.output_schema, - 'metadata': action.metadata, - } - - for meta in await registry.list_actions() or []: - try: - key = f'/{meta.kind}/{meta.name}' - except Exception as exc: - logger.warning('Skipping invalid plugin metadata: %s', exc) - continue - - advertised = { - 'key': key, - 'name': meta.name, - 'type': meta.kind, - 'description': getattr(meta, 'description', None), - 'inputSchema': getattr(meta, 'input_json_schema', None), - 'outputSchema': getattr(meta, 'output_json_schema', None), - 'metadata': getattr(meta, 'metadata', None), - } - - if key not in actions: - actions[key] = advertised - else: - existing = actions[key] - for f in ('description', 'inputSchema', 'outputSchema'): - if not existing.get(f) and advertised.get(f): - existing[f] = advertised[f] - if isinstance(existing.get('metadata'), dict) and isinstance(advertised.get('metadata'), dict): - # isinstance checks above guarantee both are dicts, but ty can't narrow .get() results - existing['metadata'] = {**advertised['metadata'], **existing['metadata']} # ty: ignore[invalid-argument-type] - - return actions - - def create_reflection_asgi_app( registry: Registry, on_startup: LifecycleHook | None = None, @@ -194,6 +147,7 @@ def create_reflection_asgi_app( active_actions: dict[str, asyncio.Task[Any]] = {} async def health(_: Request) -> JSONResponse: + await registry.list_actions() return JSONResponse({'status': 'OK'}) async def terminate(_: Request) -> JSONResponse: @@ -202,7 +156,8 @@ async def terminate(_: Request) -> JSONResponse: return JSONResponse({'status': 'OK'}) async def actions(_: Request) -> JSONResponse: - return JSONResponse(await _get_actions_payload(registry), headers={'x-genkit-version': version}) + # Full catalog: list_resolvable_actions (plugins, registered, DAP; merged with parent). + return JSONResponse(await registry.list_resolvable_actions(), headers={'x-genkit-version': version}) async def values(req: Request) -> JSONResponse: if req.query_params.get('type') != 'defaultModel': @@ -238,6 +193,14 @@ async def run(req: Request) -> Response: ) return await runner.stream_response(version) + async def reflection_startup() -> None: + # Eagerly initialize plugins and enumerate concrete actions before handling traffic. + await registry.list_actions() + + startup_hooks: list[LifecycleHook] = [reflection_startup] + if on_startup is not None: + startup_hooks.append(on_startup) + app = Starlette( routes=[ Route('/api/__health', health, methods=['GET']), @@ -258,7 +221,7 @@ async def run(req: Request) -> Response: expose_headers=['X-Genkit-Trace-Id', 'X-Genkit-Span-Id', 'x-genkit-version'], ) ], - on_startup=[on_startup] if on_startup else [], + on_startup=startup_hooks, on_shutdown=[on_shutdown] if on_shutdown else [], ) app.active_actions = active_actions # type: ignore[attr-defined] diff --git a/py/packages/genkit/src/genkit/_core/_registry.py b/py/packages/genkit/src/genkit/_core/_registry.py index b4731900be..428d3408a8 100644 --- a/py/packages/genkit/src/genkit/_core/_registry.py +++ b/py/packages/genkit/src/genkit/_core/_registry.py @@ -16,23 +16,29 @@ """Registry for managing Genkit resources and actions.""" +from __future__ import annotations + import asyncio import threading -from collections.abc import Awaitable, Callable -from typing import cast +from collections.abc import Awaitable, Callable, Mapping +from typing import Any, cast from dotpromptz.dotprompt import Dotprompt from pydantic import BaseModel from typing_extensions import Never, TypeVar from genkit._core._action import ( + GENKIT_DAP_QUALIFIED_KEY_ATTR, + GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, Action, ActionKind, ActionMetadata, ActionName, ActionRunContext, SpanAttributeValue, + create_action_key, parse_action_key, + parse_dap_qualified_name, ) from genkit._core._logger import get_logger from genkit._core._model import ( @@ -76,6 +82,45 @@ ) +def _reflection_payload_for_registered_action(action: Action) -> dict[str, Any]: + dap_key = getattr(action, GENKIT_DAP_QUALIFIED_KEY_ATTR, None) + key = dap_key if isinstance(dap_key, str) else create_action_key(action.kind, action.name) + return { + 'key': key, + 'name': action.name, + 'type': action.kind, + 'description': action.description, + 'inputSchema': action.input_schema, + 'outputSchema': action.output_schema, + 'metadata': action.metadata, + } + + +def _reflection_payload_for_plugin_metadata(meta: ActionMetadata) -> dict[str, Any]: + key = f'/{meta.kind}/{meta.name}' + return { + 'key': key, + 'name': meta.name, + 'type': meta.kind, + 'description': meta.description, + 'inputSchema': meta.input_json_schema, + 'outputSchema': meta.output_json_schema, + 'metadata': meta.metadata, + } + + +def _reflection_payload_for_dap_metadata(full_key: str, meta: Mapping[str, object]) -> dict[str, Any]: + return { + 'key': full_key, + 'name': meta.get('name'), + 'type': meta.get('type'), + 'description': meta.get('description'), + 'inputSchema': meta.get('inputSchema') or meta.get('input_json_schema'), + 'outputSchema': meta.get('outputSchema') or meta.get('output_json_schema'), + 'metadata': dict(meta), + } + + class Registry: """Central repository for Genkit resources. @@ -84,6 +129,11 @@ class Registry: plugins, and schemas. It provides methods for registering new resources and looking them up at runtime. + Supports a **child registry** pattern (see ``new_child``): a child registry + delegates lookups to its parent when a name is not found locally. This is + used to create cheap, ephemeral registries scoped to a single generate call + (for DAP-resolved tools) without polluting the root registry. + This class is thread-safe and can be safely shared between multiple threads. Attributes: @@ -91,10 +141,17 @@ class Registry: action names and their corresponding Action instances. """ - default_model: str | None = None + def __init__(self, parent: Registry | None = None) -> None: + """Initialize a Registry instance. - def __init__(self) -> None: - """Initialize an empty Registry instance.""" + Args: + parent: Optional parent registry. When provided this is a *child* + registry that falls back to the parent for any lookup that + returns ``None`` locally. Use ``new_child()`` as the + preferred factory rather than passing ``parent`` directly. + """ + self._parent: Registry | None = parent + self._default_model: str | None = None self._entries: ActionStore = {} self._value_by_kind_and_name: dict[str, dict[str, object]] = {} self._schemas_by_name: dict[str, dict[str, object]] = {} @@ -106,12 +163,15 @@ def __init__(self) -> None: # https://github.com/firebase/genkit/issues/4491). self._loading_actions: set[str] = set() - # Initialize Dotprompt with schema_resolver to match JS SDK pattern - # Use async function to avoid thread pool deadlock in resolve_json_schema + # Dotprompt resolves ``output.schema`` names via the registry's stored schemas. + # Async resolver avoids thread-pool deadlock in ``resolve_json_schema``. async def async_schema_resolver(name: str) -> dict[str, object] | str: return self.lookup_schema(name) or name - self.dotprompt: Dotprompt = Dotprompt(schema_resolver=async_schema_resolver) + # Children share the parent's Dotprompt instance (prompts are global). + self.dotprompt: Dotprompt = ( + parent.dotprompt if parent is not None else Dotprompt(schema_resolver=async_schema_resolver) + ) # TODO(#4352): Figure out how to set this. self.api_stability: str = 'stable' @@ -127,6 +187,45 @@ async def async_schema_resolver(name: str) -> dict[str, object] | str: # reflection server schedules coroutines onto that loop.) self._plugins: dict[str, Plugin] = {} self._plugin_init_tasks: dict[str, asyncio.Task[None]] = {} + self._all_plugins_initialized: bool = False + + # ------------------------------------------------------------------------- + # Child registry support + # ------------------------------------------------------------------------- + + def new_child(self) -> Registry: + """Create a cheap child registry that inherits from this registry. + + Child registries are used to create short-lived, ephemeral scopes (e.g. + per-generate-call tool registrations from a DAP) without polluting the + root registry. Any lookup that fails locally falls back to this parent. + Writes on the child never propagate back to the parent. + + Returns: + A new ``Registry`` whose parent is ``self``. + """ + return Registry(parent=self) + + @property + def parent(self) -> Registry | None: + """The parent registry, or ``None`` if this is a root registry.""" + return self._parent + + @property + def is_child(self) -> bool: + """``True`` if this registry has a parent.""" + return self._parent is not None + + @property + def default_model(self) -> str | None: + """The default model name, falling back to parent if not set locally.""" + if self._default_model is not None: + return self._default_model + return self._parent.default_model if self._parent is not None else None + + @default_model.setter + def default_model(self, value: str | None) -> None: + self._default_model = value def register_action( self, @@ -204,6 +303,118 @@ async def resolve_actions_by_kind(self, kind: ActionKind) -> dict[str, Action]: await self._trigger_lazy_loading(action) return actions + async def list_actions(self) -> dict[str, Action]: + """Return every concrete :class:`Action` in ``_entries``, keyed by ``//``. + + Ensures plugins are initialized first so ``init()``-registered actions appear. + Merges with the parent registry when present; on duplicate keys the child wins. + For advertised-only and DAP-expanded metadata (reflection catalog), use + :meth:`list_resolvable_actions`. + + Returns: + Map of action key string to :class:`Action` instance. + """ + await self.initialize_all_plugins() + local: dict[str, Action] = {} + for kind in ActionKind.__members__.values(): + for name, action in (await self.resolve_actions_by_kind(kind)).items(): + local[create_action_key(kind, name)] = action + if self._parent is None: + return local + parent_actions = await self._parent.list_actions() + return {**parent_actions, **local} + + async def list_resolvable_actions(self) -> dict[str, dict[str, Any]]: + """Return reflection metadata for plugins, registered actions, and DAP-expanded tools. + + Builds plugin rows from each plugin's ``list_actions()``, then fills registered + actions and DAP expansions via :meth:`list_actions` (which initializes plugins). + Merges with parent's list_resolvable_actions() output (prefer child entries on duplicate keys). + + Returns: + Map of action key to reflection-style payload dicts (``key``, ``name``, ``type``, etc.). + """ + local: dict[str, dict[str, Any]] = {} + + with self._lock: + plugins = list(self._plugins.items()) + for plugin_name, plugin in plugins: + try: + plugin_metas = await plugin.list_actions() + except Exception: + logger.exception('Error listing actions for plugin %s', plugin_name) + continue + for meta in plugin_metas or []: + if not meta.name: + raise ValueError(f'Invalid ActionMetadata from {plugin_name}: name required') + if '/' not in meta.name: + meta = meta.model_copy(update={'name': f'{plugin_name}/{meta.name}'}) + key = f'/{meta.kind}/{meta.name}' + local[key] = _reflection_payload_for_plugin_metadata(meta) + + actions_dict = await self.list_actions() + actions = actions_dict.items() + for key, action in actions: + local[key] = _reflection_payload_for_registered_action(action) + dap = getattr(action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None) + if dap is None: + continue + try: + # Record keys use the provider action ``name``; see + # :meth:`DynamicActionProvider.get_action_metadata_record`. + record = await dap.get_action_metadata_record(action.name) + except Exception: + logger.exception( + 'Error listing actions for Dynamic Action Provider %s', + action.name, + ) + continue + for record_key, meta in record.items(): + full_key = create_action_key(ActionKind.DYNAMIC_ACTION_PROVIDER, record_key) + local[full_key] = _reflection_payload_for_dap_metadata(full_key, meta) + parts = parse_dap_qualified_name(record_key) + if parts is None: + continue + _provider, inner_kind_str, inner_name = parts + try: + inner_kind = ActionKind(inner_kind_str) + except ValueError: + logger.debug( + "Unrecognized ActionKind '%s' in DAP record key '%s' from provider '%s'", + inner_kind_str, + record_key, + action.name, + ) + continue + + canonical = create_action_key(inner_kind_str, inner_name) + if canonical in local: + continue + try: + nested = await dap.get_action(inner_kind_str, inner_name) + except Exception as e: + logger.debug( + 'DAP %s failed resolving nested action %s/%s for canonical catalog row', + action.name, + inner_kind_str, + inner_name, + exc_info=e, + ) + nested = None + if nested is not None: + local[canonical] = _reflection_payload_for_registered_action(nested) + else: + canon_payload = dict(_reflection_payload_for_dap_metadata(full_key, meta)) + canon_payload['key'] = canonical + canon_payload['name'] = inner_name + canon_payload['type'] = inner_kind + local[canonical] = canon_payload + + if self._parent is None: + return local + parent_resolvable = await self._parent.list_resolvable_actions() + return {**parent_resolvable, **local} + def register_value(self, kind: str, name: str, value: object) -> None: """Registers a value with a given kind and name. @@ -238,10 +449,13 @@ def lookup_value(self, kind: str, name: str) -> object | None: name: The name of the value (e.g., "json", "text"). Returns: - The value or None if not found. + The value or None if not found. Falls back to parent registry. """ with self._lock: - return self._value_by_kind_and_name.get(kind, {}).get(name) + local = self._value_by_kind_and_name.get(kind, {}).get(name) + if local is not None: + return local + return self._parent.lookup_value(kind, name) if self._parent is not None else None def list_values(self, kind: str) -> list[str]: """List all values registered for a specific kind. @@ -270,6 +484,20 @@ def register_plugin(self, plugin: Plugin) -> None: if plugin.name in self._plugins: raise ValueError(f'Plugin {plugin.name} already registered') self._plugins[plugin.name] = plugin + self._all_plugins_initialized = False + + async def initialize_all_plugins(self) -> None: + """Run ``init()`` for every plugin on this registry exactly once (until a new plugin is registered). + + Used before enumerating registered actions so plugin-registered entries exist in ``_entries``. + """ + if self._all_plugins_initialized: + return + with self._lock: + plugin_names = list(self._plugins.keys()) + for name in plugin_names: + await self._ensure_plugin_initialized(name) + self._all_plugins_initialized = True async def _ensure_plugin_initialized(self, plugin_name: str) -> None: """Ensure a plugin is initialized exactly once. @@ -359,163 +587,161 @@ async def _trigger_lazy_loading(self, action: Action | None) -> Action | None: self._loading_actions.discard(action_id) return action + async def _resolve_dap_qualified_action(self, kind: ActionKind, name: str) -> Action | None: + """Resolve through the one registered DAP for ``provider:innerKind/innerName`` names. + + Caller must ensure :func:`parse_dap_qualified_name` accepts ``name``. Does not consult + plugins. Returns ``None`` if the provider is not registered here (caller may delegate + to a parent registry). + """ + qualified = parse_dap_qualified_name(name) + if qualified is None: + return None + dap_host = qualified.provider + with self._lock: + provider = self._entries.get(ActionKind.DYNAMIC_ACTION_PROVIDER, {}).get(dap_host) + if provider is None: + return None + dap_action = await self._trigger_lazy_loading(provider) + if dap_action is None: + raise RuntimeError( + f'Dynamic action provider {dap_host!r} is not registered. ' + 'DAPs must be registered using define_dynamic_action_provider ' + 'before referencing qualified action names.' + ) + dap = getattr(dap_action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None) + if dap is not None: + try: + resolved = await dap.get_action(qualified.inner_kind, qualified.inner_name) + except Exception as e: + raise ValueError(f'Dynamic action provider {dap_host!r} get_action failed for {kind} {name!r}') from e + if resolved is not None and resolved.kind == kind: + return resolved + if resolved is None: + raise ValueError( + f'Dynamic action provider {dap_host!r} has no action ' + f'{qualified.inner_kind!r}/{qualified.inner_name!r} for {name!r}' + ) + raise ValueError( + f'Dynamic action provider {dap_host!r} returned {resolved.kind!r} for {name!r}, expected {kind!r}' + ) + try: + response = await dap_action.run({'kind': kind, 'name': name}) + if response.response: + self.register_action_instance(response.response) + return await self._trigger_lazy_loading(response.response) + except Exception as e: + logger.debug( + f'Dynamic action provider {dap_host} failed for {kind}/{name}', + exc_info=e, + ) + return None + async def resolve_action(self, kind: ActionKind, name: str) -> Action | None: - """Resolve an action by kind and name, supporting both prefixed and unprefixed names. + """Resolve an action by kind and name. - This method supports: - 1. Cache hit: Returns immediately if action is already registered - 2. Namespaced request (e.g., "plugin/model"): Resolves via specific plugin - 3. Unprefixed request (e.g., "model"): Tries all plugins, errors on ambiguity - 4. Dynamic action providers: Last-resort fallback for dynamic action creation + Tries an exact (kind, name) cache hit first. DAP-qualified names + (provider:innerKind/innerName) go through that provider only. If the name contains a + slash, the first segment is treated as a plugin id: that plugin is initialized and + plugin.resolve is used. Falls back to parent registry if nothing found. Args: kind: The type of action to resolve. - name: The name of the action (may be prefixed with "plugin/" or unprefixed). + name: Action name, optionally plugin/... for a specific plugin. Returns: The Action instance if found, None otherwise. - - Raises: - ValueError: If an unprefixed name matches multiple plugins (ambiguous). """ - # Cache hit with self._lock: if kind in self._entries and name in self._entries[kind]: return await self._trigger_lazy_loading(self._entries[kind][name]) - action: Action | None = None + # DAP-qualified names: resolve via that provider only (not plugin slash splitting). + if kind != ActionKind.DYNAMIC_ACTION_PROVIDER and parse_dap_qualified_name(name) is not None: + action = await self._resolve_dap_qualified_action(kind, name) + if action is not None: + return action + if self._parent is not None: + return await self._parent.resolve_action(kind, name) + return None - # Namespaced request + # /: resolve that plugin. if '/' in name: - plugin_name, local = name.split('/', 1) + plugin_name, action_name = name.split('/', 1) with self._lock: plugin = self._plugins.get(plugin_name) if plugin is not None: await self._ensure_plugin_initialized(plugin_name) - target = f'{plugin_name}/{local}' # normalized + target = f'{plugin_name}/{action_name}' # normalized # Check cache again after init - init() might have registered this action with self._lock: if kind in self._entries and target in self._entries[kind]: return await self._trigger_lazy_loading(self._entries[kind][target]) + # On-demand resolution: target may not have been in init()'s registered set. action = await plugin.resolve(kind, target) if action is not None: self.register_action_instance(action, namespace=plugin_name) with self._lock: return await self._trigger_lazy_loading(self._entries.get(kind, {}).get(target)) - else: - # Unprefixed request: try all plugins - successes: list[tuple[str, Action]] = [] - with self._lock: - plugins = list(self._plugins.items()) - for plugin_name, plugin in plugins: - await self._ensure_plugin_initialized(plugin_name) - target = f'{plugin_name}/{name}' - - # Check cache first - init() might have registered this action - with self._lock: - cached_action = self._entries.get(kind, {}).get(target) - if cached_action is not None: - successes.append((plugin_name, cached_action)) - continue - - action = await plugin.resolve(kind, target) - if action is not None: - successes.append((plugin_name, action)) - if len(successes) > 1: - plugin_names = [p for p, _ in successes] - raise ValueError( - f"Ambiguous {kind} action name '{name}'. " - + f"Matches plugins: {plugin_names}. Use 'plugin/{name}'." - ) - - if len(successes) == 1: - plugin_name, action = successes[0] - self.register_action_instance(action, namespace=plugin_name) - with self._lock: - return await self._trigger_lazy_loading(self._entries.get(kind, {}).get(f'{plugin_name}/{name}')) - - # Fallback: try dynamic action providers (for MCP, dynamic resources, etc.) - # Skip if we're looking up a dynamic action provider itself to avoid recursion - if kind != ActionKind.DYNAMIC_ACTION_PROVIDER: - with self._lock: - if ActionKind.DYNAMIC_ACTION_PROVIDER in self._entries: - providers_dict = self._entries[ActionKind.DYNAMIC_ACTION_PROVIDER] - else: - providers_dict = {} - providers = list(providers_dict.values()) - for provider in providers: - try: - response = await provider.run({'kind': kind, 'name': name}) - if response.response: - self.register_action_instance(response.response) - return await self._trigger_lazy_loading(response.response) - except Exception as e: - logger.debug( - f'Dynamic action provider {provider.name} failed for {kind}/{name}', - exc_info=e, - ) - continue + # Final fallback: delegate to parent registry. + if self._parent is not None: + return await self._parent.resolve_action(kind, name) return None async def resolve_action_by_key(self, key: str) -> Action | None: """Resolve an action using its combined key string. - The key format is `/`, where kind must be a valid - `ActionKind` and name may be prefixed with plugin namespace or unprefixed. + The key format is ``//``, where kind must be a valid + ``ActionKind`` and name may be prefixed with plugin namespace or + unprefixed. + + For nested actions exposed by a dynamic action provider, use + ``/dynamic-action-provider/:/`` (for + example ``/dynamic-action-provider/my-mcp:tool/echo``). Args: - key: The action key in the format `/`. + key: The action key in the format ``//``. Returns: - The `Action` instance if found, None otherwise. + The ``Action`` instance if found, None otherwise. Raises: - ValueError: If the key format is invalid, the kind is not a valid - `ActionKind`, or an unprefixed name is ambiguous. + ValueError: If the key format is invalid or the kind is not a valid + ``ActionKind``. """ kind, name = parse_action_key(key) + if kind == ActionKind.DYNAMIC_ACTION_PROVIDER: + dap_parts = parse_dap_qualified_name(name) + if dap_parts is not None: + provider_action = await self.resolve_action( + ActionKind.DYNAMIC_ACTION_PROVIDER, + dap_parts.provider, + ) + if provider_action is None: + return None + dap = getattr(provider_action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None) + if dap is None: + return None + try: + resolved = await dap.get_action(dap_parts.inner_kind, dap_parts.inner_name) + except Exception as e: + logger.debug( + f'Dynamic action provider {dap_parts.provider} failed for ' + f'qualified key {dap_parts.inner_kind}/{dap_parts.inner_name}', + exc_info=e, + ) + return None + if resolved is None: + return None + return resolved return await self.resolve_action(kind, name) - async def list_actions(self, allowed_kinds: list[ActionKind] | None = None) -> list[ActionMetadata]: - """List all actions advertised by plugins. - - This method returns the advertised set of actions from all registered - plugins. It does NOT trigger plugin initialization and does NOT consult - the registry's internal action store. - - Args: - allowed_kinds: Optional list of action kinds to filter by. - - Returns: - A list of ActionMetadata objects describing available actions. - - Raises: - ValueError: If a plugin returns invalid ActionMetadata. - """ - metas: list[ActionMetadata] = [] - with self._lock: - plugins = list(self._plugins.items()) - for plugin_name, plugin in plugins: - plugin_metas = await plugin.list_actions() - for meta in plugin_metas or []: - if not meta.name: - raise ValueError(f'Invalid ActionMetadata from {plugin_name}: name required') - - # Normalize metadata name - if '/' not in meta.name: - meta = meta.model_copy(update={'name': f'{plugin_name}/{meta.name}'}) - - if allowed_kinds and meta.kind not in allowed_kinds: - continue - metas.append(meta) - return metas - def register_schema(self, name: str, schema: dict[str, object], schema_type: type[BaseModel] | None = None) -> None: """Registers a schema by name. @@ -545,10 +771,13 @@ def lookup_schema(self, name: str) -> dict[str, object] | None: name: The name of the schema to look up. Returns: - The schema data if found, None otherwise. + The schema data if found, None otherwise. Falls back to parent. """ with self._lock: - return self._schemas_by_name.get(name) + local = self._schemas_by_name.get(name) + if local is not None: + return local + return self._parent.lookup_schema(name) if self._parent is not None else None def lookup_schema_type(self, name: str) -> type[BaseModel] | None: """Looks up a schema's Pydantic type by name. @@ -557,10 +786,13 @@ def lookup_schema_type(self, name: str) -> type[BaseModel] | None: name: The name of the schema to look up. Returns: - The Pydantic model class if found, None otherwise. + The Pydantic model class if found, None otherwise. Falls back to parent. """ with self._lock: - return self._schema_types_by_name.get(name) + local = self._schema_types_by_name.get(name) + if local is not None: + return local + return self._parent.lookup_schema_type(name) if self._parent is not None else None # ===== Typed Action Lookups ===== # diff --git a/py/packages/genkit/tests/genkit/ai/dap_test.py b/py/packages/genkit/tests/genkit/ai/dap_test.py index 02b650cd19..274855b5f7 100644 --- a/py/packages/genkit/tests/genkit/ai/dap_test.py +++ b/py/packages/genkit/tests/genkit/ai/dap_test.py @@ -361,6 +361,34 @@ async def dap_fn() -> DapValue: assert call_count == 1 +@pytest.mark.asyncio +async def test_dap_run_returns_child_metadata_rows(registry: Registry, tool1: Action) -> None: + async def dap_fn() -> DapValue: + return {'tool': [tool1]} + + dap = define_dynamic_action_provider(registry, 'my-dap', dap_fn) + ar = await dap.action.run() + rows = ar.response + assert isinstance(rows, list) + assert len(rows) == 1 + assert rows[0]['name'] == 'tool1' + assert rows[0]['type'] == 'tool' + assert rows[0]['key'] == '/dynamic-action-provider/my-dap:tool/tool1' + + +@pytest.mark.asyncio +async def test_child_actions_get_dap_qualified_key_attr(registry: Registry, tool1: Action) -> None: + from genkit._core._action import GENKIT_DAP_QUALIFIED_KEY_ATTR, ActionKind, create_action_key + + async def dap_fn() -> DapValue: + return {'tool': [tool1]} + + dap = define_dynamic_action_provider(registry, 'my-dap', dap_fn) + await dap.get_action('tool', 'tool1') + expected = create_action_key(ActionKind.DYNAMIC_ACTION_PROVIDER, 'my-dap:tool/tool1') + assert getattr(tool1, GENKIT_DAP_QUALIFIED_KEY_ATTR) == expected + + @pytest.mark.asyncio async def test_get_action_metadata_record_raises_on_missing_name(registry: Registry) -> None: async def nameless_fn(input: str) -> str: diff --git a/py/packages/genkit/tests/genkit/ai/dynamic_tools_generate_test.py b/py/packages/genkit/tests/genkit/ai/dynamic_tools_generate_test.py new file mode 100644 index 0000000000..0c9085ed7e --- /dev/null +++ b/py/packages/genkit/tests/genkit/ai/dynamic_tools_generate_test.py @@ -0,0 +1,303 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for DAP-backed tool resolution in the generate loop.""" + +import pytest +from pydantic import BaseModel + +from genkit import Genkit, Message, ModelResponse +from genkit._ai._generate import expand_wildcard_tools +from genkit._ai._testing import define_programmable_model +from genkit._core._action import Action, ActionKind +from genkit._core._dap import DapValue, define_dynamic_action_provider +from genkit._core._registry import Registry +from genkit._core._typing import ( + FinishReason, + Part, + Role, + TextPart, + ToolRequest, + ToolRequestPart, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _text_response(text: str) -> ModelResponse: + return ModelResponse( + message=Message(role=Role.MODEL, content=[Part(root=TextPart(text=text))]), + finish_reason=FinishReason.STOP, + ) + + +def _tool_call_response(tool_name: str, input: dict) -> ModelResponse: + return ModelResponse( + message=Message( + role=Role.MODEL, + content=[Part(root=ToolRequestPart(tool_request=ToolRequest(name=tool_name, input=input, ref=tool_name)))], + ), + finish_reason=FinishReason.STOP, + ) + + +# --------------------------------------------------------------------------- +# expand_wildcard_tools +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_expand_wildcard_all() -> None: + """'provider:tool/*' expands to all tools from the DAP.""" + registry = Registry() + + async def tool_fn(x: str) -> str: + return x + + t1 = registry.register_action(name='echo', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'echo'}) + t2 = registry.register_action(name='ping', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'ping'}) + + async def dap_fn() -> DapValue: + return {'tool': [t1, t2]} + + define_dynamic_action_provider(registry, 'mcp', dap_fn) + + result = await expand_wildcard_tools(registry, ['mcp:tool/*']) + assert sorted(result) == [ + '/dynamic-action-provider/mcp:tool/echo', + '/dynamic-action-provider/mcp:tool/ping', + ] + + +@pytest.mark.asyncio +async def test_expand_wildcard_prefix() -> None: + """'provider:tool/prefix*' expands only matching tools.""" + registry = Registry() + + async def tool_fn(x: str) -> str: + return x + + t1 = registry.register_action( + name='get_weather', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'get_weather'} + ) + t2 = registry.register_action(name='get_time', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'get_time'}) + t3 = registry.register_action(name='set_alarm', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'set_alarm'}) + + async def dap_fn() -> DapValue: + return {'tool': [t1, t2, t3]} + + define_dynamic_action_provider(registry, 'mcp', dap_fn) + + result = await expand_wildcard_tools(registry, ['mcp:tool/get_*']) + assert sorted(result) == [ + '/dynamic-action-provider/mcp:tool/get_time', + '/dynamic-action-provider/mcp:tool/get_weather', + ] + + +@pytest.mark.asyncio +async def test_non_wildcard_names_pass_through() -> None: + """Non-wildcard names are returned unchanged.""" + registry = Registry() + result = await expand_wildcard_tools(registry, ['my_tool', 'other_tool']) + assert result == ['my_tool', 'other_tool'] + + +# --------------------------------------------------------------------------- +# DAP tools resolved inside generate loop +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_dap_tool_resolved_in_generate() -> None: + """generate resolves and runs a tool that is only advertised via a DAP (never register_action).""" + ai = Genkit() + pm, _ = define_programmable_model(ai) + + call_log: list[str] = [] + + class EchoInput(BaseModel): + text: str + + async def echo_fn(inp: EchoInput) -> str: + call_log.append(inp.text) + return f'echoed:{inp.text}' + + # Detached Action — only returned from the DAP; not registered on the root registry. + echo_action = Action( + name='echo', + kind=ActionKind.TOOL, + fn=echo_fn, + metadata={'name': 'echo'}, + ) + + async def dap_fn() -> DapValue: + return {'tool': [echo_action]} + + ai.define_dynamic_action_provider('mcp', dap_fn) + + # Precondition: `echo` is not a normal root TOOL registration (only in the DAP). + assert 'echo' not in ai.registry._entries.get(ActionKind.TOOL, {}) + + pm.responses = [ + _tool_call_response('echo', {'text': 'hello'}), + _text_response('done'), + ] + + response = await ai.generate( + model='programmableModel', + prompt='use echo', + tools=['mcp:tool/echo'], + ) + + assert response.text == 'done' + assert call_log == ['hello'] + # Postcondition: resolving/running the tool via the child registry + DAP still does not + # persist `echo` under the root registry as a static tool (same check as above). + assert 'echo' not in ai.registry._entries.get(ActionKind.TOOL, {}) + + +@pytest.mark.asyncio +async def test_dap_tools_do_not_pollute_root_registry() -> None: + """After generate, DAP-resolved tools are not cached in the root registry.""" + ai = Genkit() + pm, _ = define_programmable_model(ai) + + class Inp(BaseModel): + x: str + + async def tool_fn(inp: Inp) -> str: + return inp.x + + # Create an Action directly — NOT registered in root via register_action + dap_only_action = Action(name='dap_only_tool', kind=ActionKind.TOOL, fn=tool_fn, metadata={'name': 'dap_only_tool'}) + + async def dap_fn() -> DapValue: + return {'tool': [dap_only_action]} + + ai.define_dynamic_action_provider('mcp', dap_fn) + + pm.responses = [_text_response('no tools called')] + + await ai.generate( + model='programmableModel', + prompt='hi', + tools=['mcp:tool/dap_only_tool'], + ) + + # Root registry should NOT have dap_only_tool cached — it was never registered there + root_tools = ai.registry._entries.get(ActionKind.TOOL, {}) + assert 'dap_only_tool' not in root_tools + + +@pytest.mark.asyncio +async def test_wildcard_tools_in_generate() -> None: + """Wildcard tool pattern is expanded before generate resolves tools.""" + ai = Genkit() + pm, _ = define_programmable_model(ai) + + call_log: list[str] = [] + + class InpA(BaseModel): + x: str + + class InpB(BaseModel): + x: str + + async def tool_a_fn(inp: InpA) -> str: + call_log.append(f'a:{inp.x}') + return f'a:{inp.x}' + + async def tool_b_fn(inp: InpB) -> str: + call_log.append(f'b:{inp.x}') + return f'b:{inp.x}' + + tool_a = ai.registry.register_action(name='tool_a', kind=ActionKind.TOOL, fn=tool_a_fn, metadata={'name': 'tool_a'}) + tool_b = ai.registry.register_action(name='tool_b', kind=ActionKind.TOOL, fn=tool_b_fn, metadata={'name': 'tool_b'}) + + async def dap_fn() -> DapValue: + return {'tool': [tool_a, tool_b]} + + ai.define_dynamic_action_provider('mcp', dap_fn) + + pm.responses = [ + _tool_call_response('tool_a', {'x': 'hi'}), + _text_response('finished'), + ] + + response = await ai.generate( + model='programmableModel', + prompt='use a tool', + tools=['mcp:tool/*'], + ) + + assert response.text == 'finished' + assert call_log == ['a:hi'] + + +@pytest.mark.asyncio +async def test_wildcard_tools_avoids_shadowing_conflict() -> None: + """Explicit wildcard provider paths should not be shadowed by earlier providers.""" + ai = Genkit() + pm, _ = define_programmable_model(ai) + + call_log: list[str] = [] + + class Inp(BaseModel): + x: str + + async def echo1_fn(inp: Inp) -> str: + call_log.append('mcp1') + return 'echo 1' + + async def echo2_fn(inp: Inp) -> str: + call_log.append('mcp2') + return 'echo 2' + + # Detached Actions (not registered in root registry directly) + echo1_action = Action(name='echo', kind=ActionKind.TOOL, fn=echo1_fn, metadata={'name': 'echo'}) + echo2_action = Action(name='echo', kind=ActionKind.TOOL, fn=echo2_fn, metadata={'name': 'echo'}) + + async def dap1_fn() -> DapValue: + return {'tool': [echo1_action]} + + async def dap2_fn() -> DapValue: + return {'tool': [echo2_action]} + + # Register mcp1 first. If resolution falls back to an unqualified lookup, mcp1 will "win". + ai.define_dynamic_action_provider('mcp1', dap1_fn) + ai.define_dynamic_action_provider('mcp2', dap2_fn) + + # The model calls the 'echo' tool + pm.responses = [ + _tool_call_response('echo', {'x': 'hello'}), + _text_response('finished'), + ] + + response = await ai.generate( + model='programmableModel', + prompt='use echo', + # Crucially, we explicitly request tools from mcp2 ONLY + tools=['mcp2:tool/*'], + ) + + assert response.text == 'finished' + + # If the bug is present, this will fail because it will fall back to the unqualified + # global loop and find mcp1's 'echo' tool instead. + assert call_log == ['mcp2'] diff --git a/py/packages/genkit/tests/genkit/ai/resource_integration_test.py b/py/packages/genkit/tests/genkit/ai/resource_integration_test.py index eedf3892b5..63d00d3ffe 100644 --- a/py/packages/genkit/tests/genkit/ai/resource_integration_test.py +++ b/py/packages/genkit/tests/genkit/ai/resource_integration_test.py @@ -75,29 +75,28 @@ async def test_dynamic_action_provider_resource() -> None: """Test dynamic action provider with resources.""" registry = Registry() - # Register a dynamic provider that handles any "dynamic://*" uri + # Register a dynamic provider that handles any "dynamic://*" uri (DAP-qualified: provider:resource/). async def provider_fn(input: dict[str, object], ctx: ActionRunContext) -> object: kind = cast(ActionKind, input['kind']) name = cast(str, input['name']) - if kind == ActionKind.RESOURCE and name.startswith('dynamic://'): + if kind != ActionKind.RESOURCE: + return None + # DAP-qualified resource refs only: ``provider:resource/`` (see run() full name). + if ':resource/' not in name: + return None + inner_uri = name.split(':resource/', 1)[1] + if not inner_uri.startswith('dynamic://'): + return None - async def dyn_res_fn(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: - return ResourceOutput(content=[Part(root=TextPart(text=f'Dynamic content for {input.uri}'))]) + async def dyn_res_fn(input: ResourceInput, ctx: ActionRunContext) -> ResourceOutput: + return ResourceOutput(content=[Part(root=TextPart(text=f'Dynamic content for {input.uri}'))]) - return resource({'uri': name}, dyn_res_fn) - return None + return resource({'uri': inner_uri}, dyn_res_fn) - # Register the provider as an action (it effectively acts as a factory) - # Note: Accessing internal structure for test setup as register_action expects specific signature - # But we want to register it under DYNAMIC_ACTION_PROVIDER kind. registry.register_action(kind=ActionKind.DYNAMIC_ACTION_PROVIDER, name='test-provider', fn=provider_fn) - # Register mock model - # Register mock model async def mock_model(input: ModelRequest, ctx: ActionRunContext) -> ModelResponse: - # Verify docs are empty assert not input.docs - # Verify dynamic hydration assert input.messages[0].content[0].root.text == 'Dynamic content for dynamic://bar' return ModelResponse(message=Message(role=Role.MODEL, content=[Part(root=TextPart(text='Done'))])) @@ -106,7 +105,7 @@ async def mock_model(input: ModelRequest, ctx: ActionRunContext) -> ModelRespons options = GenerateActionOptions( model='mock-model', messages=[Message(role=Role.USER, content=[Part(root=ResourcePart(resource=Resource1(uri='dynamic://bar')))])], - resources=['dynamic://bar'], + resources=['test-provider:resource/dynamic://bar'], ) response = await generate_action(registry, options) diff --git a/py/packages/genkit/tests/genkit/core/action_test.py b/py/packages/genkit/tests/genkit/core/action_test.py index 79610a4d04..3adfae42c0 100644 --- a/py/packages/genkit/tests/genkit/core/action_test.py +++ b/py/packages/genkit/tests/genkit/core/action_test.py @@ -14,8 +14,10 @@ Action, ActionKind, ActionRunContext, + DapQualifiedName, create_action_key, parse_action_key, + parse_dap_qualified_name, parse_plugin_name_from_action_name, ) from genkit._core._error import GenkitError @@ -72,6 +74,15 @@ def test_parse_action_key_invalid_format() -> None: parse_action_key(key) +def test_parse_dap_qualified_name() -> None: + """Parse provider:innerKind/innerName segments.""" + assert parse_dap_qualified_name('my-dap:tool/echo') == DapQualifiedName('my-dap', 'tool', 'echo') + assert parse_dap_qualified_name('plugin/foo:model/bar') is None + assert parse_dap_qualified_name('plain-name') is None + assert parse_dap_qualified_name('no-slash:toolonly') is None + assert parse_dap_qualified_name(':tool/x') is None + + def test_create_action_key() -> None: """Create action key.""" assert create_action_key(ActionKind.CUSTOM, 'foo') == '/custom/foo' diff --git a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py index 060d05f13f..160e77583c 100644 --- a/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py +++ b/py/packages/genkit/tests/genkit/core/endpoints/reflection_test.py @@ -45,7 +45,7 @@ import pytest_asyncio from httpx import ASGITransport, AsyncClient -from genkit._core._action import ActionKind, ActionMetadata +from genkit._core._action import ActionKind from genkit._core._reflection import create_reflection_asgi_app from genkit._core._registry import Registry @@ -66,6 +66,8 @@ async def asgi_client(mock_registry: MagicMock) -> AsyncIterator[AsyncClient]: Returns: An AsyncClient configured to make requests to the test ASGI app. """ + mock_registry.list_actions = AsyncMock(return_value={}) + mock_registry.list_resolvable_actions = AsyncMock(return_value={}) app = create_reflection_asgi_app(mock_registry) transport = ASGITransport(app=app) client = AsyncClient(transport=transport, base_url='http://test') @@ -86,21 +88,16 @@ async def test_health_check(asgi_client: AsyncClient) -> None: async def test_list_actions(asgi_client: AsyncClient, mock_registry: MagicMock) -> None: """Test that the actions list endpoint returns registered actions.""" - # Mock the async list_actions method to return a list of ActionMetadata - async def mock_list_actions_async(allowed_kinds: list[ActionKind] | None = None) -> list[ActionMetadata]: - return [ - ActionMetadata( - kind=ActionKind.CUSTOM, - name='action1', - ) - ] - - # Mock resolve_actions_by_kind to return empty dict (no registered actions in this test) - async def mock_resolve_actions_by_kind(kind: ActionKind) -> dict: - return {} - - mock_registry.list_actions = mock_list_actions_async - mock_registry.resolve_actions_by_kind = mock_resolve_actions_by_kind + async def mock_list_resolvable() -> dict[str, dict[str, object]]: + return { + '/custom/action1': { + 'key': '/custom/action1', + 'name': 'action1', + 'type': ActionKind.CUSTOM, + } + } + + mock_registry.list_resolvable_actions = mock_list_resolvable response = await asgi_client.get('/api/actions') assert response.status_code == 200 result = response.json() diff --git a/py/packages/genkit/tests/genkit/core/registry_test.py b/py/packages/genkit/tests/genkit/core/registry_test.py index 43cf4452f5..cac65b4aee 100644 --- a/py/packages/genkit/tests/genkit/core/registry_test.py +++ b/py/packages/genkit/tests/genkit/core/registry_test.py @@ -12,7 +12,8 @@ import pytest from genkit import Genkit, Plugin -from genkit._core._action import Action, ActionKind, ActionMetadata +from genkit._core._action import Action, ActionKind, ActionMetadata, create_action_key +from genkit._core._dap import DapValue, define_dynamic_action_provider from genkit._core._registry import Registry @@ -54,6 +55,54 @@ async def test_resolve_action_by_key_invalid_format() -> None: await registry.resolve_action_by_key('invalid_key') +@pytest.mark.asyncio +async def test_resolve_action_via_dynamic_action_provider() -> None: + """Registry resolves DAP tools only for DAP-qualified names (host:kind/name).""" + registry = Registry() + + async def tool_fn(x: str) -> str: + return x + + inner = Action( + name='inner-tool', + kind=ActionKind.TOOL, + fn=tool_fn, + metadata={'name': 'inner-tool'}, + ) + + async def dap_fn() -> DapValue: + return {'tool': [inner]} + + define_dynamic_action_provider(registry, 'my-dap', dap_fn) + + got = await registry.resolve_action(ActionKind.TOOL, 'my-dap:tool/inner-tool') + assert got is inner + + +@pytest.mark.asyncio +async def test_resolve_action_by_key_dap_qualified() -> None: + """DAP-qualified keys resolve nested actions.""" + registry = Registry() + + async def tool_fn(x: str) -> str: + return x + + inner = Action( + name='inner-tool', + kind=ActionKind.TOOL, + fn=tool_fn, + metadata={'name': 'inner-tool'}, + ) + + async def dap_fn() -> DapValue: + return {'tool': [inner]} + + define_dynamic_action_provider(registry, 'my-dap', dap_fn) + + got = await registry.resolve_action_by_key('/dynamic-action-provider/my-dap:tool/inner-tool') + assert got is inner + + @pytest.mark.asyncio async def test_resolve_action_from_plugin() -> None: """Resolve action from plugin test.""" @@ -79,8 +128,8 @@ async def list_actions(self) -> list[ActionMetadata]: ai = Genkit(plugins=[MyPlugin()]) - metas = await ai.registry.list_actions() - assert metas == [ActionMetadata(kind=ActionKind.MODEL, name='myplugin/foo')] + catalog = await ai.registry.list_resolvable_actions() + assert catalog['/model/myplugin/foo']['name'] == 'myplugin/foo' action = await ai.registry.resolve_action(ActionKind.MODEL, 'myplugin/foo') @@ -139,3 +188,249 @@ async def noop() -> None: assert resolved.name == 'self_ref' # Factory should have been called exactly once (re-entrant call skipped) assert call_count == 1 + + +# ============================================================================= +# Child registry tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_new_child_is_child() -> None: + """new_child() returns a child whose is_child is True.""" + parent = Registry() + child = parent.new_child() + assert child.is_child + assert not parent.is_child + assert child.parent is parent + + +@pytest.mark.asyncio +async def test_child_resolves_parent_action() -> None: + """Child registry falls back to parent for resolve_action.""" + parent = Registry() + action = parent.register_action(name='shared', kind=ActionKind.CUSTOM, fn=_identity) + + child = parent.new_child() + got = await child.resolve_action(ActionKind.CUSTOM, 'shared') + assert got is action + + +@pytest.mark.asyncio +async def test_child_action_does_not_pollute_parent() -> None: + """Actions registered on child are invisible to parent.""" + parent = Registry() + child = parent.new_child() + child.register_action(name='child_only', kind=ActionKind.CUSTOM, fn=_identity) + + assert await parent.resolve_action(ActionKind.CUSTOM, 'child_only') is None + assert await child.resolve_action(ActionKind.CUSTOM, 'child_only') is not None + + +@pytest.mark.asyncio +async def test_child_shadows_parent_action() -> None: + """Child action with the same name takes precedence over parent.""" + parent = Registry() + parent_action = parent.register_action(name='shared', kind=ActionKind.CUSTOM, fn=_identity) + + child = parent.new_child() + + async def child_fn(x: object) -> object: + return x + + child_action = child.register_action(name='shared', kind=ActionKind.CUSTOM, fn=child_fn) + + assert await child.resolve_action(ActionKind.CUSTOM, 'shared') is child_action + assert await parent.resolve_action(ActionKind.CUSTOM, 'shared') is parent_action + + +def test_child_inherits_default_model() -> None: + """Child inherits default_model from parent if not set locally.""" + parent = Registry() + parent.default_model = 'gemini-pro' + + child = parent.new_child() + assert child.default_model == 'gemini-pro' + + child.default_model = 'gemini-flash' + assert child.default_model == 'gemini-flash' + assert parent.default_model == 'gemini-pro' + + +def test_child_inherits_lookup_value() -> None: + """Child falls back to parent for lookup_value.""" + parent = Registry() + parent.register_value('format', 'json', {'json': True}) + + child = parent.new_child() + assert child.lookup_value('format', 'json') == {'json': True} + + # Local override shadows parent + child.register_value('format', 'json', {'json': False}) + assert child.lookup_value('format', 'json') == {'json': False} + assert parent.lookup_value('format', 'json') == {'json': True} + + +@pytest.mark.asyncio +async def test_child_resolvable_includes_parent_plugin() -> None: + """list_resolvable_actions on child includes parent plugin rows not shadowed locally.""" + + class ParentPlugin(Plugin): + name = 'parentplugin' + + async def init(self) -> list[Action]: + return [] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + return None + + async def list_actions(self) -> list[ActionMetadata]: + return [ActionMetadata(kind=ActionKind.MODEL, name='my-model')] + + parent = Registry() + parent.register_plugin(ParentPlugin()) + + child = parent.new_child() + catalog = await child.list_resolvable_actions() + assert '/model/parentplugin/my-model' in catalog + assert catalog['/model/parentplugin/my-model']['name'] == 'parentplugin/my-model' + + +@pytest.mark.asyncio +async def test_child_resolvable_local_tool_shadows_parent_plugin_metadata() -> None: + """A tool registered on the child must not inherit parent plugin metadata for the same name.""" + + class ParentPlugin(Plugin): + name = 'parentplugin' + + async def init(self) -> list[Action]: + return [] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + return None + + async def list_actions(self) -> list[ActionMetadata]: + return [ + ActionMetadata( + kind=ActionKind.TOOL, + name='parentplugin/shared-name', + description='from parent plugin', + ) + ] + + async def local_tool(_: str) -> str: + return 'local' + + parent = Registry() + parent.register_plugin(ParentPlugin()) + child = parent.new_child() + child.register_action( + kind=ActionKind.TOOL, + name='parentplugin/shared-name', + fn=local_tool, + description='from child registry', + ) + + catalog = await child.list_resolvable_actions() + entry = catalog['/tool/parentplugin/shared-name'] + assert entry['description'] == 'from child registry' + assert entry['description'] != 'from parent plugin' + + +@pytest.mark.asyncio +async def test_child_resolvable_dap_tool_shadows_parent_plugin_metadata() -> None: + """DAP-exposed nested actions must shadow parent plugin metadata for the same (kind, name).""" + + class ParentPlugin(Plugin): + name = 'parentplugin' + + async def init(self) -> list[Action]: + return [] + + async def resolve(self, action_type: ActionKind, name: str) -> Action | None: + return None + + async def list_actions(self) -> list[ActionMetadata]: + return [ + ActionMetadata( + kind=ActionKind.TOOL, + name='parentplugin/mcp-tool', + description='stale parent schema', + ) + ] + + async def mcp_tool_fn(_: str) -> str: + return 'mcp' + + mcp_tool = Action( + kind=ActionKind.TOOL, + name='parentplugin/mcp-tool', + fn=mcp_tool_fn, + description='from mcp', + ) + + parent = Registry() + parent.register_plugin(ParentPlugin()) + child = parent.new_child() + + async def dap_fn() -> DapValue: + return {'tool': [mcp_tool]} + + define_dynamic_action_provider(child, 'mcp', dap_fn) + + catalog = await child.list_resolvable_actions() + entry = catalog['/tool/parentplugin/mcp-tool'] + assert entry['description'] == 'from mcp' + assert entry['description'] != 'stale parent schema' + + +@pytest.mark.asyncio +async def test_list_resolvable_registered_canonical_coexists_with_qualified_dap_rows() -> None: + """Same canonical tool path from registration and from DAP: both catalog shapes appear. + + The qualified DAP metadata row (under ``dynamic-action-provider``) is always merged; + the canonical ``/tool/...`` row prefers the action already in the registry when both + would describe the same path. + """ + tool_name = 'suite/same-canonical' + + async def registered_fn(_: str) -> str: + return 'registered' + + async def dap_nested_fn(_: str) -> str: + return 'dap' + + dap_nested = Action( + kind=ActionKind.TOOL, + name=tool_name, + fn=dap_nested_fn, + description='from dap nested', + ) + + registry = Registry() + registry.register_action( + kind=ActionKind.TOOL, + name=tool_name, + fn=registered_fn, + description='from registry registration', + ) + + async def dap_fn() -> DapValue: + return {'tool': [dap_nested]} + + define_dynamic_action_provider(registry, 'mcp', dap_fn) + + catalog = await registry.list_resolvable_actions() + + canonical = create_action_key(ActionKind.TOOL, tool_name) + record_key = f'mcp:tool/{tool_name}' + qualified = create_action_key(ActionKind.DYNAMIC_ACTION_PROVIDER, record_key) + provider_key = create_action_key(ActionKind.DYNAMIC_ACTION_PROVIDER, 'mcp') + + assert canonical in catalog + assert catalog[canonical]['description'] == 'from registry registration' + + assert qualified in catalog + assert catalog[qualified]['key'] == qualified + + assert provider_key in catalog