[Mlir-commits] [mlir] d30fcad - [mlir][linalg][bufferize] Op interface implementation for Bufferization dialect ops
Matthias Springer
llvmlistbot at llvm.org
Thu Dec 2 23:31:05 PST 2021
Author: Matthias Springer
Date: 2021-12-03T16:25:44+09:00
New Revision: d30fcadf07ee552f20156ea90be2fdb54cb9cb08
URL: https://github.com/llvm/llvm-project/commit/d30fcadf07ee552f20156ea90be2fdb54cb9cb08
DIFF: https://github.com/llvm/llvm-project/commit/d30fcadf07ee552f20156ea90be2fdb54cb9cb08.diff
LOG: [mlir][linalg][bufferize] Op interface implementation for Bufferization dialect ops
This change provides `BufferizableOpInterface` implementations for ops from the Bufferization dialects. These ops are needed at the bufferization boundaries for partial bufferization.
Differential Revision: https://reviews.llvm.org/D114618
Added:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
Modified:
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
new file mode 100644
index 0000000000000..23c17f4b188f5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
@@ -0,0 +1,27 @@
+//===- BufferizationInterfaceImpl.h - Bufferization Impl. of Op Interface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace bufferization_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
+
+} // namespace bufferization_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 4d7c445e0a805..e0c5a1020447e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -416,10 +416,6 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
BufferizationState &state) {
OpBuilder b(op->getContext());
- // Skip ToMemrefOp and ToTensorOp.
- if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
- return success();
-
// Check if op has tensor results or operands.
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
new file mode 100644
index 0000000000000..c8a2649842402
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -0,0 +1,101 @@
+//===- BufferizationInterfaceImpl.cpp - Bufferization Impl. of Interface --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace linalg;
+using namespace comprehensive_bufferize;
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace bufferization_ext {
+
+// TODO: These ops should implement BufferizableOpInterface directly when moved
+// to the Bufferization dialect.
+
+// TODO: These implementations are conservative and will likely have to be
+// loosened for partial bufferization.
+
+/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
+/// location of the incoming tensor once it will be bufferized. In the anlysis,
+/// the incoming tensor is assumed to bufferize to a memory read and to an
+/// inplace memory write, since it is unknown what will happen to the resulting
+/// memref.
+struct ToMemrefOpInterface
+ : public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
+ bufferization::ToMemrefOp> {
+ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+ // It is unknown whether the resulting MemRef will be read or not.
+ return true;
+ }
+
+ SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+ OpResult opResult) const {
+ return {};
+ }
+
+ OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+ return OpResult();
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ return success();
+ }
+};
+
+/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do
+/// not lower any further, and they should have disappeared by the time the
+/// input is fully bufferized.
+///
+/// The analysis has no information about the memref that is loaded from by the
+/// ToTensorOp. We have to assume that the loaded tensor may after bufferization
+/// potentially alias with any other bufferized tensor. Since ToTensorOp and
+/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded
+/// directly in the analysis. However, declaring ToTensorOp results as not
+/// writable also enforces a buffer copy and has the same effect.
+struct ToTensorOpInterface
+ : public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
+ bufferization::ToTensorOp> {
+ SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+ OpResult opResult) const {
+ return {};
+ }
+
+ LogicalResult bufferize(Operation *op, OpBuilder &b,
+ BufferizationState &state) const {
+ auto tensorLoadOp = cast<bufferization::ToTensorOp>(op);
+ state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref());
+ return success();
+ }
+
+ bool isWritable(Operation *op, Value value) const {
+ // It is unknown whether the MemRef operand is writable or not.
+ return false;
+ }
+};
+
+} // namespace bufferization_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::bufferization_ext::
+ registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
+ registry.addOpInterface<bufferization::ToMemrefOp,
+ bufferization_ext::ToMemrefOpInterface>();
+ registry.addOpInterface<bufferization::ToTensorOp,
+ bufferization_ext::ToTensorOpInterface>();
+}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 68d5d037d72af..f03319669fd3b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
AffineInterfaceImpl.cpp
ArithInterfaceImpl.cpp
BufferizableOpInterface.cpp
+ BufferizationInterfaceImpl.cpp
ComprehensiveBufferize.cpp
LinalgInterfaceImpl.cpp
ModuleBufferization.cpp
@@ -80,6 +81,7 @@ add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
)
add_mlir_dialect_library(MLIRComprehensiveBufferize
+ BufferizationInterfaceImpl.cpp
ComprehensiveBufferize.cpp
ModuleBufferization.cpp
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 6cbc308281e22..d4571d3ac6702 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -239,6 +239,12 @@ static std::string printValueInfo(Value value, bool prefix) {
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
const BufferizationAliasInfo &aliasInfo) {
+ // The analysis does not know what happens to the result of a ToMemrefOp, so
+ // we assume that it is written to.
+ // TODO: This is a conservative implementation. This rule will have to be
+ // relaxed for partial bufferization.
+ if (isa<bufferization::ToMemrefOp>(opOperand.getOwner()))
+ return true;
// OpOperands without an aliasing OpResult do not write.
OpResult opResult = getAliasingOpResult(opOperand);
if (!opResult)
@@ -453,14 +459,23 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
/// If `checkConsistencyOnly` is true, this function checks if there is a
/// read-after-write conflict without bufferizing `operand` inplace. This would
/// indicate a problem with the current inplace bufferization decisions.
+///
+/// Note: If `checkConsistencyOnly`, this function may be called with a null
+/// OpResult. In that case, only the consistency of bufferization decisions
+/// involving aliases of the given OpOperand are checked.
bool wouldCreateReadAfterWriteInterference(
OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
const BufferizationAliasInfo &aliasInfo,
bool checkConsistencyOnly = false) {
#ifndef NDEBUG
- SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
- assert(llvm::find(opOperands, &operand) != opOperands.end() &&
- "operand and result do not match");
+ if (result) {
+ SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
+ assert(llvm::find(opOperands, &operand) != opOperands.end() &&
+ "operand and result do not match");
+ } else {
+ assert(checkConsistencyOnly &&
+ "result not provided, can only check consistency");
+ }
#endif // NDEBUG
// Helper function to iterate on aliases of `root` and capture the reads.
@@ -486,9 +501,11 @@ bool wouldCreateReadAfterWriteInterference(
// Collect reads and writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesRead, usesWrite;
getAliasingReads(usesRead, operand.get());
- getAliasingReads(usesRead, result);
+ if (result)
+ getAliasingReads(usesRead, result);
getAliasingInplaceWrites(usesWrite, operand.get());
- getAliasingInplaceWrites(usesWrite, result);
+ if (result)
+ getAliasingInplaceWrites(usesWrite, result);
if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
@@ -673,25 +690,38 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
return res;
}
-#ifndef NDEBUG
/// Assert that the current bufferization decisions are consistent.
-static void checkAliasInfoConsistency(FuncOp funcOp,
- const DominanceInfo &domInfo,
- const BufferizationAliasInfo &aliasInfo) {
- funcOp.walk([&](Operation *op) {
+static LogicalResult
+checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo,
+ const BufferizationAliasInfo &aliasInfo) {
+ Operation *inconsistentOp = nullptr;
+ WalkResult walkResult = funcOp.walk([&](Operation *op) {
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
for (OpOperand &opOperand : op->getOpOperands())
- if (opOperand.get().getType().isa<TensorType>())
- if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
- // If this assertion fails, there is probably an inconsistent
- // combination of "mustBufferizeInPlace" decisions.
- assert(!wouldCreateReadAfterWriteInterference(
- opOperand, opResult, domInfo, aliasInfo,
- /*checkConsistencyOnly=*/true) &&
- "found read after write conflict before running analysis");
+ if (opOperand.get().getType().isa<TensorType>()) {
+ OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
+ if (wouldCreateReadAfterWriteInterference(
+ opOperand, opResult, domInfo, aliasInfo,
+ /*checkConsistencyOnly=*/true)) {
+ // This error can happen for two reasons. Either the input IR
+ // already has a read-after-write conflict. Or certain
+ // "mustBufferizeInPlace" interface methods are implemented
+ // incorrectly.
+ inconsistentOp = op;
+ return WalkResult::interrupt();
+ }
+ }
+ return WalkResult::advance();
});
+
+ if (walkResult.wasInterrupted())
+ // This can currently happen in one situation: When a tensor is passed into
+ // a ToMemrefOp and read by another op consecutively. ToMemrefOps are
+ // currently handled conservatively. Once a tensor is passed into a
+ // ToMemrefOp, it may longer be read.
+ return inconsistentOp->emitError("input IR has RaW conflict");
+ return success();
}
-#endif
/// Annotate the IR with the result of the analysis. For testing/debugging only.
static void
@@ -720,9 +750,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
if (funcOp.body().empty())
return success();
-#ifndef NDEBUG
- checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
-#endif // NDEBUG
+ if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo)))
+ return failure();
// If the analysis fails, just return.
if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index ca713d151238d..1910298213334 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
@@ -47,6 +48,7 @@ struct LinalgComprehensiveModuleBufferize
arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
+ bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 616d84545e5c2..2e2792b1146cd 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1492,3 +1492,44 @@ func @main_func(%A : tensor<?xf32> {linalg.inplaceable = true},
%0 = call @some_use(%A, %v) : (tensor<?xf32>, vector<5xf32>) -> (tensor<?xf32>)
return %0 : tensor<?xf32>
}
+
+// -----
+
+// CHECK-LABEL: func @to_tensor_op_not_writable
+func @to_tensor_op_not_writable(%m: memref<?xf32>, %v: vector<5xf32>,
+ %idx1: index, %idx2: index)
+ -> vector<10xf32> {
+ %0 = bufferization.to_tensor %m : memref<?xf32>
+
+ // Write to the tensor. Cannot be inplace due to tensor_load.
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %w = vector.transfer_write %v, %0[%idx1] : vector<5xf32>, tensor<?xf32>
+
+ // Read from the tensor and return result.
+ %cst = arith.constant 0.0 : f32
+ %r = vector.transfer_read %w[%idx2], %cst : tensor<?xf32>, vector<10xf32>
+ return %r : vector<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @to_memref_op_is_reading
+func @to_memref_op_is_reading(%t1: tensor<?xf32> {linalg.inplaceable = true},
+ %idx1: index, %idx2: index, %idx3: index,
+ %v1: vector<5xf32>)
+ -> (vector<5xf32>, vector<5xf32>) {
+ // Write + read to/from tensor.
+ // CHECK: vector.transfer_write
+ // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+ %1 = vector.transfer_write %v1, %t1[%idx2] : vector<5xf32>, tensor<?xf32>
+ %cst = arith.constant 0.0 : f32
+ %r1 = vector.transfer_read %1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
+
+ // Write + read to/from same memref.
+ %0 = bufferization.to_memref %t1 : memref<?xf32>
+ vector.transfer_write %v1, %0[%idx1] : vector<5xf32>, memref<?xf32>
+ %r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
+
+ return %r1, %r2 : vector<5xf32>, vector<5xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index c0a91e3f1df83..edeb0c07da0f2 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -167,3 +167,23 @@ func @main() -> tensor<4xi32> {
}
return %r: tensor<4xi32>
}
+
+// -----
+
+func @to_memref_op_is_writing(
+ %t1: tensor<?xf32> {linalg.inplaceable = true}, %idx1: index,
+ %idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) {
+ // This is a RaW conflict because to_memref is an inplace write and %t1 is
+ // read further down. This will likely have to change with partial
+ // bufferization.
+
+ // expected-error @+1 {{input IR has RaW conflict}}
+ %0 = bufferization.to_memref %t1 : memref<?xf32>
+
+ // Read from both.
+ %cst = arith.constant 0.0 : f32
+ %r1 = vector.transfer_read %t1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
+ %r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
+
+ return %r1, %r2 : vector<5xf32>, vector<5xf32>
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a68aa7b74150e..7667ce971aab3 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6673,10 +6673,12 @@ cc_library(
cc_library(
name = "ComprehensiveBufferize",
srcs = [
+ "lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp",
"lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp",
"lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp",
],
hdrs = [
+ "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h",
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h",
"include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h",
],
More information about the Mlir-commits
mailing list