[Mlir-commits] [mlir] 8dbbb22 - [mlir][Linalg] NFC - Refactor and simplify Promotion

Nicolas Vasilache llvmlistbot at llvm.org
Mon May 11 07:46:34 PDT 2020


Author: Nicolas Vasilache
Date: 2020-05-11T10:44:45-04:00
New Revision: 8dbbb223834d1715bc9869aa409a4b0f52816da3

URL: https://github.com/llvm/llvm-project/commit/8dbbb223834d1715bc9869aa409a4b0f52816da3
DIFF: https://github.com/llvm/llvm-project/commit/8dbbb223834d1715bc9869aa409a4b0f52816da3.diff

LOG: [mlir][Linalg] NFC - Refactor and simplify Promotion

Summary: This revision introduces LinalgPromotionOptions to more easily control the application of promotion patterns. It also simplifies the different entry points into Promotion in preparation for some behavior change in subsequent revisions.

Differential Revision: https://reviews.llvm.org/D79489

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Transforms/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index b67ff776ea4a..896b31835fb4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -89,11 +89,30 @@ LinalgOp interchange(LinalgOp op, ArrayRef<unsigned> interchangeVector);
 /// Returns a list of PromotionInfo which hold the promoted buffer and the
 /// full and partial views indexing into the buffer.
 // TODO: revisit dynamicBuffers option.
-LinalgOp promoteSubViewOperands(OpBuilder &b, LinalgOp op,
-                                llvm::SetVector<Value> subViews,
-                                bool dynamicBuffers = false,
-                                int64_t alignment = 0,
-                                OperationFolder *folder = nullptr);
+struct LinalgPromotionOptions {
+  /// Indices of subViews to promote. If `None`, try to promote all operands.
+  Optional<DenseSet<unsigned>> operandsToPromote = None;
+  LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) {
+    operandsToPromote = DenseSet<unsigned>();
+    operandsToPromote->insert(operands.begin(), operands.end());
+    return *this;
+  }
+  /// Allow the use of dynamicaly-sized buffers.
+  bool dynamicBuffers = false;
+  LinalgPromotionOptions &setDynamicBuffers(unsigned dynamic) {
+    dynamicBuffers = dynamic;
+    return *this;
+  }
+  /// Alignment of promoted buffer. If `None` do not specify alignment.
+  Optional<unsigned> alignment = None;
+  LinalgPromotionOptions &setAlignment(unsigned align) {
+    alignment = align;
+    return *this;
+  }
+};
+LinalgOp promoteSubViews(OpBuilder &b, LinalgOp op,
+                         LinalgPromotionOptions options,
+                         OperationFolder *folder = nullptr);
 
 /// Emit a suitable vector form for a Linalg op with fully static shape.
 void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
@@ -125,8 +144,8 @@ interchangeGenericLinalgOpPrecondition(Operation *op,
                                        ArrayRef<unsigned> interchangeVector);
 
 /// Promote std.subviews feeding linalg operations.
-LogicalResult promoteSubviewsLinalgOpPrecondition(
-    Operation *op, Optional<DenseSet<unsigned>> operandIndicesToPromote = None);
+LogicalResult promoteSubviewsPrecondition(Operation *op,
+                                          LinalgPromotionOptions options);
 
 /// Rewrite a linalg.generic into a suitable vector.contraction op.
 LogicalResult vectorizeLinalgOpPrecondition(Operation *op);
@@ -242,13 +261,12 @@ struct LinalgInterchangePattern : public LinalgBaseInterchangePattern {
 ///
 /// Linalg promotion patterns.
 ///
-/// Apply the `promoteSubViewOperands` transformation as a pattern.
+/// Apply the `promoteSubViews` transformation as a pattern.
 /// `marker` controls LinalgTransformMarker matching and update when specified.
-/// See `promoteSubViewOperands` for more details.
+/// See `promoteSubViews` for more details.
 struct LinalgBasePromotionPattern : public RewritePattern {
   LinalgBasePromotionPattern(StringRef opName, MLIRContext *context,
-                             ArrayRef<unsigned> operandsToPromote = {},
-                             unsigned alignment = 0,
+                             LinalgPromotionOptions options,
                              LinalgMarker marker = LinalgMarker(),
                              PatternBenefit benefit = 1);
   LogicalResult matchAndRewrite(Operation *op,
@@ -257,35 +275,17 @@ struct LinalgBasePromotionPattern : public RewritePattern {
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
   LinalgMarker marker;
-  /// Indices of subViews to promote.
-  SmallVector<unsigned, 4> operandsToPromote;
-  /// Alignment of promoted buffer.
-  unsigned alignment;
+  /// Promotion options.
+  LinalgPromotionOptions options;
 };
 
 template <typename OpTy>
 struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
-  LinalgPromotionPattern(MLIRContext *context,
-                         ArrayRef<unsigned> operandsToPromote = {},
-                         unsigned alignment = 0,
+  LinalgPromotionPattern(MLIRContext *context, LinalgPromotionOptions options,
                          LinalgMarker marker = LinalgMarker(),
                          PatternBenefit benefit = 1)
-      : LinalgBasePromotionPattern(OpTy::getOperationName(), context,
-                                   operandsToPromote, alignment, marker,
-                                   benefit) {}
-  LinalgPromotionPattern(MLIRContext *context,
-                         ArrayRef<unsigned> operandsToPromote,
-                         LinalgMarker marker = LinalgMarker(),
-                         PatternBenefit benefit = 1)
-      : LinalgPromotionPattern(context, operandsToPromote, 0, marker, benefit) {
-  }
-  LinalgPromotionPattern(MLIRContext *context, unsigned alignment,
-                         LinalgMarker marker = LinalgMarker(),
-                         PatternBenefit benefit = 1)
-      : LinalgPromotionPattern(context, {}, alignment, marker, benefit) {}
-  LinalgPromotionPattern(MLIRContext *context, LinalgMarker marker,
-                         PatternBenefit benefit = 1)
-      : LinalgPromotionPattern(context, {}, 0, marker, benefit) {}
+      : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
+                                   marker, benefit) {}
 };
 
 ///
@@ -342,8 +342,6 @@ struct LinalgLoweringPattern : public RewritePattern {
       return failure();
     if (failed(marker.checkAndNotify(rewriter, linalgOp)))
       return failure();
-    if (failed(promoteSubviewsLinalgOpPrecondition(op)))
-      return failure();
 
     if (loweringType == LinalgLoweringType::LibraryCall) {
       // TODO: Move lowering to library calls here.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index 8e93ea355a12..86c5ceaef579 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -45,7 +45,45 @@ using folded_std_view = FoldedValueBuilder<ViewOp>;
 
 #define DEBUG_TYPE "linalg-promotion"
 
-/// If `size` comes from an AffineMinOp and one of the dimensions of AffineMin
+namespace {
+
+/// Helper struct that captures the information required to apply the
+/// transformation on each op. This bridges the abstraction gap with the
+/// user-facing API which exposes positional arguments to control which operands
+/// are promoted.
+struct LinalgOpInstancePromotionOptions {
+  LinalgOpInstancePromotionOptions(LinalgOp op,
+                                   const LinalgPromotionOptions &options);
+  /// SubViews to promote.
+  SetVector<Value> subViews;
+  /// Allow the use of dynamicaly-sized buffers.
+  bool dynamicBuffers;
+  /// Alignment of promoted buffer.
+  Optional<unsigned> alignment;
+};
+} // namespace
+
+LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
+    LinalgOp linalgOp, const LinalgPromotionOptions &options)
+    : subViews(), dynamicBuffers(options.dynamicBuffers),
+      alignment(options.alignment) {
+  if (options.operandsToPromote.hasValue()) {
+    for (unsigned idx : options.operandsToPromote.getValue()) {
+      auto *op = linalgOp.getBuffer(idx).getDefiningOp();
+      if (auto sv = dyn_cast_or_null<SubViewOp>(op))
+        subViews.insert(sv);
+    }
+  } else {
+    unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers();
+    for (unsigned idx = 0; idx < nBuffers; ++idx) {
+      auto *op = linalgOp.getBuffer(idx).getDefiningOp();
+      if (auto sv = dyn_cast_or_null<SubViewOp>(op))
+        subViews.insert(sv);
+    }
+  }
+}
+
+/// If `size` comes from an AffineMinOp and one of the values of AffineMinOp
 /// is a constant then return a new value set to the smallest such constant.
 /// Otherwise return size.
 static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
@@ -53,25 +91,26 @@ static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
   auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
   if (!affineMinOp)
     return size;
-  if (!llvm::any_of(affineMinOp.getAffineMap().getResults(), [](AffineExpr e) {
-        return e.dyn_cast<AffineConstantExpr>();
-      }))
-    return size;
   int64_t minConst = std::numeric_limits<int64_t>::max();
   for (auto e : affineMinOp.getAffineMap().getResults())
     if (auto cst = e.dyn_cast<AffineConstantExpr>())
       minConst = std::min(minConst, cst.getValue());
-  assert(minConst != std::numeric_limits<int64_t>::max());
-  return b.create<ConstantIndexOp>(loc, minConst);
+  return (minConst == std::numeric_limits<int64_t>::max())
+             ? size
+             : b.create<ConstantIndexOp>(loc, minConst);
 }
 
+/// Alloc a new buffer of `size`. If `dynamicBuffers` is true allocate exactly
+/// the size needed, otherwise try to allocate a static bounding box.
 static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
-                         OperationFolder *folder, int64_t alignment = 0) {
+                         OperationFolder *folder,
+                         Optional<unsigned> alignment = None) {
   auto *ctx = size.getContext();
   auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8);
   IntegerAttr alignment_attr;
-  if (alignment)
-    alignment_attr = IntegerAttr::get(IntegerType::get(64, ctx), alignment);
+  if (alignment.hasValue())
+    alignment_attr =
+        IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue());
   if (!dynamicBuffers)
     if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
       return std_alloc(
@@ -100,11 +139,11 @@ static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
 // To account for general boundary effects, padding must be performed on the
 // boundary tiles. For now this is done with an unconditional `fill` op followed
 // by a partial `copy` op.
-static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
-                                           SubViewOp subView,
-                                           bool dynamicBuffers,
-                                           int64_t alignment,
-                                           OperationFolder *folder) {
+static PromotionInfo promoteSubviewAsNewBuffer(OpBuilder &b, Location loc,
+                                               SubViewOp subView,
+                                               bool dynamicBuffers,
+                                               Optional<unsigned> alignment,
+                                               OperationFolder *folder) {
   auto zero = folded_std_constant_index(folder, 0);
   auto one = folded_std_constant_index(folder, 1);
 
@@ -117,8 +156,10 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
   for (auto en : llvm::enumerate(subView.getRanges())) {
     auto rank = en.index();
     auto rangeValue = en.value();
-    // Try to extract a tight constant
+    // Try to extract a tight constant.
+    LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
     Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size);
+    LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
     allocSize = folded_std_muli(folder, allocSize, size);
     fullSizes.push_back(size);
     partialSizes.push_back(folded_std_dim(folder, subView, rank));
@@ -136,26 +177,26 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc,
   return PromotionInfo{buffer, fullLocalView, partialLocalView};
 }
 
-SmallVector<PromotionInfo, 8>
-mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
-                              ArrayRef<Value> subViews, bool dynamicBuffers,
-                              int64_t alignment, OperationFolder *folder) {
-  if (subViews.empty())
+static SmallVector<PromotionInfo, 8>
+promoteSubViews(OpBuilder &b, Location loc,
+                LinalgOpInstancePromotionOptions options,
+                OperationFolder *folder) {
+  if (options.subViews.empty())
     return {};
 
   ScopedContext scope(b, loc);
   SmallVector<PromotionInfo, 8> res;
-  res.reserve(subViews.size());
+  res.reserve(options.subViews.size());
   DenseMap<Value, PromotionInfo> promotionInfoMap;
-  for (auto v : subViews) {
+  for (auto v : options.subViews) {
     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
-    auto promotionInfo = promoteFullTileBuffer(b, loc, subView, dynamicBuffers,
-                                               alignment, folder);
+    auto promotionInfo = promoteSubviewAsNewBuffer(
+        b, loc, subView, options.dynamicBuffers, options.alignment, folder);
     promotionInfoMap.insert(std::make_pair(subView.getResult(), promotionInfo));
     res.push_back(promotionInfo);
   }
 
-  for (auto v : subViews) {
+  for (auto v : options.subViews) {
     SubViewOp subView = cast<SubViewOp>(v.getDefiningOp());
     auto info = promotionInfoMap.find(v);
     if (info == promotionInfoMap.end())
@@ -172,7 +213,7 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
     linalg_fill(info->second.fullLocalView, fillVal);
   }
 
-  for (auto v : subViews) {
+  for (auto v : options.subViews) {
     auto info = promotionInfoMap.find(v);
     if (info == promotionInfoMap.end())
       continue;
@@ -182,11 +223,9 @@ mlir::linalg::promoteSubViews(OpBuilder &b, Location loc,
   return res;
 }
 
-LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
-                                              SetVector<Value> subViews,
-                                              bool dynamicBuffers,
-                                              int64_t alignment,
-                                              OperationFolder *folder) {
+static void promoteSubViews(OpBuilder &b, LinalgOp op,
+                            LinalgOpInstancePromotionOptions options,
+                            OperationFolder *folder) {
   assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics");
 
   if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
@@ -196,17 +235,15 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
   }
 
   // 1. Promote the specified views and use them in the new op.
-  ScopedContext scope(b, op.getLoc());
-  auto promotedBufferAndViews =
-      promoteSubViews(b, op.getLoc(), subViews.getArrayRef(), dynamicBuffers,
-                      alignment, folder);
+  auto loc = op.getLoc();
+  auto promotedBufferAndViews = promoteSubViews(b, loc, options, folder);
   SmallVector<Value, 8> opViews;
   opViews.reserve(op.getNumInputsAndOutputs());
   SmallVector<std::pair<Value, Value>, 8> writebackViews;
-  writebackViews.reserve(subViews.size());
+  writebackViews.reserve(promotedBufferAndViews.size());
   unsigned promotedIdx = 0;
   for (auto view : op.getInputsAndOutputBuffers()) {
-    if (subViews.count(view) != 0) {
+    if (options.subViews.count(view) != 0) {
       opViews.push_back(promotedBufferAndViews[promotedIdx].fullLocalView);
       writebackViews.emplace_back(std::make_pair(
           view, promotedBufferAndViews[promotedIdx].partialLocalView));
@@ -219,67 +256,55 @@ LinalgOp mlir::linalg::promoteSubViewOperands(OpBuilder &b, LinalgOp op,
   // 2. Append all other operands as they appear, this enforces that such
   // operands are not views. This is to support cases such as FillOp taking
   // extra scalars etc.
-  auto operands = getAssumedNonViewOperands(op);
-  opViews.append(operands.begin(), operands.end());
-  LinalgOp res = op.clone(b, op.getLoc(), opViews);
+  // Keep a reference to output buffers;
+  DenseSet<Value> originalOutputs(op.getOutputBuffers().begin(),
+                                  op.getOutputBuffers().end());
+  op.getOperation()->setOperands(0, opViews.size(), opViews);
 
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointAfter(op);
+  ScopedContext scope(b, loc);
   // 3. Emit write-back for the promoted output views: copy the partial view.
-  for (auto viewAndPartialLocalView : writebackViews) {
-    // WARNING: MUST use the old op to determine whether the operand view is an
-    // output.
-    bool isOutput =
-        op.getIndexOfOutputBuffer(viewAndPartialLocalView.first).hasValue();
-    if (isOutput)
+  for (auto viewAndPartialLocalView : writebackViews)
+    if (originalOutputs.count(viewAndPartialLocalView.first))
       linalg_copy(viewAndPartialLocalView.second,
                   viewAndPartialLocalView.first);
-  }
 
-  // 4. Dealloc local buffers.
+  // 4. Dealloc all local buffers.
   for (const auto &pi : promotedBufferAndViews)
     std_dealloc(pi.buffer);
-
-  return res;
 }
 
-static void promoteSubViews(FuncOp f, bool dynamicBuffers) {
-  SmallVector<LinalgOp, 8> toErase;
-  OperationFolder folder(f.getContext());
-  f.walk([dynamicBuffers, &folder, &toErase](LinalgOp op) {
-    if (!op.hasBufferSemantics())
-      return;
-
-    // TODO(ntv) some heuristic here to decide what to promote. Atm only float
-    // and integer buffers can be promoted.
-    SetVector<Value> subViews;
-    OpBuilder b(op);
-    for (auto it : op.getInputsAndOutputBuffers())
-      if (auto sv = dyn_cast_or_null<SubViewOp>(it.getDefiningOp()))
-        if (sv.getType().getElementType().isSignlessIntOrFloat())
-          subViews.insert(sv);
-    if (!subViews.empty()) {
-      promoteSubViewOperands(b, op, subViews, dynamicBuffers, 0, &folder);
-      toErase.push_back(op);
-    }
-  });
-  for (auto op : toErase)
-    op.erase();
-}
-
-LogicalResult mlir::linalg::promoteSubviewsLinalgOpPrecondition(
-    Operation *op, llvm::Optional<DenseSet<unsigned>> operandIndicesToPromote) {
+LogicalResult
+mlir::linalg::promoteSubviewsPrecondition(Operation *op,
+                                          LinalgPromotionOptions options) {
   LinalgOp linOp = dyn_cast<LinalgOp>(op);
   // Transformation applies to buffers only.
   if (!linOp || !linOp.hasBufferSemantics())
     return failure();
+  // Check that at least one of the requested operands is indeed a subview.
   for (auto en : llvm::enumerate(linOp.getInputsAndOutputBuffers())) {
     auto sv = isa_and_nonnull<SubViewOp>(en.value().getDefiningOp());
-    if (sv && (!operandIndicesToPromote.hasValue() ||
-               operandIndicesToPromote->count(en.index())))
-      return success();
+    if (sv) {
+      if (!options.operandsToPromote.hasValue() ||
+          options.operandsToPromote->count(en.index()))
+        return success();
+    }
   }
+  // TODO: Check all subviews requested are bound by a static constant.
+  // TODO: Check that the total footprint fits within a given size.
   return failure();
 }
 
+LinalgOp mlir::linalg::promoteSubViews(OpBuilder &b, LinalgOp linalgOp,
+                                       LinalgPromotionOptions options,
+                                       OperationFolder *folder) {
+  LinalgOpInstancePromotionOptions linalgOptions(linalgOp, options);
+  ::promoteSubViews(
+      b, linalgOp, LinalgOpInstancePromotionOptions(linalgOp, options), folder);
+  return linalgOp;
+}
+
 namespace {
 struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> {
   LinalgPromotionPass() = default;
@@ -288,11 +313,20 @@ struct LinalgPromotionPass : public LinalgPromotionBase<LinalgPromotionPass> {
   }
 
   void runOnFunction() override {
-    promoteSubViews(getFunction(), dynamicBuffers);
+    OperationFolder folder(&getContext());
+    getFunction().walk([this, &folder](LinalgOp op) {
+      auto options = LinalgPromotionOptions().setDynamicBuffers(dynamicBuffers);
+      if (failed(promoteSubviewsPrecondition(op, options)))
+        return;
+      LLVM_DEBUG(llvm::dbgs() << "Promote: " << *(op.getOperation()) << "\n");
+      OpBuilder b(op);
+      promoteSubViews(b, op, options, &folder);
+    });
   }
 };
 } // namespace
 
+// TODO: support more transformation options in the pass.
 std::unique_ptr<OperationPass<FuncOp>>
 mlir::createLinalgPromotionPass(bool dynamicBuffers) {
   return std::make_unique<LinalgPromotionPass>(dynamicBuffers);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index e229b10072f0..175c6c8ef096 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -160,51 +160,23 @@ LogicalResult mlir::linalg::LinalgBaseInterchangePattern::matchAndRewrite(
 }
 
 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
-    StringRef opName, MLIRContext *context,
-    ArrayRef<unsigned> operandsToPromote, unsigned alignment,
+    StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
     LinalgMarker marker, PatternBenefit benefit)
     : RewritePattern(opName, {}, benefit, context), marker(marker),
-      operandsToPromote(operandsToPromote.begin(), operandsToPromote.end()),
-      alignment(alignment) {}
+      options(options) {}
 
 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
     Operation *op, PatternRewriter &rewriter) const {
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
+  if (failed(marker.checkAndNotify(rewriter, op)))
     return failure();
-  if (failed(marker.checkAndNotify(rewriter, linalgOp)))
+  if (failed(promoteSubviewsPrecondition(op, options)))
     return failure();
-  if (operandsToPromote.empty()) {
-    if (failed(promoteSubviewsLinalgOpPrecondition(op, llvm::None)))
-      return failure();
-  } else {
-    DenseSet<unsigned> set;
-    set.insert(operandsToPromote.begin(), operandsToPromote.end());
-    if (failed(promoteSubviewsLinalgOpPrecondition(op, set)))
-      return failure();
-  }
-
-  llvm::SetVector<Value> subViews;
-  if (!operandsToPromote.empty()) {
-    for (unsigned idx : operandsToPromote) {
-      auto *op = linalgOp.getBuffer(idx).getDefiningOp();
-      if (auto sv = dyn_cast_or_null<SubViewOp>(op))
-        subViews.insert(sv);
-    }
-  } else {
-    unsigned nBuffers = linalgOp.getNumInputsAndOutputBuffers();
-    for (unsigned idx = 0; idx < nBuffers; ++idx) {
-      auto *op = linalgOp.getBuffer(idx).getDefiningOp();
-      if (auto sv = dyn_cast_or_null<SubViewOp>(op))
-        subViews.insert(sv);
-    }
-  }
-
-  auto promotedOp =
-      promoteSubViewOperands(rewriter, op, subViews, /*dynamicBuffers=*/false,
-                             /*alignment=*/alignment);
-  marker.replaceLinalgMarker(rewriter, promotedOp.getOperation());
-  rewriter.eraseOp(op);
+  rewriter.updateRootInPlace(op, [&]() {
+    auto promotedOp = promoteSubViews(rewriter, op, options);
+    (void)promotedOp;
+    assert(promotedOp && "Unexpected pattern failure");
+    marker.replaceLinalgMarker(rewriter, op);
+  });
   return success();
 }
 

diff  --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index f3861c38fa60..eb27a7ae0034 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -120,15 +120,13 @@ static void applyPatterns(FuncOp funcOp) {
   // Linalg subview operands promotion.
   //===--------------------------------------------------------------------===//
   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
-      ctx, LinalgMarker({"_promote_views_"}, "_views_promoted_"));
+      ctx, LinalgPromotionOptions(),
+      LinalgMarker({"_promote_views_"}, "_views_promoted_"));
   patterns.insert<LinalgPromotionPattern<MatmulOp>>(
-      ctx,
-      /*operandsToPromote=*/ArrayRef<unsigned>{0},
+      ctx, LinalgPromotionOptions().setOperandsToPromote({0}),
       LinalgMarker({"_promote_first_view_"}, "_first_view_promoted_"));
   patterns.insert<LinalgPromotionPattern<FillOp>>(
-      ctx,
-      /*operandsToPromote=*/ArrayRef<unsigned>{0},
-      /*alignment=*/32,
+      ctx, LinalgPromotionOptions().setOperandsToPromote({0}).setAlignment(32),
       LinalgMarker({"_promote_views_aligned_"}, "_views_aligned_promoted_"));
 
   applyPatternsAndFoldGreedily(funcOp, patterns);


        


More information about the Mlir-commits mailing list