[Mlir-commits] [mlir] [mlir] [bufferize] fix bufferize deallocation error in nest symbol table (PR #98476)

donald chen llvmlistbot at llvm.org
Thu Jul 11 06:02:47 PDT 2024


https://github.com/cxy-1993 created https://github.com/llvm/llvm-project/pull/98476

In nested symbols, the dealloc_helper function generated by lower deallocations pass was incorrectly positioned, causing calls fail. This patch fixes this issue.

>From efc079b050d551fef950f42d4ecc21bf9479eff8 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Thu, 11 Jul 2024 12:56:41 +0000
Subject: [PATCH] [mlir] [bufferize] fix bufferize deallocation error in nest
 symbol table

In nested symbols, the dealloc_helper function generated by lower deallocations
pass was incorrectly positioned, causing calls fail. This patch fixes this issue.
---
 .../Dialect/Bufferization/Transforms/Passes.h |  3 +-
 .../BufferizationToMemRef.cpp                 | 21 ++++++----
 .../Transforms/LowerDeallocations.cpp         | 41 +++++++++++--------
 .../Transforms/lower-deallocations.mlir       | 41 +++++++++++++++++++
 4 files changed, 80 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e053e6c97e143..298b2165f0e82 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -46,7 +46,8 @@ std::unique_ptr<Pass> createLowerDeallocationsPass();
 /// Adds the conversion pattern of the `bufferization.dealloc` operation to the
 /// given pattern set for use in other transformation passes.
 void populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
+    RewritePatternSet &patterns,
+    const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap);
 
 /// Construct the library function needed for the fully generic
 /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 2aae39f51b940..4de204994f519 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -132,27 +132,30 @@ struct BufferizationToMemRefPass
       return;
     }
 
-    func::FuncOp helperFuncOp;
+    llvm::DenseMap<Operation *, func::FuncOp> deallocHelperFuncMap;
     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
       OpBuilder builder =
           OpBuilder::atBlockBegin(&module.getBodyRegion().front());
-      SymbolTable symbolTable(module);
 
       // Build dealloc helper function if there are deallocs.
       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
-        if (deallocOp.getMemrefs().size() > 1) {
-          helperFuncOp = bufferization::buildDeallocationLibraryFunction(
-              builder, getOperation()->getLoc(), symbolTable);
-          return WalkResult::interrupt();
+        Operation *symtableOp =
+            deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+        if (deallocOp.getMemrefs().size() > 1 &&
+            !deallocHelperFuncMap.contains(symtableOp)) {
+          SymbolTable symbolTable(symtableOp);
+          func::FuncOp helperFuncOp =
+              bufferization::buildDeallocationLibraryFunction(
+                  builder, getOperation()->getLoc(), symbolTable);
+          deallocHelperFuncMap[symtableOp] = helperFuncOp;
         }
-        return WalkResult::advance();
       });
     }
 
     RewritePatternSet patterns(&getContext());
     patterns.add<CloneOpConversion>(patterns.getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helperFuncOp);
+    bufferization::populateBufferizationDeallocLoweringPattern(
+        patterns, deallocHelperFuncMap);
 
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
index 7fb46918ab1e8..17987f7322144 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
@@ -300,8 +300,9 @@ class DeallocOpConversion
         MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
         retainCondsMemref);
 
+    Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     rewriter.create<func::CallOp>(
-        op.getLoc(), deallocHelperFunc,
+        op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
         SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
                            castedCondsMemref, castedDeallocCondsMemref,
                            castedRetainCondsMemref});
@@ -338,9 +339,11 @@ class DeallocOpConversion
   }
 
 public:
-  DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+  DeallocOpConversion(
+      MLIRContext *context,
+      const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap)
       : OpConversionPattern<bufferization::DeallocOp>(context),
-        deallocHelperFunc(deallocHelperFunc) {}
+        deallocHelperFuncMap(deallocHelperFuncMap) {}
 
   LogicalResult
   matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
@@ -360,7 +363,8 @@ class DeallocOpConversion
     if (adaptor.getMemrefs().size() == 1)
       return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
 
-    if (!deallocHelperFunc)
+    Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+    if (!deallocHelperFuncMap.contains(symtableOp))
       return op->emitError(
           "library function required for generic lowering, but cannot be "
           "automatically inserted when operating on functions");
@@ -369,7 +373,7 @@ class DeallocOpConversion
   }
 
 private:
-  func::FuncOp deallocHelperFunc;
+  const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap;
 };
 } // namespace
 
@@ -385,26 +389,29 @@ struct LowerDeallocationsPass
       return;
     }
 
-    func::FuncOp helperFuncOp;
+    llvm::DenseMap<Operation *, func::FuncOp> deallocHelperFuncMap;
     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
       OpBuilder builder =
           OpBuilder::atBlockBegin(&module.getBodyRegion().front());
-      SymbolTable symbolTable(module);
 
       // Build dealloc helper function if there are deallocs.
       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
-        if (deallocOp.getMemrefs().size() > 1) {
-          helperFuncOp = bufferization::buildDeallocationLibraryFunction(
-              builder, getOperation()->getLoc(), symbolTable);
-          return WalkResult::interrupt();
+        Operation *symtableOp =
+            deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+        if (deallocOp.getMemrefs().size() > 1 &&
+            !deallocHelperFuncMap.contains(symtableOp)) {
+          SymbolTable symbolTable(symtableOp);
+          func::FuncOp helperFuncOp =
+              bufferization::buildDeallocationLibraryFunction(
+                  builder, getOperation()->getLoc(), symbolTable);
+          deallocHelperFuncMap[symtableOp] = helperFuncOp;
         }
-        return WalkResult::advance();
       });
     }
 
     RewritePatternSet patterns(&getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helperFuncOp);
+    bufferization::populateBufferizationDeallocLoweringPattern(
+        patterns, deallocHelperFuncMap);
 
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
@@ -535,8 +542,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
 }
 
 void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
-  patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
+    RewritePatternSet &patterns,
+    const llvm::DenseMap<Operation *, func::FuncOp> &deallocHelperFuncMap) {
+  patterns.add<DeallocOpConversion>(patterns.getContext(),
+                                    deallocHelperFuncMap);
 }
 
 std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {
diff --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
index 5fedd45555fcd..2d83a2a1ec28d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
@@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
 // CHECK-NEXT:     memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return
+
+// -----
+
+// This test check dealloc_helper function is generated on each nested symbol
+// table operation when needed and only generate once.
+module @conversion_nest_module_dealloc_helper {
+  func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+    %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+    func.return %0#0, %0#1 : i1, i1
+  }
+  module @nested_module_not_need_dealloc_helper {
+    func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
+      return %0#0, %0#1 : i1, i1
+    }
+  }
+  module @nested_module_need_dealloc_helper {
+    func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+      func.return %0#0, %0#1 : i1, i1
+    }
+    func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+      func.return %0#0, %0#1 : i1, i1
+    }
+  }
+}
+
+// CHECK:     module @conversion_nest_module_dealloc_helper {
+// CHECK:       func.func @top_level_func
+// CHECK:         call @dealloc_helper
+// CHECK:       module @nested_module_not_need_dealloc_helper {
+// CHECK:         func.func @nested_module_not_need_dealloc_helper_func
+// CHECK-NOT:       @dealloc_helper
+// CHECK:       module @nested_module_need_dealloc_helper {
+// CHECK:         func.func @nested_module_need_dealloc_helper_func0
+// CHECK:           call @dealloc_helper
+// CHECK:         func.func @nested_module_need_dealloc_helper_func1
+// CHECK:           call @dealloc_helper
+// CHECK:         func.func private @dealloc_helper
+// CHECK:       func.func private @dealloc_helper



More information about the Mlir-commits mailing list