Skip to content

Commit eebdbf1

Browse files
authored
Merge branch 'master' into master
2 parents 33e9a90 + c4df981 commit eebdbf1

File tree

7 files changed

+211
-1
lines changed

7 files changed

+211
-1
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
message: >
2+
**ai-proxy**: Fixed an issue where OpenAI chat completion's tool_choice was not converted to Anthropic's.
3+
type: bugfix
4+
scope: Plugin

kong/llm/drivers/anthropic.lua

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,33 @@ local function to_tools(in_tools)
128128
return out_tools
129129
end
130130

131+
local function to_tool_choice(openai_tool_choice)
132+
-- See https://docs.anthropic.com/en/api/messages#body-tool-choice and
133+
-- https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
134+
if type(openai_tool_choice) == "string" then
135+
if openai_tool_choice == "required" then
136+
return {type = "any"}
137+
elseif openai_tool_choice == "none" or openai_tool_choice == "auto" then
138+
return {type = openai_tool_choice}
139+
else
140+
kong.log.warn("invalid tool choice string: ", openai_tool_choice, ", expected 'required', 'none', or 'auto'")
141+
return nil
142+
end
143+
end
144+
145+
if type(openai_tool_choice) == "table" then
146+
if openai_tool_choice.type == "function" and openai_tool_choice["function"].name then
147+
return {type = "tool", name = openai_tool_choice["function"].name}
148+
end
149+
150+
kong.log.warn("invalid tool choice table: ", cjson.encode(openai_tool_choice))
151+
return nil
152+
end
153+
154+
kong.log.warn("invalid tool choice type: ", type(openai_tool_choice), ", expected string or table")
155+
return nil
156+
end
157+
131158
local transformers_to = {
132159
["llm/v1/chat"] = function(request_table, model)
133160
local messages = {}
@@ -145,7 +172,7 @@ local transformers_to = {
145172

146173
-- handle function calling translation from OpenAI format
147174
messages.tools = request_table.tools and to_tools(request_table.tools)
148-
messages.tool_choice = request_table.tool_choice
175+
messages.tool_choice = request_table.tool_choice and to_tool_choice(request_table.tool_choice)
149176

150177
return messages, "application/json", nil
151178
end,

spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,27 @@ for _, strategy in helpers.all_strategies() do
147147
}
148148
}
149149
150+
location = "/llm/v1/chat/tool_choice" {
151+
content_by_lua_block {
152+
local pl_file = require "pl.file"
153+
local json = require("cjson.safe")
154+
155+
ngx.req.read_body()
156+
local function assert_ok(ok, err)
157+
if not ok then
158+
ngx.status = 500
159+
ngx.say(err)
160+
ngx.exit(ngx.HTTP_INTERNAL_SERVER_ERROR)
161+
end
162+
return ok
163+
end
164+
local body = assert_ok(ngx.req.get_body_data())
165+
body = assert_ok(json.decode(body))
166+
local tool_choice = body.tool_choice
167+
ngx.header["tool-choice"] = json.encode(tool_choice)
168+
ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/good.json"))
169+
}
170+
}
150171
151172
location = "/llm/v1/completions/good" {
152173
content_by_lua_block {
@@ -374,6 +395,36 @@ for _, strategy in helpers.all_strategies() do
374395
}
375396
--
376397

398+
-- 200 chat tool_choice response
399+
local chat_tool_choice = assert(bp.routes:insert {
400+
service = empty_service,
401+
protocols = { "http" },
402+
strip_path = true,
403+
paths = { "/anthropic/llm/v1/chat/tool_choice" }
404+
})
405+
bp.plugins:insert {
406+
name = PLUGIN_NAME,
407+
route = { id = chat_tool_choice.id },
408+
config = {
409+
route_type = "llm/v1/chat",
410+
auth = {
411+
header_name = "x-api-key",
412+
header_value = "anthropic-key",
413+
},
414+
model = {
415+
name = "claude-2.1",
416+
provider = "anthropic",
417+
options = {
418+
max_tokens = 256,
419+
temperature = 1.0,
420+
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/tool_choice",
421+
anthropic_version = "2023-06-01",
422+
},
423+
},
424+
},
425+
}
426+
--
427+
377428
-- 401 unauthorized
378429
local chat_401 = assert(bp.routes:insert {
379430
service = empty_service,
@@ -740,6 +791,42 @@ for _, strategy in helpers.all_strategies() do
740791
local json = cjson.decode(body)
741792
assert.is_truthy(deepcompare(json.usage, {}))
742793
end)
794+
795+
it("tool_choice conversion", function()
796+
local function get_converted_tool_choice(input)
797+
local body = pl_file.read(input)
798+
-- rewrite the model so we can reuse the same test fixture with different models
799+
body = cjson.decode(body)
800+
body.model = "claude-2.1" -- anthropic model name
801+
local r = client:post("/anthropic/llm/v1/chat/tool_choice", {
802+
headers = {
803+
["content-type"] = "application/json",
804+
["accept"] = "application/json",
805+
},
806+
body = cjson.encode(body),
807+
})
808+
r:read_body()
809+
local sent = r.headers["tool-choice"]
810+
if not sent then
811+
return nil
812+
end
813+
return cjson.decode(sent)
814+
end
815+
816+
for _, case in ipairs({
817+
{input = "spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/tool_choice_auto.json",
818+
output = {type = "auto"}},
819+
{input = "spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/tool_choice_none.json",
820+
output = {type = "none"}},
821+
{input = "spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/tool_choice_required.json",
822+
output = {type = "any"}},
823+
{input = "spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/tool_choice_object_function.json",
824+
output = {type = "tool", name = "my_function"}},
825+
}) do
826+
local r = get_converted_tool_choice(case.input)
827+
assert.same(case.output, r)
828+
end
829+
end)
743830
end)
744831

745832
describe("anthropic llm/v1/completions", function()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"model": "gpt-4.1",
3+
"messages": [
4+
{
5+
"role": "user",
6+
"content": "What is the weather like in Boston today?"
7+
}
8+
],
9+
"tools": [
10+
{
11+
"type": "function",
12+
"function": {
13+
"name": "get_current_weather",
14+
"description": "Get the current weather in a given location",
15+
"parameters": {
16+
"type": "object",
17+
"properties": {
18+
"location": {
19+
"type": "string",
20+
"description": "The city and state, e.g. San Francisco, CA"
21+
},
22+
"unit": {
23+
"type": "string",
24+
"enum": ["celsius", "fahrenheit"]
25+
}
26+
},
27+
"required": ["location"]
28+
}
29+
}
30+
}
31+
],
32+
"tool_choice": "auto"
33+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"model": "gpt-4.1",
3+
"messages": [
4+
{
5+
"role": "user",
6+
"content": "What is the weather like in Boston today?"
7+
}
8+
],
9+
"tool_choice": "none"
10+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"model": "gpt-4-0613",
3+
"messages": [
4+
{
5+
"role": "system",
6+
"content": "You are an AI assistant."
7+
},
8+
{
9+
"role": "user",
10+
"content": "Tell me a joke."
11+
}
12+
],
13+
"tools": [
14+
{
15+
"type": "function",
16+
"function": {
17+
"name": "my_function",
18+
"parameters": {}
19+
}
20+
}
21+
],
22+
"tool_choice": {
23+
"type": "function",
24+
"function": {
25+
"name": "my_function"
26+
}
27+
}
28+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"model": "gpt-4-0613",
3+
"messages": [
4+
{
5+
"role": "user",
6+
"content": "What is the weather today in Paris?"
7+
}
8+
],
9+
"tools": [
10+
{
11+
"type": "function",
12+
"function": {
13+
"name": "get_current_weather",
14+
"parameters": {
15+
"location": "Paris"
16+
}
17+
}
18+
}
19+
],
20+
"tool_choice": "required"
21+
}

0 commit comments

Comments
 (0)