[Mlir-commits] [mlir] [mlir][IR] Add `ScalarTypeInterface` and use as `VectorType` element type (PR #132400)

Matthias Springer llvmlistbot at llvm.org
Fri Mar 21 06:51:21 PDT 2025


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/132400

This commit adds a new builtin type interface: `ScalarTypeInterface`

Instead of maintaining a list of valid element types for `VectorType`, restrict valid element types to `ScalarTypeInterface`.

>From 85c0b6be5c046b342987ff3523836bd87806e971 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 21 Mar 2025 14:49:28 +0100
Subject: [PATCH] [mlir][IR] Add `ShapedTypeInterface`

---
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  2 +-
 .../include/mlir/IR/BuiltinDialectBytecode.td |  4 +-
 mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 35 +++++++++-
 mlir/include/mlir/IR/BuiltinTypes.h           |  6 +-
 mlir/include/mlir/IR/BuiltinTypes.td          | 39 ++++++++---
 mlir/include/mlir/IR/CommonTypeConstraints.td |  4 ++
 mlir/lib/AsmParser/TypeParser.cpp             |  9 ++-
 mlir/lib/CAPI/IR/BuiltinTypes.cpp             | 13 ++--
 .../AMDGPUToROCDL/AMDGPUToROCDL.cpp           |  9 +--
 .../ArithToAMDGPU/ArithToAMDGPU.cpp           |  6 +-
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  |  6 +-
 .../ArmNeon2dToIntr/ArmNeon2dToIntr.cpp       |  3 +-
 .../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp    |  3 +-
 .../Conversion/LLVMCommon/TypeConverter.cpp   |  5 +-
 .../Conversion/MathToSPIRV/MathToSPIRV.cpp    |  5 +-
 .../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp    |  3 +-
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |  5 +-
 .../Conversion/VectorToSCF/VectorToSCF.cpp    |  5 +-
 .../AMDGPU/Transforms/EmulateAtomics.cpp      |  2 +-
 .../Affine/Transforms/SuperVectorize.cpp      | 10 +--
 .../Arith/Transforms/EmulateWideInt.cpp       |  3 +-
 .../LowerContractionToSMMLAPattern.cpp        |  8 +--
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          |  2 +-
 .../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 15 ++--
 mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp    | 14 ++--
 mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp      |  8 ++-
 .../Linalg/Transforms/Vectorization.cpp       | 70 ++++++++++++-------
 .../Transforms/PolynomialApproximation.cpp    |  9 ++-
 .../NVGPU/TransformOps/NVGPUTransformOps.cpp  |  3 +-
 mlir/lib/Dialect/Quant/IR/QuantTypes.cpp      | 15 ++--
 .../Dialect/Quant/Utils/UniformSupport.cpp    |  3 +-
 .../Dialect/SPIRV/IR/SPIRVOpDefinition.cpp    |  3 +-
 .../SPIRV/Transforms/SPIRVConversion.cpp      | 13 ++--
 .../Transforms/UnifyAliasedResourcePass.cpp   |  6 +-
 mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp  |  2 +-
 .../Transforms/SparseVectorization.cpp        |  3 +-
 .../lib/Dialect/SparseTensor/Utils/Merger.cpp |  3 +-
 mlir/lib/Dialect/Traits.cpp                   |  2 +-
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 12 ++--
 .../Transforms/LowerVectorBroadcast.cpp       |  6 +-
 .../Vector/Transforms/LowerVectorContract.cpp |  3 +-
 .../Vector/Transforms/LowerVectorGather.cpp   |  6 +-
 .../Vector/Transforms/VectorDistribute.cpp    |  7 +-
 .../Transforms/VectorEmulateNarrowType.cpp    | 42 +++++++----
 .../Vector/Transforms/VectorTransforms.cpp    |  5 +-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp |  3 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  7 +-
 mlir/lib/IR/BuiltinTypes.cpp                  |  8 ++-
 mlir/lib/Target/LLVMIR/ModuleImport.cpp       |  6 +-
 .../SPIRV/Deserialization/Deserializer.cpp    |  3 +-
 mlir/test/IR/invalid-builtin-types.mlir       |  2 +-
 .../MathToVCIX/TestMathToVCIXConversion.cpp   |  3 +-
 .../Dialect/ArmSME/TileTypeConversionTest.cpp |  2 +-
 mlir/unittests/IR/ShapedTypeTest.cpp          | 10 +--
 54 files changed, 317 insertions(+), 164 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
 bool isMultipleOfSMETileVectorType(VectorType vType);
 
 /// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
 
 /// Erase trivially dead tile ops from a function.
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
   Type:$elementType
 )> {
   let printerPredicate = "!$_val.isScalable()";
+  // Note: Element type must implement ScalarTypeInterface.
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
 }
 
 def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
 )> {
   let printerPredicate = "$_val.isScalable()";
   // Note: order of serialization does not match order of builder.
-  let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+  let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
 }
 }
 
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
 
 include "mlir/IR/OpBase.td"
 
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Indication that this type is a scalar type.
+
+    The bitwidth of a scalar type is a fixed constant but may be unknown in the
+    absence of data layout information.
+
+    Scalar types are POD (plain-old-data) entities that have an in-memory
+    representation: scalar values can be loaded/store from/to memory, so
+    abstract types like function types or async tokens cannot be scalar types.
+
+    Scalar types should be limited to types that can lower to something that
+    egress dialects would consider a valid vector element type.
+  }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+      bitwidth that is known in the absence of data layout information.
+    }],
+    "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+  ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
   let cppNamespace = "::mlir";
   let description = [{
     This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
         scalableDims(other.getScalableDims()) {}
 
   /// Build from scratch.
-  Builder(ArrayRef<int64_t> shape, Type elementType,
+  Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
           ArrayRef<bool> scalableDims = {})
       : elementType(elementType), shape(shape), scalableDims(scalableDims) {}
 
@@ -286,7 +286,7 @@ class VectorType::Builder {
     return *this;
   }
 
-  Builder &setElementType(Type newElementType) {
+  Builder &setElementType(ScalarTypeInterface newElementType) {
     elementType = newElementType;
     return *this;
   }
@@ -312,7 +312,7 @@ class VectorType::Builder {
   }
 
 private:
-  Type elementType;
+  ScalarTypeInterface elementType;
   CopyOnWriteArrayRef<int64_t> shape;
   CopyOnWriteArrayRef<bool> scalableDims;
 };
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
         DeclareTypeInterfaceMethods<
             FloatTypeInterface,
             ["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+  let extraClassDeclaration = [{
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
+  }];
 }
 
 // Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
     : Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
+
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() {
+      return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+    }
   }];
 }
 
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
 // IndexType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
   let summary = "Integer-like type with unknown platform-dependent bit width";
   let description = [{
     Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
   let extraClassDeclaration = [{
     static IndexType get(MLIRContext *context);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return std::nullopt;
+    }
+
     /// Storage bit width used for IndexType by internal compiler data
     /// structures.
     static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
 // IntegerType
 //===----------------------------------------------------------------------===//
 
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+    : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
   let summary = "Integer type with arbitrary precision up to a fixed limit";
   let description = [{
     Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
     /// Return null if the scaled element type cannot be represented.
     IntegerType scaleElementBitwidth(unsigned scale);
 
+    /// Return the bitwidth of this type. This is an interface method of
+    /// ScalarTypeInterface.
+    std::optional<uint64_t> getInherentBitwidth() const {
+      return static_cast<uint64_t>(getWidth());
+    }
+
     /// Integer representation maximal bitwidth.
     /// Note: This is aligned with the maximum width of llvm::IntegerType.
     static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
 // VectorType
 //===----------------------------------------------------------------------===//
 
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
-  let cppFunctionName = "isValidVectorTypeElementType";
-}
-
 def Builtin_Vector : Builtin_Type<"Vector", "vector",
     [ShapedTypeInterface, ValueSemantics], "Type"> {
   let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    Builtin_VectorTypeElementType:$elementType,
+    AnyScalarType:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
     TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType,
+      "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
       CArg<"ArrayRef<bool>", "{}">:$scalableDims
     ), [{
       // While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
     list<Pred> predicateList = predicates;
 }
 
+def AnyScalarType : Type<
+    CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+        "scalable type", "::mlir::ScalarTypeInterface">;
+
 // Integer types.
 
 // Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+  auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+  if (!scalarElementType) {
+    emitWrongTokenError("vector type requires scalar element type");
+    return nullptr;
+  }
+
+  return getChecked<VectorType>(loc, dimensions, scalarElementType,
+                                scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
 MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
                            MlirType elementType) {
   return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-                              unwrap(elementType)));
+                              cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
                                   const int64_t *shape, MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType)));
+      cast<ScalarTypeInterface>(unwrap(elementType))));
 }
 
 MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
                                    const bool *scalable, MlirType elementType) {
-  return wrap(VectorType::get(
-      llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
-      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+  return wrap(
+      VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+                      cast<ScalarTypeInterface>(unwrap(elementType)),
+                      llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
 MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
                                           MlirType elementType) {
   return wrap(VectorType::getChecked(
       unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
-      unwrap(elementType),
+      cast<ScalarTypeInterface>(unwrap(elementType)),
       llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
 }
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
-    Type i32 = rewriter.getI32Type();
+    auto i32 = rewriter.getI32Type();
 
     // Get the type size in bytes.
     DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
 
   int64_t numBits =
       vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
-  Type i32 = rewriter.getI32Type();
+  auto i32 = rewriter.getI32Type();
   Type intrinsicInType = numBits <= 32
                              ? (Type)rewriter.getIntegerType(numBits)
                              : (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
           operand =
               rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
         }
-        auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
-            32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+        auto llvmVecType = typeConverter->convertType(
+            mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+                                  cast<ScalarTypeInterface>(llvmSrcIntType)));
         Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
         operand = rewriter.create<LLVM::InsertElementOp>(
             loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
-  VectorType truncResType = VectorType::get(4, outElemType);
+  VectorType truncResType =
+      VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
   if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  VectorType truncResType = VectorType::get(2, outElemType);
+  VectorType truncResType =
+      VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
   auto inVectorTy = dyn_cast<VectorType>(in.getType());
 
   // Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
         dstAttrType =
             RankedTensorType::get(dstAttrType.getShape(), dstElemType);
       else
-        dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+        dstAttrType = VectorType::get(dstAttrType.getShape(),
+                                      cast<ScalarTypeInterface>(dstElemType));
 
       dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
     }
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
       // cases. Extend them to 32-bit and do comparision then.
       Type type = rewriter.getI32Type();
       if (auto vectorType = dyn_cast<VectorType>(dstType))
-        type = VectorType::get(vectorType.getShape(), type);
+        type = VectorType::get(vectorType.getShape(),
+                               cast<ScalarTypeInterface>(type));
       Value extLhs =
           rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
       Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
   /// arm.neon.intr.sdot
   LogicalResult matchAndRewrite(Sdot2dOp op,
                                 PatternRewriter &rewriter) const override {
-    Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+    ScalarTypeInterface elemType =
+        cast<VectorType>(op.getB().getType()).getElementType();
     int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
                  Sdot2dOp::kReductionSize;
     VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
   auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
   Type i1Type = builder.getI1Type();
   if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
-    i1Type = VectorType::get(vecType.getShape(), i1Type);
+    i1Type =
+        VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
   Value cmp = builder.create<LLVM::FCmpOp>(
       loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
       lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
   if (!elementType)
     return {};
   if (type.getShape().empty())
-    return VectorType::get({1}, elementType);
-  Type vectorType = VectorType::get(type.getShape().back(), elementType,
+    return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+  Type vectorType = VectorType::get(type.getShape().back(),
+                                    cast<ScalarTypeInterface>(elementType),
                                     type.getScalableDims().back());
   assert(LLVM::isCompatibleVectorType(vectorType) &&
          "expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
     if (auto vectorType = dyn_cast<VectorType>(type)) {
       assert(vectorType.getRank() == 1);
       int count = vectorType.getNumElements();
-      intType = VectorType::get(count, intType);
+      intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
 
       SmallVector<Value> signSplat(count, signMask);
       signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
     auto operandType = adaptor.getRhs().getType();
     if (auto vectorType = dyn_cast<VectorType>(operandType)) {
       auto shape = vectorType.getShape();
-      intType = VectorType::get(shape, scalarIntType);
+      intType =
+          VectorType::get(shape, cast<ScalarTypeInterface>(scalarIntType));
     }
 
     // Per GL Pow extended instruction spec:
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 6e0adfc1e0ff3..d1e10ef2e80f7 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -152,7 +152,8 @@ static Value optionallyTruncateOrExtend(Location loc, Value value,
 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
                        const TypeConverter &typeConverter,
                        ConversionPatternRewriter &rewriter) {
-  auto vectorType = VectorType::get(numElements, toBroadcast.getType());
+  auto vectorType = VectorType::get(
+      numElements, cast<ScalarTypeInterface>(toBroadcast.getType()));
   auto llvmVectorType = typeConverter.convertType(vectorType);
   auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
   Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index eaefe9e385793..5ed167bde0899 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -631,7 +631,7 @@ getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo) {
   Type elType = regInfo.registerLLVMType;
   if (auto vecType = dyn_cast<VectorType>(elType))
     elType = vecType.getElementType();
-  return VectorType::get(shape, elType);
+  return VectorType::get(shape, cast<ScalarTypeInterface>(elType));
 }
 
 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
@@ -802,7 +802,8 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
   // must load each element individually.
   if (!isTransposeLoad) {
     if (!isa<VectorType>(loadedElType)) {
-      loadedElType = VectorType::get({1}, loadedElType);
+      loadedElType =
+          VectorType::get({1}, cast<ScalarTypeInterface>(loadedElType));
     }
 
     for (int i = 0; i < vectorType.getShape()[0]; i++) {
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 95db831185590..8cb35e1cab935 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1480,7 +1480,10 @@ struct UnrollTransferWriteConversion
               // argument into `transfer_write` to become a scalar. We solve
               // this by broadcasting the scalar to a 0D vector.
               xferVec = b.create<vector::BroadcastOp>(
-                  loc, VectorType::get({}, extracted.getType()), extracted);
+                  loc,
+                  VectorType::get(
+                      {}, cast<ScalarTypeInterface>(extracted.getType())),
+                  extracted);
             } else {
               xferVec = extracted;
             }
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index 7dd4be66d2bd6..87c94f23b5152 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -98,7 +98,7 @@ static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
 
   int64_t bitwidth =
       vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
-  Type allBitsType = rewriter.getIntegerType(bitwidth);
+  auto allBitsType = rewriter.getIntegerType(bitwidth);
   auto allBitsVecType = VectorType::get({1}, allBitsType);
   Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
   Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index eaaafaf68767e..38df408ad3b02 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -936,7 +936,8 @@ isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
 static VectorType getVectorType(Type scalarTy,
                                 const VectorizationStrategy *strategy) {
   assert(!isa<VectorType>(scalarTy) && "Expected scalar type");
-  return VectorType::get(strategy->vectorSizes, scalarTy);
+  return VectorType::get(strategy->vectorSizes,
+                         cast<ScalarTypeInterface>(scalarTy));
 }
 
 /// Tries to transform a scalar constant into a vector constant. Returns the
@@ -1195,7 +1196,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
                                       VectorizationState &state) {
   MemRefType memRefType = loadOp.getMemRefType();
   Type elementType = memRefType.getElementType();
-  auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType);
+  auto vectorType = VectorType::get(state.strategy->vectorSizes,
+                                    cast<ScalarTypeInterface>(elementType));
 
   // Replace map operands with operands from the vector loop nest.
   SmallVector<Value, 8> mapOperands;
@@ -1426,7 +1428,8 @@ static Operation *widenOp(Operation *op, VectorizationState &state) {
   SmallVector<Type, 8> vectorTypes;
   for (Value result : op->getResults())
     vectorTypes.push_back(
-        VectorType::get(state.strategy->vectorSizes, result.getType()));
+        VectorType::get(state.strategy->vectorSizes,
+                        cast<ScalarTypeInterface>(result.getType())));
 
   SmallVector<Value, 8> vectorOperands;
   for (Value operand : op->getOperands()) {
@@ -1832,7 +1835,6 @@ verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) {
   return success();
 }
 
-
 /// External utility to vectorize affine loops in 'loops' using the n-D
 /// vectorization factors in 'vectorSizes'. By default, each vectorization
 /// factor is applied inner-to-outer to the loops of each loop nest.
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 61f8d82a615d8..3d00efa72ec59 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -581,7 +581,8 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
     Type narrowTy =
         rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
     if (auto vecTy = dyn_cast<VectorType>(resultType))
-      narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
+      narrowTy = VectorType::get(vecTy.getShape(),
+                                 cast<ScalarTypeInterface>(narrowTy));
 
     // Sign or zero-extend the result. Let the matching conversion pattern
     // legalize the extension op.
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..013fb0019755b 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -162,10 +162,10 @@ class LowerContractionToSMMLAPattern
       Value tiledAcc =
           extractOperand(op.getAcc(), accPermutationMap, accOffsets);
 
-      auto inputElementType =
-          cast<ShapedType>(tiledLhs.getType()).getElementType();
-      auto accElementType =
-          cast<ShapedType>(tiledAcc.getType()).getElementType();
+      auto inputElementType = cast<ScalarTypeInterface>(
+          cast<ShapedType>(tiledLhs.getType()).getElementType());
+      auto accElementType = cast<ScalarTypeInterface>(
+          cast<ShapedType>(tiledAcc.getType()).getElementType());
       auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
       auto outputExpandedType = VectorType::get({2, 2}, accElementType);
 
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 1f7305a5f8141..3975b400950ec 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -111,7 +111,7 @@ bool isMultipleOfSMETileVectorType(VectorType vType) {
          vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
 }
 
-VectorType getSMETileTypeForElement(Type elementType) {
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType) {
   unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
   return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
 }
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index edd7f607f24f4..9f7082ca93605 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -89,7 +89,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
   return failure();
 }
 
-// This side effect models "program termination". 
+// This side effect models "program termination".
 void AssertOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -480,8 +480,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
                      ArrayRef<ValueRange> caseOperands) {
   DenseIntElementsAttr caseValuesAttr;
   if (!caseValues.empty()) {
-    ShapedType caseValueType = VectorType::get(
-        static_cast<int64_t>(caseValues.size()), value.getType());
+    ShapedType caseValueType =
+        VectorType::get(static_cast<int64_t>(caseValues.size()),
+                        cast<ScalarTypeInterface>(value.getType()));
     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
   }
   build(builder, result, value, defaultDestination, defaultOperands,
@@ -494,8 +495,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
                      ArrayRef<ValueRange> caseOperands) {
   DenseIntElementsAttr caseValuesAttr;
   if (!caseValues.empty()) {
-    ShapedType caseValueType = VectorType::get(
-        static_cast<int64_t>(caseValues.size()), value.getType());
+    ShapedType caseValueType =
+        VectorType::get(static_cast<int64_t>(caseValues.size()),
+                        cast<ScalarTypeInterface>(value.getType()));
     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
   }
   build(builder, result, value, defaultDestination, defaultOperands,
@@ -550,7 +552,8 @@ static ParseResult parseSwitchOpCases(
 
   if (!values.empty()) {
     ShapedType caseValueType =
-        VectorType::get(static_cast<int64_t>(values.size()), flagType);
+        VectorType::get(static_cast<int64_t>(values.size()),
+                        cast<ScalarTypeInterface>(flagType));
     caseValues = DenseIntElementsAttr::get(caseValueType, values);
   }
   return success();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5370de501a85c..833eb96baadc1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -548,8 +548,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
                      ArrayRef<int32_t> branchWeights) {
   DenseIntElementsAttr caseValuesAttr;
   if (!caseValues.empty()) {
-    ShapedType caseValueType = VectorType::get(
-        static_cast<int64_t>(caseValues.size()), value.getType());
+    ShapedType caseValueType =
+        VectorType::get(static_cast<int64_t>(caseValues.size()),
+                        cast<ScalarTypeInterface>(value.getType()));
     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
   }
 
@@ -564,8 +565,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
                      ArrayRef<int32_t> branchWeights) {
   DenseIntElementsAttr caseValuesAttr;
   if (!caseValues.empty()) {
-    ShapedType caseValueType = VectorType::get(
-        static_cast<int64_t>(caseValues.size()), value.getType());
+    ShapedType caseValueType =
+        VectorType::get(static_cast<int64_t>(caseValues.size()),
+                        cast<ScalarTypeInterface>(value.getType()));
     caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
   }
 
@@ -611,8 +613,8 @@ static ParseResult parseSwitchOpCases(
   if (failed(parser.parseCommaSeparatedList(parseCase)))
     return failure();
 
-  ShapedType caseValueType =
-      VectorType::get(static_cast<int64_t>(values.size()), flagType);
+  ShapedType caseValueType = VectorType::get(
+      static_cast<int64_t>(values.size()), cast<ScalarTypeInterface>(flagType));
   caseValues = DenseIntElementsAttr::get(caseValueType, values);
   return parser.parseRSquare();
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 8f39ede721c92..5e790de461cea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -946,7 +946,8 @@ Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
 
   // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
   // scalable/non-scalable.
-  return VectorType::get(numElements, elementType, {isScalable});
+  return VectorType::get(numElements, cast<ScalarTypeInterface>(elementType),
+                         {isScalable});
 }
 
 Type mlir::LLVM::getVectorType(Type elementType,
@@ -966,7 +967,7 @@ Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
                                    "to be either builtin or LLVM dialect type");
   if (useLLVM)
     return LLVMFixedVectorType::get(elementType, numElements);
-  return VectorType::get(numElements, elementType);
+  return VectorType::get(numElements, cast<ScalarTypeInterface>(elementType));
 }
 
 Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
@@ -981,7 +982,8 @@ Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
 
   // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
   // scalable/non-scalable.
-  return VectorType::get(numElements, elementType, /*scalableDims=*/true);
+  return VectorType::get(numElements, cast<ScalarTypeInterface>(elementType),
+                         /*scalableDims=*/true);
 }
 
 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2dcd897330d1e..e4909c4ee0f6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -249,7 +249,8 @@ struct VectorizationState {
       scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
     }
 
-    return VectorType::get(vectorShape, elementType, scalableDims);
+    return VectorType::get(vectorShape, cast<ScalarTypeInterface>(elementType),
+                           scalableDims);
   }
 
   /// Masks an operation with the canonical vector mask if the operation needs
@@ -1338,9 +1339,10 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
     assert(vecOperand && "Vector operand couldn't be found");
 
     if (firstMaxRankedType) {
-      auto vecType = VectorType::get(firstMaxRankedType.getShape(),
-                                     getElementTypeOrSelf(vecOperand.getType()),
-                                     firstMaxRankedType.getScalableDims());
+      auto vecType = VectorType::get(
+          firstMaxRankedType.getShape(),
+          cast<ScalarTypeInterface>(getElementTypeOrSelf(vecOperand.getType())),
+          firstMaxRankedType.getScalableDims());
       vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
     } else {
       vecOperands.push_back(vecOperand);
@@ -1351,7 +1353,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
   for (Type resultType : op->getResultTypes()) {
     resultTypes.push_back(
         firstMaxRankedType
-            ? VectorType::get(firstMaxRankedType.getShape(), resultType,
+            ? VectorType::get(firstMaxRankedType.getShape(),
+                              cast<ScalarTypeInterface>(resultType),
                               firstMaxRankedType.getScalableDims())
             : resultType);
   }
@@ -1632,8 +1635,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   // Create ShapeCastOp.
   SmallVector<int64_t> destShape(inputVectorSizes);
   destShape.append(innerTiles.begin(), innerTiles.end());
-  auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
-                                       packOp.getDestType().getElementType());
+  auto tiledPackType = VectorType::get(
+      getTiledPackShape(packOp, destShape),
+      cast<ScalarTypeInterface>(packOp.getDestType().getElementType()));
   auto shapeCastOp =
       rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
 
@@ -1768,8 +1772,9 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   // Collapse the vector to the size required by result.
   RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
       stripMineTensorType, packMetadata.reassociations);
-  mlir::VectorType vecCollapsedType =
-      VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+  mlir::VectorType vecCollapsedType = VectorType::get(
+      collapsedType.getShape(),
+      cast<ScalarTypeInterface>(collapsedType.getElementType()));
   vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
       loc, vecCollapsedType, transposeOp->getResult(0));
 
@@ -2473,8 +2478,10 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
       !VectorType::isValidElementType(dstElementType))
     return failure();
 
-  auto readType = VectorType::get(srcType.getShape(), srcElementType);
-  auto writeType = VectorType::get(dstType.getShape(), dstElementType);
+  auto readType = VectorType::get(srcType.getShape(),
+                                  cast<ScalarTypeInterface>(srcElementType));
+  auto writeType = VectorType::get(dstType.getShape(),
+                                   cast<ScalarTypeInterface>(dstElementType));
 
   Location loc = copyOp->getLoc();
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
@@ -2839,7 +2846,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
       return failure();
     }
   }
-  auto vecType = VectorType::get(vecShape, sourceType.getElementType());
+  auto vecType = VectorType::get(
+      vecShape, cast<ScalarTypeInterface>(sourceType.getElementType()));
 
   // 3. Generate TransferReadOp + TransferWriteOp
   ReifiedRankedShapedTypeDims reifiedSrcSizes;
@@ -2943,8 +2951,9 @@ struct PadOpVectorizationWithInsertSlicePattern
     if (insertOp.getDest() == padOp.getResult())
       return failure();
 
-    auto vecType = VectorType::get(padOp.getType().getShape(),
-                                   padOp.getType().getElementType());
+    auto vecType = VectorType::get(
+        padOp.getType().getShape(),
+        cast<ScalarTypeInterface>(padOp.getType().getElementType()));
     unsigned vecRank = vecType.getRank();
     unsigned tensorRank = insertOp.getType().getRank();
 
@@ -3366,9 +3375,12 @@ struct Conv1DGenerator
     Type lhsEltType = lhsShapedType.getElementType();
     Type rhsEltType = rhsShapedType.getElementType();
     Type resEltType = resShapedType.getElementType();
-    auto lhsType = VectorType::get(lhsShape, lhsEltType);
-    auto rhsType = VectorType::get(rhsShape, rhsEltType);
-    auto resType = VectorType::get(resShape, resEltType);
+    auto lhsType =
+        VectorType::get(lhsShape, cast<ScalarTypeInterface>(lhsEltType));
+    auto rhsType =
+        VectorType::get(rhsShape, cast<ScalarTypeInterface>(rhsEltType));
+    auto resType =
+        VectorType::get(resShape, cast<ScalarTypeInterface>(resEltType));
     // Zero padding with the corresponding dimensions for lhs, rhs and res.
     SmallVector<Value> lhsPadding(lhsShape.size(), zero);
     SmallVector<Value> rhsPadding(rhsShape.size(), zero);
@@ -3595,13 +3607,14 @@ struct Conv1DGenerator
          //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
          ((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
          cSize},
-        lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
+        cast<ScalarTypeInterface>(lhsEltType),
+        /*scalableDims=*/{false, false, scalableChDim});
     VectorType rhsType =
-        VectorType::get({kwSize, cSize}, rhsEltType,
+        VectorType::get({kwSize, cSize}, cast<ScalarTypeInterface>(rhsEltType),
                         /*scalableDims=*/{false, scalableChDim});
-    VectorType resType =
-        VectorType::get({nSize, wSize, cSize}, resEltType,
-                        /*scalableDims=*/{false, false, scalableChDim});
+    VectorType resType = VectorType::get(
+        {nSize, wSize, cSize}, cast<ScalarTypeInterface>(resEltType),
+        /*scalableDims=*/{false, false, scalableChDim});
 
     // Masks the input xfer Op along the channel dim, iff the corresponding
     // scalable flag is set.
@@ -3685,10 +3698,10 @@ struct Conv1DGenerator
     // Note - the scalable flags are ignored as flattening combined with
     // scalable vectorization is not supported.
     SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
-    auto lhsTypeAfterFlattening =
-        VectorType::get(inOutFlattenSliceSizes, lhsEltType);
-    auto resTypeAfterFlattening =
-        VectorType::get(inOutFlattenSliceSizes, resEltType);
+    auto lhsTypeAfterFlattening = VectorType::get(
+        inOutFlattenSliceSizes, cast<ScalarTypeInterface>(lhsEltType));
+    auto resTypeAfterFlattening = VectorType::get(
+        inOutFlattenSliceSizes, cast<ScalarTypeInterface>(resEltType));
 
     // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
     for (int64_t kw = 0; kw < kwSize; ++kw) {
@@ -3708,7 +3721,10 @@ struct Conv1DGenerator
         if (flatten) {
           // Un-flatten the output vector (restore the channel dimension)
           resVals[w] = rewriter.create<vector::ShapeCastOp>(
-              loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
+              loc,
+              VectorType::get(inOutSliceSizes,
+                              cast<ScalarTypeInterface>(resEltType)),
+              resVals[w]);
         }
       }
     }
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index a26e380232a91..bdfedabe23648 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -64,7 +64,8 @@ static std::optional<VectorShape> vectorShape(Value value) {
 // Broadcasts scalar type into vector type (iff shape is non-scalar).
 static Type broadcast(Type type, std::optional<VectorShape> shape) {
   assert(!isa<VectorType>(type) && "must be scalar type");
-  return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
+  return shape ? VectorType::get(shape->sizes, cast<ScalarTypeInterface>(type),
+                                 shape->scalableFlags)
                : type;
 }
 
@@ -156,7 +157,8 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
 
   // Stitch results together into one large vector.
   Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
-  Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
+  Type resultExpandedType =
+      VectorType::get(expandedShape, cast<ScalarTypeInterface>(resultEltType));
   Value result = builder.create<arith::ConstantOp>(
       resultExpandedType, builder.getZeroAttr(resultExpandedType));
 
@@ -166,7 +168,8 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
 
   // Reshape back to the original vector shape.
   return builder.create<vector::ShapeCastOp>(
-      VectorType::get(inputShape, resultEltType), result);
+      VectorType::get(inputShape, cast<ScalarTypeInterface>(resultEltType)),
+      result);
 }
 
 //----------------------------------------------------------------------------//
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 556922a64b093..b2b914cd66424 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -684,7 +684,8 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
   auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
 
   Type elementType = getElementTypeOrSelf(memref.getType());
-  auto vt = VectorType::get(vectorShape, elementType);
+  auto vt =
+      VectorType::get(vectorShape, cast<ScalarTypeInterface>(elementType));
   Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
   foreachIndividualVectorElement(
       res,
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 7c0d369648651..7281e0da7f7f2 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 #include "TypeDetail.h"
 #include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/Dialect/Quant/IR/QuantTypes.h"
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/MLIRContext.h"
@@ -34,7 +34,7 @@ double getMaxScale(Type expressedType) {
   return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
 }
 
-}  // namespace
+} // namespace
 
 unsigned QuantizedType::getFlags() const {
   return static_cast<ImplType *>(impl)->flags;
@@ -146,7 +146,7 @@ Type QuantizedType::castFromStorageType(Type candidateType) {
   if (llvm::isa<VectorType>(candidateType)) {
     // i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
     return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
-                           getStorageType());
+                           llvm::cast<ScalarTypeInterface>(getStorageType()));
   }
 
   return nullptr;
@@ -172,7 +172,8 @@ Type QuantizedType::castToStorageType(Type quantizedType) {
       return UnrankedTensorType::get(storageType);
     }
     if (llvm::isa<VectorType>(quantizedType)) {
-      return VectorType::get(sType.getShape(), storageType);
+      return VectorType::get(sType.getShape(),
+                             llvm::cast<ScalarTypeInterface>(storageType));
     }
   }
 
@@ -200,7 +201,8 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
     }
     if (llvm::isa<VectorType>(candidateType)) {
       // i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
-      return VectorType::get(candidateShapedType.getShape(), *this);
+      return VectorType::get(candidateShapedType.getShape(),
+                             llvm::cast<ScalarTypeInterface>(*this));
     }
   }
 
@@ -227,7 +229,8 @@ Type QuantizedType::castToExpressedType(Type quantizedType) {
       return UnrankedTensorType::get(expressedType);
     }
     if (llvm::isa<VectorType>(quantizedType)) {
-      return VectorType::get(sType.getShape(), expressedType);
+      return VectorType::get(sType.getShape(),
+                             llvm::cast<ScalarTypeInterface>(expressedType));
     }
   }
 
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index 62c7a7128d63a..7cd7bc8da8509 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -39,7 +39,8 @@ Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
   if (dyn_cast<UnrankedTensorType>(inputType))
     return UnrankedTensorType::get(elementalType);
   if (auto vectorType = dyn_cast<VectorType>(inputType))
-    return VectorType::get(vectorType.getShape(), elementalType);
+    return VectorType::get(vectorType.getShape(),
+                           cast<ScalarTypeInterface>(elementalType));
 
   // If the expressed types match, just use the new elemental type.
   if (elementalType.getExpressedType() == expressedType)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index d8dfe164458e2..0b0a309c02c3e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -42,7 +42,8 @@ static Type getUnaryOpResultType(Type operandType) {
   Builder builder(operandType.getContext());
   Type resultType = builder.getIntegerType(1);
   if (auto vecType = llvm::dyn_cast<VectorType>(operandType))
-    return VectorType::get(vecType.getNumElements(), resultType);
+    return VectorType::get(vecType.getNumElements(),
+                           cast<ScalarTypeInterface>(resultType));
   return resultType;
 }
 
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index a60410d01ac57..77305de066c1a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -366,7 +366,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
       return nullptr;
     }
 
-    return VectorType::get(type.getShape(), elementType);
+    return VectorType::get(type.getShape(),
+                           cast<ScalarTypeInterface>(elementType));
   }
 
   if (type.getRank() <= 1 && type.getNumElements() == 1)
@@ -392,7 +393,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   auto elementType =
       convertScalarType(targetEnv, options, scalarType, storageClass);
   if (elementType)
-    return VectorType::get(type.getShape(), elementType);
+    return VectorType::get(type.getShape(),
+                           cast<ScalarTypeInterface>(elementType));
   return nullptr;
 }
 
@@ -417,7 +419,7 @@ convertComplexType(const spirv::TargetEnv &targetEnv,
     return nullptr;
   }
 
-  return VectorType::get(2, elementType);
+  return VectorType::get(2, cast<ScalarTypeInterface>(elementType));
 }
 
 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
@@ -770,8 +772,9 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
   case spirv::BuiltIn::WorkgroupId:
   case spirv::BuiltIn::LocalInvocationId:
   case spirv::BuiltIn::GlobalInvocationId: {
-    auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
-                                           spirv::StorageClass::Input);
+    auto ptrType = spirv::PointerType::get(
+        VectorType::get({3}, cast<ScalarTypeInterface>(integerType)),
+        spirv::StorageClass::Input);
     std::string name = getBuiltinVarName(builtin, prefix, suffix);
     newVarOp =
         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 07cf26926a1df..e0337ae7e9162 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -496,7 +496,8 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
 
       Type vectorType = srcElemType;
       if (!isa<VectorType>(srcElemType))
-        vectorType = VectorType::get({ratio}, dstElemType);
+        vectorType =
+            VectorType::get({ratio}, cast<ScalarTypeInterface>(dstElemType));
 
       // If both the source and destination are vector types, we need to make
       // sure the scalar type is the same for composite construction later.
@@ -511,7 +512,8 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
             // SPIR-V.
             Type castType = srcElemVecType.getElementType();
             if (count > 1)
-              castType = VectorType::get({count}, castType);
+              castType =
+                  VectorType::get({count}, cast<ScalarTypeInterface>(castType));
 
             for (Value &c : components)
               c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index b19495bc37445..9a416eb15ef81 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -118,7 +118,7 @@ Type VulkanLayoutUtils::decorateType(VectorType vectorType,
   // times its scalar alignment."
   size = elementSize * numElements;
   alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4;
-  return VectorType::get(numElements, memberType);
+  return VectorType::get(numElements, cast<ScalarTypeInterface>(memberType));
 }
 
 Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index b2eca539194a8..54e43089dc8e3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -57,7 +57,8 @@ static bool isInvariantArg(BlockArgument arg, Block *block) {
 
 /// Constructs vector type for element type.
 static VectorType vectorType(VL vl, Type etp) {
-  return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
+  return VectorType::get(vl.vectorLength, cast<ScalarTypeInterface>(etp),
+                         vl.enableVLAVectorization);
 }
 
 /// Constructs vector type from a memref value.
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 0258f797143cb..acd508a9b35d3 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1236,7 +1236,8 @@ Type Merger::inferType(ExprId e, Value src) const {
   // Inspect source type. For vector types, apply the same
   // vectorization to the destination type.
   if (auto vtp = dyn_cast<VectorType>(src.getType()))
-    return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
+    return VectorType::get(vtp.getNumElements(), cast<ScalarTypeInterface>(dtp),
+                           vtp.getScalableDims());
   return dtp;
 }
 
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..cc26463a84d53 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -179,7 +179,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
 
   // Compose the final broadcasted type
   if (resultCompositeKind == VectorType::getTypeID())
-    return VectorType::get(resultShape, elementType);
+    return VectorType::get(resultShape, cast<ScalarTypeInterface>(elementType));
   if (resultCompositeKind == RankedTensorType::getTypeID())
     return RankedTensorType::get(resultShape, elementType);
   return elementType;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..73fe27bf12e1f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2419,7 +2419,8 @@ Value BroadcastOp::createOrFoldBroadcastOp(
   Location loc = value.getLoc();
   Type elementType = getElementTypeOrSelf(value.getType());
   VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
-  VectorType dstVectorType = VectorType::get(dstShape, elementType);
+  VectorType dstVectorType =
+      VectorType::get(dstShape, cast<ScalarTypeInterface>(elementType));
 
   // Step 2. If scalar -> dstShape broadcast, just do it.
   if (!srcVectorType) {
@@ -2481,7 +2482,8 @@ Value BroadcastOp::createOrFoldBroadcastOp(
              .empty() &&
          "unexpected \"dim-1\" broadcast");
 
-  VectorType broadcastType = VectorType::get(broadcastShape, elementType);
+  VectorType broadcastType =
+      VectorType::get(broadcastShape, cast<ScalarTypeInterface>(elementType));
   assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
              vector::BroadcastableToResult::Success &&
          "must be broadcastable");
@@ -5914,9 +5916,9 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result,
                        Value source) {
   result.addOperands(source);
   MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
-  VectorType vectorType =
-      VectorType::get(extractShape(memRefType),
-                      getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
+  VectorType vectorType = VectorType::get(
+      extractShape(memRefType), cast<ScalarTypeInterface>(getElementTypeOrSelf(
+                                    getElementTypeOrSelf(memRefType))));
   result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
                                   memRefType.getMemorySpace()));
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index fec3c6c52e5e4..225df20e37faf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -112,9 +112,9 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     //   %a = [%u, %v]
     //   ..
     //   %x = [%a,%b,%c,%d]
-    VectorType resType =
-        VectorType::get(dstType.getShape().drop_front(), eltType,
-                        dstType.getScalableDims().drop_front());
+    VectorType resType = VectorType::get(
+        dstType.getShape().drop_front(), cast<ScalarTypeInterface>(eltType),
+        dstType.getScalableDims().drop_front());
     Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
     if (m == 0) {
       // Stetch at start.
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index c6627b5ec0d77..c659bfc67a21b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1367,7 +1367,8 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
   mul = rew.create<vector::ShapeCastOp>(
       loc,
       VectorType::get({lhsRows, rhsColumns},
-                      getElementTypeOrSelf(op.getAcc().getType())),
+                      cast<ScalarTypeInterface>(
+                          getElementTypeOrSelf(op.getAcc().getType()))),
       mul);
 
   // ACC must be C(m, n) or C(n, m).
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..e22a3c0f4dfc6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
 /// ```mlir
 ///   %subview = memref.subview %M (...)
 ///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+///   strided<[3]>>
 /// ```
 /// ==>
 /// ```mlir
@@ -200,7 +201,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
     Location loc = op.getLoc();
     Type elemTy = resultTy.getElementType();
     // Vector type with a single element. Used to generate `vector.loads`.
-    VectorType elemVecTy = VectorType::get({1}, elemTy);
+    VectorType elemVecTy =
+        VectorType::get({1}, cast<ScalarTypeInterface>(elemTy));
 
     Value condMask = op.getMask();
     Value base = op.getBase();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..7953e91f65f4c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1192,7 +1192,8 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
         return failure();
       int64_t elementsPerLane =
           extractSrcType.getShape()[0] / warpOp.getWarpSize();
-      distributedVecType = VectorType::get({elementsPerLane}, elType);
+      distributedVecType =
+          VectorType::get({elementsPerLane}, cast<ScalarTypeInterface>(elType));
     } else {
       distributedVecType = extractSrcType;
     }
@@ -1711,8 +1712,8 @@ struct WarpOpReduction : public WarpDistributionPattern {
     // Return vector that will be reduced from the WarpExecuteOnLane0Op.
     unsigned operandIndex = yieldOperand->getOperandNumber();
     SmallVector<Value> yieldValues = {reductionOp.getVector()};
-    SmallVector<Type> retTypes = {
-        VectorType::get({numElements}, reductionOp.getType())};
+    SmallVector<Type> retTypes = {VectorType::get(
+        {numElements}, cast<ScalarTypeInterface>(reductionOp.getType()))};
     if (reductionOp.getAcc()) {
       yieldValues.push_back(reductionOp.getAcc());
       retTypes.push_back(reductionOp.getAcc().getType());
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index cf6efaa04ae44..19424e7854e5b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -297,12 +297,14 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
   auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
                                   emulatedElemTy.getIntOrFloatBitWidth();
   auto newLoad = rewriter.create<vector::LoadOp>(
-      loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
-      getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+      loc,
+      VectorType::get(numContainerElemsToLoad,
+                      cast<ScalarTypeInterface>(containerElemTy)),
+      base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
   return rewriter.create<vector::BitCastOp>(
       loc,
       VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
-                      emulatedElemTy),
+                      cast<ScalarTypeInterface>(emulatedElemTy)),
       newLoad);
 }
 
@@ -358,7 +360,8 @@ static void atomicRMW(OpBuilder &builder, Location loc,
 
   // Load the original value from memory, and cast it to the original element
   // type.
-  auto oneElemVecType = VectorType::get({1}, origValue.getType());
+  auto oneElemVecType =
+      VectorType::get({1}, cast<ScalarTypeInterface>(origValue.getType()));
   Value origVecValue = builder.create<vector::FromElementsOp>(
       loc, oneElemVecType, ValueRange{origValue});
 
@@ -378,8 +381,9 @@ static void nonAtomicRMW(OpBuilder &builder, Location loc,
                          VectorValue valueToStore, Value mask) {
   assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
 
-  auto oneElemVecType =
-      VectorType::get({1}, linearizedMemref.getType().getElementType());
+  auto oneElemVecType = VectorType::get(
+      {1},
+      cast<ScalarTypeInterface>(linearizedMemref.getType().getElementType()));
   Value origVecValue = builder.create<vector::LoadOp>(
       loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
   origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
@@ -559,7 +563,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
       // Basic case: storing full bytes.
       auto numElements = origElements / emulatedPerContainerElem;
       auto bitCast = rewriter.create<vector::BitCastOp>(
-          loc, VectorType::get(numElements, containerElemTy),
+          loc,
+          VectorType::get(numElements,
+                          cast<ScalarTypeInterface>(containerElemTy)),
           op.getValueToStore());
       rewriter.replaceOpWithNewOp<vector::StoreOp>(
           op, bitCast.getResult(), memrefBase,
@@ -665,7 +671,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
       auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
       auto storeType = VectorType::get(
           {originType.getNumElements() / emulatedPerContainerElem},
-          memrefElemType);
+          cast<ScalarTypeInterface>(memrefElemType));
       auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
                                                         fullWidthStorePart);
       rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
@@ -794,7 +800,8 @@ struct ConvertVectorMaskedStore final
 
     auto numElements = (origElements + emulatedPerContainerElem - 1) /
                        emulatedPerContainerElem;
-    auto newType = VectorType::get(numElements, containerElemTy);
+    auto newType = VectorType::get(numElements,
+                                   cast<ScalarTypeInterface>(containerElemTy));
     auto passThru = rewriter.create<arith::ConstantOp>(
         loc, newType, rewriter.getZeroAttr(newType));
 
@@ -803,7 +810,8 @@ struct ConvertVectorMaskedStore final
         newMask.value()->getResult(0), passThru);
 
     auto newBitCastType =
-        VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
+        VectorType::get(numElements * emulatedPerContainerElem,
+                        cast<ScalarTypeInterface>(emulatedElemTy));
     Value valueToStore =
         rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
     valueToStore = rewriter.create<arith::SelectOp>(
@@ -1032,9 +1040,11 @@ struct ConvertVectorMaskedLoad final
 
     auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
                                         emulatedPerContainerElem);
-    auto loadType = VectorType::get(numElements, containerElemTy);
+    auto loadType = VectorType::get(numElements,
+                                    cast<ScalarTypeInterface>(containerElemTy));
     auto newBitcastType =
-        VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
+        VectorType::get(numElements * emulatedPerContainerElem,
+                        cast<ScalarTypeInterface>(emulatedElemTy));
 
     auto emptyVector = rewriter.create<arith::ConstantOp>(
         loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -1188,13 +1198,17 @@ struct ConvertVectorTransferRead final
                                         emulatedPerContainerElem);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
-        loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
+        loc,
+        VectorType::get(numElements,
+                        cast<ScalarTypeInterface>(containerElemTy)),
+        adaptor.getSource(),
         getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
         newPadding);
 
     auto bitCast = rewriter.create<vector::BitCastOp>(
         loc,
-        VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
+        VectorType::get(numElements * emulatedPerContainerElem,
+                        cast<ScalarTypeInterface>(emulatedElemTy)),
         newRead);
 
     Value result = bitCast->getResult(0);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index dc46ed17a374d..1339e3f49eab2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -607,7 +607,8 @@ struct BubbleDownVectorBitCastForExtract
     Location loc = extractOp.getLoc();
     Value packedValue = rewriter.create<vector::ExtractOp>(
         loc, castOp.getSource(), index / expandRatio);
-    Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
+    Type packedVecType = VectorType::get(
+        /*shape=*/{1}, cast<ScalarTypeInterface>(packedValue.getType()));
     Value zero = rewriter.create<arith::ConstantOp>(
         loc, packedVecType, rewriter.getZeroAttr(packedVecType));
     packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
@@ -1059,7 +1060,7 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
   // If we can assume all indices fit in 32-bit, we perform the vector
   // comparison in 32-bit to get a higher degree of SIMD parallelism.
   // Otherwise we perform the vector comparison using 64-bit indices.
-  Type idxType =
+  ScalarTypeInterface idxType =
       force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
   DenseIntElementsAttr indicesAttr;
   if (dim == 0 && force32BitVectorIndices) {
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 7b56cd0cf0e91..fa0ac4e47bac9 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,7 +337,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == readShape.size() && "expected same ranks.");
   auto maskType = VectorType::get(readShape, builder.getI1Type());
-  auto vectorType = VectorType::get(readShape, padValue.getType());
+  auto vectorType =
+      VectorType::get(readShape, cast<ScalarTypeInterface>(padValue.getType()));
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = readShape.size();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 78c242571935c..31ccb14de0cf0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -368,8 +368,9 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
            "tensor descriptor shape is not distributable");
     if (chunkSize > 1)
       return VectorType::get({chunkSize / wiDataSize, wiDataSize},
-                             getElementType());
-    return VectorType::get({wiDataSize}, getElementType());
+                             llvm::cast<ScalarTypeInterface>(getElementType()));
+    return VectorType::get({wiDataSize},
+                           llvm::cast<ScalarTypeInterface>(getElementType()));
   }
 
   // Case 2: block loads/stores
@@ -393,7 +394,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
   tensorSize *= getArrayLength();
 
   return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
-                         getElementType());
+                         llvm::cast<ScalarTypeInterface>(getElementType()));
 }
 
 } // namespace xegpu
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..2c8d75aaf2594 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -211,11 +211,12 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
 //===----------------------------------------------------------------------===//
 
 bool VectorType::isValidElementType(Type t) {
-  return isValidVectorTypeElementType(t);
+  return llvm::isa<ScalarTypeInterface>(t);
 }
 
 LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
-                                 ArrayRef<int64_t> shape, Type elementType,
+                                 ArrayRef<int64_t> shape,
+                                 ScalarTypeInterface elementType,
                                  ArrayRef<bool> scalableDims) {
   if (!isValidElementType(elementType))
     return emitError()
@@ -248,7 +249,8 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
 
 VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
                                  Type elementType) const {
-  return VectorType::get(shape.value_or(getShape()), elementType,
+  return VectorType::get(shape.value_or(getShape()),
+                         llvm::cast<ScalarTypeInterface>(elementType),
                          getScalableDims());
 }
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index a07189ae1323c..54c540b28fdbd 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
   if (iface.isConvertibleInstruction(inst->getOpcode()))
     return iface.convertInstruction(odsBuilder, inst, llvmOperands,
                                     moduleImport);
-  // TODO: Implement the `convertInstruction` hooks in the
-  // `LLVMDialectLLVMIRImportInterface` and move the following include there.
+    // TODO: Implement the `convertInstruction` hooks in the
+    // `LLVMDialectLLVMIRImportInterface` and move the following include there.
 #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
   return failure();
 }
@@ -813,7 +813,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
 
   SmallVector<int64_t> shape(arrayShape);
   shape.push_back(numElements.getKnownMinValue());
-  return VectorType::get(shape, elementType);
+  return VectorType::get(shape, cast<ScalarTypeInterface>(elementType));
 }
 
 Type ModuleImport::getBuiltinTypeForAttr(Type type) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 273817d53d308..6b2726970e94e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -882,7 +882,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
       return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
              << operands[1];
     }
-    typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
+    typeMap[operands[0]] =
+        VectorType::get({operands[2]}, cast<ScalarTypeInterface>(elementTy));
   } break;
   case spirv::Opcode::OpTypePointer: {
     return processOpTypePointer(operands);
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 51612446d2e6a..5be76d7fd3878 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -115,7 +115,7 @@ func.func @illegaltype(i21312312323120) // expected-error {{invalid integer widt
 // -----
 
 // Test no nested vector.
-// expected-error at +1 {{failed to verify 'elementType': integer or index or floating-point}}
+// expected-error at +1 {{failed to verify 'elementType': vector type requires scalar element type}}
 func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
 
 // -----
diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
index 1e45ab57ebcc7..67eb832c471d3 100644
--- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
+++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
@@ -48,7 +48,8 @@ static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
   const unsigned lmul = eltCount * sew / 64;
 
   unsigned n = lmul > 8 ? llvm::Log2_32(lmul) - 2 : 1;
-  return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})};
+  return {n, VectorType::get({eltCount >> (n - 1)},
+                             cast<ScalarTypeInterface>(eltTy), {true})};
 }
 
 /// Replace math.cos(v) operation with vcix.v.iv(v).
diff --git a/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
index 305f879489813..98fb71d9355ee 100644
--- a/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
+++ b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
@@ -31,7 +31,7 @@ TEST_F(ArmSMETest, TestTileTypeConversion) {
   populateArmSMEToLLVMConversionPatterns(llvmConverterWithArmSMEConversion,
                                          patterns);
 
-  Type i32 = IntegerType::get(&context, 32);
+  auto i32 = IntegerType::get(&context, 32);
   auto smeTileType = VectorType::get({4, 4}, i32, {true, true});
 
   // An unmodified LLVMTypeConverer should fail to convert an ArmSME tile type.
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index bc4066ed210e8..abb33f5bedea1 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -110,10 +110,10 @@ TEST(ShapedTypeTest, CloneTensor) {
 TEST(ShapedTypeTest, CloneVector) {
   MLIRContext context;
 
-  Type i32 = IntegerType::get(&context, 32);
-  Type f32 = Float32Type::get(&context);
+  auto i32 = IntegerType::get(&context, 32);
+  auto f32 = Float32Type::get(&context);
 
-  Type vectorOriginalType = i32;
+  auto vectorOriginalType = i32;
   llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
   ShapedType vectorType =
       VectorType::get(vectorOriginalShape, vectorOriginalType);
@@ -123,7 +123,7 @@ TEST(ShapedTypeTest, CloneVector) {
   ASSERT_EQ(vectorType.clone(vectorNewShape),
             VectorType::get(vectorNewShape, vectorOriginalType));
   // Update type.
-  Type vectorNewType = f32;
+  auto vectorNewType = f32;
   ASSERT_NE(vectorOriginalType, vectorNewType);
   ASSERT_EQ(vectorType.clone(vectorNewType),
             VectorType::get(vectorOriginalShape, vectorNewType));
@@ -134,7 +134,7 @@ TEST(ShapedTypeTest, CloneVector) {
 
 TEST(ShapedTypeTest, VectorTypeBuilder) {
   MLIRContext context;
-  Type f32 = Float32Type::get(&context);
+  auto f32 = Float32Type::get(&context);
 
   SmallVector<int64_t> shape{2, 4, 8, 9, 1};
   SmallVector<bool> scalableDims{true, false, true, false, false};



More information about the Mlir-commits mailing list