Skip to content

Commit 10437df

Browse files
authored
feat: Retry GET and DELETE (#431)
* Retry * Upgrade pyproject to new convention * polling to 1s
1 parent fe5bd7d commit 10437df

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

src/art/client.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import os
2-
from typing import Any, Iterable, Literal, TypedDict, cast
2+
from typing import Any, Iterable, Literal, Type, TypedDict, TypeVar, cast
33

44
import httpx
55
from openai import AsyncOpenAI, BaseModel, _exceptions
6-
from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options
6+
from openai._base_client import (
7+
AsyncAPIClient,
8+
AsyncPaginator,
9+
FinalRequestOptions,
10+
make_request_options,
11+
)
712
from openai._compat import cached_property
813
from openai._qs import Querystring
914
from openai._resource import AsyncAPIResource
15+
from openai._streaming import AsyncStream
1016
from openai._types import NOT_GIVEN, NotGiven, Omit
1117
from openai._utils import is_mapping, maybe_transform
1218
from openai._version import __version__
@@ -17,6 +23,8 @@
1723

1824
from .trajectories import TrajectoryGroup
1925

26+
ResponseT = TypeVar("ResponseT")
27+
2028

2129
class Model(BaseModel):
2230
id: str
@@ -214,7 +222,23 @@ def __init__(
214222
version=__version__,
215223
base_url=base_url or "https://api.training.wandb.ai/v1",
216224
_strict_response_validation=False,
217-
max_retries=0,
225+
max_retries=3,
226+
)
227+
228+
@override
229+
async def request(
230+
self,
231+
cast_to: Type[ResponseT],
232+
options: FinalRequestOptions,
233+
*,
234+
stream: bool = False,
235+
stream_cls: type[AsyncStream[Any]] | None = None,
236+
) -> ResponseT | AsyncStream[Any]:
237+
# Disable retries for POST requests
238+
if options.method.upper() == "POST":
239+
options.max_retries = 0
240+
return await super().request(
241+
cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls
218242
)
219243

220244
@cached_property

src/art/serverless/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def _train_model(
136136
num_sequences: int | None = None
137137
pbar: tqdm.tqdm | None = None
138138
while True:
139-
await asyncio.sleep(0.5)
139+
await asyncio.sleep(1)
140140
async for event in self._client.training_jobs.events.list(
141141
training_job_id=training_job.id, after=after or NOT_GIVEN
142142
):

0 commit comments

Comments
 (0)