Skip to content

Commit d8195eb

Browse files
authored
Merge pull request #10 from stg-tud/feature/structured-output
Feature/structured output
2 parents 5ae891d + ab87add commit d8195eb

14 files changed

Lines changed: 1617 additions & 55 deletions

File tree

engine/build.gradle.kts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,7 @@ dependencies {
99
implementation(libs.ktor.client.java)
1010
implementation(libs.ktor.client.content.negotiation)
1111
implementation(libs.ktor.serialization.json)
12+
testImplementation(kotlin("test"))
13+
testImplementation(libs.kotlinx.serialization.json)
14+
testImplementation(libs.kotlinx.coroutines.core)
1215
}

engine/src/main/kotlin/de/tuda/stg/securecoder/engine/llm/FilesInContextPromptBuilder.kt renamed to engine/src/main/kotlin/de/tuda/stg/securecoder/engine/file/FilesInContextPromptBuilder.kt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
package de.tuda.stg.securecoder.engine.llm
1+
package de.tuda.stg.securecoder.engine.file
22

33
import de.tuda.stg.securecoder.filesystem.FileSystem
44

5-
65
object FilesInContextPromptBuilder {
76
suspend fun build(files: Iterable<FileSystem.File>, edit: Boolean = false) = buildString {
87
if (files.count() == 0) {
98
appendLine("You have no files in the context.")
109
appendLine("If you saw files they are only part of the prompt and dont exists yet!")
1110
if (edit) {
12-
appendLine("You may create new files (keep in mind that searchedText needs to be empty in this case!)")
11+
appendLine("You may create new files (keep in mind that searched text needs to be empty in this case!)")
1312
}
1413
return@buildString
1514
}
@@ -20,4 +19,4 @@ object FilesInContextPromptBuilder {
2019
appendLine("<<<END FILE>>>")
2120
}
2221
}
23-
}
22+
}

engine/src/main/kotlin/de/tuda/stg/securecoder/engine/file/edit/EditFilesLlmWrapper.kt

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@ import de.tuda.stg.securecoder.engine.llm.ChatMessage.Role
66
import de.tuda.stg.securecoder.engine.llm.LlmClient
77
import de.tuda.stg.securecoder.filesystem.FileSystem
88
import de.tuda.stg.securecoder.engine.llm.ChatExchange
9-
import kotlinx.coroutines.flow.collect
10-
import kotlinx.coroutines.flow.map
11-
import kotlinx.coroutines.flow.toList
129
import kotlin.collections.plusAssign
1310

1411
class EditFilesLlmWrapper(
@@ -76,7 +73,7 @@ class EditFilesLlmWrapper(
7673
appendLine("It violated the required format.")
7774
appendLine("Errors:")
7875
messages.forEach { appendLine(it) }
79-
appendLine("Respond again with ONLY <EDITN> blocks that strictly follow the rules. Do NOT include prose, markdown, or explanations.")
76+
appendLine("Respond again with ONLY edit blocks that strictly follow the rules. Do NOT include prose, markdown, or explanations.")
8077
appendLine("IMPORTANT: Resend the COMPLETE set of edits you intend to apply from your previous message")
8178
}
8279
}
@@ -155,12 +152,7 @@ class EditFilesLlmWrapper(
155152
return ParseResult.Err(allErrors)
156153
}
157154

158-
val seen = HashSet<Pair<String, String>>()
159-
val deduped = results.filter { sr ->
160-
seen.add(sr.fileName to sr.searchedText.text)
161-
}
162-
163-
return ParseResult.Ok(Changes(deduped))
155+
return ParseResult.Ok(Changes(results))
164156
}
165157

166158
private fun getTextByXMLTag(container: String, tag: String): String? {
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package de.tuda.stg.securecoder.engine.file.edit
2+
3+
import de.tuda.stg.securecoder.engine.file.edit.Changes.SearchedText
4+
import de.tuda.stg.securecoder.engine.llm.ChatMessage
5+
import de.tuda.stg.securecoder.engine.llm.ChatMessage.Role
6+
import de.tuda.stg.securecoder.engine.llm.LlmClient
7+
import de.tuda.stg.securecoder.engine.llm.LLMDescription
8+
import de.tuda.stg.securecoder.engine.llm.chatStructured
9+
import de.tuda.stg.securecoder.filesystem.FileSystem
10+
import de.tuda.stg.securecoder.engine.llm.ChatExchange
11+
import kotlinx.serialization.Serializable
12+
import kotlinx.serialization.encodeToString
13+
import kotlinx.serialization.json.Json
14+
import kotlin.collections.plusAssign
15+
16+
class StructuredEditFilesLlmWrapper(
17+
private val llmClient: LlmClient
18+
) {
19+
//TODO path => **uri** ; EditFilesLlmWrapper should be separate from the filesystem implementation
20+
private val prompt = """
21+
Your task it is to produce code. The agent will just parse the code you produce. So dont do a extensive review in your final answer!
22+
23+
It's acceptable to add multiple *search/REPLACE* sections if you need to change multiple parts of the file.
24+
To create a file: search must be empty and replace must contain the entire file content
25+
Each *search* pattern must match the existing source code exactly once, line for line, character for character, including all comments, docstrings, etc.
26+
Do not use a part of the line as *search* pattern. You must use full lines.
27+
Include enough lines to make code inside *search* pattern uniquely identifiable. A *search* pattern that produces multiple matches in the source code will be rejected as an error.
28+
Do not add backslashes to escape special characters. Write the code exactly as it should appear in the intended programming language.
29+
Do not use git diff style (+ and - at the beginning of the line) for *search/REPLACE* blocks.
30+
Do not use line numbers in *search/REPLACE* blocks. Do not enclose the *search/REPLACE* block or any of its components in triple quotes. Use only tags to separate the parameters.
31+
Do not use the same value for *search* and *REPLACE* parameters, as this will make no changes.
32+
33+
If you need to edit a file again after making changes, use the latest version of the code that includes all your modifications applied during **current session**.
34+
""".trimIndent()
35+
36+
37+
suspend fun chat(
38+
messages: List<ChatMessage>,
39+
fileSystem: FileSystem,
40+
params: LlmClient.GenerationParams = LlmClient.GenerationParams(),
41+
onParseError: suspend (parseErrors: List<String>, llm: ChatExchange) -> Unit = { _, _ -> },
42+
attempts: Int = 3
43+
): ChatResult {
44+
val messages = messages.toMutableList()
45+
appendPromptToLastSystem(messages)
46+
repeat(attempts) {
47+
val llmInput = messages.toList()
48+
val structured = llmClient.chatStructured<StructuredEdits>(llmInput, params)
49+
messages += ChatMessage(Role.Assistant, Json.encodeToString(structured))
50+
when (val result = validateAndConvert(structured, fileSystem)) {
51+
is ParseResult.Ok -> return ChatResult(messages, result.value)
52+
is ParseResult.Err -> {
53+
messages += ChatMessage(Role.User, result.buildMessage())
54+
onParseError(result.messages, ChatExchange(llmInput, messages.last().content))
55+
}
56+
}
57+
}
58+
return ChatResult(messages, null)
59+
}
60+
61+
data class ChatResult(val messages: List<ChatMessage>, val changes: Changes?) {
62+
fun changesMessage() = messages.last { it.role == Role.Assistant }
63+
}
64+
65+
sealed interface ParseResult {
66+
data class Ok(val value: Changes) : ParseResult
67+
data class Err(val messages: List<String>) : ParseResult {
68+
fun buildMessage() = buildString {
69+
appendLine("Your previous output could not be applied.")
70+
appendLine("It violated the required format.")
71+
appendLine("Errors:")
72+
messages.forEach { appendLine(it) }
73+
appendLine("Respond again with ONLY edit blocks that strictly follow the rules. Do NOT include prose, markdown, or explanations.")
74+
appendLine("IMPORTANT: Resend the COMPLETE set of edits you intend to apply from your previous message")
75+
}
76+
}
77+
}
78+
79+
private suspend fun validateAndConvert(structured: StructuredEdits, fileSystem: FileSystem): ParseResult {
80+
val results = mutableListOf<Changes.SearchReplace>()
81+
val allErrors = mutableListOf<String>()
82+
if (structured.edits.isEmpty()) {
83+
allErrors += "No edits provided. Provide at least one edit block."
84+
return ParseResult.Err(allErrors)
85+
}
86+
for (e in structured.edits) {
87+
val file = e.filePath.trim()
88+
val searchPart = e.search
89+
val replacePart = e.replace
90+
if (file.isEmpty()) {
91+
allErrors += "`filePath` should not be empty"
92+
continue
93+
}
94+
if (searchPart == replacePart) {
95+
allErrors += "`search` and `replace` parameters are the same"
96+
continue
97+
}
98+
val replace = Changes.SearchReplace(file, SearchedText(searchPart), replacePart)
99+
val content = fileSystem.getFile(file)?.content()
100+
val match = ApplyChanges.match(content, replace.searchedText)
101+
if (match is Matcher.MatchResult.Error) {
102+
allErrors += ApplyChanges.buildErrorMessage(file, searchPart, match)
103+
continue
104+
}
105+
results += replace
106+
}
107+
if (results.isEmpty()) return ParseResult.Err(allErrors)
108+
return ParseResult.Ok(Changes(results))
109+
}
110+
111+
private fun appendPromptToLastSystem(messages: MutableList<ChatMessage>) {
112+
val lastSystemIndex = messages.indexOfLast { it.role == Role.System }
113+
if (lastSystemIndex >= 0) {
114+
val existing = messages[lastSystemIndex]
115+
val combined = "${existing.content}\n\n$prompt\n\nRespond ONLY with a JSON object that matches the provided schema. Do not include explanations."
116+
messages[lastSystemIndex] = ChatMessage(Role.System, combined)
117+
} else {
118+
messages += ChatMessage(Role.System, "$prompt\n\nRespond ONLY with a JSON object that matches the provided schema. Do not include explanations.")
119+
}
120+
}
121+
122+
@Serializable
123+
data class StructuredEdits(
124+
@LLMDescription("List of edit operations to apply")
125+
val edits: List<EditOperation>
126+
)
127+
128+
@Serializable
129+
data class EditOperation(
130+
@LLMDescription("The full **uri** of the file that will be modified")
131+
val filePath: String,
132+
@LLMDescription("A continuous, yet concise block of lines to search for in the existing source code (*search* pattern). If this section is empty, the lines from `replace` will be added to the end of the file.")
133+
val search: String,
134+
@LLMDescription("The lines to replace the existing code found using `search`. If this section is empty, the lines specified in `search` will be removed.")
135+
val replace: String,
136+
)
137+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package de.tuda.stg.securecoder.engine.llm
2+
3+
import kotlinx.serialization.ExperimentalSerializationApi
4+
import kotlinx.serialization.KSerializer
5+
import kotlinx.serialization.descriptors.PolymorphicKind
6+
import kotlinx.serialization.descriptors.PrimitiveKind
7+
import kotlinx.serialization.descriptors.SerialDescriptor
8+
import kotlinx.serialization.descriptors.SerialKind
9+
import kotlinx.serialization.descriptors.StructureKind
10+
import kotlinx.serialization.json.JsonArray
11+
import kotlinx.serialization.json.JsonObject
12+
import kotlinx.serialization.json.JsonObjectBuilder
13+
import kotlinx.serialization.json.JsonPrimitive
14+
import kotlinx.serialization.json.buildJsonArray
15+
import kotlinx.serialization.json.buildJsonObject
16+
17+
@OptIn(ExperimentalSerializationApi::class)
18+
class KxJsonSchemaFormat {
19+
fun <T> format(serializer: KSerializer<T>): JsonObject =
20+
schemaForDescriptor(serializer.descriptor, seen = HashSet())
21+
22+
private fun schemaForDescriptor(desc: SerialDescriptor, seen: MutableSet<String>): JsonObject {
23+
val key = desc.serialName
24+
if (!seen.add(key)) {
25+
throw IllegalStateException("Recursive type detected: $key")
26+
}
27+
val jsonType = when (desc.kind) {
28+
PrimitiveKind.BOOLEAN -> type("boolean")
29+
PrimitiveKind.BYTE, PrimitiveKind.SHORT, PrimitiveKind.INT, PrimitiveKind.LONG -> type("integer")
30+
PrimitiveKind.FLOAT, PrimitiveKind.DOUBLE -> type("number")
31+
PrimitiveKind.CHAR, PrimitiveKind.STRING -> type("string")
32+
SerialKind.ENUM -> type("string") {
33+
put("enum", buildJsonArray {
34+
for (i in 0 until desc.elementsCount) {
35+
add(JsonPrimitive(desc.getElementName(i)))
36+
}
37+
})
38+
}
39+
StructureKind.LIST -> type("array") {
40+
put("items", schemaForDescriptor(desc.getElementDescriptor(0), seen))
41+
}
42+
StructureKind.MAP -> type("object") {
43+
val keyDesc = desc.getElementDescriptor(0)
44+
if (keyDesc.kind != PrimitiveKind.STRING) {
45+
throw IllegalStateException("Map keys must be strings, but was ${keyDesc.serialName}")
46+
}
47+
put("additionalProperties", schemaForDescriptor(desc.getElementDescriptor(1), seen))
48+
}
49+
StructureKind.CLASS, StructureKind.OBJECT -> type("object") {
50+
put("properties", buildJsonObject {
51+
for (i in 0 until desc.elementsCount) {
52+
val name = desc.getElementName(i)
53+
val childDesc = desc.getElementDescriptor(i)
54+
val childSchema = schemaForDescriptor(childDesc, seen)
55+
val propDesc = getDescription(desc.getElementAnnotations(i))
56+
put(name, if (propDesc != null) addDescription(childSchema, propDesc) else childSchema)
57+
}
58+
})
59+
val required = JsonArray(desc.requiredElements().map { name -> JsonPrimitive(name) })
60+
if (required.isNotEmpty()) put("required", required)
61+
put("additionalProperties", JsonPrimitive(false))
62+
}
63+
PolymorphicKind.SEALED, PolymorphicKind.OPEN, SerialKind.CONTEXTUAL
64+
-> throw IllegalStateException("Polymorphic types are not supported")
65+
}
66+
seen.remove(key)
67+
if (desc.isNullable) {
68+
throw IllegalStateException("Nullable types are not supported")
69+
}
70+
val selfDesc = getDescription(desc.annotations)
71+
return if (selfDesc != null) addDescription(jsonType, selfDesc) else jsonType
72+
}
73+
74+
private fun type(name: String, builderAction: JsonObjectBuilder.() -> Unit = {}): JsonObject =
75+
buildJsonObject {
76+
put("type", JsonPrimitive(name))
77+
builderAction()
78+
}
79+
80+
private fun SerialDescriptor.requiredElements(): List<String> = (0 until elementsCount)
81+
.filter { !isElementOptional(it) }
82+
.map { getElementName(it) }
83+
84+
private fun JsonArray.isNotEmpty(): Boolean = this.size > 0
85+
86+
private fun getDescription(annotations: List<Annotation>): String? =
87+
annotations.filterIsInstance<LLMDescription>().firstOrNull()?.text
88+
89+
private fun addDescription(obj: JsonObject, text: String): JsonObject =
90+
buildJsonObject {
91+
obj.forEach { (k, v) -> put(k, v) }
92+
put("description", JsonPrimitive(text))
93+
}
94+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package de.tuda.stg.securecoder.engine.llm
2+
3+
import kotlinx.serialization.ExperimentalSerializationApi
4+
import kotlinx.serialization.SerialInfo
5+
6+
@OptIn(ExperimentalSerializationApi::class)
7+
@SerialInfo
8+
@Target(AnnotationTarget.CLASS, AnnotationTarget.PROPERTY)
9+
@Retention(AnnotationRetention.RUNTIME)
10+
annotation class LLMDescription(val text: String)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
package de.tuda.stg.securecoder.engine.llm
22

3+
import kotlinx.serialization.KSerializer
4+
import kotlinx.serialization.serializer
5+
36
interface LlmClient : AutoCloseable {
47
suspend fun chat(
58
messages: List<ChatMessage>,
69
params: GenerationParams = GenerationParams(),
710
): String
811

12+
suspend fun <T> chatStructured(
13+
messages: List<ChatMessage>,
14+
serializer: KSerializer<T>,
15+
params: GenerationParams = GenerationParams(),
16+
): T
17+
918
data class GenerationParams(
1019
val temperature: Double? = null,
1120
val maxTokens: Int? = null
1221
)
1322
}
23+
24+
suspend inline fun <reified T> LlmClient.chatStructured(
25+
messages: List<ChatMessage>,
26+
params: LlmClient.GenerationParams = LlmClient.GenerationParams(),
27+
): T = this.chatStructured(messages, serializer(), params)

0 commit comments

Comments
 (0)