[Mlir-commits] [mlir] [mlir][vector] Fix a crash in `VectorExtractOpConversion` (PR #115717)

Longsheng Mou llvmlistbot at llvm.org
Mon Nov 11 06:01:09 PST 2024


https://github.com/CoTinker updated https://github.com/llvm/llvm-project/pull/115717

>From 14191445e73ec9c0a655ce338d7d0063848b1223 Mon Sep 17 00:00:00 2001
From: Longsheng Mou <moulongsheng at huawei.com>
Date: Mon, 11 Nov 2024 21:43:49 +0800
Subject: [PATCH] [mlir][vector] Fix a crash in `VectorExtractOpConversion`

This PR fixes a crash when `vector.extract` extract a scalar and
the size of `position` smaller than rank of vector.
E.g., `vector.extract %arg0[0]: f32 from vector<4x1xf32>`.
---
 .../VectorToLLVM/ConvertVectorToLLVM.cpp      |  5 ++---
 .../VectorToLLVM/vector-to-llvm.mlir          | 20 +++++++++++++++++++
 2 files changed, 22 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 58ca84c8d7bca6..9cb09e9d9c248d 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1105,9 +1105,8 @@ class VectorExtractOpConversion
 
     // One-shot extraction of vector from array (only requires extractvalue).
     // Except for extracting 1-element vectors.
-    if (isa<VectorType>(resultType) &&
-        position.size() !=
-            static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
+    if (position.size() <
+        static_cast<size_t>(extractOp.getSourceVectorType().getRank())) {
       if (extractOp.hasDynamicPosition())
         return failure();
 
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 03bcb341efea2f..26f1f3e5190df5 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1216,6 +1216,26 @@ func.func @extract_vec_1d_from_vec_3d_f32_scalable(%arg0: vector<4x3x[16]xf32>)
 
 // -----
 
+func.func @extract_scalar_from_vec_2d_f32(%arg0: vector<4x1xf32>) -> f32 {
+  %0 = vector.extract %arg0[0]: f32 from vector<4x1xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32
+//       CHECK:   %[[T0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.array<4 x vector<1xf32>>
+//       CHECK:   %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<1xf32> to f32
+//       CHECK:   return %[[T1]] : f32
+
+func.func @extract_scalar_from_vec_2d_f32_scalable(%arg0: vector<4x[1]xf32>) -> f32 {
+  %0 = vector.extract %arg0[0]: f32 from vector<4x[1]xf32>
+  return %0 : f32
+}
+// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_scalable
+//       CHECK:   %[[T0:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.array<4 x vector<[1]xf32>>
+//       CHECK:   %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<[1]xf32> to f32
+//       CHECK:   return %[[T1]] : f32
+
+// -----
+
 func.func @extract_scalar_from_vec_3d_f32(%arg0: vector<4x3x16xf32>) -> f32 {
   %0 = vector.extract %arg0[0, 0, 0]: f32 from vector<4x3x16xf32>
   return %0 : f32



More information about the Mlir-commits mailing list