[Mlir-commits] [mlir] [MLIR][Vector]: Generalize conversion of `vector.insert` to LLVM in line with `vector.extract` (PR #128915)

Benoit Jacob llvmlistbot at llvm.org
Wed Feb 26 09:21:45 PST 2025


https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/128915

This is doing the same as https://github.com/llvm/llvm-project/pull/117731 did for `vector.extract`, but for `vector.insert`.

It is a bit more complicated as the insertion destination may itself need to be extracted.

As the test shows, this fixes two previously unsupported cases:
- Dynamic indices
- 0-D vectors.

>From d18ba6757b8b012e1290b3f015e9bf187fc973e0 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 26 Feb 2025 11:18:11 -0600
Subject: [PATCH] like-117731-but-for-insert

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 111 ++++++++++--------
 .../vector-to-llvm-interface.mlir             |  22 +++-
 2 files changed, 78 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..e1c7547774c3b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -35,13 +35,6 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-// Helper to reduce vector type by *all* but one rank at back.
-static VectorType reducedVectorTypeBack(VectorType tp) {
-  assert((tp.getRank() > 1) && "unlowerable vector type");
-  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
-                         tp.getScalableDims().take_back());
-}
-
 // Helper that picks the proper sequence for inserting.
 static Value insertOne(ConversionPatternRewriter &rewriter,
                        const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = insertOp->getLoc();
-    auto sourceType = insertOp.getSourceType();
     auto destVectorType = insertOp.getDestVectorType();
     auto llvmResultType = typeConverter->convertType(destVectorType);
     // Bail if result type cannot be lowered.
@@ -1233,53 +1225,74 @@ class VectorInsertOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Overwrite entire vector with value. Should be handled by folder, but
-    // just to be safe.
-    ArrayRef<OpFoldResult> position(positionVec);
-    if (position.empty()) {
-      rewriter.replaceOp(insertOp, adaptor.getSource());
-      return success();
-    }
-
-    // One-shot insertion of a vector into an array (only requires insertvalue).
-    if (isa<VectorType>(sourceType)) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      Value inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
-      rewriter.replaceOp(insertOp, inserted);
-      return success();
+    // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
+    // its explanatory comment about how N-D vectors are converted as nested
+    // aggregates (llvm.array's) of 1D vectors.
+    //
+    // There are 3 steps here, vs 2 in VectorExtractOpConversion:
+    // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
+    // - Insertion into the 1D vector: llvm.insertelement.
+    // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
+
+    // Determine if we need to extract/insert a 1D vector out of the aggregate.
+    bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
+    // Determine if we need to insert a scalar into the 1D vector.
+    bool isScalarWithin1DVector =
+        static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+    ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
+        positionVec.begin(),
+        isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
+    OpFoldResult positionOfScalarWithin1DVector;
+    if (destVectorType.getRank() == 0) {
+      // Since the LLVM type converter converts 0D vectors to 1D vectors, we
+      // need to create a 0 here as the position into the 1D vector.
+      Type idxType = typeConverter->convertType(rewriter.getIndexType());
+      positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
+    } else if (isScalarWithin1DVector) {
+      positionOfScalarWithin1DVector = positionVec.back();
     }
 
-    // Potential extraction of 1-D vector from array.
-    Value extracted = adaptor.getDest();
-    auto oneDVectorType = destVectorType;
-    if (position.size() > 1) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      oneDVectorType = reducedVectorTypeBack(destVectorType);
-      extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, extracted, getAsIntegers(position.drop_back()));
+    // We are going to mutate this 1D vector until it is either the final
+    // result (in the non-aggregate case) or the value that needs to be
+    // inserted into the aggregate result.
+    Value vector1d;
+    if (isScalarWithin1DVector) {
+      // Scalar-into-1D-vector case, so we know we will have to create a
+      // InsertElementOp. The question is into what destination.
+      if (is1DVectorWithinAggregate) {
+        // Aggregate case: the destination for the InsertElementOp needs to be
+        // extracted from the aggregate.
+        if (!llvm::all_of(positionOf1DVectorWithinAggregate,
+                          llvm::IsaPred<Attribute>)) {
+          // llvm.extractvalue does not support dynamic dimensions.
+          return failure();
+        }
+        vector1d = rewriter.create<LLVM::ExtractValueOp>(
+            loc, adaptor.getDest(),
+            getAsIntegers(positionOf1DVectorWithinAggregate));
+      } else {
+        // No-aggregate case. The destination for the InsertElementOp is just
+        // the insertOp's destination.
+        vector1d = adaptor.getDest();
+      }
+      // Insert the scalar into the 1D vector.
+      vector1d = rewriter.create<LLVM::InsertElementOp>(
+          loc, vector1d.getType(), vector1d, adaptor.getSource(),
+          getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
+    } else {
+      // No scalar insertion. The 1D vector is just the source.
+      vector1d = adaptor.getSource();
     }
 
-    // Insertion of an element into a 1-D LLVM vector.
-    Value inserted = rewriter.create<LLVM::InsertElementOp>(
-        loc, typeConverter->convertType(oneDVectorType), extracted,
-        adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
-
-    // Potential insertion of resulting 1-D vector into array.
-    if (position.size() > 1) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), inserted,
-          getAsIntegers(position.drop_back()));
+    Value result = vector1d;
+    if (is1DVectorWithinAggregate) {
+      result = rewriter.create<LLVM::InsertValueOp>(
+          loc, adaptor.getDest(), vector1d,
+          getAsIntegers(positionOf1DVectorWithinAggregate));
     }
 
-    rewriter.replaceOp(insertOp, inserted);
+    rewriter.replaceOp(insertOp, result);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index fa7c030538401..7e60e62363ceb 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -628,6 +628,16 @@ func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f
 // vector.insert
 //===----------------------------------------------------------------------===//
 
+func.func @insert_scalar_into_vec_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
+  %0 = vector.insert %src, %dst[] : f32 into vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @insert_scalar_into_vec_0d
+//       CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+
+// -----
+
 func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32>
   return %0 : vector<4xf32>
@@ -780,10 +790,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
   return %0 : vector<1x16xf32>
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
 // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx(
-//       CHECK:   vector.insert
+//       CHECK:   llvm.extractvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
+//       CHECK:   llvm.insertelement {{.*}} : vector<16xf32>
+//       CHECK:   llvm.insertvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
 
 // -----
 
@@ -793,10 +803,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
   return %0 : vector<1x[16]xf32>
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
 // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(
-//       CHECK:   vector.insert
+//       CHECK:   llvm.extractvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
+//       CHECK:   llvm.insertelement {{.*}} : vector<[16]xf32>
+//       CHECK:   llvm.insertvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
 
 // -----
 



More information about the Mlir-commits mailing list