Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions examples/tic_tac_toe/tic-tac-toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from rollout import TicTacToeScenario, rollout

import art
from art.utils.deployment import TogetherDeploymentConfig, deploy_model
from art.utils.strip_logprobs import strip_logprobs

load_dotenv()
Expand Down Expand Up @@ -75,18 +76,28 @@ async def main():

if DEPLOY_MODEL:
print("deploying")
deployment_result = await backend._experimental_deploy(
deploy_to="together",
# Pull checkpoint (already local since we just trained, but ensures correct path)
checkpoint_path = await backend._experimental_pull_model_checkpoint(
model,
step=STEP,
s3_bucket=os.environ.get("BACKUP_BUCKET"),
verbose=True,
)

# Deploy to Together
deployment_result = await deploy_model(
model=model,
checkpoint_path=checkpoint_path,
step=STEP,
provider="together",
config=TogetherDeploymentConfig(
s3_bucket=os.environ.get("BACKUP_BUCKET"),
wait_for_completion=True,
),
verbose=True,
pull_s3=False,
wait_for_completion=True,
)
if deployment_result.status == "Failed":
raise Exception(f"Deployment failed: {deployment_result.failure_reason}")

deployed_model_name = deployment_result.model_name
deployed_model_name = deployment_result.inference_model_name

lora_model = art.Model(
name=deployed_model_name,
Expand Down
25 changes: 18 additions & 7 deletions examples/tic_tac_toe_self_play/deploy_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from train import BASE_MODEL, CLUSTER_NAME, MODEL_NAME, PROJECT_NAME

import art
from art.utils.deployment import TogetherDeploymentConfig, deploy_model


async def deploy_step():
Expand Down Expand Up @@ -44,18 +45,28 @@ async def deploy_step():

backend = LocalBackend()

deployment_result = await backend._experimental_deploy(
deploy_to="together",
# Pull checkpoint from S3
checkpoint_path = await backend._experimental_pull_model_checkpoint(
model,
step=args.step,
s3_bucket=os.environ.get("BACKUP_BUCKET"),
verbose=True,
)

# Deploy to Together
deployment_result = await deploy_model(
model=model,
checkpoint_path=checkpoint_path,
step=args.step,
provider="together",
config=TogetherDeploymentConfig(
s3_bucket=os.environ.get("BACKUP_BUCKET"),
wait_for_completion=True,
),
verbose=True,
pull_s3=True,
wait_for_completion=True,
)
if deployment_result.status == "Failed":
raise Exception(f"Deployment failed: {deployment_result.failure_reason}")

deployed_model_name = deployment_result.model_name
deployed_model_name = deployment_result.inference_model_name

lora_model = art.Model(
name=deployed_model_name,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ backend = [
"setproctitle>=1.3.6",
"tblib>=3.0.0",
"setuptools>=78.1.0",
"wandb==0.21.0",
"wandb==0.22.1",
"polars>=1.26.0",
"transformers==4.53.2",
"trl==0.20.0",
Expand Down
25 changes: 14 additions & 11 deletions scripts/deploy-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from dotenv import load_dotenv

import art
from art.utils.deploy_model import deploy_model
from art.utils.deployment import TogetherDeploymentConfig, deploy_model
from art.utils.get_model_step import get_model_step
from art.utils.output_dirs import get_model_dir, get_step_checkpoint_dir
from art.utils.s3 import pull_model_from_s3

load_dotenv()
Expand Down Expand Up @@ -86,23 +87,25 @@ async def deploy() -> None:
f"using checkpoints from s3://{backup_bucket}…"
)

# Construct the checkpoint path from the pulled model
checkpoint_path = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=args.art_path), step
)

deployment_result = await deploy_model(
deploy_to="together",
model=model,
checkpoint_path=checkpoint_path,
step=step,
provider="together",
config=TogetherDeploymentConfig(
s3_bucket=backup_bucket,
wait_for_completion=True,
),
verbose=True,
pull_s3=False,
wait_for_completion=True,
art_path=args.art_path,
)

if deployment_result.status == "Failed":
raise RuntimeError(f"Deployment failed: {deployment_result.failure_reason}")

print("Deployment successful! ✨")
print(
f"Model deployed at Together under name: {deployment_result.model_name} (job_id={deployment_result.job_id})"
)
print(f"Model deployed under name: {deployment_result.inference_model_name}")


if __name__ == "__main__":
Expand Down
51 changes: 15 additions & 36 deletions src/art/backend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import json
import warnings
from typing import TYPE_CHECKING, AsyncIterator, Literal

import httpx
from tqdm import auto as tqdm

from art.utils import log_http_errors
from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider
from art.utils.deployment import (
DeploymentResult,
Provider,
TogetherDeploymentConfig,
WandbDeploymentConfig,
)

from . import dev
from .trajectories import TrajectoryGroup
Expand Down Expand Up @@ -143,10 +149,18 @@ async def _experimental_pull_from_s3(
) -> None:
"""Download the model directory from S3 into file system where the LocalBackend is running. Right now this can be used to pull trajectory logs for processing or model checkpoints.

.. deprecated::
This method is deprecated. Use `_experimental_pull_model_checkpoint` instead.

Args:
only_step: If specified, only pull this specific step. Can be an int for a specific step,
or "latest" to pull only the latest checkpoint. If None, pulls all steps.
"""
warnings.warn(
"_experimental_pull_from_s3 is deprecated. Use _experimental_pull_model_checkpoint instead.",
DeprecationWarning,
stacklevel=2,
)
response = await self._client.post(
"/_experimental_pull_from_s3",
json={
Expand Down Expand Up @@ -223,38 +237,3 @@ async def _experimental_fork_checkpoint(
timeout=600,
)
response.raise_for_status()

@log_http_errors
async def _experimental_deploy(
self,
deploy_to: LoRADeploymentProvider,
model: "TrainableModel",
step: int | None = None,
s3_bucket: str | None = None,
prefix: str | None = None,
verbose: bool = False,
pull_s3: bool = True,
wait_for_completion: bool = True,
) -> LoRADeploymentJob:
"""
Deploy the model's latest checkpoint to a hosted inference endpoint.

Together is currently the only supported provider. See link for supported base models:
https://docs.together.ai/docs/lora-inference#supported-base-models
"""
response = await self._client.post(
"/_experimental_deploy",
json={
"deploy_to": deploy_to,
"model": model.safe_model_dump(),
"step": step,
"s3_bucket": s3_bucket,
"prefix": prefix,
"verbose": verbose,
"pull_s3": pull_s3,
"wait_for_completion": wait_for_completion,
},
timeout=600,
)
response.raise_for_status()
return LoRADeploymentJob(**response.json())
28 changes: 5 additions & 23 deletions src/art/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from .model import Model, TrainableModel
from .trajectories import TrajectoryGroup
from .types import TrainConfig
from .utils.deploy_model import LoRADeploymentProvider
from .utils.deployment import (
Provider,
TogetherDeploymentConfig,
WandbDeploymentConfig,
)

load_dotenv()

Expand Down Expand Up @@ -126,26 +130,4 @@ async def _experimental_push_to_s3(
delete=delete,
)

@app.post("/_experimental_deploy")
async def _experimental_deploy(
deploy_to: LoRADeploymentProvider = Body(...),
model: TrainableModel = Body(...),
step: int | None = Body(None),
s3_bucket: str | None = Body(None),
prefix: str | None = Body(None),
verbose: bool = Body(False),
pull_s3: bool = Body(True),
wait_for_completion: bool = Body(True),
):
return await backend._experimental_deploy(
deploy_to=deploy_to,
model=model,
step=step,
s3_bucket=s3_bucket,
prefix=prefix,
verbose=verbose,
pull_s3=pull_s3,
wait_for_completion=wait_for_completion,
)

uvicorn.run(app, host=host, port=port, loop="asyncio")
Loading