[Mlir-commits] [mlir] [MLIR] feat(mlir-tblgen): Add support for dialect interfaces (PR #170046)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 30 10:51:03 PST 2025


https://github.com/aidint created https://github.com/llvm/llvm-project/pull/170046

Currently, Dialect Interfaces can't be defined in ODS. This PR adds the support for dialect interfaces. It follows the same approach with other interfaces and extends on top of `Interface` class defined in `mlir/TableGen/Interfaces.h`.

Given the following input:
```tablegen
#ifndef MY_INTERFACES
#define MY_INTERFACES

include "mlir/IR/Interfaces.td"

def DialectInlinerInterface : DialectInterface<"DialectInlinerInterface"> {
  let description = [{
     Define a base inlining interface class to allow for dialects to opt-in to the inliner.
  }];

  let cppNamespace = "::mlir";

  let methods = [
    InterfaceMethod<
      /*desc=*/        [{
        Returns true if the given region 'src' can be inlined into the region
        'dest' that is attached to an operation registered to the current dialect.
        'valueMapping' contains any remapped values from within the 'src' region.
        This can be used to examine what values will replace entry arguments into
        the 'src' region, for example.
      }],
      /*returnType=*/  "bool",
      /*methodName=*/  "isLegalToInline",
      /*args=*/        (ins "::mlir::Region *":$dest, "::mlir::Region *":$src, "::mlir::IRMapping &":$valueMapping),
      /*methodBody=*/  [{
        return true;
      }]
      >
  ];
}


#endif

```

It will generate the following code:
```cpp
/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\
|*                                                                            *|
|* Dialect Interface Declarations                                             *|
|*                                                                            *|
|* Automatically generated file, do not edit!                                 *|
|*                                                                            *|
\*===----------------------------------------------------------------------===*/

namespace mlir {

/// Define a base inlining interface class to allow for dialects to opt-in to the inliner.
class DialectInlinerInterface : public ::mlir::DialectInterface::Base<DialectInlinerInterface> {
public:

  /// Returns true if the given region 'src' can be inlined into the region
  /// 'dest' that is attached to an operation registered to the current dialect.
  /// 'valueMapping' contains any remapped values from within the 'src' region.
  /// This can be used to examine what values will replace entry arguments into
  /// the 'src' region, for example.
  virtual bool isLegalToInline(::mlir::Region * dest, ::mlir::Region * src, ::mlir::IRMapping & valueMapping) const;

protected:
  DialectInlinerInterface(::mlir::Dialect *dialect) : Base(dialect) {}
};

} // namespace mlir

bool ::mlir::DialectInlinerInterface::isLegalToInline(::mlir::Region * dest, ::mlir::Region * src, ::mlir::IRMapping & valueMapping) const {
  return true;
}
```

>From 504bb5c0dd7ce2907938d5bb1d0649e56537d04a Mon Sep 17 00:00:00 2001
From: aidint <at.aidin at gmail.com>
Date: Sun, 30 Nov 2025 19:40:51 +0100
Subject: [PATCH] feat(mlir-tblgen): Add support for dialect interfaces

---
 mlir/include/mlir/IR/Interfaces.td            |   5 +
 mlir/include/mlir/TableGen/Interfaces.h       |   7 +
 mlir/lib/TableGen/Interfaces.cpp              |   8 +
 mlir/test/mlir-tblgen/dialect-interface.td    |  66 +++++++
 mlir/tools/mlir-tblgen/CMakeLists.txt         |   1 +
 .../mlir-tblgen/DialectInterfacesGen.cpp      | 176 ++++++++++++++++++
 6 files changed, 263 insertions(+)
 create mode 100644 mlir/test/mlir-tblgen/dialect-interface.td
 create mode 100644 mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp

diff --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td
index 0cbe3fa25c9e7..746ad3408f424 100644
--- a/mlir/include/mlir/IR/Interfaces.td
+++ b/mlir/include/mlir/IR/Interfaces.td
@@ -147,6 +147,11 @@ class TypeInterface<string name, list<Interface> baseInterfaces = []>
 			!if(!empty(cppNamespace),"", cppNamespace # "::") # name
     >;
 
+// DialectInterface represents an interface registered to an operation.
+class DialectInterface<string name, list<Interface> baseInterfaces = []>
+  : Interface<name, baseInterfaces>, OpInterfaceTrait<name>;
+
+
 // Whether to declare the interface methods in the user entity's header. This
 // class simply wraps an Interface but is used to indicate that the method
 // declarations should be generated. This class takes an optional set of methods
diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 7c36cbc1192ac..f62d21da467a1 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -157,6 +157,13 @@ struct TypeInterface : public Interface {
 
   static bool classof(const Interface *interface);
 };
+// An interface that is registered to a Dialect.
+struct DialectInterface : public Interface {
+  using Interface::Interface;
+
+  static bool classof(const Interface *interface);
+};
+
 } // namespace tblgen
 } // namespace mlir
 
diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index b0ad3ee59a089..77a6cecebbeaf 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -208,3 +208,11 @@ bool OpInterface::classof(const Interface *interface) {
 bool TypeInterface::classof(const Interface *interface) {
   return interface->getDef().isSubClassOf("TypeInterface");
 }
+
+//===----------------------------------------------------------------------===//
+// DialectInterface
+//===----------------------------------------------------------------------===//
+
+bool DialectInterface::classof(const Interface *interface) {
+  return interface->getDef().isSubClassOf("DialectInterface");
+}
diff --git a/mlir/test/mlir-tblgen/dialect-interface.td b/mlir/test/mlir-tblgen/dialect-interface.td
new file mode 100644
index 0000000000000..9b424bf501be3
--- /dev/null
+++ b/mlir/test/mlir-tblgen/dialect-interface.td
@@ -0,0 +1,66 @@
+// RUN: mlir-tblgen -gen-dialect-interface-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+
+include "mlir/IR/Interfaces.td"
+
+def NoDefaultMethod : DialectInterface<"NoDefaultMethod"> {
+  let description = [{
+    This is an example dialect interface without default method body.
+  }];
+
+  let cppNamespace = "::mlir::example";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Check if it's an example dialect",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "isExampleDialect",
+      /*args=*/        (ins)
+      >,
+      InterfaceMethod<
+      /*desc=*/        "second method to check if multiple methods supported",
+      /*returnType=*/  "unsigned",
+      /*methodName=*/  "supportSecondMethod",
+      /*args=*/        (ins "::mlir::Type":$type)
+      >
+
+  ];
+}
+
+// DECL:   class NoDefaultMethod : public {{.*}}DialectInterface::Base<NoDefaultMethod>
+// DECL:   virtual bool isExampleDialect() const = 0;
+// DECL:   virtual unsigned supportSecondMethod(::mlir::Type type) const = 0;
+// DECL:   protected:
+// DECL-NEXT:   NoDefaultMethod(::mlir::Dialect *dialect) : Base(dialect) {}
+
+def WithDefaultMethodInterface : DialectInterface<"WithDefaultMethodInterface"> {
+  let description = [{
+    This is an example dialect interface with default method bodies.
+  }];
+
+  let cppNamespace = "::mlir::example";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Check if it's an example dialect",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "isExampleDialect",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{
+          return true;
+         }]
+      >,
+      InterfaceMethod<
+      /*desc=*/        "second method to check if multiple methods supported",
+      /*returnType=*/  "unsigned",
+      /*methodName=*/  "supportSecondMethod",
+      /*args=*/        (ins "::mlir::Type":$type)
+      >
+
+  ];
+}
+
+// DECL:  virtual bool isExampleDialect() const;
+// DECL:  bool ::mlir::example::WithDefaultMethodInterface::isExampleDialect() const {
+// DECL-NEXT:  return true;
+// DECL-NEXT: }
+
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 2a7ef7e0576c8..d7087cba3c874 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -12,6 +12,7 @@ add_tablegen(mlir-tblgen MLIR
   AttrOrTypeFormatGen.cpp
   BytecodeDialectGen.cpp
   DialectGen.cpp
+  DialectInterfacesGen.cpp
   DirectiveCommonGen.cpp
   EnumsGen.cpp
   EnumPythonBindingGen.cpp
diff --git a/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
new file mode 100644
index 0000000000000..2fc500343501c
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/DialectInterfacesGen.cpp
@@ -0,0 +1,176 @@
+//===- DialectInterfacesGen.cpp - MLIR dialect interface utility generator ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// DialectInterfaceGen generates definitions for Dialect interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#include "CppGenUtilities.h"
+#include "DocGenUtilities.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Interfaces.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/TableGen/CodeGenHelpers.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+#include "llvm/TableGen/TableGenBackend.h"
+
+using namespace mlir;
+using llvm::Record;
+using llvm::RecordKeeper;
+using mlir::tblgen::Interface;
+using mlir::tblgen::InterfaceMethod;
+
+/// Emit a string corresponding to a C++ type, followed by a space if necessary.
+static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
+  type = type.trim();
+  os << type;
+  if (type.back() != '&' && type.back() != '*')
+    os << " ";
+  return os;
+}
+
+/// Emit the method name and argument list for the given method.
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
+                                  raw_ostream &os) {
+  os << name << '(';
+  llvm::interleaveComma(method.getArguments(), os,
+                        [&](const InterfaceMethod::Argument &arg) {
+                          os << arg.type << " " << arg.name;
+                        });
+  os << ") const";
+}
+
+/// Get an array of all Dialect Interface definitions
+static std::vector<const Record *>
+getAllInterfaceDefinitions(const RecordKeeper &records) {
+  std::vector<const Record *> defs =
+      records.getAllDerivedDefinitions("DialectInterface");
+
+  llvm::erase_if(defs, [&](const Record *def) {
+    // Ignore interfaces defined outside of the top-level file.
+    return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
+           llvm::SrcMgr.getMainFileID();
+  });
+  return defs;
+}
+
+namespace {
+/// This struct is the generator used when processing tablegen dialect
+/// interfaces.
+class DialectInterfaceGenerator {
+public:
+  DialectInterfaceGenerator(const RecordKeeper &records, raw_ostream &os)
+      : defs(getAllInterfaceDefinitions(records)), os(os) {}
+
+  bool emitInterfaceDecls();
+
+protected:
+  void emitInterfaceDecl(const Interface &interface);
+  void emitInterfaceMethodsDef(const Interface &interface);
+
+  /// The set of interface records to emit.
+  std::vector<const Record *> defs;
+  // The stream to emit to.
+  raw_ostream &os;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface declarations
+//===----------------------------------------------------------------------===//
+
+static void emitInterfaceMethodDoc(const InterfaceMethod &method,
+                                   raw_ostream &os, StringRef prefix = "") {
+  if (std::optional<StringRef> description = method.getDescription())
+    tblgen::emitDescriptionComment(*description, os, prefix);
+}
+
+static void emitInterfaceDeclMethods(const Interface &interface,
+                                     raw_ostream &os) {
+  for (auto &method : interface.getMethods()) {
+    emitInterfaceMethodDoc(method, os, "  ");
+    os << "  virtual ";
+    emitCPPType(method.getReturnType(), os);
+    emitMethodNameAndArgs(method, method.getName(), os);
+    if (!method.getBody())
+      // no default method body
+      os << " = 0";
+    os << ";\n";
+  }
+}
+
+void DialectInterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
+  llvm::NamespaceEmitter ns(os, interface.getCppNamespace());
+
+  StringRef interfaceName = interface.getName();
+
+  tblgen::emitSummaryAndDescComments(os, "",
+                                     interface.getDescription().value_or(""));
+
+  // Emit the main interface class declaration.
+  os << llvm::formatv(
+      "class {0} : public ::mlir::DialectInterface::Base<{0}> {{\n"
+      "public:\n",
+      interfaceName);
+
+  emitInterfaceDeclMethods(interface, os);
+  os << llvm::formatv("\nprotected:\n"
+                      "  {0}(::mlir::Dialect *dialect) : Base(dialect) {{}\n",
+                      interfaceName);
+
+  os << "};\n";
+}
+
+void DialectInterfaceGenerator::emitInterfaceMethodsDef(
+    const Interface &interface) {
+
+  for (auto &method : interface.getMethods()) {
+    if (auto body = method.getBody()) {
+      emitCPPType(method.getReturnType(), os);
+      os << interface.getCppNamespace() << "::";
+      os << interface.getName() << "::";
+      emitMethodNameAndArgs(method, method.getName(), os);
+      os << " {\n  " << body.value() << "\n}\n";
+    }
+  }
+}
+
+bool DialectInterfaceGenerator::emitInterfaceDecls() {
+
+  llvm::emitSourceFileHeader("Dialect Interface Declarations", os);
+
+  // Sort according to ID, so defs are emitted in the order in which they appear
+  // in the Tablegen file.
+  std::vector<const Record *> sortedDefs(defs);
+  llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) {
+    return lhs->getID() < rhs->getID();
+  });
+
+  for (const Record *def : sortedDefs)
+    emitInterfaceDecl(Interface(def));
+
+  os << "\n";
+  for (const Record *def : sortedDefs)
+    emitInterfaceMethodsDef(Interface(def));
+
+  return false;
+}
+
+//===----------------------------------------------------------------------===//
+// GEN: Interface registration hooks
+//===----------------------------------------------------------------------===//
+
+static mlir::GenRegistration genDecls(
+    "gen-dialect-interface-decls",
+    "Generate dialect interface declarations.",
+    [](const RecordKeeper &records, raw_ostream &os) {
+      return DialectInterfaceGenerator(records, os).emitInterfaceDecls();
+    });



More information about the Mlir-commits mailing list