[Mlir-commits] [mlir] [mlir][ODS] Consistent `cppType` / `cppClassName` usage (PR #102657)
Matthias Springer
llvmlistbot at llvm.org
Fri Aug 9 11:27:06 PDT 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/102657
Make sure that the usage of `cppType` and `cppClassName` of type and attribute definitions/constraints is consistent in TableGen.
- `cppClassName`: The C++ class name of the type or attribute.
- `cppType`: The fully qualified C++ class name: C++ namespace and C++ class name.
Also some minor cleanups.
Fixes #57279.
>From 9968cd55e4f606728b90854f0ff6f0ec8b04a1ae Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 9 Aug 2024 20:21:44 +0200
Subject: [PATCH] [mlir][ODS] Consistent `cppType` / `cppClassName` usage
Make sure that the usage of `cppType` and `cppClassName` of type and attribute definitions/constraints is consistent in TableGen.
- `cppClassName`: The C++ class name of the type or attribute.
- `cppType`: The fully qualified C++ class name: C++ namespace and C++ class name.
---
mlir/include/mlir/IR/AttrTypeBase.td | 18 +++++-----
mlir/include/mlir/IR/CommonAttrConstraints.td | 6 ++--
mlir/include/mlir/IR/CommonTypeConstraints.td | 34 +++++++++----------
mlir/include/mlir/IR/Constraints.td | 4 +--
mlir/include/mlir/TableGen/Type.h | 4 +--
mlir/lib/TableGen/Type.cpp | 17 ++--------
mlir/lib/Tools/PDLL/Parser/Parser.cpp | 5 ++-
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 19 ++---------
mlir/tools/mlir-tblgen/OpFormatGen.cpp | 6 ++--
9 files changed, 44 insertions(+), 69 deletions(-)
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 91c9283de8bd41..d176b36068f7a5 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -256,7 +256,7 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
AttrOrTypeDef<"Attr", name, traits, baseCppClass> {
// The name of the C++ Attribute class.
string cppClassName = name # "Attr";
- let storageType = dialect.cppNamespace # "::" # name # "Attr";
+ let storageType = dialect.cppNamespace # "::" # cppClassName;
// The underlying C++ value type
let returnType = dialect.cppNamespace # "::" # cppClassName;
@@ -275,12 +275,10 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
//
// For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
// expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
- let convertFromStorage = "::llvm::cast<" # dialect.cppNamespace #
- "::" # cppClassName # ">($_self)";
+ let convertFromStorage = "::llvm::cast<" # cppType # ">($_self)";
// The predicate for when this def is used as a constraint.
- let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
- "::" # cppClassName # ">($_self)">;
+ let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}
// Define a new type, named `name`, belonging to `dialect` that inherits from
@@ -289,6 +287,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
AttrOrTypeDef<"Type", name, traits, baseCppClass> {
+ // The name of the C++ Type class.
+ string cppClassName = name # "Type";
+
// Make it possible to use such type as parameters for other types.
string cppType = dialect.cppNamespace # "::" # cppClassName;
@@ -297,12 +298,11 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
- "$_builder.getType<" # dialect.cppNamespace #
- "::" # cppClassName # ">()",
+ "$_builder.getType<" # cppType # ">()",
"");
+
// The predicate for when this def is used as a constraint.
- let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
- "::" # cppClassName # ">($_self)">;
+ let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index d99bde1f87ef00..6774a7c568315d 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -169,14 +169,14 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {
// Any attribute from the given list
class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
- string cppClassName = "::mlir::Attribute",
+ string cppType = "::mlir::Attribute",
string fromStorage = "$_self"> : Attr<
// Satisfy any of the allowed attribute's condition
Or<!foreach(allowedattr, allowedAttrs, allowedattr.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedAttrs, t.summary), " or "),
summary)> {
- let returnType = cppClassName;
+ let returnType = cppType;
let convertFromStorage = fromStorage;
}
@@ -369,7 +369,7 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
}
class TypeAttrOf<Type ty>
- : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary,
+ : TypeAttrBase<ty.cppType, "type attribute of " # ty.summary,
ty.predicate> {
let constBuilderCall = "::mlir::TypeAttr::get($0)";
}
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa2420..65e1b1cbdc905a 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -98,23 +98,23 @@ def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;
// A type, carries type constraints.
class Type<Pred condition, string descr = "",
- string cppClassName = "::mlir::Type"> :
- TypeConstraint<condition, descr, cppClassName> {
+ string cppType = "::mlir::Type"> :
+ TypeConstraint<condition, descr, cppType> {
string description = "";
string builderCall = "";
}
// Allows providing an alternative name and summary to an existing type def.
class TypeAlias<Type t, string summary = t.summary> :
- Type<t.predicate, summary, t.cppClassName> {
+ Type<t.predicate, summary, t.cppType> {
let description = t.description;
let builderCall = t.builderCall;
}
// A type of a specific dialect.
class DialectType<Dialect d, Pred condition, string descr = "",
- string cppClassName = "::mlir::Type"> :
- Type<condition, descr, cppClassName> {
+ string cppType = "::mlir::Type"> :
+ Type<condition, descr, cppType> {
Dialect dialect = d;
}
@@ -122,7 +122,7 @@ class DialectType<Dialect d, Pred condition, string descr = "",
// class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate,
"variadic of " # type.summary,
- type.cppClassName> {
+ type.cppType> {
Type baseType = type;
int minSize = 0;
}
@@ -140,7 +140,7 @@ class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
- type.cppClassName> {
+ type.cppType> {
Type baseType = type;
}
@@ -172,33 +172,33 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
- string cppClassName = "::mlir::Type"> : Type<
+ string cppType = "::mlir::Type"> : Type<
// Satisfy any of the allowed types' conditions.
Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypeList, t.summary), " or "),
summary),
- cppClassName> {
+ cppType> {
list<Type> allowedTypes = allowedTypeList;
}
// A type that satisfies the constraints of all given types.
class AllOfType<list<Type> allowedTypeList, string summary = "",
- string cppClassName = "::mlir::Type"> : Type<
+ string cppType = "::mlir::Type"> : Type<
// Satisfy all of the allowed types' conditions.
And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypeList, t.summary), " and "),
summary),
- cppClassName> {
+ cppType> {
list<Type> allowedTypes = allowedTypeList;
}
// A type that satisfies additional predicates.
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
- string cppClassName = type.cppClassName> : Type<
+ string cppType = type.cppType> : Type<
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
- summary, cppClassName>;
+ summary, cppType>;
// Integer types.
@@ -375,23 +375,23 @@ def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
- string descr, string cppClassName = "::mlir::Type"> :
+ string descr, string cppType = "::mlir::Type"> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<And<[containerPred,
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
etype.predicate>]>,
- descr # " of " # etype.summary # " values", cppClassName>;
+ descr # " of " # etype.summary # " values", cppType>;
class ShapedContainerType<list<Type> allowedTypes,
Pred containerPred, string descr,
- string cppClassName = "::mlir::Type"> :
+ string cppType = "::mlir::Type"> :
Type<And<[containerPred,
Concat<"[](::mlir::Type elementType) { return ",
SubstLeaves<"$_self", "elementType",
AnyTypeOf<allowedTypes>.predicate>,
"; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>,
- descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppClassName>;
+ descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppType>;
// Whether a shaped type is ranked.
def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">;
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index a026d58ccffb8e..39bc55db63da1a 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -149,10 +149,10 @@ class Constraint<Pred pred, string desc = ""> {
// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string summary = "",
- string cppClassNameParam = "::mlir::Type"> :
+ string cppTypeParam = "::mlir::Type"> :
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
- string cppClassName = cppClassNameParam;
+ string cppType = cppTypeParam;
}
// Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index 06cf4f5730d565..b974ac281041bc 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -56,8 +56,8 @@ class TypeConstraint : public Constraint {
// returns std::nullopt otherwise.
std::optional<StringRef> getBuilderCall() const;
- // Return the C++ class name for this type (which may just be ::mlir::Type).
- std::string getCPPClassName() const;
+ // Return the C++ type for this type (which may just be ::mlir::Type).
+ StringRef getCppType() const;
};
// Wrapper class with helper methods for accessing Types defined in TableGen.
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index e9f2394dd540af..cda752297988bb 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -59,20 +59,9 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
.Default([](auto *) { return std::nullopt; });
}
-// Return the C++ class name for this type (which may just be ::mlir::Type).
-std::string TypeConstraint::getCPPClassName() const {
- StringRef className = def->getValueAsString("cppClassName");
-
- // If the class name is already namespace resolved, use it.
- if (className.contains("::"))
- return className.str();
-
- // Otherwise, check to see if there is a namespace from a dialect to prepend.
- if (const llvm::RecordVal *value = def->getValue("dialect")) {
- Dialect dialect(cast<const llvm::DefInit>(value->getValue())->getDef());
- return (dialect.getCppNamespace() + "::" + className).str();
- }
- return className.str();
+// Return the C++ type for this type (which may just be ::mlir::Type).
+StringRef TypeConstraint::getCppType() const {
+ return def->getValueAsString("cppType");
}
Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 1f0df033d43398..01c78e280080ee 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -879,8 +879,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
-> const ods::TypeConstraint & {
return odsContext.insertTypeConstraint(
cst.constraint.getUniqueDefName(),
- processDoc(cst.constraint.getSummary()),
- cst.constraint.getCPPClassName());
+ processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
};
auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
@@ -944,7 +943,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
tblgen::TypeConstraint constraint(def);
decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
constraint, convertLocToRange(def->getLoc().front()), typeTy,
- constraint.getCPPClassName()));
+ constraint.getCppType()));
}
/// OpInterfaces.
ast::Type opTy = ast::OperationType::get(ctx);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a2ceefb34db453..66dbb16760ebb0 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2085,21 +2085,8 @@ static void generateValueRangeStartAndEnd(
}
static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
- std::string str = "::mlir::Value";
- /// If the CPPClassName is not a fully qualified type. Uses of types
- /// across Dialect fail because they are not in the correct namespace. So we
- /// dont generate TypedValue unless the type is fully qualified.
- /// getCPPClassName doesn't return the fully qualified path for
- /// `mlir::pdl::OperationType` see
- /// https://github.com/llvm/llvm-project/issues/57279.
- /// Adaptor will have values that are not from the type of their operation and
- /// this is expected, so we dont generate TypedValue for Adaptor
- if (value.constraint.getCPPClassName() != "::mlir::Type" &&
- StringRef(value.constraint.getCPPClassName()).starts_with("::"))
- str = llvm::formatv("::mlir::TypedValue<{0}>",
- value.constraint.getCPPClassName())
- .str();
- return str;
+ return llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCppType())
+ .str();
}
// Generates the named operand getter methods for the given Operator `op` and
@@ -3944,7 +3931,7 @@ void OpEmitter::genTraits() {
// For single result ops with a known specific type, generate a OneTypedResult
// trait.
if (numResults == 1 && numVariadicResults == 0) {
- auto cppName = op.getResults().begin()->constraint.getCPPClassName();
+ auto cppName = op.getResults().begin()->constraint.getCppType();
opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 27ad79a5c1efed..9a95f495b77658 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1657,7 +1657,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
TypeSwitch<FormatElement *>(dir->getArg())
.Case<OperandVariable, ResultVariable>([&](auto operand) {
body << formatv(parserCode,
- operand->getVar()->constraint.getCPPClassName(),
+ operand->getVar()->constraint.getCppType(),
listName);
})
.Default([&](auto operand) {
@@ -2603,7 +2603,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
}
if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
!var->isOptional()) {
- std::string cppClass = var->constraint.getCPPClassName();
+ StringRef cppType = var->constraint.getCppType();
if (dir->shouldBeQualified()) {
body << " _odsPrinter << " << op.getGetterName(var->name)
<< "().getType();\n";
@@ -2612,7 +2612,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
body << " {\n"
<< " auto type = " << op.getGetterName(var->name)
<< "().getType();\n"
- << " if (auto validType = ::llvm::dyn_cast<" << cppClass
+ << " if (auto validType = ::llvm::dyn_cast<" << cppType
<< ">(type))\n"
<< " _odsPrinter.printStrippedAttrOrType(validType);\n"
<< " else\n"
More information about the Mlir-commits
mailing list