[Mlir-commits] [mlir] cdb7675 - [mlir][bufferize][NFC] Make PostAnalysisSteps a function

Matthias Springer llvmlistbot at llvm.org
Wed Feb 9 02:03:19 PST 2022


Author: Matthias Springer
Date: 2022-02-09T18:56:06+09:00
New Revision: cdb7675c2649ee91c4e97d84daa76c98cb93b9c4

URL: https://github.com/llvm/llvm-project/commit/cdb7675c2649ee91c4e97d84daa76c98cb93b9c4
DIFF: https://github.com/llvm/llvm-project/commit/cdb7675c2649ee91c4e97d84daa76c98cb93b9c4.diff

LOG: [mlir][bufferize][NFC] Make PostAnalysisSteps a function

They used to be classes with a virtual `run` function. This was inconvenient because post analysis steps are stored in BufferizationOptions. Because of this design choice, BufferizationOptions were not copyable.

Differential Revision: https://reviews.llvm.org/D119258

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
    mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
    mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 3176d6fa337bc..609a1bb520c9d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -47,9 +47,6 @@ struct BufferizationOptions {
 
   BufferizationOptions();
 
-  // BufferizationOptions cannot be copied.
-  BufferizationOptions(const BufferizationOptions &other) = delete;
-
   /// Return `true` if the op is allowed to be bufferized.
   bool isOpAllowed(Operation *op) const {
     if (!hasFilter)

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index a56287995aa96..e2ede4b63d2f3 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -82,7 +82,7 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options);
 void populateBufferizationPattern(const BufferizationState &state,
                                   RewritePatternSet &patterns);
 
-std::unique_ptr<BufferizationOptions> getPartialBufferizationOptions();
+BufferizationOptions getPartialBufferizationOptions();
 
 } // namespace bufferization
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 93b8be9c7c7aa..8e9e09663ec13 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -20,35 +20,25 @@ class AnalysisBufferizationState;
 class BufferizationAliasInfo;
 struct AnalysisBufferizationOptions;
 
-/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
+/// PostAnalysisStepFns can be registered with `BufferizationOptions` and are
 /// executed after the analysis, but before bufferization. They can be used to
-/// implement custom dialect-specific optimizations.
-struct PostAnalysisStep {
-  virtual ~PostAnalysisStep() = default;
-
-  /// Run the post analysis step. This function may modify the IR, but must keep
-  /// `aliasInfo` consistent. Newly created operations and operations that
-  /// should be re-analyzed must be added to `newOps`.
-  virtual LogicalResult run(Operation *op, BufferizationState &state,
-                            BufferizationAliasInfo &aliasInfo,
-                            SmallVector<Operation *> &newOps) = 0;
-};
+/// implement custom dialect-specific optimizations. They may modify the IR, but
+/// must keep `aliasInfo` consistent. Newly created operations and operations
+/// that should be re-analyzed must be added to `newOps`.
+using PostAnalysisStepFn = std::function<LogicalResult(
+    Operation *, BufferizationState &, BufferizationAliasInfo &,
+    SmallVector<Operation *> &)>;
 
-using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
+using PostAnalysisStepList = SmallVector<PostAnalysisStepFn>;
 
 /// Options for analysis-enabled bufferization.
 struct AnalysisBufferizationOptions : public BufferizationOptions {
   AnalysisBufferizationOptions() = default;
 
-  // AnalysisBufferizationOptions cannot be copied.
-  AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
-
   /// Register a "post analysis" step. Such steps are executed after the
   /// analysis, but before bufferization.
-  template <typename Step, typename... Args>
-  void addPostAnalysisStep(Args... args) {
-    postAnalysisSteps.emplace_back(
-        std::make_unique<Step>(std::forward<Args>(args)...));
+  void addPostAnalysisStep(PostAnalysisStepFn fn) {
+    postAnalysisSteps.push_back(fn);
   }
 
   /// Registered post analysis steps.

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
index 06145d028d4d1..010fd565faa92 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -18,42 +18,38 @@ namespace linalg {
 namespace comprehensive_bufferize {
 namespace linalg_ext {
 
-struct InitTensorEliminationStep : public bufferization::PostAnalysisStep {
-  /// A function that matches anchor OpOperands for InitTensorOp elimination.
-  /// If an OpOperand is matched, the function should populate the SmallVector
-  /// with all values that are needed during `RewriteFn` to produce the
-  /// replacement value.
-  using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
-
-  /// A function that rewrites matched anchors.
-  using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
-
-  /// Try to eliminate InitTensorOps inside `op`.
-  ///
-  /// * `rewriteFunc` generates the replacement for the InitTensorOp.
-  /// * Only InitTensorOps that are anchored on a matching OpOperand as per
-  ///   `anchorMatchFunc` are considered. "Anchored" means that there is a path
-  ///   on the reverse SSA use-def chain, starting from the OpOperand and always
-  ///   following the aliasing  OpOperand, that eventually ends at a single
-  ///   InitTensorOp.
-  /// * The result of `rewriteFunc` must usually be analyzed for inplacability.
-  ///   This analysis can be skipped with `skipAnalysis`.
-  LogicalResult
-  eliminateInitTensors(Operation *op, bufferization::BufferizationState &state,
-                       bufferization::BufferizationAliasInfo &aliasInfo,
-                       AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
-                       SmallVector<Operation *> &newOps);
-};
+/// A function that matches anchor OpOperands for InitTensorOp elimination.
+/// If an OpOperand is matched, the function should populate the SmallVector
+/// with all values that are needed during `RewriteFn` to produce the
+/// replacement value.
+using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
+
+/// A function that rewrites matched anchors.
+using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
+
+/// Try to eliminate InitTensorOps inside `op`.
+///
+/// * `rewriteFunc` generates the replacement for the InitTensorOp.
+/// * Only InitTensorOps that are anchored on a matching OpOperand as per
+///   `anchorMatchFunc` are considered. "Anchored" means that there is a path
+///   on the reverse SSA use-def chain, starting from the OpOperand and always
+///   following the aliasing  OpOperand, that eventually ends at a single
+///   InitTensorOp.
+/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
+///   This analysis can be skipped with `skipAnalysis`.
+LogicalResult
+eliminateInitTensors(Operation *op, bufferization::BufferizationState &state,
+                     bufferization::BufferizationAliasInfo &aliasInfo,
+                     AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
+                     SmallVector<Operation *> &newOps);
 
 /// Try to eliminate InitTensorOps inside `op` that are anchored on an
 /// InsertSliceOp, i.e., if it is eventually inserted into another tensor
 /// (and some other conditions are met).
-struct InsertSliceAnchoredInitTensorEliminationStep
-    : public InitTensorEliminationStep {
-  LogicalResult run(Operation *op, bufferization::BufferizationState &state,
-                    bufferization::BufferizationAliasInfo &aliasInfo,
-                    SmallVector<Operation *> &newOps) override;
-};
+LogicalResult insertSliceAnchoredInitTensorEliminationStep(
+    Operation *op, bufferization::BufferizationState &state,
+    bufferization::BufferizationAliasInfo &aliasInfo,
+    SmallVector<Operation *> &newOps);
 
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
 

diff  --git a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
index ea8969e004686..dfeb9514409fb 100644
--- a/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h
@@ -14,16 +14,21 @@
 namespace mlir {
 class DialectRegistry;
 
+namespace bufferization {
+class BufferizationState;
+class BufferizationAliasInfo;
+} // namespace bufferization
+
 namespace scf {
 /// Assert that yielded values of an scf.for op are aliasing their corresponding
 /// bbArgs. This is required because the i-th OpResult of an scf.for op is
 /// currently assumed to alias with the i-th iter_arg (in the absence of
 /// conflicts).
-struct AssertScfForAliasingProperties : public bufferization::PostAnalysisStep {
-  LogicalResult run(Operation *op, bufferization::BufferizationState &state,
-                    bufferization::BufferizationAliasInfo &aliasInfo,
-                    SmallVector<Operation *> &newOps) override;
-};
+LogicalResult
+assertScfForAliasingProperties(Operation *op,
+                               bufferization::BufferizationState &state,
+                               bufferization::BufferizationAliasInfo &aliasInfo,
+                               SmallVector<Operation *> &newOps);
 
 void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
 } // namespace scf

diff  --git a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
index 1fafd255d60d3..d88ee4a65e26c 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp
@@ -29,16 +29,15 @@ struct ArithmeticBufferizePass
   }
 
   void runOnOperation() override {
-    std::unique_ptr<BufferizationOptions> options =
-        getPartialBufferizationOptions();
+    BufferizationOptions options = getPartialBufferizationOptions();
     if (constantOpOnly) {
-      options->addToOperationFilter<arith::ConstantOp>();
+      options.addToOperationFilter<arith::ConstantOp>();
     } else {
-      options->addToDialectFilter<arith::ArithmeticDialect>();
+      options.addToDialectFilter<arith::ArithmeticDialect>();
     }
-    options->bufferAlignment = alignment;
+    options.bufferAlignment = alignment;
 
-    if (failed(bufferizeOp(getOperation(), *options)))
+    if (failed(bufferizeOp(getOperation(), options)))
       signalPassFailure();
   }
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 2b2d7cceeabb4..c7468da6132f5 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -253,12 +253,11 @@ void bufferization::populateBufferizationPattern(
   patterns.add<BufferizationPattern>(patterns.getContext(), state);
 }
 
-std::unique_ptr<BufferizationOptions>
-bufferization::getPartialBufferizationOptions() {
-  auto options = std::make_unique<BufferizationOptions>();
-  options->allowReturnMemref = true;
-  options->allowUnknownOps = true;
-  options->createDeallocs = false;
-  options->fullyDynamicLayoutMaps = false;
+BufferizationOptions bufferization::getPartialBufferizationOptions() {
+  BufferizationOptions options;
+  options.allowReturnMemref = true;
+  options.allowUnknownOps = true;
+  options.createDeallocs = false;
+  options.fullyDynamicLayoutMaps = false;
   return options;
 }

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index c21f7f9704b56..d03a287af5a0e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -698,52 +698,51 @@ annotateOpsWithBufferizationMarkers(Operation *op,
 // aliasing values, which is stricter than needed. We can currently not check
 // for aliasing values because the analysis is a maybe-alias analysis and we
 // need a must-alias analysis here.
-struct AssertDestinationPassingStyle : public PostAnalysisStep {
-  LogicalResult run(Operation *op, BufferizationState &state,
-                    BufferizationAliasInfo &aliasInfo,
-                    SmallVector<Operation *> &newOps) override {
-    LogicalResult status = success();
-    DominanceInfo domInfo(op);
-    op->walk([&](Operation *returnOp) {
-      if (!isRegionReturnLike(returnOp))
-        return WalkResult::advance();
-
-      for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
-        Value returnVal = returnValOperand.get();
-        // Skip non-tensor values.
-        if (!returnVal.getType().isa<TensorType>())
-          continue;
+static LogicalResult
+assertDestinationPassingStyle(Operation *op, BufferizationState &state,
+                              BufferizationAliasInfo &aliasInfo,
+                              SmallVector<Operation *> &newOps) {
+  LogicalResult status = success();
+  DominanceInfo domInfo(op);
+  op->walk([&](Operation *returnOp) {
+    if (!isRegionReturnLike(returnOp))
+      return WalkResult::advance();
 
-        bool foundEquivValue = false;
-        aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
-          if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
-            Operation *definingOp = bbArg.getOwner()->getParentOp();
-            if (definingOp->isProperAncestor(returnOp))
-              foundEquivValue = true;
-            return;
-          }
+    for (OpOperand &returnValOperand : returnOp->getOpOperands()) {
+      Value returnVal = returnValOperand.get();
+      // Skip non-tensor values.
+      if (!returnVal.getType().isa<TensorType>())
+        continue;
 
-          Operation *definingOp = equivVal.getDefiningOp();
-          if (definingOp->getBlock()->findAncestorOpInBlock(
-                  *returnOp->getParentOp()))
-            // Skip ops that happen after `returnOp` and parent ops.
-            if (happensBefore(definingOp, returnOp, domInfo))
-              foundEquivValue = true;
-        });
-
-        if (!foundEquivValue)
-          status =
-              returnOp->emitError()
-              << "operand #" << returnValOperand.getOperandNumber()
-              << " of ReturnLike op does not satisfy destination passing style";
-      }
+      bool foundEquivValue = false;
+      aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) {
+        if (auto bbArg = equivVal.dyn_cast<BlockArgument>()) {
+          Operation *definingOp = bbArg.getOwner()->getParentOp();
+          if (definingOp->isProperAncestor(returnOp))
+            foundEquivValue = true;
+          return;
+        }
 
-      return WalkResult::advance();
-    });
+        Operation *definingOp = equivVal.getDefiningOp();
+        if (definingOp->getBlock()->findAncestorOpInBlock(
+                *returnOp->getParentOp()))
+          // Skip ops that happen after `returnOp` and parent ops.
+          if (happensBefore(definingOp, returnOp, domInfo))
+            foundEquivValue = true;
+      });
+
+      if (!foundEquivValue)
+        status =
+            returnOp->emitError()
+            << "operand #" << returnValOperand.getOperandNumber()
+            << " of ReturnLike op does not satisfy destination passing style";
+    }
 
-    return status;
-  }
-};
+    return WalkResult::advance();
+  });
+
+  return status;
+}
 
 LogicalResult bufferization::analyzeOp(Operation *op,
                                        AnalysisBufferizationState &state) {
@@ -761,12 +760,11 @@ LogicalResult bufferization::analyzeOp(Operation *op,
     return failure();
   equivalenceAnalysis(op, aliasInfo, state);
 
-  for (const std::unique_ptr<PostAnalysisStep> &step :
-       options.postAnalysisSteps) {
+  for (const PostAnalysisStepFn &fn : options.postAnalysisSteps) {
     SmallVector<Operation *> newOps;
-    if (failed(step->run(op, state, aliasInfo, newOps)))
+    if (failed(fn(op, state, aliasInfo, newOps)))
       return failure();
-    // Analyze ops that were created by the PostAnalysisStep.
+    // Analyze ops that were created by the PostAnalysisStepFn.
     if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
       return failure();
     equivalenceAnalysis(newOps, aliasInfo, state);
@@ -774,8 +772,7 @@ LogicalResult bufferization::analyzeOp(Operation *op,
 
   if (!options.allowReturnMemref) {
     SmallVector<Operation *> newOps;
-    if (failed(
-            AssertDestinationPassingStyle().run(op, state, aliasInfo, newOps)))
+    if (failed(assertDestinationPassingStyle(op, state, aliasInfo, newOps)))
       return failure();
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 493044aa53aaa..984b42b59c7f9 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -524,11 +524,10 @@ findValidInsertionPoint(Operation *initTensorOp,
 /// chain, starting from the OpOperand and always following the aliasing
 /// OpOperand, that eventually ends at a single InitTensorOp.
 LogicalResult
-mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
-    eliminateInitTensors(Operation *op, BufferizationState &state,
-                         BufferizationAliasInfo &aliasInfo,
-                         AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
-                         SmallVector<Operation *> &newOps) {
+mlir::linalg::comprehensive_bufferize::linalg_ext::eliminateInitTensors(
+    Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
+    AnchorMatchFn anchorMatchFunc, RewriteFn rewriteFunc,
+    SmallVector<Operation *> &newOps) {
   OpBuilder b(op->getContext());
 
   WalkResult status = op->walk([&](Operation *op) {
@@ -628,7 +627,7 @@ mlir::linalg::comprehensive_bufferize::linalg_ext::InitTensorEliminationStep::
 /// Note that the newly inserted ExtractSliceOp may have to bufferize
 /// out-of-place due to RaW conflicts.
 LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
-    InsertSliceAnchoredInitTensorEliminationStep::run(
+    insertSliceAnchoredInitTensorEliminationStep(
         Operation *op, BufferizationState &state,
         BufferizationAliasInfo &aliasInfo, SmallVector<Operation *> &newOps) {
   return eliminateInitTensors(

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index 63ffb09320076..6f04a2fd40c27 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -16,11 +16,12 @@
 // their respective callers.
 //
 // After analyzing a FuncOp, additional information about its bbArgs is
-// gathered through PostAnalysisSteps and stored in `ModuleBufferizationState`.
+// gathered through PostAnalysisStepFns and stored in
+// `ModuleBufferizationState`.
 //
-// * `EquivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
+// * `equivalentFuncOpBBArgsAnalysis` determines the equivalent bbArg for each
 //   tensor return value (if any).
-// * `FuncOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
+// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
 //   read/written.
 //
 // Only tensors that are equivalent to some FuncOp bbArg may be returned.
@@ -47,7 +48,7 @@
 // modify the buffer of `%t1` directly. The CallOp in `caller` must bufferize
 // out-of-place because `%t0` is modified by the callee but read by the
 // tensor.extract op. The analysis of CallOps decides whether an OpOperand must
-// bufferize out-of-place based on results of `FuncOpBbArgReadWriteAnalysis`.
+// bufferize out-of-place based on results of `funcOpBbArgReadWriteAnalysis`.
 // ```
 // func @callee(%t1 : tensor<?xf32>) -> tensor<?xf32> {
 //   %f = ... : f32
@@ -62,7 +63,7 @@
 // }
 // ```
 //
-// Note: If a function is external, `FuncOpBbArgReadWriteAnalysis` cannot
+// Note: If a function is external, `funcOpBbArgReadWriteAnalysis` cannot
 // analyze the function body. In such a case, the CallOp analysis conservatively
 // assumes that each tensor OpOperand is both read and written.
 //
@@ -159,55 +160,55 @@ static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
 }
 
 namespace {
-/// Store function BlockArguments that are equivalent to a returned value in
-/// ModuleBufferizationState.
-struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
-  /// Annotate IR with the results of the analysis. For testing purposes only.
-  static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) {
-    const char *kEquivalentArgsAttr = "__equivalent_func_args__";
-    Operation *op = returnVal.getOwner();
-
-    SmallVector<int64_t> equivBbArgs;
-    if (op->hasAttr(kEquivalentArgsAttr)) {
-      auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
-      equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
-        return a.cast<IntegerAttr>().getValue().getSExtValue();
-      }));
-    } else {
-      equivBbArgs.append(op->getNumOperands(), -1);
-    }
-    equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
 
-    OpBuilder b(op->getContext());
-    op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
+/// Annotate IR with the results of the analysis. For testing purposes only.
+static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
+                                          BlockArgument bbArg) {
+  const char *kEquivalentArgsAttr = "__equivalent_func_args__";
+  Operation *op = returnVal.getOwner();
+
+  SmallVector<int64_t> equivBbArgs;
+  if (op->hasAttr(kEquivalentArgsAttr)) {
+    auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
+    equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
+      return a.cast<IntegerAttr>().getValue().getSExtValue();
+    }));
+  } else {
+    equivBbArgs.append(op->getNumOperands(), -1);
   }
+  equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
 
-  LogicalResult run(Operation *op, BufferizationState &state,
-                    BufferizationAliasInfo &aliasInfo,
-                    SmallVector<Operation *> &newOps) override {
-    ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+  OpBuilder b(op->getContext());
+  op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
+}
 
-    // Support only single return-terminated block in the function.
-    auto funcOp = cast<FuncOp>(op);
-    ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-    assert(returnOp && "expected func with single return op");
-
-    for (OpOperand &returnVal : returnOp->getOpOperands())
-      if (returnVal.get().getType().isa<RankedTensorType>())
-        for (BlockArgument bbArg : funcOp.getArguments())
-          if (bbArg.getType().isa<RankedTensorType>())
-            if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
-                                                        bbArg)) {
-              moduleState
-                  .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
-                  bbArg.getArgNumber();
-              if (state.getOptions().testAnalysisOnly)
-                annotateReturnOp(returnVal, bbArg);
-            }
+/// Store function BlockArguments that are equivalent to a returned value in
+/// ModuleBufferizationState.
+static LogicalResult
+equivalentFuncOpBBArgsAnalysis(Operation *op, BufferizationState &state,
+                               BufferizationAliasInfo &aliasInfo,
+                               SmallVector<Operation *> &newOps) {
+  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
 
-    return success();
-  }
-};
+  // Support only single return-terminated block in the function.
+  auto funcOp = cast<FuncOp>(op);
+  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+  assert(returnOp && "expected func with single return op");
+
+  for (OpOperand &returnVal : returnOp->getOpOperands())
+    if (returnVal.get().getType().isa<RankedTensorType>())
+      for (BlockArgument bbArg : funcOp.getArguments())
+        if (bbArg.getType().isa<RankedTensorType>())
+          if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) {
+            moduleState
+                .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
+                bbArg.getArgNumber();
+            if (state.getOptions().testAnalysisOnly)
+              annotateEquivalentReturnBbArg(returnVal, bbArg);
+          }
+
+  return success();
+}
 
 /// Return true if the buffer of the given tensor value is written to. Must not
 /// be called for values inside not yet analyzed functions. (Post-analysis
@@ -239,38 +240,37 @@ static bool isValueWritten(Value value, const BufferizationState &state,
 }
 
 /// Determine which FuncOp bbArgs are read and which are written. If this
-/// PostAnalysisStep is run on a function with unknown ops, it will
+/// PostAnalysisStepFn is run on a function with unknown ops, it will
 /// conservatively assume that such ops bufferize to a read + write.
-struct FuncOpBbArgReadWriteAnalysis : public PostAnalysisStep {
-  LogicalResult run(Operation *op, BufferizationState &state,
-                    BufferizationAliasInfo &aliasInfo,
-                    SmallVector<Operation *> &newOps) override {
-    ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
-    auto funcOp = cast<FuncOp>(op);
-
-    // If the function has no body, conservatively assume that all args are
-    // read + written.
-    if (funcOp.getBody().empty()) {
-      for (BlockArgument bbArg : funcOp.getArguments()) {
-        moduleState.readBbArgs.insert(bbArg);
-        moduleState.writtenBbArgs.insert(bbArg);
-      }
-
-      return success();
-    }
+static LogicalResult
+funcOpBbArgReadWriteAnalysis(Operation *op, BufferizationState &state,
+                             BufferizationAliasInfo &aliasInfo,
+                             SmallVector<Operation *> &newOps) {
+  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+  auto funcOp = cast<FuncOp>(op);
 
+  // If the function has no body, conservatively assume that all args are
+  // read + written.
+  if (funcOp.getBody().empty()) {
     for (BlockArgument bbArg : funcOp.getArguments()) {
-      if (!bbArg.getType().isa<TensorType>())
-        continue;
-      if (state.isValueRead(bbArg))
-        moduleState.readBbArgs.insert(bbArg);
-      if (isValueWritten(bbArg, state, aliasInfo))
-        moduleState.writtenBbArgs.insert(bbArg);
+      moduleState.readBbArgs.insert(bbArg);
+      moduleState.writtenBbArgs.insert(bbArg);
     }
 
     return success();
   }
-};
+
+  for (BlockArgument bbArg : funcOp.getArguments()) {
+    if (!bbArg.getType().isa<TensorType>())
+      continue;
+    if (state.isValueRead(bbArg))
+      moduleState.readBbArgs.insert(bbArg);
+    if (isValueWritten(bbArg, state, aliasInfo))
+      moduleState.writtenBbArgs.insert(bbArg);
+  }
+
+  return success();
+}
 } // namespace
 
 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
@@ -983,10 +983,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     return failure();
 
   // Collect bbArg/return value information after the analysis.
-  options->postAnalysisSteps.emplace_back(
-      std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
-  options->postAnalysisSteps.emplace_back(
-      std::make_unique<FuncOpBbArgReadWriteAnalysis>());
+  options->postAnalysisSteps.push_back(equivalentFuncOpBBArgsAnalysis);
+  options->postAnalysisSteps.push_back(funcOpBbArgReadWriteAnalysis);
 
   // Analyze ops.
   for (FuncOp funcOp : moduleState.orderedFuncOps) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 7597ff8be1ff5..708db1e089072 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -125,12 +125,12 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
 
   // Enable InitTensorOp elimination.
   if (initTensorElimination) {
-    options->addPostAnalysisStep<
-        linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>();
+    options->addPostAnalysisStep(
+        linalg_ext::insertSliceAnchoredInitTensorEliminationStep);
   }
 
   // Only certain scf.for ops are supported by the analysis.
-  options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
+  options->addPostAnalysisStep(scf::assertScfForAliasingProperties);
 
   ModuleOp moduleOp = getOperation();
   applyEnablingTransformations(moduleOp);

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index f26f7b9ec890e..cc4147fdc2691 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -432,7 +432,7 @@ struct YieldOpInterface
 } // namespace scf
 } // namespace mlir
 
-LogicalResult mlir::scf::AssertScfForAliasingProperties::run(
+LogicalResult mlir::scf::assertScfForAliasingProperties(
     Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
     SmallVector<Operation *> &newOps) {
   LogicalResult status = success();

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 1d435c64d9287..7d8ee6ff3ee67 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -30,11 +30,10 @@ using namespace bufferization;
 namespace {
 struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
   void runOnOperation() override {
-    std::unique_ptr<BufferizationOptions> options =
-        getPartialBufferizationOptions();
-    options->addToDialectFilter<tensor::TensorDialect>();
+    BufferizationOptions options = getPartialBufferizationOptions();
+    options.addToDialectFilter<tensor::TensorDialect>();
 
-    if (failed(bufferizeOp(getOperation(), *options)))
+    if (failed(bufferizeOp(getOperation(), options)))
       signalPassFailure();
   }
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 230a412cd2c2c..28535ba8248a9 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -104,7 +104,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
   auto options = std::make_unique<AnalysisBufferizationOptions>();
 
   if (!allowReturnMemref)
-    options->addPostAnalysisStep<scf::AssertScfForAliasingProperties>();
+    options->addPostAnalysisStep(scf::assertScfForAliasingProperties);
 
   options->allowReturnMemref = allowReturnMemref;
   options->allowUnknownOps = allowUnknownOps;


        


More information about the Mlir-commits mailing list