[Mlir-commits] [mlir] 5cdc2bb - [mlir] Move SymbolOpInterfaces "classof" check to a proper "extraClassOf" interface field

River Riddle llvmlistbot at llvm.org
Wed Jan 18 19:16:48 PST 2023


Author: River Riddle
Date: 2023-01-18T19:16:30-08:00
New Revision: 5cdc2bbc7588e7b046ac5c7f79a84ef18978a83a

URL: https://github.com/llvm/llvm-project/commit/5cdc2bbc7588e7b046ac5c7f79a84ef18978a83a
DIFF: https://github.com/llvm/llvm-project/commit/5cdc2bbc7588e7b046ac5c7f79a84ef18978a83a.diff

LOG: [mlir] Move SymbolOpInterfaces "classof" check to a proper "extraClassOf" interface field

SymbolOpInterface overrides the base classof to provide support
for optionally implementing the interface. This is currently placed
in the extraClassDeclarations, but that is kind of awkard given that
it requires underlying knowledge of how the base classof is implemented.
This commit adds a proper "extraClassOf" field to allow interfaces to
implement this, which abstracts away the default classof logic.

Differential Revision: https://reviews.llvm.org/D140197

Added: 
    

Modified: 
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/IR/SymbolInterfaces.td
    mlir/include/mlir/Support/InterfaceSupport.h
    mlir/include/mlir/TableGen/Format.h
    mlir/include/mlir/TableGen/Interfaces.h
    mlir/lib/TableGen/CodeGenHelpers.cpp
    mlir/lib/TableGen/Format.cpp
    mlir/lib/TableGen/Interfaces.cpp
    mlir/test/mlir-tblgen/op-interface.td
    mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
    mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
    mlir/unittests/TableGen/FormatTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 00b70ab3de6a3..ea9976ca2e24f 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -2048,6 +2048,13 @@ class Interface<string name> {
   // An optional code block containing extra declarations to place in both
   // the interface and trait declaration.
   code extraSharedClassDeclaration = "";
+
+  // An optional code block for adding additional "classof" logic. This can
+  // be used to better enable "optional" interfaces, where an entity only
+  // implements the interface if some dynamic characteristic holds.
+  // `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the
+  // entity being checked.
+  code extraClassOf = "";
 }
 
 // AttrInterface represents an interface registered to an attribute.

diff  --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td
index 50737746f3d6d..a3a6833a475b8 100644
--- a/mlir/include/mlir/IR/SymbolInterfaces.td
+++ b/mlir/include/mlir/IR/SymbolInterfaces.td
@@ -174,28 +174,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
     return success();
   }];
 
-  let extraClassDeclaration = [{
-    /// Convenience version of `getNameAttr` that returns a StringRef.
-    StringRef getName() {
-      return getNameAttr().getValue();
-    }
-
-    /// Convenience version of `setName` that take a StringRef.
-    void setName(StringRef name) {
-      setName(StringAttr::get(this->getContext(), name));
-    }
-
-    /// Custom classof that handles the case where the symbol is optional.
-    static bool classof(Operation *op) {
-      auto *opConcept = getInterfaceFor(op);
-      if (!opConcept)
-        return false;
-      return !opConcept->isOptionalSymbol(opConcept, op) ||
-             op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
-    }
-  }];
-
-  let extraTraitClassDeclaration = [{
+  let extraSharedClassDeclaration = [{
     using Visibility = mlir::SymbolTable::Visibility;
 
     /// Convenience version of `getNameAttr` that returns a StringRef.
@@ -208,6 +187,11 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
       setName(StringAttr::get($_op->getContext(), name));
     }
   }];
+
+  // Add additional classof checks to properly handle "optional" symbols.
+  let extraClassOf = [{
+    return $_op->hasAttr(::mlir::SymbolTable::getSymbolAttrName());
+  }];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Support/InterfaceSupport.h b/mlir/include/mlir/Support/InterfaceSupport.h
index d8f63e08ea1dc..6ba7b33189843 100644
--- a/mlir/include/mlir/Support/InterfaceSupport.h
+++ b/mlir/include/mlir/Support/InterfaceSupport.h
@@ -110,6 +110,12 @@ class Interface : public BaseType {
            "expected value to provide interface instance");
   }
 
+  /// Constructor for a known concept.
+  Interface(ValueT t, Concept *conceptImpl)
+      : BaseType(t), conceptImpl(conceptImpl) {
+    assert(!t || ConcreteType::getInterfaceFor(t) == conceptImpl);
+  }
+
   /// Constructor for DenseMapInfo's empty key and tombstone key.
   Interface(ValueT t, std::nullptr_t) : BaseType(t), conceptImpl(nullptr) {}
 

diff  --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h
index 60d5887ffcbb1..79d3d26a9d68d 100644
--- a/mlir/include/mlir/TableGen/Format.h
+++ b/mlir/include/mlir/TableGen/Format.h
@@ -44,7 +44,6 @@ class FmtContext {
     None,
     Custom,  // For custom placeholders
     Builder, // For the $_builder placeholder
-    Op,      // For the $_op placeholder
     Self,    // For the $_self placeholder
   };
 
@@ -58,7 +57,6 @@ class FmtContext {
 
   // Setters for builtin placeholders
   FmtContext &withBuilder(Twine subst);
-  FmtContext &withOp(Twine subst);
   FmtContext &withSelf(Twine subst);
 
   std::optional<StringRef> getSubstFor(PHKind placeholder) const;

diff  --git a/mlir/include/mlir/TableGen/Interfaces.h b/mlir/include/mlir/TableGen/Interfaces.h
index aeef36087fb8f..7168f130c73a7 100644
--- a/mlir/include/mlir/TableGen/Interfaces.h
+++ b/mlir/include/mlir/TableGen/Interfaces.h
@@ -95,6 +95,9 @@ class Interface {
   // trait classes.
   std::optional<StringRef> getExtraSharedClassDeclaration() const;
 
+  // Return the extra classof method code.
+  std::optional<StringRef> getExtraClassOf() const;
+
   // Return the verify method body if it has one.
   std::optional<StringRef> getVerify() const;
 

diff  --git a/mlir/lib/TableGen/CodeGenHelpers.cpp b/mlir/lib/TableGen/CodeGenHelpers.cpp
index 5caefb416ea2f..193e8c1ce374b 100644
--- a/mlir/lib/TableGen/CodeGenHelpers.cpp
+++ b/mlir/lib/TableGen/CodeGenHelpers.cpp
@@ -190,7 +190,7 @@ void StaticVerifierFunctionEmitter::emitConstraints(
     const ConstraintMap &constraints, StringRef selfName,
     const char *const codeTemplate) {
   FmtContext ctx;
-  ctx.withOp("*op").withSelf(selfName);
+  ctx.addSubst("_op", "*op").withSelf(selfName);
   for (auto &it : constraints) {
     os << formatv(codeTemplate, it.second,
                   tgfmt(it.first.getConditionTemplate(), &ctx),
@@ -216,7 +216,7 @@ void StaticVerifierFunctionEmitter::emitRegionConstraints() {
 
 void StaticVerifierFunctionEmitter::emitPatternConstraints() {
   FmtContext ctx;
-  ctx.withOp("*op").withBuilder("rewriter").withSelf("type");
+  ctx.addSubst("_op", "*op").withBuilder("rewriter").withSelf("type");
   for (auto &it : typeConstraints) {
     os << formatv(patternAttrOrTypeConstraintCode, it.second,
                   tgfmt(it.first.getConditionTemplate(), &ctx),
@@ -240,9 +240,9 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
 /// because ops use cached identifiers.
 static bool canUniqueAttrConstraint(Attribute attr) {
   FmtContext ctx;
-  auto test =
-      tgfmt(attr.getConditionTemplate(), &ctx.withSelf("attr").withOp("*op"))
-          .str();
+  auto test = tgfmt(attr.getConditionTemplate(),
+                    &ctx.withSelf("attr").addSubst("_op", "*op"))
+                  .str();
   return !StringRef(test).contains("<no-subst-found>");
 }
 

diff  --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp
index 25952157d9b25..03f888b139f8e 100644
--- a/mlir/lib/TableGen/Format.cpp
+++ b/mlir/lib/TableGen/Format.cpp
@@ -38,11 +38,6 @@ FmtContext &FmtContext::withBuilder(Twine subst) {
   return *this;
 }
 
-FmtContext &FmtContext::withOp(Twine subst) {
-  builtinSubstMap[PHKind::Op] = subst.str();
-  return *this;
-}
-
 FmtContext &FmtContext::withSelf(Twine subst) {
   builtinSubstMap[PHKind::Self] = subst.str();
   return *this;
@@ -69,7 +64,6 @@ std::optional<StringRef> FmtContext::getSubstFor(StringRef placeholder) const {
 FmtContext::PHKind FmtContext::getPlaceHolderKind(StringRef str) {
   return StringSwitch<FmtContext::PHKind>(str)
       .Case("_builder", FmtContext::PHKind::Builder)
-      .Case("_op", FmtContext::PHKind::Op)
       .Case("_self", FmtContext::PHKind::Self)
       .Case("", FmtContext::PHKind::None)
       .Default(FmtContext::PHKind::Custom);

diff  --git a/mlir/lib/TableGen/Interfaces.cpp b/mlir/lib/TableGen/Interfaces.cpp
index 4ddfa2f359328..bd56f6b027007 100644
--- a/mlir/lib/TableGen/Interfaces.cpp
+++ b/mlir/lib/TableGen/Interfaces.cpp
@@ -116,6 +116,11 @@ std::optional<StringRef> Interface::getExtraSharedClassDeclaration() const {
   return value.empty() ? std::optional<StringRef>() : value;
 }
 
+std::optional<StringRef> Interface::getExtraClassOf() const {
+  auto value = def->getValueAsString("extraClassOf");
+  return value.empty() ? std::optional<StringRef>() : value;
+}
+
 // Return the body for this method if it has one.
 std::optional<StringRef> Interface::getVerify() const {
   // Only OpInterface supports the verify method.

diff  --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td
index ab041982f0276..8129eb1a8b3bb 100644
--- a/mlir/test/mlir-tblgen/op-interface.td
+++ b/mlir/test/mlir-tblgen/op-interface.td
@@ -4,6 +4,17 @@
 
 include "mlir/IR/OpBase.td"
 
+def ExtraClassOfInterface : OpInterface<"ExtraClassOfInterface"> {
+  let extraClassOf = "return $_op->someOtherMethod();";
+}
+
+// DECL: class ExtraClassOfInterface
+// DECL:   static bool classof(::mlir::Operation * base) {
+// DECL-NEXT:     if (!getInterfaceFor(base))
+// DECL-NEXT:       return false;
+// DECL-NEXT:     return base->someOtherMethod();
+// DECL-NEXT:   }
+
 def ExtraShardDeclsInterface : OpInterface<"ExtraShardDeclsInterface"> {
   let extraSharedClassDeclaration = [{
     bool sharedMethodDeclaration() {

diff  --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 83937f46e6649..7ed29f91d3d64 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -819,7 +819,7 @@ OpEmitter::OpEmitter(const Operator &op,
               formatExtraDefinitions(op)),
       staticVerifierEmitter(staticVerifierEmitter),
       emitHelper(op, /*emitForOp=*/true) {
-  verifyCtx.withOp("(*this->getOperation())");
+  verifyCtx.addSubst("_op", "(*this->getOperation())");
   verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()");
 
   genTraits();

diff  --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
index 9e84d1985ff5f..363bec72649dd 100644
--- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp
@@ -108,6 +108,8 @@ class InterfaceGenerator {
   StringRef interfaceBaseType;
   /// The name of the typename for the value template.
   StringRef valueTemplate;
+  /// The name of the substituion variable for the value.
+  StringRef substVar;
   /// The format context to use for methods.
   tblgen::FmtContext nonStaticMethodFmt;
   tblgen::FmtContext traitMethodFmt;
@@ -121,11 +123,12 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
     valueType = "::mlir::Attribute";
     interfaceBaseType = "AttributeInterface";
     valueTemplate = "ConcreteAttr";
+    substVar = "_attr";
     StringRef castCode = "(tablegen_opaque_val.cast<ConcreteAttr>())";
-    nonStaticMethodFmt.addSubst("_attr", castCode).withSelf(castCode);
-    traitMethodFmt.addSubst("_attr",
+    nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
+    traitMethodFmt.addSubst(substVar,
                             "(*static_cast<const ConcreteAttr *>(this))");
-    extraDeclsFmt.addSubst("_attr", "(*this)");
+    extraDeclsFmt.addSubst(substVar, "(*this)");
   }
 };
 /// A specialized generator for operation interfaces.
@@ -135,12 +138,13 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
     valueType = "::mlir::Operation *";
     interfaceBaseType = "OpInterface";
     valueTemplate = "ConcreteOp";
+    substVar = "_op";
     StringRef castCode = "(llvm::cast<ConcreteOp>(tablegen_opaque_val))";
     nonStaticMethodFmt.addSubst("_this", "impl")
-        .withOp(castCode)
+        .addSubst(substVar, castCode)
         .withSelf(castCode);
-    traitMethodFmt.withOp("(*static_cast<ConcreteOp *>(this))");
-    extraDeclsFmt.withOp("(*this)");
+    traitMethodFmt.addSubst(substVar, "(*static_cast<ConcreteOp *>(this))");
+    extraDeclsFmt.addSubst(substVar, "(*this)");
   }
 };
 /// A specialized generator for type interfaces.
@@ -150,11 +154,12 @@ struct TypeInterfaceGenerator : public InterfaceGenerator {
     valueType = "::mlir::Type";
     interfaceBaseType = "TypeInterface";
     valueTemplate = "ConcreteType";
+    substVar = "_type";
     StringRef castCode = "(tablegen_opaque_val.cast<ConcreteType>())";
-    nonStaticMethodFmt.addSubst("_type", castCode).withSelf(castCode);
-    traitMethodFmt.addSubst("_type",
+    nonStaticMethodFmt.addSubst(substVar, castCode).withSelf(castCode);
+    traitMethodFmt.addSubst(substVar,
                             "(*static_cast<const ConcreteType *>(this))");
-    extraDeclsFmt.addSubst("_type", "(*this)");
+    extraDeclsFmt.addSubst(substVar, "(*this)");
   }
 };
 } // namespace
@@ -434,7 +439,7 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
     assert(isa<OpInterface>(interface) && "only OpInterface supports 'verify'");
 
     tblgen::FmtContext verifyCtx;
-    verifyCtx.withOp("op");
+    verifyCtx.addSubst("_op", "op");
     os << llvm::formatv(
               "    static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
               (interface.verifyWithRegions() ? "verifyRegionTrait"
@@ -506,6 +511,17 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
           interface.getExtraSharedClassDeclaration())
     os << tblgen::tgfmt(*extraDecls, &extraDeclsFmt);
 
+  // Emit classof code if necessary.
+  if (std::optional<StringRef> extraClassOf = interface.getExtraClassOf()) {
+    auto extraClassOfFmt = tblgen::FmtContext();
+    extraClassOfFmt.addSubst(substVar, "base");
+    os << "  static bool classof(" << valueType << " base) {\n"
+       << "    if (!getInterfaceFor(base))\n"
+          "      return false;\n"
+       << "    " << tblgen::tgfmt(extraClassOf->trim(), &extraClassOfFmt)
+       << "\n  }\n";
+  }
+
   os << "};\n";
 
   os << "namespace detail {\n";

diff  --git a/mlir/unittests/TableGen/FormatTest.cpp b/mlir/unittests/TableGen/FormatTest.cpp
index 0cae408bc3fb8..ef00cb4a1c24b 100644
--- a/mlir/unittests/TableGen/FormatTest.cpp
+++ b/mlir/unittests/TableGen/FormatTest.cpp
@@ -105,12 +105,6 @@ TEST(FormatTest, PlaceHolderFmtStrWithBuilder) {
   EXPECT_THAT(result, StrEq("bbb"));
 }
 
-TEST(FormatTest, PlaceHolderFmtStrWithOp) {
-  FmtContext ctx;
-  std::string result = std::string(tgfmt("$_op", &ctx.withOp("ooo")));
-  EXPECT_THAT(result, StrEq("ooo"));
-}
-
 TEST(FormatTest, PlaceHolderMissingCtx) {
   std::string result = std::string(tgfmt("$_op", nullptr));
   EXPECT_THAT(result, StrEq("$_op<no-subst-found>"));


        


More information about the Mlir-commits mailing list