Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions agentkit/apps/agent_server_app/agent_server_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import logging
from contextlib import asynccontextmanager
from typing import Any
from typing_extensions import override

import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from google.adk.a2a.utils.agent_to_a2a import to_a2a
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.run_config import RunConfig, StreamingMode
from google.adk.apps.app import App, ResumabilityConfig
from google.adk.artifacts.in_memory_artifact_service import (
InMemoryArtifactService,
)
Expand All @@ -41,9 +41,9 @@
from google.adk.utils.context_utils import Aclosing
from google.genai import types
from opentelemetry import trace
from veadk import Agent
from typing_extensions import override
from veadk import Agent, Runner
from veadk.memory.short_term_memory import ShortTermMemory
from veadk.runner import Runner

from agentkit.apps.agent_server_app.middleware import (
AgentkitTelemetryHTTPMiddleware,
Expand All @@ -68,7 +68,6 @@ def load_agent(self, agent_name: str) -> BaseAgent:
def list_agents(self) -> list[str]:
return [self.agent.name]

@override
def list_agents_detailed(self) -> list[dict[str, Any]]:
name = self.agent.name
description = getattr(self.agent, "description", "") or ""
Expand All @@ -82,21 +81,52 @@ def list_agents_detailed(self) -> list[dict[str, Any]]:
]


class AgentKitAdkWebServer(AdkWebServer):
def __init__(self, *args, **kwargs) -> None:
self.enable_resume = kwargs.pop("enable_resume", False)
super().__init__(*args, **kwargs)

@override
def _create_runner(self, agentic_app: App) -> Runner:
"""Create a runner with common services."""
logger.debug(f"Enable resume: {self.enable_resume}")
try:
agentic_app.resumability_config = ResumabilityConfig(
is_resumable=self.enable_resume
)
runner = Runner(
app=agentic_app,
artifact_service=self.artifact_service,
session_service=self.session_service,
memory_service=self.memory_service,
credential_service=self.credential_service,
)
return runner
except Exception as e:
logger.error(
f"Set resume config to runner failed: {e}. Please check your google-adk version."
)
raise e


class AgentkitAgentServerApp(BaseAgentkitApp):
def __init__(
self,
agent: BaseAgent,
short_term_memory: BaseSessionService | ShortTermMemory,
enable_resume: bool = False,
) -> None:
super().__init__()

self.enable_resume = enable_resume

_artifact_service = InMemoryArtifactService()
_credential_service = InMemoryCredentialService()

_eval_sets_manager = LocalEvalSetsManager(agents_dir=".")
_eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=".")

self.server = AdkWebServer(
self.server = AgentKitAdkWebServer(
agent_loader=AgentKitAgentLoader(agent),
session_service=short_term_memory
if isinstance(short_term_memory, BaseSessionService)
Expand All @@ -109,6 +139,7 @@ def __init__(
eval_sets_manager=_eval_sets_manager,
eval_set_results_manager=_eval_set_results_manager,
agents_dir=".",
enable_resume=self.enable_resume,
)

runner = Runner(agent=agent)
Expand All @@ -117,7 +148,9 @@ def __init__(
@asynccontextmanager
async def lifespan(app: FastAPI):
# trigger A2A server app startup
logger.info("Triggering A2A server app startup within API server...")
logger.info(
"Triggering A2A server app startup within API server..."
)
for handler in _a2a_server_app.router.on_startup:
await handler()
yield
Expand All @@ -144,7 +177,9 @@ async def _invoke_compat(request: Request):
# Determine app_name from loader
app_names = self.server.agent_loader.list_agents()
if not app_names:
raise HTTPException(status_code=404, detail="No agents configured")
raise HTTPException(
status_code=404, detail="No agents configured"
)
app_name = app_names[0]

# Parse payload and convert to ADK Content
Expand Down Expand Up @@ -193,7 +228,9 @@ async def event_generator():
user_id=user_id,
session_id=session_id,
new_message=content,
run_config=RunConfig(streaming_mode=StreamingMode.SSE),
run_config=RunConfig(
streaming_mode=StreamingMode.SSE
),
)
) as agen:
async for event in agen:
Expand Down Expand Up @@ -231,3 +268,4 @@ async def event_generator():
def run(self, host: str, port: int = 8000) -> None:
"""Run the app with Uvicorn server."""
uvicorn.run(self.app, host=host, port=port)
uvicorn.run(self.app, host=host, port=port)