[llvm-branch-commits] [mlir] [mlir][LLVM] Delete `getFixedVectorType` and `getScalableVectorType` (PR #135051)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Apr 9 10:07:53 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

The LLVM dialect no longer has its own vector types. It uses `mlir::VectorType` everywhere. Remove `LLVM::getFixedVectorType/getScalableVectorType` and use `VectorType::get` instead. This commit addresses a [comment](https://github.com/llvm/llvm-project/pull/133286#discussion_r2022192500) on the PR that deleted the LLVM vector types.

Depends on #<!-- -->134981.



---
Full diff: https://github.com/llvm/llvm-project/pull/135051.diff


7 Files Affected:

- (modified) mlir/docs/Dialects/LLVM.md (-4) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (-8) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+16-17) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (-12) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+14-9) 
- (modified) mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp (+10-14) 
- (modified) mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp (+5-4) 


``````````diff
diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md
index 468f69c419071..4b5d518ca4eab 100644
--- a/mlir/docs/Dialects/LLVM.md
+++ b/mlir/docs/Dialects/LLVM.md
@@ -336,10 +336,6 @@ compatible with the LLVM dialect:
     vector type compatible with the LLVM dialect;
 -   `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
     of elements in any vector type compatible with the LLVM dialect;
--   `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
-    with the given element type and size; the resulting type is either a
-    built-in or an LLVM dialect vector type depending on which one supports the
-    given element type.
 
 #### Examples of Compatible Vector Types
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index a2a76c49a2bda..17561f79d135a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -126,14 +126,6 @@ Type getVectorType(Type elementType, unsigned numElements,
 /// and length.
 Type getVectorType(Type elementType, const llvm::ElementCount &numElements);
 
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getFixedVectorType(Type elementType, unsigned numElements);
-
-/// Creates an LLVM dialect-compatible type with the given element type and
-/// length.
-Type getScalableVectorType(Type elementType, unsigned numElements);
-
 /// Returns the size of the given primitive LLVM dialect-compatible type
 /// (including vectors) in bits, for example, the size of i16 is 16 and
 /// the size of vector<4xi16> is 64. Returns 0 for non-primitive
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 51507c6507b69..69fa62c8196e4 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -61,13 +61,13 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
 static Type inferIntrinsicResultType(Type vectorResultType) {
   MLIRContext *ctx = vectorResultType.getContext();
   auto a = cast<LLVM::LLVMArrayType>(vectorResultType);
-  auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2);
+  auto f16x2Ty = VectorType::get(2, Float16Type::get(ctx));
   auto i32Ty = IntegerType::get(ctx, 32);
-  auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
+  auto i32x2Ty = VectorType::get(2, i32Ty);
   Type f64Ty = Float64Type::get(ctx);
-  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
+  Type f64x2Ty = VectorType::get(2, f64Ty);
   Type f32Ty = Float32Type::get(ctx);
-  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
+  Type f32x2Ty = VectorType::get(2, f32Ty);
   if (a.getElementType() == f16x2Ty) {
     return LLVM::LLVMStructType::getLiteral(
         ctx, SmallVector<Type>(a.getNumElements(), f16x2Ty));
@@ -85,7 +85,7 @@ static Type inferIntrinsicResultType(Type vectorResultType) {
         ctx,
         SmallVector<Type>(static_cast<size_t>(a.getNumElements()) * 2, f32Ty));
   }
-  if (a.getElementType() == LLVM::getFixedVectorType(f32Ty, 1)) {
+  if (a.getElementType() == VectorType::get(1, f32Ty)) {
     return LLVM::LLVMStructType::getLiteral(
         ctx, SmallVector<Type>(static_cast<size_t>(a.getNumElements()), f32Ty));
   }
@@ -106,11 +106,11 @@ static Value convertIntrinsicResult(Location loc, Type intrinsicResultType,
   Type i32Ty = rewriter.getI32Type();
   Type f32Ty = rewriter.getF32Type();
   Type f64Ty = rewriter.getF64Type();
-  Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2);
-  Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2);
-  Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2);
-  Type f32x2Ty = LLVM::getFixedVectorType(f32Ty, 2);
-  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+  Type f16x2Ty = VectorType::get(2, rewriter.getF16Type());
+  Type i32x2Ty = VectorType::get(2, i32Ty);
+  Type f64x2Ty = VectorType::get(2, f64Ty);
+  Type f32x2Ty = VectorType::get(2, f32Ty);
+  Type f32x1Ty = VectorType::get(1, f32Ty);
 
   auto makeConst = [&](int32_t index) -> Value {
     return rewriter.create<LLVM::ConstantOp>(loc, IntegerType::get(ctx, 32),
@@ -181,9 +181,9 @@ static SmallVector<Value> unpackOperandVector(ImplicitLocOpBuilder &b,
   Type f64Ty = b.getF64Type();
   Type f32Ty = b.getF32Type();
   Type i64Ty = b.getI64Type();
-  Type i8x4Ty = LLVM::getFixedVectorType(b.getI8Type(), 4);
-  Type i4x8Ty = LLVM::getFixedVectorType(b.getIntegerType(4), 8);
-  Type f32x1Ty = LLVM::getFixedVectorType(f32Ty, 1);
+  Type i8x4Ty = VectorType::get(4, b.getI8Type());
+  Type i4x8Ty = VectorType::get(8, b.getIntegerType(4));
+  Type f32x1Ty = VectorType::get(1, f32Ty);
   auto arrayTy = cast<LLVM::LLVMArrayType>(operand.getType());
 
   for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) {
@@ -268,8 +268,8 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
     if (!vectorResultType) {
       return failure();
     }
-    Type innerVectorType = LLVM::getFixedVectorType(
-        vectorResultType.getElementType(), vectorResultType.getDimSize(1));
+    Type innerVectorType = VectorType::get(vectorResultType.getDimSize(1),
+                                           vectorResultType.getElementType());
 
     int64_t num32BitRegs = vectorResultType.getDimSize(0);
 
@@ -627,8 +627,7 @@ struct NVGPUMmaSparseSyncLowering
 
     // Bitcast the sparse metadata from vector<2xf16> to an i32.
     Value sparseMetadata = adaptor.getSparseMetadata();
-    if (sparseMetadata.getType() !=
-        LLVM::getFixedVectorType(rewriter.getI16Type(), 2))
+    if (sparseMetadata.getType() != VectorType::get(2, rewriter.getI16Type()))
       return op->emitOpError() << "Expected metadata type to be LLVM "
                                   "VectorType of 2 i16 elements";
     sparseMetadata =
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index b3c2a29309528..29cf38c1fefea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -851,18 +851,6 @@ Type mlir::LLVM::getVectorType(Type elementType,
                        /*isScalable=*/false);
 }
 
-Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) {
-  assert(VectorType::isValidElementType(elementType) &&
-         "incompatible element type");
-  return VectorType::get(numElements, elementType);
-}
-
-Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) {
-  // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as
-  // scalable/non-scalable.
-  return VectorType::get(numElements, elementType, /*scalableDims=*/true);
-}
-
 llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) {
   assert(isCompatibleType(type) &&
          "expected a type compatible with the LLVM dialect");
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 09bff6101edd3..b9d6952f67671 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -144,7 +144,7 @@ LogicalResult BulkStoreOp::verify() {
 std::optional<mlir::NVVM::MMATypes>
 MmaOp::inferOperandMMAType(Type operandElType, bool isAccumulator) {
   auto half2Type =
-      LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
+      VectorType::get(2, Float16Type::get(operandElType.getContext()));
   if (operandElType.isF64())
     return NVVM::MMATypes::f64;
   if (operandElType.isF16() || operandElType == half2Type)
@@ -243,7 +243,8 @@ void MmaOp::print(OpAsmPrinter &p) {
   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
 
   // Print the types of the operands and result.
-  p << " : " << "(";
+  p << " : "
+    << "(";
   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
                                              frags[1].regs[0].getType(),
                                              frags[2].regs[0].getType()},
@@ -404,7 +405,7 @@ LogicalResult MmaOp::verify() {
   MLIRContext *context = getContext();
   auto f16Ty = Float16Type::get(context);
   auto i32Ty = IntegerType::get(context, 32);
-  auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
+  auto f16x2Ty = VectorType::get(2, f16Ty);
   auto f32Ty = Float32Type::get(context);
   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
@@ -506,7 +507,7 @@ LogicalResult MmaOp::verify() {
       expectedA.emplace_back(1, f64Ty);
       expectedB.emplace_back(1, f64Ty);
       expectedC.emplace_back(2, f64Ty);
-      // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
+      // expectedC.emplace_back(1, VectorType::get(2, f64Ty));
       expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
           context, SmallVector<Type>(2, f64Ty)));
       allowedShapes.push_back({8, 8, 4});
@@ -992,7 +993,9 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   ss << "},";
   // Need to map read/write registers correctly.
   regCnt = (regCnt * 2);
-  ss << " $" << (regCnt) << "," << " $" << (regCnt + 1) << "," << " p";
+  ss << " $" << (regCnt) << ","
+     << " $" << (regCnt + 1) << ","
+     << " p";
   if (getTypeD() != WGMMATypes::s32) {
     ss << ", $" << (regCnt + 3) << ",  $" << (regCnt + 4);
   }
@@ -1219,7 +1222,7 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
             : CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
 
 #define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col)                       \
-  [&]() -> auto {                                                              \
+  [&]() -> auto{                                                               \
     switch (dims) {                                                            \
     case 1:                                                                    \
       return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile);                    \
@@ -1234,7 +1237,8 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
     default:                                                                   \
       llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp.");     \
     }                                                                          \
-  }()
+  }                                                                            \
+  ()
 
 llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
     int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
@@ -1364,13 +1368,14 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
           : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
 
 #define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)                          \
-  [&]() -> auto {                                                              \
+  [&]() -> auto{                                                               \
     if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32)                              \
       return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta);                   \
     if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64)                              \
       return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta);                   \
     return TCGEN05_CP_2CTA(shape_mc, , is_2cta);                               \
-  }()
+  }                                                                            \
+  ()
 
 llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
   auto curOp = cast<NVVM::Tcgen05CpOp>(op);
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 39cca7d363e0d..e80360aa08ed5 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -103,16 +103,15 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
 
   Type elType = type.vectorType.getElementType();
   if (elType.isF16()) {
-    return FragmentElementInfo{
-        LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32,
-        inferNumRegistersPerMatrixFragment(type)};
+    return FragmentElementInfo{VectorType::get(2, Float16Type::get(ctx)), 2, 32,
+                               inferNumRegistersPerMatrixFragment(type)};
   }
 
   // f64 operand
   Type f64Ty = Float64Type::get(ctx);
   if (elType.isF64()) {
     return isAccum
-               ? FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128,
+               ? FragmentElementInfo{VectorType::get(2, f64Ty), 2, 128,
                                      inferNumRegistersPerMatrixFragment(type)}
                : FragmentElementInfo{f64Ty, 1, 64,
                                      inferNumRegistersPerMatrixFragment(type)};
@@ -120,30 +119,27 @@ nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) {
 
   // int8 operand
   if (elType.isInteger(8)) {
-    return FragmentElementInfo{
-        LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32,
-        inferNumRegistersPerMatrixFragment(type)};
+    return FragmentElementInfo{VectorType::get(4, IntegerType::get(ctx, 8)), 4,
+                               32, inferNumRegistersPerMatrixFragment(type)};
   }
 
   // int4 operand
   if (elType.isInteger(4)) {
-    return FragmentElementInfo{
-        LLVM::getFixedVectorType(IntegerType::get(ctx, 4), 8), 8, 32,
-        inferNumRegistersPerMatrixFragment(type)};
+    return FragmentElementInfo{VectorType::get(8, IntegerType::get(ctx, 4)), 8,
+                               32, inferNumRegistersPerMatrixFragment(type)};
   }
 
   // Integer 32bit acc operands
   if (elType.isInteger(32)) {
-    return FragmentElementInfo{
-        LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64,
-        inferNumRegistersPerMatrixFragment(type)};
+    return FragmentElementInfo{VectorType::get(2, IntegerType::get(ctx, 32)), 2,
+                               64, inferNumRegistersPerMatrixFragment(type)};
   }
 
   // Floating point 32bit operands
   if (elType.isF32()) {
     Type f32Ty = Float32Type::get(ctx);
     return isAccum
-               ? FragmentElementInfo{LLVM::getFixedVectorType(f32Ty, 2), 2, 64,
+               ? FragmentElementInfo{VectorType::get(2, f32Ty), 2, 64,
                                      inferNumRegistersPerMatrixFragment(type)}
                : FragmentElementInfo{f32Ty, 1, 32,
                                      inferNumRegistersPerMatrixFragment(type)};
diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
index bc9765fff2953..c46aa3e80d51a 100644
--- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
+++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp
@@ -124,14 +124,15 @@ class TypeFromLLVMIRTranslatorImpl {
 
   /// Translates the given fixed-vector type.
   Type translate(llvm::FixedVectorType *type) {
-    return LLVM::getFixedVectorType(translateType(type->getElementType()),
-                                    type->getNumElements());
+    return VectorType::get(type->getNumElements(),
+                           translateType(type->getElementType()));
   }
 
   /// Translates the given scalable-vector type.
   Type translate(llvm::ScalableVectorType *type) {
-    return LLVM::getScalableVectorType(translateType(type->getElementType()),
-                                       type->getMinNumElements());
+    return VectorType::get(type->getMinNumElements(),
+                           translateType(type->getElementType()),
+                           /*scalable=*/true);
   }
 
   /// Translates the given target extension type.

``````````

</details>


https://github.com/llvm/llvm-project/pull/135051


More information about the llvm-branch-commits mailing list