[Mlir-commits] [mlir] [MLIR][Vector]: Generalize conversion of `vector.insert` to LLVM in line with `vector.extract` (PR #128915)

Kunwar Grover llvmlistbot at llvm.org
Wed Feb 26 11:52:40 PST 2025


================
@@ -1233,53 +1225,74 @@ class VectorInsertOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Overwrite entire vector with value. Should be handled by folder, but
-    // just to be safe.
-    ArrayRef<OpFoldResult> position(positionVec);
-    if (position.empty()) {
-      rewriter.replaceOp(insertOp, adaptor.getSource());
-      return success();
-    }
-
-    // One-shot insertion of a vector into an array (only requires insertvalue).
-    if (isa<VectorType>(sourceType)) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      Value inserted = rewriter.create<LLVM::InsertValueOp>(
-          loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
-      rewriter.replaceOp(insertOp, inserted);
-      return success();
+    // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
+    // its explanatory comment about how N-D vectors are converted as nested
+    // aggregates (llvm.array's) of 1D vectors.
+    //
+    // There are 3 steps here, vs 2 in VectorExtractOpConversion:
+    // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
+    // - Insertion into the 1D vector: llvm.insertelement.
+    // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
+
+    // Determine if we need to extract/insert a 1D vector out of the aggregate.
+    bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
+    // Determine if we need to insert a scalar into the 1D vector.
+    bool isScalarWithin1DVector =
+        static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+    ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
+        positionVec.begin(),
+        isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
+    OpFoldResult positionOfScalarWithin1DVector;
+    if (destVectorType.getRank() == 0) {
+      // Since the LLVM type converter converts 0D vectors to 1D vectors, we
+      // need to create a 0 here as the position into the 1D vector.
+      Type idxType = typeConverter->convertType(rewriter.getIndexType());
+      positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
+    } else if (isScalarWithin1DVector) {
+      positionOfScalarWithin1DVector = positionVec.back();
     }
 
-    // Potential extraction of 1-D vector from array.
-    Value extracted = adaptor.getDest();
-    auto oneDVectorType = destVectorType;
-    if (position.size() > 1) {
-      if (insertOp.hasDynamicPosition())
-        return failure();
-
-      oneDVectorType = reducedVectorTypeBack(destVectorType);
-      extracted = rewriter.create<LLVM::ExtractValueOp>(
-          loc, extracted, getAsIntegers(position.drop_back()));
+    // We are going to mutate this 1D vector until it is either the final
+    // result (in the non-aggregate case) or the value that needs to be
+    // inserted into the aggregate result.
+    Value vector1d;
----------------
Groverkss wrote:

nit: Rename to `sourceAggregate`. This may not always be a 1D vector.

https://github.com/llvm/llvm-project/pull/128915


More information about the Mlir-commits mailing list