[Mlir-commits] [mlir] [mlir][Vector] Support scalar 'vector.insert' in vector linearization (PR #146954)

Diego Caballero llvmlistbot at llvm.org
Thu Jul 3 12:43:26 PDT 2025


https://github.com/dcaballe created https://github.com/llvm/llvm-project/pull/146954

This PR add support for linearizing the insertion of a scalar element by just linearizing the `vector.insert` op.

>From e7a32be59e78d022073c67dcaa7edc913d845e0e Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Thu, 3 Jul 2025 04:19:28 +0000
Subject: [PATCH] [mlir][Vector] Support scalar 'vector.insert' in vector
 linearization

Inserting a scalar element was not supported.
---
 .../Vector/Transforms/VectorLinearize.cpp     | 83 ++++++++++++-------
 mlir/test/Dialect/Vector/linearize.mlir       | 16 +++-
 2 files changed, 70 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..7cac1cbafdd64 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -445,16 +445,36 @@ struct LinearizeVectorExtract final
   }
 };
 
-/// This pattern converts the InsertOp to a ShuffleOp that works on a
-/// linearized vector.
-/// Following,
-///   vector.insert %source %destination [ position ]
-/// is converted to :
-///   %source_1d = vector.shape_cast %source
-///   %destination_1d = vector.shape_cast %destination
-///   %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
-///   ] %out_nd = vector.shape_cast %out_1d
-/// `shuffle_indices_1d` is computed using the position of the original insert.
+/// This pattern linearizes `vector.insert` operations. It generates a 1-D
+/// version of the `vector.insert` operation when inserting a scalar into a
+/// vector. It generates a 1-D `vector.shuffle` operation when inserting a
+/// vector into another vector.
+///
+/// Example #1:
+///
+///     %0 = vector.insert %source, %destination[0] :
+///       vector<2x4xf32> into vector<2x2x4xf32>
+///
+///   is converted to:
+///
+///     %0 = vector.shape_cast %source : vector<2x4xf32> to vector<8xf32>
+///     %1 = vector.shape_cast %destination :
+///            vector<2x2x4xf32> to vector<16xf32>
+///     %2 = vector.shuffle %1, %0 [16, 17, 18, 19, 20, 21, 22, 23
+///                                  8, 9, 10, 11, 12, 13, 14, 15] :
+///            vector<16xf32>, vector<8xf32>
+///     %3 = vector.shape_cast %2 : vector<16xf32> to vector<2x2x4xf32>
+///
+/// Example #2:
+///
+///     %0 = vector.insert %source, %destination[1, 2]: f32 into vector<2x4xf32>
+///
+///   is converted to:
+///
+///     %0 = vector.shape_cast %destination : vector<2x4xf32> to vector<8xf32>
+///     %1 = vector.insert %source, %0[6]: f32 into vector<8xf32>
+///     %2 = vector.shape_cast %1 : vector<8xf32> to vector<2x4xf32>
+///
 struct LinearizeVectorInsert final
     : public OpConversionPattern<vector::InsertOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -468,25 +488,19 @@ struct LinearizeVectorInsert final
         insertOp.getDestVectorType());
     assert(dstTy && "vector type destination expected.");
 
-    // dynamic position is not supported
+    // Dynamic position is not supported.
     if (insertOp.hasDynamicPosition())
       return rewriter.notifyMatchFailure(insertOp,
                                          "dynamic position is not supported.");
     auto srcTy = insertOp.getValueToStoreType();
     auto srcAsVec = dyn_cast<VectorType>(srcTy);
-    uint64_t srcSize = 0;
-    if (srcAsVec) {
-      srcSize = srcAsVec.getNumElements();
-    } else {
-      return rewriter.notifyMatchFailure(insertOp,
-                                         "scalars are not supported.");
-    }
+    uint64_t srcSize = srcAsVec ? srcAsVec.getNumElements() : 1;
 
     auto dstShape = insertOp.getDestVectorType().getShape();
     const auto dstSize = insertOp.getDestVectorType().getNumElements();
     auto dstSizeForOffsets = dstSize;
 
-    // compute linearized offset
+    // Compute linearized offset.
     int64_t linearizedOffset = 0;
     auto offsetsNd = insertOp.getStaticPosition();
     for (auto [dim, offset] : llvm::enumerate(offsetsNd)) {
@@ -494,22 +508,35 @@ struct LinearizeVectorInsert final
       linearizedOffset += offset * dstSizeForOffsets;
     }
 
+    Location loc = insertOp.getLoc();
+    Value valueToStore = adaptor.getValueToStore();
+
+    if (!isa<VectorType>(valueToStore.getType())) {
+      // Scalar case: generate a 1-D insert.
+      Value result = rewriter.createOrFold<vector::InsertOp>(
+          loc, valueToStore, adaptor.getDest(), linearizedOffset);
+      rewriter.replaceOp(insertOp, result);
+      return success();
+    }
+
+    // Vector case: generate a shuffle.
     llvm::SmallVector<int64_t, 2> indices(dstSize);
     auto *origValsUntil = indices.begin();
     std::advance(origValsUntil, linearizedOffset);
-    std::iota(indices.begin(), origValsUntil,
-              0); // original values that remain [0, offset)
+
+    // Original values that remain [0, offset).
+    std::iota(indices.begin(), origValsUntil, 0);
     auto *newValsUntil = origValsUntil;
     std::advance(newValsUntil, srcSize);
-    std::iota(origValsUntil, newValsUntil,
-              dstSize); // new values [offset, offset+srcNumElements)
-    std::iota(newValsUntil, indices.end(),
-              linearizedOffset + srcSize); // the rest of original values
-                                           // [offset+srcNumElements, end)
+    // New values [offset, offset+srcNumElements).
+    std::iota(origValsUntil, newValsUntil, dstSize);
+    // The rest of original values [offset+srcNumElements, end);
+    std::iota(newValsUntil, indices.end(), linearizedOffset + srcSize);
 
-    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        insertOp, dstTy, adaptor.getDest(), adaptor.getValueToStore(), indices);
+    Value result = rewriter.createOrFold<vector::ShuffleOp>(
+        loc, dstTy, adaptor.getDest(), valueToStore, indices);
 
+    rewriter.replaceOp(insertOp, result);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 9cbf319ffddb2..894171500d9d6 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -294,6 +294,20 @@ func.func @test_vector_extract_scalable(%arg0: vector<2x8x[2]xf32>) -> vector<8x
 
 // -----
 
+// CHECK-LABEL: test_vector_insert_scalar
+// CHECK-SAME: (%[[DEST:.*]]: vector<2x4xf32>, %[[SRC:.*]]: f32) -> vector<2x4xf32> {
+func.func @test_vector_insert_scalar(%arg0: vector<2x4xf32>, %arg1: f32) -> vector<2x4xf32> {
+
+  // CHECK: %[[DEST_1D:.*]] = vector.shape_cast %[[DEST]] : vector<2x4xf32> to vector<8xf32>
+  // CHECK: %[[INSERT_1D:.*]] = vector.insert %[[SRC]], %[[DEST_1D]] [6] : f32 into vector<8xf32>
+  // CHECK: %[[RES:.*]] = vector.shape_cast %[[INSERT_1D]] : vector<8xf32> to vector<2x4xf32>
+  // CHECK: return %[[RES]] : vector<2x4xf32>
+  %0 = vector.insert %arg1, %arg0[1, 2]: f32 into vector<2x4xf32>
+  return %0 : vector<2x4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: test_vector_insert
 // CHECK-SAME: (%[[DEST:.*]]: vector<2x8x4xf32>, %[[SRC:.*]]: vector<8x4xf32>) -> vector<2x8x4xf32> {
 func.func @test_vector_insert(%arg0: vector<2x8x4xf32>, %arg1: vector<8x4xf32>) -> vector<2x8x4xf32> {
@@ -444,7 +458,7 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
 // CHECK-LABEL: linearize_create_mask
 // CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
 func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
-  
+
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
   // CHECK: %[[CMP:.*]] = arith.cmpi sgt, %[[ARG0]], %[[C0]] : index
   // CHECK: %[[INDEXCAST:.*]] = arith.index_cast %[[CMP]] : i1 to index



More information about the Mlir-commits mailing list