[Mlir-commits] [mlir] [MLIR][Vector]: Generalize conversion of `vector.insert` to LLVM in line with `vector.extract` (PR #128915)
Benoit Jacob
llvmlistbot at llvm.org
Wed Feb 26 12:13:59 PST 2025
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/128915
>From d18ba6757b8b012e1290b3f015e9bf187fc973e0 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 26 Feb 2025 11:18:11 -0600
Subject: [PATCH 1/2] like-117731-but-for-insert
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 111 ++++++++++--------
.../vector-to-llvm-interface.mlir | 22 +++-
2 files changed, 78 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..e1c7547774c3b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -35,13 +35,6 @@
using namespace mlir;
using namespace mlir::vector;
-// Helper to reduce vector type by *all* but one rank at back.
-static VectorType reducedVectorTypeBack(VectorType tp) {
- assert((tp.getRank() > 1) && "unlowerable vector type");
- return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
- tp.getScalableDims().take_back());
-}
-
// Helper that picks the proper sequence for inserting.
static Value insertOne(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = insertOp->getLoc();
- auto sourceType = insertOp.getSourceType();
auto destVectorType = insertOp.getDestVectorType();
auto llvmResultType = typeConverter->convertType(destVectorType);
// Bail if result type cannot be lowered.
@@ -1233,53 +1225,74 @@ class VectorInsertOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
- // Overwrite entire vector with value. Should be handled by folder, but
- // just to be safe.
- ArrayRef<OpFoldResult> position(positionVec);
- if (position.empty()) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- // One-shot insertion of a vector into an array (only requires insertvalue).
- if (isa<VectorType>(sourceType)) {
- if (insertOp.hasDynamicPosition())
- return failure();
-
- Value inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
- rewriter.replaceOp(insertOp, inserted);
- return success();
+ // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
+ // its explanatory comment about how N-D vectors are converted as nested
+ // aggregates (llvm.array's) of 1D vectors.
+ //
+ // There are 3 steps here, vs 2 in VectorExtractOpConversion:
+ // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
+ // - Insertion into the 1D vector: llvm.insertelement.
+ // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
+
+ // Determine if we need to extract/insert a 1D vector out of the aggregate.
+ bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
+ // Determine if we need to insert a scalar into the 1D vector.
+ bool isScalarWithin1DVector =
+ static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
+
+ ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
+ positionVec.begin(),
+ isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
+ OpFoldResult positionOfScalarWithin1DVector;
+ if (destVectorType.getRank() == 0) {
+ // Since the LLVM type converter converts 0D vectors to 1D vectors, we
+ // need to create a 0 here as the position into the 1D vector.
+ Type idxType = typeConverter->convertType(rewriter.getIndexType());
+ positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
+ } else if (isScalarWithin1DVector) {
+ positionOfScalarWithin1DVector = positionVec.back();
}
- // Potential extraction of 1-D vector from array.
- Value extracted = adaptor.getDest();
- auto oneDVectorType = destVectorType;
- if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
- return failure();
-
- oneDVectorType = reducedVectorTypeBack(destVectorType);
- extracted = rewriter.create<LLVM::ExtractValueOp>(
- loc, extracted, getAsIntegers(position.drop_back()));
+ // We are going to mutate this 1D vector until it is either the final
+ // result (in the non-aggregate case) or the value that needs to be
+ // inserted into the aggregate result.
+ Value vector1d;
+ if (isScalarWithin1DVector) {
+ // Scalar-into-1D-vector case, so we know we will have to create a
+ // InsertElementOp. The question is into what destination.
+ if (is1DVectorWithinAggregate) {
+ // Aggregate case: the destination for the InsertElementOp needs to be
+ // extracted from the aggregate.
+ if (!llvm::all_of(positionOf1DVectorWithinAggregate,
+ llvm::IsaPred<Attribute>)) {
+ // llvm.extractvalue does not support dynamic dimensions.
+ return failure();
+ }
+ vector1d = rewriter.create<LLVM::ExtractValueOp>(
+ loc, adaptor.getDest(),
+ getAsIntegers(positionOf1DVectorWithinAggregate));
+ } else {
+ // No-aggregate case. The destination for the InsertElementOp is just
+ // the insertOp's destination.
+ vector1d = adaptor.getDest();
+ }
+ // Insert the scalar into the 1D vector.
+ vector1d = rewriter.create<LLVM::InsertElementOp>(
+ loc, vector1d.getType(), vector1d, adaptor.getSource(),
+ getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
+ } else {
+ // No scalar insertion. The 1D vector is just the source.
+ vector1d = adaptor.getSource();
}
- // Insertion of an element into a 1-D LLVM vector.
- Value inserted = rewriter.create<LLVM::InsertElementOp>(
- loc, typeConverter->convertType(oneDVectorType), extracted,
- adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
-
- // Potential insertion of resulting 1-D vector into array.
- if (position.size() > 1) {
- if (insertOp.hasDynamicPosition())
- return failure();
-
- inserted = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), inserted,
- getAsIntegers(position.drop_back()));
+ Value result = vector1d;
+ if (is1DVectorWithinAggregate) {
+ result = rewriter.create<LLVM::InsertValueOp>(
+ loc, adaptor.getDest(), vector1d,
+ getAsIntegers(positionOf1DVectorWithinAggregate));
}
- rewriter.replaceOp(insertOp, inserted);
+ rewriter.replaceOp(insertOp, result);
return success();
}
};
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index fa7c030538401..7e60e62363ceb 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -628,6 +628,16 @@ func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f
// vector.insert
//===----------------------------------------------------------------------===//
+func.func @insert_scalar_into_vec_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
+ %0 = vector.insert %src, %dst[] : f32 into vector<f32>
+ return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @insert_scalar_into_vec_0d
+// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+
+// -----
+
func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
%0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32>
return %0 : vector<4xf32>
@@ -780,10 +790,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
return %0 : vector<1x16xf32>
}
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx(
-// CHECK: vector.insert
+// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
+// CHECK: llvm.insertelement {{.*}} : vector<16xf32>
+// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
// -----
@@ -793,10 +803,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
return %0 : vector<1x[16]xf32>
}
-// Multi-dim vectors are not supported but this test shouldn't crash.
-
// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(
-// CHECK: vector.insert
+// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
+// CHECK: llvm.insertelement {{.*}} : vector<[16]xf32>
+// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
// -----
>From d8e2291ab1afff2cfbeb7f5a4f2a45ebb8739b63 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 26 Feb 2025 14:08:02 -0600
Subject: [PATCH 2/2] review comments
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 49 +++++++++++--------
1 file changed, 28 insertions(+), 21 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index e1c7547774c3b..94efec61a466c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1229,38 +1229,48 @@ class VectorInsertOpConversion
// its explanatory comment about how N-D vectors are converted as nested
// aggregates (llvm.array's) of 1D vectors.
//
- // There are 3 steps here, vs 2 in VectorExtractOpConversion:
- // - Extraction of a 1D vector from the nested aggregate: llvm.extractvalue.
- // - Insertion into the 1D vector: llvm.insertelement.
- // - Insertion of the 1D vector into the nested aggregate: llvm.insertvalue.
+ // The innermost dimension of the destination vector, when converted to a
+ // nested aggregate form, will always be a 1D vector.
+ //
+ // * If the insertion is happening into the innermost dimension of the
+ // destination vector:
+ // - If the destination is a nested aggregate, extract a 1D vector out of
+ // the aggregate. This can be done using llvm.extractvalue. The
+ // destination is now guaranteed to be a 1D vector, to which we are
+ // inserting.
+ // - Do the insertion into the 1D destination vector, and make the result
+ // the new source nested aggregate. This can be done using
+ // llvm.insertelement.
+ // * Insert the source nested aggregate into the destination nested
+ // aggregate.
// Determine if we need to extract/insert a 1D vector out of the aggregate.
- bool is1DVectorWithinAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
+ bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
// Determine if we need to insert a scalar into the 1D vector.
- bool isScalarWithin1DVector =
+ bool insertIntoInnermostDim =
static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
positionVec.begin(),
- isScalarWithin1DVector ? positionVec.size() - 1 : positionVec.size());
+ insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
OpFoldResult positionOfScalarWithin1DVector;
if (destVectorType.getRank() == 0) {
// Since the LLVM type converter converts 0D vectors to 1D vectors, we
// need to create a 0 here as the position into the 1D vector.
Type idxType = typeConverter->convertType(rewriter.getIndexType());
positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
- } else if (isScalarWithin1DVector) {
+ } else if (insertIntoInnermostDim) {
positionOfScalarWithin1DVector = positionVec.back();
}
// We are going to mutate this 1D vector until it is either the final
// result (in the non-aggregate case) or the value that needs to be
// inserted into the aggregate result.
- Value vector1d;
- if (isScalarWithin1DVector) {
+ Value sourceAggregate = adaptor.getSource();
+ if (insertIntoInnermostDim) {
// Scalar-into-1D-vector case, so we know we will have to create a
// InsertElementOp. The question is into what destination.
- if (is1DVectorWithinAggregate) {
+ if (isNestedAggregate) {
// Aggregate case: the destination for the InsertElementOp needs to be
// extracted from the aggregate.
if (!llvm::all_of(positionOf1DVectorWithinAggregate,
@@ -1268,27 +1278,24 @@ class VectorInsertOpConversion
// llvm.extractvalue does not support dynamic dimensions.
return failure();
}
- vector1d = rewriter.create<LLVM::ExtractValueOp>(
+ sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
loc, adaptor.getDest(),
getAsIntegers(positionOf1DVectorWithinAggregate));
} else {
// No-aggregate case. The destination for the InsertElementOp is just
// the insertOp's destination.
- vector1d = adaptor.getDest();
+ sourceAggregate = adaptor.getDest();
}
// Insert the scalar into the 1D vector.
- vector1d = rewriter.create<LLVM::InsertElementOp>(
- loc, vector1d.getType(), vector1d, adaptor.getSource(),
+ sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
+ loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
- } else {
- // No scalar insertion. The 1D vector is just the source.
- vector1d = adaptor.getSource();
}
- Value result = vector1d;
- if (is1DVectorWithinAggregate) {
+ Value result = sourceAggregate;
+ if (isNestedAggregate) {
result = rewriter.create<LLVM::InsertValueOp>(
- loc, adaptor.getDest(), vector1d,
+ loc, adaptor.getDest(), sourceAggregate,
getAsIntegers(positionOf1DVectorWithinAggregate));
}
More information about the Mlir-commits
mailing list