|
7 | 7 |
|
8 | 8 |
|
9 | 9 | @skip_if_cuda_not_available |
| 10 | +@pytest.mark.parametrize( |
| 11 | + "rounding_mode", |
| 12 | + [ |
| 13 | + None, |
| 14 | + pytest.param( |
| 15 | + "trunc", marks=pytest.mark.skip(reason="TODO: Test for `trunc` mode later.") |
| 16 | + ), |
| 17 | + "floor", |
| 18 | + ], |
| 19 | +) |
10 | 20 | @pytest.mark.parametrize(*generate_arguments()) |
11 | | -def test_div(shape, dtype, device, rtol, atol): |
| 21 | +def test_div(shape, rounding_mode, dtype, device, rtol, atol): |
12 | 22 | input = torch.randn(shape, dtype=dtype, device=device) |
13 | 23 | other = torch.randn(shape, dtype=dtype, device=device) |
14 | 24 |
|
15 | | - for rounding_mode in (None, "trunc", "floor"): |
16 | | - # TODO: Test for `trunc` mode later. |
17 | | - if rounding_mode == "trunc": |
18 | | - continue |
| 25 | + ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode) |
| 26 | + reference_output = torch.div(input, other, rounding_mode=rounding_mode) |
19 | 27 |
|
20 | | - ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode) |
21 | | - reference_output = torch.div(input, other, rounding_mode=rounding_mode) |
22 | | - |
23 | | - assert torch.allclose( |
24 | | - ninetoothed_output, reference_output, rtol=rtol, atol=atol |
25 | | - ) |
| 28 | + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol) |
0 commit comments