[Mlir-commits] [mlir] [mlir] Add base class type aliases for rewrites/conversions. NFC. (PR #158433)
    Jakub Kuderski 
    llvmlistbot at llvm.org
       
    Sat Sep 13 12:52:29 PDT 2025
    
    
  
https://github.com/kuhar created https://github.com/llvm/llvm-project/pull/158433
This is to simplify writing rewrite/conversion patterns that usually start with:
```c++
struct MyPattern : public OpRewritePattern<MyOp> {
  using OpRewritePattern::OpRewritePattern;
```
and allow for:
```c++
struct MyPattern : public OpRewritePattern<MyOp> {
  using Base::Base;
```
similar to how we enable it for pass classes.
>From 918730938658cda069037e5084e30de14a7db96a Mon Sep 17 00:00:00 2001
From: Jakub Kuderski <jakub at nod-labs.com>
Date: Sat, 13 Sep 2025 15:43:59 -0400
Subject: [PATCH] [mlir] Add base class type aliases for rewrites/conversions.
 NFC.
This is to simplify writing rewrite/conversion patterns that usually
start with:
```c++
struct MyPattern : public OpRewritePattern<MyPattern> {
  using OpRewritePattern::OpRewritePattern;
```
and allow for:
```c++
struct MyPattern : public OpRewritePattern<MyPattern> {
  using Base::Base;
```
similar to pass classes.
---
 mlir/include/mlir/IR/PatternMatch.h              | 10 ++++++++++
 mlir/include/mlir/Transforms/DialectConversion.h | 16 ++++++++++++++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp      | 11 ++++++++---
 3 files changed, 34 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 7b0b9cef9c5bd..576481a6e7215 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -312,6 +312,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
 template <typename SourceOp>
 struct OpRewritePattern
     : public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  /// Type alias to allow derived classes to inherit constructors with
+  /// `using Base::Base;`.
+  using Base = OpRewritePattern;
 
   /// Patterns must specify the root operation name they match against, and can
   /// also specify the benefit of the pattern matching and a list of generated
@@ -328,6 +331,9 @@ struct OpRewritePattern
 template <typename SourceOp>
 struct OpInterfaceRewritePattern
     : public mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp> {
+  /// Type alias to allow derived classes to inherit constructors with
+  /// `using Base::Base;`.
+  using Base = OpInterfaceRewritePattern;
 
   OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
       : mlir::detail::OpOrInterfaceRewritePatternBase<SourceOp>(
@@ -341,6 +347,10 @@ struct OpInterfaceRewritePattern
 template <template <typename> class TraitType>
 class OpTraitRewritePattern : public RewritePattern {
 public:
+  /// Type alias to allow derived classes to inherit constructors with
+  /// `using Base::Base;`.
+  using Base = OpTraitRewritePattern;
+
   OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
       : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
                        benefit, context) {}
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index bfbe12d2a5668..6ef649e8fc13a 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -40,6 +40,10 @@ class Value;
 /// registered using addConversion and addMaterialization, respectively.
 class TypeConverter {
 public:
+  /// Type alias to allow derived classes to inherit constructors with
+  /// `using Base::Base;`.
+  using Base = TypeConverter;
+
   virtual ~TypeConverter() = default;
   TypeConverter() = default;
   // Copy the registered conversions, but not the caches
@@ -679,6 +683,10 @@ class ConversionPattern : public RewritePattern {
 template <typename SourceOp>
 class OpConversionPattern : public ConversionPattern {
 public:
+  /// Type alias to allow derived classes to inherit constructors with
+  /// `using Base::Base;`.
+  using Base = OpConversionPattern;
+
   using OpAdaptor = typename SourceOp::Adaptor;
   using OneToNOpAdaptor =
       typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
@@ -729,6 +737,10 @@ class OpConversionPattern : public ConversionPattern {
 template <typename SourceOp>
 class OpInterfaceConversionPattern : public ConversionPattern {
 public:
+  /// Type alias to allow derived classes to inherit constructors with
+  /// `using Base::Base;`.
+  using Base = OpInterfaceConversionPattern;
+
   OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
                           SourceOp::getInterfaceID(), benefit, context) {}
@@ -773,6 +785,10 @@ class OpInterfaceConversionPattern : public ConversionPattern {
 template <template <typename> class TraitType>
 class OpTraitConversionPattern : public ConversionPattern {
 public:
+  /// Type alias to allow derived classes to inherit constructors with
+  /// `using Base::Base;`.
+  using Base = OpTraitConversionPattern;
+
   OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
       : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
                           TypeID::get<TraitType>(), benefit, context) {}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 93b007c792ad9..f8b5144e3acb2 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Matchers.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -114,7 +115,8 @@ struct FoldingPattern : public RewritePattern {
 struct FolderInsertBeforePreviouslyFoldedConstantPattern
     : public OpRewritePattern<TestCastOp> {
 public:
-  using OpRewritePattern<TestCastOp>::OpRewritePattern;
+  static_assert(std::is_same_v<Base, OpRewritePattern<TestCastOp>>);
+  using Base::Base;
 
   LogicalResult matchAndRewrite(TestCastOp op,
                                 PatternRewriter &rewriter) const override {
@@ -1306,7 +1308,8 @@ class TestReplaceWithValidConsumer : public ConversionPattern {
 /// b) or: drops all block arguments and replaces each with 2x the first
 ///    operand.
 class TestConvertBlockArgs : public OpConversionPattern<ConvertBlockArgsOp> {
-  using OpConversionPattern<ConvertBlockArgsOp>::OpConversionPattern;
+  static_assert(std::is_same_v<Base, OpConversionPattern<ConvertBlockArgsOp>>);
+  using Base::Base;
 
   LogicalResult
   matchAndRewrite(ConvertBlockArgsOp op, OpAdaptor adaptor,
@@ -1431,7 +1434,9 @@ class TestTypeConsumerOpPattern
 
 namespace {
 struct TestTypeConverter : public TypeConverter {
-  using TypeConverter::TypeConverter;
+  static_assert(std::is_same_v<Base, TypeConverter>);
+  using Base::Base;
+
   TestTypeConverter() {
     addConversion(convertType);
     addSourceMaterialization(materializeCast);
    
    
More information about the Mlir-commits
mailing list