[Mlir-commits] [mlir] [mlir] [bufferize] fix bufferize deallocation error in nest symbol table (PR #98476)
donald chen
llvmlistbot at llvm.org
Thu Jul 11 06:02:47 PDT 2024
https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/98476
In nested symbols, the dealloc_helper function generated by lower deallocations pass was incorrectly positioned, causing calls fail. This patch fixes this issue.
>From efc079b050d551fef950f42d4ecc21bf9479eff8 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Thu, 11 Jul 2024 12:56:41 +0000
Subject: [PATCH] [mlir] [bufferize] fix bufferize deallocation error in nest
symbol table
In nested symbols, the dealloc_helper function generated by lower deallocations
pass was incorrectly positioned, causing calls fail. This patch fixes this issue.
---
.../Dialect/Bufferization/Transforms/Passes.h | 3 +-
.../BufferizationToMemRef.cpp | 21 ++++++----
.../Transforms/LowerDeallocations.cpp | 41 +++++++++++--------
.../Transforms/lower-deallocations.mlir | 41 +++++++++++++++++++
4 files changed, 80 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e053e6c97e143..298b2165f0e82 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -46,7 +46,8 @@ std::unique_ptr<Pass> createLowerDeallocationsPass();
/// Adds the conversion pattern of the `bufferization.dealloc` operation to the
/// given pattern set for use in other transformation passes.
void populateBufferizationDeallocLoweringPattern(
- RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
+ RewritePatternSet &patterns,
+ const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap);
/// Construct the library function needed for the fully generic
/// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 2aae39f51b940..4de204994f519 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -132,27 +132,30 @@ struct BufferizationToMemRefPass
return;
}
- func::FuncOp helperFuncOp;
+ llvm::DenseMap<Operation *, func::FuncOp> deallocHelperFuncMap;
if (auto module = dyn_cast<ModuleOp>(getOperation())) {
OpBuilder builder =
OpBuilder::atBlockBegin(&module.getBodyRegion().front());
- SymbolTable symbolTable(module);
// Build dealloc helper function if there are deallocs.
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
- if (deallocOp.getMemrefs().size() > 1) {
- helperFuncOp = bufferization::buildDeallocationLibraryFunction(
- builder, getOperation()->getLoc(), symbolTable);
- return WalkResult::interrupt();
+ Operation *symtableOp =
+ deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+ if (deallocOp.getMemrefs().size() > 1 &&
+ !deallocHelperFuncMap.contains(symtableOp)) {
+ SymbolTable symbolTable(symtableOp);
+ func::FuncOp helperFuncOp =
+ bufferization::buildDeallocationLibraryFunction(
+ builder, getOperation()->getLoc(), symbolTable);
+ deallocHelperFuncMap[symtableOp] = helperFuncOp;
}
- return WalkResult::advance();
});
}
RewritePatternSet patterns(&getContext());
patterns.add<CloneOpConversion>(patterns.getContext());
- bufferization::populateBufferizationDeallocLoweringPattern(patterns,
- helperFuncOp);
+ bufferization::populateBufferizationDeallocLoweringPattern(
+ patterns, deallocHelperFuncMap);
ConversionTarget target(getContext());
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
index 7fb46918ab1e8..17987f7322144 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
@@ -300,8 +300,9 @@ class DeallocOpConversion
MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
retainCondsMemref);
+ Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
rewriter.create<func::CallOp>(
- op.getLoc(), deallocHelperFunc,
+ op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
castedCondsMemref, castedDeallocCondsMemref,
castedRetainCondsMemref});
@@ -338,9 +339,11 @@ class DeallocOpConversion
}
public:
- DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+ DeallocOpConversion(
+ MLIRContext *context,
+ const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap)
: OpConversionPattern<bufferization::DeallocOp>(context),
- deallocHelperFunc(deallocHelperFunc) {}
+ deallocHelperFuncMap(deallocHelperFuncMap) {}
LogicalResult
matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
@@ -360,7 +363,8 @@ class DeallocOpConversion
if (adaptor.getMemrefs().size() == 1)
return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
- if (!deallocHelperFunc)
+ Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+ if (!deallocHelperFuncMap.contains(symtableOp))
return op->emitError(
"library function required for generic lowering, but cannot be "
"automatically inserted when operating on functions");
@@ -369,7 +373,7 @@ class DeallocOpConversion
}
private:
- func::FuncOp deallocHelperFunc;
+ const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap;
};
} // namespace
@@ -385,26 +389,29 @@ struct LowerDeallocationsPass
return;
}
- func::FuncOp helperFuncOp;
+ llvm::DenseMap<Operation *, func::FuncOp> deallocHelperFuncMap;
if (auto module = dyn_cast<ModuleOp>(getOperation())) {
OpBuilder builder =
OpBuilder::atBlockBegin(&module.getBodyRegion().front());
- SymbolTable symbolTable(module);
// Build dealloc helper function if there are deallocs.
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
- if (deallocOp.getMemrefs().size() > 1) {
- helperFuncOp = bufferization::buildDeallocationLibraryFunction(
- builder, getOperation()->getLoc(), symbolTable);
- return WalkResult::interrupt();
+ Operation *symtableOp =
+ deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+ if (deallocOp.getMemrefs().size() > 1 &&
+ !deallocHelperFuncMap.contains(symtableOp)) {
+ SymbolTable symbolTable(symtableOp);
+ func::FuncOp helperFuncOp =
+ bufferization::buildDeallocationLibraryFunction(
+ builder, getOperation()->getLoc(), symbolTable);
+ deallocHelperFuncMap[symtableOp] = helperFuncOp;
}
- return WalkResult::advance();
});
}
RewritePatternSet patterns(&getContext());
- bufferization::populateBufferizationDeallocLoweringPattern(patterns,
- helperFuncOp);
+ bufferization::populateBufferizationDeallocLoweringPattern(
+ patterns, deallocHelperFuncMap);
ConversionTarget target(getContext());
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
@@ -535,8 +542,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
}
void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
- RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
- patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
+ RewritePatternSet &patterns,
+ const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap) {
+ patterns.add<DeallocOpConversion>(patterns.getContext(),
+ deallocHelperFuncMap);
}
std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
index 5fedd45555fcd..2d83a2a1ec28d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
@@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
// CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
// CHECK-NEXT: }
// CHECK-NEXT: return
+
+// -----
+
+// This test check dealloc_helper function is generated on each nested symbol
+// table operation when needed and only generate once.
+module @conversion_nest_module_dealloc_helper {
+ func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+ func.return %0#0, %0#1 : i1, i1
+ }
+ module @nested_module_not_need_dealloc_helper {
+ func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
+ return %0#0, %0#1 : i1, i1
+ }
+ }
+ module @nested_module_need_dealloc_helper {
+ func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+ func.return %0#0, %0#1 : i1, i1
+ }
+ func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+ %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+ func.return %0#0, %0#1 : i1, i1
+ }
+ }
+}
+
+// CHECK: module @conversion_nest_module_dealloc_helper {
+// CHECK: func.func @top_level_func
+// CHECK: call @dealloc_helper
+// CHECK: module @nested_module_not_need_dealloc_helper {
+// CHECK: func.func @nested_module_not_need_dealloc_helper_func
+// CHECK-NOT: @dealloc_helper
+// CHECK: module @nested_module_need_dealloc_helper {
+// CHECK: func.func @nested_module_need_dealloc_helper_func0
+// CHECK: call @dealloc_helper
+// CHECK: func.func @nested_module_need_dealloc_helper_func1
+// CHECK: call @dealloc_helper
+// CHECK: func.func private @dealloc_helper
+// CHECK: func.func private @dealloc_helper
More information about the Mlir-commits
mailing list