[Mlir-commits] [mlir] a43641c - [mlir][bufferization] Fix `regionOperatesOnMemrefValues` (#75016)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 11 15:56:27 PST 2023


Author: Matthias Springer
Date: 2023-12-12T08:56:23+09:00
New Revision: a43641c9dbd7e61d10f130858b55cf011260cebf

URL: https://github.com/llvm/llvm-project/commit/a43641c9dbd7e61d10f130858b55cf011260cebf
DIFF: https://github.com/llvm/llvm-project/commit/a43641c9dbd7e61d10f130858b55cf011260cebf.diff

LOG: [mlir][bufferization] Fix `regionOperatesOnMemrefValues` (#75016)

`Region::walk([](Block *b) {...})` does not enumerate blocks that are
direct children of the region. These blocks must be checked manually.

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
    mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 9459cc43547faf..38ffae68a43de2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -463,7 +463,7 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
 }
 
 static bool regionOperatesOnMemrefValues(Region &region) {
-  WalkResult result = region.walk([](Block *block) {
+  auto checkBlock = [](Block *block) {
     if (llvm::any_of(block->getArguments(), isMemref))
       return WalkResult::interrupt();
     for (Operation &op : *block) {
@@ -473,8 +473,18 @@ static bool regionOperatesOnMemrefValues(Region &region) {
         return WalkResult::interrupt();
     }
     return WalkResult::advance();
-  });
-  return result.wasInterrupted();
+  };
+  WalkResult result = region.walk(checkBlock);
+  if (result.wasInterrupted())
+    return true;
+
+  // Note: Block::walk/Region::walk visits only blocks that are nested under
+  // nested operations, but not direct children.
+  for (Block &block : region)
+    if (checkBlock(&block).wasInterrupted())
+      return true;
+
+  return false;
 }
 
 LogicalResult

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
index ad7c4c783e907f..1a8a930bc9002b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -531,8 +531,8 @@ func.func @noRegionBranchOpInterface() {
 // This is not allowed in buffer deallocation.
 
 func.func @noRegionBranchOpInterface() {
-  // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
   %0 = "test.bar"() ({
+    // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
     %1 = "test.bar"() ({
       %2 = "test.get_memref"() : () -> memref<2xi32>
       "test.yield"(%2) : (memref<2xi32>) -> ()
@@ -544,6 +544,21 @@ func.func @noRegionBranchOpInterface() {
 
 // -----
 
+// Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
+// This is not allowed in buffer deallocation.
+
+func.func @noRegionBranchOpInterface() {
+  // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
+  %0 = "test.bar"() ({
+    %2 = "test.get_memref"() : () -> memref<2xi32>
+    %3 = "test.foo"(%2) : (memref<2xi32>) -> (i32)
+    "test.yield"(%3) : (i32) -> ()
+  }) : () -> (i32)
+  "test.terminator"() : () -> ()
+}
+
+// -----
+
 func.func @while_two_arg(%arg0: index) {
   %a = memref.alloc(%arg0) : memref<?xf32>
   scf.while (%arg1 = %a, %arg2 = %a) : (memref<?xf32>, memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {


        


More information about the Mlir-commits mailing list