diff --git a/examples/tic_tac_toe/tic-tac-toe.py b/examples/tic_tac_toe/tic-tac-toe.py index f456d350..72d3b759 100644 --- a/examples/tic_tac_toe/tic-tac-toe.py +++ b/examples/tic_tac_toe/tic-tac-toe.py @@ -8,6 +8,7 @@ from rollout import TicTacToeScenario, rollout import art +from art.utils.deployment import TogetherDeploymentConfig, deploy_model from art.utils.strip_logprobs import strip_logprobs load_dotenv() @@ -75,18 +76,28 @@ async def main(): if DEPLOY_MODEL: print("deploying") - deployment_result = await backend._experimental_deploy( - deploy_to="together", + # Pull checkpoint (already local since we just trained, but ensures correct path) + checkpoint_path = await backend._experimental_pull_model_checkpoint( + model, + step=STEP, + s3_bucket=os.environ.get("BACKUP_BUCKET"), + verbose=True, + ) + + # Deploy to Together + deployment_result = await deploy_model( model=model, + checkpoint_path=checkpoint_path, step=STEP, + provider="together", + config=TogetherDeploymentConfig( + s3_bucket=os.environ.get("BACKUP_BUCKET"), + wait_for_completion=True, + ), verbose=True, - pull_s3=False, - wait_for_completion=True, ) - if deployment_result.status == "Failed": - raise Exception(f"Deployment failed: {deployment_result.failure_reason}") - deployed_model_name = deployment_result.model_name + deployed_model_name = deployment_result.inference_model_name lora_model = art.Model( name=deployed_model_name, diff --git a/examples/tic_tac_toe_self_play/deploy_step.py b/examples/tic_tac_toe_self_play/deploy_step.py index 858c55b4..c9594619 100644 --- a/examples/tic_tac_toe_self_play/deploy_step.py +++ b/examples/tic_tac_toe_self_play/deploy_step.py @@ -6,6 +6,7 @@ from train import BASE_MODEL, CLUSTER_NAME, MODEL_NAME, PROJECT_NAME import art +from art.utils.deployment import TogetherDeploymentConfig, deploy_model async def deploy_step(): @@ -44,18 +45,28 @@ async def deploy_step(): backend = LocalBackend() - deployment_result = await backend._experimental_deploy( - deploy_to="together", + # Pull checkpoint from S3 + checkpoint_path = await backend._experimental_pull_model_checkpoint( + model, + step=args.step, + s3_bucket=os.environ.get("BACKUP_BUCKET"), + verbose=True, + ) + + # Deploy to Together + deployment_result = await deploy_model( model=model, + checkpoint_path=checkpoint_path, step=args.step, + provider="together", + config=TogetherDeploymentConfig( + s3_bucket=os.environ.get("BACKUP_BUCKET"), + wait_for_completion=True, + ), verbose=True, - pull_s3=True, - wait_for_completion=True, ) - if deployment_result.status == "Failed": - raise Exception(f"Deployment failed: {deployment_result.failure_reason}") - deployed_model_name = deployment_result.model_name + deployed_model_name = deployment_result.inference_model_name lora_model = art.Model( name=deployed_model_name, diff --git a/pyproject.toml b/pyproject.toml index 834bc23e..077cb4ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ backend = [ "setproctitle>=1.3.6", "tblib>=3.0.0", "setuptools>=78.1.0", - "wandb==0.21.0", + "wandb==0.22.1", "polars>=1.26.0", "transformers==4.53.2", "trl==0.20.0", diff --git a/scripts/deploy-model.py b/scripts/deploy-model.py index f43c1486..004afb71 100644 --- a/scripts/deploy-model.py +++ b/scripts/deploy-model.py @@ -5,8 +5,9 @@ from dotenv import load_dotenv import art -from art.utils.deploy_model import deploy_model +from art.utils.deployment import TogetherDeploymentConfig, deploy_model from art.utils.get_model_step import get_model_step +from art.utils.output_dirs import get_model_dir, get_step_checkpoint_dir from art.utils.s3 import pull_model_from_s3 load_dotenv() @@ -86,23 +87,25 @@ async def deploy() -> None: f"using checkpoints from s3://{backup_bucket}…" ) + # Construct the checkpoint path from the pulled model + checkpoint_path = get_step_checkpoint_dir( + get_model_dir(model=model, art_path=args.art_path), step + ) + deployment_result = await deploy_model( - deploy_to="together", model=model, + checkpoint_path=checkpoint_path, step=step, + provider="together", + config=TogetherDeploymentConfig( + s3_bucket=backup_bucket, + wait_for_completion=True, + ), verbose=True, - pull_s3=False, - wait_for_completion=True, - art_path=args.art_path, ) - if deployment_result.status == "Failed": - raise RuntimeError(f"Deployment failed: {deployment_result.failure_reason}") - print("Deployment successful! ✨") - print( - f"Model deployed at Together under name: {deployment_result.model_name} (job_id={deployment_result.job_id})" - ) + print(f"Model deployed under name: {deployment_result.inference_model_name}") if __name__ == "__main__": diff --git a/src/art/backend.py b/src/art/backend.py index 9fa95c0e..f762a098 100644 --- a/src/art/backend.py +++ b/src/art/backend.py @@ -1,11 +1,17 @@ import json +import warnings from typing import TYPE_CHECKING, AsyncIterator, Literal import httpx from tqdm import auto as tqdm from art.utils import log_http_errors -from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider +from art.utils.deployment import ( + DeploymentResult, + Provider, + TogetherDeploymentConfig, + WandbDeploymentConfig, +) from . import dev from .trajectories import TrajectoryGroup @@ -143,10 +149,18 @@ async def _experimental_pull_from_s3( ) -> None: """Download the model directory from S3 into file system where the LocalBackend is running. Right now this can be used to pull trajectory logs for processing or model checkpoints. + .. deprecated:: + This method is deprecated. Use `_experimental_pull_model_checkpoint` instead. + Args: only_step: If specified, only pull this specific step. Can be an int for a specific step, or "latest" to pull only the latest checkpoint. If None, pulls all steps. """ + warnings.warn( + "_experimental_pull_from_s3 is deprecated. Use _experimental_pull_model_checkpoint instead.", + DeprecationWarning, + stacklevel=2, + ) response = await self._client.post( "/_experimental_pull_from_s3", json={ @@ -223,38 +237,3 @@ async def _experimental_fork_checkpoint( timeout=600, ) response.raise_for_status() - - @log_http_errors - async def _experimental_deploy( - self, - deploy_to: LoRADeploymentProvider, - model: "TrainableModel", - step: int | None = None, - s3_bucket: str | None = None, - prefix: str | None = None, - verbose: bool = False, - pull_s3: bool = True, - wait_for_completion: bool = True, - ) -> LoRADeploymentJob: - """ - Deploy the model's latest checkpoint to a hosted inference endpoint. - - Together is currently the only supported provider. See link for supported base models: - https://docs.together.ai/docs/lora-inference#supported-base-models - """ - response = await self._client.post( - "/_experimental_deploy", - json={ - "deploy_to": deploy_to, - "model": model.safe_model_dump(), - "step": step, - "s3_bucket": s3_bucket, - "prefix": prefix, - "verbose": verbose, - "pull_s3": pull_s3, - "wait_for_completion": wait_for_completion, - }, - timeout=600, - ) - response.raise_for_status() - return LoRADeploymentJob(**response.json()) diff --git a/src/art/cli.py b/src/art/cli.py index 6c7c4903..ca96d794 100644 --- a/src/art/cli.py +++ b/src/art/cli.py @@ -15,7 +15,11 @@ from .model import Model, TrainableModel from .trajectories import TrajectoryGroup from .types import TrainConfig -from .utils.deploy_model import LoRADeploymentProvider +from .utils.deployment import ( + Provider, + TogetherDeploymentConfig, + WandbDeploymentConfig, +) load_dotenv() @@ -126,26 +130,4 @@ async def _experimental_push_to_s3( delete=delete, ) - @app.post("/_experimental_deploy") - async def _experimental_deploy( - deploy_to: LoRADeploymentProvider = Body(...), - model: TrainableModel = Body(...), - step: int | None = Body(None), - s3_bucket: str | None = Body(None), - prefix: str | None = Body(None), - verbose: bool = Body(False), - pull_s3: bool = Body(True), - wait_for_completion: bool = Body(True), - ): - return await backend._experimental_deploy( - deploy_to=deploy_to, - model=model, - step=step, - s3_bucket=s3_bucket, - prefix=prefix, - verbose=verbose, - pull_s3=pull_s3, - wait_for_completion=wait_for_completion, - ) - uvicorn.run(app, host=host, port=port, loop="asyncio") diff --git a/src/art/local/backend.py b/src/art/local/backend.py index 13a906b4..5a0bc619 100644 --- a/src/art/local/backend.py +++ b/src/art/local/backend.py @@ -3,6 +3,7 @@ import math import os import subprocess +import warnings from datetime import datetime from types import TracebackType from typing import AsyncIterator, Literal, cast @@ -22,9 +23,11 @@ from wandb.sdk.wandb_run import Run from weave.trace.weave_client import WeaveClient -from art.utils.deploy_model import ( - LoRADeploymentJob, - LoRADeploymentProvider, +from art.utils.deployment import ( + DeploymentResult, + Provider, + TogetherDeploymentConfig, + WandbDeploymentConfig, deploy_model, ) from art.utils.old_benchmarking.calculate_step_metrics import calculate_step_std_dev @@ -675,6 +678,146 @@ def _get_wandb_run(self, model: Model) -> Run | None: # Experimental support for S3 # ------------------------------------------------------------------ + async def _experimental_pull_model_checkpoint( + self, + model: "TrainableModel", + *, + step: int | Literal["latest"] | None = None, + local_path: str | None = None, + s3_bucket: str | None = None, + prefix: str | None = None, + verbose: bool = False, + ) -> str: + """Pull a model checkpoint to a local path. + + For LocalBackend, this: + 1. When step is "latest" or None, checks both local storage and S3 (if provided) + to find the latest checkpoint, preferring local if steps are equal + 2. If checkpoint exists locally, uses it (optionally copying to local_path) + 3. If checkpoint doesn't exist locally but s3_bucket is provided, pulls from S3 + 4. Returns the final checkpoint path + + Args: + model: The model to pull checkpoint for. + step: The step to pull. Can be an int for a specific step, + or "latest" to pull the latest checkpoint. If None, pulls latest. + local_path: Custom directory to save/copy the checkpoint to. + If None, returns checkpoint from backend's default art path. + s3_bucket: S3 bucket to check/pull from. When step is "latest", both + local storage and S3 are checked to find the true latest. + prefix: S3 prefix. + verbose: Whether to print verbose output. + + Returns: + Path to the local checkpoint directory. + """ + # Determine which step to use + resolved_step: int + if step is None or step == "latest": + # Check both local storage and S3 (if provided) for the latest checkpoint + local_latest_step: int | None = None + s3_latest_step: int | None = None + + # Get latest from local storage + try: + local_latest_step = get_model_step(model, self._path) + if local_latest_step == 0: + # get_model_step returns 0 if no checkpoints exist + local_latest_step = None + except Exception: + local_latest_step = None + + # Get latest from S3 if bucket provided + if s3_bucket is not None: + from art.utils.s3_checkpoint_utils import ( + get_latest_checkpoint_step_from_s3, + ) + + s3_latest_step = await get_latest_checkpoint_step_from_s3( + model_name=model.name, + project=model.project, + s3_bucket=s3_bucket, + prefix=prefix, + ) + + # Determine which source has the latest checkpoint + if local_latest_step is None and s3_latest_step is None: + raise ValueError( + f"No checkpoints found for {model.project}/{model.name} in local storage or S3" + ) + elif local_latest_step is None: + resolved_step = s3_latest_step # type: ignore[assignment] + if verbose: + print(f"Using latest checkpoint from S3: step {resolved_step}") + elif s3_latest_step is None: + resolved_step = local_latest_step + if verbose: + print( + f"Using latest checkpoint from local storage: step {resolved_step}" + ) + elif local_latest_step >= s3_latest_step: + # Prefer local if equal or greater + resolved_step = local_latest_step + if verbose: + print( + f"Using latest checkpoint from local storage: step {resolved_step} " + ) + else: + resolved_step = s3_latest_step + if verbose: + print(f"Using latest checkpoint from S3: step {resolved_step} ") + else: + resolved_step = step + + # Check if checkpoint exists in the original training location + original_checkpoint_dir = get_step_checkpoint_dir( + get_model_dir(model=model, art_path=self._path), resolved_step + ) + + # Step 1: Ensure checkpoint exists at original_checkpoint_dir + if not os.path.exists(original_checkpoint_dir): + if s3_bucket is None: + raise FileNotFoundError( + f"Checkpoint not found at {original_checkpoint_dir} and no S3 bucket specified" + ) + if verbose: + print(f"Pulling checkpoint step {resolved_step} from S3...") + await pull_model_from_s3( + model_name=model.name, + project=model.project, + step=resolved_step, + s3_bucket=s3_bucket, + prefix=prefix, + verbose=verbose, + art_path=self._path, + exclude=["logs", "trajectories"], + ) + # Validate that the checkpoint was actually downloaded + if not os.path.exists(original_checkpoint_dir) or not os.listdir( + original_checkpoint_dir + ): + raise FileNotFoundError(f"Checkpoint step {resolved_step} not found") + + # Step 2: Handle local_path if provided + if local_path is not None: + if verbose: + print( + f"Copying checkpoint from {original_checkpoint_dir} to {local_path}..." + ) + import shutil + + os.makedirs(local_path, exist_ok=True) + shutil.copytree(original_checkpoint_dir, local_path, dirs_exist_ok=True) + if verbose: + print(f"✓ Checkpoint copied successfully") + return local_path + + if verbose: + print( + f"Checkpoint step {resolved_step} exists at {original_checkpoint_dir}" + ) + return original_checkpoint_dir + async def _experimental_pull_from_s3( self, model: Model, @@ -690,6 +833,10 @@ async def _experimental_pull_from_s3( latest_only: bool = False, ) -> None: """Download the model directory from S3 into local Backend storage. Right now this can be used to pull trajectory logs for processing or model checkpoints. + + .. deprecated:: + This method is deprecated. Use `_experimental_pull_model_checkpoint` instead. + Args: model: The model to pull from S3. step: DEPRECATED. Use only_step instead. @@ -702,6 +849,11 @@ async def _experimental_pull_from_s3( only_step: If specified, only pull this specific step. Can be an int for a specific step, or "latest" to pull only the latest checkpoint. If None, pulls all steps. """ + warnings.warn( + "_experimental_pull_from_s3 is deprecated. Use _experimental_pull_model_checkpoint instead.", + DeprecationWarning, + stacklevel=2, + ) # Handle backward compatibility and new only_step parameter if only_step is None and latest_only: @@ -941,32 +1093,3 @@ async def _experimental_fork_checkpoint( print( f"Successfully forked checkpoint from {from_model} (step {selected_step}) to {model.name}" ) - - async def _experimental_deploy( - self, - deploy_to: LoRADeploymentProvider, - model: "TrainableModel", - step: int | None = None, - s3_bucket: str | None = None, - prefix: str | None = None, - verbose: bool = False, - pull_s3: bool = True, - wait_for_completion: bool = True, - ) -> LoRADeploymentJob: - """ - Deploy the model's latest checkpoint to a hosted inference endpoint. - - Together is currently the only supported provider. See link for supported base models: - https://docs.together.ai/docs/lora-inference#supported-base-models - """ - return await deploy_model( - deploy_to=deploy_to, - model=model, - step=step, - s3_bucket=s3_bucket, - prefix=prefix, - verbose=verbose, - pull_s3=pull_s3, - wait_for_completion=wait_for_completion, - art_path=self._path, - ) diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index f7b9f858..23dbca87 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -2,7 +2,7 @@ import random from dataclasses import dataclass from itertools import takewhile -from typing import Generator, cast +from typing import Any, Generator, cast import torch from PIL import Image @@ -154,7 +154,7 @@ def tokenize_trajectory( return None messages_and_choices = history.messages_and_choices[: last_assistant_index + 1] messages = get_messages(messages_and_choices) - tools = ( + tools: Any = ( [{"type": "function", "function": tool} for tool in history.tools] if history.tools is not None else None diff --git a/src/art/serverless/backend.py b/src/art/serverless/backend.py index bad5dd6a..9fce2753 100644 --- a/src/art/serverless/backend.py +++ b/src/art/serverless/backend.py @@ -1,11 +1,17 @@ import asyncio +import warnings from typing import TYPE_CHECKING, AsyncIterator, Literal from openai._types import NOT_GIVEN from tqdm import auto as tqdm from art.serverless.client import Client, ExperimentalTrainingConfig -from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider +from art.utils.deployment import ( + DeploymentResult, + Provider, + TogetherDeploymentConfig, + WandbDeploymentConfig, +) from .. import dev from ..backend import Backend @@ -54,6 +60,26 @@ async def register( model.id = client_model.id model.entity = client_model.entity + async def delete( + self, + model: "Model", + ) -> None: + """ + Deletes a model from the Backend. + + Args: + model: An art.Model instance to delete. + """ + from art import TrainableModel + + if not isinstance(model, TrainableModel): + print( + "Deleting a non-trainable model from the Serverless backend is not supported." + ) + return + assert model.id is not None, "Model ID is required" + await self._client.models.delete(model_id=model.id) + def _model_inference_name(self, model: "TrainableModel") -> str: assert model.entity is not None, "Model entity is required" return f"wandb-artifact:///{model.entity}/{model.project}/{model.name}" @@ -133,7 +159,9 @@ async def _train_model( epsilon_high=dev_config.get("epsilon_high"), importance_sampling_level=dev_config.get("importance_sampling_level"), learning_rate=config.learning_rate, - max_negative_advantage_importance_sampling_weight=dev_config.get("max_negative_advantage_importance_sampling_weight"), + max_negative_advantage_importance_sampling_weight=dev_config.get( + "max_negative_advantage_importance_sampling_weight" + ), ppo=dev_config.get("ppo"), precalculate_logprobs=dev_config.get("precalculate_logprobs"), scale_rewards=dev_config.get("scale_rewards"), @@ -167,9 +195,93 @@ async def _train_model( after = event.id # ------------------------------------------------------------------ - # Experimental support for S3 + # Experimental support for S3 and checkpoints # ------------------------------------------------------------------ + async def _experimental_pull_model_checkpoint( + self, + model: "TrainableModel", + *, + step: int | Literal["latest"] | None = None, + local_path: str | None = None, + verbose: bool = False, + ) -> str: + """Pull a model checkpoint from W&B artifacts to a local path. + + For ServerlessBackend, this downloads the checkpoint from W&B artifact storage. + + Args: + model: The model to pull checkpoint for. + step: The step to pull. Can be an int for a specific step, + or "latest" to pull the latest checkpoint. If None, pulls latest. + local_path: Local directory to save the checkpoint. If None, uses temporary directory. + verbose: Whether to print verbose output. + + Returns: + Path to the local checkpoint directory. + """ + import os + import tempfile + + import wandb + + assert model.id is not None, "Model ID is required" + + # If entity is not set, use the user's default entity from W&B + api = wandb.Api(api_key=self._client.api_key) + if model.entity is None: + model.entity = api.default_entity + if verbose: + print(f"Using default W&B entity: {model.entity}") + + # Determine which step to use + resolved_step: int + if step is None or step == "latest": + # Get latest checkpoint from API + async for checkpoint in self._client.models.checkpoints.list( + limit=1, order="desc", model_id=model.id + ): + resolved_step = checkpoint.step + break + else: + raise ValueError(f"No checkpoints found for model {model.name}") + else: + resolved_step = step + + if verbose: + print(f"Downloading checkpoint step {resolved_step} from W&B artifacts...") + + # Download from W&B artifacts + # The artifact name follows the pattern: {entity}/{project}/{model_name}:step{step} + artifact_name = ( + f"{model.entity}/{model.project}/{model.name}:step{resolved_step}" + ) + + # Use wandb API to download (api was already created above for entity lookup) + artifact = api.artifact(artifact_name, type="lora") + + # Determine download path + if local_path is None: + # Create a temporary directory that won't be cleaned up automatically + checkpoint_dir = os.path.join( + tempfile.gettempdir(), + "art_checkpoints", + model.project, + model.name, + f"{resolved_step:04d}", + ) + else: + # Custom location - copy directly to local_path + checkpoint_dir = local_path + + # Download artifact + os.makedirs(checkpoint_dir, exist_ok=True) + artifact.download(root=checkpoint_dir) + if verbose: + print(f"Downloaded checkpoint to {checkpoint_dir}") + + return checkpoint_dir + async def _experimental_pull_from_s3( self, model: "Model", @@ -180,6 +292,12 @@ async def _experimental_pull_from_s3( delete: bool = False, only_step: int | Literal["latest"] | None = None, ) -> None: + """Deprecated. Use `_experimental_pull_model_checkpoint` instead.""" + warnings.warn( + "_experimental_pull_from_s3 is deprecated. Use _experimental_pull_model_checkpoint instead.", + DeprecationWarning, + stacklevel=2, + ) raise NotImplementedError async def _experimental_push_to_s3( @@ -204,16 +322,3 @@ async def _experimental_fork_checkpoint( prefix: str | None = None, ) -> None: raise NotImplementedError - - async def _experimental_deploy( - self, - deploy_to: LoRADeploymentProvider, - model: "TrainableModel", - step: int | None = None, - s3_bucket: str | None = None, - prefix: str | None = None, - verbose: bool = False, - pull_s3: bool = True, - wait_for_completion: bool = True, - ) -> LoRADeploymentJob: - raise NotImplementedError diff --git a/src/art/serverless/client.py b/src/art/serverless/client.py index 5c79d732..386f8661 100644 --- a/src/art/serverless/client.py +++ b/src/art/serverless/client.py @@ -120,6 +120,12 @@ async def log( cast_to=type(None), ) + async def delete(self, *, model_id: str) -> None: + return await self._delete( + f"/preview/models/{model_id}", + cast_to=type(None), + ) + @cached_property def checkpoints(self) -> "Checkpoints": return Checkpoints(cast(AsyncOpenAI, self._client)) diff --git a/src/art/skypilot/backend.py b/src/art/skypilot/backend.py index 8d910e7d..68bdfb78 100644 --- a/src/art/skypilot/backend.py +++ b/src/art/skypilot/backend.py @@ -1,7 +1,7 @@ import asyncio import os from importlib.metadata import PackageNotFoundError, version -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Literal, cast import semver import sky @@ -268,6 +268,115 @@ async def _prepare_backend_for_training( return (vllm_base_url, api_key) + async def _experimental_pull_model_checkpoint( + self, + model: "TrainableModel", + *, + step: int | Literal["latest"] | None = None, + local_path: str | None = None, + s3_bucket: str, + prefix: str | None = None, + verbose: bool = False, + ) -> str: + """Pull a model checkpoint to the client machine. + + For SkyPilotBackend, this always pulls from S3 to the client machine + (where the script is running), not to the remote cluster. + + Args: + model: The model to pull checkpoint for. + step: The step to pull. Can be an int for a specific step, + or "latest" to pull the latest checkpoint. If None, pulls latest. + local_path: Local directory to save the checkpoint. If None, uses default paths. + s3_bucket: S3 bucket to pull from. + prefix: S3 prefix. + verbose: Whether to print verbose output. + + Returns: + Path to the local checkpoint directory on the client machine. + """ + import os + + from art.utils.output_dirs import ( + get_default_art_path, + get_model_dir, + get_step_checkpoint_dir, + ) + from art.utils.s3 import pull_model_from_s3 + from art.utils.s3_checkpoint_utils import get_latest_checkpoint_step_from_s3 + + # Determine which step to use + resolved_step: int + if step is None or step == "latest": + # Get latest from S3 + latest_step = await get_latest_checkpoint_step_from_s3( + model_name=model.name, + project=model.project, + s3_bucket=s3_bucket, + prefix=prefix, + ) + if latest_step is None: + raise ValueError( + f"No checkpoints found in S3 for {model.project}/{model.name}" + ) + resolved_step = latest_step + else: + resolved_step = step + + # Determine target location + art_path = get_default_art_path() + if local_path is not None: + # Custom location - copy directly to local_path + checkpoint_dir = local_path + else: + # Standard ART structure + checkpoint_dir = get_step_checkpoint_dir( + get_model_dir(model=model, art_path=art_path), resolved_step + ) + + # Pull from S3 to client machine + if verbose: + print( + f"Pulling checkpoint step {resolved_step} from S3 to client machine..." + ) + + if local_path is not None: + # For custom location, pull to default location first, then copy + await pull_model_from_s3( + model_name=model.name, + project=model.project, + step=resolved_step, + s3_bucket=s3_bucket, + prefix=prefix, + verbose=verbose, + art_path=art_path, + exclude=["logs", "trajectories"], + ) + # Copy to custom location + temp_checkpoint_dir = get_step_checkpoint_dir( + get_model_dir(model=model, art_path=art_path), resolved_step + ) + import shutil + + os.makedirs(local_path, exist_ok=True) + shutil.copytree(temp_checkpoint_dir, local_path, dirs_exist_ok=True) + if verbose: + print(f"✓ Checkpoint copied to {local_path}") + else: + # Pull directly to standard location + await pull_model_from_s3( + model_name=model.name, + project=model.project, + step=resolved_step, + s3_bucket=s3_bucket, + prefix=prefix, + verbose=verbose, + art_path=art_path, + exclude=["logs", "trajectories"], + ) + + return checkpoint_dir + async def down(self) -> None: await to_thread_typed( lambda: sky.stream_and_get(sky.down(cluster_name=self._cluster_name)) diff --git a/src/art/utils/deploy_model.py b/src/art/utils/deploy_model.py index 4be83458..34a5b8af 100644 --- a/src/art/utils/deploy_model.py +++ b/src/art/utils/deploy_model.py @@ -1,285 +1,57 @@ -import asyncio -import json -import os -import time -from enum import Enum -from typing import TYPE_CHECKING, Any - -import aiohttp -from pydantic import BaseModel +""" +DEPRECATED: This module is deprecated. Import from art.utils.deployment instead. + +This file re-exports from the new location for backwards compatibility. +""" + +# Re-export everything from the new deployment module +from art.utils.deployment import ( + # New API + DeploymentConfig, + DeploymentResult, + # Legacy API + LoRADeploymentJob, + LoRADeploymentProvider, + Provider, + TogetherDeploymentConfig, + WandbDeploymentConfig, + deploy_model, + deploy_wandb, +) -from art.errors import ( - LoRADeploymentTimedOutError, - UnsupportedBaseModelDeploymentError, - UnsupportedLoRADeploymentProviderError, +# Also export these for any code that imports them directly +from art.utils.deployment.together import ( + TOGETHER_SUPPORTED_BASE_MODELS, + TogetherJobStatus, +) +from art.utils.deployment.wandb import ( + WANDB_SUPPORTED_BASE_MODELS, ) + +# Keep these imports for any code that uses them from art.utils.get_model_step import get_model_step from art.utils.output_dirs import get_default_art_path from art.utils.s3 import archive_and_presign_step_url, pull_model_from_s3 -if TYPE_CHECKING: - from art.model import TrainableModel - - -class LoRADeploymentProvider(str, Enum): - TOGETHER = "together" - - -class LoRADeploymentJobStatus(str, Enum): - QUEUED = "Queued" - RUNNING = "Running" - COMPLETE = "Complete" - FAILED = "Failed" - - -class LoRADeploymentJob(BaseModel): - status: LoRADeploymentJobStatus - job_id: str - model_name: str - failure_reason: str | None - - -def init_together_session() -> aiohttp.ClientSession: - """ - Initializes a session for interacting with Together. - """ - if "TOGETHER_API_KEY" not in os.environ: - raise ValueError("TOGETHER_API_KEY is not set, cannot deploy LoRA to Together") - session = aiohttp.ClientSession() - session.headers.update( - { - "Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}", - "Content-Type": "application/json", - } - ) - return session - - -def model_checkpoint_id(model: "TrainableModel", step: int) -> str: - """ - Generates a unique ID for a model checkpoint. - """ - return f"{model.project}-{model.name}-{step}" - - -TOGETHER_SUPPORTED_BASE_MODELS = [ - "meta-llama/Meta-Llama-3.1-8B-Instruct", - "meta-llama/Meta-Llama-3.1-70B-Instruct", - "Qwen/Qwen2.5-14B-Instruct", - "Qwen/Qwen2.5-72B-Instruct", +__all__ = [ + # New API + "DeploymentConfig", + "DeploymentResult", + "Provider", + "TogetherDeploymentConfig", + "WandbDeploymentConfig", + "deploy_model", + "deploy_wandb", + # Legacy API + "LoRADeploymentJob", + "LoRADeploymentProvider", + # Constants + "TOGETHER_SUPPORTED_BASE_MODELS", + "WANDB_SUPPORTED_BASE_MODELS", + "TogetherJobStatus", + # Utilities (for backwards compat) + "get_model_step", + "get_default_art_path", + "archive_and_presign_step_url", + "pull_model_from_s3", ] - - -async def deploy_together( - model: "TrainableModel", - presigned_url: str, - step: int, - verbose: bool = False, -) -> dict[str, Any]: - """ - Deploys a model to Together. Supported base models: - - * meta-llama/Meta-Llama-3.1-8B-Instruct - * meta-llama/Meta-Llama-3.1-70B-Instruct - * Qwen/Qwen2.5-14B-Instruct - * Qwen/Qwen2.5-72B-Instruct - """ - # check if base model is supported for serverless LoRA deployment by Together - if model.base_model not in TOGETHER_SUPPORTED_BASE_MODELS: - raise UnsupportedBaseModelDeploymentError( - message=f"Base model {model.base_model} is not supported for serverless LoRA deployment by Together. Supported models: {TOGETHER_SUPPORTED_BASE_MODELS}" - ) - - async with init_together_session() as session: - session.headers.update( - { - "Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}", - "Content-Type": "application/json", - } - ) - - async with session.post( - url="https://api.together.xyz/v1/models", - json={ - "model_name": model_checkpoint_id(model=model, step=step), - "model_source": presigned_url, - "model_type": "adapter", - "base_model": model.base_model, - "description": f"Deployed from ART. Project: {model.project}. Model: {model.name}. Step: {step}", - }, - ) as response: - if response.status != 200: - print("Error uploading to Together:", await response.text()) - response.raise_for_status() - result = await response.json() - if verbose: - print(f"Successfully uploaded to Together: {result}") - return result - - -def convert_together_job_status( - status: str, message: str | None = None -) -> LoRADeploymentJobStatus: - MODEL_ALREADY_EXISTS_ERROR_MESSAGE = "409 Client Error: Conflict for url: https://api.together.ai/api/admin/entity/Model" - if ( - status == "Error" - and message is not None - and MODEL_ALREADY_EXISTS_ERROR_MESSAGE in message - ): - return LoRADeploymentJobStatus.COMPLETE - if status == "Bad" or status == "Error": - return LoRADeploymentJobStatus.FAILED - if status == "Retry Queued": - return LoRADeploymentJobStatus.QUEUED - return LoRADeploymentJobStatus(status) - - -async def find_existing_together_job_id( - model: "TrainableModel", - step: int, -) -> str | None: - """ - Finds an existing model deployment job in Together. - """ - checkpoint_id = model_checkpoint_id(model, step) - async with init_together_session() as session: - async with session.get(url="https://api.together.xyz/v1/jobs") as response: - response.raise_for_status() - result = await response.json() - jobs = result["data"] - # ensure we get the most recent job - jobs.sort(key=lambda x: x["updated_at"], reverse=True) - for job in jobs: - if checkpoint_id in job["args"]["modelName"]: - return job["job_id"] - return None - - -async def check_together_job_status( - job_id: str, verbose: bool = False -) -> LoRADeploymentJob: - """ - Checks the status of a model deployment job in Together. - """ - async with init_together_session() as session: - async with session.get( - url=f"https://api.together.xyz/v1/jobs/{job_id}" - ) as response: - response.raise_for_status() - result = await response.json() - if verbose: - print(f"Job status: {json.dumps(result, indent=4)}") - - last_update = result["status_updates"][-1] - status_body = LoRADeploymentJob( - status=convert_together_job_status( - result["status"], last_update.get("message") - ), - job_id=job_id, - model_name=result["args"]["modelName"], - failure_reason=result.get("failure_reason"), - ) - - if status_body.status == LoRADeploymentJobStatus.FAILED: - status_body.failure_reason = last_update.get("message") - return status_body - - -async def wait_for_together_job( - job_id: str, verbose: bool = False -) -> LoRADeploymentJob: - """ - Waits for a model deployment job to complete in Together. - - Checks the status every 15 seconds for 5 minutes. - """ - print(f"checking status of job {job_id} every 15 seconds for 5 minutes") - start_time = time.time() - max_time = start_time + 300 - while time.time() < max_time: - job_status = await check_together_job_status(job_id, verbose) - if job_status.status == "Complete" or job_status.status == "Failed": - return job_status - await asyncio.sleep(15) - - raise LoRADeploymentTimedOutError( - message=f"LoRA deployment timed out after 5 minutes. Job ID: {job_id}" - ) - - -async def deploy_model( - deploy_to: LoRADeploymentProvider, - model: "TrainableModel", - step: int | None = None, - s3_bucket: str | None = None, - prefix: str | None = None, - verbose: bool = False, - pull_s3: bool = True, - wait_for_completion: bool = True, - art_path: str | None = get_default_art_path(), -) -> LoRADeploymentJob: - """ - Deploy the model's latest checkpoint to a hosted inference endpoint. - - Together is currently the only supported provider. See link for supported base models: - https://docs.together.ai/docs/lora-inference#supported-base-models - """ - - art_path = art_path or get_default_art_path() - os.makedirs(art_path, exist_ok=True) - if pull_s3: - # pull the latest step from S3 - await pull_model_from_s3( - model_name=model.name, - project=model.project, - step=step, - s3_bucket=s3_bucket, - prefix=prefix, - verbose=verbose, - art_path=art_path, - ) - - if step is None: - step = get_model_step(model, art_path) - - presigned_url = await archive_and_presign_step_url( - model_name=model.name, - project=model.project, - step=step, - s3_bucket=s3_bucket, - prefix=prefix, - verbose=verbose, - art_path=art_path, - ) - - if deploy_to == LoRADeploymentProvider.TOGETHER: - existing_job_id = await find_existing_together_job_id(model, step) - existing_job = None - if existing_job_id is not None: - existing_job = await check_together_job_status( - existing_job_id, verbose=verbose - ) - - if not existing_job or existing_job.status == "Failed": - deployment_result = await deploy_together( - model=model, - presigned_url=presigned_url, - step=step, - verbose=verbose, - ) - job_id = deployment_result["data"]["job_id"] - else: - job_id = existing_job_id - assert job_id is not None - print( - f"Previous deployment for {model.name} at step {step} has status '{existing_job.status}', skipping redployment" - ) - - if wait_for_completion: - return await wait_for_together_job(job_id, verbose=verbose) - else: - return await check_together_job_status(job_id, verbose=verbose) - - raise UnsupportedLoRADeploymentProviderError( - f"Unsupported deployment option: {deploy_to}" - ) diff --git a/src/art/utils/deployment/__init__.py b/src/art/utils/deployment/__init__.py new file mode 100644 index 00000000..20ef2b62 --- /dev/null +++ b/src/art/utils/deployment/__init__.py @@ -0,0 +1,35 @@ +"""Deployment utilities for deploying trained models to inference endpoints.""" + +from .common import ( + DeploymentConfig, + DeploymentResult, + Provider, + deploy_model, +) + +# Legacy exports for backwards compatibility +from .legacy import ( + LoRADeploymentJob, + LoRADeploymentProvider, +) +from .together import ( + TogetherDeploymentConfig, +) +from .wandb import ( + WandbDeploymentConfig, + deploy_wandb, +) + +__all__ = [ + # New API + "DeploymentConfig", + "DeploymentResult", + "Provider", + "TogetherDeploymentConfig", + "WandbDeploymentConfig", + "deploy_model", + "deploy_wandb", + # Legacy API + "LoRADeploymentJob", + "LoRADeploymentProvider", +] diff --git a/src/art/utils/deployment/common.py b/src/art/utils/deployment/common.py new file mode 100644 index 00000000..b1bf34f4 --- /dev/null +++ b/src/art/utils/deployment/common.py @@ -0,0 +1,117 @@ +"""Common types and the main deploy_model function.""" + +import os +from typing import TYPE_CHECKING, Literal + +from pydantic import BaseModel + +if TYPE_CHECKING: + from art.model import TrainableModel + + +Provider = Literal["together", "wandb"] + + +class DeploymentConfig(BaseModel): + """Base class for deployment configurations.""" + + pass + + +class DeploymentResult(BaseModel): + """Result of a deployment operation.""" + + inference_model_name: str + """The model name to use for inference (e.g., wandb-artifact:///entity/project/name:step1)""" + + +async def deploy_model( + model: "TrainableModel", + checkpoint_path: str, + step: int, + provider: Provider, + config: DeploymentConfig | None = None, + verbose: bool = False, +) -> DeploymentResult: + """Deploy a model checkpoint to a hosted inference endpoint. + + This function assumes the checkpoint is already available locally. Use + Backend.pull_model_checkpoint() to download checkpoints first. + + Args: + model: The TrainableModel to deploy. + checkpoint_path: Local path to the checkpoint directory. + step: The step number of the checkpoint. + provider: The deployment provider ("together" or "wandb"). + config: Provider-specific deployment configuration. + - For "together": TogetherDeploymentConfig (required) + - For "wandb": WandbDeploymentConfig (optional) + verbose: Whether to print verbose output. + + Returns: + DeploymentResult with the inference model name. + + Example: + ```python + # Deploy to W&B (config optional) + result = await deploy_model( + model=model, + checkpoint_path="/path/to/checkpoint", + step=5, + provider="wandb", + ) + print(result.inference_model_name) + # wandb-artifact:///entity/project/model:step5 + + # Deploy to Together (config required) + result = await deploy_model( + model=model, + checkpoint_path="/path/to/checkpoint", + step=5, + provider="together", + config=TogetherDeploymentConfig(s3_bucket="my-bucket"), + ) + ``` + """ + # Import here to avoid circular imports + from .together import TogetherDeploymentConfig, deploy_to_together + from .wandb import WandbDeploymentConfig, deploy_wandb + + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}") + + if provider == "wandb": + # W&B config is optional - use defaults if not provided + if config is not None and not isinstance(config, WandbDeploymentConfig): + raise TypeError( + f"Expected WandbDeploymentConfig for provider 'wandb', got {type(config).__name__}" + ) + inference_name = deploy_wandb( + model=model, + checkpoint_path=checkpoint_path, + step=step, + verbose=verbose, + ) + return DeploymentResult(inference_model_name=inference_name) + + if provider == "together": + # Together config is required + if config is None: + raise ValueError( + "Config is required for provider 'together'. " + "Please provide a TogetherDeploymentConfig with at least s3_bucket specified." + ) + if not isinstance(config, TogetherDeploymentConfig): + raise TypeError( + f"Expected TogetherDeploymentConfig for provider 'together', got {type(config).__name__}" + ) + inference_name = await deploy_to_together( + model=model, + checkpoint_path=checkpoint_path, + step=step, + config=config, + verbose=verbose, + ) + return DeploymentResult(inference_model_name=inference_name) + + raise ValueError(f"Unsupported provider: {provider}. Use 'together' or 'wandb'.") diff --git a/src/art/utils/deployment/legacy.py b/src/art/utils/deployment/legacy.py new file mode 100644 index 00000000..62673f70 --- /dev/null +++ b/src/art/utils/deployment/legacy.py @@ -0,0 +1,23 @@ +"""Legacy exports for backwards compatibility.""" + +from enum import Enum + +from pydantic import BaseModel + +from .together import TogetherJobStatus + + +class LoRADeploymentProvider(str, Enum): + """Legacy enum for deployment providers.""" + + TOGETHER = "together" + WANDB = "wandb" + + +class LoRADeploymentJob(BaseModel): + """Legacy result class for deployment jobs.""" + + status: TogetherJobStatus + job_id: str + model_name: str + failure_reason: str | None diff --git a/src/art/utils/deployment/together.py b/src/art/utils/deployment/together.py new file mode 100644 index 00000000..a82f48c2 --- /dev/null +++ b/src/art/utils/deployment/together.py @@ -0,0 +1,250 @@ +"""Together deployment functionality.""" + +import asyncio +import json +import os +import time +from enum import Enum +from typing import TYPE_CHECKING, Any + +import aiohttp +from pydantic import BaseModel + +from art.errors import ( + LoRADeploymentTimedOutError, + UnsupportedBaseModelDeploymentError, +) +from art.utils.s3 import archive_and_presign_step_url + +from .common import DeploymentConfig + +if TYPE_CHECKING: + from art.model import TrainableModel + + +class TogetherDeploymentConfig(DeploymentConfig): + """Configuration for deploying to Together. + + See supported base models: https://docs.together.ai/docs/lora-inference#supported-base-models + + Attributes: + s3_bucket: S3 bucket to upload the checkpoint archive to (for presigned URL). + prefix: S3 prefix for the upload. + wait_for_completion: Whether to wait for deployment to complete (default: True). + """ + + s3_bucket: str | None = None + prefix: str | None = None + wait_for_completion: bool = True + + +class TogetherJobStatus(str, Enum): + QUEUED = "Queued" + RUNNING = "Running" + COMPLETE = "Complete" + FAILED = "Failed" + + +class TogetherJob(BaseModel): + status: TogetherJobStatus + job_id: str + model_name: str + failure_reason: str | None + + +TOGETHER_SUPPORTED_BASE_MODELS = [ + "meta-llama/Meta-Llama-3.1-8B-Instruct", + "meta-llama/Meta-Llama-3.1-70B-Instruct", + "Qwen/Qwen2.5-14B-Instruct", + "Qwen/Qwen2.5-72B-Instruct", +] + + +def _init_session() -> aiohttp.ClientSession: + """Initializes a session for interacting with Together.""" + if "TOGETHER_API_KEY" not in os.environ: + raise ValueError("TOGETHER_API_KEY is not set, cannot deploy LoRA to Together") + session = aiohttp.ClientSession() + session.headers.update( + { + "Authorization": f"Bearer {os.environ['TOGETHER_API_KEY']}", + "Content-Type": "application/json", + } + ) + return session + + +def _model_checkpoint_id(model: "TrainableModel", step: int) -> str: + """Generates a unique ID for a model checkpoint.""" + return f"{model.project}-{model.name}-{step}" + + +async def _upload_model( + model: "TrainableModel", + presigned_url: str, + step: int, + verbose: bool = False, +) -> dict[str, Any]: + """Uploads a model to Together.""" + if model.base_model not in TOGETHER_SUPPORTED_BASE_MODELS: + raise UnsupportedBaseModelDeploymentError( + message=f"Base model {model.base_model} is not supported for serverless LoRA deployment by Together. Supported models: {TOGETHER_SUPPORTED_BASE_MODELS}" + ) + + async with _init_session() as session: + async with session.post( + url="https://api.together.xyz/v1/models", + json={ + "model_name": _model_checkpoint_id(model=model, step=step), + "model_source": presigned_url, + "model_type": "adapter", + "base_model": model.base_model, + "description": f"Deployed from ART. Project: {model.project}. Model: {model.name}. Step: {step}", + }, + ) as response: + if response.status != 200: + print("Error uploading to Together:", await response.text()) + response.raise_for_status() + result = await response.json() + if verbose: + print(f"Successfully uploaded to Together: {result}") + return result + + +def _convert_job_status(status: str, message: str | None = None) -> TogetherJobStatus: + MODEL_ALREADY_EXISTS_ERROR_MESSAGE = "409 Client Error: Conflict for url: https://api.together.ai/api/admin/entity/Model" + if ( + status == "Error" + and message is not None + and MODEL_ALREADY_EXISTS_ERROR_MESSAGE in message + ): + return TogetherJobStatus.COMPLETE + if status == "Bad" or status == "Error": + return TogetherJobStatus.FAILED + if status == "Retry Queued": + return TogetherJobStatus.QUEUED + return TogetherJobStatus(status) + + +async def _find_existing_job_id( + model: "TrainableModel", + step: int, +) -> str | None: + """Finds an existing model deployment job in Together.""" + checkpoint_id = _model_checkpoint_id(model, step) + async with _init_session() as session: + async with session.get(url="https://api.together.xyz/v1/jobs") as response: + response.raise_for_status() + result = await response.json() + jobs = result["data"] + jobs.sort(key=lambda x: x["updated_at"], reverse=True) + for job in jobs: + if checkpoint_id in job["args"]["modelName"]: + return job["job_id"] + return None + + +async def _check_job_status(job_id: str, verbose: bool = False) -> TogetherJob: + """Checks the status of a model deployment job in Together.""" + async with _init_session() as session: + async with session.get( + url=f"https://api.together.xyz/v1/jobs/{job_id}" + ) as response: + response.raise_for_status() + result = await response.json() + if verbose: + print(f"Job status: {json.dumps(result, indent=4)}") + + last_update = result["status_updates"][-1] + status_body = TogetherJob( + status=_convert_job_status( + result["status"], last_update.get("message") + ), + job_id=job_id, + model_name=result["args"]["modelName"], + failure_reason=result.get("failure_reason"), + ) + + if status_body.status == TogetherJobStatus.FAILED: + status_body.failure_reason = last_update.get("message") + return status_body + + +async def _wait_for_job(job_id: str, verbose: bool = False) -> TogetherJob: + """Waits for a model deployment job to complete in Together.""" + print(f"checking status of job {job_id} every 15 seconds for 5 minutes") + start_time = time.time() + max_time = start_time + 300 + while time.time() < max_time: + job_status = await _check_job_status(job_id, verbose) + if job_status.status in (TogetherJobStatus.COMPLETE, TogetherJobStatus.FAILED): + return job_status + await asyncio.sleep(15) + + raise LoRADeploymentTimedOutError( + message=f"LoRA deployment timed out after 5 minutes. Job ID: {job_id}" + ) + + +async def deploy_to_together( + model: "TrainableModel", + checkpoint_path: str, + step: int, + config: TogetherDeploymentConfig, + verbose: bool = False, +) -> str: + """Deploy a model checkpoint to Together. + + Args: + model: The TrainableModel to deploy. + checkpoint_path: Local path to the checkpoint directory. + step: The step number of the checkpoint. + config: Together deployment configuration. + verbose: Whether to print verbose output. + + Returns: + The inference model name. + """ + # Archive and upload to S3 to get a presigned URL for Together + presigned_url = await archive_and_presign_step_url( + model_name=model.name, + project=model.project, + step=step, + s3_bucket=config.s3_bucket, + prefix=config.prefix, + verbose=verbose, + checkpoint_path=checkpoint_path, + ) + + existing_job_id = await _find_existing_job_id(model, step) + existing_job = None + if existing_job_id is not None: + existing_job = await _check_job_status(existing_job_id, verbose=verbose) + + if not existing_job or existing_job.status == TogetherJobStatus.FAILED: + deployment_result = await _upload_model( + model=model, + presigned_url=presigned_url, + step=step, + verbose=verbose, + ) + job_id = deployment_result["data"]["job_id"] + else: + job_id = existing_job_id + assert job_id is not None + print( + f"Previous deployment for {model.name} at step {step} has status '{existing_job.status}', skipping redeployment" + ) + + if config.wait_for_completion: + job = await _wait_for_job(job_id, verbose=verbose) + else: + job = await _check_job_status(job_id, verbose=verbose) + + if job.status == TogetherJobStatus.FAILED: + raise RuntimeError( + f"Together deployment failed for {model.name} step {step}. " + f"Job ID: {job.job_id}. Reason: {job.failure_reason}" + ) + + return job.model_name diff --git a/src/art/utils/deployment/wandb.py b/src/art/utils/deployment/wandb.py new file mode 100644 index 00000000..a02ff7da --- /dev/null +++ b/src/art/utils/deployment/wandb.py @@ -0,0 +1,97 @@ +"""W&B deployment functionality.""" + +import os +from typing import TYPE_CHECKING + +from art.errors import UnsupportedBaseModelDeploymentError + +from .common import DeploymentConfig + +if TYPE_CHECKING: + from art.model import TrainableModel + + +class WandbDeploymentConfig(DeploymentConfig): + """Configuration for deploying to W&B. + + Supported base models: + - OpenPipe/Qwen3-14B-Instruct + - Qwen/Qwen2.5-14B-Instruct + """ + + pass + + +WANDB_SUPPORTED_BASE_MODELS = [ + "OpenPipe/Qwen3-14B-Instruct", + "Qwen/Qwen2.5-14B-Instruct", +] + + +def deploy_wandb( + model: "TrainableModel", + checkpoint_path: str, + step: int, + verbose: bool = False, +) -> str: + """Deploy a model to W&B by uploading a LoRA artifact. + + Args: + model: The TrainableModel to deploy. + checkpoint_path: Local path to the checkpoint directory. + step: The step number of the checkpoint. + verbose: Whether to print verbose output. + + Returns: + The model name for inference: wandb-artifact:///{entity}/{project}/{name}:step{step} + """ + import wandb + + if model.base_model not in WANDB_SUPPORTED_BASE_MODELS: + raise UnsupportedBaseModelDeploymentError( + message=f"Base model {model.base_model} is not supported for serverless LoRA deployment by W&B. Supported models: {WANDB_SUPPORTED_BASE_MODELS}" + ) + + if "WANDB_API_KEY" not in os.environ: + raise ValueError("WANDB_API_KEY is not set, cannot deploy LoRA to W&B") + + # Get the user's default entity from W&B if not set + if model.entity is None: + api = wandb.Api() + model.entity = api.default_entity + + if verbose: + print(f"Uploading checkpoint from {checkpoint_path} to W&B...") + + run = wandb.init( + entity=model.entity, + project=model.project, + settings=wandb.Settings(api_key=os.environ["WANDB_API_KEY"]), + ) + try: + artifact = wandb.Artifact( + model.name, + type="lora", + metadata={"wandb.base_model": model.base_model}, + storage_region="coreweave-us", + ) + artifact.add_dir(checkpoint_path) + artifact = run.log_artifact(artifact, aliases=[f"step{step}", "latest"]) + try: + artifact = artifact.wait() + except ValueError as e: + if "Unable to fetch artifact with id" in str(e): + if verbose: + print(f"Warning: {e}") + else: + raise e + finally: + run.finish() + + inference_name = ( + f"wandb-artifact:///{model.entity}/{model.project}/{model.name}:step{step}" + ) + if verbose: + print(f"Successfully deployed to W&B. Inference model name: {inference_name}") + + return inference_name diff --git a/src/art/utils/s3.py b/src/art/utils/s3.py index f6f6798d..f29c0dbb 100644 --- a/src/art/utils/s3.py +++ b/src/art/utils/s3.py @@ -268,16 +268,32 @@ async def archive_and_presign_step_url( verbose: bool = False, delete: bool = False, art_path: str | None = None, + checkpoint_path: str | None = None, ) -> str: - """Get a presigned URL for a step in a model.""" - model_output_dir = get_output_dir_from_model_properties( - project=project, - name=model_name, - art_path=art_path, - ) - local_step_dir = get_step_checkpoint_dir(model_output_dir, step) - if not os.path.exists(local_step_dir): - raise ValueError(f"Local step directory does not exist: {local_step_dir}") + """Get a presigned URL for a step in a model. + + Args: + model_name: Name of the model. + project: Project name. + step: Step number. + s3_bucket: S3 bucket to upload to. + prefix: S3 prefix. + verbose: Whether to print verbose output. + delete: Whether to delete after upload. + art_path: Path to ART directory (used if checkpoint_path not provided). + checkpoint_path: Direct path to the checkpoint directory. If provided, uses this + instead of constructing from art_path. + """ + if checkpoint_path is None: + model_output_dir = get_output_dir_from_model_properties( + project=project, + name=model_name, + art_path=art_path, + ) + checkpoint_path = get_step_checkpoint_dir(model_output_dir, step) + + if not os.path.exists(checkpoint_path): + raise ValueError(f"Local step directory does not exist: {checkpoint_path}") s3_step_path = build_s3_zipped_step_path( model_name=model_name, @@ -292,11 +308,11 @@ async def archive_and_presign_step_url( # Create zip archive archive_path = os.path.join(temp_dir, "model.zip") with zipfile.ZipFile(archive_path, "w", zipfile.ZIP_DEFLATED) as zipf: - for root, _, files in os.walk(local_step_dir): + for root, _, files in os.walk(checkpoint_path): for file in files: file_path = os.path.join(root, file) # Add file to zip with relative path - arcname = os.path.relpath(file_path, local_step_dir) + arcname = os.path.relpath(file_path, checkpoint_path) zipf.write(file_path, arcname) await ensure_bucket_exists(s3_bucket) diff --git a/uv.lock b/uv.lock index 3c27fd39..a5c87d7e 100644 --- a/uv.lock +++ b/uv.lock @@ -4217,7 +4217,7 @@ requires-dist = [ { name = "unsloth", marker = "extra == 'backend'", specifier = "==2025.10.3" }, { name = "unsloth-zoo", marker = "extra == 'backend'", specifier = "==2025.10.3" }, { name = "vllm", marker = "extra == 'backend'", specifier = ">=0.9.2,<=0.10.0" }, - { name = "wandb", marker = "extra == 'backend'", specifier = "==0.21.0" }, + { name = "wandb", marker = "extra == 'backend'", specifier = "==0.22.1" }, { name = "weave", specifier = ">=0.51.51" }, ] provides-extras = ["plotting", "backend", "skypilot", "langgraph"] @@ -7897,7 +7897,7 @@ wheels = [ [[package]] name = "wandb" -version = "0.21.0" +version = "0.22.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -7911,18 +7911,17 @@ dependencies = [ { name = "sentry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/73/09/c84264a219e20efd615e4d5d150cc7d359d57d51328d3fa94ee02d70ed9c/wandb-0.21.0.tar.gz", hash = "sha256:473e01ef200b59d780416062991effa7349a34e51425d4be5ff482af2dc39e02", size = 40085784, upload-time = "2025-07-02T00:24:15.516Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/38/dd/65eac086e1bc337bb5f0eed65ba1fe4a6dbc62c97f094e8e9df1ef83ffed/wandb-0.21.0-py3-none-any.whl", hash = "sha256:316e8cd4329738f7562f7369e6eabeeb28ef9d473203f7ead0d03e5dba01c90d", size = 6504284, upload-time = "2025-07-02T00:23:46.671Z" }, - { url = "https://files.pythonhosted.org/packages/17/a7/80556ce9097f59e10807aa68f4a9b29d736a90dca60852a9e2af1641baf8/wandb-0.21.0-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:701d9cbdfcc8550a330c1b54a26f1585519180e0f19247867446593d34ace46b", size = 21717388, upload-time = "2025-07-02T00:23:49.348Z" }, - { url = "https://files.pythonhosted.org/packages/23/ae/660bc75aa37bd23409822ea5ed616177d94873172d34271693c80405c820/wandb-0.21.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:01689faa6b691df23ba2367e0a1ecf6e4d0be44474905840098eedd1fbcb8bdf", size = 21141465, upload-time = "2025-07-02T00:23:52.602Z" }, - { url = "https://files.pythonhosted.org/packages/23/ab/9861929530be56557c74002868c85d0d8ac57050cc21863afe909ae3d46f/wandb-0.21.0-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:55d3f42ddb7971d1699752dff2b85bcb5906ad098d18ab62846c82e9ce5a238d", size = 21793511, upload-time = "2025-07-02T00:23:55.447Z" }, - { url = "https://files.pythonhosted.org/packages/de/52/e5cad2eff6fbed1ac06f4a5b718457fa2fd437f84f5c8f0d31995a2ef046/wandb-0.21.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:893508f0c7da48917448daa5cd622c27ce7ce15119adaa861185034c2bd7b14c", size = 20704643, upload-time = "2025-07-02T00:23:58.255Z" }, - { url = "https://files.pythonhosted.org/packages/83/8f/6bed9358cc33767c877b221d4f565e1ddf00caf4bbbe54d2e3bbc932c6a7/wandb-0.21.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4e8245a8912247ddf7654f7b5330f583a6c56ab88fee65589158490d583c57d", size = 22243012, upload-time = "2025-07-02T00:24:01.423Z" }, - { url = "https://files.pythonhosted.org/packages/be/61/9048015412ea5ca916844af55add4fed7c21fe1ad70bb137951e70b550c5/wandb-0.21.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2e4c4f951e0d02755e315679bfdcb5bc38c1b02e2e5abc5432b91a91bb0cf246", size = 20716440, upload-time = "2025-07-02T00:24:04.198Z" }, - { url = "https://files.pythonhosted.org/packages/02/d9/fcd2273d8ec3f79323e40a031aba5d32d6fa9065702010eb428b5ffbab62/wandb-0.21.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:873749966eeac0069e0e742e6210641b6227d454fb1dae2cf5c437c6ed42d3ca", size = 22320652, upload-time = "2025-07-02T00:24:07.175Z" }, - { url = "https://files.pythonhosted.org/packages/80/68/b8308db6b9c3c96dcd03be17c019aee105e1d7dc1e74d70756cdfb9241c6/wandb-0.21.0-py3-none-win32.whl", hash = "sha256:9d3cccfba658fa011d6cab9045fa4f070a444885e8902ae863802549106a5dab", size = 21484296, upload-time = "2025-07-02T00:24:10.147Z" }, - { url = "https://files.pythonhosted.org/packages/cf/96/71cc033e8abd00e54465e68764709ed945e2da2d66d764f72f4660262b22/wandb-0.21.0-py3-none-win_amd64.whl", hash = "sha256:28a0b2dad09d7c7344ac62b0276be18a2492a5578e4d7c84937a3e1991edaac7", size = 21484301, upload-time = "2025-07-02T00:24:12.658Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/ac/6f/be255b13157cab6c6670594171f59f67bd8a89f20d1978dc4eb892e2de27/wandb-0.22.1.tar.gz", hash = "sha256:6a1d668ecd6bd6531a73f6f7cfec0a93a08ef578c16ccf7167168c52cbf8cb12", size = 40246806, upload-time = "2025-09-29T17:15:55.207Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/73/de1d62301ef5d084160221637f34a821b7ed90b0769698b7b420686608ab/wandb-0.22.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:d862d0d28919556a5c32977138449812a2d127853e20b0deb39a5ab17700230f", size = 18370574, upload-time = "2025-09-29T17:15:18.236Z" }, + { url = "https://files.pythonhosted.org/packages/42/0e/4f60d9c7f1fa9d249dcbe70c6bad1573a9cfc070d00c3d8dbd62f715938d/wandb-0.22.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:ce57213de331717270020f7e07b098a0ea37646550b63758eabf8cb05eeb066f", size = 19392851, upload-time = "2025-09-29T17:15:22.093Z" }, + { url = "https://files.pythonhosted.org/packages/c6/8b/757ede4a581eece5e72ade51fd4b43cfedbd3e39b85fe29d0198bc98131b/wandb-0.22.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5a0ac652c23bf88e12bf0c04e911ff4f95696ac60a3612d81e54f1f8d89f3c5", size = 18171463, upload-time = "2025-09-29T17:15:24.588Z" }, + { url = "https://files.pythonhosted.org/packages/a2/e6/275d60292183d4de89fc9053887192f978fd8612e55c8f7719aa5c81bbd1/wandb-0.22.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dc14dd06938c282900dd9f72cbaaed45368b0c6b9bc2ffd1f45d07eeb13095b", size = 19585538, upload-time = "2025-09-29T17:15:28.432Z" }, + { url = "https://files.pythonhosted.org/packages/a8/5c/4199abb92d06de6ebd63ee33551ba0de6d91a814ac42e372dec6d8009ea0/wandb-0.22.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1b2d2b9d8d1ea8aea3c2cbf9de7696105432886ba9845c50e7cc71613aa6c8ef", size = 18210525, upload-time = "2025-09-29T17:15:33.459Z" }, + { url = "https://files.pythonhosted.org/packages/0d/00/a7719c048115825861a31435fa911887c9949b20096dbc7307e11b3c981b/wandb-0.22.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:051906d7f22bdf8c07c8837ffc6d9ae357d61dcd74cfb7d29fd0243e03da8f4a", size = 19680055, upload-time = "2025-09-29T17:15:38.521Z" }, + { url = "https://files.pythonhosted.org/packages/6f/5d/e6270557a315211e880c1efa9c1cab577f945d81168bc1c1187fd483f1bb/wandb-0.22.1-py3-none-win32.whl", hash = "sha256:b2df59bd70771329f27171f55d25d5557731bb0674d60db4735c173a8fb8076d", size = 18769036, upload-time = "2025-09-29T17:15:43.19Z" }, + { url = "https://files.pythonhosted.org/packages/92/fe/34cdfd491ea6c89495794f361102b727b922adcc4f3eedb47c8aa16984c3/wandb-0.22.1-py3-none-win_amd64.whl", hash = "sha256:c1b442e6de805d78743321200a27099517509f9e4aa2e6d330211a4809f932d7", size = 18769038, upload-time = "2025-09-29T17:15:46.977Z" }, + { url = "https://files.pythonhosted.org/packages/1e/9e/fe95f5d48ff10215b7d7e67dc998cba3f660027829fac2a67c79ce89e985/wandb-0.22.1-py3-none-win_arm64.whl", hash = "sha256:52758008c9ef9e7201113af08d6015322a699ebe3497a6e6fc885b39f5652b4d", size = 17077774, upload-time = "2025-09-29T17:15:50.588Z" }, ] [[package]]