[Mlir-commits] [mlir] 36ec848 - [mlir][linalg][bufferize][NFC] Change findValueInReverseUseDefChain signature

Matthias Springer llvmlistbot at llvm.org
Thu Oct 21 01:34:15 PDT 2021


Author: Matthias Springer
Date: 2021-10-21T17:34:05+09:00
New Revision: 36ec848dc7186b9713bb69ade134f9b2b7d65070

URL: https://github.com/llvm/llvm-project/commit/36ec848dc7186b9713bb69ade134f9b2b7d65070
DIFF: https://github.com/llvm/llvm-project/commit/36ec848dc7186b9713bb69ade134f9b2b7d65070.diff

LOG: [mlir][linalg][bufferize][NFC] Change findValueInReverseUseDefChain signature

This commit is in preparation for scf.if support.

* `condition` in findValueInReverseUseDefChain takes a Value instead of OpOperand*.
* Return a SetVector<Value> instead of a single Value. This SetVector always contains exactly one Value at the moment.

Differential Revision: https://reviews.llvm.org/D111928

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index be18cd135a47..0655d97b4ccb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -836,28 +836,37 @@ void BufferizationAliasInfo::bufferizeOutOfPlace(OpResult result) {
 
 /// Starting from `value`, follow the use-def chain in reverse, always selecting
 /// the corresponding aliasing OpOperand. Try to find and return a Value for
-/// which `condition` evaluates to true for the aliasing OpOperand. Return an
-/// empty Value if no such Value was found. If `returnLast`, return the last
-/// Value (at the end of the chain), even if it does not satisfy the condition.
-static Value
+/// which `condition` evaluates to true.
+///
+/// When reaching the end of the chain (BlockArgument or Value without aliasing
+/// OpOperands), return the last Value of the chain.
+///
+/// Note: The returned SetVector contains exactly one element.
+static llvm::SetVector<Value>
 findValueInReverseUseDefChain(Value value,
-                              std::function<bool(OpOperand &)> condition,
-                              bool returnLast = false) {
-  while (value.isa<OpResult>()) {
-    auto opResult = value.cast<OpResult>();
+                              std::function<bool(Value)> condition) {
+  llvm::SetVector<Value> result, workingSet;
+  workingSet.insert(value);
+
+  while (!workingSet.empty()) {
+    Value value = workingSet.pop_back_val();
+    if (condition(value) || value.isa<BlockArgument>()) {
+      result.insert(value);
+      continue;
+    }
+
+    OpResult opResult = value.cast<OpResult>();
     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
-    assert(opOperands.size() <= 1 && "more than 1 OpOperand not supported yet");
-    if (opOperands.empty())
-      // No aliasing OpOperand. This could be an unsupported op or an op without
-      // a tensor arg such as InitTensorOp. This is the end of the chain.
-      return returnLast ? value : Value();
-    OpOperand *opOperand = opOperands.front();
-    if (condition(*opOperand))
-      return value;
-    value = opOperand->get();
+    if (opOperands.empty()) {
+      result.insert(value);
+      continue;
+    }
+
+    assert(opOperands.size() == 1 && "multiple OpOperands not supported yet");
+    workingSet.insert(opOperands.front()->get());
   }
-  // Value is a BlockArgument. Reached the end of the chain.
-  return returnLast ? value : Value();
+
+  return result;
 }
 
 /// Find the Value (result) of the last preceding write of a given Value.
@@ -866,20 +875,41 @@ findValueInReverseUseDefChain(Value value,
 /// Furthermore, BlockArguments are also assumed to be writes. There is no
 /// analysis across block boundaries.
 static Value findLastPrecedingWrite(Value value) {
-  return findValueInReverseUseDefChain(value, bufferizesToMemoryWrite, true);
+  SetVector<Value> result =
+      findValueInReverseUseDefChain(value, [](Value value) {
+        Operation *op = value.getDefiningOp();
+        if (!op)
+          return true;
+        if (!hasKnownBufferizationAliasingBehavior(op))
+          return true;
+
+        SmallVector<OpOperand *> opOperands =
+            getAliasingOpOperand(value.cast<OpResult>());
+        assert(opOperands.size() <= 1 &&
+               "op with multiple aliasing OpOperands not expected");
+
+        if (opOperands.empty())
+          return true;
+
+        return bufferizesToMemoryWrite(*opOperands.front());
+      });
+  assert(result.size() == 1 && "expected exactly one result");
+  return result.front();
 }
 
 /// Return true if `value` is originating from an ExtractSliceOp that matches
 /// the given InsertSliceOp.
 bool BufferizationAliasInfo::hasMatchingExtractSliceOp(
     Value value, InsertSliceOp insertOp) const {
-  return static_cast<bool>(
-      findValueInReverseUseDefChain(value, [&](OpOperand &opOperand) {
-        if (auto extractOp = dyn_cast<ExtractSliceOp>(opOperand.getOwner()))
-          if (areEquivalentExtractSliceOps(extractOp, insertOp))
-            return true;
-        return false;
-      }));
+  auto condition = [&](Value val) {
+    if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
+      if (areEquivalentExtractSliceOps(extractOp, insertOp))
+        return true;
+    return false;
+  };
+
+  return llvm::all_of(findValueInReverseUseDefChain(value, condition),
+                      condition);
 }
 
 /// Given sets of uses and writes, return true if there is a RaW conflict under


        


More information about the Mlir-commits mailing list