[Mlir-commits] [mlir] dc7ad19 - [mlir][bufferize][NFC] Optimize read-only tensor detection

Matthias Springer llvmlistbot at llvm.org
Thu Feb 9 00:07:25 PST 2023


Author: Matthias Springer
Date: 2023-02-09T09:07:14+01:00
New Revision: dc7ad194c77cbe5b49514bebcfcfc4afd9eb8439

URL: https://github.com/llvm/llvm-project/commit/dc7ad194c77cbe5b49514bebcfcfc4afd9eb8439
DIFF: https://github.com/llvm/llvm-project/commit/dc7ad194c77cbe5b49514bebcfcfc4afd9eb8439.diff

LOG: [mlir][bufferize][NFC] Optimize read-only tensor detection

Check alias sets instead of traversing the IR.

Differential Revision: https://reviews.llvm.org/D143500

Added: 
    

Modified: 
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 6420feb51fa1d..8a7d660f488f9 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -724,62 +724,42 @@ static void annotateNonWritableTensor(Value value) {
   }
 }
 
-/// Check the reverse SSA use-def chain (following aliasing OpOperands) for
-/// non-writable tensor values. Stop searching when an out-of-place bufferized
-/// OpOperand was found (or when the OpOperand was not bufferized yet).
-/// `currentOpOperand` is assumed to be in-place, even if that decision was not
-/// materialized in `aliasInfo` yet.
-static bool
-hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
-                                      const OneShotAnalysisState &state) {
-  SmallVector<Value> worklist;
-  worklist.push_back(value);
-  while (!worklist.empty()) {
-    Value nextVal = worklist.pop_back_val();
-    if (!state.isWritable(nextVal)) {
-      if (state.getOptions().printConflicts)
-        annotateNonWritableTensor(nextVal);
-      return true;
-    }
-
-    // If `nextVal` is not a BlockArgument: End of use-def chain reached.
-    auto opResult = nextVal.dyn_cast<OpResult>();
-    if (!opResult)
-      continue;
-
-    // Follow reverse SSA use-def chain.
-    AliasingOpOperandList aliasingOpOperands =
-        state.getAliasingOpOperands(opResult);
-    for (OpOperand *opOperand : aliasingOpOperands)
-      if (state.isInPlace(*opOperand) || currentOpOperand == opOperand)
-        worklist.push_back(opOperand->get());
-  }
-  return false;
-}
-
 /// Return true if bufferizing `operand` inplace would create a write to a
 /// non-writable buffer.
 static bool
 wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
                                     OneShotAnalysisState &state,
                                     bool checkConsistencyOnly = false) {
-  // Collect writes of all aliases of OpOperand and OpResult.
-  DenseSet<OpOperand *> usesWrite;
-  getAliasingInplaceWrites(usesWrite, operand.get(), state);
-  for (OpResult result : state.getAliasingOpResults(operand)) {
-    getAliasingInplaceWrites(usesWrite, result, state);
+  bool foundWrite =
+      !checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
+
+  if (!foundWrite) {
+    // Collect writes of all aliases of OpOperand and OpResult.
+    DenseSet<OpOperand *> usesWrite;
+    getAliasingInplaceWrites(usesWrite, operand.get(), state);
+    for (OpResult result : state.getAliasingOpResults(operand))
+      getAliasingInplaceWrites(usesWrite, result, state);
+    foundWrite = !usesWrite.empty();
   }
-  if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
-    usesWrite.insert(&operand);
 
-  // Assuming that `operand` bufferizes in-place: For each write (to each
-  // alias), check if there is a non-writable tensor in the reverse SSA use-def
-  // chain.
-  for (OpOperand *uWrite : usesWrite) {
-    if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand, state)) {
-      LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
-      return true;
+  if (!foundWrite)
+    return false;
+
+  // Look for a read-only tensor among all aliases.
+  bool foundReadOnly = false;
+  auto checkReadOnly = [&](Value v) {
+    if (!state.isWritable(v)) {
+      foundReadOnly = true;
+      if (state.getOptions().printConflicts)
+        annotateNonWritableTensor(v);
     }
+  };
+  state.applyOnAliases(operand.get(), checkReadOnly);
+  for (OpResult result : state.getAliasingOpResults(operand))
+    state.applyOnAliases(result, checkReadOnly);
+  if (foundReadOnly) {
+    LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
+    return true;
   }
 
   return false;


        


More information about the Mlir-commits mailing list