Skip to content

Torch compile fails with unkown error #100

@JoeZhang-0x000

Description

@JoeZhang-0x000

Is there an existing issue for this?

  • I have searched the existing issues.

Describe the bug:

When trying to compile a simple function with @torch.compile, it fails with the following error.
Interestingly, in a previous environment torch=2.6, triton=3.2, as a similar issue could sometimes be resolved by removing the docstring from the application function being compiled. However, with the current versions, it fails.

To reproduce:

from ninetoothed import Tensor, make, block_size, Symbol
import ninetoothed.language as ntl
import torch
from functools import lru_cache
from torch import nn
from icecream import ic

BLOCK_SIZE = Symbol("BLOCK_SIZE", constexpr=True)

def arrangement(input, weight, eps,
                BLOCK_SIZE = BLOCK_SIZE,
                ):
    '''
    input: (..., N)
    weight: (..., N)
    ->
    input_arranged: (..., 1)x(1, N)
    weight: (..., 1)x(1, N)
    '''
    ndim = len(input.shape)
    arrange_shape = tuple(1 for _ in range(ndim-1)) + (BLOCK_SIZE,)
    expand_shape = tuple(input.shape[:-1]) + (-1,)

    def _squeeze(x):
        for _ in range(ndim-1):
              x.dtype = x.dtype.squeeze(0)
        return x 

    input_arranged = input.tile(arrange_shape)
    input_arranged = _squeeze(input_arranged)
    weight_arranged = weight.tile(arrange_shape).expand(expand_shape)
    weight_arranged = _squeeze(weight_arranged)

    # subs = {
    #      input: Tensor(shape=(2, 3, 8)),
    #      weight: Tensor(shape=(1, 1, 8)),
    #      BLOCK_SIZE: 8
    # }

    # ic(input_arranged.eval(subs).shape)
    # ic(weight_arranged.eval(subs).shape)
    # exit(0)

    return input_arranged, weight_arranged, eps


def application(input, weight, eps):
    '''
    !!!!!!!!!!
    Delete this doc string may solve this issue in some versions.
    '''
    input_square = input * input
    input_square_mean = ntl.sum(input_square) / input.shape[-1]

    input = input * ntl.rsqrt(input_square_mean + eps) * weight

@lru_cache(1)
def premake(ndim):
    kernel = make(arrangement, application, (Tensor(ndim), Tensor(ndim), Tensor(0)), max_num_configs=2)
    return kernel

def rms_forward(input, weight, eps):
    # print('rms_forward', input.shape)
    assert weight.dim() == 1
    ndim = input.dim()
    weight = weight.view((1,)*(ndim-1) + (-1,))
    premake(ndim)(input, weight, eps, BLOCK_SIZE=input.shape[-1])
    return input

class RMSNorm(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
        device: torch.device| str | None = None,
    ) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(hidden_size, device=device))

    @torch.compile
    def rms_forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        return rms_forward(x, self.weight, self.eps)

    def forward(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        if residual is None:
            return self.rms_forward(x)
        else:
            raise NotImplementedError("Not implement!")

if __name__ == "__main__":
    DEVICE = "cuda"
    input = torch.randn((16384, 1024), device=DEVICE, dtype=torch.float16)
    input_c = input.clone()
    weight = torch.randn((1024,), device=DEVICE, dtype=torch.float16)
    eps = 1e-5
    def _rms_forward(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
            orig_dtype = x.dtype
            x = x.float()
            var = x.pow(2).mean(dim=-1, keepdim=True)
            x.mul_(torch.rsqrt(var + eps))
            x = x.to(orig_dtype).mul_(weight)
            return x
    torch_output = _rms_forward(input_c, weight, eps)
    output = rms_forward(input, weight, eps)
    
    ic(torch.allclose(torch_output, output, atol=1e-2, rtol=1e-2))
    ic((output - torch_output).abs().max().item())


    N = 1024
    layer = RMSNorm(N, device=DEVICE)
    x = torch.randn((2, 128, N), device=DEVICE, dtype=torch.float16)

    ic(layer(x))

Expected behavior:

ic| torch.allclose(torch_output, output, atol=1e-2, rtol=1e-2): True
Traceback (most recent call last):
  File "/home/workspace/nano-vllm/test/test_compile2.py", line 127, in <module>
    ic(layer(x))
       ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/workspace/nano-vllm/test/test_compile2.py", line 100, in forward
    return self.rms_forward(x)
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 663, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py", line 1541, in _call_user_compiler
    raise BackendCompilerFailed(
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/output_graph.py", line 1516, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 2349, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 2248, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/backends/common.py", line 101, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1160, in aot_module_simplified
    compiled_fn = AOTAutogradCache.load(
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/autograd_cache.py", line 779, in load
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1145, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 820, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 219, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 479, in __call__
    return self.compiler_fn(gm, example_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 2103, in fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 631, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/repro/after_aot.py", line 124, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/compile_fx.py", line 714, in _compile_fx_inner
    (key_info, cache_info) = FxGraphCache.prepare_key(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py", line 1246, in prepare_key
    key, debug_lines = compiled_fx_graph_hash(
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py", line 858, in compiled_fx_graph_hash
    details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py", line 830, in __init__
    self.system_info = CacheBase.get_system()
                       ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/codecache.py", line 165, in get_system
    from triton.compiler.compiler import triton_key
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
ImportError: cannot import name 'triton_key' from 'triton.compiler.compiler' (/usr/local/lib/python3.12/dist-packages/triton/compiler/compiler.py)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Environment details:

pytorch-triton 3.2.0+git4b3bb1f8b.nvinternal
torch 2.7.0a0+79aa17489c.nv25.4
triton 3.5.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions