[Mlir-commits] [flang] [mlir] [MLIR][ODS] Add support for overloading interface methods (PR #161828)

Mehdi Amini llvmlistbot at llvm.org
Fri Oct 3 07:48:41 PDT 2025


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/161828

>From d603c931b76549340e220bb31c01abf266e6ac64 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Fri, 3 Oct 2025 04:47:35 -0700
Subject: [PATCH] [MLIR][ODS] Add support for overloading interface methods

This allows to define multiple interface methods with the same name
but different arguments.
---
 .../include/flang/Optimizer/HLFIR/HLFIROps.td |  9 +++++-
 mlir/include/mlir/TableGen/Interfaces.h       |  7 +++-
 mlir/lib/TableGen/Interfaces.cpp              | 17 ++++++++--
 mlir/test/lib/Dialect/Test/TestInterfaces.td  | 10 ++++++
 mlir/test/lib/Dialect/Test/TestTypes.cpp      |  4 +++
 mlir/test/lib/IR/TestInterfaces.cpp           |  2 ++
 mlir/test/mlir-tblgen/interfaces.mlir         |  2 ++
 mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp   | 19 ++++++++++-
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   | 27 +++++++++++++++-
 mlir/tools/mlir-tblgen/OpInterfacesGen.cpp    | 32 +++++++++++--------
 10 files changed, 108 insertions(+), 21 deletions(-)

diff --git a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
index 90512586a6520..218435a44c24f 100644
--- a/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
+++ b/flang/include/flang/Optimizer/HLFIR/HLFIROps.td
@@ -269,6 +269,9 @@ def hlfir_DesignateOp : hlfir_Op<"designate", [AttrSizedOperandSegments,
     using Triplet = std::tuple<mlir::Value, mlir::Value, mlir::Value>;
     using Subscript = std::variant<mlir::Value, Triplet>;
     using Subscripts = llvm::SmallVector<Subscript, 8>;
+    void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {
+      this->setFortranAttrs(std::optional<fir::FortranVariableFlagsEnum>(flags));
+    }
   }];
 
   let builders = [
@@ -319,7 +322,7 @@ def hlfir_ParentComponentOp : hlfir_Op<"parent_comp", [AttrSizedOperandSegments,
     // Implement FortranVariableInterface interface. Parent components have
     // no attributes (pointer, allocatable or contiguous can only be added
     // to regular components).
-    std::optional<fir::FortranVariableFlagsEnum> getFortranAttrs() const {
+    std::optional<fir::FortranVariableFlagsEnum> getFortranAttrs() {
       return std::nullopt;
     }
     void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {}
@@ -882,6 +885,10 @@ def hlfir_AssociateOp : hlfir_Op<"associate", [AttrSizedOperandSegments,
       CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>];
 
   let extraClassDeclaration = [{
+    void setFortranAttrs(fir::FortranVariableFlagsEnum flags) {
+      this->setFortranAttrs(std::optional<fir::FortranVariableFlagsEnum>(flags));
+    }
+
     /// Override FortranVariableInterface default implementation
     mlir::Value getBase() {
       return getResult(0);
diff --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index 15f667e0ffce0..7280ad2e3bc17 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -32,7 +32,7 @@ class InterfaceMethod {
     StringRef name;
   };
 
-  explicit InterfaceMethod(const llvm::Record *def);
+  explicit InterfaceMethod(const llvm::Record *def, std::string dedupName);
 
   // Return the return type of this method.
   StringRef getReturnType() const;
@@ -40,6 +40,9 @@ class InterfaceMethod {
   // Return the name of this method.
   StringRef getName() const;
 
+  // Return the dedup name of this method.
+  StringRef getDedupName() const;
+
   // Return if this method is static.
   bool isStatic() const;
 
@@ -62,6 +65,8 @@ class InterfaceMethod {
 
   // The arguments of this method.
   SmallVector<Argument, 2> arguments;
+
+  std::string dedupName;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index ec7adf3b02c21..be39fc3a71cf4 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -25,7 +25,8 @@ using llvm::StringInit;
 // InterfaceMethod
 //===----------------------------------------------------------------------===//
 
-InterfaceMethod::InterfaceMethod(const Record *def) : def(def) {
+InterfaceMethod::InterfaceMethod(const Record *def, std::string dedupName)
+    : def(def), dedupName(dedupName) {
   const DagInit *args = def->getValueAsDag("arguments");
   for (unsigned i = 0, e = args->getNumArgs(); i != e; ++i) {
     arguments.push_back({cast<StringInit>(args->getArg(i))->getValue(),
@@ -42,6 +43,9 @@ StringRef InterfaceMethod::getName() const {
   return def->getValueAsString("name");
 }
 
+// Return the name of this method.
+StringRef InterfaceMethod::getDedupName() const { return dedupName; }
+
 // Return if this method is static.
 bool InterfaceMethod::isStatic() const {
   return def->isSubClassOf("StaticInterfaceMethod");
@@ -83,8 +87,15 @@ Interface::Interface(const Record *def) : def(def) {
 
   // Initialize the interface methods.
   auto *listInit = dyn_cast<ListInit>(def->getValueInit("methods"));
-  for (const Init *init : listInit->getElements())
-    methods.emplace_back(cast<DefInit>(init)->getDef());
+  StringSet<> dedupNames;
+  for (const Init *init : listInit->getElements()) {
+    std::string name =
+        cast<DefInit>(init)->getDef()->getValueAsString("name").str();
+    while (!dedupNames.insert(name).second) {
+      name = name + "_" + std::to_string(dedupNames.size());
+    }
+    methods.emplace_back(cast<DefInit>(init)->getDef(), name);
+  }
 
   // Initialize the interface base classes.
   auto *basesInit = dyn_cast<ListInit>(def->getValueInit("baseInterfaces"));
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index d3d96ea5a65a4..3697e38ac4c7d 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -44,6 +44,16 @@ def TestTypeInterface
     InterfaceMethod<"Prints the type name.",
       "void", "printTypeC", (ins "::mlir::Location":$loc)
     >,
+    // Check that we can have multiple method with the same name.
+    InterfaceMethod<"Prints the type name, with a value prefixed.",
+      "void", "printTypeC", (ins "::mlir::Location":$loc, "int":$value)
+    >,
+    InterfaceMethod<"Prints the type name, with a value prefixed.",
+      "void", "printTypeC", (ins "::mlir::Location":$loc, "float":$value),
+      [{}], /*defaultImplementation=*/[{
+        emitRemark(loc) << $_type << " - " << value << " - Float TestC";
+      }]
+    >,
     // It should be possible to use the interface type name as result type
     // as well as in the implementation.
     InterfaceMethod<"Prints the type name and returns the type as interface.",
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index bea043f56fe21..614121f1d43dd 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -245,6 +245,10 @@ void TestType::printTypeC(Location loc) const {
   emitRemark(loc) << *this << " - TestC";
 }
 
+void TestType::printTypeC(Location loc, int value) const {
+  emitRemark(loc) << *this << " - " << value << " - Int TestC";
+}
+
 //===----------------------------------------------------------------------===//
 // TestTypeWithLayout
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp
index 2dd3fe245e220..881019dbfd50d 100644
--- a/mlir/test/lib/IR/TestInterfaces.cpp
+++ b/mlir/test/lib/IR/TestInterfaces.cpp
@@ -31,6 +31,8 @@ struct TestTypeInterfaces
           testInterface.printTypeA(op->getLoc());
           testInterface.printTypeB(op->getLoc());
           testInterface.printTypeC(op->getLoc());
+          testInterface.printTypeC(op->getLoc(), 42);
+          testInterface.printTypeC(op->getLoc(), 3.14f);
           testInterface.printTypeD(op->getLoc());
           // Just check that we can assign the result to a variable of interface
           // type.
diff --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir
index 5c1ec613b387a..b5d694f75734c 100644
--- a/mlir/test/mlir-tblgen/interfaces.mlir
+++ b/mlir/test/mlir-tblgen/interfaces.mlir
@@ -3,6 +3,8 @@
 // expected-remark at below {{'!test.test_type' - TestA}}
 // expected-remark at below {{'!test.test_type' - TestB}}
 // expected-remark at below {{'!test.test_type' - TestC}}
+// expected-remark at below {{'!test.test_type' - 42 - Int TestC}}
+// expected-remark at below {{'!test.test_type' - 3.140000e+00 - Float TestC}}
 // expected-remark at below {{'!test.test_type' - TestD}}
 // expected-remark at below {{'!test.test_type' - TestRet}}
 // expected-remark at below {{'!test.test_type' - TestE}}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index b9115657d6bf3..15b03b85727f6 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -130,6 +130,9 @@ class DefGen {
   void emitTraitMethods(const InterfaceTrait &trait);
   /// Emit a trait method.
   void emitTraitMethod(const InterfaceMethod &method);
+  /// Generate a using declaration for a trait method.
+  void genTraitMethodUsingDecl(const InterfaceTrait &trait,
+                               const InterfaceMethod &method);
 
   //===--------------------------------------------------------------------===//
   // OpAsm{Type,Attr}Interface Default Method Emission
@@ -176,6 +179,9 @@ class DefGen {
   StringRef valueType;
   /// The prefix/suffix of the TableGen def name, either "Attr" or "Type".
   StringRef defType;
+
+  /// The set of using declarations for trait methods.
+  llvm::StringSet<> interfaceUsingNames;
 };
 } // namespace
 
@@ -632,8 +638,10 @@ void DefGen::emitTraitMethods(const InterfaceTrait &trait) {
     // Don't declare if the method has a body. Or if the method has a default
     // implementation and the def didn't request that it always be declared.
     if (method.getBody() || (method.getDefaultImplementation() &&
-                             !alwaysDeclared.count(method.getName())))
+                             !alwaysDeclared.count(method.getName()))) {
+      genTraitMethodUsingDecl(trait, method);
       continue;
+    }
     emitTraitMethod(method);
   }
 }
@@ -649,6 +657,15 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) {
                    std::move(params));
 }
 
+void DefGen::genTraitMethodUsingDecl(const InterfaceTrait &trait,
+                                     const InterfaceMethod &method) {
+  std::string name = (llvm::Twine(trait.getFullyQualifiedTraitName()) + "<" +
+                      def.getCppClassName() + ">::" + method.getName())
+                         .str();
+  if (interfaceUsingNames.insert(name).second)
+    defCls.declare<UsingDeclaration>(std::move(name));
+}
+
 //===----------------------------------------------------------------------===//
 // OpAsm{Type,Attr}Interface Default Method Emission
 
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 7e8e559baf878..70c462bb667b2 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -789,6 +789,14 @@ class OpEmitter {
   Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
                                bool declaration = true);
 
+  // Generate a `using` declaration for the op interface method to include
+  // the default implementation from the interface trait.
+  // This is needed when the interface defines multiple methods with the same
+  // name, but some have a default implementation and some don't.
+  UsingDeclaration *
+  genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+                                const tblgen::InterfaceMethod &method);
+
   // Generate the side effect interface methods.
   void genSideEffectInterfaceMethods();
 
@@ -815,6 +823,10 @@ class OpEmitter {
 
   // Helper for emitting op code.
   OpOrAdaptorHelper emitHelper;
+
+  // Keep track of the interface using declarations that have been generated to
+  // avoid duplicates.
+  llvm::StringSet<> interfaceUsingNames;
 };
 
 } // namespace
@@ -3672,8 +3684,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {
     // Don't declare if the method has a default implementation and the op
     // didn't request that it always be declared.
     if (method.getDefaultImplementation() &&
-        !alwaysDeclaredMethods.count(method.getName()))
+        !alwaysDeclaredMethods.count(method.getName())) {
+      genOpInterfaceMethodUsingDecl(opTrait, method);
       continue;
+    }
     // Interface methods are allowed to overlap with existing methods, so don't
     // check if pruned.
     (void)genOpInterfaceMethod(method);
@@ -3692,6 +3706,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
                            std::move(paramList));
 }
 
+UsingDeclaration *
+OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait,
+                                         const InterfaceMethod &method) {
+  std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" +
+                      op.getCppClassName() + ">::" + method.getName())
+                         .str();
+  if (interfaceUsingNames.insert(name).second)
+    return opClass.declare<UsingDeclaration>(std::move(name));
+  return nullptr;
+}
+
 void OpEmitter::genOpInterfaceMethods() {
   for (const auto &trait : op.getTraits()) {
     if (const auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 3cc1636ac3317..9dedd55005f87 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -42,10 +42,10 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) {
 /// Emit the method name and argument list for the given method. If 'addThisArg'
 /// is true, then an argument is added to the beginning of the argument list for
 /// the concrete value.
-static void emitMethodNameAndArgs(const InterfaceMethod &method,
+static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name,
                                   raw_ostream &os, StringRef valueType,
                                   bool addThisArg, bool addConst) {
-  os << method.getName() << '(';
+  os << name << '(';
   if (addThisArg) {
     if (addConst)
       os << "const ";
@@ -183,11 +183,13 @@ static void emitInterfaceDefMethods(StringRef interfaceQualName,
     emitInterfaceMethodDoc(method, os);
     emitCPPType(method.getReturnType(), os);
     os << interfaceQualName << "::";
-    emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+    emitMethodNameAndArgs(method, method.getName(), os, valueType,
+                          /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface);
 
     // Forward to the method on the concrete operation type.
-    os << " {\n      return " << implValue << "->" << method.getName() << '(';
+    os << " {\n      return " << implValue << "->" << method.getDedupName()
+       << '(';
     if (!method.isStatic()) {
       os << implValue << ", ";
       os << (isOpInterface ? "getOperation()" : "*this");
@@ -239,7 +241,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) {
   for (auto &method : interface.getMethods()) {
     os << "    ";
     emitCPPType(method.getReturnType(), os);
-    os << "(*" << method.getName() << ")(";
+    os << "(*" << method.getDedupName() << ")(";
     if (!method.isStatic()) {
       os << "const Concept *impl, ";
       emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", ");
@@ -289,13 +291,13 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
     os << "    " << modelClass << "() : Concept{";
     llvm::interleaveComma(
         interface.getMethods(), os,
-        [&](const InterfaceMethod &method) { os << method.getName(); });
+        [&](const InterfaceMethod &method) { os << method.getDedupName(); });
     os << "} {}\n\n";
 
     // Insert each of the virtual method overrides.
     for (auto &method : interface.getMethods()) {
       emitCPPType(method.getReturnType(), os << "    static inline ");
-      emitMethodNameAndArgs(method, os, valueType,
+      emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
                             /*addThisArg=*/!method.isStatic(),
                             /*addConst=*/false);
       os << ";\n";
@@ -319,7 +321,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
     if (method.isStatic())
       os << "static ";
     emitCPPType(method.getReturnType(), os);
-    os << method.getName() << "(";
+    os << method.getDedupName() << "(";
     if (!method.isStatic()) {
       emitCPPType(valueType, os);
       os << "tablegen_opaque_val";
@@ -350,7 +352,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
     emitCPPType(method.getReturnType(), os);
     os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
        << valueTemplate << ">::";
-    emitMethodNameAndArgs(method, os, valueType,
+    emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
                           /*addThisArg=*/!method.isStatic(),
                           /*addConst=*/false);
     os << " {\n  ";
@@ -384,7 +386,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
     emitCPPType(method.getReturnType(), os);
     os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<"
        << valueTemplate << ">::";
-    emitMethodNameAndArgs(method, os, valueType,
+    emitMethodNameAndArgs(method, method.getDedupName(), os, valueType,
                           /*addThisArg=*/!method.isStatic(),
                           /*addConst=*/false);
     os << " {\n  ";
@@ -396,7 +398,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
       os << "return static_cast<const " << valueTemplate << " *>(impl)->";
 
     // Add the arguments to the call.
-    os << method.getName() << '(';
+    os << method.getDedupName() << '(';
     if (!method.isStatic())
       os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", ");
     llvm::interleaveComma(
@@ -416,7 +418,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
        << "InterfaceTraits::ExternalModel<ConcreteModel, " << valueTemplate
        << ">::";
 
-    os << method.getName() << "(";
+    os << method.getDedupName() << "(";
     if (!method.isStatic()) {
       emitCPPType(valueType, os);
       os << "tablegen_opaque_val";
@@ -477,7 +479,8 @@ void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) {
     emitInterfaceMethodDoc(method, os, "    ");
     os << "    " << (method.isStatic() ? "static " : "");
     emitCPPType(method.getReturnType(), os);
-    emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+    emitMethodNameAndArgs(method, method.getName(), os, valueType,
+                          /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface && !method.isStatic());
     os << " {\n      " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt)
        << "\n    }\n";
@@ -514,7 +517,8 @@ static void emitInterfaceDeclMethods(const Interface &interface,
   for (auto &method : interface.getMethods()) {
     emitInterfaceMethodDoc(method, os, "  ");
     emitCPPType(method.getReturnType(), os << "  ");
-    emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false,
+    emitMethodNameAndArgs(method, method.getName(), os, valueType,
+                          /*addThisArg=*/false,
                           /*addConst=*/!isOpInterface);
     os << ";\n";
   }



More information about the Mlir-commits mailing list