diff --git a/src/art/client.py b/src/art/client.py index 53576b77..9bb56b6b 100644 --- a/src/art/client.py +++ b/src/art/client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import os from typing import AsyncIterator, Iterable, Literal, TypedDict, cast @@ -36,6 +38,11 @@ class DeleteCheckpointsResponse(BaseModel): not_found_steps: list[int] + +class LogResponse(BaseModel): + success: bool + + class Checkpoints(AsyncAPIResource): async def retrieve( self, *, model_id: str, step: int | Literal["latest"] @@ -81,6 +88,27 @@ async def delete( options=dict(max_retries=0), ) + async def log_trajectories( + self, + *, + model_id: str, + trajectory_groups: list[TrajectoryGroup], + split: str = "val", + ) -> LogResponse: + return await self._post( + f"/preview/models/{model_id}/log", + body={ + "model_id": model_id, + "trajectory_groups": [ + trajectory_group.model_dump() + for trajectory_group in trajectory_groups + ], + "split": split, + }, + cast_to=LogResponse, + options=dict(max_retries=0), + ) + class Model(BaseModel): id: str diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index a89f9cf2..70185912 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,5 +1,6 @@ import asyncio -from typing import TYPE_CHECKING, AsyncIterator, Literal +from typing import TYPE_CHECKING, AsyncIterator, Literal, cast +import os from art.client import Client from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider @@ -56,12 +57,16 @@ def _model_inference_name(self, model: "TrainableModel") -> str: assert model.entity is not None, "Model entity is required" return f"{model.entity}/{model.project}/{model.name}" - async def _get_step(self, model: "TrainableModel") -> int: - assert model.id is not None, "Model ID is required" - checkpoint = await self._client.checkpoints.retrieve( - model_id=model.id, step="latest" - ) - return checkpoint.step + + async def _get_step(self, model: "Model") -> int: + if model.trainable: + assert model.id is not None, "Model ID is required" + checkpoint = await self._client.checkpoints.retrieve( + model_id=model.id, step="latest" + ) + return checkpoint.step + # Non-trainable models do not have checkpoints/steps; default to 0 + return 0 async def _delete_checkpoints( self, @@ -99,27 +104,16 @@ async def _log( trajectory_groups: list[TrajectoryGroup], split: str = "val", ) -> None: - # TODO: Implement proper serverless logging via API - # For now, write to local jsonl file as a placeholder - import os - from pathlib import Path + # TODO: log trajectories to local file system? - from ..utils.trajectory_logging import serialize_trajectory_groups - - # Create log directory (configurable via env var) - log_base = os.getenv("ART_SERVERLESS_LOG_DIR", "/tmp/serverless-training-logs") - log_dir = Path(log_base) / model.name / split - log_dir.mkdir(parents=True, exist_ok=True) - - # Get current step - step = await model.get_step() - file_path = log_dir / f"{step:04d}.jsonl" + if not model.trainable: + print(f"Model {model.name} is not trainable; skipping logging.") + return - # Write trajectory groups to jsonl - with open(file_path, "w") as f: - f.write(serialize_trajectory_groups(trajectory_groups)) + await self._client.checkpoints.log_trajectories( + model_id=model.id, trajectory_groups=trajectory_groups, split=split + ) - print(f"[ServerlessBackend] Logged {len(trajectory_groups)} groups to {file_path}") async def _train_model( self,