[Mlir-commits] [mlir] [MLIR] Fix `BubbleDownVectorBitCastForExtract` crash on non-static index (PR #116518)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Nov 16 17:30:38 PST 2024
https://github.com/lialan created https://github.com/llvm/llvm-project/pull/116518
Previously the patch was not expecting to handle non-static index, when the index is a non constant value it will crash.
This patch is to make sure it return gracefully instead of crashing.
>From 3d940ccaf238090edf26032edcbe9ab57683a42b Mon Sep 17 00:00:00 2001
From: Ubuntu <450283+lialan at users.noreply.github.com>
Date: Sun, 17 Nov 2024 01:22:35 +0000
Subject: [PATCH] [MLIR] Fix `BubbleDownVectorBitCastForExtract` crash on
non-static index
Previously the patch was not expecting to handle non-static index, when
the index is a non constant value it will crash.
This patch is to make sure it return gracefully instead of crashing.
---
.../Dialect/Vector/Transforms/VectorTransforms.cpp | 10 +++++++---
mlir/test/Dialect/Vector/vector-transforms.mlir | 13 +++++++++++++
2 files changed, 20 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 7f6b2303f86e10..3745bee98f3b85 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -596,12 +596,16 @@ struct BubbleDownVectorBitCastForExtract
unsigned expandRatio =
castDstType.getNumElements() / castSrcType.getNumElements();
- auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> uint64_t {
- assert(values[0].is<Attribute>() && "Unexpected non-constant index");
+ auto getFirstIntValue = [](ArrayRef<OpFoldResult> values) -> std::optional<uint64_t> {
+ if (!values[0].is<Attribute>())
+ return std::nullopt;
return cast<IntegerAttr>(values[0].get<Attribute>()).getInt();
};
- uint64_t index = getFirstIntValue(extractOp.getMixedPosition());
+ std::optional<uint64_t> optIndex = getFirstIntValue(extractOp.getMixedPosition());
+ if (!optIndex)
+ return failure();
+ uint64_t index = *optIndex;
// Get the single scalar (as a vector) in the source value that packs the
// desired scalar. E.g. extract vector<1xf32> from vector<4xf32>
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
index 89e8ca1d93109c..de12a87253a673 100644
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -433,3 +433,16 @@ func.func @vec_0D(%arg0: vector<f32>) -> vector<i32> {
%0 = vector.bitcast %arg0 : vector<f32> to vector<i32>
return %0 : vector<i32>
}
+
+// Make sure not crash on dynamic index `vector.extract`:
+func.func @vector_extract_dynamic_index(%arg0 : vector<4xi32>, %index : index) -> i16 {
+ %0 = vector.bitcast %arg0 : vector<4xi32> to vector<8xi16>
+ %1 = vector.extract %0[%index] : i16 from vector<8xi16>
+ return %1 : i16
+}
+
+// CHECK-LABEL: func.func @vector_extract_dynamic_index
+// CHECK-SAME: (%[[VEC:.+]]: vector<4xi32>, %[[IDX:.+]]: index) -> i16 {
+// CHECK: %[[BC:.+]] = vector.bitcast %[[VEC]] : vector<4xi32> to vector<8xi16>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BC]][%[[IDX]]] : i16 from vector<8xi16>
+// CHECK: return %[[EXTRACT]]
More information about the Mlir-commits
mailing list