[Mlir-commits] [mlir] edae8f6 - [mlir] Make `classof` substitution in interface use an instance (#65492)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 7 00:47:01 PDT 2023


Author: Markus Böck
Date: 2023-09-07T09:46:57+02:00
New Revision: edae8f6ce29a980d83761f59f81b88167a0fd815

URL: https://github.com/llvm/llvm-project/commit/edae8f6ce29a980d83761f59f81b88167a0fd815
DIFF: https://github.com/llvm/llvm-project/commit/edae8f6ce29a980d83761f59f81b88167a0fd815.diff

LOG: [mlir] Make `classof` substitution in interface use an instance (#65492)

The substitution supported by `extraClassOf` is currently limited to
only the base instance, i.e. `Operation*`, `Type` or `Attribute`, which
limits the kind of checks you can perform in the `classof`
implementation.

Since prior to the user code, the interface concept is fetched, we can
use it to construct an instance of the interface, allowing use of its
methods in the `classof` check.

Since an instance of the interface allows access to the base class
methods through the `->` operator, I've gone ahead and replaced the
substitution of `$_op/$_type/$_attr` with an interface instance. This is
also consistent with `extraSharedClassDeclaration` and other methods
created in the interface class which do the same.

Added: 
    

Modified: 
    mlir/include/mlir/IR/Interfaces.td
    mlir/test/lib/Dialect/Test/TestInterfaces.td
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/op-interface.td
    mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
    mlir/unittests/IR/InterfaceTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td
index 4b95ee9de4c270..0cbe3fa25c9e70 100644
--- a/mlir/include/mlir/IR/Interfaces.td
+++ b/mlir/include/mlir/IR/Interfaces.td
@@ -114,7 +114,7 @@ class Interface<string name, list<Interface> baseInterfacesArg = []> {
   // be used to better enable "optional" interfaces, where an entity only
   // implements the interface if some dynamic characteristic holds.
   // `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the
-  // entity being checked.
+  // interface instance being checked.
   code extraClassOf = "";
 
   // An optional set of base interfaces that this interface

diff  --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td
index d7ca7089f8f988..e2ed27bdf0203e 100644
--- a/mlir/test/lib/Dialect/Test/TestInterfaces.td
+++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td
@@ -134,4 +134,17 @@ class TestEffects<list<TestEffect> effects = []>
 
 def TestConcreteEffect : TestEffect<"TestEffects::Concrete">;
 
+def TestOptionallyImplementedOpInterface
+    : OpInterface<"TestOptionallyImplementedOpInterface"> {
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<"", "bool", "getImplementsInterface", (ins)>,
+  ];
+
+  let extraClassOf = [{
+    return $_op.getImplementsInterface();
+  }];
+}
+
 #endif // MLIR_TEST_DIALECT_TEST_INTERFACES

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 3c9f230f60c56f..2d7f5b0043ba0f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2952,4 +2952,10 @@ def TestStoreWithARegionTerminator : TEST_Op<"store_with_a_region_terminator",
   let assemblyFormat = "attr-dict";
 }
 
+def TestOpOptionallyImplementingInterface
+    : TEST_Op<"op_optionally_implementing_interface",
+        [TestOptionallyImplementedOpInterface]> {
+  let arguments = (ins BoolAttr:$implementsInterface);
+}
+
 #endif // TEST_OPS

diff  --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index feffe2097dbdaa..6ca9f15bd02209 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -11,9 +11,11 @@ def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> {
 
 // DECL: class ExtraClassOfInterface
 // DECL:   static bool classof(::mlir::Operation * base) {
-// DECL-NEXT:     if (!getInterfaceFor(base))
+// DECL-NEXT:     auto* concept = getInterfaceFor(base);
+// DECL-NEXT:     if (!concept)
 // DECL-NEXT:       return false;
-// DECL-NEXT:     return base->someOtherMethod();
+// DECL-NEXT:     ExtraClassOfInterface odsInterfaceInstance(base, concept);
+// DECL-NEXT:     return odsInterfaceInstance->someOtherMethod();
 // DECL-NEXT:   }
 
 def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {

diff  --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 7b83a1eb69753e..bdc8482ce5d272 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -582,10 +582,12 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
   // Emit classof code if necessary.
   if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
     auto extraClassOfFmt = tblgen::FmtContext();
-    extraClassOfFmt.addSubst(substVar, "base");
+    extraClassOfFmt.addSubst(substVar, "odsInterfaceInstance");
     os << "  static bool classof(" << valueType << " base) {\n"
-       << "    if (!getInterfaceFor(base))\n"
+       << "    auto* concept = getInterfaceFor(base);\n"
+       << "    if (!concept)\n"
           "      return false;\n"
+          "    " << interfaceName << " odsInterfaceInstance(base, concept);\n"
        << "    " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
        << "\n  }\n";
   }

diff  --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp
index 9b20d3c1219e63..2be9e70dd59e8f 100644
--- a/mlir/unittests/IR/InterfaceTest.cpp
+++ b/mlir/unittests/IR/InterfaceTest.cpp
@@ -57,3 +57,15 @@ TEST(InterfaceTest, TypeInterfaceDenseMapKey) {
   EXPECT_TRUE(typeSet.contains(type2));
   EXPECT_FALSE(typeSet.contains(type3));
 }
+
+TEST(InterfaceTest, TestCustomClassOf) {
+  MLIRContext context;
+  context.loadDialect<test::TestDialect>();
+
+  OpBuilder builder(&context);
+  auto op = builder.create<TestOpOptionallyImplementingInterface>(
+      builder.getUnknownLoc(), /*implementsInterface=*/true);
+  EXPECT_TRUE(isa<TestOptionallyImplementedOpInterface>(*op));
+  op.setImplementsInterface(false);
+  EXPECT_FALSE(isa<TestOptionallyImplementedOpInterface>(*op));
+}


        


More information about the Mlir-commits mailing list