[Mlir-commits] [mlir] [mlir][Vector] Support scalar 'vector.insert' in vector linearization (PR #146954)
Diego Caballero
llvmlistbot at llvm.org
Thu Jul 3 12:43:26 PDT 2025
https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/146954
This PR add support for linearizing the insertion of a scalar element by just linearizing the `vector.insert` op.
>From e7a32be59e78d022073c67dcaa7edc913d845e0e Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Thu, 3 Jul 2025 04:19:28 +0000
Subject: [PATCH] [mlir][Vector] Support scalar 'vector.insert' in vector
linearization
Inserting a scalar element was not supported.
---
.../Vector/Transforms/VectorLinearize.cpp | 83 ++++++++++++-------
mlir/test/Dialect/Vector/linearize.mlir | 16 +++-
2 files changed, 70 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..7cac1cbafdd64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -445,16 +445,36 @@ struct LinearizeVectorExtract final
}
};
-/// This pattern converts the InsertOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-/// vector.insert %source %destination [ position ]
-/// is converted to :
-/// %source_1d = vector.shape_cast %source
-/// %destination_1d = vector.shape_cast %destination
-/// %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
-/// ] %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original insert.
+/// This pattern linearizes `vector.insert` operations. It generates a 1-D
+/// version of the `vector.insert` operation when inserting a scalar into a
+/// vector. It generates a 1-D `vector.shuffle` operation when inserting a
+/// vector into another vector.
+///
+/// Example #1:
+///
+/// %0 = vector.insert %source, %destination[0] :
+/// vector<2x4xf32> into vector<2x2x4xf32>
+///
+/// is converted to:
+///
+/// %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32>
+/// %1 = vector.shape_cast %destination :
+/// vector<2x2x4xf32> to vector<16xf32>
+/// %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23
+/// 8, 9, 10, 11, 12, 13, 14, 15] :
+/// vector<16xf32>, vector<8xf32>
+/// %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32>
+///
+/// Example #2:
+///
+/// %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32>
+///
+/// is converted to:
+///
+/// %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32>
+/// %1 = vector.insert %source, %0[6]: f32 into vector<8xf32>
+/// %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32>
+///
struct LinearizeVectorInsert final
: public OpConversionPattern<vector::InsertOp> {
using OpConversionPattern::OpConversionPattern;
@@ -468,25 +488,19 @@ struct LinearizeVectorInsert final
insertOp.getDestVectorType());
assert(dstTy && "vector type destination expected.");
- // dynamic position is not supported
+ // Dynamic position is not supported.
if (insertOp.hasDynamicPosition())
return rewriter.notifyMatchFailure(insertOp,
"dynamic position is not supported.");
auto srcTy = insertOp.getValueToStoreType();
auto srcAsVec = dyn_cast<VectorType>(srcTy);
- uint64_t srcSize = 0;
- if (srcAsVec) {
- srcSize = srcAsVec.getNumElements();
- } else {
- return rewriter.notifyMatchFailure(insertOp,
- "scalars are not supported.");
- }
+ uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1;
auto dstShape = insertOp.getDestVectorType().getShape();
const auto dstSize = insertOp.getDestVectorType().getNumElements();
auto dstSizeForOffsets = dstSize;
- // compute linearized offset
+ // Compute linearized offset.
int64_t linearizedOffset = 0;
auto offsetsNd = insertOp.getStaticPosition();
for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
@@ -494,22 +508,35 @@ struct LinearizeVectorInsert final
linearizedOffset += offset * dstSizeForOffsets;
}
+ Location loc = insertOp.getLoc();
+ Value valueToStore = adaptor.getValueToStore();
+
+ if (!isa<VectorType>(valueToStore.getType())) {
+ // Scalar case: generate a 1-D insert.
+ Value result = rewriter.createOrFold<vector::InsertOp>(
+ loc, valueToStore, adaptor.getDest(), linearizedOffset);
+ rewriter.replaceOp(insertOp, result);
+ return success();
+ }
+
+ // Vector case: generate a shuffle.
llvm::SmallVector<int64_t, 2> indices(dstSize);
auto *origValsUntil = indices.begin();
std::advance(origValsUntil, linearizedOffset);
- std::iota(indices.begin(), origValsUntil,
- 0); // original values that remain [0, offset)
+
+ // Original values that remain [0, offset).
+ std::iota(indices.begin(), origValsUntil, 0);
auto *newValsUntil = origValsUntil;
std::advance(newValsUntil, srcSize);
- std::iota(origValsUntil, newValsUntil,
- dstSize); // new values [offset, offset+srcNumElements)
- std::iota(newValsUntil, indices.end(),
- linearizedOffset + srcSize); // the rest of original values
- // [offset+srcNumElements, end)
+ // New values [offset, offset+srcNumElements).
+ std::iota(origValsUntil, newValsUntil, dstSize);
+ // The rest of original values [offset+srcNumElements, end);
+ std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize);
- rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
- insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
+ Value result = rewriter.createOrFold<vector::ShuffleOp>(
+ loc, dstTy, adaptor.getDest(), valueToStore, indices);
+ rewriter.replaceOp(insertOp, result);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..894171500d9d6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -294,6 +294,20 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
// -----
+// CHECK-LABEL: test_vector_insert_scalar
+// CHECK-SAME: (%[[DEST:.*]]: vector<2x4xf32>, %[[SRC:.*]]: f32) -> vector<2x4xf32> {
+func.func @test_vector_insert_scalar(%arg0: vector<2x4xf32>, %arg1: f32) -> vector<2x4xf32> {
+
+ // CHECK: %[[DEST_1D:.*]] = vector.shape_cast %[[DEST]] : vector<2x4xf32> to vector<8xf32>
+ // CHECK: %[[INSERT_1D:.*]] = vector.insert %[[SRC]], %[[DEST_1D]] [6] : f32 into vector<8xf32>
+ // CHECK: %[[RES:.*]] = vector.shape_cast %[[INSERT_1D]] : vector<8xf32> to vector<2x4xf32>
+ // CHECK: return %[[RES]] : vector<2x4xf32>
+ %0 = vector.insert %arg1, %arg0[1, 2]: f32 into vector<2x4xf32>
+ return %0 : vector<2x4xf32>
+}
+
+// -----
+
// CHECK-LABEL: test_vector_insert
// CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -444,7 +458,7 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
// CHECK-LABEL: linearize_create_mask
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
-
+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
// CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index
More information about the Mlir-commits
mailing list