[Mlir-commits] [mlir] [MLIR][ODS] Allow operations to specify interfaces using the `HasParent` trait constraint (PR #66196)

Morten Borup Petersen llvmlistbot at llvm.org
Tue Sep 19 00:28:09 PDT 2023


https://github.com/mortbopet updated https://github.com/llvm/llvm-project/pull/66196

>From 25b8553a1ce58a961c610e1d4a6fb61bdfd8286f Mon Sep 17 00:00:00 2001
From: Morten Borup Petersen <morten_bp at live.dk>
Date: Wed, 13 Sep 2023 11:46:31 +0000
Subject: [PATCH 1/3] [MLIR] Allow operations to have interfaces as parents

... by emitting an operation name for interfaces. The name is only emitted for the parent interface (i.e. base interfaces are not considered).

This change is needed by https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/OpDefinition.h#L1324C38-L1324C38 wherein _some_ name needs to be reported to the user.

... should interfaces be identified by `getOperationName` and not i.e. `getInterfaceName`? it's a possibility, but would obviously complicate code which wants to treat operation types and interface types as equals.
---
 mlir/test/mlir-tblgen/op-interface.td      |  4 ++++
 mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 16 +++++++++++++++-
 2 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 17bd631fe250d16..80878c9b3205176 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -167,6 +167,10 @@ def DeclareMethodsOp : Op<TestDialect, "declare_methods_op",
 def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
       [DeclareOpInterfaceMethods<TestOpInterface, ["default_foo"]>]>;
 
+
+// DECL: /// Returns the name of this interface.
+// DECL: static ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral( "TestOpInterface"); }
+
 // DECL-LABEL: TestOpInterfaceInterfaceTraits
 // DECL: class TestOpInterface : public ::mlir::OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
 
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 9672a02cc08f68c..153543ab083525e 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -494,6 +494,16 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
   os << "  };\n";
 }
 
+static void emitInterfaceNameGetter(const Interface &interface,
+                                    raw_ostream &os) {
+  if (!isa<OpInterface>(interface))
+    return;
+  os << "  /// Returns the name of this interface.\n"
+     << "  static ::llvm::StringLiteral getOperationName() { return "
+        "::llvm::StringLiteral( \""
+     << interface.getName() << "\"); }\n";
+}
+
 static void emitInterfaceDeclMethods(const Interface &interface,
                                      raw_ostream &os, StringRef valueType,
                                      bool isOpInterface,
@@ -553,6 +563,9 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
                       "  struct Trait : public detail::{0}Trait<{1}> {{};\n",
                       interfaceName, valueTemplate);
 
+  // Emit the name of the interface.
+  emitInterfaceNameGetter(interface, os);
+
   // Insert the method declarations.
   bool isOpInterface = isa<OpInterface>(interface);
   emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
@@ -588,7 +601,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
        << "    auto* interface = getInterfaceFor(base);\n"
        << "    if (!interface)\n"
           "      return false;\n"
-          "    " << interfaceName << " odsInterfaceInstance(base, interface);\n"
+          "    "
+       << interfaceName << " odsInterfaceInstance(base, interface);\n"
        << "    " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
        << "\n  }\n";
   }

>From e0ad57ed68da274f28adb26d7bf16cea9b285730 Mon Sep 17 00:00:00 2001
From: Morten Borup Petersen <morten_bp at live.dk>
Date: Wed, 13 Sep 2023 12:20:12 +0000
Subject: [PATCH 2/3] fix test

---
 mlir/test/mlir-tblgen/op-interface.td | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 80878c9b3205176..534cac07084539e 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -168,15 +168,16 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
       [DeclareOpInterfaceMethods<TestOpInterface, ["default_foo"]>]>;
 
 
-// DECL: /// Returns the name of this interface.
-// DECL: static ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral( "TestOpInterface"); }
-
 // DECL-LABEL: TestOpInterfaceInterfaceTraits
 // DECL: class TestOpInterface : public ::mlir::OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
 
+// DECL: /// Returns the name of this interface.
+// DECL: static ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral( "TestOpInterface"); }
+
 // DECL: /// some function comment
 // DECL: int foo(int input);
 
+
 // DECL-LABEL: struct TestOpInterfaceVerifyTrait
 // DECL: verifyTrait
 

>From 4e1f546869ad26fa1cb7413ff892a93a0f63bb15 Mon Sep 17 00:00:00 2001
From: Morten Borup Petersen <morten_bp at live.dk>
Date: Tue, 19 Sep 2023 07:24:56 +0000
Subject: [PATCH 3/3] getOperationName -> getInterfaceName, add
 HasInterfaceParent trait, TestDialect tests

---
 mlir/include/mlir/IR/OpBase.td               |  5 +++++
 mlir/include/mlir/IR/OpDefinition.h          | 18 ++++++++++++++++++
 mlir/test/IR/traits.mlir                     | 18 ++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestInterfaces.td |  4 ++++
 mlir/test/lib/Dialect/Test/TestOps.td        |  8 ++++++++
 mlir/test/mlir-tblgen/op-interface.td        |  2 +-
 mlir/tools/mlir-tblgen/OpInterfacesGen.cpp   |  2 +-
 7 files changed, 55 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 236dd74839dfb04..c5e2c2c8f871f33 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -141,6 +141,11 @@ class ParentOneOf<list<string> ops>
     : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>,
       StructuralOpTrait;
 
+// Op's parent operation implements the provided interface.
+class HasInterfaceParent<string interface>
+    : ParamNativeOpTrait<"HasInterfaceParent", interface>,
+      StructuralOpTrait;
+
 // Op result type is derived from the first attribute. If the attribute is an
 // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
 // attribute content is used.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 895f17dfe1d07c8..73d2b21c7bbcf1d 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1334,6 +1334,24 @@ struct HasParent {
   };
 };
 
+/// This class provides a verifier for ops that are expecting their parent to
+/// implement a specific interface.
+template <typename ParentInterfaceType>
+struct HasInterfaceParent {
+  template <typename ConcreteType>
+  class Impl : public TraitBase<ConcreteType, Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      if (llvm::isa_and_nonnull<ParentInterfaceType>(op->getParentOp()))
+        return success();
+
+      return op->emitOpError()
+             << "expects parent op to implement interface '"
+             << ParentInterfaceType::getInterfaceName() << "'";
+    }
+  };
+};
+
 /// A trait for operations that have an attribute specifying operand segments.
 ///
 /// Certain operations can have multiple variadic operands and their size
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 0402ebe75875086..079ddff344dbb05 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -668,3 +668,21 @@ func.func @failed_attr_traits() {
   "test.attr_with_trait"() {attr = 42 : i32} : () -> ()
   return
 }
+
+// -----
+
+func.func @has_interface_parent_trait() {
+  // CHECK: "test.interface_parent"() ({
+  // CHECK:    "test.interface_child"() : () -> ()
+  "test.interface_parent"() ({
+    "test.interface_child"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+func.func @illegal_interface_parent_trait() {
+  // expected-error at +1 {{'test.interface_child' op expects parent op to implement interface 'TestInterfaceParentInterface'}}
+  "test.interface_child"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index e2ed27bdf0203e3..32a438efaddf764 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -147,4 +147,8 @@ def TestOptionallyImplementedOpInterface
   }];
 }
 
+def TestInterfaceParentInterface : OpInterface<"TestInterfaceParentInterface"> {
+    let cppNamespace = "::mlir";
+}
+
 #endif // MLIR_TEST_DIALECT_TEST_INTERFACES
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9ceadab8fa4a086..99a309168892ab6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -750,6 +750,14 @@ def ParentOp1 : TEST_Op<"parent1"> {
 def ChildWithParentOneOf : TEST_Op<"child_with_parent_one_of",
                                 [ParentOneOf<["ParentOp", "ParentOp1"]>]>;
 
+// HasInterfaceParent trait
+def InterfaceParentOp : TEST_Op<"interface_parent",
+                [SingleBlock, NoTerminator, TestInterfaceParentInterface]> {
+  let regions = (region SizedRegion<1>:$region);
+}
+def InterfaceChildOp : TEST_Op<"interface_child",
+                               [HasInterfaceParent<"mlir::TestInterfaceParentInterface">]>;
+
 def TerminatorOp : TEST_Op<"finish", [Terminator]>;
 def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",
     [SingleBlockImplicitTerminator<"TerminatorOp">]> {
diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 534cac07084539e..fc6c793d08037c3 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -172,7 +172,7 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
 // DECL: class TestOpInterface : public ::mlir::OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
 
 // DECL: /// Returns the name of this interface.
-// DECL: static ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral( "TestOpInterface"); }
+// DECL: static ::llvm::StringLiteral getInterfaceName() { return ::llvm::StringLiteral( "TestOpInterface"); }
 
 // DECL: /// some function comment
 // DECL: int foo(int input);
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 153543ab083525e..fab6e03b3e6a8c3 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -499,7 +499,7 @@ static void emitInterfaceNameGetter(const Interface &interface,
   if (!isa<OpInterface>(interface))
     return;
   os << "  /// Returns the name of this interface.\n"
-     << "  static ::llvm::StringLiteral getOperationName() { return "
+     << "  static ::llvm::StringLiteral getInterfaceName() { return "
         "::llvm::StringLiteral( \""
      << interface.getName() << "\"); }\n";
 }



More information about the Mlir-commits mailing list