[Mlir-commits] [mlir] 26e9042 - [mlir][linalg][bufferize][NFC] Decouple ComprehensiveBufferize from tensor dialect

Matthias Springer llvmlistbot at llvm.org
Wed Nov 17 23:16:07 PST 2021


Author: Matthias Springer
Date: 2021-11-18T16:11:24+09:00
New Revision: 26e90423f4b81de7d4a6011134308c3e454964c0

URL: https://github.com/llvm/llvm-project/commit/26e90423f4b81de7d4a6011134308c3e454964c0
DIFF: https://github.com/llvm/llvm-project/commit/26e90423f4b81de7d4a6011134308c3e454964c0.diff

LOG: [mlir][linalg][bufferize][NFC] Decouple ComprehensiveBufferize from tensor dialect

Add a new BufferizableOpInterface method `isNotConflicting` that can be used to implement custom analysis rules.

Differential Revision: https://reviews.llvm.org/D113961

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 4742b51623e10..757eca50eb5f0 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -215,6 +215,29 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*defaultImplementation=*/[{
           return false;
         }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return `true` if the `uRead` and `uWrite` do not constitute a RaW
+          conflict. If they are conflicting or if it is unknown whether they are
+          conflicting, return `false`. This method will never be called with
+          OpOperands that do not have a tensor type. At least one of the two
+          given OpOperands belongs to this operation.
+
+          This method can be implemented to specify custom RaW analysis rules.
+          If this method returns `true` the given OpOperands are not considered
+          to be conflicting and do not force out-of-place bufferization. (There
+          may still be other conflicts that do.)
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"isNotConflicting",
+        /*args=*/(ins "OpOperand *":$uRead,
+                      "OpOperand *":$uWrite,
+                      "const BufferizationAliasInfo &":$aliasInfo),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return false;
+        }]
       >
   ];
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 697b894f89908..96fc066e7553e 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -281,24 +281,6 @@ static std::string printValueInfo(Value value, bool prefix) {
 // Bufferization-specific alias analysis.
 //===----------------------------------------------------------------------===//
 
-/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
-/// equivalent operand / result and same offset/sizes/strides specification).
-///
-/// This is one particular type of relationship between ops on tensors that
-/// reduce to an equivalence on buffers. This should be generalized and
-/// exposed as interfaces on the proper types.
-static bool
-areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
-                             ExtractSliceOp st, InsertSliceOp sti) {
-  if (!st || !sti)
-    return false;
-  if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
-    return false;
-  if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
-    return false;
-  return true;
-}
-
 /// Return true if opOperand has been decided to bufferize in-place.
 static bool isInplaceMemoryWrite(OpOperand &opOperand,
                                  const BufferizationAliasInfo &aliasInfo) {
@@ -368,21 +350,6 @@ static bool aliasesInPlaceWrite(Value value,
   return foundInplaceWrite;
 }
 
-/// Return true if `value` is originating from an ExtractSliceOp that matches
-/// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
-                                      Value value, InsertSliceOp insertOp) {
-  auto condition = [&](Value val) {
-    if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
-        return true;
-    return false;
-  };
-
-  return llvm::all_of(findValueInReverseUseDefChain(value, condition),
-                      condition);
-}
-
 /// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
 /// properly dominates `b` and `b` is not inside `a`.
 static bool happensBefore(Operation *a, Operation *b,
@@ -450,6 +417,21 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
       if (uConflictingWrite == uRead)
         continue;
 
+      // No conflict if the op interface says so.
+      if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(readingOp))
+        if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
+                                            aliasInfo))
+          continue;
+
+      if (conflictingWritingOp != readingOp)
+        if (auto bufferizableOp =
+                dyn_cast<BufferizableOpInterface>(conflictingWritingOp))
+          if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
+                                              aliasInfo))
+            continue;
+
+      // Special rules for branches.
+      // TODO: Use an interface.
       if (scf::insideMutuallyExclusiveBranches(readingOp, conflictingWritingOp))
         continue;
 
@@ -478,73 +460,6 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
       if (getAliasingOpResult(*uConflictingWrite) == lastWrite)
         continue;
 
-      // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
-      // uRead is an InsertSliceOp...
-      if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
-        // As an example, consider the following IR.
-        //
-        // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] }
-        // %1 = linalg.fill %cst, %0 {inplace= [true] }
-        // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-        //     {inplace= [true] }
-
-        // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
-        if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-            hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
-                                      insertSliceOp))
-          // Case 1: The main insight is that InsertSliceOp reads only part of
-          // the destination tensor. The overwritten area is not read. If
-          // uConflictingWrite writes into exactly the memory location that is
-          // being read by uRead, this is not a conflict.
-          //
-          // In the above example:
-          // uRead             = OpOperand 1 (%t) of tensor.insert_slice
-          // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
-          //
-          // The read of %t does not conflict with the write of the FillOp
-          // (same aliases!) because the area that the FillOp operates on is
-          // exactly the one that is *not* read via %t.
-          continue;
-
-        if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
-            uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-            hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
-          // Case 2: The read of the source tensor and the write to the dest
-          // tensor via an InsertSliceOp is not a conflict if the read is
-          // reading exactly that part of an equivalent tensor that the
-          // InsertSliceOp is writing.
-          //
-          // In the above example:
-          // uRead             = OpOperand 0 (%1) of tensor.insert_slice
-          // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-          continue;
-      }
-
-      // If uConflictingWrite is an InsertSliceOp...
-      if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
-        // As an example, consider the following IR.
-        //
-        // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace= [true] }
-        // %1 = linalg.fill %cst, %0 {inplace= [true] }
-        // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
-        //     {inplace= [true] }
-        // %3 = vector.transfer_read %1, %cst
-        //
-        // In the above example:
-        // uRead             = OpOperand 0 (%1) of vector.transfer_read
-        // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
-        // lastWrite         = %1
-        //
-        // This is not a conflict because the InsertSliceOp overwrites the
-        // memory segment of %1 with the exact same data. (Effectively, there
-        // is no memory write here.)
-        if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-            aliasInfo.areEquivalentBufferizedValues(uRead->get(),
-                                                    insertSliceOp.source()) &&
-            hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
-                                      insertSliceOp))
-          continue;
-
       // All requirements are met. Conflict found!
       LDBG("CONFLICT CONFIRMED!\n\n");
       return true;
@@ -2321,6 +2236,24 @@ struct ExtractOpInterface
   }
 };
 
+/// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e.
+/// equivalent operand / result and same offset/sizes/strides specification).
+///
+/// This is one particular type of relationship between ops on tensors that
+/// reduce to an equivalence on buffers. This should be generalized and
+/// exposed as interfaces on the proper types.
+static bool
+areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
+                             ExtractSliceOp st, InsertSliceOp sti) {
+  if (!st || !sti)
+    return false;
+  if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
+    return false;
+  if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
+    return false;
+  return true;
+}
+
 /// Return true if the source of a `insertSliceOp` bufferizes to an
 /// equivalent ExtractSliceOp that bufferizes inplace.
 static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
@@ -2345,6 +2278,21 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
   return foundOp;
 }
 
+/// Return true if `value` is originating from an ExtractSliceOp that matches
+/// the given InsertSliceOp.
+static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
+                                      Value value, InsertSliceOp insertOp) {
+  auto condition = [&](Value val) {
+    if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
+      if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
+        return true;
+    return false;
+  };
+
+  return llvm::all_of(findValueInReverseUseDefChain(value, condition),
+                      condition);
+}
+
 struct InsertSliceOpInterface
     : public BufferizableOpInterface::ExternalModel<InsertSliceOpInterface,
                                                     tensor::InsertSliceOp> {
@@ -2371,6 +2319,82 @@ struct InsertSliceOpInterface
     return BufferRelation::Equivalent;
   }
 
+  bool isNotConflicting(Operation *op, OpOperand *uRead,
+                        OpOperand *uConflictingWrite,
+                        const BufferizationAliasInfo &aliasInfo) const {
+    Operation *readingOp = uRead->getOwner();
+    Operation *conflictingWritingOp = uConflictingWrite->getOwner();
+
+    // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
+    // uRead is an InsertSliceOp...
+    if (auto insertSliceOp = dyn_cast<InsertSliceOp>(readingOp)) {
+      // As an example, consider the following IR.
+      //
+      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+      // %1 = linalg.fill %cst, %0 {inplace= [true] }
+      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+      //     {inplace= [true] }
+
+      // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
+      if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+          hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
+                                    insertSliceOp))
+        // Case 1: The main insight is that InsertSliceOp reads only part of
+        // the destination tensor. The overwritten area is not read. If
+        // uConflictingWrite writes into exactly the memory location that is
+        // being read by uRead, this is not a conflict.
+        //
+        // In the above example:
+        // uRead             = OpOperand 1 (%t) of tensor.insert_slice
+        // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
+        //
+        // The read of %t does not conflict with the write of the FillOp
+        // (same aliases!) because the area that the FillOp operates on is
+        // exactly the one that is *not* read via %t.
+        return true;
+
+      if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
+          uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+          hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
+        // Case 2: The read of the source tensor and the write to the dest
+        // tensor via an InsertSliceOp is not a conflict if the read is
+        // reading exactly that part of an equivalent tensor that the
+        // InsertSliceOp is writing.
+        //
+        // In the above example:
+        // uRead             = OpOperand 0 (%1) of tensor.insert_slice
+        // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+        return true;
+    }
+
+    // If uConflictingWrite is an InsertSliceOp...
+    if (auto insertSliceOp = dyn_cast<InsertSliceOp>(conflictingWritingOp))
+      // As an example, consider the following IR.
+      //
+      // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
+      // %1 = linalg.fill %cst, %0 {inplace= [true] }
+      // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
+      //     {inplace= [true] }
+      // %3 = vector.transfer_read %1, %cst
+      //
+      // In the above example:
+      // uRead             = OpOperand 0 (%1) of vector.transfer_read
+      // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
+      // lastWrite         = %1
+      //
+      // This is not a conflict because the InsertSliceOp overwrites the
+      // memory segment of %1 with the exact same data. (Effectively, there
+      // is no memory write here.)
+      if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
+          aliasInfo.areEquivalentBufferizedValues(uRead->get(),
+                                                  insertSliceOp.source()) &&
+          hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
+                                    insertSliceOp))
+        return true;
+
+    return false;
+  }
+
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BufferizationState &state) const {
     // insert_slice ops arise from tiling and bufferizing them out-of-place is


        


More information about the Mlir-commits mailing list