Skip to content

Commit 45b7f41

Browse files
committed
feat!: allow adding metadata to indexed items.
1 parent 7023783 commit 45b7f41

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

src/lib.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
use std::sync::Arc;
1+
use std::{collections::HashMap, sync::Arc};
22

33
use anyhow::Context;
44
use fastembed::{TextEmbedding, TokenizerFiles, UserDefinedEmbeddingModel};
55
use qdrant_client::{
66
Qdrant,
77
qdrant::{
88
CreateCollectionBuilder, Distance, PointStruct, QuantizationType, Query,
9-
QueryPointsBuilder, ScalarQuantization, UpsertPointsBuilder, VectorParamsBuilder,
9+
QueryPointsBuilder, ScalarQuantization, UpsertPointsBuilder, Value, VectorParamsBuilder,
1010
VectorsConfigBuilder, quantization_config::Quantization,
1111
},
1212
};
@@ -35,8 +35,8 @@ pub struct Item {
3535

3636
#[derive(Serialize, Debug, Clone)]
3737
pub struct SearchResult {
38-
pub id: String,
3938
pub score: f32,
39+
pub payload: HashMap<String, Value>,
4040
}
4141

4242
impl Engine {
@@ -148,12 +148,9 @@ impl Engine {
148148
Ok(result
149149
.result
150150
.into_iter()
151-
.filter_map(|x| match x.payload.get("id") {
152-
Some(v) if v.is_str() => Some(SearchResult {
153-
id: v.as_str().unwrap().into(),
154-
score: x.score,
155-
}),
156-
_ => None,
151+
.map(|x| SearchResult {
152+
score: x.score,
153+
payload: x.payload,
157154
})
158155
.collect())
159156
}

src/main.rs

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use poem::{
1010
};
1111

1212
use clap::Parser;
13-
use semantic_search_api::{Engine, Item, Payload, SearchResult};
14-
use serde::Deserialize;
13+
use semantic_search_api::{Engine, Item, Payload};
14+
use serde::{Deserialize, Serialize};
1515
use serde_json::Value;
1616
use uuid::{Uuid, uuid};
1717

@@ -63,20 +63,30 @@ async fn get_uuid(Path(id): Path<String>) -> String {
6363
Uuid::new_v5(&UUID_NAMESPACE, id.as_bytes()).to_string()
6464
}
6565

66+
#[derive(Deserialize)]
67+
struct IndexItem {
68+
text: String,
69+
metadata: Option<serde_json::Value>,
70+
}
71+
6672
#[handler]
6773
async fn post_index(
6874
Path(id): Path<String>,
69-
body: Body,
75+
Json(item): Json<IndexItem>,
7076
engine: Data<&Engine>,
7177
) -> Result<Json<Value>> {
7278
let uuid = Uuid::new_v5(&UUID_NAMESPACE, id.as_bytes());
7379
let mut payload = Payload::new();
7480

7581
payload.insert("id", id);
82+
if let Some(metadata) = item.metadata {
83+
payload.insert("metadata", metadata.to_string());
84+
}
85+
7686
engine
7787
.index([Item {
7888
id: uuid,
79-
text: body.into_string().await.map_err(anyhow::Error::from)?,
89+
text: item.text,
8090
payload,
8191
}])
8292
.await?;
@@ -100,15 +110,35 @@ struct PostSearchQuery {
100110
limit: Option<u64>,
101111
}
102112

113+
#[derive(Serialize)]
114+
struct PostSearchResult {
115+
id: String,
116+
score: f32,
117+
metadata: Option<serde_json::Value>,
118+
}
119+
103120
#[handler]
104121
async fn post_search(
105122
body: Body,
106123
engine: Data<&Engine>,
107124
query: Query<PostSearchQuery>,
108-
) -> Result<Json<Vec<SearchResult>>> {
125+
) -> Result<Json<Vec<PostSearchResult>>> {
109126
let text = body.into_string().await.map_err(anyhow::Error::from)?;
110127
let results = engine.search(&text, query.limit).await?;
111-
Ok(Json(results))
128+
Ok(Json(
129+
results
130+
.into_iter()
131+
.map(|x| PostSearchResult {
132+
// TODO: remove unwraps and add error handling.
133+
id: x.payload.get("id").unwrap().as_str().unwrap().clone(),
134+
metadata: x
135+
.payload
136+
.get("metadata")
137+
.map(|x| serde_json::from_str(x.as_str().unwrap()).unwrap()),
138+
score: x.score,
139+
})
140+
.collect::<Vec<_>>(),
141+
))
112142
}
113143

114144
#[tokio::main]
@@ -127,8 +157,6 @@ async fn main() -> anyhow::Result<()> {
127157

128158
let addr = format!("0.0.0.0:{}", args.listen_port);
129159
println!("Starting server on: {addr}");
130-
Server::new(TcpListener::bind(addr))
131-
.run(app)
132-
.await?;
160+
Server::new(TcpListener::bind(addr)).run(app).await?;
133161
Ok(())
134162
}

0 commit comments

Comments
 (0)