|
1 | 1 | import os |
2 | | -from typing import Any, Iterable, Literal, TypedDict, cast |
| 2 | +from typing import Any, Iterable, Literal, Type, TypedDict, TypeVar, cast |
3 | 3 |
|
4 | 4 | import httpx |
5 | 5 | from openai import AsyncOpenAI, BaseModel, _exceptions |
6 | | -from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options |
| 6 | +from openai._base_client import ( |
| 7 | + AsyncAPIClient, |
| 8 | + AsyncPaginator, |
| 9 | + FinalRequestOptions, |
| 10 | + make_request_options, |
| 11 | +) |
7 | 12 | from openai._compat import cached_property |
8 | 13 | from openai._qs import Querystring |
9 | 14 | from openai._resource import AsyncAPIResource |
| 15 | +from openai._streaming import AsyncStream |
10 | 16 | from openai._types import NOT_GIVEN, NotGiven, Omit |
11 | 17 | from openai._utils import is_mapping, maybe_transform |
12 | 18 | from openai._version import __version__ |
|
17 | 23 |
|
18 | 24 | from .trajectories import TrajectoryGroup |
19 | 25 |
|
| 26 | +ResponseT = TypeVar("ResponseT") |
| 27 | + |
20 | 28 |
|
21 | 29 | class Model(BaseModel): |
22 | 30 | id: str |
@@ -214,7 +222,23 @@ def __init__( |
214 | 222 | version=__version__, |
215 | 223 | base_url=base_url or "https://api.training.wandb.ai/v1", |
216 | 224 | _strict_response_validation=False, |
217 | | - max_retries=0, |
| 225 | + max_retries=3, |
| 226 | + ) |
| 227 | + |
| 228 | + @override |
| 229 | + async def request( |
| 230 | + self, |
| 231 | + cast_to: Type[ResponseT], |
| 232 | + options: FinalRequestOptions, |
| 233 | + *, |
| 234 | + stream: bool = False, |
| 235 | + stream_cls: type[AsyncStream[Any]] | None = None, |
| 236 | + ) -> ResponseT | AsyncStream[Any]: |
| 237 | + # Disable retries for POST requests |
| 238 | + if options.method.upper() == "POST": |
| 239 | + options.max_retries = 0 |
| 240 | + return await super().request( |
| 241 | + cast_to=cast_to, options=options, stream=stream, stream_cls=stream_cls |
218 | 242 | ) |
219 | 243 |
|
220 | 244 | @cached_property |
|
0 commit comments