[Mlir-commits] [mlir] 6cdd34b - [mlir][tensor][bufferize] Bufferize inserts into equivalent tensors in-place

Matthias Springer llvmlistbot at llvm.org
Wed Oct 5 23:14:52 PDT 2022


Author: Matthias Springer
Date: 2022-10-06T15:06:33+09:00
New Revision: 6cdd34b9739a41caeda14f63bceab6fec7fd0ae5

URL: https://github.com/llvm/llvm-project/commit/6cdd34b9739a41caeda14f63bceab6fec7fd0ae5
DIFF: https://github.com/llvm/llvm-project/commit/6cdd34b9739a41caeda14f63bceab6fec7fd0ae5.diff

LOG: [mlir][tensor][bufferize] Bufferize inserts into equivalent tensors in-place

Inserting a tensor into an equivalent tensor is a no-op after bufferization. No alloc is needed.

Differential Revision: https://reviews.llvm.org/D132662

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index cd2be2925c68..2c9dd66c45c4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -390,8 +390,12 @@ class AnalysisState {
   /// In the above example, Values with a star satisfy the condition. When
   /// starting the traversal from Value 1, the resulting SetVector is:
   /// { 2, 7, 8, 5 }
-  SetVector<Value> findValueInReverseUseDefChain(
-      Value value, llvm::function_ref<bool(Value)> condition) const;
+  ///
+  /// If `followEquivalentOnly` is set, only equivalent OpOperands are selected.
+  SetVector<Value>
+  findValueInReverseUseDefChain(Value value,
+                                llvm::function_ref<bool(Value)> condition,
+                                bool followEquivalentOnly = false) const;
 
   /// Find the Values of the last preceding write of a given Value.
   ///

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 0d3495100fda..3c135068af1a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -398,7 +398,8 @@ bool AnalysisState::isValueRead(Value value) const {
 // evaluates to true. OpOperands of such matching Values are not traversed any
 // further.
 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
-    Value value, llvm::function_ref<bool(Value)> condition) const {
+    Value value, llvm::function_ref<bool(Value)> condition,
+    bool followEquivalentOnly) const {
   llvm::SetVector<Value> result, workingSet;
   workingSet.insert(value);
 
@@ -410,8 +411,19 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
     }
 
     OpResult opResult = value.cast<OpResult>();
+    BufferizableOpInterface bufferizableOp =
+        options.dynCastBufferizableOp(opResult.getDefiningOp());
     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
-    if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
+
+    // Stop iterating in either one of these cases:
+    // * The current op is not bufferizable or excluded in the filter.
+    // * There are no OpOperands to follow.
+    // * There is an OpOperand, but it is not an equivalent tensor (only if
+    //   `followEquivalentOnly` is set).
+    if (!bufferizableOp || opOperands.empty() ||
+        (followEquivalentOnly &&
+         bufferizableOp.bufferRelation(opResult, *this) !=
+             BufferRelation::Equivalent)) {
       result.insert(value);
       continue;
     }

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 8a92129675ca..16e84a42064f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -611,9 +611,9 @@ struct InsertOpInterface
 /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
 /// equivalent operand / result and same offset/sizes/strides specification).
 template <typename OpTy>
-static bool areEquivalentExtractSliceOps(const AnalysisState &state,
-                                         ExtractSliceOp extractSliceOp,
-                                         OpTy insertSliceOp) {
+static bool areEquivalentSlices(const AnalysisState &state,
+                                ExtractSliceOp extractSliceOp,
+                                OpTy insertSliceOp) {
   if (!extractSliceOp || !insertSliceOp)
     return false;
   if (extractSliceOp != insertSliceOp &&
@@ -626,20 +626,31 @@ static bool areEquivalentExtractSliceOps(const AnalysisState &state,
   return true;
 }
 
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
+/// Return true if `value` is originating from the InsertSliceOp's destination
+/// or an ExtractSliceOp that matches the given InsertSliceOp.
 template <typename OpTy>
-static bool hasMatchingExtractSliceOp(const AnalysisState &state, Value value,
-                                      OpTy insertSliceOp) {
-  auto condition = [&](Value val) {
+static bool matchesInsertDestination(const AnalysisState &state, Value value,
+                                     OpTy insertSliceOp) {
+  // Look for matching slices.
+  auto matchesSlice = [&](Value val) {
     if (auto extractSliceOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentExtractSliceOps(state, extractSliceOp, insertSliceOp))
+      if (areEquivalentSlices(state, extractSliceOp, insertSliceOp))
         return true;
     return false;
   };
+  if (llvm::all_of(state.findValueInReverseUseDefChain(value, matchesSlice),
+                   matchesSlice))
+    return true;
 
-  return llvm::all_of(state.findValueInReverseUseDefChain(value, condition),
-                      condition);
+  // Look for equivalent values.
+  auto isEquivalent = [&](Value val) {
+    return state.areEquivalentBufferizedValues(val, insertSliceOp.getDest());
+  };
+  if (llvm::all_of(state.findValueInReverseUseDefChain(
+                       value, isEquivalent, /*followEquivalentOnly=*/true),
+                   isEquivalent))
+    return true;
+  return false;
 }
 
 template <typename OpTy>
@@ -661,8 +672,8 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
 
     // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
     if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
-                                  insertSliceOp))
+        matchesInsertDestination(state, uConflictingWrite->get(),
+                                 insertSliceOp))
       // Case 1: The main insight is that InsertSliceOp reads only part of
       // the destination tensor. The overwritten area is not read. If
       // uConflictingWrite writes into exactly the memory location that is
@@ -679,7 +690,7 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
 
     if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
         uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-        hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
+        matchesInsertDestination(state, uRead->get(), insertSliceOp))
       // Case 2: The read of the source tensor and the write to the dest
       // tensor via an InsertSliceOp is not a conflict if the read is
       // reading exactly that part of an equivalent tensor that the
@@ -712,8 +723,8 @@ static bool isNotConflictingInsertSliceLikeOp(Operation *op, OpOperand *uRead,
     if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
         state.areEquivalentBufferizedValues(uRead->get(),
                                             insertSliceOp.getSource()) &&
-        hasMatchingExtractSliceOp(state, insertSliceOp.getSource(),
-                                  insertSliceOp))
+        matchesInsertDestination(state, insertSliceOp.getSource(),
+                                 insertSliceOp))
       return true;
 
   return false;

diff  --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index 3fc9f1ce1fc9..e80027013d1c 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -126,15 +126,12 @@ func.func @insert_slice_fun_not_inplace(
 
 // -----
 
-// CHECK-LABEL: func @tensor_cast_not_in_place(
-//  CHECK-SAME:     %[[A:.*]]: memref<?xf32{{.*}}>, %[[B:.*]]: memref<?xf32{{.*}}>
-//       CHECK:   %[[alloc:.*]] = memref.alloc
-//       CHECK:   memref.copy %[[A]], %[[alloc]]
+// CHECK-LABEL: func @tensor_cast_in_place(
+//  CHECK-SAME:     %[[A:.*]]: memref<?xf32{{.*}}>
 //       CHECK:   %[[subview:.*]] = memref.subview %[[A]][{{.*}}] [4] [1] : {{.*}} to memref<4xf32
-//       CHECK:   memref.copy %[[alloc]], %[[subview]]
-func.func @tensor_cast_not_in_place(
-    %A : tensor<?xf32> {bufferization.writable = true},
-    %B : tensor<?xf32> {bufferization.writable = false}, %idx: index)
+//       CHECK:   memref.copy %[[A]], %[[subview]]
+func.func @tensor_cast_in_place(
+    %A : tensor<?xf32> {bufferization.writable = true}, %idx: index)
   -> (tensor<?xf32>)
 {
   %r0 = tensor.cast %A : tensor<?xf32> to tensor<4xf32>
@@ -243,3 +240,16 @@ func.func @dealloc_pad_buffer(%t1: tensor<?x10xindex>, %l2: index, %h1: index,
   %r = tensor.extract %0[%idx, %idx] : tensor<?x?xindex>
   return %r : index
 }
+
+// -----
+
+// CHECK-LABEL: func @insert_equivalent_tensor
+func.func @insert_equivalent_tensor(%t: tensor<10xf32>) -> tensor<10xf32> {
+  // CHECK-NOT: memref.alloc
+  %cst = arith.constant 4.200000e+01 : f32
+  // CHECK: linalg.fill
+  %0 = linalg.fill ins(%cst : f32) outs(%t : tensor<10xf32>) -> tensor<10xf32>
+  // CHECK-NOT: memref.copy
+  %1 = tensor.insert_slice %0 into %t[0][10][1] : tensor<10xf32> into tensor<10xf32>
+  return %1 : tensor<10xf32>
+}


        


More information about the Mlir-commits mailing list