[Mlir-commits] [mlir] 46e3166 - [mlir][linalg][bufferize][NFC] Refactor BufferizationOption ownership

Matthias Springer llvmlistbot at llvm.org
Wed Jan 5 03:29:41 PST 2022


Author: Matthias Springer
Date: 2022-01-05T20:24:54+09:00
New Revision: 46e316651f7870766d0698870081dea1536260f7

URL: https://github.com/llvm/llvm-project/commit/46e316651f7870766d0698870081dea1536260f7
DIFF: https://github.com/llvm/llvm-project/commit/46e316651f7870766d0698870081dea1536260f7.diff

LOG: [mlir][linalg][bufferize][NFC] Refactor BufferizationOption ownership

Pass unique_ptr<BufferizationOption> to the bufferization. This allows the bufferization to enqueue additional PostAnalysisSteps. When running bufferization a second time, a new BufferizationOptions must be constructed.

Differential Revision: https://reviews.llvm.org/D116101

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 6d0a32cfd3442..bda6c25b28774 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -423,12 +423,11 @@ class BufferizationState {
 private:
   friend LogicalResult
   runComprehensiveBufferize(Operation *op, const BufferizationOptions &options,
-                            BufferizationState &state,
-                            const PostAnalysisStepList &extraSteps);
+                            BufferizationState &state);
 
   friend LogicalResult
   runComprehensiveBufferize(ModuleOp moduleOp,
-                            const BufferizationOptions &options);
+                            std::unique_ptr<BufferizationOptions> options);
 
   /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
   /// functions and `runComprehensiveBufferize` may access this object.

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index 93a697fbeb393..a984f5749380d 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -18,17 +18,17 @@ namespace comprehensive_bufferize {
 
 struct BufferizationOptions;
 class BufferizationState;
-struct PostAnalysisStep;
 
 /// Bufferize the given operation. Reuses an existing BufferizationState object.
-LogicalResult runComprehensiveBufferize(
-    Operation *op, const BufferizationOptions &options,
-    BufferizationState &state,
-    const std::vector<std::unique_ptr<PostAnalysisStep>> &extraSteps);
+/// This function overload is for internal usage only.
+LogicalResult runComprehensiveBufferize(Operation *op,
+                                        const BufferizationOptions &options,
+                                        BufferizationState &state);
 
 /// Bufferize the given operation.
-LogicalResult runComprehensiveBufferize(Operation *op,
-                                        const BufferizationOptions &options);
+LogicalResult
+runComprehensiveBufferize(Operation *op,
+                          std::unique_ptr<BufferizationOptions> options);
 
 } // namespace comprehensive_bufferize
 } // namespace linalg

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
index 01f687ecf0cf2..88ccd0ea7727a 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULE_BUFFERIZATION_H
 #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_MODULE_BUFFERIZATION_H
 
+#include <memory>
+
 namespace mlir {
 
 class DialectRegistry;
@@ -22,8 +24,9 @@ struct BufferizationOptions;
 
 /// Bufferize the given module. This bufferizations performs a simple function
 /// call analysis to determine which function arguments are inplaceable.
-LogicalResult runComprehensiveBufferize(ModuleOp moduleOp,
-                                        const BufferizationOptions &options);
+LogicalResult
+runComprehensiveBufferize(ModuleOp moduleOp,
+                          std::unique_ptr<BufferizationOptions> options);
 
 namespace std_ext {
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 485fb735b3ee4..c46746f6813f9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -650,15 +650,14 @@ annotateOpsWithBufferizationMarkers(Operation *op,
 }
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
-    Operation *op, const BufferizationOptions &options) {
-  BufferizationState state(op, options);
-  PostAnalysisStepList extraSteps;
-  return runComprehensiveBufferize(op, options, state, extraSteps);
+    Operation *op, std::unique_ptr<BufferizationOptions> options) {
+  BufferizationState state(op, *options);
+  return runComprehensiveBufferize(op, *options, state);
 }
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     Operation *op, const BufferizationOptions &options,
-    BufferizationState &state, const PostAnalysisStepList &extraSteps) {
+    BufferizationState &state) {
 
   DominanceInfo domInfo(op);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
@@ -672,23 +671,16 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     return failure();
   equivalenceAnalysis(op, aliasInfo, state);
 
-  auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
-    for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
-      SmallVector<Operation *> newOps;
-      if (failed(step->run(op, state, aliasInfo, newOps)))
-        return failure();
-      // Analyze ops that were created by the PostAnalysisStep.
-      if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
-        return failure();
-      equivalenceAnalysis(newOps, aliasInfo, state);
-    }
-    return success();
-  };
-
-  if (failed(runPostAnalysisSteps(extraSteps)))
-    return failure();
-  if (failed(runPostAnalysisSteps(options.postAnalysisSteps)))
-    return failure();
+  for (const std::unique_ptr<PostAnalysisStep> &step :
+       options.postAnalysisSteps) {
+    SmallVector<Operation *> newOps;
+    if (failed(step->run(op, state, aliasInfo, newOps)))
+      return failure();
+    // Analyze ops that were created by the PostAnalysisStep.
+    if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
+      return failure();
+    equivalenceAnalysis(newOps, aliasInfo, state);
+  }
 
   // Annotate operations if we only want to report the analysis.
   if (options.testAnalysisOnly) {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 49687ccacd3d9..e7a5330ef3999 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -724,8 +724,8 @@ static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
 }
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
-    ModuleOp moduleOp, const BufferizationOptions &options) {
-  BufferizationState state(moduleOp, options);
+    ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
+  BufferizationState state(moduleOp, *options);
   ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
@@ -743,24 +743,23 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     if (funcOp.body().empty())
       continue;
 
-    // Register extra post analysis steps. These cannot be stored in `options`
-    // because `options` is immutable.
-    PostAnalysisStepList extraSteps;
-    extraSteps.emplace_back(std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
+    // Collect bbArg/return value information after the analysis.
+    options->postAnalysisSteps.emplace_back(
+        std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
 
     // Gather equivalence info for CallOps.
     equivalenceAnalysis(funcOp, aliasInfo, moduleState);
 
     // Analyze and bufferize funcOp.
-    if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps)))
+    if (failed(runComprehensiveBufferize(funcOp, *options, state)))
       return failure();
 
     // Add annotations to function arguments.
-    if (options.testAnalysisOnly)
+    if (options->testAnalysisOnly)
       annotateOpsWithBufferizationMarkers(funcOp, state);
   }
 
-  if (options.testAnalysisOnly)
+  if (options->testAnalysisOnly)
     return success();
 
   for (FuncOp funcOp : moduleState.orderedFuncOps) {
@@ -769,7 +768,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     if (failed(bufferizeFuncOpBoundary(funcOp, state)))
       return failure();
 
-    if (!options.allowReturnMemref &&
+    if (!options->allowReturnMemref &&
         llvm::any_of(funcOp.getType().getResults(), [](Type t) {
           return t.isa<MemRefType, UnrankedMemRefType>();
         })) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 7ee9caf004afa..c5fdf402d9412 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -73,42 +73,42 @@ static Optional<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
 }
 
 void LinalgComprehensiveModuleBufferize::runOnOperation() {
-  BufferizationOptions options;
+  auto options = std::make_unique<BufferizationOptions>();
   if (useAlloca) {
-    options.allocationFns->allocationFn = allocationFnUsingAlloca;
-    options.allocationFns->deallocationFn = [](OpBuilder &b, Location loc,
-                                               Value v) {};
+    options->allocationFns->allocationFn = allocationFnUsingAlloca;
+    options->allocationFns->deallocationFn = [](OpBuilder &b, Location loc,
+                                                Value v) {};
   }
   // TODO: Change to memref::CopyOp (default memCpyFn).
-  options.allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from,
-                                       Value to) {
+  options->allocationFns->memCpyFn = [](OpBuilder &b, Location loc, Value from,
+                                        Value to) {
     b.create<linalg::CopyOp>(loc, from, to);
   };
 
-  options.allowReturnMemref = allowReturnMemref;
-  options.allowUnknownOps = allowUnknownOps;
-  options.analysisFuzzerSeed = analysisFuzzerSeed;
-  options.testAnalysisOnly = testAnalysisOnly;
-  options.printConflicts = printConflicts;
+  options->allowReturnMemref = allowReturnMemref;
+  options->allowUnknownOps = allowUnknownOps;
+  options->analysisFuzzerSeed = analysisFuzzerSeed;
+  options->testAnalysisOnly = testAnalysisOnly;
+  options->printConflicts = printConflicts;
 
   // Enable InitTensorOp elimination.
-  options.addPostAnalysisStep<
+  options->addPostAnalysisStep<
       linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
   // TODO: Find a way to enable this step automatically when bufferizing tensor
   // dialect ops.
-  options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
+  options->addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
   if (!allowReturnMemref)
-    options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
+    options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
 
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);
 
-  if (failed(runComprehensiveBufferize(moduleOp, options))) {
+  if (failed(runComprehensiveBufferize(moduleOp, std::move(options)))) {
     signalPassFailure();
     return;
   }
 
-  if (options.testAnalysisOnly)
+  if (testAnalysisOnly)
     return;
 
   OpPassManager cleanupPipeline("builtin.module");

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index d54948e23844d..f4a43aab1ebb9 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -94,30 +94,30 @@ struct TestComprehensiveFunctionBufferize
 } // namespace
 
 void TestComprehensiveFunctionBufferize::runOnFunction() {
-  BufferizationOptions options;
+  auto options = std::make_unique<BufferizationOptions>();
 
   // Enable InitTensorOp elimination.
-  options.addPostAnalysisStep<
+  options->addPostAnalysisStep<
       linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
   // TODO: Find a way to enable this step automatically when bufferizing
   // tensor dialect ops.
-  options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
+  options->addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
   if (!allowReturnMemref)
-    options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
+    options->addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
 
-  options.allowReturnMemref = allowReturnMemref;
-  options.allowUnknownOps = allowUnknownOps;
-  options.testAnalysisOnly = testAnalysisOnly;
-  options.analysisFuzzerSeed = analysisFuzzerSeed;
+  options->allowReturnMemref = allowReturnMemref;
+  options->allowUnknownOps = allowUnknownOps;
+  options->testAnalysisOnly = testAnalysisOnly;
+  options->analysisFuzzerSeed = analysisFuzzerSeed;
 
   if (dialectFilter.hasValue()) {
-    options.dialectFilter.emplace();
+    options->dialectFilter.emplace();
     for (const std::string &dialectNamespace : dialectFilter)
-      options.dialectFilter->insert(dialectNamespace);
+      options->dialectFilter->insert(dialectNamespace);
   }
 
   Operation *op = getFunction().getOperation();
-  if (failed(runComprehensiveBufferize(op, options)))
+  if (failed(runComprehensiveBufferize(op, std::move(options))))
     return;
 
   OpPassManager cleanupPipeline("builtin.func");


        


More information about the Mlir-commits mailing list