[Mlir-commits] [mlir] [mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering (PR #151175)

Yang Bai llvmlistbot at llvm.org
Tue Jul 29 08:48:47 PDT 2025


https://github.com/yangtetris created https://github.com/llvm/llvm-project/pull/151175

This patch extends the `VectorFromElementsLowering` conversion pattern to support 
vectors of any rank, removing the previous restriction to 1D vectors only.

**Implementation Details:**
1. **0D vectors**: Handled explicitly since LLVMTypeConverter converts them to 
   length-1 1D vectors
2. **1D vectors**: Direct construction using llvm.insertelement operations
3. **N-D vectors**: Two-phase construction:
   - Build 1D vectors for the innermost dimension using `llvm.insertelement`
   - Assemble them into the nested aggregate structure using `llvm.insertvalue` 
     and `nDVectorIterate`
4. Use direct LLVM dialect operations instead of intermediate vector.insert operations for efficiency

**Example:**
```mlir
// Before: Failed for rank > 1
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32>

// After: Converts to nested aggregate
%poison = llvm.mlir.poison : !llvm.array<2 x vector<2xf32>>
%inner0 = llvm.insertelement %e0, %poison_1d[%c0] : vector<2xf32>
%inner0 = llvm.insertelement %e1, %inner0[%c1] : vector<2xf32>
%inner1 = llvm.insertelement %e2, %poison_1d[%c0] : vector<2xf32>
%inner1 = llvm.insertelement %e3, %inner1[%c1] : vector<2xf32>
%result = llvm.insertvalue %inner0, %poison[0] : !llvm.array<2 x vector<2xf32>>
%result = llvm.insertvalue %inner1, %result[1] : !llvm.array<2 x vector<2xf32>>
```

>From cf8174fcde57f8092409171ae19444886aca3625 Mon Sep 17 00:00:00 2001
From: yangtetris <baiyang0132 at gmail.com>
Date: Tue, 29 Jul 2025 23:41:22 +0800
Subject: [PATCH] [mlir] Support lowering multi-dim vectors in
 VectorFromElementsLowering

---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 63 ++++++++++++++++---
 .../vector-to-llvm-interface.mlir             | 24 +++++++
 2 files changed, 79 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 17a79e3815b97..26d056cadb19c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1890,15 +1890,62 @@ struct VectorFromElementsLowering
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = fromElementsOp.getLoc();
     VectorType vectorType = fromElementsOp.getType();
-    // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
-    // Such ops should be handled in the same way as vector.insert.
-    if (vectorType.getRank() > 1)
-      return rewriter.notifyMatchFailure(fromElementsOp,
-                                         "rank > 1 vectors are not supported");
     Type llvmType = typeConverter->convertType(vectorType);
-    Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
-    for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
-      result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+    Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
+
+    Value result;
+    // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+    if (vectorType.getRank() == 0) {
+      result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+      auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
+      result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+      rewriter.replaceOp(fromElementsOp, result);
+      return success();
+    }
+    
+    // Build 1D vectors for the innermost dimension
+    int64_t innerDimSize = vectorType.getShape().back();
+    int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
+
+    SmallVector<Value> innerVectors;
+    innerVectors.reserve(numInnerVectors);
+
+    auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+    Type llvmInnerType = typeConverter->convertType(innerVectorType);
+
+    int64_t elementInVectorIdx = 0;
+    Value innerVector;
+    for (auto val : adaptor.getElements()) {
+      if (elementInVectorIdx == 0)
+        innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
+      auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
+      innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+      if (++elementInVectorIdx == innerDimSize) {
+        innerVectors.push_back(innerVector);
+        elementInVectorIdx = 0;
+      }
+    }
+
+    // For 1D vectors, we can just return the first innermost vector.
+    if (vectorType.getRank() == 1) {
+      rewriter.replaceOp(fromElementsOp, innerVectors.front());
+      return success();
+    }
+
+    // Now build the nested aggregate structure from these 1D vectors.
+    result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+    
+    // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
+    // insert the 1D vectors into the aggregate.
+    auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+    if (!vectorTypeInfo.llvmNDVectorTy)
+      return failure();
+    int64_t vectorIdx = 0;
+    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
+      result = LLVM::InsertValueOp::create(rewriter, loc, result, 
+                                           innerVectors[vectorIdx++], position);
+    });
+    
     rewriter.replaceOp(fromElementsOp, result);
     return success();
   }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 31e17fb3e3cc6..834858c0b7c8f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2286,6 +2286,30 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
 
 // -----
 
+// CHECK-LABEL: func.func @from_elements_3d(
+//  CHECK-SAME:     %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+//       CHECK:   %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+//       CHECK:   %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+//       CHECK:   %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:   %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+//       CHECK:   %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+//       CHECK:   %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+//       CHECK:   %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:   %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+//       CHECK:   %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:   %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:   %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+//       CHECK:   return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+  return %0 : vector<2x1x2xf32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.to_elements
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list