Skip to content

Commit c61285d

Browse files
committed
Enhance API functionality and improve UI components
- Integrated OpenAI client into the API handler, allowing for dynamic model and API key usage in queries. - Updated the configuration to return the active AI provider and default model for better status reporting. - Refactored query handling to support optional API key and model parameters in requests. - Enhanced the ChatConsole component with session management features, including renaming and closing threads. - Improved the ProjectHeaderBar and TopBar components to support theme switching and project management. - Added new themes and refined the theme management system for better user experience. These changes aim to enhance the overall functionality and user experience of the demo application, making it more flexible and user-friendly.
1 parent 863cc22 commit c61285d

21 files changed

Lines changed: 1683 additions & 120 deletions

demo/go/internal/api/handler.go

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/rithulkamesh/docproc/demo/internal/grade"
1212
"github.com/rithulkamesh/docproc/demo/internal/mq"
1313
"github.com/rithulkamesh/docproc/demo/internal/rag"
14+
"github.com/sashabaranov/go-openai"
1415
)
1516

1617
type Handler struct {
@@ -67,14 +68,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
6768

6869
func (h *Handler) status(w http.ResponseWriter, r *http.Request) {
6970
writeJSON(w, map[string]any{
70-
"ok": true,
71-
"rag_backend": "embedding",
72-
"rag_configured": h.rag != nil,
73-
"database_provider": "pgvector",
74-
"primary_ai": "openai",
75-
"namespace": "default",
76-
"default_rag_model": h.cfg.OpenAIModel,
77-
"embedding_deployment": nil,
71+
"ok": true,
72+
"rag_backend": "embedding",
73+
"rag_configured": h.rag != nil,
74+
"database_provider": "pgvector",
75+
"primary_ai": h.cfg.PrimaryAI(),
76+
"namespace": "default",
77+
"default_rag_model": h.cfg.DefaultRAGModel(),
78+
"embedding_deployment": nil,
7879
})
7980
}
8081

@@ -84,8 +85,11 @@ func (h *Handler) embedCheck(w http.ResponseWriter, r *http.Request) {
8485

8586
func (h *Handler) query(w http.ResponseWriter, r *http.Request) {
8687
var body struct {
87-
Query string `json:"query"`
88-
Prompt string `json:"prompt"`
88+
Query string `json:"query"`
89+
Prompt string `json:"prompt"`
90+
APIKey string `json:"api_key"`
91+
Provider string `json:"provider"`
92+
Model string `json:"model"`
8993
}
9094
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
9195
writeError(w, "invalid JSON", http.StatusBadRequest)
@@ -99,11 +103,17 @@ func (h *Handler) query(w http.ResponseWriter, r *http.Request) {
99103
writeError(w, "missing query or prompt", http.StatusBadRequest)
100104
return
101105
}
106+
// RAG is required for embeddings and retrieval; api_key/model in body override chat only
102107
if h.rag == nil {
103-
writeJSON(w, map[string]any{"answer": "RAG not configured. Set OPENAI_API_KEY or AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT in .env.", "sources": []any{}})
108+
writeJSON(w, map[string]any{"answer": "RAG not configured. Set OPENAI_API_KEY or AZURE_OPENAI_* in .env.", "sources": []any{}})
104109
return
105110
}
106-
answer, sources, err := h.rag.Query(r.Context(), q)
111+
var chatClient *openai.Client
112+
model := strings.TrimSpace(body.Model)
113+
if body.APIKey != "" {
114+
chatClient = openai.NewClient(strings.TrimSpace(body.APIKey))
115+
}
116+
answer, sources, err := h.rag.QueryWithClient(r.Context(), q, chatClient, model)
107117
if err != nil {
108118
writeError(w, err.Error(), http.StatusInternalServerError)
109119
return
@@ -124,8 +134,11 @@ func (h *Handler) query(w http.ResponseWriter, r *http.Request) {
124134

125135
func (h *Handler) queryStream(w http.ResponseWriter, r *http.Request) {
126136
var body struct {
127-
Query string `json:"query"`
128-
Prompt string `json:"prompt"`
137+
Query string `json:"query"`
138+
Prompt string `json:"prompt"`
139+
APIKey string `json:"api_key"`
140+
Provider string `json:"provider"`
141+
Model string `json:"model"`
129142
}
130143
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
131144
writeError(w, "invalid JSON", http.StatusBadRequest)
@@ -140,7 +153,7 @@ func (h *Handler) queryStream(w http.ResponseWriter, r *http.Request) {
140153
return
141154
}
142155
if h.rag == nil {
143-
writeError(w, "RAG not configured", http.StatusServiceUnavailable)
156+
writeError(w, "RAG not configured. Set OPENAI_API_KEY or AZURE_OPENAI_* in .env.", http.StatusServiceUnavailable)
144157
return
145158
}
146159
prompt, sources, err := h.rag.GetContextForQuery(r.Context(), q)
@@ -172,7 +185,12 @@ func (h *Handler) queryStream(w http.ResponseWriter, r *http.Request) {
172185
if f, ok := w.(http.Flusher); ok {
173186
f.Flush()
174187
}
175-
if err := h.rag.StreamCompletion(ctx, prompt, w); err != nil {
188+
var streamClient *openai.Client
189+
model := strings.TrimSpace(body.Model)
190+
if body.APIKey != "" {
191+
streamClient = openai.NewClient(strings.TrimSpace(body.APIKey))
192+
}
193+
if err := h.rag.StreamCompletionWithClient(ctx, prompt, w, streamClient, model); err != nil {
176194
return
177195
}
178196
}

demo/go/internal/config/config.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,28 @@ func (c *Config) HasAI() bool {
4747
return c.OpenAIKey != "" || (c.AzureAPIKey != "" && c.AzureEndpoint != "")
4848
}
4949

50+
// PrimaryAI returns the active provider id for status: "openai" or "azure".
51+
func (c *Config) PrimaryAI() string {
52+
if c.OpenAIKey != "" {
53+
return "openai"
54+
}
55+
if c.AzureAPIKey != "" && c.AzureEndpoint != "" {
56+
return "azure"
57+
}
58+
return "openai"
59+
}
60+
61+
// DefaultRAGModel returns the chat model name used for RAG (OpenAI or Azure deployment).
62+
func (c *Config) DefaultRAGModel() string {
63+
if c.OpenAIKey != "" {
64+
return c.OpenAIModel
65+
}
66+
if c.AzureAPIKey != "" && c.AzureEndpoint != "" {
67+
return c.AzureDeployment
68+
}
69+
return c.OpenAIModel
70+
}
71+
5072
// AIClient returns an OpenAI-compatible client and model names (chat, embedding) using the default provider:
5173
// OPENAI_API_KEY if set, else AZURE_OPENAI_* if set. Returns (nil, "", "") when neither is configured.
5274
func (c *Config) AIClient() (client *openai.Client, chatModel, embeddingModel string) {

demo/go/internal/rag/rag.go

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,26 @@ func (r *RAG) DeleteByDocumentID(ctx context.Context, documentID string) error {
7373
}
7474

7575
func (r *RAG) Query(ctx context.Context, question string) (answer string, sources []map[string]interface{}, err error) {
76+
return r.QueryWithClient(ctx, question, nil, "")
77+
}
78+
79+
// QueryWithClient runs RAG query using optional client and model override.
80+
// If client is nil, uses r.client; if chatModel is empty, uses r.chatModel.
81+
func (r *RAG) QueryWithClient(ctx context.Context, question string, client *openai.Client, chatModel string) (answer string, sources []map[string]interface{}, err error) {
82+
chatClient := r.client
83+
model := r.chatModel
84+
if client != nil {
85+
chatClient = client
86+
}
87+
if chatModel != "" {
88+
model = chatModel
89+
}
7690
prompt, sources, err := r.GetContextForQuery(ctx, question)
7791
if err != nil {
7892
return "", sources, err
7993
}
80-
chatResp, err := r.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
81-
Model: r.chatModel,
94+
chatResp, err := chatClient.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
95+
Model: model,
8296
Messages: []openai.ChatCompletionMessage{
8397
{Role: openai.ChatMessageRoleUser, Content: prompt},
8498
},
@@ -148,8 +162,22 @@ Answer in a direct, helpful way. Do not add sign-offs or closing phrases (e.g. "
148162
}
149163

150164
func (r *RAG) StreamCompletion(ctx context.Context, prompt string, w io.Writer) error {
151-
stream, err := r.client.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
152-
Model: r.chatModel,
165+
return r.StreamCompletionWithClient(ctx, prompt, w, nil, "")
166+
}
167+
168+
// StreamCompletionWithClient streams chat completion with optional client and model override.
169+
// If client is nil, uses r.client; if chatModel is empty, uses r.chatModel.
170+
func (r *RAG) StreamCompletionWithClient(ctx context.Context, prompt string, w io.Writer, client *openai.Client, chatModel string) error {
171+
chatClient := r.client
172+
model := r.chatModel
173+
if client != nil {
174+
chatClient = client
175+
}
176+
if chatModel != "" {
177+
model = chatModel
178+
}
179+
stream, err := chatClient.CreateChatCompletionStream(ctx, openai.ChatCompletionRequest{
180+
Model: model,
153181
Messages: []openai.ChatCompletionMessage{
154182
{Role: openai.ChatMessageRoleUser, Content: prompt},
155183
},

demo/web/src/App.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { Routes, Route } from 'react-router-dom'
22
import './App.css'
33
import { WorkspaceProvider } from '@/context/WorkspaceContext'
4+
import { AIProviderProvider } from '@/context/AIProviderContext'
45
import { AppShell } from '@/components/shell/AppShell'
56
import { CommandPalette } from '@/components/shell/CommandPalette'
67
import { KnowledgeCanvas } from '@/components/canvas/KnowledgeCanvas'
@@ -14,7 +15,8 @@ import { AssessmentSubmissionsView } from '@/views/AssessmentSubmissionsView'
1415
function App() {
1516
return (
1617
<WorkspaceProvider>
17-
<Routes>
18+
<AIProviderProvider>
19+
<Routes>
1820
<Route
1921
path="/"
2022
element={
@@ -75,6 +77,7 @@ function App() {
7577
}
7678
/>
7779
</Routes>
80+
</AIProviderProvider>
7881
</WorkspaceProvider>
7982
)
8083
}

demo/web/src/api/query.ts

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,35 @@
11
import { apiClient } from './client'
22
import type { RagResponse, RagSource } from '../types'
33

4+
/** Optional per-request AI config. Sent to backend; backend uses for chat when provided. */
5+
export interface AIRequestConfig {
6+
api_key?: string | null
7+
model?: string | null
8+
provider?: string | null
9+
}
10+
411
interface QueryRequestBody {
512
prompt: string
13+
query?: string
614
top_k?: number
715
model?: string | null
16+
api_key?: string | null
17+
provider?: string | null
818
}
919

10-
export async function runQuery(prompt: string, topK = 10, model?: string | null): Promise<RagResponse> {
11-
return apiClient.post<RagResponse>('/query', { prompt, top_k: topK, model: model ?? undefined } satisfies QueryRequestBody)
20+
export async function runQuery(
21+
prompt: string,
22+
topK = 10,
23+
options?: { model?: string | null; api_key?: string | null; provider?: string | null } | null
24+
): Promise<RagResponse> {
25+
const body: QueryRequestBody = {
26+
prompt,
27+
top_k: topK,
28+
model: options?.model ?? undefined,
29+
api_key: options?.api_key ?? undefined,
30+
provider: options?.provider ?? undefined,
31+
}
32+
return apiClient.post<RagResponse>('/query', body)
1233
}
1334

1435
export interface QueryStreamCallbacks {
@@ -20,12 +41,17 @@ export interface QueryStreamCallbacks {
2041

2142
export async function runQueryStream(
2243
prompt: string,
23-
callbacks: QueryStreamCallbacks
44+
callbacks: QueryStreamCallbacks,
45+
options?: AIRequestConfig | null
2446
): Promise<boolean> {
47+
const body: Record<string, unknown> = { prompt }
48+
if (options?.api_key != null && options.api_key !== '') body.api_key = options.api_key
49+
if (options?.model != null && options.model !== '') body.model = options.model
50+
if (options?.provider != null && options.provider !== '') body.provider = options.provider
2551
const res = await fetch(`${apiClient.baseUrl}/query/stream`, {
2652
method: 'POST',
2753
headers: { 'Content-Type': 'application/json' },
28-
body: JSON.stringify({ prompt }),
54+
body: JSON.stringify(body),
2955
})
3056
if (!res.ok) {
3157
if (res.status === 404) return false

demo/web/src/components/ChatConsole.tsx

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@ import {
1515
loadSessions,
1616
saveSessions,
1717
sessionTitleFromMessage,
18+
closeSession as closeSessionInStore,
19+
renameSession as renameSessionInStore,
1820
type ChatSession,
1921
type SessionsState,
2022
} from '../lib/chatSessions'
23+
import { ThreadBar } from './ThreadBar'
2124
const STREAM_CHUNK_INTERVAL_MS = 28
2225

2326
const SIGN_OFF_PATTERNS = [
@@ -170,6 +173,27 @@ export function ChatConsole({ documents, selectedDocumentId, projectId }: ChatCo
170173
setError(null)
171174
}, [])
172175

176+
const handleCloseThread = useCallback((id: string) => {
177+
setSessionsState((prev) => {
178+
const nextSessions = closeSessionInStore(prev.sessions, id)
179+
let nextActiveId = prev.activeId
180+
if (prev.activeId === id) {
181+
const idx = prev.sessions.findIndex((s) => s.id === id)
182+
const nextIdx = idx < nextSessions.length ? idx : nextSessions.length - 1
183+
nextActiveId = nextSessions[nextIdx]?.id ?? null
184+
}
185+
return { sessions: nextSessions, activeId: nextActiveId }
186+
})
187+
setError(null)
188+
}, [])
189+
190+
const handleRenameThread = useCallback((id: string, title: string) => {
191+
setSessionsState((prev) => ({
192+
...prev,
193+
sessions: renameSessionInStore(prev.sessions, id, title),
194+
}))
195+
}, [])
196+
173197
useEffect(() => {
174198
if (defaultModel != null && defaultModel !== '' && !ragModel) setRagModel(defaultModel)
175199
}, [defaultModel, ragModel])
@@ -216,7 +240,7 @@ export function ChatConsole({ documents, selectedDocumentId, projectId }: ChatCo
216240
setSending(true)
217241
setError(null)
218242
try {
219-
const res = await runQuery(userMessage.content, 5, ragModel.trim() || undefined)
243+
const res = await runQuery(userMessage.content, 5, { model: ragModel.trim() || undefined })
220244
if (res.answer.startsWith('Query failed:')) {
221245
const raw = res.answer.replace(/^Query failed:\s*/, '').trim()
222246
setError(normalizeQueryError(raw))
@@ -284,10 +308,17 @@ export function ChatConsole({ documents, selectedDocumentId, projectId }: ChatCo
284308

285309
return (
286310
<div className="chat-console" style={{ display: 'flex', flexDirection: 'column', minHeight: 0, background: 'var(--color-bg-alt)' }}>
311+
<ThreadBar
312+
threads={sessions}
313+
activeId={activeSessionId}
314+
onSelect={handleSelectSession}
315+
onNew={handleNewChat}
316+
onRename={handleRenameThread}
317+
onClose={handleCloseThread}
318+
/>
287319
<div
288320
style={{
289-
padding: 'var(--space-lg)',
290-
borderBottom: `1px solid ${'var(--color-border-light)'}`,
321+
padding: 'var(--space-sm) var(--space-lg)',
291322
fontSize: 'var(--text-sm)',
292323
color: 'var(--color-text-muted)',
293324
display: 'flex',
@@ -296,31 +327,7 @@ export function ChatConsole({ documents, selectedDocumentId, projectId }: ChatCo
296327
gap: 'var(--space-md)',
297328
}}
298329
>
299-
<Button type="button" variant="secondary" onClick={handleNewChat} style={{ flexShrink: 0 }}>
300-
New chat
301-
</Button>
302-
<select
303-
value={activeSessionId ?? ''}
304-
onChange={(e) => handleSelectSession(e.target.value === '' ? null : e.target.value)}
305-
title="Switch chat session"
306-
style={{
307-
maxWidth: 220,
308-
padding: '6px 10px',
309-
borderRadius: 'var(--radius-sm)',
310-
border: `1px solid ${'var(--color-border-light)'}`,
311-
fontSize: 'var(--text-sm)',
312-
background: 'var(--color-bg)',
313-
color: 'var(--color-text)',
314-
}}
315-
>
316-
<option value="">New chat</option>
317-
{sessions.map((s) => (
318-
<option key={s.id} value={s.id}>
319-
{s.title}
320-
</option>
321-
))}
322-
</select>
323-
<span style={{ marginLeft: 'auto' }}>
330+
<span>
324331
Grounded in{' '}
325332
<strong>
326333
{completedCount} document{completedCount === 1 ? '' : 's'}

0 commit comments

Comments
 (0)