Skip to content

Commit b209dc1

Browse files
committed
Allow conversations to have "previous messages"
This commit modifies the `Chat` method (in a backwards compatible way) to allow users to start the conversation with previously exchanged messages, i.e. messages that had already been exchanged between the user and the assistant. This practically allows "loading" previous conversations in order to continue them. To support this, the `Chat` method now takes `types.Message` values in a variadic way. The messages field of the `Conversation` type is also exposed now, so that users may save it.
1 parent ae51057 commit b209dc1

File tree

4 files changed

+81
-29
lines changed

4 files changed

+81
-29
lines changed

libaiac/bedrock/chat.go

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,44 @@ import (
1313
// Conversation is a struct used to converse with a Bedrock chat model. It
1414
// maintains all messages sent/received in order to maintain context.
1515
type Conversation struct {
16-
backend *Bedrock
17-
model string
18-
messages []bedrocktypes.Message
16+
// Messages is the list of all messages exchanged between the user and the
17+
// assistant.
18+
Messages []bedrocktypes.Message
19+
20+
backend *Bedrock
21+
model string
1922
}
2023

2124
// Chat initiates a conversation with a Bedrock chat model. A conversation
2225
// maintains context, allowing to send further instructions to modify the output
23-
// from previous requests.
24-
func (backend *Bedrock) Chat(model string) types.Conversation {
25-
return &Conversation{
26+
// from previous requests. The name of the model to use must be provided. Users
27+
// can also supply zero or more "previous messages" that may have been exchanged
28+
// in the past. This practically allows "loading" previous conversations and
29+
// continuing them.
30+
func (backend *Bedrock) Chat(model string, msgs ...types.Message) types.Conversation {
31+
chat := &Conversation{
2632
backend: backend,
2733
model: model,
2834
}
35+
36+
if len(msgs) > 0 {
37+
chat.Messages = make([]bedrocktypes.Message, len(msgs))
38+
for i := range msgs {
39+
role := bedrocktypes.ConversationRoleUser
40+
if msgs[i].Role == "assistant" {
41+
role = bedrocktypes.ConversationRoleAssistant
42+
}
43+
44+
chat.Messages[i] = bedrocktypes.Message{
45+
Role: role,
46+
Content: []bedrocktypes.ContentBlock{
47+
&bedrocktypes.ContentBlockMemberText{Value: msgs[i].Content},
48+
},
49+
}
50+
}
51+
}
52+
53+
return chat
2954
}
3055

3156
// Send sends the provided message to the backend and returns a Response object.
@@ -36,7 +61,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
3661
res types.Response,
3762
err error,
3863
) {
39-
conv.messages = append(conv.messages, bedrocktypes.Message{
64+
conv.Messages = append(conv.Messages, bedrocktypes.Message{
4065
Role: bedrocktypes.ConversationRoleUser,
4166
Content: []bedrocktypes.ContentBlock{
4267
&bedrocktypes.ContentBlockMemberText{Value: prompt},
@@ -45,7 +70,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
4570

4671
input := bedrockruntime.ConverseInput{
4772
ModelId: aws.String(conv.model),
48-
Messages: conv.messages,
73+
Messages: conv.Messages,
4974
InferenceConfig: &bedrocktypes.InferenceConfiguration{
5075
Temperature: aws.Float32(0.2),
5176
},
@@ -76,7 +101,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
76101
res.TokensUsed = int64(*output.Usage.TotalTokens)
77102
res.StopReason = string(output.StopReason)
78103

79-
conv.messages = append(conv.messages, outputMsg)
104+
conv.Messages = append(conv.Messages, outputMsg)
80105

81106
if res.Code, ok = types.ExtractCode(res.FullOutput); !ok {
82107
res.Code = res.FullOutput

libaiac/ollama/chat.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@ import (
1111
// Conversation is a struct used to converse with an Ollama chat model. It
1212
// maintains all messages sent/received in order to maintain context.
1313
type Conversation struct {
14-
backend *Ollama
15-
model string
16-
messages []types.Message
14+
// Messages is the list of all messages exchanged between the user and the
15+
// assistant.
16+
Messages []types.Message
17+
18+
backend *Ollama
19+
model string
1720
}
1821

1922
type chatResponse struct {
@@ -23,12 +26,21 @@ type chatResponse struct {
2326

2427
// Chat initiates a conversation with an Ollama chat model. A conversation
2528
// maintains context, allowing to send further instructions to modify the output
26-
// from previous requests.
27-
func (backend *Ollama) Chat(model string) types.Conversation {
28-
return &Conversation{
29+
// from previous requests. The name of the model to use must be provided. Users
30+
// can also supply zero or more "previous messages" that may have been exchanged
31+
// in the past. This practically allows "loading" previous conversations and
32+
// continuing them.
33+
func (backend *Ollama) Chat(model string, msgs ...types.Message) types.Conversation {
34+
chat := &Conversation{
2935
backend: backend,
3036
model: model,
3137
}
38+
39+
if len(msgs) > 0 {
40+
chat.Messages = msgs
41+
}
42+
43+
return chat
3244
}
3345

3446
// Send sends the provided message to the API and returns a Response object.
@@ -41,15 +53,15 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
4153
) {
4254
var answer chatResponse
4355

44-
conv.messages = append(conv.messages, types.Message{
56+
conv.Messages = append(conv.Messages, types.Message{
4557
Role: "user",
4658
Content: prompt,
4759
})
4860

4961
err = conv.backend.NewRequest("POST", "/chat").
5062
JSONBody(map[string]interface{}{
5163
"model": conv.model,
52-
"messages": conv.messages,
64+
"messages": conv.Messages,
5365
"options": map[string]interface{}{
5466
"temperature": 0.2,
5567
},
@@ -61,7 +73,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
6173
return res, fmt.Errorf("failed sending prompt: %w", err)
6274
}
6375

64-
conv.messages = append(conv.messages, answer.Message)
76+
conv.Messages = append(conv.Messages, answer.Message)
6577

6678
res.FullOutput = strings.TrimSpace(answer.Message.Content)
6779
if answer.Done {

libaiac/openai/chat.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ import (
1212
// maintains all messages sent/received in order to maintain context just like
1313
// using ChatGPT.
1414
type Conversation struct {
15-
backend *OpenAI
16-
model string
17-
messages []types.Message
15+
// Messages is the list of all messages exchanged between the user and the
16+
// assistant.
17+
Messages []types.Message
18+
19+
backend *OpenAI
20+
model string
1821
}
1922

2023
type chatResponse struct {
@@ -30,12 +33,21 @@ type chatResponse struct {
3033

3134
// Chat initiates a conversation with an OpenAI chat model. A conversation
3235
// maintains context, allowing to send further instructions to modify the output
33-
// from previous requests, just like using the ChatGPT website.
34-
func (backend *OpenAI) Chat(model string) types.Conversation {
35-
return &Conversation{
36+
// from previous requests, just like using the ChatGPT website. The name of the
37+
// model to use must be provided. Users can also supply zero or more "previous
38+
// messages" that may have been exchanged in the past. This practically allows
39+
// "loading" previous conversations and continuing them.
40+
func (backend *OpenAI) Chat(model string, msgs ...types.Message) types.Conversation {
41+
chat := &Conversation{
3642
backend: backend,
3743
model: model,
3844
}
45+
46+
if len(msgs) > 0 {
47+
chat.Messages = msgs
48+
}
49+
50+
return chat
3951
}
4052

4153
// Send sends the provided message to the API and returns a Response object.
@@ -48,7 +60,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
4860
) {
4961
var answer chatResponse
5062

51-
conv.messages = append(conv.messages, types.Message{
63+
conv.Messages = append(conv.Messages, types.Message{
5264
Role: "user",
5365
Content: prompt,
5466
})
@@ -62,7 +74,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
6274
NewRequest("POST", fmt.Sprintf("/chat/completions%s", apiVersion)).
6375
JSONBody(map[string]interface{}{
6476
"model": conv.model,
65-
"messages": conv.messages,
77+
"messages": conv.Messages,
6678
"temperature": 0.2,
6779
}).
6880
Into(&answer).
@@ -75,7 +87,7 @@ func (conv *Conversation) Send(ctx context.Context, prompt string) (
7587
return res, types.ErrNoResults
7688
}
7789

78-
conv.messages = append(conv.messages, answer.Choices[0].Message)
90+
conv.Messages = append(conv.Messages, answer.Choices[0].Message)
7991

8092
res.FullOutput = strings.TrimSpace(answer.Choices[0].Message.Content)
8193
res.APIKeyUsed = conv.backend.apiKey

libaiac/types/interfaces.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@ type Backend interface {
88
// ListModels returns a list of all models supported by the backend.
99
ListModels(context.Context) ([]string, error)
1010

11-
// Chat initiates a conversation with an LLM backend.
12-
Chat(string) Conversation
11+
// Chat initiates a conversation with an LLM backend. The name of the model
12+
// to use must be provided. Users can also supply zero or more "previous
13+
// messages" that may have been exchanged in the past. This practically
14+
// allows "loading" previous conversations and continuing them.
15+
Chat(string, ...Message) Conversation
1316
}
1417

1518
// Conversation is an interface that must be implemented in order to support

0 commit comments

Comments
 (0)