[Mlir-commits] [mlir] [mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering (PR #151175)

Yang Bai llvmlistbot at llvm.org
Wed Aug 13 20:25:34 PDT 2025


https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/151175

>From cf8174fcde57f8092409171ae19444886aca3625 Mon Sep 17 00:00:00 2001
From: yangtetris <baiyang0132 at gmail.com>
Date: Tue, 29 Jul 2025 23:41:22 +0800
Subject: [PATCH 1/8] [mlir] Support lowering multi-dim vectors in
 VectorFromElementsLowering

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 63 ++++++++++++++++---
 .../vector-to-llvm-interface.mlir             | 24 +++++++
 2 files changed, 79 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 17a79e3815b97..26d056cadb19c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1890,15 +1890,62 @@ struct VectorFromElementsLowering
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = fromElementsOp.getLoc();
     VectorType vectorType = fromElementsOp.getType();
-    // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
-    // Such ops should be handled in the same way as vector.insert.
-    if (vectorType.getRank() > 1)
-      return rewriter.notifyMatchFailure(fromElementsOp,
-                                         "rank > 1 vectors are not supported");
     Type llvmType = typeConverter->convertType(vectorType);
-    Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
-    for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
-      result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+    Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
+
+    Value result;
+    // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+    if (vectorType.getRank() == 0) {
+      result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+      auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
+      result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+      rewriter.replaceOp(fromElementsOp, result);
+      return success();
+    }
+    
+    // Build 1D vectors for the innermost dimension
+    int64_t innerDimSize = vectorType.getShape().back();
+    int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
+
+    SmallVector<Value> innerVectors;
+    innerVectors.reserve(numInnerVectors);
+
+    auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+    Type llvmInnerType = typeConverter->convertType(innerVectorType);
+
+    int64_t elementInVectorIdx = 0;
+    Value innerVector;
+    for (auto val : adaptor.getElements()) {
+      if (elementInVectorIdx == 0)
+        innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
+      auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
+      innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+      if (++elementInVectorIdx == innerDimSize) {
+        innerVectors.push_back(innerVector);
+        elementInVectorIdx = 0;
+      }
+    }
+
+    // For 1D vectors, we can just return the first innermost vector.
+    if (vectorType.getRank() == 1) {
+      rewriter.replaceOp(fromElementsOp, innerVectors.front());
+      return success();
+    }
+
+    // Now build the nested aggregate structure from these 1D vectors.
+    result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+    
+    // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
+    // insert the 1D vectors into the aggregate.
+    auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+    if (!vectorTypeInfo.llvmNDVectorTy)
+      return failure();
+    int64_t vectorIdx = 0;
+    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
+      result = LLVM::InsertValueOp::create(rewriter, loc, result, 
+                                           innerVectors[vectorIdx++], position);
+    });
+    
     rewriter.replaceOp(fromElementsOp, result);
     return success();
   }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 31e17fb3e3cc6..834858c0b7c8f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2286,6 +2286,30 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
 
 // -----
 
+// CHECK-LABEL: func.func @from_elements_3d(
+//  CHECK-SAME:     %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+//       CHECK:   %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+//       CHECK:   %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+//       CHECK:   %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:   %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+//       CHECK:   %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+//       CHECK:   %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+//       CHECK:   %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:   %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+//       CHECK:   %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:   %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:   %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+//       CHECK:   return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+  return %0 : vector<2x1x2xf32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.to_elements
 //===----------------------------------------------------------------------===//

>From 8fe493014531f8a227a090e82088393f6df185f8 Mon Sep 17 00:00:00 2001
From: yangtetris <baiyang0132 at gmail.com>
Date: Tue, 29 Jul 2025 23:53:07 +0800
Subject: [PATCH 2/8] fix code format

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 23 +++++++++++--------
 1 file changed, 14 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 26d056cadb19c..59a09be7738e8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1898,11 +1898,12 @@ struct VectorFromElementsLowering
     if (vectorType.getRank() == 0) {
       result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
       auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
-      result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+      result = LLVM::InsertElementOp::create(
+          rewriter, loc, result, adaptor.getElements().front(), index0);
       rewriter.replaceOp(fromElementsOp, result);
       return success();
     }
-    
+
     // Build 1D vectors for the innermost dimension
     int64_t innerDimSize = vectorType.getShape().back();
     int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
@@ -1910,7 +1911,8 @@ struct VectorFromElementsLowering
     SmallVector<Value> innerVectors;
     innerVectors.reserve(numInnerVectors);
 
-    auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+    auto innerVectorType =
+        VectorType::get(innerDimSize, vectorType.getElementType());
     Type llvmInnerType = typeConverter->convertType(innerVectorType);
 
     int64_t elementInVectorIdx = 0;
@@ -1918,8 +1920,10 @@ struct VectorFromElementsLowering
     for (auto val : adaptor.getElements()) {
       if (elementInVectorIdx == 0)
         innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
-      auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
-      innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+      auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType,
+                                               elementInVectorIdx);
+      innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType,
+                                                  innerVector, val, position);
       if (++elementInVectorIdx == innerDimSize) {
         innerVectors.push_back(innerVector);
         elementInVectorIdx = 0;
@@ -1934,18 +1938,19 @@ struct VectorFromElementsLowering
 
     // Now build the nested aggregate structure from these 1D vectors.
     result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
-    
+
     // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
     // insert the 1D vectors into the aggregate.
-    auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+    auto vectorTypeInfo =
+        LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
     if (!vectorTypeInfo.llvmNDVectorTy)
       return failure();
     int64_t vectorIdx = 0;
     nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
-      result = LLVM::InsertValueOp::create(rewriter, loc, result, 
+      result = LLVM::InsertValueOp::create(rewriter, loc, result,
                                            innerVectors[vectorIdx++], position);
     });
-    
+
     rewriter.replaceOp(fromElementsOp, result);
     return success();
   }

>From 44211421a2c3d0b8b09d19be6bcb1fd20b6fb1c9 Mon Sep 17 00:00:00 2001
From: Yang Bai <baiyang0132 at gmail.com>
Date: Thu, 31 Jul 2025 10:23:23 +0800
Subject: [PATCH 3/8] Fix typo

Co-authored-by: Nicolas Vasilache <Nico.Vasilache at amd.com>
---
 mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 59a09be7738e8..1006605bd9130 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1894,7 +1894,7 @@ struct VectorFromElementsLowering
     Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
 
     Value result;
-    // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+    // 0D vectors are converted to length-1 1D vectors by LLVMTypeConverter.
     if (vectorType.getRank() == 0) {
       result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
       auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);

>From eec412bc4b1b58f0a79b93561f5cc218ed1b824f Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 30 Jul 2025 23:53:52 -0700
Subject: [PATCH 4/8] refine

---
 .../Conversion/LLVMCommon/VectorPattern.h     |  7 ++++
 .../Conversion/LLVMCommon/VectorPattern.cpp   | 10 ++++++
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 33 ++++++++++---------
 3 files changed, 34 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 964281592cc65..36dcffc79974d 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -49,6 +49,13 @@ SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
 void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
                      function_ref<void(ArrayRef<int64_t>)> fun);
 
+// Overload that accepts VectorType directly and extracts type info internally.
+// Returns failure if the vector type info extraction fails.
+LogicalResult nDVectorIterate(VectorType vectorType,
+                              const LLVMTypeConverter &converter,
+                              OpBuilder &builder,
+                              function_ref<void(ArrayRef<int64_t>)> fun);
+
 LogicalResult handleMultidimensionalVectors(
     Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
     std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index e7dd0b506e12d..adc7c9f1551e7 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -77,6 +77,16 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
   }
 }
 
+LogicalResult LLVM::detail::nDVectorIterate(
+    VectorType vectorType, const LLVMTypeConverter &converter,
+    OpBuilder &builder, function_ref<void(ArrayRef<int64_t>)> fun) {
+  auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, converter);
+  if (!vectorTypeInfo.llvmNDVectorTy)
+    return failure();
+  nDVectorIterate(vectorTypeInfo, builder, fun);
+  return success();
+}
+
 LogicalResult LLVM::detail::handleMultidimensionalVectors(
     Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
     std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 1006605bd9130..137cc7a14c7e0 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1904,8 +1904,10 @@ struct VectorFromElementsLowering
       return success();
     }
 
-    // Build 1D vectors for the innermost dimension
+    // Build 1D vectors for the innermost dimension.
     int64_t innerDimSize = vectorType.getShape().back();
+    assert(vectorType.getNumElements() % innerDimSize == 0 &&
+           "innerDimSize must divide vectorType.getNumElements()");
     int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
 
     SmallVector<Value> innerVectors;
@@ -1915,23 +1917,23 @@ struct VectorFromElementsLowering
         VectorType::get(innerDimSize, vectorType.getElementType());
     Type llvmInnerType = typeConverter->convertType(innerVectorType);
 
-    int64_t elementInVectorIdx = 0;
     Value innerVector;
-    for (auto val : adaptor.getElements()) {
+    for (auto [elemIdx, val] : llvm::enumerate(adaptor.getElements())) {
+      int64_t elementInVectorIdx = elemIdx % innerDimSize;
       if (elementInVectorIdx == 0)
         innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
       auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType,
                                                elementInVectorIdx);
       innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType,
                                                   innerVector, val, position);
-      if (++elementInVectorIdx == innerDimSize) {
+      if (elementInVectorIdx == innerDimSize - 1)
         innerVectors.push_back(innerVector);
-        elementInVectorIdx = 0;
-      }
     }
 
     // For 1D vectors, we can just return the first innermost vector.
     if (vectorType.getRank() == 1) {
+      assert(innerVectors.size() == 1 &&
+             "for 1D vectors, innerVectors should have exactly one element");
       rewriter.replaceOp(fromElementsOp, innerVectors.front());
       return success();
     }
@@ -1939,17 +1941,16 @@ struct VectorFromElementsLowering
     // Now build the nested aggregate structure from these 1D vectors.
     result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
 
-    // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
-    // insert the 1D vectors into the aggregate.
-    auto vectorTypeInfo =
-        LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
-    if (!vectorTypeInfo.llvmNDVectorTy)
-      return failure();
+    // Iterate over each position of the first n-1 dimensions and insert the 1D
+    // vectors into the aggregate.
     int64_t vectorIdx = 0;
-    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
-      result = LLVM::InsertValueOp::create(rewriter, loc, result,
-                                           innerVectors[vectorIdx++], position);
-    });
+    if (failed(LLVM::detail::nDVectorIterate(
+            vectorType, *getTypeConverter(), rewriter,
+            [&](ArrayRef<int64_t> position) {
+              result = LLVM::InsertValueOp::create(
+                  rewriter, loc, result, innerVectors[vectorIdx++], position);
+            })))
+      return failure();
 
     rewriter.replaceOp(fromElementsOp, result);
     return success();

>From d23ddd2d75db8ce51038399f03950245ad3f6723 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Tue, 12 Aug 2025 08:30:12 -0700
Subject: [PATCH 5/8] re-implmentated with unrolling transformation

---
 .../Conversion/LLVMCommon/VectorPattern.h     |  7 --
 .../Vector/TransformOps/VectorTransformOps.td | 11 +++
 .../Vector/Transforms/LoweringPatterns.h      |  8 +++
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   | 39 ++++++++++
 .../Conversion/LLVMCommon/VectorPattern.cpp   | 10 ---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 71 ++++---------------
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  1 +
 .../TransformOps/VectorTransformOps.cpp       |  5 ++
 .../Vector/Transforms/LowerVectorGather.cpp   | 33 +++------
 ...LowerVectorToFromElementsToShuffleTree.cpp | 42 +++++++++++
 .../vector-to-llvm-interface.mlir             | 24 -------
 .../VectorToLLVM/vector-to-llvm.mlir          | 37 ++++++++++
 .../Vector/vector-from-elements-lowering.mlir | 45 ++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   | 24 +++++++
 .../python/dialects/transform_vector_ext.py   |  2 +
 15 files changed, 234 insertions(+), 125 deletions(-)
 create mode 100644 mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir

diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 36dcffc79974d..964281592cc65 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -49,13 +49,6 @@ SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
 void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
                      function_ref<void(ArrayRef<int64_t>)> fun);
 
-// Overload that accepts VectorType directly and extracts type info internally.
-// Returns failure if the vector type info extraction fails.
-LogicalResult nDVectorIterate(VectorType vectorType,
-                              const LLVMTypeConverter &converter,
-                              OpBuilder &builder,
-                              function_ref<void(ArrayRef<int64_t>)> fun);
-
 LogicalResult handleMultidimensionalVectors(
     Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
     std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 299f198e4ab9c..07a4117a37b2c 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -254,6 +254,17 @@ def ApplyLowerGatherPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyUnrollFromElementsPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.vector.unroll_from_elements",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector from_elements operations should be unrolled
+    along the outermost dimension.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 def ApplyLowerScanPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_scan",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index e03f0dabece52..8c2cafe83c791 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -303,6 +303,14 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
 void populateVectorToFromElementsToShuffleTreePatterns(
     RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollFromElements]
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension.
+void populateVectorFromElementsUnrollingPatterns(RewritePatternSet &patterns,
+                                                 PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 7cd70e42d363c..8309cdde6ad76 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -238,6 +239,44 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
 ///      static sizes in `shape`.
 LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
                                        ArrayRef<int64_t> inputVectorSizes);
+
+/// Generic utility for unrolling n-D vector operations to (n-1)-D operations.
+/// This handles the common pattern of:
+/// 1. Check if already 1-D. If so, return failure.
+/// 2. Check for scalable dimensions. If so, return failure.
+/// 3. Create poison initialized result.
+/// 4. Loop through the outermost dimension, execute the UnrollVectorOpFn to
+/// create sub vectors.
+/// 5. Insert the sub vectors back into the final vector.
+/// 6. Replace the original op with the new result.
+using UnrollVectorOpFn =
+    function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
+
+template <typename VectorOpType>
+LogicalResult unrollVectorOp(VectorOpType op, PatternRewriter &rewriter,
+                             UnrollVectorOpFn unrollFn) {
+  VectorType resultTy = op.getType();
+  if (resultTy.getRank() < 2)
+    return rewriter.notifyMatchFailure(op, "already 1-D");
+
+  // Unrolling doesn't take vscale into account. Pattern is disabled for
+  // vectors with leading scalable dim(s).
+  if (resultTy.getScalableDims().front())
+    return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+  Location loc = op.getLoc();
+  Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
+  VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
+
+  for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+    Value subVector = unrollFn(rewriter, loc, subTy, i);
+    result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+  }
+
+  rewriter.replaceOp(op, result);
+  return success();
+}
+
 } // namespace vector
 
 /// Constructs a permutation map of invariant memref indices to vector
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index adc7c9f1551e7..e7dd0b506e12d 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -77,16 +77,6 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
   }
 }
 
-LogicalResult LLVM::detail::nDVectorIterate(
-    VectorType vectorType, const LLVMTypeConverter &converter,
-    OpBuilder &builder, function_ref<void(ArrayRef<int64_t>)> fun) {
-  auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, converter);
-  if (!vectorTypeInfo.llvmNDVectorTy)
-    return failure();
-  nDVectorIterate(vectorTypeInfo, builder, fun);
-  return success();
-}
-
 LogicalResult LLVM::detail::handleMultidimensionalVectors(
     Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
     std::function<Value(Type, ValueRange)> createOperand,
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 137cc7a14c7e0..b44df3f0320e8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1890,68 +1890,21 @@ struct VectorFromElementsLowering
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = fromElementsOp.getLoc();
     VectorType vectorType = fromElementsOp.getType();
+    // Only support 1-D vectors. Multi-dimensional vectors should have been
+    // transformed to 1-D vectors by the vector-to-vector transformations before
+    // this.
+    if (vectorType.getRank() > 1)
+      return rewriter.notifyMatchFailure(fromElementsOp,
+                                         "rank > 1 vectors are not supported");
     Type llvmType = typeConverter->convertType(vectorType);
     Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
-
-    Value result;
-    // 0D vectors are converted to length-1 1D vectors by LLVMTypeConverter.
-    if (vectorType.getRank() == 0) {
-      result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
-      auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
-      result = LLVM::InsertElementOp::create(
-          rewriter, loc, result, adaptor.getElements().front(), index0);
-      rewriter.replaceOp(fromElementsOp, result);
-      return success();
+    Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+    for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
+      auto constIdx =
+          LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
+      result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
+                                             val, constIdx);
     }
-
-    // Build 1D vectors for the innermost dimension.
-    int64_t innerDimSize = vectorType.getShape().back();
-    assert(vectorType.getNumElements() % innerDimSize == 0 &&
-           "innerDimSize must divide vectorType.getNumElements()");
-    int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
-
-    SmallVector<Value> innerVectors;
-    innerVectors.reserve(numInnerVectors);
-
-    auto innerVectorType =
-        VectorType::get(innerDimSize, vectorType.getElementType());
-    Type llvmInnerType = typeConverter->convertType(innerVectorType);
-
-    Value innerVector;
-    for (auto [elemIdx, val] : llvm::enumerate(adaptor.getElements())) {
-      int64_t elementInVectorIdx = elemIdx % innerDimSize;
-      if (elementInVectorIdx == 0)
-        innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
-      auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType,
-                                               elementInVectorIdx);
-      innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType,
-                                                  innerVector, val, position);
-      if (elementInVectorIdx == innerDimSize - 1)
-        innerVectors.push_back(innerVector);
-    }
-
-    // For 1D vectors, we can just return the first innermost vector.
-    if (vectorType.getRank() == 1) {
-      assert(innerVectors.size() == 1 &&
-             "for 1D vectors, innerVectors should have exactly one element");
-      rewriter.replaceOp(fromElementsOp, innerVectors.front());
-      return success();
-    }
-
-    // Now build the nested aggregate structure from these 1D vectors.
-    result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
-
-    // Iterate over each position of the first n-1 dimensions and insert the 1D
-    // vectors into the aggregate.
-    int64_t vectorIdx = 0;
-    if (failed(LLVM::detail::nDVectorIterate(
-            vectorType, *getTypeConverter(), rewriter,
-            [&](ArrayRef<int64_t> position) {
-              result = LLVM::InsertValueOp::create(
-                  rewriter, loc, result, innerVectors[vectorIdx++], position);
-            })))
-      return failure();
-
     rewriter.replaceOp(fromElementsOp, result);
     return success();
   }
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index cf108690c3741..7ac3bd4aee937 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
+    populateVectorFromElementsUnrollingPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
         arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 2d5cc070558c3..e6917c03d3b26 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
   vector::populateVectorGatherLoweringPatterns(patterns);
 }
 
+void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::populateVectorFromElementsUnrollingPatterns(patterns);
+}
+
 void transform::ApplyLowerScanPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   vector::populateVectorScanLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index e062f55f87679..90f21c53246b0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
 
   LogicalResult matchAndRewrite(vector::GatherOp op,
                                 PatternRewriter &rewriter) const override {
-    VectorType resultTy = op.getType();
-    if (resultTy.getRank() < 2)
-      return rewriter.notifyMatchFailure(op, "already 1-D");
-
-    // Unrolling doesn't take vscale into account. Pattern is disabled for
-    // vectors with leading scalable dim(s).
-    if (resultTy.getScalableDims().front())
-      return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
-
-    Location loc = op.getLoc();
     Value indexVec = op.getIndexVec();
     Value maskVec = op.getMask();
     Value passThruVec = op.getPassThru();
 
-    Value result = arith::ConstantOp::create(rewriter, loc, resultTy,
-                                             rewriter.getZeroAttr(resultTy));
-
-    VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
-
-    for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
-      int64_t thisIdx[1] = {i};
+    auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
+                              VectorType subTy, int64_t index) {
+      int64_t thisIdx[1] = {index};
 
       Value indexSubVec =
           vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
@@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
           vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
       Value passThruSubVec =
           vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
-      Value subGather = vector::GatherOp::create(
-          rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
-          maskSubVec, passThruSubVec);
-      result =
-          vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
-    }
+      return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
+                                      op.getIndices(), indexSubVec, maskSubVec,
+                                      passThruSubVec);
+    };
 
-    rewriter.replaceOp(op, result);
-    return success();
+    return unrollVectorOp(op, rewriter, unrollGatherFn);
   }
 };
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
index 6407a868abd85..3ed81fecefc41 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -735,6 +735,43 @@ struct LowerVectorToFromElementsToShuffleTreePass
   }
 };
 
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
+///
+/// ==>
+///
+/// %0   = ub.poison : vector<2x3xf32>
+/// %v0  = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
+/// %1   = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %v1  = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
+/// %v   = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::FromElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    ValueRange allElements = op.getElements();
+
+    auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
+                                    VectorType subTy, int64_t index) {
+      size_t subTyNumElements = subTy.getNumElements();
+      assert((index + 1) * subTyNumElements <= allElements.size() &&
+             "out of bounds");
+      ValueRange subElements =
+          allElements.slice(index * subTyNumElements, subTyNumElements);
+      return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
+    };
+
+    return unrollVectorOp(op, rewriter, unrollFromElementsFn);
+  }
+};
+
 } // namespace
 
 void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns(
@@ -742,3 +779,8 @@ void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns(
   patterns.add<ToFromElementsToShuffleTreeRewrite>(patterns.getContext(),
                                                    benefit);
 }
+
+void mlir::vector::populateVectorFromElementsUnrollingPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 834858c0b7c8f..31e17fb3e3cc6 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2286,30 +2286,6 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
 
 // -----
 
-// CHECK-LABEL: func.func @from_elements_3d(
-//  CHECK-SAME:     %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
-//       CHECK:   %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
-//       CHECK:   %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
-//       CHECK:   %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
-//       CHECK:   %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
-//       CHECK:   %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
-//       CHECK:   %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
-//       CHECK:   %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
-//       CHECK:   %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
-//       CHECK:   %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
-//       CHECK:   %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
-//       CHECK:   %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
-//       CHECK:   %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
-//       CHECK:   %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
-//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
-//       CHECK:   return %[[CAST]]
-func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
-  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
-  return %0 : vector<2x1x2xf32>
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // vector.to_elements
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 72810b5dddaa3..fb8a5b436797d 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1737,3 +1737,40 @@ func.func @step() -> vector<4xindex> {
   %0 = vector.step : vector<4xindex>
   return %0 : vector<4xindex>
 }
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.from_elements
+//===----------------------------------------------------------------------===//
+
+// NOTE: For now, we unroll multi-dimensional from_elements ops with pattern `UnrollFromElements`
+// and then convert the 1-D from_elements ops to llvm.
+
+// CHECK-LABEL: func @from_elements_3d
+//  CHECK-SAME:  %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+//       CHECK:  %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
+//       CHECK:  %[[UNDEF_RES_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_RES]] : vector<2x1x2xf32> to !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:  %[[UNDEF_VEC_RANK_2:.*]] = ub.poison : vector<1x2xf32>
+//       CHECK:  %[[UNDEF_VEC_RANK_2_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_VEC_RANK_2]] : vector<1x2xf32> to !llvm.array<1 x vector<2xf32>>
+//       CHECK:  %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+//       CHECK:  %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:  %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+//       CHECK:  %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:  %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+//       CHECK:  %[[RES_RANK_2_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>>
+//       CHECK:  %[[RES_0:.*]] = llvm.insertvalue %[[RES_RANK_2_0]], %[[UNDEF_RES_LLVM]][0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:  %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+//       CHECK:  %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:  %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+//       CHECK:  %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:  %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+//       CHECK:  %[[RES_RANK_2_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>>
+//       CHECK:  %[[RES_1:.*]] = llvm.insertvalue %[[RES_RANK_2_1]], %[[RES_0]][1] : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:  %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+//       CHECK:  return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+  return %0 : vector<2x1x2xf32>
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir
new file mode 100644
index 0000000000000..1c2e07086d093
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir
@@ -0,0 +1,45 @@
+// RUN: mlir-opt %s -test-unroll-vector-from-elements | FileCheck %s --check-prefix=CHECK-UNROLL
+
+//===----------------------------------------------------------------------===//
+// Test UnrollFromElements.
+//===----------------------------------------------------------------------===//
+
+// CHECK-UNROLL-LABEL: @unroll_from_elements_2d
+// CHECK-UNROLL-SAME:    (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+// CHECK-UNROLL-NEXT:    %[[UNDEF_RES:.*]] = ub.poison : vector<2x2xf32>
+// CHECK-UNROLL-NEXT:    %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+// CHECK-UNROLL-NEXT:    %[[RES_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RES]] [0] : vector<2xf32> into vector<2x2xf32>
+// CHECK-UNROLL-NEXT:    %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+// CHECK-UNROLL-NEXT:    %[[RES_1:.*]] = vector.insert %[[VEC_1]], %[[RES_0]] [1] : vector<2xf32> into vector<2x2xf32>
+// CHECK-UNROLL-NEXT:    return %[[RES_1]] : vector<2x2xf32>
+func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
+
+// CHECK-UNROLL-LABEL: @unroll_from_elements_3d
+// CHECK-UNROLL-SAME:    (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+// CHECK-UNROLL-NEXT:    %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32>
+// CHECK-UNROLL-NEXT:    %[[UNDEF_RANK_2:.*]] = ub.poison : vector<1x2xf32>
+// CHECK-UNROLL-NEXT:    %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+// CHECK-UNROLL-NEXT:    %[[RANK_2_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
+// CHECK-UNROLL-NEXT:    %[[RES_0:.*]] = vector.insert %[[RANK_2_0]], %[[UNDEF_RES]] [0] : vector<1x2xf32> into vector<2x1x2xf32>
+// CHECK-UNROLL-NEXT:    %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+// CHECK-UNROLL-NEXT:    %[[RANK_2_1:.*]] = vector.insert %[[VEC_1]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32>
+// CHECK-UNROLL-NEXT:    %[[RES_1:.*]] = vector.insert %[[RANK_2_1]], %[[RES_0]] [1] : vector<1x2xf32> into vector<2x1x2xf32>
+// CHECK-UNROLL-NEXT:    return %[[RES_1]] : vector<2x1x2xf32>
+func.func @unroll_from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+  return %0 : vector<2x1x2xf32>
+}
+
+// 1-D vector.from_elements should not be unrolled.
+
+// CHECK-UNROLL-LABEL: @negative_unroll_from_elements_1d
+// CHECK-UNROLL-SAME:    (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK-UNROLL-NEXT:         %[[RES:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+// CHECK-UNROLL-NEXT:    return %[[RES]] : vector<2xf32>
+func.func @negative_unroll_from_elements_1d(%arg0: f32, %arg1: f32) -> vector<2xf32> {
+  %0 = vector.from_elements %arg0, %arg1 : vector<2xf32>
+  return %0 : vector<2xf32>
+}
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index f89c944b5c564..dd35ad11e80ac 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -786,6 +786,28 @@ struct TestVectorGatherLowering
   }
 };
 
+struct TestUnrollVectorFromElements
+    : public PassWrapper<TestUnrollVectorFromElements,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorFromElements)
+
+  StringRef getArgument() const final {
+    return "test-unroll-vector-from-elements";
+  }
+  StringRef getDescription() const final {
+    return "Test unrolling patterns for from_elements ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<func::FuncDialect, vector::VectorDialect, ub::UBDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorFromElementsUnrollingPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFoldArithExtensionIntoVectorContractPatterns
     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1059,6 +1081,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorGatherLowering>();
 
+  PassRegistration<TestUnrollVectorFromElements>();
+
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index a51f2154d1f7d..5a648fe073315 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -46,6 +46,8 @@ def non_configurable_patterns():
     vector.ApplyLowerOuterProductPatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_gather
     vector.ApplyLowerGatherPatternsOp()
+    # CHECK: transform.apply_patterns.vector.unroll_from_elements
+    vector.ApplyUnrollFromElementsPatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_scan
     vector.ApplyLowerScanPatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_shape_cast

>From 1a5b07570337ed21f008139351f5abdae4a6a3cd Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Tue, 12 Aug 2025 09:24:28 -0700
Subject: [PATCH 6/8] fix test

---
 mlir/test/Dialect/Vector/vector-gather-lowering.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 5be267c1be984..9c2a508671e06 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -81,7 +81,7 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
 // CHECK-SAME:      %[[PASS:.*]]: vector<2x[3]xf32>
 // CHECK:         %[[C0:.*]] = arith.constant 0 : index
 // CHECK:         %[[C1:.*]] = arith.constant 1 : index
-// CHECK:         %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
+// CHECK:         %[[INIT:.*]] = ub.poison : vector<2x[3]xf32>
 // CHECK:         %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex>
 // CHECK:         %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1>
 // CHECK:         %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32>

>From 8c4e7488a8aca41bac7a651d508f069b48509a6d Mon Sep 17 00:00:00 2001
From: Yang Bai <baiyang0132 at gmail.com>
Date: Thu, 14 Aug 2025 10:52:19 +0800
Subject: [PATCH 7/8] Update
 mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Co-authored-by: James Newling <james.newling at gmail.com>
---
 mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index fb8a5b436797d..0de435e4d77dd 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1745,7 +1745,7 @@ func.func @step() -> vector<4xindex> {
 // vector.from_elements
 //===----------------------------------------------------------------------===//
 
-// NOTE: For now, we unroll multi-dimensional from_elements ops with pattern `UnrollFromElements`
+// NOTE: We unroll multi-dimensional from_elements ops with pattern `UnrollFromElements`
 // and then convert the 1-D from_elements ops to llvm.
 
 // CHECK-LABEL: func @from_elements_3d

>From bc5ad743b5dce4f4adf76ee0d98db4669bcefe44 Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 13 Aug 2025 19:58:16 -0700
Subject: [PATCH 8/8] refine according to comments from reviewers

---
 .../Vector/Transforms/LoweringPatterns.h      |  4 +--
 .../mlir/Dialect/Vector/Utils/VectorUtils.h   | 26 ++-----------------
 .../VectorToLLVM/ConvertVectorToLLVMPass.cpp  |  2 +-
 .../TransformOps/VectorTransformOps.cpp       |  2 +-
 ...LowerVectorToFromElementsToShuffleTree.cpp |  2 +-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 26 +++++++++++++++++++
 .../Dialect/Vector/TestVectorTransforms.cpp   |  2 +-
 7 files changed, 34 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 8c2cafe83c791..47f96112a9433 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -308,8 +308,8 @@ void populateVectorToFromElementsToShuffleTreePatterns(
 /// [UnrollFromElements]
 /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
 /// outermost dimension.
-void populateVectorFromElementsUnrollingPatterns(RewritePatternSet &patterns,
-                                                 PatternBenefit benefit = 1);
+void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
+                                                PatternBenefit benefit = 1);
 
 /// Populate the pattern set with the following patterns:
 ///
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 8309cdde6ad76..2699d9acec00b 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -252,30 +252,8 @@ LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
 using UnrollVectorOpFn =
     function_ref<Value(PatternRewriter &, Location, VectorType, int64_t)>;
 
-template <typename VectorOpType>
-LogicalResult unrollVectorOp(VectorOpType op, PatternRewriter &rewriter,
-                             UnrollVectorOpFn unrollFn) {
-  VectorType resultTy = op.getType();
-  if (resultTy.getRank() < 2)
-    return rewriter.notifyMatchFailure(op, "already 1-D");
-
-  // Unrolling doesn't take vscale into account. Pattern is disabled for
-  // vectors with leading scalable dim(s).
-  if (resultTy.getScalableDims().front())
-    return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
-
-  Location loc = op.getLoc();
-  Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
-  VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
-
-  for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
-    Value subVector = unrollFn(rewriter, loc, subTy, i);
-    result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
-  }
-
-  rewriter.replaceOp(op, result);
-  return success();
-}
+LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter,
+                             UnrollVectorOpFn unrollFn);
 
 } // namespace vector
 
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7ac3bd4aee937..9852df6970fdc 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,7 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorStepLoweringPatterns(patterns);
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
-    populateVectorFromElementsUnrollingPatterns(patterns);
+    populateVectorFromElementsLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
         arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index e6917c03d3b26..fe066dc04ad55 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -141,7 +141,7 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
 
 void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  vector::populateVectorFromElementsUnrollingPatterns(patterns);
+  vector::populateVectorFromElementsLoweringPatterns(patterns);
 }
 
 void transform::ApplyLowerScanPatternsOp::populatePatterns(
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
index 3ed81fecefc41..c82507cc09e23 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -780,7 +780,7 @@ void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns(
                                                    benefit);
 }
 
-void mlir::vector::populateVectorFromElementsUnrollingPatterns(
+void mlir::vector::populateVectorFromElementsLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
 }
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 10ed2bcfb35a3..e887bdf7b8709 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -391,3 +391,29 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
   }
   return success();
 }
+
+LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
+                                     vector::UnrollVectorOpFn unrollFn) {
+  assert(op->getNumResults() == 1 && "expected single result");
+  assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
+  VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
+  if (resultTy.getRank() < 2)
+    return rewriter.notifyMatchFailure(op, "already 1-D");
+
+  // Unrolling doesn't take vscale into account. Pattern is disabled for
+  // vectors with leading scalable dim(s).
+  if (resultTy.getScalableDims().front())
+    return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+  Location loc = op->getLoc();
+  Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
+  VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
+
+  for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+    Value subVector = unrollFn(rewriter, loc, subTy, i);
+    result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+  }
+
+  rewriter.replaceOp(op, result);
+  return success();
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index dd35ad11e80ac..bb1598ee3efe5 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -803,7 +803,7 @@ struct TestUnrollVectorFromElements
 
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    populateVectorFromElementsUnrollingPatterns(patterns);
+    populateVectorFromElementsLoweringPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 };



More information about the Mlir-commits mailing list