[Mlir-commits] [mlir] 08aa956 - [mlir][bufferization]-Replace only one use in TensorEmptyElimination (#118958)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 18 13:57:18 PST 2024


Author: Amir Bishara
Date: 2024-12-18T23:57:13+02:00
New Revision: 08aa95638713a37407367e0e158df6fb82509725

URL: https://github.com/llvm/llvm-project/commit/08aa95638713a37407367e0e158df6fb82509725
DIFF: https://github.com/llvm/llvm-project/commit/08aa95638713a37407367e0e158df6fb82509725.diff

LOG: [mlir][bufferization]-Replace only one use in TensorEmptyElimination (#118958)

In many cases the emptyTensorElimination can not transform or eliminate
the empty tensor which is being inserted into the
`SubsetInsertionOpInterface`.

Two major reasons for that:

1- Failing when trying to find a legal/suitable insertion point for the
`subsetExtract` which is about to replace the empty tensor. However, we
may try to handle this issue by moving the needed values which
responsible on building the `subsetExtract` nearby the empty tensor
(which is about to be eliminated). Thus increasing the probability to
find a legal insertion point.

2-The EmptyTensorElimination transform replaces the tensor.empty's uses
all at once in one apply, rather than replacing only the specific use
which was visited in the use-def chain (when traversing from the
tensor.insert_slice). This scenario of replacing all the uses of the
tensor.empty may lead into additional read effects after bufferization
of the specific subset extract/subview which should not be the case.

Both cases may result in many copies in the coming bufferization which
can not be canonicalized.

The first case can be noticed when having a `tensor.empty` followed by
`SubsetInsertionOpInterface` (or in simple words `tensor.insert_slice`),
which have been lowered from `tensor/tosa.concat`.

The second case can be noticed when having a `tensor.empty`, with many
uses and leading to applying the transformation only once, since the
whole uses have been replaced at once.

The first commit in the PR only adds the lit tests for the cases shown
above (NFC), to emphasize how the transform works, in the coming MRs
will upload a slight changes to handle these case.

The second commit in this PR, we want to replace only the specific use
which was visited in the `use-def` chain (when traversing from the
`tensor.insert_slice`'s source).

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 4866e31b19d5de..983f7a29cb2206 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -459,7 +459,8 @@ class AnalysisState {
   /// Starting from `value`, follow the use-def chain in reverse, always
   /// selecting the aliasing OpOperands. Find and return Values for which
   /// `condition` evaluates to true. OpOperands of such matching Values are not
-  /// traversed any further.
+  /// traversed any further, the visited aliasing opOperands will be preserved
+  /// through `visitedOpOperands`.
   ///
   /// When reaching the end of a chain, also return the last Value of that
   /// chain if `config.alwaysIncludeLeaves` is set.
@@ -484,7 +485,8 @@ class AnalysisState {
   /// `config`.
   SetVector<Value> findValueInReverseUseDefChain(
       Value value, llvm::function_ref<bool(Value)> condition,
-      TraversalConfig config = TraversalConfig()) const;
+      TraversalConfig config = TraversalConfig(),
+      llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
 
   /// Find the values that may define the contents of the given value at
   /// runtime. A block argument is always a definition. An OpResult is a

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 065739ea8e5951..f8a7a22787404b 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -483,10 +483,12 @@ bool AnalysisState::isValueRead(Value value) const {
 // Starting from `value`, follow the use-def chain in reverse, always selecting
 // the aliasing OpOperands. Find and return Values for which `condition`
 // evaluates to true. OpOperands of such matching Values are not traversed any
-// further.
+// further, the visited aliasing opOperands will be preserved through
+// `visitedOpOperands`.
 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
     Value value, llvm::function_ref<bool(Value)> condition,
-    TraversalConfig config) const {
+    TraversalConfig config,
+    llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
   llvm::DenseSet<Value> visited;
   llvm::SetVector<Value> result, workingSet;
   workingSet.insert(value);
@@ -553,6 +555,8 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
       }
 
       workingSet.insert(a.opOperand->get());
+      if (visitedOpOperands)
+        visitedOpOperands->insert(a.opOperand);
     }
   }
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index cb2efef5c038b1..abc0635a2cdff0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -48,27 +48,20 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
   return true;
 }
 
-/// Return true if the given `insertionPoint` dominates all uses of
-/// `emptyTensorOp`.
-static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
-                                        Operation *insertionPoint,
-                                        Operation *emptyTensorOp) {
-  return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
-    return domInfo.dominates(insertionPoint, user);
-  });
-}
-
-/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
-/// that the replacement may use any value from `neededValues`.
+/// Find a valid insertion point for a replacement of `emptyTensorOp`'s
+/// use of `user` operation, assuming that the replacement may use any
+/// value from `neededValues`.
 static Operation *
-findValidInsertionPoint(Operation *emptyTensorOp,
+findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
                         const SmallVector<Value> &neededValues) {
   DominanceInfo domInfo;
+  Operation *candidateInsertionPoint = emptyTensorOp;
 
-  // Gather all possible insertion points: the location of `emptyTensorOp` and
-  // right after the definition of each value in `neededValues`.
+  // Gather all possible insertion points: the location of
+  // `candidateInsertionPoint` and right after the definition of each value in
+  // `neededValues`.
   SmallVector<Operation *> insertionPointCandidates;
-  insertionPointCandidates.push_back(emptyTensorOp);
+  insertionPointCandidates.push_back(candidateInsertionPoint);
   for (Value val : neededValues) {
     // Note: The anchor op is using all of `neededValues`, so:
     // * in case of a block argument: There must be at least one op in the block
@@ -90,8 +83,8 @@ findValidInsertionPoint(Operation *emptyTensorOp,
     if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
                                             neededValues))
       continue;
-    // Check if the insertion point is before all uses.
-    if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
+    // Check if the insertion point is before the use to be replaced.
+    if (!domInfo.dominates(insertionPoint, user))
       continue;
     return insertionPoint;
   }
@@ -103,8 +96,9 @@ findValidInsertionPoint(Operation *emptyTensorOp,
 LogicalResult mlir::bufferization::eliminateEmptyTensors(
     RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
   OpBuilder::InsertionGuard g(rewriter);
-
+  llvm::DenseSet<OpOperand *> visitedOpOperands;
   op->walk([&](SubsetInsertionOpInterface op) {
+    visitedOpOperands.clear();
     OpOperand &source = op.getSourceOperand();
     // Skip operands that do not bufferize inplace. "tensor.empty" could still
     // be replaced, but the transformation may not be beneficial.
@@ -131,16 +125,28 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
     config.followSameTypeOrCastsOnly = true;
     SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
         source.get(), /*condition=*/
-        [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
-        config);
+        [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
+        &visitedOpOperands);
 
     for (Value v : emptyTensors) {
       Operation *emptyTensorOp = v.getDefiningOp();
 
+      // Find the use to be replaced from the use-def chain.
+      auto iter = llvm::find_if(
+          visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
+            return llvm::count(emptyTensorOp->getUses(), *opOperand);
+          });
+      // This could be achieved when a use of `emptyTensorOp` is being
+      // consumed by `SubsetInsertionOpInterface`'s source directly.
+      if (iter == visitedOpOperands.end())
+        continue;
+      OpOperand *useToBeReplaced = *iter;
+      Operation *user = useToBeReplaced->getOwner();
+
       // Find a suitable insertion point. If no suitable insertion point for
       // the replacement can be found, skip this replacement.
       Operation *insertionPoint =
-          findValidInsertionPoint(emptyTensorOp, neededValues);
+          findValidInsertionPoint(emptyTensorOp, user, neededValues);
       if (!insertionPoint)
         continue;
 
@@ -159,8 +165,10 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
         replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
                                                       replacement);
       }
-      // Replace the tensor::EmptyOp.
-      rewriter.replaceOp(emptyTensorOp, replacement);
+      // Replace the specific use of the tensor::EmptyOp.
+      rewriter.modifyOpInPlace(user, [&]() {
+        user->setOperand(useToBeReplaced->getOperandNumber(), replacement);
+      });
       state.resetCache();
     }
 

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
index 2ba8246a8d5254..9150986f4c2a2a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
@@ -55,6 +55,7 @@ func.func @buffer_forwarding_conflict_with_
diff erent_element_type(%arg0: tensor<
   //      CHECK: tensor.extract_slice
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
   %cst = arith.constant 0.000000e+00 : f32
+  //      CHECK: bufferization.alloc_tensor(%arg1)
   %0 = tensor.empty(%arg1) : tensor<?xf32>
 
   //      CHECK: bufferization.alloc_tensor(%arg1)

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index efe59af97d9649..26434774730e1b 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -365,3 +365,103 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
   bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
   return
 }
+
+// -----
+
+// `EmptyTensorElimination` fails to find a valid insertion
+// point for the new injected `SubsetExtraction`.
+// CHECK-LABEL:   func.func @fail_to_eliminate_any_empty_tensors
+func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
+  %cst_1 = arith.constant 1.0 : f32
+  %cst_2 = arith.constant 2.0 : f32
+  // CHECK: memref.alloc
+  // CHECK: memref.alloc
+  // CHECK: memref.alloc
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %empty_2 = tensor.empty() : tensor<5x6x64xf32>
+  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+  // CHECK: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_2 : tensor<5x6x128xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @succeed_to_eliminate_one_empty_tensor
+func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
+  %cst_1 = arith.constant 1.0 : f32
+  %cst_2 = arith.constant 2.0 : f32
+  // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
+  // CHECK: memref.alloc
+  // CHECK-NOT: memref.alloc
+  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %empty_2 = tensor.empty() : tensor<5x6x64xf32>
+  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  // CHECK: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_2 : tensor<5x6x128xf32>
+}
+
+// -----
+
+// `EmptyTensorElimination` will replace the specific use of the tensor
+// empty with the new injected `SubsetExtraction`, i.e. the specific use
+// which has been tracked.
+
+// CHECK-ELIM-LABEL:   func.func @mutli_use_of_the_same_tensor_empty
+// CHECK-LABEL:   func.func @mutli_use_of_the_same_tensor_empty
+func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
+  %cst_1 = arith.constant 1.0 : f32
+  %cst_2 = arith.constant 2.0 : f32
+  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
+  // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
+  // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
+  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  // CHECK: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  // CHECK-NOT: memref.copy
+  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_2 : tensor<5x6x128xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read
+// CHECK-ELIM-LABEL:   func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read
+func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: tensor<5x6x128xf32> , %arg2: tensor<5x6x64xf32>)
+    -> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) {
+  %cst_1 = arith.constant 1.0 : f32
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x64xf32>
+  // CHECK-NOT: memref.alloc
+  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %res_2 = linalg.generic{
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  }
+  ins(%empty_1 : tensor<5x6x64xf32>)
+  outs(%arg2 :tensor<5x6x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %res = arith.addf %in, %in : f32
+    linalg.yield %res : f32
+  } -> tensor<5x6x64xf32>
+  // CHECK-NOT: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
+}


        


More information about the Mlir-commits mailing list