[Mlir-commits] [mlir] [MLIR][Vector]: Generalize conversion of `vector.insert` to LLVM in line with `vector.extract` (PR #128915)
Kunwar Grover
llvmlistbot at llvm.org
Wed Feb 26 12:26:01 PST 2025
================
@@ -1233,53 +1225,74 @@ class VectorInsertOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
- // Overwrite entire vector with value. Should be handled by folder, but
- // just to be safe.
- ArrayRef<OpFoldResult> position(positionVec);
- if (position.empty()) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- // One-shot insertion of a vector into an array (only requires insertvalue).
- if (isa<VectorType>(sourceType)) {
- if (insertOp.hasDynamicPosition())
- return failure();
-
- Value inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
- rewriter.replaceOp(insertOp, inserted);
- return success();
+ // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
+ // its explanatory comment about how N-D vectors are converted as nested
+ // aggregates (llvm.array's) of 1D vectors.
+ //
+ // There are 3 steps here, vs 2 in VectorExtractOpConversion:
+ // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
+ // - Insertion into the 1D vector: llvm.insertelement.
+ // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
+
+ // Determine if we need to extract/insert a 1D vector out of the aggregate.
+ bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
+ // Determine if we need to insert a scalar into the 1D vector.
+ bool isScalarWithin1DVector =
+ static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+ ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
+ positionVec.begin(),
+ isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
+ OpFoldResult positionOfScalarWithin1DVector;
+ if (destVectorType.getRank() == 0) {
+ // Since the LLVM type converter converts 0D vectors to 1D vectors, we
+ // need to create a 0 here as the position into the 1D vector.
+ Type idxType = typeConverter->convertType(rewriter.getIndexType());
+ positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
+ } else if (isScalarWithin1DVector) {
+ positionOfScalarWithin1DVector = positionVec.back();
}
----------------
Groverkss wrote:
I see. We can keep it as is.
https://github.com/llvm/llvm-project/pull/128915
More information about the Mlir-commits
mailing list