[llvm-branch-commits] [mlir] [mlir][IR] Auto-generate element type verification for VectorType (PR #102449)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Aug 8 03:42:25 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
#<!-- -->102326 enables verification of type parameters that are type constraints. The element type verification for `VectorType` (and maybe other builtin types in the future) can now be auto-generated.
Also remove redundant error checking in the vector type parser: element type and dimensions are already checked by the verifier (which is called from `getChecked`).
Depends on #<!-- -->102326.
---
Full diff: https://github.com/llvm/llvm-project/pull/102449.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+3-1)
- (modified) mlir/lib/AsmParser/TypeParser.cpp (+3-10)
- (modified) mlir/test/IR/invalid-builtin-types.mlir (+3-3)
``````````diff
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 365edcf68d8b94..4b3add2035263c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/CommonTypeConstraints.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
// This is to differentiate the types here with the ones in OpBase.td. We should
@@ -1146,7 +1147,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- "Type":$elementType,
+ AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
@@ -1173,6 +1174,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
static bool isValidElementType(Type t) {
+ // TODO: Auto-generate this function from $elementType.
return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
}
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 542eaeefe57f12..f070c072c43296 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -458,31 +458,24 @@ Type Parser::parseTupleType() {
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
///
VectorType Parser::parseVectorType() {
+ SMLoc loc = getToken().getLoc();
consumeToken(Token::kw_vector);
if (parseToken(Token::less, "expected '<' in vector type"))
return nullptr;
+ // Parse the dimensions.
SmallVector<int64_t, 4> dimensions;
SmallVector<bool, 4> scalableDims;
if (parseVectorDimensionList(dimensions, scalableDims))
return nullptr;
- if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
- return emitError(getToken().getLoc(),
- "vector types must have positive constant sizes"),
- nullptr;
// Parse the element type.
- auto typeLoc = getToken().getLoc();
auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- if (!VectorType::isValidElementType(elementType))
- return emitError(typeLoc, "vector elements must be int/index/float type"),
- nullptr;
-
- return VectorType::get(dimensions, elementType, scalableDims);
+ return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 9884212e916c1f..07854a25000feb 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -120,17 +120,17 @@ func.func @illegaltype(i21312312323120) // expected-error {{invalid integer widt
// -----
// Test no nested vector.
-// expected-error at +1 {{vector elements must be int/index/float type}}
+// expected-error at +1 {{failed to verify 'elementType': integer or index or floating-point}}
func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
// -----
-// expected-error @+1 {{vector types must have positive constant sizes}}
+// expected-error @+1 {{vector types must have positive constant sizes but got 0}}
func.func @zero_vector_type() -> vector<0xi32>
// -----
-// expected-error @+1 {{vector types must have positive constant sizes}}
+// expected-error @+1 {{vector types must have positive constant sizes but got 1, 0}}
func.func @zero_in_vector_type() -> vector<1x0xi32>
// -----
``````````
</details>
https://github.com/llvm/llvm-project/pull/102449
More information about the llvm-branch-commits
mailing list