[llvm-branch-commits] [mlir] [mlir][LLVM] Delete `getFixedVectorType` and `getScalableVectorType` (PR #135051)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Apr 9 10:07:53 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
The LLVM dialect no longer has its own vector types. It uses `mlir::VectorType` everywhere. Remove `LLVM::getFixedVectorType/getScalableVectorType` and use `VectorType::get` instead. This commit addresses a [comment](https://github.com/llvm/llvm-project/pull/133286#discussion_r2022192500) on the PR that deleted the LLVM vector types.
Depends on #<!-- -->134981.
---
Full diff: https://github.com/llvm/llvm-project/pull/135051.diff
7 Files Affected:
- (modified) mlir/docs/Dialects/LLVM.md (-4)
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (-8)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+16-17)
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (-12)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+14-9)
- (modified) mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp (+10-14)
- (modified) mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp (+5-4)
``````````diff
diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index 468f69c419071..4b5d518ca4eab 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -336,10 +336,6 @@ compatible with the LLVM dialect:
vector type compatible with the LLVM dialect;
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
of elements in any vector type compatible with the LLVM dialect;
-- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
- with the given element type and size; the resulting type is either a
- built-in or an LLVM dialect vector type depending on which one supports the
- given element type.
#### Examples of Compatible Vector Types
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index a2a76c49a2bda..17561f79d135a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements,
/// and length.
Type getVectorType(Type elementType, const llvm::ElementCount &numElements);
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getFixedVectorType(Type elementType, unsigned numElements);
-
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getScalableVectorType(Type elementType, unsigned numElements);
-
/// Returns the size of the given primitive LLVM dialect-compatible type
/// (including vectors) in bits, for example, the size of i16 is 16 and
/// the size of vector<4xi16> is 64. Returns 0 for non-primitive
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 51507c6507b69..69fa62c8196e4 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
static Type inferIntrinsicResultType(Type vectorResultType) {
MLIRContext *ctx = vectorResultType.getContext();
auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
- auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
+ auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
auto i32Ty = IntegerType::get(ctx, 32);
- auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+ auto i32x2Ty = VectorType::get(2, i32Ty);
Type f64Ty = Float64Type::get(ctx);
- Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+ Type f64x2Ty = VectorType::get(2, f64Ty);
Type f32Ty = Float32Type::get(ctx);
- Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
+ Type f32x2Ty = VectorType::get(2, f32Ty);
if (a.getElementType() == f16x2Ty) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
@@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
ctx,
SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
}
- if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
+ if (a.getElementType() == VectorType::get(1, f32Ty)) {
return LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
}
@@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
Type i32Ty = rewriter.getI32Type();
Type f32Ty = rewriter.getF32Type();
Type f64Ty = rewriter.getF64Type();
- Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
- Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
- Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
- Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
- Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+ Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
+ Type i32x2Ty = VectorType::get(2, i32Ty);
+ Type f64x2Ty = VectorType::get(2, f64Ty);
+ Type f32x2Ty = VectorType::get(2, f32Ty);
+ Type f32x1Ty = VectorType::get(1, f32Ty);
auto makeConst = [&](int32_t index) -> Value {
return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
@@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
Type f64Ty = b.getF64Type();
Type f32Ty = b.getF32Type();
Type i64Ty = b.getI64Type();
- Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
- Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
- Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+ Type i8x4Ty = VectorType::get(4, b.getI8Type());
+ Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
+ Type f32x1Ty = VectorType::get(1, f32Ty);
auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
@@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
if (!vectorResultType) {
return failure();
}
- Type innerVectorType = LLVM::getFixedVectorType(
- vectorResultType.getElementType(), vectorResultType.getDimSize(1));
+ Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
+ vectorResultType.getElementType());
int64_t num32BitRegs = vectorResultType.getDimSize(0);
@@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering
// Bitcast the sparse metadata from vector<2xf16> to an i32.
Value sparseMetadata = adaptor.getSparseMetadata();
- if (sparseMetadata.getType() !=
- LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
+ if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
return op->emitOpError() << "Expected metadata type to be LLVM "
"VectorType of 2 i16 elements";
sparseMetadata =
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index b3c2a29309528..29cf38c1fefea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType,
/*isScalable=*/false);
}
-Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
- assert(VectorType::isValidElementType(elementType) &&
- "incompatible element type");
- return VectorType::get(numElements, elementType);
-}
-
-Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
- // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
- // scalable/non-scalable.
- return VectorType::get(numElements, elementType, /*scalableDims=*/true);
-}
-
llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
assert(isCompatibleType(type) &&
"expected a type compatible with the LLVM dialect");
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 09bff6101edd3..b9d6952f67671 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() {
std::optional<mlir::NVVM::MMATypes>
MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
auto half2Type =
- LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
+ VectorType::get(2, Float16Type::get(operandElType.getContext()));
if (operandElType.isF64())
return NVVM::MMATypes::f64;
if (operandElType.isF16() || operandElType == half2Type)
@@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) {
p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
// Print the types of the operands and result.
- p << " : " << "(";
+ p << " : "
+ << "(";
llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
frags[1].regs[0].getType(),
frags[2].regs[0].getType()},
@@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto i32Ty = IntegerType::get(context, 32);
- auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
+ auto f16x2Ty = VectorType::get(2, f16Ty);
auto f32Ty = Float32Type::get(context);
auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
@@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() {
expectedA.emplace_back(1, f64Ty);
expectedB.emplace_back(1, f64Ty);
expectedC.emplace_back(2, f64Ty);
- // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
+ // expectedC.emplace_back(1, VectorType::get(2, f64Ty));
expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
context, SmallVector<Type>(2, f64Ty)));
allowedShapes.push_back({8, 8, 4});
@@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
ss << "},";
// Need to map read/write registers correctly.
regCnt = (regCnt * 2);
- ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
+ ss << " $" << (regCnt) << ","
+ << " $" << (regCnt + 1) << ","
+ << " p";
if (getTypeD() != WGMMATypes::s32) {
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
}
@@ -1219,7 +1222,7 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
- [&]() -> auto { \
+ [&]() -> auto{ \
switch (dims) { \
case 1: \
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
@@ -1234,7 +1237,8 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
default: \
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
} \
- }()
+ } \
+ ()
llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
@@ -1364,13 +1368,14 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
: TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
- [&]() -> auto { \
+ [&]() -> auto{ \
if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
- }()
+ } \
+ ()
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 39cca7d363e0d..e80360aa08ed5 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -103,16 +103,15 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
Type elType = type.vectorType.getElementType();
if (elType.isF16()) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
+ inferNumRegistersPerMatrixFragment(type)};
}
// f64 operand
Type f64Ty = Float64Type::get(ctx);
if (elType.isF64()) {
return isAccum
- ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
+ ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f64Ty, 1, 64,
inferNumRegistersPerMatrixFragment(type)};
@@ -120,30 +119,27 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
// int8 operand
if (elType.isInteger(8)) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
+ 32, inferNumRegistersPerMatrixFragment(type)};
}
// int4 operand
if (elType.isInteger(4)) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
+ 32, inferNumRegistersPerMatrixFragment(type)};
}
// Integer 32bit acc operands
if (elType.isInteger(32)) {
- return FragmentElementInfo{
- LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
- inferNumRegistersPerMatrixFragment(type)};
+ return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
+ 64, inferNumRegistersPerMatrixFragment(type)};
}
// Floating point 32bit operands
if (elType.isF32()) {
Type f32Ty = Float32Type::get(ctx);
return isAccum
- ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
+ ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
inferNumRegistersPerMatrixFragment(type)}
: FragmentElementInfo{f32Ty, 1, 32,
inferNumRegistersPerMatrixFragment(type)};
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index bc9765fff2953..c46aa3e80d51a 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -124,14 +124,15 @@ class TypeFromLLVMIRTranslatorImpl {
/// Translates the given fixed-vector type.
Type translate(llvm::FixedVectorType *type) {
- return LLVM::getFixedVectorType(translateType(type->getElementType()),
- type->getNumElements());
+ return VectorType::get(type->getNumElements(),
+ translateType(type->getElementType()));
}
/// Translates the given scalable-vector type.
Type translate(llvm::ScalableVectorType *type) {
- return LLVM::getScalableVectorType(translateType(type->getElementType()),
- type->getMinNumElements());
+ return VectorType::get(type->getMinNumElements(),
+ translateType(type->getElementType()),
+ /*scalable=*/true);
}
/// Translates the given target extension type.
``````````
</details>
https://github.com/llvm/llvm-project/pull/135051
More information about the llvm-branch-commits
mailing list