[Mlir-commits] [mlir] 1f479c1 - [mlir][bufferization] Improve findValueInReverseUseDefChain signature

Matthias Springer llvmlistbot at llvm.org
Mon May 15 06:32:08 PDT 2023


Author: Matthias Springer
Date: 2023-05-15T15:31:56+02:00
New Revision: 1f479c1e46d111a6f001cf4ee24290f60f13257d

URL: https://github.com/llvm/llvm-project/commit/1f479c1e46d111a6f001cf4ee24290f60f13257d
DIFF: https://github.com/llvm/llvm-project/commit/1f479c1e46d111a6f001cf4ee24290f60f13257d.diff

LOG: [mlir][bufferization] Improve findValueInReverseUseDefChain signature

Instead of passing traversal options as a long list of arguments, store them in a TraversalConfig object and pass that object.

Differential Revision: https://reviews.llvm.org/D143927

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 2dbd113547e91..45d705c444a7e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -392,6 +392,23 @@ struct BufferizationOptions {
 /// Return `true` if the given value is a BlockArgument of a func::FuncOp.
 bool isFunctionArgument(Value value);
 
+/// Traversal parameters for `findValueInReverseUseDefChain`.
+struct TraversalConfig {
+  /// Specifies if leaves (that do not have further OpOperands to follow)
+  /// should be returned even if they do not match the specified filter.
+  bool alwaysIncludeLeaves = true;
+
+  /// Specifies whether out-of-place/undecided OpOperands should be followed.
+  bool followInPlaceOnly = false;
+
+  /// Specifies whether non-equivalent OpOperands should be followed.
+  bool followEquivalentOnly = false;
+
+  /// Specifies whether unknown/non-bufferizable/ops not included in the
+  /// OpFilter of BufferizationOptions should be followed.
+  bool followUnknownOps = false;
+};
+
 /// AnalysisState provides a variety of helper functions for dealing with
 /// tensor values.
 class AnalysisState {
@@ -437,9 +454,8 @@ class AnalysisState {
   /// `condition` evaluates to true. OpOperands of such matching Values are not
   /// traversed any further.
   ///
-  /// When reaching the end of a chain (BlockArgument or Value without aliasing
-  /// OpOperands), also return the last Value of that chain if
-  /// `alwaysIncludeLeaves` is set.
+  /// When reaching the end of a chain, also return the last Value of that
+  /// chain if `config.alwaysIncludeLeaves` is set.
   ///
   /// Example:
   ///
@@ -457,10 +473,11 @@ class AnalysisState {
   /// starting the traversal from Value 1, the resulting SetVector is:
   /// { 2, 7, 8, 5 }
   ///
-  /// If `followEquivalentOnly` is set, only equivalent OpOperands are selected.
+  /// Additional stopping conditions for the traversal can be specified in
+  /// `config`.
   SetVector<Value> findValueInReverseUseDefChain(
       Value value, llvm::function_ref<bool(Value)> condition,
-      bool followEquivalentOnly = false, bool alwaysIncludeLeaves = true) const;
+      TraversalConfig config = TraversalConfig()) const;
 
   /// Find the values that may define the contents of the given value at
   /// runtime. A block argument is always a definition. An OpResult is a

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index cb69a9e5879c0..712693ddd53a1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -367,10 +367,7 @@ BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
 
 BufferizableOpInterface
 BufferizationOptions::dynCastBufferizableOp(Value value) const {
-  if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
-    if (isOpAllowed(bufferizableOp.getOperation()))
-      return bufferizableOp;
-  return nullptr;
+  return dynCastBufferizableOp(getOwnerOfValue(value));
 }
 
 void BufferizationOptions::setFunctionBoundaryTypeConversion(
@@ -500,7 +497,7 @@ bool AnalysisState::isValueRead(Value value) const {
 // further.
 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
     Value value, llvm::function_ref<bool(Value)> condition,
-    bool followEquivalentOnly, bool alwaysIncludeLeaves) const {
+    TraversalConfig config) const {
   llvm::SetVector<Value> result, workingSet;
   workingSet.insert(value);
 
@@ -512,7 +509,7 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
     }
 
     if (llvm::isa<BlockArgument>(value)) {
-      if (alwaysIncludeLeaves)
+      if (config.alwaysIncludeLeaves)
         result.insert(value);
       continue;
     }
@@ -520,26 +517,43 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
     OpResult opResult = llvm::cast<OpResult>(value);
     BufferizableOpInterface bufferizableOp =
         options.dynCastBufferizableOp(opResult.getDefiningOp());
-    AliasingOpOperandList aliases = getAliasingOpOperands(opResult);
+    if (!config.followUnknownOps && !bufferizableOp) {
+      // Stop iterating if `followUnknownOps` is unset and the op is either
+      // not bufferizable or excluded in the OpFilter.
+      if (config.alwaysIncludeLeaves)
+        result.insert(value);
+      continue;
+    }
 
-    // Stop iterating in either one of these cases:
-    // * The current op is not bufferizable or excluded in the filter.
-    // * There are no OpOperands to follow.
-    if (!bufferizableOp || aliases.getNumAliases() == 0) {
-      if (alwaysIncludeLeaves)
+    AliasingOpOperandList aliases = getAliasingOpOperands(opResult);
+    if (aliases.getNumAliases() == 0) {
+      // The traversal ends naturally if there are no more OpOperands that
+      // could be followed.
+      if (config.alwaysIncludeLeaves)
         result.insert(value);
       continue;
     }
 
     for (AliasingOpOperand a : aliases) {
-      if (followEquivalentOnly && a.relation != BufferRelation::Equivalent) {
+      if (config.followEquivalentOnly &&
+          a.relation != BufferRelation::Equivalent) {
         // Stop iterating if `followEquivalentOnly` is set but the alias is not
         // equivalent.
-        if (alwaysIncludeLeaves)
+        if (config.alwaysIncludeLeaves)
           result.insert(value);
       } else {
         workingSet.insert(a.opOperand->get());
       }
+
+      if (config.followInPlaceOnly && !isInPlace(*a.opOperand)) {
+        // Stop iterating if `followInPlaceOnly` is set but the alias is
+        // out-of-place.
+        if (config.alwaysIncludeLeaves)
+          result.insert(value);
+        continue;
+      }
+
+      workingSet.insert(a.opOperand->get());
     }
   }
 
@@ -548,9 +562,10 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
 
 // Find the values that define the contents of the given value.
 llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
+  TraversalConfig config;
+  config.alwaysIncludeLeaves = false;
   return findValueInReverseUseDefChain(
-      value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
-      /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
+      value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
 }
 
 AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -927,12 +942,12 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
       return false;
     return state.bufferizesToMemoryWrite(v);
   };
+  TraversalConfig config;
+  config.alwaysIncludeLeaves = false;
   for (AliasingOpOperand alias : opOperands) {
     if (!state
              .findValueInReverseUseDefChain(alias.opOperand->get(),
-                                            isMemoryWriteInsideOp,
-                                            /*followEquivalentOnly=*/false,
-                                            /*alwaysIncludeLeaves=*/false)
+                                            isMemoryWriteInsideOp, config)
              .empty())
       return true;
   }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 58475d225ce8b..76d424867af61 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -132,10 +132,13 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
       // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
       // equivalent tensors. I.e., stop when there are ops such as extract_slice
       // on the path.
+      TraversalConfig config;
+      config.followEquivalentOnly = true;
+      config.alwaysIncludeLeaves = false;
       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
           operand.get(), /*condition=*/
           [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); },
-          /*followEquivalentOnly=*/true, /*alwaysIncludeLeaves=*/false);
+          config);
 
       for (Value v : emptyTensors) {
         Operation *emptyTensorOp = v.getDefiningOp();

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 6da512699cc7b..a9f05b21282de 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -775,11 +775,8 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
 // Find the values that define the contents of the given value.
 const llvm::SetVector<Value> &
 OneShotAnalysisState::findDefinitionsCached(Value value) {
-  if (!cachedDefinitions.count(value)) {
-    cachedDefinitions[value] = findValueInReverseUseDefChain(
-        value, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
-        /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false);
-  }
+  if (!cachedDefinitions.count(value))
+    cachedDefinitions[value] = findDefinitions(value);
   return cachedDefinitions[value];
 }
 


        


More information about the Mlir-commits mailing list