[Mlir-commits] [mlir] [mlir][bufferization] BufferDeallocationOpInterface: support custom ownership update logic (PR #66350)
Martin Erhart
llvmlistbot at llvm.org
Thu Sep 14 05:03:15 PDT 2023
https://github.com/maerhart updated https://github.com/llvm/llvm-project/pull/66350:
>From 6c1b5d2f630360bd0fc44acc229d78b7101bc68d Mon Sep 17 00:00:00 2001
From: Martin Erhart <merhart at google.com>
Date: Tue, 12 Sep 2023 15:20:05 +0000
Subject: [PATCH] [mlir][bufferization] BufferDeallocationOpInterface: support
custom ownership update logic
Add a method to the BufferDeallocationOpInterface that allows operations to
implement the interface and provide custom logic to compute the ownership
indicators of values it defines. As a demonstrating example, this new method is
implemented by the `arith.select` operation.
---
.../BufferDeallocationOpInterfaceImpl.h | 22 +++++
.../IR/BufferDeallocationOpInterface.h | 4 +-
.../IR/BufferDeallocationOpInterface.td | 29 ++++++-
mlir/include/mlir/InitAllDialects.h | 2 +
.../BufferDeallocationOpInterfaceImpl.cpp | 85 +++++++++++++++++++
.../Dialect/Arith/Transforms/CMakeLists.txt | 1 +
.../IR/BufferDeallocationOpInterface.cpp | 4 +-
.../OwnershipBasedBufferDeallocation.cpp | 46 ++++++++--
.../dealloc-callop-interface.mlir | 10 +--
9 files changed, 187 insertions(+), 16 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h
create mode 100644 mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h
new file mode 100644
index 000000000000000..16cec1a82b5c86c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h
@@ -0,0 +1,22 @@
+//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace arith {
+void registerBufferDeallocationOpInterfaceExternalModels(
+ DialectRegistry ®istry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
index b88270f1c150a27..7ac4592de7875fb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h
@@ -142,8 +142,8 @@ class DeallocationState {
/// a new SSA value, returned as the first element of the pair, which has
/// 'Unique' ownership and can be used instead of the passed Value with the
/// the ownership indicator returned as the second element of the pair.
- std::pair<Value, Value> getMemrefWithUniqueOwnership(OpBuilder &builder,
- Value memref);
+ std::pair<Value, Value>
+ getMemrefWithUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
/// Given two basic blocks and the values passed via block arguments to the
/// destination block, compute the list of MemRefs that have to be retained in
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
index c35fe417184ffd4..3e11432c65c5f08 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td
@@ -39,7 +39,34 @@ def BufferDeallocationOpInterface :
/*retType=*/"FailureOr<Operation *>",
/*methodName=*/"process",
/*args=*/(ins "DeallocationState &":$state,
- "const DeallocationOptions &":$options)>
+ "const DeallocationOptions &":$options)>,
+ InterfaceMethod<
+ /*desc=*/[{
+ This method allows the implementing operation to specify custom logic
+ to materialize an ownership indicator value for the given MemRef typed
+ value it defines (including block arguments of nested regions). Since
+ the operation itself has more information about its semantics the
+ materialized IR can be more efficient compared to the default
+ implementation and avoid cloning MemRefs and/or doing alias checking
+ at runtime.
+ Note that the same logic could also be implemented in the 'process'
+ method above, however, the IR is always materialized then. If
+ it's desirable to only materialize the IR to compute an updated
+ ownership indicator when needed, it should be implemented using this
+ method (which is especially important if operations are created that
+ cannot be easily canonicalized away anymore).
+ }],
+ /*retType=*/"std::pair<Value, Value>",
+ /*methodName=*/"materializeUniqueOwnershipForMemref",
+ /*args=*/(ins "DeallocationState &":$state,
+ "const DeallocationOptions &":$options,
+ "OpBuilder &":$builder,
+ "Value":$memref),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return state.getMemrefWithUniqueOwnership(
+ builder, memref, memref.getParentBlock());
+ }]>,
];
}
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index ee91bfa57d12a39..0182ab93929cb8c 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -20,6 +20,7 @@
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -133,6 +134,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
// Register all external models.
affine::registerValueBoundsOpInterfaceExternalModels(registry);
+ arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..f2e7732e8ea4aa3
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -0,0 +1,85 @@
+//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+
+namespace {
+/// Provides custom logic to materialize ownership indicator values for the
+/// result value of 'arith.select'. Instead of cloning or runtime alias
+/// checking, this implementation inserts another `arith.select` to choose the
+/// ownership indicator of the operand in the same way the original
+/// `arith.select` chooses the MemRef operand. If at least one of the operand's
+/// ownerships is 'Unknown', fall back to the default implementation.
+///
+/// Example:
+/// ```mlir
+/// // let ownership(%m0) := %o0
+/// // let ownership(%m1) := %o1
+/// %res = arith.select %cond, %m0, %m1
+/// ```
+/// The default implementation would insert a clone and replace all uses of the
+/// result of `arith.select` with that clone:
+/// ```mlir
+/// %res = arith.select %cond, %m0, %m1
+/// %clone = bufferization.clone %res
+/// // let ownership(%res) := 'Unknown'
+/// // let ownership(%clone) := %true
+/// // replace all uses of %res with %clone
+/// ```
+/// This implementation, on the other hand, materializes the following:
+/// ```mlir
+/// %res = arith.select %cond, %m0, %m1
+/// %res_ownership = arith.select %cond, %o0, %o1
+/// // let ownership(%res) := %res_ownership
+/// ```
+struct SelectOpInterface
+ : public BufferDeallocationOpInterface::ExternalModel<SelectOpInterface,
+ arith::SelectOp> {
+ FailureOr<Operation *> process(Operation *op, DeallocationState &state,
+ const DeallocationOptions &options) const {
+ return op; // nothing to do
+ }
+
+ std::pair<Value, Value>
+ materializeUniqueOwnershipForMemref(Operation *op, DeallocationState &state,
+ const DeallocationOptions &options,
+ OpBuilder &builder, Value value) const {
+ auto selectOp = cast<arith::SelectOp>(op);
+ assert(value == selectOp.getResult() &&
+ "Value not defined by this operation");
+
+ Block *block = value.getParentBlock();
+ if (!state.getOwnership(selectOp.getTrueValue(), block).isUnique() ||
+ !state.getOwnership(selectOp.getFalseValue(), block).isUnique())
+ return state.getMemrefWithUniqueOwnership(builder, value,
+ value.getParentBlock());
+
+ Value ownership = builder.create<arith::SelectOp>(
+ op->getLoc(), selectOp.getCondition(),
+ state.getOwnership(selectOp.getTrueValue(), block).getIndicator(),
+ state.getOwnership(selectOp.getFalseValue(), block).getIndicator());
+ return {selectOp.getResult(), ownership};
+ }
+};
+
+} // namespace
+
+void mlir::arith::registerBufferDeallocationOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, ArithDialect *dialect) {
+ SelectOp::attachInterface<SelectOpInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index a9b86b4d99256c0..02240601bcd35a1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRArithTransforms
+ BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
EmulateUnsupportedFloats.cpp
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
index 2314cee2ff2c158..407d75e2426e9f9 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp
@@ -134,8 +134,8 @@ void DeallocationState::getLiveMemrefsIn(Block *block,
std::pair<Value, Value>
DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder,
- Value memref) {
- auto iter = ownershipMap.find({memref, memref.getParentBlock()});
+ Value memref, Block *block) {
+ auto iter = ownershipMap.find({memref, block});
assert(iter != ownershipMap.end() &&
"Value must already have been registered in the ownership map");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index d4b8e0dff67bae4..02fb4d3c42fa521 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -376,13 +376,24 @@ class BufferDeallocation {
/// Given an SSA value of MemRef type, returns the same of a new SSA value
/// which has 'Unique' ownership where the ownership indicator is guaranteed
/// to be always 'true'.
- Value getMemrefWithGuaranteedOwnership(OpBuilder &builder, Value memref);
+ Value materializeMemrefWithGuaranteedOwnership(OpBuilder &builder,
+ Value memref, Block *block);
/// Returns whether the given operation implements FunctionOpInterface, has
/// private visibility, and the private-function-dynamic-ownership pass option
/// is enabled.
bool isFunctionWithoutDynamicOwnership(Operation *op);
+ /// Given an SSA value of MemRef type, this function queries the
+ /// BufferDeallocationOpInterface of the defining operation of 'memref' for a
+ /// materialized ownership indicator for 'memref'. If the op does not
+ /// implement the interface or if the block for which the materialized value
+ /// is requested does not match the block in which 'memref' is defined, the
+ /// default implementation in
+ /// `DeallocationState::getMemrefWithUniqueOwnership` is queried instead.
+ std::pair<Value, Value>
+ materializeUniqueOwnership(OpBuilder &builder, Value memref, Block *block);
+
/// Checks all the preconditions for operations implementing the
/// FunctionOpInterface that have to hold for the deallocation to be
/// applicable:
@@ -428,6 +439,28 @@ class BufferDeallocation {
// BufferDeallocation Implementation
//===----------------------------------------------------------------------===//
+std::pair<Value, Value>
+BufferDeallocation::materializeUniqueOwnership(OpBuilder &builder, Value memref,
+ Block *block) {
+ // The interface can only materialize ownership indicators in the same block
+ // as the defining op.
+ if (memref.getParentBlock() != block)
+ return state.getMemrefWithUniqueOwnership(builder, memref, block);
+
+ Operation *owner = memref.getDefiningOp();
+ if (!owner)
+ owner = memref.getParentBlock()->getParentOp();
+
+ // If the op implements the interface, query it for a materialized ownership
+ // value.
+ if (auto deallocOpInterface = dyn_cast<BufferDeallocationOpInterface>(owner))
+ return deallocOpInterface.materializeUniqueOwnershipForMemref(
+ state, options, builder, memref);
+
+ // Otherwise use the default implementation.
+ return state.getMemrefWithUniqueOwnership(builder, memref, block);
+}
+
static bool regionOperatesOnMemrefValues(Region ®ion) {
WalkResult result = region.walk([](Block *block) {
if (llvm::any_of(block->getArguments(), isMemref))
@@ -677,11 +710,11 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) {
return newOp.getOperation();
}
-Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder,
- Value memref) {
+Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership(
+ OpBuilder &builder, Value memref, Block *block) {
// First, make sure we at least have 'Unique' ownership already.
std::pair<Value, Value> newMemrefAndOnwership =
- state.getMemrefWithUniqueOwnership(builder, memref);
+ materializeUniqueOwnership(builder, memref, block);
Value newMemref = newMemrefAndOnwership.first;
Value condition = newMemrefAndOnwership.second;
@@ -785,7 +818,7 @@ FailureOr<Operation *> BufferDeallocation::handleInterface(CallOpInterface op) {
continue;
}
auto [memref, condition] =
- state.getMemrefWithUniqueOwnership(builder, operand);
+ materializeUniqueOwnership(builder, operand, op->getBlock());
newOperands.push_back(memref);
ownershipIndicatorsToAdd.push_back(condition);
}
@@ -868,7 +901,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
if (!isMemref(val.get()))
continue;
- val.set(getMemrefWithGuaranteedOwnership(builder, val.get()));
+ val.set(materializeMemrefWithGuaranteedOwnership(builder, val.get(),
+ op->getBlock()));
}
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
index 67128fee3dfe0ab..bff06d4499938df 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-callop-interface.mlir
@@ -95,15 +95,15 @@ func.func @function_call_requries_merged_ownership_mid_block(%arg0: i1) {
// CHECK-NEXT: return
// CHECK-DYNAMIC-LABEL: func @function_call_requries_merged_ownership_mid_block
+// CHECK-DYNAMIC-SAME: ([[ARG0:%.+]]: i1)
// CHECK-DYNAMIC: [[ALLOC0:%.+]] = memref.alloc(
// CHECK-DYNAMIC-NEXT: [[ALLOC1:%.+]] = memref.alloca(
-// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select{{.*}}[[ALLOC0]], [[ALLOC1]]
-// CHECK-DYNAMIC-NEXT: [[CLONE:%.+]] = bufferization.clone [[SELECT]]
-// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[CLONE]], %true{{[0-9_]*}})
+// CHECK-DYNAMIC-NEXT: [[SELECT:%.+]] = arith.select [[ARG0]], [[ALLOC0]], [[ALLOC1]]
+// CHECK-DYNAMIC-NEXT: [[RET:%.+]]:2 = call @f([[SELECT]], [[ARG0]])
// CHECK-DYNAMIC-NEXT: test.copy
// CHECK-DYNAMIC-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[RET]]#0
-// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[CLONE]], [[BASE]] :
-// CHECK-DYNAMIC-SAME: if (%true{{[0-9_]*}}, %true{{[0-9_]*}}, [[RET]]#1)
+// CHECK-DYNAMIC-NEXT: bufferization.dealloc ([[ALLOC0]], [[BASE]] :
+// CHECK-DYNAMIC-SAME: if (%true{{[0-9_]*}}, [[RET]]#1)
// CHECK-DYNAMIC-NOT: retain
// CHECK-DYNAMIC-NEXT: return
More information about the Mlir-commits
mailing list