[Mlir-commits] [mlir] faa9be7 - [mlir][bufferize][NFC] Rename DialectAnalysisState and move to OneShotAnalysis

Matthias Springer llvmlistbot at llvm.org
Tue Nov 22 05:39:08 PST 2022


Author: Matthias Springer
Date: 2022-11-22T14:34:55+01:00
New Revision: faa9be75ee9bfefa6a435f6570997ec3dd3657a3

URL: https://github.com/llvm/llvm-project/commit/faa9be75ee9bfefa6a435f6570997ec3dd3657a3
DIFF: https://github.com/llvm/llvm-project/commit/faa9be75ee9bfefa6a435f6570997ec3dd3657a3.diff

LOG: [mlir][bufferize][NFC] Rename DialectAnalysisState and move to OneShotAnalysis

`DialectAnalysisState` is now `OneShotAnalysisState::Extension`.

This state extension mechanism is needed only for One-Shot Analysis, so it is moved from `BufferizableOpInterface.h` to `OneShotAnalysis.h`.

Extensions are now identified via TypeIDs instead of StringRefs. The API of state extensions is cleaned up and follows the same pattern as other extension mechanisms in MLIR (e.g., `transform::TransformState::Extension`).

Also delete some dead code.

Differential Revision: https://reviews.llvm.org/D135051

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
    mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
    mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 5ea15f94a2c2c..b3bdb43603636 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -23,7 +23,6 @@ namespace bufferization {
 
 class AnalysisState;
 class BufferizableOpInterface;
-struct DialectAnalysisState;
 
 class OpFilter {
 public:
@@ -181,9 +180,6 @@ struct BufferizationOptions {
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
-  /// Initializer function for dialect-specific analysis state.
-  using DialectStateInitFn =
-      std::function<std::unique_ptr<DialectAnalysisState>()>;
   /// Tensor -> MemRef type converter.
   /// Parameters: Value, memory space, bufferization options
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
@@ -301,10 +297,6 @@ struct BufferizationOptions {
   /// Initializer functions for analysis state. These can be used to
   /// initialize dialect-specific analysis state.
   SmallVector<AnalysisStateInitFn> stateInitializers;
-
-  /// Add a analysis state initializer that initializes the specified
-  /// dialect-specific analysis state.
-  void addDialectStateInitializer(StringRef name, const DialectStateInitFn &fn);
 };
 
 /// Specify fine-grain relationship between buffers to enable more analysis.
@@ -318,18 +310,6 @@ enum class BufferRelation {
 /// Return `true` if the given value is a BlockArgument of a func::FuncOp.
 bool isFunctionArgument(Value value);
 
-/// Dialect-specific analysis state. Analysis/bufferization information
-/// that is specific to ops from a certain dialect can be stored in derived
-/// variants of this struct.
-struct DialectAnalysisState {
-  DialectAnalysisState() = default;
-
-  virtual ~DialectAnalysisState() = default;
-
-  // Copying state is forbidden. Always pass as reference.
-  DialectAnalysisState(const DialectAnalysisState &) = delete;
-};
-
 /// AnalysisState provides a variety of helper functions for dealing with
 /// tensor values.
 class AnalysisState {
@@ -422,52 +402,29 @@ class AnalysisState {
   /// any given tensor.
   virtual bool isTensorYielded(Value tensor) const;
 
-  /// Return `true` if the given dialect state exists.
-  bool hasDialectState(StringRef name) const {
-    auto it = dialectState.find(name);
-    return it != dialectState.end();
-  }
-
-  /// Return dialect-specific bufferization state.
-  template <typename StateT>
-  Optional<const StateT *> getDialectState(StringRef name) const {
-    auto it = dialectState.find(name);
-    if (it == dialectState.end())
-      return None;
-    return static_cast<const StateT *>(it->getSecond().get());
-  }
-
-  /// Return dialect-specific analysis state or create one if none exists.
-  template <typename StateT>
-  StateT &getOrCreateDialectState(StringRef name) {
-    // Create state if it does not exist yet.
-    if (!hasDialectState(name))
-      dialectState[name] = std::make_unique<StateT>();
-    return static_cast<StateT &>(*dialectState[name]);
-  }
-
-  void insertDialectState(StringRef name,
-                          std::unique_ptr<DialectAnalysisState> state) {
-    assert(!dialectState.count(name) && "dialect state already initialized");
-    dialectState[name] = std::move(state);
-  }
-
   /// Return a reference to the BufferizationOptions.
   const BufferizationOptions &getOptions() const { return options; }
 
-  explicit AnalysisState(const BufferizationOptions &options);
+  AnalysisState(const BufferizationOptions &options);
 
   // AnalysisState should be passed as a reference.
   AnalysisState(const AnalysisState &) = delete;
 
   virtual ~AnalysisState() = default;
 
-private:
-  /// Dialect-specific analysis state.
-  DenseMap<StringRef, std::unique_ptr<DialectAnalysisState>> dialectState;
+  static bool classof(const AnalysisState *base) { return true; }
+
+  TypeID getType() const { return type; }
 
+protected:
+  AnalysisState(const BufferizationOptions &options, TypeID type);
+
+private:
   /// A reference to current bufferization options.
   const BufferizationOptions &options;
+
+  /// The type of analysis.
+  TypeID type;
 };
 
 /// Create an AllocTensorOp for the given shaped value (memref or tensor).
@@ -583,6 +540,8 @@ bool defaultIsRepetitiveRegion(BufferizableOpInterface bufferizableOp,
 } // namespace bufferization
 } // namespace mlir
 
+MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
+
 //===----------------------------------------------------------------------===//
 // Bufferization Interfaces
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
index 32f7ff1d1a771..8d9caf30a6881 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h
@@ -10,6 +10,8 @@
 #define MLIR_BUFFERIZATION_TRANSFORMS_FUNCBUFFERIZABLEOPINTERFACEIMPL_H
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 
 namespace mlir {
 class DialectRegistry;
@@ -27,7 +29,10 @@ using func::FuncOp;
 
 /// Extra analysis state that is required for bufferization of function
 /// boundaries.
-struct FuncAnalysisState : public DialectAnalysisState {
+struct FuncAnalysisState : public OneShotAnalysisState::Extension {
+  FuncAnalysisState(OneShotAnalysisState &state)
+      : OneShotAnalysisState::Extension(state) {}
+
   // Note: Function arguments and/or function return values may disappear during
   // bufferization. Functions and their CallOps are analyzed and bufferized
   // separately. To ensure that a CallOp analysis/bufferization can access an

diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 8303d79d5e232..672383787e8e4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -138,6 +138,10 @@ class OneShotAnalysisState : public AnalysisState {
 
   ~OneShotAnalysisState() override = default;
 
+  static bool classof(const AnalysisState *base) {
+    return base->getType() == TypeID::get<OneShotAnalysisState>();
+  }
+
   /// Return a reference to the BufferizationAliasInfo.
   BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
 
@@ -172,6 +176,89 @@ class OneShotAnalysisState : public AnalysisState {
   /// Return true if the buffer of the given tensor value is writable.
   bool isWritable(Value value) const;
 
+  /// Base class for OneShotAnalysisState extensions that allow
+  /// OneShotAnalysisState to contain user-specified information in the state
+  /// object. Clients are expected to derive this class, add the desired fields,
+  /// and make the derived class compatible with the MLIR TypeID mechanism.
+  ///
+  /// ```mlir
+  /// class MyExtension final : public OneShotAnalysisState::Extension {
+  /// public:
+  ///   MyExtension(OneShotAnalysisState &state, int myData)
+  ///       : Extension(state) {...}
+  /// private:
+  ///   int mySupplementaryData;
+  /// };
+  /// ```
+  ///
+  /// Instances of this and derived classes are not expected to be created by
+  /// the user, instead they are directly constructed within a
+  /// OneShotAnalysisState. A OneShotAnalysisState can only contain one
+  /// extension with the given TypeID. Extensions can be obtained from a
+  /// OneShotAnalysisState instance.
+  ///
+  /// ```mlir
+  /// state.addExtension<MyExtension>(/*myData=*/42);
+  /// MyExtension *ext = state.getExtension<MyExtension>();
+  /// ext->doSomething();
+  /// ```
+  class Extension {
+    // Allow OneShotAnalysisState to allocate Extensions.
+    friend class OneShotAnalysisState;
+
+  public:
+    /// Base virtual destructor.
+    // Out-of-line definition ensures symbols are emitted in a single object
+    // file.
+    virtual ~Extension();
+
+  protected:
+    /// Constructs an extension of the given state object.
+    Extension(OneShotAnalysisState &state) : state(state) {}
+
+    /// Provides read-only access to the parent OneShotAnalysisState object.
+    const OneShotAnalysisState &getAnalysisState() const { return state; }
+
+  private:
+    /// Back-reference to the state that is being extended.
+    OneShotAnalysisState &state;
+  };
+
+  /// Adds a new Extension of the type specified as template parameter,
+  /// constructing it with the arguments provided. The extension is owned by the
+  /// OneShotAnalysisState. It is expected that the state does not already have
+  /// an extension of the same type. Extension constructors are expected to take
+  /// a reference to OneShotAnalysisState as first argument, automatically
+  /// supplied by this call.
+  template <typename Ty, typename... Args>
+  Ty &addExtension(Args &&...args) {
+    static_assert(
+        std::is_base_of<Extension, Ty>::value,
+        "only a class derived from OneShotAnalysisState::Extension is allowed");
+    auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
+    auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
+    assert(result.second && "extension already added");
+    return *static_cast<Ty *>(result.first->second.get());
+  }
+
+  /// Returns the extension of the specified type.
+  template <typename Ty>
+  Ty *getExtension() {
+    static_assert(
+        std::is_base_of<Extension, Ty>::value,
+        "only a class derived from OneShotAnalysisState::Extension is allowed");
+    auto iter = extensions.find(TypeID::get<Ty>());
+    if (iter == extensions.end())
+      return nullptr;
+    return static_cast<Ty *>(iter->second.get());
+  }
+
+  /// Returns the extension of the specified type.
+  template <typename Ty>
+  const Ty *getExtension() const {
+    return const_cast<OneShotAnalysisState *>(this)->getExtension<Ty>();
+  }
+
 private:
   /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
   /// functions and `runOneShotBufferize` may access this object.
@@ -183,6 +270,10 @@ class OneShotAnalysisState : public AnalysisState {
 
   /// A set of uses of tensors that have undefined contents.
   DenseSet<OpOperand *> undefinedTensorUses;
+
+  /// Extensions attached to the state, identified by the TypeID of their type.
+  /// Only one extension of any given type is allowed.
+  DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
 };
 
 /// Analyze `op` and its nested ops. Bufferization decisions are stored in
@@ -196,4 +287,6 @@ LogicalResult runOneShotBufferize(Operation *op,
 } // namespace bufferization
 } // namespace mlir
 
+MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
+
 #endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTANALYSIS_H

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index f3a0394a29f05..f79db154193dc 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -32,6 +32,8 @@ namespace bufferization {
 } // namespace bufferization
 } // namespace mlir
 
+MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState)
+
 #define DEBUG_TYPE "bufferizable-op-interface"
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
 #define LDBG(X) LLVM_DEBUG(DBGS() << (X))
@@ -297,12 +299,6 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
   return nullptr;
 }
 
-void BufferizationOptions::addDialectStateInitializer(
-    StringRef name, const DialectStateInitFn &fn) {
-  stateInitializers.push_back(
-      [=](AnalysisState &state) { state.insertDialectState(name, fn()); });
-}
-
 //===----------------------------------------------------------------------===//
 // Helper functions for BufferizableOpInterface
 //===----------------------------------------------------------------------===//
@@ -451,7 +447,10 @@ AnalysisState::findLastPrecedingWrite(Value value) const {
 }
 
 AnalysisState::AnalysisState(const BufferizationOptions &options)
-    : options(options) {
+    : AnalysisState(options, TypeID::get<AnalysisState>()) {}
+
+AnalysisState::AnalysisState(const BufferizationOptions &options, TypeID type)
+    : options(options), type(type) {
   for (const BufferizationOptions::AnalysisStateInitFn &fn :
        options.stateInitializers)
     fn(*this);

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index a441b31c49211..6981cb1ea3ca0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -101,22 +101,23 @@ static FuncOp getCalledFunction(CallOpInterface callOp) {
 /// Get FuncAnalysisState.
 static const FuncAnalysisState &
 getFuncAnalysisState(const AnalysisState &state) {
-  Optional<const FuncAnalysisState *> maybeState =
-      state.getDialectState<FuncAnalysisState>(
-          func::FuncDialect::getDialectNamespace());
-  assert(maybeState && "FuncAnalysisState does not exist");
-  return **maybeState;
+  assert(isa<OneShotAnalysisState>(state) && "expected OneShotAnalysisState");
+  auto *result = static_cast<const OneShotAnalysisState &>(state)
+                     .getExtension<FuncAnalysisState>();
+  assert(result && "FuncAnalysisState does not exist");
+  return *result;
 }
 
 /// Return the state (phase) of analysis of the FuncOp.
 static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
                                                   FuncOp funcOp) {
-  Optional<const FuncAnalysisState *> maybeState =
-      state.getDialectState<FuncAnalysisState>(
-          func::FuncDialect::getDialectNamespace());
-  if (!maybeState.has_value())
+  if (!isa<OneShotAnalysisState>(state))
     return FuncOpAnalysisState::NotAnalyzed;
-  const auto &analyzedFuncOps = maybeState.value()->analyzedFuncOps;
+  auto *funcState = static_cast<const OneShotAnalysisState &>(state)
+                        .getExtension<FuncAnalysisState>();
+  if (!funcState)
+    return FuncOpAnalysisState::NotAnalyzed;
+  const auto &analyzedFuncOps = funcState->analyzedFuncOps;
   auto it = analyzedFuncOps.find(funcOp);
   if (it == analyzedFuncOps.end())
     return FuncOpAnalysisState::NotAnalyzed;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 6e02740464f5a..c7fd2c36f955f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -57,6 +57,8 @@
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SetVector.h"
 
+MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
+
 // Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug
 // output.
 #define DEBUG_TYPE "one-shot-analysis"
@@ -193,7 +195,8 @@ BufferizationAliasInfo::getAliases(Value v) const {
 
 OneShotAnalysisState::OneShotAnalysisState(
     Operation *op, const OneShotBufferizationOptions &options)
-    : AnalysisState(options), aliasInfo(op) {
+    : AnalysisState(options, TypeID::get<OneShotAnalysisState>()),
+      aliasInfo(op) {
   // Set up alias sets for OpResults that must bufferize in-place. This should
   // be done before making any other bufferization decisions.
   op->walk([&](BufferizableOpInterface bufferizableOp) {
@@ -325,6 +328,8 @@ bool OneShotAnalysisState::isWritable(Value value) const {
   return false;
 }
 
+OneShotAnalysisState::Extension::~Extension() = default;
+
 //===----------------------------------------------------------------------===//
 // Bufferization-specific alias analysis.
 //===----------------------------------------------------------------------===//
@@ -1138,11 +1143,6 @@ LogicalResult bufferization::analyzeOp(Operation *op,
   const auto &options =
       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
 
-  // Catch incorrect API usage.
-  assert((state.hasDialectState(func::FuncDialect::getDialectNamespace()) ||
-          !options.bufferizeFunctionBoundaries) &&
-         "must use ModuleBufferize to bufferize function boundaries");
-
   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
     return failure();
 

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index fb1d50c466f9c..3ba1726d4a1fc 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -76,32 +76,13 @@ using namespace mlir::bufferization::func_ext;
 /// A mapping of FuncOps to their callers.
 using FuncCallerMap = DenseMap<func::FuncOp, DenseSet<Operation *>>;
 
-/// Get FuncAnalysisState.
-static const FuncAnalysisState &
-getFuncAnalysisState(const AnalysisState &state) {
-  Optional<const FuncAnalysisState *> maybeState =
-      state.getDialectState<FuncAnalysisState>(
-          func::FuncDialect::getDialectNamespace());
-  assert(maybeState && "FuncAnalysisState does not exist");
-  return **maybeState;
-}
-
 /// Get or create FuncAnalysisState.
-static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) {
-  return state.getOrCreateDialectState<FuncAnalysisState>(
-      func::FuncDialect::getDialectNamespace());
-}
-
-/// Return the state (phase) of analysis of the FuncOp.
-/// Used for debug modes.
-LLVM_ATTRIBUTE_UNUSED
-static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state,
-                                                  func::FuncOp funcOp) {
-  const FuncAnalysisState &funcState = getFuncAnalysisState(state);
-  auto it = funcState.analyzedFuncOps.find(funcOp);
-  if (it == funcState.analyzedFuncOps.end())
-    return FuncOpAnalysisState::NotAnalyzed;
-  return it->second;
+static FuncAnalysisState &
+getOrCreateFuncAnalysisState(OneShotAnalysisState &state) {
+  auto *result = state.getExtension<FuncAnalysisState>();
+  if (result)
+    return *result;
+  return state.addExtension<FuncAnalysisState>();
 }
 
 /// Return the unique ReturnOp that terminates `funcOp`.
@@ -143,10 +124,9 @@ static void annotateEquivalentReturnBbArg(OpOperand &returnVal,
 
 /// Store function BlockArguments that are equivalent to/aliasing a returned
 /// value in FuncAnalysisState.
-static LogicalResult aliasingFuncOpBBArgsAnalysis(FuncOp funcOp,
-                                                  OneShotAnalysisState &state) {
-  FuncAnalysisState &funcState = getFuncAnalysisState(state);
-
+static LogicalResult
+aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+                             FuncAnalysisState &funcState) {
   // Support only single return-terminated block in the function.
   func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
   assert(returnOp && "expected func with single return op");
@@ -190,10 +170,9 @@ static void annotateFuncArgAccess(func::FuncOp funcOp, BlockArgument bbArg,
 /// Determine which FuncOp bbArgs are read and which are written. When run on a
 /// function with unknown ops, we conservatively assume that such ops bufferize
 /// to a read + write.
-static LogicalResult funcOpBbArgReadWriteAnalysis(FuncOp funcOp,
-                                                  OneShotAnalysisState &state) {
-  FuncAnalysisState &funcState = getFuncAnalysisState(state);
-
+static LogicalResult
+funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
+                             FuncAnalysisState &funcState) {
   // If the function has no body, conservatively assume that all args are
   // read + written.
   if (funcOp.getBody().empty()) {
@@ -246,8 +225,8 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
 // TODO: This does not handle cyclic function call graphs etc.
 static void equivalenceAnalysis(func::FuncOp funcOp,
                                 BufferizationAliasInfo &aliasInfo,
-                                OneShotAnalysisState &state) {
-  FuncAnalysisState &funcState = getFuncAnalysisState(state);
+                                OneShotAnalysisState &state,
+                                FuncAnalysisState &funcState) {
   funcOp->walk([&](func::CallOp callOp) {
     func::FuncOp calledFunction = getCalledFunction(callOp);
     assert(calledFunction && "could not retrieved called func::FuncOp");
@@ -360,7 +339,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
       static_cast<const OneShotBufferizationOptions &>(state.getOptions());
   assert(options.bufferizeFunctionBoundaries &&
          "expected that function boundary bufferization is activated");
-  FuncAnalysisState &funcState = getFuncAnalysisState(state);
+  FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
 
   // A list of functions in the order in which they are analyzed + bufferized.
@@ -382,15 +361,15 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
     funcState.startFunctionAnalysis(funcOp);
 
     // Gather equivalence info for CallOps.
-    equivalenceAnalysis(funcOp, aliasInfo, state);
+    equivalenceAnalysis(funcOp, aliasInfo, state, funcState);
 
     // Analyze funcOp.
     if (failed(analyzeOp(funcOp, state)))
       return failure();
 
     // Run some extra function analyses.
-    if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state)) ||
-        failed(funcOpBbArgReadWriteAnalysis(funcOp, state)))
+    if (failed(aliasingFuncOpBBArgsAnalysis(funcOp, state, funcState)) ||
+        failed(funcOpBbArgReadWriteAnalysis(funcOp, state, funcState)))
       return failure();
 
     // Mark op as fully analyzed.


        


More information about the Mlir-commits mailing list