[Mlir-commits] [mlir] c69c9e0 - [mlir] Remove LLVMType, LLVM dialect types now derive Type directly

Alex Zinenko llvmlistbot at llvm.org
Tue Jan 5 08:37:03 PST 2021


Author: Alex Zinenko
Date: 2021-01-05T17:36:54+01:00
New Revision: c69c9e0f0fd2e5b72c4a1947822a4961b8630123

URL: https://github.com/llvm/llvm-project/commit/c69c9e0f0fd2e5b72c4a1947822a4961b8630123
DIFF: https://github.com/llvm/llvm-project/commit/c69c9e0f0fd2e5b72c4a1947822a4961b8630123.diff

LOG: [mlir] Remove LLVMType, LLVM dialect types now derive Type directly

BEGIN_PUBLIC
[mlir] Remove LLVMType, LLVM dialect types now derive Type directly

This class has become a simple `isa` hook with no proper functionality.
Removing will allow us to eventually make the LLVM dialect type infrastructure
open, i.e., support non-LLVM types inside container types, which itself will
make the type conversion more progressive.

Introduce a call `LLVM::isCompatibleType` to be used instead of
`isa<LLVMType>`. For now, this is strictly equivalent.
END_PUBLIC

Depends On D93681

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D93713

Added: 
    

Modified: 
    mlir/docs/Tutorials/Toy/Ch-6.md
    mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
    mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
    mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
    mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
    mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
    mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
    mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
    mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
    mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
    mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
    mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
    mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
    mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
    mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
    mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
    mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
    mlir/lib/Target/LLVMIR/TypeTranslation.cpp
    mlir/test/Dialect/LLVMIR/invalid.mlir
    mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md
index 86805c10831a..c2211412e5c4 100644
--- a/mlir/docs/Tutorials/Toy/Ch-6.md
+++ b/mlir/docs/Tutorials/Toy/Ch-6.md
@@ -37,10 +37,11 @@ static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
 
   // Create a function declaration for printf, the signature is:
   //   * `i32 (i8*, ...)`
-  auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
-  auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
-  auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
-                                                  /*isVarArg=*/true);
+  auto llvmI32Ty = LLVM::LLVMIntegerType::get(context, 32);
+  auto llvmI8PtrTy =
+      LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
+  auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy,
+                                                /*isVarArg=*/true);
 
   // Insert the printf function into the body of the parent module.
   PatternRewriter::InsertionGuard insertGuard(rewriter);

diff  --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
index d5c1e923fab9..686a8ca01d9e 100644
--- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
+++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h
@@ -34,7 +34,6 @@ class UnrankedMemRefType;
 
 namespace LLVM {
 class LLVMDialect;
-class LLVMType;
 class LLVMPointerType;
 } // namespace LLVM
 
@@ -71,8 +70,8 @@ class LLVMTypeConverter : public TypeConverter {
   /// Convert a function type.  The arguments and results are converted one by
   /// one and results are packed into a wrapped LLVM IR structure type. `result`
   /// is populated with argument mapping.
-  LLVM::LLVMType convertFunctionSignature(FunctionType funcTy, bool isVariadic,
-                                          SignatureConversion &result);
+  Type convertFunctionSignature(FunctionType funcTy, bool isVariadic,
+                                SignatureConversion &result);
 
   /// Convert a non-empty list of types to be returned from a function into a
   /// supported LLVM IR type.  In particular, if more than one value is
@@ -118,14 +117,14 @@ class LLVMTypeConverter : public TypeConverter {
 
   /// Converts the function type to a C-compatible format, in particular using
   /// pointers to memref descriptors for arguments.
-  LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type);
+  Type convertFunctionTypeCWrapper(FunctionType type);
 
   /// Returns the data layout to use during and after conversion.
   const llvm::DataLayout &getDataLayout() { return options.dataLayout; }
 
   /// Gets the LLVM representation of the index type. The returned type is an
   /// integer type with the size configured for this type converter.
-  LLVM::LLVMType getIndexType();
+  Type getIndexType();
 
   /// Gets the bitwidth of the index type when converted to LLVM.
   unsigned getIndexTypeBitwidth() { return options.indexBitwidth; }
@@ -185,8 +184,8 @@ class LLVMTypeConverter : public TypeConverter {
   /// - `!llvm.i64`, `!llvm.i64` (sizes),
   /// - `!llvm.i64`, `!llvm.i64` (strides).
   /// These types can be recomposed to a memref descriptor struct.
-  SmallVector<LLVM::LLVMType, 5>
-  getMemRefDescriptorFields(MemRefType type, bool unpackAggregates);
+  SmallVector<Type, 5> getMemRefDescriptorFields(MemRefType type,
+                                                 bool unpackAggregates);
 
   /// Convert an unranked memref type into a list of non-aggregate LLVM IR types
   /// that will form the unranked memref descriptor. In particular, this list
@@ -197,7 +196,7 @@ class LLVMTypeConverter : public TypeConverter {
   /// !llvm.i64 (rank)
   /// !llvm<"i8*"> (type-erased pointer).
   /// These types can be recomposed to a unranked memref descriptor struct.
-  SmallVector<LLVM::LLVMType, 2> getUnrankedMemRefDescriptorFields();
+  SmallVector<Type, 2> getUnrankedMemRefDescriptorFields();
 
   // Convert an unranked memref type to an LLVM type that captures the
   // runtime rank and a pointer to the static ranked memref desc
@@ -417,31 +416,30 @@ class UnrankedMemRefDescriptor : public StructBuilder {
 
   /// Builds IR extracting the allocated pointer from the descriptor.
   static Value allocatedPtr(OpBuilder &builder, Location loc,
-                            Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType);
+                            Value memRefDescPtr, Type elemPtrPtrType);
   /// Builds IR inserting the allocated pointer into the descriptor.
   static void setAllocatedPtr(OpBuilder &builder, Location loc,
-                              Value memRefDescPtr,
-                              LLVM::LLVMType elemPtrPtrType,
+                              Value memRefDescPtr, Type elemPtrPtrType,
                               Value allocatedPtr);
 
   /// Builds IR extracting the aligned pointer from the descriptor.
   static Value alignedPtr(OpBuilder &builder, Location loc,
                           LLVMTypeConverter &typeConverter, Value memRefDescPtr,
-                          LLVM::LLVMType elemPtrPtrType);
+                          Type elemPtrPtrType);
   /// Builds IR inserting the aligned pointer into the descriptor.
   static void setAlignedPtr(OpBuilder &builder, Location loc,
                             LLVMTypeConverter &typeConverter,
-                            Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType,
+                            Value memRefDescPtr, Type elemPtrPtrType,
                             Value alignedPtr);
 
   /// Builds IR extracting the offset from the descriptor.
   static Value offset(OpBuilder &builder, Location loc,
                       LLVMTypeConverter &typeConverter, Value memRefDescPtr,
-                      LLVM::LLVMType elemPtrPtrType);
+                      Type elemPtrPtrType);
   /// Builds IR inserting the offset into the descriptor.
   static void setOffset(OpBuilder &builder, Location loc,
                         LLVMTypeConverter &typeConverter, Value memRefDescPtr,
-                        LLVM::LLVMType elemPtrPtrType, Value offset);
+                        Type elemPtrPtrType, Value offset);
 
   /// Builds IR extracting the pointer to the first element of the size array.
   static Value sizeBasePtr(OpBuilder &builder, Location loc,
@@ -490,17 +488,17 @@ class ConvertToLLVMPattern : public ConversionPattern {
 
   /// Gets the MLIR type wrapping the LLVM integer type whose bit width is
   /// defined by the used type converter.
-  LLVM::LLVMType getIndexType() const;
+  Type getIndexType() const;
 
   /// Gets the MLIR type wrapping the LLVM integer type whose bit width
   /// corresponds to that of a LLVM pointer type.
-  LLVM::LLVMType getIntPtrType(unsigned addressSpace = 0) const;
+  Type getIntPtrType(unsigned addressSpace = 0) const;
 
   /// Gets the MLIR type wrapping the LLVM void type.
-  LLVM::LLVMType getVoidType() const;
+  Type getVoidType() const;
 
   /// Get the MLIR type wrapping the LLVM i8* type.
-  LLVM::LLVMType getVoidPtrType() const;
+  Type getVoidPtrType() const;
 
   /// Create an LLVM dialect operation defining the given index constant.
   Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index ef5efc9c8281..3b7cd5ea9184 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -49,8 +49,8 @@ def LLVM_Dialect : Dialect {
 
 // LLVM dialect type.
 def LLVM_Type : DialectType<LLVM_Dialect,
-                            CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
-                            "LLVM dialect type">;
+                            CPred<"::mlir::LLVM::isCompatibleType($_self)">,
+                            "LLVM dialect-compatible type">;
 
 // Type constraint accepting LLVM integer types.
 def LLVM_AnyInteger : Type<
@@ -223,9 +223,9 @@ class ListIntSubst<string pattern, list<int> values> {
 // or result in the operation.
 def LLVM_IntrPatterns {
   string operand =
-    [{convertType(opInst.getOperand($0).getType().cast<LLVM::LLVMType>())}];
+    [{convertType(opInst.getOperand($0).getType())}];
   string result =
-    [{convertType(opInst.getResult($0).getType().cast<LLVM::LLVMType>())}];
+    [{convertType(opInst.getResult($0).getType())}];
   string structResult =
     [{convertType(opInst.getResult(0).getType().cast<LLVM::LLVMStructType>()
                                                    .getBody()[$0])}];

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 4968b33f47a4..428ca6783afd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -49,9 +49,8 @@ def LLVM_VoidResultTypeOpBuilder :
   OpBuilderDAG<(ins "Type":$resultType, "ValueRange":$operands,
     CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
   [{
-    auto llvmType = resultType.dyn_cast<LLVMType>(); (void)llvmType;
-    assert(llvmType && "result must be an LLVM type");
-    assert(llvmType.isa<LLVMVoidType>() &&
+    assert(isCompatibleType(resultType) && "result must be an LLVM type");
+    assert(resultType.isa<LLVMVoidType>() &&
            "for zero-result operands, only 'void' is accepted as result type");
     build($_builder, $_state, operands, attributes);
   }]>;
@@ -443,7 +442,7 @@ def LLVM_CallOp : LLVM_Op<"call"> {
     OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
     [{
-      LLVMType resultType = func.getType().getReturnType();
+      Type resultType = func.getType().getReturnType();
       if (!resultType.isa<LLVM::LLVMVoidType>())
         $_state.addTypes(resultType);
       $_state.addAttribute("callee", $_builder.getSymbolRefAttr(func));
@@ -756,23 +755,21 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof"> {
   }];
 
   let builders = [
-    OpBuilderDAG<(ins "LLVMType":$resType, "StringRef":$name,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
-    [{
-      $_state.addAttribute("global_name",$_builder.getSymbolRefAttr(name));
-      $_state.addAttributes(attrs);
-      $_state.addTypes(resType);}]>,
     OpBuilderDAG<(ins "GlobalOp":$global,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
             LLVM::LLVMPointerType::get(global.getType(), global.addr_space()),
-            global.sym_name(), attrs);}]>,
+            global.sym_name());
+      $_state.addAttributes(attrs);
+    }]>,
     OpBuilderDAG<(ins "LLVMFuncOp":$func,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
     [{
       build($_builder, $_state,
-            LLVM::LLVMPointerType::get(func.getType()), func.getName(), attrs);}]>
+            LLVM::LLVMPointerType::get(func.getType()), func.getName());
+      $_state.addAttributes(attrs);
+    }]>
   ];
 
   let extraClassDeclaration = [{
@@ -883,15 +880,15 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
   let regions = (region AnyRegion:$initializer);
 
   let builders = [
-    OpBuilderDAG<(ins "LLVMType":$type, "bool":$isConstant, "Linkage":$linkage,
+    OpBuilderDAG<(ins "Type":$type, "bool":$isConstant, "Linkage":$linkage,
       "StringRef":$name, "Attribute":$value, CArg<"unsigned", "0">:$addrSpace,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
   ];
 
   let extraClassDeclaration = [{
     /// Return the LLVM type of the global.
-    LLVMType getType() {
-      return type().cast<LLVMType>();
+    Type getType() {
+      return type();
     }
     /// Return the initializer attribute if it exists, or a null attribute.
     Attribute getValueOrNull() {
@@ -957,7 +954,7 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func",
   let skipDefaultBuilders = 1;
 
   let builders = [
-    OpBuilderDAG<(ins "StringRef":$name, "LLVMType":$type,
+    OpBuilderDAG<(ins "StringRef":$name, "Type":$type,
       CArg<"Linkage", "Linkage::External">:$linkage,
       CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
       CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)>

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 7c7731946ba8..6f78118a5b00 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -45,52 +45,13 @@ class LLVMFP128Type;
 class LLVMX86FP80Type;
 class LLVMIntegerType;
 
-//===----------------------------------------------------------------------===//
-// LLVMType.
-//===----------------------------------------------------------------------===//
-
-/// Base class for LLVM dialect types.
-///
-/// The LLVM dialect in MLIR fully reflects the LLVM IR type system, prodiving a
-/// separate MLIR type for each LLVM IR type. All types are represented as
-/// separate subclasses and are compatible with the isa/cast infrastructure.
-///
-/// The LLVM dialect type system is closed: parametric types can only refer to
-/// other LLVM dialect types. This is consistent with LLVM IR and enables a more
-/// concise pretty-printing format.
-///
-/// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR
-/// context, have an immutable identifier (for most types except identified
-/// structs, the entire type is the identifier) and are thread-safe.
-///
-/// This class is a thin common base class for 
diff erent types available in the
-/// LLVM dialect. It intentionally does not provide the API similar to
-/// llvm::Type to avoid confusion and highlight potentially expensive operations
-/// (e.g., type creation in MLIR takes a lock, so it's better to cache types).
-class LLVMType : public Type {
-public:
-  /// Inherit base constructors.
-  using Type::Type;
-
-  /// Support for PointerLikeTypeTraits.
-  using Type::getAsOpaquePointer;
-  static LLVMType getFromOpaquePointer(const void *ptr) {
-    return LLVMType(static_cast<ImplType *>(const_cast<void *>(ptr)));
-  }
-
-  /// Support for isa/cast.
-  static bool classof(Type type);
-
-  LLVMDialect &getDialect();
-};
-
 //===----------------------------------------------------------------------===//
 // Trivial types.
 //===----------------------------------------------------------------------===//
 
 // Batch-define trivial types.
 #define DEFINE_TRIVIAL_LLVM_TYPE(ClassName)                                    \
-  class ClassName : public Type::TypeBase<ClassName, LLVMType, TypeStorage> {  \
+  class ClassName : public Type::TypeBase<ClassName, Type, TypeStorage> {      \
   public:                                                                      \
     using Base::Base;                                                          \
   }
@@ -117,30 +78,30 @@ DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType);
 /// LLVM dialect array type. It is an aggregate type representing consecutive
 /// elements in memory, parameterized by the number of elements and the element
 /// type.
-class LLVMArrayType : public Type::TypeBase<LLVMArrayType, LLVMType,
+class LLVMArrayType : public Type::TypeBase<LLVMArrayType, Type,
                                             detail::LLVMTypeAndSizeStorage> {
 public:
   /// Inherit base constructors.
   using Base::Base;
 
   /// Checks if the given type can be used inside an array type.
-  static bool isValidElementType(LLVMType type);
+  static bool isValidElementType(Type type);
 
   /// Gets or creates an instance of LLVM dialect array type containing
   /// `numElements` of `elementType`, in the same context as `elementType`.
-  static LLVMArrayType get(LLVMType elementType, unsigned numElements);
-  static LLVMArrayType getChecked(Location loc, LLVMType elementType,
+  static LLVMArrayType get(Type elementType, unsigned numElements);
+  static LLVMArrayType getChecked(Location loc, Type elementType,
                                   unsigned numElements);
 
   /// Returns the element type of the array.
-  LLVMType getElementType();
+  Type getElementType();
 
   /// Returns the number of elements in the array type.
   unsigned getNumElements();
 
   /// Verifies that the type about to be constructed is well-formed.
   static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    LLVMType elementType,
+                                                    Type elementType,
                                                     unsigned numElements);
 };
 
@@ -152,46 +113,46 @@ class LLVMArrayType : public Type::TypeBase<LLVMArrayType, LLVMType,
 /// which can have multiple), a list of parameter types and can optionally be
 /// variadic.
 class LLVMFunctionType
-    : public Type::TypeBase<LLVMFunctionType, LLVMType,
+    : public Type::TypeBase<LLVMFunctionType, Type,
                             detail::LLVMFunctionTypeStorage> {
 public:
   /// Inherit base constructors.
   using Base::Base;
 
   /// Checks if the given type can be used an argument in a function type.
-  static bool isValidArgumentType(LLVMType type);
+  static bool isValidArgumentType(Type type);
 
   /// Checks if the given type can be used as a result in a function type.
-  static bool isValidResultType(LLVMType type);
+  static bool isValidResultType(Type type);
 
   /// Returns whether the function is variadic.
   bool isVarArg();
 
   /// Gets or creates an instance of LLVM dialect function in the same context
   /// as the `result` type.
-  static LLVMFunctionType get(LLVMType result, ArrayRef<LLVMType> arguments,
+  static LLVMFunctionType get(Type result, ArrayRef<Type> arguments,
                               bool isVarArg = false);
-  static LLVMFunctionType getChecked(Location loc, LLVMType result,
-                                     ArrayRef<LLVMType> arguments,
+  static LLVMFunctionType getChecked(Location loc, Type result,
+                                     ArrayRef<Type> arguments,
                                      bool isVarArg = false);
 
   /// Returns the result type of the function.
-  LLVMType getReturnType();
+  Type getReturnType();
 
   /// Returns the number of arguments to the function.
   unsigned getNumParams();
 
   /// Returns `i`-th argument of the function. Asserts on out-of-bounds.
-  LLVMType getParamType(unsigned i);
+  Type getParamType(unsigned i);
 
   /// Returns a list of argument types of the function.
-  ArrayRef<LLVMType> getParams();
-  ArrayRef<LLVMType> params() { return getParams(); }
+  ArrayRef<Type> getParams();
+  ArrayRef<Type> params() { return getParams(); }
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult
-  verifyConstructionInvariants(Location loc, LLVMType result,
-                               ArrayRef<LLVMType> arguments, bool);
+  static LogicalResult verifyConstructionInvariants(Location loc, Type result,
+                                                    ArrayRef<Type> arguments,
+                                                    bool);
 };
 
 //===----------------------------------------------------------------------===//
@@ -199,7 +160,7 @@ class LLVMFunctionType
 //===----------------------------------------------------------------------===//
 
 /// LLVM dialect signless integer type parameterized by bitwidth.
-class LLVMIntegerType : public Type::TypeBase<LLVMIntegerType, LLVMType,
+class LLVMIntegerType : public Type::TypeBase<LLVMIntegerType, Type,
                                               detail::LLVMIntegerTypeStorage> {
 public:
   /// Inherit base constructor.
@@ -225,31 +186,31 @@ class LLVMIntegerType : public Type::TypeBase<LLVMIntegerType, LLVMType,
 /// LLVM dialect pointer type. This type typically represents a reference to an
 /// object in memory. It is parameterized by the element type and the address
 /// space.
-class LLVMPointerType : public Type::TypeBase<LLVMPointerType, LLVMType,
+class LLVMPointerType : public Type::TypeBase<LLVMPointerType, Type,
                                               detail::LLVMPointerTypeStorage> {
 public:
   /// Inherit base constructors.
   using Base::Base;
 
   /// Checks if the given type can have a pointer type pointing to it.
-  static bool isValidElementType(LLVMType type);
+  static bool isValidElementType(Type type);
 
   /// Gets or creates an instance of LLVM dialect pointer type pointing to an
   /// object of `pointee` type in the given address space. The pointer type is
   /// created in the same context as `pointee`.
-  static LLVMPointerType get(LLVMType pointee, unsigned addressSpace = 0);
-  static LLVMPointerType getChecked(Location loc, LLVMType pointee,
+  static LLVMPointerType get(Type pointee, unsigned addressSpace = 0);
+  static LLVMPointerType getChecked(Location loc, Type pointee,
                                     unsigned addressSpace = 0);
 
   /// Returns the pointed-to type.
-  LLVMType getElementType();
+  Type getElementType();
 
   /// Returns the address space of the pointer.
   unsigned getAddressSpace();
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    LLVMType pointee, unsigned);
+  static LogicalResult verifyConstructionInvariants(Location loc, Type pointee,
+                                                    unsigned);
 };
 
 //===----------------------------------------------------------------------===//
@@ -280,14 +241,14 @@ class LLVMPointerType : public Type::TypeBase<LLVMPointerType, LLVMType,
 ///
 /// Note that the packedness of the struct takes place in uniquing of literal
 /// structs, but does not in uniquing of identified structs.
-class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
+class LLVMStructType : public Type::TypeBase<LLVMStructType, Type,
                                              detail::LLVMStructTypeStorage> {
 public:
   /// Inherit base constructors.
   using Base::Base;
 
   /// Checks if the given type can be contained in a structure type.
-  static bool isValidElementType(LLVMType type);
+  static bool isValidElementType(Type type);
 
   /// Gets or creates an identified struct with the given name in the provided
   /// context. Note that unlike llvm::StructType::create, this function will
@@ -302,16 +263,14 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
   /// the struct by appending a `.` followed by a number to the name. Renaming
   /// happens even if the existing struct has the same body.
   static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name,
-                                         ArrayRef<LLVMType> elements,
+                                         ArrayRef<Type> elements,
                                          bool isPacked = false);
 
   /// Gets or creates a literal struct with the given body in the provided
   /// context.
-  static LLVMStructType getLiteral(MLIRContext *context,
-                                   ArrayRef<LLVMType> types,
+  static LLVMStructType getLiteral(MLIRContext *context, ArrayRef<Type> types,
                                    bool isPacked = false);
-  static LLVMStructType getLiteralChecked(Location loc,
-                                          ArrayRef<LLVMType> types,
+  static LLVMStructType getLiteralChecked(Location loc, ArrayRef<Type> types,
                                           bool isPacked = false);
 
   /// Gets or creates an intentionally-opaque identified struct. Such a struct
@@ -329,7 +288,7 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
   /// 
diff erent thread modified the struct after it was created. Most callers
   /// are likely to assert this always succeeds, but it is possible to implement
   /// a local renaming scheme based on the result of this call.
-  LogicalResult setBody(ArrayRef<LLVMType> types, bool isPacked);
+  LogicalResult setBody(ArrayRef<Type> types, bool isPacked);
 
   /// Checks if a struct is packed.
   bool isPacked();
@@ -347,12 +306,12 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
   StringRef getName();
 
   /// Returns the list of element types contained in a non-opaque struct.
-  ArrayRef<LLVMType> getBody();
+  ArrayRef<Type> getBody();
 
   /// Verifies that the type about to be constructed is well-formed.
   static LogicalResult verifyConstructionInvariants(Location, StringRef, bool);
-  static LogicalResult
-  verifyConstructionInvariants(Location loc, ArrayRef<LLVMType> types, bool);
+  static LogicalResult verifyConstructionInvariants(Location loc,
+                                                    ArrayRef<Type> types, bool);
 };
 
 //===----------------------------------------------------------------------===//
@@ -362,26 +321,26 @@ class LLVMStructType : public Type::TypeBase<LLVMStructType, LLVMType,
 /// LLVM dialect vector type, represents a sequence of elements that can be
 /// processed as one, typically in SIMD context. This is a base class for fixed
 /// and scalable vectors.
-class LLVMVectorType : public LLVMType {
+class LLVMVectorType : public Type {
 public:
   /// Inherit base constructor.
-  using LLVMType::LLVMType;
+  using Type::Type;
 
   /// Support type casting functionality.
   static bool classof(Type type);
 
   /// Checks if the given type can be used in a vector type.
-  static bool isValidElementType(LLVMType type);
+  static bool isValidElementType(Type type);
 
   /// Returns the element type of the vector.
-  LLVMType getElementType();
+  Type getElementType();
 
   /// Returns the number of elements in the vector.
   llvm::ElementCount getElementCount();
 
   /// Verifies that the type about to be constructed is well-formed.
   static LogicalResult verifyConstructionInvariants(Location loc,
-                                                    LLVMType elementType,
+                                                    Type elementType,
                                                     unsigned numElements);
 };
 
@@ -401,8 +360,8 @@ class LLVMFixedVectorType
 
   /// Gets or creates a fixed vector type containing `numElements` of
   /// `elementType` in the same context as `elementType`.
-  static LLVMFixedVectorType get(LLVMType elementType, unsigned numElements);
-  static LLVMFixedVectorType getChecked(Location loc, LLVMType elementType,
+  static LLVMFixedVectorType get(Type elementType, unsigned numElements);
+  static LLVMFixedVectorType getChecked(Location loc, Type elementType,
                                         unsigned numElements);
 
   /// Returns the number of elements in the fixed vector.
@@ -426,9 +385,8 @@ class LLVMScalableVectorType
 
   /// Gets or creates a scalable vector type containing a non-zero multiple of
   /// `minNumElements` of `elementType` in the same context as `elementType`.
-  static LLVMScalableVectorType get(LLVMType elementType,
-                                    unsigned minNumElements);
-  static LLVMScalableVectorType getChecked(Location loc, LLVMType elementType,
+  static LLVMScalableVectorType get(Type elementType, unsigned minNumElements);
+  static LLVMScalableVectorType getChecked(Location loc, Type elementType,
                                            unsigned minNumElements);
 
   /// Returns the scaling factor of the number of elements in the vector. The
@@ -443,10 +401,10 @@ class LLVMScalableVectorType
 
 namespace detail {
 /// Parses an LLVM dialect type.
-LLVMType parseType(DialectAsmParser &parser);
+Type parseType(DialectAsmParser &parser);
 
 /// Prints an LLVM Dialect type.
-void printType(LLVMType type, DialectAsmPrinter &printer);
+void printType(Type type, DialectAsmPrinter &printer);
 } // namespace detail
 
 //===----------------------------------------------------------------------===//
@@ -454,7 +412,30 @@ void printType(LLVMType type, DialectAsmPrinter &printer);
 //===----------------------------------------------------------------------===//
 
 /// Returns `true` if the given type is compatible with the LLVM dialect.
-inline bool isCompatibleType(Type type) { return type.isa<LLVMType>(); }
+inline bool isCompatibleType(Type type) {
+  // clang-format off
+  return type.isa<
+      LLVMArrayType,
+      LLVMBFloatType,
+      LLVMDoubleType,
+      LLVMFP128Type,
+      LLVMFloatType,
+      LLVMFunctionType,
+      LLVMHalfType,
+      LLVMIntegerType,
+      LLVMLabelType,
+      LLVMMetadataType,
+      LLVMPPCFP128Type,
+      LLVMPointerType,
+      LLVMStructType,
+      LLVMTokenType,
+      LLVMVectorType,
+      LLVMVoidType,
+      LLVMX86FP80Type,
+      LLVMX86MMXType
+  >();
+  // clang-format on
+}
 
 inline bool isCompatibleFloatingPointType(Type type) {
   return type.isa<LLVMHalfType, LLVMBFloatType, LLVMFloatType, LLVMDoubleType,
@@ -470,46 +451,4 @@ llvm::TypeSize getPrimitiveTypeSizeInBits(Type type);
 } // namespace LLVM
 } // namespace mlir
 
-//===----------------------------------------------------------------------===//
-// Support for hashing and containers.
-//===----------------------------------------------------------------------===//
-
-namespace llvm {
-
-// LLVMType instances hash just like pointers.
-template <>
-struct DenseMapInfo<mlir::LLVM::LLVMType> {
-  static mlir::LLVM::LLVMType getEmptyKey() {
-    void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
-    return mlir::LLVM::LLVMType(
-        static_cast<mlir::LLVM::LLVMType::ImplType *>(pointer));
-  }
-  static mlir::LLVM::LLVMType getTombstoneKey() {
-    void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
-    return mlir::LLVM::LLVMType(
-        static_cast<mlir::LLVM::LLVMType::ImplType *>(pointer));
-  }
-  static unsigned getHashValue(mlir::LLVM::LLVMType val) {
-    return mlir::hash_value(val);
-  }
-  static bool isEqual(mlir::LLVM::LLVMType lhs, mlir::LLVM::LLVMType rhs) {
-    return lhs == rhs;
-  }
-};
-
-// LLVMType behaves like a pointer similarly to mlir::Type.
-template <>
-struct PointerLikeTypeTraits<mlir::LLVM::LLVMType> {
-  static inline void *getAsVoidPointer(mlir::LLVM::LLVMType type) {
-    return const_cast<void *>(type.getAsOpaquePointer());
-  }
-  static inline mlir::LLVM::LLVMType getFromVoidPointer(void *ptr) {
-    return mlir::LLVM::LLVMType::getFromOpaquePointer(ptr);
-  }
-  static constexpr int NumLowBitsAvailable =
-      PointerLikeTypeTraits<mlir::Type>::NumLowBitsAvailable;
-};
-
-} // namespace llvm
-
 #endif // MLIR_DIALECT_LLVMIR_LLVMTYPES_H_

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index c6d2ded073e6..e073126450d6 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -149,11 +149,11 @@ def ROCDL_MubufLoadOp :
                  LLVM_Type:$slc)>{
   string llvmBuilder = [{
       $res = createIntrinsicCall(builder,
-          llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc, 
+          llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc,
           $slc}, {$_resultType});
   }];
   let parser = [{ return parseROCDLMubufLoadOp(parser, result); }];
-  let printer = [{ 
+  let printer = [{
     Operation *op = this->getOperation();
     p << op->getName() << " " << op->getOperands()
       << " : " << op->getResultTypes();
@@ -169,7 +169,7 @@ def ROCDL_MubufStoreOp :
                  LLVM_Type:$glc,
                  LLVM_Type:$slc)>{
   string llvmBuilder = [{
-    auto vdataType = convertType(op.vdata().getType().cast<LLVM::LLVMType>());
+    auto vdataType = convertType(op.vdata().getType());
     createIntrinsicCall(builder,
           llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex,
           $offset, $glc, $slc}, {vdataType});

diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index 5259ed7fe182..7691adfeef14 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -104,7 +104,7 @@ class ModuleTranslation {
                                          llvm::IRBuilder<> &builder);
 
   /// Converts the type from MLIR LLVM dialect to LLVM.
-  llvm::Type *convertType(LLVMType type);
+  llvm::Type *convertType(Type type);
 
   static std::unique_ptr<llvm::Module>
   prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,

diff  --git a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
index 71924b0c61ca..6b8e29ad7d9c 100644
--- a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h
@@ -24,12 +24,11 @@ class Type;
 
 namespace mlir {
 
+class Type;
 class MLIRContext;
 
 namespace LLVM {
 
-class LLVMType;
-
 namespace detail {
 class TypeToLLVMIRTranslatorImpl;
 class TypeFromLLVMIRTranslatorImpl;
@@ -47,11 +46,10 @@ class TypeToLLVMIRTranslator {
   /// that this will perform type conversion and store its results for future
   /// uses.
   // TODO: this should be removed when MLIR has proper data layout.
-  unsigned getPreferredAlignment(LLVM::LLVMType type,
-                                 const llvm::DataLayout &layout);
+  unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout);
 
   /// Translates the given MLIR LLVM dialect type to LLVM IR.
-  llvm::Type *translateType(LLVM::LLVMType type);
+  llvm::Type *translateType(Type type);
 
 private:
   /// Private implementation.
@@ -67,7 +65,7 @@ class TypeFromLLVMIRTranslator {
   ~TypeFromLLVMIRTranslator();
 
   /// Translates the given LLVM IR type to the MLIR LLVM dialect.
-  LLVM::LLVMType translateType(llvm::Type *type);
+  Type translateType(llvm::Type *type);
 
 private:
   /// Private implementation.

diff  --git a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
index 5742cd790e77..8a5790352263 100644
--- a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp
@@ -36,15 +36,14 @@ using VectorScaleOpLowering =
     OneToOneConvertToLLVMPattern<VectorScaleOp, LLVM::vector_scale>;
 
 // Extract an LLVM IR type from the LLVM IR dialect type.
-static LLVM::LLVMType unwrap(Type type) {
+static Type unwrap(Type type) {
   if (!type)
     return nullptr;
   auto *mlirContext = type.getContext();
-  auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedLLVMType)
+  if (!LLVM::isCompatibleType(type))
     emitError(UnknownLoc::get(mlirContext),
               "conversion resulted in a non-LLVM type");
-  return wrappedLLVMType;
+  return type;
 }
 
 static Optional<Type>

diff  --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 517c8d2c6f56..93854c3cc05c 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -149,7 +149,7 @@ struct AsyncAPI {
   }
 
   // Auxiliary coroutine resume intrinsic wrapper.
-  static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) {
+  static Type resumeFunctionType(MLIRContext *ctx) {
     auto voidTy = LLVM::LLVMVoidType::get(ctx);
     auto i8Ptr = opaquePointerType(ctx);
     return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false);
@@ -203,13 +203,11 @@ static constexpr const char *kCoroEnd = "llvm.coro.end";
 static constexpr const char *kCoroFree = "llvm.coro.free";
 static constexpr const char *kCoroResume = "llvm.coro.resume";
 
-/// Adds an LLVM function declaration to a module.
 static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder,
-                            StringRef name, LLVM::LLVMType ret,
-                            ArrayRef<LLVM::LLVMType> params) {
+                            StringRef name, Type ret, ArrayRef<Type> params) {
   if (module.lookupSymbol(name))
     return;
-  LLVM::LLVMType type = LLVM::LLVMFunctionType::get(ret, params);
+  Type type = LLVM::LLVMFunctionType::get(ret, params);
   builder.create<LLVM::LLVMFuncOp>(name, type);
 }
 
@@ -386,8 +384,7 @@ static CoroMachinery setupCoroMachinery(FuncOp func) {
   // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt
   auto sizeOf = [&](ValueType valueType) -> Value {
     auto storedType = converter.convertType(valueType.getValueType());
-    auto storagePtrType =
-        LLVM::LLVMPointerType::get(storedType.cast<LLVM::LLVMType>());
+    auto storagePtrType = LLVM::LLVMPointerType::get(storedType);
 
     // %Size = getelementptr %T* null, int 1
     // %SizeI = ptrtoint %T* %Size to i32
@@ -949,8 +946,7 @@ class AwaitValueOpLowering : public AwaitOpLoweringBase<AwaitOp, ValueType> {
     // Cast from i8* to the pointer pointer to LLVM type.
     auto llvmValueType = getTypeConverter()->convertType(valueType);
     auto castedStorage = rewriter.create<LLVM::BitcastOp>(
-        loc, LLVM::LLVMPointerType::get(llvmValueType.cast<LLVM::LLVMType>()),
-        storage.getResult(0));
+        loc, LLVM::LLVMPointerType::get(llvmValueType), storage.getResult(0));
 
     // Load from the async value storage.
     auto loaded = rewriter.create<LLVM::LoadOp>(loc, castedStorage.getResult());
@@ -1015,9 +1011,7 @@ class YieldOpLowering : public ConversionPattern {
 
       // Cast storage pointer to the yielded value type.
       auto castedStorage = rewriter.create<LLVM::BitcastOp>(
-          loc,
-          LLVM::LLVMPointerType::get(
-              yieldValue.getType().cast<LLVM::LLVMType>()),
+          loc, LLVM::LLVMPointerType::get(yieldValue.getType()),
           storage.getResult(0));
 
       // Store the yielded value into the async value storage.

diff  --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
index 6859834de67f..cd16df12bfae 100644
--- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
+++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp
@@ -52,8 +52,8 @@ class GpuToLLVMConversionPass
 
 class FunctionCallBuilder {
 public:
-  FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType,
-                      ArrayRef<LLVM::LLVMType> argumentTypes)
+  FunctionCallBuilder(StringRef functionName, Type returnType,
+                      ArrayRef<Type> argumentTypes)
       : functionName(functionName),
         functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
   LLVM::CallOp create(Location loc, OpBuilder &builder,
@@ -73,15 +73,14 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
 protected:
   MLIRContext *context = &this->getTypeConverter()->getContext();
 
-  LLVM::LLVMType llvmVoidType = LLVM::LLVMVoidType::get(context);
-  LLVM::LLVMType llvmPointerType =
+  Type llvmVoidType = LLVM::LLVMVoidType::get(context);
+  Type llvmPointerType =
       LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
-  LLVM::LLVMType llvmPointerPointerType =
-      LLVM::LLVMPointerType::get(llvmPointerType);
-  LLVM::LLVMType llvmInt8Type = LLVM::LLVMIntegerType::get(context, 8);
-  LLVM::LLVMType llvmInt32Type = LLVM::LLVMIntegerType::get(context, 32);
-  LLVM::LLVMType llvmInt64Type = LLVM::LLVMIntegerType::get(context, 64);
-  LLVM::LLVMType llvmIntPtrType = LLVM::LLVMIntegerType::get(
+  Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
+  Type llvmInt8Type = LLVM::LLVMIntegerType::get(context, 8);
+  Type llvmInt32Type = LLVM::LLVMIntegerType::get(context, 32);
+  Type llvmInt64Type = LLVM::LLVMIntegerType::get(context, 64);
+  Type llvmIntPtrType = LLVM::LLVMIntegerType::get(
       context, this->getTypeConverter()->getPointerBitwidth(0));
 
   FunctionCallBuilder moduleLoadCallBuilder = {
@@ -321,7 +320,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
                                      ConversionPatternRewriter &rewriter) {
   if (!llvm::all_of(operands, [](Value value) {
-        return value.getType().isa<LLVM::LLVMType>();
+        return LLVM::isCompatibleType(value.getType());
       }))
     return rewriter.notifyMatchFailure(
         op, "Cannot convert if operands aren't of LLVM type.");
@@ -511,10 +510,10 @@ Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
       loc, launchOp.getOperands().take_back(numKernelOperands),
       operands.take_back(numKernelOperands), builder);
   auto numArguments = arguments.size();
-  SmallVector<LLVM::LLVMType, 4> argumentTypes;
+  SmallVector<Type, 4> argumentTypes;
   argumentTypes.reserve(numArguments);
   for (auto argument : arguments)
-    argumentTypes.push_back(argument.getType().cast<LLVM::LLVMType>());
+    argumentTypes.push_back(argument.getType());
   auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
                                                            argumentTypes);
   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 3acc73415ef1..37c4469676c6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -38,7 +38,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
       uint64_t numElements = type.getNumElements();
 
       auto elementType = typeConverter->convertType(type.getElementType())
-                             .template cast<LLVM::LLVMType>();
+                             .template cast<Type>();
       auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
       std::string name = std::string(
           llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
@@ -126,7 +126,7 @@ struct GPUFuncOpLowering : ConvertOpToLLVMPattern<gpu::GPUFuncOp> {
         // memory space and does not support `alloca`s with addrspace(5).
         auto ptrType = LLVM::LLVMPointerType::get(
             typeConverter->convertType(type.getElementType())
-                .template cast<LLVM::LLVMType>(),
+                .template cast<Type>(),
             AllocaAddrSpace);
         Value numElements = rewriter.create<LLVM::ConstantOp>(
             gpuFuncOp.getLoc(), int64Ty,

diff  --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
index 631eca5cd32d..5b98d3cee1fb 100644
--- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h
@@ -40,7 +40,6 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
   matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     using LLVM::LLVMFuncOp;
-    using LLVM::LLVMType;
 
     static_assert(
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
@@ -54,9 +53,8 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
     for (Value operand : operands)
       castedOperands.push_back(maybeCast(operand, rewriter));
 
-    LLVMType resultType =
-        castedOperands.front().getType().cast<LLVM::LLVMType>();
-    LLVMType funcType = getFunctionType(resultType, castedOperands);
+    Type resultType = castedOperands.front().getType();
+    Type funcType = getFunctionType(resultType, castedOperands);
     StringRef funcName = getFunctionName(
         funcType.cast<LLVM::LLVMFunctionType>().getReturnType());
     if (funcName.empty())
@@ -80,7 +78,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
 
 private:
   Value maybeCast(Value operand, PatternRewriter &rewriter) const {
-    LLVM::LLVMType type = operand.getType().cast<LLVM::LLVMType>();
+    Type type = operand.getType();
     if (!type.isa<LLVM::LLVMHalfType>())
       return operand;
 
@@ -89,17 +87,15 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
         operand);
   }
 
-  LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType,
-                                 ArrayRef<Value> operands) const {
-    using LLVM::LLVMType;
-    SmallVector<LLVMType, 1> operandTypes;
+  Type getFunctionType(Type resultType, ArrayRef<Value> operands) const {
+    SmallVector<Type, 1> operandTypes;
     for (Value operand : operands) {
-      operandTypes.push_back(operand.getType().cast<LLVMType>());
+      operandTypes.push_back(operand.getType());
     }
     return LLVM::LLVMFunctionType::get(resultType, operandTypes);
   }
 
-  StringRef getFunctionName(LLVM::LLVMType type) const {
+  StringRef getFunctionName(Type type) const {
     if (type.isa<LLVM::LLVMFloatType>())
       return f32Func;
     if (type.isa<LLVM::LLVMDoubleType>())
@@ -107,8 +103,7 @@ struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> {
     return "";
   }
 
-  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName,
-                                     LLVM::LLVMType funcType,
+  LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType,
                                      Operation *op) const {
     using LLVM::LLVMFuncOp;
 

diff  --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index f747f519c66b..09877bfd19b3 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -56,7 +56,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
     Location loc = op->getLoc();
     gpu::ShuffleOpAdaptor adaptor(operands);
 
-    auto valueTy = adaptor.value().getType().cast<LLVM::LLVMType>();
+    auto valueTy = adaptor.value().getType();
     auto int32Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 32);
     auto predTy = LLVM::LLVMIntegerType::get(rewriter.getContext(), 1);
     auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),

diff  --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
index 4b657d25f51e..6c88e5774239 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
@@ -65,7 +65,7 @@ class VulkanLaunchFuncToVulkanCallsPass
     llvmInt64Type = LLVM::LLVMIntegerType::get(&getContext(), 64);
   }
 
-  LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) {
+  Type getMemRefType(uint32_t rank, Type elemenType) {
     // According to the MLIR doc memref argument is converted into a
     // pointer-to-struct argument of type:
     // template <typename Elem, size_t Rank>
@@ -89,10 +89,10 @@ class VulkanLaunchFuncToVulkanCallsPass
          llvmArrayRankElementSizeType, llvmArrayRankElementSizeType});
   }
 
-  LLVM::LLVMType getVoidType() { return llvmVoidType; }
-  LLVM::LLVMType getPointerType() { return llvmPointerType; }
-  LLVM::LLVMType getInt32Type() { return llvmInt32Type; }
-  LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
+  Type getVoidType() { return llvmVoidType; }
+  Type getPointerType() { return llvmPointerType; }
+  Type getInt32Type() { return llvmInt32Type; }
+  Type getInt64Type() { return llvmInt64Type; }
 
   /// Creates an LLVM global for the given `name`.
   Value createEntryPointNameConstant(StringRef name, Location loc,
@@ -128,10 +128,10 @@ class VulkanLaunchFuncToVulkanCallsPass
 
   /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`.
   LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor,
-                                        uint32_t &rank, LLVM::LLVMType &type);
+                                        uint32_t &rank, Type &type);
 
   /// Returns a string representation from the given `type`.
-  StringRef stringifyType(LLVM::LLVMType type) {
+  StringRef stringifyType(Type type) {
     if (type.isa<LLVM::LLVMFloatType>())
       return "Float";
     if (type.isa<LLVM::LLVMHalfType>())
@@ -152,11 +152,11 @@ class VulkanLaunchFuncToVulkanCallsPass
   void runOnOperation() override;
 
 private:
-  LLVM::LLVMType llvmFloatType;
-  LLVM::LLVMType llvmVoidType;
-  LLVM::LLVMType llvmPointerType;
-  LLVM::LLVMType llvmInt32Type;
-  LLVM::LLVMType llvmInt64Type;
+  Type llvmFloatType;
+  Type llvmVoidType;
+  Type llvmPointerType;
+  Type llvmInt32Type;
+  Type llvmInt64Type;
 
   // TODO: Use an associative array to support multiple vulkan launch calls.
   std::pair<StringAttr, StringAttr> spirvAttributes;
@@ -230,7 +230,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
 
     auto ptrToMemRefDescriptor = en.value();
     uint32_t rank = 0;
-    LLVM::LLVMType type;
+    Type type;
     if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) {
       cInterfaceVulkanLaunchCallOp.emitError()
           << "invalid memref descriptor " << ptrToMemRefDescriptor.getType();
@@ -258,7 +258,7 @@ void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls(
 }
 
 LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
-    Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) {
+    Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) {
   auto llvmPtrDescriptorTy =
       ptrToMemRefDescriptor.getType().dyn_cast<LLVM::LLVMPointerType>();
   if (!llvmPtrDescriptorTy)
@@ -324,12 +324,11 @@ void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
   }
 
   for (unsigned i = 1; i <= 3; i++) {
-    SmallVector<LLVM::LLVMType, 5> types{
-        LLVM::LLVMFloatType::get(&getContext()),
-        LLVM::LLVMIntegerType::get(&getContext(), 32),
-        LLVM::LLVMIntegerType::get(&getContext(), 16),
-        LLVM::LLVMIntegerType::get(&getContext(), 8),
-        LLVM::LLVMHalfType::get(&getContext())};
+    SmallVector<Type, 5> types{LLVM::LLVMFloatType::get(&getContext()),
+                               LLVM::LLVMIntegerType::get(&getContext(), 32),
+                               LLVM::LLVMIntegerType::get(&getContext(), 16),
+                               LLVM::LLVMIntegerType::get(&getContext(), 8),
+                               LLVM::LLVMHalfType::get(&getContext())};
     for (auto type : types) {
       std::string fnName = "bindMemRef" + std::to_string(i) + "D" +
                            std::string(stringifyType(type));

diff  --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index c86ae710aac6..95e7c34a1710 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -67,11 +67,9 @@ using llvm_alloca = ValueBuilder<LLVM::AllocaOp>;
 using llvm_return = OperationBuilder<LLVM::ReturnOp>;
 
 template <typename T>
-static LLVMType getPtrToElementType(T containerType,
-                                    LLVMTypeConverter &lowering) {
-  return lowering.convertType(containerType.getElementType())
-      .template cast<LLVMType>()
-      .getPointerTo();
+static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) {
+  return LLVMPointerType::get(
+      lowering.convertType(containerType.getElementType()));
 }
 
 /// Convert the given range descriptor type to the LLVMIR dialect.
@@ -84,8 +82,7 @@ static LLVMType getPtrToElementType(T containerType,
 /// };
 static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) {
   auto *context = t.getContext();
-  auto int64Ty = converter.convertType(IntegerType::get(context, 64))
-                     .cast<LLVM::LLVMType>();
+  auto int64Ty = converter.convertType(IntegerType::get(context, 64));
   return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty});
 }
 
@@ -206,8 +203,7 @@ class SliceOpConversion : public ConvertOpToLLVMPattern<SliceOp> {
     BaseViewConversionHelper baseDesc(adaptor.view());
 
     auto memRefType = sliceOp.getBaseViewType();
-    auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64))
-                       .cast<LLVM::LLVMType>();
+    auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64));
 
     BaseViewConversionHelper desc(
         typeConverter->convertType(sliceOp.getShapedType()));

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index c293dacf8ec5..8c42c2a7dc92 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -184,7 +184,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
       kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
           rewriter.getUnknownLoc(), newKernelFuncName,
           LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
-                                      ArrayRef<LLVM::LLVMType>()));
+                                      ArrayRef<Type>()));
       rewriter.setInsertionPoint(launchOp);
     }
 
@@ -234,7 +234,7 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
         OpBuilder::InsertionGuard guard(rewriter);
         rewriter.setInsertionPointToStart(module.getBody());
         dstGlobal = rewriter.create<LLVM::GlobalOp>(
-            loc, dstGlobalType.cast<LLVM::LLVMType>(),
+            loc, dstGlobalType,
             /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute());
         rewriter.setInsertionPoint(launchOp);
       }

diff  --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 76a4e0b2e07e..2ebb24b5aaeb 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -65,7 +65,7 @@ static unsigned getBitWidth(Type type) {
 }
 
 /// Returns the bit width of LLVMType integer or vector.
-static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) {
+static unsigned getLLVMTypeBitWidth(Type type) {
   auto vectorType = type.dyn_cast<LLVM::LLVMVectorType>();
   return (vectorType ? vectorType.getElementType() : type)
       .cast<LLVM::LLVMIntegerType>()
@@ -115,16 +115,15 @@ static Value createFPConstant(Location loc, Type srcType, Type dstType,
 ///   - `BitFieldSExtract`
 ///   - `BitFieldUExtract`
 /// Truncates or extends the value. If the bitwidth of the value is the same as
-/// `dstType` bitwidth, the value remains unchanged.
-static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType,
+/// `llvmType` bitwidth, the value remains unchanged.
+static Value optionallyTruncateOrExtend(Location loc, Value value,
+                                        Type llvmType,
                                         PatternRewriter &rewriter) {
   auto srcType = value.getType();
-  auto llvmType = dstType.cast<LLVM::LLVMType>();
   unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
-  unsigned valueBitWidth =
-      srcType.isa<LLVM::LLVMType>()
-          ? getLLVMTypeBitWidth(srcType.cast<LLVM::LLVMType>())
-          : getBitWidth(srcType);
+  unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
+                               ? getLLVMTypeBitWidth(srcType)
+                               : getBitWidth(srcType);
 
   if (valueBitWidth < targetBitWidth)
     return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
@@ -193,7 +192,7 @@ convertStructTypeWithOffset(spirv::StructType type,
 
   auto elementsVector = llvm::to_vector<8>(
       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
-        return converter.convertType(elementType).cast<LLVM::LLVMType>();
+        return converter.convertType(elementType);
       }));
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/false);
@@ -204,7 +203,7 @@ static Type convertStructTypePacked(spirv::StructType type,
                                     LLVMTypeConverter &converter) {
   auto elementsVector = llvm::to_vector<8>(
       llvm::map_range(type.getElementTypes(), [&](Type elementType) {
-        return converter.convertType(elementType).cast<LLVM::LLVMType>();
+        return converter.convertType(elementType);
       }));
   return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
                                           /*isPacked=*/true);
@@ -255,8 +254,7 @@ static Optional<Type> convertArrayType(spirv::ArrayType type,
       !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
     return llvm::None;
 
-  auto llvmElementType =
-      converter.convertType(elementType).cast<LLVM::LLVMType>();
+  auto llvmElementType = converter.convertType(elementType);
   unsigned numElements = type.getNumElements();
   return LLVM::LLVMArrayType::get(llvmElementType, numElements);
 }
@@ -265,8 +263,7 @@ static Optional<Type> convertArrayType(spirv::ArrayType type,
 /// modelled at the moment.
 static Type convertPointerType(spirv::PointerType type,
                                TypeConverter &converter) {
-  auto pointeeType =
-      converter.convertType(type.getPointeeType()).cast<LLVM::LLVMType>();
+  auto pointeeType = converter.convertType(type.getPointeeType());
   return LLVM::LLVMPointerType::get(pointeeType);
 }
 
@@ -277,8 +274,7 @@ static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
                                               TypeConverter &converter) {
   if (type.getArrayStride() != 0)
     return llvm::None;
-  auto elementType =
-      converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
+  auto elementType = converter.convertType(type.getElementType());
   return LLVM::LLVMArrayType::get(elementType, 0);
 }
 
@@ -336,8 +332,7 @@ class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
     auto dstType = typeConverter.convertType(op.pointer().getType());
     if (!dstType)
       return failure();
-    rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
-        op, dstType.cast<LLVM::LLVMType>(), op.variable());
+    rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
     return success();
   }
 };
@@ -667,7 +662,7 @@ class ExecutionModePattern
     //   int32_t values[];          // optional values
     // };
     auto llvmI32Type = LLVM::LLVMIntegerType::get(context, 32);
-    SmallVector<LLVM::LLVMType, 2> fields;
+    SmallVector<Type, 2> fields;
     fields.push_back(llvmI32Type);
     ArrayAttr values = op.values();
     if (!values.empty()) {
@@ -757,8 +752,7 @@ class GlobalVariablePattern
                        ? LLVM::Linkage::Private
                        : LLVM::Linkage::External;
     rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
-        op, dstType.cast<LLVM::LLVMType>(), isConstant, linkage, op.sym_name(),
-        Attribute());
+        op, dstType, isConstant, linkage, op.sym_name(), Attribute());
     return success();
   }
 };

diff  --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 233c2eadc77c..39680a28a33e 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -41,15 +41,15 @@ using namespace mlir;
 #define PASS_NAME "convert-std-to-llvm"
 
 // Extract an LLVM IR type from the LLVM IR dialect type.
-static LLVM::LLVMType unwrap(Type type) {
+static Type unwrap(Type type) {
   if (!type)
     return nullptr;
   auto *mlirContext = type.getContext();
-  auto wrappedLLVMType = type.dyn_cast<LLVM::LLVMType>();
-  if (!wrappedLLVMType)
+  if (!LLVM::isCompatibleType(type))
     emitError(UnknownLoc::get(mlirContext),
-              "conversion resulted in a non-LLVM type");
-  return wrappedLLVMType;
+              "conversion resulted in a non-LLVM type ")
+        << type;
+  return type;
 }
 
 /// Callback to convert function argument types. It converts a MemRef function
@@ -120,8 +120,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); });
   addConversion([&](VectorType type) { return convertVectorType(type); });
 
-  // LLVMType is legal, so add a pass-through conversion.
-  addConversion([](LLVM::LLVMType type) { return type; });
+  // LLVM-compatible types are legal, so add a pass-through conversion.
+  addConversion([](Type type) {
+    return LLVM::isCompatibleType(type) ? llvm::Optional<Type>(type)
+                                        : llvm::None;
+  });
 
   // Materialization for memrefs creates descriptor structs from individual
   // values constituting them, when descriptors are used, i.e. more than one
@@ -170,7 +173,7 @@ MLIRContext &LLVMTypeConverter::getContext() {
   return *getDialect()->getContext();
 }
 
-LLVM::LLVMType LLVMTypeConverter::getIndexType() {
+Type LLVMTypeConverter::getIndexType() {
   return LLVM::LLVMIntegerType::get(&getContext(), getIndexTypeBitwidth());
 }
 
@@ -205,7 +208,7 @@ Type LLVMTypeConverter::convertFloatType(FloatType type) {
 static constexpr unsigned kRealPosInComplexNumberStruct = 0;
 static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1;
 Type LLVMTypeConverter::convertComplexType(ComplexType type) {
-  auto elementType = convertType(type.getElementType()).cast<LLVM::LLVMType>();
+  auto elementType = convertType(type.getElementType());
   return LLVM::LLVMStructType::getLiteral(&getContext(),
                                           {elementType, elementType});
 }
@@ -214,7 +217,7 @@ Type LLVMTypeConverter::convertComplexType(ComplexType type) {
 // pointer-to-function types.
 Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
   SignatureConversion conversion(type.getNumInputs());
-  LLVM::LLVMType converted =
+  Type converted =
       convertFunctionSignature(type, /*isVariadic=*/false, conversion);
   return LLVM::LLVMPointerType::get(converted);
 }
@@ -224,7 +227,7 @@ Type LLVMTypeConverter::convertFunctionType(FunctionType type) {
 // argument and result types.  If MLIR Function has zero results, the LLVM
 // Function has one VoidType result.  If MLIR Function has more than one result,
 // they are into an LLVM StructType in their order of appearance.
-LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
+Type LLVMTypeConverter::convertFunctionSignature(
     FunctionType funcTy, bool isVariadic,
     LLVMTypeConverter::SignatureConversion &result) {
   // Select the argument converter depending on the calling convention.
@@ -240,7 +243,7 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
     result.addInputs(en.index(), converted);
   }
 
-  SmallVector<LLVM::LLVMType, 8> argTypes;
+  SmallVector<Type, 8> argTypes;
   argTypes.reserve(llvm::size(result.getConvertedTypes()));
   for (Type type : result.getConvertedTypes())
     argTypes.push_back(unwrap(type));
@@ -248,10 +251,9 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
   // If function does not return anything, create the void result type,
   // if it returns on element, convert it, otherwise pack the result types into
   // a struct.
-  LLVM::LLVMType resultType =
-      funcTy.getNumResults() == 0
-          ? LLVM::LLVMVoidType::get(&getContext())
-          : unwrap(packFunctionResults(funcTy.getResults()));
+  Type resultType = funcTy.getNumResults() == 0
+                        ? LLVM::LLVMVoidType::get(&getContext())
+                        : unwrap(packFunctionResults(funcTy.getResults()));
   if (!resultType)
     return {};
   return LLVM::LLVMFunctionType::get(resultType, argTypes, isVariadic);
@@ -259,23 +261,21 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature(
 
 /// Converts the function type to a C-compatible format, in particular using
 /// pointers to memref descriptors for arguments.
-LLVM::LLVMType
-LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
-  SmallVector<LLVM::LLVMType, 4> inputs;
+Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) {
+  SmallVector<Type, 4> inputs;
 
   for (Type t : type.getInputs()) {
-    auto converted = convertType(t).dyn_cast_or_null<LLVM::LLVMType>();
-    if (!converted)
+    auto converted = convertType(t);
+    if (!converted || !LLVM::isCompatibleType(converted))
       return {};
     if (t.isa<MemRefType, UnrankedMemRefType>())
       converted = LLVM::LLVMPointerType::get(converted);
     inputs.push_back(converted);
   }
 
-  LLVM::LLVMType resultType =
-      type.getNumResults() == 0
-          ? LLVM::LLVMVoidType::get(&getContext())
-          : unwrap(packFunctionResults(type.getResults()));
+  Type resultType = type.getNumResults() == 0
+                        ? LLVM::LLVMVoidType::get(&getContext())
+                        : unwrap(packFunctionResults(type.getResults()));
   if (!resultType)
     return {};
 
@@ -316,19 +316,19 @@ static constexpr unsigned kStridePosInMemRefDescriptor = 4;
 ///    Index sizes[Rank]; // omitted when rank == 0
 ///    Index strides[Rank]; // omitted when rank == 0
 ///  };
-SmallVector<LLVM::LLVMType, 5>
+SmallVector<Type, 5>
 LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
                                              bool unpackAggregates) {
   assert(isStrided(type) &&
          "Non-strided layout maps must have been normalized away");
 
-  LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+  Type elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
   auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
   auto indexTy = getIndexType();
 
-  SmallVector<LLVM::LLVMType, 5> results = {ptrTy, ptrTy, indexTy};
+  SmallVector<Type, 5> results = {ptrTy, ptrTy, indexTy};
   auto rank = type.getRank();
   if (rank == 0)
     return results;
@@ -345,7 +345,7 @@ LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type,
 Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
   // When converting a MemRefType to a struct with descriptor fields, do not
   // unpack the `sizes` and `strides` arrays.
-  SmallVector<LLVM::LLVMType, 5> types =
+  SmallVector<Type, 5> types =
       getMemRefDescriptorFields(type, /*unpackAggregates=*/false);
   return LLVM::LLVMStructType::getLiteral(&getContext(), types);
 }
@@ -360,8 +360,7 @@ static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1;
 /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be
 ///    stack allocated (alloca) copy of a MemRef descriptor that got casted to
 ///    be unranked.
-SmallVector<LLVM::LLVMType, 2>
-LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
+SmallVector<Type, 2> LLVMTypeConverter::getUnrankedMemRefDescriptorFields() {
   return {getIndexType(), LLVM::LLVMPointerType::get(
                               LLVM::LLVMIntegerType::get(&getContext(), 8))};
 }
@@ -395,7 +394,7 @@ Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) {
   if (ShapedType::isDynamicStrideOrOffset(offset))
     return {};
 
-  LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
+  Type elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
   return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace());
@@ -409,7 +408,7 @@ Type LLVMTypeConverter::convertVectorType(VectorType type) {
   auto elementType = unwrap(convertType(type.getElementType()));
   if (!elementType)
     return {};
-  LLVM::LLVMType vectorType =
+  Type vectorType =
       LLVM::LLVMFixedVectorType::get(elementType, type.getShape().back());
   auto shape = type.getShape();
   for (int i = shape.size() - 2; i >= 0; --i)
@@ -454,10 +453,9 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName,
 // StructBuilder implementation
 //===----------------------------------------------------------------------===//
 
-StructBuilder::StructBuilder(Value v) : value(v) {
+StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) {
   assert(value != nullptr && "value cannot be null");
-  structType = value.getType().dyn_cast<LLVM::LLVMType>();
-  assert(structType && "expected llvm type");
+  assert(LLVM::isCompatibleType(structType) && "expected llvm type");
 }
 
 Value StructBuilder::extractPtr(OpBuilder &builder, Location loc,
@@ -479,7 +477,7 @@ void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos,
 
 ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder,
                                                  Location loc, Type type) {
-  Value val = builder.create<LLVM::UndefOp>(loc, type.cast<LLVM::LLVMType>());
+  Value val = builder.create<LLVM::UndefOp>(loc, type);
   return ComplexStructBuilder(val);
 }
 
@@ -518,8 +516,7 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor)
 MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc,
                                          Type descriptorType) {
 
-  Value descriptor =
-      builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
+  Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
   return MemRefDescriptor(descriptor);
 }
 
@@ -620,9 +617,8 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) {
 
 Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
                              int64_t rank) {
-  auto indexTy = indexType.cast<LLVM::LLVMType>();
-  auto indexPtrTy = LLVM::LLVMPointerType::get(indexTy);
-  auto arrayTy = LLVM::LLVMArrayType::get(indexTy, rank);
+  auto indexPtrTy = LLVM::LLVMPointerType::get(indexType);
+  auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank);
   auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy);
 
   // Copy size values to stack-allocated memory.
@@ -774,8 +770,7 @@ UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor)
 UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder,
                                                          Location loc,
                                                          Type descriptorType) {
-  Value descriptor =
-      builder.create<LLVM::UndefOp>(loc, descriptorType.cast<LLVM::LLVMType>());
+  Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
   return UnrankedMemRefDescriptor(descriptor);
 }
 Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) {
@@ -828,7 +823,7 @@ void UnrankedMemRefDescriptor::computeSizes(
     return;
 
   // Cache the index type.
-  LLVM::LLVMType indexType = typeConverter.getIndexType();
+  Type indexType = typeConverter.getIndexType();
 
   // Initialize shared constants.
   Value one = createIndexAttrConstant(builder, loc, indexType, 1);
@@ -868,7 +863,7 @@ void UnrankedMemRefDescriptor::computeSizes(
 
 Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
                                              Value memRefDescPtr,
-                                             LLVM::LLVMType elemPtrPtrType) {
+                                             Type elemPtrPtrType) {
 
   Value elementPtrPtr =
       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
@@ -877,7 +872,7 @@ Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc,
 
 void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
                                                Value memRefDescPtr,
-                                               LLVM::LLVMType elemPtrPtrType,
+                                               Type elemPtrPtrType,
                                                Value allocatedPtr) {
   Value elementPtrPtr =
       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
@@ -887,7 +882,7 @@ void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc,
 Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
                                            LLVMTypeConverter &typeConverter,
                                            Value memRefDescPtr,
-                                           LLVM::LLVMType elemPtrPtrType) {
+                                           Type elemPtrPtrType) {
   Value elementPtrPtr =
       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
 
@@ -901,7 +896,7 @@ Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc,
 void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
                                              LLVMTypeConverter &typeConverter,
                                              Value memRefDescPtr,
-                                             LLVM::LLVMType elemPtrPtrType,
+                                             Type elemPtrPtrType,
                                              Value alignedPtr) {
   Value elementPtrPtr =
       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
@@ -916,7 +911,7 @@ void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc,
 Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
                                        LLVMTypeConverter &typeConverter,
                                        Value memRefDescPtr,
-                                       LLVM::LLVMType elemPtrPtrType) {
+                                       Type elemPtrPtrType) {
   Value elementPtrPtr =
       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
 
@@ -932,8 +927,7 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc,
 void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
                                          LLVMTypeConverter &typeConverter,
                                          Value memRefDescPtr,
-                                         LLVM::LLVMType elemPtrPtrType,
-                                         Value offset) {
+                                         Type elemPtrPtrType, Value offset) {
   Value elementPtrPtr =
       builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
 
@@ -949,16 +943,15 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc,
 Value UnrankedMemRefDescriptor::sizeBasePtr(
     OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter,
     Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) {
-  LLVM::LLVMType elemPtrTy = elemPtrPtrType.getElementType();
-  LLVM::LLVMType indexTy = typeConverter.getIndexType();
-  LLVM::LLVMType structPtrTy =
+  Type elemPtrTy = elemPtrPtrType.getElementType();
+  Type indexTy = typeConverter.getIndexType();
+  Type structPtrTy =
       LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral(
           indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy}));
   Value structPtr =
       builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
 
-  LLVM::LLVMType int32_type =
-      unwrap(typeConverter.convertType(builder.getI32Type()));
+  Type int32_type = unwrap(typeConverter.convertType(builder.getI32Type()));
   Value zero =
       createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0);
   Value three = builder.create<LLVM::ConstantOp>(loc, int32_type,
@@ -970,8 +963,7 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(
 Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc,
                                      LLVMTypeConverter typeConverter,
                                      Value sizeBasePtr, Value index) {
-  LLVM::LLVMType indexPtrTy =
-      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+  Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
                                                    ValueRange({index}));
   return builder.create<LLVM::LoadOp>(loc, sizeStoreGep);
@@ -981,8 +973,7 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
                                        LLVMTypeConverter typeConverter,
                                        Value sizeBasePtr, Value index,
                                        Value size) {
-  LLVM::LLVMType indexPtrTy =
-      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+  Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
                                                    ValueRange({index}));
   builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep);
@@ -991,8 +982,7 @@ void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc,
 Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc,
                                               LLVMTypeConverter &typeConverter,
                                               Value sizeBasePtr, Value rank) {
-  LLVM::LLVMType indexPtrTy =
-      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+  Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr,
                                      ValueRange({rank}));
 }
@@ -1001,8 +991,7 @@ Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc,
                                        LLVMTypeConverter typeConverter,
                                        Value strideBasePtr, Value index,
                                        Value stride) {
-  LLVM::LLVMType indexPtrTy =
-      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+  Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   Value strideStoreGep = builder.create<LLVM::GEPOp>(
       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
   return builder.create<LLVM::LoadOp>(loc, strideStoreGep);
@@ -1012,8 +1001,7 @@ void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc,
                                          LLVMTypeConverter typeConverter,
                                          Value strideBasePtr, Value index,
                                          Value stride) {
-  LLVM::LLVMType indexPtrTy =
-      LLVM::LLVMPointerType::get(typeConverter.getIndexType());
+  Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType());
   Value strideStoreGep = builder.create<LLVM::GEPOp>(
       loc, indexPtrTy, strideBasePtr, ValueRange({index}));
   builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep);
@@ -1028,22 +1016,21 @@ LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const {
   return *getTypeConverter()->getDialect();
 }
 
-LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
+Type ConvertToLLVMPattern::getIndexType() const {
   return getTypeConverter()->getIndexType();
 }
 
-LLVM::LLVMType
-ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
+Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const {
   return LLVM::LLVMIntegerType::get(
       &getTypeConverter()->getContext(),
       getTypeConverter()->getPointerBitwidth(addressSpace));
 }
 
-LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const {
+Type ConvertToLLVMPattern::getVoidType() const {
   return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext());
 }
 
-LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const {
+Type ConvertToLLVMPattern::getVoidPtrType() const {
   return LLVM::LLVMPointerType::get(
       LLVM::LLVMIntegerType::get(&getTypeConverter()->getContext(), 8));
 }
@@ -1084,7 +1071,7 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
         index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
   }
 
-  LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType();
+  Type elementPtrType = memRefDescriptor.getElementPtrType();
   return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index)
                : base;
 }
@@ -1161,8 +1148,8 @@ Value ConvertToLLVMPattern::getSizeInBytes(
   //   %0 = getelementptr %elementType* null, %indexType 1
   //   %1 = ptrtoint %elementType* %0 to %indexType
   // which is a common pattern of getting the size of a type in bytes.
-  auto convertedPtrType = LLVM::LLVMPointerType::get(
-      typeConverter->convertType(type).cast<LLVM::LLVMType>());
+  auto convertedPtrType =
+      LLVM::LLVMPointerType::get(typeConverter->convertType(type));
   auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType);
   auto gep = rewriter.create<LLVM::GEPOp>(
       loc, convertedPtrType,
@@ -1276,7 +1263,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
                                  FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) {
   OpBuilder::InsertionGuard guard(builder);
 
-  LLVM::LLVMType wrapperType =
+  Type wrapperType =
       typeConverter.convertFunctionTypeCWrapper(funcOp.getType());
   // This conversion can only fail if it could not convert one of the argument
   // types. But since it has been applies to a non-wrapper function before, it
@@ -1318,8 +1305,7 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
                     builder, loc, typeConverter, unrankedMemRefType,
                     wrapperArgsRange.take_front(numToDrop));
 
-      auto ptrTy =
-          LLVM::LLVMPointerType::get(packed.getType().cast<LLVM::LLVMType>());
+      auto ptrTy = LLVM::LLVMPointerType::get(packed.getType());
       Value one = builder.create<LLVM::ConstantOp>(
           loc, typeConverter.convertType(builder.getIndexType()),
           builder.getIntegerAttr(builder.getIndexType(), 1));
@@ -1494,9 +1480,9 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
 // 1-D LLVM vectors.
 struct NDVectorTypeInfo {
   // LLVM array struct which encodes n-D vectors.
-  LLVM::LLVMType llvmArrayTy;
+  Type llvmArrayTy;
   // LLVM vector type which encodes the inner 1-D vector type.
-  LLVM::LLVMType llvmVectorTy;
+  Type llvmVectorTy;
   // Multiplicity of llvmArrayTy to llvmVectorTy.
   SmallVector<int64_t, 4> arraySizes;
 };
@@ -1510,10 +1496,11 @@ static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
                                                 LLVMTypeConverter &converter) {
   assert(vectorType.getRank() > 1 && "expected >1D vector type");
   NDVectorTypeInfo info;
-  info.llvmArrayTy =
-      converter.convertType(vectorType).dyn_cast<LLVM::LLVMType>();
-  if (!info.llvmArrayTy)
+  info.llvmArrayTy = converter.convertType(vectorType);
+  if (!info.llvmArrayTy || !LLVM::isCompatibleType(info.llvmArrayTy)) {
+    info.llvmArrayTy = nullptr;
     return info;
+  }
   info.arraySizes.reserve(vectorType.getRank() - 1);
   auto llvmTy = info.llvmArrayTy;
   while (llvmTy.isa<LLVM::LLVMArrayType>()) {
@@ -1610,14 +1597,14 @@ LogicalResult LLVM::detail::oneToOneRewrite(
 
 static LogicalResult handleMultidimensionalVectors(
     Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
-    std::function<Value(LLVM::LLVMType, ValueRange)> createOperand,
+    std::function<Value(Type, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter) {
   auto vectorType = op->getResult(0).getType().dyn_cast<VectorType>();
   if (!vectorType)
     return failure();
   auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter);
   auto llvmVectorTy = vectorTypeInfo.llvmVectorTy;
-  auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+  auto llvmArrayTy = operands[0].getType();
   if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy)
     return failure();
 
@@ -1645,14 +1632,14 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
 
   // Cannot convert ops if their operands are not of LLVM type.
   if (!llvm::all_of(operands.getTypes(),
-                    [](Type t) { return t.isa<LLVM::LLVMType>(); }))
+                    [](Type t) { return isCompatibleType(t); }))
     return failure();
 
-  auto llvmArrayTy = operands[0].getType().cast<LLVM::LLVMType>();
+  auto llvmArrayTy = operands[0].getType();
   if (!llvmArrayTy.isa<LLVM::LLVMArrayType>())
     return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
 
-  auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy,
+  auto callback = [op, targetOp, &rewriter](Type llvmVectorTy,
                                             ValueRange operands) {
     OperationState state(op->getLoc(), targetOp);
     state.addTypes(llvmVectorTy);
@@ -1896,16 +1883,18 @@ struct ConstantOpLowering : public ConvertOpToLLVMPattern<ConstantOp> {
                   ConversionPatternRewriter &rewriter) const override {
     // If constant refers to a function, convert it to "addressof".
     if (auto symbolRef = op.getValue().dyn_cast<FlatSymbolRefAttr>()) {
-      auto type = typeConverter->convertType(op.getResult().getType())
-                      .dyn_cast_or_null<LLVM::LLVMType>();
-      if (!type)
+      auto type = typeConverter->convertType(op.getResult().getType());
+      if (!type || !LLVM::isCompatibleType(type))
         return rewriter.notifyMatchFailure(op, "failed to convert result type");
 
-      NamedAttrList attrs(op->getAttrDictionary());
-      attrs.erase("value");
-      rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(
-          op, type.cast<LLVM::LLVMType>(), symbolRef.getValue(),
-          attrs.getAttrs());
+      auto newOp = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), type,
+                                                      symbolRef.getValue());
+      for (const NamedAttribute &attr : op->getAttrs()) {
+        if (attr.first.strref() == "value")
+          continue;
+        newOp.setAttr(attr.first, attr.second);
+      }
+      rewriter.replaceOp(op, newOp->getResults());
       return success();
     }
 
@@ -1947,11 +1936,11 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern {
   Value createAllocCall(Location loc, StringRef name, Type ptrType,
                         ArrayRef<Value> params, ModuleOp module,
                         ConversionPatternRewriter &rewriter) const {
-    SmallVector<LLVM::LLVMType, 2> paramTypes;
+    SmallVector<Type, 2> paramTypes;
     auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
     if (!allocFuncOp) {
       for (Value param : params)
-        paramTypes.push_back(param.getType().cast<LLVM::LLVMType>());
+        paramTypes.push_back(param.getType());
       auto allocFuncType =
           LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes);
       OpBuilder::InsertionGuard guard(rewriter);
@@ -2206,10 +2195,10 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc,
   // Get frequently used types.
   MLIRContext *context = builder.getContext();
   auto voidType = LLVM::LLVMVoidType::get(context);
-  LLVM::LLVMType voidPtrType =
+  Type voidPtrType =
       LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8));
   auto i1Type = LLVM::LLVMIntegerType::get(context, 1);
-  LLVM::LLVMType indexType = typeConverter.getIndexType();
+  Type indexType = typeConverter.getIndexType();
 
   // Find the malloc and free, or declare them if necessary.
   auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
@@ -2389,17 +2378,15 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> {
 };
 
 /// Returns the LLVM type of the global variable given the memref type `type`.
-static LLVM::LLVMType
-convertGlobalMemrefTypeToLLVM(MemRefType type,
-                              LLVMTypeConverter &typeConverter) {
+static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
+                                          LLVMTypeConverter &typeConverter) {
   // LLVM type for a global memref will be a multi-dimension array. For
   // declarations or uninitialized global memrefs, we can potentially flatten
   // this to a 1D array. However, for global_memref's with an initial value,
   // we do not intend to flatten the ElementsAttribute when going from std ->
   // LLVM dialect, so the LLVM type needs to me a multi-dimension array.
-  LLVM::LLVMType elementType =
-      unwrap(typeConverter.convertType(type.getElementType()));
-  LLVM::LLVMType arrayTy = elementType;
+  Type elementType = unwrap(typeConverter.convertType(type.getElementType()));
+  Type arrayTy = elementType;
   // Shape has the outermost dim at index 0, so need to walk it backwards
   for (int64_t dim : llvm::reverse(type.getShape()))
     arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim);
@@ -2417,8 +2404,7 @@ struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern<GlobalMemrefOp> {
     if (!isConvertibleAndHasIdentityMaps(type))
       return failure();
 
-    LLVM::LLVMType arrayTy =
-        convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
+    Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
 
     LLVM::Linkage linkage =
         global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private;
@@ -2457,17 +2443,15 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering {
     MemRefType type = getGlobalOp.result().getType().cast<MemRefType>();
     unsigned memSpace = type.getMemorySpace();
 
-    LLVM::LLVMType arrayTy =
-        convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
+    Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
     auto addressOf = rewriter.create<LLVM::AddressOfOp>(
         loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name());
 
     // Get the address of the first element in the array by creating a GEP with
     // the address of the GV as the base, and (rank + 1) number of 0 indices.
-    LLVM::LLVMType elementType =
+    Type elementType =
         unwrap(typeConverter->convertType(type.getElementType()));
-    LLVM::LLVMType elementPtrType =
-        LLVM::LLVMPointerType::get(elementType, memSpace);
+    Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace);
 
     SmallVector<Value, 4> operands = {addressOf};
     operands.insert(operands.end(), type.getRank() + 1,
@@ -2497,10 +2481,9 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
   matchAndRewrite(RsqrtOp op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
     RsqrtOp::Adaptor transformed(operands);
-    auto operandType =
-        transformed.operand().getType().dyn_cast<LLVM::LLVMType>();
+    auto operandType = transformed.operand().getType();
 
-    if (!operandType)
+    if (!operandType || !LLVM::isCompatibleType(operandType))
       return failure();
 
     auto loc = op.getLoc();
@@ -2528,7 +2511,7 @@ struct RsqrtOpLowering : public ConvertOpToLLVMPattern<RsqrtOp> {
 
     return handleMultidimensionalVectors(
         op.getOperation(), operands, *getTypeConverter(),
-        [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) {
+        [&](Type llvmVectorTy, ValueRange operands) {
           auto splatAttr = SplatElementsAttr::get(
               mlir::VectorType::get(
                   {llvmVectorTy.cast<LLVM::LLVMFixedVectorType>()
@@ -2620,13 +2603,11 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<MemRefCastOp> {
       // ptr = ExtractValueOp src, 1
       auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
       // castPtr = BitCastOp i8* to structTy*
-      auto castPtr = rewriter
-                         .create<LLVM::BitcastOp>(
-                             loc,
-                             LLVM::LLVMPointerType::get(
-                                 targetStructType.cast<LLVM::LLVMType>()),
-                             ptr)
-                         .getResult();
+      auto castPtr =
+          rewriter
+              .create<LLVM::BitcastOp>(
+                  loc, LLVM::LLVMPointerType::get(targetStructType), ptr)
+              .getResult();
       // struct = LoadOp castPtr
       auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr);
       rewriter.replaceOp(memRefCastOp, loadOp.getResult());
@@ -2659,9 +2640,8 @@ static void extractPointersAndOffset(Location loc,
   unsigned memorySpace =
       operandType.cast<UnrankedMemRefType>().getMemorySpace();
   Type elementType = operandType.cast<UnrankedMemRefType>().getElementType();
-  LLVM::LLVMType llvmElementType =
-      unwrap(typeConverter.convertType(elementType));
-  LLVM::LLVMType elementPtrPtrType = LLVM::LLVMPointerType::get(
+  Type llvmElementType = unwrap(typeConverter.convertType(elementType));
+  Type elementPtrPtrType = LLVM::LLVMPointerType::get(
       LLVM::LLVMPointerType::get(llvmElementType, memorySpace));
 
   // Extract pointer to the underlying ranked memref descriptor and cast it to
@@ -2809,8 +2789,7 @@ struct MemRefReshapeOpLowering
                              &allocatedPtr, &alignedPtr, &offset);
 
     // Set pointers and offset.
-    LLVM::LLVMType llvmElementType =
-        unwrap(typeConverter->convertType(elementType));
+    Type llvmElementType = unwrap(typeConverter->convertType(elementType));
     auto elementPtrPtrType = LLVM::LLVMPointerType::get(
         LLVM::LLVMPointerType::get(llvmElementType, addressSpace));
     UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
@@ -2835,7 +2814,7 @@ struct MemRefReshapeOpLowering
         rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex);
 
     Block *initBlock = rewriter.getInsertionBlock();
-    LLVM::LLVMType indexType = getTypeConverter()->getIndexType();
+    Type indexType = getTypeConverter()->getIndexType();
     Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint());
 
     Block *condBlock = rewriter.createBlock(initBlock->getParent(), {},
@@ -2865,7 +2844,7 @@ struct MemRefReshapeOpLowering
     rewriter.setInsertionPointToStart(bodyBlock);
 
     // Copy size from shape to descriptor.
-    LLVM::LLVMType llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
+    Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType);
     Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
         loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg});
     Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep);
@@ -2957,9 +2936,8 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<DimOp> {
     Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);
     Value scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
         loc,
-        LLVM::LLVMPointerType::get(
-            typeConverter->convertType(scalarMemRefType).cast<LLVM::LLVMType>(),
-            addressSpace),
+        LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType),
+                                   addressSpace),
         underlyingRankedDesc);
 
     // Get pointer to offset field of memref<element_type> descriptor.
@@ -3435,8 +3413,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
 
     auto sourceMemRefType = subViewOp.source().getType().cast<MemRefType>();
     auto sourceElementTy =
-        typeConverter->convertType(sourceMemRefType.getElementType())
-            .dyn_cast_or_null<LLVM::LLVMType>();
+        typeConverter->convertType(sourceMemRefType.getElementType());
 
     auto viewMemRefType = subViewOp.getType();
     auto inferredType = SubViewOp::inferResultType(
@@ -3446,11 +3423,12 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
                             extractFromI64ArrayAttr(subViewOp.static_strides()))
                             .cast<MemRefType>();
     auto targetElementTy =
-        typeConverter->convertType(viewMemRefType.getElementType())
-            .dyn_cast<LLVM::LLVMType>();
-    auto targetDescTy = typeConverter->convertType(viewMemRefType)
-                            .dyn_cast_or_null<LLVM::LLVMType>();
-    if (!sourceElementTy || !targetDescTy)
+        typeConverter->convertType(viewMemRefType.getElementType());
+    auto targetDescTy = typeConverter->convertType(viewMemRefType);
+    if (!sourceElementTy || !targetDescTy || !targetElementTy ||
+        !LLVM::isCompatibleType(sourceElementTy) ||
+        !LLVM::isCompatibleType(targetElementTy) ||
+        !LLVM::isCompatibleType(targetDescTy))
       return failure();
 
     // Extract the offset and strides from the type.
@@ -3461,7 +3439,7 @@ struct SubViewOpLowering : public ConvertOpToLLVMPattern<SubViewOp> {
       return failure();
 
     // Create the descriptor.
-    if (!operands.front().getType().isa<LLVM::LLVMType>())
+    if (!LLVM::isCompatibleType(operands.front().getType()))
       return failure();
     MemRefDescriptor sourceMemRef(operands.front());
     auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy);
@@ -3650,11 +3628,11 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<ViewOp> {
 
     auto viewMemRefType = viewOp.getType();
     auto targetElementTy =
-        typeConverter->convertType(viewMemRefType.getElementType())
-            .dyn_cast<LLVM::LLVMType>();
-    auto targetDescTy =
-        typeConverter->convertType(viewMemRefType).dyn_cast<LLVM::LLVMType>();
-    if (!targetDescTy)
+        typeConverter->convertType(viewMemRefType.getElementType());
+    auto targetDescTy = typeConverter->convertType(viewMemRefType);
+    if (!targetDescTy || !targetElementTy ||
+        !LLVM::isCompatibleType(targetElementTy) ||
+        !LLVM::isCompatibleType(targetDescTy))
       return viewOp.emitWarning("Target descriptor type not converted to LLVM"),
              failure();
 
@@ -3849,9 +3827,7 @@ struct GenericAtomicRMWOpLowering
 
     auto loc = atomicOp.getLoc();
     GenericAtomicRMWOp::Adaptor adaptor(operands);
-    LLVM::LLVMType valueType =
-        typeConverter->convertType(atomicOp.getResult().getType())
-            .cast<LLVM::LLVMType>();
+    Type valueType = typeConverter->convertType(atomicOp.getResult().getType());
 
     // Split the block into initial, loop, and ending parts.
     auto *initBlock = rewriter.getInsertionBlock();
@@ -4060,12 +4036,11 @@ Type LLVMTypeConverter::packFunctionResults(ArrayRef<Type> types) {
   if (types.size() == 1)
     return convertCallingConventionType(types.front());
 
-  SmallVector<LLVM::LLVMType, 8> resultTypes;
+  SmallVector<Type, 8> resultTypes;
   resultTypes.reserve(types.size());
   for (auto t : types) {
-    auto converted =
-        convertCallingConventionType(t).dyn_cast_or_null<LLVM::LLVMType>();
-    if (!converted)
+    auto converted = convertCallingConventionType(t);
+    if (!converted || !LLVM::isCompatibleType(converted))
       return {};
     resultTypes.push_back(converted);
   }
@@ -4080,8 +4055,7 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
   auto indexType = IndexType::get(context);
   // Alloca with proper alignment. We do not expect optimizations of this
   // alloca op and so we omit allocating at the entry block.
-  auto ptrType =
-      LLVM::LLVMPointerType::get(operand.getType().cast<LLVM::LLVMType>());
+  auto ptrType = LLVM::LLVMPointerType::get(operand.getType());
   Value one = builder.create<LLVM::ConstantOp>(loc, int64Ty,
                                                IntegerAttr::get(indexType, 1));
   Value allocated =

diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index e023b6da460b..535cdcb7dfd7 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -152,8 +152,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
   // stop depending on translation.
   llvm::LLVMContext llvmContext;
   align = LLVM::TypeToLLVMIRTranslator(llvmContext)
-              .getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
-                                     typeConverter.getDataLayout());
+              .getPreferredAlignment(elementTy, typeConverter.getDataLayout());
   return success();
 }
 
@@ -193,7 +192,7 @@ static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter,
   Value base;
   if (failed(getBase(rewriter, loc, memref, memRefType, base)))
     return failure();
-  auto pType = LLVM::LLVMPointerType::get(type.template cast<LLVM::LLVMType>());
+  auto pType = LLVM::LLVMPointerType::get(type);
   base = rewriter.create<LLVM::BitcastOp>(loc, pType, base);
   ptr = rewriter.create<LLVM::GEPOp>(loc, pType, base);
   return success();
@@ -1401,7 +1400,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
 
   // Helper for printer method declaration (first hit) and lookup.
   static Operation *getPrint(Operation *op, StringRef name,
-                             ArrayRef<LLVM::LLVMType> params) {
+                             ArrayRef<Type> params) {
     auto module = op->getParentOfType<ModuleOp>();
     auto func = module.lookupSymbol<LLVM::LLVMFuncOp>(name);
     if (func)

diff  --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 3e3ddc6aaff6..d27f097a3baa 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -30,8 +30,8 @@ using namespace mlir::vector;
 static LogicalResult replaceTransferOpWithMubuf(
     ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
     LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp,
-    LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
-    Value &offsetSizeInBytes, Value &glc, Value &slc) {
+    Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
+    Value &glc, Value &slc) {
   rewriter.replaceOpWithNewOp<ROCDL::MubufLoadOp>(
       xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc);
   return success();
@@ -40,8 +40,8 @@ static LogicalResult replaceTransferOpWithMubuf(
 static LogicalResult replaceTransferOpWithMubuf(
     ConversionPatternRewriter &rewriter, ArrayRef<Value> operands,
     LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp,
-    LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex,
-    Value &offsetSizeInBytes, Value &glc, Value &slc) {
+    Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes,
+    Value &glc, Value &slc) {
   auto adaptor = TransferWriteOpAdaptor(operands);
   rewriter.replaceOpWithNewOp<ROCDL::MubufStoreOp>(xferOp, adaptor.vector(),
                                                    dwordConfig, vindex,
@@ -121,16 +121,16 @@ class VectorTransferConversion : public ConvertOpToLLVMPattern<ConcreteOp> {
     Type i64Ty = rewriter.getIntegerType(64);
     Value i64x2Ty = rewriter.create<LLVM::BitcastOp>(
         loc,
-        LLVM::LLVMFixedVectorType::get(
-            toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+        LLVM::LLVMFixedVectorType::get(toLLVMTy(i64Ty).template cast<Type>(),
+                                       2),
         constConfig);
     Value dataPtrAsI64 = rewriter.create<LLVM::PtrToIntOp>(
-        loc, toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), dataPtr);
+        loc, toLLVMTy(i64Ty).template cast<Type>(), dataPtr);
     Value zero = this->createIndexConstant(rewriter, loc, 0);
     Value dwordConfig = rewriter.create<LLVM::InsertElementOp>(
         loc,
-        LLVM::LLVMFixedVectorType::get(
-            toLLVMTy(i64Ty).template cast<LLVM::LLVMType>(), 2),
+        LLVM::LLVMFixedVectorType::get(toLLVMTy(i64Ty).template cast<Type>(),
+                                       2),
         i64x2Ty, dataPtrAsI64, zero);
     dwordConfig =
         rewriter.create<LLVM::BitcastOp>(loc, toLLVMTy(i32Vecx4), dwordConfig);

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 765538ca7a53..0a9b61628384 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -101,14 +101,14 @@ static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
 
   // The result type is either i1 or a vector type <? x i1> if the inputs are
   // vectors.
-  LLVMType resultType = LLVMIntegerType::get(builder.getContext(), 1);
-  auto argType = type.dyn_cast<LLVM::LLVMType>();
-  if (!argType)
-    return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type");
-  if (auto vecArgType = argType.dyn_cast<LLVM::LLVMFixedVectorType>())
+  Type resultType = LLVMIntegerType::get(builder.getContext(), 1);
+  if (!isCompatibleType(type))
+    return parser.emitError(trailingTypeLoc,
+                            "expected LLVM dialect-compatible type");
+  if (auto vecArgType = type.dyn_cast<LLVM::LLVMFixedVectorType>())
     resultType =
         LLVMFixedVectorType::get(resultType, vecArgType.getNumElements());
-  assert(!argType.isa<LLVM::LLVMScalableVectorType>() &&
+  assert(!type.isa<LLVM::LLVMScalableVectorType>() &&
          "unhandled scalable vector");
 
   result.addTypes({resultType});
@@ -546,21 +546,21 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
       return parser.emitError(trailingTypeLoc,
                               "expected function with 0 or 1 result");
 
-    LLVM::LLVMType llvmResultType;
+    Type llvmResultType;
     if (funcType.getNumResults() == 0) {
       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
     } else {
-      llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
-      if (!llvmResultType)
+      llvmResultType = funcType.getResult(0);
+      if (!isCompatibleType(llvmResultType))
         return parser.emitError(trailingTypeLoc,
                                 "expected result to have LLVM type");
     }
 
-    SmallVector<LLVM::LLVMType, 8> argTypes;
+    SmallVector<Type, 8> argTypes;
     argTypes.reserve(funcType.getNumInputs());
     for (Type ty : funcType.getInputs()) {
-      if (auto argType = ty.dyn_cast<LLVM::LLVMType>())
-        argTypes.push_back(argType);
+      if (isCompatibleType(ty))
+        argTypes.push_back(ty);
       else
         return parser.emitError(trailingTypeLoc,
                                 "expected LLVM types as inputs");
@@ -693,7 +693,7 @@ static LogicalResult verify(CallOp &op) {
 
   // Type for the callee, we'll get it 
diff erently depending if it is a direct
   // or indirect call.
-  LLVMType fnType;
+  Type fnType;
 
   bool isIndirect = false;
 
@@ -704,14 +704,10 @@ static LogicalResult verify(CallOp &op) {
     if (!op.getNumOperands())
       return op.emitOpError(
           "must have either a `callee` attribute or at least an operand");
-    fnType = op.getOperand(0).getType().dyn_cast<LLVMType>();
-    if (!fnType)
-      return op.emitOpError("indirect call to a non-llvm type: ")
-             << op.getOperand(0).getType();
-    auto ptrType = fnType.dyn_cast<LLVMPointerType>();
+    auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>();
     if (!ptrType)
       return op.emitOpError("indirect call expects a pointer as callee: ")
-             << fnType;
+             << ptrType;
     fnType = ptrType.getElementType();
   } else {
     Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
@@ -825,21 +821,21 @@ static ParseResult parseCallOp(OpAsmParser &parser, OperationState &result) {
                               "expected function with 0 or 1 result");
 
     Builder &builder = parser.getBuilder();
-    LLVM::LLVMType llvmResultType;
+    Type llvmResultType;
     if (funcType.getNumResults() == 0) {
       llvmResultType = LLVM::LLVMVoidType::get(builder.getContext());
     } else {
-      llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
-      if (!llvmResultType)
+      llvmResultType = funcType.getResult(0);
+      if (!isCompatibleType(llvmResultType))
         return parser.emitError(trailingTypeLoc,
                                 "expected result to have LLVM type");
     }
 
-    SmallVector<LLVM::LLVMType, 8> argTypes;
+    SmallVector<Type, 8> argTypes;
     argTypes.reserve(funcType.getNumInputs());
     for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) {
-      auto argType = funcType.getInput(i).dyn_cast<LLVM::LLVMType>();
-      if (!argType)
+      auto argType = funcType.getInput(i);
+      if (!isCompatibleType(argType))
         return parser.emitError(trailingTypeLoc,
                                 "expected LLVM types as inputs");
       argTypes.push_back(argType);
@@ -922,13 +918,13 @@ static void printExtractValueOp(OpAsmPrinter &p, ExtractValueOp &op) {
 // `containerType`.  Position is an integer array attribute where each value
 // is a zero-based position of the element in the aggregate type.  Return the
 // resulting type wrapped in MLIR, or nullptr on error.
-static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser,
-                                                       Type containerType,
-                                                       ArrayAttr positionAttr,
-                                                       llvm::SMLoc attributeLoc,
-                                                       llvm::SMLoc typeLoc) {
-  auto llvmType = containerType.dyn_cast<LLVM::LLVMType>();
-  if (!llvmType)
+static Type getInsertExtractValueElementType(OpAsmParser &parser,
+                                             Type containerType,
+                                             ArrayAttr positionAttr,
+                                             llvm::SMLoc attributeLoc,
+                                             llvm::SMLoc typeLoc) {
+  Type llvmType = containerType;
+  if (!isCompatibleType(containerType))
     return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr;
 
   // Infer the element type from the structure type: iteratively step inside the
@@ -1162,7 +1158,7 @@ static LogicalResult verify(AddressOfOp op) {
 /// the name of the attribute in ODS.
 static StringRef getLinkageAttrName() { return "linkage"; }
 
-void GlobalOp::build(OpBuilder &builder, OperationState &result, LLVMType type,
+void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
                      bool isConstant, Linkage linkage, StringRef name,
                      Attribute value, unsigned addrSpace,
                      ArrayRef<NamedAttribute> attrs) {
@@ -1212,14 +1208,13 @@ static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
 /// report the error, the user is expected to produce an appropriate message.
 // TODO: make the size depend on data layout rather than on the conversion
 // pass option, and pull that information here.
-static LogicalResult verifyCastWithIndex(LLVMType llvmType) {
+static LogicalResult verifyCastWithIndex(Type llvmType) {
   return success(llvmType.isa<LLVMIntegerType>());
 }
 
 /// Checks if `llvmType` is dialect cast-compatible with built-in `type` and
 /// reports errors to the location of `op`.
-static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType,
-                                Type type) {
+static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type) {
   // Index is compatible with any integer.
   if (type.isIndex()) {
     if (succeeded(verifyCastWithIndex(llvmType)))
@@ -1387,14 +1382,13 @@ static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType,
 }
 
 static LogicalResult verify(DialectCastOp op) {
-  if (auto llvmType = op.getType().dyn_cast<LLVMType>())
-    return verifyCast(op, llvmType, op.in().getType());
+  if (isCompatibleType(op.getType()))
+    return verifyCast(op, op.getType(), op.in().getType());
 
-  auto llvmType = op.in().getType().dyn_cast<LLVMType>();
-  if (!llvmType)
+  if (!isCompatibleType(op.in().getType()))
     return op->emitOpError("expected one LLVM type and one built-in type");
 
-  return verifyCast(op, llvmType, op.getType());
+  return verifyCast(op, op.in().getType(), op.getType());
 }
 
 // Parses one of the keywords provided in the list `keywords` and returns the
@@ -1597,7 +1591,7 @@ Block *LLVMFuncOp::addEntryBlock() {
 }
 
 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
-                       StringRef name, LLVMType type, LLVM::Linkage linkage,
+                       StringRef name, Type type, LLVM::Linkage linkage,
                        ArrayRef<NamedAttribute> attrs,
                        ArrayRef<DictionaryAttr> argAttrs) {
   result.addRegion();
@@ -1633,23 +1627,23 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
   }
 
   // Convert inputs to LLVM types, exit early on error.
-  SmallVector<LLVMType, 4> llvmInputs;
+  SmallVector<Type, 4> llvmInputs;
   for (auto t : inputs) {
-    auto llvmTy = t.dyn_cast<LLVMType>();
-    if (!llvmTy) {
+    if (!isCompatibleType(t)) {
       parser.emitError(loc, "failed to construct function type: expected LLVM "
                             "type for function arguments");
       return {};
     }
-    llvmInputs.push_back(llvmTy);
+    llvmInputs.push_back(t);
   }
 
   // No output is denoted as "void" in LLVM type system.
-  LLVMType llvmOutput = outputs.empty() ? LLVMVoidType::get(b.getContext())
-                                        : outputs.front().dyn_cast<LLVMType>();
-  if (!llvmOutput) {
+  Type llvmOutput =
+      outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
+  if (!isCompatibleType(llvmOutput)) {
     parser.emitError(loc, "failed to construct function type: expected LLVM "
-                          "type for function results");
+                          "type for function results")
+        << llvmOutput;
     return {};
   }
   return LLVMFunctionType::get(llvmOutput, llvmInputs,
@@ -1720,7 +1714,7 @@ static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
   for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
     argTypes.push_back(fnType.getParamType(i));
 
-  LLVMType returnType = fnType.getReturnType();
+  Type returnType = fnType.getReturnType();
   if (!returnType.isa<LLVMVoidType>())
     resTypes.push_back(returnType);
 
@@ -1792,11 +1786,10 @@ static LogicalResult verify(LLVMFuncOp op) {
   Block &entryBlock = op.front();
   for (unsigned i = 0; i < numArguments; ++i) {
     Type argType = entryBlock.getArgument(i).getType();
-    auto argLLVMType = argType.dyn_cast<LLVMType>();
-    if (!argLLVMType)
+    if (!isCompatibleType(argType))
       return op.emitOpError("entry block argument #")
              << i << " is not of LLVM type";
-    if (op.getType().getParamType(i) != argLLVMType)
+    if (op.getType().getParamType(i) != argType)
       return op.emitOpError("the type of entry block argument #")
              << i << " does not match the function signature";
   }
@@ -1889,7 +1882,7 @@ static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
 //                 attribute-dict? `:` type
 static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
                                     OperationState &result) {
-  LLVMType type;
+  Type type;
   OpAsmParser::OperandType ptr, val;
   if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) ||
       parser.parseComma() || parser.parseOperand(val) ||
@@ -1907,11 +1900,11 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
 
 static LogicalResult verify(AtomicRMWOp op) {
   auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
-  auto valType = op.val().getType().cast<LLVM::LLVMType>();
+  auto valType = op.val().getType();
   if (valType != ptrType.getElementType())
     return op.emitOpError("expected LLVM IR element type for operand #0 to "
                           "match type for operand #1");
-  auto resType = op.res().getType().cast<LLVM::LLVMType>();
+  auto resType = op.res().getType();
   if (resType != valType)
     return op.emitOpError(
         "expected LLVM IR result type to match type for operand #1");
@@ -1954,7 +1947,7 @@ static void printAtomicCmpXchgOp(OpAsmPrinter &p, AtomicCmpXchgOp &op) {
 static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
                                         OperationState &result) {
   auto &builder = parser.getBuilder();
-  LLVMType type;
+  Type type;
   OpAsmParser::OperandType ptr, cmp, val;
   if (parser.parseOperand(ptr) || parser.parseComma() ||
       parser.parseOperand(cmp) || parser.parseComma() ||
@@ -1981,8 +1974,8 @@ static LogicalResult verify(AtomicCmpXchgOp op) {
   auto ptrType = op.ptr().getType().cast<LLVM::LLVMPointerType>();
   if (!ptrType)
     return op.emitOpError("expected LLVM IR pointer type for operand #0");
-  auto cmpType = op.cmp().getType().cast<LLVM::LLVMType>();
-  auto valType = op.val().getType().cast<LLVM::LLVMType>();
+  auto cmpType = op.cmp().getType();
+  auto valType = op.val().getType();
   if (cmpType != ptrType.getElementType() || cmpType != valType)
     return op.emitOpError("expected LLVM IR element type for operand #0 to "
                           "match type for all other operands");
@@ -2088,7 +2081,7 @@ Type LLVMDialect::parseType(DialectAsmParser &parser) const {
 
 /// Print a type registered to this dialect.
 void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const {
-  return detail::printType(type.cast<LLVMType>(), os);
+  return detail::printType(type, os);
 }
 
 LogicalResult LLVMDialect::verifyDataLayoutString(

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 574d0aa8c37f..3d72e254f338 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -19,11 +19,11 @@ using namespace mlir::LLVM;
 // Printing.
 //===----------------------------------------------------------------------===//
 
-static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
+static void printTypeImpl(llvm::raw_ostream &os, Type type,
                           llvm::SetVector<StringRef> &stack);
 
 /// Returns the keyword to use for the given type.
-static StringRef getTypeKeyword(LLVMType type) {
+static StringRef getTypeKeyword(Type type) {
   return TypeSwitch<Type, StringRef>(type)
       .Case<LLVMVoidType>([&](Type) { return "void"; })
       .Case<LLVMHalfType>([&](Type) { return "half"; })
@@ -64,7 +64,7 @@ static void printStructTypeBody(llvm::raw_ostream &os, LLVMStructType type,
   os << '(';
   if (type.isIdentified())
     stack.insert(type.getName());
-  llvm::interleaveComma(type.getBody(), os, [&](LLVMType subtype) {
+  llvm::interleaveComma(type.getBody(), os, [&](Type subtype) {
     printTypeImpl(os, subtype, stack);
   });
   if (type.isIdentified())
@@ -109,9 +109,9 @@ static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
   os << '<';
   printTypeImpl(os, funcType.getReturnType(), stack);
   os << " (";
-  llvm::interleaveComma(
-      funcType.getParams(), os,
-      [&os, &stack](LLVMType subtype) { printTypeImpl(os, subtype, stack); });
+  llvm::interleaveComma(funcType.getParams(), os, [&os, &stack](Type subtype) {
+    printTypeImpl(os, subtype, stack);
+  });
   if (funcType.isVarArg()) {
     if (funcType.getNumParams() != 0)
       os << ", ";
@@ -129,7 +129,7 @@ static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType,
 ///   struct<"c", (ptr<struct<"b", (ptr<struct<"c">>)>>,
 ///                ptr<struct<"b", (ptr<struct<"c">>)>>)>
 /// note that "b" is printed twice.
-static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
+static void printTypeImpl(llvm::raw_ostream &os, Type type,
                           llvm::SetVector<StringRef> &stack) {
   if (!type) {
     os << "<<NULL-TYPE>>";
@@ -171,7 +171,7 @@ static void printTypeImpl(llvm::raw_ostream &os, LLVMType type,
     return printFunctionType(os, funcType, stack);
 }
 
-void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) {
+void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) {
   llvm::SetVector<StringRef> stack;
   return printTypeImpl(printer.getStream(), type, stack);
 }
@@ -180,13 +180,13 @@ void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) {
 // Parsing.
 //===----------------------------------------------------------------------===//
 
-static LLVMType parseTypeImpl(DialectAsmParser &parser,
-                              llvm::SetVector<StringRef> &stack);
+static Type parseTypeImpl(DialectAsmParser &parser,
+                          llvm::SetVector<StringRef> &stack);
 
 /// Helper to be chained with other parsing functions.
 static ParseResult parseTypeImpl(DialectAsmParser &parser,
                                  llvm::SetVector<StringRef> &stack,
-                                 LLVMType &result) {
+                                 Type &result) {
   result = parseTypeImpl(parser, stack);
   return success(result != nullptr);
 }
@@ -196,7 +196,7 @@ static ParseResult parseTypeImpl(DialectAsmParser &parser,
 static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
                                           llvm::SetVector<StringRef> &stack) {
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
-  LLVMType returnType;
+  Type returnType;
   if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) ||
       parser.parseLParen())
     return LLVMFunctionType();
@@ -210,7 +210,7 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
   }
 
   // Parse arguments.
-  SmallVector<LLVMType, 8> argTypes;
+  SmallVector<Type, 8> argTypes;
   do {
     if (succeeded(parser.parseOptionalEllipsis())) {
       if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
@@ -235,7 +235,7 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser,
 static LLVMPointerType parsePointerType(DialectAsmParser &parser,
                                         llvm::SetVector<StringRef> &stack) {
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
-  LLVMType elementType;
+  Type elementType;
   if (parser.parseLess() || parseTypeImpl(parser, stack, elementType))
     return LLVMPointerType();
 
@@ -255,7 +255,7 @@ static LLVMVectorType parseVectorType(DialectAsmParser &parser,
                                       llvm::SetVector<StringRef> &stack) {
   SmallVector<int64_t, 2> dims;
   llvm::SMLoc dimPos;
-  LLVMType elementType;
+  Type elementType;
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
       parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
@@ -286,7 +286,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser,
                                     llvm::SetVector<StringRef> &stack) {
   SmallVector<int64_t, 1> dims;
   llvm::SMLoc sizePos;
-  LLVMType elementType;
+  Type elementType;
   Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
   if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
       parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
@@ -305,11 +305,11 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser,
 /// error at `subtypesLoc` in case of failure, uses `stack` to make sure the
 /// types printed in the error message look like they did when parsed.
 static LLVMStructType trySetStructBody(LLVMStructType type,
-                                       ArrayRef<LLVMType> subtypes,
-                                       bool isPacked, DialectAsmParser &parser,
+                                       ArrayRef<Type> subtypes, bool isPacked,
+                                       DialectAsmParser &parser,
                                        llvm::SMLoc subtypesLoc,
                                        llvm::SetVector<StringRef> &stack) {
-  for (LLVMType t : subtypes) {
+  for (Type t : subtypes) {
     if (!LLVMStructType::isValidElementType(t)) {
       parser.emitError(subtypesLoc)
           << "invalid LLVM structure element type: " << t;
@@ -389,12 +389,12 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
 
   // Parse subtypes. For identified structs, put the identifier of the struct on
   // the stack to support self-references in the recursive calls.
-  SmallVector<LLVMType, 4> subtypes;
+  SmallVector<Type, 4> subtypes;
   llvm::SMLoc subtypesLoc = parser.getCurrentLocation();
   do {
     if (isIdentified)
       stack.insert(name);
-    LLVMType type = parseTypeImpl(parser, stack);
+    Type type = parseTypeImpl(parser, stack);
     if (!type)
       return LLVMStructType();
     subtypes.push_back(type);
@@ -413,8 +413,8 @@ static LLVMStructType parseStructType(DialectAsmParser &parser,
 }
 
 /// Parses one of the LLVM dialect types.
-static LLVMType parseTypeImpl(DialectAsmParser &parser,
-                              llvm::SetVector<StringRef> &stack) {
+static Type parseTypeImpl(DialectAsmParser &parser,
+                          llvm::SetVector<StringRef> &stack) {
   // Special case for integers (i[1-9][0-9]*) that are literals rather than
   // keywords for the parser, so they are not caught by the main dispatch below.
   // Try parsing it a built-in integer type instead.
@@ -425,11 +425,11 @@ static LLVMType parseTypeImpl(DialectAsmParser &parser,
   OptionalParseResult result = parser.parseOptionalType(maybeIntegerType);
   if (result.hasValue()) {
     if (failed(*result))
-      return LLVMType();
+      return Type();
 
     if (!maybeIntegerType.isSignlessInteger()) {
       parser.emitError(keyLoc) << "unexpected type, expected i* or keyword";
-      return LLVMType();
+      return Type();
     }
     return LLVMIntegerType::getChecked(
         loc, maybeIntegerType.getIntOrFloatBitWidth());
@@ -438,9 +438,9 @@ static LLVMType parseTypeImpl(DialectAsmParser &parser,
   // Dispatch to concrete functions.
   StringRef key;
   if (failed(parser.parseKeyword(&key)))
-    return LLVMType();
+    return Type();
 
-  return StringSwitch<function_ref<LLVMType()>>(key)
+  return StringSwitch<function_ref<Type()>>(key)
       .Case("void", [&] { return LLVMVoidType::get(ctx); })
       .Case("half", [&] { return LLVMHalfType::get(ctx); })
       .Case("bfloat", [&] { return LLVMBFloatType::get(ctx); })
@@ -460,11 +460,11 @@ static LLVMType parseTypeImpl(DialectAsmParser &parser,
       .Case("struct", [&] { return parseStructType(parser, stack); })
       .Default([&] {
         parser.emitError(keyLoc) << "unknown LLVM type: " << key;
-        return LLVMType();
+        return Type();
       })();
 }
 
-LLVMType mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
+Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
   llvm::SetVector<StringRef> stack;
   return parseTypeImpl(parser, stack);
 }

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 3d75245a1fb3..c982abf8ad72 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -24,44 +24,32 @@
 using namespace mlir;
 using namespace mlir::LLVM;
 
-//===----------------------------------------------------------------------===//
-// LLVMType.
-//===----------------------------------------------------------------------===//
-
-bool LLVMType::classof(Type type) {
-  return llvm::isa<LLVMDialect>(type.getDialect());
-}
-
-LLVMDialect &LLVMType::getDialect() {
-  return static_cast<LLVMDialect &>(Type::getDialect());
-}
-
 //===----------------------------------------------------------------------===//
 // Array type.
 //===----------------------------------------------------------------------===//
 
-bool LLVMArrayType::isValidElementType(LLVMType type) {
+bool LLVMArrayType::isValidElementType(Type type) {
   return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
                    LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
 }
 
-LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) {
+LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) {
   assert(elementType && "expected non-null subtype");
   return Base::get(elementType.getContext(), elementType, numElements);
 }
 
-LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType,
+LLVMArrayType LLVMArrayType::getChecked(Location loc, Type elementType,
                                         unsigned numElements) {
   assert(elementType && "expected non-null subtype");
   return Base::getChecked(loc, elementType, numElements);
 }
 
-LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; }
+Type LLVMArrayType::getElementType() { return getImpl()->elementType; }
 
 unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; }
 
 LogicalResult
-LLVMArrayType::verifyConstructionInvariants(Location loc, LLVMType elementType,
+LLVMArrayType::verifyConstructionInvariants(Location loc, Type elementType,
                                             unsigned numElements) {
   if (!isValidElementType(elementType))
     return emitError(loc, "invalid array element type: ") << elementType;
@@ -72,52 +60,50 @@ LLVMArrayType::verifyConstructionInvariants(Location loc, LLVMType elementType,
 // Function type.
 //===----------------------------------------------------------------------===//
 
-bool LLVMFunctionType::isValidArgumentType(LLVMType type) {
+bool LLVMFunctionType::isValidArgumentType(Type type) {
   return !type.isa<LLVMVoidType, LLVMFunctionType>();
 }
 
-bool LLVMFunctionType::isValidResultType(LLVMType type) {
+bool LLVMFunctionType::isValidResultType(Type type) {
   return !type.isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>();
 }
 
-LLVMFunctionType LLVMFunctionType::get(LLVMType result,
-                                       ArrayRef<LLVMType> arguments,
+LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef<Type> arguments,
                                        bool isVarArg) {
   assert(result && "expected non-null result");
   return Base::get(result.getContext(), result, arguments, isVarArg);
 }
 
-LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result,
-                                              ArrayRef<LLVMType> arguments,
+LLVMFunctionType LLVMFunctionType::getChecked(Location loc, Type result,
+                                              ArrayRef<Type> arguments,
                                               bool isVarArg) {
   assert(result && "expected non-null result");
   return Base::getChecked(loc, result, arguments, isVarArg);
 }
 
-LLVMType LLVMFunctionType::getReturnType() {
-  return getImpl()->getReturnType();
-}
+Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); }
 
 unsigned LLVMFunctionType::getNumParams() {
   return getImpl()->getArgumentTypes().size();
 }
 
-LLVMType LLVMFunctionType::getParamType(unsigned i) {
+Type LLVMFunctionType::getParamType(unsigned i) {
   return getImpl()->getArgumentTypes()[i];
 }
 
 bool LLVMFunctionType::isVarArg() { return getImpl()->isVariadic(); }
 
-ArrayRef<LLVMType> LLVMFunctionType::getParams() {
+ArrayRef<Type> LLVMFunctionType::getParams() {
   return getImpl()->getArgumentTypes();
 }
 
-LogicalResult LLVMFunctionType::verifyConstructionInvariants(
-    Location loc, LLVMType result, ArrayRef<LLVMType> arguments, bool) {
+LogicalResult
+LLVMFunctionType::verifyConstructionInvariants(Location loc, Type result,
+                                               ArrayRef<Type> arguments, bool) {
   if (!isValidResultType(result))
     return emitError(loc, "invalid function result type: ") << result;
 
-  for (LLVMType arg : arguments)
+  for (Type arg : arguments)
     if (!isValidArgumentType(arg))
       return emitError(loc, "invalid function argument type: ") << arg;
 
@@ -150,27 +136,27 @@ LogicalResult LLVMIntegerType::verifyConstructionInvariants(Location loc,
 // Pointer type.
 //===----------------------------------------------------------------------===//
 
-bool LLVMPointerType::isValidElementType(LLVMType type) {
+bool LLVMPointerType::isValidElementType(Type type) {
   return !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType,
                    LLVMLabelType>();
 }
 
-LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) {
+LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) {
   assert(pointee && "expected non-null subtype");
   return Base::get(pointee.getContext(), pointee, addressSpace);
 }
 
-LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee,
+LLVMPointerType LLVMPointerType::getChecked(Location loc, Type pointee,
                                             unsigned addressSpace) {
   return Base::getChecked(loc, pointee, addressSpace);
 }
 
-LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
+Type LLVMPointerType::getElementType() { return getImpl()->pointeeType; }
 
 unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; }
 
 LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc,
-                                                            LLVMType pointee,
+                                                            Type pointee,
                                                             unsigned) {
   if (!isValidElementType(pointee))
     return emitError(loc, "invalid pointer element type: ") << pointee;
@@ -181,7 +167,7 @@ LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc,
 // Struct type.
 //===----------------------------------------------------------------------===//
 
-bool LLVMStructType::isValidElementType(LLVMType type) {
+bool LLVMStructType::isValidElementType(Type type) {
   return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
                    LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>();
 }
@@ -198,7 +184,7 @@ LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc,
 
 LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
                                                 StringRef name,
-                                                ArrayRef<LLVMType> elements,
+                                                ArrayRef<Type> elements,
                                                 bool isPacked) {
   std::string stringName = name.str();
   unsigned counter = 0;
@@ -214,13 +200,12 @@ LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context,
 }
 
 LLVMStructType LLVMStructType::getLiteral(MLIRContext *context,
-                                          ArrayRef<LLVMType> types,
-                                          bool isPacked) {
+                                          ArrayRef<Type> types, bool isPacked) {
   return Base::get(context, types, isPacked);
 }
 
 LLVMStructType LLVMStructType::getLiteralChecked(Location loc,
-                                                 ArrayRef<LLVMType> types,
+                                                 ArrayRef<Type> types,
                                                  bool isPacked) {
   return Base::getChecked(loc, types, isPacked);
 }
@@ -233,7 +218,7 @@ LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) {
   return Base::getChecked(loc, name, /*opaque=*/true);
 }
 
-LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) {
+LogicalResult LLVMStructType::setBody(ArrayRef<Type> types, bool isPacked) {
   assert(isIdentified() && "can only set bodies of identified structs");
   assert(llvm::all_of(types, LLVMStructType::isValidElementType) &&
          "expected valid body types");
@@ -248,7 +233,7 @@ bool LLVMStructType::isOpaque() {
 }
 bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); }
 StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); }
-ArrayRef<LLVMType> LLVMStructType::getBody() {
+ArrayRef<Type> LLVMStructType::getBody() {
   return isIdentified() ? getImpl()->getIdentifiedStructBody()
                         : getImpl()->getTypeList();
 }
@@ -258,10 +243,10 @@ LogicalResult LLVMStructType::verifyConstructionInvariants(Location, StringRef,
   return success();
 }
 
-LogicalResult
-LLVMStructType::verifyConstructionInvariants(Location loc,
-                                             ArrayRef<LLVMType> types, bool) {
-  for (LLVMType t : types)
+LogicalResult LLVMStructType::verifyConstructionInvariants(Location loc,
+                                                           ArrayRef<Type> types,
+                                                           bool) {
+  for (Type t : types)
     if (!isValidElementType(t))
       return emitError(loc, "invalid LLVM structure element type: ") << t;
 
@@ -272,7 +257,7 @@ LLVMStructType::verifyConstructionInvariants(Location loc,
 // Vector types.
 //===----------------------------------------------------------------------===//
 
-bool LLVMVectorType::isValidElementType(LLVMType type) {
+bool LLVMVectorType::isValidElementType(Type type) {
   return type.isa<LLVMIntegerType, LLVMPointerType>() ||
          mlir::LLVM::isCompatibleFloatingPointType(type);
 }
@@ -282,7 +267,7 @@ bool LLVMVectorType::classof(Type type) {
   return type.isa<LLVMFixedVectorType, LLVMScalableVectorType>();
 }
 
-LLVMType LLVMVectorType::getElementType() {
+Type LLVMVectorType::getElementType() {
   // Both derived classes share the implementation type.
   return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType;
 }
@@ -296,7 +281,7 @@ llvm::ElementCount LLVMVectorType::getElementCount() {
 
 /// Verifies that the type about to be constructed is well-formed.
 LogicalResult
-LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
+LLVMVectorType::verifyConstructionInvariants(Location loc, Type elementType,
                                              unsigned numElements) {
   if (numElements == 0)
     return emitError(loc, "the number of vector elements must be positive");
@@ -307,14 +292,14 @@ LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType,
   return success();
 }
 
-LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType,
+LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType,
                                              unsigned numElements) {
   assert(elementType && "expected non-null subtype");
   return Base::get(elementType.getContext(), elementType, numElements);
 }
 
 LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc,
-                                                    LLVMType elementType,
+                                                    Type elementType,
                                                     unsigned numElements) {
   assert(elementType && "expected non-null subtype");
   return Base::getChecked(loc, elementType, numElements);
@@ -324,14 +309,14 @@ unsigned LLVMFixedVectorType::getNumElements() {
   return getImpl()->numElements;
 }
 
-LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType,
+LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType,
                                                    unsigned minNumElements) {
   assert(elementType && "expected non-null subtype");
   return Base::get(elementType.getContext(), elementType, minNumElements);
 }
 
 LLVMScalableVectorType
-LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType,
+LLVMScalableVectorType::getChecked(Location loc, Type elementType,
                                    unsigned minNumElements) {
   assert(elementType && "expected non-null subtype");
   return Base::getChecked(loc, elementType, minNumElements);
@@ -351,16 +336,16 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
 
   return llvm::TypeSwitch<Type, llvm::TypeSize>(type)
       .Case<LLVMHalfType, LLVMBFloatType>(
-          [](LLVMType) { return llvm::TypeSize::Fixed(16); })
-      .Case<LLVMFloatType>([](LLVMType) { return llvm::TypeSize::Fixed(32); })
+          [](Type) { return llvm::TypeSize::Fixed(16); })
+      .Case<LLVMFloatType>([](Type) { return llvm::TypeSize::Fixed(32); })
       .Case<LLVMDoubleType, LLVMX86MMXType>(
-          [](LLVMType) { return llvm::TypeSize::Fixed(64); })
+          [](Type) { return llvm::TypeSize::Fixed(64); })
       .Case<LLVMIntegerType>([](LLVMIntegerType intTy) {
         return llvm::TypeSize::Fixed(intTy.getBitWidth());
       })
-      .Case<LLVMX86FP80Type>([](LLVMType) { return llvm::TypeSize::Fixed(80); })
+      .Case<LLVMX86FP80Type>([](Type) { return llvm::TypeSize::Fixed(80); })
       .Case<LLVMPPCFP128Type, LLVMFP128Type>(
-          [](LLVMType) { return llvm::TypeSize::Fixed(128); })
+          [](Type) { return llvm::TypeSize::Fixed(128); })
       .Case<LLVMVectorType>([](LLVMVectorType t) {
         llvm::TypeSize elementSize =
             getPrimitiveTypeSizeInBits(t.getElementType());

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index c2f689be493a..f8d6518b23aa 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -53,19 +53,18 @@ static ParseResult parseNVVMShflSyncBflyOp(OpAsmParser &parser,
       parser.addTypeToList(resultType, result.types))
     return failure();
 
-  auto type = resultType.cast<LLVM::LLVMType>();
   for (auto &attr : result.attributes) {
     if (attr.first != "return_value_and_is_valid")
       continue;
-    auto structType = type.dyn_cast<LLVM::LLVMStructType>();
+    auto structType = resultType.dyn_cast<LLVM::LLVMStructType>();
     if (structType && !structType.getBody().empty())
-      type = structType.getBody()[0];
+      resultType = structType.getBody()[0];
     break;
   }
 
   auto int32Ty =
       LLVM::LLVMIntegerType::get(parser.getBuilder().getContext(), 32);
-  return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty},
+  return parser.resolveOperands(ops, {int32Ty, resultType, int32Ty, int32Ty},
                                 parser.getNameLoc(), result.operands);
 }
 

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
index 1d147beafb40..a02a77d5a4c0 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
+++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h
@@ -72,7 +72,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
     Key(StringRef name, bool opaque)
         : name(name), identified(true), packed(false), opaque(opaque) {}
     /// Constructs a key for a literal struct.
-    Key(ArrayRef<LLVMType> types, bool packed)
+    Key(ArrayRef<Type> types, bool packed)
         : types(types), identified(false), packed(packed), opaque(false) {}
 
     /// Checks a specific property of the struct.
@@ -96,7 +96,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
     }
 
     /// Returns the list of type contained in the key of a literal struct.
-    ArrayRef<LLVMType> getTypeList() const {
+    ArrayRef<Type> getTypeList() const {
       assert(!isIdentified() &&
              "identified struct key cannot have a type list");
       return types;
@@ -138,7 +138,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
     }
 
   private:
-    ArrayRef<LLVMType> types;
+    ArrayRef<Type> types;
     StringRef name;
     bool identified;
     bool packed;
@@ -153,18 +153,18 @@ struct LLVMStructTypeStorage : public TypeStorage {
   }
 
   /// Returns the list of types (partially) identifying a literal struct.
-  ArrayRef<LLVMType> getTypeList() const {
+  ArrayRef<Type> getTypeList() const {
     // If this triggers, use getIdentifiedStructBody() instead.
     assert(!isIdentified() && "requested typelist on an identified struct");
-    return ArrayRef<LLVMType>(static_cast<const LLVMType *>(keyPtr), keySize());
+    return ArrayRef<Type>(static_cast<const Type *>(keyPtr), keySize());
   }
 
   /// Returns the list of types contained in an identified struct.
-  ArrayRef<LLVMType> getIdentifiedStructBody() const {
+  ArrayRef<Type> getIdentifiedStructBody() const {
     // If this triggers, use getTypeList() instead.
     assert(isIdentified() &&
            "requested struct body on a non-identified struct");
-    return ArrayRef<LLVMType>(identifiedBodyArray, identifiedBodySize());
+    return ArrayRef<Type>(identifiedBodyArray, identifiedBodySize());
   }
 
   /// Checks whether the struct is identified.
@@ -199,7 +199,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
   /// as initialized and can no longer be mutated.
   LLVMStructTypeStorage(const KeyTy &key) {
     if (!key.isIdentified()) {
-      ArrayRef<LLVMType> types = key.getTypeList();
+      ArrayRef<Type> types = key.getTypeList();
       keyPtr = static_cast<const void *>(types.data());
       setKeySize(types.size());
       llvm::Bitfield::set<KeyFlagPacked>(keySizeAndFlags, key.isPacked());
@@ -232,7 +232,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
   /// initialized, succeeds only if the body is equal to the current body. Fails
   /// if the struct is marked as intentionally opaque. The struct will be marked
   /// as initialized as a result of this operation and can no longer be changed.
-  LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef<LLVMType> body,
+  LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef<Type> body,
                        bool packed) {
     if (!isIdentified())
       return failure();
@@ -244,7 +244,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
                                                 true);
     llvm::Bitfield::set<MutableFlagPacked>(identifiedBodySizeAndFlags, packed);
 
-    ArrayRef<LLVMType> typesInAllocator = allocator.copyInto(body);
+    ArrayRef<Type> typesInAllocator = allocator.copyInto(body);
     identifiedBodyArray = typesInAllocator.data();
     setIdentifiedBodySize(typesInAllocator.size());
 
@@ -310,7 +310,7 @@ struct LLVMStructTypeStorage : public TypeStorage {
   const void *keyPtr = nullptr;
 
   /// Pointer to the first type contained in an identified struct.
-  const LLVMType *identifiedBodyArray = nullptr;
+  const Type *identifiedBodyArray = nullptr;
 
   /// Size of the uniquing key combined with identified/literal and
   /// packedness bits. Must only be used through the Key* bitfields.
@@ -328,12 +328,11 @@ struct LLVMStructTypeStorage : public TypeStorage {
 /// Type storage for LLVM dialect function types. These are uniqued using the
 /// list of types they contain and the vararg bit.
 struct LLVMFunctionTypeStorage : public TypeStorage {
-  using KeyTy = std::tuple<LLVMType, ArrayRef<LLVMType>, bool>;
+  using KeyTy = std::tuple<Type, ArrayRef<Type>, bool>;
 
   /// Construct a storage from the given components. The list is expected to be
   /// allocated in the context.
-  LLVMFunctionTypeStorage(LLVMType result, ArrayRef<LLVMType> arguments,
-                          bool variadic)
+  LLVMFunctionTypeStorage(Type result, ArrayRef<Type> arguments, bool variadic)
       : argumentTypes(arguments) {
     returnTypeAndVariadic.setPointerAndInt(result, variadic);
   }
@@ -359,19 +358,19 @@ struct LLVMFunctionTypeStorage : public TypeStorage {
   }
 
   /// Returns the list of function argument types.
-  ArrayRef<LLVMType> getArgumentTypes() const { return argumentTypes; }
+  ArrayRef<Type> getArgumentTypes() const { return argumentTypes; }
 
   /// Checks whether the function type is variadic.
   bool isVariadic() const { return returnTypeAndVariadic.getInt(); }
 
   /// Returns the function result type.
-  LLVMType getReturnType() const { return returnTypeAndVariadic.getPointer(); }
+  Type getReturnType() const { return returnTypeAndVariadic.getPointer(); }
 
 private:
   /// Function result type packed with the variadic bit.
-  llvm::PointerIntPair<LLVMType, 1, bool> returnTypeAndVariadic;
+  llvm::PointerIntPair<Type, 1, bool> returnTypeAndVariadic;
   /// Argument types.
-  ArrayRef<LLVMType> argumentTypes;
+  ArrayRef<Type> argumentTypes;
 };
 
 //===----------------------------------------------------------------------===//
@@ -402,7 +401,7 @@ struct LLVMIntegerTypeStorage : public TypeStorage {
 /// Storage type for LLVM dialect pointer types. These are uniqued by a pair of
 /// element type and address space.
 struct LLVMPointerTypeStorage : public TypeStorage {
-  using KeyTy = std::tuple<LLVMType, unsigned>;
+  using KeyTy = std::tuple<Type, unsigned>;
 
   LLVMPointerTypeStorage(const KeyTy &key)
       : pointeeType(std::get<0>(key)), addressSpace(std::get<1>(key)) {}
@@ -417,7 +416,7 @@ struct LLVMPointerTypeStorage : public TypeStorage {
     return std::make_tuple(pointeeType, addressSpace) == key;
   }
 
-  LLVMType pointeeType;
+  Type pointeeType;
   unsigned addressSpace;
 };
 
@@ -429,7 +428,7 @@ struct LLVMPointerTypeStorage : public TypeStorage {
 /// number: arrays, fixed and scalable vectors. The actual semantics of the
 /// type is defined by its kind.
 struct LLVMTypeAndSizeStorage : public TypeStorage {
-  using KeyTy = std::tuple<LLVMType, unsigned>;
+  using KeyTy = std::tuple<Type, unsigned>;
 
   LLVMTypeAndSizeStorage(const KeyTy &key)
       : elementType(std::get<0>(key)), numElements(std::get<1>(key)) {}
@@ -444,7 +443,7 @@ struct LLVMTypeAndSizeStorage : public TypeStorage {
     return std::make_tuple(elementType, numElements) == key;
   }
 
-  LLVMType elementType;
+  Type elementType;
   unsigned numElements;
 };
 

diff  --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index 9786751ef4b0..89e7decc8152 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -68,8 +68,8 @@ class Importer {
   LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block);
   /// Imports `inst` and populates instMap[inst] with the imported Value.
   LogicalResult processInstruction(llvm::Instruction *inst);
-  /// Creates an LLVMType for `type`.
-  LLVMType processType(llvm::Type *type);
+  /// Creates an LLVM-compatible MLIR type for `type`.
+  Type processType(llvm::Type *type);
   /// `value` is an SSA-use. Return the remapped version of `value` or a
   /// placeholder that will be remapped later if this is an instruction that
   /// has not yet been visited.
@@ -87,7 +87,7 @@ class Importer {
                                   SmallVectorImpl<Value> &blockArguments);
   /// Returns the builtin type equivalent to be used in attributes for the given
   /// LLVM IR dialect type.
-  Type getStdTypeForAttr(LLVMType type);
+  Type getStdTypeForAttr(Type type);
   /// Return `value` as an attribute to attach to a GlobalOp.
   Attribute getConstantAsAttr(llvm::Constant *value);
   /// Return `c` as an MLIR Value. This could either be a ConstantOp, or
@@ -150,8 +150,8 @@ Location Importer::processDebugLoc(const llvm::DebugLoc &loc,
                              context);
 }
 
-LLVMType Importer::processType(llvm::Type *type) {
-  if (LLVMType result = typeTranslator.translateType(type))
+Type Importer::processType(llvm::Type *type) {
+  if (Type result = typeTranslator.translateType(type))
     return result;
 
   // FIXME: Diagnostic should be able to natively handle types that have
@@ -168,7 +168,7 @@ LLVMType Importer::processType(llvm::Type *type) {
 // equivalents. Array types are converted to ranked tensors; nested array types
 // are converted to multi-dimensional tensors or vectors, depending on the
 // innermost type being a scalar or a vector.
-Type Importer::getStdTypeForAttr(LLVMType type) {
+Type Importer::getStdTypeForAttr(Type type) {
   if (!type)
     return nullptr;
 
@@ -252,7 +252,7 @@ Attribute Importer::getConstantAsAttr(llvm::Constant *value) {
 
   // Convert constant data to a dense elements attribute.
   if (auto *cd = dyn_cast<llvm::ConstantDataSequential>(value)) {
-    LLVMType type = processType(cd->getElementType());
+    Type type = processType(cd->getElementType());
     if (!type)
       return nullptr;
 
@@ -315,7 +315,7 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
   Attribute valueAttr;
   if (GV->hasInitializer())
     valueAttr = getConstantAsAttr(GV->getInitializer());
-  LLVMType type = processType(GV->getValueType());
+  Type type = processType(GV->getValueType());
   if (!type)
     return nullptr;
   GlobalOp op = b.create<GlobalOp>(
@@ -338,7 +338,7 @@ Value Importer::processConstant(llvm::Constant *c) {
   if (Attribute attr = getConstantAsAttr(c)) {
     // These constants can be represented as attributes.
     OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
-    LLVMType type = processType(c->getType());
+    Type type = processType(c->getType());
     if (!type)
       return nullptr;
     if (auto symbolRef = attr.dyn_cast<FlatSymbolRefAttr>())
@@ -347,7 +347,7 @@ Value Importer::processConstant(llvm::Constant *c) {
     return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
   }
   if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
-    LLVMType type = processType(cn->getType());
+    Type type = processType(cn->getType());
     if (!type)
       return nullptr;
     return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
@@ -370,7 +370,7 @@ Value Importer::processConstant(llvm::Constant *c) {
     return instMap[c] = instMap[i];
   }
   if (auto *ue = dyn_cast<llvm::UndefValue>(c)) {
-    LLVMType type = processType(ue->getType());
+    Type type = processType(ue->getType());
     if (!type)
       return nullptr;
     return instMap[c] = bEntry.create<UndefOp>(UnknownLoc::get(context), type);
@@ -388,7 +388,7 @@ Value Importer::processValue(llvm::Value *value) {
   // this instruction yet, create an unknown op and remap it later.
   if (isa<llvm::Instruction>(value)) {
     OperationState state(UnknownLoc::get(context), "llvm.unknown");
-    LLVMType type = processType(value->getType());
+    Type type = processType(value->getType());
     if (!type)
       return nullptr;
     state.addTypes(type);
@@ -578,7 +578,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     }
     state.addOperands(ops);
     if (!inst->getType()->isVoidTy()) {
-      LLVMType type = processType(inst->getType());
+      Type type = processType(inst->getType());
       if (!type)
         return failure();
       state.addTypes(type);
@@ -629,7 +629,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
     return success();
   }
   case llvm::Instruction::PHI: {
-    LLVMType type = processType(inst->getType());
+    Type type = processType(inst->getType());
     if (!type)
       return failure();
     v = b.getInsertionBlock()->addArgument(type);
@@ -648,7 +648,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
 
     SmallVector<Type, 2> tys;
     if (!ci->getType()->isVoidTy()) {
-      LLVMType type = processType(inst->getType());
+      Type type = processType(inst->getType());
       if (!type)
         return failure();
       tys.push_back(type);

diff  --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index c28588d32ad6..da9c734fbd80 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -762,17 +762,17 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
     // TODO: refactor function type creation which usually occurs in std-LLVM
     // conversion.
-    SmallVector<LLVM::LLVMType, 8> operandTypes;
+    SmallVector<Type, 8> operandTypes;
     operandTypes.reserve(inlineAsmOp.operands().size());
     for (auto t : inlineAsmOp.operands().getTypes())
-      operandTypes.push_back(t.cast<LLVM::LLVMType>());
+      operandTypes.push_back(t);
 
-    LLVM::LLVMType resultType;
+    Type resultType;
     if (inlineAsmOp.getNumResults() == 0) {
       resultType = LLVM::LLVMVoidType::get(mlirModule->getContext());
     } else {
       assert(inlineAsmOp.getNumResults() == 1);
-      resultType = inlineAsmOp.getResultTypes()[0].cast<LLVM::LLVMType>();
+      resultType = inlineAsmOp.getResultTypes()[0];
     }
     auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
     llvm::InlineAsm *inlineAsmInst =
@@ -813,7 +813,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
   }
 
   if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
-    llvm::Type *ty = convertType(lpOp.getType().cast<LLVMType>());
+    llvm::Type *ty = convertType(lpOp.getType());
     llvm::LandingPadInst *lpi =
         builder.CreateLandingPad(ty, lpOp.getNumOperands());
 
@@ -872,8 +872,8 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
                              blockMapping[switchOp.defaultDestination()],
                              switchOp.caseDestinations().size(), branchWeights);
 
-    auto *ty = llvm::cast<llvm::IntegerType>(
-        convertType(switchOp.value().getType().cast<LLVMType>()));
+    auto *ty =
+        llvm::cast<llvm::IntegerType>(convertType(switchOp.value().getType()));
     for (auto i :
          llvm::zip(switchOp.case_values()->cast<DenseIntElementsAttr>(),
                    switchOp.caseDestinations()))
@@ -927,8 +927,8 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
     unsigned numPredecessors =
         std::distance(predecessors.begin(), predecessors.end());
     for (auto arg : bb.getArguments()) {
-      auto wrappedType = arg.getType().dyn_cast<LLVM::LLVMType>();
-      if (!wrappedType)
+      auto wrappedType = arg.getType();
+      if (!isCompatibleType(wrappedType))
         return emitError(bb.front().getLoc(),
                          "block argument does not have an LLVM type");
       llvm::Type *type = convertType(wrappedType);
@@ -1094,7 +1094,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
             argIdx, LLVMDialect::getNoAliasAttrName())) {
       // NB: Attribute already verified to be boolean, so check if we can indeed
       // attach the attribute to this argument, based on its type.
-      auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
+      auto argTy = mlirArg.getType();
       if (!argTy.isa<LLVM::LLVMPointerType>())
         return func.emitError(
             "llvm.noalias attribute attached to LLVM non-pointer argument");
@@ -1106,7 +1106,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
             argIdx, LLVMDialect::getAlignAttrName())) {
       // NB: Attribute already verified to be int, so check if we can indeed
       // attach the attribute to this argument, based on its type.
-      auto argTy = mlirArg.getType().dyn_cast<LLVM::LLVMType>();
+      auto argTy = mlirArg.getType();
       if (!argTy.isa<LLVM::LLVMPointerType>())
         return func.emitError(
             "llvm.align attribute attached to LLVM non-pointer argument");
@@ -1190,7 +1190,7 @@ LogicalResult ModuleTranslation::convertFunctions() {
   return success();
 }
 
-llvm::Type *ModuleTranslation::convertType(LLVMType type) {
+llvm::Type *ModuleTranslation::convertType(Type type) {
   return typeTranslator.translateType(type);
 }
 

diff  --git a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp
index 2a4325f0df97..ecde56cc78f6 100644
--- a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp
@@ -27,14 +27,14 @@ class TypeToLLVMIRTranslatorImpl {
   TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {}
 
   /// Translates a single type.
-  llvm::Type *translateType(LLVM::LLVMType type) {
+  llvm::Type *translateType(Type type) {
     // If the conversion is already known, just return it.
     if (knownTranslations.count(type))
       return knownTranslations.lookup(type);
 
     // Dispatch to an appropriate function.
     llvm::Type *translated =
-        llvm::TypeSwitch<LLVM::LLVMType, llvm::Type *>(type)
+        llvm::TypeSwitch<Type, llvm::Type *>(type)
             .Case([this](LLVM::LLVMVoidType) {
               return llvm::Type::getVoidTy(context);
             })
@@ -76,7 +76,7 @@ class TypeToLLVMIRTranslatorImpl {
                   LLVM::LLVMStructType, LLVM::LLVMFixedVectorType,
                   LLVM::LLVMScalableVectorType>(
                 [this](auto type) { return this->translate(type); })
-            .Default([](LLVM::LLVMType t) -> llvm::Type * {
+            .Default([](Type t) -> llvm::Type * {
               llvm_unreachable("unknown LLVM dialect type");
             });
 
@@ -147,7 +147,7 @@ class TypeToLLVMIRTranslatorImpl {
   }
 
   /// Translates a list of types.
-  void translateTypes(ArrayRef<LLVM::LLVMType> types,
+  void translateTypes(ArrayRef<Type> types,
                       SmallVectorImpl<llvm::Type *> &result) {
     result.reserve(result.size() + types.size());
     for (auto type : types)
@@ -161,7 +161,7 @@ class TypeToLLVMIRTranslatorImpl {
   /// results to avoid repeated recursive calls and makes sure identified
   /// structs with the same name (that is, equal) are resolved to an existing
   /// type instead of creating a new type.
-  llvm::DenseMap<LLVM::LLVMType, llvm::Type *> knownTranslations;
+  llvm::DenseMap<Type, llvm::Type *> knownTranslations;
 };
 } // end namespace detail
 } // end namespace LLVM
@@ -172,12 +172,12 @@ LLVM::TypeToLLVMIRTranslator::TypeToLLVMIRTranslator(llvm::LLVMContext &context)
 
 LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() {}
 
-llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(LLVM::LLVMType type) {
+llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(Type type) {
   return impl->translateType(type);
 }
 
 unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment(
-    LLVM::LLVMType type, const llvm::DataLayout &layout) {
+    Type type, const llvm::DataLayout &layout) {
   return layout.getPrefTypeAlignment(translateType(type));
 }
 
@@ -191,12 +191,12 @@ class TypeFromLLVMIRTranslatorImpl {
   TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {}
 
   /// Translates the given type.
-  LLVM::LLVMType translateType(llvm::Type *type) {
+  Type translateType(llvm::Type *type) {
     if (knownTranslations.count(type))
       return knownTranslations.lookup(type);
 
-    LLVM::LLVMType translated =
-        llvm::TypeSwitch<llvm::Type *, LLVM::LLVMType>(type)
+    Type translated =
+        llvm::TypeSwitch<llvm::Type *, Type>(type)
             .Case<llvm::ArrayType, llvm::FunctionType, llvm::IntegerType,
                   llvm::PointerType, llvm::StructType, llvm::FixedVectorType,
                   llvm::ScalableVectorType>(
@@ -211,7 +211,7 @@ class TypeFromLLVMIRTranslatorImpl {
 private:
   /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature,
   /// type.
-  LLVM::LLVMType translatePrimitiveType(llvm::Type *type) {
+  Type translatePrimitiveType(llvm::Type *type) {
     if (type->isVoidTy())
       return LLVM::LLVMVoidType::get(&context);
     if (type->isHalfTy())
@@ -238,33 +238,33 @@ class TypeFromLLVMIRTranslatorImpl {
   }
 
   /// Translates the given array type.
-  LLVM::LLVMType translate(llvm::ArrayType *type) {
+  Type translate(llvm::ArrayType *type) {
     return LLVM::LLVMArrayType::get(translateType(type->getElementType()),
                                     type->getNumElements());
   }
 
   /// Translates the given function type.
-  LLVM::LLVMType translate(llvm::FunctionType *type) {
-    SmallVector<LLVM::LLVMType, 8> paramTypes;
+  Type translate(llvm::FunctionType *type) {
+    SmallVector<Type, 8> paramTypes;
     translateTypes(type->params(), paramTypes);
     return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()),
                                        paramTypes, type->isVarArg());
   }
 
   /// Translates the given integer type.
-  LLVM::LLVMType translate(llvm::IntegerType *type) {
+  Type translate(llvm::IntegerType *type) {
     return LLVM::LLVMIntegerType::get(&context, type->getBitWidth());
   }
 
   /// Translates the given pointer type.
-  LLVM::LLVMType translate(llvm::PointerType *type) {
+  Type translate(llvm::PointerType *type) {
     return LLVM::LLVMPointerType::get(translateType(type->getElementType()),
                                       type->getAddressSpace());
   }
 
   /// Translates the given structure type.
-  LLVM::LLVMType translate(llvm::StructType *type) {
-    SmallVector<LLVM::LLVMType, 8> subtypes;
+  Type translate(llvm::StructType *type) {
+    SmallVector<Type, 8> subtypes;
     if (type->isLiteral()) {
       translateTypes(type->subtypes(), subtypes);
       return LLVM::LLVMStructType::getLiteral(&context, subtypes,
@@ -286,20 +286,20 @@ class TypeFromLLVMIRTranslatorImpl {
   }
 
   /// Translates the given fixed-vector type.
-  LLVM::LLVMType translate(llvm::FixedVectorType *type) {
+  Type translate(llvm::FixedVectorType *type) {
     return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()),
                                           type->getNumElements());
   }
 
   /// Translates the given scalable-vector type.
-  LLVM::LLVMType translate(llvm::ScalableVectorType *type) {
+  Type translate(llvm::ScalableVectorType *type) {
     return LLVM::LLVMScalableVectorType::get(
         translateType(type->getElementType()), type->getMinNumElements());
   }
 
   /// Translates a list of types.
   void translateTypes(ArrayRef<llvm::Type *> types,
-                      SmallVectorImpl<LLVM::LLVMType> &result) {
+                      SmallVectorImpl<Type> &result) {
     result.reserve(result.size() + types.size());
     for (llvm::Type *type : types)
       result.push_back(translateType(type));
@@ -307,7 +307,7 @@ class TypeFromLLVMIRTranslatorImpl {
 
   /// Map of known translations. Serves as a cache and as recursion stopper for
   /// translating recursive structs.
-  llvm::DenseMap<llvm::Type *, LLVM::LLVMType> knownTranslations;
+  llvm::DenseMap<llvm::Type *, Type> knownTranslations;
 
   /// The context in which MLIR types are created.
   MLIRContext &context;
@@ -321,6 +321,6 @@ LLVM::TypeFromLLVMIRTranslator::TypeFromLLVMIRTranslator(MLIRContext &context)
 
 LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {}
 
-LLVM::LLVMType LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
+Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) {
   return impl->translateType(type);
 }

diff  --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index d02c252c0bf3..5e2f666c5b83 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -163,7 +163,7 @@ func @call_non_llvm() {
 // -----
 
 func @call_non_llvm_indirect(%arg0 : i32) {
-  // expected-error at +1 {{'llvm.call' op operand #0 must be LLVM dialect type, but got 'i32'}}
+  // expected-error at +1 {{'llvm.call' op operand #0 must be LLVM dialect-compatible type, but got 'i32'}}
   "llvm.call"(%arg0) : (i32) -> ()
 }
 

diff  --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
index f748b56b1bdf..428b61589134 100644
--- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
+++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp
@@ -135,7 +135,7 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
     } else if (isResultName(op, name)) {
       bs << formatv("valueMapping[op.{0}()]", name);
     } else if (name == "_resultType") {
-      bs << "convertType(op.getResult().getType().cast<LLVM::LLVMType>())";
+      bs << "convertType(op.getResult().getType())";
     } else if (name == "_hasResult") {
       bs << "opInst.getNumResults() == 1";
     } else if (name == "_location") {


        


More information about the Mlir-commits mailing list