diff --git a/agents/agents-core/build.gradle.kts b/agents/agents-core/build.gradle.kts index 126414a4d7..ff33ace2a6 100644 --- a/agents/agents-core/build.gradle.kts +++ b/agents/agents-core/build.gradle.kts @@ -17,6 +17,7 @@ kotlin { api(project(":utils")) api(project(":prompt:prompt-executor:prompt-executor-model")) api(project(":prompt:prompt-llm")) + api(project(":prompt:prompt-processor")) api(project(":prompt:prompt-structure")) api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client")) diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/AIAgentSimpleStrategies.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/AIAgentSimpleStrategies.kt index 7fa193ace2..81c326a2fb 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/AIAgentSimpleStrategies.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/AIAgentSimpleStrategies.kt @@ -13,6 +13,7 @@ import ai.koog.agents.core.dsl.extension.onAssistantMessage import ai.koog.agents.core.dsl.extension.onMultipleAssistantMessages import ai.koog.agents.core.dsl.extension.onMultipleToolCalls import ai.koog.agents.core.dsl.extension.onToolCall +import ai.koog.prompt.processor.ResponseProcessor /** * Creates a single-run strategy for an AI agent. @@ -28,19 +29,26 @@ import ai.koog.agents.core.dsl.extension.onToolCall * - SingleRunMode.SINGLE: Executes without allowing multiple simultaneous tool calls. * - SingleRunMode.SEQUENTIAL: Executes simultaneous tool calls sequentially. * - SingleRunMode.PARALLEL: Executes multiple tool calls in parallel. + * @param responseProcessor The processor applied to all LLM responses. If null, no processing is applied. Defaults to null. * @return An instance of AIAgentStrategy configured according to the specified single-run mode. */ -public fun singleRunStrategy(runMode: ToolCalls = ToolCalls.SINGLE_RUN_SEQUENTIAL): AIAgentGraphStrategy = +public fun singleRunStrategy( + runMode: ToolCalls = ToolCalls.SINGLE_RUN_SEQUENTIAL, + responseProcessor: ResponseProcessor? = null, +): AIAgentGraphStrategy = when (runMode) { - ToolCalls.SEQUENTIAL -> singleRunWithParallelAbility(false) - ToolCalls.PARALLEL -> singleRunWithParallelAbility(true) - ToolCalls.SINGLE_RUN_SEQUENTIAL -> singleRunModeStrategy() + ToolCalls.SEQUENTIAL -> singleRunWithParallelAbility(false, responseProcessor) + ToolCalls.PARALLEL -> singleRunWithParallelAbility(true, responseProcessor) + ToolCalls.SINGLE_RUN_SEQUENTIAL -> singleRunModeStrategy(responseProcessor) } -private fun singleRunWithParallelAbility(parallelTools: Boolean) = strategy("single_run_sequential") { - val nodeCallLLM by nodeLLMRequestMultiple() +private fun singleRunWithParallelAbility( + parallelTools: Boolean, + responseProcessor: ResponseProcessor? +) = strategy("single_run_sequential") { + val nodeCallLLM by nodeLLMRequestMultiple(responseProcessor = responseProcessor) val nodeExecuteTool by nodeExecuteMultipleTools(parallelTools = parallelTools) - val nodeSendToolResult by nodeLLMSendMultipleToolResults() + val nodeSendToolResult by nodeLLMSendMultipleToolResults(responseProcessor = responseProcessor) edge(nodeStart forwardTo nodeCallLLM) edge(nodeCallLLM forwardTo nodeExecuteTool onMultipleToolCalls { true }) @@ -61,10 +69,10 @@ private fun singleRunWithParallelAbility(parallelTools: Boolean) = strategy("sin edge(nodeSendToolResult forwardTo nodeExecuteTool onMultipleToolCalls { true }) } -private fun singleRunModeStrategy() = strategy("single_run") { - val nodeCallLLM by nodeLLMRequest() +private fun singleRunModeStrategy(responseProcessor: ResponseProcessor?) = strategy("single_run") { + val nodeCallLLM by nodeLLMRequest(responseProcessor = responseProcessor) val nodeExecuteTool by nodeExecuteTool() - val nodeSendToolResult by nodeLLMSendToolResult() + val nodeSendToolResult by nodeLLMSendToolResult(responseProcessor = responseProcessor) edge(nodeStart forwardTo nodeCallLLM) edge(nodeCallLLM forwardTo nodeExecuteTool onToolCall { true }) diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt index e4dd3f400c..f29e04187b 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt @@ -11,6 +11,8 @@ import ai.koog.prompt.llm.LLModel import ai.koog.prompt.message.LLMChoice import ai.koog.prompt.message.Message import ai.koog.prompt.params.LLMParams +import ai.koog.prompt.processor.ResponseProcessor +import ai.koog.prompt.processor.executeProcessed import ai.koog.prompt.streaming.StreamFrame import ai.koog.prompt.structure.StructureFixingParser import ai.koog.prompt.structure.StructuredRequestConfig @@ -114,13 +116,24 @@ public sealed class AIAgentLLMSession( return executor.executeStreaming(preparedPrompt, model, tools) } - protected suspend fun executeMultiple(prompt: Prompt, tools: List): List { + protected suspend fun executeMultiple( + prompt: Prompt, + tools: List, + responseProcessor: ResponseProcessor? = null + ): List { val preparedPrompt = preparePrompt(prompt, tools) - return executor.execute(preparedPrompt, model, tools) + return if (responseProcessor == null) { + executor.execute(preparedPrompt, model, tools) + } else { + executor.executeProcessed(preparedPrompt, model, tools, responseProcessor) + } } - protected suspend fun executeSingle(prompt: Prompt, tools: List): Message.Response = - executeMultiple(prompt, tools).first() + protected suspend fun executeSingle( + prompt: Prompt, + tools: List, + responseProcessor: ResponseProcessor? = null + ): Message.Response = executeMultiple(prompt, tools, responseProcessor).first() /** * Sends a request to the language model without utilizing any tools and returns the response. @@ -152,14 +165,17 @@ public sealed class AIAgentLLMSession( * This method updates the session's prompt configuration to mark tool usage as required before * executing the request. Additionally, it ensures the session is active before proceeding. * + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return The response from the language model after executing the request with enforced tool usage. */ - public open suspend fun requestLLMOnlyCallingTools(): Message.Response { + public open suspend fun requestLLMOnlyCallingTools( + responseProcessor: ResponseProcessor? = null + ): Message.Response { validateSession() val promptWithOnlyCallingTools = prompt.withUpdatedParams { toolChoice = LLMParams.ToolChoice.Required } - return executeSingle(promptWithOnlyCallingTools, tools) + return executeSingle(promptWithOnlyCallingTools, tools, responseProcessor) } /** @@ -173,16 +189,20 @@ public sealed class AIAgentLLMSession( * @param tool The tool to be used for the request, represented by a [ToolDescriptor] instance. * This parameter ensures that the language model utilizes the specified tool * during the interaction. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return The response from the language model as a [Message.Response] instance after * processing the request with the enforced tool. */ - public open suspend fun requestLLMForceOneTool(tool: ToolDescriptor): Message.Response { + public open suspend fun requestLLMForceOneTool( + tool: ToolDescriptor, + responseProcessor: ResponseProcessor? = null + ): Message.Response { validateSession() check(tools.contains(tool)) { "Unable to force call to tool `${tool.name}` because it is not defined" } val promptWithForcingOneTool = prompt.withUpdatedParams { toolChoice = LLMParams.ToolChoice.Named(tool.name) } - return executeSingle(promptWithForcingOneTool, tools) + return executeSingle(promptWithForcingOneTool, tools, responseProcessor) } /** @@ -194,11 +214,15 @@ public sealed class AIAgentLLMSession( * * @param tool The tool to be used for the request, represented as an instance of [Tool]. This parameter ensures * the specified tool is utilized during the LLM interaction. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return The response from the language model as a [Message.Response] instance after processing the request with the * enforced tool. */ - public open suspend fun requestLLMForceOneTool(tool: Tool<*, *>): Message.Response { - return requestLLMForceOneTool(tool.descriptor) + public open suspend fun requestLLMForceOneTool( + tool: Tool<*, *>, + responseProcessor: ResponseProcessor? = null + ): Message.Response { + return requestLLMForceOneTool(tool.descriptor, responseProcessor) } /** @@ -207,9 +231,9 @@ public sealed class AIAgentLLMSession( * * @return The first response message from the LLM after executing the request. */ - public open suspend fun requestLLM(): Message.Response { + public open suspend fun requestLLM(responseProcessor: ResponseProcessor? = null): Message.Response { validateSession() - return executeSingle(prompt, tools) + return executeSingle(prompt, tools, responseProcessor) } /** @@ -248,11 +272,12 @@ public sealed class AIAgentLLMSession( * Before executing the request, the session state is validated to ensure * it is active and usable. * + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return a list of responses from the language model */ - public open suspend fun requestLLMMultiple(): List { + public open suspend fun requestLLMMultiple(responseProcessor: ResponseProcessor? = null): List { validateSession() - return executeMultiple(prompt, tools) + return executeMultiple(prompt, tools, responseProcessor) } /** diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMWriteSession.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMWriteSession.kt index ff1daffc8d..8bbb7b5a88 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMWriteSession.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMWriteSession.kt @@ -14,6 +14,7 @@ import ai.koog.prompt.executor.model.PromptExecutor import ai.koog.prompt.llm.LLModel import ai.koog.prompt.message.Message import ai.koog.prompt.params.LLMParams +import ai.koog.prompt.processor.ResponseProcessor import ai.koog.prompt.streaming.StreamFrame import ai.koog.prompt.structure.StructureDefinition import ai.koog.prompt.structure.StructureFixingParser @@ -381,20 +382,27 @@ public class AIAgentLLMWriteSession internal constructor( * Requests a response from the Language Learning Model (LLM) while also processing * the response by updating the current prompt with the received message. * + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return The response received from the Language Learning Model (LLM). */ - override suspend fun requestLLMOnlyCallingTools(): Message.Response { - return super.requestLLMOnlyCallingTools().also { response -> appendPrompt { message(response) } } + override suspend fun requestLLMOnlyCallingTools(responseProcessor: ResponseProcessor?): Message.Response { + return super.requestLLMOnlyCallingTools(responseProcessor) + .also { response -> appendPrompt { message(response) } } } /** * Requests an LLM (Large Language Model) to forcefully utilize a specific tool during its operation. * * @param tool A descriptor object representing the tool to be enforced for use by the LLM. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return A response message received from the LLM after executing the enforced tool request. */ - override suspend fun requestLLMForceOneTool(tool: ToolDescriptor): Message.Response { - return super.requestLLMForceOneTool(tool).also { response -> appendPrompt { message(response) } } + override suspend fun requestLLMForceOneTool( + tool: ToolDescriptor, + responseProcessor: ResponseProcessor? + ): Message.Response { + return super.requestLLMForceOneTool(tool, responseProcessor) + .also { response -> appendPrompt { message(response) } } } /** @@ -402,20 +410,26 @@ public class AIAgentLLMWriteSession internal constructor( * and updates the prompt based on the generated response. * * @param tool The tool that will be enforced and executed. It contains the input and output types. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return The response generated after executing the provided tool. */ - override suspend fun requestLLMForceOneTool(tool: Tool<*, *>): Message.Response { - return super.requestLLMForceOneTool(tool).also { response -> appendPrompt { message(response) } } + override suspend fun requestLLMForceOneTool( + tool: Tool<*, *>, + responseProcessor: ResponseProcessor? + ): Message.Response { + return super.requestLLMForceOneTool(tool, responseProcessor) + .also { response -> appendPrompt { message(response) } } } /** - * Makes an asynchronous request to a Large Language Model (LLM) and updates the current prompt - * with the response received from the LLM. + * Makes an asynchronous request to a Large Language Model (LLM), processes the response using the provided + * responseProcessor, and updates the current prompt with the processed response. * + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return A [Message.Response] object containing the response from the LLM. */ - override suspend fun requestLLM(): Message.Response { - return super.requestLLM().also { response -> + override suspend fun requestLLM(responseProcessor: ResponseProcessor?): Message.Response { + return super.requestLLM(responseProcessor).also { response -> appendPrompt { message(response) } } } @@ -427,10 +441,11 @@ public class AIAgentLLMWriteSession internal constructor( * response is subsequently used to update the session's prompt. The prompt updating mechanism * allows stateful interactions with the LLM, maintaining context across multiple requests. * + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return A list of `Message.Response` containing the results from the LLM. */ - override suspend fun requestLLMMultiple(): List { - return super.requestLLMMultiple().also { responses -> + override suspend fun requestLLMMultiple(responseProcessor: ResponseProcessor?): List { + return super.requestLLMMultiple(responseProcessor).also { responses -> appendPrompt { responses.forEach { message(it) } } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentFunctionalContextExt.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentFunctionalContextExt.kt index d19caab204..1a87a4b172 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentFunctionalContextExt.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentFunctionalContextExt.kt @@ -10,6 +10,7 @@ import ai.koog.agents.core.tools.ToolArgs import ai.koog.agents.core.tools.ToolDescriptor import ai.koog.agents.core.tools.ToolResult import ai.koog.prompt.message.Message +import ai.koog.prompt.processor.ResponseProcessor import ai.koog.prompt.streaming.StreamFrame import ai.koog.prompt.structure.StructureDefinition import ai.koog.prompt.structure.StructureFixingParser @@ -24,10 +25,12 @@ import kotlinx.serialization.serializer * * @param message The content of the message to be sent to the LLM. * @param allowToolCalls Specifies whether tool calls are allowed during the LLM interaction. Defaults to `true`. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. */ public suspend fun AIAgentFunctionalContext.requestLLM( message: String, - allowToolCalls: Boolean = true + allowToolCalls: Boolean = true, + responseProcessor: ResponseProcessor? = null ): Message.Response { return llm.writeSession { appendPrompt { @@ -35,7 +38,7 @@ public suspend fun AIAgentFunctionalContext.requestLLM( } if (allowToolCalls) { - requestLLM() + requestLLM(responseProcessor) } else { requestLLMWithoutTools() } @@ -190,15 +193,19 @@ public suspend fun AIAgentFunctionalContext.requestLLMStreaming( * The message becomes part of the current prompt, and multiple responses from the LLM are collected. * * @param message The content of the message to be sent to the LLM. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return A list of LLM responses. */ -public suspend fun AIAgentFunctionalContext.requestLLMMultiple(message: String): List { +public suspend fun AIAgentFunctionalContext.requestLLMMultiple( + message: String, + responseProcessor: ResponseProcessor? = null +): List { return llm.writeSession { appendPrompt { user(message) } - requestLLMMultiple() + requestLLMMultiple(responseProcessor) } } @@ -207,15 +214,19 @@ public suspend fun AIAgentFunctionalContext.requestLLMMultiple(message: String): * The message becomes part of the current prompt, and the LLM is instructed to only use tools. * * @param message The content of the message to be sent to the LLM. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return The LLM response containing tool calls. */ -public suspend fun AIAgentFunctionalContext.requestLLMOnlyCallingTools(message: String): Message.Response { +public suspend fun AIAgentFunctionalContext.requestLLMOnlyCallingTools( + message: String, + responseProcessor: ResponseProcessor? = null +): Message.Response { return llm.writeSession { appendPrompt { user(message) } - requestLLMOnlyCallingTools() + requestLLMOnlyCallingTools(responseProcessor) } } @@ -225,18 +236,20 @@ public suspend fun AIAgentFunctionalContext.requestLLMOnlyCallingTools(message: * * @param message The content of the message to be sent to the LLM. * @param tool The tool descriptor that the LLM must use. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return The LLM response containing the tool call. */ public suspend fun AIAgentFunctionalContext.requestLLMForceOneTool( message: String, - tool: ToolDescriptor + tool: ToolDescriptor, + responseProcessor: ResponseProcessor? = null ): Message.Response { return llm.writeSession { appendPrompt { user(message) } - requestLLMForceOneTool(tool) + requestLLMForceOneTool(tool, responseProcessor) } } @@ -246,18 +259,20 @@ public suspend fun AIAgentFunctionalContext.requestLLMForceOneTool( * * @param message The content of the message to be sent to the LLM. * @param tool The tool that the LLM must use. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return The LLM response containing the tool call. */ public suspend fun AIAgentFunctionalContext.requestLLMForceOneTool( message: String, - tool: Tool<*, *> + tool: Tool<*, *>, + responseProcessor: ResponseProcessor? = null ): Message.Response { return llm.writeSession { appendPrompt { user(message) } - requestLLMForceOneTool(tool) + requestLLMForceOneTool(tool, responseProcessor) } } @@ -294,9 +309,13 @@ public suspend fun AIAgentFunctionalContext.executeMultipleTools( * Adds a tool result to the prompt and requests an LLM response. * * @param toolResult The tool result to add to the prompt. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return The LLM response. */ -public suspend fun AIAgentFunctionalContext.sendToolResult(toolResult: ReceivedToolResult): Message.Response { +public suspend fun AIAgentFunctionalContext.sendToolResult( + toolResult: ReceivedToolResult, + responseProcessor: ResponseProcessor? = null +): Message.Response { return llm.writeSession { appendPrompt { tool { @@ -304,7 +323,7 @@ public suspend fun AIAgentFunctionalContext.sendToolResult(toolResult: ReceivedT } } - requestLLM() + requestLLM(responseProcessor) } } @@ -312,10 +331,12 @@ public suspend fun AIAgentFunctionalContext.sendToolResult(toolResult: ReceivedT * Adds multiple tool results to the prompt and gets multiple LLM responses. * * @param results The list of tool results to add to the prompt. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @return A list of LLM responses. */ public suspend fun AIAgentFunctionalContext.sendMultipleToolResults( - results: List + results: List, + responseProcessor: ResponseProcessor? = null ): List { return llm.writeSession { appendPrompt { @@ -324,7 +345,7 @@ public suspend fun AIAgentFunctionalContext.sendMultipleToolResults( } } - requestLLMMultiple() + requestLLMMultiple(responseProcessor) } } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt index 775b345550..0c88d29232 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt @@ -15,6 +15,7 @@ import ai.koog.prompt.dsl.PromptBuilder import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.LLModel import ai.koog.prompt.message.Message +import ai.koog.prompt.processor.ResponseProcessor import ai.koog.prompt.streaming.StreamFrame import ai.koog.prompt.streaming.toMessageResponses import ai.koog.prompt.structure.StructureDefinition @@ -79,11 +80,13 @@ public inline fun AIAgentSubgraphBuilderBase<*, *>.nodeUpdatePrompt( /** * A node that appends a user message to the LLM prompt and gets a response where the LLM can only call tools. * + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. * @param name Optional name for the node. */ @AIAgentBuilderDslMarker public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageOnlyCallingTools( - name: String? = null + name: String? = null, + responseProcessor: ResponseProcessor? = null ): AIAgentNodeDelegate = node(name) { message -> llm.writeSession { @@ -91,7 +94,7 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageOnlyCallingTools( user(message) } - requestLLMOnlyCallingTools() + requestLLMOnlyCallingTools(responseProcessor) } } @@ -100,11 +103,13 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageOnlyCallingTools( * * @param name Optional node name. * @param tool Tool descriptor the LLM is required to use. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. */ @AIAgentBuilderDslMarker public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool( name: String? = null, - tool: ToolDescriptor + tool: ToolDescriptor, + responseProcessor: ResponseProcessor? = null ): AIAgentNodeDelegate = node(name) { message -> llm.writeSession { @@ -112,7 +117,7 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool( user(message) } - requestLLMForceOneTool(tool) + requestLLMForceOneTool(tool, responseProcessor) } } @@ -121,24 +126,28 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool( * * @param name Optional node name. * @param tool Tool the LLM is required to use. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. */ @AIAgentBuilderDslMarker public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMessageForceOneTool( name: String? = null, - tool: Tool<*, *> + tool: Tool<*, *>, + responseProcessor: ResponseProcessor? = null ): AIAgentNodeDelegate = - nodeLLMSendMessageForceOneTool(name, tool.descriptor) + nodeLLMSendMessageForceOneTool(name, tool.descriptor, responseProcessor) /** * A node that appends a user message to the LLM prompt and gets a response with optional tool usage. * * @param name Optional node name. * @param allowToolCalls Controls whether LLM can use tools (default: true). + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. */ @AIAgentBuilderDslMarker public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequest( name: String? = null, - allowToolCalls: Boolean = true + allowToolCalls: Boolean = true, + responseProcessor: ResponseProcessor? = null ): AIAgentNodeDelegate = node(name) { message -> llm.writeSession { @@ -147,7 +156,7 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequest( } if (allowToolCalls) { - requestLLM() + requestLLM(responseProcessor) } else { requestLLMWithoutTools() } @@ -287,10 +296,12 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestStreaming( * A node that appends a user message to the LLM prompt and gets multiple LLM responses with tool calls enabled. * * @param name Optional node name. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. */ @AIAgentBuilderDslMarker public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestMultiple( - name: String? = null + name: String? = null, + responseProcessor: ResponseProcessor? = null ): AIAgentNodeDelegate> = node(name) { message -> llm.writeSession { @@ -298,7 +309,7 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMRequestMultiple( user(message) } - requestLLMMultiple() + requestLLMMultiple(responseProcessor) } } @@ -390,10 +401,12 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeExecuteTool( * A node that adds a tool result to the prompt and requests an LLM response. * * @param name Optional node name. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. */ @AIAgentBuilderDslMarker public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendToolResult( - name: String? = null + name: String? = null, + responseProcessor: ResponseProcessor? = null, ): AIAgentNodeDelegate = node(name) { result -> llm.writeSession { @@ -403,7 +416,7 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendToolResult( } } - requestLLM() + requestLLM(responseProcessor) } } @@ -435,12 +448,15 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeExecuteMultipleTools( * @param parallelTools A flag to determine if the tool calls should be executed concurrently. * If true, all tool calls are executed in parallel; otherwise, they are * executed sequentially. Default value is false. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. + * * @return An instance of [AIAgentNodeDelegate] that takes a list of tool calls as input * and returns the corresponding list of tool responses. */ public fun AIAgentSubgraphBuilderBase<*, *>.nodeExecuteMultipleToolsAndSendResults( name: String? = null, parallelTools: Boolean = false, + responseProcessor: ResponseProcessor? = null, ): AIAgentNodeDelegate, List> = node(name) { toolCalls -> val results = if (parallelTools) { @@ -456,7 +472,7 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeExecuteMultipleToolsAndSendResul } } - requestLLMMultiple() + requestLLMMultiple(responseProcessor) } } @@ -464,10 +480,12 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeExecuteMultipleToolsAndSendResul * A node that adds multiple tool results to the prompt and gets multiple LLM responses. * * @param name Optional node name. + * @param responseProcessor The processor applied to the LLM response. If null, no processing is applied. Defaults to null. */ @AIAgentBuilderDslMarker public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMultipleToolResults( - name: String? = null + name: String? = null, + responseProcessor: ResponseProcessor? = null, ): AIAgentNodeDelegate, List> = node(name) { results -> llm.writeSession { @@ -477,7 +495,7 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeLLMSendMultipleToolResults( } } - requestLLMMultiple() + requestLLMMultiple(responseProcessor) } } diff --git a/build.gradle.kts b/build.gradle.kts index 6c3053d21a..8530afd368 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -244,6 +244,7 @@ dependencies { dokka(project(":prompt:prompt-llm")) dokka(project(":prompt:prompt-markdown")) dokka(project(":prompt:prompt-model")) + dokka(project(":prompt:prompt-processor")) dokka(project(":prompt:prompt-structure")) dokka(project(":prompt:prompt-tokenizer")) dokka(project(":prompt:prompt-xml")) diff --git a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt index 02f164b3e8..936ab81282 100644 --- a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt +++ b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt @@ -41,6 +41,7 @@ class OllamaTestFixture { val model = OllamaModels.Meta.LLAMA_3_2 val visionModel = OllamaModels.Granite.GRANITE_3_2_VISION val moderationModel = OllamaModels.Meta.LLAMA_GUARD_3 + val modelsWithHallucinations = listOf(OllamaModels.Meta.LLAMA_3_2, OllamaModels.Groq.LLAMA_3_GROK_TOOL_USE_8B) private lateinit var ollamaContainer: GenericContainer<*> @@ -67,6 +68,7 @@ class OllamaTestFixture { client.getModelOrNull(model.id, pullIfMissing = true) client.getModelOrNull(visionModel.id, pullIfMissing = true) client.getModelOrNull(moderationModel.id, pullIfMissing = true) + modelsWithHallucinations.forEach { client.getModelOrNull(it.id, pullIfMissing = true) } } catch (e: Exception) { logger.error(e) { "Failed to pull models: ${e.message}" } cleanContainer() diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt index 7672c3ac9b..32c36613f8 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt @@ -4,6 +4,7 @@ import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.agent.context.agentInput import ai.koog.agents.core.agent.entity.AIAgentGraphStrategy +import ai.koog.agents.core.agent.singleRunStrategy import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy import ai.koog.agents.core.dsl.extension.nodeExecuteTool @@ -16,20 +17,33 @@ import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.integration.tests.InjectOllamaTestFixture import ai.koog.integration.tests.OllamaTestFixture import ai.koog.integration.tests.OllamaTestFixtureExtension +import ai.koog.integration.tests.utils.RetryUtils.withRetry import ai.koog.integration.tests.utils.annotations.Retry import ai.koog.integration.tests.utils.annotations.RetryExtension import ai.koog.integration.tests.utils.tools.AnswerVerificationTool +import ai.koog.integration.tests.utils.tools.FileOperationsTools import ai.koog.integration.tests.utils.tools.GenericParameterTool import ai.koog.integration.tests.utils.tools.GeographyQueryTool +import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.model.PromptExecutor +import ai.koog.prompt.llm.LLModel import ai.koog.prompt.llm.OllamaModels +import ai.koog.prompt.markdown.markdown import ai.koog.prompt.params.LLMParams +import ai.koog.prompt.processor.FixToolCallLLMBased +import ai.koog.prompt.processor.ResponseProcessorApi import io.kotest.matchers.string.shouldContain import io.kotest.matchers.string.shouldNotBeBlank import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.Test import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource +import java.util.stream.Stream +import kotlin.test.BeforeTest +import kotlin.test.assertContains +import kotlin.test.assertEquals import kotlin.time.Duration.Companion.seconds @ExtendWith(OllamaTestFixtureExtension::class) @@ -40,8 +54,20 @@ class OllamaAgentIntegrationTest : AIAgentTestBase() { private lateinit var fixture: OllamaTestFixture private val executor get() = fixture.executor private val model get() = fixture.model + private val modelsWithHallucinations get() = fixture.modelsWithHallucinations + + @JvmStatic + private fun modelsWithHallucinations(): Stream = + Stream.of(*modelsWithHallucinations.toTypedArray()) + } + + @BeforeTest + fun clearToolCalls() { + toolCalls.clear() } + private val toolCalls = mutableListOf() + private fun createTestStrategy() = strategy("test-ollama") { val askCapitalSubgraph by subgraph("ask-capital") { val definePrompt by node { @@ -57,7 +83,7 @@ class OllamaAgentIntegrationTest : AIAgentTestBase() { ALWAYS generate valid JSON responses. ALWAYS call tool correctly, with valid arguments. NEVER provide tool call in result body. - + Example tool call: { "id":"ollama_tool_call_3743609160", @@ -137,7 +163,9 @@ class OllamaAgentIntegrationTest : AIAgentTestBase() { private fun createAgent( executor: PromptExecutor, strategy: AIAgentGraphStrategy, - toolRegistry: ToolRegistry + toolRegistry: ToolRegistry, + llmModel: LLModel = model, + prompt: Prompt = prompt("test-ollama", LLMParams(temperature = 0.0)) {} ): AIAgent { val promptsAndResponses = mutableListOf() @@ -145,13 +173,17 @@ class OllamaAgentIntegrationTest : AIAgentTestBase() { promptExecutor = executor, strategy = strategy, agentConfig = AIAgentConfig( - prompt("test-ollama", LLMParams(temperature = 0.0)) {}, - model, - 20 + prompt, + llmModel, + 20, ), toolRegistry = toolRegistry ) { install(EventHandler) { + onToolCallStarting { eventContext -> + toolCalls.add(eventContext.tool.name) + } + onLLMCallStarting { eventContext -> val promptText = eventContext.prompt.messages.joinToString { "${it.role.name}: ${it.content}" } promptsAndResponses.add("PROMPT_WITH_TOOLS: $promptText") @@ -173,4 +205,62 @@ class OllamaAgentIntegrationTest : AIAgentTestBase() { .shouldNotBeBlank() .shouldContain("Paris") } + + @OptIn(ResponseProcessorApi::class) + @ParameterizedTest + @MethodSource("modelsWithHallucinations") + fun ollama_testFixToolCallLLMBased(llmModel: LLModel) = runTest(timeout = 600.seconds) { + withRetry(5) { + val fileTools = FileOperationsTools() + fileTools.createNewFileWithText( + pathInProject = "scores.txt", + text = """ + name,age,score + Alice,25,85 + Bob,30,92 + Charlie,22,78 + """.trimIndent() + ) + val toolRegistry = ToolRegistry { + tool(fileTools.readFileContentTool) + tool(fileTools.createNewFileWithTextTool) + } + + val responseProcessor = FixToolCallLLMBased(toolRegistry) + val strategy = singleRunStrategy(responseProcessor = responseProcessor) + + val prompt = prompt("test-file-operations", LLMParams(temperature = 0.5)) { + system { + markdown { + +"You are a helpful assistant that can work with files using tools." + +"Perform all actions using tools." + +"When you completed the task, answer with a single word: \"Done!\"." + +"Do not include any summary in the final message." + } + } + } + + val agent = createAgent(executor, strategy, toolRegistry, llmModel, prompt) + + val request = """ + I have created a file named "scores.txt" in the project directory. + The file contains the data about the students. + + Your task: + Read the data to understand the format of the file. + Create a "compute_scores.py" file to compute the average score. + Do not summarize results in the end. + + Note: + Make sure that all paths are relative to the project directory, e.g. "scores.csv", "compute_scores.py". + """.trimIndent() + + agent.run(request) + + assertContains(toolCalls, "ReadFileContent", "readFileContent tool should be called") + assertContains(toolCalls, "CreateNewFileWithText", "createNewFileWithText tool should be called") + + assertEquals(2, fileTools.fileContentsByPath.size, "A script with average score should be created") + } + } } diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/annotations/RetryExtension.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/annotations/RetryExtension.kt index e124e9737e..648f50a529 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/annotations/RetryExtension.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/annotations/RetryExtension.kt @@ -6,6 +6,8 @@ import org.junit.jupiter.api.extension.InvocationInterceptor import org.junit.jupiter.api.extension.ReflectiveInvocationContext import org.opentest4j.TestAbortedException import java.lang.reflect.Method +import kotlin.test.AfterTest +import kotlin.test.BeforeTest class RetryExtension : InvocationInterceptor { companion object { @@ -110,6 +112,32 @@ class RetryExtension : InvocationInterceptor { val testMethod = invocationContext.executable val arguments = invocationContext.arguments - testMethod.invoke(testInstance, *arguments.toTypedArray()) + executeBeforeTestMethods(testInstance) + + try { + testMethod.invoke(testInstance, *arguments.toTypedArray()) + } finally { + executeAfterTestMethods(testInstance) + } + } + + private fun executeAnnotatedMethods(testInstance: Any, annotationClass: Class) { + // Find all methods in the test class that are annotated with the specified annotation + val annotatedMethods = testInstance.javaClass.methods + .filter { method -> method.isAnnotationPresent(annotationClass) } + + // Execute each annotated method + annotatedMethods.forEach { method -> + method.isAccessible = true + method.invoke(testInstance) + } + } + + private fun executeBeforeTestMethods(testInstance: Any) { + executeAnnotatedMethods(testInstance, BeforeTest::class.java) + } + + private fun executeAfterTestMethods(testInstance: Any) { + executeAnnotatedMethods(testInstance, AfterTest::class.java) } } diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/tools/FileOperationsTools.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/tools/FileOperationsTools.kt new file mode 100644 index 0000000000..a633d86d95 --- /dev/null +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/tools/FileOperationsTools.kt @@ -0,0 +1,51 @@ +package ai.koog.integration.tests.utils.tools + +import ai.koog.agents.core.tools.SimpleTool +import kotlinx.serialization.Serializable + +class FileOperationsTools { + val fileContentsByPath = mutableMapOf() + + val createNewFileWithTextTool = CreateNewFileWithText(this) + val readFileContentTool = ReadFileContent(this) + + class CreateNewFileWithText(private val fileOperationsTools: FileOperationsTools) : + SimpleTool() { + @Serializable + data class Args( + val pathInProject: String, + val text: String + ) + + override val argsSerializer = Args.serializer() + override val description = "Creates a new file at the specified path with the provided text content" + + override suspend fun doExecute(args: Args): String { + return fileOperationsTools.createNewFileWithText(args.pathInProject, args.text) + } + } + + class ReadFileContent(private val fileOperationsTools: FileOperationsTools) : SimpleTool() { + @Serializable + data class Args( + val pathInProject: String + ) + + override val argsSerializer = Args.serializer() + + override val description = "Reads the content of a file at the specified path" + + override suspend fun doExecute(args: Args): String { + return fileOperationsTools.readFileContent(args.pathInProject) + } + } + + fun createNewFileWithText(pathInProject: String, text: String): String { + fileContentsByPath[pathInProject] = text + return "OK" + } + + fun readFileContent(pathInProject: String): String { + return fileContentsByPath[pathInProject] ?: "Error: file not found" + } +} diff --git a/koog-agents/build.gradle.kts b/koog-agents/build.gradle.kts index cf1d42adb1..fa334927ab 100644 --- a/koog-agents/build.gradle.kts +++ b/koog-agents/build.gradle.kts @@ -72,6 +72,7 @@ val included = setOf( ":prompt:prompt-llm", ":prompt:prompt-markdown", ":prompt:prompt-model", + ":prompt:prompt-processor", ":prompt:prompt-structure", ":prompt:prompt-tokenizer", ":prompt:prompt-xml", diff --git a/prompt/prompt-processor/Module.md b/prompt/prompt-processor/Module.md new file mode 100644 index 0000000000..88c40a03cc --- /dev/null +++ b/prompt/prompt-processor/Module.md @@ -0,0 +1,67 @@ +# Module prompt-processor + +A module for processing and fixing LLM responses. + +### Overview + +The prompt-processor module provides utilities for post-processing responses from language models. Its primary focus is +fixing incorrectly formatted tool calls that may occur when LLMs generate responses. The module includes both JSON-based +fixes for common formatting issues and LLM-based approaches for more complex corrections. + +Key components: + +- **ResponseProcessor**: An abstract base class for implementing response processors, with support for chaining multiple + processors together. +- **FixJsonToolCall**: A processor that fixes invalid tool call JSONs, handling incorrect keys and missing escapes. +- **FixToolCallLLMBased**: An advanced processor that uses the LLM itself to iteratively fix incorrectly generated tool + calls. + +### Example of usage + +Basic usage with FixJsonToolCall + +```kotlin +val processor = FixJsonToolCall(toolRegistry) + +// Execute a prompt with response processing +val responses = executor.executeProcessed(prompt, model, tools, processor) +``` + +Customizing JSON key mappings for different LLM providers + +```kotlin +val customConfig = ToolCallJsonFixConfig( + idJsonKeys = ToolCallJsonFixConfig.defaultIdJsonKeys + listOf("custom_id"), + nameJsonKeys = ToolCallJsonFixConfig.defaultNameJsonKeys + listOf("function_name"), + argsJsonKeys = ToolCallJsonFixConfig.defaultArgsJsonKeys + listOf("function_args") +) + +val processor = FixJsonToolCall(toolRegistry, customConfig) +``` + +Chaining multiple processors + +```kotlin +val processor1 = FixJsonToolCall(toolRegistry) +val processor2 = FixToolCallLLMBased(toolRegistry) + +val chainedProcessor = processor1 + processor2 +val responses = executor.executeProcessed(prompt, model, tools, chainedProcessor) +``` + +Using processor in agentic strategy + +```kotlin +// uses processor for all LLM requests in this strategy +val strategy = singleRunStrategy(responsesProcessor = processor) +``` + +```kotlin +val strategy = strategy("strategy-name") { + // ... + // uses processor for LLM calls in this node + // you can provide a processor to many nodes which call LLM + val callLLM by nodeRequestLLM(responsesProcessor = processor) + // ... +} +``` diff --git a/prompt/prompt-processor/build.gradle.kts b/prompt/prompt-processor/build.gradle.kts new file mode 100644 index 0000000000..e3dc2e0300 --- /dev/null +++ b/prompt/prompt-processor/build.gradle.kts @@ -0,0 +1,47 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":prompt:prompt-executor:prompt-executor-model")) + + api(project(":prompt:prompt-markdown")) + api(libs.kotlinx.serialization.json) + implementation(libs.oshai.kotlin.logging) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + implementation(project(":agents:agents-test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/FixJsonToolCall.kt b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/FixJsonToolCall.kt new file mode 100644 index 0000000000..3145ae1446 --- /dev/null +++ b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/FixJsonToolCall.kt @@ -0,0 +1,45 @@ +package ai.koog.prompt.processor + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.executor.model.PromptExecutor +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.Message +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlin.jvm.JvmStatic + +/** + * A response processor that fixes invalid tool call jsons. + * Fixes incorrectly formatted jsons, e.g. + * - incorrect tool id / name / arguments keys + * - missing escapes in strings + * + * @param toolRegistry The tool registry with available tools + * @param toolCallJsonFixConfig Configuration for parsing and fixing tool call json + */ +@ResponseProcessorApi +public class FixJsonToolCall( + private val toolRegistry: ToolRegistry, + private val toolCallJsonFixConfig: ToolCallJsonFixConfig = ToolCallJsonFixConfig() +) : ResponseProcessor() { + + private companion object { + @JvmStatic + private val logger = KotlinLogging.logger {} + } + + override suspend fun process( + executor: PromptExecutor, + prompt: Prompt, + model: LLModel, + tools: List, + responses: List + ): List = responses.map { response -> + logger.info { "Updating message: $response" } + response + as? Message.Tool.Call + ?: toolCallJsonFixConfig.extractToolCall(response.content, response.metaInfo, toolRegistry) + ?: response + }.also { logger.info { "Updated messages: $it" } } +} diff --git a/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/FixToolCallLLMBased.kt b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/FixToolCallLLMBased.kt new file mode 100644 index 0000000000..b7df993e1d --- /dev/null +++ b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/FixToolCallLLMBased.kt @@ -0,0 +1,187 @@ +package ai.koog.prompt.processor + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.model.PromptExecutor +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.Message +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlin.jvm.JvmStatic + +/** + * A response processor that fixes incorrectly communicated tool calls. + * + * Applies an LLM-based approach to fix incorrectly generated tool calls. + * Iteratively asks the LLM to update a message until it is a correct tool call. + * + * The first step is to identify if the corrections are needed. + * It is done by + * (a) Asking the LLM if the message intends to call a tool if the message is [Message.Assistant] + * (b) Trying to parse the name and parameters if the message is [Message.Tool.Call] + * + * The main step is to fix the message (if needed). + * The processor runs a loop asking the LLM to fix the message. + * On every iteration, the processor provides the LLM with the current message and the feedback on it. + * If the LLM fails to return a correct tool call message in [maxRetries] iterations, the fallback processor is used. + * If no fallback processor is provided, the original message is returned. + * + * Some use-cases: + * + * 1. Simple usage: + * ```kotlin + * val processor = FixToolCallLLMBased(toolRegistry) // Tool registry is required + * ``` + * + * 2. Customizing the json keys: + * + * ```kotlin + * val processor = FixToolCallLLMBased( + * toolRegistry, + * ToolCallJsonFixConfig( + * idJsonKeys = ToolCallJsonFixConfig.defaultIdJsonKeys + listOf("custom_id_keys", ...), + * nameJsonKeys = ToolCallJsonFixConfig.defaultNameJsonKeys + listOf("custom_name_keys", ...), + * argsJsonKeys = ToolCallJsonFixConfig.defaultArgsJsonKeys + listOf("custom_args_keys", ...),, + * ), // Add custom json keys produced by your LLM + * ) + * ``` + * + * 3. Using a fallback processor. Here the fallback processor calls another (e.g. better but more expensive) LLM to fix the message: + * ``` + * val betterModel = OpenAIModels.Chat.GPT4o + * val fallbackProcessor = object : ResponseProcessor() { + * override suspend fun process( + * executor: PromptExecutor, + * prompt: Prompt, + * model: LLModel, + * tools: List, + * responses: List + * ): List { + * val promptFixing = prompt(prompt) { + * user("please fix the following incorrectly generated tool call messages: $responses") + * } + * return executor.execute(promptFixing, betterModel, tools) // use a better LLM + * } + * } + * + * val processor = FixToolCallLLMBased( + * toolRegistry, + * fallbackProcessor = fallbackProcessor + * ) + * ``` + * + * @param toolRegistry The tool registry with available tools + * @param toolCallJsonFixConfig Configuration for parsing and fixing tool call json + * @param assessToolCallIntentSystemMessage The system message to ask LLM if a tool call was intended + * @param fixToolCallSystemMessage The system message to ask LLM to fix a tool call + * @param invalidJsonFeedback The message sent to the LLM when tool call json is invalid + * @param invalidNameFeedback The message sent to the LLM when the tool name is invalid + * @param invalidArgumentsFeedback The message sent to the LLM when tool arguments are invalid + * @param fallbackProcessor The fallback processor to use if LLM fails to fix a tool call. + * Defaults to null, meaning that the original message is returned if the LLM fails to fix a tool call. + * @param preprocessor A processor applied to all responses from the LLM. Defaults to [FixJsonToolCall] + * @param maxRetries The maximum number of iterations in the main loop + */ +@ResponseProcessorApi +public class FixToolCallLLMBased( + private val toolRegistry: ToolRegistry, + private val toolCallJsonFixConfig: ToolCallJsonFixConfig = ToolCallJsonFixConfig(), + private val preprocessor: ResponseProcessor = FixJsonToolCall(toolRegistry, toolCallJsonFixConfig), + private val fallbackProcessor: ResponseProcessor? = null, + private val assessToolCallIntentSystemMessage: String = Messages.assessToolCallIntent, + private val fixToolCallSystemMessage: String = Messages.fixToolCall, + private val invalidJsonFeedback: (List) -> String = Messages::invalidJsonFeedback, + private val invalidNameFeedback: (String, List) -> String = Messages::invalidNameFeedback, + private val invalidArgumentsFeedback: (String, ToolDescriptor) -> String = Messages::invalidArgumentsFeedback, + private val maxRetries: Int = 3, +) : ResponseProcessor() { + + private companion object { + @JvmStatic + private val logger = KotlinLogging.logger {} + } + + init { + require(maxRetries > 0) { "numRetries must be greater than 0" } + } + + override suspend fun process( + executor: PromptExecutor, + prompt: Prompt, + model: LLModel, + tools: List, + responses: List + ): List = responses.map processSingleMessage@{ response -> + logger.info { "Updating message: $response" } + + var result = preprocessor.process(executor, prompt, model, tools, response) + if (!isToolCallIntended(executor, prompt, model, result)) return@processSingleMessage result + + var fixToolCallPrompt = prompt(prompt.withMessages { emptyList() }) { + system(fixToolCallSystemMessage) + } + + var i = 0 + + while (i++ < maxRetries) { + val feedback = getFeedback(result, tools) ?: return@processSingleMessage result + fixToolCallPrompt = prompt(fixToolCallPrompt) { + message(result) + user(feedback) + } + result = executor.executeProcessed(fixToolCallPrompt, model, tools, preprocessor).first() + } + + // use fallback with the initial prompt + fallbackProcessor?.process(executor, prompt, model, tools, response) ?: response + }.also { + logger.info { "Updated messages: $it" } + } + + private suspend fun isToolCallIntended( + executor: PromptExecutor, + prompt: Prompt, + model: LLModel, + response: Message.Response + ): Boolean { + if (response is Message.Tool.Call) return true + + val toolCallIntentPrompt = prompt(prompt.withMessages { emptyList() }) { + system(assessToolCallIntentSystemMessage) + user(response.content) + } + + val decision = executor.execute(toolCallIntentPrompt, model, emptyList()).first() + + return decision is Message.Tool.Call || + decision.content.contains( + Messages.INTENDED_TOOL_CALL, + ignoreCase = true + ) + } + + private fun getFeedback( + message: Message.Response, + tools: List, + ): String? { + val toolName = (message as? Message.Tool.Call)?.tool + ?: toolCallJsonFixConfig.getToolName(message.content) + ?: return invalidJsonFeedback(tools) + + if (!tools.any { it.name == toolName }) { + return invalidNameFeedback(toolName, tools) + } + + val tool = toolRegistry.getTool(toolName) + + try { + tool.decodeArgs((message as Message.Tool.Call).contentJson) + } catch (e: Exception) { + val errorMessage = e.message ?: "Unknown error" + return invalidArgumentsFeedback(errorMessage, tool.descriptor) + } + + return null + } +} diff --git a/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/Messages.kt b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/Messages.kt new file mode 100644 index 0000000000..a46e439887 --- /dev/null +++ b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/Messages.kt @@ -0,0 +1,133 @@ +package ai.koog.prompt.processor + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.prompt.markdown.markdown + +internal object Messages { + const val INTENDED_TOOL_CALL = "YES" + const val NOT_INTENDED_TOOL_CALL = "NO" + + val assessToolCallIntent = + markdown { + +"You are a helpful assistant specialized in determining if a message is intended to call a tool." + br() + + h2("TASK") + +"The user will provide you with a message." + +"Your task is to determine if this message is intended to call a tool or if it's just a regular assistant message." + br() + + h2("INSTRUCTIONS") + +"Determine if the message is meant to call a tool or perform a specific action that would require a tool call." + +"Important: Distinguish between reports about actions and actual intent to perform actions:" + bulleted { + item("If the message only reports or describes what was done or what happened, it is NOT a tool call") + item("If the message expresses an intent to perform an action or request an action to be performed, it IS a tool call") + item("If the message contains both, it IS a tool call") + } + + h3("EXAMPLES OF INTENT PHRASES") + +"These phrases indicate an intent to perform an action (IS a tool call):" + bulleted { + item("Now I will do that") + item("Now I have to do that") + item("Let me search for that") + item("When a message contains a json-like structure designed to call a tool: {name: , args: }") + } + + h3("EXAMPLES OF REGULAR ASSISTANT MESSAGES") + +"These phrases indicate a regular assistant message (IS NOT a tool call):" + bulleted { + item("Done!") + item("I'm done!") + item("I completed the task.") + } + + h2("RESPONSE FORMAT") + +"Respond with ONLY ONE of these exact options:" + + bulleted { + item("$INTENDED_TOOL_CALL - if a tool call was intended") + item("$NOT_INTENDED_TOOL_CALL - if no tool call was intended") + } + br() + } + + val fixToolCall = + markdown { + +"You are a helpful assistant specialized in fixing tool call formats." + br() + + h2("TASK") + +"You will see a tool call message with an incorrect format: invalid JSON or use incorrect tool names." + +"Your task is to convert the message to the proper format." + br() + + h2("COMMON ISSUES TO FIX") + bulleted { + item("Intent message instead of direct tool call") + item("Invalid JSON syntax") + item("Missing required parameters for the tool") + item("Incorrect tool names (misspelled or non-existent)") + } + br() + + +"Your goal is to fix the format while preserving the original intention of the message." + +"YOUR RESPONSE MUST BE A TOOL CALL MESSAGE IN THE CORRECT FORMAT!" + br() + } + + fun invalidJsonFeedback(tools: List) = + markdown { + +"The message appears to be intending to call a tool, but it's not in the proper tool call format." + br() + + +"Please generate a proper tool call message based on the provided message." + br() + + h2("IMPORTANT INSTRUCTIONS") + bulleted { + item("DO NOT explain what you're going to do - just call the tool directly") + item("DO NOT respond with text descriptions - use the JSON format") + } + br() + + h2("POSSIBLE ISSUES") + bulleted { + item("The message shows an intention to call a tool but does not produce a tool call") + item("Incorrect json formatting in tool call json: unescaped characters, missing quotes, etc.") + } + + h2("Available tools") + showTools(tools) + } + + fun invalidNameFeedback(toolName: String, tools: List) = + markdown { + +"Tool name \"$toolName\" is not recognized." + br() + + +"Available tools:" + showTools(tools) + } + + fun invalidArgumentsFeedback(errorMessage: String, tool: ToolDescriptor) = + markdown { + +"Failed to parse tool arguments with error: $errorMessage" + br() + + +"$tool" + br() + + +"Please rewrite the tool call using proper JSON format." + } + + fun showTools(tools: List) = + markdown { + bulleted { + tools.forEach { tool -> + item(tool.name) + } + } + } +} diff --git a/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/PromptExecutorExtension.kt b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/PromptExecutorExtension.kt new file mode 100644 index 0000000000..465319b48f --- /dev/null +++ b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/PromptExecutorExtension.kt @@ -0,0 +1,20 @@ +package ai.koog.prompt.processor + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.executor.model.PromptExecutor +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.Message + +/** + * Executes the given prompt and processes responses using the given [responseProcessor]. + */ +public suspend fun PromptExecutor.executeProcessed( + prompt: Prompt, + model: LLModel, + tools: List, + responseProcessor: ResponseProcessor +): List { + val responses = execute(prompt, model, tools) + return responseProcessor.process(this, prompt, model, tools, responses) +} diff --git a/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/ResponseProcessor.kt b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/ResponseProcessor.kt new file mode 100644 index 0000000000..8653e429e8 --- /dev/null +++ b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/ResponseProcessor.kt @@ -0,0 +1,74 @@ +package ai.koog.prompt.processor + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.executor.model.PromptExecutor +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.Message + +/** + * Opt-in annotation for ResponseProcessor API. + */ +@RequiresOptIn( + message = "ResponseProcessor API can update LLM responses and make additional calls to LLM. Please opt-in only when you reviewed the api and understand the risks." +) +public annotation class ResponseProcessorApi + +/** + * A processor for handling and modifying LLM responses. + */ +public abstract class ResponseProcessor @ResponseProcessorApi constructor() { + + /** + * Processes the given LLM responses. + * These responses were received using [executor], [prompt], [model], [tools]. + */ + internal abstract suspend fun process( + executor: PromptExecutor, + prompt: Prompt, + model: LLModel, + tools: List, + responses: List + ): List + + internal suspend fun process( + executor: PromptExecutor, + prompt: Prompt, + model: LLModel, + tools: List, + response: Message.Response + ) = process( + executor, + prompt, + model, + tools, + listOf(response) + ).first() + + /** + * Chains multiple response processors together. + */ + @OptIn(ResponseProcessorApi::class) + public class Chain(vararg processors: ResponseProcessor) : ResponseProcessor() { + private val processors = processors.toList() + + override suspend fun process( + executor: PromptExecutor, + prompt: Prompt, + model: LLModel, + tools: List, + responses: List + ): List { + var result = responses + for (processor in processors) { + result = processor.process(executor, prompt, model, tools, result) + } + return result + } + } + + /** + * Chains two processors together. + */ + public operator fun plus(other: ResponseProcessor): ResponseProcessor = Chain(this, other) +} diff --git a/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/ToolCallJsonFixConfig.kt b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/ToolCallJsonFixConfig.kt new file mode 100644 index 0000000000..1e0c1332a2 --- /dev/null +++ b/prompt/prompt-processor/src/commonMain/kotlin/ai/koog/prompt/processor/ToolCallJsonFixConfig.kt @@ -0,0 +1,195 @@ +package ai.koog.prompt.processor + +import ai.koog.agents.core.tools.ToolParameterType +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.Serializable +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonDecoder +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.serializer + +/** + * Configuration for parsing and fixing tool call json + * + * @param json The json parser to use for parsing tool call json + * @param idJsonKeys The keys used by various models for tool ID in tool call json + * @param nameJsonKeys The keys used by various models for tool name in tool call json + * @param argsJsonKeys The keys used by various models for tool arguments in tool call json + */ +public class ToolCallJsonFixConfig( + private val json: Json = defaultJson, + private val idJsonKeys: List = defaultIdJsonKeys, + private val nameJsonKeys: List = defaultNameJsonKeys, + private val argsJsonKeys: List = defaultArgsJsonKeys, +) { + + private fun fixJsonString(jsonValue: String): String { + // Remove the surrounding quotes to work with the content + val content = if (jsonValue.startsWith('"') && jsonValue.endsWith('"')) { + jsonValue.drop(1).dropLast(1) + } else { + jsonValue + } + + val correctString = StringBuilder() + var i = 0 + while (i < content.length) { + if (content[i] == '\\' && i + 1 < content.length) { + when (content[i + 1]) { + '"' -> '"' + '\\' -> '\\' + '/' -> '/' + 'b' -> '\b' + 'n' -> '\n' + 'r' -> '\r' + 't' -> '\t' + else -> null + }?.let { + correctString.append(it) + i++ + } + } else { + // Regular character + correctString.append(content[i]) + } + i++ + } + + return correctString.toString() + } + + internal fun getToolName(messageContent: String): String? { + val toolKeyPattern = getKeyPattern(nameJsonKeys) + val toolNameRegex = """$toolKeyPattern\s*:\s*"([a-zA-Z0-9_]+)"""".toRegex() + return toolNameRegex.find(messageContent)?.groupValues?.get(2) + } + + internal fun extractToolCall( + messageContent: String, + metaInfo: ResponseMetaInfo, + toolRegistry: ToolRegistry, + ): Message.Tool.Call? { + runCatching { + val decodedToolCall = json.decodeFromString(toolCallDeserializer, messageContent) + return Message.Tool.Call( + decodedToolCall.id, + decodedToolCall.tool, + decodedToolCall.args.toString(), + metaInfo + ) + } + + val toolName = getToolName(messageContent) ?: return null + + val params = runCatching { + toolRegistry.getTool(toolName).descriptor.requiredParameters + }.getOrNull() ?: return null + + val argsKeyPattern = getKeyPattern(argsJsonKeys) + val argsPattern = """\{\s*${params.joinToString("\\s*,\\s*") { """"${it.name}"\s*:\s*(.+)""" }}\s*\}\s*\}""" + val argsRegex = """$argsKeyPattern\s*:\s*$argsPattern""".toRegex() + val argsMatch = argsRegex.find(messageContent)?.groupValues + val args = argsMatch?.drop(2) ?: return null + + val fixedArgs = buildJsonObject { + params.zip(args).forEach { (param, argValue) -> + val key = param.name + val value = when (param.type) { + is ToolParameterType.String -> JsonPrimitive(fixJsonString(argValue)) + else -> json.parseToJsonElement(argValue) + } + put(key, value) + } + }.toString() + + val idKeyPattern = getKeyPattern(idJsonKeys) + val idRegex = """$idKeyPattern\s*:\s*"([a-zA-Z0-9_]+)"""".toRegex() + val id = idRegex.find(messageContent)?.groupValues?.get(3) + + return Message.Tool.Call(id, toolName, fixedArgs, metaInfo) + } + + @Serializable + private data class ToolCall( + val id: String? = null, + val tool: String, + val args: JsonObject + ) + + private val toolCallDeserializer + get() = object : DeserializationStrategy { + private val baseSerializer = serializer() + override val descriptor = baseSerializer.descriptor + + override fun deserialize(decoder: Decoder): ToolCall { + require(decoder is JsonDecoder) { "This serializer can only be used with JSON" } + + val jsonElement = decoder.decodeJsonElement() + require(jsonElement is JsonObject) { "Expected a JSON object" } + + return deserializeFromJsonObject(jsonElement, decoder.json) + } + + private fun deserializeFromJsonObject(jsonObject: JsonObject, json: Json): ToolCall { + var objectToDeserialize = findNestedObject(jsonObject) ?: jsonObject + + objectToDeserialize = updateKey(objectToDeserialize, idJsonKeys, "id") + objectToDeserialize = updateKey(objectToDeserialize, nameJsonKeys, "tool") + objectToDeserialize = updateKey(objectToDeserialize, argsJsonKeys, "args") + + return json.decodeFromJsonElement(baseSerializer, objectToDeserialize) + } + + private fun findNestedObject(jsonObject: JsonObject): JsonObject? = + jsonObject.takeIf { it.size == 1 }?.entries?.first()?.value as? JsonObject + + private fun updateKey( + jsonObject: JsonObject, + expectedKeys: List, + updatedKey: String + ) = buildJsonObject { + for ((key, value) in jsonObject) { + put(if (key in expectedKeys) updatedKey else key, value) + } + } + } + + /** + * Companion object with defaults for json configurations of [ToolCallJsonFixConfig] + */ + public companion object { + + /** + * Default json configuration used to parse tool call json + */ + public val defaultJson: Json = Json { + ignoreUnknownKeys = true + isLenient = true + } + + /** + * Keys used by various models for tool ID in tool call json + */ + public val defaultIdJsonKeys: List = listOf("id", "tool_call_id") + + /** + * Keys used by various models for tool name in tool call json + */ + public val defaultNameJsonKeys: List = listOf("name", "tool", "tool_name") + + /** + * Keys used by various models for tool arguments in tool call json + */ + public val defaultArgsJsonKeys: List = listOf("arguments", "args", "parameters", "params", "tool_args") + + private fun getKeyPattern(keys: List): String { + return """"(${keys.joinToString("|")})"""" + } + } +} diff --git a/prompt/prompt-processor/src/commonTest/kotlin/ai/koog/prompt/processor/FixJsonToolCallTest.kt b/prompt/prompt-processor/src/commonTest/kotlin/ai/koog/prompt/processor/FixJsonToolCallTest.kt new file mode 100644 index 0000000000..6e257984cd --- /dev/null +++ b/prompt/prompt-processor/src/commonTest/kotlin/ai/koog/prompt/processor/FixJsonToolCallTest.kt @@ -0,0 +1,221 @@ +package ai.koog.prompt.processor + +import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertIs +import kotlin.test.assertNotNull + +@OptIn(ResponseProcessorApi::class) +class FixJsonToolCallTest { + private companion object { + private val testClock: Clock = object : Clock { + override fun now(): Instant = Instant.parse("2023-01-01T00:00:00Z") + } + + private val testMetaInfo = ResponseMetaInfo.create(testClock) + + private const val ID = "123" + private const val TOOL = "plus" + private const val ARGS = """{"a":5,"b":3}""" + + private val validJson = """ + { + "tool": "$TOOL", + "args": $ARGS + } + """.trimIndent() + + private val validJsonWithId = """ + { + "id": "$ID", + "tool": "$TOOL", + "args": $ARGS + } + """.trimIndent() + + private val validJsonWithAlternativeJsonKeys = """ + { + "tool_call_id": "$ID", + "tool_name": "$TOOL", + "parameters": $ARGS + } + """.trimIndent() + + private val nestedJson = """ + { + "function_call": { + "name": "$TOOL", + "arguments": $ARGS + } + } + """.trimIndent() + + private val invalidJson = """ + { + "args": $ARGS + } + """.trimIndent() + + private val taggedJson = """ + Some text before the tool call + + $validJson + + Some text after the tool call + """.trimIndent() + + private val multipleToolCallsJson = """ + $validJson + Some text in between + + { + "tool": "weather", + "args": { + "location": "New York" + } + } + """.trimIndent() + + private val executor = getMockExecutor { } + private val prompt = prompt("test-prompt") { } + private val model = OpenAIModels.Chat.GPT4o + private val toolRegistry = Tools.toolRegistry + private val tools = toolRegistry.tools.map { it.descriptor } + + private val fixJsonToolCall = FixJsonToolCall(toolRegistry) + + private fun validateToolCallResult(result: Message.Response, checkId: Boolean = false) { + assertIs(result) + if (checkId) assertEquals(ID, result.id) + assertEquals(TOOL, result.tool) + assertEquals(ARGS, result.content) + } + } + + @Test + fun test_shouldParseValidJson() = runTest { + val message = Message.Assistant(validJson, metaInfo = testMetaInfo) + val result = process(message) + + validateToolCallResult(result) + } + + @Test + fun test_shouldParseToolCallWithId() = runTest { + val message = Message.Assistant(validJsonWithId, metaInfo = testMetaInfo) + val result = process(message) + + validateToolCallResult(result, checkId = true) + } + + @Test + fun test_shouldParseAlternativeJsonKeys() = runTest { + val message = Message.Assistant(validJsonWithAlternativeJsonKeys, metaInfo = testMetaInfo) + val result = process(message) + + validateToolCallResult(result) + } + + @Test + fun test_shouldParseNestedJson() = runTest { + val message = Message.Assistant(nestedJson, metaInfo = testMetaInfo) + val result = process(message) + + validateToolCallResult(result) + } + + @Test + fun test_shouldNotParseInvalidJson() = runTest { + val message = Message.Assistant(invalidJson, metaInfo = testMetaInfo) + val result = process(message) + + assertEquals(message, result) + } + + @Test + fun test_shouldParseTaggedJson() = runTest { + val message = Message.Assistant(taggedJson, metaInfo = testMetaInfo) + val result = process(message) + + validateToolCallResult(result) + } + + @Test + fun test_shouldParseFirstOfMultipleJsons() = runTest { + val message = Message.Assistant(multipleToolCallsJson, metaInfo = testMetaInfo) + val result = process(message) + + validateToolCallResult(result) + } + + @Test + fun test_shouldParseCorrectEscapes() = runTest { + val text1 = "Test \"quoted\" string" + val text2 = "Test a string with\na new line" + + val validJson = """ + { + "tool": "string_tool", + "args": { + "text1": "Test \"quoted\" string", + "text2": "Test a string with\na new line" + } + } + """.trimIndent() + + val message = Message.Assistant(validJson, metaInfo = testMetaInfo) + val result = process(message) + + assertNotNull(result) + assertIs(result) + assertEquals("string_tool", result.tool) + + val expectedContentJson = buildJsonObject { + put("text1", JsonPrimitive(text1)) + put("text2", JsonPrimitive(text2)) + } + assertEquals(expectedContentJson.toString(), result.content) + } + + @Test + fun test_shouldParseIncorrectEscapes() = runTest { + val text1 = "Test \"quoted\" string" + val text2 = "Test a string with\na new line" + + val malformedJson = """ + { + "tool": "string_tool", + "args": { + "text1": "$text1", + "text2": "Test a string with\na new line" + } + } + """.trimIndent() + + val message = Message.Assistant(malformedJson, metaInfo = testMetaInfo) + val result = process(message) + + assertNotNull(result) + assertIs(result) + assertEquals("string_tool", result.tool) + + val expectedContentJson = buildJsonObject { + put("text1", JsonPrimitive(text1)) + put("text2", JsonPrimitive(text2)) + } + assertEquals(expectedContentJson.toString(), result.content) + } + + private suspend fun process(response: Message.Response) = + fixJsonToolCall.process(executor, prompt, model, tools, response) +} diff --git a/prompt/prompt-processor/src/commonTest/kotlin/ai/koog/prompt/processor/FixToolCallLLMBasedTest.kt b/prompt/prompt-processor/src/commonTest/kotlin/ai/koog/prompt/processor/FixToolCallLLMBasedTest.kt new file mode 100644 index 0000000000..05e8a5b6f2 --- /dev/null +++ b/prompt/prompt-processor/src/commonTest/kotlin/ai/koog/prompt/processor/FixToolCallLLMBasedTest.kt @@ -0,0 +1,181 @@ +package ai.koog.prompt.processor + +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.prompt.dsl.ModerationResult +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.executor.model.PromptExecutor +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.prompt.streaming.StreamFrame +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant +import kotlin.test.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals + +@OptIn(ResponseProcessorApi::class) +class FixToolCallLLMBasedTest { + private companion object { + private val testClock: Clock = object : Clock { + override fun now(): Instant = Instant.parse("2023-01-01T00:00:00Z") + } + + private val testMetaInfo = ResponseMetaInfo.create(testClock) + + private val notIntendedToolCall = Message.Assistant(Messages.NOT_INTENDED_TOOL_CALL, metaInfo = testMetaInfo) + private val intendedToolCall = Message.Assistant(Messages.INTENDED_TOOL_CALL, metaInfo = testMetaInfo) + private val toolCallMessage = Message.Tool.Call( + id = null, + tool = "plus", + content = """{"a":5,"b":3}""", + metaInfo = testMetaInfo + ) + + private val toolRegistry = Tools.toolRegistry + + private val tools = toolRegistry.tools.map { it.descriptor } + private val prompt = prompt("test-prompt") { } + private val model = OpenAIModels.Chat.GPT4o + + private val message = Message.Assistant("I want to use the calculator tool", metaInfo = testMetaInfo) + + val processor = FixToolCallLLMBased(toolRegistry) + } + + private class MockExecutor( + private val responses: List, + ) : PromptExecutor { + private var index = 0 + val prompts = mutableListOf() + + override suspend fun execute( + prompt: Prompt, + model: LLModel, + tools: List + ): List = + listOf(responses[index++]).also { prompts.add(prompt) } + + override fun executeStreaming(prompt: Prompt, model: LLModel, tools: List): Flow = + error("Not supported") + + override suspend fun moderate(prompt: Prompt, model: LLModel): ModerationResult = error("Not supported") + override fun close() {} + } + + private suspend fun process( + executor: PromptExecutor, + response: Message.Response, + processor: ResponseProcessor + ) = processor.process(executor, prompt, model, tools, response) + + @Test + fun test_shouldStopIfToolCallNotIntended() = runTest { + val executor = MockExecutor(listOf(notIntendedToolCall)) + val result = process(executor, message, processor) + + assertEquals(message, result) + assertEquals( + Messages.assessToolCallIntent, + executor.prompts.last().messages.dropLast(1).last().content + ) + } + + @Test + fun test_shouldFixAssistantMessage() = runTest { + val executor = MockExecutor(listOf(intendedToolCall, toolCallMessage)) + val result = process(executor, message, processor) + + assertEquals(toolCallMessage, result) + assertEquals( + Messages.fixToolCall, + executor.prompts.last().messages.dropLast(2).last().content + ) + assertEquals( + Messages.invalidJsonFeedback(tools), + executor.prompts.last().messages.last().content + ) + } + + @Test + fun test_shouldFixInvalidToolName() = runTest { + val executor = MockExecutor(listOf(toolCallMessage)) + val message = toolCallMessage.copy(tool = "minus") + val result = process(executor, message, processor) + + assertEquals(toolCallMessage, result) + assertEquals( + Messages.invalidNameFeedback("minus", tools), + executor.prompts.last().messages.last().content + ) + } + + @Test + fun test_shouldFixIncorrectArguments() = runTest { + val executor = MockExecutor(listOf(toolCallMessage)) + val message = Message.Tool.Call( + id = null, + tool = "plus", + content = """{"x":5,"y":3}""", + metaInfo = testMetaInfo + ) + val result = process(executor, message, processor) + + assertEquals(toolCallMessage, result) + assertContains( + executor.prompts.last().messages.last().content, + "Failed to parse tool arguments with error" + ) + } + + @Test + fun test_shouldRetry() = runTest { + val executor = MockExecutor(listOf(intendedToolCall, toolCallMessage.copy(tool = "minus"), toolCallMessage)) + val result = process(executor, message, processor) + + assertEquals(toolCallMessage, result) + } + + @Test + fun test_shouldStopWhenMaxRetriesReached() = runTest { + val executor = MockExecutor(listOf(intendedToolCall, toolCallMessage.copy(tool = "minus"))) + val processor = FixToolCallLLMBased( + toolRegistry = toolRegistry, + maxRetries = 1, + ) + val result = process(executor, message, processor) + + assertEquals(message, result) + } + + @Test + fun test_shouldApplyFallbackWhenMaxRetriesReached() = runTest { + val executor = MockExecutor(listOf(intendedToolCall, toolCallMessage)) + val fallbackExecutor = MockExecutor(listOf(toolCallMessage)) + + @OptIn(ResponseProcessorApi::class) + val fallbackProcessor = object : ResponseProcessor() { + override suspend fun process( + executor: PromptExecutor, + prompt: Prompt, + model: LLModel, + tools: List, + responses: List + ): List = fallbackExecutor.execute(prompt, model, tools) + } + + val processor = FixToolCallLLMBased( + toolRegistry = toolRegistry, + fallbackProcessor = fallbackProcessor, + maxRetries = 1, + ) + + val result = process(executor, message, processor) + + assertEquals(toolCallMessage, result) + } +} diff --git a/prompt/prompt-processor/src/commonTest/kotlin/ai/koog/prompt/processor/Tools.kt b/prompt/prompt-processor/src/commonTest/kotlin/ai/koog/prompt/processor/Tools.kt new file mode 100644 index 0000000000..a11a1fdebf --- /dev/null +++ b/prompt/prompt-processor/src/commonTest/kotlin/ai/koog/prompt/processor/Tools.kt @@ -0,0 +1,83 @@ +package ai.koog.prompt.processor + +import ai.koog.agents.core.tools.Tool +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.agents.core.tools.ToolParameterDescriptor +import ai.koog.agents.core.tools.ToolParameterType +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.core.tools.annotations.LLMDescription +import kotlinx.serialization.KSerializer +import kotlinx.serialization.Serializable +import kotlin.jvm.JvmInline + +object Tools { + object PlusTool : Tool() { + override val description: String = "Adds a and b" + override val name: String = "plus" + + @Serializable + data class Args( + @property:LLMDescription("First number") + val a: Float, + @property:LLMDescription("Second number") + val b: Float + ) + + @Serializable + @JvmInline + value class Result(val result: Float) + + override val argsSerializer = Args.serializer() + override val resultSerializer: KSerializer = Result.serializer() + + override suspend fun execute(args: Args): Result { + return Result(args.a + args.b) + } + } + + object StringTool : Tool() { + @Serializable + data class Args( + @property:LLMDescription("First string") + val text1: String, + @property:LLMDescription("Second string") + val text2: String + ) + + @Serializable + @JvmInline + value class Result(val result: String) + + override val name: String = "string_tool" + override val description: String = "A tool that takes string parameters" + override val argsSerializer = Args.serializer() + override val resultSerializer: KSerializer = Result.serializer() + + override val descriptor = ToolDescriptor( + name = "string_tool", + description = "A tool that takes string parameters", + requiredParameters = listOf( + ToolParameterDescriptor( + name = "text1", + description = "First string", + type = ToolParameterType.String + ), + ToolParameterDescriptor( + name = "text2", + description = "Second string", + type = ToolParameterType.String + ) + ) + ) + + override suspend fun execute(args: Args): Result { + return Result(args.text1 + " " + args.text2) + } + } + + val tools = listOf(PlusTool, StringTool) + val toolRegistry = ToolRegistry { + tool(PlusTool) + tool(StringTool) + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 08b15441e3..f235cd0472 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -58,6 +58,7 @@ include(":prompt:prompt-executor:prompt-executor-model") include(":prompt:prompt-llm") include(":prompt:prompt-markdown") include(":prompt:prompt-model") +include("prompt:prompt-processor") include(":prompt:prompt-structure") include(":prompt:prompt-tokenizer") include(":prompt:prompt-xml")