Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/art/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import os
from typing import Any, Iterable, Literal, TypedDict, cast
from typing import Any, Iterable, Literal, Type, TypedDict, TypeVar, cast

import httpx
from openai import AsyncOpenAI, BaseModel, _exceptions
from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options
from openai._base_client import (
AsyncAPIClient,
AsyncPaginator,
FinalRequestOptions,
make_request_options,
)
from openai._compat import cached_property
from openai._qs import Querystring
from openai._resource import AsyncAPIResource
from openai._streaming import AsyncStream
from openai._types import NOT_GIVEN, NotGiven, Omit
from openai._utils import is_mapping, maybe_transform
from openai._version import __version__
Expand All @@ -17,6 +23,8 @@

from .trajectories import TrajectoryGroup

ResponseT = TypeVar("ResponseT")


class Model(BaseModel):
id: str
Expand Down Expand Up @@ -214,7 +222,23 @@ def __init__(
version=__version__,
base_url=base_url or "https://api.training.wandb.ai/v1",
_strict_response_validation=False,
max_retries=0,
max_retries=3,
)

@override
async def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
*,
stream: bool = False,
stream_cls: type[AsyncStream[Any]] | None = None,
) -> ResponseT | AsyncStream[Any]:
# Disable retries for POST requests
if options.method.upper() == "POST":
options.max_retries = 0
return await super().request(
cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls
)

@cached_property
Expand Down
2 changes: 1 addition & 1 deletion src/art/serverless/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def _train_model(
num_sequences: int | None = None
pbar: tqdm.tqdm | None = None
while True:
await asyncio.sleep(0.5)
await asyncio.sleep(1)
async for event in self._client.training_jobs.events.list(
training_job_id=training_job.id, after=after or NOT_GIVEN
):
Expand Down
Loading