[Mlir-commits] [mlir] [MLIR] VectorToLLVM: Fix vector.insert conversion for 0-D vectors, and add a test (PR #128810)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 25 19:47:10 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Benoit Jacob (bjacob)

<details>
<summary>Changes</summary>

The handling of `vector.insert` in VectorToLLVM was incorrectly handling the case of a 0-D destination vector, as in:

```mlir
%0 = vector.insert %src, %dst[] : f32 into vector<f32>
```

Since the type conversion to LLVM convertes `vector<f32>` to `vector<1xf32>`, it was required to rewrite the op into a llvm.insertelement into such a `vector<1xf32>`. Instead, the existing code simply returned the source value, as if the converted type was the scalar type.

Tests added. There were no tests convering the `vector.insert` conversions.

---
Full diff: https://github.com/llvm/llvm-project/pull/128810.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+7-3) 
- (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+29) 


``````````diff
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index c9d637ce81f93..a5e9e9bf6498b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1233,11 +1233,15 @@ class VectorInsertOpConversion
     SmallVector<OpFoldResult> positionVec = getMixedValues(
         adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
-    // Overwrite entire vector with value. Should be handled by folder, but
-    // just to be safe.
     ArrayRef<OpFoldResult> position(positionVec);
+    // Case of empty position, used with 0-D destination vector. In that case,
+    // the converted destination type is a LLVM vector of size 1, and we need
+    // a 0 as the position.
     if (position.empty()) {
-      rewriter.replaceOp(insertOp, adaptor.getSource());
+      rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
+          insertOp, llvmResultType, adaptor.getDest(), adaptor.getSource(),
+          rewriter.create<LLVM::ConstantOp>(loc,
+                                            rewriter.getI64IntegerAttr(0)));
       return success();
     }
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 36b37a137ac1e..72ca06ba7d9a4 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1787,3 +1787,32 @@ func.func @step() -> vector<4xindex> {
   %0 = vector.step : vector<4xindex>
   return %0 : vector<4xindex>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.insert
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @insert_0d
+// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
+func.func @insert_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
+  %0 = vector.insert %src, %dst[] : f32 into vector<f32>
+  return %0 : vector<f32>
+}
+
+// CHECK-LABEL: @insert_1d
+// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
+func.func @insert_1d(%src: f32, %dst: vector<2xf32>) -> vector<2xf32> {
+  %0 = vector.insert %src, %dst[1] : f32 into vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: @insert_2d
+// CHECK: llvm.extractvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
+// CHECK: llvm.insertelement {{.*}} : vector<2xf32>
+// CHECK: llvm.insertvalue {{.*}} : !llvm.array<2 x vector<2xf32>>
+func.func @insert_2d(%src: f32, %dst: vector<2x2xf32>) -> vector<2x2xf32> {
+  %0 = vector.insert %src, %dst[1, 0] : f32 into vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/128810


More information about the Mlir-commits mailing list