[Mlir-commits] [mlir] 3135548 - [mlir][linalg][bufferize][NFC] Split analysis-related code from BufferizationState/Options

Matthias Springer llvmlistbot at llvm.org
Wed Jan 19 05:06:33 PST 2022


Author: Matthias Springer
Date: 2022-01-19T21:59:23+09:00
New Revision: 31355482e5189c489efa98607d86215480d17452

URL: https://github.com/llvm/llvm-project/commit/31355482e5189c489efa98607d86215480d17452
DIFF: https://github.com/llvm/llvm-project/commit/31355482e5189c489efa98607d86215480d17452.diff

LOG: [mlir][linalg][bufferize][NFC] Split analysis-related code from BufferizationState/Options

This separates the analysis (and its helpers/data structures) more clearly from the rest of the bufferization.

Differential Revision: https://reviews.llvm.org/D117477

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 8fb05125deeb5..86a58001ac735 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -68,14 +68,6 @@ struct BufferizationOptions {
   // BufferizationOptions cannot be copied.
   BufferizationOptions(const BufferizationOptions &other) = delete;
 
-  /// Register a "post analysis" step. Such steps are executed after the
-  /// analysis, but before bufferization.
-  template <typename Step, typename... Args>
-  void addPostAnalysisStep(Args... args) {
-    postAnalysisSteps.emplace_back(
-        std::make_unique<Step>(std::forward<Args>(args)...));
-  }
-
   /// Return `true` if the op is allowed to be bufferized.
   bool isOpAllowed(Operation *op) const {
     if (!dialectFilter.hasValue())
@@ -134,9 +126,6 @@ struct BufferizationOptions {
   /// For debugging only. Should be used together with `testAnalysisOnly`.
   bool printConflicts = false;
 
-  /// Registered post analysis steps.
-  PostAnalysisStepList postAnalysisSteps;
-
   /// Only bufferize ops from dialects that are allowed-listed by the filter.
   /// All other ops are ignored. This option controls the scope of partial
   /// bufferization.
@@ -157,6 +146,25 @@ struct BufferizationOptions {
   }
 };
 
+/// Options for analysis-enabled bufferization.
+struct AnalysisBufferizationOptions : public BufferizationOptions {
+  AnalysisBufferizationOptions() = default;
+
+  // AnalysisBufferizationOptions cannot be copied.
+  AnalysisBufferizationOptions(const AnalysisBufferizationOptions &) = delete;
+
+  /// Register a "post analysis" step. Such steps are executed after the
+  /// analysis, but before bufferization.
+  template <typename Step, typename... Args>
+  void addPostAnalysisStep(Args... args) {
+    postAnalysisSteps.emplace_back(
+        std::make_unique<Step>(std::forward<Args>(args)...));
+  }
+
+  /// Registered post analysis steps.
+  PostAnalysisStepList postAnalysisSteps;
+};
+
 /// Specify fine-grain relationship between buffers to enable more analysis.
 enum class BufferRelation {
   None,
@@ -198,11 +206,6 @@ class BufferizationAliasInfo {
     return equivalentInfo.isEquivalent(v1, v2);
   }
 
-  /// Return true if `v1` and `v2` bufferize to aliasing buffers.
-  bool areAliasingBufferizedValues(Value v1, Value v2) const {
-    return aliasInfo.isEquivalent(v1, v2);
-  }
-
   /// Union the alias sets of `v1` and `v2`.
   void unionAliasSets(Value v1, Value v2) { aliasInfo.unionSets(v1, v2); }
 
@@ -276,11 +279,6 @@ struct DialectBufferizationState {
 /// tensor values and memref buffers.
 class BufferizationState {
 public:
-  BufferizationState(Operation *op, const BufferizationOptions &options);
-
-  // BufferizationState should be passed as a reference.
-  BufferizationState(const BufferizationState &) = delete;
-
   /// Determine which OpOperand* will alias with `result` if the op is
   /// bufferized in place. Return an empty vector if the op is not bufferizable.
   SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
@@ -344,7 +342,10 @@ class BufferizationState {
   SetVector<Value> findLastPrecedingWrite(Value value) const;
 
   /// Return `true` if the given OpResult has been decided to bufferize inplace.
-  bool isInPlace(OpOperand &opOperand) const;
+  virtual bool isInPlace(OpOperand &opOperand) const = 0;
+
+  /// Return true if `v1` and `v2` bufferize to equivalent buffers.
+  virtual bool areEquivalentBufferizedValues(Value v1, Value v2) const = 0;
 
   /// Return the buffer (memref) for a given OpOperand (tensor). Allocate
   /// a new buffer and copy over data from the existing buffer if out-of-place
@@ -374,14 +375,15 @@ class BufferizationState {
   /// Return a reference to the BufferizationOptions.
   const BufferizationOptions &getOptions() const { return options; }
 
-  /// Return a reference to the BufferizationAliasInfo.
-  BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
+protected:
+  BufferizationState(const BufferizationOptions &options);
 
-private:
-  /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
-  /// functions and `runComprehensiveBufferize` may access this object.
-  BufferizationAliasInfo aliasInfo;
+  // BufferizationState should be passed as a reference.
+  BufferizationState(const BufferizationState &) = delete;
+
+  ~BufferizationState() = default;
 
+private:
   /// Dialect-specific bufferization state.
   DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
 
@@ -389,6 +391,33 @@ class BufferizationState {
   const BufferizationOptions &options;
 };
 
+/// State for analysis-enabled bufferization. This class keeps track of alias
+/// (via BufferizationAliasInfo) to decide if tensor OpOperands should bufferize
+/// in-place.
+class AnalysisBufferizationState : public BufferizationState {
+public:
+  AnalysisBufferizationState(Operation *op,
+                             const AnalysisBufferizationOptions &options);
+
+  AnalysisBufferizationState(const AnalysisBufferizationState &) = delete;
+
+  virtual ~AnalysisBufferizationState() = default;
+
+  /// Return a reference to the BufferizationAliasInfo.
+  BufferizationAliasInfo &getAliasInfo() { return aliasInfo; }
+
+  /// Return `true` if the given OpResult has been decided to bufferize inplace.
+  bool isInPlace(OpOperand &opOperand) const override;
+
+  /// Return true if `v1` and `v2` bufferize to equivalent buffers.
+  bool areEquivalentBufferizedValues(Value v1, Value v2) const override;
+
+private:
+  /// `aliasInfo` keeps track of aliasing and equivalent values. Only internal
+  /// functions and `runComprehensiveBufferize` may access this object.
+  BufferizationAliasInfo aliasInfo;
+};
+
 /// Replace an op with replacement values. The op is deleted. Tensor OpResults
 /// must be replaced with memref values.
 void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
@@ -483,7 +512,6 @@ struct AllocationHoistingBarrierOnly
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::None;
   }

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 0f642b6eedeef..6569b0df68123 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -183,7 +183,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*retType=*/"BufferRelation",
         /*methodName=*/"bufferRelation",
         /*args=*/(ins "OpResult":$opResult,
-                      "const BufferizationAliasInfo &":$aliasInfo,
                       "const BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
@@ -284,8 +283,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*methodName=*/"isNotConflicting",
         /*args=*/(ins "OpOperand *":$uRead,
                       "OpOperand *":$uWrite,
-                      "const BufferizationState &":$state,
-                      "const BufferizationAliasInfo &":$aliasInfo),
+                      "const BufferizationState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return false;

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index bd1cd17875c98..6a53295babcd8 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -16,22 +16,22 @@ namespace mlir {
 namespace linalg {
 namespace comprehensive_bufferize {
 
+class AnalysisBufferizationState;
 class BufferizationAliasInfo;
-struct BufferizationOptions;
+struct AnalysisBufferizationOptions;
 class BufferizationState;
 
 /// Analyze `op` and its nested ops. Bufferization decisions are stored in
 /// `state`.
-LogicalResult analyzeOp(Operation *op, BufferizationState &state);
+LogicalResult analyzeOp(Operation *op, AnalysisBufferizationState &state);
 
 /// Bufferize `op` and its nested ops. Bufferization decisions are stored in
 /// `state`.
 LogicalResult bufferizeOp(Operation *op, const BufferizationState &state);
 
 /// Run Comprehensive Bufferize on the given op: Analysis + Bufferization
-LogicalResult
-runComprehensiveBufferize(Operation *op,
-                          std::unique_ptr<BufferizationOptions> options);
+LogicalResult runComprehensiveBufferize(
+    Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options);
 
 } // namespace comprehensive_bufferize
 } // namespace linalg

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
index cde14939e3fe2..6b4039f5283ff 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h
@@ -20,14 +20,13 @@ class ModuleOp;
 namespace linalg {
 namespace comprehensive_bufferize {
 
-struct BufferizationOptions;
+struct AnalysisBufferizationOptions;
 
 /// Run Module Bufferization on the given module. Performs a simple function
 /// call analysis to determine which function arguments are inplaceable. Then
 /// analyzes and bufferizes FuncOps one-by-one with Comprehensive Bufferization.
-LogicalResult
-runComprehensiveBufferize(ModuleOp moduleOp,
-                          std::unique_ptr<BufferizationOptions> options);
+LogicalResult runComprehensiveBufferize(
+    ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options);
 
 namespace std_ext {
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 885a70f56c64b..be9e919fbb628 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -288,8 +288,13 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
 }
 
 mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
-    Operation *op, const BufferizationOptions &options)
-    : aliasInfo(op), options(options) {
+    const BufferizationOptions &options)
+    : options(options) {}
+
+mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
+    AnalysisBufferizationState(Operation *op,
+                               const AnalysisBufferizationOptions &options)
+    : BufferizationState(options), aliasInfo(op) {
   // Set up alias sets for OpResults that must bufferize in-place. This should
   // be done before making any other bufferization decisions.
   op->walk([&](BufferizableOpInterface bufferizableOp) {
@@ -353,7 +358,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
   Value operand = opOperand.get();
   Value operandBuffer = lookupBuffer(rewriter, operand);
 
-  if (forceInPlace || aliasInfo.isInPlace(opOperand))
+  if (forceInPlace || isInPlace(opOperand))
     return operandBuffer;
 
   // Bufferizing out-of-place: Allocate a new buffer.
@@ -597,11 +602,16 @@ bool mlir::linalg::comprehensive_bufferize::isFunctionArgument(Value value) {
   return isa<FuncOp>(bbArg.getOwner()->getParentOp());
 }
 
-bool mlir::linalg::comprehensive_bufferize::BufferizationState::isInPlace(
-    OpOperand &opOperand) const {
+bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
+    isInPlace(OpOperand &opOperand) const {
   return aliasInfo.isInPlace(opOperand);
 }
 
+bool mlir::linalg::comprehensive_bufferize::AnalysisBufferizationState::
+    areEquivalentBufferizedValues(Value v1, Value v2) const {
+  return aliasInfo.areEquivalentBufferizedValues(v1, v2);
+}
+
 MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
     ShapedType shapedType, MemRefLayoutAttrInterface layout,
     Attribute memorySpace) {

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 7955f7b35b61b..67fd483641669 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -252,15 +252,13 @@ static bool hasReadAfterWriteInterference(
 
       // No conflict if the op interface says so.
       if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
-        if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
-                                            aliasInfo))
+        if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
           continue;
 
       if (conflictingWritingOp != readingOp)
         if (auto bufferizableOp =
                 options.dynCastBufferizableOp(conflictingWritingOp))
-          if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
-                                              aliasInfo))
+          if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state))
             continue;
 
       // Ops are not conflicting if they are in mutually exclusive regions.
@@ -496,7 +494,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
           for (OpOperand *opOperand :
                bufferizableOp.getAliasingOpOperand(opResult, state))
             if (state.isInPlace(*opOperand))
-              if (bufferizableOp.bufferRelation(opResult, aliasInfo, state) ==
+              if (bufferizableOp.bufferRelation(opResult, state) ==
                   BufferRelation::Equivalent)
                 aliasInfo.unionEquivalenceClasses(opResult, opOperand->get());
 }
@@ -687,12 +685,12 @@ checkBufferizationResult(Operation *op, const BufferizationOptions &options) {
   return success();
 }
 
-LogicalResult
-mlir::linalg::comprehensive_bufferize::analyzeOp(Operation *op,
-                                                 BufferizationState &state) {
+LogicalResult mlir::linalg::comprehensive_bufferize::analyzeOp(
+    Operation *op, AnalysisBufferizationState &state) {
   DominanceInfo domInfo(op);
   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
-  const BufferizationOptions &options = state.getOptions();
+  const auto &options =
+      static_cast<const AnalysisBufferizationOptions &>(state.getOptions());
 
   if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
     return failure();
@@ -740,8 +738,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::bufferizeOp(
 }
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
-    Operation *op, std::unique_ptr<BufferizationOptions> options) {
-  BufferizationState state(op, *options);
+    Operation *op, std::unique_ptr<AnalysisBufferizationOptions> options) {
+  AnalysisBufferizationState state(op, *options);
   if (failed(analyzeOp(op, state)))
     return failure();
   if (options->testAnalysisOnly)

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index 6d153522811d7..6cce30d165a34 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -193,7 +193,6 @@ struct LinalgOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
@@ -264,7 +263,6 @@ struct TiledLoopOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index ec2d6d758e0e0..fae2682c239fb 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -737,7 +737,6 @@ struct CallOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
@@ -964,9 +963,9 @@ annotateOpsWithBufferizationMarkers(FuncOp funcOp,
 }
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
-    ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
+    ModuleOp moduleOp, std::unique_ptr<AnalysisBufferizationOptions> options) {
   IRRewriter rewriter(moduleOp.getContext());
-  BufferizationState state(moduleOp, *options);
+  AnalysisBufferizationState state(moduleOp, *options);
   ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
   BufferizationAliasInfo &aliasInfo = state.getAliasInfo();
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
index 6337cd023dd22..97cba18592c93 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp
@@ -124,7 +124,6 @@ struct ExecuteRegionOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
@@ -247,7 +246,6 @@ struct IfOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     // IfOp results are equivalent to their corresponding yield values if both
     // yield values are equivalent to each other.
@@ -255,7 +253,7 @@ struct IfOpInterface
     SmallVector<OpOperand *> yieldValues =
         bufferizableOp.getAliasingOpOperand(opResult, state);
     assert(yieldValues.size() == 2 && "expected 2 yield values");
-    bool equivalentYields = aliasInfo.areEquivalentBufferizedValues(
+    bool equivalentYields = state.areEquivalentBufferizedValues(
         yieldValues[0]->get(), yieldValues[1]->get());
     return equivalentYields ? BufferRelation::Equivalent : BufferRelation::None;
   }
@@ -291,7 +289,6 @@ struct ForOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     // ForOp results are equivalent to their corresponding init_args if the
     // corresponding iter_args and yield values are equivalent.
@@ -299,7 +296,7 @@ struct ForOpInterface
     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
     auto yieldOp = cast<scf::YieldOp>(&forOp.getLoopBody().front().back());
-    bool equivalentYield = aliasInfo.areEquivalentBufferizedValues(
+    bool equivalentYield = state.areEquivalentBufferizedValues(
         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
     return equivalentYield ? BufferRelation::Equivalent : BufferRelation::None;
   }
@@ -408,7 +405,9 @@ mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties::
       OpOperand &forOperand = forOp.getOpOperandForResult(
           forOp->getResult(operand.getOperandNumber()));
       auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-      if (!aliasInfo.areAliasingBufferizedValues(operand.get(), bbArg)) {
+      // Note: This is overly strict. We should check for aliasing bufferized
+      // values. But we don't have a "must-alias" analysis yet.
+      if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
         // TODO: this could get resolved with copies but it can also turn into
         // swaps so we need to be careful about order of copies.
         status =

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
index 1f8cee5cf7adb..5e603b3765763 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.cpp
@@ -62,7 +62,6 @@ struct SelectOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::None;
   }

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index a31832452ea84..dc0742b99022d 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -42,7 +42,6 @@ struct CastOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
@@ -137,7 +136,6 @@ struct ExtractSliceOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::None;
   }
@@ -273,7 +271,6 @@ struct InsertOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
@@ -285,12 +282,12 @@ struct InsertOpInterface
 /// This is one particular type of relationship between ops on tensors that
 /// reduce to an equivalence on buffers. This should be generalized and
 /// exposed as interfaces on the proper types.
-static bool
-areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
-                             ExtractSliceOp st, InsertSliceOp sti) {
+static bool areEquivalentExtractSliceOps(const BufferizationState &state,
+                                         ExtractSliceOp st, InsertSliceOp sti) {
   if (!st || !sti)
     return false;
-  if (!aliasInfo.areEquivalentBufferizedValues(st.source(), sti.dest()))
+  if (sti != sti &&
+      !state.areEquivalentBufferizedValues(st.source(), sti.dest()))
     return false;
   if (!sameOffsetsSizesAndStrides(st, sti, isEqualConstantIntOrValue))
     return false;
@@ -299,12 +296,11 @@ areEquivalentExtractSliceOps(const BufferizationAliasInfo &aliasInfo,
 
 /// Return true if `value` is originating from an ExtractSliceOp that matches
 /// the given InsertSliceOp.
-static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
-                                      const BufferizationState &state,
+static bool hasMatchingExtractSliceOp(const BufferizationState &state,
                                       Value value, InsertSliceOp insertOp) {
   auto condition = [&](Value val) {
     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
-      if (areEquivalentExtractSliceOps(aliasInfo, extractOp, insertOp))
+      if (areEquivalentExtractSliceOps(state, extractOp, insertOp))
         return true;
     return false;
   };
@@ -336,15 +332,13 @@ struct InsertSliceOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }
 
   bool isNotConflicting(Operation *op, OpOperand *uRead,
                         OpOperand *uConflictingWrite,
-                        const BufferizationState &state,
-                        const BufferizationAliasInfo &aliasInfo) const {
+                        const BufferizationState &state) const {
     Operation *readingOp = uRead->getOwner();
     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
 
@@ -360,7 +354,7 @@ struct InsertSliceOpInterface
 
       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(aliasInfo, state, uConflictingWrite->get(),
+          hasMatchingExtractSliceOp(state, uConflictingWrite->get(),
                                     insertSliceOp))
         // Case 1: The main insight is that InsertSliceOp reads only part of
         // the destination tensor. The overwritten area is not read. If
@@ -378,8 +372,7 @@ struct InsertSliceOpInterface
 
       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(aliasInfo, state, uRead->get(),
-                                    insertSliceOp))
+          hasMatchingExtractSliceOp(state, uRead->get(), insertSliceOp))
         // Case 2: The read of the source tensor and the write to the dest
         // tensor via an InsertSliceOp is not a conflict if the read is
         // reading exactly that part of an equivalent tensor that the
@@ -410,9 +403,9 @@ struct InsertSliceOpInterface
       // memory segment of %1 with the exact same data. (Effectively, there
       // is no memory write here.)
       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          aliasInfo.areEquivalentBufferizedValues(uRead->get(),
-                                                  insertSliceOp.source()) &&
-          hasMatchingExtractSliceOp(aliasInfo, state, insertSliceOp.source(),
+          state.areEquivalentBufferizedValues(uRead->get(),
+                                              insertSliceOp.source()) &&
+          hasMatchingExtractSliceOp(state, insertSliceOp.source(),
                                     insertSliceOp))
         return true;
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
index 0b3d8ff6d2663..4dd7fdcebc034 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.cpp
@@ -85,7 +85,6 @@ struct TransferWriteOpInterface
   }
 
   BufferRelation bufferRelation(Operation *op, OpResult opResult,
-                                const BufferizationAliasInfo &aliasInfo,
                                 const BufferizationState &state) const {
     return BufferRelation::Equivalent;
   }

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
index 36da9e52ea3c7..dfc25bae2a475 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp
@@ -75,7 +75,7 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
 }
 
 void LinalgComprehensiveModuleBufferize::runOnOperation() {
-  auto options = std::make_unique<BufferizationOptions>();
+  auto options = std::make_unique<AnalysisBufferizationOptions>();
   if (useAlloca) {
     options->allocationFn = allocationFnUsingAlloca;
     options->deallocationFn = [](OpBuilder &b, Location loc, Value v) {

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index d0852e9c65189..87677ef2383de 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -101,7 +101,7 @@ struct TestComprehensiveFunctionBufferize
 } // namespace
 
 void TestComprehensiveFunctionBufferize::runOnOperation() {
-  auto options = std::make_unique<BufferizationOptions>();
+  auto options = std::make_unique<AnalysisBufferizationOptions>();
 
   if (!allowReturnMemref)
     options->addPostAnalysisStep<scf_ext::AssertScfForAliasingProperties>();


        


More information about the Mlir-commits mailing list