[Mlir-commits] [mlir] [mlir][Vector] Fix n-D vector.extract/insert lowering to LLVM (PR #87591)

Diego Caballero llvmlistbot at llvm.org
Fri Apr 5 12:10:51 PDT 2024


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/87591

>From 87148c492b3fa7a11a7c82d033ecf720c574121a Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Thu, 4 Apr 2024 01:13:28 +0000
Subject: [PATCH 1/2] [mlir][Vector] Fix n-D vector.extract/insert lowering to
 LLVM

The lowering of n-D vector.extract/insert ops to LLVM is not supported but
if one of these accidentally reaches the vector-to-llvm conversion patterns,
we end up with a kind of puzzling crash. This PR fixes that crash and
gracefully bails out in those cases.
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      | 20 +++----------
 .../VectorToLLVM/vector-to-llvm.mlir          | 28 +++++++++++++------
 2 files changed, 24 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 337f8bb6ab99ed..85d10f326e260e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1082,14 +1082,8 @@ class VectorExtractOpConversion
     if (!llvmResultType)
       return failure();
 
-    SmallVector<OpFoldResult> positionVec;
-    for (auto [idx, pos] : llvm::enumerate(extractOp.getMixedPosition())) {
-      if (pos.is<Value>())
-        // Make sure we use the value that has been already converted to LLVM.
-        positionVec.push_back(adaptor.getDynamicPosition()[idx]);
-      else
-        positionVec.push_back(pos);
-    }
+    SmallVector<OpFoldResult> positionVec = getMixedValues(
+        adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
     // Extract entire vector. Should be handled by folder, but just to be safe.
     ArrayRef<OpFoldResult> position(positionVec);
@@ -1209,14 +1203,8 @@ class VectorInsertOpConversion
     if (!llvmResultType)
       return failure();
 
-    SmallVector<OpFoldResult> positionVec;
-    for (auto [idx, pos] : llvm::enumerate(insertOp.getMixedPosition())) {
-      if (pos.is<Value>())
-        // Make sure we use the value that has been already converted to LLVM.
-        positionVec.push_back(adaptor.getDynamicPosition()[idx]);
-      else
-        positionVec.push_back(pos);
-    }
+    SmallVector<OpFoldResult> positionVec = getMixedValues(
+        adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
 
     // Overwrite entire vector with value. Should be handled by folder, but
     // just to be safe.
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index e94e51d49a98b7..e2528bebdf3faa 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -738,6 +738,18 @@ func.func @extract_element_with_value_1d(%arg0: vector<16xf32>, %arg1: index) ->
 
 // -----
 
+func.func @extract_element_with_value_2d(%arg0: vector<1x16xf32>, %arg1: index) -> f32 {
+  %0 = vector.extract %arg0[0, %arg1]: f32 from vector<1x16xf32>
+  return %0 : f32
+}
+
+// Multi-dim vectors are not supported but this test shouldn't crash.
+
+// CHECK-LABEL: @extract_element_with_value_2d(
+//       CHECK:   vector.extract
+
+// -----
+
 // CHECK-LABEL: @insert_element_0d
 // CHECK-SAME: %[[A:.*]]: f32,
 func.func @insert_element_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
@@ -840,16 +852,16 @@ func.func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) ->
 
 // -----
 
-func.func @insert_element_with_value_1d(%arg0: vector<16xf32>, %arg1: f32, %arg2: index)
-                                      -> vector<16xf32> {
-  %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<16xf32>
-  return %0 : vector<16xf32>
+func.func @insert_element_with_value_2d(%base: vector<1x16xf32>, %value: f32, %idx: index)
+                                        -> vector<1x16xf32> {
+  %0 = vector.insert %value, %base[0, %idx]: f32 into vector<1x16xf32>
+  return %0 : vector<1x16xf32>
 }
 
-// CHECK-LABEL: @insert_element_with_value_1d
-//  CHECK-SAME:   %[[DST:.+]]: vector<16xf32>, %[[SRC:.+]]: f32, %[[INDEX:.+]]: index
-//       CHECK:   %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
-//       CHECK:   llvm.insertelement %[[SRC]], %[[DST]][%[[UC]] : i64] : vector<16xf32>
+// Multi-dim vectors are not supported but this test shouldn't crash.
+
+// CHECK-LABEL: @insert_element_with_value_2d(
+//       CHECK:   vector.insert
 
 // -----
 

>From 849a04bff1cf46294d0ea840518a7219ac287ffd Mon Sep 17 00:00:00 2001
From: Diego Caballero <diegocaballero at google.com>
Date: Fri, 5 Apr 2024 18:57:17 +0000
Subject: [PATCH 2/2] Restore test

---
 .../Conversion/VectorToLLVM/vector-to-llvm.mlir     | 13 +++++++++++++
 1 file changed, 13 insertions(+)

diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index e2528bebdf3faa..1712d3d745b766 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -852,6 +852,19 @@ func.func @insert_element_into_vec_3d(%arg0: f32, %arg1: vector<4x8x16xf32>) ->
 
 // -----
 
+func.func @insert_element_with_value_1d(%arg0: vector<16xf32>, %arg1: f32, %arg2: index)
+                                      -> vector<16xf32> {
+  %0 = vector.insert %arg1, %arg0[%arg2]: f32 into vector<16xf32>
+  return %0 : vector<16xf32>
+}
+
+// CHECK-LABEL: @insert_element_with_value_1d
+//  CHECK-SAME:   %[[DST:.+]]: vector<16xf32>, %[[SRC:.+]]: f32, %[[INDEX:.+]]: index
+//       CHECK:   %[[UC:.+]] = builtin.unrealized_conversion_cast %[[INDEX]] : index to i64
+//       CHECK:   llvm.insertelement %[[SRC]], %[[DST]][%[[UC]] : i64] : vector<16xf32>
+
+// -----
+
 func.func @insert_element_with_value_2d(%base: vector<1x16xf32>, %value: f32, %idx: index)
                                         -> vector<1x16xf32> {
   %0 = vector.insert %value, %base[0, %idx]: f32 into vector<1x16xf32>



More information about the Mlir-commits mailing list