[Mlir-commits] [mlir] f46c2c4 - [MLIR] Convert DialectReductionPatternInterface using ODS (#180640)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 16 10:04:43 PST 2026
Author: AidinT
Date: 2026-02-16T20:04:38+02:00
New Revision: f46c2c46d19996911271c40d4b4bc4cc6ad0e591
URL: https://github.com/llvm/llvm-project/commit/f46c2c46d19996911271c40d4b4bc4cc6ad0e591
DIFF: https://github.com/llvm/llvm-project/commit/f46c2c46d19996911271c40d4b4bc4cc6ad0e591.diff
LOG: [MLIR] Convert DialectReductionPatternInterface using ODS (#180640)
This PR converts `DialectReductionPatternInterface` using ODS.
It also introduces a new Interface Method class:
`PureVirtualInterfaceMethod` which creates the method as pure virtual.
Added:
mlir/include/mlir/Reducer/DialectReductionPatternInterface.td
Modified:
mlir/include/mlir/IR/Interfaces.td
mlir/include/mlir/Reducer/CMakeLists.txt
mlir/include/mlir/Reducer/ReductionPatternInterface.h
mlir/include/mlir/TableGen/Interfaces.h
mlir/lib/Reducer/CMakeLists.txt
mlir/lib/TableGen/Interfaces.cpp
mlir/test/mlir-tblgen/dialect-interface.td
mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td
index 149a254fa7a0d..e16de2942a043 100644
--- a/mlir/include/mlir/IR/Interfaces.td
+++ b/mlir/include/mlir/IR/Interfaces.td
@@ -85,6 +85,11 @@ class StaticInterfaceMethod<string desc, string retTy, string methodName,
: InterfaceMethod<desc, retTy, methodName, args, methodBody,
defaultImplementation>;
+// This class represents a pure virtual interface method.
+class PureVirtualInterfaceMethod<string desc, string retTy, string methodName,
+ dag args = (ins)>
+ : InterfaceMethod<desc, retTy, methodName, args>;
+
// This class represents a interface method declaration.
class InterfaceMethodDeclaration<string desc, string retTy, string methodName,
dag args = (ins)>
diff --git a/mlir/include/mlir/Reducer/CMakeLists.txt b/mlir/include/mlir/Reducer/CMakeLists.txt
index 3d09f87c6f17e..37a19a481adc4 100644
--- a/mlir/include/mlir/Reducer/CMakeLists.txt
+++ b/mlir/include/mlir/Reducer/CMakeLists.txt
@@ -2,4 +2,8 @@ set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Reducer)
add_mlir_generic_tablegen_target(MLIRReducerIncGen)
+set(LLVM_TARGET_DEFINITIONS DialectReductionPatternInterface.td)
+mlir_tablegen(DialectReductionPatternInterface.h.inc -gen-dialect-interface-decls)
+add_mlir_generic_tablegen_target(MLIRDialectReductionPatternInterfaceIncGen)
+
add_mlir_doc(Passes ReducerPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Reducer/DialectReductionPatternInterface.td b/mlir/include/mlir/Reducer/DialectReductionPatternInterface.td
new file mode 100644
index 0000000000000..8f33897d7a75d
--- /dev/null
+++ b/mlir/include/mlir/Reducer/DialectReductionPatternInterface.td
@@ -0,0 +1,59 @@
+#ifndef MLIR_INTERFACES_DIALECTREDUCTIONPATTERNINTERFACE
+#define MLIR_INTERFACES_DIALECTREDUCTIONPATTERNINTERFACE
+
+include "mlir/IR/Interfaces.td"
+
+def DialectReductionPatternInterface : DialectInterface<"DialectReductionPatternInterface"> {
+ let description = [{
+ This is used to report the reduction patterns for a Dialect. While using
+ mlir-reduce to reduce a module, we may want to transform certain cases into
+ simpler forms by applying certain rewrite patterns. Implement the
+ `populateReductionPatterns` to report those patterns by adding them to the
+ RewritePatternSet.
+
+ Example:
+ MyDialectReductionPattern::populateReductionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<TensorOpReduction>(patterns.getContext());
+ }
+
+ For DRR, mlir-tblgen will generate a helper function
+ `populateWithGenerated` which has the same signature therefore you can
+ delegate to the helper function as well.
+
+ Example:
+ MyDialectReductionPattern::populateReductionPatterns(
+ RewritePatternSet &patterns) {
+ // Include the autogen file somewhere above.
+ populateWithGenerated(patterns);
+ }
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ PureVirtualInterfaceMethod<[{
+ Patterns provided here are intended to transform operations from a complex
+ form to a simpler form, without breaking the semantics of the program
+ being reduced. For example, you may want to replace the
+ tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
+ replacing an operation with a constant.
+ }],
+ "void", "populateReductionPatterns",
+ (ins "::mlir::RewritePatternSet &":$patterns)
+ >,
+ InterfaceMethod<[{
+ This method extends `populateReductionPatterns` by allowing reduction
+ patterns to use a `Tester` instance. Some reduction patterns may need to
+ run tester to determine whether certain transformations preserve the
+ "interesting" behavior of the program. This is mostly useful when pattern
+ should choose between multiple modifications.
+ }],
+ "void", "populateReductionPatternsWithTester",
+ (ins "::mlir::RewritePatternSet &":$patterns, "::mlir::Tester &":$tester),
+ [{}]
+ >
+ ];
+}
+
+
+#endif
diff --git a/mlir/include/mlir/Reducer/ReductionPatternInterface.h b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
index a33877dc0bd77..7d7a1eeea27ac 100644
--- a/mlir/include/mlir/Reducer/ReductionPatternInterface.h
+++ b/mlir/include/mlir/Reducer/ReductionPatternInterface.h
@@ -13,52 +13,9 @@
#include "mlir/Reducer/Tester.h"
namespace mlir {
-
class RewritePatternSet;
-
-/// This is used to report the reduction patterns for a Dialect. While using
-/// mlir-reduce to reduce a module, we may want to transform certain cases into
-/// simpler forms by applying certain rewrite patterns. Implement the
-/// `populateReductionPatterns` to report those patterns by adding them to the
-/// RewritePatternSet.
-///
-/// Example:
-/// MyDialectReductionPattern::populateReductionPatterns(
-/// RewritePatternSet &patterns) {
-/// patterns.add<TensorOpReduction>(patterns.getContext());
-/// }
-///
-/// For DRR, mlir-tblgen will generate a helper function
-/// `populateWithGenerated` which has the same signature therefore you can
-/// delegate to the helper function as well.
-///
-/// Example:
-/// MyDialectReductionPattern::populateReductionPatterns(
-/// RewritePatternSet &patterns) {
-/// // Include the autogen file somewhere above.
-/// populateWithGenerated(patterns);
-/// }
-class DialectReductionPatternInterface
- : public DialectInterface::Base<DialectReductionPatternInterface> {
-public:
- /// Patterns provided here are intended to transform operations from a complex
- /// form to a simpler form, without breaking the semantics of the program
- /// being reduced. For example, you may want to replace the
- /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
- /// replacing an operation with a constant.
- virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
-
- /// This method extends `populateReductionPatterns` by allowing reduction
- /// patterns to use a `Tester` instance. Some reduction patterns may need to
- /// run tester to determine whether certain transformations preserve the
- /// "interesting" behavior of the program. This is mostly useful when pattern
- /// should choose between multiple modifications.
- virtual void populateReductionPatternsWithTester(RewritePatternSet &patterns,
- Tester &tester) const {}
-
-protected:
- DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
-};
} // namespace mlir
+#include "mlir/Reducer/DialectReductionPatternInterface.h.inc"
+
#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 09b5d6cbf6aa6..5b2bb34f8c153 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -47,6 +47,9 @@ class InterfaceMethod {
// Return if this method is static.
bool isStatic() const;
+ // Return if the method is a pure virtual one.
+ bool isPureVirtual() const;
+
// Return if the method is only a declaration.
bool isDeclaration() const;
diff --git a/mlir/lib/Reducer/CMakeLists.txt b/mlir/lib/Reducer/CMakeLists.txt
index a723263e4f41a..68864e373c993 100644
--- a/mlir/lib/Reducer/CMakeLists.txt
+++ b/mlir/lib/Reducer/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_library(MLIRReduce
DEPENDS
MLIRReducerIncGen
+ MLIRDialectReductionPatternInterfaceIncGen
)
mlir_check_all_link_libraries(MLIRReduce)
diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index f92ef18710941..f4fa65777f585 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -52,6 +52,11 @@ bool InterfaceMethod::isStatic() const {
return def->isSubClassOf("StaticInterfaceMethod");
}
+// Return if the method is a pure virtual one.
+bool InterfaceMethod::isPureVirtual() const {
+ return def->isSubClassOf("PureVirtualInterfaceMethod");
+}
+
// Return if the method is only a declaration.
bool InterfaceMethod::isDeclaration() const {
return def->isSubClassOf("InterfaceMethodDeclaration");
diff --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td
index d9035a63b2d2e..5e41fdedfa761 100644
--- a/mlir/test/mlir-tblgen/dialect-interface.td
+++ b/mlir/test/mlir-tblgen/dialect-interface.td
@@ -22,9 +22,9 @@ def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
// DECL: public:
-// DECL: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
// DECL: virtual bool isExampleDialect() const {}
// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const {}
+// DECL: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> {
let description = [{
@@ -73,3 +73,24 @@ def OnlyDeclarationInterfaceWithExtraDecls : DialectInterface<"OnlyDeclarationIn
// DECL: class OnlyDeclarationInterfaceWithExtraDecls : public {{.*}}DialectInterface::Base<OnlyDeclarationInterfaceWithExtraDecls>
// DECL: virtual void exampleMethodDeclaration(::mlir::Type type) const;
// DECL: using DeclType = int;
+
+def PureVirtualInterface : DialectInterface<"PureVirtualInterface"> {
+ let description = [{
+ This is an example dialect interface with pure virtual methods.
+ }];
+
+ let cppNamespace = "::mlir::example";
+
+ let methods = [
+ PureVirtualInterfaceMethod<
+ "Check if it's an example dialect", "bool", "isExampleDialect",
+ (ins)
+ >
+ ];
+}
+
+// DECL: class PureVirtualInterface : public {{.*}}DialectInterface::Base<PureVirtualInterface>
+// DECL: public:
+// DECL: virtual bool isExampleDialect() const = 0;
+// DECL: protected:
+// DECL-NEXT: PureVirtualInterface(::mlir::Dialect *dialect) : Base(dialect) {}
diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
index 6ad426c78226d..e695b8c761895 100644
--- a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
@@ -113,6 +113,11 @@ static void emitInterfaceMethodsDef(const DialectInterface &interface,
continue;
}
+ if (method.isPureVirtual()) {
+ ios << " = 0;\n";
+ continue;
+ }
+
// if it is not a method declaration, then it's a normal interface method.
ios << " {";
@@ -126,6 +131,27 @@ static void emitInterfaceMethodsDef(const DialectInterface &interface,
}
}
+static void emitConstructor(const DialectInterface &interface,
+ raw_ostream &os) {
+
+ raw_indented_ostream ios(os);
+
+ // We consider a constructor protected if interface has at least one pure
+ // virtual method
+ auto hasProtectedConstructor =
+ llvm::any_of(interface.getMethods(), [](const InterfaceMethod &method) {
+ return method.isPureVirtual();
+ });
+
+ ios.indent(0);
+ if (hasProtectedConstructor)
+ ios << "protected:\n";
+
+ ios.indent(2);
+ ios << llvm::formatv("{0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
+ interface.getName());
+}
+
void DialectInterfaceGenerator::emitInterfaceDecl(
const DialectInterface &interface) {
llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
@@ -135,9 +161,8 @@ void DialectInterfaceGenerator::emitInterfaceDecl(
// Emit the main interface class declaration.
os << llvm::formatv(
- "class {0} : public ::mlir::DialectInterface::Base<{0}> {\n"
- "public:\n"
- " {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
+ "class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n"
+ "public:\n",
interface.getName());
emitInterfaceMethodsDef(interface, os);
@@ -151,6 +176,10 @@ void DialectInterfaceGenerator::emitInterfaceDecl(
ios << "\n";
}
+ os << "\n";
+
+ emitConstructor(interface, os);
+
os << "};\n";
}
More information about the Mlir-commits
mailing list