11#include " mlir/IR/BuiltinOps.h" // mlir::ModuleOp
2- #include " mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
3- #include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
4- #include " mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
5- #include " mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
62#include " mlir/Target/LLVMIR/LLVMTranslationInterface.h"
73#include " mlir/Target/LLVMIR/ModuleTranslation.h"
84#include " llvm/ADT/SmallVector.h"
95#include " llvm/IR/LLVMContext.h"
10- #include " llvm/IR/Module.h"
11- #include " llvm/IR/PassManager.h"
12- #include " llvm/IRReader/IRReader.h"
13- #include " llvm/Pass.h"
14- #include " llvm/Passes/OptimizationLevel.h"
15- #include " llvm/Passes/PassBuilder.h"
16- #include " llvm/Support/CodeGen.h"
17- #include " llvm/Target/TargetMachine.h"
18- #include " llvm/Transforms/InstCombine/InstCombine.h"
19- #include < pybind11/pybind11.h>
20- #include < pybind11/stl.h>
21-
22- #ifdef _WIN32
23- #define WIN32_LEAN_AND_MEAN
24- #include < windows.h>
25- #else
26- #include < dlfcn.h>
27- #endif
28- #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
29- #include " mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
30- #include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
31- #include " mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
32- #include " mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
33- #include " mlir/Target/LLVMIR/Export.h"
34- #include " mlir/Target/LLVMIR/LLVMTranslationInterface.h"
35- #include " triton/Tools/Sys/GetEnv.hpp"
36- #include " llvm/Linker/Linker.h"
37- #include < filesystem>
38- #include < iterator>
39-
40- #include " llvm/ADT/APInt.h"
41- #include " llvm/ADT/STLExtras.h"
42- #include " llvm/ADT/SmallVector.h"
43- #include " llvm/Analysis/TargetTransformInfo.h"
44- #include " llvm/IR/CallingConv.h"
45- #include " llvm/IR/Constants.h"
46- #include " llvm/IR/IRBuilder.h"
476#include " llvm/IR/LegacyPassManager.h"
487#include " llvm/IR/Module.h"
8+ #include " llvm/IR/PassManager.h"
499#include " llvm/IR/Verifier.h"
5010#include " llvm/IRReader/IRReader.h"
5111#include " llvm/Linker/Linker.h"
5212#include " llvm/MC/TargetRegistry.h"
5313#include " llvm/Pass.h"
5414#include " llvm/Passes/OptimizationLevel.h"
5515#include " llvm/Passes/PassBuilder.h"
56- #include " llvm/Support/CommandLine.h"
57- #include " llvm/Support/Error.h"
58- #include " llvm/Support/FormatVariadic.h"
59- #include " llvm/Support/SourceMgr.h"
60- #include " llvm/Support/TargetSelect.h"
16+ #include " llvm/Support/CodeGen.h"
6117#include " llvm/Target/TargetMachine.h"
6218#include " llvm/Transforms/IPO/AlwaysInliner.h"
6319#include " llvm/Transforms/InstCombine/InstCombine.h"
20+ #include < pybind11/pybind11.h>
21+ #include < pybind11/stl.h>
6422
6523namespace py = pybind11;
6624
@@ -72,23 +30,6 @@ struct BreakStructPhiNodesPass : PassInfoMixin<BreakStructPhiNodesPass> {
7230} // namespace llvm
7331
7432using namespace llvm ;
75- //
76- // TODO: move to python
77- static void initLLVM () {
78- static std::once_flag init_flag;
79- std::call_once (init_flag, []() {
80- LLVMInitializeNVPTXTargetInfo ();
81- LLVMInitializeNVPTXTarget ();
82- LLVMInitializeNVPTXTargetMC ();
83- LLVMInitializeNVPTXAsmPrinter ();
84-
85- LLVMInitializeAMDGPUTarget ();
86- LLVMInitializeAMDGPUTargetInfo ();
87- LLVMInitializeAMDGPUTargetMC ();
88- LLVMInitializeAMDGPUAsmParser ();
89- LLVMInitializeAMDGPUAsmPrinter ();
90- });
91- }
9233
9334std::string translateLLVMIRToASM (llvm::Module &module ,
9435 const std::string &triple,
@@ -97,7 +38,6 @@ std::string translateLLVMIRToASM(llvm::Module &module,
9738 const std::vector<std::string> &flags,
9839 bool enable_fp_fusion, bool isObject) {
9940 using namespace mlir ;
100- initLLVM ();
10141 // options
10242 auto options = llvm::cl::getRegisteredOptions ();
10343 for (std::string flag : flags) {
@@ -152,26 +92,6 @@ std::string translateLLVMIRToASM(llvm::Module &module,
15292
15393using ret = py::return_value_policy;
15494
155- void findKernels (llvm::Module &M, std::set<llvm::Function *> &functions) {
156- llvm::NamedMDNode *annotations = M.getNamedMetadata (" nvvm.annotations" );
157- assert (annotations);
158- for (auto *Node : annotations->operands ()) {
159- if (Node->getNumOperands () < 3 )
160- continue ;
161- llvm::Metadata *Op = Node->getOperand (0 ).get ();
162- auto *ValueAsMetadata = llvm::dyn_cast<llvm::ValueAsMetadata>(Op);
163- if (!ValueAsMetadata)
164- continue ;
165- auto *F = llvm::dyn_cast<llvm::Function>(ValueAsMetadata->getValue ());
166- if (!F)
167- continue ;
168- llvm::Metadata *Property = Node->getOperand (1 ).get ();
169- if (auto *MDString = llvm::dyn_cast<llvm::MDString>(Property))
170- if (MDString->getString () == " kernel" )
171- functions.insert (F);
172- }
173- }
174-
17595void init_triton_llvm (py::module &&m) {
17696
17797 py::class_<llvm::LLVMContext>(m, " context" , py::module_local ())
@@ -198,18 +118,9 @@ void init_triton_llvm(py::module &&m) {
198118 m.attr (" OPTIMIZE_Os" ) = (llvm::OptimizationLevel::Os);
199119 m.attr (" OPTIMIZE_Oz" ) = (llvm::OptimizationLevel::Oz);
200120
201- m.def (" to_module" ,
202- [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx, std::string name) {
203- // TODO: dialects can be registered earlier...
204- // This shouldn't depend on ROCDL or NVVM
205- mlir::DialectRegistry registry;
206- mlir::registerBuiltinDialectTranslation (registry);
207- mlir::registerLLVMDialectTranslation (registry);
208- mlir::registerROCDLDialectTranslation (registry);
209- mlir::registerNVVMDialectTranslation (registry);
210- mod->getContext ()->appendDialectRegistry (registry);
211- return mlir::translateModuleToLLVMIR (mod, ctx);
212- });
121+ m.def (" to_module" , [](mlir::ModuleOp &mod, llvm::LLVMContext &ctx) {
122+ return mlir::translateModuleToLLVMIR (mod, ctx);
123+ });
213124
214125 m.def (" optimize_module" , [](llvm::Module *mod,
215126 const llvm::OptimizationLevel &opt) {
@@ -255,8 +166,7 @@ void init_triton_llvm(py::module &&m) {
255166 " translate_to_asm" ,
256167 [](std::string llvmIR, std::string triple, std::string proc,
257168 std::string features, std::vector<std::string> flags,
258- bool enable_fp_fusion,
259- bool isObject) -> std::tuple<py::object, std::string> {
169+ bool enable_fp_fusion, bool isObject) -> py::object {
260170 py::gil_scoped_release allow_threads;
261171 // create LLVM module from C++
262172 llvm::LLVMContext context;
@@ -270,35 +180,15 @@ void init_triton_llvm(py::module &&m) {
270180 " failed to parse IR: " + error.getMessage () +
271181 " lineno: " + std::to_string (error.getLineNo ()));
272182 }
273- // Get name of kernel in the module
274- std::set<llvm::Function *> kernels;
275- findKernels (*module , kernels);
276- assert (kernels.size () == 1 );
277- std::string name = (*kernels.begin ())->getName ().str ();
278183 std::string obj = translateLLVMIRToASM (
279184 *module , triple, proc, features, flags, enable_fp_fusion, isObject);
280185 if (isObject)
281- return std::make_tuple ( py::bytes (obj), name );
186+ return py::bytes (obj);
282187 else
283- return std::make_tuple ( py::str (obj), name );
188+ return py::str (obj);
284189 },
285190 ret::take_ownership);
286191
287- m.def (" set_nvvm_reflect_ftz" , [](llvm::Module *mod) {
288- // please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters
289- // this will enable fast math path in libdevice
290- // for example, when enable nvvm-reflect-ftz, sqrt.approx.f32 will change to
291- // sqrt.approx.ftz.f32
292- using namespace llvm ;
293- auto &ctx = mod->getContext ();
294- Type *i32 = Type::getInt32Ty (ctx);
295- Metadata *mdFour = ConstantAsMetadata::get (ConstantInt::getSigned (i32 , 4 ));
296- Metadata *mdName = MDString::get (ctx, " nvvm-reflect-ftz" );
297- Metadata *mdOne = ConstantAsMetadata::get (ConstantInt::getSigned (i32 , 1 ));
298- MDNode *reflect = MDNode::get (ctx, {mdFour, mdName, mdOne});
299- mod->addModuleFlag (reflect);
300- });
301-
302192 m.def (" link_extern_lib" , [](llvm::Module *mod, std::string path) {
303193 llvm::SMDiagnostic err;
304194 auto &ctx = mod->getContext ();
0 commit comments