[Mlir-commits] [mlir] [mlir][vector] Follow-up improvements for multi-dimensional vector.from_elements support (PR #154664)

Yang Bai llvmlistbot at llvm.org
Fri Aug 22 00:01:53 PDT 2025


https://github.com/yangtetris updated https://github.com/llvm/llvm-project/pull/154664

>From 9471daa510a13aa76853ed64dceeccaa5c9f636c Mon Sep 17 00:00:00 2001
From: Yang Bai <yangb at nvidia.com>
Date: Wed, 20 Aug 2025 20:24:22 -0700
Subject: [PATCH 1/2] [mlir] support ND->1D flattening for vector.from_elements
 op

---
 .../Vector/Transforms/VectorLinearize.cpp     | 39 ++++++++++++++++++-
 mlir/test/Dialect/Vector/linearize.mlir       | 14 +++++++
 2 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 491b448e9e1e9..2cb6d47f37128 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -762,6 +762,42 @@ struct LinearizeVectorStore final
   }
 };
 
+/// This pattern linearizes `vector.from_elements` operations by converting
+/// the result type to a 1-D vector while preserving all element values.
+/// The transformation creates a linearized `vector.from_elements` followed by
+/// a `vector.shape_cast` to restore the original multidimensional shape.
+///
+/// Example:
+///
+///     %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
+///
+///   is converted to:
+///
+///     %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
+///     %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+///
+struct LinearizeVectorFromElements final
+    : public OpConversionPattern<vector::FromElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorFromElements(const TypeConverter &typeConverter,
+                              MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+  LogicalResult
+  matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType dstTy =
+        getTypeConverter()->convertType<VectorType>(fromElementsOp.getType());
+    assert(dstTy && "vector type destination expected.");
+
+    auto elements = fromElementsOp.getElements();
+    assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) &&
+           "expected same number of elements");
+    rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy,
+                                                        elements);
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
-           LinearizeVectorStore>(typeConverter, patterns.getContext());
+           LinearizeVectorStore, LinearizeVectorFromElements>(
+          typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 2e630bf93622e..5e8bfd0698b33 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -524,3 +524,17 @@ func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector
   vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
   return
 }
+
+// -----
+
+// Test pattern LinearizeVectorFromElements.
+
+// CHECK-LABEL: test_vector_from_elements
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32
+func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
+  // CHECK: %[[FROM_ELEMENTS:.*]] = vector.from_elements %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]] : vector<4xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[FROM_ELEMENTS]] : vector<4xf32> to vector<2x2xf32>
+  // CHECK: return %[[CAST]] : vector<2x2xf32>
+  %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+  return %1 : vector<2x2xf32>
+}

>From 5aa6bfa807d49b37e7b6689b426cd96bd2735419 Mon Sep 17 00:00:00 2001
From: Yang Bai <baiyang0132 at gmail.com>
Date: Fri, 22 Aug 2025 15:01:45 +0800
Subject: [PATCH 2/2] Update
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Co-authored-by: James Newling <james.newling at gmail.com>
---
 mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 2cb6d47f37128..1c8750e33c475 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -771,7 +771,7 @@ struct LinearizeVectorStore final
 ///
 ///     %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
 ///
-///   is converted to:
+/// is converted to:
 ///
 ///     %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
 ///     %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>



More information about the Mlir-commits mailing list