[Mlir-commits] [mlir] [MLIR] VectorToLLVM: Fix vector.insert conversion for 0-D vectors, and add a test (PR #128810)
Benoit Jacob
llvmlistbot at llvm.org
Tue Feb 25 19:46:35 PST 2025
https://github.com/bjacob created https://github.com/llvm/llvm-project/pull/128810
The handling of `vector.insert` in VectorToLLVM was incorrectly handling the case of a 0-D destination vector, as in:
```mlir
%0 = vector.insert %src, %dst[] : f32 into vector<f32>
```
Since the type conversion to LLVM convertes `vector<f32>` to `vector<1xf32>`, it was required to rewrite the op into a llvm.insertelement into such a `vector<1xf32>`. Instead, the existing code simply returned the source value, as if the converted type was the scalar type.
Tests added. There were no tests convering the `vector.insert` conversions.
>From 7e937c9897add659f316d10848a971973744600a Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 25 Feb 2025 21:42:27 -0600
Subject: [PATCH] Fix vector.insert
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
.../VectorToLLVM/ConvertVectorToLLVM.cpp | 10 +++++--
.../VectorToLLVM/vector-to-llvm.mlir | 29 +++++++++++++++++++
2 files changed, 36 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..a5e9e9bf6498b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1233,11 +1233,15 @@ class VectorInsertOpConversion
SmallVector<OpFoldResult> positionVec = getMixedValues(
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
- // Overwrite entire vector with value. Should be handled by folder, but
- // just to be safe.
ArrayRef<OpFoldResult> position(positionVec);
+ // Case of empty position, used with 0-D destination vector. In that case,
+ // the converted destination type is a LLVM vector of size 1, and we need
+ // a 0 as the position.
if (position.empty()) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
+ rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+ insertOp, llvmResultType, adaptor.getDest(), adaptor.getSource(),
+ rewriter.create<LLVM::ConstantOp>(loc,
+ rewriter.getI64IntegerAttr(0)));
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..72ca06ba7d9a4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1787,3 +1787,32 @@ func.func @step() -> vector<4xindex> {
%0 = vector.step : vector<4xindex>
return %0 : vector<4xindex>
}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.insert
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @insert_0d
+// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+func.func @insert_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
+ %0 = vector.insert %src, %dst[] : f32 into vector<f32>
+ return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @insert_1d
+// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
+func.func @insert_1d(%src: f32, %dst: vector<2xf32>) -> vector<2xf32> {
+ %0 = vector.insert %src, %dst[1] : f32 into vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @insert_2d
+// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
+// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
+// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
+func.func @insert_2d(%src: f32, %dst: vector<2x2xf32>) -> vector<2x2xf32> {
+ %0 = vector.insert %src, %dst[1, 0] : f32 into vector<2x2xf32>
+ return %0 : vector<2x2xf32>
+}
More information about the Mlir-commits
mailing list