[Mlir-commits] [mlir] [mlir][IR] Add `ScalarTypeInterface` and use as `VectorType` element type (PR #132400)
Matthias Springer
llvmlistbot at llvm.org
Fri Mar 21 06:51:21 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/132400
This commit adds a new builtin type interface: `ScalarTypeInterface`
Instead of maintaining a list of valid element types for `VectorType`, restrict valid element types to `ScalarTypeInterface`.
>From 85c0b6be5c046b342987ff3523836bd87806e971 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Fri, 21 Mar 2025 14:49:28 +0100
Subject: [PATCH] [mlir][IR] Add `ShapedTypeInterface`
---
.../include/mlir/Dialect/ArmSME/Utils/Utils.h | 2 +-
.../include/mlir/IR/BuiltinDialectBytecode.td | 4 +-
mlir/include/mlir/IR/BuiltinTypeInterfaces.td | 35 +++++++++-
mlir/include/mlir/IR/BuiltinTypes.h | 6 +-
mlir/include/mlir/IR/BuiltinTypes.td | 39 ++++++++---
mlir/include/mlir/IR/CommonTypeConstraints.td | 4 ++
mlir/lib/AsmParser/TypeParser.cpp | 9 ++-
mlir/lib/CAPI/IR/BuiltinTypes.cpp | 13 ++--
.../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 9 +--
.../ArithToAMDGPU/ArithToAMDGPU.cpp | 6 +-
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 6 +-
.../ArmNeon2dToIntr/ArmNeon2dToIntr.cpp | 3 +-
.../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 3 +-
.../Conversion/LLVMCommon/TypeConverter.cpp | 5 +-
.../Conversion/MathToSPIRV/MathToSPIRV.cpp | 5 +-
.../Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp | 3 +-
.../Conversion/VectorToGPU/VectorToGPU.cpp | 5 +-
.../Conversion/VectorToSCF/VectorToSCF.cpp | 5 +-
.../AMDGPU/Transforms/EmulateAtomics.cpp | 2 +-
.../Affine/Transforms/SuperVectorize.cpp | 10 +--
.../Arith/Transforms/EmulateWideInt.cpp | 3 +-
.../LowerContractionToSMMLAPattern.cpp | 8 +--
mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 2 +-
.../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 15 ++--
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 14 ++--
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 8 ++-
.../Linalg/Transforms/Vectorization.cpp | 70 ++++++++++++-------
.../Transforms/PolynomialApproximation.cpp | 9 ++-
.../NVGPU/TransformOps/NVGPUTransformOps.cpp | 3 +-
mlir/lib/Dialect/Quant/IR/QuantTypes.cpp | 15 ++--
.../Dialect/Quant/Utils/UniformSupport.cpp | 3 +-
.../Dialect/SPIRV/IR/SPIRVOpDefinition.cpp | 3 +-
.../SPIRV/Transforms/SPIRVConversion.cpp | 13 ++--
.../Transforms/UnifyAliasedResourcePass.cpp | 6 +-
mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp | 2 +-
.../Transforms/SparseVectorization.cpp | 3 +-
.../lib/Dialect/SparseTensor/Utils/Merger.cpp | 3 +-
mlir/lib/Dialect/Traits.cpp | 2 +-
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++--
.../Transforms/LowerVectorBroadcast.cpp | 6 +-
.../Vector/Transforms/LowerVectorContract.cpp | 3 +-
.../Vector/Transforms/LowerVectorGather.cpp | 6 +-
.../Vector/Transforms/VectorDistribute.cpp | 7 +-
.../Transforms/VectorEmulateNarrowType.cpp | 42 +++++++----
.../Vector/Transforms/VectorTransforms.cpp | 5 +-
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 3 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 7 +-
mlir/lib/IR/BuiltinTypes.cpp | 8 ++-
mlir/lib/Target/LLVMIR/ModuleImport.cpp | 6 +-
.../SPIRV/Deserialization/Deserializer.cpp | 3 +-
mlir/test/IR/invalid-builtin-types.mlir | 2 +-
.../MathToVCIX/TestMathToVCIXConversion.cpp | 3 +-
.../Dialect/ArmSME/TileTypeConversionTest.cpp | 2 +-
mlir/unittests/IR/ShapedTypeTest.cpp | 10 +--
54 files changed, 317 insertions(+), 164 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vectorType = dyn_cast<VectorType>(operandType)) {
auto shape = vectorType.getShape();
- intType = VectorType::get(shape, scalarIntType);
+ intType =
+ VectorType::get(shape, cast<ScalarTypeInterface>(scalarIntType));
}
// Per GL Pow extended instruction spec:
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 6e0adfc1e0ff3..d1e10ef2e80f7 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -152,7 +152,8 @@ static Value optionallyTruncateOrExtend(Location loc, Value value,
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
const TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
- auto vectorType = VectorType::get(numElements, toBroadcast.getType());
+ auto vectorType = VectorType::get(
+ numElements, cast<ScalarTypeInterface>(toBroadcast.getType()));
auto llvmVectorType = typeConverter.convertType(vectorType);
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
Value broadcasted = rewriter.create<LLVM::PoisonOp>(loc, llvmVectorType);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index eaefe9e385793..5ed167bde0899 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -631,7 +631,7 @@ getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) {
Type elType = regInfo.registerLLVMType;
if (auto vecType = dyn_cast<VectorType>(elType))
elType = vecType.getElementType();
- return VectorType::get(shape, elType);
+ return VectorType::get(shape, cast<ScalarTypeInterface>(elType));
}
/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
@@ -802,7 +802,8 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
// must load each element individually.
if (!isTransposeLoad) {
if (!isa<VectorType>(loadedElType)) {
- loadedElType = VectorType::get({1}, loadedElType);
+ loadedElType =
+ VectorType::get({1}, cast<ScalarTypeInterface>(loadedElType));
}
for (int i = 0; i < vectorType.getShape()[0]; i++) {
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 95db831185590..8cb35e1cab935 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1480,7 +1480,10 @@ struct UnrollTransferWriteConversion
// argument into `transfer_write` to become a scalar. We solve
// this by broadcasting the scalar to a 0D vector.
xferVec = b.create<vector::BroadcastOp>(
- loc, VectorType::get({}, extracted.getType()), extracted);
+ loc,
+ VectorType::get(
+ {}, cast<ScalarTypeInterface>(extracted.getType())),
+ extracted);
} else {
xferVec = extracted;
}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index 7dd4be66d2bd6..87c94f23b5152 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -98,7 +98,7 @@ static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
int64_t bitwidth =
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
- Type allBitsType = rewriter.getIntegerType(bitwidth);
+ auto allBitsType = rewriter.getIntegerType(bitwidth);
auto allBitsVecType = VectorType::get({1}, allBitsType);
Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index eaaafaf68767e..38df408ad3b02 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -936,7 +936,8 @@ isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops,
static VectorType getVectorType(Type scalarTy,
const VectorizationStrategy *strategy) {
assert(!isa<VectorType>(scalarTy) && "Expected scalar type");
- return VectorType::get(strategy->vectorSizes, scalarTy);
+ return VectorType::get(strategy->vectorSizes,
+ cast<ScalarTypeInterface>(scalarTy));
}
/// Tries to transform a scalar constant into a vector constant. Returns the
@@ -1195,7 +1196,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
VectorizationState &state) {
MemRefType memRefType = loadOp.getMemRefType();
Type elementType = memRefType.getElementType();
- auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType);
+ auto vectorType = VectorType::get(state.strategy->vectorSizes,
+ cast<ScalarTypeInterface>(elementType));
// Replace map operands with operands from the vector loop nest.
SmallVector<Value, 8> mapOperands;
@@ -1426,7 +1428,8 @@ static Operation *widenOp(Operation *op, VectorizationState &state) {
SmallVector<Type, 8> vectorTypes;
for (Value result : op->getResults())
vectorTypes.push_back(
- VectorType::get(state.strategy->vectorSizes, result.getType()));
+ VectorType::get(state.strategy->vectorSizes,
+ cast<ScalarTypeInterface>(result.getType())));
SmallVector<Value, 8> vectorOperands;
for (Value operand : op->getOperands()) {
@@ -1832,7 +1835,6 @@ verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) {
return success();
}
-
/// External utility to vectorize affine loops in 'loops' using the n-D
/// vectorization factors in 'vectorSizes'. By default, each vectorization
/// factor is applied inner-to-outer to the loops of each loop nest.
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
index 61f8d82a615d8..3d00efa72ec59 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
@@ -581,7 +581,8 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> {
Type narrowTy =
rewriter.getIntegerType(typeConverter->getMaxTargetIntBitWidth());
if (auto vecTy = dyn_cast<VectorType>(resultType))
- narrowTy = VectorType::get(vecTy.getShape(), narrowTy);
+ narrowTy = VectorType::get(vecTy.getShape(),
+ cast<ScalarTypeInterface>(narrowTy));
// Sign or zero-extend the result. Let the matching conversion pattern
// legalize the extension op.
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..013fb0019755b 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -162,10 +162,10 @@ class LowerContractionToSMMLAPattern
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
- auto inputElementType =
- cast<ShapedType>(tiledLhs.getType()).getElementType();
- auto accElementType =
- cast<ShapedType>(tiledAcc.getType()).getElementType();
+ auto inputElementType = cast<ScalarTypeInterface>(
+ cast<ShapedType>(tiledLhs.getType()).getElementType());
+ auto accElementType = cast<ScalarTypeInterface>(
+ cast<ShapedType>(tiledAcc.getType()).getElementType());
auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
auto outputExpandedType = VectorType::get({2, 2}, accElementType);
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 1f7305a5f8141..3975b400950ec 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -111,7 +111,7 @@ bool isMultipleOfSMETileVectorType(VectorType vType) {
vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
}
-VectorType getSMETileTypeForElement(Type elementType) {
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType) {
unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
}
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index edd7f607f24f4..9f7082ca93605 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -89,7 +89,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
return failure();
}
-// This side effect models "program termination".
+// This side effect models "program termination".
void AssertOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
@@ -480,8 +480,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
ArrayRef<ValueRange> caseOperands) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
- ShapedType caseValueType = VectorType::get(
- static_cast<int64_t>(caseValues.size()), value.getType());
+ ShapedType caseValueType =
+ VectorType::get(static_cast<int64_t>(caseValues.size()),
+ cast<ScalarTypeInterface>(value.getType()));
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
build(builder, result, value, defaultDestination, defaultOperands,
@@ -494,8 +495,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
ArrayRef<ValueRange> caseOperands) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
- ShapedType caseValueType = VectorType::get(
- static_cast<int64_t>(caseValues.size()), value.getType());
+ ShapedType caseValueType =
+ VectorType::get(static_cast<int64_t>(caseValues.size()),
+ cast<ScalarTypeInterface>(value.getType()));
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
build(builder, result, value, defaultDestination, defaultOperands,
@@ -550,7 +552,8 @@ static ParseResult parseSwitchOpCases(
if (!values.empty()) {
ShapedType caseValueType =
- VectorType::get(static_cast<int64_t>(values.size()), flagType);
+ VectorType::get(static_cast<int64_t>(values.size()),
+ cast<ScalarTypeInterface>(flagType));
caseValues = DenseIntElementsAttr::get(caseValueType, values);
}
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 5370de501a85c..833eb96baadc1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -548,8 +548,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
ArrayRef<int32_t> branchWeights) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
- ShapedType caseValueType = VectorType::get(
- static_cast<int64_t>(caseValues.size()), value.getType());
+ ShapedType caseValueType =
+ VectorType::get(static_cast<int64_t>(caseValues.size()),
+ cast<ScalarTypeInterface>(value.getType()));
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
@@ -564,8 +565,9 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
ArrayRef<int32_t> branchWeights) {
DenseIntElementsAttr caseValuesAttr;
if (!caseValues.empty()) {
- ShapedType caseValueType = VectorType::get(
- static_cast<int64_t>(caseValues.size()), value.getType());
+ ShapedType caseValueType =
+ VectorType::get(static_cast<int64_t>(caseValues.size()),
+ cast<ScalarTypeInterface>(value.getType()));
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
}
@@ -611,8 +613,8 @@ static ParseResult parseSwitchOpCases(
if (failed(parser.parseCommaSeparatedList(parseCase)))
return failure();
- ShapedType caseValueType =
- VectorType::get(static_cast<int64_t>(values.size()), flagType);
+ ShapedType caseValueType = VectorType::get(
+ static_cast<int64_t>(values.size()), cast<ScalarTypeInterface>(flagType));
caseValues = DenseIntElementsAttr::get(caseValueType, values);
return parser.parseRSquare();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 8f39ede721c92..5e790de461cea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -946,7 +946,8 @@ Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements,
// LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
// scalable/non-scalable.
- return VectorType::get(numElements, elementType, {isScalable});
+ return VectorType::get(numElements, cast<ScalarTypeInterface>(elementType),
+ {isScalable});
}
Type mlir::LLVM::getVectorType(Type elementType,
@@ -966,7 +967,7 @@ Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
"to be either builtin or LLVM dialect type");
if (useLLVM)
return LLVMFixedVectorType::get(elementType, numElements);
- return VectorType::get(numElements, elementType);
+ return VectorType::get(numElements, cast<ScalarTypeInterface>(elementType));
}
Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
@@ -981,7 +982,8 @@ Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
// LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
// scalable/non-scalable.
- return VectorType::get(numElements, elementType, /*scalableDims=*/true);
+ return VectorType::get(numElements, cast<ScalarTypeInterface>(elementType),
+ /*scalableDims=*/true);
}
llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2dcd897330d1e..e4909c4ee0f6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -249,7 +249,8 @@ struct VectorizationState {
scalableDims.append(scalableVecDims.begin(), scalableVecDims.end());
}
- return VectorType::get(vectorShape, elementType, scalableDims);
+ return VectorType::get(vectorShape, cast<ScalarTypeInterface>(elementType),
+ scalableDims);
}
/// Masks an operation with the canonical vector mask if the operation needs
@@ -1338,9 +1339,10 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
assert(vecOperand && "Vector operand couldn't be found");
if (firstMaxRankedType) {
- auto vecType = VectorType::get(firstMaxRankedType.getShape(),
- getElementTypeOrSelf(vecOperand.getType()),
- firstMaxRankedType.getScalableDims());
+ auto vecType = VectorType::get(
+ firstMaxRankedType.getShape(),
+ cast<ScalarTypeInterface>(getElementTypeOrSelf(vecOperand.getType())),
+ firstMaxRankedType.getScalableDims());
vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType));
} else {
vecOperands.push_back(vecOperand);
@@ -1351,7 +1353,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
for (Type resultType : op->getResultTypes()) {
resultTypes.push_back(
firstMaxRankedType
- ? VectorType::get(firstMaxRankedType.getShape(), resultType,
+ ? VectorType::get(firstMaxRankedType.getShape(),
+ cast<ScalarTypeInterface>(resultType),
firstMaxRankedType.getScalableDims())
: resultType);
}
@@ -1632,8 +1635,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
// Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
destShape.append(innerTiles.begin(), innerTiles.end());
- auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape),
- packOp.getDestType().getElementType());
+ auto tiledPackType = VectorType::get(
+ getTiledPackShape(packOp, destShape),
+ cast<ScalarTypeInterface>(packOp.getDestType().getElementType()));
auto shapeCastOp =
rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead);
@@ -1768,8 +1772,9 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
// Collapse the vector to the size required by result.
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
stripMineTensorType, packMetadata.reassociations);
- mlir::VectorType vecCollapsedType =
- VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+ mlir::VectorType vecCollapsedType = VectorType::get(
+ collapsedType.getShape(),
+ cast<ScalarTypeInterface>(collapsedType.getElementType()));
vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>(
loc, vecCollapsedType, transposeOp->getResult(0));
@@ -2473,8 +2478,10 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter,
!VectorType::isValidElementType(dstElementType))
return failure();
- auto readType = VectorType::get(srcType.getShape(), srcElementType);
- auto writeType = VectorType::get(dstType.getShape(), dstElementType);
+ auto readType = VectorType::get(srcType.getShape(),
+ cast<ScalarTypeInterface>(srcElementType));
+ auto writeType = VectorType::get(dstType.getShape(),
+ cast<ScalarTypeInterface>(dstElementType));
Location loc = copyOp->getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
@@ -2839,7 +2846,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
return failure();
}
}
- auto vecType = VectorType::get(vecShape, sourceType.getElementType());
+ auto vecType = VectorType::get(
+ vecShape, cast<ScalarTypeInterface>(sourceType.getElementType()));
// 3. Generate TransferReadOp + TransferWriteOp
ReifiedRankedShapedTypeDims reifiedSrcSizes;
@@ -2943,8 +2951,9 @@ struct PadOpVectorizationWithInsertSlicePattern
if (insertOp.getDest() == padOp.getResult())
return failure();
- auto vecType = VectorType::get(padOp.getType().getShape(),
- padOp.getType().getElementType());
+ auto vecType = VectorType::get(
+ padOp.getType().getShape(),
+ cast<ScalarTypeInterface>(padOp.getType().getElementType()));
unsigned vecRank = vecType.getRank();
unsigned tensorRank = insertOp.getType().getRank();
@@ -3366,9 +3375,12 @@ struct Conv1DGenerator
Type lhsEltType = lhsShapedType.getElementType();
Type rhsEltType = rhsShapedType.getElementType();
Type resEltType = resShapedType.getElementType();
- auto lhsType = VectorType::get(lhsShape, lhsEltType);
- auto rhsType = VectorType::get(rhsShape, rhsEltType);
- auto resType = VectorType::get(resShape, resEltType);
+ auto lhsType =
+ VectorType::get(lhsShape, cast<ScalarTypeInterface>(lhsEltType));
+ auto rhsType =
+ VectorType::get(rhsShape, cast<ScalarTypeInterface>(rhsEltType));
+ auto resType =
+ VectorType::get(resShape, cast<ScalarTypeInterface>(resEltType));
// Zero padding with the corresponding dimensions for lhs, rhs and res.
SmallVector<Value> lhsPadding(lhsShape.size(), zero);
SmallVector<Value> rhsPadding(rhsShape.size(), zero);
@@ -3595,13 +3607,14 @@ struct Conv1DGenerator
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
cSize},
- lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
+ cast<ScalarTypeInterface>(lhsEltType),
+ /*scalableDims=*/{false, false, scalableChDim});
VectorType rhsType =
- VectorType::get({kwSize, cSize}, rhsEltType,
+ VectorType::get({kwSize, cSize}, cast<ScalarTypeInterface>(rhsEltType),
/*scalableDims=*/{false, scalableChDim});
- VectorType resType =
- VectorType::get({nSize, wSize, cSize}, resEltType,
- /*scalableDims=*/{false, false, scalableChDim});
+ VectorType resType = VectorType::get(
+ {nSize, wSize, cSize}, cast<ScalarTypeInterface>(resEltType),
+ /*scalableDims=*/{false, false, scalableChDim});
// Masks the input xfer Op along the channel dim, iff the corresponding
// scalable flag is set.
@@ -3685,10 +3698,10 @@ struct Conv1DGenerator
// Note - the scalable flags are ignored as flattening combined with
// scalable vectorization is not supported.
SmallVector<int64_t> inOutFlattenSliceSizes = {nSize, wSizeStep * cSize};
- auto lhsTypeAfterFlattening =
- VectorType::get(inOutFlattenSliceSizes, lhsEltType);
- auto resTypeAfterFlattening =
- VectorType::get(inOutFlattenSliceSizes, resEltType);
+ auto lhsTypeAfterFlattening = VectorType::get(
+ inOutFlattenSliceSizes, cast<ScalarTypeInterface>(lhsEltType));
+ auto resTypeAfterFlattening = VectorType::get(
+ inOutFlattenSliceSizes, cast<ScalarTypeInterface>(resEltType));
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
for (int64_t kw = 0; kw < kwSize; ++kw) {
@@ -3708,7 +3721,10 @@ struct Conv1DGenerator
if (flatten) {
// Un-flatten the output vector (restore the channel dimension)
resVals[w] = rewriter.create<vector::ShapeCastOp>(
- loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
+ loc,
+ VectorType::get(inOutSliceSizes,
+ cast<ScalarTypeInterface>(resEltType)),
+ resVals[w]);
}
}
}
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index a26e380232a91..bdfedabe23648 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -64,7 +64,8 @@ static std::optional<VectorShape> vectorShape(Value value) {
// Broadcasts scalar type into vector type (iff shape is non-scalar).
static Type broadcast(Type type, std::optional<VectorShape> shape) {
assert(!isa<VectorType>(type) && "must be scalar type");
- return shape ? VectorType::get(shape->sizes, type, shape->scalableFlags)
+ return shape ? VectorType::get(shape->sizes, cast<ScalarTypeInterface>(type),
+ shape->scalableFlags)
: type;
}
@@ -156,7 +157,8 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Stitch results together into one large vector.
Type resultEltType = cast<VectorType>(results[0].getType()).getElementType();
- Type resultExpandedType = VectorType::get(expandedShape, resultEltType);
+ Type resultExpandedType =
+ VectorType::get(expandedShape, cast<ScalarTypeInterface>(resultEltType));
Value result = builder.create<arith::ConstantOp>(
resultExpandedType, builder.getZeroAttr(resultExpandedType));
@@ -166,7 +168,8 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Reshape back to the original vector shape.
return builder.create<vector::ShapeCastOp>(
- VectorType::get(inputShape, resultEltType), result);
+ VectorType::get(inputShape, cast<ScalarTypeInterface>(resultEltType)),
+ result);
}
//----------------------------------------------------------------------------//
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 556922a64b093..b2b914cd66424 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -684,7 +684,8 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
Type elementType = getElementTypeOrSelf(memref.getType());
- auto vt = VectorType::get(vectorShape, elementType);
+ auto vt =
+ VectorType::get(vectorShape, cast<ScalarTypeInterface>(elementType));
Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
foreachIndividualVectorElement(
res,
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 7c0d369648651..7281e0da7f7f2 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -6,9 +6,9 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/IR/Quant.h"
-#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
@@ -34,7 +34,7 @@ double getMaxScale(Type expressedType) {
return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble();
}
-} // namespace
+} // namespace
unsigned QuantizedType::getFlags() const {
return static_cast<ImplType *>(impl)->flags;
@@ -146,7 +146,7 @@ Type QuantizedType::castFromStorageType(Type candidateType) {
if (llvm::isa<VectorType>(candidateType)) {
// i.e. tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
return VectorType::get(llvm::cast<VectorType>(candidateType).getShape(),
- getStorageType());
+ llvm::cast<ScalarTypeInterface>(getStorageType()));
}
return nullptr;
@@ -172,7 +172,8 @@ Type QuantizedType::castToStorageType(Type quantizedType) {
return UnrankedTensorType::get(storageType);
}
if (llvm::isa<VectorType>(quantizedType)) {
- return VectorType::get(sType.getShape(), storageType);
+ return VectorType::get(sType.getShape(),
+ llvm::cast<ScalarTypeInterface>(storageType));
}
}
@@ -200,7 +201,8 @@ Type QuantizedType::castFromExpressedType(Type candidateType) {
}
if (llvm::isa<VectorType>(candidateType)) {
// i.e. tensor<4xf32> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">>
- return VectorType::get(candidateShapedType.getShape(), *this);
+ return VectorType::get(candidateShapedType.getShape(),
+ llvm::cast<ScalarTypeInterface>(*this));
}
}
@@ -227,7 +229,8 @@ Type QuantizedType::castToExpressedType(Type quantizedType) {
return UnrankedTensorType::get(expressedType);
}
if (llvm::isa<VectorType>(quantizedType)) {
- return VectorType::get(sType.getShape(), expressedType);
+ return VectorType::get(sType.getShape(),
+ llvm::cast<ScalarTypeInterface>(expressedType));
}
}
diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
index 62c7a7128d63a..7cd7bc8da8509 100644
--- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
+++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp
@@ -39,7 +39,8 @@ Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const {
if (dyn_cast<UnrankedTensorType>(inputType))
return UnrankedTensorType::get(elementalType);
if (auto vectorType = dyn_cast<VectorType>(inputType))
- return VectorType::get(vectorType.getShape(), elementalType);
+ return VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(elementalType));
// If the expressed types match, just use the new elemental type.
if (elementalType.getExpressedType() == expressedType)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
index d8dfe164458e2..0b0a309c02c3e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
@@ -42,7 +42,8 @@ static Type getUnaryOpResultType(Type operandType) {
Builder builder(operandType.getContext());
Type resultType = builder.getIntegerType(1);
if (auto vecType = llvm::dyn_cast<VectorType>(operandType))
- return VectorType::get(vecType.getNumElements(), resultType);
+ return VectorType::get(vecType.getNumElements(),
+ cast<ScalarTypeInterface>(resultType));
return resultType;
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index a60410d01ac57..77305de066c1a 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -366,7 +366,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- return VectorType::get(type.getShape(), elementType);
+ return VectorType::get(type.getShape(),
+ cast<ScalarTypeInterface>(elementType));
}
if (type.getRank() <= 1 && type.getNumElements() == 1)
@@ -392,7 +393,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
auto elementType =
convertScalarType(targetEnv, options, scalarType, storageClass);
if (elementType)
- return VectorType::get(type.getShape(), elementType);
+ return VectorType::get(type.getShape(),
+ cast<ScalarTypeInterface>(elementType));
return nullptr;
}
@@ -417,7 +419,7 @@ convertComplexType(const spirv::TargetEnv &targetEnv,
return nullptr;
}
- return VectorType::get(2, elementType);
+ return VectorType::get(2, cast<ScalarTypeInterface>(elementType));
}
/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
@@ -770,8 +772,9 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
case spirv::BuiltIn::WorkgroupId:
case spirv::BuiltIn::LocalInvocationId:
case spirv::BuiltIn::GlobalInvocationId: {
- auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
- spirv::StorageClass::Input);
+ auto ptrType = spirv::PointerType::get(
+ VectorType::get({3}, cast<ScalarTypeInterface>(integerType)),
+ spirv::StorageClass::Input);
std::string name = getBuiltinVarName(builtin, prefix, suffix);
newVarOp =
builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
index 07cf26926a1df..e0337ae7e9162 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp
@@ -496,7 +496,8 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
Type vectorType = srcElemType;
if (!isa<VectorType>(srcElemType))
- vectorType = VectorType::get({ratio}, dstElemType);
+ vectorType =
+ VectorType::get({ratio}, cast<ScalarTypeInterface>(dstElemType));
// If both the source and destination are vector types, we need to make
// sure the scalar type is the same for composite construction later.
@@ -511,7 +512,8 @@ struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
// SPIR-V.
Type castType = srcElemVecType.getElementType();
if (count > 1)
- castType = VectorType::get({count}, castType);
+ castType =
+ VectorType::get({count}, cast<ScalarTypeInterface>(castType));
for (Value &c : components)
c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
diff --git a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
index b19495bc37445..9a416eb15ef81 100644
--- a/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
+++ b/mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp
@@ -118,7 +118,7 @@ Type VulkanLayoutUtils::decorateType(VectorType vectorType,
// times its scalar alignment."
size = elementSize * numElements;
alignment = numElements == 2 ? elementAlignment * 2 : elementAlignment * 4;
- return VectorType::get(numElements, memberType);
+ return VectorType::get(numElements, cast<ScalarTypeInterface>(memberType));
}
Type VulkanLayoutUtils::decorateType(spirv::ArrayType arrayType,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index b2eca539194a8..54e43089dc8e3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -57,7 +57,8 @@ static bool isInvariantArg(BlockArgument arg, Block *block) {
/// Constructs vector type for element type.
static VectorType vectorType(VL vl, Type etp) {
- return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
+ return VectorType::get(vl.vectorLength, cast<ScalarTypeInterface>(etp),
+ vl.enableVLAVectorization);
}
/// Constructs vector type from a memref value.
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 0258f797143cb..acd508a9b35d3 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -1236,7 +1236,8 @@ Type Merger::inferType(ExprId e, Value src) const {
// Inspect source type. For vector types, apply the same
// vectorization to the destination type.
if (auto vtp = dyn_cast<VectorType>(src.getType()))
- return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims());
+ return VectorType::get(vtp.getNumElements(), cast<ScalarTypeInterface>(dtp),
+ vtp.getScalableDims());
return dtp;
}
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..cc26463a84d53 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -179,7 +179,7 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
// Compose the final broadcasted type
if (resultCompositeKind == VectorType::getTypeID())
- return VectorType::get(resultShape, elementType);
+ return VectorType::get(resultShape, cast<ScalarTypeInterface>(elementType));
if (resultCompositeKind == RankedTensorType::getTypeID())
return RankedTensorType::get(resultShape, elementType);
return elementType;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 8e0e723cf4ed3..73fe27bf12e1f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2419,7 +2419,8 @@ Value BroadcastOp::createOrFoldBroadcastOp(
Location loc = value.getLoc();
Type elementType = getElementTypeOrSelf(value.getType());
VectorType srcVectorType = llvm::dyn_cast<VectorType>(value.getType());
- VectorType dstVectorType = VectorType::get(dstShape, elementType);
+ VectorType dstVectorType =
+ VectorType::get(dstShape, cast<ScalarTypeInterface>(elementType));
// Step 2. If scalar -> dstShape broadcast, just do it.
if (!srcVectorType) {
@@ -2481,7 +2482,8 @@ Value BroadcastOp::createOrFoldBroadcastOp(
.empty() &&
"unexpected \"dim-1\" broadcast");
- VectorType broadcastType = VectorType::get(broadcastShape, elementType);
+ VectorType broadcastType =
+ VectorType::get(broadcastShape, cast<ScalarTypeInterface>(elementType));
assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==
vector::BroadcastableToResult::Success &&
"must be broadcastable");
@@ -5914,9 +5916,9 @@ void TypeCastOp::build(OpBuilder &builder, OperationState &result,
Value source) {
result.addOperands(source);
MemRefType memRefType = llvm::cast<MemRefType>(source.getType());
- VectorType vectorType =
- VectorType::get(extractShape(memRefType),
- getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
+ VectorType vectorType = VectorType::get(
+ extractShape(memRefType), cast<ScalarTypeInterface>(getElementTypeOrSelf(
+ getElementTypeOrSelf(memRefType))));
result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
memRefType.getMemorySpace()));
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index fec3c6c52e5e4..225df20e37faf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -112,9 +112,9 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
// %a = [%u, %v]
// ..
// %x = [%a,%b,%c,%d]
- VectorType resType =
- VectorType::get(dstType.getShape().drop_front(), eltType,
- dstType.getScalableDims().drop_front());
+ VectorType resType = VectorType::get(
+ dstType.getShape().drop_front(), cast<ScalarTypeInterface>(eltType),
+ dstType.getScalableDims().drop_front());
Value result = rewriter.create<ub::PoisonOp>(loc, dstType);
if (m == 0) {
// Stetch at start.
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index c6627b5ec0d77..c659bfc67a21b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1367,7 +1367,8 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
mul = rew.create<vector::ShapeCastOp>(
loc,
VectorType::get({lhsRows, rhsColumns},
- getElementTypeOrSelf(op.getAcc().getType())),
+ cast<ScalarTypeInterface>(
+ getElementTypeOrSelf(op.getAcc().getType()))),
mul);
// ACC must be C(m, n) or C(n, m).
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 3b38505becd18..e22a3c0f4dfc6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -107,7 +107,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
/// ```mlir
/// %subview = memref.subview %M (...)
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+/// strided<[3]>>
/// ```
/// ==>
/// ```mlir
@@ -200,7 +201,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
Location loc = op.getLoc();
Type elemTy = resultTy.getElementType();
// Vector type with a single element. Used to generate `vector.loads`.
- VectorType elemVecTy = VectorType::get({1}, elemTy);
+ VectorType elemVecTy =
+ VectorType::get({1}, cast<ScalarTypeInterface>(elemTy));
Value condMask = op.getMask();
Value base = op.getBase();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index e214257de2cdf..7953e91f65f4c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1192,7 +1192,8 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
return failure();
int64_t elementsPerLane =
extractSrcType.getShape()[0] / warpOp.getWarpSize();
- distributedVecType = VectorType::get({elementsPerLane}, elType);
+ distributedVecType =
+ VectorType::get({elementsPerLane}, cast<ScalarTypeInterface>(elType));
} else {
distributedVecType = extractSrcType;
}
@@ -1711,8 +1712,8 @@ struct WarpOpReduction : public WarpDistributionPattern {
// Return vector that will be reduced from the WarpExecuteOnLane0Op.
unsigned operandIndex = yieldOperand->getOperandNumber();
SmallVector<Value> yieldValues = {reductionOp.getVector()};
- SmallVector<Type> retTypes = {
- VectorType::get({numElements}, reductionOp.getType())};
+ SmallVector<Type> retTypes = {VectorType::get(
+ {numElements}, cast<ScalarTypeInterface>(reductionOp.getType()))};
if (reductionOp.getAcc()) {
yieldValues.push_back(reductionOp.getAcc());
retTypes.push_back(reductionOp.getAcc().getType());
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index cf6efaa04ae44..19424e7854e5b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -297,12 +297,14 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
emulatedElemTy.getIntOrFloatBitWidth();
auto newLoad = rewriter.create<vector::LoadOp>(
- loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
- getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+ loc,
+ VectorType::get(numContainerElemsToLoad,
+ cast<ScalarTypeInterface>(containerElemTy)),
+ base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return rewriter.create<vector::BitCastOp>(
loc,
VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
- emulatedElemTy),
+ cast<ScalarTypeInterface>(emulatedElemTy)),
newLoad);
}
@@ -358,7 +360,8 @@ static void atomicRMW(OpBuilder &builder, Location loc,
// Load the original value from memory, and cast it to the original element
// type.
- auto oneElemVecType = VectorType::get({1}, origValue.getType());
+ auto oneElemVecType =
+ VectorType::get({1}, cast<ScalarTypeInterface>(origValue.getType()));
Value origVecValue = builder.create<vector::FromElementsOp>(
loc, oneElemVecType, ValueRange{origValue});
@@ -378,8 +381,9 @@ static void nonAtomicRMW(OpBuilder &builder, Location loc,
VectorValue valueToStore, Value mask) {
assert(valueToStore.getType().getRank() == 1 && "expected 1-D vector");
- auto oneElemVecType =
- VectorType::get({1}, linearizedMemref.getType().getElementType());
+ auto oneElemVecType = VectorType::get(
+ {1},
+ cast<ScalarTypeInterface>(linearizedMemref.getType().getElementType()));
Value origVecValue = builder.create<vector::LoadOp>(
loc, oneElemVecType, linearizedMemref, ValueRange{linearizedIndex});
origVecValue = builder.create<vector::BitCastOp>(loc, valueToStore.getType(),
@@ -559,7 +563,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Basic case: storing full bytes.
auto numElements = origElements / emulatedPerContainerElem;
auto bitCast = rewriter.create<vector::BitCastOp>(
- loc, VectorType::get(numElements, containerElemTy),
+ loc,
+ VectorType::get(numElements,
+ cast<ScalarTypeInterface>(containerElemTy)),
op.getValueToStore());
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, bitCast.getResult(), memrefBase,
@@ -665,7 +671,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
auto storeType = VectorType::get(
{originType.getNumElements() / emulatedPerContainerElem},
- memrefElemType);
+ cast<ScalarTypeInterface>(memrefElemType));
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
fullWidthStorePart);
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
@@ -794,7 +800,8 @@ struct ConvertVectorMaskedStore final
auto numElements = (origElements + emulatedPerContainerElem - 1) /
emulatedPerContainerElem;
- auto newType = VectorType::get(numElements, containerElemTy);
+ auto newType = VectorType::get(numElements,
+ cast<ScalarTypeInterface>(containerElemTy));
auto passThru = rewriter.create<arith::ConstantOp>(
loc, newType, rewriter.getZeroAttr(newType));
@@ -803,7 +810,8 @@ struct ConvertVectorMaskedStore final
newMask.value()->getResult(0), passThru);
auto newBitCastType =
- VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
+ VectorType::get(numElements * emulatedPerContainerElem,
+ cast<ScalarTypeInterface>(emulatedElemTy));
Value valueToStore =
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
valueToStore = rewriter.create<arith::SelectOp>(
@@ -1032,9 +1040,11 @@ struct ConvertVectorMaskedLoad final
auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
emulatedPerContainerElem);
- auto loadType = VectorType::get(numElements, containerElemTy);
+ auto loadType = VectorType::get(numElements,
+ cast<ScalarTypeInterface>(containerElemTy));
auto newBitcastType =
- VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
+ VectorType::get(numElements * emulatedPerContainerElem,
+ cast<ScalarTypeInterface>(emulatedElemTy));
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
@@ -1188,13 +1198,17 @@ struct ConvertVectorTransferRead final
emulatedPerContainerElem);
auto newRead = rewriter.create<vector::TransferReadOp>(
- loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
+ loc,
+ VectorType::get(numElements,
+ cast<ScalarTypeInterface>(containerElemTy)),
+ adaptor.getSource(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
newPadding);
auto bitCast = rewriter.create<vector::BitCastOp>(
loc,
- VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
+ VectorType::get(numElements * emulatedPerContainerElem,
+ cast<ScalarTypeInterface>(emulatedElemTy)),
newRead);
Value result = bitCast->getResult(0);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index dc46ed17a374d..1339e3f49eab2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -607,7 +607,8 @@ struct BubbleDownVectorBitCastForExtract
Location loc = extractOp.getLoc();
Value packedValue = rewriter.create<vector::ExtractOp>(
loc, castOp.getSource(), index / expandRatio);
- Type packedVecType = VectorType::get(/*shape=*/{1}, packedValue.getType());
+ Type packedVecType = VectorType::get(
+ /*shape=*/{1}, cast<ScalarTypeInterface>(packedValue.getType()));
Value zero = rewriter.create<arith::ConstantOp>(
loc, packedVecType, rewriter.getZeroAttr(packedVecType));
packedValue = rewriter.create<vector::InsertOp>(loc, packedValue, zero,
@@ -1059,7 +1060,7 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
// If we can assume all indices fit in 32-bit, we perform the vector
// comparison in 32-bit to get a higher degree of SIMD parallelism.
// Otherwise we perform the vector comparison using 64-bit indices.
- Type idxType =
+ ScalarTypeInterface idxType =
force32BitVectorIndices ? rewriter.getI32Type() : rewriter.getI64Type();
DenseIntElementsAttr indicesAttr;
if (dim == 0 && force32BitVectorIndices) {
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 7b56cd0cf0e91..fa0ac4e47bac9 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,7 +337,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == readShape.size() && "expected same ranks.");
auto maskType = VectorType::get(readShape, builder.getI1Type());
- auto vectorType = VectorType::get(readShape, padValue.getType());
+ auto vectorType =
+ VectorType::get(readShape, cast<ScalarTypeInterface>(padValue.getType()));
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = readShape.size();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 78c242571935c..31ccb14de0cf0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -368,8 +368,9 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
"tensor descriptor shape is not distributable");
if (chunkSize > 1)
return VectorType::get({chunkSize / wiDataSize, wiDataSize},
- getElementType());
- return VectorType::get({wiDataSize}, getElementType());
+ llvm::cast<ScalarTypeInterface>(getElementType()));
+ return VectorType::get({wiDataSize},
+ llvm::cast<ScalarTypeInterface>(getElementType()));
}
// Case 2: block loads/stores
@@ -393,7 +394,7 @@ FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
tensorSize *= getArrayLength();
return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
- getElementType());
+ llvm::cast<ScalarTypeInterface>(getElementType()));
}
} // namespace xegpu
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 3924d082f0628..2c8d75aaf2594 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -211,11 +211,12 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
//===----------------------------------------------------------------------===//
bool VectorType::isValidElementType(Type t) {
- return isValidVectorTypeElementType(t);
+ return llvm::isa<ScalarTypeInterface>(t);
}
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
- ArrayRef<int64_t> shape, Type elementType,
+ ArrayRef<int64_t> shape,
+ ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims) {
if (!isValidElementType(elementType))
return emitError()
@@ -248,7 +249,8 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
- return VectorType::get(shape.value_or(getShape()), elementType,
+ return VectorType::get(shape.value_or(getShape()),
+ llvm::cast<ScalarTypeInterface>(elementType),
getScalableDims());
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index a07189ae1323c..54c540b28fdbd 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -139,8 +139,8 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder,
if (iface.isConvertibleInstruction(inst->getOpcode()))
return iface.convertInstruction(odsBuilder, inst, llvmOperands,
moduleImport);
- // TODO: Implement the `convertInstruction` hooks in the
- // `LLVMDialectLLVMIRImportInterface` and move the following include there.
+ // TODO: Implement the `convertInstruction` hooks in the
+ // `LLVMDialectLLVMIRImportInterface` and move the following include there.
#include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc"
return failure();
}
@@ -813,7 +813,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
SmallVector<int64_t> shape(arrayShape);
shape.push_back(numElements.getKnownMinValue());
- return VectorType::get(shape, elementType);
+ return VectorType::get(shape, cast<ScalarTypeInterface>(elementType));
}
Type ModuleImport::getBuiltinTypeForAttr(Type type) {
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 273817d53d308..6b2726970e94e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -882,7 +882,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
<< operands[1];
}
- typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
+ typeMap[operands[0]] =
+ VectorType::get({operands[2]}, cast<ScalarTypeInterface>(elementTy));
} break;
case spirv::Opcode::OpTypePointer: {
return processOpTypePointer(operands);
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 51612446d2e6a..5be76d7fd3878 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -115,7 +115,7 @@ func.func @illegaltype(i21312312323120) // expected-error {{invalid integer widt
// -----
// Test no nested vector.
-// expected-error at +1 {{failed to verify 'elementType': integer or index or floating-point}}
+// expected-error at +1 {{failed to verify 'elementType': vector type requires scalar element type}}
func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
// -----
diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
index 1e45ab57ebcc7..67eb832c471d3 100644
--- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
+++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp
@@ -48,7 +48,8 @@ static std::pair<unsigned, VectorType> legalizeVectorType(const Type &type) {
const unsigned lmul = eltCount * sew / 64;
unsigned n = lmul > 8 ? llvm::Log2_32(lmul) - 2 : 1;
- return {n, VectorType::get({eltCount >> (n - 1)}, eltTy, {true})};
+ return {n, VectorType::get({eltCount >> (n - 1)},
+ cast<ScalarTypeInterface>(eltTy), {true})};
}
/// Replace math.cos(v) operation with vcix.v.iv(v).
diff --git a/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
index 305f879489813..98fb71d9355ee 100644
--- a/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
+++ b/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
@@ -31,7 +31,7 @@ TEST_F(ArmSMETest, TestTileTypeConversion) {
populateArmSMEToLLVMConversionPatterns(llvmConverterWithArmSMEConversion,
patterns);
- Type i32 = IntegerType::get(&context, 32);
+ auto i32 = IntegerType::get(&context, 32);
auto smeTileType = VectorType::get({4, 4}, i32, {true, true});
// An unmodified LLVMTypeConverer should fail to convert an ArmSME tile type.
diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp
index bc4066ed210e8..abb33f5bedea1 100644
--- a/mlir/unittests/IR/ShapedTypeTest.cpp
+++ b/mlir/unittests/IR/ShapedTypeTest.cpp
@@ -110,10 +110,10 @@ TEST(ShapedTypeTest, CloneTensor) {
TEST(ShapedTypeTest, CloneVector) {
MLIRContext context;
- Type i32 = IntegerType::get(&context, 32);
- Type f32 = Float32Type::get(&context);
+ auto i32 = IntegerType::get(&context, 32);
+ auto f32 = Float32Type::get(&context);
- Type vectorOriginalType = i32;
+ auto vectorOriginalType = i32;
llvm::SmallVector<int64_t> vectorOriginalShape({10, 20});
ShapedType vectorType =
VectorType::get(vectorOriginalShape, vectorOriginalType);
@@ -123,7 +123,7 @@ TEST(ShapedTypeTest, CloneVector) {
ASSERT_EQ(vectorType.clone(vectorNewShape),
VectorType::get(vectorNewShape, vectorOriginalType));
// Update type.
- Type vectorNewType = f32;
+ auto vectorNewType = f32;
ASSERT_NE(vectorOriginalType, vectorNewType);
ASSERT_EQ(vectorType.clone(vectorNewType),
VectorType::get(vectorOriginalShape, vectorNewType));
@@ -134,7 +134,7 @@ TEST(ShapedTypeTest, CloneVector) {
TEST(ShapedTypeTest, VectorTypeBuilder) {
MLIRContext context;
- Type f32 = Float32Type::get(&context);
+ auto f32 = Float32Type::get(&context);
SmallVector<int64_t> shape{2, 4, 8, 9, 1};
SmallVector<bool> scalableDims{true, false, true, false, false};
More information about the Mlir-commits
mailing list