Skip to content

Commit 7d3f045

Browse files
authored
[FRONTEND] more modular dialect registration; remove some unnecessary includes (triton-lang#2847)
1 parent 9c3ec7b commit 7d3f045

File tree

7 files changed

+90
-161
lines changed

7 files changed

+90
-161
lines changed

python/src/ir.cc

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
#include "mlir/Pass/Pass.h"
1313
#include "mlir/Pass/PassManager.h"
1414
#include "mlir/Support/FileUtilities.h"
15+
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
16+
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
1517
#include "mlir/Transforms/Passes.h"
1618
#include "triton/Analysis/Allocation.h"
17-
#include "triton/Dialect/NVGPU/IR/Dialect.h"
1819
#include "triton/Dialect/Triton/IR/Dialect.h"
1920
#include "triton/Dialect/Triton/IR/Types.h"
20-
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2121
#include "triton/Tools/Sys/GetEnv.hpp"
2222
#include <pybind11/pybind11.h>
2323
#include <pybind11/stl.h>
@@ -184,17 +184,20 @@ void init_triton_ir(py::module &&m) {
184184
.value("RTNE", mlir::triton::RoundingMode::RTNE);
185185

186186
py::class_<mlir::MLIRContext>(m, "context", py::module_local())
187-
.def(py::init<>())
188-
.def("load_triton", [](mlir::MLIRContext &self) {
189-
self.getOrLoadDialect<mlir::triton::TritonDialect>();
190-
self.getOrLoadDialect<mlir::index::IndexDialect>();
191-
self.getOrLoadDialect<mlir::triton::TritonDialect>();
192-
self.getOrLoadDialect<mlir::gpu::GPUDialect>();
193-
// we load LLVM because the frontend uses LLVM.undef for
194-
// some placeholders
195-
self.getOrLoadDialect<mlir::LLVM::LLVMDialect>();
196-
self.getOrLoadDialect<mlir::tensor::TensorDialect>();
197-
});
187+
.def(py::init<>());
188+
189+
m.def("load_dialects", [](mlir::MLIRContext &context) {
190+
mlir::DialectRegistry registry;
191+
registry.insert<mlir::triton::TritonDialect,
192+
mlir::triton::gpu::TritonGPUDialect,
193+
mlir::math::MathDialect, mlir::arith::ArithDialect,
194+
mlir::index::IndexDialect, mlir::scf::SCFDialect,
195+
mlir::cf::ControlFlowDialect, mlir::LLVM::LLVMDialect>();
196+
mlir::registerBuiltinDialectTranslation(registry);
197+
mlir::registerLLVMDialectTranslation(registry);
198+
context.appendDialectRegistry(registry);
199+
context.loadAllAvailableDialects();
200+
});
198201

199202
py::class_<mlir::Type>(m, "type", py::module_local())
200203
.def("is_integer", &mlir::Type::isInteger)
@@ -426,19 +429,6 @@ void init_triton_ir(py::module &&m) {
426429
m.def(
427430
"parse_mlir_module",
428431
[](const std::string &inputFilename, mlir::MLIRContext &context) {
429-
// initialize registry
430-
// note: we initialize llvm for undef
431-
mlir::DialectRegistry registry;
432-
registry.insert<
433-
mlir::triton::TritonDialect, mlir::triton::gpu::TritonGPUDialect,
434-
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
435-
mlir::triton::nvgpu::NVGPUDialect, mlir::math::MathDialect,
436-
mlir::arith::ArithDialect, mlir::index::IndexDialect,
437-
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect,
438-
mlir::LLVM::LLVMDialect>();
439-
context.appendDialectRegistry(registry);
440-
context.loadAllAvailableDialects();
441-
442432
// parse module
443433
mlir::OwningOpRef<mlir::ModuleOp> module =
444434
mlir::parseSourceFile<mlir::ModuleOp>(inputFilename, &context);

python/src/llvm.cc

Lines changed: 10 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,24 @@
11
#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp
2-
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
3-
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
4-
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
5-
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
62
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
73
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
84
#include "llvm/ADT/SmallVector.h"
95
#include "llvm/IR/LLVMContext.h"
10-
#include "llvm/IR/Module.h"
11-
#include "llvm/IR/PassManager.h"
12-
#include "llvm/IRReader/IRReader.h"
13-
#include "llvm/Pass.h"
14-
#include "llvm/Passes/OptimizationLevel.h"
15-
#include "llvm/Passes/PassBuilder.h"
16-
#include "llvm/Support/CodeGen.h"
17-
#include "llvm/Target/TargetMachine.h"
18-
#include "llvm/Transforms/InstCombine/InstCombine.h"
19-
#include <pybind11/pybind11.h>
20-
#include <pybind11/stl.h>
21-
22-
#ifdef _WIN32
23-
#define WIN32_LEAN_AND_MEAN
24-
#include <windows.h>
25-
#else
26-
#include <dlfcn.h>
27-
#endif
28-
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
29-
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
30-
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
31-
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
32-
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
33-
#include "mlir/Target/LLVMIR/Export.h"
34-
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
35-
#include "triton/Tools/Sys/GetEnv.hpp"
36-
#include "llvm/Linker/Linker.h"
37-
#include <filesystem>
38-
#include <iterator>
39-
40-
#include "llvm/ADT/APInt.h"
41-
#include "llvm/ADT/STLExtras.h"
42-
#include "llvm/ADT/SmallVector.h"
43-
#include "llvm/Analysis/TargetTransformInfo.h"
44-
#include "llvm/IR/CallingConv.h"
45-
#include "llvm/IR/Constants.h"
46-
#include "llvm/IR/IRBuilder.h"
476
#include "llvm/IR/LegacyPassManager.h"
487
#include "llvm/IR/Module.h"
8+
#include "llvm/IR/PassManager.h"
499
#include "llvm/IR/Verifier.h"
5010
#include "llvm/IRReader/IRReader.h"
5111
#include "llvm/Linker/Linker.h"
5212
#include "llvm/MC/TargetRegistry.h"
5313
#include "llvm/Pass.h"
5414
#include "llvm/Passes/OptimizationLevel.h"
5515
#include "llvm/Passes/PassBuilder.h"
56-
#include "llvm/Support/CommandLine.h"
57-
#include "llvm/Support/Error.h"
58-
#include "llvm/Support/FormatVariadic.h"
59-
#include "llvm/Support/SourceMgr.h"
60-
#include "llvm/Support/TargetSelect.h"
16+
#include "llvm/Support/CodeGen.h"
6117
#include "llvm/Target/TargetMachine.h"
6218
#include "llvm/Transforms/IPO/AlwaysInliner.h"
6319
#include "llvm/Transforms/InstCombine/InstCombine.h"
20+
#include <pybind11/pybind11.h>
21+
#include <pybind11/stl.h>
6422

6523
namespace py = pybind11;
6624

@@ -72,23 +30,6 @@ struct BreakStructPhiNodesPass : PassInfoMixin<BreakStructPhiNodesPass> {
7230
} // namespace llvm
7331

7432
using namespace llvm;
75-
//
76-
// TODO: move to python
77-
static void initLLVM() {
78-
static std::once_flag init_flag;
79-
std::call_once(init_flag, []() {
80-
LLVMInitializeNVPTXTargetInfo();
81-
LLVMInitializeNVPTXTarget();
82-
LLVMInitializeNVPTXTargetMC();
83-
LLVMInitializeNVPTXAsmPrinter();
84-
85-
LLVMInitializeAMDGPUTarget();
86-
LLVMInitializeAMDGPUTargetInfo();
87-
LLVMInitializeAMDGPUTargetMC();
88-
LLVMInitializeAMDGPUAsmParser();
89-
LLVMInitializeAMDGPUAsmPrinter();
90-
});
91-
}
9233

9334
std::string translateLLVMIRToASM(llvm::Module &module,
9435
const std::string &triple,
@@ -97,7 +38,6 @@ std::string translateLLVMIRToASM(llvm::Module &module,
9738
const std::vector<std::string> &flags,
9839
bool enable_fp_fusion, bool isObject) {
9940
using namespace mlir;
100-
initLLVM();
10141
// options
10242
auto options = llvm::cl::getRegisteredOptions();
10343
for (std::string flag : flags) {
@@ -152,26 +92,6 @@ std::string translateLLVMIRToASM(llvm::Module &module,
15292

15393
using ret = py::return_value_policy;
15494

155-
void findKernels(llvm::Module &M, std::set<llvm::Function *> &functions) {
156-
llvm::NamedMDNode *annotations = M.getNamedMetadata("nvvm.annotations");
157-
assert(annotations);
158-
for (auto *Node : annotations->operands()) {
159-
if (Node->getNumOperands() < 3)
160-
continue;
161-
llvm::Metadata *Op = Node->getOperand(0).get();
162-
auto *ValueAsMetadata = llvm::dyn_cast<llvm::ValueAsMetadata>(Op);
163-
if (!ValueAsMetadata)
164-
continue;
165-
auto *F = llvm::dyn_cast<llvm::Function>(ValueAsMetadata->getValue());
166-
if (!F)
167-
continue;
168-
llvm::Metadata *Property = Node->getOperand(1).get();
169-
if (auto *MDString = llvm::dyn_cast<llvm::MDString>(Property))
170-
if (MDString->getString() == "kernel")
171-
functions.insert(F);
172-
}
173-
}
174-
17595
void init_triton_llvm(py::module &&m) {
17696

17797
py::class_<llvm::LLVMContext>(m, "context", py::module_local())
@@ -198,18 +118,9 @@ void init_triton_llvm(py::module &&m) {
198118
m.attr("OPTIMIZE_Os") = (llvm::OptimizationLevel::Os);
199119
m.attr("OPTIMIZE_Oz") = (llvm::OptimizationLevel::Oz);
200120

201-
m.def("to_module",
202-
[](mlir::ModuleOp &mod, llvm::LLVMContext &ctx, std::string name) {
203-
// TODO: dialects can be registered earlier...
204-
// This shouldn't depend on ROCDL or NVVM
205-
mlir::DialectRegistry registry;
206-
mlir::registerBuiltinDialectTranslation(registry);
207-
mlir::registerLLVMDialectTranslation(registry);
208-
mlir::registerROCDLDialectTranslation(registry);
209-
mlir::registerNVVMDialectTranslation(registry);
210-
mod->getContext()->appendDialectRegistry(registry);
211-
return mlir::translateModuleToLLVMIR(mod, ctx);
212-
});
121+
m.def("to_module", [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) {
122+
return mlir::translateModuleToLLVMIR(mod, ctx);
123+
});
213124

214125
m.def("optimize_module", [](llvm::Module *mod,
215126
const llvm::OptimizationLevel &opt) {
@@ -255,8 +166,7 @@ void init_triton_llvm(py::module &&m) {
255166
"translate_to_asm",
256167
[](std::string llvmIR, std::string triple, std::string proc,
257168
std::string features, std::vector<std::string> flags,
258-
bool enable_fp_fusion,
259-
bool isObject) -> std::tuple<py::object, std::string> {
169+
bool enable_fp_fusion, bool isObject) -> py::object {
260170
py::gil_scoped_release allow_threads;
261171
// create LLVM module from C++
262172
llvm::LLVMContext context;
@@ -270,35 +180,15 @@ void init_triton_llvm(py::module &&m) {
270180
"failed to parse IR: " + error.getMessage() +
271181
"lineno: " + std::to_string(error.getLineNo()));
272182
}
273-
// Get name of kernel in the module
274-
std::set<llvm::Function *> kernels;
275-
findKernels(*module, kernels);
276-
assert(kernels.size() == 1);
277-
std::string name = (*kernels.begin())->getName().str();
278183
std::string obj = translateLLVMIRToASM(
279184
*module, triple, proc, features, flags, enable_fp_fusion, isObject);
280185
if (isObject)
281-
return std::make_tuple(py::bytes(obj), name);
186+
return py::bytes(obj);
282187
else
283-
return std::make_tuple(py::str(obj), name);
188+
return py::str(obj);
284189
},
285190
ret::take_ownership);
286191

287-
m.def("set_nvvm_reflect_ftz", [](llvm::Module *mod) {
288-
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
289-
// this will enable fast math path in libdevice
290-
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
291-
// sqrt.approx.ftz.f32
292-
using namespace llvm;
293-
auto &ctx = mod->getContext();
294-
Type *i32 = Type::getInt32Ty(ctx);
295-
Metadata *mdFour = ConstantAsMetadata::get(ConstantInt::getSigned(i32, 4));
296-
Metadata *mdName = MDString::get(ctx, "nvvm-reflect-ftz");
297-
Metadata *mdOne = ConstantAsMetadata::get(ConstantInt::getSigned(i32, 1));
298-
MDNode *reflect = MDNode::get(ctx, {mdFour, mdName, mdOne});
299-
mod->addModuleFlag(reflect);
300-
});
301-
302192
m.def("link_extern_lib", [](llvm::Module *mod, std::string path) {
303193
llvm::SMDiagnostic err;
304194
auto &ctx = mod->getContext();

python/src/nvidia.cc

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
#include "mlir/Pass/Pass.h"
22
#include "mlir/Pass/PassManager.h"
3+
#include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
34
#include "passes.h"
45
#include "triton/Conversion/NVGPUToLLVM/Passes.h"
56
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
7+
#include "triton/Dialect/NVGPU/IR/Dialect.h"
68
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
79
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
10+
#include "llvm/IR/Constants.h"
11+
#include "llvm/Support/TargetSelect.h"
812
#include <pybind11/pybind11.h>
913
#include <pybind11/stl.h>
1014
#include <pybind11/stl_bind.h>
@@ -18,7 +22,7 @@ void init_triton_nvidia_passes_ttgpuir(py::module &&m) {
1822
ADD_PASS_WRAPPER_1("add_rewrite_tensor_pointer",
1923
mlir::createTritonGPURewriteTensorPointerPass, int);
2024
// TODO: it is weird to pass mlir::triton::NVVM here since the conversion is
21-
// nvidia-specific
25+
// nvidia-specificontext
2226
m.def("add_to_llvmir", [](mlir::PassManager &pm, int32_t capability,
2327
mlir::triton::gpu::TMAMetadataTy *tmaMetadata) {
2428
pm.addPass(createConvertTritonGPUToLLVMPass(capability, mlir::triton::NVVM,
@@ -98,4 +102,41 @@ void init_triton_nvidia(py::module &&m) {
98102
.def_readwrite("TMADescArgIdx",
99103
&mlir::triton::gpu::TMAInfo::TMADescArgIdx);
100104
py::bind_vector<std::vector<mlir::triton::gpu::TMAInfo>>(m, "TMAInfos");
105+
106+
// load dialects
107+
m.def("load_dialects", [](mlir::MLIRContext &context) {
108+
mlir::DialectRegistry registry;
109+
registry.insert<mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
110+
mlir::triton::nvgpu::NVGPUDialect>();
111+
mlir::registerNVVMDialectTranslation(registry);
112+
context.appendDialectRegistry(registry);
113+
context.loadAllAvailableDialects();
114+
});
115+
116+
// init llvm
117+
m.def("init_llvm", []() {
118+
static std::once_flag init_flag;
119+
std::call_once(init_flag, []() {
120+
LLVMInitializeNVPTXTargetInfo();
121+
LLVMInitializeNVPTXTarget();
122+
LLVMInitializeNVPTXTargetMC();
123+
LLVMInitializeNVPTXAsmPrinter();
124+
});
125+
});
126+
127+
// TODO: could be done in python if we had a generic interface to set metadata
128+
m.def("set_nvvm_reflect_ftz", [](llvm::Module *mod) {
129+
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
130+
// this will enable fast math path in libdevice
131+
// for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
132+
// sqrt.approx.ftz.f32
133+
using namespace llvm;
134+
auto &ctx = mod->getContext();
135+
Type *i32 = Type::getInt32Ty(ctx);
136+
auto *mdFour = ConstantAsMetadata::get(ConstantInt::getSigned(i32, 4));
137+
auto *mdName = MDString::get(ctx, "nvvm-reflect-ftz");
138+
auto *mdOne = ConstantAsMetadata::get(ConstantInt::getSigned(i32, 1));
139+
auto *reflect = MDNode::get(ctx, {mdFour, mdName, mdOne});
140+
mod->addModuleFlag(reflect);
141+
});
101142
}

python/triton/compiler/backends/cuda.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def parse_options(self, opts) -> Any:
7474
args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
7575
return CUDAOptions(**args)
7676

77+
@staticmethod
78+
def load_dialects(ctx):
79+
nvidia.load_dialects(ctx)
80+
7781
@staticmethod
7882
def make_ttir(mod, metadata, opt):
7983
pm = ir.pass_manager(mod.context)
@@ -179,9 +183,10 @@ def make_llir(src, metadata, options, capability):
179183
passes.llvmir.add_di_scope(pm)
180184
pm.run(mod)
181185
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
186+
nvidia.init_llvm()
182187
context = llvm.context()
183-
llvm_mod = llvm.to_module(mod, context, "LLVMModule")
184-
llvm.set_nvvm_reflect_ftz(llvm_mod)
188+
llvm_mod = llvm.to_module(mod, context)
189+
nvidia.set_nvvm_reflect_ftz(llvm_mod)
185190
if options.extern_libs:
186191
for name, path in options.extern_libs:
187192
llvm.link_extern_lib(llvm_mod, path)
@@ -201,9 +206,12 @@ def make_llir(src, metadata, options, capability):
201206
@staticmethod
202207
def make_ptx(src, metadata, opt, capability):
203208
proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
204-
ret, name = llvm.translate_to_asm(src, 'nvptx64-nvidia-cuda', proc, '', ['nvptx-short-ptr'],
205-
opt.enable_fp_fusion, False)
206-
metadata["name"] = name
209+
ret = llvm.translate_to_asm(src, 'nvptx64-nvidia-cuda', proc, '', ['nvptx-short-ptr'], opt.enable_fp_fusion,
210+
False)
211+
# Find kernel names (there should only be one)
212+
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
213+
assert len(names) == 1
214+
metadata["name"] = names[0]
207215
# post-process
208216
ptx_version = opt.ptx_version
209217
if ptx_version is None:

python/triton/compiler/code_generator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,10 +1189,8 @@ def kernel_suffix(signature, specialization):
11891189
return suffix
11901190

11911191

1192-
def ast_to_ttir(fn, specialization, options):
1192+
def ast_to_ttir(fn, specialization, context, options):
11931193
attrs = specialization.attrs
1194-
context = ir.context()
1195-
context.load_triton()
11961194
# create kernel prototype
11971195
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
11981196
constants = {cst_key(key): value for key, value in specialization.constants.items()}

0 commit comments

Comments
 (0)