[Mlir-commits] [mlir] [MLIR][ODS] Add support for overloading interface methods (PR #161828)
Mehdi Amini
llvmlistbot at llvm.org
Fri Oct 3 04:50:54 PDT 2025
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/161828
This allows to define multiple interface methods with the same name but different arguments.
>From 150cb5b9e1a7bb7c2c9d96d545c5b76c75476fcc Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 3 Oct 2025 04:47:35 -0700
Subject: [PATCH] [MLIR][ODS] Add support for overloading interface methods
This allows to define multiple interface methods with the same name
but different arguments.
---
mlir/test/lib/Dialect/Test/TestTypes.cpp | 4 +++
mlir/test/lib/IR/TestInterfaces.cpp | 1 +
mlir/test/mlir-tblgen/interfaces.mlir | 1 +
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 27 ++++++++++++++++-
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 32 ++++++++++++---------
5 files changed, 50 insertions(+), 15 deletions(-)
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index bea043f56fe21..9076c7e54d7bf 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -245,6 +245,10 @@ void TestType::printTypeC(Location loc) const {
emitRemark(loc) << *this << " - TestC";
}
+void TestType::printTypeC(Location loc, int value) const {
+ emitRemark(loc) << *this << " - " << value << " - TestC";
+}
+
//===----------------------------------------------------------------------===//
// TestTypeWithLayout
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp
index 2dd3fe245e220..e021f78e1142d 100644
--- a/mlir/test/lib/IR/TestInterfaces.cpp
+++ b/mlir/test/lib/IR/TestInterfaces.cpp
@@ -31,6 +31,7 @@ struct TestTypeInterfaces
testInterface.printTypeA(op->getLoc());
testInterface.printTypeB(op->getLoc());
testInterface.printTypeC(op->getLoc());
+ testInterface.printTypeC(op->getLoc(), 42);
testInterface.printTypeD(op->getLoc());
// Just check that we can assign the result to a variable of interface
// type.
diff --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir
index 5c1ec613b387a..927cfd728bcd4 100644
--- a/mlir/test/mlir-tblgen/interfaces.mlir
+++ b/mlir/test/mlir-tblgen/interfaces.mlir
@@ -3,6 +3,7 @@
// expected-remark at below {{'!test.test_type' - TestA}}
// expected-remark at below {{'!test.test_type' - TestB}}
// expected-remark at below {{'!test.test_type' - TestC}}
+// expected-remark at below {{'!test.test_type' - 42 - TestC}}
// expected-remark at below {{'!test.test_type' - TestD}}
// expected-remark at below {{'!test.test_type' - TestRet}}
// expected-remark at below {{'!test.test_type' - TestE}}
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 7e8e559baf878..4c6519cd2f7bf 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -789,6 +789,14 @@ class OpEmitter {
Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
bool declaration = true);
+ // Generate a `using` declaration for the op interface method to include
+ // the default implementation from the interface trait.
+ // This is needed when the interface defines multiple methods with the same
+ // name, but some have a default implementation and some don't.
+ UsingDeclaration *
+ genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+ const tblgen::InterfaceMethod &method);
+
// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();
@@ -815,6 +823,10 @@ class OpEmitter {
// Helper for emitting op code.
OpOrAdaptorHelper emitHelper;
+
+ // Keep track of the interface using declarations that have been generated to
+ // avoid duplicates.
+ llvm::StringSet<> interfaceUsingNames;
};
} // namespace
@@ -3672,8 +3684,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
// Don't declare if the method has a default implementation and the op
// didn't request that it always be declared.
if (method.getDefaultImplementation() &&
- !alwaysDeclaredMethods.count(method.getName()))
+ !alwaysDeclaredMethods.count(method.getName())) {
+ genOpInterfaceMethodUsingDecl(opTrait, method);
continue;
+ }
// Interface methods are allowed to overlap with existing methods, so don't
// check if pruned.
(void)genOpInterfaceMethod(method);
@@ -3692,6 +3706,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
std::move(paramList));
}
+UsingDeclaration *
+OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+ const InterfaceMethod &method) {
+ std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" +
+ op.getQualCppClassName() + ">::" + method.getName())
+ .str();
+ if (interfaceUsingNames.insert(name).second)
+ return opClass.declare<UsingDeclaration>(std::move(name));
+ return nullptr;
+}
+
void OpEmitter::genOpInterfaceMethods() {
for (const auto &trait : op.getTraits()) {
if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 3cc1636ac3317..9dedd55005f87 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -42,10 +42,10 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
/// Emit the method name and argument list for the given method. If 'addThisArg'
/// is true, then an argument is added to the beginning of the argument list for
/// the concrete value.
-static void emitMethodNameAndArgs(const InterfaceMethod &method,
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
raw_ostream &os, StringRef valueType,
bool addThisArg, bool addConst) {
- os << method.getName() << '(';
+ os << name << '(';
if (addThisArg) {
if (addConst)
os << "const ";
@@ -183,11 +183,13 @@ static void emitInterfaceDefMethods(StringRef interfaceQualName,
emitInterfaceMethodDoc(method, os);
emitCPPType(method.getReturnType(), os);
os << interfaceQualName << "::";
- emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+ emitMethodNameAndArgs(method, method.getName(), os, valueType,
+ /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
// Forward to the method on the concrete operation type.
- os << " {\n return " << implValue << "->" << method.getName() << '(';
+ os << " {\n return " << implValue << "->" << method.getDedupName()
+ << '(';
if (!method.isStatic()) {
os << implValue << ", ";
os << (isOpInterface ? "getOperation()" : "*this");
@@ -239,7 +241,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
for (auto &method : interface.getMethods()) {
os << " ";
emitCPPType(method.getReturnType(), os);
- os << "(*" << method.getName() << ")(";
+ os << "(*" << method.getDedupName() << ")(";
if (!method.isStatic()) {
os << "const Concept *impl, ";
emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
@@ -289,13 +291,13 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
os << " " << modelClass << "() : Concept{";
llvm::interleaveComma(
interface.getMethods(), os,
- [&](const InterfaceMethod &method) { os << method.getName(); });
+ [&](const InterfaceMethod &method) { os << method.getDedupName(); });
os << "} {}\n\n";
// Insert each of the virtual method overrides.
for (auto &method : interface.getMethods()) {
emitCPPType(method.getReturnType(), os << " static inline ");
- emitMethodNameAndArgs(method, os, valueType,
+ emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
/*addThisArg=*/!method.isStatic(),
/*addConst=*/false);
os << ";\n";
@@ -319,7 +321,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
if (method.isStatic())
os << "static ";
emitCPPType(method.getReturnType(), os);
- os << method.getName() << "(";
+ os << method.getDedupName() << "(";
if (!method.isStatic()) {
emitCPPType(valueType, os);
os << "tablegen_opaque_val";
@@ -350,7 +352,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
emitCPPType(method.getReturnType(), os);
os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
<< valueTemplate << ">::";
- emitMethodNameAndArgs(method, os, valueType,
+ emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
/*addThisArg=*/!method.isStatic(),
/*addConst=*/false);
os << " {\n ";
@@ -384,7 +386,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
emitCPPType(method.getReturnType(), os);
os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
<< valueTemplate << ">::";
- emitMethodNameAndArgs(method, os, valueType,
+ emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
/*addThisArg=*/!method.isStatic(),
/*addConst=*/false);
os << " {\n ";
@@ -396,7 +398,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
os << "return static_cast<const " << valueTemplate << " *>(impl)->";
// Add the arguments to the call.
- os << method.getName() << '(';
+ os << method.getDedupName() << '(';
if (!method.isStatic())
os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
llvm::interleaveComma(
@@ -416,7 +418,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
<< "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
<< ">::";
- os << method.getName() << "(";
+ os << method.getDedupName() << "(";
if (!method.isStatic()) {
emitCPPType(valueType, os);
os << "tablegen_opaque_val";
@@ -477,7 +479,8 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
emitInterfaceMethodDoc(method, os, " ");
os << " " << (method.isStatic() ? "static " : "");
emitCPPType(method.getReturnType(), os);
- emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+ emitMethodNameAndArgs(method, method.getName(), os, valueType,
+ /*addThisArg=*/false,
/*addConst=*/!isOpInterface && !method.isStatic());
os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
<< "\n }\n";
@@ -514,7 +517,8 @@ static void emitInterfaceDeclMethods(const Interface &interface,
for (auto &method : interface.getMethods()) {
emitInterfaceMethodDoc(method, os, " ");
emitCPPType(method.getReturnType(), os << " ");
- emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+ emitMethodNameAndArgs(method, method.getName(), os, valueType,
+ /*addThisArg=*/false,
/*addConst=*/!isOpInterface);
os << ";\n";
}
More information about the Mlir-commits
mailing list