diff --git a/context-graph/actions-graph/LICENSE b/context-graph/actions-graph/LICENSE new file mode 100644 index 00000000..29e3cdde --- /dev/null +++ b/context-graph/actions-graph/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Memgraph + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/context-graph/actions-graph/README.md b/context-graph/actions-graph/README.md new file mode 100644 index 00000000..75cbbc65 --- /dev/null +++ b/context-graph/actions-graph/README.md @@ -0,0 +1,253 @@ +# Actions Graph + +Store and track LLM actions, tool calls, and sessions in Memgraph. + +Actions Graph provides a graph-based storage system for tracking all LLM interactions, including: +- **Tool Calls**: Function/tool invocations by the LLM +- **Tool Results**: Outputs from tool executions +- **Messages**: User, assistant, and system messages +- **Structured Outputs**: Validated JSON outputs from the LLM +- **Subagent Events**: Subagent lifecycle tracking +- **Sessions**: Conversation session management + +## Features + +- 📊 **Graph-based Storage**: Store actions as nodes with relationships in Memgraph +- 🔗 **Temporal Sequences**: Track action order with `FOLLOWED_BY` relationships +- 🌳 **Nested Actions**: Support for parent-child action relationships (e.g., subagents) +- 🏷️ **Session Management**: Create, track, and query sessions with tags +- 📈 **Analytics**: Built-in queries for tool usage stats and session summaries +- 🤖 **Claude Agent SDK Integration**: Ready-to-use hooks for automatic tracking + +## Installation + +```bash +pip install actions-graph +``` + +For Claude Agent SDK integration: +```bash +pip install actions-graph[claude-agent] +``` + +## Quick Start + +### Basic Usage + +```python +from actions_graph import ActionsGraph, Session, ToolCall + +# Initialize the graph +graph = ActionsGraph() +graph.setup() # Create indexes and constraints + +# Create a session +session = Session( + session_id="session-123", + model="claude-sonnet-4-20250514", + working_directory="/path/to/project", +) +graph.create_session(session) + +# Record a tool call +tool_call = graph.record_tool_call( + session_id="session-123", + tool_name="Read", + tool_input={"file_path": "/path/to/file.py"}, + tool_use_id="tool-use-001", +) + +# Record the result +tool_result = graph.record_tool_result( + session_id="session-123", + tool_use_id="tool-use-001", + tool_name="Read", + content="def hello():\n print('Hello, World!')", +) + +# Get session summary +summary = graph.get_session_summary("session-123") +print(f"Actions: {summary['action_count']}, Tools: {summary['tool_call_count']}") +``` + +### Claude Agent SDK Integration + +```python +import asyncio +from actions_graph import ActionsGraph +from actions_graph.hooks import create_tracking_hooks +from claude_agent_sdk import query, ClaudeAgentOptions + +async def main(): + # Initialize graph + graph = ActionsGraph() + graph.setup() + + # Create tracking hooks + hooks = create_tracking_hooks( + graph, + session_id="my-session-123", + session_kwargs={ + "model": "claude-sonnet-4-20250514", + "working_directory": "/path/to/project", + "tags": ["code-review", "python"], + }, + ) + + # Run with automatic tracking + async for message in query( + prompt="Review the code in src/main.py for potential bugs", + options=ClaudeAgentOptions( + hooks=hooks, + allowed_tools=["Read", "Glob", "Grep"], + permission_mode="acceptEdits", + ), + ): + if hasattr(message, "result"): + print(message.result) + + # Query the recorded actions + actions = graph.get_session_actions("my-session-123") + print(f"Recorded {len(actions)} actions") + + # Get tool usage stats + stats = graph.get_tool_usage_stats("my-session-123") + for stat in stats: + print(f"{stat['tool_name']}: {stat['call_count']} calls") + +asyncio.run(main()) +``` + +## Graph Schema + +### Nodes + +- **Session**: LLM conversation sessions + - Properties: `session_id`, `started_at`, `ended_at`, `status`, `model`, `total_cost_usd`, etc. + +- **Action**: Individual actions with type-specific labels + - Labels: `ToolCall`, `ToolResult`, `Message`, `StructuredOutput`, `SubagentEvent`, etc. + - Properties: `action_id`, `session_id`, `action_type`, `timestamp`, `status`, etc. + +- **Tool**: Tool definitions + - Properties: `name`, `is_mcp`, `mcp_server` + +- **Tag**: Session/action tags + - Properties: `name` + +### Relationships + +``` +(:Session)-[:HAS_ACTION]->(:Action) +(:Action)-[:FOLLOWED_BY]->(:Action) +(:Action)-[:PARENT_OF]->(:Action) +(:Session)-[:FORKED_FROM]->(:Session) +(:Action)-[:USED_TOOL]->(:Tool) +(:Session)-[:HAS_TAG]->(:Tag) +``` + +## API Reference + +### ActionsGraph + +Main class for interacting with the graph. + +```python +graph = ActionsGraph() + +# Setup +graph.setup() # Create indexes +graph.drop() # Remove indexes +graph.clear() # Clear all data + +# Sessions +graph.create_session(session) +graph.get_session(session_id) +graph.end_session(session_id, status=ActionStatus.COMPLETED) +graph.list_sessions(limit=100, status=None, tag=None) + +# Actions +graph.record_action(action) +graph.record_tool_call(session_id, tool_name, tool_input, ...) +graph.record_tool_result(session_id, tool_use_id, tool_name, content, ...) +graph.record_message(session_id, role, content, ...) +graph.get_action(action_id) +graph.get_session_actions(session_id, action_type=None, limit=1000) + +# Analytics +graph.get_tool_usage_stats(session_id=None) +graph.get_action_sequence(session_id, include_content=False) +graph.get_session_summary(session_id) +``` + +### Action Types + +| Type | Model Class | Description | +|------|-------------|-------------| +| `tool_call` | `ToolCall` | Tool/function invocation | +| `tool_result` | `ToolResult` | Tool execution result | +| `user_message` | `Message` | User input | +| `assistant_message` | `Message` | LLM response | +| `system_message` | `Message` | System messages | +| `structured_output` | `StructuredOutput` | Validated JSON output | +| `subagent_start` | `SubagentEvent` | Subagent started | +| `subagent_stop` | `SubagentEvent` | Subagent completed | +| `error` | `ErrorEvent` | Error occurred | +| `permission_request` | `PermissionRequest` | Permission requested | +| `rate_limit` | `RateLimitEvent` | Rate limit event | + +### Hooks + +For Claude Agent SDK integration: + +```python +from actions_graph.hooks import create_tracking_hooks, ActionTracker + +# Simple usage +hooks = create_tracking_hooks(graph, session_id) + +# Advanced usage with custom tracker +tracker = ActionTracker( + graph, + session_id, + track_tool_calls=True, + track_tool_results=True, + track_messages=True, + track_subagents=True, + track_permissions=True, + track_errors=True, +) +``` + +## Example Queries + +### Find sessions with errors + +```python +sessions = graph.list_sessions(status=ActionStatus.FAILED) +``` + +### Get all tool calls in a session + +```python +from actions_graph import ActionType + +tool_calls = graph.get_session_actions( + session_id, + action_type=ActionType.TOOL_CALL, +) +``` + +### Custom Cypher queries + +```python +rows = graph._db.query(""" + MATCH (s:Session {session_id: $session_id})-[:HAS_ACTION]->(a:ToolCall) + RETURN a.tool_name AS tool, count(*) AS count + ORDER BY count DESC +""", params={"session_id": "my-session"}) +``` + +## License + +MIT diff --git a/context-graph/actions-graph/pyproject.toml b/context-graph/actions-graph/pyproject.toml new file mode 100644 index 00000000..000c2e6c --- /dev/null +++ b/context-graph/actions-graph/pyproject.toml @@ -0,0 +1,32 @@ +[project] +name = "actions-graph" +version = "0.1.0" +description = "Store and track LLM actions, tool calls, and sessions in Memgraph" +readme = "README.md" +license = { text = "MIT" } +requires-python = ">=3.10" +classifiers = [ + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + "memgraph-toolbox", +] + +[project.optional-dependencies] +claude-agent = [ + "claude-agent-sdk>=0.1.0", +] +test = [ + "pytest>=9.0.3", + "pytest-asyncio>=0.24.0", +] + +[tool.uv.sources] +memgraph-toolbox = { workspace = true } + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/context-graph/actions-graph/src/actions_graph/__init__.py b/context-graph/actions-graph/src/actions_graph/__init__.py new file mode 100644 index 00000000..ecdaa5cd --- /dev/null +++ b/context-graph/actions-graph/src/actions_graph/__init__.py @@ -0,0 +1,80 @@ +"""Actions Graph: Store and track LLM actions, tool calls, and sessions in Memgraph. + +This package provides a graph-based storage system for tracking all LLM interactions, +including tool calls, messages, structured outputs, and session management. + +Quick Start: + from actions_graph import ActionsGraph, Session, ToolCall + + # Initialize the graph + graph = ActionsGraph() + graph.setup() + + # Create a session + session = Session(session_id="my-session-123") + graph.create_session(session) + + # Record tool calls + graph.record_tool_call( + session_id="my-session-123", + tool_name="Read", + tool_input={"file_path": "/path/to/file"}, + ) + +Integration with Claude Agent SDK: + from actions_graph import ActionsGraph + from actions_graph.hooks import create_tracking_hooks + from claude_agent_sdk import query, ClaudeAgentOptions + + graph = ActionsGraph() + graph.setup() + + hooks = create_tracking_hooks(graph, session_id="my-session-123") + + async for message in query( + prompt="Analyze this codebase", + options=ClaudeAgentOptions( + hooks=hooks, + allowed_tools=["Read", "Glob", "Grep"], + ), + ): + print(message) +""" + +from .core import ActionsGraph +from .models import ( + Action, + ActionStatus, + ActionType, + ActionValidationError, + ErrorEvent, + Message, + MessageRole, + PermissionRequest, + RateLimitEvent, + Session, + StructuredOutput, + SubagentEvent, + ToolCall, + ToolResult, +) + +__all__ = [ + "Action", + "ActionStatus", + "ActionType", + "ActionValidationError", + "ActionsGraph", + "ErrorEvent", + "Message", + "MessageRole", + "PermissionRequest", + "RateLimitEvent", + "Session", + "StructuredOutput", + "SubagentEvent", + "ToolCall", + "ToolResult", +] + +__version__ = "0.1.0" diff --git a/context-graph/actions-graph/src/actions_graph/core.py b/context-graph/actions-graph/src/actions_graph/core.py new file mode 100644 index 00000000..f0df11d7 --- /dev/null +++ b/context-graph/actions-graph/src/actions_graph/core.py @@ -0,0 +1,930 @@ +"""Core ActionsGraph class for storing and querying LLM actions in Memgraph. + +This module provides the main interface for persisting and analyzing +LLM actions, tool calls, and sessions in a Memgraph graph database. + +Graph Schema: + Nodes: + - (:Session) - LLM conversation sessions + - (:Action) - Individual actions with labels for type (ToolCall, Message, etc.) + - (:Tool) - Tool definitions + - (:Tag) - Session/action tags + + Relationships: + - (:Session)-[:HAS_ACTION]->(:Action) + - (:Action)-[:FOLLOWED_BY]->(:Action) - Temporal sequence + - (:Action)-[:PARENT_OF]->(:Action) - Nested actions (e.g., subagent) + - (:Session)-[:FORKED_FROM]->(:Session) + - (:Action)-[:USED_TOOL]->(:Tool) + - (:Session)-[:HAS_TAG]->(:Tag) +""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Any + +from memgraph_toolbox.api.memgraph import Memgraph + +from .models import ( + Action, + ActionStatus, + ActionType, + ErrorEvent, + Message, + MessageRole, + PermissionRequest, + RateLimitEvent, + Session, + StructuredOutput, + SubagentEvent, + ToolCall, + ToolResult, +) + + +class ActionsGraph: + """Store and query LLM actions and sessions in Memgraph. + + Provides methods for: + - Creating and managing sessions + - Recording tool calls, messages, and other actions + - Querying action history and analytics + - Building action sequence graphs + """ + + def __init__(self, memgraph: Memgraph | None = None, **kwargs: Any): + """Initialize ActionsGraph. + + Args: + memgraph: An existing Memgraph client instance. If not provided, + a new one is created using kwargs / environment variables. + **kwargs: Forwarded to Memgraph() when memgraph is None. + """ + self._db = memgraph or Memgraph(**kwargs) + self._last_action_id: dict[str, str] = {} # session_id -> last action_id + + # ------------------------------------------------------------------ + # Schema setup + # ------------------------------------------------------------------ + + def setup(self) -> None: + """Create constraints and indexes required for action storage.""" + # Session constraints and indexes + self._db.query("CREATE CONSTRAINT ON (s:Session) ASSERT s.session_id IS UNIQUE;") + self._db.query("CREATE INDEX ON :Session(session_id);") + self._db.query("CREATE INDEX ON :Session(started_at);") + self._db.query("CREATE INDEX ON :Session(status);") + + # Action constraints and indexes + self._db.query("CREATE CONSTRAINT ON (a:Action) ASSERT a.action_id IS UNIQUE;") + self._db.query("CREATE INDEX ON :Action(action_id);") + self._db.query("CREATE INDEX ON :Action(session_id);") + self._db.query("CREATE INDEX ON :Action(timestamp);") + self._db.query("CREATE INDEX ON :Action(action_type);") + + # Tool indexes + self._db.query("CREATE INDEX ON :Tool(name);") + + # Tag indexes + self._db.query("CREATE INDEX ON :Tag(name);") + + def drop(self) -> None: + """Remove all action-related constraints and indexes.""" + import contextlib + + with contextlib.suppress(Exception): + self._db.query("DROP CONSTRAINT ON (s:Session) ASSERT s.session_id IS UNIQUE;") + with contextlib.suppress(Exception): + self._db.query("DROP CONSTRAINT ON (a:Action) ASSERT a.action_id IS UNIQUE;") + # Indexes are dropped with constraints in most cases + + def clear(self) -> None: + """Remove all session and action data from the graph.""" + self._db.query("MATCH (n) WHERE n:Session OR n:Action OR n:Tool DETACH DELETE n;") + self._last_action_id.clear() + + # ------------------------------------------------------------------ + # Session operations + # ------------------------------------------------------------------ + + def create_session(self, session: Session) -> Session: + """Create a new session in the graph. + + Args: + session: Session object to persist + + Returns: + The persisted session + """ + self._db.query( + """ + CREATE (s:Session { + session_id: $session_id, + started_at: $started_at, + ended_at: $ended_at, + status: $status, + model: $model, + total_cost_usd: $total_cost_usd, + total_input_tokens: $total_input_tokens, + total_output_tokens: $total_output_tokens, + working_directory: $working_directory, + git_branch: $git_branch, + metadata: $metadata, + parent_session_id: $parent_session_id + }) + """, + params={ + "session_id": session.session_id, + "started_at": session.started_at, + "ended_at": session.ended_at, + "status": session.status.value, + "model": session.model, + "total_cost_usd": session.total_cost_usd, + "total_input_tokens": session.total_input_tokens, + "total_output_tokens": session.total_output_tokens, + "working_directory": session.working_directory, + "git_branch": session.git_branch, + "metadata": json.dumps(session.metadata), + "parent_session_id": session.parent_session_id, + }, + ) + + # Handle forked sessions + if session.parent_session_id: + self._db.query( + """ + MATCH (child:Session {session_id: $session_id}) + MATCH (parent:Session {session_id: $parent_session_id}) + MERGE (child)-[:FORKED_FROM]->(parent) + """, + params={ + "session_id": session.session_id, + "parent_session_id": session.parent_session_id, + }, + ) + + # Handle tags + if session.tags: + self._db.query( + """ + MATCH (s:Session {session_id: $session_id}) + UNWIND $tags AS tag_name + MERGE (t:Tag {name: tag_name}) + MERGE (s)-[:HAS_TAG]->(t) + """, + params={"session_id": session.session_id, "tags": session.tags}, + ) + + return session + + def get_session(self, session_id: str) -> Session | None: + """Retrieve a session by ID. + + Args: + session_id: Unique session identifier + + Returns: + Session object or None if not found + """ + rows = self._db.query( + """ + MATCH (s:Session {session_id: $session_id}) + OPTIONAL MATCH (s)-[:HAS_TAG]->(t:Tag) + RETURN s.session_id AS session_id, + s.started_at AS started_at, + s.ended_at AS ended_at, + s.status AS status, + s.model AS model, + s.total_cost_usd AS total_cost_usd, + s.total_input_tokens AS total_input_tokens, + s.total_output_tokens AS total_output_tokens, + s.working_directory AS working_directory, + s.git_branch AS git_branch, + s.metadata AS metadata, + s.parent_session_id AS parent_session_id, + collect(t.name) AS tags + """, + params={"session_id": session_id}, + ) + + if not rows: + return None + + row = rows[0] + return self._row_to_session(row) + + def end_session( + self, + session_id: str, + *, + status: ActionStatus = ActionStatus.COMPLETED, + total_cost_usd: float | None = None, + total_input_tokens: int | None = None, + total_output_tokens: int | None = None, + ) -> Session | None: + """Mark a session as ended. + + Args: + session_id: Session to end + status: Final status + total_cost_usd: Total cost in USD + total_input_tokens: Total input tokens + total_output_tokens: Total output tokens + + Returns: + Updated session or None if not found + """ + ended_at = datetime.now(timezone.utc).isoformat() + + sets = ["s.ended_at = $ended_at", "s.status = $status"] + params: dict[str, Any] = { + "session_id": session_id, + "ended_at": ended_at, + "status": status.value, + } + + if total_cost_usd is not None: + sets.append("s.total_cost_usd = $total_cost_usd") + params["total_cost_usd"] = total_cost_usd + if total_input_tokens is not None: + sets.append("s.total_input_tokens = $total_input_tokens") + params["total_input_tokens"] = total_input_tokens + if total_output_tokens is not None: + sets.append("s.total_output_tokens = $total_output_tokens") + params["total_output_tokens"] = total_output_tokens + + self._db.query( + f""" + MATCH (s:Session {{session_id: $session_id}}) + SET {", ".join(sets)} + """, + params=params, + ) + + # Clean up last action tracking + self._last_action_id.pop(session_id, None) + + return self.get_session(session_id) + + def list_sessions( + self, + *, + limit: int = 100, + status: ActionStatus | None = None, + tag: str | None = None, + ) -> list[Session]: + """List sessions with optional filtering. + + Args: + limit: Maximum number of sessions to return + status: Filter by status + tag: Filter by tag + + Returns: + List of sessions ordered by start time (newest first) + """ + where_clauses = [] + params: dict[str, Any] = {"limit": limit} + + if status: + where_clauses.append("s.status = $status") + params["status"] = status.value + + if tag: + where_clauses.append("EXISTS((s)-[:HAS_TAG]->(:Tag {name: $tag}))") + params["tag"] = tag + + where_str = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + rows = self._db.query( + f""" + MATCH (s:Session) + {where_str} + OPTIONAL MATCH (s)-[:HAS_TAG]->(t:Tag) + RETURN s.session_id AS session_id, + s.started_at AS started_at, + s.ended_at AS ended_at, + s.status AS status, + s.model AS model, + s.total_cost_usd AS total_cost_usd, + s.total_input_tokens AS total_input_tokens, + s.total_output_tokens AS total_output_tokens, + s.working_directory AS working_directory, + s.git_branch AS git_branch, + s.metadata AS metadata, + s.parent_session_id AS parent_session_id, + collect(t.name) AS tags + ORDER BY s.started_at DESC + LIMIT $limit + """, + params=params, + ) + + return [self._row_to_session(row) for row in rows] + + # ------------------------------------------------------------------ + # Action operations + # ------------------------------------------------------------------ + + def record_action(self, action: Action) -> Action: + """Record an action in the graph. + + Creates the action node, links it to the session, and creates + temporal FOLLOWED_BY relationships with the previous action. + + Args: + action: Action to record + + Returns: + The recorded action + """ + # Determine additional labels based on action type + type_labels = self._get_type_labels(action) + labels_str = ":Action" + "".join(f":{lbl}" for lbl in type_labels) + + # Build properties based on action type + props = self._action_to_props(action) + + # Create the action node + self._db.query( + f""" + CREATE (a{labels_str} {{ + action_id: $action_id, + session_id: $session_id, + action_type: $action_type, + timestamp: $timestamp, + status: $status, + duration_ms: $duration_ms, + parent_action_id: $parent_action_id, + metadata: $metadata, + properties: $properties + }}) + """, + params={ + "action_id": action.action_id, + "session_id": action.session_id, + "action_type": action.action_type.value, + "timestamp": action.timestamp, + "status": action.status.value, + "duration_ms": action.duration_ms, + "parent_action_id": action.parent_action_id, + "metadata": json.dumps(action.metadata), + "properties": json.dumps(props), + }, + ) + + # Link to session + self._db.query( + """ + MATCH (s:Session {session_id: $session_id}) + MATCH (a:Action {action_id: $action_id}) + MERGE (s)-[:HAS_ACTION]->(a) + """, + params={"session_id": action.session_id, "action_id": action.action_id}, + ) + + # Create temporal sequence + last_action_id = self._last_action_id.get(action.session_id) + if last_action_id: + self._db.query( + """ + MATCH (prev:Action {action_id: $prev_id}) + MATCH (curr:Action {action_id: $curr_id}) + MERGE (prev)-[:FOLLOWED_BY]->(curr) + """, + params={"prev_id": last_action_id, "curr_id": action.action_id}, + ) + + self._last_action_id[action.session_id] = action.action_id + + # Handle parent action relationship + if action.parent_action_id: + self._db.query( + """ + MATCH (parent:Action {action_id: $parent_id}) + MATCH (child:Action {action_id: $child_id}) + MERGE (parent)-[:PARENT_OF]->(child) + """, + params={"parent_id": action.parent_action_id, "child_id": action.action_id}, + ) + + # Handle tool references + if isinstance(action, ToolCall): + self._db.query( + """ + MATCH (a:Action {action_id: $action_id}) + MERGE (t:Tool {name: $tool_name}) + ON CREATE SET t.is_mcp = $is_mcp, t.mcp_server = $mcp_server + MERGE (a)-[:USED_TOOL]->(t) + """, + params={ + "action_id": action.action_id, + "tool_name": action.tool_name, + "is_mcp": action.is_mcp, + "mcp_server": action.mcp_server, + }, + ) + + return action + + def record_tool_call( + self, + session_id: str, + tool_name: str, + tool_input: dict[str, Any], + tool_use_id: str | None = None, + **kwargs: Any, + ) -> ToolCall: + """Convenience method to record a tool call. + + Args: + session_id: Session ID + tool_name: Name of the tool + tool_input: Tool input parameters + tool_use_id: Optional tool use ID for correlation + **kwargs: Additional ToolCall fields + + Returns: + The recorded ToolCall + """ + action = ToolCall( + session_id=session_id, + tool_name=tool_name, + tool_input=tool_input, + tool_use_id=tool_use_id, + **kwargs, + ) + return self.record_action(action) # type: ignore + + def record_tool_result( + self, + session_id: str, + tool_use_id: str, + tool_name: str, + content: str | list[dict[str, Any]] | None = None, + is_error: bool = False, + error_message: str | None = None, + **kwargs: Any, + ) -> ToolResult: + """Convenience method to record a tool result. + + Args: + session_id: Session ID + tool_use_id: ID of the tool call + tool_name: Name of the tool + content: Result content + is_error: Whether the execution failed + error_message: Error message if failed + **kwargs: Additional ToolResult fields + + Returns: + The recorded ToolResult + """ + action = ToolResult( + session_id=session_id, + tool_use_id=tool_use_id, + tool_name=tool_name, + content=content, + is_error=is_error, + error_message=error_message, + **kwargs, + ) + return self.record_action(action) # type: ignore + + def record_message( + self, + session_id: str, + role: MessageRole, + content: str | list[dict[str, Any]], + **kwargs: Any, + ) -> Message: + """Convenience method to record a message. + + Args: + session_id: Session ID + role: Message role (user, assistant, system) + content: Message content + **kwargs: Additional Message fields + + Returns: + The recorded Message + """ + action = Message( + session_id=session_id, + role=role, + content=content, + **kwargs, + ) + return self.record_action(action) # type: ignore + + def get_action(self, action_id: str) -> Action | None: + """Retrieve an action by ID. + + Args: + action_id: Unique action identifier + + Returns: + Action object or None if not found + """ + rows = self._db.query( + """ + MATCH (a:Action {action_id: $action_id}) + RETURN a.action_id AS action_id, + a.session_id AS session_id, + a.action_type AS action_type, + a.timestamp AS timestamp, + a.status AS status, + a.duration_ms AS duration_ms, + a.parent_action_id AS parent_action_id, + a.metadata AS metadata, + a.properties AS properties, + labels(a) AS labels + """, + params={"action_id": action_id}, + ) + + if not rows: + return None + + return self._row_to_action(rows[0]) + + def get_session_actions( + self, + session_id: str, + *, + action_type: ActionType | None = None, + limit: int = 1000, + ) -> list[Action]: + """Get all actions for a session. + + Args: + session_id: Session ID + action_type: Filter by action type + limit: Maximum number of actions + + Returns: + List of actions ordered by timestamp + """ + where_clauses = ["a.session_id = $session_id"] + params: dict[str, Any] = {"session_id": session_id, "limit": limit} + + if action_type: + where_clauses.append("a.action_type = $action_type") + params["action_type"] = action_type.value + + where_str = f"WHERE {' AND '.join(where_clauses)}" + + rows = self._db.query( + f""" + MATCH (a:Action) + {where_str} + RETURN a.action_id AS action_id, + a.session_id AS session_id, + a.action_type AS action_type, + a.timestamp AS timestamp, + a.status AS status, + a.duration_ms AS duration_ms, + a.parent_action_id AS parent_action_id, + a.metadata AS metadata, + a.properties AS properties, + labels(a) AS labels + ORDER BY a.timestamp + LIMIT $limit + """, + params=params, + ) + + return [self._row_to_action(row) for row in rows] + + # ------------------------------------------------------------------ + # Analytics + # ------------------------------------------------------------------ + + def get_tool_usage_stats( + self, + session_id: str | None = None, + ) -> list[dict[str, Any]]: + """Get tool usage statistics. + + Args: + session_id: Optional session filter + + Returns: + List of tool usage statistics + """ + where_clause = "" + params: dict[str, Any] = {} + + if session_id: + where_clause = "WHERE a.session_id = $session_id" + params["session_id"] = session_id + + rows = self._db.query( + f""" + MATCH (a:Action:ToolCall)-[:USED_TOOL]->(t:Tool) + {where_clause} + RETURN t.name AS tool_name, + t.is_mcp AS is_mcp, + t.mcp_server AS mcp_server, + count(a) AS call_count, + avg(a.duration_ms) AS avg_duration_ms, + sum(CASE WHEN a.status = 'failed' THEN 1 ELSE 0 END) AS error_count + ORDER BY call_count DESC + """, + params=params, + ) + + return [dict(row) for row in rows] + + def get_action_sequence( + self, + session_id: str, + *, + include_content: bool = False, + ) -> list[dict[str, Any]]: + """Get the sequence of actions in a session. + + Args: + session_id: Session ID + include_content: Whether to include full action content + + Returns: + List of actions in sequence with relationships + """ + rows = self._db.query( + """ + MATCH (s:Session {session_id: $session_id})-[:HAS_ACTION]->(a:Action) + OPTIONAL MATCH (a)-[:FOLLOWED_BY]->(next:Action) + OPTIONAL MATCH (a)-[:PARENT_OF]->(child:Action) + RETURN a.action_id AS action_id, + a.action_type AS action_type, + a.timestamp AS timestamp, + a.status AS status, + a.properties AS properties, + next.action_id AS next_action_id, + collect(DISTINCT child.action_id) AS child_action_ids + ORDER BY a.timestamp + """, + params={"session_id": session_id}, + ) + + result = [] + for row in rows: + item = { + "action_id": row["action_id"], + "action_type": row["action_type"], + "timestamp": row["timestamp"], + "status": row["status"], + "next_action_id": row["next_action_id"], + "child_action_ids": row["child_action_ids"], + } + if include_content: + item["properties"] = json.loads(row["properties"]) if row["properties"] else {} + result.append(item) + + return result + + def get_session_summary(self, session_id: str) -> dict[str, Any]: + """Get a summary of a session. + + Args: + session_id: Session ID + + Returns: + Summary statistics for the session + """ + rows = self._db.query( + """ + MATCH (s:Session {session_id: $session_id}) + OPTIONAL MATCH (s)-[:HAS_ACTION]->(a:Action) + RETURN s.session_id AS session_id, + s.started_at AS started_at, + s.ended_at AS ended_at, + s.status AS status, + s.model AS model, + s.total_cost_usd AS total_cost_usd, + s.total_input_tokens AS total_input_tokens, + s.total_output_tokens AS total_output_tokens, + count(a) AS action_count, + sum(CASE WHEN a.action_type = 'tool_call' THEN 1 ELSE 0 END) AS tool_call_count, + sum(CASE WHEN a.action_type = 'user_message' THEN 1 ELSE 0 END) AS user_message_count, + sum(CASE WHEN a.action_type = 'assistant_message' THEN 1 ELSE 0 END) AS assistant_message_count, + sum(CASE WHEN a.status = 'failed' THEN 1 ELSE 0 END) AS error_count + """, + params={"session_id": session_id}, + ) + + if not rows: + return {} + + row = rows[0] + return { + "session_id": row["session_id"], + "started_at": row["started_at"], + "ended_at": row["ended_at"], + "status": row["status"], + "model": row["model"], + "total_cost_usd": row["total_cost_usd"], + "total_input_tokens": row["total_input_tokens"], + "total_output_tokens": row["total_output_tokens"], + "action_count": row["action_count"], + "tool_call_count": row["tool_call_count"], + "user_message_count": row["user_message_count"], + "assistant_message_count": row["assistant_message_count"], + "error_count": row["error_count"], + } + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _get_type_labels(self, action: Action) -> list[str]: + """Get additional labels for an action based on its type.""" + labels = [] + if isinstance(action, ToolCall): + labels.append("ToolCall") + elif isinstance(action, ToolResult): + labels.append("ToolResult") + elif isinstance(action, Message): + labels.append("Message") + labels.append(action.role.value.title() + "Message") + elif isinstance(action, StructuredOutput): + labels.append("StructuredOutput") + elif isinstance(action, SubagentEvent): + labels.append("SubagentEvent") + elif isinstance(action, PermissionRequest): + labels.append("PermissionRequest") + elif isinstance(action, ErrorEvent): + labels.append("ErrorEvent") + elif isinstance(action, RateLimitEvent): + labels.append("RateLimitEvent") + return labels + + def _action_to_props(self, action: Action) -> dict[str, Any]: + """Extract type-specific properties from an action.""" + props: dict[str, Any] = {} + + if isinstance(action, ToolCall): + props["tool_name"] = action.tool_name + props["tool_input"] = action.tool_input + props["tool_use_id"] = action.tool_use_id + props["is_mcp"] = action.is_mcp + props["mcp_server"] = action.mcp_server + elif isinstance(action, ToolResult): + props["tool_use_id"] = action.tool_use_id + props["tool_name"] = action.tool_name + props["content"] = action.content + props["is_error"] = action.is_error + props["error_message"] = action.error_message + elif isinstance(action, Message): + props["role"] = action.role.value + props["content"] = action.content + props["message_id"] = action.message_id + props["model"] = action.model + props["usage"] = action.usage + elif isinstance(action, StructuredOutput): + props["output_type"] = action.output_type + props["output_data"] = action.output_data + props["schema"] = action.schema + props["validation_passed"] = action.validation_passed + elif isinstance(action, SubagentEvent): + props["agent_id"] = action.agent_id + props["agent_type"] = action.agent_type + props["description"] = action.description + props["result"] = action.result + props["usage"] = action.usage + elif isinstance(action, PermissionRequest): + props["tool_name"] = action.tool_name + props["tool_input"] = action.tool_input + props["decision"] = action.decision + props["reason"] = action.reason + elif isinstance(action, ErrorEvent): + props["error_type"] = action.error_type + props["error_message"] = action.error_message + props["error_details"] = action.error_details + props["recoverable"] = action.recoverable + elif isinstance(action, RateLimitEvent): + props["rate_limit_status"] = action.rate_limit_status + props["rate_limit_type"] = action.rate_limit_type + props["resets_at"] = action.resets_at + props["utilization"] = action.utilization + + return props + + def _row_to_session(self, row: dict[str, Any]) -> Session: + """Convert a database row to a Session object.""" + metadata = row.get("metadata") + if isinstance(metadata, str): + metadata = json.loads(metadata) + + return Session( + session_id=row["session_id"], + started_at=row["started_at"], + ended_at=row.get("ended_at"), + status=ActionStatus(row["status"]) if row.get("status") else ActionStatus.IN_PROGRESS, + model=row.get("model"), + total_cost_usd=row.get("total_cost_usd"), + total_input_tokens=row.get("total_input_tokens", 0), + total_output_tokens=row.get("total_output_tokens", 0), + working_directory=row.get("working_directory"), + git_branch=row.get("git_branch"), + tags=row.get("tags", []), + metadata=metadata or {}, + parent_session_id=row.get("parent_session_id"), + ) + + def _row_to_action(self, row: dict[str, Any]) -> Action: + """Convert a database row to an Action object.""" + action_type = ActionType(row["action_type"]) + props = json.loads(row["properties"]) if row.get("properties") else {} + metadata = json.loads(row["metadata"]) if row.get("metadata") else {} + + base_kwargs = { + "action_id": row["action_id"], + "session_id": row["session_id"], + "timestamp": row["timestamp"], + "status": ActionStatus(row["status"]) if row.get("status") else ActionStatus.COMPLETED, + "duration_ms": row.get("duration_ms"), + "parent_action_id": row.get("parent_action_id"), + "metadata": metadata, + } + + if action_type == ActionType.TOOL_CALL: + return ToolCall( + **base_kwargs, + tool_name=props.get("tool_name", ""), + tool_input=props.get("tool_input", {}), + tool_use_id=props.get("tool_use_id"), + is_mcp=props.get("is_mcp", False), + mcp_server=props.get("mcp_server"), + ) + elif action_type == ActionType.TOOL_RESULT: + return ToolResult( + **base_kwargs, + tool_use_id=props.get("tool_use_id", ""), + tool_name=props.get("tool_name", ""), + content=props.get("content"), + is_error=props.get("is_error", False), + error_message=props.get("error_message"), + ) + elif action_type in ( + ActionType.USER_MESSAGE, + ActionType.ASSISTANT_MESSAGE, + ActionType.SYSTEM_MESSAGE, + ): + role = MessageRole(props.get("role", "user")) + return Message( + **base_kwargs, + role=role, + content=props.get("content", ""), + message_id=props.get("message_id"), + model=props.get("model"), + usage=props.get("usage"), + ) + elif action_type == ActionType.STRUCTURED_OUTPUT: + return StructuredOutput( + **base_kwargs, + output_type=props.get("output_type", ""), + output_data=props.get("output_data"), + schema=props.get("schema"), + validation_passed=props.get("validation_passed", True), + ) + elif action_type in (ActionType.SUBAGENT_START, ActionType.SUBAGENT_STOP): + event = SubagentEvent( + **base_kwargs, + agent_id=props.get("agent_id", ""), + agent_type=props.get("agent_type", ""), + description=props.get("description", ""), + result=props.get("result"), + usage=props.get("usage"), + ) + event.action_type = action_type + return event + elif action_type == ActionType.PERMISSION_REQUEST: + return PermissionRequest( + **base_kwargs, + tool_name=props.get("tool_name", ""), + tool_input=props.get("tool_input", {}), + decision=props.get("decision"), + reason=props.get("reason"), + ) + elif action_type == ActionType.ERROR: + return ErrorEvent( + **base_kwargs, + error_type=props.get("error_type", ""), + error_message=props.get("error_message", ""), + error_details=props.get("error_details"), + recoverable=props.get("recoverable", True), + ) + elif action_type == ActionType.RATE_LIMIT: + return RateLimitEvent( + **base_kwargs, + rate_limit_status=props.get("rate_limit_status", ""), + rate_limit_type=props.get("rate_limit_type"), + resets_at=props.get("resets_at"), + utilization=props.get("utilization"), + ) + else: + return Action(**base_kwargs, action_type=action_type) diff --git a/context-graph/actions-graph/src/actions_graph/hooks.py b/context-graph/actions-graph/src/actions_graph/hooks.py new file mode 100644 index 00000000..5e2710b7 --- /dev/null +++ b/context-graph/actions-graph/src/actions_graph/hooks.py @@ -0,0 +1,568 @@ +"""Integration hooks for Claude Agent SDK. + +This module provides hooks that automatically track LLM actions +from the Claude Agent SDK and persist them to Memgraph. + +Usage: + from actions_graph import ActionsGraph + from actions_graph.hooks import create_tracking_hooks + + graph = ActionsGraph() + hooks = create_tracking_hooks(graph, session_id="my-session") + + # Use with Claude Agent SDK + options = ClaudeAgentOptions( + hooks=hooks, + ... + ) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from .models import ( + ActionStatus, + ActionType, + ErrorEvent, + Message, + MessageRole, + PermissionRequest, + RateLimitEvent, + Session, + SubagentEvent, + ToolCall, + ToolResult, +) + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from .core import ActionsGraph + + # Type hints for Claude Agent SDK (optional dependency) + HookCallback = Callable[[dict[str, Any], str | None, dict[str, Any]], Awaitable[dict[str, Any]]] + + +class ActionTracker: + """Tracks LLM actions and persists them to Memgraph. + + This class provides hook callbacks compatible with the Claude Agent SDK + that automatically record all actions to an ActionsGraph instance. + """ + + def __init__( + self, + graph: ActionsGraph, + session_id: str, + *, + track_tool_calls: bool = True, + track_tool_results: bool = True, + track_messages: bool = True, + track_subagents: bool = True, + track_permissions: bool = True, + track_errors: bool = True, + track_rate_limits: bool = True, + ): + """Initialize ActionTracker. + + Args: + graph: ActionsGraph instance for persistence + session_id: Session ID to associate actions with + track_tool_calls: Whether to track PreToolUse events + track_tool_results: Whether to track PostToolUse events + track_messages: Whether to track UserPromptSubmit events + track_subagents: Whether to track SubagentStart/Stop events + track_permissions: Whether to track PermissionRequest events + track_errors: Whether to track PostToolUseFailure events + track_rate_limits: Whether to track rate limit events + """ + self.graph = graph + self.session_id = session_id + self.track_tool_calls = track_tool_calls + self.track_tool_results = track_tool_results + self.track_messages = track_messages + self.track_subagents = track_subagents + self.track_permissions = track_permissions + self.track_errors = track_errors + self.track_rate_limits = track_rate_limits + + # Track tool use IDs to parent action IDs + self._tool_use_to_action: dict[str, str] = {} + + async def pre_tool_use( + self, + input_data: dict[str, Any], + tool_use_id: str | None, + _context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback for PreToolUse events. + + Records tool calls before they are executed. + """ + if not self.track_tool_calls: + return {} + + tool_name = input_data.get("tool_name", "") + tool_input = input_data.get("tool_input", {}) + actual_tool_use_id = input_data.get("tool_use_id") or tool_use_id + + # Get parent action from subagent context + parent_action_id = None + agent_id = input_data.get("agent_id") + if agent_id: + # This is a subagent tool call + parent_action_id = self._tool_use_to_action.get(agent_id) + + action = ToolCall( + session_id=self.session_id, + tool_name=tool_name, + tool_input=tool_input, + tool_use_id=actual_tool_use_id, + status=ActionStatus.IN_PROGRESS, + parent_action_id=parent_action_id, + metadata={ + "agent_id": agent_id, + "agent_type": input_data.get("agent_type"), + "cwd": input_data.get("cwd"), + }, + ) + + self.graph.record_action(action) + + # Track tool use ID to action ID mapping + if actual_tool_use_id: + self._tool_use_to_action[actual_tool_use_id] = action.action_id + + return {} + + async def post_tool_use( + self, + input_data: dict[str, Any], + tool_use_id: str | None, + _context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback for PostToolUse events. + + Records tool results after execution. + """ + if not self.track_tool_results: + return {} + + tool_name = input_data.get("tool_name", "") + actual_tool_use_id = input_data.get("tool_use_id") or tool_use_id + tool_response = input_data.get("tool_response") + + # Determine if this is an error response + is_error = False + error_message = None + content = tool_response + + if isinstance(tool_response, dict): + is_error = tool_response.get("is_error", False) + error_message = tool_response.get("error") + content = tool_response.get("content", tool_response) + + # Find parent tool call action + parent_action_id = None + if actual_tool_use_id: + parent_action_id = self._tool_use_to_action.get(actual_tool_use_id) + + action = ToolResult( + session_id=self.session_id, + tool_use_id=actual_tool_use_id or "", + tool_name=tool_name, + content=content if isinstance(content, (str, list)) else str(content), + is_error=is_error, + error_message=error_message, + parent_action_id=parent_action_id, + metadata={ + "agent_id": input_data.get("agent_id"), + "agent_type": input_data.get("agent_type"), + }, + ) + + self.graph.record_action(action) + return {} + + async def post_tool_use_failure( + self, + input_data: dict[str, Any], + tool_use_id: str | None, + _context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback for PostToolUseFailure events. + + Records tool execution failures. + """ + if not self.track_errors: + return {} + + tool_name = input_data.get("tool_name", "") + actual_tool_use_id = input_data.get("tool_use_id") or tool_use_id + error = input_data.get("error", "Unknown error") + is_interrupt = input_data.get("is_interrupt", False) + + # Find parent tool call action + parent_action_id = None + if actual_tool_use_id: + parent_action_id = self._tool_use_to_action.get(actual_tool_use_id) + + action = ErrorEvent( + session_id=self.session_id, + error_type="tool_failure", + error_message=error, + error_details={ + "tool_name": tool_name, + "tool_use_id": actual_tool_use_id, + "tool_input": input_data.get("tool_input", {}), + "is_interrupt": is_interrupt, + }, + recoverable=not is_interrupt, + parent_action_id=parent_action_id, + ) + + self.graph.record_action(action) + return {} + + async def user_prompt_submit( + self, + input_data: dict[str, Any], + _tool_use_id: str | None, + _context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback for UserPromptSubmit events. + + Records user messages. + """ + if not self.track_messages: + return {} + + prompt = input_data.get("prompt", "") + + action = Message( + session_id=self.session_id, + role=MessageRole.USER, + content=prompt, + metadata={ + "cwd": input_data.get("cwd"), + "permission_mode": input_data.get("permission_mode"), + }, + ) + + self.graph.record_action(action) + return {} + + async def subagent_start( + self, + input_data: dict[str, Any], + _tool_use_id: str | None, + _context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback for SubagentStart events. + + Records when subagents start. + """ + if not self.track_subagents: + return {} + + agent_id = input_data.get("agent_id", "") + agent_type = input_data.get("agent_type", "") + + action = SubagentEvent( + session_id=self.session_id, + agent_id=agent_id, + agent_type=agent_type, + status=ActionStatus.IN_PROGRESS, + ) + action.action_type = ActionType.SUBAGENT_START + + self.graph.record_action(action) + + # Track agent_id to action mapping for nested tool calls + self._tool_use_to_action[agent_id] = action.action_id + + return {} + + async def subagent_stop( + self, + input_data: dict[str, Any], + _tool_use_id: str | None, + _context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback for SubagentStop events. + + Records when subagents complete. + """ + if not self.track_subagents: + return {} + + agent_id = input_data.get("agent_id", "") + agent_type = input_data.get("agent_type", "") + stop_hook_active = input_data.get("stop_hook_active", False) + + # Find the start event + parent_action_id = self._tool_use_to_action.get(agent_id) + + action = SubagentEvent( + session_id=self.session_id, + agent_id=agent_id, + agent_type=agent_type, + status=ActionStatus.COMPLETED, + parent_action_id=parent_action_id, + metadata={ + "stop_hook_active": stop_hook_active, + "agent_transcript_path": input_data.get("agent_transcript_path"), + }, + ) + action.action_type = ActionType.SUBAGENT_STOP + + self.graph.record_action(action) + return {} + + async def permission_request( + self, + input_data: dict[str, Any], + _tool_use_id: str | None, + _context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback for PermissionRequest events. + + Records permission requests. + """ + if not self.track_permissions: + return {} + + tool_name = input_data.get("tool_name", "") + tool_input = input_data.get("tool_input", {}) + + action = PermissionRequest( + session_id=self.session_id, + tool_name=tool_name, + tool_input=tool_input, + status=ActionStatus.PENDING, + metadata={ + "permission_suggestions": input_data.get("permission_suggestions", []), + }, + ) + + self.graph.record_action(action) + return {} + + async def notification( + self, + input_data: dict[str, Any], + _tool_use_id: str | None, + _context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback for Notification events. + + Records system notifications. + """ + message = input_data.get("message", "") + title = input_data.get("title") + notification_type = input_data.get("notification_type", "") + + action = Message( + session_id=self.session_id, + role=MessageRole.SYSTEM, + content=message, + metadata={ + "title": title, + "notification_type": notification_type, + }, + ) + + self.graph.record_action(action) + return {} + + async def stop( + self, + _input_data: dict[str, Any], + _tool_use_id: str | None, + _context: dict[str, Any], + ) -> dict[str, Any]: + """Hook callback for Stop events. + + Records when execution stops. + """ + # Update session as ended + self.graph.end_session(self.session_id, status=ActionStatus.COMPLETED) + return {} + + +def create_tracking_hooks( + graph: ActionsGraph, + session_id: str, + *, + create_session: bool = True, + session_kwargs: dict[str, Any] | None = None, + **tracker_kwargs: Any, +) -> dict[str, list[Any]]: + """Create Claude Agent SDK hooks for action tracking. + + This function creates a complete set of hooks that can be passed + directly to ClaudeAgentOptions.hooks. + + Args: + graph: ActionsGraph instance for persistence + session_id: Session ID to associate actions with + create_session: Whether to create the session in the graph + session_kwargs: Additional kwargs for Session creation + **tracker_kwargs: Additional kwargs for ActionTracker + + Returns: + Dictionary of hook event names to hook matchers, ready for + use with ClaudeAgentOptions.hooks + + Example: + from actions_graph import ActionsGraph + from actions_graph.hooks import create_tracking_hooks + from claude_agent_sdk import query, ClaudeAgentOptions + + graph = ActionsGraph() + graph.setup() + + hooks = create_tracking_hooks(graph, session_id="my-session-123") + + async for message in query( + prompt="Hello!", + options=ClaudeAgentOptions( + hooks=hooks, + allowed_tools=["Read", "Write"], + ), + ): + print(message) + """ + # Try to import HookMatcher from claude_agent_sdk + try: + from claude_agent_sdk import HookMatcher + except ImportError: + # Create a simple HookMatcher-like class if SDK not available + class HookMatcher: # type: ignore[no-redef] + def __init__( + self, + matcher: str | None = None, + hooks: list[Any] | None = None, + timeout: float | None = None, + ): + self.matcher = matcher + self.hooks = hooks or [] + self.timeout = timeout + + # Create session if requested + if create_session: + session_kwargs = session_kwargs or {} + session = Session(session_id=session_id, **session_kwargs) + graph.create_session(session) + + # Create tracker + tracker = ActionTracker(graph, session_id, **tracker_kwargs) + + # Build hooks dictionary + return { + "PreToolUse": [HookMatcher(hooks=[tracker.pre_tool_use])], + "PostToolUse": [HookMatcher(hooks=[tracker.post_tool_use])], + "PostToolUseFailure": [HookMatcher(hooks=[tracker.post_tool_use_failure])], + "UserPromptSubmit": [HookMatcher(hooks=[tracker.user_prompt_submit])], + "SubagentStart": [HookMatcher(hooks=[tracker.subagent_start])], + "SubagentStop": [HookMatcher(hooks=[tracker.subagent_stop])], + "PermissionRequest": [HookMatcher(hooks=[tracker.permission_request])], + "Notification": [HookMatcher(hooks=[tracker.notification])], + "Stop": [HookMatcher(hooks=[tracker.stop])], + } + + +def create_message_handler( + graph: ActionsGraph, + session_id: str, +) -> Callable[[Any], None]: + """Create a message handler for processing Claude Agent SDK messages. + + This handler can be used to process messages from the Claude Agent SDK + and record them to the ActionsGraph. + + Args: + graph: ActionsGraph instance for persistence + session_id: Session ID to associate actions with + + Returns: + A callable that processes SDK messages + + Example: + from actions_graph import ActionsGraph + from actions_graph.hooks import create_message_handler + from claude_agent_sdk import query + + graph = ActionsGraph() + handler = create_message_handler(graph, "my-session") + + async for message in query(prompt="Hello!"): + handler(message) + """ + + def handler(message: Any) -> None: + """Process a Claude Agent SDK message and record to graph.""" + # Handle different message types + message_type = getattr(message, "type", None) or type(message).__name__ + + if message_type == "AssistantMessage" or hasattr(message, "content"): + # Assistant message + content_blocks = getattr(message, "content", []) + model = getattr(message, "model", None) + usage = getattr(message, "usage", None) + message_id = getattr(message, "message_id", None) + + # Extract text content + text_content = [] + for block in content_blocks: + if hasattr(block, "text"): + text_content.append({"type": "text", "text": block.text}) + elif hasattr(block, "name"): + # Tool use block - already handled by hooks + pass + elif hasattr(block, "thinking"): + text_content.append({"type": "thinking", "thinking": block.thinking}) + + if text_content: + action = Message( + session_id=session_id, + role=MessageRole.ASSISTANT, + content=text_content, + model=model, + usage=dict(usage) if usage else None, + message_id=message_id, + ) + graph.record_action(action) + + elif message_type == "ResultMessage" or hasattr(message, "subtype"): + # Result message - update session + subtype = getattr(message, "subtype", "") + total_cost = getattr(message, "total_cost_usd", None) + usage = getattr(message, "usage", {}) + + status = ActionStatus.COMPLETED + if subtype.startswith("error"): + status = ActionStatus.FAILED + + graph.end_session( + session_id, + status=status, + total_cost_usd=total_cost, + total_input_tokens=usage.get("input_tokens") if usage else None, + total_output_tokens=usage.get("output_tokens") if usage else None, + ) + + elif message_type == "RateLimitEvent" or hasattr(message, "rate_limit_info"): + # Rate limit event + info = getattr(message, "rate_limit_info", message) + action = RateLimitEvent( + session_id=session_id, + rate_limit_status=getattr(info, "status", ""), + rate_limit_type=getattr(info, "rate_limit_type", None), + resets_at=getattr(info, "resets_at", None), + utilization=getattr(info, "utilization", None), + ) + graph.record_action(action) + + return handler diff --git a/context-graph/actions-graph/src/actions_graph/models.py b/context-graph/actions-graph/src/actions_graph/models.py new file mode 100644 index 00000000..cb49cfd8 --- /dev/null +++ b/context-graph/actions-graph/src/actions_graph/models.py @@ -0,0 +1,360 @@ +"""Data models for tracking LLM actions and sessions in Memgraph. + +This module defines the core data structures for representing: +- Sessions: LLM conversation sessions with metadata +- Actions: Base class for all trackable LLM actions +- ToolCalls: Invocations of tools by the LLM +- ToolResults: Results returned from tool executions +- Messages: User and assistant messages in the conversation +- StructuredOutputs: Validated structured outputs from the LLM + +These models are designed to be compatible with the Claude Agent SDK +and can be extended for other LLM frameworks. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Any +from uuid import uuid4 + + +class ActionValidationError(ValueError): + """Raised when an action field violates validation rules.""" + + +# Validation patterns +_SESSION_ID_RE = re.compile(r"^[a-zA-Z0-9_-]{1,128}$") +_ACTION_ID_RE = re.compile(r"^[a-zA-Z0-9_-]{1,128}$") + + +def _utc_now() -> str: + """Return current UTC timestamp in ISO format.""" + return datetime.now(timezone.utc).isoformat() + + +def _generate_id() -> str: + """Generate a unique action ID.""" + return str(uuid4()) + + +def validate_session_id(session_id: str) -> str: + """Validate session ID format.""" + if not session_id or not _SESSION_ID_RE.match(session_id): + raise ActionValidationError(f"session_id must match pattern {_SESSION_ID_RE.pattern}, got: {session_id!r}") + return session_id + + +def validate_action_id(action_id: str) -> str: + """Validate action ID format.""" + if not action_id or not _ACTION_ID_RE.match(action_id): + raise ActionValidationError(f"action_id must match pattern {_ACTION_ID_RE.pattern}, got: {action_id!r}") + return action_id + + +class ActionType(str, Enum): + """Types of actions that can be tracked.""" + + TOOL_CALL = "tool_call" + TOOL_RESULT = "tool_result" + USER_MESSAGE = "user_message" + ASSISTANT_MESSAGE = "assistant_message" + SYSTEM_MESSAGE = "system_message" + STRUCTURED_OUTPUT = "structured_output" + SUBAGENT_START = "subagent_start" + SUBAGENT_STOP = "subagent_stop" + SESSION_START = "session_start" + SESSION_END = "session_end" + ERROR = "error" + PERMISSION_REQUEST = "permission_request" + RATE_LIMIT = "rate_limit" + + +class ActionStatus(str, Enum): + """Status of an action.""" + + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + BLOCKED = "blocked" + DENIED = "denied" + + +class MessageRole(str, Enum): + """Role of a message sender.""" + + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + + +@dataclass +class Session: + """Represents an LLM conversation session. + + A session contains multiple actions and maintains state + across the conversation. + + Attributes: + session_id: Unique identifier for the session + started_at: ISO timestamp when session started + ended_at: ISO timestamp when session ended (if completed) + status: Current status of the session + model: LLM model used in the session + total_cost_usd: Estimated total cost in USD + total_input_tokens: Total input tokens consumed + total_output_tokens: Total output tokens generated + working_directory: Working directory for the session + git_branch: Git branch at start of session + tags: Optional tags for categorization + metadata: Additional session metadata + parent_session_id: ID of parent session if forked + """ + + session_id: str + started_at: str = field(default_factory=_utc_now) + ended_at: str | None = None + status: ActionStatus = ActionStatus.IN_PROGRESS + model: str | None = None + total_cost_usd: float | None = None + total_input_tokens: int = 0 + total_output_tokens: int = 0 + working_directory: str | None = None + git_branch: str | None = None + tags: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + parent_session_id: str | None = None + + def __post_init__(self): + validate_session_id(self.session_id) + if self.parent_session_id: + validate_session_id(self.parent_session_id) + + +@dataclass +class Action: + """Base class for all trackable LLM actions. + + Actions represent discrete events in an LLM session, such as + tool calls, messages, or structured outputs. + + Attributes: + action_id: Unique identifier for the action + session_id: ID of the session this action belongs to + action_type: Type of the action + timestamp: ISO timestamp when action occurred + status: Current status of the action + duration_ms: Duration in milliseconds (if applicable) + parent_action_id: ID of parent action (for nested actions) + metadata: Additional action metadata + """ + + action_id: str = field(default_factory=_generate_id) + session_id: str = "" + action_type: ActionType = ActionType.TOOL_CALL + timestamp: str = field(default_factory=_utc_now) + status: ActionStatus = ActionStatus.COMPLETED + duration_ms: int | None = None + parent_action_id: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + validate_action_id(self.action_id) + if self.session_id: + validate_session_id(self.session_id) + if self.parent_action_id: + validate_action_id(self.parent_action_id) + + +@dataclass +class ToolCall(Action): + """Represents a tool/function call made by the LLM. + + Attributes: + tool_name: Name of the tool being called + tool_input: Input parameters passed to the tool + tool_use_id: Unique ID for this tool use (for correlating results) + is_mcp: Whether this is an MCP (Model Context Protocol) tool + mcp_server: MCP server name if applicable + """ + + tool_name: str = "" + tool_input: dict[str, Any] = field(default_factory=dict) + tool_use_id: str | None = None + is_mcp: bool = False + mcp_server: str | None = None + + def __post_init__(self): + super().__post_init__() + self.action_type = ActionType.TOOL_CALL + if self.tool_name.startswith("mcp__"): + self.is_mcp = True + parts = self.tool_name.split("__") + if len(parts) >= 2: + self.mcp_server = parts[1] + + +@dataclass +class ToolResult(Action): + """Represents the result of a tool execution. + + Attributes: + tool_use_id: ID of the tool call this result corresponds to + tool_name: Name of the tool that was executed + content: Result content (text or structured) + is_error: Whether the tool execution resulted in an error + error_message: Error message if is_error is True + """ + + tool_use_id: str = "" + tool_name: str = "" + content: str | list[dict[str, Any]] | None = None + is_error: bool = False + error_message: str | None = None + + def __post_init__(self): + super().__post_init__() + self.action_type = ActionType.TOOL_RESULT + if self.is_error: + self.status = ActionStatus.FAILED + + +@dataclass +class Message(Action): + """Represents a message in the conversation. + + Attributes: + role: Role of the message sender + content: Message content (text or content blocks) + message_id: API message ID (if available) + model: Model that generated the message (for assistant messages) + usage: Token usage for this message + """ + + role: MessageRole = MessageRole.USER + content: str | list[dict[str, Any]] = "" + message_id: str | None = None + model: str | None = None + usage: dict[str, int] | None = None + + def __post_init__(self): + super().__post_init__() + if self.role == MessageRole.USER: + self.action_type = ActionType.USER_MESSAGE + elif self.role == MessageRole.ASSISTANT: + self.action_type = ActionType.ASSISTANT_MESSAGE + else: + self.action_type = ActionType.SYSTEM_MESSAGE + + +@dataclass +class StructuredOutput(Action): + """Represents a validated structured output from the LLM. + + Attributes: + output_type: Type/schema name of the structured output + output_data: The structured output data + schema: JSON schema used for validation (if any) + validation_passed: Whether the output passed validation + """ + + output_type: str = "" + output_data: Any = None + schema: dict[str, Any] | None = None + validation_passed: bool = True + + def __post_init__(self): + super().__post_init__() + self.action_type = ActionType.STRUCTURED_OUTPUT + + +@dataclass +@dataclass +class SubagentEvent(Action): + """Represents a subagent lifecycle event. + + Attributes: + agent_id: Unique identifier for the subagent + agent_type: Type of the subagent + description: Description of the subagent task + result: Result from subagent (if completed) + usage: Token usage for the subagent + + Note: + The action_type should be set by the caller to either + SUBAGENT_START or SUBAGENT_STOP after construction. + """ + + agent_id: str = "" + agent_type: str = "" + description: str = "" + result: str | None = None + usage: dict[str, Any] | None = None + + +@dataclass +class PermissionRequest(Action): + """Represents a permission request event. + + Attributes: + tool_name: Tool requesting permission + tool_input: Input for the tool + decision: Permission decision (allow/deny/ask) + reason: Reason for the decision + """ + + tool_name: str = "" + tool_input: dict[str, Any] = field(default_factory=dict) + decision: str | None = None + reason: str | None = None + + def __post_init__(self): + super().__post_init__() + self.action_type = ActionType.PERMISSION_REQUEST + + +@dataclass +class ErrorEvent(Action): + """Represents an error that occurred during execution. + + Attributes: + error_type: Type/classification of the error + error_message: Human-readable error message + error_details: Additional error details + recoverable: Whether the error is recoverable + """ + + error_type: str = "" + error_message: str = "" + error_details: dict[str, Any] | None = None + recoverable: bool = True + + def __post_init__(self): + super().__post_init__() + self.action_type = ActionType.ERROR + self.status = ActionStatus.FAILED + + +@dataclass +class RateLimitEvent(Action): + """Represents a rate limit event. + + Attributes: + rate_limit_status: Status (allowed, allowed_warning, rejected) + rate_limit_type: Type of rate limit + resets_at: Unix timestamp when limit resets + utilization: Fraction of limit consumed (0.0 to 1.0) + """ + + rate_limit_status: str = "" + rate_limit_type: str | None = None + resets_at: int | None = None + utilization: float | None = None + + def __post_init__(self): + super().__post_init__() + self.action_type = ActionType.RATE_LIMIT diff --git a/context-graph/actions-graph/tests/__init__.py b/context-graph/actions-graph/tests/__init__.py new file mode 100644 index 00000000..9ce957de --- /dev/null +++ b/context-graph/actions-graph/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for actions_graph package.""" diff --git a/context-graph/actions-graph/tests/test_e2e.py b/context-graph/actions-graph/tests/test_e2e.py new file mode 100644 index 00000000..82a309df --- /dev/null +++ b/context-graph/actions-graph/tests/test_e2e.py @@ -0,0 +1,318 @@ +"""End-to-end tests for actions_graph with Memgraph. + +These tests require a running Memgraph instance. +""" + +import pytest +from actions_graph import ( + ActionsGraph, + ActionStatus, + ActionType, + MessageRole, + Session, + ToolCall, +) + + +@pytest.fixture +def graph(): + """Create a fresh ActionsGraph instance for testing.""" + import contextlib + + g = ActionsGraph() + with contextlib.suppress(Exception): + g.setup() # Constraints may already exist + g.clear() + yield g + g.clear() + + +class TestActionsGraphSetup: + """Tests for ActionsGraph setup and teardown.""" + + def test_setup_and_drop(self, graph: ActionsGraph): + """Test setting up and dropping the schema.""" + # Schema should already be set up by fixture + # Just verify we can create a session + session = Session(session_id="test-setup-session") + graph.create_session(session) + retrieved = graph.get_session("test-setup-session") + assert retrieved is not None + + +class TestSessionOperations: + """Tests for session CRUD operations.""" + + def test_create_and_get_session(self, graph: ActionsGraph): + """Test creating and retrieving a session.""" + session = Session( + session_id="test-session-001", + model="claude-sonnet-4-20250514", + working_directory="/test/project", + tags=["test", "e2e"], + ) + graph.create_session(session) + + retrieved = graph.get_session("test-session-001") + assert retrieved is not None + assert retrieved.session_id == "test-session-001" + assert retrieved.model == "claude-sonnet-4-20250514" + assert "test" in retrieved.tags + + def test_end_session(self, graph: ActionsGraph): + """Test ending a session.""" + session = Session(session_id="test-end-session") + graph.create_session(session) + + ended = graph.end_session( + "test-end-session", + status=ActionStatus.COMPLETED, + total_cost_usd=0.05, + total_input_tokens=1000, + total_output_tokens=500, + ) + + assert ended is not None + assert ended.status == ActionStatus.COMPLETED + assert ended.ended_at is not None + assert ended.total_cost_usd == 0.05 + + def test_list_sessions(self, graph: ActionsGraph): + """Test listing sessions.""" + for i in range(3): + session = Session( + session_id=f"list-test-{i}", + tags=["list-test"], + ) + graph.create_session(session) + + sessions = graph.list_sessions(tag="list-test") + assert len(sessions) == 3 + + def test_forked_session(self, graph: ActionsGraph): + """Test creating a forked session.""" + # Create parent session + parent = Session(session_id="parent-session") + graph.create_session(parent) + + # Create forked session + forked = Session( + session_id="forked-session", + parent_session_id="parent-session", + ) + graph.create_session(forked) + + retrieved = graph.get_session("forked-session") + assert retrieved is not None + assert retrieved.parent_session_id == "parent-session" + + +class TestActionOperations: + """Tests for action CRUD operations.""" + + def test_record_tool_call(self, graph: ActionsGraph): + """Test recording a tool call.""" + session = Session(session_id="tool-call-session") + graph.create_session(session) + + tool_call = graph.record_tool_call( + session_id="tool-call-session", + tool_name="Read", + tool_input={"file_path": "/test/file.py"}, + tool_use_id="tool-001", + ) + + assert tool_call.tool_name == "Read" + assert tool_call.action_type == ActionType.TOOL_CALL + + # Retrieve and verify + retrieved = graph.get_action(tool_call.action_id) + assert retrieved is not None + assert isinstance(retrieved, ToolCall) + assert retrieved.tool_name == "Read" + + def test_record_tool_result(self, graph: ActionsGraph): + """Test recording a tool result.""" + session = Session(session_id="tool-result-session") + graph.create_session(session) + + result = graph.record_tool_result( + session_id="tool-result-session", + tool_use_id="tool-001", + tool_name="Read", + content="file contents", + ) + + assert result.content == "file contents" + assert result.is_error is False + + def test_record_message(self, graph: ActionsGraph): + """Test recording a message.""" + session = Session(session_id="message-session") + graph.create_session(session) + + message = graph.record_message( + session_id="message-session", + role=MessageRole.USER, + content="Hello!", + ) + + assert message.role == MessageRole.USER + assert message.action_type == ActionType.USER_MESSAGE + + def test_action_sequence(self, graph: ActionsGraph): + """Test that actions form a sequence with FOLLOWED_BY.""" + session = Session(session_id="sequence-session") + graph.create_session(session) + + # Record multiple actions + graph.record_message( + session_id="sequence-session", + role=MessageRole.USER, + content="First message", + ) + graph.record_tool_call( + session_id="sequence-session", + tool_name="Read", + tool_input={"file_path": "/test.py"}, + ) + graph.record_message( + session_id="sequence-session", + role=MessageRole.ASSISTANT, + content="Response", + ) + + # Get sequence + sequence = graph.get_action_sequence("sequence-session") + assert len(sequence) == 3 + + # Verify FOLLOWED_BY relationships + assert sequence[0]["next_action_id"] == sequence[1]["action_id"] + assert sequence[1]["next_action_id"] == sequence[2]["action_id"] + + def test_get_session_actions(self, graph: ActionsGraph): + """Test getting all actions for a session.""" + session = Session(session_id="get-actions-session") + graph.create_session(session) + + # Record several actions + for i in range(5): + graph.record_tool_call( + session_id="get-actions-session", + tool_name=f"Tool{i}", + tool_input={}, + ) + + actions = graph.get_session_actions("get-actions-session") + assert len(actions) == 5 + + def test_filter_actions_by_type(self, graph: ActionsGraph): + """Test filtering actions by type.""" + session = Session(session_id="filter-actions-session") + graph.create_session(session) + + # Record mixed actions + graph.record_message( + session_id="filter-actions-session", + role=MessageRole.USER, + content="Message", + ) + graph.record_tool_call( + session_id="filter-actions-session", + tool_name="Read", + tool_input={}, + ) + graph.record_tool_call( + session_id="filter-actions-session", + tool_name="Write", + tool_input={}, + ) + + # Filter by tool call + tool_calls = graph.get_session_actions( + "filter-actions-session", + action_type=ActionType.TOOL_CALL, + ) + assert len(tool_calls) == 2 + + +class TestAnalytics: + """Tests for analytics queries.""" + + def test_tool_usage_stats(self, graph: ActionsGraph): + """Test getting tool usage statistics.""" + session = Session(session_id="stats-session") + graph.create_session(session) + + # Record tool calls + for _ in range(3): + graph.record_tool_call( + session_id="stats-session", + tool_name="Read", + tool_input={}, + ) + for _ in range(2): + graph.record_tool_call( + session_id="stats-session", + tool_name="Write", + tool_input={}, + ) + + stats = graph.get_tool_usage_stats("stats-session") + assert len(stats) == 2 + + # Read should have more calls + read_stats = next(s for s in stats if s["tool_name"] == "Read") + assert read_stats["call_count"] == 3 + + def test_session_summary(self, graph: ActionsGraph): + """Test getting a session summary.""" + session = Session(session_id="summary-session") + graph.create_session(session) + + # Add various actions + graph.record_message( + session_id="summary-session", + role=MessageRole.USER, + content="Hello", + ) + graph.record_tool_call( + session_id="summary-session", + tool_name="Read", + tool_input={}, + ) + graph.record_message( + session_id="summary-session", + role=MessageRole.ASSISTANT, + content="Hi!", + ) + + summary = graph.get_session_summary("summary-session") + assert summary["action_count"] == 3 + assert summary["user_message_count"] == 1 + assert summary["assistant_message_count"] == 1 + assert summary["tool_call_count"] == 1 + + +class TestMCPTools: + """Tests for MCP tool handling.""" + + def test_mcp_tool_tracking(self, graph: ActionsGraph): + """Test that MCP tools are correctly identified and tracked.""" + session = Session(session_id="mcp-session") + graph.create_session(session) + + tool_call = graph.record_tool_call( + session_id="mcp-session", + tool_name="mcp__playwright__browser_click", + tool_input={"selector": "button"}, + ) + + assert tool_call.is_mcp is True + assert tool_call.mcp_server == "playwright" + + # Verify in tool stats + stats = graph.get_tool_usage_stats("mcp-session") + assert len(stats) == 1 + assert stats[0]["is_mcp"] is True + assert stats[0]["mcp_server"] == "playwright" diff --git a/context-graph/actions-graph/tests/test_models.py b/context-graph/actions-graph/tests/test_models.py new file mode 100644 index 00000000..f2846657 --- /dev/null +++ b/context-graph/actions-graph/tests/test_models.py @@ -0,0 +1,256 @@ +"""Tests for actions_graph.models module.""" + +import pytest +from actions_graph.models import ( + ActionStatus, + ActionType, + ActionValidationError, + ErrorEvent, + Message, + MessageRole, + PermissionRequest, + RateLimitEvent, + Session, + StructuredOutput, + SubagentEvent, + ToolCall, + ToolResult, +) + + +class TestSession: + """Tests for Session model.""" + + def test_create_session(self): + """Test creating a valid session.""" + session = Session(session_id="test-session-123") + assert session.session_id == "test-session-123" + assert session.status == ActionStatus.IN_PROGRESS + assert session.started_at is not None + assert session.ended_at is None + + def test_create_session_with_all_fields(self): + """Test creating a session with all fields.""" + session = Session( + session_id="test-session-456", + model="claude-sonnet-4-20250514", + working_directory="/path/to/project", + git_branch="main", + tags=["test", "demo"], + metadata={"key": "value"}, + ) + assert session.model == "claude-sonnet-4-20250514" + assert session.working_directory == "/path/to/project" + assert session.git_branch == "main" + assert session.tags == ["test", "demo"] + assert session.metadata == {"key": "value"} + + def test_invalid_session_id(self): + """Test that invalid session IDs raise an error.""" + with pytest.raises(ActionValidationError): + Session(session_id="") + + with pytest.raises(ActionValidationError): + Session(session_id="invalid session id with spaces") + + def test_forked_session(self): + """Test creating a forked session.""" + session = Session( + session_id="forked-session", + parent_session_id="original-session", + ) + assert session.parent_session_id == "original-session" + + +class TestToolCall: + """Tests for ToolCall model.""" + + def test_create_tool_call(self): + """Test creating a tool call.""" + tool_call = ToolCall( + session_id="session-123", + tool_name="Read", + tool_input={"file_path": "/path/to/file.py"}, + tool_use_id="tool-use-001", + ) + assert tool_call.tool_name == "Read" + assert tool_call.tool_input == {"file_path": "/path/to/file.py"} + assert tool_call.tool_use_id == "tool-use-001" + assert tool_call.action_type == ActionType.TOOL_CALL + + def test_mcp_tool_detection(self): + """Test automatic MCP tool detection.""" + tool_call = ToolCall( + session_id="session-123", + tool_name="mcp__playwright__browser_click", + tool_input={"selector": "button"}, + ) + assert tool_call.is_mcp is True + assert tool_call.mcp_server == "playwright" + + def test_non_mcp_tool(self): + """Test non-MCP tool detection.""" + tool_call = ToolCall( + session_id="session-123", + tool_name="Write", + tool_input={"content": "hello"}, + ) + assert tool_call.is_mcp is False + assert tool_call.mcp_server is None + + +class TestToolResult: + """Tests for ToolResult model.""" + + def test_create_tool_result(self): + """Test creating a tool result.""" + result = ToolResult( + session_id="session-123", + tool_use_id="tool-use-001", + tool_name="Read", + content="file contents here", + ) + assert result.tool_use_id == "tool-use-001" + assert result.content == "file contents here" + assert result.is_error is False + assert result.action_type == ActionType.TOOL_RESULT + + def test_error_result(self): + """Test creating an error result.""" + result = ToolResult( + session_id="session-123", + tool_use_id="tool-use-002", + tool_name="Write", + is_error=True, + error_message="Permission denied", + ) + assert result.is_error is True + assert result.error_message == "Permission denied" + assert result.status == ActionStatus.FAILED + + +class TestMessage: + """Tests for Message model.""" + + def test_create_user_message(self): + """Test creating a user message.""" + message = Message( + session_id="session-123", + role=MessageRole.USER, + content="Hello, Claude!", + ) + assert message.role == MessageRole.USER + assert message.content == "Hello, Claude!" + assert message.action_type == ActionType.USER_MESSAGE + + def test_create_assistant_message(self): + """Test creating an assistant message.""" + message = Message( + session_id="session-123", + role=MessageRole.ASSISTANT, + content=[{"type": "text", "text": "Hello!"}], + model="claude-sonnet-4-20250514", + ) + assert message.role == MessageRole.ASSISTANT + assert message.action_type == ActionType.ASSISTANT_MESSAGE + assert message.model == "claude-sonnet-4-20250514" + + def test_create_system_message(self): + """Test creating a system message.""" + message = Message( + session_id="session-123", + role=MessageRole.SYSTEM, + content="Session started", + ) + assert message.action_type == ActionType.SYSTEM_MESSAGE + + +class TestStructuredOutput: + """Tests for StructuredOutput model.""" + + def test_create_structured_output(self): + """Test creating a structured output.""" + output = StructuredOutput( + session_id="session-123", + output_type="code_review", + output_data={"issues": [], "suggestions": ["Add tests"]}, + validation_passed=True, + ) + assert output.output_type == "code_review" + assert output.output_data["suggestions"] == ["Add tests"] + assert output.action_type == ActionType.STRUCTURED_OUTPUT + + +class TestSubagentEvent: + """Tests for SubagentEvent model.""" + + def test_create_subagent_start(self): + """Test creating a subagent start event.""" + event = SubagentEvent( + session_id="session-123", + agent_id="subagent-001", + agent_type="code-reviewer", + description="Review the code changes", + ) + event.action_type = ActionType.SUBAGENT_START + assert event.agent_id == "subagent-001" + assert event.action_type == ActionType.SUBAGENT_START + + def test_create_subagent_stop(self): + """Test creating a subagent stop event.""" + event = SubagentEvent( + session_id="session-123", + agent_id="subagent-001", + agent_type="code-reviewer", + result="Review completed successfully", + ) + event.action_type = ActionType.SUBAGENT_STOP + assert event.result == "Review completed successfully" + + +class TestErrorEvent: + """Tests for ErrorEvent model.""" + + def test_create_error_event(self): + """Test creating an error event.""" + error = ErrorEvent( + session_id="session-123", + error_type="api_error", + error_message="Rate limit exceeded", + recoverable=True, + ) + assert error.error_type == "api_error" + assert error.error_message == "Rate limit exceeded" + assert error.recoverable is True + assert error.status == ActionStatus.FAILED + assert error.action_type == ActionType.ERROR + + +class TestPermissionRequest: + """Tests for PermissionRequest model.""" + + def test_create_permission_request(self): + """Test creating a permission request.""" + request = PermissionRequest( + session_id="session-123", + tool_name="Bash", + tool_input={"command": "rm -rf /tmp/test"}, + ) + assert request.tool_name == "Bash" + assert request.action_type == ActionType.PERMISSION_REQUEST + + +class TestRateLimitEvent: + """Tests for RateLimitEvent model.""" + + def test_create_rate_limit_event(self): + """Test creating a rate limit event.""" + event = RateLimitEvent( + session_id="session-123", + rate_limit_status="allowed_warning", + rate_limit_type="five_hour", + utilization=0.8, + ) + assert event.rate_limit_status == "allowed_warning" + assert event.utilization == 0.8 + assert event.action_type == ActionType.RATE_LIMIT diff --git a/pyproject.toml b/pyproject.toml index a2a449a6..db2dc832 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ members = [ "agents/sql2graph/", "unstructured2graph", "context-graph/skills-graph", + "context-graph/actions-graph", ] [tool.ruff] diff --git a/uv.lock b/uv.lock index 6ad208fe..5e666930 100644 --- a/uv.lock +++ b/uv.lock @@ -54,6 +54,7 @@ resolution-markers = [ [manifest] members = [ + "actions-graph", "langchain-memgraph", "lightrag-memgraph", "mcp-memgraph", @@ -176,6 +177,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744 }, ] +[[package]] +name = "actions-graph" +version = "0.1.0" +source = { editable = "context-graph/actions-graph" } +dependencies = [ + { name = "memgraph-toolbox" }, +] + +[package.optional-dependencies] +claude-agent = [ + { name = "claude-agent-sdk" }, +] +test = [ + { name = "pytest" }, + { name = "pytest-asyncio" }, +] + +[package.metadata] +requires-dist = [ + { name = "claude-agent-sdk", marker = "extra == 'claude-agent'", specifier = ">=0.1.0" }, + { name = "memgraph-toolbox", editable = "memgraph-toolbox" }, + { name = "pytest", marker = "extra == 'test'", specifier = ">=9.0.3" }, + { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=0.24.0" }, +] +provides-extras = ["claude-agent", "test"] + [[package]] name = "aiofile" version = "3.9.0" @@ -1002,6 +1029,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/8f/61959034484a4a7c527811f4721e75d02d653a35afb0b6054474d8185d4c/charset_normalizer-3.4.7-py3-none-any.whl", hash = "sha256:3dce51d0f5e7951f8bb4900c257dad282f49190fdbebecd4ba99bcc41fef404d", size = 61958 }, ] +[[package]] +name = "claude-agent-sdk" +version = "0.1.71" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "mcp" }, + { name = "sniffio" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/4c/e877f084c3c5bedc695d4045451af303b869b1c3dc302ea854e092fd882f/claude_agent_sdk-0.1.71.tar.gz", hash = "sha256:89ac5e4dd0fecf3e62dcbea69dca096921136fe7549daf52c546eacce9b70131", size = 241282 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/b1/9ccbf6ee447b8ce02acca5c65617ba36828129309f69b8c782120bc25591/claude_agent_sdk-0.1.71-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f9996b9c03d9b75bfa5eea5d57acc3d4c736421f21d2c75d169f1e3a5b6f0f43", size = 63577286 }, + { url = "https://files.pythonhosted.org/packages/bc/bf/39b5f3d13f12f0ed193f4a8ed02678d59097ce939b9da30cd67a15b3cb25/claude_agent_sdk-0.1.71-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:d6d5ed0441f3780d3f692d9bf86c99045af0dc925321d5549ab107c33d0f9695", size = 65435355 }, + { url = "https://files.pythonhosted.org/packages/f2/74/fc8dff93e0e79295acd89f588b94c9ba5f61ae60f59ecaec0d58474afabe/claude_agent_sdk-0.1.71-py3-none-manylinux_2_17_aarch64.whl", hash = "sha256:fbb8ce166b9cc861ffb706f437c40b940cbeca37ed6c8846c90dd9af885d12c5", size = 76748516 }, + { url = "https://files.pythonhosted.org/packages/05/0c/ea00389db12f0d2feb429190b54bf102213bb5e848c95a71ffad7a0284f6/claude_agent_sdk-0.1.71-py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:97f6c2cc36a216d4bac5f6cd418042585627d6b73c342d8041cf63dcfb97f4de", size = 76939021 }, + { url = "https://files.pythonhosted.org/packages/8e/0a/8de473d525e5a0058e75c16bb1315a44ea7502b0825ab63f13bdf0324256/claude_agent_sdk-0.1.71-py3-none-win_amd64.whl", hash = "sha256:76288c59a5d25aab5df4e6cf71cd4550c0bb43a2542046da14e9670b07546095", size = 78318549 }, +] + [[package]] name = "click" version = "8.2.1"