[Mlir-commits] [mlir] [mlir][bufferization]-Support unhandled cases in EmptyTensorElimination (PR #118958)

Amir Bishara llvmlistbot at llvm.org
Wed Dec 11 07:47:25 PST 2024


https://github.com/amirBish updated https://github.com/llvm/llvm-project/pull/118958

>From c4577b1422eeae39d498533dfb1e933525b2b8fe Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Wed, 4 Dec 2024 23:50:29 +0200
Subject: [PATCH 1/4] [mlir][bufferization]-Add lit tests for unhandled cases
 in EmptyTensorElimination

In many cases the emptyTensorElimination can not transform or
eliminate the empty tensor which is being inserted into the
`SubsetInsertionOpInterface`.

Two major reasons for that:

1- Failing when trying to find a legal/suitable insertion point
for the `subsetExtract` which is about to replace the empty tensor.
However, we may try to handle this issue by moving the needed values
which responsible on building the  `subsetExtract` nearby the empty
tensor (which is about to be eliminated). Thus increasing the
probability to find a legal insertion point.

2-The EmptyTensorElimination transform replaces the tensor.empty's uses
all at once in one apply, rather than replacing only the specific use
which was visited in the use-def chain (when traversing from the tensor.insert_slice).
This scenario of replacing all the uses of the tensor.empty may lead
into additional read effects after bufferization of the specific subset
extract/subview which should not be the case.

Both cases may result in many copies in the coming bufferization
which can not be canonicalized.

The first case can be noticed when having a `tensor.empty` followed by
`SubsetInsertionOpInterface` (or in simple words `tensor.insert_slice`),
which have been lowered from `tensor/tosa.concat`.

The second case can be noticed when having a `tensor.empty`, with many
uses and leading to applying the transformation only once, since the
whole uses have been replaced at once.

This MR only adds the lit tests for the cases shown above (NFC),
to emphasize how the transform works, in the coming MRs will upload
a slight changes to handle these case.
---
 ...ot-bufferize-empty-tensor-elimination.mlir | 98 +++++++++++++++++++
 1 file changed, 98 insertions(+)

diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index efe59af97d9649..9d9bb443160465 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -365,3 +365,101 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
   bufferization.materialize_in_destination %selected in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
   return
 }
+
+// -----
+
+// `EmptyTensorElimination` fails to find a valid insertion
+// point for the new injected `SubsetExtraction`.
+// CHECK-LABEL:   func.func @fail_to_eliminate_any_empty_tensors
+func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
+  %cst_1 = arith.constant 1.0 : f32
+  %cst_2 = arith.constant 2.0 : f32
+  // CHECK: memref.alloc
+  // CHECK: memref.alloc
+  // CHECK: memref.alloc
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %empty_2 = tensor.empty() : tensor<5x6x64xf32>
+  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+  // CHECK: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_2 : tensor<5x6x128xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @succeed_to_eliminate_one_empty_tensor
+func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
+  %cst_1 = arith.constant 1.0 : f32
+  %cst_2 = arith.constant 2.0 : f32
+  // CHECK: memref.alloc
+  // CHECK: memref.alloc
+  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %empty_2 = tensor.empty() : tensor<5x6x64xf32>
+  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  // CHECK: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_2 : tensor<5x6x128xf32>
+}
+
+// -----
+
+// `EmptyTensorElimination` replaces all of the uses of the tensor
+// empty with the new injected `SubsetExtraction`, without to consider
+// the specific use has been tracked, sometimes creating a non existent
+// bufferization conflicts.
+
+// CHECK-ELIM-LABEL:   func.func @mutli_use_of_the_same_tensor_empty
+// CHECK-LABEL:   func.func @mutli_use_of_the_same_tensor_empty
+func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
+  %cst_1 = arith.constant 1.0 : f32
+  %cst_2 = arith.constant 2.0 : f32
+  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
+  // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
+  // CHECK-ELIM: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
+  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  // CHECK: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  // CHECK: memref.copy
+  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_2 : tensor<5x6x128xf32>
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read
+func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: tensor<5x6x128xf32> , %arg2: tensor<5x6x64xf32>)
+    -> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) {
+  %cst_1 = arith.constant 1.0 : f32
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  // CHECK: memref.alloc
+  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %res_2 = linalg.generic{
+    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  }
+  ins(%empty_1 : tensor<5x6x64xf32>)
+  outs(%arg2 :tensor<5x6x64xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %res = arith.addf %in, %in : f32
+    linalg.yield %res : f32
+  } -> tensor<5x6x64xf32>
+  // CHECK: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
+}

>From 4a308069fcd2a81407b4ab5feab891c1ad224c6c Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Thu, 5 Dec 2024 16:14:00 +0200
Subject: [PATCH 2/4] [mlir][bufferization]-Try to move the needed values for
 subsetExtract in EmptyTensorElimination

In this MR, we will handle the case were we may fail finding
a legal/suitable insertion point for the subsetExtract which
is  about to replace the empty tensor.

For this reason, now we try also to move the needed values
which are responsible to create the `subsetExtract` before
the candidate insertion point (tensor.empty about to be
eliminated).
---
 .../Transforms/EmptyTensorElimination.cpp     | 48 ++++++++++++++-----
 ...ot-bufferize-empty-tensor-elimination.mlir | 14 +++---
 2 files changed, 44 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index cb2efef5c038b1..3925e858f50514 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -28,23 +28,32 @@ namespace bufferization {
 using namespace mlir;
 using namespace mlir::bufferization;
 
+/// Return true if `val` is in scope at the given
+/// `insertionPoint`.
+static bool valueDominateInsertionPoint(const DominanceInfo &domInfo,
+                                        Operation *insertionPoint, Value val) {
+  if (auto bbArg = dyn_cast<BlockArgument>(val)) {
+    Block *owner = bbArg.getOwner();
+    if (!owner->findAncestorOpInBlock(*insertionPoint))
+      return false;
+  } else {
+    auto opResult = cast<OpResult>(val);
+    if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
+      return false;
+  }
+  return true;
+}
+
 /// Return true if all `neededValues` are in scope at the given
 /// `insertionPoint`.
 static bool
 neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
                                    Operation *insertionPoint,
                                    const SmallVector<Value> &neededValues) {
-  for (Value val : neededValues) {
-    if (auto bbArg = dyn_cast<BlockArgument>(val)) {
-      Block *owner = bbArg.getOwner();
-      if (!owner->findAncestorOpInBlock(*insertionPoint))
-        return false;
-    } else {
-      auto opResult = cast<OpResult>(val);
-      if (!domInfo.properlyDominates(opResult.getOwner(), insertionPoint))
-        return false;
-    }
-  }
+  for (Value val : neededValues)
+    if (!valueDominateInsertionPoint(domInfo, insertionPoint, val))
+      return false;
+
   return true;
 }
 
@@ -65,6 +74,23 @@ findValidInsertionPoint(Operation *emptyTensorOp,
                         const SmallVector<Value> &neededValues) {
   DominanceInfo domInfo;
 
+  // Trying to move the needed values before the `emptyTensorOp`.
+  for (Value val : neededValues) {
+    if (valueDominateInsertionPoint(domInfo, emptyTensorOp, val))
+      continue;
+    Operation *definingOp = val.getDefiningOp();
+    if (!definingOp)
+      continue;
+
+    bool isItSafeToMoveOp =
+        llvm::all_of(definingOp->getOperands(), [&](Value operand) {
+          return valueDominateInsertionPoint(domInfo, emptyTensorOp, operand);
+        });
+
+    if (isItSafeToMoveOp)
+      definingOp->moveBefore(emptyTensorOp);
+  }
+
   // Gather all possible insertion points: the location of `emptyTensorOp` and
   // right after the definition of each value in `neededValues`.
   SmallVector<Operation *> insertionPointCandidates;
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 9d9bb443160465..db29adef993959 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -368,21 +368,21 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
 
 // -----
 
-// `EmptyTensorElimination` fails to find a valid insertion
-// point for the new injected `SubsetExtraction`.
+// `EmptyTensorElimination` finds a valid insertion
+// point for the new injected `SubsetExtraction` by
+// trying to move the needed value for the extraction.
 // CHECK-LABEL:   func.func @fail_to_eliminate_any_empty_tensors
 func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
   %cst_1 = arith.constant 1.0 : f32
   %cst_2 = arith.constant 2.0 : f32
   // CHECK: memref.alloc
-  // CHECK: memref.alloc
-  // CHECK: memref.alloc
+  // CHECK-NOT: memref.alloc
   %empty_1 = tensor.empty() : tensor<5x6x64xf32>
   %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
   %empty_2 = tensor.empty() : tensor<5x6x64xf32>
   %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
   %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
-  // CHECK: memref.copy
+  // CHECK-NOT: memref.copy
   %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
@@ -397,13 +397,13 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
   %cst_1 = arith.constant 1.0 : f32
   %cst_2 = arith.constant 2.0 : f32
   // CHECK: memref.alloc
-  // CHECK: memref.alloc
+  // CHECK-NOT: memref.alloc
   %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
   %empty_1 = tensor.empty() : tensor<5x6x64xf32>
   %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
   %empty_2 = tensor.empty() : tensor<5x6x64xf32>
   %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
-  // CHECK: memref.copy
+  // CHECK-NOT: memref.copy
   %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]

>From 8087fc02219f25ee5c6c2a4347f618e97d8ed741 Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Fri, 6 Dec 2024 13:27:20 +0200
Subject: [PATCH 3/4] [mlir][bufferization]-Replace only one use in
 TensorEmptyElimination

This MR hanldes the second case where we want to replace only the
specific use which was visited in the `use-def` chain (when
traversing from the tensor.insert_slice's source).

This scenario of replacing all the uses of the tensor.empty may lead
into additional read effects after bufferization of the specific subset
extract/subview which should not be the case, Thus eliminating a
potential copies.
---
 .../IR/BufferizableOpInterface.h              |  6 +-
 .../IR/BufferizableOpInterface.cpp            |  8 ++-
 .../Transforms/EmptyTensorElimination.cpp     | 64 +++++++++++--------
 ...ize-analysis-empty-tensor-elimination.mlir |  7 +-
 ...ot-bufferize-empty-tensor-elimination.mlir | 25 ++++----
 5 files changed, 65 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 4866e31b19d5de..58afdd5df6a67b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -459,7 +459,8 @@ class AnalysisState {
   /// Starting from `value`, follow the use-def chain in reverse, always
   /// selecting the aliasing OpOperands. Find and return Values for which
   /// `condition` evaluates to true. OpOperands of such matching Values are not
-  /// traversed any further.
+  /// traversed any further, The visited aliasing opOperands will be preserved
+  /// through `visitedOpOperands`.
   ///
   /// When reaching the end of a chain, also return the last Value of that
   /// chain if `config.alwaysIncludeLeaves` is set.
@@ -484,7 +485,8 @@ class AnalysisState {
   /// `config`.
   SetVector<Value> findValueInReverseUseDefChain(
       Value value, llvm::function_ref<bool(Value)> condition,
-      TraversalConfig config = TraversalConfig()) const;
+      TraversalConfig config = TraversalConfig(),
+      llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
 
   /// Find the values that may define the contents of the given value at
   /// runtime. A block argument is always a definition. An OpResult is a
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 065739ea8e5951..03d8f2a9342a13 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -483,10 +483,12 @@ bool AnalysisState::isValueRead(Value value) const {
 // Starting from `value`, follow the use-def chain in reverse, always selecting
 // the aliasing OpOperands. Find and return Values for which `condition`
 // evaluates to true. OpOperands of such matching Values are not traversed any
-// further.
+// further, The visited aliasing opOperands will be preserved through
+// `visitedOpOperands`.
 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
     Value value, llvm::function_ref<bool(Value)> condition,
-    TraversalConfig config) const {
+    TraversalConfig config,
+    llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
   llvm::DenseSet<Value> visited;
   llvm::SetVector<Value> result, workingSet;
   workingSet.insert(value);
@@ -553,6 +555,8 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
       }
 
       workingSet.insert(a.opOperand->get());
+      if (visitedOpOperands)
+        visitedOpOperands->insert(a.opOperand);
     }
   }
 
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 3925e858f50514..f814c3b2a6dad5 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -57,26 +57,23 @@ neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
   return true;
 }
 
-/// Return true if the given `insertionPoint` dominates all uses of
-/// `emptyTensorOp`.
-static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
-                                        Operation *insertionPoint,
-                                        Operation *emptyTensorOp) {
-  return llvm::all_of(emptyTensorOp->getUsers(), [&](Operation *user) {
-    return domInfo.dominates(insertionPoint, user);
-  });
-}
-
-/// Find a valid insertion point for a replacement of `emptyTensorOp`, assuming
-/// that the replacement may use any value from `neededValues`.
+/// Find a valid insertion point for a replacement of `useToBeEliminated`,
+/// assuming that the replacement may use any value from `neededValues`.
 static Operation *
-findValidInsertionPoint(Operation *emptyTensorOp,
+findValidInsertionPoint(OpOperand *useToBeEliminated,
                         const SmallVector<Value> &neededValues) {
   DominanceInfo domInfo;
 
+  Operation *candidateInsertionPoint = useToBeEliminated->getOwner();
+  assert(isa<OpResult>(useToBeEliminated->get()) && "expected a result value");
+  // Both `tensor.empty` and its user are within different blocks.
+  if (useToBeEliminated->getOwner()->getBlock() !=
+      useToBeEliminated->get().getDefiningOp()->getBlock())
+    candidateInsertionPoint = useToBeEliminated->get().getDefiningOp();
+
   // Trying to move the needed values before the `emptyTensorOp`.
   for (Value val : neededValues) {
-    if (valueDominateInsertionPoint(domInfo, emptyTensorOp, val))
+    if (valueDominateInsertionPoint(domInfo, candidateInsertionPoint, val))
       continue;
     Operation *definingOp = val.getDefiningOp();
     if (!definingOp)
@@ -84,17 +81,19 @@ findValidInsertionPoint(Operation *emptyTensorOp,
 
     bool isItSafeToMoveOp =
         llvm::all_of(definingOp->getOperands(), [&](Value operand) {
-          return valueDominateInsertionPoint(domInfo, emptyTensorOp, operand);
+          return valueDominateInsertionPoint(domInfo, candidateInsertionPoint,
+                                             operand);
         });
 
     if (isItSafeToMoveOp)
-      definingOp->moveBefore(emptyTensorOp);
+      definingOp->moveBefore(candidateInsertionPoint);
   }
 
-  // Gather all possible insertion points: the location of `emptyTensorOp` and
-  // right after the definition of each value in `neededValues`.
+  // Gather all possible insertion points: the location of
+  // `candidateInsertionPoint` and right after the definition of each value in
+  // `neededValues`.
   SmallVector<Operation *> insertionPointCandidates;
-  insertionPointCandidates.push_back(emptyTensorOp);
+  insertionPointCandidates.push_back(candidateInsertionPoint);
   for (Value val : neededValues) {
     // Note: The anchor op is using all of `neededValues`, so:
     // * in case of a block argument: There must be at least one op in the block
@@ -116,8 +115,8 @@ findValidInsertionPoint(Operation *emptyTensorOp,
     if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
                                             neededValues))
       continue;
-    // Check if the insertion point is before all uses.
-    if (!insertionPointDominatesUses(domInfo, insertionPoint, emptyTensorOp))
+    // Check if the insertion point is before the use to be replaced.
+    if (!domInfo.dominates(insertionPoint, useToBeEliminated->getOwner()))
       continue;
     return insertionPoint;
   }
@@ -129,8 +128,9 @@ findValidInsertionPoint(Operation *emptyTensorOp,
 LogicalResult mlir::bufferization::eliminateEmptyTensors(
     RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
   OpBuilder::InsertionGuard g(rewriter);
-
+  llvm::DenseSet<OpOperand *> visitedOpOperands;
   op->walk([&](SubsetInsertionOpInterface op) {
+    visitedOpOperands.clear();
     OpOperand &source = op.getSourceOperand();
     // Skip operands that do not bufferize inplace. "tensor.empty" could still
     // be replaced, but the transformation may not be beneficial.
@@ -157,16 +157,25 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
     config.followSameTypeOrCastsOnly = true;
     SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
         source.get(), /*condition=*/
-        [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
-        config);
+        [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
+        &visitedOpOperands);
 
     for (Value v : emptyTensors) {
       Operation *emptyTensorOp = v.getDefiningOp();
 
+      // Find the use to be replaced from the use-def chain
+      auto iter = llvm::find_if(
+          visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
+            return llvm::count(emptyTensorOp->getUses(), *opOperand);
+          });
+      if (iter == visitedOpOperands.end())
+        continue;
+      OpOperand *useToBeReplaced = *iter;
+
       // Find a suitable insertion point. If no suitable insertion point for
       // the replacement can be found, skip this replacement.
       Operation *insertionPoint =
-          findValidInsertionPoint(emptyTensorOp, neededValues);
+          findValidInsertionPoint(useToBeReplaced, neededValues);
       if (!insertionPoint)
         continue;
 
@@ -185,8 +194,9 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
         replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(),
                                                       replacement);
       }
-      // Replace the tensor::EmptyOp.
-      rewriter.replaceOp(emptyTensorOp, replacement);
+      // Replace the specific use of the tensor::EmptyOp.
+      useToBeReplaced->getOwner()->setOperand(
+          useToBeReplaced->getOperandNumber(), replacement);
       state.resetCache();
     }
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
index 2ba8246a8d5254..3a643f37e66007 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis-empty-tensor-elimination.mlir
@@ -52,9 +52,8 @@ func.func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {bufferization.wri
 
 // CHECK-LABEL: func @buffer_forwarding_conflict_with_different_element_type
 func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<?xf32> {bufferization.writable = true}, %arg1: index) -> (tensor<?xf32>, tensor<?xf32>) {
-  //      CHECK: tensor.extract_slice
-  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
   %cst = arith.constant 0.000000e+00 : f32
+  //      CHECK: bufferization.alloc_tensor(%arg1)
   %0 = tensor.empty(%arg1) : tensor<?xf32>
 
   //      CHECK: bufferization.alloc_tensor(%arg1)
@@ -64,6 +63,10 @@ func.func @buffer_forwarding_conflict_with_different_element_type(%arg0: tensor<
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
   %2 = linalg.copy ins(%0 : tensor<?xf32>) outs(%1 : tensor<?xbf16>) -> tensor<?xbf16>
 
+
+  //      CHECK: tensor.extract_slice
+  // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none"]
+
   //      CHECK: linalg.copy
   // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]
   %3 = linalg.copy ins(%2 : tensor<?xbf16>) outs(%0 : tensor<?xf32>) -> tensor<?xf32>
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index db29adef993959..af11638b80adf4 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -396,7 +396,7 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
 func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
   %cst_1 = arith.constant 1.0 : f32
   %cst_2 = arith.constant 2.0 : f32
-  // CHECK: memref.alloc
+  // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
   // CHECK-NOT: memref.alloc
   %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
   %empty_1 = tensor.empty() : tensor<5x6x64xf32>
@@ -413,10 +413,9 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
 
 // -----
 
-// `EmptyTensorElimination` replaces all of the uses of the tensor
-// empty with the new injected `SubsetExtraction`, without to consider
-// the specific use has been tracked, sometimes creating a non existent
-// bufferization conflicts.
+// `EmptyTensorElimination` will replace the specific use of the tensor
+// empty with the new injected `SubsetExtraction`, i.e. the specific use
+// which has been tracked.
 
 // CHECK-ELIM-LABEL:   func.func @mutli_use_of_the_same_tensor_empty
 // CHECK-LABEL:   func.func @mutli_use_of_the_same_tensor_empty
@@ -425,15 +424,16 @@ func.func @mutli_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
   %cst_2 = arith.constant 2.0 : f32
   %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
   %empty_1 = tensor.empty() : tensor<5x6x64xf32>
-  // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
-  // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
-  // CHECK-ELIM: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
+  // CHECK-ELIM: %[[VAL_4:.*]] = tensor.extract_slice %[[VAL_2:.*]]
+  // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_4]] : tensor<5x6x64xf32>)
+  // CHECK-ELIM: %[[VAL_6:.*]] = tensor.insert_slice
   %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  // CHECK-ELIM: %[[VAL_7:.*]] = tensor.extract_slice %[[VAL_6]]
+  // CHECK-ELIM: %[[VAL_8:.*]] = linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_7]] : tensor<5x6x64xf32>)
   %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
-  // CHECK: memref.copy
+  // CHECK-NOT: memref.copy
   %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
-  // CHECK: memref.copy
   %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -446,7 +446,8 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
     -> (tensor<5x6x128xf32>, tensor<5x6x64xf32>) {
   %cst_1 = arith.constant 1.0 : f32
   %empty_1 = tensor.empty() : tensor<5x6x64xf32>
-  // CHECK: memref.alloc
+  // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x64xf32>
+  // CHECK-NOT: memref.alloc
   %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
   %res_2 = linalg.generic{
     indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
@@ -458,7 +459,7 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
     %res = arith.addf %in, %in : f32
     linalg.yield %res : f32
   } -> tensor<5x6x64xf32>
-  // CHECK: memref.copy
+  // CHECK-NOT: memref.copy
   %inserted_slice_1 = tensor.insert_slice %res_1 into %arg1[0, 0, 0][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>

>From e40d0ef4c74865456fbb2bfa7a6227507eea087c Mon Sep 17 00:00:00 2001
From: Amir Bishara <amir.bishara at mobileye.com>
Date: Wed, 11 Dec 2024 17:45:38 +0200
Subject: [PATCH 4/4] [mlir][bufferization]-add lit test to show where
 insertion point still fail

---
 ...ot-bufferize-empty-tensor-elimination.mlir | 21 +++++++++++++++++++
 1 file changed, 21 insertions(+)

diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index af11638b80adf4..20948434ee0973 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -392,6 +392,27 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
 
 // -----
 
+// CHECK-LABEL:   func.func @empty_tensor_destination_with_dynamic_index
+func.func @empty_tensor_destination_with_dynamic_index(%arg0: tensor<?xf32>) -> tensor<?x6x128xf32> {
+  %cst_0 = arith.constant 0 : index
+  %cst_1 = arith.constant 1.0 : f32
+  %cst_2 = arith.constant 2.0 : f32
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %empty_2 = tensor.empty() : tensor<5x6x64xf32>
+  %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
+  %dim = tensor.dim %arg0, %cst_0 : tensor<?xf32>
+  %cancatenated_empty = tensor.empty(%dim) : tensor<?x6x128xf32>
+  // CHECK: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<?x6x128xf32>
+  %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<?x6x128xf32>
+  return %inserted_slice_2 : tensor<?x6x128xf32>
+}
+
+// -----
+
 // CHECK-LABEL:   func.func @succeed_to_eliminate_one_empty_tensor
 func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
   %cst_1 = arith.constant 1.0 : f32



More information about the Mlir-commits mailing list