Skip to content

Commit e2a5cda

Browse files
authored
KG-381 Add input and total tokens count in internal Mock Test framework (#1145)
<!-- Thank you for opening a pull request! Please add a brief description of the proposed change here. Also, please tick the appropriate points in the checklist below. --> ## Motivation and Context <!-- Why is this change needed? What problem does it solve? --> [KG-381](https://youtrack.jetbrains.com/issue/KG-381) Add input and total tokens count in internal Mock Test framework 1. The new `updateTokenCounts` function calculates input, output, and total token counts. 2. Updated response metadata creation to include token counts where applicable. 3. Removed unused `inputTokensCount` variable. 4. Added tests for token counter in the MockLLM responses. ## Breaking Changes <!-- Will users need to update their code or configurations? --> None. --- #### Type of the changes - [x] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Tests improvement - [x] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [x] An issue describing the proposed change exists - [x] The pull request includes a link to the issue - [x] The change was discussed and approved in the issue - [x] Docs have been added / updated
1 parent f88af15 commit e2a5cda

File tree

4 files changed

+383
-35
lines changed

4 files changed

+383
-35
lines changed

agents/agents-features/agents-features-opentelemetry/src/jvmTest/kotlin/ai/koog/agents/features/opentelemetry/feature/span/OpenTelemetryInferenceSpanTest.kt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ class OpenTelemetryInferenceSpanTest : OpenTelemetryTestBase() {
445445
}
446446

447447
@Test
448-
fun `test inference span contains tokes data`() = runTest {
448+
fun `test inference span contains tokens data`() = runTest {
449449
val userInput = USER_PROMPT_PARIS
450450
val mockLLMResponse = MOCK_LLM_RESPONSE_PARIS
451451
val model = defaultModel
@@ -492,7 +492,10 @@ class OpenTelemetryInferenceSpanTest : OpenTelemetryTestBase() {
492492
"gen_ai.operation.name" to "chat",
493493
"gen_ai.request.temperature" to temperature,
494494
"gen_ai.response.finish_reasons" to listOf(FinishReasonType.Stop.id),
495-
"gen_ai.usage.output_tokens" to tokenizer.countTokens(text = mockLLMResponse).toLong()
495+
"gen_ai.usage.input_tokens" to tokenizer.countTokens(text = userInput).toLong(),
496+
"gen_ai.usage.output_tokens" to tokenizer.countTokens(text = mockLLMResponse).toLong(),
497+
"gen_ai.usage.total_tokens" to tokenizer.countTokens(text = userInput)
498+
.toLong() + tokenizer.countTokens(text = mockLLMResponse).toLong(),
496499
),
497500
"events" to mapOf(
498501
"gen_ai.system.message" to mapOf(

agents/agents-test/src/commonMain/kotlin/ai/koog/agents/testing/tools/MockLLMBuilder.kt

Lines changed: 91 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
114114
*
115115
* Useful in scenarios where the mock response handling involves mixed results
116116
* from the LLM, and there is a need to differentiate between handling the general
117-
* last message vs the last assistant-specific message.
117+
* last message vs. the last assistant-specific message.
118118
*/
119119
public var handleLastAssistantMessage: Boolean = false
120120

@@ -184,7 +184,12 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
184184
id = toolCallId,
185185
tool = tool.name,
186186
content = toolContent,
187-
metaInfo = ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(toolContent))
187+
metaInfo = ResponseMetaInfo.create(
188+
clock,
189+
inputTokensCount = null, // Will be updated at runtime with actual input
190+
outputTokensCount = tokenizer?.countTokens(toolContent),
191+
totalTokensCount = null // Will be calculated at runtime
192+
)
188193
)
189194
)
190195
}
@@ -208,7 +213,12 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
208213
id = null,
209214
tool = tool.name,
210215
content = toolContent,
211-
metaInfo = ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(toolContent))
216+
metaInfo = ResponseMetaInfo.create(
217+
clock,
218+
inputTokensCount = null, // Will be updated at runtime with actual input
219+
outputTokensCount = tokenizer?.countTokens(toolContent),
220+
totalTokensCount = null // Will be calculated at runtime
221+
)
212222
)
213223
)
214224
}
@@ -231,7 +241,12 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
231241
id = null,
232242
tool = tool.name,
233243
content = toolContent,
234-
metaInfo = ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(toolContent))
244+
metaInfo = ResponseMetaInfo.create(
245+
clock,
246+
inputTokensCount = null, // Will be updated at runtime with actual input
247+
outputTokensCount = tokenizer?.countTokens(toolContent),
248+
totalTokensCount = null // Will be calculated at runtime
249+
)
235250
)
236251
}
237252
}
@@ -253,7 +268,12 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
253268
id = null,
254269
tool = tool.name,
255270
content = toolContent,
256-
metaInfo = ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(toolContent))
271+
metaInfo = ResponseMetaInfo.create(
272+
clock,
273+
inputTokensCount = null, // Will be updated at runtime with actual input
274+
outputTokensCount = tokenizer?.countTokens(toolContent),
275+
totalTokensCount = null // Will be calculated at runtime
276+
)
257277
)
258278
}
259279
}
@@ -278,7 +298,12 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
278298
id = null,
279299
tool = tool.name,
280300
content = toolContent,
281-
metaInfo = ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(toolContent))
301+
metaInfo = ResponseMetaInfo.create(
302+
clock,
303+
inputTokensCount = null, // Will be updated at runtime with actual input
304+
outputTokensCount = tokenizer?.countTokens(toolContent),
305+
totalTokensCount = null // Will be calculated at runtime
306+
)
282307
)
283308
}
284309
}
@@ -306,7 +331,12 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
306331
id = null,
307332
tool = tool.name,
308333
content = toolContent,
309-
metaInfo = ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(toolContent))
334+
metaInfo = ResponseMetaInfo.create(
335+
clock,
336+
inputTokensCount = null, // Cannot determine input tokens for conditional matches without the actual input string
337+
outputTokensCount = tokenizer?.countTokens(toolContent),
338+
totalTokensCount = null // Will be calculated at runtime
339+
)
310340
)
311341
)
312342
}
@@ -330,7 +360,12 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
330360
id = null,
331361
tool = tool.name,
332362
content = toolContent,
333-
metaInfo = ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(toolContent))
363+
metaInfo = ResponseMetaInfo.create(
364+
clock,
365+
inputTokensCount = null, // Cannot determine input tokens for conditional matches without the actual input string
366+
outputTokensCount = tokenizer?.countTokens(toolContent),
367+
totalTokensCount = null // Will be calculated at runtime
368+
)
334369
)
335370
}
336371
}
@@ -366,7 +401,12 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
366401
id = null,
367402
tool = tool.name,
368403
content = toolContent,
369-
metaInfo = ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(toolContent))
404+
metaInfo = ResponseMetaInfo.create(
405+
clock,
406+
inputTokensCount = tokenizer?.countTokens(pattern),
407+
outputTokensCount = tokenizer?.countTokens(toolContent),
408+
totalTokensCount = null // Will be calculated at runtime
409+
)
370410
)
371411
}
372412
}
@@ -792,24 +832,35 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
792832
texts.map { text ->
793833
Message.Assistant(
794834
text,
795-
ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(text))
835+
ResponseMetaInfo.create(
836+
clock,
837+
inputTokensCount = null, // Will be updated at runtime with actual input
838+
outputTokensCount = tokenizer?.countTokens(text),
839+
totalTokensCount = null // Will be calculated at runtime
840+
)
796841
)
797842
}
798843
}
799844

800-
val combinedExactMatches = (processedAssistantExactMatches.keys + toolCallExactMatches.keys).associateWith { key ->
801-
val assistantList = processedAssistantExactMatches[key] ?: emptyList()
802-
val toolCallList = toolCallExactMatches[key] ?: emptyList()
803-
assistantList + toolCallList
804-
}
845+
val combinedExactMatches =
846+
(processedAssistantExactMatches.keys + toolCallExactMatches.keys).associateWith { key ->
847+
val assistantList = processedAssistantExactMatches[key] ?: emptyList()
848+
val toolCallList = toolCallExactMatches[key] ?: emptyList()
849+
assistantList + toolCallList
850+
}
805851

806852
// Partial Matches
807853
val processedAssistantPartialMatches = assistantPartialMatches.mapValues { (_, value) ->
808854
val texts = value.map { text -> text.trimIndent() }
809855
texts.map { text ->
810856
Message.Assistant(
811857
text,
812-
ResponseMetaInfo.create(clock, outputTokensCount = tokenizer?.countTokens(text))
858+
ResponseMetaInfo.create(
859+
clock,
860+
inputTokensCount = null, // Will be updated at runtime with actual input
861+
outputTokensCount = tokenizer?.countTokens(text),
862+
totalTokensCount = null // Will be calculated at runtime
863+
)
813864
)
814865
}
815866
}
@@ -827,23 +878,39 @@ public class MockLLMBuilder(private val clock: Clock, private val tokenizer: Tok
827878
textResponse.map { response ->
828879
Message.Assistant(
829880
content = response,
830-
metaInfo = ResponseMetaInfo.create(clock)
881+
metaInfo = ResponseMetaInfo.create(
882+
clock,
883+
inputTokensCount = null, // Cannot determine input tokens for conditional matches without the actual input string
884+
outputTokensCount = tokenizer?.countTokens(response),
885+
totalTokensCount = null // Will be calculated at runtime
886+
)
831887
)
832888
}
833889
} ?: emptyMap()
834890

835-
val combinedConditionalMatches = (processedAssistantConditionalMatches.keys + toolCallConditionalMatches.keys).associateWith { key ->
836-
buildList {
837-
processedAssistantConditionalMatches[key]?.let { addAll(it) }
838-
toolCallConditionalMatches[key]?.let { addAll(it) }
891+
val combinedConditionalMatches =
892+
(processedAssistantConditionalMatches.keys + toolCallConditionalMatches.keys).associateWith { key ->
893+
buildList {
894+
processedAssistantConditionalMatches[key]?.let { addAll(it) }
895+
toolCallConditionalMatches[key]?.let { addAll(it) }
896+
}
839897
}
840-
}
841898

842899
val responseMatcher = ResponseMatcher(
843900
partialMatches = combinedPartialMatches.takeIf { it.isNotEmpty() },
844901
exactMatches = combinedExactMatches.takeIf { it.isNotEmpty() },
845902
conditional = combinedConditionalMatches,
846-
defaultResponse = listOf(Message.Assistant(defaultResponse, ResponseMetaInfo.create(clock)))
903+
defaultResponse = listOf(
904+
Message.Assistant(
905+
defaultResponse,
906+
ResponseMetaInfo.create(
907+
clock,
908+
inputTokensCount = null, // Will be updated at runtime with actual input
909+
outputTokensCount = tokenizer?.countTokens(defaultResponse),
910+
totalTokensCount = null // Will be calculated at runtime
911+
)
912+
)
913+
)
847914
)
848915

849916
val moderationResponseMatcher = ResponseMatcher(
@@ -942,7 +1009,7 @@ public class DefaultResponseReceiver(
9421009
* @param clock: A clock that is used for mock message timestamps
9431010
* @param tokenizer: Tokenizer that will be used to estimate token counts in mock messages
9441011
* @param init A lambda with receiver that configures the mock LLM executor
945-
* @return A configured PromptExecutor for testing
1012+
* @return Сonfigured PromptExecutor for testing
9461013
*
9471014
* Example usage:
9481015
* ```kotlin

agents/agents-test/src/commonMain/kotlin/ai/koog/agents/testing/tools/MockLLMExecutor.kt

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import ai.koog.prompt.dsl.Prompt
77
import ai.koog.prompt.executor.model.PromptExecutor
88
import ai.koog.prompt.llm.LLModel
99
import ai.koog.prompt.message.Message
10+
import ai.koog.prompt.message.ResponseMetaInfo
1011
import ai.koog.prompt.streaming.StreamFrame
1112
import ai.koog.prompt.streaming.toStreamFrame
1213
import ai.koog.prompt.tokenizer.Tokenizer
@@ -38,9 +39,9 @@ internal class ResponseMatcher<TResponse>(
3839
*
3940
* This class simulates an LLM by returning predefined responses based on the input prompt.
4041
* It supports different types of matching:
41-
* 1. Exact matching - Returns a response when the input exactly matches a pattern
42+
* 1. Exact matching - Returns a response when the input exactly matches pattern
4243
* 2. Partial matching - Returns a response when the input contains a pattern
43-
* 3. Conditional matching - Returns a response when the input satisfies a condition
44+
* 3. Conditional matching - Returns a response when the input satisfies condition
4445
* 4. Default response - Returns a default response when no other matches are found
4546
*
4647
* It also supports tool calls and can be configured to return specific tool results.
@@ -138,7 +139,7 @@ internal class MockLLMExecutor(
138139
* 1. First checking for exact matches
139140
* 2. Then checking for partial matches
140141
* 3. Then checking for conditional matches
141-
* 4. Finally returning the default response if no matches are found
142+
* 4. Finally, returning the default response if no matches are found
142143
*
143144
* @param prompt The prompt to handle
144145
* @return The appropriate response based on the configured matches
@@ -147,8 +148,6 @@ internal class MockLLMExecutor(
147148
logger.debug { "Handling prompt with messages:" }
148149
prompt.messages.forEach { logger.debug { "Message content: ${it.content.take(300)}..." } }
149150

150-
val inputTokensCount = tokenizer?.let { prompt.messages.map { it.content }.sumOf(it::countTokens) }
151-
152151
val lastMessage = getLastMessage(prompt) ?: return responseMatcher.defaultResponse
153152

154153
// Check the exact response match
@@ -170,20 +169,19 @@ internal class MockLLMExecutor(
170169
}
171170

172171
// Check request conditions
173-
val conditionals = getConditionalResponse(lastMessage, inputTokensCount) ?: listOf()
172+
val conditionals = getConditionalResponse(lastMessage) ?: listOf()
174173

175174
val result = (exactMatchedResponse ?: listOf()) + partiallyMatchedResponse + conditionals
176175
if (result.any()) {
177-
return result
176+
return updateTokenCounts(result, lastMessage.content)
178177
}
179178

180179
// Process the default LLM response
181-
return responseMatcher.defaultResponse
180+
return updateTokenCounts(responseMatcher.defaultResponse, lastMessage.content)
182181
}
183182

184183
private fun getConditionalResponse(
185184
lastMessage: Message,
186-
inputTokensCount: Int?
187185
): List<Message.Response>? = if (!responseMatcher.conditional.isNullOrEmpty()) {
188186
responseMatcher.conditional.entries.firstOrNull { it.key(lastMessage.content) }?.let { (_, response) ->
189187
logger.debug { "Returning response for conditional match: $response" }
@@ -193,6 +191,51 @@ internal class MockLLMExecutor(
193191
emptyList()
194192
}
195193

194+
/**
195+
* Updates the token counts in response metadata to use the input string.
196+
*/
197+
private fun updateTokenCounts(
198+
responses: List<Message.Response>,
199+
input: String,
200+
): List<Message.Response> {
201+
if (tokenizer == null) return responses
202+
203+
val inputTokenCount = tokenizer.countTokens(input)
204+
205+
return responses.map { response ->
206+
when (response) {
207+
is Message.Assistant -> {
208+
val outputTokenCount = tokenizer.countTokens(response.content)
209+
val updatedMetaInfo = ResponseMetaInfo.create(
210+
clock = clock,
211+
inputTokensCount = inputTokenCount,
212+
outputTokensCount = outputTokenCount,
213+
totalTokensCount = inputTokenCount + outputTokenCount
214+
)
215+
Message.Assistant(response.content, updatedMetaInfo)
216+
}
217+
218+
is Message.Tool.Call -> {
219+
val outputTokenCount = tokenizer.countTokens(response.content)
220+
val updatedMetaInfo = ResponseMetaInfo.create(
221+
clock = clock,
222+
inputTokensCount = inputTokenCount,
223+
outputTokensCount = outputTokenCount,
224+
totalTokensCount = inputTokenCount + outputTokenCount
225+
)
226+
Message.Tool.Call(
227+
id = response.id,
228+
tool = response.tool,
229+
content = response.content,
230+
metaInfo = updatedMetaInfo
231+
)
232+
}
233+
234+
else -> response // Keep other response types unchanged
235+
}
236+
}
237+
}
238+
196239
/*
197240
Additional helper functions
198241
*/

0 commit comments

Comments
 (0)