[Mlir-commits] [mlir] [MLIR][Vector] Fix crash in BitCastOp::fold for index element type (PR #183572)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 26 08:59:47 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-vector
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
`BitCastOp::fold` called `Type::getIntOrFloatBitWidth()` on the source element type without first verifying it satisfies `isIntOrFloat()`. When the source vector has `index` element type (e.g. `vector<16xindex>`), the assertion `only integers and floats have a bitwidth` fires.
Add an `srcElemType.isIntOrFloat()` guard to the condition so that the constant-folding path is skipped for non-integer/float element types.
Fixes #<!-- -->177835
---
Full diff: https://github.com/llvm/llvm-project/pull/183572.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-1)
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+13)
``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 613adeb5eeaaf..25c2fe71f5ff4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6743,7 +6743,7 @@ OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
if (intPack.isSplat()) {
auto splat = intPack.getSplatValue<IntegerAttr>();
- if (llvm::isa<IntegerType>(dstElemType)) {
+ if (llvm::isa<IntegerType>(dstElemType) && srcElemType.isIntOrFloat()) {
uint64_t srcBitWidth = srcElemType.getIntOrFloatBitWidth();
uint64_t dstBitWidth = dstElemType.getIntOrFloatBitWidth();
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 8126389212ce6..82b2cb633d1c9 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -1371,6 +1371,19 @@ func.func @bitcast_i8_to_i32() -> (vector<4xi32>, vector<4xi32>) {
// -----
+// Verify that bitcast with index source element type does not crash (the fold
+// must not call getIntOrFloatBitWidth on a non-integer/float type).
+// CHECK-LABEL: func @bitcast_index_no_fold
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<16xindex>
+// CHECK: vector.bitcast %[[CST]] : vector<16xindex> to vector<16xi64>
+func.func @bitcast_index_no_fold() -> vector<16xi64> {
+ %cst = arith.constant dense<0> : vector<16xindex>
+ %0 = vector.bitcast %cst : vector<16xindex> to vector<16xi64>
+ return %0 : vector<16xi64>
+}
+
+// -----
+
// CHECK-LABEL: broadcast_poison
// CHECK: %[[POISON:.*]] = ub.poison : vector<4x6xi8>
// CHECK: return %[[POISON]] : vector<4x6xi8>
``````````
</details>
https://github.com/llvm/llvm-project/pull/183572
More information about the Mlir-commits
mailing list