[Mlir-commits] [mlir] 87c770b - [mlir][bufferization][NFC] Put inplacability conflict resolution in op interface

Matthias Springer llvmlistbot at llvm.org
Thu Jun 9 13:10:26 PDT 2022


Author: Matthias Springer
Date: 2022-06-09T22:06:44+02:00
New Revision: 87c770bbd04462369950b1e6940e4f9ee4fc6de3

URL: https://github.com/llvm/llvm-project/commit/87c770bbd04462369950b1e6940e4f9ee4fc6de3
DIFF: https://github.com/llvm/llvm-project/commit/87c770bbd04462369950b1e6940e4f9ee4fc6de3.diff

LOG: [mlir][bufferization][NFC] Put inplacability conflict resolution in op interface

The TensorCopyInsertion pass resolves out-of-place bufferization decisions by inserting explicit `bufferization.alloc_tensor` ops. This change moves that functionality into a new BufferizableOpInterface method, so that it can be overridden by op implementations. Some op bufferizations must insert additional `alloc_tensor` ops to make sure that certain aliasing invariants are not violated (e.g., scf::ForOp). This will be addressed in a subsequent change.

Differential Revision: https://reviews.llvm.org/D126817

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index f69d612bbcdb9..e47e8478d4ad7 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -192,6 +192,32 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           llvm_unreachable("bufferRelation not implemented");
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Resolve all inplacability conflicts by inserting explicit
+          `bufferization.alloc_tensor` ops. Examples of inplacability conflicts
+          are read-after-write conflicts or writes into non-writable buffers.
+
+          This method should rewrite the IR in such a way that for each tensor
+          OpOperand t, buffer(t) can be directly used when during bufferization.
+          The bufferization does no longer have to care about inplacability
+          conflicts.
+
+          This method can query analysis information from the given analysis
+          state.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"resolveConflicts",
+        /*args=*/(ins "RewriterBase &":$rewriter,
+                      "const AnalysisState &":$state),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          auto bufferizableOp =
+              cast<BufferizableOpInterface>($_op.getOperation());
+          return bufferizableOp.resolveTensorOpOperandConflicts(
+              rewriter, state);
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Bufferize this op, i.e., rewrite it into a memref-based equivalent.
@@ -301,6 +327,11 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
   ];
 
   let extraClassDeclaration = [{
+    /// Resolve out-of-place tensor OpOperands with explicit allocations in the
+    /// form of `bufferization.alloc_tensor` ops.
+    LogicalResult resolveTensorOpOperandConflicts(
+        RewriterBase &rewriter, const AnalysisState &state);
+
     /// Return `true` if the given OpOperand creates an alias but does neither
     /// read nor write. This implies that `bufferizesToMemoryRead` and
     /// `bufferizesToMemoryWrite` must return `false`. This method will never

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 387b5eb269398..de8e30415a6d1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -18,6 +18,10 @@
 #include "mlir/IR/Value.h"
 #include "llvm/Support/Debug.h"
 
+//===----------------------------------------------------------------------===//
+// BufferizableOpInterface
+//===----------------------------------------------------------------------===//
+
 namespace mlir {
 namespace bufferization {
 
@@ -38,6 +42,31 @@ using namespace bufferization;
 constexpr const ::llvm::StringLiteral
     bufferization::BufferizableOpInterface::kInplaceableAttrName;
 
+LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
+    RewriterBase &rewriter, const AnalysisState &state) {
+  Operation *op = getOperation();
+  for (OpOperand &opOperand : op->getOpOperands()) {
+    Type operandType = opOperand.get().getType();
+    if (!operandType.isa<TensorType>())
+      continue;
+    if (state.isInPlace(opOperand))
+      continue;
+    if (operandType.isa<UnrankedTensorType>())
+      return op->emitError("copies of unranked tensors are not supported");
+    auto tensorType = operandType.dyn_cast<RankedTensorType>();
+    if (!tensorType)
+      continue;
+    SmallVector<OpResult> aliasingOpResults =
+        state.getAliasingOpResult(opOperand);
+    bool escape = llvm::any_of(
+        aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
+    Value copy = rewriter.create<AllocTensorOp>(
+        op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
+    rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); });
+  }
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // OpFilter
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 16b4b0b0d2ca7..e04f1e386ee91 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -43,7 +43,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
 LogicalResult
 mlir::bufferization::insertTensorCopies(Operation *op,
                                         const AnalysisState &state) {
-  OpBuilder builder(op->getContext());
+  IRRewriter rewriter(op->getContext());
   WalkResult result = op->walk([&](Operation *op) {
     auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op);
     if (!bufferizableOp)
@@ -55,31 +55,15 @@ mlir::bufferization::insertTensorCopies(Operation *op,
       if (allocTensorOp.escape())
         return WalkResult::advance();
       bool escape = state.isTensorYielded(allocTensorOp.result());
-      allocTensorOp.escapeAttr(builder.getBoolAttr(escape));
+      allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape));
       return WalkResult::advance();
     }
 
-    // Find out-of-place tensor OpOperands and resolve them with an explicit
-    // tensor copy in the form of an AllocTensorOp.
-    builder.setInsertionPoint(op);
-    for (OpOperand &opOperand : op->getOpOperands()) {
-      if (opOperand.get().getType().isa<UnrankedTensorType>()) {
-        op->emitError("copies of unranked tensors are not supported");
-        return WalkResult::interrupt();
-      }
-      auto tensorType = opOperand.get().getType().dyn_cast<RankedTensorType>();
-      if (!tensorType)
-        continue;
-      if (state.isInPlace(opOperand))
-        continue;
-      SmallVector<OpResult> aliasingOpResults =
-          state.getAliasingOpResult(opOperand);
-      bool escape = llvm::any_of(
-          aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
-      Value copy = builder.create<AllocTensorOp>(
-          op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
-      opOperand.set(copy);
-    }
+    // Find inplacability conflicts and resolve them. (Typically with explicit
+    // tensor copies in the form of AllocTensorOps.)
+    rewriter.setInsertionPoint(op);
+    if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
+      return WalkResult::interrupt();
 
     return WalkResult::advance();
   });


        


More information about the Mlir-commits mailing list