Skip to content

Commit 5cdc6a3

Browse files
committed
Add download checkpoint
1 parent a0ae38a commit 5cdc6a3

File tree

10 files changed

+715
-49
lines changed

10 files changed

+715
-49
lines changed

dev/stage.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import asyncio
2+
import os
3+
import random
4+
from itertools import permutations
5+
6+
import art
7+
import litellm
8+
from art.serverless.backend import ServerlessBackend
9+
from art.utils.deploy_model import (
10+
LoRADeploymentProvider,
11+
deploy_model,
12+
)
13+
from art.utils.litellm import convert_litellm_choice_to_openai
14+
from dotenv import load_dotenv
15+
from litellm.types.utils import Choices, ModelResponse
16+
17+
load_dotenv()
18+
19+
20+
async def rollout(model: art.Model, scenario: str, step: int) -> art.Trajectory:
21+
messages: art.Messages = [
22+
{
23+
"role": "user",
24+
"content": scenario,
25+
}
26+
]
27+
response = await litellm.acompletion(
28+
messages=messages,
29+
model=f"openai/{model.get_inference_name()}",
30+
max_tokens=100,
31+
timeout=100,
32+
base_url=model.inference_base_url,
33+
api_key=model.inference_api_key,
34+
)
35+
assert isinstance(response, ModelResponse)
36+
choice = response.choices[0]
37+
assert isinstance(choice, Choices)
38+
content = choice.message.content
39+
assert isinstance(content, str)
40+
if content == "yes":
41+
reward = 0.5
42+
elif content == "no":
43+
reward = 0.75
44+
elif content == "maybe":
45+
reward = 1.0
46+
else:
47+
reward = 0.0
48+
return art.Trajectory(
49+
messages_and_choices=[*messages, convert_litellm_choice_to_openai(choice)],
50+
reward=reward,
51+
metrics={"custom_metric": random.random(), "run_step": step},
52+
)
53+
54+
55+
async def main() -> None:
56+
backend = ServerlessBackend(
57+
base_url="https://api.qa.training.wandb.ai/v1",
58+
api_key="be47e013c03bd1afc979794cde276bdd421de0f3",
59+
# api_key="be47e013c03bd1afc979794cde276bdd421de0f3", // production
60+
)
61+
model = art.TrainableModel(
62+
name="".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=8)),
63+
project="yes-no-maybe",
64+
base_model="Qwen/Qwen2.5-14B-Instruct",
65+
)
66+
await model.register(backend)
67+
print(f"Created model: {model.name}")
68+
69+
def with_quotes(w: str) -> str:
70+
return f"'{w}'"
71+
72+
scenarios = [
73+
f"{prefix} with {', '.join([with_quotes(w) if use_quotes else w for w in words]) if len(words) == 3 else f'{words[0]}' + (f' or {words[1]}' if len(words) > 1 else '')}"
74+
for prefix in ["respond", "just respond"]
75+
for use_quotes in [True, False]
76+
for words in (
77+
list(p) for n in [3, 2] for p in permutations(["yes", "no", "maybe"], n)
78+
)
79+
]
80+
random.seed(42)
81+
random.shuffle(scenarios)
82+
val_scenarios = scenarios[: len(scenarios) // 2]
83+
train_scenarios = scenarios[len(scenarios) // 2 :]
84+
85+
has_printed_step_warning = False
86+
target_steps = 1 # Train for 1 steps
87+
starting_step = await model.get_step()
88+
89+
for _step in range(starting_step, starting_step + target_steps):
90+
step = await model.get_step()
91+
if step != _step and not has_printed_step_warning:
92+
print(f"Warning: Step mismatch: {step} != {_step}")
93+
has_printed_step_warning = True
94+
val_groups = await art.gather_trajectory_groups(
95+
(
96+
art.TrajectoryGroup(rollout(model, scenario, step) for _ in range(8))
97+
for scenario in val_scenarios
98+
),
99+
pbar_desc=f"gather(val:{step})",
100+
)
101+
train_groups = await art.gather_trajectory_groups(
102+
(
103+
art.TrajectoryGroup(rollout(model, scenario, step) for _ in range(8))
104+
for scenario in train_scenarios
105+
),
106+
pbar_desc=f"gather(train:{step})",
107+
)
108+
await model.log(val_groups)
109+
await model.train(
110+
train_groups,
111+
config=art.TrainConfig(learning_rate=5e-5),
112+
_config=art.dev.TrainConfig(precalculate_logprobs=True),
113+
)
114+
await model.delete_checkpoints(best_checkpoint_metric="train/reward")
115+
116+
# Download the latest checkpoint to local directory (same folder as this script)
117+
print("\n" + "=" * 80)
118+
print("Downloading checkpoint to local directory...")
119+
print("=" * 80)
120+
121+
script_dir = os.path.dirname(os.path.abspath(__file__))
122+
checkpoint_path = await backend._experimental_pull_model_checkpoint(
123+
model, step="latest", local_path=script_dir, verbose=True
124+
)
125+
126+
print(f"\n✓ Checkpoint downloaded to: {checkpoint_path}")
127+
print("\nFiles in checkpoint directory:")
128+
print("-" * 80)
129+
130+
# List all files in the checkpoint directory
131+
for root, dirs, files in os.walk(checkpoint_path):
132+
level = root.replace(checkpoint_path, "").count(os.sep)
133+
indent = " " * 2 * level
134+
print(f"{indent}{os.path.basename(root)}/")
135+
subindent = " " * 2 * (level + 1)
136+
for file in files:
137+
file_size = os.path.getsize(os.path.join(root, file))
138+
# Format file size nicely
139+
if file_size < 1024:
140+
size_str = f"{file_size}B"
141+
elif file_size < 1024 * 1024:
142+
size_str = f"{file_size / 1024:.1f}KB"
143+
else:
144+
size_str = f"{file_size / (1024 * 1024):.1f}MB"
145+
print(f"{subindent}{file} ({size_str})")
146+
147+
# Deploy the checkpoint to Together
148+
print("\n" + "=" * 80)
149+
print("Deploying checkpoint to Together...")
150+
print("=" * 80)
151+
152+
# Extract step number from checkpoint path
153+
final_step = int(os.path.basename(checkpoint_path))
154+
155+
deployment_job = await deploy_model(
156+
deploy_to=LoRADeploymentProvider.TOGETHER,
157+
model=model,
158+
checkpoint_path=checkpoint_path,
159+
step=final_step,
160+
s3_bucket=None, # Will use default S3 bucket for presigned URL
161+
verbose=True,
162+
wait_for_completion=True,
163+
)
164+
165+
print(f"\n✓ Deployment complete!")
166+
print(f" Status: {deployment_job.status}")
167+
print(f" Job ID: {deployment_job.job_id}")
168+
print(f" Model Name: {deployment_job.model_name}")
169+
if deployment_job.failure_reason:
170+
print(f" Failure Reason: {deployment_job.failure_reason}")
171+
172+
print("\n" + "=" * 80)
173+
print(f"Training complete! Model: {model.name}")
174+
print("=" * 80)
175+
176+
177+
if __name__ == "__main__":
178+
asyncio.run(main())

scripts/deploy-model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from dotenv import load_dotenv
66

77
import art
8-
from art.utils.deploy_model import deploy_model
8+
from art.utils.deploy_model import LoRADeploymentProvider, deploy_model
99
from art.utils.get_model_step import get_model_step
10+
from art.utils.output_dirs import get_model_dir, get_step_checkpoint_dir
1011
from art.utils.s3 import pull_model_from_s3
1112

1213
load_dotenv()
@@ -86,14 +87,19 @@ async def deploy() -> None:
8687
f"using checkpoints from s3://{backup_bucket}…"
8788
)
8889

90+
# Construct the checkpoint path from the pulled model
91+
checkpoint_path = get_step_checkpoint_dir(
92+
get_model_dir(model=model, art_path=args.art_path), step
93+
)
94+
8995
deployment_result = await deploy_model(
90-
deploy_to="together",
96+
deploy_to=LoRADeploymentProvider.TOGETHER,
9197
model=model,
98+
checkpoint_path=checkpoint_path,
9299
step=step,
100+
s3_bucket=backup_bucket,
93101
verbose=True,
94-
pull_s3=False,
95102
wait_for_completion=True,
96-
art_path=args.art_path,
97103
)
98104

99105
if deployment_result.status == "Failed":

src/art/backend.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import TYPE_CHECKING, AsyncIterator, Literal
2+
from typing import TYPE_CHECKING, Any, AsyncIterator, Literal
33

44
import httpx
55
from tqdm import auto as tqdm
@@ -130,6 +130,49 @@ async def _train_model(
130130
# Experimental support for S3
131131
# ------------------------------------------------------------------
132132

133+
@log_http_errors
134+
async def _experimental_pull_model_checkpoint(
135+
self,
136+
model: "TrainableModel",
137+
*,
138+
step: int | Literal["latest"] | None = None,
139+
local_path: str | None = None,
140+
verbose: bool = False,
141+
**kwargs: Any,
142+
) -> str:
143+
"""Pull a model checkpoint to a local path.
144+
145+
This method downloads a specific checkpoint from the backend's storage
146+
(S3 for LocalBackend/SkyPilot, W&B artifacts for ServerlessBackend)
147+
to a local directory.
148+
149+
Args:
150+
model: The model to pull checkpoint for.
151+
step: The step to pull. Can be an int for a specific step,
152+
or "latest" to pull the latest checkpoint. If None, pulls latest.
153+
local_path: Local directory to save the checkpoint. If None, uses default art path.
154+
verbose: Whether to print verbose output.
155+
**kwargs: Backend-specific parameters:
156+
- s3_bucket (str | None): S3 bucket to pull from (LocalBackend/SkyPilotBackend)
157+
- prefix (str | None): S3 prefix (LocalBackend/SkyPilotBackend)
158+
159+
Returns:
160+
Path to the local checkpoint directory.
161+
"""
162+
response = await self._client.post(
163+
"/_experimental_pull_model_checkpoint",
164+
json={
165+
"model": model.safe_model_dump(),
166+
"step": step,
167+
"local_path": local_path,
168+
"verbose": verbose,
169+
**kwargs,
170+
},
171+
timeout=600,
172+
)
173+
response.raise_for_status()
174+
return response.json()["checkpoint_path"]
175+
133176
@log_http_errors
134177
async def _experimental_pull_from_s3(
135178
self,
@@ -233,7 +276,7 @@ async def _experimental_deploy(
233276
s3_bucket: str | None = None,
234277
prefix: str | None = None,
235278
verbose: bool = False,
236-
pull_s3: bool = True,
279+
pull_checkpoint: bool = True,
237280
wait_for_completion: bool = True,
238281
) -> LoRADeploymentJob:
239282
"""
@@ -251,7 +294,7 @@ async def _experimental_deploy(
251294
"s3_bucket": s3_bucket,
252295
"prefix": prefix,
253296
"verbose": verbose,
254-
"pull_s3": pull_s3,
297+
"pull_checkpoint": pull_checkpoint,
255298
"wait_for_completion": wait_for_completion,
256299
},
257300
timeout=600,

src/art/cli.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,31 @@ async def _experimental_push_to_s3(
126126
delete=delete,
127127
)
128128

129+
@app.post("/_experimental_pull_model_checkpoint")
130+
async def _experimental_pull_model_checkpoint(
131+
model: TrainableModel = Body(...),
132+
step: int | str | None = Body(None),
133+
local_path: str | None = Body(None),
134+
verbose: bool = Body(False),
135+
s3_bucket: str | None = Body(None),
136+
prefix: str | None = Body(None),
137+
):
138+
# Build kwargs for backend-specific parameters
139+
kwargs = {}
140+
if s3_bucket is not None:
141+
kwargs["s3_bucket"] = s3_bucket
142+
if prefix is not None:
143+
kwargs["prefix"] = prefix
144+
145+
checkpoint_path = await backend._experimental_pull_model_checkpoint(
146+
model=model,
147+
step=step,
148+
local_path=local_path,
149+
verbose=verbose,
150+
**kwargs,
151+
)
152+
return {"checkpoint_path": checkpoint_path}
153+
129154
@app.post("/_experimental_deploy")
130155
async def _experimental_deploy(
131156
deploy_to: LoRADeploymentProvider = Body(...),
@@ -134,7 +159,7 @@ async def _experimental_deploy(
134159
s3_bucket: str | None = Body(None),
135160
prefix: str | None = Body(None),
136161
verbose: bool = Body(False),
137-
pull_s3: bool = Body(True),
162+
pull_checkpoint: bool = Body(True),
138163
wait_for_completion: bool = Body(True),
139164
):
140165
return await backend._experimental_deploy(
@@ -144,7 +169,7 @@ async def _experimental_deploy(
144169
s3_bucket=s3_bucket,
145170
prefix=prefix,
146171
verbose=verbose,
147-
pull_s3=pull_s3,
172+
pull_checkpoint=pull_checkpoint,
148173
wait_for_completion=wait_for_completion,
149174
)
150175

0 commit comments

Comments
 (0)