Skip to content

Commit a5fc129

Browse files
authored
Bugfix/#486: provide redis_connection for creating all Object Models (#487)
* [mod] provide redis_connection while model creation for consistency Signed-off-by: Anurag Wagh <[email protected]> * [add] documentation for `get_models` method Signed-off-by: Anurag Wagh <[email protected]> * [add] set redis connection details to common variable for unit test Signed-off-by: Anurag Wagh <[email protected]> * [add] provide redis connection for embedded model Signed-off-by: Anurag Wagh <[email protected]> * [add] add doc string for Counter class Signed-off-by: Anurag Wagh <[email protected]> --------- Signed-off-by: Anurag Wagh <[email protected]>
1 parent e811c62 commit a5fc129

File tree

2 files changed

+44
-16
lines changed

2 files changed

+44
-16
lines changed

gptcache/manager/scalar_data/redis_storage.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,32 @@
2020
from redis_om import JsonModel, EmbeddedJsonModel, NotFoundError, Field, Migrator
2121

2222

23-
def get_models(global_key):
23+
def get_models(global_key: str, redis_connection: Redis):
24+
"""
25+
Get all the models for the given global key and redis connection.
26+
:param global_key: Global key will be used as a prefix for all the keys
27+
:type global_key: str
28+
29+
:param redis_connection: Redis connection to use for all the models.
30+
Note: This needs to be explicitly mentioned in `Meta` class for each Object Model,
31+
otherwise it will use the default connection from the pool.
32+
:type redis_connection: Redis
33+
"""
34+
2435
class Counter:
36+
"""
37+
counter collection
38+
"""
2539
key_name = global_key + ":counter"
40+
database = redis_connection
2641

2742
@classmethod
28-
def incr(cls, con: Redis):
29-
con.incr(cls.key_name)
43+
def incr(cls):
44+
cls.database.incr(cls.key_name)
3045

3146
@classmethod
32-
def get(cls, con: Redis):
33-
return con.get(cls.key_name)
47+
def get(cls):
48+
return cls.database.get(cls.key_name)
3449

3550
class Embedding:
3651
"""
@@ -75,6 +90,9 @@ class Answers(EmbeddedJsonModel):
7590
answer: str
7691
answer_type: int
7792

93+
class Meta:
94+
database = redis_connection
95+
7896
class Questions(JsonModel):
7997
"""
8098
questions collection
@@ -89,6 +107,7 @@ class Questions(JsonModel):
89107
class Meta:
90108
global_key_prefix = global_key
91109
model_key_prefix = "questions"
110+
database = redis_connection
92111

93112
class Sessions(JsonModel):
94113
"""
@@ -98,6 +117,7 @@ class Sessions(JsonModel):
98117
class Meta:
99118
global_key_prefix = global_key
100119
model_key_prefix = "sessions"
120+
database = redis_connection
101121

102122
session_id: str = Field(index=True)
103123
session_question: str
@@ -111,6 +131,7 @@ class QuestionDeps(JsonModel):
111131
class Meta:
112132
global_key_prefix = global_key
113133
model_key_prefix = "ques_deps"
134+
database = redis_connection
114135

115136
question_id: str = Field(index=True)
116137
dep_name: str
@@ -125,6 +146,7 @@ class Report(JsonModel):
125146
class Meta:
126147
global_key_prefix = global_key
127148
model_key_prefix = "report"
149+
database = redis_connection
128150

129151
user_question: str
130152
cache_question_id: int = Field(index=True)
@@ -194,16 +216,16 @@ def __init__(
194216
self._session,
195217
self._counter,
196218
self._report,
197-
) = get_models(global_key_prefix)
219+
) = get_models(global_key_prefix, redis_connection=self.con)
198220

199221
Migrator().run()
200222

201223
def create(self):
202224
pass
203225

204226
def _insert(self, data: CacheData, pipeline: Pipeline = None):
205-
self._counter.incr(self.con)
206-
pk = str(self._counter.get(self.con))
227+
self._counter.incr()
228+
pk = str(self._counter.get())
207229
answers = data.answers if isinstance(data.answers, list) else [data.answers]
208230
all_data = []
209231
for answer in answers:
@@ -360,7 +382,8 @@ def delete_session(self, keys: List[str]):
360382
self._session.delete_many(sessions_to_delete, pipeline)
361383
pipeline.execute()
362384

363-
def report_cache(self, user_question, cache_question, cache_question_id, cache_answer, similarity_value, cache_delta_time):
385+
def report_cache(self, user_question, cache_question, cache_question_id, cache_answer, similarity_value,
386+
cache_delta_time):
364387
self._report(
365388
user_question=user_question,
366389
cache_question=cache_question,

tests/unit_tests/manager/test_redis_cache_storage.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,30 @@
44
import numpy as np
55

66
from gptcache.manager.scalar_data.base import CacheData, Question
7-
from gptcache.manager.scalar_data.redis_storage import RedisCacheStorage
7+
from gptcache.manager.scalar_data.redis_storage import RedisCacheStorage, get_models
88
from gptcache.utils import import_redis
99

1010
import_redis()
11-
from redis_om import get_redis_connection
11+
from redis_om import get_redis_connection, RedisModel
1212

1313

1414
class TestRedisStorage(unittest.TestCase):
1515
test_dbname = "gptcache_test"
16+
url = "redis://default:default@localhost:6379"
1617

1718
def setUp(cls) -> None:
1819
cls._clear_test_db()
1920

2021
@staticmethod
2122
def _clear_test_db():
22-
r = get_redis_connection()
23+
r = get_redis_connection(url=TestRedisStorage.url)
2324
r.flushall()
2425
r.flushdb()
2526
time.sleep(1)
2627

2728
def test_normal(self):
28-
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
29+
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
30+
url=self.url)
2931
data = []
3032
for i in range(1, 10):
3133
data.append(
@@ -61,7 +63,8 @@ def test_normal(self):
6163
assert redis_storage.count(is_all=True) == 7
6264

6365
def test_with_deps(self):
64-
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
66+
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
67+
url=self.url)
6568
data_id = redis_storage.batch_insert(
6669
[
6770
CacheData(
@@ -98,7 +101,8 @@ def test_with_deps(self):
98101
assert ret.question.deps[1].dep_type == 1
99102

100103
def test_create_on(self):
101-
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
104+
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
105+
url=self.url)
102106
redis_storage.create()
103107
data = []
104108
for i in range(1, 10):
@@ -124,7 +128,8 @@ def test_create_on(self):
124128
assert last_access1 < last_access2
125129

126130
def test_session(self):
127-
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname)
131+
redis_storage = RedisCacheStorage(global_key_prefix=self.test_dbname,
132+
url=self.url)
128133
data = []
129134
for i in range(1, 11):
130135
data.append(

0 commit comments

Comments
 (0)