-
Notifications
You must be signed in to change notification settings - Fork 5
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Is there an existing issue for this?
- I have searched the existing issues.
Describe the bug:
When trying to compile a simple function with @torch.compile, it fails with the following error.
Interestingly, in a previous environment torch=2.6, triton=3.2, as a similar issue could sometimes be resolved by removing the docstring from the application function being compiled. However, with the current versions, it fails.
To reproduce:
from ninetoothed import Tensor, make, block_size, Symbol
import ninetoothed.language as ntl
import torch
from functools import lru_cache
from torch import nn
from icecream import ic
BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True)
def arrangement(input, weight, eps,
BLOCK_SIZE = BLOCK_SIZE,
):
'''
input: (..., N)
weight: (..., N)
->
input_arranged: (..., 1)x(1, N)
weight: (..., 1)x(1, N)
'''
ndim = len(input.shape)
arrange_shape = tuple(1 for _ in range(ndim-1)) + (BLOCK_SIZE,)
expand_shape = tuple(input.shape[:-1]) + (-1,)
def _squeeze(x):
for _ in range(ndim-1):
x.dtype = x.dtype.squeeze(0)
return x
input_arranged = input.tile(arrange_shape)
input_arranged = _squeeze(input_arranged)
weight_arranged = weight.tile(arrange_shape).expand(expand_shape)
weight_arranged = _squeeze(weight_arranged)
# subs = {
# input: Tensor(shape=(2, 3, 8)),
# weight: Tensor(shape=(1, 1, 8)),
# BLOCK_SIZE: 8
# }
# ic(input_arranged.eval(subs).shape)
# ic(weight_arranged.eval(subs).shape)
# exit(0)
return input_arranged, weight_arranged, eps
def application(input, weight, eps):
'''
!!!!!!!!!!
Delete this doc string may solve this issue in some versions.
'''
input_square = input * input
input_square_mean = ntl.sum(input_square) / input.shape[-1]
input = input * ntl.rsqrt(input_square_mean + eps) * weight
@lru_cache(1)
def premake(ndim):
kernel = make(arrangement, application, (Tensor(ndim), Tensor(ndim), Tensor(0)), max_num_configs=2)
return kernel
def rms_forward(input, weight, eps):
# print('rms_forward', input.shape)
assert weight.dim() == 1
ndim = input.dim()
weight = weight.view((1,)*(ndim-1) + (-1,))
premake(ndim)(input, weight, eps, BLOCK_SIZE=input.shape[-1])
return input
class RMSNorm(nn.Module):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
device: torch.device| str | None = None,
) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size, device=device))
@torch.compile
def rms_forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
return rms_forward(x, self.weight, self.eps)
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if residual is None:
return self.rms_forward(x)
else:
raise NotImplementedError("Not implement!")
if __name__ == "__main__":
DEVICE = "cuda"
input = torch.randn((16384, 1024), device=DEVICE, dtype=torch.float16)
input_c = input.clone()
weight = torch.randn((1024,), device=DEVICE, dtype=torch.float16)
eps = 1e-5
def _rms_forward(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + eps))
x = x.to(orig_dtype).mul_(weight)
return x
torch_output = _rms_forward(input_c, weight, eps)
output = rms_forward(input, weight, eps)
ic(torch.allclose(torch_output, output, atol=1e-2, rtol=1e-2))
ic((output - torch_output).abs().max().item())
N = 1024
layer = RMSNorm(N, device=DEVICE)
x = torch.randn((2, 128, N), device=DEVICE, dtype=torch.float16)
ic(layer(x))
Expected behavior:
ic| torch.allclose(torch_output, output, atol=1e-2, rtol=1e-2): True
Traceback (most recent call last):
File "/home/workspace/nano-vllm/test/test_compile2.py", line 127, in <module>
ic(layer(x))
^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/workspace/nano-vllm/test/test_compile2.py", line 100, in forward
return self.rms_forward(x)
^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 663, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py", line 1541, in _call_user_compiler
raise BackendCompilerFailed(
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py", line 1516, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 2349, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 2248, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/backends/common.py", line 101, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1160, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 779, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1145, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 219, in aot_dispatch_base
compiled_fw = compiler(fw_module, updated_flat_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 479, in __call__
return self.compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 2103, in fw_compiler_base
return inner_compile(
^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 631, in compile_fx_inner
return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
inner_compiled_fn = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 714, in _compile_fx_inner
(key_info, cache_info) = FxGraphCache.prepare_key(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py", line 1246, in prepare_key
key, debug_lines = compiled_fx_graph_hash(
^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py", line 858, in compiled_fx_graph_hash
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py", line 830, in __init__
self.system_info = CacheBase.get_system()
^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py", line 165, in get_system
from triton.compiler.compiler import triton_key
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
ImportError: cannot import name 'triton_key' from 'triton.compiler.compiler' (/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py)
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Environment details:
pytorch-triton 3.2.0+git4b3bb1f8b.nvinternal
torch 2.7.0a0+79aa17489c.nv25.4
triton 3.5.0
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working