[Mlir-commits] [mlir] d9111f1 - [mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand (#121304)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 30 11:18:42 PST 2024


Author: Amir Bishara
Date: 2024-12-30T21:18:38+02:00
New Revision: d9111f19d2ea53d8ce105b3d09425394ccf37969

URL: https://github.com/llvm/llvm-project/commit/d9111f19d2ea53d8ce105b3d09425394ccf37969
DIFF: https://github.com/llvm/llvm-project/commit/d9111f19d2ea53d8ce105b3d09425394ccf37969.diff

LOG: [mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand (#121304)

Edit the `findValueInReverseUseDefChain` method to accept `OpOperand`
instead of the `Value` type, This change will make sure that the
populated `visitedOpOperands` argument is fully accurate and contains
the opOperand we have started the reverse chain from.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
    mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 983f7a29cb2206..d1a102e2a6e4e8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -456,7 +456,7 @@ class AnalysisState {
   /// read by themselves (e.g., ExtractSliceOp).
   bool isValueRead(Value value) const;
 
-  /// Starting from `value`, follow the use-def chain in reverse, always
+  /// Starting from `opOperand`, follow the use-def chain in reverse, always
   /// selecting the aliasing OpOperands. Find and return Values for which
   /// `condition` evaluates to true. OpOperands of such matching Values are not
   /// traversed any further, the visited aliasing opOperands will be preserved
@@ -484,7 +484,7 @@ class AnalysisState {
   /// Additional stopping conditions for the traversal can be specified in
   /// `config`.
   SetVector<Value> findValueInReverseUseDefChain(
-      Value value, llvm::function_ref<bool(Value)> condition,
+      OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
       TraversalConfig config = TraversalConfig(),
       llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
 
@@ -520,7 +520,7 @@ class AnalysisState {
   ///
   /// Note: OpResults of unknown ops are handled conservatively and assumed to
   /// be definitions.
-  SetVector<Value> findDefinitions(Value value) const;
+  SetVector<Value> findDefinitions(OpOperand *opOperand) const;
 
   /// Return `true` if the given OpResult has been decided to bufferize inplace.
   virtual bool isInPlace(OpOperand &opOperand) const;

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index d50a3042aeeacf..bd23a19f747285 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -127,9 +127,9 @@ class OneShotAnalysisState : public AnalysisState {
   /// Return true if the buffer of the given tensor value is writable.
   bool isWritable(Value value) const;
 
-  /// Find the definitions of the given tensor value or retrieve them from the
-  /// cache.
-  const SetVector<Value> &findDefinitionsCached(Value value);
+  /// Find the definitions of the given operand's value or
+  /// retrieve them from the cache.
+  const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand);
 
   /// Reset cached data structures.
   void resetCache() override;

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 349841f06959c3..1eb27e44810b0d 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const {
   return false;
 }
 
-// Starting from `value`, follow the use-def chain in reverse, always selecting
-// the aliasing OpOperands. Find and return Values for which `condition`
-// evaluates to true. OpOperands of such matching Values are not traversed any
-// further, the visited aliasing opOperands will be preserved through
-// `visitedOpOperands`.
+// Starting from `opOperand`, follow the use-def chain in reverse, always
+// selecting the aliasing OpOperands. Find and return Values for which
+// `condition` evaluates to true. Uses of such matching Values are not
+// traversed any further, the visited aliasing opOperands will be preserved
+// through `visitedOpOperands`.
 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
-    Value value, llvm::function_ref<bool(Value)> condition,
+    OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
     TraversalConfig config,
     llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
   llvm::DenseSet<Value> visited;
   llvm::SetVector<Value> result, workingSet;
-  workingSet.insert(value);
+  workingSet.insert(opOperand->get());
+
+  if (visitedOpOperands)
+    visitedOpOperands->insert(opOperand);
 
   while (!workingSet.empty()) {
     Value value = workingSet.pop_back_val();
@@ -563,12 +566,14 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
   return result;
 }
 
-// Find the values that define the contents of the given value.
-llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
+// Find the values that define the contents of the given operand's value.
+llvm::SetVector<Value>
+AnalysisState::findDefinitions(OpOperand *opOperand) const {
   TraversalConfig config;
   config.alwaysIncludeLeaves = false;
   return findValueInReverseUseDefChain(
-      value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
+      opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
+      config);
 }
 
 AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -892,7 +897,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
   config.alwaysIncludeLeaves = false;
   for (AliasingOpOperand alias : opOperands) {
     if (!state
-             .findValueInReverseUseDefChain(alias.opOperand->get(),
+             .findValueInReverseUseDefChain(alias.opOperand,
                                             isMemoryWriteInsideOp, config)
              .empty())
       return true;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 98c3d8d0adc6d2..2c4e362101f8f6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -143,7 +143,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
     // %3 = tensor.insert_slice %2 into ...
     config.followSameTypeOrCastsOnly = true;
     SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
-        source.get(), /*condition=*/
+        &source, /*condition=*/
         [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
         &visitedOpOperands);
 
@@ -155,10 +155,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
           visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
             return llvm::count(emptyTensorOp->getUses(), *opOperand);
           });
-      // This could be achieved when a use of `emptyTensorOp` is being
-      // consumed by `SubsetInsertionOpInterface`'s source directly.
-      if (iter == visitedOpOperands.end())
-        continue;
+
+      assert(iter != visitedOpOperands.end() && "could not find use");
       OpOperand *useToBeReplaced = *iter;
       Operation *user = useToBeReplaced->getOwner();
       auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index d1e6acef324fbd..fc1b221b4f0369 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
 
       // If there is no preceding definition, the tensor contents are
       // undefined.
-      if (findDefinitionsCached(opResult).empty())
+      if (opResult.getUses().empty())
+        continue;
+      // It does not really matter which use to take to search about
+      // the value's definitions.
+      OpOperand *opOperand = &(*opResult.getUses().begin());
+      if (findDefinitionsCached(opOperand).empty())
         for (OpOperand &use : opResult.getUses())
           undefinedTensorUses.insert(&use);
     }
@@ -464,7 +469,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
 /// indexing. I.e., the tensor types do not change along the use-def chain,
 /// apart from static <-> dynamic dim casts.
 static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
-                                                   Value start, Value other) {
+                                                   OpOperand *start,
+                                                   Value other) {
   TraversalConfig config;
   config.followEquivalentOnly = true;
   config.alwaysIncludeLeaves = false;
@@ -475,9 +481,10 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
               .empty();
 }
 
-/// Return "true" if `value` is originating from a subset that is equivalent to
-/// the subset that `subsetOp` inserts into.
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
+/// Return "true" if the given operand's value is originating from a subset
+/// that is equivalent to the subset that `subsetOp` inserts into.
+static bool matchesInsertDestination(const AnalysisState &state,
+                                     OpOperand *opOperand,
                                      SubsetInsertionOpInterface subsetOp) {
   auto matchingSubset = [&](Value val) {
     if (auto opResult = dyn_cast<OpResult>(val))
@@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
   // There may be multiple leaves at which the reverse SSA use-def chain lookup
   // terminates. All of them must be equivalent subsets.
   SetVector<Value> backwardSlice =
-      state.findValueInReverseUseDefChain(value, matchingSubset);
+      state.findValueInReverseUseDefChain(opOperand, matchingSubset);
   return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
 }
 
@@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
     //     {inplace= [true] }
 
     if (uRead == &subsetOp.getDestinationOperand() &&
-        matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+        matchesInsertDestination(state, uConflictingWrite, subsetOp))
       // Case 1: The main insight is that InsertSliceOp reads only part of
       // the destination tensor. The overwritten area is not read. If
       // uConflictingWrite writes into exactly the memory location that is
@@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
 
     if (uRead == &subsetOp.getSourceOperand() &&
         uConflictingWrite == &subsetOp.getDestinationOperand() &&
-        matchesInsertDestination(state, uRead->get(), subsetOp))
+        matchesInsertDestination(state, uRead, subsetOp))
       // Case 2: The read of the source tensor and the write to the dest
       // tensor via an InsertSliceOp is not a conflict if the read is
       // reading exactly that part of an equivalent tensor that the
@@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
     if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
         state.areEquivalentBufferizedValues(
             uRead->get(), subsetOp.getSourceOperand().get()) &&
-        matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
-                                 subsetOp))
+        matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp))
       return true;
 
   return false;
@@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
       // even though that op just bufferizes to an allocation but does define
       // the contents of the buffer.
       SetVector<Value> definitionsOrLeaves =
-          state.findValueInReverseUseDefChain(
-              uConflictingWrite->get(),
-              [&](Value v) { return state.bufferizesToMemoryWrite(v); });
+          state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) {
+            return state.bufferizesToMemoryWrite(v);
+          });
       assert(!definitionsOrLeaves.empty() &&
              "expected at least one definition or leaf");
 
@@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
     // In the above example, if uRead is the OpOperand of reading_op, the
     // definition is %0. Note that operations that create an alias but do not
     // bufferize to a memory write (such as ExtractSliceOp) are skipped.
-    const SetVector<Value> &definitions =
-        state.findDefinitionsCached(uRead->get());
+    const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
     if (definitions.empty()) {
       // Fast path: No conflict if there are no definitions.
       LLVM_DEBUG(llvm::dbgs()
@@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
             if (bufferizableOp.bufferizesToElementwiseAccess(
                     state, {uRead, uConflictingWrite})) {
               if (hasEquivalentValueInReverseUseDefChain(
-                      state, uRead->get(), uConflictingWrite->get()) ||
+                      state, uRead, uConflictingWrite->get()) ||
                   hasEquivalentValueInReverseUseDefChain(
-                      state, uConflictingWrite->get(), uRead->get())) {
+                      state, uConflictingWrite, uRead->get())) {
                 LLVM_DEBUG(
                     llvm::dbgs()
                     << "  no conflict: op bufferizes to element-wise access\n");
@@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
 // Bufferization analyses.
 //===----------------------------------------------------------------------===//
 
-// Find the values that define the contents of the given value.
+// Find the values that define the contents of the given operand's value.
 const llvm::SetVector<Value> &
-OneShotAnalysisState::findDefinitionsCached(Value value) {
+OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
+  Value value = opOperand->get();
   if (!cachedDefinitions.count(value))
-    cachedDefinitions[value] = findDefinitions(value);
+    cachedDefinitions[value] = findDefinitions(opOperand);
   return cachedDefinitions[value];
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 6801b68a853815..6c1087730ebba8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -553,7 +553,7 @@ Value linalg::bufferizeToAllocation(
     Value alloc = createAllocationForTensor(
         rewriter, op->getLoc(), operand->get(), options, memorySpace);
     allocs.push_back(alloc);
-    if (!state.findDefinitions(operand->get()).empty()) {
+    if (!state.findDefinitions(operand).empty()) {
       // Initialize buffer with a copy of the operand data. Not needed if the
       // tensor is uninitialized.
       createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 4776883ed95c5c..b710bde87f9f33 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -59,7 +59,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
       config.followEquivalentOnly = true;
       config.alwaysIncludeLeaves = false;
       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
-          in->get(), /*condition=*/
+          in, /*condition=*/
           [&](Value val) {
             return val.getDefiningOp<tensor::EmptyOp>() &&
                    val.getType() == in->get().getType();

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 26434774730e1b..820fb3dfa5e5e0 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -465,3 +465,14 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @direct_use_of_tensor_empty
+func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
+  // CHECK-NOT: memref.alloc
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  %inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_1 : tensor<5x6x128xf32>
+}


        


More information about the Mlir-commits mailing list