diff --git a/thop/profile.py b/thop/profile.py index edcdaae..2639924 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -155,7 +155,7 @@ def add_hooks(m): def profile( model: nn.Module, - inputs, + kwargs, custom_ops=None, verbose=True, ret_layer_info=False, @@ -208,7 +208,7 @@ def add_hooks(m: nn.Module): model.apply(add_hooks) with torch.no_grad(): - model(*inputs) + model(**kwargs,) def dfs_count(module: nn.Module, prefix="\t") -> (int, int): total_ops, total_params = module.total_ops.item(), 0