[Mlir-commits] [mlir] 8f61191 - [mlir][linalg][bufferize] Add mustBufferizeInPlace to op interface

Matthias Springer llvmlistbot at llvm.org
Wed Nov 10 02:33:30 PST 2021


Author: Matthias Springer
Date: 2021-11-10T19:33:11+09:00
New Revision: 8f6119128f28c2e8a5a92ae230b9af32861e6c87

URL: https://github.com/llvm/llvm-project/commit/8f6119128f28c2e8a5a92ae230b9af32861e6c87
DIFF: https://github.com/llvm/llvm-project/commit/8f6119128f28c2e8a5a92ae230b9af32861e6c87.diff

LOG: [mlir][linalg][bufferize] Add mustBufferizeInPlace to op interface

This is useful for ops such as scf::IfOp, which always bufferize in-place.

This commit is in preparation of decoupling BufferizationAliasInfo from the SCF dialect.

Differential Revision: https://reviews.llvm.org/D113339

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index c2dd8b1321b4..04f2f133dff4 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -89,6 +89,25 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
                 });
           }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return `true` if the given OpResult must bufferize in-place with its
+          corresponding aliasing OpOperand. Alias sets and inplace attributes
+          will be set up accordingly before making any other bufferization
+          decisions. This method will never be called on OpResults that do not
+          have a tensor type.
+
+          Note: This method may not return `true` if the given OpResult does not
+          have an aliasing OpOperand.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"mustBufferizeInPlace",
+        /*args=*/(ins "OpResult":$opResult),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return false;
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Return the OpResult that aliases with a given OpOperand when

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index d285e530c204..675686fead27 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -538,18 +538,20 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
             createAliasInfoEntry(bbArg);
   });
 
-  // The return value of an scf::IfOp aliases with both yield values.
-  rootOp->walk([&](scf::IfOp ifOp) {
-    if (ifOp->getNumResults() > 0) {
-      for (auto it : llvm::zip(ifOp.thenYield().results(),
-                               ifOp.elseYield().results(), ifOp.results())) {
-        aliasInfo.unionSets(std::get<0>(it), std::get<1>(it));
-        aliasInfo.unionSets(std::get<0>(it), std::get<2>(it));
-      }
-
-      // scf::IfOp always bufferizes in-place.
-      for (OpResult opResult : ifOp->getResults())
-        setInPlaceOpResult(opResult, InPlaceSpec::True);
+  // Set up alias sets for OpResults that must bufferize in-place. This should
+  // be done before making any other bufferization decisions.
+  rootOp->walk([&](BufferizableOpInterface bufferizableOp) {
+    for (OpResult opResult : bufferizableOp->getOpResults()) {
+      if (opResult.getType().isa<TensorType>())
+        if (bufferizableOp.mustBufferizeInPlace(opResult)) {
+          SmallVector<OpOperand *> operands =
+              bufferizableOp.getAliasingOpOperand(opResult);
+          assert(!operands.empty() &&
+                 "expected that OpResult has aliasing OpOperand");
+          for (OpOperand *operand : operands)
+            aliasInfo.unionSets(operand->get(), opResult);
+          setInPlaceOpResult(opResult, InPlaceSpec::True);
+        }
     }
   });
 }
@@ -951,9 +953,14 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
 /// * However, adding an alias {%0, %t} would mean that the second
 ///   TransferWriteOp overwrites the first one. Therefore, the TransferReadOp
 ///   would no longer be reading the result of %1.
+///
+/// If `checkConsistencyOnly` is true, this function checks if there is a
+/// read-after-write conflict without bufferizing `operand` inplace. This would
+/// indicate a problem with the current inplace bufferization decisions.
 bool wouldCreateReadAfterWriteInterference(
     OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
-    const BufferizationAliasInfo &aliasInfo) {
+    const BufferizationAliasInfo &aliasInfo,
+    bool checkConsistencyOnly = false) {
 #ifndef NDEBUG
   SmallVector<OpOperand *> opOperands = getAliasingOpOperand(result);
   assert(llvm::find(opOperands, &operand) != opOperands.end() &&
@@ -986,7 +993,7 @@ bool wouldCreateReadAfterWriteInterference(
   getAliasingReads(usesRead, result);
   getAliasingInplaceWrites(usesWrite, operand.get());
   getAliasingInplaceWrites(usesWrite, result);
-  if (bufferizesToMemoryWrite(operand))
+  if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
     usesWrite.insert(&operand);
 
   return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo);
@@ -2229,6 +2236,24 @@ LogicalResult mlir::linalg::eliminateInsertSliceAnchoredInitTensorOps(
       });
 }
 
+/// Assert that the current bufferization decisions are consistent.
+static void checkAliasInfoConsistency(FuncOp funcOp,
+                                      const DominanceInfo &domInfo,
+                                      const BufferizationAliasInfo &aliasInfo) {
+  funcOp.walk([&](Operation *op) {
+    if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+      for (OpOperand &opOperand : op->getOpOperands())
+        if (opOperand.get().getType().isa<TensorType>())
+          if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
+            // If this assertion fails, there is probably an inconsistent
+            // combination of "mustBufferizeInPlace" decisions.
+            assert(!wouldCreateReadAfterWriteInterference(
+                       opOperand, opResult, domInfo, aliasInfo,
+                       /*checkConsistencyOnly=*/true) &&
+                   "found read after write conflict before running analysis");
+  });
+}
+
 LogicalResult
 mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp,
                                         const BufferizationOptions &options) {
@@ -2240,6 +2265,7 @@ mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp,
 
   DominanceInfo domInfo(moduleOp);
   BufferizationAliasInfo aliasInfo(moduleOp);
+
   // Interestingly, all function args that are not visible outside of a module
   // can be fully bufferized inplace by guaranteeing the CallOp is bufferized
   // inplace. Therefore, we just bufferize funcOp as if none of its results were
@@ -2260,6 +2286,10 @@ mlir::linalg::runComprehensiveBufferize(ModuleOp moduleOp,
         if (bbArg.getType().isa<TensorType>())
           setInPlaceFuncArgument(bbArg);
 
+#ifndef NDEBUG
+    checkAliasInfoConsistency(funcOp, domInfo, aliasInfo);
+#endif // NDEBUG
+
     // If the analysis fails, just return.
     if (failed(inPlaceAnalysisFuncOpBody(funcOp, aliasInfo, domInfo,
                                          options.analysisFuzzerSeed)))
@@ -2778,6 +2808,12 @@ struct IfOpInterface
     return true;
   }
 
+  bool mustBufferizeInPlace(Operation *op, OpResult opResult) const {
+    // IfOp results always bufferize in-place. Since they have no OpOperands,
+    // they are mostly ignored by the analysis once alias sets are set up.
+    return true;
+  }
+
   LogicalResult bufferize(Operation *op, OpBuilder &b,
                           BlockAndValueMapping &bvm,
                           BufferizationAliasInfo &aliasInfo,


        


More information about the Mlir-commits mailing list