[Mlir-commits] [mlir] d30fcad - [mlir][linalg][bufferize] Op interface implementation for Bufferization dialect ops

Matthias Springer llvmlistbot at llvm.org
Thu Dec 2 23:31:05 PST 2021


Author: Matthias Springer
Date: 2021-12-03T16:25:44+09:00
New Revision: d30fcadf07ee552f20156ea90be2fdb54cb9cb08

URL: https://github.com/llvm/llvm-project/commit/d30fcadf07ee552f20156ea90be2fdb54cb9cb08
DIFF: https://github.com/llvm/llvm-project/commit/d30fcadf07ee552f20156ea90be2fdb54cb9cb08.diff

LOG: [mlir][linalg][bufferize] Op interface implementation for Bufferization dialect ops

This change provides `BufferizableOpInterface` implementations for ops from the Bufferization dialects. These ops are needed at the bufferization boundaries for partial bufferization.

Differential Revision: https://reviews.llvm.org/D114618

Added: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp

Modified: 
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
new file mode 100644
index 0000000000000..23c17f4b188f5
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h
@@ -0,0 +1,27 @@
+//===- BufferizationInterfaceImpl.h - Bufferization Impl. of Op Interface -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
+#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H
+
+namespace mlir {
+
+class DialectRegistry;
+
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace bufferization_ext {
+
+void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+
+} // namespace bufferization_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_BUFFERIZATION_INTERFACE_IMPL_H

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 4d7c445e0a805..e0c5a1020447e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -416,10 +416,6 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
                                                  BufferizationState &state) {
   OpBuilder b(op->getContext());
 
-  // Skip ToMemrefOp and ToTensorOp.
-  if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp>(op))
-    return success();
-
   // Check if op has tensor results or operands.
   auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
   bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
new file mode 100644
index 0000000000000..c8a2649842402
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp
@@ -0,0 +1,101 @@
+//===- BufferizationInterfaceImpl.cpp - Bufferization Impl. of Interface --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace linalg;
+using namespace comprehensive_bufferize;
+
+namespace mlir {
+namespace linalg {
+namespace comprehensive_bufferize {
+namespace bufferization_ext {
+
+// TODO: These ops should implement BufferizableOpInterface directly when moved
+// to the Bufferization dialect.
+
+// TODO: These implementations are conservative and will likely have to be
+// loosened for partial bufferization.
+
+/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
+/// location of the incoming tensor once it will be bufferized. In the anlysis,
+/// the incoming tensor is assumed to bufferize to a memory read and to an
+/// inplace memory write, since it is unknown what will happen to the resulting
+/// memref.
+struct ToMemrefOpInterface
+    : public BufferizableOpInterface::ExternalModel<ToMemrefOpInterface,
+                                                    bufferization::ToMemrefOp> {
+  bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const {
+    // It is unknown whether the resulting MemRef will be read or not.
+    return true;
+  }
+
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    return {};
+  }
+
+  OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const {
+    return OpResult();
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    return success();
+  }
+};
+
+/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do
+/// not lower any further, and they should have disappeared by the time the
+/// input is fully bufferized.
+///
+/// The analysis has no information about the memref that is loaded from by the
+/// ToTensorOp. We have to assume that the loaded tensor may after bufferization
+/// potentially alias with any other bufferized tensor. Since ToTensorOp and
+/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded
+/// directly in the analysis. However, declaring ToTensorOp results as not
+/// writable also enforces a buffer copy and has the same effect.
+struct ToTensorOpInterface
+    : public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
+                                                    bufferization::ToTensorOp> {
+  SmallVector<OpOperand *> getAliasingOpOperand(Operation *op,
+                                                OpResult opResult) const {
+    return {};
+  }
+
+  LogicalResult bufferize(Operation *op, OpBuilder &b,
+                          BufferizationState &state) const {
+    auto tensorLoadOp = cast<bufferization::ToTensorOp>(op);
+    state.mapBuffer(tensorLoadOp.result(), tensorLoadOp.memref());
+    return success();
+  }
+
+  bool isWritable(Operation *op, Value value) const {
+    // It is unknown whether the MemRef operand is writable or not.
+    return false;
+  }
+};
+
+} // namespace bufferization_ext
+} // namespace comprehensive_bufferize
+} // namespace linalg
+} // namespace mlir
+
+void mlir::linalg::comprehensive_bufferize::bufferization_ext::
+    registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
+  registry.addOpInterface<bufferization::ToMemrefOp,
+                          bufferization_ext::ToMemrefOpInterface>();
+  registry.addOpInterface<bufferization::ToTensorOp,
+                          bufferization_ext::ToTensorOpInterface>();
+}

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
index 68d5d037d72af..f03319669fd3b 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
@@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
   AffineInterfaceImpl.cpp
   ArithInterfaceImpl.cpp
   BufferizableOpInterface.cpp
+  BufferizationInterfaceImpl.cpp
   ComprehensiveBufferize.cpp
   LinalgInterfaceImpl.cpp
   ModuleBufferization.cpp
@@ -80,6 +81,7 @@ add_mlir_dialect_library(MLIRVectorBufferizableOpInterfaceImpl
 )
 
 add_mlir_dialect_library(MLIRComprehensiveBufferize
+  BufferizationInterfaceImpl.cpp
   ComprehensiveBufferize.cpp
   ModuleBufferization.cpp
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 6cbc308281e22..d4571d3ac6702 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -239,6 +239,12 @@ static std::string printValueInfo(Value value, bool prefix) {
 /// Return true if opOperand has been decided to bufferize in-place.
 static bool isInplaceMemoryWrite(OpOperand &opOperand,
                                  const BufferizationAliasInfo &aliasInfo) {
+  // The analysis does not know what happens to the result of a ToMemrefOp, so
+  // we assume that it is written to.
+  // TODO: This is a conservative implementation. This rule will have to be
+  // relaxed for partial bufferization.
+  if (isa<bufferization::ToMemrefOp>(opOperand.getOwner()))
+    return true;
   // OpOperands without an aliasing OpResult do not write.
   OpResult opResult = getAliasingOpResult(opOperand);
   if (!opResult)
@@ -453,14 +459,23 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
 /// If `checkConsistencyOnly` is true, this function checks if there is a
 /// read-after-write conflict without bufferizing `operand` inplace. This would
 /// indicate a problem with the current inplace bufferization decisions.
+///
+/// Note: If `checkConsistencyOnly`, this function may be called with a null
+/// OpResult. In that case, only the consistency of bufferization decisions
+/// involving aliases of the given OpOperand are checked.
 bool wouldCreateReadAfterWriteInterference(
     OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
     const BufferizationAliasInfo &aliasInfo,
     bool checkConsistencyOnly = false) {
 #ifndef NDEBUG
-  SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
-  assert(llvm::find(opOperands, &operand) != opOperands.end() &&
-         "operand and result do not match");
+  if (result) {
+    SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
+    assert(llvm::find(opOperands, &operand) != opOperands.end() &&
+           "operand and result do not match");
+  } else {
+    assert(checkConsistencyOnly &&
+           "result not provided, can only check consistency");
+  }
 #endif // NDEBUG
 
   // Helper function to iterate on aliases of `root` and capture the reads.
@@ -486,9 +501,11 @@ bool wouldCreateReadAfterWriteInterference(
   // Collect reads and writes of all aliases of OpOperand and OpResult.
   DenseSet<OpOperand *> usesRead, usesWrite;
   getAliasingReads(usesRead, operand.get());
-  getAliasingReads(usesRead, result);
+  if (result)
+    getAliasingReads(usesRead, result);
   getAliasingInplaceWrites(usesWrite, operand.get());
-  getAliasingInplaceWrites(usesWrite, result);
+  if (result)
+    getAliasingInplaceWrites(usesWrite, result);
   if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
     usesWrite.insert(&operand);
 
@@ -673,25 +690,38 @@ inPlaceAnalysisFuncOpBody(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
   return res;
 }
 
-#ifndef NDEBUG
 /// Assert that the current bufferization decisions are consistent.
-static void checkAliasInfoConsistency(FuncOp funcOp,
-                                      const DominanceInfo &domInfo,
-                                      const BufferizationAliasInfo &aliasInfo) {
-  funcOp.walk([&](Operation *op) {
+static LogicalResult
+checkAliasInfoConsistency(FuncOp funcOp, const DominanceInfo &domInfo,
+                          const BufferizationAliasInfo &aliasInfo) {
+  Operation *inconsistentOp = nullptr;
+  WalkResult walkResult = funcOp.walk([&](Operation *op) {
     if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
       for (OpOperand &opOperand : op->getOpOperands())
-        if (opOperand.get().getType().isa<TensorType>())
-          if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
-            // If this assertion fails, there is probably an inconsistent
-            // combination of "mustBufferizeInPlace" decisions.
-            assert(!wouldCreateReadAfterWriteInterference(
-                       opOperand, opResult, domInfo, aliasInfo,
-                       /*checkConsistencyOnly=*/true) &&
-                   "found read after write conflict before running analysis");
+        if (opOperand.get().getType().isa<TensorType>()) {
+          OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
+          if (wouldCreateReadAfterWriteInterference(
+                  opOperand, opResult, domInfo, aliasInfo,
+                  /*checkConsistencyOnly=*/true)) {
+            // This error can happen for two reasons. Either the input IR
+            // already has a read-after-write conflict. Or certain
+            // "mustBufferizeInPlace" interface methods are implemented
+            // incorrectly.
+            inconsistentOp = op;
+            return WalkResult::interrupt();
+          }
+        }
+    return WalkResult::advance();
   });
+
+  if (walkResult.wasInterrupted())
+    // This can currently happen in one situation: When a tensor is passed into
+    // a ToMemrefOp and read by another op consecutively. ToMemrefOps are
+    // currently handled conservatively. Once a tensor is passed into a
+    // ToMemrefOp, it may longer be read.
+    return inconsistentOp->emitError("input IR has RaW conflict");
+  return success();
 }
-#endif
 
 /// Annotate the IR with the result of the analysis. For testing/debugging only.
 static void
@@ -720,9 +750,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   if (funcOp.body().empty())
     return success();
 
-#ifndef NDEBUG
-  checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
-#endif // NDEBUG
+  if (failed(checkAliasInfoConsistency(funcOp, domInfo, aliasInfo)))
+    return failure();
 
   // If the analysis fails, just return.
   if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index ca713d151238d..1910298213334 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
 #include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
@@ -47,6 +48,7 @@ struct LinalgComprehensiveModuleBufferize
                 arith::ArithmeticDialect, StandardOpsDialect, AffineDialect>();
     affine_ext::registerBufferizableOpInterfaceExternalModels(registry);
     arith_ext::registerBufferizableOpInterfaceExternalModels(registry);
+    bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
     linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
     scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
     std_ext::registerBufferizableOpInterfaceExternalModels(registry);

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
index 616d84545e5c2..2e2792b1146cd 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -1492,3 +1492,44 @@ func @main_func(%A : tensor<?xf32> {linalg.inplaceable = true},
   %0 = call @some_use(%A, %v) : (tensor<?xf32>, vector<5xf32>) -> (tensor<?xf32>)
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @to_tensor_op_not_writable
+func @to_tensor_op_not_writable(%m: memref<?xf32>, %v:  vector<5xf32>,
+                                %idx1: index, %idx2: index)
+    -> vector<10xf32> {
+  %0 = bufferization.to_tensor %m : memref<?xf32>
+
+  // Write to the tensor. Cannot be inplace due to tensor_load.
+  //      CHECK: vector.transfer_write
+  // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+  %w = vector.transfer_write %v, %0[%idx1] : vector<5xf32>, tensor<?xf32>
+
+  // Read from the tensor and return result.
+  %cst = arith.constant 0.0 : f32
+  %r = vector.transfer_read %w[%idx2], %cst : tensor<?xf32>, vector<10xf32>
+  return %r : vector<10xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @to_memref_op_is_reading
+func @to_memref_op_is_reading(%t1: tensor<?xf32> {linalg.inplaceable = true},
+                              %idx1: index, %idx2: index, %idx3: index,
+                              %v1: vector<5xf32>)
+    -> (vector<5xf32>, vector<5xf32>) {
+  // Write + read to/from tensor.
+  //      CHECK: vector.transfer_write
+  // CHECK-SAME: {__inplace_results_attr__ = ["false"]
+  %1 = vector.transfer_write %v1, %t1[%idx2] : vector<5xf32>, tensor<?xf32>
+  %cst = arith.constant 0.0 : f32
+  %r1 = vector.transfer_read %1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
+
+  // Write + read to/from same memref.
+  %0 = bufferization.to_memref %t1 : memref<?xf32>
+  vector.transfer_write %v1, %0[%idx1] : vector<5xf32>, memref<?xf32>
+  %r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
+
+  return %r1, %r2 : vector<5xf32>, vector<5xf32>
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
index c0a91e3f1df83..edeb0c07da0f2 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir
@@ -167,3 +167,23 @@ func @main() -> tensor<4xi32> {
   }
   return %r: tensor<4xi32>
 }
+
+// -----
+
+func @to_memref_op_is_writing(
+    %t1: tensor<?xf32> {linalg.inplaceable = true}, %idx1: index,
+    %idx2: index, %idx3: index, %v1: vector<5xf32>) -> (vector<5xf32>, vector<5xf32>) {
+  // This is a RaW conflict because to_memref is an inplace write and %t1 is
+  // read further down. This will likely have to change with partial
+  // bufferization.
+
+  // expected-error @+1 {{input IR has RaW conflict}}
+  %0 = bufferization.to_memref %t1 : memref<?xf32>
+
+  // Read from both.
+  %cst = arith.constant 0.0 : f32
+  %r1 = vector.transfer_read %t1[%idx3], %cst : tensor<?xf32>, vector<5xf32>
+  %r2 = vector.transfer_read %0[%idx3], %cst : memref<?xf32>, vector<5xf32>
+
+  return %r1, %r2 : vector<5xf32>, vector<5xf32>
+}

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index a68aa7b74150e..7667ce971aab3 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -6673,10 +6673,12 @@ cc_library(
 cc_library(
     name = "ComprehensiveBufferize",
     srcs = [
+        "lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp",
         "lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp",
         "lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp",
     ],
     hdrs = [
+        "include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.h",
         "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h",
         "include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h",
     ],


        


More information about the Mlir-commits mailing list