Skip to content

Commit 523eb72

Browse files
authored
Report serverless metrics (#424)
* Log metrics to W&B Models and Training Endpoints * Fix logging * Update serverless backend logging * Remove imports * Remove report metrics stuff
1 parent e940243 commit 523eb72

File tree

2 files changed

+47
-25
lines changed

2 files changed

+47
-25
lines changed

src/art/client.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import os
35
from typing import AsyncIterator, Iterable, Literal, TypedDict, cast
@@ -36,6 +38,11 @@ class DeleteCheckpointsResponse(BaseModel):
3638
not_found_steps: list[int]
3739

3840

41+
42+
class LogResponse(BaseModel):
43+
success: bool
44+
45+
3946
class Checkpoints(AsyncAPIResource):
4047
async def retrieve(
4148
self, *, model_id: str, step: int | Literal["latest"]
@@ -81,6 +88,27 @@ async def delete(
8188
options=dict(max_retries=0),
8289
)
8390

91+
async def log_trajectories(
92+
self,
93+
*,
94+
model_id: str,
95+
trajectory_groups: list[TrajectoryGroup],
96+
split: str = "val",
97+
) -> LogResponse:
98+
return await self._post(
99+
f"/preview/models/{model_id}/log",
100+
body={
101+
"model_id": model_id,
102+
"trajectory_groups": [
103+
trajectory_group.model_dump()
104+
for trajectory_group in trajectory_groups
105+
],
106+
"split": split,
107+
},
108+
cast_to=LogResponse,
109+
options=dict(max_retries=0),
110+
)
111+
84112

85113
class Model(BaseModel):
86114
id: str

src/art/serverless/backend.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
2-
from typing import TYPE_CHECKING, AsyncIterator, Literal
2+
from typing import TYPE_CHECKING, AsyncIterator, Literal, cast
3+
import os
34

45
from art.client import Client
56
from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider
@@ -56,12 +57,16 @@ def _model_inference_name(self, model: "TrainableModel") -> str:
5657
assert model.entity is not None, "Model entity is required"
5758
return f"{model.entity}/{model.project}/{model.name}"
5859

59-
async def _get_step(self, model: "TrainableModel") -> int:
60-
assert model.id is not None, "Model ID is required"
61-
checkpoint = await self._client.checkpoints.retrieve(
62-
model_id=model.id, step="latest"
63-
)
64-
return checkpoint.step
60+
61+
async def _get_step(self, model: "Model") -> int:
62+
if model.trainable:
63+
assert model.id is not None, "Model ID is required"
64+
checkpoint = await self._client.checkpoints.retrieve(
65+
model_id=model.id, step="latest"
66+
)
67+
return checkpoint.step
68+
# Non-trainable models do not have checkpoints/steps; default to 0
69+
return 0
6570

6671
async def _delete_checkpoints(
6772
self,
@@ -99,27 +104,16 @@ async def _log(
99104
trajectory_groups: list[TrajectoryGroup],
100105
split: str = "val",
101106
) -> None:
102-
# TODO: Implement proper serverless logging via API
103-
# For now, write to local jsonl file as a placeholder
104-
import os
105-
from pathlib import Path
107+
# TODO: log trajectories to local file system?
106108

107-
from ..utils.trajectory_logging import serialize_trajectory_groups
108-
109-
# Create log directory (configurable via env var)
110-
log_base = os.getenv("ART_SERVERLESS_LOG_DIR", "/tmp/serverless-training-logs")
111-
log_dir = Path(log_base) / model.name / split
112-
log_dir.mkdir(parents=True, exist_ok=True)
113-
114-
# Get current step
115-
step = await model.get_step()
116-
file_path = log_dir / f"{step:04d}.jsonl"
109+
if not model.trainable:
110+
print(f"Model {model.name} is not trainable; skipping logging.")
111+
return
117112

118-
# Write trajectory groups to jsonl
119-
with open(file_path, "w") as f:
120-
f.write(serialize_trajectory_groups(trajectory_groups))
113+
await self._client.checkpoints.log_trajectories(
114+
model_id=model.id, trajectory_groups=trajectory_groups, split=split
115+
)
121116

122-
print(f"[ServerlessBackend] Logged {len(trajectory_groups)} groups to {file_path}")
123117

124118
async def _train_model(
125119
self,

0 commit comments

Comments
 (0)