Skip to content

Commit 8748d59

Browse files
committed
Add download checkpoint
1 parent a0ae38a commit 8748d59

File tree

9 files changed

+537
-49
lines changed

9 files changed

+537
-49
lines changed

scripts/deploy-model.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
from dotenv import load_dotenv
66

77
import art
8-
from art.utils.deploy_model import deploy_model
8+
from art.utils.deploy_model import LoRADeploymentProvider, deploy_model
99
from art.utils.get_model_step import get_model_step
10+
from art.utils.output_dirs import get_model_dir, get_step_checkpoint_dir
1011
from art.utils.s3 import pull_model_from_s3
1112

1213
load_dotenv()
@@ -86,14 +87,19 @@ async def deploy() -> None:
8687
f"using checkpoints from s3://{backup_bucket}…"
8788
)
8889

90+
# Construct the checkpoint path from the pulled model
91+
checkpoint_path = get_step_checkpoint_dir(
92+
get_model_dir(model=model, art_path=args.art_path), step
93+
)
94+
8995
deployment_result = await deploy_model(
90-
deploy_to="together",
96+
deploy_to=LoRADeploymentProvider.TOGETHER,
9197
model=model,
98+
checkpoint_path=checkpoint_path,
9299
step=step,
100+
s3_bucket=backup_bucket,
93101
verbose=True,
94-
pull_s3=False,
95102
wait_for_completion=True,
96-
art_path=args.art_path,
97103
)
98104

99105
if deployment_result.status == "Failed":

src/art/backend.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import TYPE_CHECKING, AsyncIterator, Literal
2+
from typing import TYPE_CHECKING, Any, AsyncIterator, Literal
33

44
import httpx
55
from tqdm import auto as tqdm
@@ -130,6 +130,49 @@ async def _train_model(
130130
# Experimental support for S3
131131
# ------------------------------------------------------------------
132132

133+
@log_http_errors
134+
async def _experimental_pull_model_checkpoint(
135+
self,
136+
model: "TrainableModel",
137+
*,
138+
step: int | Literal["latest"] | None = None,
139+
local_path: str | None = None,
140+
verbose: bool = False,
141+
**kwargs: Any,
142+
) -> str:
143+
"""Pull a model checkpoint to a local path.
144+
145+
This method downloads a specific checkpoint from the backend's storage
146+
(S3 for LocalBackend/SkyPilot, W&B artifacts for ServerlessBackend)
147+
to a local directory.
148+
149+
Args:
150+
model: The model to pull checkpoint for.
151+
step: The step to pull. Can be an int for a specific step,
152+
or "latest" to pull the latest checkpoint. If None, pulls latest.
153+
local_path: Local directory to save the checkpoint. If None, uses default art path.
154+
verbose: Whether to print verbose output.
155+
**kwargs: Backend-specific parameters:
156+
- s3_bucket (str | None): S3 bucket to pull from (LocalBackend/SkyPilotBackend)
157+
- prefix (str | None): S3 prefix (LocalBackend/SkyPilotBackend)
158+
159+
Returns:
160+
Path to the local checkpoint directory.
161+
"""
162+
response = await self._client.post(
163+
"/_experimental_pull_model_checkpoint",
164+
json={
165+
"model": model.safe_model_dump(),
166+
"step": step,
167+
"local_path": local_path,
168+
"verbose": verbose,
169+
**kwargs,
170+
},
171+
timeout=600,
172+
)
173+
response.raise_for_status()
174+
return response.json()["checkpoint_path"]
175+
133176
@log_http_errors
134177
async def _experimental_pull_from_s3(
135178
self,
@@ -233,7 +276,7 @@ async def _experimental_deploy(
233276
s3_bucket: str | None = None,
234277
prefix: str | None = None,
235278
verbose: bool = False,
236-
pull_s3: bool = True,
279+
pull_checkpoint: bool = True,
237280
wait_for_completion: bool = True,
238281
) -> LoRADeploymentJob:
239282
"""
@@ -251,7 +294,7 @@ async def _experimental_deploy(
251294
"s3_bucket": s3_bucket,
252295
"prefix": prefix,
253296
"verbose": verbose,
254-
"pull_s3": pull_s3,
297+
"pull_checkpoint": pull_checkpoint,
255298
"wait_for_completion": wait_for_completion,
256299
},
257300
timeout=600,

src/art/cli.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,31 @@ async def _experimental_push_to_s3(
126126
delete=delete,
127127
)
128128

129+
@app.post("/_experimental_pull_model_checkpoint")
130+
async def _experimental_pull_model_checkpoint(
131+
model: TrainableModel = Body(...),
132+
step: int | str | None = Body(None),
133+
local_path: str | None = Body(None),
134+
verbose: bool = Body(False),
135+
s3_bucket: str | None = Body(None),
136+
prefix: str | None = Body(None),
137+
):
138+
# Build kwargs for backend-specific parameters
139+
kwargs = {}
140+
if s3_bucket is not None:
141+
kwargs["s3_bucket"] = s3_bucket
142+
if prefix is not None:
143+
kwargs["prefix"] = prefix
144+
145+
checkpoint_path = await backend._experimental_pull_model_checkpoint(
146+
model=model,
147+
step=step,
148+
local_path=local_path,
149+
verbose=verbose,
150+
**kwargs,
151+
)
152+
return {"checkpoint_path": checkpoint_path}
153+
129154
@app.post("/_experimental_deploy")
130155
async def _experimental_deploy(
131156
deploy_to: LoRADeploymentProvider = Body(...),
@@ -134,7 +159,7 @@ async def _experimental_deploy(
134159
s3_bucket: str | None = Body(None),
135160
prefix: str | None = Body(None),
136161
verbose: bool = Body(False),
137-
pull_s3: bool = Body(True),
162+
pull_checkpoint: bool = Body(True),
138163
wait_for_completion: bool = Body(True),
139164
):
140165
return await backend._experimental_deploy(
@@ -144,7 +169,7 @@ async def _experimental_deploy(
144169
s3_bucket=s3_bucket,
145170
prefix=prefix,
146171
verbose=verbose,
147-
pull_s3=pull_s3,
172+
pull_checkpoint=pull_checkpoint,
148173
wait_for_completion=wait_for_completion,
149174
)
150175

src/art/local/backend.py

Lines changed: 175 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,148 @@ def _get_wandb_run(self, model: Model) -> Run | None:
675675
# Experimental support for S3
676676
# ------------------------------------------------------------------
677677

678+
async def _experimental_pull_model_checkpoint(
679+
self,
680+
model: "TrainableModel",
681+
*,
682+
step: int | Literal["latest"] | None = None,
683+
local_path: str | None = None,
684+
s3_bucket: str | None = None,
685+
prefix: str | None = None,
686+
verbose: bool = False,
687+
) -> str:
688+
"""Pull a model checkpoint to a local path.
689+
690+
For LocalBackend, this:
691+
1. Checks if checkpoint exists in the original training location (self._path)
692+
2. If exists and local_path is provided, copies it to the custom location
693+
3. If doesn't exist and s3_bucket is provided, pulls from S3
694+
4. Returns the final checkpoint path
695+
696+
Args:
697+
model: The model to pull checkpoint for.
698+
step: The step to pull. Can be an int for a specific step,
699+
or "latest" to pull the latest checkpoint. If None, pulls latest.
700+
local_path: Custom directory to save/copy the checkpoint to.
701+
If None, returns checkpoint from backend's default art path.
702+
s3_bucket: S3 bucket to pull from if checkpoint doesn't exist locally.
703+
prefix: S3 prefix.
704+
verbose: Whether to print verbose output.
705+
706+
Returns:
707+
Path to the local checkpoint directory.
708+
"""
709+
# Determine which step to use
710+
resolved_step: int
711+
if step is None or step == "latest":
712+
if s3_bucket is not None:
713+
# Get latest from S3
714+
from art.utils.s3_checkpoint_utils import (
715+
get_latest_checkpoint_step_from_s3,
716+
)
717+
718+
latest_step = await get_latest_checkpoint_step_from_s3(
719+
model_name=model.name,
720+
project=model.project,
721+
s3_bucket=s3_bucket,
722+
prefix=prefix,
723+
)
724+
if latest_step is None:
725+
raise ValueError(
726+
f"No checkpoints found in S3 for {model.project}/{model.name}"
727+
)
728+
resolved_step = latest_step
729+
else:
730+
# Get latest from default training location
731+
resolved_step = get_model_step(model, self._path)
732+
else:
733+
resolved_step = step
734+
735+
# Check if checkpoint exists in the original training location
736+
original_checkpoint_dir = get_step_checkpoint_dir(
737+
get_model_dir(model=model, art_path=self._path), resolved_step
738+
)
739+
740+
# Check if checkpoint exists in original location
741+
if os.path.exists(original_checkpoint_dir):
742+
if local_path is not None:
743+
# Copy from original location to custom location
744+
target_checkpoint_dir = os.path.join(local_path, f"{resolved_step:04d}")
745+
if os.path.exists(target_checkpoint_dir):
746+
if verbose:
747+
print(
748+
f"Checkpoint already exists at target location: {target_checkpoint_dir}"
749+
)
750+
return target_checkpoint_dir
751+
else:
752+
if verbose:
753+
print(
754+
f"Copying checkpoint from {original_checkpoint_dir} to {target_checkpoint_dir}..."
755+
)
756+
import shutil
757+
758+
os.makedirs(os.path.dirname(target_checkpoint_dir), exist_ok=True)
759+
shutil.copytree(original_checkpoint_dir, target_checkpoint_dir)
760+
if verbose:
761+
print(f"✓ Checkpoint copied successfully")
762+
return target_checkpoint_dir
763+
else:
764+
# No custom location, return original
765+
if verbose:
766+
print(
767+
f"Checkpoint step {resolved_step} exists at {original_checkpoint_dir}"
768+
)
769+
return original_checkpoint_dir
770+
else:
771+
# Checkpoint doesn't exist in original location, try S3
772+
if s3_bucket is None:
773+
raise FileNotFoundError(
774+
f"Checkpoint not found at {original_checkpoint_dir} and no S3 bucket specified"
775+
)
776+
# Pull from S3
777+
if verbose:
778+
print(f"Pulling checkpoint step {resolved_step} from S3...")
779+
780+
if local_path is not None:
781+
# Pull to custom location, then copy to flat structure
782+
# First pull to default structure
783+
await pull_model_from_s3(
784+
model_name=model.name,
785+
project=model.project,
786+
step=resolved_step,
787+
s3_bucket=s3_bucket,
788+
prefix=prefix,
789+
verbose=verbose,
790+
art_path=self._path,
791+
exclude=["logs", "trajectories"],
792+
)
793+
# Now copy to custom flat location
794+
target_checkpoint_dir = os.path.join(local_path, f"{resolved_step:04d}")
795+
if verbose:
796+
print(
797+
f"Copying checkpoint from {original_checkpoint_dir} to {target_checkpoint_dir}..."
798+
)
799+
import shutil
800+
801+
os.makedirs(os.path.dirname(target_checkpoint_dir), exist_ok=True)
802+
shutil.copytree(original_checkpoint_dir, target_checkpoint_dir)
803+
if verbose:
804+
print(f"✓ Checkpoint copied to custom location")
805+
return target_checkpoint_dir
806+
else:
807+
# Pull to default location
808+
await pull_model_from_s3(
809+
model_name=model.name,
810+
project=model.project,
811+
step=resolved_step,
812+
s3_bucket=s3_bucket,
813+
prefix=prefix,
814+
verbose=verbose,
815+
art_path=self._path,
816+
exclude=["logs", "trajectories"],
817+
)
818+
return original_checkpoint_dir
819+
678820
async def _experimental_pull_from_s3(
679821
self,
680822
model: Model,
@@ -950,23 +1092,53 @@ async def _experimental_deploy(
9501092
s3_bucket: str | None = None,
9511093
prefix: str | None = None,
9521094
verbose: bool = False,
953-
pull_s3: bool = True,
1095+
pull_checkpoint: bool = True,
9541096
wait_for_completion: bool = True,
9551097
) -> LoRADeploymentJob:
9561098
"""
9571099
Deploy the model's latest checkpoint to a hosted inference endpoint.
9581100
9591101
Together is currently the only supported provider. See link for supported base models:
9601102
https://docs.together.ai/docs/lora-inference#supported-base-models
1103+
1104+
Args:
1105+
deploy_to: The deployment provider.
1106+
model: The model to deploy.
1107+
step: The checkpoint step to deploy. If None, deploys latest.
1108+
s3_bucket: S3 bucket for checkpoint storage and presigned URL.
1109+
prefix: S3 prefix.
1110+
verbose: Whether to print verbose output.
1111+
pull_checkpoint: Whether to pull the checkpoint first (from S3 if needed).
1112+
wait_for_completion: Whether to wait for deployment to complete.
9611113
"""
1114+
# Step 1: Pull checkpoint to local path if needed
1115+
if pull_checkpoint:
1116+
checkpoint_path = await self._experimental_pull_model_checkpoint(
1117+
model,
1118+
step=step,
1119+
s3_bucket=s3_bucket,
1120+
prefix=prefix,
1121+
verbose=verbose,
1122+
)
1123+
# Extract step from checkpoint path if not provided
1124+
if step is None:
1125+
step = int(os.path.basename(checkpoint_path))
1126+
else:
1127+
# Checkpoint should already exist locally
1128+
if step is None:
1129+
step = get_model_step(model, self._path)
1130+
checkpoint_path = get_step_checkpoint_dir(
1131+
get_model_dir(model=model, art_path=self._path), step
1132+
)
1133+
1134+
# Step 2: Deploy from local checkpoint
9621135
return await deploy_model(
9631136
deploy_to=deploy_to,
9641137
model=model,
1138+
checkpoint_path=checkpoint_path,
9651139
step=step,
9661140
s3_bucket=s3_bucket,
9671141
prefix=prefix,
9681142
verbose=verbose,
969-
pull_s3=pull_s3,
9701143
wait_for_completion=wait_for_completion,
971-
art_path=self._path,
9721144
)

0 commit comments

Comments
 (0)