[Mlir-commits] [mlir] [mlir][bufferization] BufferDeallocationOpInterface: support custom ownership update logic (PR #66350)

Martin Erhart llvmlistbot at llvm.org
Thu Sep 14 05:03:15 PDT 2023


https://github.com/maerhart updated https://github.com/llvm/llvm-project/pull/66350:

>From 6c1b5d2f630360bd0fc44acc229d78b7101bc68d Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Tue, 12 Sep 2023 15:20:05 +0000
Subject: [PATCH] [mlir][bufferization] BufferDeallocationOpInterface: support
 custom ownership update logic

Add a method to the BufferDeallocationOpInterface that allows operations to
implement the interface and provide custom logic to compute the ownership
indicators of values it defines. As a demonstrating example, this new method is
implemented by the `arith.select` operation.
---
 .../BufferDeallocationOpInterfaceImpl.h       | 22 +++++
 .../IR/BufferDeallocationOpInterface.h        |  4 +-
 .../IR/BufferDeallocationOpInterface.td       | 29 ++++++-
 mlir/include/mlir/InitAllDialects.h           |  2 +
 .../BufferDeallocationOpInterfaceImpl.cpp     | 85 +++++++++++++++++++
 .../Dialect/Arith/Transforms/CMakeLists.txt   |  1 +
 .../IR/BufferDeallocationOpInterface.cpp      |  4 +-
 .../OwnershipBasedBufferDeallocation.cpp      | 46 ++++++++--
 .../dealloc-callop-interface.mlir             | 10 +--
 9 files changed, 187 insertions(+), 16 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h
 create mode 100644 mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp

diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h
new file mode 100644
index 000000000000000..16cec1a82b5c86c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h
@@ -0,0 +1,22 @@
+//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace arith {
+void registerBufferDeallocationOpInterfaceExternalModels(
+    DialectRegistry &registry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
index b88270f1c150a27..7ac4592de7875fb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
@@ -142,8 +142,8 @@ class DeallocationState {
   /// a new SSA value, returned as the first element of the pair, which has
   /// 'Unique' ownership and can be used instead of the passed Value with the
   /// the ownership indicator returned as the second element of the pair.
-  std::pair<Value, Value> getMemrefWithUniqueOwnership(OpBuilder &builder,
-                                                       Value memref);
+  std::pair<Value, Value>
+  getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
 
   /// Given two basic blocks and the values passed via block arguments to the
   /// destination block, compute the list of MemRefs that have to be retained in
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
index c35fe417184ffd4..3e11432c65c5f08 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
@@ -39,7 +39,34 @@ def BufferDeallocationOpInterface :
         /*retType=*/"FailureOr<Operation *>",
         /*methodName=*/"process",
         /*args=*/(ins "DeallocationState &":$state,
-                      "const DeallocationOptions &":$options)>
+                      "const DeallocationOptions &":$options)>,
+      InterfaceMethod<
+        /*desc=*/[{
+          This method allows the implementing operation to specify custom logic
+          to materialize an ownership indicator value for the given MemRef typed
+          value it defines (including block arguments of nested regions). Since
+          the operation itself has more information about its semantics the
+          materialized IR can be more efficient compared to the default
+          implementation and avoid cloning MemRefs and/or doing alias checking
+          at runtime.
+          Note that the same logic could also be implemented in the 'process'
+          method above, however, the IR is always materialized then. If
+          it's desirable to only materialize the IR to compute an updated
+          ownership indicator when needed, it should be implemented using this
+          method (which is especially important if operations are created that
+          cannot be easily canonicalized away anymore).
+        }],
+        /*retType=*/"std::pair<Value, Value>",
+        /*methodName=*/"materializeUniqueOwnershipForMemref",
+        /*args=*/(ins "DeallocationState &":$state,
+                      "const DeallocationOptions &":$options,
+                      "OpBuilder &":$builder,
+                      "Value":$memref),
+        /*methodBody=*/[{}],
+        /*defaultImplementation=*/[{
+          return state.getMemrefWithUniqueOwnership(
+            builder, memref, memref.getParentBlock());
+        }]>,
   ];
 }
 
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index ee91bfa57d12a39..0182ab93929cb8c 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -20,6 +20,7 @@
 #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -133,6 +134,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
 
   // Register all external models.
   affine::registerValueBoundsOpInterfaceExternalModels(registry);
+  arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
   arith::registerBufferizableOpInterfaceExternalModels(registry);
   arith::registerValueBoundsOpInterfaceExternalModels(registry);
   bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..f2e7732e8ea4aa3
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -0,0 +1,85 @@
+//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace {
+/// Provides custom logic to materialize ownership indicator values for the
+/// result value of 'arith.select'. Instead of cloning or runtime alias
+/// checking, this implementation inserts another `arith.select` to choose the
+/// ownership indicator of the operand in the same way the original
+/// `arith.select` chooses the MemRef operand. If at least one of the operand's
+/// ownerships is 'Unknown', fall back to the default implementation.
+///
+/// Example:
+/// ```mlir
+/// // let ownership(%m0) := %o0
+/// // let ownership(%m1) := %o1
+/// %res = arith.select %cond, %m0, %m1
+/// ```
+/// The default implementation would insert a clone and replace all uses of the
+/// result of `arith.select` with that clone:
+/// ```mlir
+/// %res = arith.select %cond, %m0, %m1
+/// %clone = bufferization.clone %res
+/// // let ownership(%res) := 'Unknown'
+/// // let ownership(%clone) := %true
+/// // replace all uses of %res with %clone
+/// ```
+/// This implementation, on the other hand, materializes the following:
+/// ```mlir
+/// %res = arith.select %cond, %m0, %m1
+/// %res_ownership = arith.select %cond, %o0, %o1
+/// // let ownership(%res) := %res_ownership
+/// ```
+struct SelectOpInterface
+    : public BufferDeallocationOpInterface::ExternalModel<SelectOpInterface,
+                                                          arith::SelectOp> {
+  FailureOr<Operation *> process(Operation *op, DeallocationState &state,
+                                 const DeallocationOptions &options) const {
+    return op; // nothing to do
+  }
+
+  std::pair<Value, Value>
+  materializeUniqueOwnershipForMemref(Operation *op, DeallocationState &state,
+                                      const DeallocationOptions &options,
+                                      OpBuilder &builder, Value value) const {
+    auto selectOp = cast<arith::SelectOp>(op);
+    assert(value == selectOp.getResult() &&
+           "Value not defined by this operation");
+
+    Block *block = value.getParentBlock();
+    if (!state.getOwnership(selectOp.getTrueValue(), block).isUnique() ||
+        !state.getOwnership(selectOp.getFalseValue(), block).isUnique())
+      return state.getMemrefWithUniqueOwnership(builder, value,
+                                                value.getParentBlock());
+
+    Value ownership = builder.create<arith::SelectOp>(
+        op->getLoc(), selectOp.getCondition(),
+        state.getOwnership(selectOp.getTrueValue(), block).getIndicator(),
+        state.getOwnership(selectOp.getFalseValue(), block).getIndicator());
+    return {selectOp.getResult(), ownership};
+  }
+};
+
+} // namespace
+
+void mlir::arith::registerBufferDeallocationOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
+    SelectOp::attachInterface<SelectOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index a9b86b4d99256c0..02240601bcd35a1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRArithTransforms
+  BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
   EmulateUnsupportedFloats.cpp
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index 2314cee2ff2c158..407d75e2426e9f9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -134,8 +134,8 @@ void DeallocationState::getLiveMemrefsIn(Block *block,
 
 std::pair<Value, Value>
 DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
-                                                Value memref) {
-  auto iter = ownershipMap.find({memref, memref.getParentBlock()});
+                                                Value memref, Block *block) {
+  auto iter = ownershipMap.find({memref, block});
   assert(iter != ownershipMap.end() &&
          "Value must already have been registered in the ownership map");
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index d4b8e0dff67bae4..02fb4d3c42fa521 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -376,13 +376,24 @@ class BufferDeallocation {
   /// Given an SSA value of MemRef type, returns the same of a new SSA value
   /// which has 'Unique' ownership where the ownership indicator is guaranteed
   /// to be always 'true'.
-  Value getMemrefWithGuaranteedOwnership(OpBuilder &builder, Value memref);
+  Value materializeMemrefWithGuaranteedOwnership(OpBuilder &builder,
+                                                 Value memref, Block *block);
 
   /// Returns whether the given operation implements FunctionOpInterface, has
   /// private visibility, and the private-function-dynamic-ownership pass option
   /// is enabled.
   bool isFunctionWithoutDynamicOwnership(Operation *op);
 
+  /// Given an SSA value of MemRef type, this function queries the
+  /// BufferDeallocationOpInterface of the defining operation of 'memref' for a
+  /// materialized ownership indicator for 'memref'.  If the op does not
+  /// implement the interface or if the block for which the materialized value
+  /// is requested does not match the block in which 'memref' is defined, the
+  /// default implementation in
+  /// `DeallocationState::getMemrefWithUniqueOwnership` is queried instead.
+  std::pair<Value, Value>
+  materializeUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
+
   /// Checks all the preconditions for operations implementing the
   /// FunctionOpInterface that have to hold for the deallocation to be
   /// applicable:
@@ -428,6 +439,28 @@ class BufferDeallocation {
 // BufferDeallocation Implementation
 //===----------------------------------------------------------------------===//
 
+std::pair<Value, Value>
+BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
+                                               Block *block) {
+  // The interface can only materialize ownership indicators in the same block
+  // as the defining op.
+  if (memref.getParentBlock() != block)
+    return state.getMemrefWithUniqueOwnership(builder, memref, block);
+
+  Operation *owner = memref.getDefiningOp();
+  if (!owner)
+    owner = memref.getParentBlock()->getParentOp();
+
+  // If the op implements the interface, query it for a materialized ownership
+  // value.
+  if (auto deallocOpInterface = dyn_cast<BufferDeallocationOpInterface>(owner))
+    return deallocOpInterface.materializeUniqueOwnershipForMemref(
+        state, options, builder, memref);
+
+  // Otherwise use the default implementation.
+  return state.getMemrefWithUniqueOwnership(builder, memref, block);
+}
+
 static bool regionOperatesOnMemrefValues(Region &region) {
   WalkResult result = region.walk([](Block *block) {
     if (llvm::any_of(block->getArguments(), isMemref))
@@ -677,11 +710,11 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
   return newOp.getOperation();
 }
 
-Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder,
-                                                           Value memref) {
+Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership(
+    OpBuilder &builder, Value memref, Block *block) {
   // First, make sure we at least have 'Unique' ownership already.
   std::pair<Value, Value> newMemrefAndOnwership =
-      state.getMemrefWithUniqueOwnership(builder, memref);
+      materializeUniqueOwnership(builder, memref, block);
   Value newMemref = newMemrefAndOnwership.first;
   Value condition = newMemrefAndOnwership.second;
 
@@ -785,7 +818,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
         continue;
       }
       auto [memref, condition] =
-          state.getMemrefWithUniqueOwnership(builder, operand);
+          materializeUniqueOwnership(builder, operand, op->getBlock());
       newOperands.push_back(memref);
       ownershipIndicatorsToAdd.push_back(condition);
     }
@@ -868,7 +901,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
       if (!isMemref(val.get()))
         continue;
 
-      val.set(getMemrefWithGuaranteedOwnership(builder, val.get()));
+      val.set(materializeMemrefWithGuaranteedOwnership(builder, val.get(),
+                                                       op->getBlock()));
     }
   }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
index 67128fee3dfe0ab..bff06d4499938df 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
@@ -95,15 +95,15 @@ func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
 //  CHECK-NEXT:   return
 
 // CHECK-DYNAMIC-LABEL: func @function_call_requries_merged_ownership_mid_block
+//  CHECK-DYNAMIC-SAME: ([[ARG0:%.+]]: i1)
 //       CHECK-DYNAMIC:   [[ALLOC0:%.+]] = memref.alloc(
 //  CHECK-DYNAMIC-NEXT:   [[ALLOC1:%.+]] = memref.alloca(
-//  CHECK-DYNAMIC-NEXT:   [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
-//  CHECK-DYNAMIC-NEXT:   [[CLONE:%.+]] = bufferization.clone [[SELECT]]
-//  CHECK-DYNAMIC-NEXT:   [[RET:%.+]]:2 = call @f([[CLONE]], %true{{[0-9_]*}})
+//  CHECK-DYNAMIC-NEXT:   [[SELECT:%.+]] = arith.select [[ARG0]], [[ALLOC0]], [[ALLOC1]]
+//  CHECK-DYNAMIC-NEXT:   [[RET:%.+]]:2 = call @f([[SELECT]], [[ARG0]])
 //  CHECK-DYNAMIC-NEXT:   test.copy
 //  CHECK-DYNAMIC-NEXT:   [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
-//  CHECK-DYNAMIC-NEXT:   bufferization.dealloc ([[ALLOC0]], [[CLONE]], [[BASE]] :
-//  CHECK-DYNAMIC-SAME:     if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1)
+//  CHECK-DYNAMIC-NEXT:   bufferization.dealloc ([[ALLOC0]], [[BASE]] :
+//  CHECK-DYNAMIC-SAME:     if (%true{{[0-9_]*}}, [[RET]]#1)
 //   CHECK-DYNAMIC-NOT:     retain
 //  CHECK-DYNAMIC-NEXT:   return
 



More information about the Mlir-commits mailing list