Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 106 additions & 15 deletions py/packages/genkit/src/genkit/_ai/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
)
from genkit._ai._resource import ResourceArgument, ResourceInput, find_matching_resource, resolve_resources
from genkit._ai._tools import Tool, ToolInterruptError
from genkit._core._action import Action, ActionKind, ActionRunContext
from genkit._core._action import (
GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR,
Action,
ActionKind,
ActionRunContext,
)
from genkit._core._error import GenkitError
from genkit._core._logger import get_logger
from genkit._core._model import GenerateActionOptions
Expand All @@ -61,6 +66,51 @@
logger = get_logger(__name__)


async def expand_wildcard_tools(registry: Registry, tool_names: list[str]) -> list[str]:
"""Expand DAP wildcard tool names into individual registry keys.

A wildcard has the form ``<provider>:tool/*`` (or ``<provider>:tool/<prefix>*``).
Each match becomes a full DAP key
``/dynamic-action-provider/<provider>:<actionType>/<toolName>`` so later resolution
stays bound to that provider (no ambiguous bare-name lookup across DAPs).

Non-wildcard names are passed through unchanged.
"""
expanded: list[str] = []
for name in tool_names:
if not name.endswith('*') or ':' not in name:
expanded.append(name)
continue

colon = name.index(':')
provider_name = name[:colon]
rest = name[colon + 1 :] # e.g. "tool/*" or "tool/prefix*"

provider_action = await registry.resolve_action(ActionKind.DYNAMIC_ACTION_PROVIDER, provider_name)
if provider_action is None:
expanded.append(name)
continue

dap = getattr(provider_action, GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR, None)
if dap is None:
expanded.append(name)
continue

if '/' not in rest:
expanded.append(name)
continue

action_type, action_pattern = rest.split('/', 1)
metas = await dap.list_action_metadata(action_type, action_pattern)
for meta in metas:
tool_name = meta.get('name')
if tool_name:
tn = str(tool_name)
expanded.append(f'/dynamic-action-provider/{provider_name}:{action_type}/{tn}')

return expanded


def tools_to_action_names(
tools: Sequence[str | Tool] | None,
) -> list[str] | None:
Expand Down Expand Up @@ -158,20 +208,27 @@ async def _generate_action(
context: dict[str, Any] | None = None,
) -> ModelResponse:
"""Execute a generation request with tool calling and middleware support."""
model, tools, format_def = await resolve_parameters(registry, raw_request)
effective_registry = registry if registry.is_child else registry.new_child()

tools_in = raw_request.tools
if tools_in:
raw_request = raw_request.model_copy()
raw_request.tools = await expand_wildcard_tools(effective_registry, tools_in)

model, tools, format_def = await resolve_parameters(effective_registry, raw_request)

raw_request, formatter = apply_format(raw_request, format_def)

if raw_request.resources:
raw_request = await apply_resources(registry, raw_request)
raw_request = await apply_resources(effective_registry, raw_request)

assert_valid_tool_names(raw_request)
assert_valid_tool_names(tools)

(
revised_request,
interrupted_response,
resumed_tool_message,
) = await _resolve_resume_options(registry, raw_request)
) = await _resolve_resume_options(effective_registry, raw_request)

# NOTE: in the future we should make it possible to interrupt a restart, but
# at the moment it's too complicated because it's not clear how to return a
Expand Down Expand Up @@ -374,7 +431,7 @@ def message_parser(msg: Message) -> Any: # noqa: ANN401
revised_model_msg,
tool_msg,
transfer_preamble,
) = await resolve_tool_requests(registry, raw_request, generated_msg)
) = await resolve_tool_requests(effective_registry, raw_request, generated_msg)

# if an interrupt message is returned, stop the tool loop and return a
# response.
Expand Down Expand Up @@ -408,7 +465,7 @@ def message_parser(msg: Message) -> Any: # noqa: ANN401

# then recursively call for another loop
return await _generate_action(
registry,
effective_registry,
raw_request=next_request,
# middleware: middleware,
current_turn=current_turn + 1,
Expand Down Expand Up @@ -584,10 +641,30 @@ async def apply_resources(registry: Registry, raw_request: GenerateActionOptions
return new_request


def assert_valid_tool_names(_raw_request: GenerateActionOptions) -> None:
"""Validate tool names in the request. (TODO: not yet implemented)."""
# TODO(#4338): implement me
pass
def _tool_short_name_for_model(name: str) -> str:
"""Return the last path segment of a tool name."""
if '/' not in name:
return name
return name[name.rfind('/') + 1 :]


def assert_valid_tool_names(tools: list[Action[Any, Any, Any]]) -> None:
"""Reject overlapping model-facing tool names before the model is called.

Two resolved tools that share the same short name (segment after the last ``/``)
cannot both appear in one generate request.
"""
if not tools:
return
seen: dict[str, str] = {}
for tool in tools:
short = _tool_short_name_for_model(tool.name)
if short in seen:
raise GenkitError(
status='INVALID_ARGUMENT',
message=(f"Cannot provide two tools with the same name: '{tool.name}' and '{seen[short]}'"),
)
seen[short] = tool.name


async def resolve_parameters(
Expand All @@ -605,9 +682,10 @@ async def resolve_parameters(
tools: list[Action[Any, Any, Any]] = []
if request.tools:
for tool_name in request.tools:
tool_action = await registry.resolve_action(ActionKind.TOOL, tool_name)
if tool_action is None:
raise Exception(f'Unable to resolve tool {tool_name}')
try:
tool_action = await resolve_tool(registry, tool_name)
except GenkitError as e:
raise Exception(f'Unable to resolve tool {tool_name}') from e
tools.append(tool_action)

format_def: FormatDef | None = None
Expand Down Expand Up @@ -665,7 +743,12 @@ async def resolve_tool_requests(
tool_dict: dict[str, Action] = {}
if request.tools:
for tool_name in request.tools:
tool_dict[tool_name] = await resolve_tool(registry, tool_name)
tool_action = await resolve_tool(registry, tool_name)
tool_dict[tool_name] = tool_action
# Model tool calls use ToolDefinition.name (short); wildcard expansion uses full DAP keys.
short = tool_action.name
if short not in tool_dict:
tool_dict[short] = tool_action

revised_model_message = message.model_copy(deep=True)

Expand Down Expand Up @@ -762,11 +845,19 @@ async def _resolve_tool_request(tool: Action, tool_request_part: ToolRequestPart
async def resolve_tool(registry: Registry, tool_ref: str | Tool) -> Action:
"""Resolve a tool from a registry name or a Tool instance.

Accepts full action keys (``/dynamic-action-provider/...``), DAP-qualified
names (``provider:tool/name``), or plain registered tool names.

Used when building ModelRequest (for example from to_generate_request).
"""
if isinstance(tool_ref, Tool):
return tool_ref.action

if tool_ref.startswith('/'):
tool = await registry.resolve_action_by_key(tool_ref)
if tool is not None:
return tool

tool = await registry.resolve_action(kind=ActionKind.TOOL, name=tool_ref)
if tool is None:
raise GenkitError(status='NOT_FOUND', message=f'Unable to resolve tool {tool_ref}')
Expand Down
56 changes: 54 additions & 2 deletions py/packages/genkit/src/genkit/_core/_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import asyncio
import inspect
import json
import re
import time
from collections.abc import AsyncIterator, Awaitable, Callable, Generator, Mapping
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, ClassVar, Generic, cast, get_type_hints
from typing import Any, ClassVar, Generic, NamedTuple, cast, get_type_hints

from opentelemetry import trace as trace_api
from opentelemetry.trace import Span
Expand Down Expand Up @@ -228,6 +229,43 @@ def extract_action_args_and_types(
# =============================================================================


GENKIT_DYNAMIC_ACTION_PROVIDER_ATTR = '_genkit_dynamic_action_provider'
# Nested actions in a DAP cache store their qualified reflection key via this attribute name:
# ``/dynamic-action-provider/<id>:<type>/<name>``.
GENKIT_DAP_QUALIFIED_KEY_ATTR = '_genkit_dap_qualified_key'


class DapQualifiedName(NamedTuple):
"""Segments of a DAP-qualified name ``provider:innerKind/innerName``."""

provider: str
inner_kind: str
inner_name: str


def parse_dap_qualified_name(name: str) -> DapQualifiedName | None:
"""Parse DAP-qualified segment ``provider:innerKind/innerName``.

Used when the action key kind is ``dynamic-action-provider`` and the name
references a nested action exposed by a provider (e.g. MCP tools).

Pattern: ``[provider]:[inner_kind]/[inner_name]`` — no slashes in the
provider segment (``plugin/foo`` is not a valid provider host).

Returns:
A :class:`DapQualifiedName` if the string matches; otherwise ``None`` so
callers can treat the name as a plain dynamic-action-provider id.
"""
# Pattern: [provider]:[inner_kind]/[inner_name]; no '/' or ':' in provider.
match = re.match(r'^([^/:]+):([^/:]+)/(.+)$', name)
if not match:
return None
provider, inner_kind, inner_name = match.groups()
if not provider or not inner_kind or not inner_name:
return None
return DapQualifiedName(provider, inner_kind, inner_name)


def parse_action_key(key: str) -> tuple[ActionKind, str]:
"""Parse '/<kind>/<name>' key into (ActionKind, name)."""
tokens = key.split('/')
Expand All @@ -246,11 +284,25 @@ def parse_action_key(key: str) -> tuple[ActionKind, str]:
return kind, name


def create_action_key(kind: ActionKind, name: str) -> str:
def create_action_key(kind: ActionKind | str, name: str) -> str:
"""Create '/<kind>/<name>' key."""
return f'/{kind}/{name}'


def parse_dap_provider_host(name: str) -> str | None:
"""Return the segment before the first ``:`` when the name has multiple ``:``-split parts.

If there is no ``:``, or the first segment is empty, returns ``None``.
"""
parts = name.split(':')
if len(parts) < 2:
return None
host = parts[0]
if not host:
return None
return host


# =============================================================================
# Action core
# =============================================================================
Expand Down
Loading
Loading