Skip to content

Commit 4ea25c0

Browse files
committed
move metrics - support url for loading ckpt
1 parent 20dd07c commit 4ea25c0

File tree

6 files changed

+14
-9
lines changed

6 files changed

+14
-9
lines changed

LitModel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from optim.schedulers import get_lr_scheduler
99
import models.surgeries
1010
from models.models import get_net
11-
from metrics import Accuracy, wAUC, PE, MD5
11+
from metrics.roc_metrics import wAUC, PE, MD5
12+
from torchmetrics.classification.accuracy import Accuracy
1213

1314
class LitModel(pl.LightningModule):
1415
"""
@@ -41,7 +42,7 @@ def __build_model(self):
4142
self.net = get_net(self.args.model.backbone,
4243
num_classes=self.num_classes,
4344
in_chans=self.in_chans,
44-
imagenet=self.args.ckpt.imagenet,
45+
pretrained=self.args.ckpt.pretrained,
4546
ckpt_path=self.args.ckpt.seed_from)
4647

4748
# 2. Do surgery if needed

cfg/default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ ckpt:
1919
resume_from: null
2020
seed_from: null
2121
load_fc: True
22-
imagenet: True
22+
pretrained: True
2323

2424
optimizer:
2525
# do not decay batch norm and bias and FC

metrics.py renamed to metrics/roc_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from tools.numpy_utils import check_nans
1212
from torchmetrics.utilities.data import dim_zero_cat
1313
from torchmetrics.classification.auc import AUC
14-
from torchmetrics.classification.accuracy import Accuracy
1514
from torchmetrics.functional.classification.auc import _auc_update
1615

1716
@_numpy_metric_conversion

models/models.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22
warnings.simplefilter(action='ignore', category=FutureWarning)
33
from functools import partial
4+
from validators import url
45
import numpy as np
56
from torch import nn
67
import timm
@@ -38,12 +39,15 @@
3839

3940
}
4041

41-
def get_net(model_name, num_classes=2, in_chans=3, imagenet=True, ckpt_path=None, strict_loading=False):
42-
net = zoo_params[model_name]['init_op'](num_classes=num_classes, in_chans=in_chans, pretrained=imagenet)
42+
def get_net(model_name, num_classes=2, in_chans=3, pretrained=True, ckpt_path=None, strict_loading=False):
43+
net = zoo_params[model_name]['init_op'](num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
4344
net.model_name = model_name
44-
45+
4546
if ckpt_path is not None:
46-
state_dict = torch.load(ckpt_path)['state_dict']
47+
if url(ckpt_path):
48+
state_dict = torch.hub.load_state_dict_from_url(ckpt_path)['state_dict']
49+
else:
50+
state_dict = torch.load(ckpt_path)['state_dict']
4751
state_dict = {k.split('net.')[1]: v for k, v in state_dict.items()}
4852

4953
# Check FC compatibility

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ omegaconf==2.1.1
1313
wonderwords==2.2.0
1414
braceexpand==0.1.7
1515
wandb
16+
validators

tests/test_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ ckpt:
1919
resume_from: null
2020
seed_from: null
2121
load_fc: True
22-
imagenet: True
22+
pretrained: True
2323

2424
optimizer:
2525
# do not decay batch norm and bias and FC

0 commit comments

Comments
 (0)