Skip to content

Commit 07db497

Browse files
added support for weaviate vector databse (#493)
* added support for weaviate vector databse Signed-off-by: pranaychandekar <[email protected]> * added support for in local db for weaviate vector store Signed-off-by: pranaychandekar <[email protected]> * added unit test case for weaviate vector store Signed-off-by: pranaychandekar <[email protected]> * resolved unit test case error for weaviate vector store Signed-off-by: pranaychandekar <[email protected]> * increased code coverage resolved pylint issues pylint: disabled C0413 Signed-off-by: pranaychandekar <[email protected]> --------- Signed-off-by: pranaychandekar <[email protected]>
1 parent f60c303 commit 07db497

File tree

6 files changed

+255
-2
lines changed

6 files changed

+255
-2
lines changed

examples/data_manager/vector_store.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def run():
1818
'milvus',
1919
'chromadb',
2020
'docarray',
21-
'redis'
21+
'redis',
22+
'weaviate',
2223
]
2324
for vector_store in vector_stores:
2425
cache_base = CacheBase('sqlite')
@@ -40,4 +41,4 @@ def run():
4041

4142

4243
if __name__ == '__main__':
43-
run()
44+
run()

gptcache/manager/vector_data/manager.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
COLLECTION_NAME = "gptcache"
2929

30+
WEAVIATE_TIMEOUT_CONFIG = (10, 60)
31+
WEAVIATE_STARTUP_PERIOD = 5
32+
3033

3134
# pylint: disable=import-outside-toplevel
3235
class VectorBase:
@@ -257,6 +260,31 @@ def get(name, **kwargs):
257260
flush_interval_sec=flush_interval_sec,
258261
index_params=index_params,
259262
)
263+
elif name == "weaviate":
264+
from gptcache.manager.vector_data.weaviate import Weaviate
265+
266+
url = kwargs.get("url", None)
267+
auth_client_secret = kwargs.get("auth_client_secret", None)
268+
timeout_config = kwargs.get("timeout_config", WEAVIATE_TIMEOUT_CONFIG)
269+
proxies = kwargs.get("proxies", None)
270+
trust_env = kwargs.get("trust_env", False)
271+
additional_headers = kwargs.get("additional_headers", None)
272+
startup_period = kwargs.get("startup_period", WEAVIATE_STARTUP_PERIOD)
273+
embedded_options = kwargs.get("embedded_options", None)
274+
additional_config = kwargs.get("additional_config", None)
275+
276+
vector_base = Weaviate(
277+
url=url,
278+
auth_client_secret=auth_client_secret,
279+
timeout_config=timeout_config,
280+
proxies=proxies,
281+
trust_env=trust_env,
282+
additional_headers=additional_headers,
283+
startup_period=startup_period,
284+
embedded_options=embedded_options,
285+
additional_config=additional_config,
286+
top_k=top_k,
287+
)
260288
else:
261289
raise NotFoundError("vector store", name)
262290
return vector_base
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
from typing import List, Optional, Tuple, Union
2+
import numpy as np
3+
4+
from gptcache.utils import import_weaviate
5+
from gptcache.utils.log import gptcache_log
6+
from gptcache.manager.vector_data.base import VectorBase, VectorData
7+
8+
import_weaviate()
9+
10+
from weaviate import Client
11+
from weaviate.auth import AuthCredentials
12+
from weaviate.config import Config
13+
from weaviate.embedded import EmbeddedOptions
14+
from weaviate.types import NUMBERS
15+
16+
17+
class Weaviate(VectorBase):
18+
"""
19+
vector store: Weaviate
20+
"""
21+
22+
TIMEOUT_TYPE = Union[Tuple[NUMBERS, NUMBERS], NUMBERS]
23+
24+
def __init__(
25+
self,
26+
url: Optional[str] = None,
27+
auth_client_secret: Optional[AuthCredentials] = None,
28+
timeout_config: TIMEOUT_TYPE = (10, 60),
29+
proxies: Union[dict, str, None] = None,
30+
trust_env: bool = False,
31+
additional_headers: Optional[dict] = None,
32+
startup_period: Optional[int] = 5,
33+
embedded_options: Optional[EmbeddedOptions] = None,
34+
additional_config: Optional[Config] = None,
35+
top_k: Optional[int] = 1,
36+
) -> None:
37+
38+
if url is None and embedded_options is None:
39+
embedded_options = EmbeddedOptions()
40+
41+
self.client = Client(
42+
url=url,
43+
auth_client_secret=auth_client_secret,
44+
timeout_config=timeout_config,
45+
proxies=proxies,
46+
trust_env=trust_env,
47+
additional_headers=additional_headers,
48+
startup_period=startup_period,
49+
embedded_options=embedded_options,
50+
additional_config=additional_config,
51+
)
52+
53+
self._create_class()
54+
self.top_k = top_k
55+
56+
def _create_class(self):
57+
class_schema = self._get_default_class_schema()
58+
59+
self.class_name = class_schema.get("class")
60+
61+
if self.client.schema.exists(self.class_name):
62+
gptcache_log.warning(
63+
"The %s collection already exists, and it will be used directly.",
64+
self.class_name,
65+
)
66+
else:
67+
self.client.schema.create_class(class_schema)
68+
69+
@staticmethod
70+
def _get_default_class_schema() -> dict:
71+
return {
72+
"class": "GPTCache",
73+
"description": "LLM response cache",
74+
"properties": [
75+
{
76+
"name": "data_id",
77+
"dataType": ["int"],
78+
"description": "The data-id generated by GPTCache for vectors.",
79+
}
80+
],
81+
"vectorIndexConfig": {"distance": "cosine"},
82+
}
83+
84+
def mul_add(self, datas: List[VectorData]):
85+
with self.client.batch(batch_size=100, dynamic=True) as batch:
86+
for data in datas:
87+
properties = {
88+
"data_id": data.id,
89+
}
90+
91+
batch.add_data_object(
92+
data_object=properties, class_name=self.class_name, vector=data.data
93+
)
94+
95+
def search(self, data: np.ndarray, top_k: int = -1):
96+
if top_k == -1:
97+
top_k = self.top_k
98+
99+
result = (
100+
self.client.query.get(class_name=self.class_name, properties=["data_id"])
101+
.with_near_vector(content={"vector": data})
102+
.with_additional(["distance"])
103+
.with_limit(top_k)
104+
.do()
105+
)
106+
107+
return list(
108+
map(
109+
lambda x: (x["_additional"]["distance"], x["data_id"]),
110+
result["data"]["Get"][self.class_name],
111+
)
112+
)
113+
114+
def _get_uuids(self, data_ids):
115+
uuid_list = []
116+
117+
for data_id in data_ids:
118+
res = (
119+
self.client.query.get(
120+
class_name=self.class_name, properties=["data_id"]
121+
)
122+
.with_where(
123+
{"path": ["data_id"], "operator": "Equal", "valueInt": data_id}
124+
)
125+
.with_additional(["id"])
126+
.do()
127+
)
128+
129+
uuid_list.append(
130+
res["data"]["Get"][self.class_name][0]["_additional"]["id"]
131+
)
132+
133+
return uuid_list
134+
135+
def delete(self, ids):
136+
uuids = self._get_uuids(ids)
137+
138+
for uuid in uuids:
139+
self.client.data_object.delete(class_name=self.class_name, uuid=uuid)
140+
141+
def rebuild(self, ids=None):
142+
return
143+
144+
def flush(self):
145+
self.client.batch.flush()
146+
147+
def close(self):
148+
self.flush()
149+
150+
def get_embeddings(self, data_id: int):
151+
results = (
152+
self.client.query.get(class_name=self.class_name, properties=["data_id"])
153+
.with_where(
154+
{
155+
"path": ["data_id"],
156+
"operator": "Equal",
157+
"valueInt": data_id,
158+
}
159+
)
160+
.with_additional(["vector"])
161+
.with_limit(1)
162+
.do()
163+
)
164+
165+
results = results["data"]["Get"][self.class_name]
166+
167+
if len(results) < 1:
168+
return None
169+
170+
vec_emb = np.asarray(results[0]["_additional"]["vector"], dtype="float32")
171+
return vec_emb
172+
173+
def update_embeddings(self, data_id: int, emb: np.ndarray):
174+
self.delete([data_id])
175+
176+
properties = {
177+
"data_id": data_id,
178+
}
179+
180+
self.client.data_object.create(
181+
data_object=properties, class_name=self.class_name, vector=emb
182+
)

gptcache/utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"import_fastapi",
4242
"import_redis",
4343
"import_qdrant",
44+
"import_weaviate",
4445
]
4546

4647
import importlib.util
@@ -262,3 +263,7 @@ def import_redis():
262263

263264
def import_starlette():
264265
_check_library("starlette")
266+
267+
268+
def import_weaviate():
269+
_check_library("weaviate-client")

pylint.conf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ disable=abstract-method,
148148
zip-builtin-not-iterating,
149149
missing-module-docstring,
150150
super-init-not-called,
151+
wrong-import-position
151152

152153

153154
[REPORTS]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
import numpy as np
3+
4+
from gptcache.manager.vector_data import VectorBase
5+
from gptcache.manager.vector_data.base import VectorData
6+
7+
8+
class TestWeaviateDB(unittest.TestCase):
9+
def test_normal(self):
10+
size = 1000
11+
dim = 512
12+
top_k = 10
13+
14+
db = VectorBase(
15+
"weaviate",
16+
top_k=top_k
17+
)
18+
19+
db._create_class()
20+
data = np.random.randn(size, dim).astype(np.float32)
21+
db.mul_add([VectorData(id=i, data=v) for v, i in zip(data, range(size))])
22+
self.assertEqual(len(db.search(data[0])), top_k)
23+
db.mul_add([VectorData(id=size, data=data[0])])
24+
ret = db.search(data[0])
25+
self.assertIn(ret[0][1], [0, size])
26+
self.assertIn(ret[1][1], [0, size])
27+
db.delete([0, 1, 2, 3, 4, 5, size])
28+
ret = db.search(data[0])
29+
self.assertNotIn(ret[0][1], [0, size])
30+
db.rebuild()
31+
db.update_embeddings(6, data[7])
32+
emb = db.get_embeddings(6)
33+
self.assertEqual(emb.tolist(), data[7].tolist())
34+
emb = db.get_embeddings(0)
35+
self.assertIsNone(emb)
36+
db.close()

0 commit comments

Comments
 (0)