diff --git a/thop/profile.py b/thop/profile.py index 4b98364..59dbfdc 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -2,7 +2,7 @@ from thop.vision.basic_hooks import * from thop.rnn_hooks import * - +from copy import deepcopy # logger = logging.getLogger(__name__) # logger.setLevel(logging.INFO) @@ -182,13 +182,12 @@ def add_hooks(m: nn.Module): handler_collection[m] = (m.register_forward_hook(fn), m.register_forward_hook(count_parameters)) types_collection.add(m_type) - prev_training_status = model.training - - model.eval() - model.apply(add_hooks) + model_cpy = deepcopy(model) + model_cpy.eval() + model_cpy.apply(add_hooks) with torch.no_grad(): - model(*inputs) + model_cpy(*inputs) def dfs_count(module: nn.Module, prefix="\t") -> (int, int): total_ops, total_params = 0, 0 @@ -206,14 +205,8 @@ def dfs_count(module: nn.Module, prefix="\t") -> (int, int): # print(prefix, module._get_name(), (total_ops.item(), total_params.item())) return total_ops, total_params - total_ops, total_params = dfs_count(model) + total_ops, total_params = dfs_count(model_cpy) - # reset model to original status - model.train(prev_training_status) - for m, (op_handler, params_handler) in handler_collection.items(): - op_handler.remove() - params_handler.remove() - m._buffers.pop("total_ops") - m._buffers.pop("total_params") + del model_cpy return total_ops, total_params