diff --git a/src/art/client.py b/src/art/client.py index ceb9bf1f..8d23d4cd 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -5,6 +5,7 @@ from typing import Any, AsyncIterator, Iterable, Literal, TypedDict, cast import httpx +from openai import AsyncOpenAI, BaseModel, _exceptions from openai._base_client import AsyncAPIClient, AsyncPaginator, make_request_options from openai._compat import cached_property from openai._qs import Querystring @@ -17,8 +18,6 @@ from openai.resources.models import AsyncModels # noqa: F401 from typing_extensions import override -from openai import AsyncOpenAI, BaseModel, _exceptions - from .trajectories import TrajectoryGroup @@ -291,7 +290,9 @@ def events(self) -> TrainingJobEvents: class TrainingJobEvent(BaseModel): id: str - type: Literal["training_started", "gradient_step", "training_ended"] + type: Literal[ + "training_started", "gradient_step", "training_ended", "training_failed" + ] data: dict[str, Any] diff --git a/src/art/gather.py b/src/art/gather.py index 830ce82d..a9a37624 100644 --- a/src/art/gather.py +++ b/src/art/gather.py @@ -190,7 +190,8 @@ def record_metrics(context: "GatherContext", trajectory: Trajectory) -> None: if logprobs: # TODO: probably shouldn't average this trajectory.metrics["completion_tokens"] = sum( - len(l.content or l.refusal or []) for l in logprobs # noqa: E741 + len(l.content or l.refusal or []) + for l in logprobs # noqa: E741 ) / len(logprobs) context.metric_sums["reward"] += trajectory.reward # type: ignore context.metric_divisors["reward"] += 1 diff --git a/src/art/openai.py b/src/art/openai.py index 039f42a8..a56fcabd 100644 --- a/src/art/openai.py +++ b/src/art/openai.py @@ -128,9 +128,9 @@ def update_chat_completion( choice.message.tool_calls[tool_call.index].id = tool_call.id if tool_call.function: if tool_call.function.name: - choice.message.tool_calls[tool_call.index].function.name = ( - tool_call.function.name - ) + choice.message.tool_calls[ + tool_call.index + ].function.name = tool_call.function.name if tool_call.function.arguments: choice.message.tool_calls[ tool_call.index diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index 4d6a5d3b..dda0c26a 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -156,6 +156,11 @@ async def _train_model( continue elif event.type == "training_ended": return + elif event.type == "training_failed": + error_message = event.data.get( + "error_message", "Training failed with an unknown error" + ) + raise RuntimeError(f"Training job failed: {error_message}") after = event.id # ------------------------------------------------------------------