[Mlir-commits] [mlir] [MLIR] feat(mlir-tblgen): Add support for dialect interfaces (PR #170046)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 30 11:01:31 PST 2025


https://github.com/aidint updated https://github.com/llvm/llvm-project/pull/170046

>From 504bb5c0dd7ce2907938d5bb1d0649e56537d04a Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 30 Nov 2025 19:40:51 +0100
Subject: [PATCH 1/2] feat(mlir-tblgen): Add support for dialect interfaces

---
 mlir/include/mlir/IR/Interfaces.td            |   5 +
 mlir/include/mlir/TableGen/Interfaces.h       |   7 +
 mlir/lib/TableGen/Interfaces.cpp              |   8 +
 mlir/test/mlir-tblgen/dialect-interface.td    |  66 +++++++
 mlir/tools/mlir-tblgen/CMakeLists.txt         |   1 +
 .../mlir-tblgen/DialectInterfacesGen.cpp      | 176 ++++++++++++++++++
 6 files changed, 263 insertions(+)
 create mode 100644 mlir/test/mlir-tblgen/dialect-interface.td
 create mode 100644 mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp

diff --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td
index 0cbe3fa25c9e7..746ad3408f424 100644
--- a/mlir/include/mlir/IR/Interfaces.td
+++ b/mlir/include/mlir/IR/Interfaces.td
@@ -147,6 +147,11 @@ class TypeInterface<string name, list<Interface> baseInterfaces = []>
 			!if(!empty(cppNamespace),"", cppNamespace # "::") # name
     >;
 
+// DialectInterface represents an interface registered to an operation.
+class DialectInterface<string name, list<Interface> baseInterfaces = []>
+  : Interface<name, baseInterfaces>, OpInterfaceTrait<name>;
+
+
 // Whether to declare the interface methods in the user entity's header. This
 // class simply wraps an Interface but is used to indicate that the method
 // declarations should be generated. This class takes an optional set of methods
diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 7c36cbc1192ac..f62d21da467a1 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -157,6 +157,13 @@ struct TypeInterface : public Interface {
 
   static bool classof(const Interface *interface);
 };
+// An interface that is registered to a Dialect.
+struct DialectInterface : public Interface {
+  using Interface::Interface;
+
+  static bool classof(const Interface *interface);
+};
+
 } // namespace tblgen
 } // namespace mlir
 
diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index b0ad3ee59a089..77a6cecebbeaf 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) {
 bool TypeInterface::classof(const Interface *interface) {
   return interface->getDef().isSubClassOf("TypeInterface");
 }
+
+//===----------------------------------------------------------------------===//
+// DialectInterface
+//===----------------------------------------------------------------------===//
+
+bool DialectInterface::classof(const Interface *interface) {
+  return interface->getDef().isSubClassOf("DialectInterface");
+}
diff --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td
new file mode 100644
index 0000000000000..9b424bf501be3
--- /dev/null
+++ b/mlir/test/mlir-tblgen/dialect-interface.td
@@ -0,0 +1,66 @@
+// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+
+include "mlir/IR/Interfaces.td"
+
+def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
+  let description = [{
+    This is an example dialect interface without default method body.
+  }];
+
+  let cppNamespace = "::mlir::example";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Check if it's an example dialect",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "isExampleDialect",
+      /*args=*/        (ins)
+      >,
+      InterfaceMethod<
+      /*desc=*/        "second method to check if multiple methods supported",
+      /*returnType=*/  "unsigned",
+      /*methodName=*/  "supportSecondMethod",
+      /*args=*/        (ins "::mlir::Type":$type)
+      >
+
+  ];
+}
+
+// DECL:   class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
+// DECL:   virtual bool isExampleDialect() const = 0;
+// DECL:   virtual unsigned supportSecondMethod(::mlir::Type type) const = 0;
+// DECL:   protected:
+// DECL-NEXT:   NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
+
+def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> {
+  let description = [{
+    This is an example dialect interface with default method bodies.
+  }];
+
+  let cppNamespace = "::mlir::example";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Check if it's an example dialect",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "isExampleDialect",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{
+          return true;
+         }]
+      >,
+      InterfaceMethod<
+      /*desc=*/        "second method to check if multiple methods supported",
+      /*returnType=*/  "unsigned",
+      /*methodName=*/  "supportSecondMethod",
+      /*args=*/        (ins "::mlir::Type":$type)
+      >
+
+  ];
+}
+
+// DECL:  virtual bool isExampleDialect() const;
+// DECL:  bool ::mlir::example::WithDefaultMethodInterface::isExampleDialect() const {
+// DECL-NEXT:  return true;
+// DECL-NEXT: }
+
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 2a7ef7e0576c8..d7087cba3c874 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR
   AttrOrTypeFormatGen.cpp
   BytecodeDialectGen.cpp
   DialectGen.cpp
+  DialectInterfacesGen.cpp
   DirectiveCommonGen.cpp
   EnumsGen.cpp
   EnumPythonBindingGen.cpp
diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
new file mode 100644
index 0000000000000..2fc500343501c
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
@@ -0,0 +1,176 @@
+//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// DialectInterfaceGen generates definitions for Dialect interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CppGenUtilities.h"
+#include "DocGenUtilities.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/CodeGenHelpers.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace mlir;
+using llvm::Record;
+using llvm::RecordKeeper;
+using mlir::tblgen::Interface;
+using mlir::tblgen::InterfaceMethod;
+
+/// Emit a string corresponding to a C++ type, followed by a space if necessary.
+static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
+  type = type.trim();
+  os << type;
+  if (type.back() != '&' && type.back() != '*')
+    os << " ";
+  return os;
+}
+
+/// Emit the method name and argument list for the given method.
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
+                                  raw_ostream &os) {
+  os << name << '(';
+  llvm::interleaveComma(method.getArguments(), os,
+                        [&](const InterfaceMethod::Argument &arg) {
+                          os << arg.type << " " << arg.name;
+                        });
+  os << ") const";
+}
+
+/// Get an array of all Dialect Interface definitions
+static std::vector<const Record *>
+getAllInterfaceDefinitions(const RecordKeeper &records) {
+  std::vector<const Record *> defs =
+      records.getAllDerivedDefinitions("DialectInterface");
+
+  llvm::erase_if(defs, [&](const Record *def) {
+    // Ignore interfaces defined outside of the top-level file.
+    return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
+           llvm::SrcMgr.getMainFileID();
+  });
+  return defs;
+}
+
+namespace {
+/// This struct is the generator used when processing tablegen dialect
+/// interfaces.
+class DialectInterfaceGenerator {
+public:
+  DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
+      : defs(getAllInterfaceDefinitions(records)), os(os) {}
+
+  bool emitInterfaceDecls();
+
+protected:
+  void emitInterfaceDecl(const Interface &interface);
+  void emitInterfaceMethodsDef(const Interface &interface);
+
+  /// The set of interface records to emit.
+  std::vector<const Record *> defs;
+  // The stream to emit to.
+  raw_ostream &os;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface declarations
+//===----------------------------------------------------------------------===//
+
+static void emitInterfaceMethodDoc(const InterfaceMethod &method,
+                                   raw_ostream &os, StringRef prefix = "") {
+  if (std::optional<StringRef> description = method.getDescription())
+    tblgen::emitDescriptionComment(*description, os, prefix);
+}
+
+static void emitInterfaceDeclMethods(const Interface &interface,
+                                     raw_ostream &os) {
+  for (auto &method : interface.getMethods()) {
+    emitInterfaceMethodDoc(method, os, "  ");
+    os << "  virtual ";
+    emitCPPType(method.getReturnType(), os);
+    emitMethodNameAndArgs(method, method.getName(), os);
+    if (!method.getBody())
+      // no default method body
+      os << " = 0";
+    os << ";\n";
+  }
+}
+
+void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
+  llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
+
+  StringRef interfaceName = interface.getName();
+
+  tblgen::emitSummaryAndDescComments(os, "",
+                                     interface.getDescription().value_or(""));
+
+  // Emit the main interface class declaration.
+  os << llvm::formatv(
+      "class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n"
+      "public:\n",
+      interfaceName);
+
+  emitInterfaceDeclMethods(interface, os);
+  os << llvm::formatv("\nprotected:\n"
+                      "  {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
+                      interfaceName);
+
+  os << "};\n";
+}
+
+void DialectInterfaceGenerator::emitInterfaceMethodsDef(
+    const Interface &interface) {
+
+  for (auto &method : interface.getMethods()) {
+    if (auto body = method.getBody()) {
+      emitCPPType(method.getReturnType(), os);
+      os << interface.getCppNamespace() << "::";
+      os << interface.getName() << "::";
+      emitMethodNameAndArgs(method, method.getName(), os);
+      os << " {\n  " << body.value() << "\n}\n";
+    }
+  }
+}
+
+bool DialectInterfaceGenerator::emitInterfaceDecls() {
+
+  llvm::emitSourceFileHeader("Dialect Interface Declarations", os);
+
+  // Sort according to ID, so defs are emitted in the order in which they appear
+  // in the Tablegen file.
+  std::vector<const Record *> sortedDefs(defs);
+  llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
+    return lhs->getID() < rhs->getID();
+  });
+
+  for (const Record *def : sortedDefs)
+    emitInterfaceDecl(Interface(def));
+
+  os << "\n";
+  for (const Record *def : sortedDefs)
+    emitInterfaceMethodsDef(Interface(def));
+
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface registration hooks
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration genDecls(
+    "gen-dialect-interface-decls",
+    "Generate dialect interface declarations.",
+    [](const RecordKeeper &records, raw_ostream &os) {
+      return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
+    });

>From 4ca9d25f184b808021289ea921a3494a01c1ff93 Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 30 Nov 2025 20:00:59 +0100
Subject: [PATCH 2/2] resolve clang-format problem

---
 mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
index 2fc500343501c..0ce4c3ef603b3 100644
--- a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
@@ -169,8 +169,7 @@ bool DialectInterfaceGenerator::emitInterfaceDecls() {
 //===----------------------------------------------------------------------===//
 
 static mlir::GenRegistration genDecls(
-    "gen-dialect-interface-decls",
-    "Generate dialect interface declarations.",
+    "gen-dialect-interface-decls", "Generate dialect interface declarations.",
     [](const RecordKeeper &records, raw_ostream &os) {
       return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
     });



More information about the Mlir-commits mailing list