[Mlir-commits] [mlir] 0e5f258 - [mlir][linalg][bufferize][NFC] Simplify InsertSliceOp bufferization

Matthias Springer llvmlistbot at llvm.org
Thu Jan 6 00:40:51 PST 2022


Author: Matthias Springer
Date: 2022-01-06T17:35:45+09:00
New Revision: 0e5f258452b053cc3374754efaeabe3c30f42482

URL: https://github.com/llvm/llvm-project/commit/0e5f258452b053cc3374754efaeabe3c30f42482
DIFF: https://github.com/llvm/llvm-project/commit/0e5f258452b053cc3374754efaeabe3c30f42482.diff

LOG: [mlir][linalg][bufferize][NFC] Simplify InsertSliceOp bufferization

No need to keep track of equivalent extract_slice / insert_slice tensors during bufferization. Just emit a copy, it will fold away.

Note: The analysis still keeps track of equivalent tensors to make the correct inplace bufferization decisions.

Differential Revision: https://reviews.llvm.org/D116684

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
index ca620138d6433..29355ef338f3a 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h
@@ -9,8 +9,6 @@
 #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
 #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_TENSOR_INTERFACE_IMPL_H
 
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
-
 namespace mlir {
 
 class DialectRegistry;
@@ -19,12 +17,6 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace tensor_ext {
 
-struct InplaceInsertSliceOpAnalysis : public PostAnalysisStep {
-  LogicalResult run(Operation *op, BufferizationState &state,
-                    BufferizationAliasInfo &aliasInfo,
-                    SmallVector<Operation *> &newOps) override;
-};
-
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
 
 } // namespace tensor_ext

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 0f91e52a5227e..9ee1d23d5d8af 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -23,20 +23,6 @@ namespace tensor_ext {
 using tensor::ExtractSliceOp;
 using tensor::InsertSliceOp;
 
-namespace {
-/// Extra bufferization state that is required for bufferization of tensor ops.
-struct TensorBufferizationState : public DialectBufferizationState {
-  /// InsertSliceOps that bufferize inplace and do not require a copy.
-  DenseSet<Operation *> insertSliceOpsWithoutCopy;
-};
-} // namespace
-
-static TensorBufferizationState &
-getTensorBufferizationState(BufferizationState &state) {
-  return state.getDialectState<TensorBufferizationState>(
-      tensor::TensorDialect::getDialectNamespace());
-}
-
 struct CastOpInterface
     : public BufferizableOpInterface::ExternalModel<CastOpInterface,
                                                     tensor::CastOp> {
@@ -274,23 +260,6 @@ areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
   return true;
 }
 
-/// Return true if the source of a `insertSliceOp` bufferizes to an
-/// equivalent ExtractSliceOp that bufferizes inplace.
-static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
-    const BufferizationAliasInfo &aliasInfo, InsertSliceOp insertSliceOp) {
-  bool foundOp = false;
-  aliasInfo.applyOnEquivalenceClass(insertSliceOp.source(), [&](Value value) {
-    auto extractSliceOp = value.getDefiningOp<ExtractSliceOp>();
-    if (extractSliceOp &&
-        areEquivalentExtractSliceOps(aliasInfo, extractSliceOp,
-                                     insertSliceOp) &&
-        aliasInfo.isInPlace(extractSliceOp->getResult(0))) {
-      foundOp = true;
-    }
-  });
-  return foundOp;
-}
-
 /// Return true if `value` is originating from an ExtractSliceOp that matches
 /// the given InsertSliceOp.
 static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
@@ -419,7 +388,6 @@ struct InsertSliceOpInterface
     // TODO: be very loud about it or even consider failing the pass.
     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
     Location loc = insertSliceOp.getLoc();
-    TensorBufferizationState &tensorState = getTensorBufferizationState(state);
 
     // When bufferizing out-of-place, `getResultBuffer` allocates.
     Value dstMemref =
@@ -427,24 +395,22 @@ struct InsertSliceOpInterface
     if (!dstMemref)
       return failure();
 
-    bool needCopy =
-        !tensorState.insertSliceOpsWithoutCopy.contains(insertSliceOp);
-    if (needCopy) {
-      // Take a subview of the dst.
-      auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
-      auto subviewMemRefType =
-          memref::SubViewOp::inferRankReducedResultType(
-              insertSliceOp.getSourceType().getRank(), dstMemrefType,
-              insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
-              insertSliceOp.getMixedStrides())
-              .cast<MemRefType>();
-      Value subView = rewriter.create<memref::SubViewOp>(
-          loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
-          insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
-      // Copy tensor.
-      Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
-      state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
-    }
+    // Take a subview of the dst.
+    auto dstMemrefType = dstMemref.getType().cast<MemRefType>();
+    auto subviewMemRefType =
+        memref::SubViewOp::inferRankReducedResultType(
+            insertSliceOp.getSourceType().getRank(), dstMemrefType,
+            insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+            insertSliceOp.getMixedStrides())
+            .cast<MemRefType>();
+    Value subView = rewriter.create<memref::SubViewOp>(
+        loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(),
+        insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
+
+    // Copy tensor. If this tensor.insert_slice has a matching
+    // tensor.extract_slice, the copy operation will eventually fold away.
+    Value srcMemref = state.lookupBuffer(rewriter, insertSliceOp.source());
+    state.createMemCpy(rewriter, insertSliceOp.getLoc(), srcMemref, subView);
 
     state.replaceOp(rewriter, op, dstMemref);
     return success();
@@ -456,25 +422,6 @@ struct InsertSliceOpInterface
 } // namespace linalg
 } // namespace mlir
 
-LogicalResult mlir::linalg::comprehensive_bufferize::tensor_ext::
-    InplaceInsertSliceOpAnalysis::run(Operation *op, BufferizationState &state,
-                                      BufferizationAliasInfo &aliasInfo,
-                                      SmallVector<Operation *> &newOps) {
-  auto &tensorState = getTensorBufferizationState(state);
-  op->walk([&](InsertSliceOp insertSliceOp) {
-    // A copy of the source buffer is needed if either:
-    //   - The producer of `source` is not inplace. This is the case where a
-    //     slice is computed out of place into the inplace full tensor.
-    //   - The result is not inplace. This is the case where the whole tensor is
-    //     cloned and the clone needs to be updated.
-    if (isSourceEquivalentToAMatchingInplaceExtractSliceOp(aliasInfo,
-                                                           insertSliceOp) &&
-        state.isInPlace(insertSliceOp->getResult(0)))
-      tensorState.insertSliceOpsWithoutCopy.insert(insertSliceOp);
-  });
-  return success();
-}
-
 void mlir::linalg::comprehensive_bufferize::tensor_ext::
     registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
   registry.addOpInterface<tensor::CastOp, tensor_ext::CastOpInterface>();

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index c5fdf402d9412..13e18001d82ee 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -94,9 +94,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
   // Enable InitTensorOp elimination.
   options->addPostAnalysisStep<
       linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
-  // TODO: Find a way to enable this step automatically when bufferizing tensor
-  // dialect ops.
-  options->addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
   if (!allowReturnMemref)
     options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index f4a43aab1ebb9..59e53eaba56a9 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -99,9 +99,6 @@ void TestComprehensiveFunctionBufferize::runOnFunction() {
   // Enable InitTensorOp elimination.
   options->addPostAnalysisStep<
       linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
-  // TODO: Find a way to enable this step automatically when bufferizing
-  // tensor dialect ops.
-  options->addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
   if (!allowReturnMemref)
     options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
 


        


More information about the Mlir-commits mailing list