[Mlir-commits] [mlir] bb83520 - [mlir][linalg][bufferize] Generalize InitTensorOp elimination

Matthias Springer llvmlistbot at llvm.org
Wed Nov 3 21:59:05 PDT 2021


Author: Matthias Springer
Date: 2021-11-04T13:53:12+09:00
New Revision: bb83520dce13b9804b3282f29a8ff0886384a362

URL: https://github.com/llvm/llvm-project/commit/bb83520dce13b9804b3282f29a8ff0886384a362
DIFF: https://github.com/llvm/llvm-project/commit/bb83520dce13b9804b3282f29a8ff0886384a362.diff

LOG: [mlir][linalg][bufferize] Generalize InitTensorOp elimination

This allows for external users of Comprehensive Bufferize to specify their own InitTensorOp elimination procedures.

Differential Revision: https://reviews.llvm.org/D112686

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
index b7ce96d9fedf3..e3b59d5daa60d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h
@@ -195,6 +195,29 @@ bufferizeOp(Operation *op, BlockAndValueMapping &bvm,
 
 /// Register external models implemented for the `BufferizableOpInterface`.
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
+
+/// Try to eliminate InitTensorOps inside `funcOp`.
+///
+/// * `rewriteFunc` generates the replacement for the InitTensorOp.
+/// * Only InitTensorOps that are anchored on a matching OpOperand as per
+///   `anchorMatchFunc` are considered. "Anchored" means that there is a path on
+///   the reverse SSA use-def chain, starting from the OpOperand and always
+///   following the aliasing  OpOperand, that eventually ends at a single
+///   InitTensorOp.
+/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
+///   This analysis can be skipped with `skipAnalysis`.
+LogicalResult initTensorElimination(
+    FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
+    std::function<bool(OpOperand &)> anchorMatchFunc,
+    std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
+    bool skipAnalysis = false);
+
+/// Try to eliminate InitTensorOps inside funcOp that are anchored on an
+/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
+/// (and some other conditions are met).
+LogicalResult eliminateInsertSliceAnchoredInitTensorOps(
+    FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
index 1fe5835d282bf..4fbde63c4d108 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp
@@ -2150,6 +2150,78 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
   }
 }
 
+/// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp is replaced
+/// with the the result of `rewriteFunc` if it is anchored on a matching
+/// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
+/// chain, starting from the OpOperand and always following the aliasing
+/// OpOperand, that eventually ends at a single InitTensorOp.
+LogicalResult mlir::linalg::initTensorElimination(
+    FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
+    std::function<bool(OpOperand &)> anchorMatchFunc,
+    std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
+    bool skipAnalysis) {
+  OpBuilder b(funcOp->getContext());
+
+  WalkResult status = funcOp->walk([&](Operation *op) {
+    for (OpOperand &operand : op->getOpOperands()) {
+      // Is this a matching OpOperand?
+      if (!anchorMatchFunc(operand))
+        continue;
+
+      SetVector<Value> maybeInitTensor =
+          findValueInReverseUseDefChain(operand.get(), [](Value val) {
+            // Continue traversal until this function returns true.
+            OpResult opResult = val.dyn_cast<OpResult>();
+            if (!opResult)
+              return true;
+            if (getInPlace(opResult) != InPlaceSpec::True)
+              return true;
+            // Only equivalent tensors are supported at the moment.
+            // TODO: Support cases such as extract_slice(init_tensor).
+            SmallVector<OpOperand *> opOperands =
+                getAliasingOpOperand(opResult);
+            if (!llvm::all_of(opOperands, [](OpOperand *operand) {
+                  return bufferRelation(*operand) == BufferRelation::Equivalent;
+                }))
+              return true;
+            return false;
+          });
+
+      // Replace only if the reverse use-def chain ends at exactly one
+      // InitTensorOp.
+      if (maybeInitTensor.size() != 1 ||
+          !maybeInitTensor.front().getDefiningOp<InitTensorOp>())
+        return WalkResult::skip();
+      Value initTensor = maybeInitTensor.front();
+
+      // Create a replacement for the InitTensorOp.
+      b.setInsertionPoint(initTensor.getDefiningOp());
+      Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
+      if (!replacement)
+        continue;
+
+      // Uses of the InitTensorOp are replaced here, but the op is not deleted.
+      // InitTensorOps without uses are ignored by the bufferization.
+      initTensor.replaceAllUsesWith(replacement);
+      aliasInfo.createAliasInfoEntry(replacement);
+
+      // Run analysis on the newly created op.
+      if (auto opResult = replacement.dyn_cast<OpResult>()) {
+        if (!skipAnalysis) {
+          SmallVector<Operation *> ops(1, replacement.getDefiningOp());
+          if (failed(inPlaceAnalysis(ops, aliasInfo, domInfo)))
+            return WalkResult::interrupt();
+        }
+      }
+    }
+
+    // Advance to the next operation.
+    return WalkResult::advance();
+  });
+
+  return failure(status.wasInterrupted());
+}
+
 /// Try to eliminate InitTensorOps inside funcOp. An InitTensorOp can be
 /// eliminated if it is eventually inserted into another tensor (and some other
 /// conditions are met).
@@ -2178,60 +2250,26 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
 ///
 /// Note that the newly inserted ExtractSliceOp may have to bufferize
 /// out-of-place due to RaW conflicts.
-static LogicalResult runInitTensorElimination(FuncOp funcOp,
-                                              BufferizationAliasInfo &aliasInfo,
-                                              DominanceInfo &domInfo) {
-  OpBuilder b(funcOp->getContext());
-
-  WalkResult status = funcOp->walk([&](tensor::InsertSliceOp insertOp) {
-    // Only inplace bufferized InsertSliceOps are eligible.
-    if (getInPlace(insertOp->getOpResult(0)) != InPlaceSpec::True)
-      return WalkResult::skip();
-
-    SetVector<Value> maybeInitTensor =
-        findValueInReverseUseDefChain(insertOp.source(), [](Value val) {
-          // Continue traversal until this function returns true.
-          OpResult opResult = val.dyn_cast<OpResult>();
-          if (!opResult)
-            return true;
-          if (getInPlace(opResult) != InPlaceSpec::True)
-            return true;
-          // Only equivalent tensors are supported at the moment. E.g., when
-          // taking a tensor.extract_slice of an init_tensor, we can currently
-          // not eliminate the init_tensor.
-          SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
-          if (!llvm::all_of(opOperands, [](OpOperand *operand) {
-                return bufferRelation(*operand) == BufferRelation::Equivalent;
-              }))
-            return true;
+LogicalResult mlir::linalg::eliminateInsertSliceAnchoredInitTensorOps(
+    FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo) {
+  return initTensorElimination(
+      funcOp, aliasInfo, domInfo,
+      [](OpOperand &operand) {
+        auto insertSliceOp = dyn_cast<InsertSliceOp>(operand.getOwner());
+        if (!insertSliceOp)
           return false;
-        });
-    // Replace only if the InsertSliceOp source originates from exactly one
-    // InitTensorOp.
-    if (maybeInitTensor.size() != 1 ||
-        !maybeInitTensor.front().getDefiningOp<InitTensorOp>())
-      return WalkResult::skip();
-    Value initTensor = maybeInitTensor.front();
-
-    b.setInsertionPoint(initTensor.getDefiningOp());
-    auto extractOp = b.create<tensor::ExtractSliceOp>(
-        initTensor.getLoc(), insertOp.dest(), insertOp.getMixedOffsets(),
-        insertOp.getMixedSizes(), insertOp.getMixedStrides());
-    // Uses of the InitTensorOp are replaced here, but the op is not deleted.
-    // InitTensorOps without uses are ignored by the bufferization.
-    initTensor.replaceAllUsesWith(extractOp.result());
-    aliasInfo.createAliasInfoEntry(extractOp.result());
-
-    // Run analysis on the ExtractSliceOp.
-    if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(
-            extractOp->getOpOperand(0), aliasInfo, domInfo)))
-      return WalkResult::interrupt();
-
-    // Advance to the next operation.
-    return WalkResult::advance();
-  });
-
-  return failure(status.wasInterrupted());
+        // Only inplace bufferized InsertSliceOps are eligible.
+        if (getInPlace(insertSliceOp->getOpResult(0)) != InPlaceSpec::True)
+          return false;
+        return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
+      },
+      [](OpBuilder &b, Location loc, OpOperand &operand) {
+        auto insertSliceOp = cast<InsertSliceOp>(operand.getOwner());
+        auto extractOp = b.create<tensor::ExtractSliceOp>(
+            loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(),
+            insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+        return extractOp.result();
+      });
 }
 
 void LinalgComprehensiveModuleBufferize::runOnOperation() {
@@ -2291,7 +2329,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
 
     // Try to eliminate InitTensorOps to avoid new allocations during the
     // bufferization phase.
-    if (failed(runInitTensorElimination(funcOp, aliasInfo, domInfo))) {
+    if (failed(eliminateInsertSliceAnchoredInitTensorOps(funcOp, aliasInfo,
+                                                         domInfo))) {
       signalPassFailure();
       return;
     }


        


More information about the Mlir-commits mailing list