Skip to content

thop calculates torch.nn module params incorrectly #217

@qsimeon

Description

@qsimeon
!pip install -q thop

import torch
from utils import DEVICE
from thop import profile, clever_format

# Display the DEVICE
print(f"DEVICE: {DEVICE}")

# Assuming 'model' is your PyTorch model and 'input' is a tensor representing the input to the model
input = torch.randn(1, 10)  # example input for a single length 10 sequnce
input = input.to(DEVICE)

model = torch.nn.Linear(10, 11)  # example model
model = model.to(DEVICE)
# model.eval()

# Measure the FLOPs
macs, params = profile(model, inputs=(input,), verbose=False)

print(f"MACs: {macs}, params: {params}")

macs, params = clever_format([macs, params], "%.3f")

print(f"MACs: {macs}, params: {params}")

num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {num_params}")

Outputs:

DEVICE: cpu
MACs: 11000.0, params: 0
MACs: 11.000K, params: 0.000B
Parameters: 1012


For some reason the resnet example from the README seems to work ok.

from torchvision.models import resnet50
from thop import profile

model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input,))

macs, params = clever_format([macs, params], "%.3f")
print(f"MACs: {macs}, params: {params}")

Output:

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
MACs: 4.134G, params: 25.557M

Why wouldn't this work for such basic torch.nn.Modules like Linear and LSTM but work for more complicated models?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions