[Mlir-commits] [mlir] 87b4677 - [mlir][bufferize] Improve resolveConflicts for ExtractSliceOp

Matthias Springer llvmlistbot at llvm.org
Thu Jun 9 13:19:45 PDT 2022


Author: Matthias Springer
Date: 2022-06-09T22:19:37+02:00
New Revision: 87b46776c44e0ef74f97c671eb39c317d8242769

URL: https://github.com/llvm/llvm-project/commit/87b46776c44e0ef74f97c671eb39c317d8242769
DIFF: https://github.com/llvm/llvm-project/commit/87b46776c44e0ef74f97c671eb39c317d8242769.diff

LOG: [mlir][bufferize] Improve resolveConflicts for ExtractSliceOp

It is sometimes better to make a copy of the OpResult instead of making a copy of the OpOperand. E.g., when bufferizing tensor.extract_slice.

This implementation will eventually make parts of extract_slice's `bufferize` implementation obsolete (and simplify it). It will only need to handle in-place OpOperands.

Differential Revision: https://reviews.llvm.org/D126819

Added: 
    mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir

Modified: 
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index de8e30415a6d1..eb9b5e4cca8be 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -44,7 +44,12 @@ constexpr const ::llvm::StringLiteral
 
 LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
     RewriterBase &rewriter, const AnalysisState &state) {
+  OpBuilder::InsertionGuard g(rewriter);
   Operation *op = getOperation();
+  SmallVector<OpOperand *> outOfPlaceOpOperands;
+  SmallVector<OpResult> outOfPlaceOpResults;
+
+  // Find all out-of-place OpOperands.
   for (OpOperand &opOperand : op->getOpOperands()) {
     Type operandType = opOperand.get().getType();
     if (!operandType.isa<TensorType>())
@@ -53,17 +58,52 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
       continue;
     if (operandType.isa<UnrankedTensorType>())
       return op->emitError("copies of unranked tensors are not supported");
-    auto tensorType = operandType.dyn_cast<RankedTensorType>();
-    if (!tensorType)
-      continue;
+
     SmallVector<OpResult> aliasingOpResults =
         state.getAliasingOpResult(opOperand);
+    if (aliasingOpResults.size() == 1 &&
+        !state.bufferizesToMemoryWrite(opOperand) &&
+        state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) {
+      // The op itself does not write but may create exactly one alias. Instead
+      // of copying the OpOperand, copy the OpResult. The OpResult can sometimes
+      // be smaller than the OpOperand (e.g., in the case of an extract_slice,
+      // where the result is usually a smaller part of the source).
+      outOfPlaceOpResults.push_back(aliasingOpResults.front());
+    } else {
+      // In all other cases, make a copy of the OpOperand.
+      outOfPlaceOpOperands.push_back(&opOperand);
+    }
+  }
+
+  // Insert copies of OpOperands.
+  rewriter.setInsertionPoint(op);
+  for (OpOperand *opOperand : outOfPlaceOpOperands) {
+    auto tensorType = opOperand->get().getType().cast<RankedTensorType>();
+    SmallVector<OpResult> aliasingOpResults =
+        state.getAliasingOpResult(*opOperand);
     bool escape = llvm::any_of(
         aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); });
     Value copy = rewriter.create<AllocTensorOp>(
-        op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape);
-    rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); });
+        op->getLoc(), tensorType, ValueRange(), opOperand->get(), escape);
+    rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
+  }
+
+  // Insert copies of OpResults.
+  rewriter.setInsertionPointAfter(op);
+  for (OpResult opResult : outOfPlaceOpResults) {
+    auto tensorType = opResult.getType().cast<RankedTensorType>();
+    bool escape = state.isTensorYielded(opResult);
+    Value copy = rewriter.create<AllocTensorOp>(op->getLoc(), tensorType,
+                                                ValueRange(), opResult, escape);
+    SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
+        opResult.getUses(), [](OpOperand &use) { return &use; }));
+    for (OpOperand *use : uses) {
+      // Do not update the alloc_tensor op that we just created.
+      if (use->getOwner() != copy.getDefiningOp())
+        rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
+    }
   }
+
   return success();
 }
 

diff  --git a/mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir
new file mode 100644
index 0000000000000..d399c65b441f5
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-opt %s -tensor-copy-insertion -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -tensor-copy-insertion="bufferize-function-boundaries allow-return-allocs" -split-input-file | FileCheck %s --check-prefix=CHECK-FUNC
+
+// CHECK-LABEL: func @extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?xf32>
+// CHECK-FUNC-LABEL: func @extract_slice(
+func.func @extract_slice(%t: tensor<?xf32>, %idx: index, %f: f32)
+  -> (tensor<5xf32>, tensor<?xf32>)
+{
+  // CHECK: %[[extract_slice:.*]] = tensor.extract_slice %[[t]][10] [5] [1]
+  %0 = tensor.extract_slice %t[10][5][1] : tensor<?xf32> to tensor<5xf32>
+  // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[extract_slice]]) {escape = false} : tensor<5xf32>
+  // CHECK-FUNC: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor<5xf32>
+  // CHECK: %[[insert:.*]] = tensor.insert %{{.*}} into %[[alloc]]
+  %1 = tensor.insert %f into %0[%idx] : tensor<5xf32>
+  // CHECK: return %[[insert]], %[[t]]
+  return %1, %t : tensor<5xf32>, tensor<?xf32>
+}


        


More information about the Mlir-commits mailing list