[Mlir-commits] [mlir] bae1517 - [mlir] Add verification to LLVM dialect types

Alex Zinenko llvmlistbot at llvm.org
Tue Aug 11 08:22:02 PDT 2020


Author: Alex Zinenko
Date: 2020-08-11T17:21:52+02:00
New Revision: bae1517266bf4ce85a32390323fd463f28ae9d0c

URL: https://github.com/llvm/llvm-project/commit/bae1517266bf4ce85a32390323fd463f28ae9d0c
DIFF: https://github.com/llvm/llvm-project/commit/bae1517266bf4ce85a32390323fd463f28ae9d0c.diff

LOG: [mlir] Add verification to LLVM dialect types

Now that LLVM dialect types are implemented directly in the dialect, we can use
MLIR hooks for verifying type construction invariants. Implement the verifiers
and use them in the parser.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D85663

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/test/Dialect/LLVMIR/types-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 449ade85a5b4..38a70d510e50 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -142,7 +142,6 @@ class LLVMType : public Type {
   LLVMType getPointerTo(unsigned addrSpace = 0);
   LLVMType getPointerElementTy();
   bool isPointerTy();
-  static bool isValidPointerElementType(LLVMType type);
 
   /// Struct type utilities.
   LLVMType getStructElementType(unsigned i);
@@ -287,15 +286,25 @@ class LLVMArrayType : public Type::TypeBase<LLVMArrayType, LLVMType,
   /// Inherit base constructors.
   using Base::Base;
 
+  /// Checks if the given type can be used inside an array type.
+  static bool isValidElementType(LLVMType type);
+
   /// Gets or creates an instance of LLVM dialect array type containing
   /// `numElements` of `elementType`, in the same context as `elementType`.
   static LLVMArrayType get(LLVMType elementType, unsigned numElements);
+  static LLVMArrayType getChecked(Location loc, LLVMType elementType,
+                                  unsigned numElements);
 
   /// Returns the element type of the array.
   LLVMType getElementType();
 
   /// Returns the number of elements in the array type.
   unsigned getNumElements();
+
+  /// Verifies that the type about to be constructed is well-formed.
+  static LogicalResult verifyConstructionInvariants(Location loc,
+                                                    LLVMType elementType,
+                                                    unsigned numElements);
 };
 
 //===----------------------------------------------------------------------===//
@@ -312,10 +321,22 @@ class LLVMFunctionType
   /// Inherit base constructors.
   using Base::Base;
 
+  /// Checks if the given type can be used an argument in a function type.
+  static bool isValidArgumentType(LLVMType type);
+
+  /// Checks if the given type can be used as a result in a function type.
+  static bool isValidResultType(LLVMType type);
+
+  /// Returns whether the function is variadic.
+  bool isVarArg();
+
   /// Gets or creates an instance of LLVM dialect function in the same context
   /// as the `result` type.
   static LLVMFunctionType get(LLVMType result, ArrayRef<LLVMType> arguments,
                               bool isVarArg = false);
+  static LLVMFunctionType getChecked(Location loc, LLVMType result,
+                                     ArrayRef<LLVMType> arguments,
+                                     bool isVarArg = false);
 
   /// Returns the result type of the function.
   LLVMType getReturnType();
@@ -326,12 +347,14 @@ class LLVMFunctionType
   /// Returns `i`-th argument of the function. Asserts on out-of-bounds.
   LLVMType getParamType(unsigned i);
 
-  /// Returns whether the function is variadic.
-  bool isVarArg();
-
   /// Returns a list of argument types of the function.
   ArrayRef<LLVMType> getParams();
   ArrayRef<LLVMType> params() { return getParams(); }
+
+  /// Verifies that the type about to be constructed is well-formed.
+  static LogicalResult
+  verifyConstructionInvariants(Location loc, LLVMType result,
+                               ArrayRef<LLVMType> arguments, bool);
 };
 
 //===----------------------------------------------------------------------===//
@@ -348,9 +371,14 @@ class LLVMIntegerType : public Type::TypeBase<LLVMIntegerType, LLVMType,
   /// Gets or creates an instance of the integer of the specified `bitwidth` in
   /// the given context.
   static LLVMIntegerType get(MLIRContext *ctx, unsigned bitwidth);
+  static LLVMIntegerType getChecked(Location loc, unsigned bitwidth);
 
   /// Returns the bitwidth of this integer type.
   unsigned getBitWidth();
+
+  /// Verifies that the type about to be constructed is well-formed.
+  static LogicalResult verifyConstructionInvariants(Location loc,
+                                                    unsigned bitwidth);
 };
 
 //===----------------------------------------------------------------------===//
@@ -366,16 +394,25 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, LLVMType,
   /// Inherit base constructors.
   using Base::Base;
 
+  /// Checks if the given type can have a pointer type pointing to it.
+  static bool isValidElementType(LLVMType type);
+
   /// Gets or creates an instance of LLVM dialect pointer type pointing to an
   /// object of `pointee` type in the given address space. The pointer type is
   /// created in the same context as `pointee`.
   static LLVMPointerType get(LLVMType pointee, unsigned addressSpace = 0);
+  static LLVMPointerType getChecked(Location loc, LLVMType pointee,
+                                    unsigned addressSpace = 0);
 
   /// Returns the pointed-to type.
   LLVMType getElementType();
 
   /// Returns the address space of the pointer.
   unsigned getAddressSpace();
+
+  /// Verifies that the type about to be constructed is well-formed.
+  static LogicalResult verifyConstructionInvariants(Location loc,
+                                                    LLVMType pointee, unsigned);
 };
 
 //===----------------------------------------------------------------------===//
@@ -412,18 +449,25 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
   /// Inherit base construtors.
   using Base::Base;
 
+  /// Checks if the given type can be contained in a structure type.
+  static bool isValidElementType(LLVMType type);
+
   /// Gets or creates an identified struct with the given name in the provided
   /// context. Note that unlike llvm::StructType::create, this function will
   /// _NOT_ rename a struct in case a struct with the same name already exists
   /// in the context. Instead, it will just return the existing struct,
   /// similarly to the rest of MLIR type ::get methods.
   static LLVMStructType getIdentified(MLIRContext *context, StringRef name);
+  static LLVMStructType getIdentifiedChecked(Location loc, StringRef name);
 
   /// Gets or creates a literal struct with the given body in the provided
   /// context.
   static LLVMStructType getLiteral(MLIRContext *context,
                                    ArrayRef<LLVMType> types,
                                    bool isPacked = false);
+  static LLVMStructType getLiteralChecked(Location loc,
+                                          ArrayRef<LLVMType> types,
+                                          bool isPacked = false);
 
   /// Gets or creates an intentionally-opaque identified struct. Such a struct
   /// cannot have its body set. To create an opaque struct with a mutable body,
@@ -432,6 +476,7 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
   /// already exists in the context. Instead, it will just return the existing
   /// struct, similarly to the rest of MLIR type ::get methods.
   static LLVMStructType getOpaque(StringRef name, MLIRContext *context);
+  static LLVMStructType getOpaqueChecked(Location loc, StringRef name);
 
   /// Set the body of an identified struct. Returns failure if the body could
   /// not be set, e.g. if the struct already has a body or if it was marked as
@@ -458,6 +503,11 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
 
   /// Returns the list of element types contained in a non-opaque struct.
   ArrayRef<LLVMType> getBody();
+
+  /// Verifies that the type about to be constructed is well-formed.
+  static LogicalResult verifyConstructionInvariants(Location, StringRef, bool);
+  static LogicalResult
+  verifyConstructionInvariants(Location loc, ArrayRef<LLVMType> types, bool);
 };
 
 //===----------------------------------------------------------------------===//
@@ -475,11 +525,19 @@ class LLVMVectorType : public LLVMType {
   /// Support type casting functionality.
   static bool classof(Type type);
 
+  /// Checks if the given type can be used in a vector type.
+  static bool isValidElementType(LLVMType type);
+
   /// Returns the element type of the vector.
   LLVMType getElementType();
 
   /// Returns the number of elements in the vector.
   llvm::ElementCount getElementCount();
+
+  /// Verifies that the type about to be constructed is well-formed.
+  static LogicalResult verifyConstructionInvariants(Location loc,
+                                                    LLVMType elementType,
+                                                    unsigned numElements);
 };
 
 //===----------------------------------------------------------------------===//
@@ -494,10 +552,13 @@ class LLVMFixedVectorType
 public:
   /// Inherit base constructor.
   using Base::Base;
+  using LLVMVectorType::verifyConstructionInvariants;
 
   /// Gets or creates a fixed vector type containing `numElements` of
   /// `elementType` in the same context as `elementType`.
   static LLVMFixedVectorType get(LLVMType elementType, unsigned numElements);
+  static LLVMFixedVectorType getChecked(Location loc, LLVMType elementType,
+                                        unsigned numElements);
 
   /// Returns the number of elements in the fixed vector.
   unsigned getNumElements();
@@ -516,11 +577,14 @@ class LLVMScalableVectorType
 public:
   /// Inherit base constructor.
   using Base::Base;
+  using LLVMVectorType::verifyConstructionInvariants;
 
   /// Gets or creates a scalable vector type containing a non-zero multiple of
   /// `minNumElements` of `elementType` in the same context as `elementType`.
   static LLVMScalableVectorType get(LLVMType elementType,
                                     unsigned minNumElements);
+  static LLVMScalableVectorType getChecked(Location loc, LLVMType elementType,
+                                           unsigned minNumElements);
 
   /// Returns the scaling factor of the number of elements in the vector. The
   /// vector contains at least the resulting number of elements, or any non-zero

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 515e120888b4..f1299c065d36 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -1113,7 +1113,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
 }
 
 static LogicalResult verify(GlobalOp op) {
-  if (!LLVMType::isValidPointerElementType(op.getType()))
+  if (!LLVMPointerType::isValidElementType(op.getType()))
     return op.emitOpError(
         "expects type to be a valid element type for an LLVM pointer");
   if (op.getParentOp() && !satisfiesLLVMModule(op.getParentOp()))

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 7df3ebe7c5b4..59ae48ebf3e7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -217,6 +217,7 @@ static ParseResult parseTypeImpl(DialectAsmParser &parser,
 ///   llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
 static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
                                           llvm::SetVector<StringRef> &stack) {
+  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   LLVMType returnType;
   if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) ||
       parser.parseLParen())
@@ -225,7 +226,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
   // Function type without arguments.
   if (succeeded(parser.parseOptionalRParen())) {
     if (succeeded(parser.parseGreater()))
-      return LLVMFunctionType::get(returnType, {}, /*isVarArg=*/false);
+      return LLVMFunctionType::getChecked(loc, returnType, {},
+                                          /*isVarArg=*/false);
     return LLVMFunctionType();
   }
 
@@ -235,7 +237,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
     if (succeeded(parser.parseOptionalEllipsis())) {
       if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
         return LLVMFunctionType();
-      return LLVMFunctionType::get(returnType, argTypes, /*isVarArg=*/true);
+      return LLVMFunctionType::getChecked(loc, returnType, argTypes,
+                                          /*isVarArg=*/true);
     }
 
     argTypes.push_back(parseTypeImpl(parser, stack));
@@ -245,13 +248,15 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
 
   if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
     return LLVMFunctionType();
-  return LLVMFunctionType::get(returnType, argTypes, /*isVarArg=*/false);
+  return LLVMFunctionType::getChecked(loc, returnType, argTypes,
+                                      /*isVarArg=*/false);
 }
 
 /// Parses an LLVM dialect pointer type.
 ///   llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
 static LLVMPointerType parsePointerType(DialectAsmParser &parser,
                                         llvm::SetVector<StringRef> &stack) {
+  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   LLVMType elementType;
   if (parser.parseLess() || parseTypeImpl(parser, stack, elementType))
     return LLVMPointerType();
@@ -262,7 +267,7 @@ static LLVMPointerType parsePointerType(DialectAsmParser &parser,
     return LLVMPointerType();
   if (failed(parser.parseGreater()))
     return LLVMPointerType();
-  return LLVMPointerType::get(elementType, addressSpace);
+  return LLVMPointerType::getChecked(loc, elementType, addressSpace);
 }
 
 /// Parses an LLVM dialect vector type.
@@ -273,6 +278,7 @@ static LLVMVectorType parseVectorType(DialectAsmParser &parser,
   SmallVector<int64_t, 2> dims;
   llvm::SMLoc dimPos;
   LLVMType elementType;
+  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
       parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
       parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
@@ -291,9 +297,9 @@ static LLVMVectorType parseVectorType(DialectAsmParser &parser,
   }
 
   bool isScalable = dims.size() == 2;
-  return isScalable ? static_cast<LLVMVectorType>(
-                          LLVMScalableVectorType::get(elementType, dims[1]))
-                    : LLVMFixedVectorType::get(elementType, dims[0]);
+  if (isScalable)
+    return LLVMScalableVectorType::getChecked(loc, elementType, dims[1]);
+  return LLVMFixedVectorType::getChecked(loc, elementType, dims[0]);
 }
 
 /// Parses an LLVM dialect array type.
@@ -303,6 +309,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser,
   SmallVector<int64_t, 1> dims;
   llvm::SMLoc sizePos;
   LLVMType elementType;
+  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
       parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
       parseTypeImpl(parser, stack, elementType) || parser.parseGreater())
@@ -313,7 +320,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser,
     return LLVMArrayType();
   }
 
-  return LLVMArrayType::get(elementType, dims[0]);
+  return LLVMArrayType::getChecked(loc, elementType, dims[0]);
 }
 
 /// Attempts to set the body of an identified structure type. Reports a parsing
@@ -324,6 +331,14 @@ static LLVMStructType trySetStructBody(LLVMStructType type,
                                        bool isPacked, DialectAsmParser &parser,
                                        llvm::SMLoc subtypesLoc,
                                        llvm::SetVector<StringRef> &stack) {
+  for (LLVMType t : subtypes) {
+    if (!LLVMStructType::isValidElementType(t)) {
+      parser.emitError(subtypesLoc)
+          << "invalid LLVM structure element type: " << t;
+      return LLVMStructType();
+    }
+  }
+
   if (succeeded(type.setBody(subtypes, isPacked)))
     return type;
 
@@ -343,7 +358,7 @@ static LLVMStructType trySetStructBody(LLVMStructType type,
 ///               | `struct<` string-literal `, opaque>`
 static LLVMStructType parseStructType(DialectAsmParser &parser,
                                       llvm::SetVector<StringRef> &stack) {
-  MLIRContext *ctx = parser.getBuilder().getContext();
+  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
 
   if (failed(parser.parseLess()))
     return LLVMStructType();
@@ -357,7 +372,7 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
     if (stack.count(name)) {
       if (failed(parser.parseGreater()))
         return LLVMStructType();
-      return LLVMStructType::getIdentified(ctx, name);
+      return LLVMStructType::getIdentifiedChecked(loc, name);
     }
     if (failed(parser.parseComma()))
       return LLVMStructType();
@@ -371,7 +386,7 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
              LLVMStructType();
     if (failed(parser.parseGreater()))
       return LLVMStructType();
-    auto type = LLVMStructType::getOpaque(name, ctx);
+    auto type = LLVMStructType::getOpaqueChecked(loc, name);
     if (!type.isOpaque()) {
       parser.emitError(kwLoc, "redeclaring defined struct as opaque");
       return LLVMStructType();
@@ -389,8 +404,8 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
     if (failed(parser.parseGreater()))
       return LLVMStructType();
     if (!isIdentified)
-      return LLVMStructType::getLiteral(ctx, {}, isPacked);
-    auto type = LLVMStructType::getIdentified(ctx, name);
+      return LLVMStructType::getLiteralChecked(loc, {}, isPacked);
+    auto type = LLVMStructType::getIdentifiedChecked(loc, name);
     return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack);
   }
 
@@ -414,8 +429,8 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
 
   // Construct the struct with body.
   if (!isIdentified)
-    return LLVMStructType::getLiteral(ctx, subtypes, isPacked);
-  auto type = LLVMStructType::getIdentified(ctx, name);
+    return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked);
+  auto type = LLVMStructType::getIdentifiedChecked(loc, name);
   return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack);
 }
 
@@ -428,6 +443,7 @@ static LLVMType parseTypeImpl(DialectAsmParser &parser,
   Type maybeIntegerType;
   MLIRContext *ctx = parser.getBuilder().getContext();
   llvm::SMLoc keyLoc = parser.getCurrentLocation();
+  Location loc = parser.getEncodedSourceLoc(keyLoc);
   OptionalParseResult result = parser.parseOptionalType(maybeIntegerType);
   if (result.hasValue()) {
     if (failed(*result))
@@ -437,7 +453,8 @@ static LLVMType parseTypeImpl(DialectAsmParser &parser,
       parser.emitError(keyLoc) << "unexpected type, expected i* or keyword";
       return LLVMType();
     }
-    return LLVMIntegerType::get(ctx, maybeIntegerType.getIntOrFloatBitWidth());
+    return LLVMIntegerType::getChecked(
+        loc, maybeIntegerType.getIntOrFloatBitWidth());
   }
 
   // Dispatch to concrete functions.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 7d052e3d644e..8ff7fc56edda 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -110,11 +110,6 @@ LLVMType LLVMType::getPointerElementTy() {
 
 bool LLVMType::isPointerTy() { return isa<LLVMPointerType>(); }
 
-bool LLVMType::isValidPointerElementType(LLVMType type) {
-  return !type.isa<LLVMVoidType>() && !type.isa<LLVMTokenType>() &&
-         !type.isa<LLVMMetadataType>() && !type.isa<LLVMLabelType>();
-}
-
 //----------------------------------------------------------------------------//
 // Struct type utilities.
 
@@ -227,19 +222,46 @@ LLVMType LLVMType::setStructTyBody(LLVMType structType,
 //===----------------------------------------------------------------------===//
 // Array type.
 
+bool LLVMArrayType::isValidElementType(LLVMType type) {
+  return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
+                   LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
+}
+
 LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
   assert(elementType && "expected non-null subtype");
   return Base::get(elementType.getContext(), LLVMType::ArrayType, elementType,
                    numElements);
 }
 
+LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
+                                        unsigned numElements) {
+  assert(elementType && "expected non-null subtype");
+  return Base::getChecked(loc, LLVMType::ArrayType, elementType, numElements);
+}
+
 LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
 
 unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; }
 
+LogicalResult
+LLVMArrayType::verifyConstructionInvariants(Location loc, LLVMType elementType,
+                                            unsigned numElements) {
+  if (!isValidElementType(elementType))
+    return emitError(loc, "invalid array element type: ") << elementType;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Function type.
 
+bool LLVMFunctionType::isValidArgumentType(LLVMType type) {
+  return !type.isa<LLVMVoidType, LLVMFunctionType>();
+}
+
+bool LLVMFunctionType::isValidResultType(LLVMType type) {
+  return !type.isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>();
+}
+
 LLVMFunctionType LLVMFunctionType::get(LLVMType result,
                                        ArrayRef<LLVMType> arguments,
                                        bool isVarArg) {
@@ -248,6 +270,14 @@ LLVMFunctionType LLVMFunctionType::get(LLVMType result,
                    arguments, isVarArg);
 }
 
+LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
+                                              ArrayRef<LLVMType> arguments,
+                                              bool isVarArg) {
+  assert(result && "expected non-null result");
+  return Base::getChecked(loc, LLVMType::FunctionType, result, arguments,
+                          isVarArg);
+}
+
 LLVMType LLVMFunctionType::getReturnType() {
   return getImpl()->getReturnType();
 }
@@ -266,6 +296,18 @@ ArrayRef<LLVMType> LLVMFunctionType::getParams() {
   return getImpl()->getArgumentTypes();
 }
 
+LogicalResult LLVMFunctionType::verifyConstructionInvariants(
+    Location loc, LLVMType result, ArrayRef<LLVMType> arguments, bool) {
+  if (!isValidResultType(result))
+    return emitError(loc, "invalid function result type: ") << result;
+
+  for (LLVMType arg : arguments)
+    if (!isValidArgumentType(arg))
+      return emitError(loc, "invalid function argument type: ") << arg;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Integer type.
 
@@ -273,41 +315,93 @@ LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) {
   return Base::get(ctx, LLVMType::IntegerType, bitwidth);
 }
 
+LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) {
+  return Base::getChecked(loc, LLVMType::IntegerType, bitwidth);
+}
+
 unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; }
 
+LogicalResult LLVMIntegerType::verifyConstructionInvariants(Location loc,
+                                                            unsigned bitwidth) {
+  constexpr int maxSupportedBitwidth = (1 << 24);
+  if (bitwidth >= maxSupportedBitwidth)
+    return emitError(loc, "integer type too wide");
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pointer type.
 
+bool LLVMPointerType::isValidElementType(LLVMType type) {
+  return !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
+                   LLVMLabelType>();
+}
+
 LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
   assert(pointee && "expected non-null subtype");
   return Base::get(pointee.getContext(), LLVMType::PointerType, pointee,
                    addressSpace);
 }
 
+LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
+                                            unsigned addressSpace) {
+  return Base::getChecked(loc, LLVMType::PointerType, pointee, addressSpace);
+}
+
 LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
 
 unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; }
 
+LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc,
+                                                            LLVMType pointee,
+                                                            unsigned) {
+  if (!isValidElementType(pointee))
+    return emitError(loc, "invalid pointer element type: ") << pointee;
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Struct type.
 
+bool LLVMStructType::isValidElementType(LLVMType type) {
+  return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
+                   LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
+}
+
 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
                                              StringRef name) {
   return Base::get(context, LLVMType::StructType, name, /*opaque=*/false);
 }
 
+LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
+                                                    StringRef name) {
+  return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/false);
+}
+
 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
                                           ArrayRef<LLVMType> types,
                                           bool isPacked) {
   return Base::get(context, LLVMType::StructType, types, isPacked);
 }
 
+LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
+                                                 ArrayRef<LLVMType> types,
+                                                 bool isPacked) {
+  return Base::getChecked(loc, LLVMType::StructType, types, isPacked);
+}
+
 LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) {
   return Base::get(context, LLVMType::StructType, name, /*opaque=*/true);
 }
 
+LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
+  return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/true);
+}
+
 LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
   assert(isIdentified() && "can only set bodies of identified structs");
+  assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
+         "expected valid body types");
   return Base::mutate(types, isPacked);
 }
 
@@ -323,9 +417,29 @@ ArrayRef<LLVMType> LLVMStructType::getBody() {
                         : getImpl()->getTypeList();
 }
 
+LogicalResult LLVMStructType::verifyConstructionInvariants(Location, StringRef,
+                                                           bool) {
+  return success();
+}
+
+LogicalResult
+LLVMStructType::verifyConstructionInvariants(Location loc,
+                                             ArrayRef<LLVMType> types, bool) {
+  for (LLVMType t : types)
+    if (!isValidElementType(t))
+      return emitError(loc, "invalid LLVM structure element type: ") << t;
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Vector types.
 
+bool LLVMVectorType::isValidElementType(LLVMType type) {
+  return type.isa<LLVMIntegerType, LLVMPointerType>() ||
+         type.isFloatingPointTy();
+}
+
 /// Support type casting functionality.
 bool LLVMVectorType::classof(Type type) {
   return type.isa<LLVMFixedVectorType, LLVMScalableVectorType>();
@@ -343,12 +457,32 @@ llvm::ElementCount LLVMVectorType::getElementCount() {
       isa<LLVMScalableVectorType>());
 }
 
+/// Verifies that the type about to be constructed is well-formed.
+LogicalResult
+LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
+                                             unsigned numElements) {
+  if (numElements == 0)
+    return emitError(loc, "the number of vector elements must be positive");
+
+  if (!isValidElementType(elementType))
+    return emitError(loc, "invalid vector element type");
+
+  return success();
+}
+
 LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
                                              unsigned numElements) {
   assert(elementType && "expected non-null subtype");
   return Base::get(elementType.getContext(), LLVMType::FixedVectorType,
-                   elementType, numElements)
-      .cast<LLVMFixedVectorType>();
+                   elementType, numElements);
+}
+
+LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
+                                                    LLVMType elementType,
+                                                    unsigned numElements) {
+  assert(elementType && "expected non-null subtype");
+  return Base::getChecked(loc, LLVMType::FixedVectorType, elementType,
+                          numElements);
 }
 
 unsigned LLVMFixedVectorType::getNumElements() {
@@ -359,8 +493,15 @@ LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
                                                    unsigned minNumElements) {
   assert(elementType && "expected non-null subtype");
   return Base::get(elementType.getContext(), LLVMType::ScalableVectorType,
-                   elementType, minNumElements)
-      .cast<LLVMScalableVectorType>();
+                   elementType, minNumElements);
+}
+
+LLVMScalableVectorType
+LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
+                                   unsigned minNumElements) {
+  assert(elementType && "expected non-null subtype");
+  return Base::getChecked(loc, LLVMType::ScalableVectorType, elementType,
+                          minNumElements);
 }
 
 unsigned LLVMScalableVectorType::getMinNumElements() {

diff  --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
index f7a75da46775..29da75aae584 100644
--- a/mlir/test/Dialect/LLVMIR/types-invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir
@@ -1,5 +1,33 @@
 // RUN: mlir-opt --allow-unregistered-dialect -split-input-file -verify-diagnostics %s
 
+func @array_of_void() {
+  // expected-error @+1 {{invalid array element type}}
+  "some.op"() : () -> !llvm.array<4 x void>
+}
+
+// -----
+
+func @function_returning_function() {
+  // expected-error @+1 {{invalid function result type}}
+  "some.op"() : () -> !llvm.func<func<void ()> ()>
+}
+
+// -----
+
+func @function_taking_function() {
+  // expected-error @+1 {{invalid function argument type}}
+  "some.op"() : () -> !llvm.func<void (func<void ()>)>
+}
+
+// -----
+
+func @void_pointer() {
+  // expected-error @+1 {{invalid pointer element type}}
+  "some.op"() : () -> !llvm.ptr<void>
+}
+
+// -----
+
 func @repeated_struct_name() {
   "some.op"() : () -> !llvm.struct<"a", (ptr<struct<"a">>)>
   // expected-error @+2 {{identified type already used with a 
diff erent body}}
@@ -74,6 +102,20 @@ func @explicitly_opaque_struct() {
 
 // -----
 
+func @literal_struct_with_void() {
+  // expected-error @+1 {{invalid LLVM structure element type}}
+  "some.op"() : () -> !llvm.struct<(void)>
+}
+
+// -----
+
+func @identified_struct_with_void() {
+  // expected-error @+1 {{invalid LLVM structure element type}}
+  "some.op"() : () -> !llvm.struct<"a", (void)>
+}
+
+// -----
+
 func @dynamic_vector() {
   // expected-error @+1 {{expected '? x <integer> x <type>' or '<integer> x <type>'}}
   "some.op"() : () -> !llvm.vec<? x float>
@@ -93,3 +135,23 @@ func @unscalable_vector() {
   "some.op"() : () -> !llvm.vec<4 x 4 x i32>
 }
 
+// -----
+
+func @zero_vector() {
+  // expected-error @+1 {{the number of vector elements must be positive}}
+  "some.op"() : () -> !llvm.vec<0 x i32>
+}
+
+// -----
+
+func @nested_vector() {
+  // expected-error @+1 {{invalid vector element type}}
+  "some.op"() : () -> !llvm.vec<2 x vec<2 x i32>>
+}
+
+// -----
+
+func @scalable_void_vector() {
+  // expected-error @+1 {{invalid vector element type}}
+  "some.op"() : () -> !llvm.vec<? x 4 x void>
+}


        


More information about the Mlir-commits mailing list