[Mlir-commits] [mlir] [mlir][bufferization] Buffer deallocation: skip ops that do not operate on buffers (PR #75126)

Matthias Springer llvmlistbot at llvm.org
Tue Dec 12 01:29:29 PST 2023


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/75126

>From ef5371446026d8d2905d53e552e2f567d23c6c2f Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 12 Dec 2023 18:27:16 +0900
Subject: [PATCH] [mlir][bufferization] Buffer deallocation: skip ops that do
 not operate on buffers

Skip ops that do not operate on buffers during ownership-based buffer deallocation. Such ops can be ignored during buffer deallocation. (Except for terminators.)
---
 .../OwnershipBasedBufferDeallocation.cpp      | 62 ++++++++++---------
 .../dealloc-region-branchop-interface.mlir    | 11 ++--
 mlir/test/lib/Dialect/Test/TestOps.td         | 11 ++++
 3 files changed, 50 insertions(+), 34 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 38ffae68a43de..2fe51192fd1fd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -48,6 +48,32 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
 
 static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
 
+/// Return "true" if the given op is guaranteed to have no "Allocate" or "Free"
+/// side effect.
+static bool hasNoAllocateOrFreeSideEffect(Operation *op) {
+  if (isa<MemoryEffectOpInterface>(op))
+    return hasEffect<MemoryEffects::Allocate>(op) ||
+           hasEffect<MemoryEffects::Free>(op);
+  // If the op does not implement the MemoryEffectOpInterface but has has
+  // recursive memory effects, then this op in isolation (without its body) does
+  // not have any side effects. The ops inside the regions of this op will be
+  // processed separately.
+  return op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
+}
+
+/// Return "true" if the given op has buffer semantics. I.e., it has buffer
+/// operands, buffer results and/or buffer region entry block arguments.
+static bool hasBufferSemantics(Operation *op) {
+  if (llvm::any_of(op->getOperands(), isMemref) ||
+      llvm::any_of(op->getResults(), isMemref))
+    return true;
+  for (Region &region : op->getRegions())
+    if (!region.empty())
+      if (llvm::any_of(region.front().getArguments(), isMemref))
+        return true;
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // Backedges analysis
 //===----------------------------------------------------------------------===//
@@ -462,31 +488,6 @@ BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
   return state.getMemrefWithUniqueOwnership(builder, memref, block);
 }
 
-static bool regionOperatesOnMemrefValues(Region &region) {
-  auto checkBlock = [](Block *block) {
-    if (llvm::any_of(block->getArguments(), isMemref))
-      return WalkResult::interrupt();
-    for (Operation &op : *block) {
-      if (llvm::any_of(op.getOperands(), isMemref))
-        return WalkResult::interrupt();
-      if (llvm::any_of(op.getResults(), isMemref))
-        return WalkResult::interrupt();
-    }
-    return WalkResult::advance();
-  };
-  WalkResult result = region.walk(checkBlock);
-  if (result.wasInterrupted())
-    return true;
-
-  // Note: Block::walk/Region::walk visits only blocks that are nested under
-  // nested operations, but not direct children.
-  for (Block &block : region)
-    if (checkBlock(&block).wasInterrupted())
-      return true;
-
-  return false;
-}
-
 LogicalResult
 BufferDeallocation::verifyFunctionPreconditions(FunctionOpInterface op) {
   // (1) Ensure that there are supported loops only (no explicit control flow
@@ -512,9 +513,8 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) {
   size_t size = regions.size();
   if (((size == 1 && !op->getResults().empty()) || size > 1) &&
       !dyn_cast<RegionBranchOpInterface>(op)) {
-    if (llvm::any_of(regions, regionOperatesOnMemrefValues))
-      return op->emitError("All operations with attached regions need to "
-                           "implement the RegionBranchOpInterface.");
+    return op->emitError("All operations with attached regions need to "
+                         "implement the RegionBranchOpInterface.");
   }
 
   // (2) The pass does not work properly when deallocations are already present.
@@ -648,6 +648,12 @@ LogicalResult BufferDeallocation::deallocate(Block *block) {
   // For each operation in the block, handle the interfaces that affect aliasing
   // and ownership of memrefs.
   for (Operation &op : llvm::make_early_inc_range(*block)) {
+    // Skip ops that do not operate on buffers, have no Allocate/Free side
+    // effect and are not terminators. (bufferization.dealloc ops are inserted
+    // in front of terminators, so terminators cannot be skipped.)
+    if (!op.hasTrait<OpTrait::IsTerminator>() && !hasBufferSemantics(&op) &&
+        hasNoAllocateOrFreeSideEffect(&op))
+      continue;
     FailureOr<Operation *> result = handleAllInterfaces(&op);
     if (failed(result))
       return failure();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
index 1a8a930bc9002..a39568efca168 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir
@@ -516,8 +516,8 @@ func.func @assumingOp(
 // does not deal with any MemRef values.
 
 func.func @noRegionBranchOpInterface() {
-  %0 = "test.bar"() ({
-    %1 = "test.bar"() ({
+  %0 = "test.one_region_with_recursive_memory_effects"() ({
+    %1 = "test.one_region_with_recursive_memory_effects"() ({
       "test.yield"() : () -> ()
     }) : () -> (i32)
     "test.yield"() : () -> ()
@@ -531,7 +531,7 @@ func.func @noRegionBranchOpInterface() {
 // This is not allowed in buffer deallocation.
 
 func.func @noRegionBranchOpInterface() {
-  %0 = "test.bar"() ({
+  %0 = "test.one_region_with_recursive_memory_effects"() ({
     // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
     %1 = "test.bar"() ({
       %2 = "test.get_memref"() : () -> memref<2xi32>
@@ -545,11 +545,10 @@ func.func @noRegionBranchOpInterface() {
 // -----
 
 // Test Case: The op "test.bar" does not implement the RegionBranchOpInterface.
-// This is not allowed in buffer deallocation.
+// But it also does not operate on buffers, so we don't care.
 
 func.func @noRegionBranchOpInterface() {
-  // expected-error at +1 {{All operations with attached regions need to implement the RegionBranchOpInterface.}}
-  %0 = "test.bar"() ({
+  %0 = "test.one_region_with_recursive_memory_effects"() ({
     %2 = "test.get_memref"() : () -> memref<2xi32>
     %3 = "test.foo"(%2) : (memref<2xi32>) -> (i32)
     "test.yield"(%3) : (i32) -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 96f66c2ca06ec..f8e8f204a1c86 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -480,6 +480,17 @@ def IsolatedRegionsOp : TEST_Op<"isolated_regions", [IsolatedFromAbove]> {
   let assemblyFormat = "attr-dict-with-keyword $regions";
 }
 
+def OneRegionWithRecursiveMemoryEffectsOp
+    : TEST_Op<"one_region_with_recursive_memory_effects", [
+        RecursiveMemoryEffects]> {
+  let description = [{
+    Op that has one region and recursive side effects. The
+    RegionBranchOpInterface is not implemented on this op.
+  }];
+  let results = (outs AnyType:$result);
+  let regions = (region SizedRegion<1>:$body);
+}
+
 //===----------------------------------------------------------------------===//
 // NoTerminator Operation
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list