[Mlir-commits] [mlir] [MLIR] Fix `BubbleDownVectorBitCastForExtract` crash on non-static index (PR #116518)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 18 14:45:10 PST 2024
https://github.com/lialan updated https://github.com/llvm/llvm-project/pull/116518
>From 001d08f6b27c41ebb8f19af2b0e3eca9bdf341c6 Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Sun, 17 Nov 2024 01:22:35 +0000
Subject: [PATCH] [MLIR] Fix `BubbleDownVectorBitCastForExtract` crash on
non-static index
Previously the patch was not expecting to handle non-static index, when
the index is a non constant value it will crash.
This patch is to make sure it return gracefully instead of crashing.
---
.../Dialect/Vector/Transforms/VectorTransforms.cpp | 11 +++++------
mlir/test/Dialect/Vector/vector-transforms.mlir | 13 +++++++++++++
2 files changed, 18 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f6b2303f86e10..20cd9cba6909a6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -596,12 +596,11 @@ struct BubbleDownVectorBitCastForExtract
unsigned expandRatio =
castDstType.getNumElements() / castSrcType.getNumElements();
- auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
- assert(values[0].is<Attribute>() && "Unexpected non-constant index");
- return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
- };
-
- uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
+ // Get the first element of the mixed position as integer.
+ auto mixedPos = extractOp.getMixedPosition();
+ if (mixedPos.size() > 0 && !mixedPos[0].is<Attribute>())
+ return failure();
+ uint64_t index = cast<IntegerAttr>(mixedPos[0].get<Attribute>()).getInt();
// Get the single scalar (as a vector) in the source value that packs the
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 89e8ca1d93109c..de12a87253a673 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -433,3 +433,16 @@ func.func @vec_0D(%arg0: vector<f32>) -> vector<i32> {
%0 = vector.bitcast %arg0 : vector<f32> to vector<i32>
return %0 : vector<i32>
}
+
+// Make sure not crash on dynamic index `vector.extract`:
+func.func @vector_extract_dynamic_index(%arg0 : vector<4xi32>, %index : index) -> i16 {
+ %0 = vector.bitcast %arg0 : vector<4xi32> to vector<8xi16>
+ %1 = vector.extract %0[%index] : i16 from vector<8xi16>
+ return %1 : i16
+}
+
+// CHECK-LABEL: func.func @vector_extract_dynamic_index
+// CHECK-SAME: (%[[VEC:.+]]: vector<4xi32>, %[[IDX:.+]]: index) -> i16 {
+// CHECK: %[[BC:.+]] = vector.bitcast %[[VEC]] : vector<4xi32> to vector<8xi16>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BC]][%[[IDX]]] : i16 from vector<8xi16>
+// CHECK: return %[[EXTRACT]]
More information about the Mlir-commits
mailing list