Skip to content

Commit 88ad686

Browse files
authored
Merge pull request #58 from InfiniTensor/develop-matmul
Add `matmul` operator
2 parents 7c03487 + 63fbc24 commit 88ad686

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed

src/ntops/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ntops.torch.layer_norm import layer_norm
2020
from ntops.torch.le import le
2121
from ntops.torch.lt import lt
22+
from ntops.torch.matmul import matmul
2223
from ntops.torch.mm import mm
2324
from ntops.torch.mul import mul
2425
from ntops.torch.ne import ne
@@ -58,6 +59,7 @@
5859
"layer_norm",
5960
"le",
6061
"lt",
62+
"matmul",
6163
"mm",
6264
"mul",
6365
"ne",

src/ntops/torch/matmul.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import ntops
2+
3+
4+
def matmul(input, other, *, out=None):
5+
assert input.ndim in (2, 3) and other.ndim in (2, 3), (
6+
"Currently, only 2D and 3D tensors are supported."
7+
)
8+
9+
if input.ndim == 2 and other.ndim == 2:
10+
return ntops.torch.mm(input, other, out=out)
11+
12+
if input.ndim < 3:
13+
input = input.unsqueeze(0)
14+
15+
if other.ndim < 3:
16+
other = other.unsqueeze(0)
17+
18+
return ntops.torch.bmm(input, other, out=out)

tests/test_matmul.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
import torch
3+
4+
import ntops
5+
from tests.skippers import skip_if_cuda_not_available
6+
from tests.test_mm import generate_arguments
7+
8+
9+
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize(*generate_arguments())
11+
@pytest.mark.parametrize("b", (None, 1, 2, 3))
12+
def test_matmul(b, m, n, k, dtype, device, rtol, atol):
13+
input_shape = (b, m, k) if b is not None else (m, k)
14+
other_shape = (b, k, n) if b is not None else (k, n)
15+
16+
input = torch.randn(input_shape, dtype=dtype, device=device)
17+
other = torch.randn(other_shape, dtype=dtype, device=device)
18+
19+
ninetoothed_output = ntops.torch.matmul(input, other)
20+
reference_output = torch.matmul(input, other)
21+
22+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)