From b2b58267f1c282047541ce81eb9cc046a6095c21 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 7 Oct 2025 16:56:58 -0700 Subject: [PATCH 1/3] Retry --- src/art/client.py | 30 +++++++++++++++++++++++++++--- src/art/serverless/backend.py | 2 +- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/art/client.py b/src/art/client.py index fe7dbdde..a45d1278 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -1,12 +1,18 @@ import os -from typing import Any, Iterable, Literal, TypedDict, cast +from typing import Any, Iterable, Literal, Type, TypedDict, TypeVar, cast import httpx from openai import AsyncOpenAI, BaseModel, _exceptions -from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options +from openai._base_client import ( + AsyncAPIClient, + AsyncPaginator, + FinalRequestOptions, + make_request_options, +) from openai._compat import cached_property from openai._qs import Querystring from openai._resource import AsyncAPIResource +from openai._streaming import AsyncStream from openai._types import NOT_GIVEN, NotGiven, Omit from openai._utils import is_mapping, maybe_transform from openai._version import __version__ @@ -17,6 +23,8 @@ from .trajectories import TrajectoryGroup +ResponseT = TypeVar("ResponseT") + class Model(BaseModel): id: str @@ -214,7 +222,23 @@ def __init__( version=__version__, base_url=base_url or "https://api.training.wandb.ai/v1", _strict_response_validation=False, - max_retries=0, + max_retries=3, + ) + + @override + async def request( + self, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + *, + stream: bool = False, + stream_cls: type[AsyncStream[Any]] | None = None, + ) -> ResponseT | AsyncStream[Any]: + # Disable retries for POST requests + if options.method.upper() == "POST": + options.max_retries = 0 + return await super().request( + cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls ) @cached_property diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 8ca957bf..dad32856 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -136,7 +136,7 @@ async def _train_model( num_sequences: int | None = None pbar: tqdm.tqdm | None = None while True: - await asyncio.sleep(0.5) + await asyncio.sleep(3) async for event in self._client.training_jobs.events.list( training_job_id=training_job.id, after=after or NOT_GIVEN ): From aca354ee21e7b4beebcd6371dc116e64ee7aeefc Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 7 Oct 2025 18:48:06 -0700 Subject: [PATCH 2/3] Upgrade pyproject to new convention --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7b172a43..6b8ba063 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,7 +95,9 @@ asyncio_mode = "auto" [tool.uv] required-version = ">=0.6.15" -dev-dependencies = [ + +[dependency-groups] +dev = [ "black>=25.1.0", "ipykernel>=6.29.5", "ipywidgets>=8.1.5", From 5eedee625e6cf83846d02bf0e980903ba1acd9f2 Mon Sep 17 00:00:00 2001 From: Angky William Date: Tue, 7 Oct 2025 18:56:55 -0700 Subject: [PATCH 3/3] polling to 1s --- pyproject.toml | 4 +--- src/art/serverless/backend.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6b8ba063..7b172a43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,9 +95,7 @@ asyncio_mode = "auto" [tool.uv] required-version = ">=0.6.15" - -[dependency-groups] -dev = [ +dev-dependencies = [ "black>=25.1.0", "ipykernel>=6.29.5", "ipywidgets>=8.1.5", diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index dad32856..604faea5 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -136,7 +136,7 @@ async def _train_model( num_sequences: int | None = None pbar: tqdm.tqdm | None = None while True: - await asyncio.sleep(3) + await asyncio.sleep(1) async for event in self._client.training_jobs.events.list( training_job_id=training_job.id, after=after or NOT_GIVEN ):