[Mlir-commits] [mlir] [mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering (PR #151175)
Yang Bai
llvmlistbot at llvm.org
Wed Jul 30 23:54:06 PDT 2025
https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/151175
>From cf8174fcde57f8092409171ae19444886aca3625 Mon Sep 17 00:00:00 2001
From: yangtetris <baiyang0132 at gmail.com>
Date: Tue, 29 Jul 2025 23:41:22 +0800
Subject: [PATCH 1/4] [mlir] Support lowering multi-dim vectors in
VectorFromElementsLowering
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 63 ++++++++++++++++---
.../vector-to-llvm-interface.mlir | 24 +++++++
2 files changed, 79 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 17a79e3815b97..26d056cadb19c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1890,15 +1890,62 @@ struct VectorFromElementsLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
- // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
- // Such ops should be handled in the same way as vector.insert.
- if (vectorType.getRank() > 1)
- return rewriter.notifyMatchFailure(fromElementsOp,
- "rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
- Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
- result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+ Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
+
+ Value result;
+ // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+ if (vectorType.getRank() == 0) {
+ result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+ auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+ rewriter.replaceOp(fromElementsOp, result);
+ return success();
+ }
+
+ // Build 1D vectors for the innermost dimension
+ int64_t innerDimSize = vectorType.getShape().back();
+ int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
+
+ SmallVector<Value> innerVectors;
+ innerVectors.reserve(numInnerVectors);
+
+ auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+ Type llvmInnerType = typeConverter->convertType(innerVectorType);
+
+ int64_t elementInVectorIdx = 0;
+ Value innerVector;
+ for (auto val : adaptor.getElements()) {
+ if (elementInVectorIdx == 0)
+ innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
+ auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
+ innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+ if (++elementInVectorIdx == innerDimSize) {
+ innerVectors.push_back(innerVector);
+ elementInVectorIdx = 0;
+ }
+ }
+
+ // For 1D vectors, we can just return the first innermost vector.
+ if (vectorType.getRank() == 1) {
+ rewriter.replaceOp(fromElementsOp, innerVectors.front());
+ return success();
+ }
+
+ // Now build the nested aggregate structure from these 1D vectors.
+ result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+
+ // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
+ // insert the 1D vectors into the aggregate.
+ auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+ if (!vectorTypeInfo.llvmNDVectorTy)
+ return failure();
+ int64_t vectorIdx = 0;
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
+ result = LLVM::InsertValueOp::create(rewriter, loc, result,
+ innerVectors[vectorIdx++], position);
+ });
+
rewriter.replaceOp(fromElementsOp, result);
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 31e17fb3e3cc6..834858c0b7c8f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2286,6 +2286,30 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
// -----
+// CHECK-LABEL: func.func @from_elements_3d(
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+// CHECK: %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+// CHECK: return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+ return %0 : vector<2x1x2xf32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//
>From 8fe493014531f8a227a090e82088393f6df185f8 Mon Sep 17 00:00:00 2001
From: yangtetris <baiyang0132 at gmail.com>
Date: Tue, 29 Jul 2025 23:53:07 +0800
Subject: [PATCH 2/4] fix code format
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 23 +++++++++++--------
1 file changed, 14 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 26d056cadb19c..59a09be7738e8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1898,11 +1898,12 @@ struct VectorFromElementsLowering
if (vectorType.getRank() == 0) {
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
- result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+ result = LLVM::InsertElementOp::create(
+ rewriter, loc, result, adaptor.getElements().front(), index0);
rewriter.replaceOp(fromElementsOp, result);
return success();
}
-
+
// Build 1D vectors for the innermost dimension
int64_t innerDimSize = vectorType.getShape().back();
int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
@@ -1910,7 +1911,8 @@ struct VectorFromElementsLowering
SmallVector<Value> innerVectors;
innerVectors.reserve(numInnerVectors);
- auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+ auto innerVectorType =
+ VectorType::get(innerDimSize, vectorType.getElementType());
Type llvmInnerType = typeConverter->convertType(innerVectorType);
int64_t elementInVectorIdx = 0;
@@ -1918,8 +1920,10 @@ struct VectorFromElementsLowering
for (auto val : adaptor.getElements()) {
if (elementInVectorIdx == 0)
innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
- auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
- innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+ auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType,
+ elementInVectorIdx);
+ innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType,
+ innerVector, val, position);
if (++elementInVectorIdx == innerDimSize) {
innerVectors.push_back(innerVector);
elementInVectorIdx = 0;
@@ -1934,18 +1938,19 @@ struct VectorFromElementsLowering
// Now build the nested aggregate structure from these 1D vectors.
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
-
+
// Use the same iteration approach as VectorBroadcastScalarToNdLowering to
// insert the 1D vectors into the aggregate.
- auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+ auto vectorTypeInfo =
+ LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
if (!vectorTypeInfo.llvmNDVectorTy)
return failure();
int64_t vectorIdx = 0;
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
- result = LLVM::InsertValueOp::create(rewriter, loc, result,
+ result = LLVM::InsertValueOp::create(rewriter, loc, result,
innerVectors[vectorIdx++], position);
});
-
+
rewriter.replaceOp(fromElementsOp, result);
return success();
}
>From 44211421a2c3d0b8b09d19be6bcb1fd20b6fb1c9 Mon Sep 17 00:00:00 2001
From: Yang Bai <baiyang0132 at gmail.com>
Date: Thu, 31 Jul 2025 10:23:23 +0800
Subject: [PATCH 3/4] Fix typo
Co-authored-by: Nicolas Vasilache <Nico.Vasilache at amd.com>
---
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 59a09be7738e8..1006605bd9130 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1894,7 +1894,7 @@ struct VectorFromElementsLowering
Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
Value result;
- // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+ // 0D vectors are converted to length-1 1D vectors by LLVMTypeConverter.
if (vectorType.getRank() == 0) {
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
>From eec412bc4b1b58f0a79b93561f5cc218ed1b824f Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 30 Jul 2025 23:53:52 -0700
Subject: [PATCH 4/4] refine
---
.../Conversion/LLVMCommon/VectorPattern.h | 7 ++++
.../Conversion/LLVMCommon/VectorPattern.cpp | 10 ++++++
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 33 ++++++++++---------
3 files changed, 34 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 964281592cc65..36dcffc79974d 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -49,6 +49,13 @@ SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
function_ref<void(ArrayRef<int64_t>)> fun);
+// Overload that accepts VectorType directly and extracts type info internally.
+// Returns failure if the vector type info extraction fails.
+LogicalResult nDVectorIterate(VectorType vectorType,
+ const LLVMTypeConverter &converter,
+ OpBuilder &builder,
+ function_ref<void(ArrayRef<int64_t>)> fun);
+
LogicalResult handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b506e12d..adc7c9f1551e7 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -77,6 +77,16 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
}
}
+LogicalResult LLVM::detail::nDVectorIterate(
+ VectorType vectorType, const LLVMTypeConverter &converter,
+ OpBuilder &builder, function_ref<void(ArrayRef<int64_t>)> fun) {
+ auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, converter);
+ if (!vectorTypeInfo.llvmNDVectorTy)
+ return failure();
+ nDVectorIterate(vectorTypeInfo, builder, fun);
+ return success();
+}
+
LogicalResult LLVM::detail::handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 1006605bd9130..137cc7a14c7e0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1904,8 +1904,10 @@ struct VectorFromElementsLowering
return success();
}
- // Build 1D vectors for the innermost dimension
+ // Build 1D vectors for the innermost dimension.
int64_t innerDimSize = vectorType.getShape().back();
+ assert(vectorType.getNumElements() % innerDimSize == 0 &&
+ "innerDimSize must divide vectorType.getNumElements()");
int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
SmallVector<Value> innerVectors;
@@ -1915,23 +1917,23 @@ struct VectorFromElementsLowering
VectorType::get(innerDimSize, vectorType.getElementType());
Type llvmInnerType = typeConverter->convertType(innerVectorType);
- int64_t elementInVectorIdx = 0;
Value innerVector;
- for (auto val : adaptor.getElements()) {
+ for (auto [elemIdx, val] : llvm::enumerate(adaptor.getElements())) {
+ int64_t elementInVectorIdx = elemIdx % innerDimSize;
if (elementInVectorIdx == 0)
innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType,
elementInVectorIdx);
innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType,
innerVector, val, position);
- if (++elementInVectorIdx == innerDimSize) {
+ if (elementInVectorIdx == innerDimSize - 1)
innerVectors.push_back(innerVector);
- elementInVectorIdx = 0;
- }
}
// For 1D vectors, we can just return the first innermost vector.
if (vectorType.getRank() == 1) {
+ assert(innerVectors.size() == 1 &&
+ "for 1D vectors, innerVectors should have exactly one element");
rewriter.replaceOp(fromElementsOp, innerVectors.front());
return success();
}
@@ -1939,17 +1941,16 @@ struct VectorFromElementsLowering
// Now build the nested aggregate structure from these 1D vectors.
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
- // insert the 1D vectors into the aggregate.
- auto vectorTypeInfo =
- LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
- if (!vectorTypeInfo.llvmNDVectorTy)
- return failure();
+ // Iterate over each position of the first n-1 dimensions and insert the 1D
+ // vectors into the aggregate.
int64_t vectorIdx = 0;
- nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
- result = LLVM::InsertValueOp::create(rewriter, loc, result,
- innerVectors[vectorIdx++], position);
- });
+ if (failed(LLVM::detail::nDVectorIterate(
+ vectorType, *getTypeConverter(), rewriter,
+ [&](ArrayRef<int64_t> position) {
+ result = LLVM::InsertValueOp::create(
+ rewriter, loc, result, innerVectors[vectorIdx++], position);
+ })))
+ return failure();
rewriter.replaceOp(fromElementsOp, result);
return success();
More information about the Mlir-commits
mailing list