Skip to content

Commit 75007d8

Browse files
Do not rename external func decls during --inline-nested-modules (#2244)
**Context:** One of the first steps in the compilation pipeline is `--inline-nested-modules`. This pass lifts the per-qnode module up into the global module of the qjit. Because (usually) the symbol table of a module is isolated from above, multiple modules in parallel may contain the same names inside. Therefore when inlining child modules into the global module, we need to rename them to unique names. However, there's a small problem. Each qnode module has their own specific transform sequence, so the quantum passes (applied during `--apply-transform-sequence`) have to be resolved before `--inline-nested-module`. If a pass generates a function declaration to an external API (e.g. runtime functions), that name must not be altered. **Description of the Change:** Do not perform renaming on external function declarations from within the qnode modules during `--inline-nested-modules`. Only inline the first occurrence of such func decls from the qnode modules into the global qjit module. **Benefits:** Quantum passes that generate calls to other APIs can work. The flipside is also true! API developers do not have to maintain multiple aliases for their API functions (e.g. gridsynth pass #2140 ) [sc-105020] --------- Co-authored-by: Joseph Lee <[email protected]>
1 parent 15be710 commit 75007d8

File tree

4 files changed

+106
-7
lines changed

4 files changed

+106
-7
lines changed

doc/releases/changelog-dev.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,13 @@
263263

264264
This fix enables automatic qubit management to be used with gradients.
265265

266+
* The `--inline-nested-module` pass no longer renames external function declarations.
267+
[(#2244)](https://github.com/PennyLaneAI/catalyst/pull/2244)
268+
269+
This pass inlines the qnode MLIR modules into the global QJIT MLIR module. If a qnode module
270+
contains function declarations to external APIs, the names of these declarations must
271+
stay unchanged. This change enables quantum passes to generate calls to external APIs.
272+
266273
<h3>Internal changes ⚙️</h3>
267274

268275
* Resource tracking now writes out at device destruction time instead of qubit deallocation

mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565

6666
#include <deque>
6767

68+
#include "llvm/ADT/SmallSet.h"
69+
6870
#include "mlir/Dialect/Func/IR/FuncOps.h"
6971
#include "mlir/Pass/Pass.h"
7072
#include "mlir/Pass/PassManager.h"
@@ -183,14 +185,17 @@ struct AnnotateWithFullyQualifiedName : public OpInterfaceRewritePattern<SymbolO
183185

184186
struct RenameFunctionsPattern : public RewritePattern {
185187
/// This overload constructs a pattern that matches any operation type.
186-
RenameFunctionsPattern(MLIRContext *context, SmallVector<Operation *> *symbolTables)
187-
: RewritePattern(MatchAnyOpTypeTag(), 1, context), _symbolTables(symbolTables)
188+
RenameFunctionsPattern(MLIRContext *context, SmallVector<Operation *> *symbolTables,
189+
llvm::SmallSet<StringRef, 8> *externalFuncDeclNames)
190+
: RewritePattern(MatchAnyOpTypeTag(), 1, context), _symbolTables(symbolTables),
191+
_externalFuncDeclNames(externalFuncDeclNames)
188192
{
189193
}
190194

191195
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override;
192196

193197
SmallVector<Operation *> *_symbolTables;
198+
llvm::SmallSet<StringRef, 8> *_externalFuncDeclNames;
194199
};
195200

196201
static constexpr llvm::StringRef hasBeenRenamedAttrName = "catalyst.unique_names";
@@ -228,8 +233,21 @@ LogicalResult RenameFunctionsPattern::matchAndRewrite(Operation *child,
228233
for (auto &region : child->getRegions()) {
229234
for (auto &block : region.getBlocks()) {
230235
for (auto &op : block) {
231-
if (!isa<SymbolOpInterface>(op))
236+
if (!isa<SymbolOpInterface>(op)) {
232237
continue;
238+
}
239+
240+
// We should not rename external function declarations, as they can be
241+
// names required by other APIs.
242+
// We record these external func decls during the rename pattern.
243+
// Then during the actual inlining stage, only the first occurrence of the
244+
// per-module func decls of these external decls should be inlined.
245+
if (auto f = dyn_cast<func::FuncOp>(op)) {
246+
if (f.isExternal()) {
247+
_externalFuncDeclNames->insert(f.getName());
248+
continue;
249+
}
250+
}
233251

234252
if (failed(childSymTab.renameToUnique(&op, raw_tables))) {
235253
// TODO: Check for error in one of the tests.
@@ -249,7 +267,23 @@ LogicalResult RenameFunctionsPattern::matchAndRewrite(Operation *child,
249267

250268
struct InlineNestedModule : public RewritePattern {
251269
/// This overload constructs a pattern that matches any operation type.
252-
InlineNestedModule(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), 1, context) {}
270+
InlineNestedModule(MLIRContext *context,
271+
const llvm::SmallSet<StringRef, 8> &externalFuncDeclNames)
272+
: RewritePattern(MatchAnyOpTypeTag(), 1, context),
273+
_externalFuncDeclNames(externalFuncDeclNames)
274+
{
275+
}
276+
277+
const llvm::SmallSet<StringRef, 8> &_externalFuncDeclNames;
278+
279+
// Note: mlir expects pattern objects to be const.
280+
// In other words, repeated applications of a rewrite pattern should not have dependency on each
281+
// other.
282+
// This --inline-nested-module pass is breaking this assumption.
283+
//
284+
// TODO: refactor this pass to not use the pattern rewriter, but just raw logic in a
285+
// `runOnOperation()`
286+
mutable llvm::SmallSet<StringRef, 8> alreadyInlinedFuncDeclNames;
253287

254288
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override
255289
{
@@ -261,6 +295,31 @@ struct InlineNestedModule : public RewritePattern {
261295
}
262296

263297
auto parent = op->getParentOp();
298+
assert(parent->hasTrait<OpTrait::SymbolTable>() &&
299+
"the direct parent of a qnode module must be a module op");
300+
301+
// Look for the func decls in the current qnode module
302+
// If it is a recorded external func decl, erase it if it already has been inlined.
303+
SmallVector<Operation *> _erasureWorklist;
304+
for (auto &region : op->getRegions()) {
305+
auto funcOps = region.getOps<func::FuncOp>();
306+
for (auto f : funcOps) {
307+
StringRef funcName = f.getName();
308+
if (f.isExternal() && _externalFuncDeclNames.contains(funcName)) {
309+
if (alreadyInlinedFuncDeclNames.contains(funcName)) {
310+
_erasureWorklist.push_back(f);
311+
}
312+
else {
313+
alreadyInlinedFuncDeclNames.insert(funcName);
314+
}
315+
}
316+
}
317+
}
318+
319+
for (auto _op : _erasureWorklist) {
320+
rewriter.eraseOp(_op);
321+
}
322+
264323
// Can't generalize getting a region other than the zero-th one.
265324
rewriter.inlineRegionBefore(op->getRegion(0), &parent->getRegion(0).back());
266325
Block *inlinedBlock = &parent->getRegion(0).front();
@@ -360,6 +419,7 @@ struct CleanupPattern : public RewritePattern {
360419
return failure();
361420
}
362421
rewriter.modifyOpInPlace(op, [&] { op->removeAttr(fullyQualifiedNameAttr); });
422+
363423
return success();
364424
}
365425
};
@@ -426,15 +486,16 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op
426486
return WalkResult::skip();
427487
});
428488

429-
renameFunctions.add<RenameFunctionsPattern>(context, &symbolTables);
489+
llvm::SmallSet<StringRef, 8> externalFuncDeclNames;
490+
renameFunctions.add<RenameFunctionsPattern>(context, &symbolTables, &externalFuncDeclNames);
430491

431492
bool run = _stopAfterStep >= 2 || _stopAfterStep == 0;
432493
if (run && failed(applyPatternsGreedily(symbolTable, std::move(renameFunctions), config))) {
433494
signalPassFailure();
434495
}
435496

436497
RewritePatternSet inlineNested(context);
437-
inlineNested.add<InlineNestedModule>(context);
498+
inlineNested.add<InlineNestedModule>(context, externalFuncDeclNames);
438499
run = _stopAfterStep >= 3 || _stopAfterStep == 0;
439500
if (run && failed(applyPatternsGreedily(symbolTable, std::move(inlineNested), config))) {
440501
signalPassFailure();

mlir/test/Catalyst/NestedModule2.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,3 @@ module @outer {
6161
}
6262
}
6363
}
64-

mlir/test/Catalyst/NestedModule5.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,35 @@ module @outer {
2929
catalyst.launch_kernel @inner::@f() : () -> ()
3030
}
3131

32+
// -----
33+
34+
// Test external API func decl names are preserved
35+
module @global {
36+
37+
// CHECK-DAG: func.func private @f()
38+
// CHECK-DAG: func.func @main_0() {
39+
// CHECK-DAG: call @f() : () -> ()
40+
// CHECK-DAG: return
41+
// CHECK-DAG: }
42+
// CHECK-DAG: func.func @main_1() {
43+
// CHECK-DAG: call @f() : () -> ()
44+
// CHECK-DAG: return
45+
// CHECK-DAG: }
46+
47+
module @local0 {
48+
func.func private @f()
49+
func.func @main() {
50+
func.call @f() : () -> ()
51+
return
52+
}
53+
}
54+
55+
module @local1 {
56+
func.func private @f()
57+
func.func @main() {
58+
func.call @f() : () -> ()
59+
return
60+
}
61+
}
62+
63+
}

0 commit comments

Comments
 (0)