Skip to content

Commit cf83ed2

Browse files
🧪 fix(tests): CLI unit test (#3161)
* Use right parent class for registering the usage Signed-off-by: Ashwin Vaidya <[email protected]> * Add weights_only kwarg Signed-off-by: Ashwin Vaidya <[email protected]> * Add missing weights_only kwarg Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 0b70a1d commit cf83ed2

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

src/anomalib/cli/cli.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ def init_parser(**kwargs) -> ArgumentParser:
9999
def subcommands() -> dict[str, set[str]]:
100100
"""Skip predict subcommand as it is added later."""
101101
return {
102-
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"},
103-
"validate": {"model", "dataloaders", "datamodule"},
104-
"test": {"model", "dataloaders", "datamodule"},
102+
"fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule", "weights_only"},
103+
"validate": {"model", "dataloaders", "datamodule", "weights_only"},
104+
"test": {"model", "dataloaders", "datamodule", "weights_only"},
105105
}
106106

107107
@staticmethod

src/anomalib/cli/utils/help_formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def add_usage(self, usage: str | None, actions: list, *args, **kwargs) -> None:
327327
elif self.verbosity_level == 1:
328328
actions = [action for action in actions if action.dest in REQUIRED_ARGUMENTS[self.subcommand]]
329329

330-
super().add_usage(usage, actions, *args, **kwargs)
330+
super(RichHelpFormatter, self).add_usage(usage, actions, *args, **kwargs)
331331

332332
def add_argument(self, action: argparse.Action) -> None:
333333
"""Add an argument to the help formatter.

src/anomalib/engine/engine.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -407,9 +407,15 @@ def fit(
407407
self._setup_trainer(model)
408408
if model.learning_type in {LearningType.ZERO_SHOT, LearningType.FEW_SHOT}:
409409
# if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding
410-
self.trainer.validate(model, val_dataloaders, datamodule=datamodule, ckpt_path=ckpt_path)
410+
self.trainer.validate(
411+
model,
412+
val_dataloaders,
413+
datamodule=datamodule,
414+
ckpt_path=ckpt_path,
415+
weights_only=False,
416+
)
411417
else:
412-
self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
418+
self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule, ckpt_path, weights_only=False)
413419

414420
def validate(
415421
self,
@@ -457,7 +463,7 @@ def validate(
457463
ckpt_path = Path(ckpt_path).resolve()
458464
if model:
459465
self._setup_trainer(model)
460-
return self.trainer.validate(model, dataloaders, ckpt_path, verbose, datamodule)
466+
return self.trainer.validate(model, dataloaders, ckpt_path, verbose, datamodule, weights_only=False)
461467

462468
def test(
463469
self,
@@ -551,7 +557,7 @@ def test(
551557
if self._should_run_validation(model or self.model, ckpt_path):
552558
logger.info("Running validation before testing to collect normalization metrics and/or thresholds.")
553559
self.trainer.validate(model, dataloaders, None, verbose=False, datamodule=datamodule)
554-
return self.trainer.test(model, dataloaders, ckpt_path, verbose, datamodule)
560+
return self.trainer.test(model, dataloaders, ckpt_path, verbose, datamodule, weights_only=False)
555561

556562
def predict(
557563
self,
@@ -658,9 +664,10 @@ def predict(
658664
ckpt_path=None,
659665
verbose=False,
660666
datamodule=datamodule,
667+
weights_only=False,
661668
)
662669

663-
return self.trainer.predict(model, dataloaders, datamodule, return_predictions, ckpt_path)
670+
return self.trainer.predict(model, dataloaders, datamodule, return_predictions, ckpt_path, weights_only=False)
664671

665672
def train(
666673
self,
@@ -716,8 +723,14 @@ def train(
716723
# if the model is zero-shot or few-shot, we only need to run validate for normalization and thresholding
717724
self.trainer.validate(model, val_dataloaders, None, verbose=False, datamodule=datamodule)
718725
else:
719-
self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
720-
return self.trainer.test(model, test_dataloaders, ckpt_path=ckpt_path, datamodule=datamodule)
726+
self.trainer.fit(model, train_dataloaders, val_dataloaders, datamodule, ckpt_path, weights_only=False)
727+
return self.trainer.test(
728+
model,
729+
test_dataloaders,
730+
ckpt_path=ckpt_path,
731+
datamodule=datamodule,
732+
weights_only=False,
733+
)
721734

722735
def export(
723736
self,
@@ -816,7 +829,7 @@ def export(
816829
self._setup_trainer(model)
817830
if ckpt_path:
818831
ckpt_path = Path(ckpt_path).resolve()
819-
model = model.__class__.load_from_checkpoint(ckpt_path)
832+
model = model.__class__.load_from_checkpoint(ckpt_path, weights_only=False)
820833

821834
if export_root is None:
822835
export_root = Path(self.trainer.default_root_dir)

0 commit comments

Comments
 (0)