[Mlir-commits] [mlir] f46c2c4 - [MLIR] Convert DialectReductionPatternInterface using ODS (#180640)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 16 10:04:43 PST 2026


Author: AidinT
Date: 2026-02-16T20:04:38+02:00
New Revision: f46c2c46d19996911271c40d4b4bc4cc6ad0e591

URL: https://github.com/llvm/llvm-project/commit/f46c2c46d19996911271c40d4b4bc4cc6ad0e591
DIFF: https://github.com/llvm/llvm-project/commit/f46c2c46d19996911271c40d4b4bc4cc6ad0e591.diff

LOG: [MLIR] Convert DialectReductionPatternInterface using ODS (#180640)

This PR converts `DialectReductionPatternInterface` using ODS.

It also introduces a new Interface Method class:

`PureVirtualInterfaceMethod` which creates the method as pure virtual.

Added: 
    mlir/include/mlir/Reducer/DialectReductionPatternInterface.td

Modified: 
    mlir/include/mlir/IR/Interfaces.td
    mlir/include/mlir/Reducer/CMakeLists.txt
    mlir/include/mlir/Reducer/ReductionPatternInterface.h
    mlir/include/mlir/TableGen/Interfaces.h
    mlir/lib/Reducer/CMakeLists.txt
    mlir/lib/TableGen/Interfaces.cpp
    mlir/test/mlir-tblgen/dialect-interface.td
    mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td
index 149a254fa7a0d..e16de2942a043 100644
--- a/mlir/include/mlir/IR/Interfaces.td
+++ b/mlir/include/mlir/IR/Interfaces.td
@@ -85,6 +85,11 @@ class StaticInterfaceMethod<string desc, string retTy, string methodName,
     : InterfaceMethod<desc, retTy, methodName, args, methodBody,
                       defaultImplementation>;
 
+// This class represents a pure virtual interface method.
+class PureVirtualInterfaceMethod<string desc, string retTy, string methodName,
+                            dag args = (ins)>
+    : InterfaceMethod<desc, retTy, methodName, args>;
+
 // This class represents a interface method declaration.
 class InterfaceMethodDeclaration<string desc, string retTy, string methodName,
                             dag args = (ins)>

diff  --git a/mlir/include/mlir/Reducer/CMakeLists.txt b/mlir/include/mlir/Reducer/CMakeLists.txt
index 3d09f87c6f17e..37a19a481adc4 100644
--- a/mlir/include/mlir/Reducer/CMakeLists.txt
+++ b/mlir/include/mlir/Reducer/CMakeLists.txt
@@ -2,4 +2,8 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
 mlir_tablegen(Passes.h.inc -gen-pass-decls -name Reducer)
 add_mlir_generic_tablegen_target(MLIRReducerIncGen)
 
+set(LLVM_TARGET_DEFINITIONS DialectReductionPatternInterface.td)
+mlir_tablegen(DialectReductionPatternInterface.h.inc -gen-dialect-interface-decls)
+add_mlir_generic_tablegen_target(MLIRDialectReductionPatternInterfaceIncGen)
+
 add_mlir_doc(Passes ReducerPasses ./ -gen-pass-doc)

diff  --git a/mlir/include/mlir/Reducer/DialectReductionPatternInterface.td b/mlir/include/mlir/Reducer/DialectReductionPatternInterface.td
new file mode 100644
index 0000000000000..8f33897d7a75d
--- /dev/null
+++ b/mlir/include/mlir/Reducer/DialectReductionPatternInterface.td
@@ -0,0 +1,59 @@
+#ifndef MLIR_INTERFACES_DIALECTREDUCTIONPATTERNINTERFACE
+#define MLIR_INTERFACES_DIALECTREDUCTIONPATTERNINTERFACE
+
+include "mlir/IR/Interfaces.td"
+
+def DialectReductionPatternInterface : DialectInterface<"DialectReductionPatternInterface"> {
+  let description = [{
+    This is used to report the reduction patterns for a Dialect. While using
+    mlir-reduce to reduce a module, we may want to transform certain cases into
+    simpler forms by applying certain rewrite patterns. Implement the
+    `populateReductionPatterns` to report those patterns by adding them to the
+    RewritePatternSet.
+
+    Example:
+      MyDialectReductionPattern::populateReductionPatterns(
+          RewritePatternSet &patterns) {
+          patterns.add<TensorOpReduction>(patterns.getContext());
+      }
+
+    For DRR, mlir-tblgen will generate a helper function
+    `populateWithGenerated` which has the same signature therefore you can
+    delegate to the helper function as well.
+
+    Example:
+      MyDialectReductionPattern::populateReductionPatterns(
+          RewritePatternSet &patterns) {
+          // Include the autogen file somewhere above.
+          populateWithGenerated(patterns);
+      }
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    PureVirtualInterfaceMethod<[{
+        Patterns provided here are intended to transform operations from a complex
+        form to a simpler form, without breaking the semantics of the program
+        being reduced. For example, you may want to replace the
+        tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
+        replacing an operation with a constant.
+      }],
+      "void", "populateReductionPatterns",
+      (ins "::mlir::RewritePatternSet &":$patterns)
+    >,
+    InterfaceMethod<[{
+        This method extends `populateReductionPatterns` by allowing reduction
+        patterns to use a `Tester` instance. Some reduction patterns may need to
+        run tester to determine whether certain transformations preserve the
+        "interesting" behavior of the program. This is mostly useful when pattern
+        should choose between multiple modifications.
+      }],
+      "void", "populateReductionPatternsWithTester",
+      (ins "::mlir::RewritePatternSet &":$patterns, "::mlir::Tester &":$tester),
+      [{}]
+    > 
+  ];
+}
+
+
+#endif

diff  --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index a33877dc0bd77..7d7a1eeea27ac 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -13,52 +13,9 @@
 #include "mlir/Reducer/Tester.h"
 
 namespace mlir {
-
 class RewritePatternSet;
-
-/// This is used to report the reduction patterns for a Dialect. While using
-/// mlir-reduce to reduce a module, we may want to transform certain cases into
-/// simpler forms by applying certain rewrite patterns. Implement the
-/// `populateReductionPatterns` to report those patterns by adding them to the
-/// RewritePatternSet.
-///
-/// Example:
-///   MyDialectReductionPattern::populateReductionPatterns(
-///       RewritePatternSet &patterns) {
-///       patterns.add<TensorOpReduction>(patterns.getContext());
-///   }
-///
-/// For DRR, mlir-tblgen will generate a helper function
-/// `populateWithGenerated` which has the same signature therefore you can
-/// delegate to the helper function as well.
-///
-/// Example:
-///   MyDialectReductionPattern::populateReductionPatterns(
-///       RewritePatternSet &patterns) {
-///       // Include the autogen file somewhere above.
-///       populateWithGenerated(patterns);
-///   }
-class DialectReductionPatternInterface
-    : public DialectInterface::Base<DialectReductionPatternInterface> {
-public:
-  /// Patterns provided here are intended to transform operations from a complex
-  /// form to a simpler form, without breaking the semantics of the program
-  /// being reduced. For example, you may want to replace the
-  /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
-  /// replacing an operation with a constant.
-  virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
-
-  /// This method extends `populateReductionPatterns` by allowing reduction
-  /// patterns to use a `Tester` instance. Some reduction patterns may need to
-  /// run tester to determine whether certain transformations preserve the
-  /// "interesting" behavior of the program. This is mostly useful when pattern
-  /// should choose between multiple modifications.
-  virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns,
-                                                   Tester &tester) const {}
-
-protected:
-  DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
-};
 } // namespace mlir
 
+#include "mlir/Reducer/DialectReductionPatternInterface.h.inc"
+
 #endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H

diff  --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 09b5d6cbf6aa6..5b2bb34f8c153 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -47,6 +47,9 @@ class InterfaceMethod {
   // Return if this method is static.
   bool isStatic() const;
 
+  // Return if the method is a pure virtual one.
+  bool isPureVirtual() const;
+
   // Return if the method is only a declaration.
   bool isDeclaration() const;
 

diff  --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index a723263e4f41a..68864e373c993 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRReduce
 
    DEPENDS
    MLIRReducerIncGen
+   MLIRDialectReductionPatternInterfaceIncGen
 )
 
 mlir_check_all_link_libraries(MLIRReduce)

diff  --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index f92ef18710941..f4fa65777f585 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -52,6 +52,11 @@ bool InterfaceMethod::isStatic() const {
   return def->isSubClassOf("StaticInterfaceMethod");
 }
 
+// Return if the method is a pure virtual one.
+bool InterfaceMethod::isPureVirtual() const {
+  return def->isSubClassOf("PureVirtualInterfaceMethod");
+}
+
 // Return if the method is only a declaration.
 bool InterfaceMethod::isDeclaration() const {
   return def->isSubClassOf("InterfaceMethodDeclaration");

diff  --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td
index d9035a63b2d2e..5e41fdedfa761 100644
--- a/mlir/test/mlir-tblgen/dialect-interface.td
+++ b/mlir/test/mlir-tblgen/dialect-interface.td
@@ -22,9 +22,9 @@ def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
 
 // DECL:   class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
 // DECL:   public:
-// DECL:   NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
 // DECL:   virtual bool isExampleDialect() const {}
 // DECL:   virtual unsigned supportSecondMethod(::mlir::Type type) const {}
+// DECL:   NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
 
 def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> {
   let description = [{
@@ -73,3 +73,24 @@ def OnlyDeclarationInterfaceWithExtraDecls : DialectInterface<"OnlyDeclarationIn
 // DECL:   class OnlyDeclarationInterfaceWithExtraDecls : public {{.*}}DialectInterface::Base<OnlyDeclarationInterfaceWithExtraDecls>
 // DECL:   virtual void exampleMethodDeclaration(::mlir::Type type) const;
 // DECL:   using DeclType = int;
+
+def PureVirtualInterface : DialectInterface<"PureVirtualInterface"> {
+  let description = [{
+    This is an example dialect interface with pure virtual methods.
+  }];
+
+  let cppNamespace = "::mlir::example";
+
+  let methods = [
+      PureVirtualInterfaceMethod<
+        "Check if it's an example dialect", "bool", "isExampleDialect",
+        (ins)
+      >
+  ];
+}
+
+// DECL:   class PureVirtualInterface : public {{.*}}DialectInterface::Base<PureVirtualInterface>
+// DECL:   public:
+// DECL:   virtual bool isExampleDialect() const = 0;
+// DECL:   protected:
+// DECL-NEXT:   PureVirtualInterface(::mlir::Dialect *dialect) : Base(dialect) {}

diff  --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
index 6ad426c78226d..e695b8c761895 100644
--- a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
@@ -113,6 +113,11 @@ static void emitInterfaceMethodsDef(const DialectInterface &interface,
       continue;
     }
 
+    if (method.isPureVirtual()) {
+      ios << " = 0;\n";
+      continue;
+    }
+
     // if it is not a method declaration, then it's a normal interface method.
     ios << " {";
 
@@ -126,6 +131,27 @@ static void emitInterfaceMethodsDef(const DialectInterface &interface,
   }
 }
 
+static void emitConstructor(const DialectInterface &interface,
+                            raw_ostream &os) {
+
+  raw_indented_ostream ios(os);
+
+  // We consider a constructor protected if interface has at least one pure
+  // virtual method
+  auto hasProtectedConstructor =
+      llvm::any_of(interface.getMethods(), [](const InterfaceMethod &method) {
+        return method.isPureVirtual();
+      });
+
+  ios.indent(0);
+  if (hasProtectedConstructor)
+    ios << "protected:\n";
+
+  ios.indent(2);
+  ios << llvm::formatv("{0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
+                       interface.getName());
+}
+
 void DialectInterfaceGenerator::emitInterfaceDecl(
     const DialectInterface &interface) {
   llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
@@ -135,9 +161,8 @@ void DialectInterfaceGenerator::emitInterfaceDecl(
 
   // Emit the main interface class declaration.
   os << llvm::formatv(
-      "class {0} : public ::mlir::DialectInterface::Base<{0}> {\n"
-      "public:\n"
-      "  {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
+      "class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n"
+      "public:\n",
       interface.getName());
 
   emitInterfaceMethodsDef(interface, os);
@@ -151,6 +176,10 @@ void DialectInterfaceGenerator::emitInterfaceDecl(
     ios << "\n";
   }
 
+  os << "\n";
+
+  emitConstructor(interface, os);
+
   os << "};\n";
 }
 


        


More information about the Mlir-commits mailing list