[Mlir-commits] [mlir] [MLIR] feat(mlir-tblgen): Add support for dialect interfaces (PR #170046)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Nov 30 11:01:31 PST 2025
https://github.com/aidint updated https://github.com/llvm/llvm-project/pull/170046
>From 504bb5c0dd7ce2907938d5bb1d0649e56537d04a Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 30 Nov 2025 19:40:51 +0100
Subject: [PATCH 1/2] feat(mlir-tblgen): Add support for dialect interfaces
---
mlir/include/mlir/IR/Interfaces.td | 5 +
mlir/include/mlir/TableGen/Interfaces.h | 7 +
mlir/lib/TableGen/Interfaces.cpp | 8 +
mlir/test/mlir-tblgen/dialect-interface.td | 66 +++++++
mlir/tools/mlir-tblgen/CMakeLists.txt | 1 +
.../mlir-tblgen/DialectInterfacesGen.cpp | 176 ++++++++++++++++++
6 files changed, 263 insertions(+)
create mode 100644 mlir/test/mlir-tblgen/dialect-interface.td
create mode 100644 mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
diff --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td
index 0cbe3fa25c9e7..746ad3408f424 100644
--- a/mlir/include/mlir/IR/Interfaces.td
+++ b/mlir/include/mlir/IR/Interfaces.td
@@ -147,6 +147,11 @@ class TypeInterface<string name, list<Interface> baseInterfaces = []>
!if(!empty(cppNamespace),"", cppNamespace # "::") # name
>;
+// DialectInterface represents an interface registered to an operation.
+class DialectInterface<string name, list<Interface> baseInterfaces = []>
+ : Interface<name, baseInterfaces>, OpInterfaceTrait<name>;
+
+
// Whether to declare the interface methods in the user entity's header. This
// class simply wraps an Interface but is used to indicate that the method
// declarations should be generated. This class takes an optional set of methods
diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 7c36cbc1192ac..f62d21da467a1 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -157,6 +157,13 @@ struct TypeInterface : public Interface {
static bool classof(const Interface *interface);
};
+// An interface that is registered to a Dialect.
+struct DialectInterface : public Interface {
+ using Interface::Interface;
+
+ static bool classof(const Interface *interface);
+};
+
} // namespace tblgen
} // namespace mlir
diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index b0ad3ee59a089..77a6cecebbeaf 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) {
bool TypeInterface::classof(const Interface *interface) {
return interface->getDef().isSubClassOf("TypeInterface");
}
+
+//===----------------------------------------------------------------------===//
+// DialectInterface
+//===----------------------------------------------------------------------===//
+
+bool DialectInterface::classof(const Interface *interface) {
+ return interface->getDef().isSubClassOf("DialectInterface");
+}
diff --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td
new file mode 100644
index 0000000000000..9b424bf501be3
--- /dev/null
+++ b/mlir/test/mlir-tblgen/dialect-interface.td
@@ -0,0 +1,66 @@
+// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+
+include "mlir/IR/Interfaces.td"
+
+def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
+ let description = [{
+ This is an example dialect interface without default method body.
+ }];
+
+ let cppNamespace = "::mlir::example";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Check if it's an example dialect",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "isExampleDialect",
+ /*args=*/ (ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/ "second method to check if multiple methods supported",
+ /*returnType=*/ "unsigned",
+ /*methodName=*/ "supportSecondMethod",
+ /*args=*/ (ins "::mlir::Type":$type)
+ >
+
+ ];
+}
+
+// DECL: class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
+// DECL: virtual bool isExampleDialect() const = 0;
+// DECL: virtual unsigned supportSecondMethod(::mlir::Type type) const = 0;
+// DECL: protected:
+// DECL-NEXT: NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
+
+def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> {
+ let description = [{
+ This is an example dialect interface with default method bodies.
+ }];
+
+ let cppNamespace = "::mlir::example";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Check if it's an example dialect",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "isExampleDialect",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{
+ return true;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "second method to check if multiple methods supported",
+ /*returnType=*/ "unsigned",
+ /*methodName=*/ "supportSecondMethod",
+ /*args=*/ (ins "::mlir::Type":$type)
+ >
+
+ ];
+}
+
+// DECL: virtual bool isExampleDialect() const;
+// DECL: bool ::mlir::example::WithDefaultMethodInterface::isExampleDialect() const {
+// DECL-NEXT: return true;
+// DECL-NEXT: }
+
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 2a7ef7e0576c8..d7087cba3c874 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR
AttrOrTypeFormatGen.cpp
BytecodeDialectGen.cpp
DialectGen.cpp
+ DialectInterfacesGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp
EnumPythonBindingGen.cpp
diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
new file mode 100644
index 0000000000000..2fc500343501c
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
@@ -0,0 +1,176 @@
+//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// DialectInterfaceGen generates definitions for Dialect interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CppGenUtilities.h"
+#include "DocGenUtilities.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/CodeGenHelpers.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace mlir;
+using llvm::Record;
+using llvm::RecordKeeper;
+using mlir::tblgen::Interface;
+using mlir::tblgen::InterfaceMethod;
+
+/// Emit a string corresponding to a C++ type, followed by a space if necessary.
+static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
+ type = type.trim();
+ os << type;
+ if (type.back() != '&' && type.back() != '*')
+ os << " ";
+ return os;
+}
+
+/// Emit the method name and argument list for the given method.
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
+ raw_ostream &os) {
+ os << name << '(';
+ llvm::interleaveComma(method.getArguments(), os,
+ [&](const InterfaceMethod::Argument &arg) {
+ os << arg.type << " " << arg.name;
+ });
+ os << ") const";
+}
+
+/// Get an array of all Dialect Interface definitions
+static std::vector<const Record *>
+getAllInterfaceDefinitions(const RecordKeeper &records) {
+ std::vector<const Record *> defs =
+ records.getAllDerivedDefinitions("DialectInterface");
+
+ llvm::erase_if(defs, [&](const Record *def) {
+ // Ignore interfaces defined outside of the top-level file.
+ return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
+ llvm::SrcMgr.getMainFileID();
+ });
+ return defs;
+}
+
+namespace {
+/// This struct is the generator used when processing tablegen dialect
+/// interfaces.
+class DialectInterfaceGenerator {
+public:
+ DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
+ : defs(getAllInterfaceDefinitions(records)), os(os) {}
+
+ bool emitInterfaceDecls();
+
+protected:
+ void emitInterfaceDecl(const Interface &interface);
+ void emitInterfaceMethodsDef(const Interface &interface);
+
+ /// The set of interface records to emit.
+ std::vector<const Record *> defs;
+ // The stream to emit to.
+ raw_ostream &os;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface declarations
+//===----------------------------------------------------------------------===//
+
+static void emitInterfaceMethodDoc(const InterfaceMethod &method,
+ raw_ostream &os, StringRef prefix = "") {
+ if (std::optional<StringRef> description = method.getDescription())
+ tblgen::emitDescriptionComment(*description, os, prefix);
+}
+
+static void emitInterfaceDeclMethods(const Interface &interface,
+ raw_ostream &os) {
+ for (auto &method : interface.getMethods()) {
+ emitInterfaceMethodDoc(method, os, " ");
+ os << " virtual ";
+ emitCPPType(method.getReturnType(), os);
+ emitMethodNameAndArgs(method, method.getName(), os);
+ if (!method.getBody())
+ // no default method body
+ os << " = 0";
+ os << ";\n";
+ }
+}
+
+void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
+ llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
+
+ StringRef interfaceName = interface.getName();
+
+ tblgen::emitSummaryAndDescComments(os, "",
+ interface.getDescription().value_or(""));
+
+ // Emit the main interface class declaration.
+ os << llvm::formatv(
+ "class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n"
+ "public:\n",
+ interfaceName);
+
+ emitInterfaceDeclMethods(interface, os);
+ os << llvm::formatv("\nprotected:\n"
+ " {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
+ interfaceName);
+
+ os << "};\n";
+}
+
+void DialectInterfaceGenerator::emitInterfaceMethodsDef(
+ const Interface &interface) {
+
+ for (auto &method : interface.getMethods()) {
+ if (auto body = method.getBody()) {
+ emitCPPType(method.getReturnType(), os);
+ os << interface.getCppNamespace() << "::";
+ os << interface.getName() << "::";
+ emitMethodNameAndArgs(method, method.getName(), os);
+ os << " {\n " << body.value() << "\n}\n";
+ }
+ }
+}
+
+bool DialectInterfaceGenerator::emitInterfaceDecls() {
+
+ llvm::emitSourceFileHeader("Dialect Interface Declarations", os);
+
+ // Sort according to ID, so defs are emitted in the order in which they appear
+ // in the Tablegen file.
+ std::vector<const Record *> sortedDefs(defs);
+ llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
+ return lhs->getID() < rhs->getID();
+ });
+
+ for (const Record *def : sortedDefs)
+ emitInterfaceDecl(Interface(def));
+
+ os << "\n";
+ for (const Record *def : sortedDefs)
+ emitInterfaceMethodsDef(Interface(def));
+
+ return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface registration hooks
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration genDecls(
+ "gen-dialect-interface-decls",
+ "Generate dialect interface declarations.",
+ [](const RecordKeeper &records, raw_ostream &os) {
+ return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
+ });
>From 4ca9d25f184b808021289ea921a3494a01c1ff93 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 30 Nov 2025 20:00:59 +0100
Subject: [PATCH 2/2] resolve clang-format problem
---
mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
index 2fc500343501c..0ce4c3ef603b3 100644
--- a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
@@ -169,8 +169,7 @@ bool DialectInterfaceGenerator::emitInterfaceDecls() {
//===----------------------------------------------------------------------===//
static mlir::GenRegistration genDecls(
- "gen-dialect-interface-decls",
- "Generate dialect interface declarations.",
+ "gen-dialect-interface-decls", "Generate dialect interface declarations.",
[](const RecordKeeper &records, raw_ostream &os) {
return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
});
More information about the Mlir-commits
mailing list