[Mlir-commits] [mlir] [mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering (PR #151175)
Yang Bai
llvmlistbot at llvm.org
Wed Aug 13 20:25:34 PDT 2025
https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/151175
>From cf8174fcde57f8092409171ae19444886aca3625 Mon Sep 17 00:00:00 2001
From: yangtetris <baiyang0132 at gmail.com>
Date: Tue, 29 Jul 2025 23:41:22 +0800
Subject: [PATCH 1/8] [mlir] Support lowering multi-dim vectors in
VectorFromElementsLowering
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 63 ++++++++++++++++---
.../vector-to-llvm-interface.mlir | 24 +++++++
2 files changed, 79 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 17a79e3815b97..26d056cadb19c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1890,15 +1890,62 @@ struct VectorFromElementsLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
- // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
- // Such ops should be handled in the same way as vector.insert.
- if (vectorType.getRank() > 1)
- return rewriter.notifyMatchFailure(fromElementsOp,
- "rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
- Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
- result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+ Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
+
+ Value result;
+ // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+ if (vectorType.getRank() == 0) {
+ result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+ auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+ rewriter.replaceOp(fromElementsOp, result);
+ return success();
+ }
+
+ // Build 1D vectors for the innermost dimension
+ int64_t innerDimSize = vectorType.getShape().back();
+ int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
+
+ SmallVector<Value> innerVectors;
+ innerVectors.reserve(numInnerVectors);
+
+ auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+ Type llvmInnerType = typeConverter->convertType(innerVectorType);
+
+ int64_t elementInVectorIdx = 0;
+ Value innerVector;
+ for (auto val : adaptor.getElements()) {
+ if (elementInVectorIdx == 0)
+ innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
+ auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
+ innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+ if (++elementInVectorIdx == innerDimSize) {
+ innerVectors.push_back(innerVector);
+ elementInVectorIdx = 0;
+ }
+ }
+
+ // For 1D vectors, we can just return the first innermost vector.
+ if (vectorType.getRank() == 1) {
+ rewriter.replaceOp(fromElementsOp, innerVectors.front());
+ return success();
+ }
+
+ // Now build the nested aggregate structure from these 1D vectors.
+ result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+
+ // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
+ // insert the 1D vectors into the aggregate.
+ auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+ if (!vectorTypeInfo.llvmNDVectorTy)
+ return failure();
+ int64_t vectorIdx = 0;
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
+ result = LLVM::InsertValueOp::create(rewriter, loc, result,
+ innerVectors[vectorIdx++], position);
+ });
+
rewriter.replaceOp(fromElementsOp, result);
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 31e17fb3e3cc6..834858c0b7c8f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2286,6 +2286,30 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
// -----
+// CHECK-LABEL: func.func @from_elements_3d(
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+// CHECK: %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+// CHECK: return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+ return %0 : vector<2x1x2xf32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//
>From 8fe493014531f8a227a090e82088393f6df185f8 Mon Sep 17 00:00:00 2001
From: yangtetris <baiyang0132 at gmail.com>
Date: Tue, 29 Jul 2025 23:53:07 +0800
Subject: [PATCH 2/8] fix code format
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 23 +++++++++++--------
1 file changed, 14 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 26d056cadb19c..59a09be7738e8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1898,11 +1898,12 @@ struct VectorFromElementsLowering
if (vectorType.getRank() == 0) {
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
- result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+ result = LLVM::InsertElementOp::create(
+ rewriter, loc, result, adaptor.getElements().front(), index0);
rewriter.replaceOp(fromElementsOp, result);
return success();
}
-
+
// Build 1D vectors for the innermost dimension
int64_t innerDimSize = vectorType.getShape().back();
int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
@@ -1910,7 +1911,8 @@ struct VectorFromElementsLowering
SmallVector<Value> innerVectors;
innerVectors.reserve(numInnerVectors);
- auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+ auto innerVectorType =
+ VectorType::get(innerDimSize, vectorType.getElementType());
Type llvmInnerType = typeConverter->convertType(innerVectorType);
int64_t elementInVectorIdx = 0;
@@ -1918,8 +1920,10 @@ struct VectorFromElementsLowering
for (auto val : adaptor.getElements()) {
if (elementInVectorIdx == 0)
innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
- auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
- innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+ auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType,
+ elementInVectorIdx);
+ innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType,
+ innerVector, val, position);
if (++elementInVectorIdx == innerDimSize) {
innerVectors.push_back(innerVector);
elementInVectorIdx = 0;
@@ -1934,18 +1938,19 @@ struct VectorFromElementsLowering
// Now build the nested aggregate structure from these 1D vectors.
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
-
+
// Use the same iteration approach as VectorBroadcastScalarToNdLowering to
// insert the 1D vectors into the aggregate.
- auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+ auto vectorTypeInfo =
+ LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
if (!vectorTypeInfo.llvmNDVectorTy)
return failure();
int64_t vectorIdx = 0;
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
- result = LLVM::InsertValueOp::create(rewriter, loc, result,
+ result = LLVM::InsertValueOp::create(rewriter, loc, result,
innerVectors[vectorIdx++], position);
});
-
+
rewriter.replaceOp(fromElementsOp, result);
return success();
}
>From 44211421a2c3d0b8b09d19be6bcb1fd20b6fb1c9 Mon Sep 17 00:00:00 2001
From: Yang Bai <baiyang0132 at gmail.com>
Date: Thu, 31 Jul 2025 10:23:23 +0800
Subject: [PATCH 3/8] Fix typo
Co-authored-by: Nicolas Vasilache <Nico.Vasilache at amd.com>
---
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 59a09be7738e8..1006605bd9130 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1894,7 +1894,7 @@ struct VectorFromElementsLowering
Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
Value result;
- // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+ // 0D vectors are converted to length-1 1D vectors by LLVMTypeConverter.
if (vectorType.getRank() == 0) {
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
>From eec412bc4b1b58f0a79b93561f5cc218ed1b824f Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 30 Jul 2025 23:53:52 -0700
Subject: [PATCH 4/8] refine
---
.../Conversion/LLVMCommon/VectorPattern.h | 7 ++++
.../Conversion/LLVMCommon/VectorPattern.cpp | 10 ++++++
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 33 ++++++++++---------
3 files changed, 34 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 964281592cc65..36dcffc79974d 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -49,6 +49,13 @@ SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
function_ref<void(ArrayRef<int64_t>)> fun);
+// Overload that accepts VectorType directly and extracts type info internally.
+// Returns failure if the vector type info extraction fails.
+LogicalResult nDVectorIterate(VectorType vectorType,
+ const LLVMTypeConverter &converter,
+ OpBuilder &builder,
+ function_ref<void(ArrayRef<int64_t>)> fun);
+
LogicalResult handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b506e12d..adc7c9f1551e7 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -77,6 +77,16 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
}
}
+LogicalResult LLVM::detail::nDVectorIterate(
+ VectorType vectorType, const LLVMTypeConverter &converter,
+ OpBuilder &builder, function_ref<void(ArrayRef<int64_t>)> fun) {
+ auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, converter);
+ if (!vectorTypeInfo.llvmNDVectorTy)
+ return failure();
+ nDVectorIterate(vectorTypeInfo, builder, fun);
+ return success();
+}
+
LogicalResult LLVM::detail::handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 1006605bd9130..137cc7a14c7e0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1904,8 +1904,10 @@ struct VectorFromElementsLowering
return success();
}
- // Build 1D vectors for the innermost dimension
+ // Build 1D vectors for the innermost dimension.
int64_t innerDimSize = vectorType.getShape().back();
+ assert(vectorType.getNumElements() % innerDimSize == 0 &&
+ "innerDimSize must divide vectorType.getNumElements()");
int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
SmallVector<Value> innerVectors;
@@ -1915,23 +1917,23 @@ struct VectorFromElementsLowering
VectorType::get(innerDimSize, vectorType.getElementType());
Type llvmInnerType = typeConverter->convertType(innerVectorType);
- int64_t elementInVectorIdx = 0;
Value innerVector;
- for (auto val : adaptor.getElements()) {
+ for (auto [elemIdx, val] : llvm::enumerate(adaptor.getElements())) {
+ int64_t elementInVectorIdx = elemIdx % innerDimSize;
if (elementInVectorIdx == 0)
innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType,
elementInVectorIdx);
innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType,
innerVector, val, position);
- if (++elementInVectorIdx == innerDimSize) {
+ if (elementInVectorIdx == innerDimSize - 1)
innerVectors.push_back(innerVector);
- elementInVectorIdx = 0;
- }
}
// For 1D vectors, we can just return the first innermost vector.
if (vectorType.getRank() == 1) {
+ assert(innerVectors.size() == 1 &&
+ "for 1D vectors, innerVectors should have exactly one element");
rewriter.replaceOp(fromElementsOp, innerVectors.front());
return success();
}
@@ -1939,17 +1941,16 @@ struct VectorFromElementsLowering
// Now build the nested aggregate structure from these 1D vectors.
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
- // insert the 1D vectors into the aggregate.
- auto vectorTypeInfo =
- LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
- if (!vectorTypeInfo.llvmNDVectorTy)
- return failure();
+ // Iterate over each position of the first n-1 dimensions and insert the 1D
+ // vectors into the aggregate.
int64_t vectorIdx = 0;
- nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
- result = LLVM::InsertValueOp::create(rewriter, loc, result,
- innerVectors[vectorIdx++], position);
- });
+ if (failed(LLVM::detail::nDVectorIterate(
+ vectorType, *getTypeConverter(), rewriter,
+ [&](ArrayRef<int64_t> position) {
+ result = LLVM::InsertValueOp::create(
+ rewriter, loc, result, innerVectors[vectorIdx++], position);
+ })))
+ return failure();
rewriter.replaceOp(fromElementsOp, result);
return success();
>From d23ddd2d75db8ce51038399f03950245ad3f6723 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Tue, 12 Aug 2025 08:30:12 -0700
Subject: [PATCH 5/8] re-implmentated with unrolling transformation
---
.../Conversion/LLVMCommon/VectorPattern.h | 7 --
.../Vector/TransformOps/VectorTransformOps.td | 11 +++
.../Vector/Transforms/LoweringPatterns.h | 8 +++
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 39 ++++++++++
.../Conversion/LLVMCommon/VectorPattern.cpp | 10 ---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 71 ++++---------------
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 1 +
.../TransformOps/VectorTransformOps.cpp | 5 ++
.../Vector/Transforms/LowerVectorGather.cpp | 33 +++------
...LowerVectorToFromElementsToShuffleTree.cpp | 42 +++++++++++
.../vector-to-llvm-interface.mlir | 24 -------
.../VectorToLLVM/vector-to-llvm.mlir | 37 ++++++++++
.../Vector/vector-from-elements-lowering.mlir | 45 ++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 24 +++++++
.../python/dialects/transform_vector_ext.py | 2 +
15 files changed, 234 insertions(+), 125 deletions(-)
create mode 100644 mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 36dcffc79974d..964281592cc65 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -49,13 +49,6 @@ SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
function_ref<void(ArrayRef<int64_t>)> fun);
-// Overload that accepts VectorType directly and extracts type info internally.
-// Returns failure if the vector type info extraction fails.
-LogicalResult nDVectorIterate(VectorType vectorType,
- const LLVMTypeConverter &converter,
- OpBuilder &builder,
- function_ref<void(ArrayRef<int64_t>)> fun);
-
LogicalResult handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 299f198e4ab9c..07a4117a37b2c 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -254,6 +254,17 @@ def ApplyLowerGatherPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.unroll_from_elements",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector from_elements operations should be unrolled
+ along the outermost dimension.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_scan",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index e03f0dabece52..8c2cafe83c791 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -303,6 +303,14 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
void populateVectorToFromElementsToShuffleTreePatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollFromElements]
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension.
+void populateVectorFromElementsUnrollingPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
/// Populate the pattern set with the following patterns:
///
/// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 7cd70e42d363c..8309cdde6ad76 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -238,6 +239,44 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
/// static sizes in `shape`.
LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
ArrayRef<int64_t> inputVectorSizes);
+
+/// Generic utility for unrolling n-D vector operations to (n-1)-D operations.
+/// This handles the common pattern of:
+/// 1. Check if already 1-D. If so, return failure.
+/// 2. Check for scalable dimensions. If so, return failure.
+/// 3. Create poison initialized result.
+/// 4. Loop through the outermost dimension, execute the UnrollVectorOpFn to
+/// create sub vectors.
+/// 5. Insert the sub vectors back into the final vector.
+/// 6. Replace the original op with the new result.
+using UnrollVectorOpFn =
+ function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
+
+template <typename VectorOpType>
+LogicalResult unrollVectorOp(VectorOpType op, PatternRewriter &rewriter,
+ UnrollVectorOpFn unrollFn) {
+ VectorType resultTy = op.getType();
+ if (resultTy.getRank() < 2)
+ return rewriter.notifyMatchFailure(op, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (resultTy.getScalableDims().front())
+ return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+ Location loc = op.getLoc();
+ Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
+ VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
+
+ for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+ Value subVector = unrollFn(rewriter, loc, subTy, i);
+ result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
+
} // namespace vector
/// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index adc7c9f1551e7..e7dd0b506e12d 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -77,16 +77,6 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
}
}
-LogicalResult LLVM::detail::nDVectorIterate(
- VectorType vectorType, const LLVMTypeConverter &converter,
- OpBuilder &builder, function_ref<void(ArrayRef<int64_t>)> fun) {
- auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, converter);
- if (!vectorTypeInfo.llvmNDVectorTy)
- return failure();
- nDVectorIterate(vectorTypeInfo, builder, fun);
- return success();
-}
-
LogicalResult LLVM::detail::handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 137cc7a14c7e0..b44df3f0320e8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1890,68 +1890,21 @@ struct VectorFromElementsLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
+ // Only support 1-D vectors. Multi-dimensional vectors should have been
+ // transformed to 1-D vectors by the vector-to-vector transformations before
+ // this.
+ if (vectorType.getRank() > 1)
+ return rewriter.notifyMatchFailure(fromElementsOp,
+ "rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
-
- Value result;
- // 0D vectors are converted to length-1 1D vectors by LLVMTypeConverter.
- if (vectorType.getRank() == 0) {
- result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
- result = LLVM::InsertElementOp::create(
- rewriter, loc, result, adaptor.getElements().front(), index0);
- rewriter.replaceOp(fromElementsOp, result);
- return success();
+ Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+ for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
+ auto constIdx =
+ LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
+ result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
+ val, constIdx);
}
-
- // Build 1D vectors for the innermost dimension.
- int64_t innerDimSize = vectorType.getShape().back();
- assert(vectorType.getNumElements() % innerDimSize == 0 &&
- "innerDimSize must divide vectorType.getNumElements()");
- int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
-
- SmallVector<Value> innerVectors;
- innerVectors.reserve(numInnerVectors);
-
- auto innerVectorType =
- VectorType::get(innerDimSize, vectorType.getElementType());
- Type llvmInnerType = typeConverter->convertType(innerVectorType);
-
- Value innerVector;
- for (auto [elemIdx, val] : llvm::enumerate(adaptor.getElements())) {
- int64_t elementInVectorIdx = elemIdx % innerDimSize;
- if (elementInVectorIdx == 0)
- innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
- auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType,
- elementInVectorIdx);
- innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType,
- innerVector, val, position);
- if (elementInVectorIdx == innerDimSize - 1)
- innerVectors.push_back(innerVector);
- }
-
- // For 1D vectors, we can just return the first innermost vector.
- if (vectorType.getRank() == 1) {
- assert(innerVectors.size() == 1 &&
- "for 1D vectors, innerVectors should have exactly one element");
- rewriter.replaceOp(fromElementsOp, innerVectors.front());
- return success();
- }
-
- // Now build the nested aggregate structure from these 1D vectors.
- result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
-
- // Iterate over each position of the first n-1 dimensions and insert the 1D
- // vectors into the aggregate.
- int64_t vectorIdx = 0;
- if (failed(LLVM::detail::nDVectorIterate(
- vectorType, *getTypeConverter(), rewriter,
- [&](ArrayRef<int64_t> position) {
- result = LLVM::InsertValueOp::create(
- rewriter, loc, result, innerVectors[vectorIdx++], position);
- })))
- return failure();
-
rewriter.replaceOp(fromElementsOp, result);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index cf108690c3741..7ac3bd4aee937 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ populateVectorFromElementsUnrollingPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 2d5cc070558c3..e6917c03d3b26 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
vector::populateVectorGatherLoweringPatterns(patterns);
}
+void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorFromElementsUnrollingPatterns(patterns);
+}
+
void transform::ApplyLowerScanPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorScanLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index e062f55f87679..90f21c53246b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
- VectorType resultTy = op.getType();
- if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already 1-D");
-
- // Unrolling doesn't take vscale into account. Pattern is disabled for
- // vectors with leading scalable dim(s).
- if (resultTy.getScalableDims().front())
- return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
-
- Location loc = op.getLoc();
Value indexVec = op.getIndexVec();
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();
- Value result = arith::ConstantOp::create(rewriter, loc, resultTy,
- rewriter.getZeroAttr(resultTy));
-
- VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
-
- for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
- int64_t thisIdx[1] = {i};
+ auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ int64_t thisIdx[1] = {index};
Value indexSubVec =
vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
@@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
Value passThruSubVec =
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
- Value subGather = vector::GatherOp::create(
- rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
- maskSubVec, passThruSubVec);
- result =
- vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
- }
+ return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
+ op.getIndices(), indexSubVec, maskSubVec,
+ passThruSubVec);
+ };
- rewriter.replaceOp(op, result);
- return success();
+ return unrollVectorOp(op, rewriter, unrollGatherFn);
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
index 6407a868abd85..3ed81fecefc41 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -735,6 +735,43 @@ struct LowerVectorToFromElementsToShuffleTreePass
}
};
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
+///
+/// ==>
+///
+/// %0 = ub.poison : vector<2x3xf32>
+/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
+/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
+/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::FromElementsOp op,
+ PatternRewriter &rewriter) const override {
+ ValueRange allElements = op.getElements();
+
+ auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ size_t subTyNumElements = subTy.getNumElements();
+ assert((index + 1) * subTyNumElements <= allElements.size() &&
+ "out of bounds");
+ ValueRange subElements =
+ allElements.slice(index * subTyNumElements, subTyNumElements);
+ return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
+ };
+
+ return unrollVectorOp(op, rewriter, unrollFromElementsFn);
+ }
+};
+
} // namespace
void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns(
@@ -742,3 +779,8 @@ void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns(
patterns.add<ToFromElementsToShuffleTreeRewrite>(patterns.getContext(),
benefit);
}
+
+void mlir::vector::populateVectorFromElementsUnrollingPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 834858c0b7c8f..31e17fb3e3cc6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2286,30 +2286,6 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
// -----
-// CHECK-LABEL: func.func @from_elements_3d(
-// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
-// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
-// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
-// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
-// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
-// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
-// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
-// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
-// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
-// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
-// CHECK: %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
-// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
-// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
-// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
-// CHECK: return %[[CAST]]
-func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
- %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
- return %0 : vector<2x1x2xf32>
-}
-
-// -----
-
//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 72810b5dddaa3..fb8a5b436797d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1737,3 +1737,40 @@ func.func @step() -> vector<4xindex> {
%0 = vector.step : vector<4xindex>
return %0 : vector<4xindex>
}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.from_elements
+//===----------------------------------------------------------------------===//
+
+// NOTE: For now, we unroll multi-dimensional from_elements ops with pattern `UnrollFromElements`
+// and then convert the 1-D from_elements ops to llvm.
+
+// CHECK-LABEL: func @from_elements_3d
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+// CHECK: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
+// CHECK: %[[UNDEF_RES_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_RES]] : vector<2x1x2xf32> to !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[UNDEF_VEC_RANK_2:.*]] = ub.poison : vector<1x2xf32>
+// CHECK: %[[UNDEF_VEC_RANK_2_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_VEC_RANK_2]] : vector<1x2xf32> to !llvm.array<1 x vector<2xf32>>
+// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+// CHECK: %[[RES_RANK_2_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>>
+// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[RES_RANK_2_0]], %[[UNDEF_RES_LLVM]][0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+// CHECK: %[[RES_RANK_2_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>>
+// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[RES_RANK_2_1]], %[[RES_0]][1] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+// CHECK: return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+ return %0 : vector<2x1x2xf32>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir
new file mode 100644
index 0000000000000..1c2e07086d093
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -test-unroll-vector-from-elements | FileCheck %s --check-prefix=CHECK-UNROLL
+
+//===----------------------------------------------------------------------===//
+// Test UnrollFromElements.
+//===----------------------------------------------------------------------===//
+
+// CHECK-UNROLL-LABEL: @unroll_from_elements_2d
+// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x2xf32>
+// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RES]] [0] : vector<2xf32> into vector<2x2xf32>
+// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[VEC_1]], %[[RES_0]] [1] : vector<2xf32> into vector<2x2xf32>
+// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x2xf32>
+func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
+
+// CHECK-UNROLL-LABEL: @unroll_from_elements_3d
+// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
+// CHECK-UNROLL-NEXT: %[[UNDEF_RANK_2:.*]] = ub.poison : vector<1x2xf32>
+// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: %[[RANK_2_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
+// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[RANK_2_0]], %[[UNDEF_RES]] [0] : vector<1x2xf32> into vector<2x1x2xf32>
+// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: %[[RANK_2_1:.*]] = vector.insert %[[VEC_1]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
+// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[RANK_2_1]], %[[RES_0]] [1] : vector<1x2xf32> into vector<2x1x2xf32>
+// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x1x2xf32>
+func.func @unroll_from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+ return %0 : vector<2x1x2xf32>
+}
+
+// 1-D vector.from_elements should not be unrolled.
+
+// CHECK-UNROLL-LABEL: @negative_unroll_from_elements_1d
+// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK-UNROLL-NEXT: %[[RES:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+// CHECK-UNROLL-NEXT: return %[[RES]] : vector<2xf32>
+func.func @negative_unroll_from_elements_1d(%arg0: f32, %arg1: f32) -> vector<2xf32> {
+ %0 = vector.from_elements %arg0, %arg1 : vector<2xf32>
+ return %0 : vector<2xf32>
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f89c944b5c564..dd35ad11e80ac 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -786,6 +786,28 @@ struct TestVectorGatherLowering
}
};
+struct TestUnrollVectorFromElements
+ : public PassWrapper<TestUnrollVectorFromElements,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorFromElements)
+
+ StringRef getArgument() const final {
+ return "test-unroll-vector-from-elements";
+ }
+ StringRef getDescription() const final {
+ return "Test unrolling patterns for from_elements ops";
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<func::FuncDialect, vector::VectorDialect, ub::UBDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorFromElementsUnrollingPatterns(patterns);
+ (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestFoldArithExtensionIntoVectorContractPatterns
: public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
OperationPass<func::FuncOp>> {
@@ -1059,6 +1081,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorGatherLowering>();
+ PassRegistration<TestUnrollVectorFromElements>();
+
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index a51f2154d1f7d..5a648fe073315 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -46,6 +46,8 @@ def non_configurable_patterns():
vector.ApplyLowerOuterProductPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_gather
vector.ApplyLowerGatherPatternsOp()
+ # CHECK: transform.apply_patterns.vector.unroll_from_elements
+ vector.ApplyUnrollFromElementsPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_scan
vector.ApplyLowerScanPatternsOp()
# CHECK: transform.apply_patterns.vector.lower_shape_cast
>From 1a5b07570337ed21f008139351f5abdae4a6a3cd Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Tue, 12 Aug 2025 09:24:28 -0700
Subject: [PATCH 6/8] fix test
---
mlir/test/Dialect/Vector/vector-gather-lowering.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5be267c1be984..9c2a508671e06 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -81,7 +81,7 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
// CHECK-SAME: %[[PASS:.*]]: vector<2x[3]xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
+// CHECK: %[[INIT:.*]] = ub.poison : vector<2x[3]xf32>
// CHECK: %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex>
// CHECK: %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1>
// CHECK: %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32>
>From 8c4e7488a8aca41bac7a651d508f069b48509a6d Mon Sep 17 00:00:00 2001
From: Yang Bai <baiyang0132 at gmail.com>
Date: Thu, 14 Aug 2025 10:52:19 +0800
Subject: [PATCH 7/8] Update
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Co-authored-by: James Newling <james.newling at gmail.com>
---
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index fb8a5b436797d..0de435e4d77dd 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1745,7 +1745,7 @@ func.func @step() -> vector<4xindex> {
// vector.from_elements
//===----------------------------------------------------------------------===//
-// NOTE: For now, we unroll multi-dimensional from_elements ops with pattern `UnrollFromElements`
+// NOTE: We unroll multi-dimensional from_elements ops with pattern `UnrollFromElements`
// and then convert the 1-D from_elements ops to llvm.
// CHECK-LABEL: func @from_elements_3d
>From bc5ad743b5dce4f4adf76ee0d98db4669bcefe44 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 13 Aug 2025 19:58:16 -0700
Subject: [PATCH 8/8] refine according to comments from reviewers
---
.../Vector/Transforms/LoweringPatterns.h | 4 +--
.../mlir/Dialect/Vector/Utils/VectorUtils.h | 26 ++-----------------
.../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 +-
.../TransformOps/VectorTransformOps.cpp | 2 +-
...LowerVectorToFromElementsToShuffleTree.cpp | 2 +-
mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 26 +++++++++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 2 +-
7 files changed, 34 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 8c2cafe83c791..47f96112a9433 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -308,8 +308,8 @@ void populateVectorToFromElementsToShuffleTreePatterns(
/// [UnrollFromElements]
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
/// outermost dimension.
-void populateVectorFromElementsUnrollingPatterns(RewritePatternSet &patterns,
- PatternBenefit benefit = 1);
+void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Populate the pattern set with the following patterns:
///
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 8309cdde6ad76..2699d9acec00b 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -252,30 +252,8 @@ LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
using UnrollVectorOpFn =
function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
-template <typename VectorOpType>
-LogicalResult unrollVectorOp(VectorOpType op, PatternRewriter &rewriter,
- UnrollVectorOpFn unrollFn) {
- VectorType resultTy = op.getType();
- if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already 1-D");
-
- // Unrolling doesn't take vscale into account. Pattern is disabled for
- // vectors with leading scalable dim(s).
- if (resultTy.getScalableDims().front())
- return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
-
- Location loc = op.getLoc();
- Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
- VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
-
- for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
- Value subVector = unrollFn(rewriter, loc, subTy, i);
- result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
- }
-
- rewriter.replaceOp(op, result);
- return success();
-}
+LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
+ UnrollVectorOpFn unrollFn);
} // namespace vector
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7ac3bd4aee937..9852df6970fdc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,7 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
- populateVectorFromElementsUnrollingPatterns(patterns);
+ populateVectorFromElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index e6917c03d3b26..fe066dc04ad55 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -141,7 +141,7 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
- vector::populateVectorFromElementsUnrollingPatterns(patterns);
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
}
void transform::ApplyLowerScanPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
index 3ed81fecefc41..c82507cc09e23 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -780,7 +780,7 @@ void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns(
benefit);
}
-void mlir::vector::populateVectorFromElementsUnrollingPatterns(
+void mlir::vector::populateVectorFromElementsLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
}
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 10ed2bcfb35a3..e887bdf7b8709 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -391,3 +391,29 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
}
return success();
}
+
+LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
+ vector::UnrollVectorOpFn unrollFn) {
+ assert(op->getNumResults() == 1 && "expected single result");
+ assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
+ VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
+ if (resultTy.getRank() < 2)
+ return rewriter.notifyMatchFailure(op, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (resultTy.getScalableDims().front())
+ return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+ Location loc = op->getLoc();
+ Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
+ VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
+
+ for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+ Value subVector = unrollFn(rewriter, loc, subTy, i);
+ result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index dd35ad11e80ac..bb1598ee3efe5 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -803,7 +803,7 @@ struct TestUnrollVectorFromElements
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- populateVectorFromElementsUnrollingPatterns(patterns);
+ populateVectorFromElementsLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
More information about the Mlir-commits
mailing list