[Mlir-commits] [mlir] 032be23 - [mlir][bufferize] Improve buffer writability analysis

Matthias Springer llvmlistbot at llvm.org
Wed Jun 8 01:12:10 PDT 2022


Author: Matthias Springer
Date: 2022-06-08T10:11:52+02:00
New Revision: 032be2330928995ae264a4886fd2610bc3e49656

URL: https://github.com/llvm/llvm-project/commit/032be2330928995ae264a4886fd2610bc3e49656
DIFF: https://github.com/llvm/llvm-project/commit/032be2330928995ae264a4886fd2610bc3e49656.diff

LOG: [mlir][bufferize] Improve buffer writability analysis

Find writability conflicts (writes to buffers that are not allowed to be written to) by checking SSA use-def chains. This is better than the current writability analysis, which is too conservative and finds false positives.

Differential Revision: https://reviews.llvm.org/D127256

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 8d44d579fbc8..2663e480f281 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -167,6 +167,9 @@ class OneShotAnalysisState : public AnalysisState {
   /// not be called for values inside not yet analyzed functions.
   bool isValueWritten(Value value) const;
 
+  /// Return true if the buffer of the given tensor value is writable.
+  bool isWritable(Value value) const;
+
 private:
   /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
   /// functions and `runOneShotBufferize` may access this object.

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index f29d9a96b0b3..5447f6b0bdc2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -305,6 +305,21 @@ bool OneShotAnalysisState::isValueWritten(Value value) const {
   return isWritten;
 }
 
+bool OneShotAnalysisState::isWritable(Value value) const {
+  // TODO: Out-of-place bufferized value could be considered writable.
+  if (auto bufferizableOp = getOptions().dynCastBufferizableOp(value))
+    return bufferizableOp.isWritable(value, *this);
+
+  // Query BufferizableOpInterface to see if the BlockArgument is writable.
+  if (auto bbArg = value.dyn_cast<BlockArgument>())
+    if (auto bufferizableOp =
+            getOptions().dynCastBufferizableOp(bbArg.getOwner()->getParentOp()))
+      return bufferizableOp.isWritable(bbArg, *this);
+
+  // Not a bufferizable op: The conservative answer is "not writable".
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific alias analysis.
 //===----------------------------------------------------------------------===//
@@ -312,7 +327,7 @@ bool OneShotAnalysisState::isValueWritten(Value value) const {
 /// Return true if opOperand has been decided to bufferize in-place.
 static bool isInplaceMemoryWrite(OpOperand &opOperand,
                                  const BufferizationAliasInfo &aliasInfo,
-                                 AnalysisState &state) {
+                                 const AnalysisState &state) {
   // OpOperands that do not bufferize to a memory write do not write in-place.
   if (!state.bufferizesToMemoryWrite(opOperand))
     return false;
@@ -320,49 +335,6 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
   return aliasInfo.isInPlace(opOperand);
 }
 
-/// Return true if, under current bufferization decisions, the buffer of `value`
-/// is not writable.
-static bool aliasesNonWritableBuffer(Value value,
-                                     const BufferizationAliasInfo &aliasInfo,
-                                     AnalysisState &state) {
-  bool foundNonWritableBuffer = false;
-  aliasInfo.applyOnAliases(value, [&](Value v) {
-    // Query BufferizableOpInterface to see if the value is writable.
-    // TODO: Out-of-place bufferized value could be considered writable.
-    if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
-      if (bufferizableOp && bufferizableOp.isWritable(v, state))
-        return;
-
-    // Query BufferizableOpInterface to see if the BlockArgument is writable.
-    if (auto bbArg = v.dyn_cast<BlockArgument>())
-      if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
-              bbArg.getOwner()->getParentOp()))
-        if (bufferizableOp.isWritable(bbArg, state))
-          return;
-
-    foundNonWritableBuffer = true;
-  });
-
-  return foundNonWritableBuffer;
-}
-
-/// Return true if the buffer to which `operand` would bufferize is equivalent
-/// to some buffer write.
-static bool aliasesInPlaceWrite(Value value,
-                                const BufferizationAliasInfo &aliasInfo,
-                                AnalysisState &state) {
-  bool foundInplaceWrite = false;
-  aliasInfo.applyOnAliases(value, [&](Value v) {
-    for (auto &use : v.getUses()) {
-      if (isInplaceMemoryWrite(use, aliasInfo, state)) {
-        foundInplaceWrite = true;
-        return;
-      }
-    }
-  });
-  return foundInplaceWrite;
-}
-
 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
 /// properly dominates `b` and `b` is not inside `a`.
 static bool happensBefore(Operation *a, Operation *b,
@@ -604,6 +576,30 @@ static bool hasReadAfterWriteInterference(
   return false;
 }
 
+// Helper function to iterate on aliases of `root` and capture the writes.
+static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
+                                     const BufferizationAliasInfo &aliasInfo,
+                                     const AnalysisState &state) {
+  aliasInfo.applyOnAliases(root, [&](Value alias) {
+    for (auto &use : alias.getUses())
+      // Inplace write to a value that aliases root.
+      if (isInplaceMemoryWrite(use, aliasInfo, state))
+        res.insert(&use);
+  });
+}
+
+// Helper function to iterate on aliases of `root` and capture the reads.
+static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
+                             const BufferizationAliasInfo &aliasInfo,
+                             const AnalysisState &state) {
+  aliasInfo.applyOnAliases(root, [&](Value alias) {
+    for (auto &use : alias.getUses())
+      // Read to a value that aliases root.
+      if (state.bufferizesToMemoryRead(use))
+        res.insert(&use);
+  });
+}
+
 /// Return true if bufferizing `operand` inplace would create a conflict. A read
 /// R and a write W of the same alias set is a conflict if inplace bufferization
 /// of W changes the value read by R to a value 
diff erent from the one that
@@ -637,33 +633,13 @@ static bool wouldCreateReadAfterWriteInterference(
     OpOperand &operand, const DominanceInfo &domInfo, AnalysisState &state,
     const BufferizationAliasInfo &aliasInfo,
     bool checkConsistencyOnly = false) {
-  // Helper function to iterate on aliases of `root` and capture the reads.
-  auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
-    aliasInfo.applyOnAliases(root, [&](Value alias) {
-      for (auto &use : alias.getUses())
-        // Read to a value that aliases root.
-        if (state.bufferizesToMemoryRead(use))
-          res.insert(&use);
-    });
-  };
-
-  // Helper function to iterate on aliases of `root` and capture the writes.
-  auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
-    aliasInfo.applyOnAliases(root, [&](Value alias) {
-      for (auto &use : alias.getUses())
-        // Inplace write to a value that aliases root.
-        if (isInplaceMemoryWrite(use, aliasInfo, state))
-          res.insert(&use);
-    });
-  };
-
   // Collect reads and writes of all aliases of OpOperand and OpResult.
   DenseSet<OpOperand *> usesRead, usesWrite;
-  getAliasingReads(usesRead, operand.get());
-  getAliasingInplaceWrites(usesWrite, operand.get());
+  getAliasingReads(usesRead, operand.get(), aliasInfo, state);
+  getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
   for (OpResult result : state.getAliasingOpResult(operand)) {
-    getAliasingReads(usesRead, result);
-    getAliasingInplaceWrites(usesWrite, result);
+    getAliasingReads(usesRead, result, aliasInfo, state);
+    getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
   }
   if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
     usesWrite.insert(&operand);
@@ -672,28 +648,60 @@ static bool wouldCreateReadAfterWriteInterference(
                                        aliasInfo);
 }
 
-/// Return true if bufferizing `opOperand` inplace would create a write to a
-/// non-writable buffer.
+/// Check the reverse SSA use-def chain (following aliasing OpOperands) for
+/// non-writable tensor values. Stop searching when an out-of-place bufferized
+/// OpOperand was found (or when the OpOperand was not bufferized yet).
+/// `currentOpOperand` is assumed to be in-place, even if that decision was not
+/// materialized in `aliasInfo` yet.
 static bool
-wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
-                                    const BufferizationAliasInfo &aliasInfo,
-                                    AnalysisState &state) {
-  // Certain buffers are not writeable:
-  //   1. A function bbArg that is not inplaceable or
-  //   2. A constant op.
-  bool nonWritable =
-      aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
-  if (!nonWritable)
-    return false;
+hasPrecedingAliasingNonWritableTensor(Value value, OpOperand *currentOpOperand,
+                                      const BufferizationAliasInfo &aliasInfo,
+                                      const OneShotAnalysisState &state) {
+  SmallVector<Value> worklist;
+  worklist.push_back(value);
+  while (!worklist.empty()) {
+    Value nextVal = worklist.pop_back_val();
+    if (!state.isWritable(nextVal))
+      return true;
+
+    // If `nextVal` is not a BlockArgument: End of use-def chain reached.
+    auto opResult = nextVal.dyn_cast<OpResult>();
+    if (!opResult)
+      continue;
+
+    // Follow reverse SSA use-def chain.
+    SmallVector<OpOperand *> aliasingOpOperands =
+        state.getAliasingOpOperand(opResult);
+    for (OpOperand *opOperand : aliasingOpOperands)
+      if (aliasInfo.isInPlace(*opOperand) || currentOpOperand == opOperand)
+        worklist.push_back(opOperand->get());
+  }
+  return false;
+}
 
-  // This is a problem only if the buffer is written to via some alias.
-  bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
-                  state.bufferizesToMemoryWrite(opOperand);
+/// Return true if bufferizing `operand` inplace would create a write to a
+/// non-writable buffer.
+static bool wouldCreateWriteToNonWritableBuffer(
+    OpOperand &operand, const BufferizationAliasInfo &aliasInfo,
+    OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
+  // Collect writes of all aliases of OpOperand and OpResult.
+  DenseSet<OpOperand *> usesWrite;
+  getAliasingInplaceWrites(usesWrite, operand.get(), aliasInfo, state);
+  for (OpResult result : state.getAliasingOpResult(operand)) {
+    getAliasingInplaceWrites(usesWrite, result, aliasInfo, state);
+  }
+  if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
+    usesWrite.insert(&operand);
 
-  for (OpResult opResult : state.getAliasingOpResult(opOperand))
-    hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
+  // Assuming that `operand` bufferizes in-place: For each write (to each
+  // alias), check if there is a non-writable tensor in the reverse SSA use-def
+  // chain.
+  for (OpOperand *uWrite : usesWrite)
+    if (hasPrecedingAliasingNonWritableTensor(uWrite->get(), &operand,
+                                              aliasInfo, state))
+      return true;
 
-  return hasWrite;
+  return false;
 }
 
 //===----------------------------------------------------------------------===//
@@ -702,8 +710,8 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
 
 /// Determine if `operand` can be bufferized in-place.
 static LogicalResult bufferizableInPlaceAnalysisImpl(
-    OpOperand &operand, BufferizationAliasInfo &aliasInfo, AnalysisState &state,
-    const DominanceInfo &domInfo) {
+    OpOperand &operand, BufferizationAliasInfo &aliasInfo,
+    OneShotAnalysisState &state, const DominanceInfo &domInfo) {
   bool foundInterference =
       wouldCreateWriteToNonWritableBuffer(operand, aliasInfo, state) ||
       wouldCreateReadAfterWriteInterference(operand, domInfo, state, aliasInfo);
@@ -736,7 +744,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
 /// RaW dependence violations.
 static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
                                      BufferizationAliasInfo &aliasInfo,
-                                     AnalysisState &state,
+                                     OneShotAnalysisState &state,
                                      const DominanceInfo &domInfo,
                                      unsigned analysisFuzzerSeed = 0) {
   if (analysisFuzzerSeed) {
@@ -769,7 +777,7 @@ static bool hasTensorSemantics(Operation *op) {
 /// Analyze all ops that are contained in `op`.
 static LogicalResult inPlaceAnalysis(Operation *op,
                                      BufferizationAliasInfo &aliasInfo,
-                                     AnalysisState &state,
+                                     OneShotAnalysisState &state,
                                      const DominanceInfo &domInfo,
                                      unsigned analysisFuzzerSeed = 0) {
   // Collect ops so we can build our own reverse traversal.

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
index 9fff7f990b39..3145959e767e 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-allow-return-allocs.mlir
@@ -31,3 +31,34 @@ func.func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32>
   // CHECK: return %[[r_tensor]]
   return %r : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @write_to_alloc_tensor_or_readonly_tensor(
+//  CHECK-SAME:     %[[arg0:.*]]: tensor<i32>
+func.func @write_to_alloc_tensor_or_readonly_tensor(%arg0: tensor<i32>,
+                                                    %cond: i1, %val: i32)
+  -> tensor<i32>
+{
+  // CHECK: %[[r:.*]] = scf.if {{.*}} {
+  // CHECK:   %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]]
+  // CHECK:   %[[clone:.*]] = bufferization.clone %[[arg0_m]]
+  // CHECK:   scf.yield %[[clone]]
+  // CHECK: } else {
+  // CHECK:   %[[alloc:.*]] = memref.alloc
+  // CHECK:   memref.store %{{.*}}, %[[alloc]]
+  // CHECK:   %[[casted:.*]] = memref.cast %[[alloc]]
+  // CHECK:   scf.yield %[[casted]]
+  // CHECK: }
+  // CHECK: %[[r_t:.*]] = bufferization.to_tensor %[[r]]
+  // CHECK: memref.dealloc %[[r]]
+  // CHECK: return %[[r_t]]
+  %3 = scf.if %cond -> (tensor<i32>) {
+    scf.yield %arg0 : tensor<i32>
+  } else {
+    %7 = bufferization.alloc_tensor() : tensor<i32>
+    %8 = tensor.insert %val into %7[] : tensor<i32>
+    scf.yield %8 : tensor<i32>
+  }
+  return %3 : tensor<i32>
+}


        


More information about the Mlir-commits mailing list