[Mlir-commits] [mlir] 87c770b - [mlir][bufferization][NFC] Put inplacability conflict resolution in op interface
Matthias Springer
llvmlistbot at llvm.org
Thu Jun 9 13:10:26 PDT 2022
Author: Matthias Springer
Date: 2022-06-09T22:06:44+02:00
New Revision: 87c770bbd04462369950b1e6940e4f9ee4fc6de3
URL: https://github.com/llvm/llvm-project/commit/87c770bbd04462369950b1e6940e4f9ee4fc6de3
DIFF: https://github.com/llvm/llvm-project/commit/87c770bbd04462369950b1e6940e4f9ee4fc6de3.diff
LOG: [mlir][bufferization][NFC] Put inplacability conflict resolution in op interface
The TensorCopyInsertion pass resolves out-of-place bufferization decisions by inserting explicit `bufferization.alloc_tensor` ops. This change moves that functionality into a new BufferizableOpInterface method, so that it can be overridden by op implementations. Some op bufferizations must insert additional `alloc_tensor` ops to make sure that certain aliasing invariants are not violated (e.g., scf::ForOp). This will be addressed in a subsequent change.
Differential Revision: https://reviews.llvm.org/D126817
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index f69d612bbcdb9..e47e8478d4ad7 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -192,6 +192,32 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
llvm_unreachable("bufferRelation not implemented");
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Resolve all inplacability conflicts by inserting explicit
+ `bufferization.alloc_tensor` ops. Examples of inplacability conflicts
+ are read-after-write conflicts or writes into non-writable buffers.
+
+ This method should rewrite the IR in such a way that for each tensor
+ OpOperand t, buffer(t) can be directly used when during bufferization.
+ The bufferization does no longer have to care about inplacability
+ conflicts.
+
+ This method can query analysis information from the given analysis
+ state.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"resolveConflicts",
+ /*args=*/(ins "RewriterBase &":$rewriter,
+ "const AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ auto bufferizableOp =
+ cast<BufferizableOpInterface>($_op.getOperation());
+ return bufferizableOp.resolveTensorOpOperandConflicts(
+ rewriter, state);
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
@@ -301,6 +327,11 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
];
let extraClassDeclaration = [{
+ /// Resolve out-of-place tensor OpOperands with explicit allocations in the
+ /// form of `bufferization.alloc_tensor` ops.
+ LogicalResult resolveTensorOpOperandConflicts(
+ RewriterBase &rewriter, const AnalysisState &state);
+
/// Return `true` if the given OpOperand creates an alias but does neither
/// read nor write. This implies that `bufferizesToMemoryRead` and
/// `bufferizesToMemoryWrite` must return `false`. This method will never
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 387b5eb269398..de8e30415a6d1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -18,6 +18,10 @@
#include "mlir/IR/Value.h"
#include "llvm/Support/Debug.h"
+//===----------------------------------------------------------------------===//
+// BufferizableOpInterface
+//===----------------------------------------------------------------------===//
+
namespace mlir {
namespace bufferization {
@@ -38,6 +42,31 @@ using namespace bufferization;
constexpr const ::llvm::StringLiteral
bufferization::BufferizableOpInterface::kInplaceableAttrName;
+LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
+ RewriterBase &rewriter, const AnalysisState &state) {
+ Operation *op = getOperation();
+ for (OpOperand &opOperand : op->getOpOperands()) {
+ Type operandType = opOperand.get().getType();
+ if (!operandType.isa<TensorType>())
+ continue;
+ if (state.isInPlace(opOperand))
+ continue;
+ if (operandType.isa<UnrankedTensorType>())
+ return op->emitError("copies of unranked tensors are not supported");
+ auto tensorType = operandType.dyn_cast<RankedTensorType>();
+ if (!tensorType)
+ continue;
+ SmallVector<OpResult> aliasingOpResults =
+ state.getAliasingOpResult(opOperand);
+ bool escape = llvm::any_of(
+ aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
+ Value copy = rewriter.create<AllocTensorOp>(
+ op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
+ rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); });
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// OpFilter
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 16b4b0b0d2ca7..e04f1e386ee91 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -43,7 +43,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
LogicalResult
mlir::bufferization::insertTensorCopies(Operation *op,
const AnalysisState &state) {
- OpBuilder builder(op->getContext());
+ IRRewriter rewriter(op->getContext());
WalkResult result = op->walk([&](Operation *op) {
auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
if (!bufferizableOp)
@@ -55,31 +55,15 @@ mlir::bufferization::insertTensorCopies(Operation *op,
if (allocTensorOp.escape())
return WalkResult::advance();
bool escape = state.isTensorYielded(allocTensorOp.result());
- allocTensorOp.escapeAttr(builder.getBoolAttr(escape));
+ allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape));
return WalkResult::advance();
}
- // Find out-of-place tensor OpOperands and resolve them with an explicit
- // tensor copy in the form of an AllocTensorOp.
- builder.setInsertionPoint(op);
- for (OpOperand &opOperand : op->getOpOperands()) {
- if (opOperand.get().getType().isa<UnrankedTensorType>()) {
- op->emitError("copies of unranked tensors are not supported");
- return WalkResult::interrupt();
- }
- auto tensorType = opOperand.get().getType().dyn_cast<RankedTensorType>();
- if (!tensorType)
- continue;
- if (state.isInPlace(opOperand))
- continue;
- SmallVector<OpResult> aliasingOpResults =
- state.getAliasingOpResult(opOperand);
- bool escape = llvm::any_of(
- aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
- Value copy = builder.create<AllocTensorOp>(
- op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
- opOperand.set(copy);
- }
+ // Find inplacability conflicts and resolve them. (Typically with explicit
+ // tensor copies in the form of AllocTensorOps.)
+ rewriter.setInsertionPoint(op);
+ if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
+ return WalkResult::interrupt();
return WalkResult::advance();
});
More information about the Mlir-commits
mailing list