Skip to content

Commit 66c82f0

Browse files
committed
refactor
1 parent 8f7bb5a commit 66c82f0

File tree

8 files changed

+96
-261
lines changed

8 files changed

+96
-261
lines changed

examples/tic_tac_toe/tic-tac-toe.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from rollout import TicTacToeScenario, rollout
99

1010
import art
11+
from art.utils.deployment import TogetherDeploymentConfig, deploy_model
1112
from art.utils.strip_logprobs import strip_logprobs
1213

1314
load_dotenv()
@@ -75,18 +76,27 @@ async def main():
7576

7677
if DEPLOY_MODEL:
7778
print("deploying")
78-
deployment_result = await backend._experimental_deploy(
79-
deploy_to="together",
79+
# Pull checkpoint (already local since we just trained, but ensures correct path)
80+
checkpoint_path = await backend._experimental_pull_model_checkpoint(
81+
model,
82+
step=STEP,
83+
verbose=True,
84+
)
85+
86+
# Deploy to Together
87+
deployment_result = await deploy_model(
8088
model=model,
89+
checkpoint_path=checkpoint_path,
8190
step=STEP,
91+
provider="together",
92+
config=TogetherDeploymentConfig(
93+
s3_bucket=os.environ.get("BACKUP_BUCKET"),
94+
wait_for_completion=True,
95+
),
8296
verbose=True,
83-
pull_s3=False,
84-
wait_for_completion=True,
8597
)
86-
if deployment_result.status == "Failed":
87-
raise Exception(f"Deployment failed: {deployment_result.failure_reason}")
8898

89-
deployed_model_name = deployment_result.model_name
99+
deployed_model_name = deployment_result.inference_model_name
90100

91101
lora_model = art.Model(
92102
name=deployed_model_name,

examples/tic_tac_toe_self_play/deploy_step.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from train import BASE_MODEL, CLUSTER_NAME, MODEL_NAME, PROJECT_NAME
77

88
import art
9+
from art.utils.deployment import TogetherDeploymentConfig, deploy_model
910

1011

1112
async def deploy_step():
@@ -44,18 +45,28 @@ async def deploy_step():
4445

4546
backend = LocalBackend()
4647

47-
deployment_result = await backend._experimental_deploy(
48-
deploy_to="together",
48+
# Pull checkpoint from S3
49+
checkpoint_path = await backend._experimental_pull_model_checkpoint(
50+
model,
51+
step=args.step,
52+
s3_bucket=os.environ.get("BACKUP_BUCKET"),
53+
verbose=True,
54+
)
55+
56+
# Deploy to Together
57+
deployment_result = await deploy_model(
4958
model=model,
59+
checkpoint_path=checkpoint_path,
5060
step=args.step,
61+
provider="together",
62+
config=TogetherDeploymentConfig(
63+
s3_bucket=os.environ.get("BACKUP_BUCKET"),
64+
wait_for_completion=True,
65+
),
5166
verbose=True,
52-
pull_s3=True,
53-
wait_for_completion=True,
5467
)
55-
if deployment_result.status == "Failed":
56-
raise Exception(f"Deployment failed: {deployment_result.failure_reason}")
5768

58-
deployed_model_name = deployment_result.model_name
69+
deployed_model_name = deployment_result.inference_model_name
5970

6071
lora_model = art.Model(
6172
name=deployed_model_name,

src/art/backend.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -237,40 +237,3 @@ async def _experimental_fork_checkpoint(
237237
timeout=600,
238238
)
239239
response.raise_for_status()
240-
241-
@log_http_errors
242-
async def _experimental_deploy(
243-
self,
244-
provider: Provider,
245-
model: "TrainableModel",
246-
step: int | None = None,
247-
config: TogetherDeploymentConfig | WandbDeploymentConfig | None = None,
248-
verbose: bool = False,
249-
pull_checkpoint: bool = True,
250-
) -> DeploymentResult:
251-
"""
252-
Deploy the model's latest checkpoint to a hosted inference endpoint.
253-
254-
Args:
255-
provider: The deployment provider ("together" or "wandb").
256-
model: The model to deploy.
257-
step: The checkpoint step to deploy. If None, deploys latest.
258-
config: Provider-specific deployment configuration.
259-
verbose: Whether to print verbose output.
260-
pull_checkpoint: Whether to pull the checkpoint first.
261-
"""
262-
response = await self._client.post(
263-
"/_experimental_deploy",
264-
json={
265-
"provider": provider,
266-
"model": model.safe_model_dump(),
267-
"step": step,
268-
"config": config.model_dump() if config else None,
269-
"config_type": type(config).__name__ if config else None,
270-
"verbose": verbose,
271-
"pull_checkpoint": pull_checkpoint,
272-
},
273-
timeout=600,
274-
)
275-
response.raise_for_status()
276-
return DeploymentResult(**response.json())

src/art/cli.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -114,47 +114,4 @@ async def _experimental_pull_from_s3(
114114
delete=delete,
115115
)
116116

117-
@app.post("/_experimental_push_to_s3")
118-
async def _experimental_push_to_s3(
119-
model: Model = Body(...),
120-
s3_bucket: str | None = Body(None),
121-
prefix: str | None = Body(None),
122-
verbose: bool = Body(False),
123-
delete: bool = Body(False),
124-
):
125-
await backend._experimental_push_to_s3(
126-
model=model,
127-
s3_bucket=s3_bucket,
128-
prefix=prefix,
129-
verbose=verbose,
130-
delete=delete,
131-
)
132-
133-
@app.post("/_experimental_deploy")
134-
async def _experimental_deploy(
135-
provider: Provider = Body(...),
136-
model: TrainableModel = Body(...),
137-
step: int | None = Body(None),
138-
config: dict | None = Body(None),
139-
config_type: str | None = Body(None),
140-
verbose: bool = Body(False),
141-
pull_checkpoint: bool = Body(True),
142-
):
143-
# Reconstruct config object from serialized data
144-
parsed_config = None
145-
if config is not None and config_type is not None:
146-
if config_type == "TogetherDeploymentConfig":
147-
parsed_config = TogetherDeploymentConfig(**config)
148-
elif config_type == "WandbDeploymentConfig":
149-
parsed_config = WandbDeploymentConfig(**config)
150-
151-
return await backend._experimental_deploy(
152-
provider=provider,
153-
model=model,
154-
step=step,
155-
config=parsed_config,
156-
verbose=verbose,
157-
pull_checkpoint=pull_checkpoint,
158-
)
159-
160117
uvicorn.run(app, host=host, port=port, loop="asyncio")

src/art/local/backend.py

Lines changed: 35 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -774,85 +774,49 @@ async def _experimental_pull_model_checkpoint(
774774
get_model_dir(model=model, art_path=self._path), resolved_step
775775
)
776776

777-
# Check if checkpoint exists in original location
778-
if os.path.exists(original_checkpoint_dir):
779-
if local_path is not None:
780-
# Copy from original location to custom location
781-
target_checkpoint_dir = os.path.join(local_path, f"{resolved_step:04d}")
782-
if os.path.exists(target_checkpoint_dir):
783-
if verbose:
784-
print(
785-
f"Checkpoint already exists at target location: {target_checkpoint_dir}"
786-
)
787-
return target_checkpoint_dir
788-
else:
789-
if verbose:
790-
print(
791-
f"Copying checkpoint from {original_checkpoint_dir} to {target_checkpoint_dir}..."
792-
)
793-
import shutil
794-
795-
os.makedirs(os.path.dirname(target_checkpoint_dir), exist_ok=True)
796-
shutil.copytree(original_checkpoint_dir, target_checkpoint_dir)
797-
if verbose:
798-
print(f"✓ Checkpoint copied successfully")
799-
return target_checkpoint_dir
800-
else:
801-
# No custom location, return original
802-
if verbose:
803-
print(
804-
f"Checkpoint step {resolved_step} exists at {original_checkpoint_dir}"
805-
)
806-
return original_checkpoint_dir
807-
else:
808-
# Checkpoint doesn't exist in original location, try S3
777+
# Step 1: Ensure checkpoint exists at original_checkpoint_dir
778+
if not os.path.exists(original_checkpoint_dir):
809779
if s3_bucket is None:
810780
raise FileNotFoundError(
811781
f"Checkpoint not found at {original_checkpoint_dir} and no S3 bucket specified"
812782
)
813-
# Pull from S3
814783
if verbose:
815784
print(f"Pulling checkpoint step {resolved_step} from S3...")
816-
817-
if local_path is not None:
818-
# Pull to custom location, then copy to flat structure
819-
# First pull to default structure
820-
await pull_model_from_s3(
821-
model_name=model.name,
822-
project=model.project,
823-
step=resolved_step,
824-
s3_bucket=s3_bucket,
825-
prefix=prefix,
826-
verbose=verbose,
827-
art_path=self._path,
828-
exclude=["logs", "trajectories"],
785+
await pull_model_from_s3(
786+
model_name=model.name,
787+
project=model.project,
788+
step=resolved_step,
789+
s3_bucket=s3_bucket,
790+
prefix=prefix,
791+
verbose=verbose,
792+
art_path=self._path,
793+
exclude=["logs", "trajectories"],
794+
)
795+
# Validate that the checkpoint was actually downloaded
796+
if not os.path.exists(original_checkpoint_dir) or not os.listdir(
797+
original_checkpoint_dir
798+
):
799+
raise FileNotFoundError(f"Checkpoint step {resolved_step} not found")
800+
801+
# Step 2: Handle local_path if provided
802+
if local_path is not None:
803+
if verbose:
804+
print(
805+
f"Copying checkpoint from {original_checkpoint_dir} to {local_path}..."
829806
)
830-
# Now copy to custom flat location
831-
target_checkpoint_dir = os.path.join(local_path, f"{resolved_step:04d}")
832-
if verbose:
833-
print(
834-
f"Copying checkpoint from {original_checkpoint_dir} to {target_checkpoint_dir}..."
835-
)
836-
import shutil
807+
import shutil
837808

838-
os.makedirs(os.path.dirname(target_checkpoint_dir), exist_ok=True)
839-
shutil.copytree(original_checkpoint_dir, target_checkpoint_dir)
840-
if verbose:
841-
print(f"✓ Checkpoint copied to custom location")
842-
return target_checkpoint_dir
843-
else:
844-
# Pull to default location
845-
await pull_model_from_s3(
846-
model_name=model.name,
847-
project=model.project,
848-
step=resolved_step,
849-
s3_bucket=s3_bucket,
850-
prefix=prefix,
851-
verbose=verbose,
852-
art_path=self._path,
853-
exclude=["logs", "trajectories"],
854-
)
855-
return original_checkpoint_dir
809+
os.makedirs(local_path, exist_ok=True)
810+
shutil.copytree(original_checkpoint_dir, local_path, dirs_exist_ok=True)
811+
if verbose:
812+
print(f"✓ Checkpoint copied successfully")
813+
return local_path
814+
815+
if verbose:
816+
print(
817+
f"Checkpoint step {resolved_step} exists at {original_checkpoint_dir}"
818+
)
819+
return original_checkpoint_dir
856820

857821
async def _experimental_pull_from_s3(
858822
self,
@@ -1129,63 +1093,3 @@ async def _experimental_fork_checkpoint(
11291093
print(
11301094
f"Successfully forked checkpoint from {from_model} (step {selected_step}) to {model.name}"
11311095
)
1132-
1133-
async def _experimental_deploy(
1134-
self,
1135-
provider: Provider,
1136-
model: "TrainableModel",
1137-
step: int | None = None,
1138-
config: TogetherDeploymentConfig | WandbDeploymentConfig | None = None,
1139-
verbose: bool = False,
1140-
pull_checkpoint: bool = True,
1141-
) -> DeploymentResult:
1142-
"""
1143-
Deploy the model's latest checkpoint to a hosted inference endpoint.
1144-
1145-
Args:
1146-
provider: The deployment provider ("together" or "wandb").
1147-
model: The model to deploy.
1148-
step: The checkpoint step to deploy. If None, deploys latest.
1149-
config: Provider-specific deployment configuration.
1150-
- For "together": TogetherDeploymentConfig (required)
1151-
- For "wandb": WandbDeploymentConfig (optional)
1152-
verbose: Whether to print verbose output.
1153-
pull_checkpoint: Whether to pull the checkpoint first.
1154-
"""
1155-
# Step 1: Pull checkpoint to local path if needed
1156-
if pull_checkpoint:
1157-
s3_bucket = (
1158-
config.s3_bucket
1159-
if isinstance(config, TogetherDeploymentConfig)
1160-
else None
1161-
)
1162-
prefix = (
1163-
config.prefix if isinstance(config, TogetherDeploymentConfig) else None
1164-
)
1165-
checkpoint_path = await self._experimental_pull_model_checkpoint(
1166-
model,
1167-
step=step,
1168-
s3_bucket=s3_bucket,
1169-
prefix=prefix,
1170-
verbose=verbose,
1171-
)
1172-
# Extract step from checkpoint path if not provided
1173-
if step is None:
1174-
step = int(os.path.basename(checkpoint_path))
1175-
else:
1176-
# Checkpoint should already exist locally
1177-
if step is None:
1178-
step = get_model_step(model, self._path)
1179-
checkpoint_path = get_step_checkpoint_dir(
1180-
get_model_dir(model=model, art_path=self._path), step
1181-
)
1182-
1183-
# Step 2: Deploy from local checkpoint
1184-
return await deploy_model(
1185-
model=model,
1186-
checkpoint_path=checkpoint_path,
1187-
step=step,
1188-
provider=provider,
1189-
config=config,
1190-
verbose=verbose,
1191-
)

0 commit comments

Comments
 (0)