[Mlir-commits] [mlir] [mlir][IR] Add `ScalarTypeInterface` and use as `VectorType` element type (PR #132400)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 21 06:51:59 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit adds a new builtin type interface: `ScalarTypeInterface`
Instead of maintaining a list of valid element types for `VectorType`, restrict valid element types to `ScalarTypeInterface`.
---
Patch is 77.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132400.diff
54 Files Affected:
- (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+1-1)
- (modified) mlir/include/mlir/IR/BuiltinDialectBytecode.td (+3-1)
- (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+34-1)
- (modified) mlir/include/mlir/IR/BuiltinTypes.h (+3-3)
- (modified) mlir/include/mlir/IR/BuiltinTypes.td (+31-8)
- (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+4)
- (modified) mlir/lib/AsmParser/TypeParser.cpp (+8-1)
- (modified) mlir/lib/CAPI/IR/BuiltinTypes.cpp (+7-6)
- (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+5-4)
- (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+4-2)
- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+4-2)
- (modified) mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp (+2-1)
- (modified) mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp (+2-1)
- (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+3-2)
- (modified) mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp (+3-2)
- (modified) mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp (+2-1)
- (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+3-2)
- (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+4-1)
- (modified) mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp (+1-1)
- (modified) mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp (+6-4)
- (modified) mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp (+2-1)
- (modified) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp (+4-4)
- (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+1-1)
- (modified) mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp (+9-6)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+8-6)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+5-3)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+43-27)
- (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+6-3)
- (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+2-1)
- (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+9-6)
- (modified) mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp (+2-1)
- (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp (+2-1)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+8-5)
- (modified) mlir/lib/Dialect/SPIRV/Transforms/UnifyAliasedResourcePass.cpp (+4-2)
- (modified) mlir/lib/Dialect/SPIRV/Utils/LayoutUtils.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+2-1)
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1)
- (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7-5)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+3-3)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp (+2-1)
- (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+4-2)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+4-3)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+28-14)
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+3-2)
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+2-1)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+4-3)
- (modified) mlir/lib/IR/BuiltinTypes.cpp (+5-3)
- (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-3)
- (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+2-1)
- (modified) mlir/test/IR/invalid-builtin-types.mlir (+1-1)
- (modified) mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp (+2-1)
- (modified) mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp (+1-1)
- (modified) mlir/unittests/IR/ShapedTypeTest.cpp (+5-5)
``````````diff
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 1f40eb6fc693c..50b419bce78e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -68,7 +68,7 @@ scf::ForOp createLoopOverTileSlices(
bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
-VectorType getSMETileTypeForElement(Type elementType);
+VectorType getSMETileTypeForElement(ScalarTypeInterface elementType);
/// Erase trivially dead tile ops from a function.
void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
index 0208e8cdbf293..7e79a17119c5a 100644
--- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td
+++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td
@@ -295,6 +295,8 @@ def VectorType : DialectType<(type
Type:$elementType
)> {
let printerPredicate = "!$_val.isScalable()";
+ // Note: Element type must implement ScalarTypeInterface.
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType))";
}
def VectorTypeWithScalableDims : DialectType<(type
@@ -304,7 +306,7 @@ def VectorTypeWithScalableDims : DialectType<(type
)> {
let printerPredicate = "$_val.isScalable()";
// Note: order of serialization does not match order of builder.
- let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)";
+ let cBuilder = "get<$_resultType>(context, shape, llvm::cast<ScalarTypeInterface>(elementType), scalableDims)";
}
}
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..71bd4df762d2c 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,7 +16,40 @@
include "mlir/IR/OpBase.td"
-def FloatTypeInterface : TypeInterface<"FloatType"> {
+//===----------------------------------------------------------------------===//
+// ScalarTypeInterface
+//===----------------------------------------------------------------------===//
+
+def ScalarTypeInterface : TypeInterface<"ScalarTypeInterface"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ Indication that this type is a scalar type.
+
+ The bitwidth of a scalar type is a fixed constant but may be unknown in the
+ absence of data layout information.
+
+ Scalar types are POD (plain-old-data) entities that have an in-memory
+ representation: scalar values can be loaded/store from/to memory, so
+ abstract types like function types or async tokens cannot be scalar types.
+
+ Scalar types should be limited to types that can lower to something that
+ egress dialects would consider a valid vector element type.
+ }];
+
+ let methods = [
+ InterfaceMethod<[{
+ Return the bitwidth of this type, if it has an inherent bitwidth. I.e., a
+ bitwidth that is known in the absence of data layout information.
+ }],
+ "std::optional<uint64_t>", "getInherentBitwidth", (ins)>
+ ];
+}
+
+//===----------------------------------------------------------------------===//
+// FloatTypeInterface
+//===----------------------------------------------------------------------===//
+
+def FloatTypeInterface : TypeInterface<"FloatType", [ScalarTypeInterface]> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index df1e02732617d..a1950cda6318a 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -275,7 +275,7 @@ class VectorType::Builder {
scalableDims(other.getScalableDims()) {}
/// Build from scratch.
- Builder(ArrayRef<int64_t> shape, Type elementType,
+ Builder(ArrayRef<int64_t> shape, ScalarTypeInterface elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
@@ -286,7 +286,7 @@ class VectorType::Builder {
return *this;
}
- Builder &setElementType(Type newElementType) {
+ Builder &setElementType(ScalarTypeInterface newElementType) {
elementType = newElementType;
return *this;
}
@@ -312,7 +312,7 @@ class VectorType::Builder {
}
private:
- Type elementType;
+ ScalarTypeInterface elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index af474b3e3ec47..2f03d51913855 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -85,6 +85,14 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
+
+ let extraClassDeclaration = [{
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
+ }];
}
// Float types that are cached in MLIRContext.
@@ -93,6 +101,12 @@ class Builtin_CachedFloatType<string name, string mnemonic,
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
+
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() {
+ return static_cast<uint64_t>(::llvm::cast<FloatType>(*this).getWidth());
+ }
}];
}
@@ -447,7 +461,7 @@ def Builtin_Function : Builtin_Type<"Function", "function"> {
// IndexType
//===----------------------------------------------------------------------===//
-def Builtin_Index : Builtin_Type<"Index", "index"> {
+def Builtin_Index : Builtin_Type<"Index", "index", [ScalarTypeInterface]> {
let summary = "Integer-like type with unknown platform-dependent bit width";
let description = [{
Syntax:
@@ -467,6 +481,12 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
let extraClassDeclaration = [{
static IndexType get(MLIRContext *context);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return std::nullopt;
+ }
+
/// Storage bit width used for IndexType by internal compiler data
/// structures.
static constexpr unsigned kInternalStorageBitWidth = 64;
@@ -477,7 +497,8 @@ def Builtin_Index : Builtin_Type<"Index", "index"> {
// IntegerType
//===----------------------------------------------------------------------===//
-def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
+def Builtin_Integer
+ : Builtin_Type<"Integer", "integer", [ScalarTypeInterface]> {
let summary = "Integer type with arbitrary precision up to a fixed limit";
let description = [{
Syntax:
@@ -531,6 +552,12 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
/// Return null if the scaled element type cannot be represented.
IntegerType scaleElementBitwidth(unsigned scale);
+ /// Return the bitwidth of this type. This is an interface method of
+ /// ScalarTypeInterface.
+ std::optional<uint64_t> getInherentBitwidth() const {
+ return static_cast<uint64_t>(getWidth());
+ }
+
/// Integer representation maximal bitwidth.
/// Note: This is aligned with the maximum width of llvm::IntegerType.
static constexpr unsigned kMaxWidth = (1 << 24) - 1;
@@ -1249,10 +1276,6 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//
-def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
- let cppFunctionName = "isValidVectorTypeElementType";
-}
-
def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
@@ -1303,12 +1326,12 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
- Builtin_VectorTypeElementType:$elementType,
+ AnyScalarType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
TypeBuilderWithInferredContext<(ins
- "ArrayRef<int64_t>":$shape, "Type":$elementType,
+ "ArrayRef<int64_t>":$shape, "ScalarTypeInterface":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
), [{
// While `scalableDims` is optional, its default value should be
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 601517717978e..709c7dc213ff4 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -203,6 +203,10 @@ class ConfinedType<Type type, list<Pred> predicates, string summary = "",
list<Pred> predicateList = predicates;
}
+def AnyScalarType : Type<
+ CPred<"::llvm::isa<::mlir::ScalarTypeInterface>($_self)">,
+ "scalable type", "::mlir::ScalarTypeInterface">;
+
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 21bb0ec3d0d51..1ccd16e1b3abd 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -491,7 +491,14 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
- return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
+ auto scalarElementType = dyn_cast<ScalarTypeInterface>(elementType);
+ if (!scalarElementType) {
+ emitWrongTokenError("vector type requires scalar element type");
+ return nullptr;
+ }
+
+ return getChecked<VectorType>(loc, dimensions, scalarElementType,
+ scalableDims);
}
/// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
index a080adf0f8103..80e8c239689bb 100644
--- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp
+++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp
@@ -364,21 +364,22 @@ bool mlirTypeIsAVector(MlirType type) {
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType)));
+ cast<ScalarTypeInterface>(unwrap(elementType))));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
- return wrap(VectorType::get(
- llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
- llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
+ return wrap(
+ VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
+ llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
@@ -387,7 +388,7 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
- unwrap(elementType),
+ cast<ScalarTypeInterface>(unwrap(elementType)),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 949424db7c4d6..bedebabc49087 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -283,7 +283,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
// Get the type size in bytes.
DataLayout dataLayout = DataLayout::closest(gpuOp);
@@ -560,7 +560,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
int64_t numBits =
vectorType.getNumElements() * elemType.getIntOrFloatBitWidth();
- Type i32 = rewriter.getI32Type();
+ auto i32 = rewriter.getI32Type();
Type intrinsicInType = numBits <= 32
? (Type)rewriter.getIntegerType(numBits)
: (Type)VectorType::get(numBits / 32, i32);
@@ -1099,8 +1099,9 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern<DPPOp> {
operand =
rewriter.create<LLVM::BitcastOp>(loc, llvmSrcIntType, operand);
}
- auto llvmVecType = typeConverter->convertType(mlir::VectorType::get(
- 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType));
+ auto llvmVecType = typeConverter->convertType(
+ mlir::VectorType::get(32 / operandType.getIntOrFloatBitWidth(),
+ cast<ScalarTypeInterface>(llvmSrcIntType)));
Value undefVec = rewriter.create<LLVM::UndefOp>(loc, llvmVecType);
operand = rewriter.create<LLVM::InsertElementOp>(
loc, undefVec, operand, createI32Constant(rewriter, loc, 0));
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 27be54728c1a1..d17a610e2ac2a 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -250,7 +250,8 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
- VectorType truncResType = VectorType::get(4, outElemType);
+ VectorType truncResType =
+ VectorType::get(4, cast<ScalarTypeInterface>(outElemType));
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
@@ -331,7 +332,8 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- VectorType truncResType = VectorType::get(2, outElemType);
+ VectorType truncResType =
+ VectorType::get(2, cast<ScalarTypeInterface>(outElemType));
auto inVectorTy = dyn_cast<VectorType>(in.getType());
// Handle the case where input type is not a vector type
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 9c4dfa27b1447..13ff632c18b40 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -322,7 +322,8 @@ struct ConstantCompositeOpPattern final
dstAttrType =
RankedTensorType::get(dstAttrType.getShape(), dstElemType);
else
- dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
+ dstAttrType = VectorType::get(dstAttrType.getShape(),
+ cast<ScalarTypeInterface>(dstElemType));
dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
}
@@ -908,7 +909,8 @@ class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
// cases. Extend them to 32-bit and do comparision then.
Type type = rewriter.getI32Type();
if (auto vectorType = dyn_cast<VectorType>(dstType))
- type = VectorType::get(vectorType.getShape(), type);
+ type = VectorType::get(vectorType.getShape(),
+ cast<ScalarTypeInterface>(type));
Value extLhs =
rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
Value extRhs =
diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
index 9c6de938a7108..d984ab5d932b4 100644
--- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
+++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp
@@ -33,7 +33,8 @@ class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
/// arm.neon.intr.sdot
LogicalResult matchAndRewrite(Sdot2dOp op,
PatternRewriter &rewriter) const override {
- Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
+ ScalarTypeInterface elemType =
+ cast<VectorType>(op.getB().getType()).getElementType();
int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
Sdot2dOp::kReductionSize;
VectorType flattenedVectorType = VectorType::get({length}, elemType);
diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
index 4bd94bcebf290..6a04bd39f2d8e 100644
--- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp
@@ -303,7 +303,8 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs,
auto floatType = cast<FloatType>(getElementTypeOrSelf(lhs.getType()));
Type i1Type = builder.getI1Type();
if (auto vecType = dyn_cast<VectorType>(lhs.getType()))
- i1Type = VectorType::get(vecType.getShape(), i1Type);
+ i1Type =
+ VectorType::get(vecType.getShape(), cast<ScalarTypeInterface>(i1Type));
Value cmp = builder.create<LLVM::FCmpOp>(
loc, i1Type, isMin ? LLVM::FCmpPredicate::olt : LLVM::FCmpPredicate::ogt,
lhs, rhs);
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index ea251e4564ea8..d170a1f01dada 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -638,8 +638,9 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
if (!elementType)
return {};
if (type.getShape().empty())
- return VectorType::get({1}, elementType);
- Type vectorType = VectorType::get(type.getShape().back(), elementType,
+ return VectorType::get({1}, cast<ScalarTypeInterface>(elementType));
+ Type vectorType = VectorType::get(type.getShape().back(),
+ cast<ScalarTypeInterface>(elementType),
type.getScalableDims().back());
assert(LLVM::isCompatibleVectorType(vectorType) &&
"expected vector type compatible with the LLVM dialect");
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index 1b83794b5f450..6676477b9e34b 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -154,7 +154,7 @@ struct CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
if (auto vectorType = dyn_cast<VectorType>(type)) {
assert(vectorType.getRank() == 1);
int count = vectorType.getNumElements();
- intType = VectorType::get(count, intType);
+ intType = VectorType::get(count, cast<ScalarTypeInterface>(intType));
SmallVector<Value> signSplat(count, signMask);
signMask =
@@ -380,7 +380,8 @@ struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
auto operandType = adaptor.getRhs().getType();
if (auto vector...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/132400
More information about the Mlir-commits
mailing list