[Mlir-commits] [mlir] [MLIR] Fix `BubbleDownVectorBitCastForExtract` crash on non-static index (PR #116518)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Nov 16 17:39:10 PST 2024


https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/116518

>From 91342ccb77fec96a3ce77e8e125090f917272865 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Sun, 17 Nov 2024 01:22:35 +0000
Subject: [PATCH] [MLIR] Fix `BubbleDownVectorBitCastForExtract` crash on
 non-static index

Previously the patch was not expecting to handle non-static index, when
the index is a non constant value it will crash.

This patch is to make sure it return gracefully instead of crashing.
---
 .../Dialect/Vector/Transforms/VectorTransforms.cpp  | 12 +++++++++---
 mlir/test/Dialect/Vector/vector-transforms.mlir     | 13 +++++++++++++
 2 files changed, 22 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f6b2303f86e10..55ab00e0cf3ba9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -596,12 +596,18 @@ struct BubbleDownVectorBitCastForExtract
     unsigned expandRatio =
         castDstType.getNumElements() / castSrcType.getNumElements();
 
-    auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
-      assert(values[0].is<Attribute>() && "Unexpected non-constant index");
+    auto getFirstIntValue =
+        [](ArrayRef<OpFoldResult> values) -> std::optional<uint64_t> {
+      if (!values[0].is<Attribute>())
+        return std::nullopt;
       return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
     };
 
-    uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
+    std::optional<uint64_t> optIndex =
+        getFirstIntValue(extractOp.getMixedPosition());
+    if (!optIndex)
+      return failure();
+    uint64_t index = *optIndex;
 
     // Get the single scalar (as a vector) in the source value that packs the
     // desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 89e8ca1d93109c..de12a87253a673 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -433,3 +433,16 @@ func.func @vec_0D(%arg0: vector<f32>) -> vector<i32> {
   %0 = vector.bitcast %arg0 : vector<f32> to vector<i32>
   return %0 : vector<i32>
 }
+
+// Make sure not crash on dynamic index `vector.extract`:
+func.func @vector_extract_dynamic_index(%arg0 : vector<4xi32>, %index : index) -> i16 {
+  %0 = vector.bitcast %arg0 : vector<4xi32> to vector<8xi16>
+  %1 = vector.extract %0[%index] : i16 from vector<8xi16>
+  return %1 : i16
+}
+
+// CHECK-LABEL: func.func @vector_extract_dynamic_index
+// CHECK-SAME: (%[[VEC:.+]]: vector<4xi32>, %[[IDX:.+]]: index) -> i16 {
+// CHECK: %[[BC:.+]] = vector.bitcast %[[VEC]] : vector<4xi32> to vector<8xi16>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BC]][%[[IDX]]] : i16 from vector<8xi16>
+// CHECK: return %[[EXTRACT]]



More information about the Mlir-commits mailing list