[Mlir-commits] [mlir] 9cd7e88 - [mlir][Linalg] NFC - Modernize more transformation patterns.

Nicolas Vasilache llvmlistbot at llvm.org
Thu Jan 6 14:41:26 PST 2022


Author: Nicolas Vasilache
Date: 2022-01-06T17:40:23-05:00
New Revision: 9cd7e880fd22ff3f8058e81dcd59c2f03074053d

URL: https://github.com/llvm/llvm-project/commit/9cd7e880fd22ff3f8058e81dcd59c2f03074053d
DIFF: https://github.com/llvm/llvm-project/commit/9cd7e880fd22ff3f8058e81dcd59c2f03074053d.diff

LOG: [mlir][Linalg] NFC - Modernize more transformation patterns.

Differential Revision: https://reviews.llvm.org/D116763

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
    mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4b55caed849d..72726f7b006b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -442,11 +442,19 @@ struct LinalgTransformationFilter {
       filters.push_back(f);
     return *this;
   }
+
   template <typename... OpTypes>
   LinalgTransformationFilter &addOpFilter() {
     return addFilter(
         [](Operation *op) { return success(isa<OpTypes...>(op)); });
   }
+
+  LinalgTransformationFilter &addOpNameFilter(StringRef opName) {
+    return addFilter([opName](Operation *op) {
+      return success(op->getName().getStringRef() == opName);
+    });
+  }
+
   LinalgTransformationFilter &setMatchByDefault() {
     matchByDefault = true;
     return *this;
@@ -607,7 +615,7 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
 /// See `tiling` for more details.
 // TODO: TiledOpInterface
 struct LinalgTilingPattern : public OpInterfaceRewritePattern<LinalgOp> {
-  /// Construct a generic pattern applied to all LinalgOp that verify `f`.
+  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
   LinalgTilingPattern(
       MLIRContext *context, LinalgTilingOptions options,
       LinalgTransformationFilter f = LinalgTransformationFilter(),
@@ -643,20 +651,29 @@ struct LinalgTilingPattern : public OpInterfaceRewritePattern<LinalgOp> {
 /// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `padding` for more details.
 struct LinalgPaddingPattern : public OpInterfaceRewritePattern<LinalgOp> {
-  // Entry point to match any LinalgOp OpInterface.
+  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
   LinalgPaddingPattern(
       MLIRContext *context,
       LinalgPaddingOptions options = LinalgPaddingOptions(),
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  // Entry point to match a specific LinalgOp.
+
+  /// Construct a pattern specifically applied to `opName`.
   LinalgPaddingPattern(
       StringRef opName, MLIRContext *context,
       LinalgPaddingOptions options = LinalgPaddingOptions(),
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  LogicalResult matchAndRewrite(LinalgOp,
-                                PatternRewriter &rewriter) const override;
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  FailureOr<LinalgOp> returningMatchAndRewrite(LinalgOp op,
+                                               PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
@@ -679,7 +696,7 @@ struct LinalgBaseTileAndFusePattern : public RewritePattern {
       StringRef opName, MLIRContext *context,
       const LinalgDependenceGraph &dependenceGraph,
       LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
       LinalgTransformationFilter originalOpMarker =
           LinalgTransformationFilter(),
@@ -711,14 +728,14 @@ struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern {
   LinalgTileAndFusePattern(
       MLIRContext *context, const LinalgDependenceGraph &dependenceGraph,
       LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(),
       LinalgTransformationFilter originalOpMarker =
           LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : LinalgBaseTileAndFusePattern(
             OpTy::getOperationName(), context, dependenceGraph, tilingOptions,
-            fusionOptions, filter, fusedOpMarker, originalOpMarker, benefit) {}
+            fusionOptions, f, fusedOpMarker, originalOpMarker, benefit) {}
 };
 
 ///
@@ -731,13 +748,13 @@ struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
   // Entry point to match any LinalgOp.
   LinalgTileAndFuseTensorOpsPattern(
       MLIRContext *context, LinalgTilingAndFusionOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
   // Entry point to match a specific LinalgOp.
   LinalgTileAndFuseTensorOpsPattern(
       StringRef opName, MLIRContext *context,
       LinalgTilingAndFusionOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override;
@@ -757,12 +774,22 @@ struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern {
 /// See `interchange` for more details.
 struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
   using OpRewritePattern<GenericOp>::OpRewritePattern;
+
+  /// GenericOp-specific constructor with an optional `filter`.
   GenericOpInterchangePattern(
       MLIRContext *context, ArrayRef<unsigned> interchangeVector,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  LogicalResult matchAndRewrite(GenericOp genericOp,
-                                PatternRewriter &rewriter) const override;
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  FailureOr<GenericOp>
+  returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(GenericOp op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
@@ -777,19 +804,29 @@ struct GenericOpInterchangePattern : public OpRewritePattern<GenericOp> {
 /// Apply the `generalization` transformation as a pattern.
 /// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `generalization` for more details.
-struct LinalgGeneralizationPattern : public RewritePattern {
-  // Entry point to match any LinalgOp OpInterface.
+struct LinalgGeneralizationPattern
+    : public OpInterfaceRewritePattern<LinalgOp> {
+  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
   LinalgGeneralizationPattern(
       MLIRContext *context,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  // Entry point to match a specific Linalg op.
+
+  /// Construct a pattern specifically applied to `opName`.
   LinalgGeneralizationPattern(
       StringRef opName, MLIRContext *context,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  LogicalResult matchAndRewrite(Operation *op,
-                                PatternRewriter &rewriter) const override;
+
+  /// `matchAndRewrite` implementation that returns the significant transformed
+  /// pieces of IR.
+  FailureOr<GenericOp>
+  returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const;
+
+  LogicalResult matchAndRewrite(LinalgOp op,
+                                PatternRewriter &rewriter) const override {
+    return returningMatchAndRewrite(op, rewriter);
+  }
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
@@ -806,13 +843,13 @@ struct LinalgBasePromotionPattern : public RewritePattern {
   /// Entry point to match any LinalgOp OpInterface.
   /// MatchAnyOpTag-based constructor with a mandatory `filter`.
   LinalgBasePromotionPattern(
-      MLIRContext *context, LinalgTransformationFilter filter,
+      MLIRContext *context, LinalgTransformationFilter f,
       LinalgPromotionOptions options = LinalgPromotionOptions(),
       PatternBenefit benefit = 1);
   /// Entry point to match a specific Linalg op.
   LinalgBasePromotionPattern(
       StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
 
   LogicalResult matchAndRewrite(Operation *op,
@@ -832,16 +869,16 @@ struct LinalgPromotionPattern : public LinalgBasePromotionPattern {
   template <typename ConcreateOpTy = OpTy>
   LinalgPromotionPattern(
       MLIRContext *context, LinalgPromotionOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options,
-                                   filter, benefit) {}
+                                   f, benefit) {}
   /// This constructor is available to anyone.
   LinalgPromotionPattern(
       StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
-      : LinalgBasePromotionPattern(opName, context, options, filter, benefit) {}
+      : LinalgBasePromotionPattern(opName, context, options, f, benefit) {}
 };
 
 ///
@@ -852,39 +889,28 @@ struct LinalgVectorizationOptions {};
 
 /// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `vectorizeLinalgOp` for more details.
-struct LinalgBaseVectorizationPattern : public RewritePattern {
-  /// MatchAnyOpTag-based constructor with a mandatory `filter`.
-  LinalgBaseVectorizationPattern(MLIRContext *context,
-                                 LinalgTransformationFilter filter,
-                                 PatternBenefit benefit = 1);
-  /// Name-based constructor with an optional `filter`.
-  LinalgBaseVectorizationPattern(
+struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
+  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
+  LinalgVectorizationPattern(
+      MLIRContext *context,
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
+      LinalgVectorizationOptions options = LinalgVectorizationOptions(),
+      PatternBenefit benefit = 1);
+
+  /// Construct a pattern specifically applied to `opName`.
+  LinalgVectorizationPattern(
       StringRef opName, MLIRContext *context,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgVectorizationOptions options = LinalgVectorizationOptions(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1);
-  LogicalResult matchAndRewrite(Operation *op,
+
+  LogicalResult matchAndRewrite(LinalgOp linalgOp,
                                 PatternRewriter &rewriter) const override;
 
 private:
   /// LinalgTransformMarker handles special attribute manipulations.
   LinalgTransformationFilter filter;
-};
-
-struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern {
-  /// These constructors are available to anyone.
-  /// MatchAnyOpTag-based constructor with a mandatory `filter`.
-  LinalgVectorizationPattern(
-      MLIRContext *context, LinalgTransformationFilter filter,
-      LinalgVectorizationOptions options = LinalgVectorizationOptions(),
-      PatternBenefit benefit = 1)
-      : LinalgBaseVectorizationPattern(context, filter, benefit) {}
-  /// Name-based constructor with an optional `filter`.
-  LinalgVectorizationPattern(
-      StringRef opName, MLIRContext *context,
-      LinalgVectorizationOptions options = LinalgVectorizationOptions(),
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1)
-      : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {}
+  LinalgVectorizationOptions options;
 };
 
 //===----------------------------------------------------------------------===//
@@ -1008,48 +1034,6 @@ struct LinalgVectorLoweringOptions {
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
 //===----------------------------------------------------------------------===//
-/// Trait to check if T provides a `getOperationName` method.
-template <typename T, typename... Args>
-using has_get_operation_name = decltype(T::getOperationName());
-template <typename T>
-using detect_has_get_operation_name =
-    llvm::is_detected<has_get_operation_name, T>;
-
-/// SFINAE helper for single C++ op with a `getOperationName` method.
-template <
-    typename OpType,
-    typename = std::enable_if_t<detect_has_get_operation_name<OpType>::value>,
-    typename = void>
-void insertVectorizationPatternImpl(RewritePatternSet &patternList,
-                                    linalg::LinalgVectorizationOptions options,
-                                    linalg::LinalgTransformationFilter f) {
-  patternList.add<linalg::LinalgVectorizationPattern>(
-      OpType::getOperationName(), patternList.getContext(), options, f);
-}
-
-/// SFINAE helper for single C++ class without a `getOperationName` method (e.g.
-/// an OpInterface).
-template <typename OpType, typename = std::enable_if_t<
-                               !detect_has_get_operation_name<OpType>::value>>
-void insertVectorizationPatternImpl(RewritePatternSet &patternList,
-                                    linalg::LinalgVectorizationOptions options,
-                                    linalg::LinalgTransformationFilter f) {
-  patternList.add<linalg::LinalgVectorizationPattern>(
-      patternList.getContext(), f.addOpFilter<OpType>(), options);
-}
-
-/// Variadic helper function to insert vectorization patterns for C++ ops.
-template <typename... OpTypes>
-void insertVectorizationPatterns(RewritePatternSet &patternList,
-                                 linalg::LinalgVectorizationOptions options,
-                                 linalg::LinalgTransformationFilter f =
-                                     linalg::LinalgTransformationFilter()) {
-  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
-  (void)std::initializer_list<int>{
-      0,
-      (insertVectorizationPatternImpl<OpTypes>(patternList, options, f), 0)...};
-}
-
 ///
 /// Linalg lowering patterns.
 ///
@@ -1067,10 +1051,10 @@ template <typename OpTy>
 struct LinalgLoweringPattern : public RewritePattern {
   LinalgLoweringPattern(
       MLIRContext *context, LinalgLoweringType loweringType,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : RewritePattern(OpTy::getOperationName(), benefit, context),
-        filter(filter), loweringType(loweringType) {}
+        filter(std::move(f)), loweringType(loweringType) {}
 
   // TODO: Move implementation to .cpp once named ops are auto-generated.
   LogicalResult matchAndRewrite(Operation *op,
@@ -1352,6 +1336,29 @@ struct ExtractSliceOfPadTensorSwapPattern
 //===----------------------------------------------------------------------===//
 // Helper classes for type list expansion.
 //===----------------------------------------------------------------------===//
+template <typename... OpTypes>
+class VectorizationPatterns;
+
+template <>
+class VectorizationPatterns<> {
+public:
+  static void insert(RewritePatternSet &patterns,
+                     const LinalgVectorizationOptions &options,
+                     const LinalgTransformationFilter &f) {}
+};
+
+template <typename OpTy, typename... OpTypes>
+class VectorizationPatterns<OpTy, OpTypes...> {
+public:
+  static void insert(RewritePatternSet &patterns,
+                     const LinalgVectorizationOptions &options,
+                     const LinalgTransformationFilter &f) {
+    patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
+                                             patterns.getContext(), options, f);
+    VectorizationPatterns<OpTypes...>::insert(patterns, options, f);
+  }
+};
+
 template <typename... OpTypes>
 class TilingPatterns;
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index c1482f44b4cd..22f95653701c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -357,11 +357,11 @@ mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
     StringRef opName, MLIRContext *context,
     const LinalgDependenceGraph &dependenceGraph,
     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
-    LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
+    LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker,
     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
     : RewritePattern(opName, benefit, context, {}),
       dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)),
-      fusionOptions(std::move(fusionOptions)), filter(std::move(filter)),
+      fusionOptions(std::move(fusionOptions)), filter(std::move(f)),
       fusedOpMarker(std::move(fusedOpMarker)),
       originalOpMarker(std::move(originalOpMarker)) {}
 
@@ -462,11 +462,7 @@ mlir::linalg::LinalgTilingPattern::LinalgTilingPattern(
     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
     LinalgTransformationFilter f, PatternBenefit benefit)
     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
-      filter(std::move(f)), options(std::move(options)) {
-  this->filter.addFilter([opName](Operation *op) {
-    return success(op->getName().getStringRef() == opName);
-  });
-}
+      filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
 
 FailureOr<TiledLinalgOp>
 mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
@@ -496,21 +492,18 @@ mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite(
 /// Linalg padding pattern.
 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
     MLIRContext *context, LinalgPaddingOptions options,
-    LinalgTransformationFilter filter, PatternBenefit benefit)
+    LinalgTransformationFilter f, PatternBenefit benefit)
     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
-      filter(std::move(filter)), options(std::move(options)) {}
+      filter(std::move(f)), options(std::move(options)) {}
 
 mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern(
     StringRef opName, MLIRContext *context, LinalgPaddingOptions options,
-    LinalgTransformationFilter filter, PatternBenefit benefit)
+    LinalgTransformationFilter f, PatternBenefit benefit)
     : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
-      filter(std::move(filter)), options(std::move(options)) {
-  this->filter.addFilter([opName](Operation *op) {
-    return success(op->getName().getStringRef() == opName);
-  });
-}
+      filter(f.addOpNameFilter(opName)), options(std::move(options)) {}
 
-LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
+FailureOr<LinalgOp>
+mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
     LinalgOp linalgOp, PatternRewriter &rewriter) const {
   if (!linalgOp.hasTensorSemantics())
     return failure();
@@ -549,24 +542,24 @@ LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite(
   // Replace the original operation to pad.
   rewriter.replaceOp(linalgOp, newResults.getValue());
   filter.replaceLinalgTransformationFilter(rewriter, paddedOp);
-  return success();
+  return paddedOp;
 }
 
 /// Linalg tile and fuse tensor ops pattern.
 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
     LinalgTileAndFuseTensorOpsPattern(MLIRContext *context,
                                       LinalgTilingAndFusionOptions options,
-                                      LinalgTransformationFilter filter,
+                                      LinalgTransformationFilter f,
                                       PatternBenefit benefit)
     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-      filter(std::move(filter)), options(std::move(options)) {}
+      filter(std::move(f)), options(std::move(options)) {}
 
 mlir::linalg::LinalgTileAndFuseTensorOpsPattern::
     LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context,
                                       LinalgTilingAndFusionOptions options,
-                                      LinalgTransformationFilter filter,
+                                      LinalgTransformationFilter f,
                                       PatternBenefit benefit)
-    : RewritePattern(opName, benefit, context), filter(std::move(filter)),
+    : RewritePattern(opName, benefit, context), filter(std::move(f)),
       options(std::move(options)) {}
 
 LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
@@ -624,11 +617,12 @@ LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite(
 /// Linalg generic interchange pattern.
 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
-    LinalgTransformationFilter filter, PatternBenefit benefit)
-    : OpRewritePattern(context, benefit), filter(std::move(filter)),
+    LinalgTransformationFilter f, PatternBenefit benefit)
+    : OpRewritePattern(context, benefit), filter(std::move(f)),
       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
 
-LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
+FailureOr<GenericOp>
+mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite(
     GenericOp genericOp, PatternRewriter &rewriter) const {
   if (failed(filter.checkAndNotify(rewriter, genericOp)))
     return failure();
@@ -645,41 +639,38 @@ LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
 
 /// Linalg generalization pattern.
 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
-    MLIRContext *context, LinalgTransformationFilter filter,
-    PatternBenefit benefit)
-    : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-      filter(std::move(filter)) {}
+    MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit)
+    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+      filter(std::move(f)) {}
 
 mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern(
-    StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
+    StringRef opName, MLIRContext *context, LinalgTransformationFilter f,
     PatternBenefit benefit)
-    : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)) {}
+    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+      filter(f.addOpNameFilter(opName)) {}
 
-LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
-  // TODO: Interface pattern.
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
-    return failure();
-  if (failed(filter.checkAndNotify(rewriter, op)))
+FailureOr<GenericOp>
+mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite(
+    LinalgOp linalgOp, PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
   FailureOr<GenericOp> genericOp = generalizeNamedOp(rewriter, linalgOp);
   if (failed(genericOp))
     return failure();
   filter.replaceLinalgTransformationFilter(rewriter, *genericOp);
-  return success();
+  return genericOp;
 }
 
 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
-    MLIRContext *context, LinalgTransformationFilter filter,
+    MLIRContext *context, LinalgTransformationFilter f,
     LinalgPromotionOptions options, PatternBenefit benefit)
     : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-      filter(std::move(filter)), options(std::move(options)) {}
+      filter(std::move(f)), options(std::move(options)) {}
 
 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
-    LinalgTransformationFilter filter, PatternBenefit benefit)
-    : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)),
+    LinalgTransformationFilter f, PatternBenefit benefit)
+    : RewritePattern(opName, benefit, context, {}), filter(std::move(f)),
       options(std::move(options)) {}
 
 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
@@ -704,24 +695,21 @@ LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
   return success();
 }
 
-mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
-    MLIRContext *context, LinalgTransformationFilter filter,
-    PatternBenefit benefit)
-    : RewritePattern(MatchAnyOpTypeTag(), benefit, context),
-      filter(std::move(filter)) {}
+mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
+    MLIRContext *context, LinalgTransformationFilter f,
+    LinalgVectorizationOptions options, PatternBenefit benefit)
+    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+      filter(std::move(f)) {}
 
-mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
-    StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
-    PatternBenefit benefit)
-    : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)) {}
+mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
+    StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
+    LinalgTransformationFilter f, PatternBenefit benefit)
+    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
+      filter(f.addOpNameFilter(opName)) {}
 
-LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
-    Operation *op, PatternRewriter &rewriter) const {
-  // TODO: Interface-based rewrite.
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
-  if (!linalgOp)
-    return failure();
-  if (failed(filter.checkAndNotify(rewriter, op)))
+LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
+    LinalgOp linalgOp, PatternRewriter &rewriter) const {
+  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
     return failure();
   return vectorize(rewriter, linalgOp);
 }
@@ -947,10 +935,10 @@ struct DownscaleSizeOneWindowed2DConvolution final
     : public OpRewritePattern<Conv2DNhwcHwcfOp> {
   DownscaleSizeOneWindowed2DConvolution(
       MLIRContext *context,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : OpRewritePattern<Conv2DNhwcHwcfOp>(context, benefit),
-        filter(std::move(filter)) {}
+        filter(std::move(f)) {}
 
   LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
                                 PatternRewriter &rewriter) const override {
@@ -1033,10 +1021,10 @@ struct DownscaleDepthwiseConv2DNhwcHwcOp final
     : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> {
   DownscaleDepthwiseConv2DNhwcHwcOp(
       MLIRContext *context,
-      LinalgTransformationFilter filter = LinalgTransformationFilter(),
+      LinalgTransformationFilter f = LinalgTransformationFilter(),
       PatternBenefit benefit = 1)
       : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit),
-        filter(std::move(filter)) {}
+        filter(std::move(f)) {}
 
   LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp,
                                 PatternRewriter &rewriter) const override {

diff  --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
index 0c8ab052a88c..aad40c672c38 100644
--- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp
@@ -300,8 +300,7 @@ static void fillL1TilingAndMatmulToVectorPatterns(
                MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
                LinalgTransformationFilter(StringAttr::get(ctx, "VEC"))));
   patternsVector.back().add<LinalgVectorizationPattern>(
-      ctx, LinalgTransformationFilter().addFilter(
-               [](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
+      ctx, LinalgTransformationFilter().addOpFilter<FillOp, CopyOp>());
 }
 
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list