[Mlir-commits] [mlir] 847710f - [mlir][linalg][bufferize] Add dialect filter to BufferizationOptions

Matthias Springer llvmlistbot at llvm.org
Wed Dec 8 06:51:46 PST 2021


Author: Matthias Springer
Date: 2021-12-08T23:51:18+09:00
New Revision: 847710f7b77ea4e3cd43f62b5b7d920ac47405a5

URL: https://github.com/llvm/llvm-project/commit/847710f7b77ea4e3cd43f62b5b7d920ac47405a5
DIFF: https://github.com/llvm/llvm-project/commit/847710f7b77ea4e3cd43f62b5b7d920ac47405a5.diff

LOG: [mlir][linalg][bufferize] Add dialect filter to BufferizationOptions

This adds a new option `dialectFilter` to BufferizationOptions. Only ops from dialects that are allow-listed in the filter are bufferized. Other ops are left unbufferized. Note: This option requires `allowUnknownOps = true`.

To make use of `dialectFilter`, BufferizationOptions or BufferizationState must be passed to various helper functions.

The purpose of this change is to provide a better infrastructure for partial bufferization, which will be fully activated in a subsequent change.

Differential Revision: https://reviews.llvm.org/D114691

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
    mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
    mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 527f11d4d93e6..df327aa5e243c 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -30,6 +30,7 @@ namespace comprehensive_bufferize {
 static constexpr int64_t kBufferAlignments = 128;
 
 class BufferizationAliasInfo;
+class BufferizableOpInterface;
 struct BufferizationOptions;
 class BufferizationState;
 struct PostAnalysisStep;
@@ -92,6 +93,33 @@ struct BufferizationOptions {
         std::make_unique<Step>(std::forward<Args>(args)...));
   }
 
+  /// Return `true` if the op is allowed to be bufferized.
+  bool isOpAllowed(Operation *op) const {
+    if (!dialectFilter.hasValue())
+      return true;
+    return dialectFilter->contains(op->getDialect()->getNamespace());
+  }
+
+  /// Allow-list the given dialects in the dialect filter. Only ops from
+  /// allow-listed dialects will be bufferized.
+  template <typename... DialectTs>
+  void addToDialectFilter() {
+    // The following expands a call to addToDialectFilterImpl for each dialect
+    // in 'DialectTs'. This magic is necessary due to a limitation in the places
+    // that a parameter pack can be expanded in c++11.
+    // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+    (void)std::initializer_list<int>{
+        0, (addToDialectFilterImpl<DialectTs>(), 0)...};
+  }
+
+  /// Try to cast the given op to BufferizableOpInterface if the op is allow
+  /// listed.
+  BufferizableOpInterface dynCastBufferizableOp(Operation *op) const;
+
+  /// Try to cast the given value to BufferizableOpInterface if the op is allow
+  /// listed.
+  BufferizableOpInterface dynCastBufferizableOp(Value value) const;
+
   /// Helper functions for allocation, deallocation, memory copying.
   std::unique_ptr<AllocationCallbacks> allocationFns;
 
@@ -114,6 +142,25 @@ struct BufferizationOptions {
 
   /// Registered post analysis steps.
   PostAnalysisStepList postAnalysisSteps;
+
+  /// Only bufferize ops from dialects that are allowed-listed by the filter.
+  /// All other ops are ignored. This option controls the scope of partial
+  /// bufferization.
+  ///
+  /// Note: If no filter is specified, all ops are bufferized (as long as they
+  /// implement BufferizableOpInterface). If a filter is specified,
+  /// `allowUnknownOps` should be enabled. Otherwise, bufferization would fail
+  /// when encountering an op that is forbidden by the filter.
+  Optional<DenseSet<StringRef>> dialectFilter;
+
+private:
+  /// Allow-list a dialect in the dialect filter.
+  template <typename DialectT>
+  void addToDialectFilterImpl() {
+    if (!dialectFilter.hasValue())
+      dialectFilter.emplace();
+    dialectFilter->insert(DialectT::getDialectNamespace());
+  }
 };
 
 /// Specify fine-grain relationship between buffers to enable more analysis.
@@ -128,7 +175,8 @@ enum class BufferRelation {
 /// equivalence classes to support bufferization.
 class BufferizationAliasInfo {
 public:
-  explicit BufferizationAliasInfo(Operation *rootOp);
+  explicit BufferizationAliasInfo(Operation *rootOp,
+                                  const BufferizationOptions &options);
 
   // BufferizationAliasInfo should be passed as a reference.
   BufferizationAliasInfo(const BufferizationAliasInfo &) = delete;
@@ -265,7 +313,7 @@ bool isValueRead(Value value);
 /// starting the traversal from Value 1, the resulting SetVector is:
 /// { 2, 7, 8, 5 }
 llvm::SetVector<Value>
-findValueInReverseUseDefChain(Value value,
+findValueInReverseUseDefChain(Value value, const BufferizationOptions &options,
                               std::function<bool(Value)> condition);
 
 /// Find the Value of the last preceding write of a given Value.
@@ -276,7 +324,7 @@ findValueInReverseUseDefChain(Value value,
 ///
 /// Note: When reaching an end of the reverse SSA use-def chain, that value
 /// is returned regardless of whether it is a memory write or not.
-Value findLastPrecedingWrite(Value value);
+Value findLastPrecedingWrite(Value value, const BufferizationOptions &options);
 
 /// Dialect-specific bufferization state. Analysis/bufferization information
 /// that is specific to ops from a certain dialect can be stored in derived
@@ -300,7 +348,7 @@ struct DialectBufferizationState {
 class BufferizationState {
 public:
   BufferizationState(Operation *op, const BufferizationOptions &options)
-      : aliasInfo(op), options(options), builder(op->getContext()) {}
+      : aliasInfo(op, options), options(options), builder(op->getContext()) {}
 
   // BufferizationState should be passed as a reference.
   BufferizationState(const BufferizationState &) = delete;

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
index 6a35c0e3bb525..a81b52d1433f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td
@@ -266,6 +266,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*methodName=*/"isNotConflicting",
         /*args=*/(ins "OpOperand *":$uRead,
                       "OpOperand *":$uWrite,
+                      "BufferizationState &":$state,
                       "const BufferizationAliasInfo &":$aliasInfo),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index 3f9c8979e1844..ffb8a7a25c273 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -78,7 +78,8 @@ BufferizationOptions::BufferizationOptions()
 // BufferizationAliasInfo
 //===----------------------------------------------------------------------===//
 
-BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
+BufferizationAliasInfo::BufferizationAliasInfo(
+    Operation *rootOp, const BufferizationOptions &options) {
   rootOp->walk([&](Operation *op) {
     for (Value v : op->getResults())
       if (v.getType().isa<TensorType>())
@@ -93,6 +94,8 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
   // Set up alias sets for OpResults that must bufferize in-place. This should
   // be done before making any other bufferization decisions.
   rootOp->walk([&](BufferizableOpInterface bufferizableOp) {
+    if (!options.isOpAllowed(bufferizableOp))
+      return WalkResult::skip();
     for (OpResult opResult : bufferizableOp->getOpResults()) {
       if (opResult.getType().isa<TensorType>())
         if (bufferizableOp.mustBufferizeInPlace(opResult)) {
@@ -105,6 +108,7 @@ BufferizationAliasInfo::BufferizationAliasInfo(Operation *rootOp) {
           markInPlace(opResult);
         }
     }
+    return WalkResult::advance();
   });
 }
 
@@ -197,6 +201,21 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
   }
 }
 
+BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
+    BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
+  if (isOpAllowed(op))
+    return dyn_cast<BufferizableOpInterface>(op);
+  return nullptr;
+}
+
+BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
+    BufferizationOptions::dynCastBufferizableOp(Value value) const {
+  if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
+    if (isOpAllowed(bufferizableOp.getOperation()))
+      return bufferizableOp;
+  return nullptr;
+}
+
 /// Determine which OpOperand* will alias with `result` if the op is bufferized
 /// in place. Return an empty vector if the op is not bufferizable.
 SmallVector<OpOperand *>
@@ -283,7 +302,8 @@ bool mlir::linalg::comprehensive_bufferize::isValueRead(Value value) {
 // further.
 llvm::SetVector<Value>
 mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
-    Value value, std::function<bool(Value)> condition) {
+    Value value, const BufferizationOptions &options,
+    std::function<bool(Value)> condition) {
   llvm::SetVector<Value> result, workingSet;
   workingSet.insert(value);
 
@@ -296,7 +316,7 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
 
     OpResult opResult = value.cast<OpResult>();
     SmallVector<OpOperand *> opOperands = getAliasingOpOperand(opResult);
-    if (opOperands.empty()) {
+    if (opOperands.empty() || !options.isOpAllowed(value.getDefiningOp())) {
       result.insert(value);
       continue;
     }
@@ -310,13 +330,13 @@ mlir::linalg::comprehensive_bufferize::findValueInReverseUseDefChain(
 
 // Find the Value of the last preceding write of a given Value.
 Value mlir::linalg::comprehensive_bufferize::findLastPrecedingWrite(
-    Value value) {
+    Value value, const BufferizationOptions &options) {
   SetVector<Value> result =
-      findValueInReverseUseDefChain(value, [](Value value) {
+      findValueInReverseUseDefChain(value, options, [&](Value value) {
         Operation *op = value.getDefiningOp();
         if (!op)
           return true;
-        auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
+        auto bufferizableOp = options.dynCastBufferizableOp(op);
         if (!bufferizableOp)
           return true;
         return bufferizableOp.isMemoryWrite(value.cast<OpResult>());
@@ -374,9 +394,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
     // Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
     // use-def chain, it returns that value, regardless of whether it is a
     // memory write or not.
-    Value lastWrite = findLastPrecedingWrite(operand);
-    if (auto bufferizableOp =
-            lastWrite.getDefiningOp<BufferizableOpInterface>())
+    Value lastWrite = findLastPrecedingWrite(operand, options);
+    if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
       if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>()))
         skipCopy = true;
     // Do not copy if the copied data is never read.
@@ -433,7 +452,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
 
   // Bufferize using `BufferizableOpInterface`. Interface implementations are
   // responsible for bufferizing nested ops.
-  if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op)) {
+  if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
     b.setInsertionPoint(op);
     return bufferizableOp.bufferize(b, state);
   }
@@ -640,8 +659,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
     if (options.allowUnknownOps) {
       // `tensor` was not bufferized yet. This should never happen with
       // bufferizable ops.
-      assert(!tensor.getDefiningOp<BufferizableOpInterface>() &&
-             "tensor is not mapped");
+      assert(!options.dynCastBufferizableOp(tensor) && "tensor is not mapped");
       // Insert to_memref op.
       OpBuilder b(tensor.getContext());
       setInsertionPointAfter(b, tensor);

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index 7f2ae60b1309b..6cfae7cc702e5 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -256,13 +256,13 @@ static bool aliasesNonWritableBuffer(Value value,
   aliasInfo.applyOnAliases(value, [&](Value v) {
     // Query BufferizableOpInterface to see if the OpResult is writable.
     // TODO: Out-of-place bufferized OpResult could be considered writable.
-    if (auto bufferizableOp = v.getDefiningOp<BufferizableOpInterface>())
+    if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(v))
       if (bufferizableOp && bufferizableOp.isWritable(v, state))
         return;
 
     // Query BufferizableOpInterface to see if the BlockArgument is writable.
     if (auto bbArg = v.dyn_cast<BlockArgument>())
-      if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(
+      if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(
               bbArg.getOwner()->getParentOp()))
         if (bufferizableOp.isWritable(bbArg, state))
           return;
@@ -324,11 +324,12 @@ static bool happensBefore(Operation *a, Operation *b,
 /// A conflict is: According to SSA use-def chains, a read R is supposed to read
 /// the result of a write W1. But because of bufferization decisions, R actually
 /// reads another write W2.
-static bool
-hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
-                              const DenseSet<OpOperand *> &usesWrite,
-                              const DominanceInfo &domInfo,
-                              const BufferizationAliasInfo &aliasInfo) {
+static bool hasReadAfterWriteInterference(
+    const DenseSet<OpOperand *> &usesRead,
+    const DenseSet<OpOperand *> &usesWrite, const DominanceInfo &domInfo,
+    BufferizationState &state, const BufferizationAliasInfo &aliasInfo) {
+  const BufferizationOptions &options = state.getOptions();
+
   for (OpOperand *uRead : usesRead) {
     Operation *readingOp = uRead->getOwner();
 
@@ -341,7 +342,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
     // In the above example, if uRead is the OpOperand of reading_op, lastWrite
     // is %0. Note that operations that create an alias but do not write (such
     // as ExtractSliceOp) are skipped.
-    Value lastWrite = findLastPrecedingWrite(uRead->get());
+    Value lastWrite = findLastPrecedingWrite(uRead->get(), options);
 
     // Look for conflicting memory writes. Potential conflicts are writes to an
     // alias that have been decided to bufferize inplace.
@@ -370,15 +371,15 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
         continue;
 
       // No conflict if the op interface says so.
-      if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(readingOp))
-        if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
+      if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp))
+        if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
                                             aliasInfo))
           continue;
 
       if (conflictingWritingOp != readingOp)
         if (auto bufferizableOp =
-                dyn_cast<BufferizableOpInterface>(conflictingWritingOp))
-          if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
+                options.dynCastBufferizableOp(conflictingWritingOp))
+          if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state,
                                               aliasInfo))
             continue;
 
@@ -452,7 +453,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
 /// involving aliases of the given OpOperand are checked.
 bool wouldCreateReadAfterWriteInterference(
     OpOperand &operand, OpResult result, const DominanceInfo &domInfo,
-    const BufferizationAliasInfo &aliasInfo,
+    BufferizationState &state, const BufferizationAliasInfo &aliasInfo,
     bool checkConsistencyOnly = false) {
 #ifndef NDEBUG
   if (result) {
@@ -496,7 +497,8 @@ bool wouldCreateReadAfterWriteInterference(
   if (!checkConsistencyOnly && bufferizesToMemoryWrite(operand))
     usesWrite.insert(&operand);
 
-  return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, aliasInfo);
+  return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state,
+                                       aliasInfo);
 }
 
 /// Return true if bufferizing `opOperand` inplace with `opResult` would create
@@ -555,7 +557,7 @@ static LogicalResult bufferizableInPlaceAnalysisImpl(
 
   bool foundInterference =
       wouldCreateWriteToNonWritableBuffer(operand, result, aliasInfo, state) ||
-      wouldCreateReadAfterWriteInterference(operand, result, domInfo,
+      wouldCreateReadAfterWriteInterference(operand, result, domInfo, state,
                                             aliasInfo);
 
   if (foundInterference)
@@ -603,7 +605,7 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
   for (Operation *op : reverse(ops))
     for (OpOperand &opOperand : op->getOpOperands())
       if (opOperand.get().getType().isa<TensorType>())
-        if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+        if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op))
           if (OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand))
             if (failed(bufferizableInPlaceAnalysisImpl(
                     opOperand, opResult, aliasInfo, state, domInfo)))
@@ -633,9 +635,10 @@ static LogicalResult inPlaceAnalysis(Operation *op,
 
 /// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
 static void equivalenceAnalysis(SmallVector<Operation *> &ops,
-                                BufferizationAliasInfo &aliasInfo) {
+                                BufferizationAliasInfo &aliasInfo,
+                                const BufferizationOptions &options) {
   for (Operation *op : ops)
-    if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
       for (OpResult opResult : op->getOpResults())
         if (opResult.getType().isa<TensorType>())
           if (aliasInfo.isInPlace(opResult)) {
@@ -652,7 +655,8 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
 /// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
 /// in `op`.
 static void equivalenceAnalysis(Operation *op,
-                                BufferizationAliasInfo &aliasInfo) {
+                                BufferizationAliasInfo &aliasInfo,
+                                const BufferizationOptions &options) {
   // Traverse ops in PostOrder: Nested ops first, then enclosing ops.
   SmallVector<Operation *> ops;
   op->walk<WalkOrder::PostOrder>([&](Operation *op) {
@@ -662,21 +666,23 @@ static void equivalenceAnalysis(Operation *op,
     ops.push_back(op);
   });
 
-  equivalenceAnalysis(ops, aliasInfo);
+  equivalenceAnalysis(ops, aliasInfo, options);
 }
 
 /// Assert that the current bufferization decisions are consistent.
 static LogicalResult
 checkAliasInfoConsistency(Operation *op, const DominanceInfo &domInfo,
+                          BufferizationState &state,
                           const BufferizationAliasInfo &aliasInfo) {
+  const BufferizationOptions &options = state.getOptions();
   Operation *inconsistentOp = nullptr;
   WalkResult walkResult = op->walk([&](Operation *op) {
-    if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
+    if (auto bufferizableOp = options.dynCastBufferizableOp(op))
       for (OpOperand &opOperand : op->getOpOperands())
         if (opOperand.get().getType().isa<TensorType>()) {
           OpResult opResult = bufferizableOp.getAliasingOpResult(opOperand);
           if (wouldCreateReadAfterWriteInterference(
-                  opOperand, opResult, domInfo, aliasInfo,
+                  opOperand, opResult, domInfo, state, aliasInfo,
                   /*checkConsistencyOnly=*/true)) {
             // This error can happen for two reasons. Either the input IR
             // already has a read-after-write conflict. Or certain
@@ -723,14 +729,14 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   DominanceInfo domInfo(op);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
-  if (failed(checkAliasInfoConsistency(op, domInfo, aliasInfo)))
+  if (failed(checkAliasInfoConsistency(op, domInfo, state, aliasInfo)))
     return failure();
 
   // If the analysis fails, just return.
   if (failed(inPlaceAnalysis(op, aliasInfo, state, domInfo,
                              options.analysisFuzzerSeed)))
     return failure();
-  equivalenceAnalysis(op, aliasInfo);
+  equivalenceAnalysis(op, aliasInfo, options);
 
   auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
     for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
@@ -740,7 +746,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
       // Analyze ops that were created by the PostAnalysisStep.
       if (failed(inPlaceAnalysis(newOps, aliasInfo, state, domInfo)))
         return failure();
-      equivalenceAnalysis(newOps, aliasInfo);
+      equivalenceAnalysis(newOps, aliasInfo, options);
     }
     return success();
   };

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
index b68f1a2da5162..3ac95dbe18c4f 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -388,6 +388,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
         std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
         SmallVector<Operation *> &newOps) {
   OpBuilder b(op->getContext());
+  const BufferizationOptions &options = state.getOptions();
 
   WalkResult status = op->walk([&](Operation *op) {
     for (OpOperand &operand : op->getOpOperands()) {
@@ -396,7 +397,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext::
         continue;
 
       SetVector<Value> maybeInitTensor =
-          findValueInReverseUseDefChain(operand.get(), [&](Value val) {
+          findValueInReverseUseDefChain(operand.get(), options, [&](Value val) {
             // Continue traversal until this function returns true.
             OpResult opResult = val.dyn_cast<OpResult>();
             if (!opResult)

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index ca38d27e121e1..cfc04be793b12 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -276,6 +276,7 @@ static bool isSourceEquivalentToAMatchingInplaceExtractSliceOp(
 /// Return true if `value` is originating from an ExtractSliceOp that matches
 /// the given InsertSliceOp.
 static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
+                                      const BufferizationOptions &options,
                                       Value value, InsertSliceOp insertOp) {
   auto condition = [&](Value val) {
     if (auto extractOp = val.getDefiningOp<ExtractSliceOp>())
@@ -284,7 +285,7 @@ static bool hasMatchingExtractSliceOp(const BufferizationAliasInfo &aliasInfo,
     return false;
   };
 
-  return llvm::all_of(findValueInReverseUseDefChain(value, condition),
+  return llvm::all_of(findValueInReverseUseDefChain(value, options, condition),
                       condition);
 }
 
@@ -311,7 +312,7 @@ struct InsertSliceOpInterface
   }
 
   bool isNotConflicting(Operation *op, OpOperand *uRead,
-                        OpOperand *uConflictingWrite,
+                        OpOperand *uConflictingWrite, BufferizationState &state,
                         const BufferizationAliasInfo &aliasInfo) const {
     Operation *readingOp = uRead->getOwner();
     Operation *conflictingWritingOp = uConflictingWrite->getOwner();
@@ -328,8 +329,8 @@ struct InsertSliceOpInterface
 
       // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
       if (uRead == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(aliasInfo, uConflictingWrite->get(),
-                                    insertSliceOp))
+          hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
+                                    uConflictingWrite->get(), insertSliceOp))
         // Case 1: The main insight is that InsertSliceOp reads only part of
         // the destination tensor. The overwritten area is not read. If
         // uConflictingWrite writes into exactly the memory location that is
@@ -346,7 +347,8 @@ struct InsertSliceOpInterface
 
       if (uRead == &insertSliceOp->getOpOperand(0) /*source*/ &&
           uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
-          hasMatchingExtractSliceOp(aliasInfo, uRead->get(), insertSliceOp))
+          hasMatchingExtractSliceOp(aliasInfo, state.getOptions(), uRead->get(),
+                                    insertSliceOp))
         // Case 2: The read of the source tensor and the write to the dest
         // tensor via an InsertSliceOp is not a conflict if the read is
         // reading exactly that part of an equivalent tensor that the
@@ -379,8 +381,8 @@ struct InsertSliceOpInterface
       if (uConflictingWrite == &insertSliceOp->getOpOperand(1) /*dest*/ &&
           aliasInfo.areEquivalentBufferizedValues(uRead->get(),
                                                   insertSliceOp.source()) &&
-          hasMatchingExtractSliceOp(aliasInfo, insertSliceOp.source(),
-                                    insertSliceOp))
+          hasMatchingExtractSliceOp(aliasInfo, state.getOptions(),
+                                    insertSliceOp.source(), insertSliceOp))
         return true;
 
     return false;

diff  --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
index 6da6b2a514dc5..2870d40f076c2 100644
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir
@@ -8,6 +8,8 @@
 // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
 // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
 
+// RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=tensor allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR
+
 // CHECK-LABEL: func @use_of_unknown_op_1(
 //  CHECK-SAME:     %[[m1:.*]]: memref<?xf32
 func @use_of_unknown_op_1(%t1: tensor<?xf32> {linalg.inplaceable = true})
@@ -148,3 +150,20 @@ func @unknown_op_not_writable(
   // CHECK: return %[[alloc]]
   return %1 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-TENSOR-LABEL: func @simple_tensor_test(
+//  CHECK-TENSOR-SAME:     %[[t1:.*]]: tensor<?xf32>
+func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
+  // CHECK-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
+  %c0 = arith.constant 0 : index
+  // CHECK-TENSOR: %[[alloc:.*]] = memref.alloc
+  // CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]]
+  // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[casted]]
+  // CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]]
+  %0 = tensor.insert %f into %t1[%c0] : tensor<?xf32>
+  // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]]
+  // CHECK-TENSOR: return %[[casted_tensor]]
+  return %0 : tensor<?xf32>
+}

diff  --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
index 5ac15a979d00c..fae27fc1a3f4c 100644
--- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp
@@ -85,6 +85,10 @@ struct TestComprehensiveFunctionBufferize
       *this, "analysis-fuzzer-seed",
       llvm::cl::desc("Analyze ops in random order with a given seed (fuzzer)"),
       llvm::cl::init(0)};
+  ListOption<std::string> dialectFilter{
+      *this, "dialect-filter",
+      llvm::cl::desc("Bufferize only ops from the specified dialects"),
+      llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
 };
 } // namespace
 
@@ -104,6 +108,12 @@ void TestComprehensiveFunctionBufferize::runOnFunction() {
   options.testAnalysisOnly = testAnalysisOnly;
   options.analysisFuzzerSeed = analysisFuzzerSeed;
 
+  if (dialectFilter.hasValue()) {
+    options.dialectFilter.emplace();
+    for (const std::string &dialectNamespace : dialectFilter)
+      options.dialectFilter->insert(dialectNamespace);
+  }
+
   Operation *op = getFunction().getOperation();
   if (failed(runComprehensiveBufferize(op, options)))
     return;


        


More information about the Mlir-commits mailing list