|
1 | 1 | import asyncio |
2 | | -from typing import TYPE_CHECKING, AsyncIterator, Literal |
| 2 | +from typing import TYPE_CHECKING, AsyncIterator, Literal, cast |
| 3 | +import os |
3 | 4 |
|
4 | 5 | from art.client import Client |
5 | 6 | from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider |
@@ -56,12 +57,16 @@ def _model_inference_name(self, model: "TrainableModel") -> str: |
56 | 57 | assert model.entity is not None, "Model entity is required" |
57 | 58 | return f"{model.entity}/{model.project}/{model.name}" |
58 | 59 |
|
59 | | - async def _get_step(self, model: "TrainableModel") -> int: |
60 | | - assert model.id is not None, "Model ID is required" |
61 | | - checkpoint = await self._client.checkpoints.retrieve( |
62 | | - model_id=model.id, step="latest" |
63 | | - ) |
64 | | - return checkpoint.step |
| 60 | + |
| 61 | + async def _get_step(self, model: "Model") -> int: |
| 62 | + if model.trainable: |
| 63 | + assert model.id is not None, "Model ID is required" |
| 64 | + checkpoint = await self._client.checkpoints.retrieve( |
| 65 | + model_id=model.id, step="latest" |
| 66 | + ) |
| 67 | + return checkpoint.step |
| 68 | + # Non-trainable models do not have checkpoints/steps; default to 0 |
| 69 | + return 0 |
65 | 70 |
|
66 | 71 | async def _delete_checkpoints( |
67 | 72 | self, |
@@ -99,27 +104,16 @@ async def _log( |
99 | 104 | trajectory_groups: list[TrajectoryGroup], |
100 | 105 | split: str = "val", |
101 | 106 | ) -> None: |
102 | | - # TODO: Implement proper serverless logging via API |
103 | | - # For now, write to local jsonl file as a placeholder |
104 | | - import os |
105 | | - from pathlib import Path |
| 107 | + # TODO: log trajectories to local file system? |
106 | 108 |
|
107 | | - from ..utils.trajectory_logging import serialize_trajectory_groups |
108 | | - |
109 | | - # Create log directory (configurable via env var) |
110 | | - log_base = os.getenv("ART_SERVERLESS_LOG_DIR", "/tmp/serverless-training-logs") |
111 | | - log_dir = Path(log_base) / model.name / split |
112 | | - log_dir.mkdir(parents=True, exist_ok=True) |
113 | | - |
114 | | - # Get current step |
115 | | - step = await model.get_step() |
116 | | - file_path = log_dir / f"{step:04d}.jsonl" |
| 109 | + if not model.trainable: |
| 110 | + print(f"Model {model.name} is not trainable; skipping logging.") |
| 111 | + return |
117 | 112 |
|
118 | | - # Write trajectory groups to jsonl |
119 | | - with open(file_path, "w") as f: |
120 | | - f.write(serialize_trajectory_groups(trajectory_groups)) |
| 113 | + await self._client.checkpoints.log_trajectories( |
| 114 | + model_id=model.id, trajectory_groups=trajectory_groups, split=split |
| 115 | + ) |
121 | 116 |
|
122 | | - print(f"[ServerlessBackend] Logged {len(trajectory_groups)} groups to {file_path}") |
123 | 117 |
|
124 | 118 | async def _train_model( |
125 | 119 | self, |
|
0 commit comments