Conversation
DeBERTa/apps/train.py
Outdated
| with torch.no_grad(): | ||
| trainer.train_step(batch['input_ids'], batch['type_ids'], batch['position_ids'], batch['input_mask'], batch['labels']) | ||
| # conversion fails now with: | ||
| # site-packages/torch/onnx/utils.py:617: UserWarning: ONNX export failed on ATen operator broadcast_tensors |
There was a problem hiding this comment.
broadcast_tensor and mse_loss are ops that are not implemented in ONNX currently. To get unblocked need to modify functional.py as per below comment
DeBERTa/apps/train.py
Outdated
| with torch.no_grad(): | ||
| trainer.train_step(batch['input_ids'], batch['type_ids'], batch['position_ids'], batch['input_mask'], batch['labels']) | ||
| # conversion fails now with: | ||
| # site-packages/torch/onnx/utils.py:617: UserWarning: ONNX export failed on ATen operator broadcast_tensors |
There was a problem hiding this comment.
mse_loss implementation in https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L2682 uses 2 ops that are not implemented: broadcast_tensors() and mse_loss(). Working around this to get unblocked, made a patch:
#expanded_input, expanded_target = torch.broadcast_tensors(input, target)
expanded_input = input + torch.zeros(target.size())
expanded_target = target + torch.zeros(input.size())
#ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
t = expanded_input - expanded_target
t = t * t
ret = torch.mean(t)
| self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) | ||
| self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) | ||
| # Looks like params below are never updated and const, so removing them | ||
| #self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) |
There was a problem hiding this comment.
q_bias and v_bias are always const, so commenting them out
Previous iterations i tried to redefine StableDropout to inherit from nn.Dropout, but it led to regression in model stats. Could not figure out why. If i do change this way there is no regression. Something was missing with just redefining StableDropout. |
5315a01 to
c81eb40
Compare
79dbe25 to
e59f09f
Compare
Changes needed to convert DeBerta to ONNX