[Mlir-commits] [mlir] [mlir][ODS] Verify type constraints in Types and Attributes (PR #102326)
Matthias Springer
llvmlistbot at llvm.org
Fri Aug 9 11:52:08 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/102326
>From 9968cd55e4f606728b90854f0ff6f0ec8b04a1ae Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 9 Aug 2024 20:21:44 +0200
Subject: [PATCH 1/2] [mlir][ODS] Consistent `cppType` / `cppClassName` usage
Make sure that the usage of `cppType` and `cppClassName` of type and attribute definitions/constraints is consistent in TableGen.
- `cppClassName`: The C++ class name of the type or attribute.
- `cppType`: The fully qualified C++ class name: C++ namespace and C++ class name.
---
mlir/include/mlir/IR/AttrTypeBase.td | 18 +++++-----
mlir/include/mlir/IR/CommonAttrConstraints.td | 6 ++--
mlir/include/mlir/IR/CommonTypeConstraints.td | 34 +++++++++----------
mlir/include/mlir/IR/Constraints.td | 4 +--
mlir/include/mlir/TableGen/Type.h | 4 +--
mlir/lib/TableGen/Type.cpp | 17 ++--------
mlir/lib/Tools/PDLL/Parser/Parser.cpp | 5 ++-
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 19 ++---------
mlir/tools/mlir-tblgen/OpFormatGen.cpp | 6 ++--
9 files changed, 44 insertions(+), 69 deletions(-)
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 91c9283de8bd41..d176b36068f7a5 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -256,7 +256,7 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
AttrOrTypeDef<"Attr", name, traits, baseCppClass> {
// The name of the C++ Attribute class.
string cppClassName = name # "Attr";
- let storageType = dialect.cppNamespace # "::" # name # "Attr";
+ let storageType = dialect.cppNamespace # "::" # cppClassName;
// The underlying C++ value type
let returnType = dialect.cppNamespace # "::" # cppClassName;
@@ -275,12 +275,10 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
//
// For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
// expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
- let convertFromStorage = "::llvm::cast<" # dialect.cppNamespace #
- "::" # cppClassName # ">($_self)";
+ let convertFromStorage = "::llvm::cast<" # cppType # ">($_self)";
// The predicate for when this def is used as a constraint.
- let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
- "::" # cppClassName # ">($_self)">;
+ let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}
// Define a new type, named `name`, belonging to `dialect` that inherits from
@@ -289,6 +287,9 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: DialectType<dialect, CPred<"">, /*descr*/"", name # "Type">,
AttrOrTypeDef<"Type", name, traits, baseCppClass> {
+ // The name of the C++ Type class.
+ string cppClassName = name # "Type";
+
// Make it possible to use such type as parameters for other types.
string cppType = dialect.cppNamespace # "::" # cppClassName;
@@ -297,12 +298,11 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
// A constant builder provided when the type has no parameters.
let builderCall = !if(!empty(parameters),
- "$_builder.getType<" # dialect.cppNamespace #
- "::" # cppClassName # ">()",
+ "$_builder.getType<" # cppType # ">()",
"");
+
// The predicate for when this def is used as a constraint.
- let predicate = CPred<"::llvm::isa<" # dialect.cppNamespace #
- "::" # cppClassName # ">($_self)">;
+ let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index d99bde1f87ef00..6774a7c568315d 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -169,14 +169,14 @@ def AnyAttr : Attr<CPred<"true">, "any attribute"> {
// Any attribute from the given list
class AnyAttrOf<list<Attr> allowedAttrs, string summary = "",
- string cppClassName = "::mlir::Attribute",
+ string cppType = "::mlir::Attribute",
string fromStorage = "$_self"> : Attr<
// Satisfy any of the allowed attribute's condition
Or<!foreach(allowedattr, allowedAttrs, allowedattr.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedAttrs, t.summary), " or "),
summary)> {
- let returnType = cppClassName;
+ let returnType = cppType;
let convertFromStorage = fromStorage;
}
@@ -369,7 +369,7 @@ def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> {
}
class TypeAttrOf<Type ty>
- : TypeAttrBase<ty.cppClassName, "type attribute of " # ty.summary,
+ : TypeAttrBase<ty.cppType, "type attribute of " # ty.summary,
ty.predicate> {
let constBuilderCall = "::mlir::TypeAttr::get($0)";
}
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa2420..65e1b1cbdc905a 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -98,23 +98,23 @@ def HasValueSemanticsPred : CPred<"$_self.hasTrait<::mlir::ValueSemantics>()">;
// A type, carries type constraints.
class Type<Pred condition, string descr = "",
- string cppClassName = "::mlir::Type"> :
- TypeConstraint<condition, descr, cppClassName> {
+ string cppType = "::mlir::Type"> :
+ TypeConstraint<condition, descr, cppType> {
string description = "";
string builderCall = "";
}
// Allows providing an alternative name and summary to an existing type def.
class TypeAlias<Type t, string summary = t.summary> :
- Type<t.predicate, summary, t.cppClassName> {
+ Type<t.predicate, summary, t.cppType> {
let description = t.description;
let builderCall = t.builderCall;
}
// A type of a specific dialect.
class DialectType<Dialect d, Pred condition, string descr = "",
- string cppClassName = "::mlir::Type"> :
- Type<condition, descr, cppClassName> {
+ string cppType = "::mlir::Type"> :
+ Type<condition, descr, cppType> {
Dialect dialect = d;
}
@@ -122,7 +122,7 @@ class DialectType<Dialect d, Pred condition, string descr = "",
// class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate,
"variadic of " # type.summary,
- type.cppClassName> {
+ type.cppType> {
Type baseType = type;
int minSize = 0;
}
@@ -140,7 +140,7 @@ class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.summary,
- type.cppClassName> {
+ type.cppType> {
Type baseType = type;
}
@@ -172,33 +172,33 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
- string cppClassName = "::mlir::Type"> : Type<
+ string cppType = "::mlir::Type"> : Type<
// Satisfy any of the allowed types' conditions.
Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypeList, t.summary), " or "),
summary),
- cppClassName> {
+ cppType> {
list<Type> allowedTypes = allowedTypeList;
}
// A type that satisfies the constraints of all given types.
class AllOfType<list<Type> allowedTypeList, string summary = "",
- string cppClassName = "::mlir::Type"> : Type<
+ string cppType = "::mlir::Type"> : Type<
// Satisfy all of the allowed types' conditions.
And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypeList, t.summary), " and "),
summary),
- cppClassName> {
+ cppType> {
list<Type> allowedTypes = allowedTypeList;
}
// A type that satisfies additional predicates.
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
- string cppClassName = type.cppClassName> : Type<
+ string cppType = type.cppType> : Type<
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
- summary, cppClassName>;
+ summary, cppType>;
// Integer types.
@@ -375,23 +375,23 @@ def FunctionType : Type<CPred<"::llvm::isa<::mlir::FunctionType>($_self)">,
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
- string descr, string cppClassName = "::mlir::Type"> :
+ string descr, string cppType = "::mlir::Type"> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<And<[containerPred,
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
etype.predicate>]>,
- descr # " of " # etype.summary # " values", cppClassName>;
+ descr # " of " # etype.summary # " values", cppType>;
class ShapedContainerType<list<Type> allowedTypes,
Pred containerPred, string descr,
- string cppClassName = "::mlir::Type"> :
+ string cppType = "::mlir::Type"> :
Type<And<[containerPred,
Concat<"[](::mlir::Type elementType) { return ",
SubstLeaves<"$_self", "elementType",
AnyTypeOf<allowedTypes>.predicate>,
"; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>,
- descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppClassName>;
+ descr # " of " # AnyTypeOf<allowedTypes>.summary # " values", cppType>;
// Whether a shaped type is ranked.
def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">;
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index a026d58ccffb8e..39bc55db63da1a 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -149,10 +149,10 @@ class Constraint<Pred pred, string desc = ""> {
// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string summary = "",
- string cppClassNameParam = "::mlir::Type"> :
+ string cppTypeParam = "::mlir::Type"> :
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
- string cppClassName = cppClassNameParam;
+ string cppType = cppTypeParam;
}
// Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h
index 06cf4f5730d565..b974ac281041bc 100644
--- a/mlir/include/mlir/TableGen/Type.h
+++ b/mlir/include/mlir/TableGen/Type.h
@@ -56,8 +56,8 @@ class TypeConstraint : public Constraint {
// returns std::nullopt otherwise.
std::optional<StringRef> getBuilderCall() const;
- // Return the C++ class name for this type (which may just be ::mlir::Type).
- std::string getCPPClassName() const;
+ // Return the C++ type for this type (which may just be ::mlir::Type).
+ StringRef getCppType() const;
};
// Wrapper class with helper methods for accessing Types defined in TableGen.
diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp
index e9f2394dd540af..cda752297988bb 100644
--- a/mlir/lib/TableGen/Type.cpp
+++ b/mlir/lib/TableGen/Type.cpp
@@ -59,20 +59,9 @@ std::optional<StringRef> TypeConstraint::getBuilderCall() const {
.Default([](auto *) { return std::nullopt; });
}
-// Return the C++ class name for this type (which may just be ::mlir::Type).
-std::string TypeConstraint::getCPPClassName() const {
- StringRef className = def->getValueAsString("cppClassName");
-
- // If the class name is already namespace resolved, use it.
- if (className.contains("::"))
- return className.str();
-
- // Otherwise, check to see if there is a namespace from a dialect to prepend.
- if (const llvm::RecordVal *value = def->getValue("dialect")) {
- Dialect dialect(cast<const llvm::DefInit>(value->getValue())->getDef());
- return (dialect.getCppNamespace() + "::" + className).str();
- }
- return className.str();
+// Return the C++ type for this type (which may just be ::mlir::Type).
+StringRef TypeConstraint::getCppType() const {
+ return def->getValueAsString("cppType");
}
Type::Type(const llvm::Record *record) : TypeConstraint(record) {}
diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
index 1f0df033d43398..01c78e280080ee 100644
--- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp
+++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp
@@ -879,8 +879,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
-> const ods::TypeConstraint & {
return odsContext.insertTypeConstraint(
cst.constraint.getUniqueDefName(),
- processDoc(cst.constraint.getSummary()),
- cst.constraint.getCPPClassName());
+ processDoc(cst.constraint.getSummary()), cst.constraint.getCppType());
};
auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange {
return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)};
@@ -944,7 +943,7 @@ void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords,
tblgen::TypeConstraint constraint(def);
decls.push_back(createODSNativePDLLConstraintDecl<ast::TypeConstraintDecl>(
constraint, convertLocToRange(def->getLoc().front()), typeTy,
- constraint.getCPPClassName()));
+ constraint.getCppType()));
}
/// OpInterfaces.
ast::Type opTy = ast::OperationType::get(ctx);
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index a2ceefb34db453..66dbb16760ebb0 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2085,21 +2085,8 @@ static void generateValueRangeStartAndEnd(
}
static std::string generateTypeForGetter(const NamedTypeConstraint &value) {
- std::string str = "::mlir::Value";
- /// If the CPPClassName is not a fully qualified type. Uses of types
- /// across Dialect fail because they are not in the correct namespace. So we
- /// dont generate TypedValue unless the type is fully qualified.
- /// getCPPClassName doesn't return the fully qualified path for
- /// `mlir::pdl::OperationType` see
- /// https://github.com/llvm/llvm-project/issues/57279.
- /// Adaptor will have values that are not from the type of their operation and
- /// this is expected, so we dont generate TypedValue for Adaptor
- if (value.constraint.getCPPClassName() != "::mlir::Type" &&
- StringRef(value.constraint.getCPPClassName()).starts_with("::"))
- str = llvm::formatv("::mlir::TypedValue<{0}>",
- value.constraint.getCPPClassName())
- .str();
- return str;
+ return llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCppType())
+ .str();
}
// Generates the named operand getter methods for the given Operator `op` and
@@ -3944,7 +3931,7 @@ void OpEmitter::genTraits() {
// For single result ops with a known specific type, generate a OneTypedResult
// trait.
if (numResults == 1 && numVariadicResults == 0) {
- auto cppName = op.getResults().begin()->constraint.getCPPClassName();
+ auto cppName = op.getResults().begin()->constraint.getCppType();
opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl");
}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 27ad79a5c1efed..9a95f495b77658 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1657,7 +1657,7 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
TypeSwitch<FormatElement *>(dir->getArg())
.Case<OperandVariable, ResultVariable>([&](auto operand) {
body << formatv(parserCode,
- operand->getVar()->constraint.getCPPClassName(),
+ operand->getVar()->constraint.getCppType(),
listName);
})
.Default([&](auto operand) {
@@ -2603,7 +2603,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
}
if (var && !var->isVariadicOfVariadic() && !var->isVariadic() &&
!var->isOptional()) {
- std::string cppClass = var->constraint.getCPPClassName();
+ StringRef cppType = var->constraint.getCppType();
if (dir->shouldBeQualified()) {
body << " _odsPrinter << " << op.getGetterName(var->name)
<< "().getType();\n";
@@ -2612,7 +2612,7 @@ void OperationFormat::genElementPrinter(FormatElement *element,
body << " {\n"
<< " auto type = " << op.getGetterName(var->name)
<< "().getType();\n"
- << " if (auto validType = ::llvm::dyn_cast<" << cppClass
+ << " if (auto validType = ::llvm::dyn_cast<" << cppType
<< ">(type))\n"
<< " _odsPrinter.printStrippedAttrOrType(validType);\n"
<< " else\n"
>From 27f3ffae411d6b0f5227850ab11a0ad69090daca Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Wed, 7 Aug 2024 17:49:45 +0200
Subject: [PATCH 2/2] [mlir][ODS] Verify type constraints in Types and
Attributes
When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.
Example:
```
def TestTypeVerification : Test_Type<"TestTypeVerification"> {
let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
// ...
}
```
No verification code was generated to ensure that `$param` is I16 or I32.
When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated.
When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.)
---
mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h | 7 +-
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 12 ++-
mlir/include/mlir/Dialect/Quant/QuantTypes.h | 43 +++++----
.../mlir/Dialect/SPIRV/IR/SPIRVAttributes.h | 14 +--
.../mlir/Dialect/SPIRV/IR/SPIRVTypes.h | 10 +-
mlir/include/mlir/IR/StorageUniquerSupport.h | 10 +-
mlir/include/mlir/IR/Types.h | 2 +-
mlir/include/mlir/TableGen/AttrOrTypeDef.h | 8 ++
mlir/include/mlir/TableGen/Class.h | 2 +
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 6 +-
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 12 +--
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 39 ++++----
mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp | 9 +-
mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp | 9 +-
.../SparseTensor/IR/SparseTensorDialect.cpp | 8 ++
mlir/lib/TableGen/AttrOrTypeDef.cpp | 13 +++
mlir/test/IR/test-verifiers-type.mlir | 9 ++
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 1 -
mlir/test/lib/Dialect/Test/TestTypeDefs.td | 6 ++
mlir/test/mlir-tblgen/attr-or-type-format.td | 79 ++++++++++++++-
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp | 96 +++++++++++++++++--
.../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 2 +-
22 files changed, 304 insertions(+), 93 deletions(-)
create mode 100644 mlir/test/IR/test-verifiers-type.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 96e1935bd0a841..57acd72610415f 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -148,9 +148,10 @@ class MMAMatrixType
/// Verify that shape and elementType are actually allowed for the
/// MMAMatrixType.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<int64_t> shape, Type elementType,
- StringRef operand);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand);
/// Get number of dims.
unsigned getNumDims() const;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 1befdfa74f67c5..2ea589a7c4c3bd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,11 +180,13 @@ class LLVMStructType
ArrayRef<Type> getBody() const;
/// Verifies that the type about to be constructed is well-formed.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- StringRef, bool);
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<Type> types, bool);
- using Base::verify;
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
+ bool);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Type> types, bool);
+ using Base::verifyInvariants;
/// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
/// DataLayout instance and query it instead.
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index de5aed0a91a209..57a2aa29833657 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -54,10 +54,10 @@ class QuantizedType : public Type {
/// The maximum number of bits supported for storage types.
static constexpr unsigned MaxStorageBits = 32;
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType,
- Type expressedType, int64_t storageTypeMin,
- int64_t storageTypeMax);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool classof(Type type);
@@ -214,10 +214,10 @@ class AnyQuantizedType
int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType,
- Type expressedType, int64_t storageTypeMin,
- int64_t storageTypeMax);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax);
};
/// Represents a family of uniform, quantized types.
@@ -276,11 +276,11 @@ class UniformQuantizedType
int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType,
- Type expressedType, double scale,
- int64_t zeroPoint, int64_t storageTypeMin,
- int64_t storageTypeMax);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType, double scale,
+ int64_t zeroPoint, int64_t storageTypeMin,
+ int64_t storageTypeMax);
/// Gets the scale term. The scale designates the difference between the real
/// values corresponding to consecutive quantized values differing by 1.
@@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
int64_t storageTypeMin, int64_t storageTypeMax);
/// Verifies construction invariants and issues errors/warnings.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType,
- Type expressedType, ArrayRef<double> scales,
- ArrayRef<int64_t> zeroPoints,
- int32_t quantizedDimension,
- int64_t storageTypeMin, int64_t storageTypeMax);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+ Type storageType, Type expressedType,
+ ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+ int32_t quantizedDimension, int64_t storageTypeMin,
+ int64_t storageTypeMax);
/// Gets the quantization scales. The scales designate the difference between
/// the real values corresponding to consecutive quantized values differing
@@ -403,8 +403,9 @@ class CalibratedQuantizedType
double min, double max);
/// Verifies construction invariants and issues errors/warnings.
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- Type expressedType, double min, double max);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type expressedType, double min, double max);
double getMin() const;
double getMax() const;
};
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 5ebfa9ca5ec25c..2bdd7a5bf3dd83 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -76,9 +76,10 @@ class InterfaceVarABIAttr
/// Returns `spirv::StorageClass`.
std::optional<StorageClass> getStorageClass();
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- IntegerAttr descriptorSet, IntegerAttr binding,
- IntegerAttr storageClass);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ IntegerAttr descriptorSet, IntegerAttr binding,
+ IntegerAttr storageClass);
static constexpr StringLiteral name = "spirv.interface_var_abi";
};
@@ -128,9 +129,10 @@ class VerCapExtAttr
/// Returns the capabilities as an integer array attribute.
ArrayAttr getCapabilitiesAttr();
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- IntegerAttr version, ArrayAttr capabilities,
- ArrayAttr extensions);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ IntegerAttr version, ArrayAttr capabilities,
+ ArrayAttr extensions);
static constexpr StringLiteral name = "spirv.ver_cap_ext";
};
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 55f0c787b44403..e2d04553d91b8b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -258,8 +258,9 @@ class SampledImageType
static SampledImageType
getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- Type imageType);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type imageType);
Type getImageType() const;
@@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
Type columnType, uint32_t columnCount);
- static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
- Type columnType, uint32_t columnCount);
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type columnType, uint32_t columnCount);
/// Returns true if the matrix elements are vectors of float elements.
static bool isValidColumnType(Type columnType);
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index fb64f15162df5b..d6ccbbd8579947 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
template <typename... Args>
static ConcreteT get(MLIRContext *ctx, Args &&...args) {
// Ensure that the invariants are correct for construction.
- assert(
- succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
+ assert(succeeded(
+ ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
}
@@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
MLIRContext *ctx, Args... args) {
// If the construction invariants fail then we return a null attribute.
- if (failed(ConcreteT::verify(emitErrorFn, args...)))
+ if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
return ConcreteT();
return UniquerT::template get<ConcreteT>(ctx, args...);
}
@@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
/// Default implementation that just returns success.
template <typename... Args>
- static LogicalResult verify(Args... args) {
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
+ Args... args) {
return success();
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 60dc8fee0f4a96..dfc7e69472c891 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -34,7 +34,7 @@ class AsmState;
/// Derived type classes are expected to implement several required
/// implementation hooks:
/// * Optional:
-/// - static LogicalResult verify(
+/// - static LogicalResult verifyInvariants(
/// function_ref<InFlightDiagnostic()> emitError,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 19c3a9183ec2cf..36744c85bc7086 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -16,6 +16,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Builder.h"
+#include "mlir/TableGen/Constraint.h"
#include "mlir/TableGen/Trait.h"
namespace llvm {
@@ -85,6 +86,9 @@ class AttrOrTypeParameter {
/// Get an optional C++ parameter parser.
std::optional<StringRef> getParser() const;
+ /// If this is a type constraint, return it.
+ std::optional<Constraint> getConstraint() const;
+
/// Get an optional C++ parameter printer.
std::optional<StringRef> getPrinter() const;
@@ -198,6 +202,10 @@ class AttrOrTypeDef {
/// method.
bool genVerifyDecl() const;
+ /// Return true if we need to generate any type constraint verification and
+ /// the getChecked method.
+ bool genVerifyInvariantsImpl() const;
+
/// Returns the def's extra class declaration code.
std::optional<StringRef> getExtraDecls() const;
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 855952d19492db..f750a34a3b2ba4 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -67,6 +67,8 @@ class MethodParameter {
/// Get the C++ type.
StringRef getType() const { return type; }
+ /// Get the C++ parameter name.
+ StringRef getName() const { return name; }
/// Returns true if the parameter has a default value.
bool hasDefaultValue() const { return !defaultValue.empty(); }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc2668310ddb0..a1f87a637a6141 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
}
LogicalResult
-MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<int64_t> shape, Type elementType,
- StringRef operand) {
+MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType,
+ StringRef operand) {
if (operand != "AOp" && operand != "BOp" && operand != "COp")
return emitError() << "operand expected to be one of AOp, BOp or COp";
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index dc7aef8ef7f850..7f10a15ff31ff9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
bool LLVMStructType::isValidElementType(Type type) {
return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
- LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
- type);
+ LLVMFunctionType, LLVMTokenType>(type);
}
LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
: getImpl()->getTypeList();
}
-LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
- StringRef, bool) {
+LogicalResult
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
+ bool) {
return success();
}
LogicalResult
-LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<Type> types, bool) {
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<Type> types, bool) {
for (Type t : types)
if (!isValidElementType(t))
return emitError() << "invalid LLVM structure element type: " << t;
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 81e3b914755be2..c2ba9c04e8771d 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
}
LogicalResult
-QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType, Type expressedType,
- int64_t storageTypeMin, int64_t storageTypeMax) {
+QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ unsigned flags, Type storageType,
+ Type expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
@@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
}
LogicalResult
-AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
- unsigned flags, Type storageType, Type expressedType,
- int64_t storageTypeMin, int64_t storageTypeMax) {
- if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
- storageTypeMin, storageTypeMax))) {
+AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ unsigned flags, Type storageType,
+ Type expressedType, int64_t storageTypeMin,
+ int64_t storageTypeMax) {
+ if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+ expressedType, storageTypeMin,
+ storageTypeMax))) {
return failure();
}
@@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
storageTypeMin, storageTypeMax);
}
-LogicalResult UniformQuantizedType::verify(
+LogicalResult UniformQuantizedType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
- if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
- storageTypeMin, storageTypeMax))) {
+ if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+ expressedType, storageTypeMin,
+ storageTypeMax))) {
return failure();
}
@@ -321,13 +325,14 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
quantizedDimension, storageTypeMin, storageTypeMax);
}
-LogicalResult UniformQuantizedPerAxisType::verify(
+LogicalResult UniformQuantizedPerAxisType::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
- if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
- storageTypeMin, storageTypeMax))) {
+ if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+ expressedType, storageTypeMin,
+ storageTypeMax))) {
return failure();
}
@@ -380,9 +385,9 @@ CalibratedQuantizedType CalibratedQuantizedType::getChecked(
min, max);
}
-LogicalResult
-CalibratedQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type expressedType, double min, double max) {
+LogicalResult CalibratedQuantizedType::verifyInvariants(
+ function_ref<InFlightDiagnostic()> emitError, Type expressedType,
+ double min, double max) {
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
index 8a0ee7a3d81367..b71be23fdf47d0 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp
@@ -162,7 +162,7 @@ spirv::InterfaceVarABIAttr::getStorageClass() {
return std::nullopt;
}
-LogicalResult spirv::InterfaceVarABIAttr::verify(
+LogicalResult spirv::InterfaceVarABIAttr::verifyInvariants(
function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
IntegerAttr binding, IntegerAttr storageClass) {
if (!descriptorSet.getType().isSignlessInteger(32))
@@ -257,10 +257,9 @@ ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
return llvm::cast<ArrayAttr>(getImpl()->capabilities);
}
-LogicalResult
-spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- IntegerAttr version, ArrayAttr capabilities,
- ArrayAttr extensions) {
+LogicalResult spirv::VerCapExtAttr::verifyInvariants(
+ function_ref<InFlightDiagnostic()> emitError, IntegerAttr version,
+ ArrayAttr capabilities, ArrayAttr extensions) {
if (!version.getType().isSignlessInteger(32))
return emitError() << "expected 32-bit integer for version";
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 3808620bdffa6d..c5590905b75045 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -817,8 +817,8 @@ SampledImageType::getChecked(function_ref<InFlightDiagnostic()> emitError,
Type SampledImageType::getImageType() const { return getImpl()->imageType; }
LogicalResult
-SampledImageType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type imageType) {
+SampledImageType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type imageType) {
if (!llvm::isa<ImageType>(imageType))
return emitError() << "expected image type";
@@ -1181,8 +1181,9 @@ MatrixType MatrixType::getChecked(function_ref<InFlightDiagnostic()> emitError,
columnCount);
}
-LogicalResult MatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
- Type columnType, uint32_t columnCount) {
+LogicalResult
+MatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ Type columnType, uint32_t columnCount) {
if (columnCount < 2 || columnCount > 4)
return emitError() << "matrix can have 2, 3, or 4 columns only";
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1135ea32fe1abb..a284aa2f1f020f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1227,6 +1227,14 @@ StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
}
+StorageSpecifierType
+StorageSpecifierType::getChecked(function_ref<InFlightDiagnostic()> emitError,
+ MLIRContext *ctx,
+ SparseTensorEncodingAttr encoding) {
+ return Base::getChecked(emitError, ctx,
+ getNormalizedEncodingForSpecifier(encoding));
+}
+
//===----------------------------------------------------------------------===//
// SparseTensorDialect Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index c9dbb3bc76b1fa..9b9d9fd2317d99 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -184,6 +184,12 @@ bool AttrOrTypeDef::genVerifyDecl() const {
return def->getValueAsBit("genVerifyDecl");
}
+bool AttrOrTypeDef::genVerifyInvariantsImpl() const {
+ return any_of(parameters, [](const AttrOrTypeParameter &p) {
+ return p.getConstraint() != std::nullopt;
+ });
+}
+
std::optional<StringRef> AttrOrTypeDef::getExtraDecls() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? std::optional<StringRef>() : value;
@@ -331,6 +337,13 @@ std::optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
+std::optional<Constraint> AttrOrTypeParameter::getConstraint() const {
+ if (auto *param = dyn_cast<llvm::DefInit>(getDef()))
+ if (param->getDef()->isSubClassOf("Constraint"))
+ return Constraint(param->getDef());
+ return std::nullopt;
+}
+
//===----------------------------------------------------------------------===//
// AttributeSelfTypeParameter
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/test-verifiers-type.mlir b/mlir/test/IR/test-verifiers-type.mlir
new file mode 100644
index 00000000000000..96d0005eb7a19d
--- /dev/null
+++ b/mlir/test/IR/test-verifiers-type.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// CHECK: "test.type_producer"() : () -> !test.type_verification<i16>
+"test.type_producer"() : () -> !test.type_verification<i16>
+
+// -----
+
+// expected-error @below{{failed to verify 'param': 16-bit signless integer or 32-bit signless integer}}
+"test.type_producer"() : () -> !test.type_verification<f16>
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 8f109f8ce5e6dd..b3b94bd0ffea31 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -270,7 +270,6 @@ def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> {
let assemblyFormat = "`<` $a `>`";
let skipDefaultBuilders = 1;
- let genVerifyDecl = 1;
let builders = [AttrBuilder<(ins "int":$a), [{
return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a);
}], "::mlir::Attribute">];
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index d96152a0826f96..830475bed4e444 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -392,4 +392,10 @@ def TestRecursiveAlias
}];
}
+def TestTypeVerification : Test_Type<"TestTypeVerification"> {
+ let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
+ let mnemonic = "type_verification";
+ let assemblyFormat = "`<` $param `>`";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 2884c4ed6a9081..c5348409e8e44f 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -1,8 +1,9 @@
-// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-attrdef-defs -I %S/../../include | FileCheck %s --check-prefix=ATTR
-// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-typedef-defs -I %S/../../include | FileCheck %s --check-prefix=TYPE
-// RUN: sed 's/DEFAULT_TYPE_PARSER/1/' %s | mlir-tblgen -gen-typedef-defs -I %S/../../include | FileCheck %s --check-prefix=TYPE --check-prefix=DEFAULT_TYPE_PARSER
+// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-attrdef-defs -attrdefs-dialect=TestDialect -I %S/../../include | FileCheck %s --check-prefix=ATTR
+// RUN: sed 's/DEFAULT_TYPE_PARSER/0/' %s | mlir-tblgen -gen-typedef-defs -typedefs-dialect=TestDialect -I %S/../../include | FileCheck %s --check-prefix=TYPE
+// RUN: sed 's/DEFAULT_TYPE_PARSER/1/' %s | mlir-tblgen -gen-typedef-defs -typedefs-dialect=TestDialect -I %S/../../include | FileCheck %s --check-prefix=TYPE --check-prefix=DEFAULT_TYPE_PARSER
include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/OpBase.td"
@@ -663,6 +664,78 @@ def TypeO : TestType<"TestQ"> {
let assemblyFormat = "(custom<AB>($a)^ `x`) : (`y`)?";
}
+// Test attr / type verification.
+
+// TYPE: ::llvm::LogicalResult TestPType::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type a) {
+// TYPE: if (!(((a.isSignlessInteger(16))) || ((a.isSignlessInteger(32))))) {
+// TYPE: emitError() << "failed to verify 'a': 16-bit signless integer or 32-bit signless integer";
+// TYPE: return ::mlir::failure();
+// TYPE: }
+// TYPE: return ::mlir::success();
+// TYPE: }
+
+// TYPE: ::llvm::LogicalResult TestPType::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type a) {
+// TYPE: if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
+// TYPE: return ::mlir::failure();
+// TYPE: if (::mlir::failed(verify(emitError, a)))
+// TYPE: return ::mlir::failure();
+// TYPE: return ::mlir::success();
+// TYPE: }
+
+def TypeP : TestType<"TestP"> {
+ let parameters = (ins AnyTypeOf<[I16, I32]>:$a);
+ let mnemonic = "type_p";
+ let genVerifyDecl = 1;
+ let assemblyFormat = "$a";
+}
+
+// ATTR: ::llvm::LogicalResult TestRAttr::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::IntegerType a) {
+// ATTR: if (!((a.isSignlessInteger(32)))) {
+// ATTR: emitError() << "failed to verify 'a': 32-bit signless integer";
+// ATTR: return ::mlir::failure();
+// ATTR: }
+// ATTR: return ::mlir::success();
+// ATTR: }
+
+// ATTR: ::llvm::LogicalResult TestRAttr::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::IntegerType a) {
+// ATTR: if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
+// ATTR: return ::mlir::failure();
+// ATTR: if (::mlir::failed(verify(emitError, a)))
+// ATTR: return ::mlir::failure();
+// ATTR: return ::mlir::success();
+// ATTR: }
+
+def AttrR : TestAttr<"TestR"> {
+ let parameters = (ins I32:$a);
+ let mnemonic = "attr_r";
+ let genVerifyDecl = 1;
+ let assemblyFormat = "$a";
+}
+
+// TYPE: ::llvm::LogicalResult TestSType::verifyInvariantsImpl(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::ArrayAttr a) {
+// TYPE: if (!((::llvm::isa<::mlir::ArrayAttr>(a)))) {
+// TYPE: emitError() << "failed to verify 'a': A collection of other Attribute values";
+// TYPE: return ::mlir::failure();
+// TYPE: }
+// TYPE: return ::mlir::success();
+// TYPE: }
+
+// TYPE: ::llvm::LogicalResult TestSType::verifyInvariants(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::ArrayAttr a) {
+// TYPE: if (::mlir::failed(verifyInvariantsImpl(emitError, a)))
+// TYPE: return ::mlir::failure();
+// TYPE: if (::mlir::failed(verify(emitError, a)))
+// TYPE: return ::mlir::failure();
+// TYPE: return ::mlir::success();
+// TYPE: }
+
+def TypeS : TestType<"TestS"> {
+ // TODO: Support attribute constraints as parameters.
+ let parameters = (ins Builtin_ArrayAttr:$a);
+ let mnemonic = "type_s";
+ let genVerifyDecl = 1;
+ let assemblyFormat = "$a";
+}
+
// DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
// DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
// DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 8cc8314418104c..71ba6a5c73da9e 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -93,8 +93,14 @@ class DefGen {
void emitDialectName();
/// Emit attribute or type builders.
void emitBuilders();
- /// Emit a verifier for the def.
- void emitVerifier();
+ /// Emit a verifier declaration for custom verification (impl. provided by
+ /// the users).
+ void emitVerifierDecl();
+ /// Emit a verifier that checks type constraints.
+ void emitInvariantsVerifierImpl();
+ /// Emit an entry poiunt for verification that calls the invariants and
+ /// custom verifier.
+ void emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier);
/// Emit parsers and printers.
void emitParserPrinter();
/// Emit parameter accessors, if required.
@@ -188,9 +194,17 @@ DefGen::DefGen(const AttrOrTypeDef &def)
emitName();
// Emit the dialect name.
emitDialectName();
- // Emit the verifier.
- if (storageCls && def.genVerifyDecl())
- emitVerifier();
+ // Emit verification of type constraints.
+ bool genVerifyInvariantsImpl = def.genVerifyInvariantsImpl();
+ if (storageCls && genVerifyInvariantsImpl)
+ emitInvariantsVerifierImpl();
+ // Emit the custom verifier (written by the user).
+ bool genVerifyDecl = def.genVerifyDecl();
+ if (storageCls && genVerifyDecl)
+ emitVerifierDecl();
+ // Emit the "verifyInvariants" function if there is any verification at all.
+ if (storageCls)
+ emitInvariantsVerifier(genVerifyInvariantsImpl, genVerifyDecl);
// Emit the mnemonic, if there is one, and any associated parser and printer.
if (def.getMnemonic())
emitParserPrinter();
@@ -295,24 +309,88 @@ void DefGen::emitDialectName() {
void DefGen::emitBuilders() {
if (!def.skipDefaultBuilders()) {
emitDefaultBuilder();
- if (def.genVerifyDecl())
+ if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
emitCheckedBuilder();
}
for (auto &builder : def.getBuilders()) {
emitCustomBuilder(builder);
- if (def.genVerifyDecl())
+ if (def.genVerifyDecl() || def.genVerifyInvariantsImpl())
emitCheckedCustomBuilder(builder);
}
}
-void DefGen::emitVerifier() {
- defCls.declare<UsingDeclaration>("Base::getChecked");
+void DefGen::emitVerifierDecl() {
defCls.declareStaticMethod(
"::llvm::LogicalResult", "verify",
getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>",
"emitError"}}));
}
+static const char *const patternParameterVerificationCode = R"(
+if (!({0})) {
+ emitError() << "failed to verify '{1}': {2}";
+ return ::mlir::failure();
+}
+)";
+
+void DefGen::emitInvariantsVerifierImpl() {
+ SmallVector<MethodParameter> builderParams = getBuilderParams(
+ {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+ Method *verifier =
+ defCls.addMethod("::llvm::LogicalResult", "verifyInvariantsImpl",
+ Method::Static, builderParams);
+ verifier->body().indent();
+
+ // Generate verification for each parameter that is a type constraint.
+ for (auto it : llvm::enumerate(def.getParameters())) {
+ const AttrOrTypeParameter ¶m = it.value();
+ std::optional<Constraint> constraint = param.getConstraint();
+ // No verification needed for parameters that are not type constraints.
+ if (!constraint.has_value())
+ continue;
+ FmtContext ctx;
+ // Note: Skip over the first method parameter (`emitError`).
+ ctx.withSelf(builderParams[it.index() + 1].getName());
+ std::string condition = tgfmt(constraint->getConditionTemplate(), &ctx);
+ verifier->body() << formatv(patternParameterVerificationCode, condition,
+ param.getName(), constraint->getSummary())
+ << "\n";
+ }
+ verifier->body() << "return ::mlir::success();";
+}
+
+void DefGen::emitInvariantsVerifier(bool hasImpl, bool hasCustomVerifier) {
+ if (!hasImpl && !hasCustomVerifier)
+ return;
+ defCls.declare<UsingDeclaration>("Base::getChecked");
+ SmallVector<MethodParameter> builderParams = getBuilderParams(
+ {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}});
+ Method *verifier =
+ defCls.addMethod("::llvm::LogicalResult", "verifyInvariants",
+ Method::Static, builderParams);
+ verifier->body().indent();
+
+ auto emitVerifierCall = [&](StringRef name) {
+ verifier->body() << strfmt("if (::mlir::failed({0}(", name);
+ llvm::interleaveComma(
+ llvm::map_range(builderParams,
+ [](auto ¶m) { return param.getName(); }),
+ verifier->body());
+ verifier->body() << ")))\n";
+ verifier->body() << " return ::mlir::failure();\n";
+ };
+
+ if (hasImpl) {
+ // Call the verifier that checks the type constraints.
+ emitVerifierCall("verifyInvariantsImpl");
+ }
+ if (hasCustomVerifier) {
+ // Call the custom verifier that is provided by the user.
+ emitVerifierCall("verify");
+ }
+ verifier->body() << "return ::mlir::success();";
+}
+
void DefGen::emitParserPrinter() {
auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
"::llvm::StringLiteral", "getMnemonic");
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index dacc20b6ba2086..a4ae271edb6bd2 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -323,7 +323,7 @@ void DefFormat::genParser(MethodBody &os) {
// Generate call to the attribute or type builder. Use the checked getter
// if one was generated.
- if (def.genVerifyDecl()) {
+ if (def.genVerifyDecl() || def.genVerifyInvariantsImpl()) {
os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()",
&ctx, def.getCppClassName());
} else {
More information about the Mlir-commits
mailing list