Skip to content

Commit b0a1a77

Browse files
committed
Refactor tests to use pytest.mark.parametrize instead of for loops
1 parent 53746de commit b0a1a77

File tree

3 files changed

+32
-25
lines changed

3 files changed

+32
-25
lines changed

tests/test_div.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,22 @@
77

88

99
@skip_if_cuda_not_available
10+
@pytest.mark.parametrize(
11+
"rounding_mode",
12+
[
13+
None,
14+
pytest.param(
15+
"trunc", marks=pytest.mark.skip(reason="TODO: Test for `trunc` mode later.")
16+
),
17+
"floor",
18+
],
19+
)
1020
@pytest.mark.parametrize(*generate_arguments())
11-
def test_div(shape, dtype, device, rtol, atol):
21+
def test_div(shape, rounding_mode, dtype, device, rtol, atol):
1222
input = torch.randn(shape, dtype=dtype, device=device)
1323
other = torch.randn(shape, dtype=dtype, device=device)
1424

15-
for rounding_mode in (None, "trunc", "floor"):
16-
# TODO: Test for `trunc` mode later.
17-
if rounding_mode == "trunc":
18-
continue
25+
ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode)
26+
reference_output = torch.div(input, other, rounding_mode=rounding_mode)
1927

20-
ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode)
21-
reference_output = torch.div(input, other, rounding_mode=rounding_mode)
22-
23-
assert torch.allclose(
24-
ninetoothed_output, reference_output, rtol=rtol, atol=atol
25-
)
28+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

tests/test_gelu.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,20 @@
88

99

1010
@skip_if_cuda_not_available
11+
@pytest.mark.parametrize(
12+
"approximate",
13+
(
14+
"none",
15+
pytest.param(
16+
"tanh", marks=pytest.mark.skip(reason="TODO: Test for `tanh` mode later.")
17+
),
18+
),
19+
)
1120
@pytest.mark.parametrize(*generate_arguments())
12-
def test_gelu(shape, dtype, device, rtol, atol):
21+
def test_gelu(shape, approximate, dtype, device, rtol, atol):
1322
input = torch.randn(shape, dtype=dtype, device=device)
1423

15-
for approximate in ("none", "tanh"):
16-
ninetoothed_output = ntops.torch.gelu(input)
17-
reference_output = F.gelu(input)
24+
ninetoothed_output = ntops.torch.gelu(input, approximate=approximate)
25+
reference_output = F.gelu(input, approximate=approximate)
1826

19-
assert torch.allclose(
20-
ninetoothed_output, reference_output, rtol=rtol, atol=atol
21-
)
27+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

tests/test_relu.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88

99

1010
@skip_if_cuda_not_available
11+
@pytest.mark.parametrize("inplace", (False, True))
1112
@pytest.mark.parametrize(*generate_arguments())
12-
def test_relu(shape, dtype, device, rtol, atol):
13+
def test_relu(shape, inplace, dtype, device, rtol, atol):
1314
input = torch.randn(shape, dtype=dtype, device=device)
1415

15-
for inplace in (False, True):
16-
ninetoothed_output = ntops.torch.relu(input, inplace)
17-
reference_output = F.relu(input, inplace)
16+
ninetoothed_output = ntops.torch.relu(input, inplace)
17+
reference_output = F.relu(input, inplace)
1818

19-
assert torch.allclose(
20-
ninetoothed_output, reference_output, rtol=rtol, atol=atol
21-
)
19+
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)