[Mlir-commits] [mlir] [mlir][ODS] Consistent `cppType` / `cppClassName` usage (PR #102657)

Matthias Springer llvmlistbot at llvm.org
Fri Aug 9 11:27:06 PDT 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/102657

Make sure that the usage of `cppType` and `cppClassName` of type and attribute definitions/constraints is consistent in TableGen.

- `cppClassName`: The C++ class name of the type or attribute.
- `cppType`: The fully qualified C++ class name: C++ namespace and C++ class name.

Also some minor cleanups.

Fixes #57279.


>From 9968cd55e4f606728b90854f0ff6f0ec8b04a1ae Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 9 Aug 2024 20:21:44 +0200
Subject: [PATCH] [mlir][ODS] Consistent `cppType` / `cppClassName` usage

Make sure that the usage of `cppType` and `cppClassName` of type and attribute definitions/constraints is consistent in TableGen.

- `cppClassName`: The C++ class name of the type or attribute.
- `cppType`: The fully qualified C++ class name: C++ namespace and C++ class name.
---
 mlir/include/mlir/IR/AttrTypeBase.td          | 18 +++++-----
 mlir/include/mlir/IR/CommonAttrConstraints.td |  6 ++--
 mlir/include/mlir/IR/CommonTypeConstraints.td | 34 +++++++++----------
 mlir/include/mlir/IR/Constraints.td           |  4 +--
 mlir/include/mlir/TableGen/Type.h             |  4 +--
 mlir/lib/TableGen/Type.cpp                    | 17 ++--------
 mlir/lib/Tools/PDLL/Parser/Parser.cpp         |  5 ++-
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp   | 19 ++---------
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        |  6 ++--
 9 files changed, 44 insertions(+), 69 deletions(-)

diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 91c9283de8bd41..d176b36068f7a5 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -256,7 +256,7 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
       AttrOrTypeDef<"Attr", name, traits, baseCppClass> {
   // The name of the C++ Attribute class.
   string cppClassName = name # "Attr";
-  let storageType = dialect.cppNamespace # "::" # name # "Attr";
+  let storageType = dialect.cppNamespace # "::" # cppClassName;
 
   // The underlying C++ value type
   let returnType = dialect.cppNamespace # "::" # cppClassName;
@@ -275,12 +275,10 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
   //
   // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
   // expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
-  let convertFromStorage = "::llvm::cast<" # dialect.cppNamespace #
-                                 "::" # cppClassName # ">($_self)";
+  let convertFromStorage = "::llvm::cast<" # cppType # ">($_self)";
 
   // The predicate for when this def is used as a constraint.
-  let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
-                                 "::" # cppClassName # ">($_self)">;
+  let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
 }
 
 // Define a new type, named `name`, belonging to `dialect` that inherits from
@@ -289,6 +287,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
               string baseCppClass = "::mlir::Type">
     : DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
       AttrOrTypeDef<"Type", name, traits, baseCppClass> {
+  // The name of the C++ Type class.
+  string cppClassName = name # "Type";
+
   // Make it possible to use such type as parameters for other types.
   string cppType = dialect.cppNamespace # "::" # cppClassName;
 
@@ -297,12 +298,11 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
 
   // A constant builder provided when the type has no parameters.
   let builderCall = !if(!empty(parameters),
-                           "$_builder.getType<" # dialect.cppNamespace #
-                               "::" # cppClassName # ">()",
+                           "$_builder.getType<" # cppType # ">()",
                            "");
+
   // The predicate for when this def is used as a constraint.
-  let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
-                                 "::" # cppClassName # ">($_self)">;
+  let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index d99bde1f87ef00..6774a7c568315d 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -169,14 +169,14 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {
 
 // Any attribute from the given list
 class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
-                string cppClassName = "::mlir::Attribute",
+                string cppType = "::mlir::Attribute",
                 string fromStorage = "$_self"> : Attr<
     // Satisfy any of the allowed attribute's condition
     Or<!foreach(allowedattr, allowedAttrs, allowedattr.predicate)>,
     !if(!eq(summary, ""),
         !interleave(!foreach(t, allowedAttrs, t.summary), " or "),
         summary)> {
-    let returnType = cppClassName;
+    let returnType = cppType;
     let convertFromStorage = fromStorage;
 }
 
@@ -369,7 +369,7 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
 }
 
 class TypeAttrOf<Type ty>
-   : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary,
+   : TypeAttrBase<ty.cppType, "type attribute of " # ty.summary,
                     ty.predicate> {
   let constBuilderCall = "::mlir::TypeAttr::get($0)";
 }
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa2420..65e1b1cbdc905a 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -98,23 +98,23 @@ def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;
 
 // A type, carries type constraints.
 class Type<Pred condition, string descr = "",
-           string cppClassName = "::mlir::Type"> :
-    TypeConstraint<condition, descr, cppClassName> {
+           string cppType = "::mlir::Type"> :
+    TypeConstraint<condition, descr, cppType> {
   string description = "";
   string builderCall = "";
 }
 
 // Allows providing an alternative name and summary to an existing type def.
 class TypeAlias<Type t, string summary = t.summary> :
-    Type<t.predicate, summary, t.cppClassName> {
+    Type<t.predicate, summary, t.cppType> {
   let description = t.description;
   let builderCall = t.builderCall;
 }
 
 // A type of a specific dialect.
 class DialectType<Dialect d, Pred condition, string descr = "",
-                  string cppClassName = "::mlir::Type"> :
-    Type<condition, descr, cppClassName> {
+                  string cppType = "::mlir::Type"> :
+    Type<condition, descr, cppType> {
   Dialect dialect = d;
 }
 
@@ -122,7 +122,7 @@ class DialectType<Dialect d, Pred condition, string descr = "",
 // class is used for supporting variadic operands/results.
 class Variadic<Type type> : TypeConstraint<type.predicate,
                                            "variadic of " # type.summary,
-                                           type.cppClassName> {
+                                           type.cppType> {
   Type baseType = type;
   int minSize = 0;
 }
@@ -140,7 +140,7 @@ class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
 // An optional type constraint. It expands to either zero or one of the base
 // type. This class is used for supporting optional operands/results.
 class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
-                                           type.cppClassName> {
+                                           type.cppType> {
   Type baseType = type;
 }
 
@@ -172,33 +172,33 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
 
 // Any type from the given list
 class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
-                string cppClassName = "::mlir::Type"> : Type<
+                string cppType = "::mlir::Type"> : Type<
     // Satisfy any of the allowed types' conditions.
     Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
     !if(!eq(summary, ""),
         !interleave(!foreach(t, allowedTypeList, t.summary), " or "),
         summary),
-    cppClassName> {
+    cppType> {
   list<Type> allowedTypes = allowedTypeList;
 }
 
 // A type that satisfies the constraints of all given types.
 class AllOfType<list<Type> allowedTypeList, string summary = "",
-                string cppClassName = "::mlir::Type"> : Type<
+                string cppType = "::mlir::Type"> : Type<
     // Satisfy all of the allowed types' conditions.
     And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
     !if(!eq(summary, ""),
         !interleave(!foreach(t, allowedTypeList, t.summary), " and "),
         summary),
-    cppClassName> {
+    cppType> {
   list<Type> allowedTypes = allowedTypeList;
 }
 
 // A type that satisfies additional predicates.
 class ConfinedType<Type type, list<Pred> predicates, string summary = "",
-                   string cppClassName = type.cppClassName> : Type<
+                   string cppType = type.cppType> : Type<
     And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
-    summary, cppClassName>;
+    summary, cppType>;
 
 // Integer types.
 
@@ -375,23 +375,23 @@ def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
 
 // A container type is a type that has another type embedded within it.
 class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
-                    string descr, string cppClassName = "::mlir::Type"> :
+                    string descr, string cppType = "::mlir::Type"> :
     // First, check the container predicate.  Then, substitute the extracted
     // element into the element type checker.
     Type<And<[containerPred,
                 SubstLeaves<"$_self", !cast<string>(elementTypeCall),
                 etype.predicate>]>,
-         descr # " of " # etype.summary # " values", cppClassName>;
+         descr # " of " # etype.summary # " values", cppType>;
 
 class ShapedContainerType<list<Type> allowedTypes,
                           Pred containerPred, string descr,
-                          string cppClassName = "::mlir::Type"> :
+                          string cppType = "::mlir::Type"> :
     Type<And<[containerPred,
               Concat<"[](::mlir::Type elementType) { return ",
                 SubstLeaves<"$_self", "elementType",
                 AnyTypeOf<allowedTypes>.predicate>,
                 "; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>,
-         descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppClassName>;
+         descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppType>;
 
 // Whether a shaped type is ranked.
 def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">;
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index a026d58ccffb8e..39bc55db63da1a 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -149,10 +149,10 @@ class Constraint<Pred pred, string desc = ""> {
 
 // Subclass for constraints on a type.
 class TypeConstraint<Pred predicate, string summary = "",
-                     string cppClassNameParam = "::mlir::Type"> :
+                     string cppTypeParam = "::mlir::Type"> :
     Constraint<predicate, summary> {
   // The name of the C++ Type class if known, or Type if not.
-  string cppClassName = cppClassNameParam;
+  string cppType = cppTypeParam;
 }
 
 // Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index 06cf4f5730d565..b974ac281041bc 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -56,8 +56,8 @@ class TypeConstraint : public Constraint {
   // returns std::nullopt otherwise.
   std::optional<StringRef> getBuilderCall() const;
 
-  // Return the C++ class name for this type (which may just be ::mlir::Type).
-  std::string getCPPClassName() const;
+  // Return the C++ type for this type (which may just be ::mlir::Type).
+  StringRef getCppType() const;
 };
 
 // Wrapper class with helper methods for accessing Types defined in TableGen.
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index e9f2394dd540af..cda752297988bb 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -59,20 +59,9 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
       .Default([](auto *) { return std::nullopt; });
 }
 
-// Return the C++ class name for this type (which may just be ::mlir::Type).
-std::string TypeConstraint::getCPPClassName() const {
-  StringRef className = def->getValueAsString("cppClassName");
-
-  // If the class name is already namespace resolved, use it.
-  if (className.contains("::"))
-    return className.str();
-
-  // Otherwise, check to see if there is a namespace from a dialect to prepend.
-  if (const llvm::RecordVal *value = def->getValue("dialect")) {
-    Dialect dialect(cast<const llvm::DefInit>(value->getValue())->getDef());
-    return (dialect.getCppNamespace() + "::" + className).str();
-  }
-  return className.str();
+// Return the C++ type for this type (which may just be ::mlir::Type).
+StringRef TypeConstraint::getCppType() const {
+  return def->getValueAsString("cppType");
 }
 
 Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 1f0df033d43398..01c78e280080ee 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -879,8 +879,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
       -> const ods::TypeConstraint & {
     return odsContext.insertTypeConstraint(
         cst.constraint.getUniqueDefName(),
-        processDoc(cst.constraint.getSummary()),
-        cst.constraint.getCPPClassName());
+        processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
   };
   auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
     return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
@@ -944,7 +943,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
     tblgen::TypeConstraint constraint(def);
     decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
         constraint, convertLocToRange(def->getLoc().front()), typeTy,
-        constraint.getCPPClassName()));
+        constraint.getCppType()));
   }
   /// OpInterfaces.
   ast::Type opTy = ast::OperationType::get(ctx);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a2ceefb34db453..66dbb16760ebb0 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2085,21 +2085,8 @@ static void generateValueRangeStartAndEnd(
 }
 
 static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
-  std::string str = "::mlir::Value";
-  /// If the CPPClassName is not a fully qualified type. Uses of types
-  /// across Dialect fail because they are not in the correct namespace. So we
-  /// dont generate TypedValue unless the type is fully qualified.
-  /// getCPPClassName doesn't return the fully qualified path for
-  /// `mlir::pdl::OperationType` see
-  /// https://github.com/llvm/llvm-project/issues/57279.
-  /// Adaptor will have values that are not from the type of their operation and
-  /// this is expected, so we dont generate TypedValue for Adaptor
-  if (value.constraint.getCPPClassName() != "::mlir::Type" &&
-      StringRef(value.constraint.getCPPClassName()).starts_with("::"))
-    str = llvm::formatv("::mlir::TypedValue<{0}>",
-                        value.constraint.getCPPClassName())
-              .str();
-  return str;
+  return llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCppType())
+      .str();
 }
 
 // Generates the named operand getter methods for the given Operator `op` and
@@ -3944,7 +3931,7 @@ void OpEmitter::genTraits() {
   // For single result ops with a known specific type, generate a OneTypedResult
   // trait.
   if (numResults == 1 && numVariadicResults == 0) {
-    auto cppName = op.getResults().begin()->constraint.getCPPClassName();
+    auto cppName = op.getResults().begin()->constraint.getCppType();
     opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
   }
 
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 27ad79a5c1efed..9a95f495b77658 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1657,7 +1657,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
       TypeSwitch<FormatElement *>(dir->getArg())
           .Case<OperandVariable, ResultVariable>([&](auto operand) {
             body << formatv(parserCode,
-                            operand->getVar()->constraint.getCPPClassName(),
+                            operand->getVar()->constraint.getCppType(),
                             listName);
           })
           .Default([&](auto operand) {
@@ -2603,7 +2603,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
     }
     if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
         !var->isOptional()) {
-      std::string cppClass = var->constraint.getCPPClassName();
+      StringRef cppType = var->constraint.getCppType();
       if (dir->shouldBeQualified()) {
         body << "   _odsPrinter << " << op.getGetterName(var->name)
              << "().getType();\n";
@@ -2612,7 +2612,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
       body << "  {\n"
            << "    auto type = " << op.getGetterName(var->name)
            << "().getType();\n"
-           << "    if (auto validType = ::llvm::dyn_cast<" << cppClass
+           << "    if (auto validType = ::llvm::dyn_cast<" << cppType
            << ">(type))\n"
            << "      _odsPrinter.printStrippedAttrOrType(validType);\n"
            << "   else\n"



More information about the Mlir-commits mailing list