[Mlir-commits] [mlir] [MLIR][ODS] Allow operations to specify interfaces using the `HasParent` trait constraint (PR #66196)

Morten Borup Petersen llvmlistbot at llvm.org
Tue Sep 19 00:25:09 PDT 2023


https://github.com/mortbopet updated https://github.com/llvm/llvm-project/pull/66196

>From 25b8553a1ce58a961c610e1d4a6fb61bdfd8286f Mon Sep 17 00:00:00 2001
From: Morten Borup Petersen <morten_bp at live.dk>
Date: Wed, 13 Sep 2023 11:46:31 +0000
Subject: [PATCH 1/3] [MLIR] Allow operations to have interfaces as parents

... by emitting an operation name for interfaces. The name is only emitted for the parent interface (i.e. base interfaces are not considered).

This change is needed by https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/OpDefinition.h#L1324C38-L1324C38 wherein _some_ name needs to be reported to the user.

... should interfaces be identified by `getOperationName` and not i.e. `getInterfaceName`? it's a possibility, but would obviously complicate code which wants to treat operation types and interface types as equals.
---
 mlir/test/mlir-tblgen/op-interface.td      |  4 ++++
 mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 16 +++++++++++++++-
 2 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 17bd631fe250d16..80878c9b3205176 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -167,6 +167,10 @@ def DeclareMethodsOp : Op<TestDialect, "declare_methods_op",
 def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
       [DeclareOpInterfaceMethods<TestOpInterface, ["default_foo"]>]>;
 
+
+// DECL: /// Returns the name of this interface.
+// DECL: static ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral( "TestOpInterface"); }
+
 // DECL-LABEL: TestOpInterfaceInterfaceTraits
 // DECL: class TestOpInterface : public ::mlir::OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
 
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 9672a02cc08f68c..153543ab083525e 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -494,6 +494,16 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
   os << "  };\n";
 }
 
+static void emitInterfaceNameGetter(const Interface &interface,
+                                    raw_ostream &os) {
+  if (!isa<OpInterface>(interface))
+    return;
+  os << "  /// Returns the name of this interface.\n"
+     << "  static ::llvm::StringLiteral getOperationName() { return "
+        "::llvm::StringLiteral( \""
+     << interface.getName() << "\"); }\n";
+}
+
 static void emitInterfaceDeclMethods(const Interface &interface,
                                      raw_ostream &os, StringRef valueType,
                                      bool isOpInterface,
@@ -553,6 +563,9 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
                       "  struct Trait : public detail::{0}Trait<{1}> {{};\n",
                       interfaceName, valueTemplate);
 
+  // Emit the name of the interface.
+  emitInterfaceNameGetter(interface, os);
+
   // Insert the method declarations.
   bool isOpInterface = isa<OpInterface>(interface);
   emitInterfaceDeclMethods(interface, os, valueType, isOpInterface,
@@ -588,7 +601,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
        << "    auto* interface = getInterfaceFor(base);\n"
        << "    if (!interface)\n"
           "      return false;\n"
-          "    " << interfaceName << " odsInterfaceInstance(base, interface);\n"
+          "    "
+       << interfaceName << " odsInterfaceInstance(base, interface);\n"
        << "    " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
        << "\n  }\n";
   }

>From e0ad57ed68da274f28adb26d7bf16cea9b285730 Mon Sep 17 00:00:00 2001
From: Morten Borup Petersen <morten_bp at live.dk>
Date: Wed, 13 Sep 2023 12:20:12 +0000
Subject: [PATCH 2/3] fix test

---
 mlir/test/mlir-tblgen/op-interface.td | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 80878c9b3205176..534cac07084539e 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -168,15 +168,16 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
       [DeclareOpInterfaceMethods<TestOpInterface, ["default_foo"]>]>;
 
 
-// DECL: /// Returns the name of this interface.
-// DECL: static ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral( "TestOpInterface"); }
-
 // DECL-LABEL: TestOpInterfaceInterfaceTraits
 // DECL: class TestOpInterface : public ::mlir::OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
 
+// DECL: /// Returns the name of this interface.
+// DECL: static ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral( "TestOpInterface"); }
+
 // DECL: /// some function comment
 // DECL: int foo(int input);
 
+
 // DECL-LABEL: struct TestOpInterfaceVerifyTrait
 // DECL: verifyTrait
 

>From 3d318bb95ebbc45c84349e651b43ae7c2d881dfa Mon Sep 17 00:00:00 2001
From: Morten Borup Petersen <morten_bp at live.dk>
Date: Tue, 19 Sep 2023 07:24:56 +0000
Subject: [PATCH 3/3] getOperationName -> getInterfaceName, add
 HasInterfaceParent trait, TestDialect tests

---
 mlir/include/mlir/IR/OpBase.td               |  5 +++
 mlir/include/mlir/IR/OpDefinition.h          | 38 ++++++++++++++------
 mlir/test/IR/traits.mlir                     | 18 ++++++++++
 mlir/test/lib/Dialect/Test/TestInterfaces.td |  4 +++
 mlir/test/lib/Dialect/Test/TestOps.td        |  8 +++++
 mlir/test/mlir-tblgen/op-interface.td        |  2 +-
 mlir/tools/mlir-tblgen/OpInterfacesGen.cpp   |  2 +-
 7 files changed, 65 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 236dd74839dfb04..c5e2c2c8f871f33 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -141,6 +141,11 @@ class ParentOneOf<list<string> ops>
     : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>,
       StructuralOpTrait;
 
+// Op's parent operation implements the provided interface.
+class HasInterfaceParent<string interface>
+    : ParamNativeOpTrait<"HasInterfaceParent", interface>,
+      StructuralOpTrait;
+
 // Op result type is derived from the first attribute. If the attribute is an
 // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
 // attribute content is used.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 895f17dfe1d07c8..07302525e8f3e2d 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1334,6 +1334,24 @@ struct HasParent {
   };
 };
 
+/// This class provides a verifier for ops that are expecting their parent to
+/// implement a specific interface.
+template <typename ParentInterfaceType>
+struct HasInterfaceParent {
+  template <typename ConcreteType>
+  class Impl : public TraitBase<ConcreteType, Impl> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      if (llvm::isa_and_nonnull<ParentInterfaceType>(op->getParentOp()))
+        return success();
+
+      return op->emitOpError()
+             << "expects parent op to implement interface '"
+             << ParentInterfaceType::getInterfaceName() << "'";
+    }
+  };
+};
+
 /// A trait for operations that have an attribute specifying operand segments.
 ///
 /// Certain operations can have multiple variadic operands and their size
@@ -1802,9 +1820,9 @@ class Op : public OpState, public Traits<ConcreteType>... {
       llvm::is_detected<has_single_result_fold_t, T>::value;
   /// Trait to check if T provides a general 'fold' method.
   template <typename T, typename... Args>
-  using has_fold_t = decltype(std::declval<T>().fold(
-      std::declval<ArrayRef<Attribute>>(),
-      std::declval<SmallVectorImpl<OpFoldResult> &>()));
+  using has_fold_t = decltype(
+      std::declval<T>().fold(std::declval<ArrayRef<Attribute>>(),
+                             std::declval<SmallVectorImpl<OpFoldResult> &>()));
   template <typename T>
   constexpr static bool has_fold_v = llvm::is_detected<has_fold_t, T>::value;
   /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a
@@ -1817,9 +1835,9 @@ class Op : public OpState, public Traits<ConcreteType>... {
       llvm::is_detected<has_fold_adaptor_single_result_fold_t, T>::value;
   /// Trait to check if T provides a general 'fold' method with a FoldAdaptor.
   template <typename T, typename... Args>
-  using has_fold_adaptor_fold_t = decltype(std::declval<T>().fold(
-      std::declval<typename T::FoldAdaptor>(),
-      std::declval<SmallVectorImpl<OpFoldResult> &>()));
+  using has_fold_adaptor_fold_t = decltype(
+      std::declval<T>().fold(std::declval<typename T::FoldAdaptor>(),
+                             std::declval<SmallVectorImpl<OpFoldResult> &>()));
   template <class T>
   constexpr static bool has_fold_adaptor_v =
       llvm::is_detected<has_fold_adaptor_fold_t, T>::value;
@@ -1833,16 +1851,16 @@ class Op : public OpState, public Traits<ConcreteType>... {
 
   /// Trait to check if printProperties(OpAsmPrinter, T) exist
   template <typename T, typename... Args>
-  using has_print_properties = decltype(printProperties(
-      std::declval<OpAsmPrinter &>(), std::declval<T>()));
+  using has_print_properties = decltype(
+      printProperties(std::declval<OpAsmPrinter &>(), std::declval<T>()));
   template <typename T>
   using detect_has_print_properties =
       llvm::is_detected<has_print_properties, T>;
 
   /// Trait to check if parseProperties(OpAsmParser, T) exist
   template <typename T, typename... Args>
-  using has_parse_properties = decltype(parseProperties(
-      std::declval<OpAsmParser &>(), std::declval<T &>()));
+  using has_parse_properties = decltype(
+      parseProperties(std::declval<OpAsmParser &>(), std::declval<T &>()));
   template <typename T>
   using detect_has_parse_properties =
       llvm::is_detected<has_parse_properties, T>;
diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir
index 0402ebe75875086..079ddff344dbb05 100644
--- a/mlir/test/IR/traits.mlir
+++ b/mlir/test/IR/traits.mlir
@@ -668,3 +668,21 @@ func.func @failed_attr_traits() {
   "test.attr_with_trait"() {attr = 42 : i32} : () -> ()
   return
 }
+
+// -----
+
+func.func @has_interface_parent_trait() {
+  // CHECK: "test.interface_parent"() ({
+  // CHECK:    "test.interface_child"() : () -> ()
+  "test.interface_parent"() ({
+    "test.interface_child"() : () -> ()
+  }) : () -> ()
+  return
+}
+
+// -----
+
+func.func @illegal_interface_parent_trait() {
+  // expected-error at +1 {{'test.interface_child' op expects parent op to implement interface 'TestInterfaceParentInterface'}}
+  "test.interface_child"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index e2ed27bdf0203e3..32a438efaddf764 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -147,4 +147,8 @@ def TestOptionallyImplementedOpInterface
   }];
 }
 
+def TestInterfaceParentInterface : OpInterface<"TestInterfaceParentInterface"> {
+    let cppNamespace = "::mlir";
+}
+
 #endif // MLIR_TEST_DIALECT_TEST_INTERFACES
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 9ceadab8fa4a086..99a309168892ab6 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -750,6 +750,14 @@ def ParentOp1 : TEST_Op<"parent1"> {
 def ChildWithParentOneOf : TEST_Op<"child_with_parent_one_of",
                                 [ParentOneOf<["ParentOp", "ParentOp1"]>]>;
 
+// HasInterfaceParent trait
+def InterfaceParentOp : TEST_Op<"interface_parent",
+                [SingleBlock, NoTerminator, TestInterfaceParentInterface]> {
+  let regions = (region SizedRegion<1>:$region);
+}
+def InterfaceChildOp : TEST_Op<"interface_child",
+                               [HasInterfaceParent<"mlir::TestInterfaceParentInterface">]>;
+
 def TerminatorOp : TEST_Op<"finish", [Terminator]>;
 def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",
     [SingleBlockImplicitTerminator<"TerminatorOp">]> {
diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index 534cac07084539e..fc6c793d08037c3 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -172,7 +172,7 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
 // DECL: class TestOpInterface : public ::mlir::OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
 
 // DECL: /// Returns the name of this interface.
-// DECL: static ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral( "TestOpInterface"); }
+// DECL: static ::llvm::StringLiteral getInterfaceName() { return ::llvm::StringLiteral( "TestOpInterface"); }
 
 // DECL: /// some function comment
 // DECL: int foo(int input);
diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 153543ab083525e..fab6e03b3e6a8c3 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -499,7 +499,7 @@ static void emitInterfaceNameGetter(const Interface &interface,
   if (!isa<OpInterface>(interface))
     return;
   os << "  /// Returns the name of this interface.\n"
-     << "  static ::llvm::StringLiteral getOperationName() { return "
+     << "  static ::llvm::StringLiteral getInterfaceName() { return "
         "::llvm::StringLiteral( \""
      << interface.getName() << "\"); }\n";
 }



More information about the Mlir-commits mailing list