@@ -675,6 +675,148 @@ def _get_wandb_run(self, model: Model) -> Run | None:
675675 # Experimental support for S3
676676 # ------------------------------------------------------------------
677677
678+ async def _experimental_pull_model_checkpoint (
679+ self ,
680+ model : "TrainableModel" ,
681+ * ,
682+ step : int | Literal ["latest" ] | None = None ,
683+ local_path : str | None = None ,
684+ s3_bucket : str | None = None ,
685+ prefix : str | None = None ,
686+ verbose : bool = False ,
687+ ) -> str :
688+ """Pull a model checkpoint to a local path.
689+
690+ For LocalBackend, this:
691+ 1. Checks if checkpoint exists in the original training location (self._path)
692+ 2. If exists and local_path is provided, copies it to the custom location
693+ 3. If doesn't exist and s3_bucket is provided, pulls from S3
694+ 4. Returns the final checkpoint path
695+
696+ Args:
697+ model: The model to pull checkpoint for.
698+ step: The step to pull. Can be an int for a specific step,
699+ or "latest" to pull the latest checkpoint. If None, pulls latest.
700+ local_path: Custom directory to save/copy the checkpoint to.
701+ If None, returns checkpoint from backend's default art path.
702+ s3_bucket: S3 bucket to pull from if checkpoint doesn't exist locally.
703+ prefix: S3 prefix.
704+ verbose: Whether to print verbose output.
705+
706+ Returns:
707+ Path to the local checkpoint directory.
708+ """
709+ # Determine which step to use
710+ resolved_step : int
711+ if step is None or step == "latest" :
712+ if s3_bucket is not None :
713+ # Get latest from S3
714+ from art .utils .s3_checkpoint_utils import (
715+ get_latest_checkpoint_step_from_s3 ,
716+ )
717+
718+ latest_step = await get_latest_checkpoint_step_from_s3 (
719+ model_name = model .name ,
720+ project = model .project ,
721+ s3_bucket = s3_bucket ,
722+ prefix = prefix ,
723+ )
724+ if latest_step is None :
725+ raise ValueError (
726+ f"No checkpoints found in S3 for { model .project } /{ model .name } "
727+ )
728+ resolved_step = latest_step
729+ else :
730+ # Get latest from default training location
731+ resolved_step = get_model_step (model , self ._path )
732+ else :
733+ resolved_step = step
734+
735+ # Check if checkpoint exists in the original training location
736+ original_checkpoint_dir = get_step_checkpoint_dir (
737+ get_model_dir (model = model , art_path = self ._path ), resolved_step
738+ )
739+
740+ # Check if checkpoint exists in original location
741+ if os .path .exists (original_checkpoint_dir ):
742+ if local_path is not None :
743+ # Copy from original location to custom location
744+ target_checkpoint_dir = os .path .join (local_path , f"{ resolved_step :04d} " )
745+ if os .path .exists (target_checkpoint_dir ):
746+ if verbose :
747+ print (
748+ f"Checkpoint already exists at target location: { target_checkpoint_dir } "
749+ )
750+ return target_checkpoint_dir
751+ else :
752+ if verbose :
753+ print (
754+ f"Copying checkpoint from { original_checkpoint_dir } to { target_checkpoint_dir } ..."
755+ )
756+ import shutil
757+
758+ os .makedirs (os .path .dirname (target_checkpoint_dir ), exist_ok = True )
759+ shutil .copytree (original_checkpoint_dir , target_checkpoint_dir )
760+ if verbose :
761+ print (f"✓ Checkpoint copied successfully" )
762+ return target_checkpoint_dir
763+ else :
764+ # No custom location, return original
765+ if verbose :
766+ print (
767+ f"Checkpoint step { resolved_step } exists at { original_checkpoint_dir } "
768+ )
769+ return original_checkpoint_dir
770+ else :
771+ # Checkpoint doesn't exist in original location, try S3
772+ if s3_bucket is None :
773+ raise FileNotFoundError (
774+ f"Checkpoint not found at { original_checkpoint_dir } and no S3 bucket specified"
775+ )
776+ # Pull from S3
777+ if verbose :
778+ print (f"Pulling checkpoint step { resolved_step } from S3..." )
779+
780+ if local_path is not None :
781+ # Pull to custom location, then copy to flat structure
782+ # First pull to default structure
783+ await pull_model_from_s3 (
784+ model_name = model .name ,
785+ project = model .project ,
786+ step = resolved_step ,
787+ s3_bucket = s3_bucket ,
788+ prefix = prefix ,
789+ verbose = verbose ,
790+ art_path = self ._path ,
791+ exclude = ["logs" , "trajectories" ],
792+ )
793+ # Now copy to custom flat location
794+ target_checkpoint_dir = os .path .join (local_path , f"{ resolved_step :04d} " )
795+ if verbose :
796+ print (
797+ f"Copying checkpoint from { original_checkpoint_dir } to { target_checkpoint_dir } ..."
798+ )
799+ import shutil
800+
801+ os .makedirs (os .path .dirname (target_checkpoint_dir ), exist_ok = True )
802+ shutil .copytree (original_checkpoint_dir , target_checkpoint_dir )
803+ if verbose :
804+ print (f"✓ Checkpoint copied to custom location" )
805+ return target_checkpoint_dir
806+ else :
807+ # Pull to default location
808+ await pull_model_from_s3 (
809+ model_name = model .name ,
810+ project = model .project ,
811+ step = resolved_step ,
812+ s3_bucket = s3_bucket ,
813+ prefix = prefix ,
814+ verbose = verbose ,
815+ art_path = self ._path ,
816+ exclude = ["logs" , "trajectories" ],
817+ )
818+ return original_checkpoint_dir
819+
678820 async def _experimental_pull_from_s3 (
679821 self ,
680822 model : Model ,
@@ -950,23 +1092,53 @@ async def _experimental_deploy(
9501092 s3_bucket : str | None = None ,
9511093 prefix : str | None = None ,
9521094 verbose : bool = False ,
953- pull_s3 : bool = True ,
1095+ pull_checkpoint : bool = True ,
9541096 wait_for_completion : bool = True ,
9551097 ) -> LoRADeploymentJob :
9561098 """
9571099 Deploy the model's latest checkpoint to a hosted inference endpoint.
9581100
9591101 Together is currently the only supported provider. See link for supported base models:
9601102 https://docs.together.ai/docs/lora-inference#supported-base-models
1103+
1104+ Args:
1105+ deploy_to: The deployment provider.
1106+ model: The model to deploy.
1107+ step: The checkpoint step to deploy. If None, deploys latest.
1108+ s3_bucket: S3 bucket for checkpoint storage and presigned URL.
1109+ prefix: S3 prefix.
1110+ verbose: Whether to print verbose output.
1111+ pull_checkpoint: Whether to pull the checkpoint first (from S3 if needed).
1112+ wait_for_completion: Whether to wait for deployment to complete.
9611113 """
1114+ # Step 1: Pull checkpoint to local path if needed
1115+ if pull_checkpoint :
1116+ checkpoint_path = await self ._experimental_pull_model_checkpoint (
1117+ model ,
1118+ step = step ,
1119+ s3_bucket = s3_bucket ,
1120+ prefix = prefix ,
1121+ verbose = verbose ,
1122+ )
1123+ # Extract step from checkpoint path if not provided
1124+ if step is None :
1125+ step = int (os .path .basename (checkpoint_path ))
1126+ else :
1127+ # Checkpoint should already exist locally
1128+ if step is None :
1129+ step = get_model_step (model , self ._path )
1130+ checkpoint_path = get_step_checkpoint_dir (
1131+ get_model_dir (model = model , art_path = self ._path ), step
1132+ )
1133+
1134+ # Step 2: Deploy from local checkpoint
9621135 return await deploy_model (
9631136 deploy_to = deploy_to ,
9641137 model = model ,
1138+ checkpoint_path = checkpoint_path ,
9651139 step = step ,
9661140 s3_bucket = s3_bucket ,
9671141 prefix = prefix ,
9681142 verbose = verbose ,
969- pull_s3 = pull_s3 ,
9701143 wait_for_completion = wait_for_completion ,
971- art_path = self ._path ,
9721144 )
0 commit comments