[Mlir-commits] [mlir] 7d4aa1f - [mlir][IR] Auto-generate element type verification for VectorType (#102449)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Aug 11 23:02:17 PDT 2024


Author: Matthias Springer
Date: 2024-08-12T08:02:14+02:00
New Revision: 7d4aa1ff6bab27b5442f4765336fa827479d7bbc

URL: https://github.com/llvm/llvm-project/commit/7d4aa1ff6bab27b5442f4765336fa827479d7bbc
DIFF: https://github.com/llvm/llvm-project/commit/7d4aa1ff6bab27b5442f4765336fa827479d7bbc.diff

LOG: [mlir][IR] Auto-generate element type verification for VectorType (#102449)

#102326 enables verification of type parameters that are type
constraints. The element type verification for `VectorType` (and maybe
other builtin types in the future) can now be auto-generated.

Also remove redundant error checking in the vector type parser: element
type and dimensions are already checked by the verifier (which is called
from `getChecked`).

Depends on #102326.

Added: 
    

Modified: 
    mlir/include/mlir/IR/BuiltinTypes.td
    mlir/lib/AsmParser/TypeParser.cpp
    mlir/test/IR/invalid-builtin-types.mlir
    mlir/test/python/ir/builtin_types.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 365edcf68d8b94..4b3add2035263c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/CommonTypeConstraints.td"
 
 // TODO: Currently the types defined in this file are prefixed with `Builtin_`.
 // This is to 
diff erentiate the types here with the ones in OpBase.td. We should
@@ -1146,7 +1147,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    "Type":$elementType,
+    AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
@@ -1173,6 +1174,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
     /// type. In particular, vectors can consist of integer, index, or float
     /// primitives.
     static bool isValidElementType(Type t) {
+      // TODO: Auto-generate this function from $elementType.
       return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
     }
 

diff  --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 542eaeefe57f12..f070c072c43296 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -458,31 +458,24 @@ Type Parser::parseTupleType() {
 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
 ///
 VectorType Parser::parseVectorType() {
+  SMLoc loc = getToken().getLoc();
   consumeToken(Token::kw_vector);
 
   if (parseToken(Token::less, "expected '<' in vector type"))
     return nullptr;
 
+  // Parse the dimensions.
   SmallVector<int64_t, 4> dimensions;
   SmallVector<bool, 4> scalableDims;
   if (parseVectorDimensionList(dimensions, scalableDims))
     return nullptr;
-  if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
-    return emitError(getToken().getLoc(),
-                     "vector types must have positive constant sizes"),
-           nullptr;
 
   // Parse the element type.
-  auto typeLoc = getToken().getLoc();
   auto elementType = parseType();
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  if (!VectorType::isValidElementType(elementType))
-    return emitError(typeLoc, "vector elements must be int/index/float type"),
-           nullptr;
-
-  return VectorType::get(dimensions, elementType, scalableDims);
+  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.

diff  --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 9884212e916c1f..07854a25000feb 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -120,17 +120,17 @@ func.func @illegaltype(i21312312323120) // expected-error {{invalid integer widt
 // -----
 
 // Test no nested vector.
-// expected-error at +1 {{vector elements must be int/index/float type}}
+// expected-error at +1 {{failed to verify 'elementType': integer or index or floating-point}}
 func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
 
 // -----
 
-// expected-error @+1 {{vector types must have positive constant sizes}}
+// expected-error @+1 {{vector types must have positive constant sizes but got 0}}
 func.func @zero_vector_type() -> vector<0xi32>
 
 // -----
 
-// expected-error @+1 {{vector types must have positive constant sizes}}
+// expected-error @+1 {{vector types must have positive constant sizes but got 1, 0}}
 func.func @zero_in_vector_type() -> vector<1x0xi32>
 
 // -----

diff  --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py
index 2161f110ac31e2..f95cccc54105ed 100644
--- a/mlir/test/python/ir/builtin_types.py
+++ b/mlir/test/python/ir/builtin_types.py
@@ -345,7 +345,7 @@ def testVectorType():
             VectorType.get(shape, none)
         except MLIRError as e:
             # CHECK: Invalid type:
-            # CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
+            # CHECK: error: unknown: failed to verify 'elementType': integer or index or floating-point
             print(e)
         else:
             print("Exception not produced")


        


More information about the Mlir-commits mailing list