[Mlir-commits] [mlir] 37317f5 - [mlir][linalg][bufferize] Decouple BufferizationAliasInfo

Matthias Springer llvmlistbot at llvm.org
Thu Nov 4 19:51:22 PDT 2021


Author: Matthias Springer
Date: 2021-11-05T11:41:00+09:00
New Revision: 37317f5bd21297af49b1f3968e0b44cb3596f653

URL: https://github.com/llvm/llvm-project/commit/37317f5bd21297af49b1f3968e0b44cb3596f653
DIFF: https://github.com/llvm/llvm-project/commit/37317f5bd21297af49b1f3968e0b44cb3596f653.diff

LOG: [mlir][linalg][bufferize] Decouple BufferizationAliasInfo

Move dialect-specific and analysis-specific function out of BufferizationAliasInfo. BufferizationAliasInfo's only job now is to keep track of aliases.

This is in preparation of futher decoupling ComprehensiveBufferize from various dialects.

Differential Revision: https://reviews.llvm.org/D112992

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index 94cb52b4bca5..4decaee4238b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -48,14 +48,6 @@ class BufferizationAliasInfo {
   /// `alias`. Additionally, merge their equivalence classes.
   void insertNewBufferEquivalence(Value newValue, Value alias);
 
-  /// Return true if, under current bufferization decisions, the buffer of
-  /// `value` is not writable.
-  bool aliasesNonWritableBuffer(Value value) const;
-
-  /// Return true if the buffer to which `operand` would bufferize is equivalent
-  /// to some buffer write.
-  bool aliasesInPlaceWrite(Value v) const;
-
   /// Set the inPlace bufferization spec to true.
   /// Merge result's and operand's aliasing sets and iterate to a fixed point.
   void bufferizeInPlace(OpResult result, OpOperand &operand);
@@ -63,23 +55,6 @@ class BufferizationAliasInfo {
   /// Set the inPlace bufferization spec to false.
   void bufferizeOutOfPlace(OpResult result);
 
-  /// Return true if `value` has an ExtractSliceOp matching the given
-  /// InsertSliceOp in its reverse SSA use-def chain.
-  bool hasMatchingExtractSliceOp(Value value,
-                                 tensor::InsertSliceOp insertOp) const;
-
-  /// Return true if bufferizing `opOperand` inplace with `opResult` would
-  /// create a write to a non-writable buffer.
-  bool wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
-                                           OpResult opResult) const;
-
-  /// Assume that result bufferizes in-place with one of the operation's
-  /// operands. Return true if it is possible to find an inplace write W that
-  /// creates a conflict.
-  bool
-  wouldCreateReadAfterWriteInterference(OpOperand &operand, OpResult result,
-                                        const DominanceInfo &domInfo) const;
-
   /// Return true if `v1` and `v2` bufferize to equivalent buffers.
   bool areEquivalentBufferizedValues(Value v1, Value v2) const {
     // Return `false` if we have no information about `v1` or `v2`.
@@ -91,14 +66,13 @@ class BufferizationAliasInfo {
            equivalentInfo.getLeaderValue(v2);
   }
 
-  /// Return true if the source of an `insertSliceOp` bufferizes to an
-  /// equivalent ExtractSliceOp.
-  bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
-      tensor::InsertSliceOp insertSliceOp) const;
-
   /// Apply `fun` to all the members of the equivalence class of `v`.
   void applyOnEquivalenceClass(Value v, function_ref<void(Value)> fun) const;
 
+  /// Apply `fun` to all aliases of `v`.
+  void applyOnAliases(Value v, function_ref<void(Value)> fun) const;
+
+  // TODO: Move these out of BufferizationAliasInfo.
   /// Return true if the value is known to bufferize to writable memory.
   bool bufferizesToWritableMemory(Value v) const;
 
@@ -128,22 +102,6 @@ class BufferizationAliasInfo {
   /// Check that aliasInfo for `v` exists and return a reference to it.
   EquivalenceClassRangeType getAliases(Value v) const;
 
-  /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-  /// equivalent operand / result and same offset/sizes/strides specification).
-  ///
-  /// This is one particular type of relationship between ops on tensors that
-  /// reduce to an equivalence on buffers. This should be generalized and
-  /// exposed as interfaces on the proper types.
-  bool areEquivalentExtractSliceOps(tensor::ExtractSliceOp st,
-                                    tensor::InsertSliceOp sti) const;
-
-  /// Given sets of uses and writes, return true if there is a RaW conflict
-  /// under the assumption that all given reads/writes alias the same buffer and
-  /// that all given writes bufferize inplace.
-  bool hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
-                                     const DenseSet<OpOperand *> &usesWrite,
-                                     const DominanceInfo &domInfo) const;
-
   /// Set of tensors that are known to bufferize to writable memory.
   llvm::DenseSet<Value> bufferizeToWritableMemory;
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 63145b157d55..28e432a88ad3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -508,6 +508,24 @@ static BufferRelation bufferRelation(OpOperand &opOperand) {
 // Bufferization-specific alias analysis.
 //===----------------------------------------------------------------------===//
 
+/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
+/// equivalent operand / result and same offset/sizes/strides specification).
+///
+/// This is one particular type of relationship between ops on tensors that
+/// reduce to an equivalence on buffers. This should be generalized and
+/// exposed as interfaces on the proper types.
+static bool
+areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
+                             ExtractSliceOp st, InsertSliceOp sti) {
+  if (!st || !sti)
+    return false;
+  if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
+    return false;
+  if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+    return false;
+  return true;
+}
+
 /// Return true if opOperand has been decided to bufferize in-place.
 static bool isInplaceMemoryWrite(OpOperand &opOperand) {
   // Ops that do not bufferize to a memory write, cannot be write in-place.
@@ -567,24 +585,27 @@ void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue,
 
 /// Return true if, under current bufferization decisions, the buffer of `value`
 /// is not writable.
-bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const {
+static bool aliasesNonWritableBuffer(Value value,
+                                     const BufferizationAliasInfo &aliasInfo) {
   LDBG("----Start aliasesNonWritableBuffer\n");
-  for (Value v : getAliases(value)) {
+  bool foundNonWritableBuffer = false;
+  aliasInfo.applyOnAliases(value, [&](Value v) {
     LDBG("-----------examine: " << printValueInfo(v) << '\n');
-    if (bufferizesToWritableMemory(v)) {
+    if (aliasInfo.bufferizesToWritableMemory(v)) {
       LDBG("-----------Value is known to be writable -> skip: "
            << printValueInfo(v) << '\n');
-      continue;
+      return;
     }
 
     if (auto bbArg = v.dyn_cast<BlockArgument>()) {
       if (getInPlace(bbArg) == InPlaceSpec::True) {
         LDBG("-----------bbArg is writable -> skip: " << printValueInfo(bbArg)
                                                       << '\n');
-        continue;
+        return;
       }
       LDBG("-----------notWritable bbArg\n");
-      return true;
+      foundNonWritableBuffer = true;
+      return;
     }
 
     auto bufferizableOp = dyn_cast<BufferizableOpInterface>(v.getDefiningOp());
@@ -592,11 +613,15 @@ bool BufferizationAliasInfo::aliasesNonWritableBuffer(Value value) const {
       // Unknown ops are treated conservatively: Assume that it is illegal to
       // write to their OpResults in-place.
       LDBG("-----------notWritable op\n");
-      return true;
+      foundNonWritableBuffer = true;
+      return;
     }
-  }
-  LDBG("---->value is writable\n");
-  return false;
+  });
+
+  if (!foundNonWritableBuffer)
+    LDBG("---->value is writable\n");
+
+  return foundNonWritableBuffer;
 }
 
 bool BufferizationAliasInfo::bufferizesToWritableMemory(Value v) const {
@@ -610,20 +635,26 @@ void BufferizationAliasInfo::setBufferizesToWritableMemory(Value v) {
 
 /// Return true if the buffer to which `operand` would bufferize is equivalent
 /// to some buffer write.
-bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const {
+static bool aliasesInPlaceWrite(Value value,
+                                const BufferizationAliasInfo &aliasInfo) {
   LDBG("----Start aliasesInPlaceWrite\n");
   LDBG("-------for : " << printValueInfo(value) << '\n');
-  for (Value v : getAliases(value)) {
+  bool foundInplaceWrite = false;
+  aliasInfo.applyOnAliases(value, [&](Value v) {
     for (auto &use : v.getUses()) {
       if (isInplaceMemoryWrite(use)) {
         LDBG("-----------wants to bufferize to inPlace write: "
              << printOperationInfo(use.getOwner()) << '\n');
-        return true;
+        foundInplaceWrite = true;
+        return;
       }
     }
-  }
-  LDBG("----------->does not alias an inplace write\n");
-  return false;
+  });
+
+  if (!foundInplaceWrite)
+    LDBG("----------->does not alias an inplace write\n");
+
+  return foundInplaceWrite;
 }
 
 /// Set the inPlace bufferization spec to true.
@@ -731,11 +762,11 @@ static Value findLastPrecedingWrite(Value value) {
 
 /// Return true if `value` is originating from an ExtractSliceOp that matches
 /// the given InsertSliceOp.
-bool BufferizationAliasInfo::hasMatchingExtractSliceOp(
-    Value value, InsertSliceOp insertOp) const {
+static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
+                                      Value value, InsertSliceOp insertOp) {
   auto condition = [&](Value val) {
     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentExtractSliceOps(extractOp, insertOp))
+      if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
         return true;
     return false;
   };
@@ -766,10 +797,11 @@ static bool happensBefore(Operation *a, Operation *b,
 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
 /// the result of a write W1. But because of bufferization decisions, R actually
 /// reads another write W2.
-bool BufferizationAliasInfo::hasReadAfterWriteInterference(
-    const DenseSet<OpOperand *> &usesRead,
-    const DenseSet<OpOperand *> &usesWrite,
-    const DominanceInfo &domInfo) const {
+static bool
+hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
+                              const DenseSet<OpOperand *> &usesWrite,
+                              const DominanceInfo &domInfo,
+                              const BufferizationAliasInfo &aliasInfo) {
   for (OpOperand *uRead : usesRead) {
     Operation *readingOp = uRead->getOwner();
 
@@ -850,7 +882,8 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
 
         // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
         if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-            hasMatchingExtractSliceOp(uConflictingWrite->get(), insertSliceOp))
+            hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
+                                      insertSliceOp))
           // Case 1: The main insight is that InsertSliceOp reads only part of
           // the destination tensor. The overwritten area is not read. If
           // uConflictingWrite writes into exactly the memory location that is
@@ -867,7 +900,7 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
 
         if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
             uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-            hasMatchingExtractSliceOp(uRead->get(), insertSliceOp))
+            hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
           // Case 2: The read of the source tensor and the write to the dest
           // tensor via an InsertSliceOp is not a conflict if the read is
           // reading exactly that part of an equivalent tensor that the
@@ -910,8 +943,9 @@ bool BufferizationAliasInfo::hasReadAfterWriteInterference(
 /// * However, adding an alias {%0, %t} would mean that the second
 ///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
 ///   would no longer be reading the result of %1.
-bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
-    OpOperand &operand, OpResult result, const DominanceInfo &domInfo) const {
+bool wouldCreateReadAfterWriteInterference(
+    OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
+    const BufferizationAliasInfo &aliasInfo) {
 #ifndef NDEBUG
   SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
   assert(llvm::find(opOperands, &operand) != opOperands.end() &&
@@ -920,20 +954,22 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
 
   // Helper function to iterate on aliases of `root` and capture the reads.
   auto getAliasingReads = [&](DenseSet<OpOperand *> &res, Value root) {
-    for (Value alias : getAliases(root))
+    aliasInfo.applyOnAliases(root, [&](Value alias) {
       for (auto &use : alias.getUses())
         // Read to a value that aliases root.
         if (bufferizesToMemoryRead(use))
           res.insert(&use);
+    });
   };
 
   // Helper function to iterate on aliases of `root` and capture the writes.
   auto getAliasingInplaceWrites = [&](DenseSet<OpOperand *> &res, Value root) {
-    for (Value alias : getAliases(root))
+    aliasInfo.applyOnAliases(root, [&](Value alias) {
       for (auto &use : alias.getUses())
         // Inplace write to a value that aliases root.
         if (isInplaceMemoryWrite(use))
           res.insert(&use);
+    });
   };
 
   // Collect reads and writes of all aliases of OpOperand and OpResult.
@@ -945,13 +981,14 @@ bool BufferizationAliasInfo::wouldCreateReadAfterWriteInterference(
   if (bufferizesToMemoryWrite(operand))
     usesWrite.insert(&operand);
 
-  return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo);
+  return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo);
 }
 
 /// Return true if bufferizing `opOperand` inplace with `opResult` would create
 /// a write to a non-writable buffer.
-bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
-    OpOperand &opOperand, OpResult opResult) const {
+static bool
+wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
+                                    const BufferizationAliasInfo &aliasInfo) {
 #ifndef NDEBUG
   SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
   assert(llvm::find(opOperands, &opOperand) != opOperands.end() &&
@@ -961,15 +998,15 @@ bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
   // Certain buffers are not writeable:
   //   1. A function bbArg that is not inplaceable or
   //   2. A constant op.
-  assert(!aliasesNonWritableBuffer(opResult) &&
+  assert(!aliasesNonWritableBuffer(opResult, aliasInfo) &&
          "expected that opResult does not alias non-writable buffer");
-  bool nonWritable = aliasesNonWritableBuffer(opOperand.get());
+  bool nonWritable = aliasesNonWritableBuffer(opOperand.get(), aliasInfo);
   if (!nonWritable)
     return false;
 
   // This is a problem only if the buffer is written to via some alias.
-  bool hasWrite = aliasesInPlaceWrite(opResult) ||
-                  aliasesInPlaceWrite(opOperand.get()) ||
+  bool hasWrite = aliasesInPlaceWrite(opResult, aliasInfo) ||
+                  aliasesInPlaceWrite(opOperand.get(), aliasInfo) ||
                   bufferizesToMemoryWrite(opOperand);
   if (!hasWrite)
     return false;
@@ -978,28 +1015,6 @@ bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer(
   return true;
 }
 
-/// Return true if the source of a `insertSliceOp` bufferizes to an
-/// equivalent ExtractSliceOp that bufferizes inplace.
-bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp(
-    InsertSliceOp insertSliceOp) const {
-  LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
-                                                              << '\n');
-  auto leaderIt = equivalentInfo.findLeader(insertSliceOp.source());
-  for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
-       ++mit) {
-    auto extractSliceOp =
-        dyn_cast_or_null<ExtractSliceOp>(mit->getDefiningOp());
-    if (extractSliceOp &&
-        areEquivalentExtractSliceOps(extractSliceOp, insertSliceOp) &&
-        getInPlace(extractSliceOp.result()) == InPlaceSpec::True) {
-      LDBG("\tfound: " << *mit->getDefiningOp() << '\n');
-      return true;
-    }
-  }
-  LDBG("\tnot equivalent\n");
-  return false;
-}
-
 /// Apply `fun` to all the members of the equivalence class of `v`.
 void BufferizationAliasInfo::applyOnEquivalenceClass(
     Value v, function_ref<void(Value)> fun) const {
@@ -1010,6 +1025,15 @@ void BufferizationAliasInfo::applyOnEquivalenceClass(
   }
 }
 
+/// Apply `fun` to all aliases of `v`.
+void BufferizationAliasInfo::applyOnAliases(
+    Value v, function_ref<void(Value)> fun) const {
+  auto leaderIt = aliasInfo.findLeader(v);
+  for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
+    fun(*mit);
+  }
+}
+
 void BufferizationAliasInfo::printAliases(raw_ostream &os) const {
   os << "\n/===================== AliasInfo =====================\n";
   for (auto it = aliasInfo.begin(), eit = aliasInfo.end(); it != eit; ++it) {
@@ -1066,20 +1090,6 @@ void BufferizationAliasInfo::dumpEquivalences() const {
   printEquivalences(llvm::errs());
 }
 
-/// This is one particular type of relationship between ops on tensors that
-/// reduce to an equivalence on buffers. This should be generalized and exposed
-/// as interfaces on the proper types.
-bool BufferizationAliasInfo::areEquivalentExtractSliceOps(
-    ExtractSliceOp st, InsertSliceOp sti) const {
-  if (!st || !sti)
-    return false;
-  if (!equivalentInfo.isEquivalent(st.source(), sti.dest()))
-    return false;
-  if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
-    return false;
-  return true;
-}
-
 //===----------------------------------------------------------------------===//
 // Forward declarations.
 //===----------------------------------------------------------------------===//
@@ -1475,8 +1485,9 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
                                    << printValueInfo(result) << '\n');
 
   bool foundInterference =
-      aliasInfo.wouldCreateWriteToNonWritableBuffer(operand, result) ||
-      aliasInfo.wouldCreateReadAfterWriteInterference(operand, result, domInfo);
+      wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo) ||
+      wouldCreateReadAfterWriteInterference(operand, result, domInfo,
+                                            aliasInfo);
 
   if (foundInterference)
     aliasInfo.bufferizeOutOfPlace(result);
@@ -3276,6 +3287,30 @@ struct ExtractOpInterface
   }
 };
 
+/// Return true if the source of a `insertSliceOp` bufferizes to an
+/// equivalent ExtractSliceOp that bufferizes inplace.
+static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
+    const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
+  LDBG("isSourceEquivalentToAMatchingInplaceExtractSliceOp: " << *insertSliceOp
+                                                              << '\n');
+  bool foundOp = false;
+  aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
+    auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
+    if (extractSliceOp &&
+        areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
+                                     insertSliceOp) &&
+        getInPlace(extractSliceOp.result()) == InPlaceSpec::True) {
+      LDBG("\tfound: " << extractSliceOp.getOperation() << '\n');
+      foundOp = true;
+    }
+  });
+
+  if (!foundOp)
+    LDBG("\tnot equivalent\n");
+
+  return foundOp;
+}
+
 struct InsertSliceOpInterface
     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
                                                     tensor::InsertSliceOp> {
@@ -3345,8 +3380,8 @@ struct InsertSliceOpInterface
     //     cloned and the clone needs to be updated.
     auto inPlace = getInPlace(insertSliceOp->getResult(0));
     // TODO: Is this necessary?
-    if (!aliasInfo.isSourceEquivalentToAMatchingInplaceExtractSliceOp(
-            insertSliceOp) ||
+    if (!isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
+                                                            insertSliceOp) ||
         inPlace != InPlaceSpec::True) {
       LDBG("insert_slice needs extra source copy: " << insertSliceOp.source()
                                                     << " -> copy\n");


        


More information about the Mlir-commits mailing list