[Mlir-commits] [mlir] 4a66160 - [mlir][Linalg] NFC - Modernize APIs and get rid of unnecessary tiling paterns.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jan 6 13:27:40 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-06T16:27:35-05:00
New Revision: 4a661602ef2db22272cbb39bdb179996dbfa54b1

URL: https://github.com/llvm/llvm-project/commit/4a661602ef2db22272cbb39bdb179996dbfa54b1
DIFF: https://github.com/llvm/llvm-project/commit/4a661602ef2db22272cbb39bdb179996dbfa54b1.diff

LOG: [mlir][Linalg] NFC - Modernize APIs and get rid of unnecessary tiling paterns.

Tiling patterns can be reduced to a single pattern by using interface-based patterns.

Differential Revision: https://reviews.llvm.org/D116733

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
    mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 7592094410632..4b55caed849d3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -169,9 +169,14 @@ struct TiledLinalgOp {
   SmallVector<Operation *, 8> loops;
   SmallVector<Value, 4> tensorResults;
 };
-FailureOr<TiledLinalgOp> tileLinalgOp(OpBuilder &b, LinalgOp op,
+FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op,
                                       const LinalgTilingOptions &options);
 
+/// Peel the loops of a TiledLinalgOp.
+void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
+                       ArrayRef<int64_t> peeledLoops,
+                       LinalgTilingLoopType loopType);
+
 /// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This
 /// proceeds as follows:
 /// - Find outer parallel loops in these ops that can be fused.
@@ -594,24 +599,35 @@ struct LinalgTilingOptions {
 RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
 void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
 
-/// Base pattern that applies the tiling transformation specified by `options`.
-/// Abort and return failure in 2 cases:
-///   1. if the tiling specification is invalid and tiling fails to occur.
-///   2. if tiling occurs but `options.paddingValueComputationFunction` is set
-///      and some operand shape cannot be bounded statically.
-struct LinalgBaseTilingPattern : public RewritePattern {
-  // Entry point to match any LinalgOp OpInterface.
-  LinalgBaseTilingPattern(
+///
+/// Linalg tiling pattern.
+///
+/// Apply the `tiling` transformation as a pattern.
+/// `filter` controls LinalgTransformMarker matching and update when specified.
+/// See `tiling` for more details.
+// TODO: TiledOpInterface
+struct LinalgTilingPattern : public OpInterfaceRewritePattern<LinalgOp> {
+  /// Construct a generic pattern applied to all LinalgOp that verify `f`.
+  LinalgTilingPattern(
       MLIRContext *context, LinalgTilingOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  // Entry point to match a specific Linalg op.
-  LinalgBaseTilingPattern(
+
+  /// Construct a pattern specifically applied to `opName`.
+  LinalgTilingPattern(
       StringRef opName, MLIRContext *context, LinalgTilingOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter,
-                                    TiledLinalgOp &result) const;
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  FailureOr<TiledLinalgOp>
+  returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
@@ -620,68 +636,6 @@ struct LinalgBaseTilingPattern : public RewritePattern {
   LinalgTilingOptions options;
 };
 
-template <typename OpTy>
-struct LinalgTilingPattern : public LinalgBaseTilingPattern {
-  /// SFINAE: This constructor can only trigger for concrete ops that have a
-  /// static `getOperationName` method.
-  template <typename ConcreateOpTy = OpTy>
-  LinalgTilingPattern(
-      MLIRContext *context, LinalgTilingOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context,
-                                options, filter, benefit) {}
-
-  /// This constructor is available to anyone.
-  LinalgTilingPattern(
-      StringRef opName, MLIRContext *context, LinalgTilingOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
-
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override {
-    TiledLinalgOp tiledLinalgOp;
-    if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
-                                                            tiledLinalgOp)))
-      return failure();
-    if (tiledLinalgOp.tensorResults.empty())
-      rewriter.eraseOp(op);
-    else
-      rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
-    return success();
-  }
-};
-
-struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern {
-  /// Entry point to match any LinalgOp OpInterface.
-  /// MatchAnyOpTag-based constructor with a mandatory `filter`.
-  LinalgGenericTilingPattern(
-      MLIRContext *context, LinalgTransformationFilter filter,
-      LinalgTilingOptions options = LinalgTilingOptions(),
-      PatternBenefit benefit = 1)
-      : LinalgBaseTilingPattern(context, options, filter, benefit) {}
-  /// Entry point to match a specific Linalg op.
-  LinalgGenericTilingPattern(
-      StringRef opName, MLIRContext *context, LinalgTilingOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : LinalgBaseTilingPattern(opName, context, options, filter, benefit) {}
-
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override {
-    TiledLinalgOp tiledLinalgOp;
-    if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter,
-                                                            tiledLinalgOp)))
-      return failure();
-    if (tiledLinalgOp.tensorResults.empty())
-      rewriter.eraseOp(op);
-    else
-      rewriter.replaceOp(op, tiledLinalgOp.tensorResults);
-    return success();
-  }
-};
-
 ///
 /// Linalg padding pattern.
 ///
@@ -1395,6 +1349,32 @@ struct ExtractSliceOfPadTensorSwapPattern
                                 PatternRewriter &rewriter) const override;
 };
 
+//===----------------------------------------------------------------------===//
+// Helper classes for type list expansion.
+//===----------------------------------------------------------------------===//
+template <typename... OpTypes>
+class TilingPatterns;
+
+template <>
+class TilingPatterns<> {
+public:
+  static void insert(RewritePatternSet &patterns,
+                     const LinalgTilingOptions &options,
+                     const LinalgTransformationFilter &f) {}
+};
+
+template <typename OpTy, typename... OpTypes>
+class TilingPatterns<OpTy, OpTypes...> {
+public:
+  static void insert(RewritePatternSet &patterns,
+                     const LinalgTilingOptions &options,
+                     const LinalgTransformationFilter &f) {
+    patterns.add<LinalgTilingPattern>(OpTy::getOperationName(),
+                                      patterns.getContext(), options, f);
+    TilingPatterns<OpTypes...>::insert(patterns, options, f);
+  }
+};
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index f426af01d8722..bc53a719a4741 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -784,7 +784,9 @@ tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector,
       tileSizes[i] = zero;
   LinalgTilingOptions tileFusedLoopsOptions = options;
   tileFusedLoopsOptions.setTileSizes(tileSizes);
-  return tileLinalgOp(b, op, tileFusedLoopsOptions);
+  // TODO: Propagate RewriterBase everywhere.
+  IRRewriter rewriter(b);
+  return tileLinalgOp(rewriter, op, tileFusedLoopsOptions);
 }
 
 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 6bdcc192e27aa..eb1415dabde2d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -283,10 +283,14 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
                           tileInterchange.begin(), tileInterchange.end()))
                       .setTileSizes(tileSizes)
                       .setLoopType(LinalgTilingLoopType::Loops);
-  Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);
+
+  // TODO: Propagate RewriterBase everywhere.
+  IRRewriter rewriter(b);
+  FailureOr<TiledLinalgOp> tiledRootOp =
+      tileLinalgOp(rewriter, rootOp, tilingOptions);
 
   // Exit if tiling the root operation fails.
-  if (!tiledRootOp.hasValue())
+  if (failed(tiledRootOp))
     return failure();
 
   // Replace all uses of the root operation if it has been tiled before. All

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
index 136f38feedf35..859f3f8521b13 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp
@@ -1,4 +1,4 @@
-//===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===//
+//===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -93,14 +93,13 @@ struct LinalgStrategyTilePass
     if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName)
       return;
 
-    RewritePatternSet tilingPattern(funcOp.getContext());
-    if (!anchorOpName.empty()) {
-      tilingPattern.add<LinalgGenericTilingPattern>(
-          anchorOpName, funcOp.getContext(), options, filter);
-    } else {
-      tilingPattern.add<LinalgGenericTilingPattern>(funcOp.getContext(), filter,
-                                                    options);
-    }
+    MLIRContext *ctx = funcOp.getContext();
+    RewritePatternSet tilingPattern(ctx);
+    if (!anchorOpName.empty())
+      tilingPattern.add<LinalgTilingPattern>(anchorOpName, ctx, options,
+                                             filter);
+    else
+      tilingPattern.add<LinalgTilingPattern>(ctx, options, filter);
     (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
   }
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index cb2987973ea51..89ca83375c0f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -51,7 +51,7 @@ using LoopIndexToRangeIndexMap = DenseMap<int, int>;
 // a map from loop indices of the LinalgOp to the corresponding non-empty range
 // indices of newly created loops.
 static std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap>
-makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
+makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map,
                     ValueRange allShapeSizes, ValueRange allTileSizes) {
   assert(allTileSizes.size() == map.getNumResults());
   // Apply `map` to get shape sizes in loop order.
@@ -129,7 +129,7 @@ makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map,
 // TODO: Investigate whether mixing implicit and explicit indices
 // does not lead to losing information.
 static void
-transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
+transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
                   const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) {
   SmallVector<Value> allIvs(op.getNumLoops(), nullptr);
   for (auto &en : enumerate(allIvs)) {
@@ -144,7 +144,7 @@ transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl<Value> &ivs,
 // Insert a tile `source` into the destination tensor `dest`. The position at
 // which the tile is inserted (as well as size of tile) is taken from a given
 // ExtractSliceOp `sliceOp`.
-static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
+static Value insertSliceIntoTensor(RewriterBase &b, Location loc,
                                    tensor::ExtractSliceOp sliceOp, Value source,
                                    Value dest) {
   return b.create<tensor::InsertSliceOp>(
@@ -155,7 +155,7 @@ static Value insertSliceIntoTensor(OpBuilder &b, Location loc,
 
 template <typename LoopTy>
 static FailureOr<TiledLinalgOp>
-tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
+tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
                  const LinalgTilingOptions &options) {
   auto nLoops = op.getNumLoops();
   // Initial tile sizes may be too big, only take the first nLoops.
@@ -216,7 +216,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
   LinalgOp res = op;
   SmallVector<Value, 4> ivs, tensorResults;
   auto tiledLoopBodyBuilder =
-      [&](OpBuilder &b, Location loc, ValueRange localIvs,
+      [&](OpBuilder &builder, Location loc, ValueRange localIvs,
           ValueRange operandValuesToUse) -> scf::ValueVector {
     ivs.assign(localIvs.begin(), localIvs.end());
 
@@ -255,9 +255,12 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
       // TODO: use an interface/adaptor to avoid leaking position in
       // `tiledOperands`.
       Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
+      // TODO: Propagate RewriterBase everywhere.
+      IRRewriter rewriter(b);
       if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
-        tensorResults.push_back(insertSliceIntoTensor(
-            b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source()));
+        tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp,
+                                                      res->getResult(resultIdx),
+                                                      sliceOp.source()));
       } else {
         tensorResults.push_back(res->getResult(resultIdx));
       }
@@ -299,7 +302,7 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
 
 template <typename LoopTy>
 FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
-    OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) {
+    RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) {
   OpBuilder::InsertionGuard g(b);
   b.setInsertionPoint(op);
 
@@ -321,7 +324,7 @@ FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
 }
 
 FailureOr<TiledLinalgOp>
-mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
+mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op,
                            const LinalgTilingOptions &options) {
   switch (options.loopType) {
   case LinalgTilingLoopType::Loops:
@@ -338,7 +341,7 @@ mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op,
 /// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp`
 /// and `loopNest` are output parameters that return the new (tiled) PadTensorOp
 /// and the loop nest.
-static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
+static LogicalResult tilePadTensorOp(RewriterBase &builder, PadTensorOp op,
                                      PadTensorOp &newPadOp, LoopNest &loopNest,
                                      const LinalgTilingOptions &options) {
   Location loc = op.getLoc();
@@ -384,8 +387,10 @@ static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op,
         auto sliceOp = tiledOutput.getDefiningOp<tensor::ExtractSliceOp>();
         assert(sliceOp && "expected ExtractSliceOp");
         // Insert the tile into the output tensor.
+        // TODO: Propagate RewriterBase everywhere.
+        IRRewriter rewriter(b);
         Value yieldValue =
-            insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]);
+            insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]);
         return scf::ValueVector({yieldValue});
       });
   return success();
@@ -434,31 +439,6 @@ class CanonicalizationPatternList<OpTy, OpTypes...> {
     CanonicalizationPatternList<OpTypes...>::insert(patterns);
   }
 };
-
-/// Helper classes for type list expansion.
-template <typename... OpTypes>
-class RewritePatternList;
-
-template <>
-class RewritePatternList<> {
-public:
-  static void insert(RewritePatternSet &patterns,
-                     const LinalgTilingOptions &options) {}
-};
-
-template <typename OpTy, typename... OpTypes>
-class RewritePatternList<OpTy, OpTypes...> {
-public:
-  static void insert(RewritePatternSet &patterns,
-                     const LinalgTilingOptions &options) {
-    auto *ctx = patterns.getContext();
-    patterns.add<LinalgTilingPattern<OpTy>>(
-        ctx, options,
-        LinalgTransformationFilter(ArrayRef<StringAttr>{},
-                                   StringAttr::get(ctx, "tiled")));
-    RewritePatternList<OpTypes...>::insert(patterns, options);
-  }
-};
 } // namespace
 
 RewritePatternSet
@@ -500,11 +480,14 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
 /// Populate the given list with patterns that apply Linalg tiling.
 static void insertTilingPatterns(RewritePatternSet &patterns,
                                  const LinalgTilingOptions &options) {
-  RewritePatternList<GenericOp,
+  auto *ctx = patterns.getContext();
+  LinalgTransformationFilter f(ArrayRef<StringAttr>{},
+                               StringAttr::get(ctx, "tiled"));
+  TilingPatterns<GenericOp,
 #define GET_OP_LIST
 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
-                     >::insert(patterns, options);
-  patterns.add<PadTensorOpTilingPattern>(patterns.getContext(), options);
+                 >::insert(patterns, options, f);
+  patterns.add<PadTensorOpTilingPattern>(ctx, options);
 }
 
 static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 177a2abda6e7d..c1482f44b4cd0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1,4 +1,4 @@
-//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
+//===- Transforms.cpp - Linalg transformations as patterns ----------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -284,19 +284,6 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
   return paddedSubviewResults;
 }
 
-/// Linalg base tiling pattern.
-mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
-    StringRef opName, MLIRContext *context, LinalgTilingOptions options,
-    LinalgTransformationFilter filter, PatternBenefit benefit)
-    : RewritePattern(opName, benefit, context), filter(std::move(filter)),
-      options(std::move(options)) {}
-
-mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
-    MLIRContext *context, LinalgTilingOptions options,
-    LinalgTransformationFilter filter, PatternBenefit benefit)
-    : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-      filter(std::move(filter)), options(std::move(options)) {}
-
 /// Try to peel a loop `op` and return the new result.
 // TODO: Add support for scf.parallel and affine.for loops.
 static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter, Operation *op) {
@@ -325,14 +312,15 @@ static SmallVector<Value, 4> peelLoop(RewriterBase &rewriter,
 }
 
 /// Peel loops after tiling.
-static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
-                      const LinalgTilingOptions &options) {
-  for (int64_t loop : options.peeledLoops) {
+void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res,
+                                     ArrayRef<int64_t> peeledLoops,
+                                     LinalgTilingLoopType loopType) {
+  for (int64_t loop : peeledLoops) {
     assert(loop < static_cast<int64_t>(res.loops.size()) &&
            "requested peeling of non-existing loop");
     SmallVector<Value, 4> loopResults;
     Operation *loopOp = res.loops[loop];
-    if (options.loopType == LinalgTilingLoopType::TiledLoops) {
+    if (loopType == LinalgTilingLoopType::TiledLoops) {
       assert(llvm::all_of(
                  res.loops,
                  [&](Operation *op) { return op == res.loops.front(); }) &&
@@ -352,28 +340,6 @@ static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res,
   }
 }
 
-LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
-    Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
-    return failure();
-  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
-    return failure();
-
-  Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
-
-  if (!res)
-    return failure();
-  // Clear filter to stop recursive pattern application.
-  filter.replaceLinalgTransformationFilter(rewriter, res->op);
-
-  // Peel loops.
-  peelLoops(rewriter, *res, options);
-
-  result = *res;
-  return success();
-}
-
 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
   if (tiledOp.loops.empty())
     return tiledOp.op.getOperation()->getResults();
@@ -459,9 +425,9 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
       })) {
     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
-    Optional<TiledLinalgOp> unfusedTiledOp =
+    FailureOr<TiledLinalgOp> unfusedTiledOp =
         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
-    if (!unfusedTiledOp)
+    if (failed(unfusedTiledOp))
       return failure();
     rewriter.replaceOp(tiledAndFusedOps->op,
                        getTiledOpResult(unfusedTiledOp.getValue()));
@@ -485,6 +451,48 @@ LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
   return success();
 }
 
+/// Linalg tiling pattern.
+mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
+    MLIRContext *context, LinalgTilingOptions options,
+    LinalgTransformationFilter f, PatternBenefit benefit)
+    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+      filter(std::move(f)), options(std::move(options)) {}
+
+mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
+    StringRef opName, MLIRContext *context, LinalgTilingOptions options,
+    LinalgTransformationFilter f, PatternBenefit benefit)
+    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+      filter(std::move(f)), options(std::move(options)) {
+  this->filter.addFilter([opName](Operation *op) {
+    return success(op->getName().getStringRef() == opName);
+  });
+}
+
+FailureOr<TiledLinalgOp>
+mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
+    LinalgOp op, PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, op)))
+    return failure();
+
+  FailureOr<TiledLinalgOp> res = tileLinalgOp(rewriter, op, options);
+  if (failed(res))
+    return failure();
+
+  // Clear filter to stop recursive pattern application.
+  // This must be done here to properly propagate to peeling branches.
+  filter.replaceLinalgTransformationFilter(rewriter, res->op);
+
+  // Peel the loops of the TiledLinalgOp.
+  peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType);
+
+  if (res->tensorResults.empty())
+    rewriter.eraseOp(op);
+  else
+    rewriter.replaceOp(op, res->tensorResults);
+
+  return res;
+}
+
 /// Linalg padding pattern.
 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
     MLIRContext *context, LinalgPaddingOptions options,

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4a597f64d72ff..14593800b16f1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1178,8 +1178,9 @@ static void populateVectorizationPatterns(
 
   constexpr static StringRef kTiledMarker = "TILED";
   constexpr static StringRef kPromotedMarker = "PROMOTED";
-  tilingPatterns.add<LinalgTilingPattern<ConvOp>>(
-      context, LinalgTilingOptions().setTileSizes(tileSizes),
+  tilingPatterns.add<LinalgTilingPattern>(
+      ConvOp::getOperationName(), context,
+      LinalgTilingOptions().setTileSizes(tileSizes),
       LinalgTransformationFilter(ArrayRef<StringAttr>{},
                                  StringAttr::get(kTiledMarker, context)));
 

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index d4119f26c8197..0c8ab052a88c1 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -138,32 +138,36 @@ static void applyPatterns(FuncOp funcOp) {
   //===--------------------------------------------------------------------===//
   // Linalg tiling patterns.
   //===--------------------------------------------------------------------===//
-  patterns.add<LinalgTilingPattern<MatmulOp>>(
-      ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
+  patterns.add<LinalgTilingPattern>(
+      MatmulOp::getOperationName(), ctx,
+      LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
       LinalgTransformationFilter(StringAttr::get(ctx, "MEM"),
                                  StringAttr::get(ctx, "L3")));
-  patterns.add<LinalgTilingPattern<MatmulOp>>(
-      ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
+  patterns.add<LinalgTilingPattern>(
+      MatmulOp::getOperationName(), ctx,
+      LinalgTilingOptions().setTileSizes({200, 300, 400}),
       LinalgTransformationFilter(StringAttr::get(ctx, "L3"),
                                  StringAttr::get(ctx, "L2")));
-  patterns.add<LinalgTilingPattern<MatmulOp>>(
-      ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
+  patterns.add<LinalgTilingPattern>(
+      MatmulOp::getOperationName(), ctx,
+      LinalgTilingOptions().setTileSizes({20, 30, 40}),
       LinalgTransformationFilter(StringAttr::get(ctx, "L2"),
                                  StringAttr::get(ctx, "L1")));
-  patterns.add<LinalgTilingPattern<MatmulOp>>(
-      ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
+  patterns.add<LinalgTilingPattern>(
+      MatmulOp::getOperationName(), ctx,
+      LinalgTilingOptions().setTileSizes({2, 3, 4}),
       LinalgTransformationFilter(StringAttr::get(ctx, "L1"),
                                  StringAttr::get(ctx, "REG")));
 
-  patterns.add<LinalgTilingPattern<MatvecOp>>(
-      ctx,
+  patterns.add<LinalgTilingPattern>(
+      MatvecOp::getOperationName(), ctx,
       LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
           LinalgTilingLoopType::ParallelLoops),
       LinalgTransformationFilter(ArrayRef<StringAttr>{},
                                  StringAttr::get(ctx, "L1")));
 
-  patterns.add<LinalgTilingPattern<DotOp>>(
-      ctx, LinalgTilingOptions().setTileSizes(8000),
+  patterns.add<LinalgTilingPattern>(
+      DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000),
       LinalgTransformationFilter(
           ArrayRef<StringAttr>{StringAttr::get(ctx, "MEM"),
                                StringAttr::get(ctx, "L3"),
@@ -173,32 +177,34 @@ static void applyPatterns(FuncOp funcOp) {
   //===--------------------------------------------------------------------===//
   // Linalg tiling and permutation patterns.
   //===--------------------------------------------------------------------===//
-  patterns.add<LinalgTilingPattern<MatmulOp>>(
-      ctx,
+  patterns.add<LinalgTilingPattern>(
+      MatmulOp::getOperationName(), ctx,
       LinalgTilingOptions()
           .setTileSizes({2000, 3000, 4000})
           .setInterchange({1, 2, 0}),
       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
                                  StringAttr::get(ctx, "L2__with_perm__")));
-  patterns.add<LinalgTilingPattern<MatmulOp>>(
-      ctx,
+  patterns.add<LinalgTilingPattern>(
+      MatmulOp::getOperationName(), ctx,
       LinalgTilingOptions()
           .setTileSizes({200, 300, 400})
           .setInterchange({1, 0, 2}),
       LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"),
                                  StringAttr::get(ctx, "L1__with_perm__")));
-  patterns.add<LinalgTilingPattern<MatmulOp>>(
-      ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
+  patterns.add<LinalgTilingPattern>(
+      MatmulOp::getOperationName(), ctx,
+      LinalgTilingOptions().setTileSizes({20, 30, 40}),
       LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"),
                                  StringAttr::get(ctx, "REG__with_perm__")));
 
-  patterns.add<LinalgTilingPattern<MatvecOp>>(
-      ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
+  patterns.add<LinalgTilingPattern>(
+      MatvecOp::getOperationName(), ctx,
+      LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
       LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"),
                                  StringAttr::get(ctx, "L1__with_perm__")));
 
-  patterns.add<LinalgTilingPattern<MatmulOp>>(
-      ctx,
+  patterns.add<LinalgTilingPattern>(
+      MatmulOp::getOperationName(), ctx,
       LinalgTilingOptions()
           .setTileSizes({16, 8, 4})
           .setInterchange({1, 2, 0})
@@ -274,8 +280,8 @@ static void fillL1TilingAndMatmulToVectorPatterns(
     SmallVectorImpl<RewritePatternSet> &patternsVector) {
   MLIRContext *ctx = funcOp.getContext();
   patternsVector.emplace_back(
-      ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
-               ctx,
+      ctx, std::make_unique<LinalgTilingPattern>(
+               MatmulOp::getOperationName(), ctx,
                LinalgTilingOptions()
                    .setTileSizes({8, 12, 16})
                    .setInterchange({1, 0, 2}),
@@ -339,8 +345,9 @@ static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
 
 static void fillPromotionCallBackPatterns(MLIRContext *ctx,
                                           RewritePatternSet &patterns) {
-  patterns.add<LinalgTilingPattern<MatmulOp>>(
-      ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
+  patterns.add<LinalgTilingPattern>(
+      MatmulOp::getOperationName(), ctx,
+      LinalgTilingOptions().setTileSizes({16, 16, 16}),
       LinalgTransformationFilter(StringAttr::get(ctx, "START"),
                                  StringAttr::get(ctx, "PROMOTE")));
   patterns.add<LinalgPromotionPattern<MatmulOp>>(
@@ -382,8 +389,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
         2, DistributionMethod::CyclicNumProcsEqNumIters);
     cyclicNprocsEqNiters.procInfo =
         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
-    patterns.add<LinalgTilingPattern<MatmulOp>>(
-        context,
+    patterns.add<LinalgTilingPattern>(
+        MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
             .setTileSizes({8, 8, 4})
             .setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -399,8 +406,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
         2, DistributionMethod::CyclicNumProcsGeNumIters);
     cyclicNprocsGeNiters.procInfo =
         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
-    patterns.add<LinalgTilingPattern<MatmulOp>>(
-        context,
+    patterns.add<LinalgTilingPattern>(
+        MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
             .setTileSizes({8, 8, 4})
             .setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -416,8 +423,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
                                                   DistributionMethod::Cyclic);
     cyclicNprocsDefault.procInfo =
         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
-    patterns.add<LinalgTilingPattern<MatmulOp>>(
-        context,
+    patterns.add<LinalgTilingPattern>(
+        MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
             .setTileSizes({8, 8, 4})
             .setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -433,8 +440,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
         DistributionMethod::CyclicNumProcsEqNumIters,
         DistributionMethod::CyclicNumProcsGeNumIters};
     cyclicNprocsMixed1.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
-    patterns.add<LinalgTilingPattern<MatmulOp>>(
-        context,
+    patterns.add<LinalgTilingPattern>(
+        MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
             .setTileSizes({8, 8, 4})
             .setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -450,8 +457,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
         DistributionMethod::CyclicNumProcsGeNumIters,
         DistributionMethod::Cyclic};
     cyclicNprocsMixed2.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
-    patterns.add<LinalgTilingPattern<MatmulOp>>(
-        context,
+    patterns.add<LinalgTilingPattern>(
+        MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
             .setTileSizes({8, 8, 4})
             .setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -468,8 +475,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
         DistributionMethod::CyclicNumProcsEqNumIters};
     cyclicNprocsMixed3.procInfo = getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
 
-    patterns.add<LinalgTilingPattern<MatmulOp>>(
-        context,
+    patterns.add<LinalgTilingPattern>(
+        MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
             .setTileSizes({8, 8, 4})
             .setLoopType(LinalgTilingLoopType::ParallelLoops)
@@ -485,8 +492,8 @@ static void fillTileAndDistributePatterns(MLIRContext *context,
                                                    DistributionMethod::Cyclic);
     cyclicNprocsEqNiters.procInfo =
         getGpuProcIds<gpu::BlockIdOp, gpu::GridDimOp>;
-    patterns.add<LinalgTilingPattern<MatmulOp>>(
-        context,
+    patterns.add<LinalgTilingPattern>(
+        MatmulOp::getOperationName(), context,
         LinalgTilingOptions()
             .setTileSizes({8, 8, 4})
             .setLoopType(LinalgTilingLoopType::Loops)
@@ -507,8 +514,8 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
     fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
   } else if (testMatmulToVectorPatterns2dTiling) {
     stage1Patterns.emplace_back(
-        ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
-                 ctx,
+        ctx, std::make_unique<LinalgTilingPattern>(
+                 MatmulOp::getOperationName(), ctx,
                  LinalgTilingOptions()
                      .setTileSizes({768, 264, 768})
                      .setInterchange({1, 2, 0}),
@@ -589,10 +596,9 @@ static void applyTilePattern(FuncOp funcOp, const std::string &loopType,
   } else {
     linalgTilingOptions.setTileSizes(tileSizes);
   }
-  tilingPattern.add<linalg::LinalgTilingPattern<linalg::MatmulOp>,
-                    linalg::LinalgTilingPattern<linalg::GenericOp>>(
-      context, linalgTilingOptions,
-      linalg::LinalgTransformationFilter(StringAttr::get(context, "tile")));
+  linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile"));
+  TilingPatterns<linalg::MatmulOp, linalg::GenericOp>::insert(
+      tilingPattern, linalgTilingOptions, f);
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern));
 }
 


        


More information about the Mlir-commits mailing list