diff --git a/src/art/client.py b/src/art/client.py index fe7dbdde..a45d1278 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -1,12 +1,18 @@ import os -from typing import Any, Iterable, Literal, TypedDict, cast +from typing import Any, Iterable, Literal, Type, TypedDict, TypeVar, cast import httpx from openai import AsyncOpenAI, BaseModel, _exceptions -from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options +from openai._base_client import ( + AsyncAPIClient, + AsyncPaginator, + FinalRequestOptions, + make_request_options, +) from openai._compat import cached_property from openai._qs import Querystring from openai._resource import AsyncAPIResource +from openai._streaming import AsyncStream from openai._types import NOT_GIVEN, NotGiven, Omit from openai._utils import is_mapping, maybe_transform from openai._version import __version__ @@ -17,6 +23,8 @@ from .trajectories import TrajectoryGroup +ResponseT = TypeVar("ResponseT") + class Model(BaseModel): id: str @@ -214,7 +222,23 @@ def __init__( version=__version__, base_url=base_url or "https://api.training.wandb.ai/v1", _strict_response_validation=False, - max_retries=0, + max_retries=3, + ) + + @override + async def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: bool = False, + stream_cls: type[AsyncStream[Any]] | None = None, + ) -> ResponseT | AsyncStream[Any]: + # Disable retries for POST requests + if options.method.upper() == "POST": + options.max_retries = 0 + return await super().request( + cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls ) @cached_property diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 8ca957bf..604faea5 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -136,7 +136,7 @@ async def _train_model( num_sequences: int | None = None pbar: tqdm.tqdm | None = None while True: - await asyncio.sleep(0.5) + await asyncio.sleep(1) async for event in self._client.training_jobs.events.list( training_job_id=training_job.id, after=after or NOT_GIVEN ):