@@ -407,9 +407,15 @@ def fit(
407407 self ._setup_trainer (model )
408408 if model .learning_type in {LearningType .ZERO_SHOT , LearningType .FEW_SHOT }:
409409 # if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding
410- self .trainer .validate (model , val_dataloaders , datamodule = datamodule , ckpt_path = ckpt_path )
410+ self .trainer .validate (
411+ model ,
412+ val_dataloaders ,
413+ datamodule = datamodule ,
414+ ckpt_path = ckpt_path ,
415+ weights_only = False ,
416+ )
411417 else :
412- self .trainer .fit (model , train_dataloaders , val_dataloaders , datamodule , ckpt_path )
418+ self .trainer .fit (model , train_dataloaders , val_dataloaders , datamodule , ckpt_path , weights_only = False )
413419
414420 def validate (
415421 self ,
@@ -457,7 +463,7 @@ def validate(
457463 ckpt_path = Path (ckpt_path ).resolve ()
458464 if model :
459465 self ._setup_trainer (model )
460- return self .trainer .validate (model , dataloaders , ckpt_path , verbose , datamodule )
466+ return self .trainer .validate (model , dataloaders , ckpt_path , verbose , datamodule , weights_only = False )
461467
462468 def test (
463469 self ,
@@ -551,7 +557,7 @@ def test(
551557 if self ._should_run_validation (model or self .model , ckpt_path ):
552558 logger .info ("Running validation before testing to collect normalization metrics and/or thresholds." )
553559 self .trainer .validate (model , dataloaders , None , verbose = False , datamodule = datamodule )
554- return self .trainer .test (model , dataloaders , ckpt_path , verbose , datamodule )
560+ return self .trainer .test (model , dataloaders , ckpt_path , verbose , datamodule , weights_only = False )
555561
556562 def predict (
557563 self ,
@@ -658,9 +664,10 @@ def predict(
658664 ckpt_path = None ,
659665 verbose = False ,
660666 datamodule = datamodule ,
667+ weights_only = False ,
661668 )
662669
663- return self .trainer .predict (model , dataloaders , datamodule , return_predictions , ckpt_path )
670+ return self .trainer .predict (model , dataloaders , datamodule , return_predictions , ckpt_path , weights_only = False )
664671
665672 def train (
666673 self ,
@@ -716,8 +723,14 @@ def train(
716723 # if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding
717724 self .trainer .validate (model , val_dataloaders , None , verbose = False , datamodule = datamodule )
718725 else :
719- self .trainer .fit (model , train_dataloaders , val_dataloaders , datamodule , ckpt_path )
720- return self .trainer .test (model , test_dataloaders , ckpt_path = ckpt_path , datamodule = datamodule )
726+ self .trainer .fit (model , train_dataloaders , val_dataloaders , datamodule , ckpt_path , weights_only = False )
727+ return self .trainer .test (
728+ model ,
729+ test_dataloaders ,
730+ ckpt_path = ckpt_path ,
731+ datamodule = datamodule ,
732+ weights_only = False ,
733+ )
721734
722735 def export (
723736 self ,
@@ -816,7 +829,7 @@ def export(
816829 self ._setup_trainer (model )
817830 if ckpt_path :
818831 ckpt_path = Path (ckpt_path ).resolve ()
819- model = model .__class__ .load_from_checkpoint (ckpt_path )
832+ model = model .__class__ .load_from_checkpoint (ckpt_path , weights_only = False )
820833
821834 if export_root is None :
822835 export_root = Path (self .trainer .default_root_dir )
0 commit comments