[Mlir-commits] [mlir] 662c6fc - [mlir] [bufferize] fix bufferize deallocation error in nest symbol table (#98476)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jul 14 21:52:49 PDT 2024


Author: donald chen
Date: 2024-07-15T12:52:46+08:00
New Revision: 662c6fc74c1ae7fdefd27524dfaeb7f3e9a1f553

URL: https://github.com/llvm/llvm-project/commit/662c6fc74c1ae7fdefd27524dfaeb7f3e9a1f553
DIFF: https://github.com/llvm/llvm-project/commit/662c6fc74c1ae7fdefd27524dfaeb7f3e9a1f553.diff

LOG: [mlir] [bufferize] fix bufferize deallocation error in nest symbol table (#98476)

In nested symbols, the dealloc_helper function generated by lower
deallocations pass was incorrectly positioned, causing calls fail. This
patch fixes this issue.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
    mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
    mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
    mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index e053e6c97e143..c12ed7f5d0180 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -18,6 +18,9 @@ class FuncOp;
 namespace bufferization {
 struct OneShotBufferizationOptions;
 
+/// Maps from symbol table to its corresponding dealloc helper function.
+using DeallocHelperMap = llvm::DenseMap<Operation *, func::FuncOp>;
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
@@ -46,7 +49,7 @@ std::unique_ptr<Pass> createLowerDeallocationsPass();
 /// Adds the conversion pattern of the `bufferization.dealloc` operation to the
 /// given pattern set for use in other transformation passes.
 void populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
+    RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap);
 
 /// Construct the library function needed for the fully generic
 /// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.

diff  --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
index 2aae39f51b940..f9903071be084 100644
--- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
+++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp
@@ -132,27 +132,30 @@ struct BufferizationToMemRefPass
       return;
     }
 
-    func::FuncOp helperFuncOp;
+    bufferization::DeallocHelperMap deallocHelperFuncMap;
     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
       OpBuilder builder =
           OpBuilder::atBlockBegin(&module.getBodyRegion().front());
-      SymbolTable symbolTable(module);
 
       // Build dealloc helper function if there are deallocs.
       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
-        if (deallocOp.getMemrefs().size() > 1) {
-          helperFuncOp = bufferization::buildDeallocationLibraryFunction(
-              builder, getOperation()->getLoc(), symbolTable);
-          return WalkResult::interrupt();
+        Operation *symtableOp =
+            deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+        if (deallocOp.getMemrefs().size() > 1 &&
+            !deallocHelperFuncMap.contains(symtableOp)) {
+          SymbolTable symbolTable(symtableOp);
+          func::FuncOp helperFuncOp =
+              bufferization::buildDeallocationLibraryFunction(
+                  builder, getOperation()->getLoc(), symbolTable);
+          deallocHelperFuncMap[symtableOp] = helperFuncOp;
         }
-        return WalkResult::advance();
       });
     }
 
     RewritePatternSet patterns(&getContext());
     patterns.add<CloneOpConversion>(patterns.getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helperFuncOp);
+    bufferization::populateBufferizationDeallocLoweringPattern(
+        patterns, deallocHelperFuncMap);
 
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
index 7fb46918ab1e8..9e2c91bad7bfd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp
@@ -300,8 +300,9 @@ class DeallocOpConversion
         MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
         retainCondsMemref);
 
+    Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     rewriter.create<func::CallOp>(
-        op.getLoc(), deallocHelperFunc,
+        op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
         SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
                            castedCondsMemref, castedDeallocCondsMemref,
                            castedRetainCondsMemref});
@@ -338,9 +339,11 @@ class DeallocOpConversion
   }
 
 public:
-  DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
+  DeallocOpConversion(
+      MLIRContext *context,
+      const bufferization::DeallocHelperMap &deallocHelperFuncMap)
       : OpConversionPattern<bufferization::DeallocOp>(context),
-        deallocHelperFunc(deallocHelperFunc) {}
+        deallocHelperFuncMap(deallocHelperFuncMap) {}
 
   LogicalResult
   matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
@@ -360,7 +363,8 @@ class DeallocOpConversion
     if (adaptor.getMemrefs().size() == 1)
       return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
 
-    if (!deallocHelperFunc)
+    Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
+    if (!deallocHelperFuncMap.contains(symtableOp))
       return op->emitError(
           "library function required for generic lowering, but cannot be "
           "automatically inserted when operating on functions");
@@ -369,7 +373,7 @@ class DeallocOpConversion
   }
 
 private:
-  func::FuncOp deallocHelperFunc;
+  const bufferization::DeallocHelperMap &deallocHelperFuncMap;
 };
 } // namespace
 
@@ -385,26 +389,29 @@ struct LowerDeallocationsPass
       return;
     }
 
-    func::FuncOp helperFuncOp;
+    bufferization::DeallocHelperMap deallocHelperFuncMap;
     if (auto module = dyn_cast<ModuleOp>(getOperation())) {
       OpBuilder builder =
           OpBuilder::atBlockBegin(&module.getBodyRegion().front());
-      SymbolTable symbolTable(module);
 
       // Build dealloc helper function if there are deallocs.
       getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
-        if (deallocOp.getMemrefs().size() > 1) {
-          helperFuncOp = bufferization::buildDeallocationLibraryFunction(
-              builder, getOperation()->getLoc(), symbolTable);
-          return WalkResult::interrupt();
+        Operation *symtableOp =
+            deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
+        if (deallocOp.getMemrefs().size() > 1 &&
+            !deallocHelperFuncMap.contains(symtableOp)) {
+          SymbolTable symbolTable(symtableOp);
+          func::FuncOp helperFuncOp =
+              bufferization::buildDeallocationLibraryFunction(
+                  builder, getOperation()->getLoc(), symbolTable);
+          deallocHelperFuncMap[symtableOp] = helperFuncOp;
         }
-        return WalkResult::advance();
       });
     }
 
     RewritePatternSet patterns(&getContext());
-    bufferization::populateBufferizationDeallocLoweringPattern(patterns,
-                                                               helperFuncOp);
+    bufferization::populateBufferizationDeallocLoweringPattern(
+        patterns, deallocHelperFuncMap);
 
     ConversionTarget target(getContext());
     target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
@@ -535,8 +542,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
 }
 
 void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
-    RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
-  patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
+    RewritePatternSet &patterns,
+    const bufferization::DeallocHelperMap &deallocHelperFuncMap) {
+  patterns.add<DeallocOpConversion>(patterns.getContext(),
+                                    deallocHelperFuncMap);
 }
 
 std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
index 5fedd45555fcd..edffcbdd0ba7d 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir
@@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
 // CHECK-NEXT:     memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
 // CHECK-NEXT:   }
 // CHECK-NEXT:   return
+
+// -----
+
+// This test check dealloc_helper function is generated on each nested symbol
+// table operation when needed and only generated once.
+module @conversion_nest_module_dealloc_helper {
+  func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+    %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+    func.return %0#0, %0#1 : i1, i1
+  }
+  module @nested_module_not_need_dealloc_helper {
+    func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
+      return %0#0, %0#1 : i1, i1
+    }
+  }
+  module @nested_module_need_dealloc_helper {
+    func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+      func.return %0#0, %0#1 : i1, i1
+    }
+    func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
+      %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
+      func.return %0#0, %0#1 : i1, i1
+    }
+  }
+}
+
+// CHECK:     module @conversion_nest_module_dealloc_helper {
+// CHECK:       func.func @top_level_func
+// CHECK:         call @dealloc_helper
+// CHECK:       module @nested_module_not_need_dealloc_helper {
+// CHECK:         func.func @nested_module_not_need_dealloc_helper_func
+// CHECK-NOT:       @dealloc_helper
+// CHECK:       module @nested_module_need_dealloc_helper {
+// CHECK:         func.func @nested_module_need_dealloc_helper_func0
+// CHECK:           call @dealloc_helper
+// CHECK:         func.func @nested_module_need_dealloc_helper_func1
+// CHECK:           call @dealloc_helper
+// CHECK:         func.func private @dealloc_helper
+// CHECK:       func.func private @dealloc_helper


        


More information about the Mlir-commits mailing list