[Mlir-commits] [mlir] 1b2bda8 - [mlir][linalg][bufferize] Add PostAnalysisStep

Matthias Springer llvmlistbot at llvm.org
Thu Nov 11 16:52:34 PST 2021


Author: Matthias Springer
Date: 2021-11-12T09:51:06+09:00
New Revision: 1b2bda8d1a82b0ffc49eb485824fca93db1b0aac

URL: https://github.com/llvm/llvm-project/commit/1b2bda8d1a82b0ffc49eb485824fca93db1b0aac
DIFF: https://github.com/llvm/llvm-project/commit/1b2bda8d1a82b0ffc49eb485824fca93db1b0aac.diff

LOG: [mlir][linalg][bufferize] Add PostAnalysisStep

This helper struct allows users of ComprehensiveBufferize to inject "post analysis" steps that are implemented after the analysis but before the bufferization.

Differential Revision: https://reviews.llvm.org/D113458

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 3c3f601a9385b..07f065d909e56 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -18,6 +18,8 @@
 
 namespace mlir {
 class BlockAndValueMapping;
+class DominanceInfo;
+class FuncOp;
 
 namespace linalg {
 namespace comprehensive_bufferize {
@@ -266,6 +268,20 @@ struct BufferizationState {
 /// bufferization is necessary.
 Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state);
 
+/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
+/// executed after the analysis, but before bufferization. They can be used
+/// implement custom dialect-specific optimizations.
+struct PostAnalysisStep {
+  virtual ~PostAnalysisStep() {}
+
+  /// Run the post analysis step. This function may modify the IR, but must keep
+  /// `aliasInfo` consistent. Newly created operations and operations that
+  /// should be re-analyzed must be stored in `newOps`.
+  virtual LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+                            DominanceInfo &domInfo,
+                            SmallVector<Operation *> &newOps) = 0;
+};
+
 } // namespace comprehensive_bufferize
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index 72a5c700c10a4..5df3b63aa60aa 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -63,40 +63,59 @@ bufferizeOp(Operation *op, BufferizationState &state,
 /// Register external models implemented for the `BufferizableOpInterface`.
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
 
-/// Try to eliminate InitTensorOps inside `funcOp`.
-///
-/// * `rewriteFunc` generates the replacement for the InitTensorOp.
-/// * Only InitTensorOps that are anchored on a matching OpOperand as per
-///   `anchorMatchFunc` are considered. "Anchored" means that there is a path on
-///   the reverse SSA use-def chain, starting from the OpOperand and always
-///   following the aliasing  OpOperand, that eventually ends at a single
-///   InitTensorOp.
-/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
-///   This analysis can be skipped with `skipAnalysis`.
-LogicalResult initTensorElimination(
-    FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
-    std::function<bool(OpOperand &)> anchorMatchFunc,
-    std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
-    bool skipAnalysis = false);
-
-/// Try to eliminate InitTensorOps inside funcOp that are anchored on an
-/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
-/// (and some other conditions are met).
-LogicalResult eliminateInsertSliceAnchoredInitTensorOps(
-    FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo);
-
 struct BufferizationOptions {
   BufferizationOptions();
 
+  /// Register a "post analysis" step. Such steps are executed after the
+  /// analysis, but before bufferization.
+  template <typename Step, typename... Args>
+  void addPostAnalysisStep(Args... args) {
+    postAnalysisSteps.emplace_back(
+        std::make_unique<Step>(std::forward<Args>(args)...));
+  }
+
   std::unique_ptr<AllocationCallbacks> allocationFns;
   bool allowReturnMemref = false;
   unsigned analysisFuzzerSeed = 0;
   bool testAnalysisOnly = false;
+  std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
 };
 
 LogicalResult runComprehensiveBufferize(ModuleOp moduleOp,
                                         const BufferizationOptions &options);
 
+namespace linalg_ext {
+
+struct InitTensorEliminationStep : public PostAnalysisStep {
+  /// Try to eliminate InitTensorOps inside `funcOp`.
+  ///
+  /// * `rewriteFunc` generates the replacement for the InitTensorOp.
+  /// * Only InitTensorOps that are anchored on a matching OpOperand as per
+  ///   `anchorMatchFunc` are considered. "Anchored" means that there is a path
+  ///   on the reverse SSA use-def chain, starting from the OpOperand and always
+  ///   following the aliasing  OpOperand, that eventually ends at a single
+  ///   InitTensorOp.
+  /// * The result of `rewriteFunc` must usually be analyzed for inplacability.
+  ///   This analysis can be skipped with `skipAnalysis`.
+  LogicalResult eliminateInitTensors(
+      FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
+      std::function<bool(OpOperand &)> anchorMatchFunc,
+      std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
+      SmallVector<Operation *> &newOps);
+};
+
+/// Try to eliminate InitTensorOps inside funcOp that are anchored on an
+/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
+/// (and some other conditions are met).
+struct InsertSliceAnchoredInitTensorEliminationStep
+    : public InitTensorEliminationStep {
+  LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+                    DominanceInfo &domInfo,
+                    SmallVector<Operation *> &newOps) override;
+};
+
+} // namespace linalg_ext
+
 } // namespace comprehensive_bufferize
 } // namespace linalg
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index cbfbef467b6f3..bd6babfc456f3 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -1595,11 +1595,13 @@ static void layoutPostProcessing(ModuleOp moduleOp) {
 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
 /// chain, starting from the OpOperand and always following the aliasing
 /// OpOperand, that eventually ends at a single InitTensorOp.
-LogicalResult mlir::linalg::comprehensive_bufferize::initTensorElimination(
-    FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo,
-    std::function<bool(OpOperand &)> anchorMatchFunc,
-    std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
-    bool skipAnalysis) {
+LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
+    InitTensorEliminationStep::eliminateInitTensors(
+        FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+        DominanceInfo &domInfo,
+        std::function<bool(OpOperand &)> anchorMatchFunc,
+        std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
+        SmallVector<Operation *> &newOps) {
   OpBuilder b(funcOp->getContext());
 
   WalkResult status = funcOp->walk([&](Operation *op) {
@@ -1647,14 +1649,9 @@ LogicalResult mlir::linalg::comprehensive_bufferize::initTensorElimination(
       aliasInfo.unionAliasSets(initTensor, replacement);
       aliasInfo.unionEquivalenceClasses(initTensor, replacement);
 
-      // Run analysis on the newly created op.
-      if (auto opResult = replacement.dyn_cast<OpResult>()) {
-        if (!skipAnalysis) {
-          SmallVector<Operation *> ops(1, replacement.getDefiningOp());
-          if (failed(inPlaceAnalysis(ops, aliasInfo, domInfo)))
-            return WalkResult::interrupt();
-        }
-      }
+      // Register replacement ops.
+      if (Operation *newOp = replacement.getDefiningOp())
+        newOps.push_back(newOp);
     }
 
     // Advance to the next operation.
@@ -1692,11 +1689,11 @@ LogicalResult mlir::linalg::comprehensive_bufferize::initTensorElimination(
 ///
 /// Note that the newly inserted ExtractSliceOp may have to bufferize
 /// out-of-place due to RaW conflicts.
-LogicalResult mlir::linalg::comprehensive_bufferize::
-    eliminateInsertSliceAnchoredInitTensorOps(FuncOp funcOp,
-                                              BufferizationAliasInfo &aliasInfo,
-                                              DominanceInfo &domInfo) {
-  return initTensorElimination(
+LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
+    InsertSliceAnchoredInitTensorEliminationStep::run(
+        FuncOp funcOp, BufferizationAliasInfo &aliasInfo,
+        DominanceInfo &domInfo, SmallVector<Operation *> &newOps) {
+  return eliminateInitTensors(
       funcOp, aliasInfo, domInfo,
       [&](OpOperand &operand) {
         auto insertSliceOp = dyn_cast<InsertSliceOp>(operand.getOwner());
@@ -1713,7 +1710,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::
             loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(),
             insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
         return extractOp.result();
-      });
+      },
+      newOps);
 }
 
 #ifndef NDEBUG
@@ -1793,11 +1791,15 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
                                          options.analysisFuzzerSeed)))
       return failure();
 
-    // Try to eliminate InitTensorOps to avoid new allocations during the
-    // bufferization phase.
-    if (failed(eliminateInsertSliceAnchoredInitTensorOps(funcOp, aliasInfo,
-                                                         domInfo)))
-      return failure();
+    for (const std::unique_ptr<PostAnalysisStep> &step :
+         options.postAnalysisSteps) {
+      SmallVector<Operation *> newOps;
+      if (failed(step->run(funcOp, aliasInfo, domInfo, newOps)))
+        return failure();
+      // Analyze ops that were created by the PostAnalysisStep.
+      if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
+        return failure();
+    }
 
     // Bufferization phase.
     if (!options.testAnalysisOnly) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index f14ab5af50a6a..20b1894755fda 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -71,6 +71,10 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
   options.analysisFuzzerSeed = analysisFuzzerSeed;
   options.testAnalysisOnly = testAnalysisOnly;
 
+  // Enable InitTensorOp elimination.
+  options.addPostAnalysisStep<
+      linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
+
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);
 


        


More information about the Mlir-commits mailing list