[Mlir-commits] [mlir] 5f60c48 - [mlir][linalg][bufferize][NFC] Make init_tensor elimination a separate pre-processing pass

Matthias Springer llvmlistbot at llvm.org
Wed May 4 01:19:28 PDT 2022


Author: Matthias Springer
Date: 2022-05-04T17:17:27+09:00
New Revision: 5f60c4825b351189379f147ad636f19fa5060e5c

URL: https://github.com/llvm/llvm-project/commit/5f60c4825b351189379f147ad636f19fa5060e5c
DIFF: https://github.com/llvm/llvm-project/commit/5f60c4825b351189379f147ad636f19fa5060e5c.diff

LOG: [mlir][linalg][bufferize][NFC] Make init_tensor elimination a separate pre-processing pass

This commit decouples init_tensor elimination from the rest of the bufferization.

Differential Revision: https://reviews.llvm.org/D124853

Added: 
    mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir
    mlir/test/Dialect/Linalg/one-shot-bufferize-init-tensor-elimination.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.h
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
    mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Removed: 
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 73226b2587a54..3510b2f1f984a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -73,6 +73,10 @@ std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass();
 std::unique_ptr<Pass> createLinalgComprehensiveModuleBufferizePass(
     const bufferization::OneShotBufferizationOptions &options);
 
+/// Create a pass that tries to eliminate init_tensor ops that are anchored on
+/// insert_slice ops.
+std::unique_ptr<Pass> createLinalgInitTensorEliminationPass();
+
 /// Create a pass to convert Linalg operations which work on tensors to use
 /// buffers instead.
 std::unique_ptr<OperationPass<func::FuncOp>> createLinalgBufferizePass();

diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 06f0e217986d7..2c0287de0fcac 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -62,10 +62,6 @@ def LinalgComprehensiveModuleBufferize :
     Option<"analysisFuzzerSeed", "analysis-fuzzer-seed", "unsigned",
            /*default=*/"0",
            "Analyze ops in random order with a given seed (fuzzer)">,
-    Option<"initTensorElimination", "init-tensor-elimination", "bool",
-            /*default=*/"false",
-           "(Experimental) Try to eliminate init_tensor operations that are "
-           "anchored at an insert_slice op">,
     Option<"createDeallocs", "create-deallocs", "bool", /*default=*/"true",
            "Specify if buffers should be deallocated. For compatibility with "
            "core bufferization passes.">,
@@ -73,6 +69,18 @@ def LinalgComprehensiveModuleBufferize :
   let constructor = "mlir::createLinalgComprehensiveModuleBufferizePass()";
 }
 
+def LinalgInitTensorElimination : Pass<"linalg-eliminate-init-tensors"> {
+  let summary = "Try to eliminate all init_tensor ops.";
+  let description = [{
+    This pass tries to eliminate all insert_slice op-anchored init_tensor ops.
+    I.e., when a value that is aliasing with an init_tensor op is inserted into
+    another tensor, this pass tries to rewrite the IR in such a way that the
+    destination tensor of the insert_slice op is used directly instead of the
+    init_tensor result.
+  }];
+  let constructor = "mlir::createLinalgInitTensorEliminationPass()";
+}
+
 def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {
   let summary = "Remove unit-extent dimension in Linalg ops on tensors";
   let constructor = "mlir::createLinalgFoldUnitExtentDimsPass()";

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
index 64c3232e86372..8fc406ceef80c 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h
@@ -33,21 +33,16 @@ using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
 ///   on the reverse SSA use-def chain, starting from the OpOperand and always
 ///   following the aliasing  OpOperand, that eventually ends at a single
 ///   InitTensorOp.
-/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
-///   This analysis can be skipped with `skipAnalysis`.
-LogicalResult
-eliminateInitTensors(Operation *op, bufferization::AnalysisState &state,
-                     bufferization::BufferizationAliasInfo &aliasInfo,
-                     AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
-                     SmallVector<Operation *> &newOps);
+LogicalResult eliminateInitTensors(RewriterBase &rewriter, Operation *op,
+                                   bufferization::AnalysisState &state,
+                                   AnchorMatchFn anchorMatchFunc,
+                                   RewriteFn rewriteFunc);
 
 /// Try to eliminate InitTensorOps inside `op` that are anchored on an
 /// InsertSliceOp, i.e., if it is eventually inserted into another tensor
 /// (and some other conditions are met).
 LogicalResult insertSliceAnchoredInitTensorEliminationStep(
-    Operation *op, bufferization::AnalysisState &state,
-    bufferization::BufferizationAliasInfo &aliasInfo,
-    SmallVector<Operation *> &newOps);
+    RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);
 
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 6146175debb7d..11fea3adb76c0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -334,16 +334,17 @@ findValidInsertionPoint(Operation *initTensorOp,
 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
 /// chain, starting from the OpOperand and always following the aliasing
 /// OpOperand, that eventually ends at a single InitTensorOp.
-LogicalResult mlir::linalg::eliminateInitTensors(
-    Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo,
-    AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
-    SmallVector<Operation *> &newOps) {
-  OpBuilder b(op->getContext());
+LogicalResult mlir::linalg::eliminateInitTensors(RewriterBase &rewriter,
+                                                 Operation *op,
+                                                 AnalysisState &state,
+                                                 AnchorMatchFn anchorMatchFunc,
+                                                 RewriteFn rewriteFunc) {
+  OpBuilder::InsertionGuard g(rewriter);
 
   WalkResult status = op->walk([&](Operation *op) {
     for (OpOperand &operand : op->getOpOperands()) {
       // Skip operands that do not bufferize inplace.
-      if (!aliasInfo.isInPlace(operand))
+      if (!state.isInPlace(operand))
         continue;
       // All values that are needed to create the replacement op.
       SmallVector<Value> neededValues;
@@ -359,14 +360,14 @@ LogicalResult mlir::linalg::eliminateInitTensors(
             SmallVector<OpOperand *> opOperands =
                 state.getAliasingOpOperand(opResult);
             if (!llvm::all_of(opOperands, [&](OpOperand *operand) {
-                  return aliasInfo.isInPlace(*operand);
+                  return state.isInPlace(*operand);
                 }))
               return true;
             // Only equivalent tensors are supported at the moment.
             // TODO: Support cases such as extract_slice(init_tensor)
             return !llvm::all_of(opOperands, [&](OpOperand *operand) {
-              return aliasInfo.areEquivalentBufferizedValues(operand->get(),
-                                                             opResult);
+              return state.areEquivalentBufferizedValues(operand->get(),
+                                                         opResult);
             });
           });
 
@@ -384,21 +385,13 @@ LogicalResult mlir::linalg::eliminateInitTensors(
         continue;
 
       // Create a replacement for the InitTensorOp.
-      b.setInsertionPoint(insertionPoint);
-      Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
+      rewriter.setInsertionPoint(insertionPoint);
+      Value replacement = rewriteFunc(rewriter, initTensor.getLoc(), operand);
       if (!replacement)
         continue;
 
-      // Uses of the InitTensorOp are replaced here, but the op is not deleted.
-      // InitTensorOps without uses are ignored by the bufferization.
-      initTensor.replaceAllUsesWith(replacement);
-      aliasInfo.createAliasInfoEntry(replacement);
-      aliasInfo.unionAliasSets(initTensor, replacement);
-      aliasInfo.unionEquivalenceClasses(initTensor, replacement);
-
-      // Register replacement ops.
-      if (Operation *newOp = replacement.getDefiningOp())
-        newOps.push_back(newOp);
+      // Replace the InitTensorOp.
+      rewriter.replaceOp(initTensor.getDefiningOp(), replacement);
     }
 
     // Advance to the next operation.
@@ -428,28 +421,20 @@ LogicalResult mlir::linalg::eliminateInitTensors(
 ///
 /// Starting from an InsertSliceOp, an InitTensorOp at the end of the insert
 /// source's reverse use-def chain is eliminated if:
-/// * The InsertSliceOp was decided to bufferize inplace.
 /// * On the reverse use-def chain path from the InsertSliceOp to the
 ///   InitTensorOp, all ops were decided to bufferize inplace and the buffer
 ///   relation is "equivalent" (TODO: can be relaxed if needed).
 /// * The reverse use-def chain has exactly one end, which is the InitTensorOp.
-///
-/// Note that the newly inserted ExtractSliceOp may have to bufferize
-/// out-of-place due to RaW conflicts.
 LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep(
-    Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo,
-    SmallVector<Operation *> &newOps) {
+    RewriterBase &rewriter, Operation *op, AnalysisState &state) {
   return eliminateInitTensors(
-      op, state, aliasInfo,
+      rewriter, op, state,
       /*anchorMatchFunc=*/
       [&](OpOperand &operand, SmallVector<Value> &neededValues) {
         auto insertSliceOp =
             dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
         if (!insertSliceOp)
           return false;
-        // Only inplace bufferized InsertSliceOps are eligible.
-        if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/))
-          return false;
         if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
           return false;
 
@@ -487,8 +472,7 @@ LogicalResult mlir::linalg::insertSliceAnchoredInitTensorEliminationStep(
         auto extractOp = b.create<tensor::ExtractSliceOp>(
             loc, t, insertOp.dest(), mixedOffsets, mixedSizes, mixedStrides);
         return extractOp.result();
-      },
-      newOps);
+      });
 }
 
 void mlir::linalg::registerBufferizableOpInterfaceExternalModels(

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 13b760c6dd959..bbb013d955332 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -63,6 +63,17 @@ struct LinalgComprehensiveModuleBufferize
 private:
   llvm::Optional<OneShotBufferizationOptions> options;
 };
+
+struct LinalgInitTensorElimination
+    : public LinalgInitTensorEliminationBase<LinalgInitTensorElimination> {
+  LinalgInitTensorElimination() = default;
+
+  void runOnOperation() override;
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect, tensor::TensorDialect>();
+  }
+};
 } // namespace
 
 static void applyEnablingTransformations(ModuleOp moduleOp) {
@@ -100,9 +111,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
     opt.testAnalysisOnly = testAnalysisOnly;
     opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
     opt.bufferizeFunctionBoundaries = true;
-    if (initTensorElimination) {
-      opt.addPostAnalysisStep(insertSliceAnchoredInitTensorEliminationStep);
-    }
   } else {
     opt = *options;
   }
@@ -125,6 +133,20 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
   (void)runPipeline(cleanupPipeline, moduleOp);
 }
 
+void LinalgInitTensorElimination::runOnOperation() {
+  Operation *op = getOperation();
+  OneShotBufferizationOptions options;
+  OneShotAnalysisState state(op, options);
+  if (failed(analyzeOp(op, state))) {
+    signalPassFailure();
+    return;
+  }
+
+  IRRewriter rewriter(op->getContext());
+  if (failed(insertSliceAnchoredInitTensorEliminationStep(rewriter, op, state)))
+    signalPassFailure();
+}
+
 std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass() {
   return std::make_unique<LinalgComprehensiveModuleBufferize>();
 }
@@ -133,3 +155,7 @@ std::unique_ptr<Pass> mlir::createLinalgComprehensiveModuleBufferizePass(
     const OneShotBufferizationOptions &options) {
   return std::make_unique<LinalgComprehensiveModuleBufferize>(options);
 }
+
+std::unique_ptr<Pass> mlir::createLinalgInitTensorEliminationPass() {
+  return std::make_unique<LinalgInitTensorElimination>();
+}

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir
similarity index 92%
rename from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
rename to mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir
index dcad535385bb5..512407907aa3c 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis-init-tensor-elimination.mlir
@@ -1,6 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-allocs init-tensor-elimination" -split-input-file | FileCheck %s
-
-// -----
+// RUN: mlir-opt %s -linalg-eliminate-init-tensors -one-shot-bufferize="bufferize-function-boundaries test-analysis-only allow-return-allocs" -split-input-file | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // InitTensorOp elimination

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-init-tensor-elimination.mlir
similarity index 96%
rename from mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
rename to mlir/test/Dialect/Linalg/one-shot-bufferize-init-tensor-elimination.mlir
index 72c350fea5cb8..053efe8728837 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-init-tensor-elimination.mlir
@@ -1,6 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-allocs init-tensor-elimination" -canonicalize -split-input-file | FileCheck %s
-
-// -----
+// RUN: mlir-opt %s -linalg-eliminate-init-tensors -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs" -canonicalize -split-input-file | FileCheck %s
 
 //      CHECK: func @buffer_forwarding_conflict(
 // CHECK-SAME:   %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref<?xf32>


        


More information about the Mlir-commits mailing list