[Mlir-commits] [mlir] 7371f69 - [MLIR][Vector]: Generalize conversion of `vector.insert` to LLVM in line with `vector.extract` (#128915)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 26 13:35:39 PST 2025


Author: Benoit Jacob
Date: 2025-02-26T16:35:36-05:00
New Revision: 7371f691b97986fd3f32d8618131ca40788c7b8b

URL: https://github.com/llvm/llvm-project/commit/7371f691b97986fd3f32d8618131ca40788c7b8b
DIFF: https://github.com/llvm/llvm-project/commit/7371f691b97986fd3f32d8618131ca40788c7b8b.diff

LOG: [MLIR][Vector]: Generalize conversion of `vector.insert` to LLVM in line with `vector.extract` (#128915)

This is doing the same as
https://github.com/llvm/llvm-project/pull/117731 did for
`vector.extract`, but for `vector.insert`.

It is a bit more complicated as the insertion destination may itself
need to be extracted.

As the test shows, this fixes two previously unsupported cases:
- Dynamic indices
- 0-D vectors.

---------

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
    mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..94efec61a466c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -35,13 +35,6 @@
 using namespace mlir;
 using namespace mlir::vector;
 
-// Helper to reduce vector type by *all* but one rank at back.
-static VectorType reducedVectorTypeBack(VectorType tp) {
-  assert((tp.getRank() > 1) && "unlowerable vector type");
-  return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
-                         tp.getScalableDims().take_back());
-}
-
 // Helper that picks the proper sequence for inserting.
 static Value insertOne(ConversionPatternRewriter &rewriter,
                        const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = insertOp->getLoc();
-    auto sourceType = insertOp.getSourceType();
     auto destVectorType = insertOp.getDestVectorType();
     auto llvmResultType = typeConverter->convertType(destVectorType);
     // Bail if result type cannot be lowered.
@@ -1233,53 +1225,81 @@ class VectorInsertOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Overwrite entire vector with value. Should be handled by folder, but
-    // just to be safe.
-    ArrayRef<OpFoldResult> position(positionVec);
-    if (position.empty()) {
-      rewriter.replaceOp(insertOp, adaptor.getSource());
-      return success();
-    }
-
-    // One-shot insertion of a vector into an array (only requires insertvalue).
-    if (isa<VectorType>(sourceType)) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      Value inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
-      rewriter.replaceOp(insertOp, inserted);
-      return success();
+    // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
+    // its explanatory comment about how N-D vectors are converted as nested
+    // aggregates (llvm.array's) of 1D vectors.
+    //
+    // The innermost dimension of the destination vector, when converted to a
+    // nested aggregate form, will always be a 1D vector.
+    //
+    // * If the insertion is happening into the innermost dimension of the
+    //   destination vector:
+    //   - If the destination is a nested aggregate, extract a 1D vector out of
+    //     the aggregate. This can be done using llvm.extractvalue. The
+    //     destination is now guaranteed to be a 1D vector, to which we are
+    //     inserting.
+    //   - Do the insertion into the 1D destination vector, and make the result
+    //     the new source nested aggregate. This can be done using
+    //     llvm.insertelement.
+    // * Insert the source nested aggregate into the destination nested
+    //   aggregate.
+
+    // Determine if we need to extract/insert a 1D vector out of the aggregate.
+    bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
+    // Determine if we need to insert a scalar into the 1D vector.
+    bool insertIntoInnermostDim =
+        static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+    ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
+        positionVec.begin(),
+        insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
+    OpFoldResult positionOfScalarWithin1DVector;
+    if (destVectorType.getRank() == 0) {
+      // Since the LLVM type converter converts 0D vectors to 1D vectors, we
+      // need to create a 0 here as the position into the 1D vector.
+      Type idxType = typeConverter->convertType(rewriter.getIndexType());
+      positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
+    } else if (insertIntoInnermostDim) {
+      positionOfScalarWithin1DVector = positionVec.back();
     }
 
-    // Potential extraction of 1-D vector from array.
-    Value extracted = adaptor.getDest();
-    auto oneDVectorType = destVectorType;
-    if (position.size() > 1) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      oneDVectorType = reducedVectorTypeBack(destVectorType);
-      extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, extracted, getAsIntegers(position.drop_back()));
+    // We are going to mutate this 1D vector until it is either the final
+    // result (in the non-aggregate case) or the value that needs to be
+    // inserted into the aggregate result.
+    Value sourceAggregate = adaptor.getSource();
+    if (insertIntoInnermostDim) {
+      // Scalar-into-1D-vector case, so we know we will have to create a
+      // InsertElementOp. The question is into what destination.
+      if (isNestedAggregate) {
+        // Aggregate case: the destination for the InsertElementOp needs to be
+        // extracted from the aggregate.
+        if (!llvm::all_of(positionOf1DVectorWithinAggregate,
+                          llvm::IsaPred<Attribute>)) {
+          // llvm.extractvalue does not support dynamic dimensions.
+          return failure();
+        }
+        sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
+            loc, adaptor.getDest(),
+            getAsIntegers(positionOf1DVectorWithinAggregate));
+      } else {
+        // No-aggregate case. The destination for the InsertElementOp is just
+        // the insertOp's destination.
+        sourceAggregate = adaptor.getDest();
+      }
+      // Insert the scalar into the 1D vector.
+      sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
+          loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
+          getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
     }
 
-    // Insertion of an element into a 1-D LLVM vector.
-    Value inserted = rewriter.create<LLVM::InsertElementOp>(
-        loc, typeConverter->convertType(oneDVectorType), extracted,
-        adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
-
-    // Potential insertion of resulting 1-D vector into array.
-    if (position.size() > 1) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), inserted,
-          getAsIntegers(position.drop_back()));
+    Value result = sourceAggregate;
+    if (isNestedAggregate) {
+      result = rewriter.create<LLVM::InsertValueOp>(
+          loc, adaptor.getDest(), sourceAggregate,
+          getAsIntegers(positionOf1DVectorWithinAggregate));
     }
 
-    rewriter.replaceOp(insertOp, inserted);
+    rewriter.replaceOp(insertOp, result);
     return success();
   }
 };

diff  --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index fa7c030538401..c3f06dd4d5dd1 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -628,6 +628,16 @@ func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f
 // vector.insert
 //===----------------------------------------------------------------------===//
 
+func.func @insert_scalar_into_vec_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
+  %0 = vector.insert %src, %dst[] : f32 into vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @insert_scalar_into_vec_0d
+//       CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+
+// -----
+
 func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32>
   return %0 : vector<4xf32>
@@ -780,10 +790,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
   return %0 : vector<1x16xf32>
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
 // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx(
-//       CHECK:   vector.insert
+//       CHECK:   llvm.extractvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
+//       CHECK:   llvm.insertelement {{.*}} : vector<16xf32>
+//       CHECK:   llvm.insertvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
 
 // -----
 
@@ -793,10 +803,26 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
   return %0 : vector<1x[16]xf32>
 }
 
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
 // CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(
-//       CHECK:   vector.insert
+//       CHECK:   llvm.extractvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
+//       CHECK:   llvm.insertelement {{.*}} : vector<[16]xf32>
+//       CHECK:   llvm.insertvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
+
+
+// -----
+
+func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_fail(%arg0: vector<2x16xf32>, %arg1: f32, %idx: index)
+                                        -> vector<2x16xf32> {
+  %0 = vector.insert %arg1, %arg0[%idx, 0]: f32 into vector<2x16xf32>
+  return %0 : vector<2x16xf32>
+}
+
+// Currently fails to convert because of the dynamic index in non-innermost
+// dimension that converts to a llvm.array, as llvm.extractvalue does not
+// support dynamic dimensions
+
+// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_fail
+//       CHECK: vector.insert
 
 // -----
 


        


More information about the Mlir-commits mailing list