@@ -365,6 +365,7 @@ def fit(
365365 val_dataloaders : EVAL_DATALOADERS | None = None ,
366366 datamodule : AnomalibDataModule | None = None ,
367367 ckpt_path : str | Path | None = None ,
368+ ** kwargs , # noqa: ARG002
368369 ) -> None :
369370 """Fit the model using the trainer.
370371
@@ -379,6 +380,7 @@ def fit(
379380 Defaults to None.
380381 ckpt_path (str | None, optional): Checkpoint path. If provided, the model will be loaded from this path.
381382 Defaults to None.
383+ **kwargs: Additional arguments passed to PyTorch Lightning Trainer's fit method.
382384
383385 CLI Usage:
384386 1. you can pick a model, and you can run through the MVTec dataset.
@@ -415,7 +417,14 @@ def fit(
415417 weights_only = False ,
416418 )
417419 else :
418- self .trainer .fit (model , train_dataloaders , val_dataloaders , datamodule , ckpt_path , weights_only = False )
420+ self .trainer .fit (
421+ model ,
422+ train_dataloaders ,
423+ val_dataloaders ,
424+ datamodule ,
425+ ckpt_path ,
426+ weights_only = False ,
427+ )
419428
420429 def validate (
421430 self ,
@@ -424,6 +433,7 @@ def validate(
424433 ckpt_path : str | Path | None = None ,
425434 verbose : bool = True ,
426435 datamodule : AnomalibDataModule | None = None ,
436+ ** kwargs , # noqa: ARG002
427437 ) -> _EVALUATE_OUTPUT | None :
428438 """Validate the model using the trainer.
429439
@@ -441,6 +451,7 @@ def validate(
441451 AnomalibDataModule` that defines the
442452 :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook.
443453 Defaults to None.
454+ **kwargs: Additional arguments passed to PyTorch Lightning Trainer's validate method.
444455
445456 Returns:
446457 _EVALUATE_OUTPUT | None: Validation results.
@@ -472,6 +483,7 @@ def test(
472483 ckpt_path : str | Path | None = None ,
473484 verbose : bool = True ,
474485 datamodule : AnomalibDataModule | None = None ,
486+ ** kwargs , # noqa: ARG002
475487 ) -> _EVALUATE_OUTPUT :
476488 """Test the model using the trainer.
477489
@@ -497,6 +509,7 @@ def test(
497509 A :class:`~lightning.pytorch.core.datamodule.AnomalibDataModule` that defines
498510 the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook.
499511 Defaults to None.
512+ **kwargs: Additional arguments passed to PyTorch Lightning Trainer's test method.
500513
501514 Returns:
502515 _EVALUATE_OUTPUT: A List of dictionaries containing the test results. 1 dict per dataloader.
@@ -556,7 +569,7 @@ def test(
556569
557570 if self ._should_run_validation (model or self .model , ckpt_path ):
558571 logger .info ("Running validation before testing to collect normalization metrics and/or thresholds." )
559- self .trainer .validate (model , dataloaders , None , verbose = False , datamodule = datamodule )
572+ self .trainer .validate (model , dataloaders , None , verbose = False , datamodule = datamodule , weights_only = False )
560573 return self .trainer .test (model , dataloaders , ckpt_path , verbose , datamodule , weights_only = False )
561574
562575 def predict (
@@ -568,6 +581,7 @@ def predict(
568581 return_predictions : bool | None = None ,
569582 ckpt_path : str | Path | None = None ,
570583 data_path : str | Path | None = None ,
584+ ** kwargs , # noqa: ARG002
571585 ) -> _PREDICT_OUTPUT | None :
572586 """Predict using the model using the trainer.
573587
@@ -601,6 +615,7 @@ def predict(
601615 data_path (str | Path | None):
602616 Path to the image or folder containing images to generate predictions for.
603617 Defaults to None.
618+ **kwargs: Additional arguments passed to PyTorch Lightning Trainer's predict method.
604619
605620 Returns:
606621 _PREDICT_OUTPUT | None: Predictions.
@@ -747,6 +762,7 @@ def export(
747762 ov_kwargs : dict [str , Any ] | None = None ,
748763 onnx_kwargs : dict [str , Any ] | None = None ,
749764 ckpt_path : str | Path | None = None ,
765+ ** kwargs , # noqa: ARG002
750766 ) -> Path | None :
751767 r"""Export the model in PyTorch, ONNX or OpenVINO format.
752768
@@ -785,6 +801,7 @@ def export(
785801 See https://pytorch.org/docs/stable/onnx.html#torch.onnx.export for details.
786802 Defaults to ``None``.
787803 ckpt_path (str | Path | None): Checkpoint path. If provided, the model will be loaded from this path.
804+ **kwargs: Additional arguments.
788805
789806 Returns:
790807 Path: Path to the exported model.
0 commit comments