@@ -774,85 +774,49 @@ async def _experimental_pull_model_checkpoint(
774774 get_model_dir (model = model , art_path = self ._path ), resolved_step
775775 )
776776
777- # Check if checkpoint exists in original location
778- if os .path .exists (original_checkpoint_dir ):
779- if local_path is not None :
780- # Copy from original location to custom location
781- target_checkpoint_dir = os .path .join (local_path , f"{ resolved_step :04d} " )
782- if os .path .exists (target_checkpoint_dir ):
783- if verbose :
784- print (
785- f"Checkpoint already exists at target location: { target_checkpoint_dir } "
786- )
787- return target_checkpoint_dir
788- else :
789- if verbose :
790- print (
791- f"Copying checkpoint from { original_checkpoint_dir } to { target_checkpoint_dir } ..."
792- )
793- import shutil
794-
795- os .makedirs (os .path .dirname (target_checkpoint_dir ), exist_ok = True )
796- shutil .copytree (original_checkpoint_dir , target_checkpoint_dir )
797- if verbose :
798- print (f"✓ Checkpoint copied successfully" )
799- return target_checkpoint_dir
800- else :
801- # No custom location, return original
802- if verbose :
803- print (
804- f"Checkpoint step { resolved_step } exists at { original_checkpoint_dir } "
805- )
806- return original_checkpoint_dir
807- else :
808- # Checkpoint doesn't exist in original location, try S3
777+ # Step 1: Ensure checkpoint exists at original_checkpoint_dir
778+ if not os .path .exists (original_checkpoint_dir ):
809779 if s3_bucket is None :
810780 raise FileNotFoundError (
811781 f"Checkpoint not found at { original_checkpoint_dir } and no S3 bucket specified"
812782 )
813- # Pull from S3
814783 if verbose :
815784 print (f"Pulling checkpoint step { resolved_step } from S3..." )
816-
817- if local_path is not None :
818- # Pull to custom location, then copy to flat structure
819- # First pull to default structure
820- await pull_model_from_s3 (
821- model_name = model .name ,
822- project = model .project ,
823- step = resolved_step ,
824- s3_bucket = s3_bucket ,
825- prefix = prefix ,
826- verbose = verbose ,
827- art_path = self ._path ,
828- exclude = ["logs" , "trajectories" ],
785+ await pull_model_from_s3 (
786+ model_name = model .name ,
787+ project = model .project ,
788+ step = resolved_step ,
789+ s3_bucket = s3_bucket ,
790+ prefix = prefix ,
791+ verbose = verbose ,
792+ art_path = self ._path ,
793+ exclude = ["logs" , "trajectories" ],
794+ )
795+ # Validate that the checkpoint was actually downloaded
796+ if not os .path .exists (original_checkpoint_dir ) or not os .listdir (
797+ original_checkpoint_dir
798+ ):
799+ raise FileNotFoundError (f"Checkpoint step { resolved_step } not found" )
800+
801+ # Step 2: Handle local_path if provided
802+ if local_path is not None :
803+ if verbose :
804+ print (
805+ f"Copying checkpoint from { original_checkpoint_dir } to { local_path } ..."
829806 )
830- # Now copy to custom flat location
831- target_checkpoint_dir = os .path .join (local_path , f"{ resolved_step :04d} " )
832- if verbose :
833- print (
834- f"Copying checkpoint from { original_checkpoint_dir } to { target_checkpoint_dir } ..."
835- )
836- import shutil
807+ import shutil
837808
838- os .makedirs (os .path .dirname (target_checkpoint_dir ), exist_ok = True )
839- shutil .copytree (original_checkpoint_dir , target_checkpoint_dir )
840- if verbose :
841- print (f"✓ Checkpoint copied to custom location" )
842- return target_checkpoint_dir
843- else :
844- # Pull to default location
845- await pull_model_from_s3 (
846- model_name = model .name ,
847- project = model .project ,
848- step = resolved_step ,
849- s3_bucket = s3_bucket ,
850- prefix = prefix ,
851- verbose = verbose ,
852- art_path = self ._path ,
853- exclude = ["logs" , "trajectories" ],
854- )
855- return original_checkpoint_dir
809+ os .makedirs (local_path , exist_ok = True )
810+ shutil .copytree (original_checkpoint_dir , local_path , dirs_exist_ok = True )
811+ if verbose :
812+ print (f"✓ Checkpoint copied successfully" )
813+ return local_path
814+
815+ if verbose :
816+ print (
817+ f"Checkpoint step { resolved_step } exists at { original_checkpoint_dir } "
818+ )
819+ return original_checkpoint_dir
856820
857821 async def _experimental_pull_from_s3 (
858822 self ,
@@ -1129,63 +1093,3 @@ async def _experimental_fork_checkpoint(
11291093 print (
11301094 f"Successfully forked checkpoint from { from_model } (step { selected_step } ) to { model .name } "
11311095 )
1132-
1133- async def _experimental_deploy (
1134- self ,
1135- provider : Provider ,
1136- model : "TrainableModel" ,
1137- step : int | None = None ,
1138- config : TogetherDeploymentConfig | WandbDeploymentConfig | None = None ,
1139- verbose : bool = False ,
1140- pull_checkpoint : bool = True ,
1141- ) -> DeploymentResult :
1142- """
1143- Deploy the model's latest checkpoint to a hosted inference endpoint.
1144-
1145- Args:
1146- provider: The deployment provider ("together" or "wandb").
1147- model: The model to deploy.
1148- step: The checkpoint step to deploy. If None, deploys latest.
1149- config: Provider-specific deployment configuration.
1150- - For "together": TogetherDeploymentConfig (required)
1151- - For "wandb": WandbDeploymentConfig (optional)
1152- verbose: Whether to print verbose output.
1153- pull_checkpoint: Whether to pull the checkpoint first.
1154- """
1155- # Step 1: Pull checkpoint to local path if needed
1156- if pull_checkpoint :
1157- s3_bucket = (
1158- config .s3_bucket
1159- if isinstance (config , TogetherDeploymentConfig )
1160- else None
1161- )
1162- prefix = (
1163- config .prefix if isinstance (config , TogetherDeploymentConfig ) else None
1164- )
1165- checkpoint_path = await self ._experimental_pull_model_checkpoint (
1166- model ,
1167- step = step ,
1168- s3_bucket = s3_bucket ,
1169- prefix = prefix ,
1170- verbose = verbose ,
1171- )
1172- # Extract step from checkpoint path if not provided
1173- if step is None :
1174- step = int (os .path .basename (checkpoint_path ))
1175- else :
1176- # Checkpoint should already exist locally
1177- if step is None :
1178- step = get_model_step (model , self ._path )
1179- checkpoint_path = get_step_checkpoint_dir (
1180- get_model_dir (model = model , art_path = self ._path ), step
1181- )
1182-
1183- # Step 2: Deploy from local checkpoint
1184- return await deploy_model (
1185- model = model ,
1186- checkpoint_path = checkpoint_path ,
1187- step = step ,
1188- provider = provider ,
1189- config = config ,
1190- verbose = verbose ,
1191- )
0 commit comments