Skip to content

Commit d7c2e23

Browse files
author
nik
committed
Reformat the code
1 parent e73e317 commit d7c2e23

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

adala/runtimes/_openai.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def check_if_new_openai_version():
3535

3636
@retry(wait=wait_random(min=5, max=10), stop=stop_after_attempt(3))
3737
def chat_completion_call(model, messages):
38-
return openai.ChatCompletion.create(model=model, messages=messages, timeout=120, request_timeout=120)
38+
return openai.ChatCompletion.create(
39+
model=model, messages=messages, timeout=120, request_timeout=120
40+
)
3941

4042

4143
class OpenAIChatRuntime(Runtime):
@@ -161,9 +163,12 @@ def record_to_record(
161163
completion_text = self.execute(messages)
162164

163165
field_schema = field_schema or {}
164-
if output_field_name in field_schema and field_schema[output_field_name]["type"] == "array":
166+
if (
167+
output_field_name in field_schema
168+
and field_schema[output_field_name]["type"] == "array"
169+
):
165170
# expected output is one item from the array
166-
expected_items = field_schema[output_field_name]['items']['enum']
171+
expected_items = field_schema[output_field_name]["items"]["enum"]
167172
completion_text = self._match_items(completion_text, expected_items)
168173

169174
return {output_field_name: completion_text}
@@ -176,7 +181,12 @@ def _match_items(self, query: str, items: List[str]) -> str:
176181
filtered_items = items
177182

178183
# soft constraint: find the most similar item to the query
179-
scores = list(map(lambda item: difflib.SequenceMatcher(None, query, item).ratio(), filtered_items))
184+
scores = list(
185+
map(
186+
lambda item: difflib.SequenceMatcher(None, query, item).ratio(),
187+
filtered_items,
188+
)
189+
)
180190
matched_item = filtered_items[scores.index(max(scores))]
181191
return matched_item
182192

0 commit comments

Comments
 (0)