[Mlir-commits] [mlir] [MLIR][Vector]: Generalize conversion of `vector.insert` to LLVM in line with `vector.extract` (PR #128915)

Benoit Jacob llvmlistbot at llvm.org
Wed Feb 26 12:13:59 PST 2025


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/128915

>From d18ba6757b8b012e1290b3f015e9bf187fc973e0 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 26 Feb 2025 11:18:11 -0600
Subject: [PATCH 1/2] like-117731-but-for-insert

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 111 ++++++++++--------
 .../vector-to-llvm-interface.mlir             |  22 +++-
 2 files changed, 78 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..e1c7547774c3b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -35,13 +35,6 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-// Helper to reduce vector type by *all* but one rank at back.
-static VectorType reducedVectorTypeBack(VectorType tp) {
-  assert((tp.getRank() > 1) && "unlowerable vector type");
-  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
-                         tp.getScalableDims().take_back());
-}
-
 // Helper that picks the proper sequence for inserting.
 static Value insertOne(ConversionPatternRewriter &rewriter,
                        const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = insertOp->getLoc();
-    auto sourceType = insertOp.getSourceType();
     auto destVectorType = insertOp.getDestVectorType();
     auto llvmResultType = typeConverter->convertType(destVectorType);
     // Bail if result type cannot be lowered.
@@ -1233,53 +1225,74 @@ class VectorInsertOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Overwrite entire vector with value. Should be handled by folder, but
-    // just to be safe.
-    ArrayRef<OpFoldResult> position(positionVec);
-    if (position.empty()) {
-      rewriter.replaceOp(insertOp, adaptor.getSource());
-      return success();
-    }
-
-    // One-shot insertion of a vector into an array (only requires insertvalue).
-    if (isa<VectorType>(sourceType)) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      Value inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
-      rewriter.replaceOp(insertOp, inserted);
-      return success();
+    // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
+    // its explanatory comment about how N-D vectors are converted as nested
+    // aggregates (llvm.array's) of 1D vectors.
+    //
+    // There are 3 steps here, vs 2 in VectorExtractOpConversion:
+    // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
+    // - Insertion into the 1D vector: llvm.insertelement.
+    // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
+
+    // Determine if we need to extract/insert a 1D vector out of the aggregate.
+    bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
+    // Determine if we need to insert a scalar into the 1D vector.
+    bool isScalarWithin1DVector =
+        static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+    ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
+        positionVec.begin(),
+        isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
+    OpFoldResult positionOfScalarWithin1DVector;
+    if (destVectorType.getRank() == 0) {
+      // Since the LLVM type converter converts 0D vectors to 1D vectors, we
+      // need to create a 0 here as the position into the 1D vector.
+      Type idxType = typeConverter->convertType(rewriter.getIndexType());
+      positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
+    } else if (isScalarWithin1DVector) {
+      positionOfScalarWithin1DVector = positionVec.back();
     }
 
-    // Potential extraction of 1-D vector from array.
-    Value extracted = adaptor.getDest();
-    auto oneDVectorType = destVectorType;
-    if (position.size() > 1) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      oneDVectorType = reducedVectorTypeBack(destVectorType);
-      extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, extracted, getAsIntegers(position.drop_back()));
+    // We are going to mutate this 1D vector until it is either the final
+    // result (in the non-aggregate case) or the value that needs to be
+    // inserted into the aggregate result.
+    Value vector1d;
+    if (isScalarWithin1DVector) {
+      // Scalar-into-1D-vector case, so we know we will have to create a
+      // InsertElementOp. The question is into what destination.
+      if (is1DVectorWithinAggregate) {
+        // Aggregate case: the destination for the InsertElementOp needs to be
+        // extracted from the aggregate.
+        if (!llvm::all_of(positionOf1DVectorWithinAggregate,
+                          llvm::IsaPred<Attribute>)) {
+          // llvm.extractvalue does not support dynamic dimensions.
+          return failure();
+        }
+        vector1d = rewriter.create<LLVM::ExtractValueOp>(
+            loc, adaptor.getDest(),
+            getAsIntegers(positionOf1DVectorWithinAggregate));
+      } else {
+        // No-aggregate case. The destination for the InsertElementOp is just
+        // the insertOp's destination.
+        vector1d = adaptor.getDest();
+      }
+      // Insert the scalar into the 1D vector.
+      vector1d = rewriter.create<LLVM::InsertElementOp>(
+          loc, vector1d.getType(), vector1d, adaptor.getSource(),
+          getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
+    } else {
+      // No scalar insertion. The 1D vector is just the source.
+      vector1d = adaptor.getSource();
     }
 
-    // Insertion of an element into a 1-D LLVM vector.
-    Value inserted = rewriter.create<LLVM::InsertElementOp>(
-        loc, typeConverter->convertType(oneDVectorType), extracted,
-        adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
-
-    // Potential insertion of resulting 1-D vector into array.
-    if (position.size() > 1) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), inserted,
-          getAsIntegers(position.drop_back()));
+    Value result = vector1d;
+    if (is1DVectorWithinAggregate) {
+      result = rewriter.create<LLVM::InsertValueOp>(
+          loc, adaptor.getDest(), vector1d,
+          getAsIntegers(positionOf1DVectorWithinAggregate));
     }
 
-    rewriter.replaceOp(insertOp, inserted);
+    rewriter.replaceOp(insertOp, result);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index fa7c030538401..7e60e62363ceb 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -628,6 +628,16 @@ func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f
 // vector.insert
 //===----------------------------------------------------------------------===//
 
+func.func @insert_scalar_into_vec_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
+  %0 = vector.insert %src, %dst[] : f32 into vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @insert_scalar_into_vec_0d
+//       CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+
+// -----
+
 func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32>
   return %0 : vector<4xf32>
@@ -780,10 +790,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
   return %0 : vector<1x16xf32>
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
 // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx(
-//       CHECK:   vector.insert
+//       CHECK:   llvm.extractvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
+//       CHECK:   llvm.insertelement {{.*}} : vector<16xf32>
+//       CHECK:   llvm.insertvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
 
 // -----
 
@@ -793,10 +803,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
   return %0 : vector<1x[16]xf32>
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
 // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(
-//       CHECK:   vector.insert
+//       CHECK:   llvm.extractvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
+//       CHECK:   llvm.insertelement {{.*}} : vector<[16]xf32>
+//       CHECK:   llvm.insertvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
 
 // -----
 

>From d8e2291ab1afff2cfbeb7f5a4f2a45ebb8739b63 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 26 Feb 2025 14:08:02 -0600
Subject: [PATCH 2/2] review comments

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 49 +++++++++++--------
 1 file changed, 28 insertions(+), 21 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index e1c7547774c3b..94efec61a466c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1229,38 +1229,48 @@ class VectorInsertOpConversion
     // its explanatory comment about how N-D vectors are converted as nested
     // aggregates (llvm.array's) of 1D vectors.
     //
-    // There are 3 steps here, vs 2 in VectorExtractOpConversion:
-    // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
-    // - Insertion into the 1D vector: llvm.insertelement.
-    // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
+    // The innermost dimension of the destination vector, when converted to a
+    // nested aggregate form, will always be a 1D vector.
+    //
+    // * If the insertion is happening into the innermost dimension of the
+    //   destination vector:
+    //   - If the destination is a nested aggregate, extract a 1D vector out of
+    //     the aggregate. This can be done using llvm.extractvalue. The
+    //     destination is now guaranteed to be a 1D vector, to which we are
+    //     inserting.
+    //   - Do the insertion into the 1D destination vector, and make the result
+    //     the new source nested aggregate. This can be done using
+    //     llvm.insertelement.
+    // * Insert the source nested aggregate into the destination nested
+    //   aggregate.
 
     // Determine if we need to extract/insert a 1D vector out of the aggregate.
-    bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
+    bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
     // Determine if we need to insert a scalar into the 1D vector.
-    bool isScalarWithin1DVector =
+    bool insertIntoInnermostDim =
         static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
 
     ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
         positionVec.begin(),
-        isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
+        insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
     OpFoldResult positionOfScalarWithin1DVector;
     if (destVectorType.getRank() == 0) {
       // Since the LLVM type converter converts 0D vectors to 1D vectors, we
       // need to create a 0 here as the position into the 1D vector.
       Type idxType = typeConverter->convertType(rewriter.getIndexType());
       positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
-    } else if (isScalarWithin1DVector) {
+    } else if (insertIntoInnermostDim) {
       positionOfScalarWithin1DVector = positionVec.back();
     }
 
     // We are going to mutate this 1D vector until it is either the final
     // result (in the non-aggregate case) or the value that needs to be
     // inserted into the aggregate result.
-    Value vector1d;
-    if (isScalarWithin1DVector) {
+    Value sourceAggregate = adaptor.getSource();
+    if (insertIntoInnermostDim) {
       // Scalar-into-1D-vector case, so we know we will have to create a
       // InsertElementOp. The question is into what destination.
-      if (is1DVectorWithinAggregate) {
+      if (isNestedAggregate) {
         // Aggregate case: the destination for the InsertElementOp needs to be
         // extracted from the aggregate.
         if (!llvm::all_of(positionOf1DVectorWithinAggregate,
@@ -1268,27 +1278,24 @@ class VectorInsertOpConversion
           // llvm.extractvalue does not support dynamic dimensions.
           return failure();
         }
-        vector1d = rewriter.create<LLVM::ExtractValueOp>(
+        sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
             loc, adaptor.getDest(),
             getAsIntegers(positionOf1DVectorWithinAggregate));
       } else {
         // No-aggregate case. The destination for the InsertElementOp is just
         // the insertOp's destination.
-        vector1d = adaptor.getDest();
+        sourceAggregate = adaptor.getDest();
       }
       // Insert the scalar into the 1D vector.
-      vector1d = rewriter.create<LLVM::InsertElementOp>(
-          loc, vector1d.getType(), vector1d, adaptor.getSource(),
+      sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
+          loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
           getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
-    } else {
-      // No scalar insertion. The 1D vector is just the source.
-      vector1d = adaptor.getSource();
     }
 
-    Value result = vector1d;
-    if (is1DVectorWithinAggregate) {
+    Value result = sourceAggregate;
+    if (isNestedAggregate) {
       result = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), vector1d,
+          loc, adaptor.getDest(), sourceAggregate,
           getAsIntegers(positionOf1DVectorWithinAggregate));
     }
 



More information about the Mlir-commits mailing list