[Mlir-commits] [mlir] [mlir][arith-to-spirv] Fix null dereference when converting trunci/extui with tensor types (PR #183654)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 26 16:50:34 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
`getScalarOrVectorConstInt` only handles `VectorType` and `IntegerType`, returning `nullptr` for any other type (e.g., a `RankedTensorType` that slips through after type emulation maps `tensor<Nxi16>` to `tensor<Nxi32>` with the same destination type). The callers in `TruncIPattern` and `ExtUIPattern` passed this null value directly to `spirv::BitwiseAndOp::create`, causing a null-pointer dereference in `OperandStorage`.
Similarly, the signed-extension pattern passes the result of `getScalarOrVectorConstInt` as a shift amount to `ShiftLeftLogicalOp::create` without a null check.
Add `if (\!mask)` / `if (\!shiftSize)` guards that return a match failure in all three cases, converting the crash into a proper legalization failure.
Fixes #<!-- -->178214
---
Full diff: https://github.com/llvm/llvm-project/pull/183654.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+6)
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir (+10)
``````````diff
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 299c8afffb2e5..0bc001b5d576a 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -725,6 +725,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
assert(srcBW < dstBW);
Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
rewriter, op.getLoc());
+ if (!shiftSize)
+ return rewriter.notifyMatchFailure(op, "unsupported type for shift");
// First shift left to sequeeze out all leading bits beyond the original
// bitwidth. Here we need to use the original source and result type's
@@ -800,6 +802,8 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
Value mask = getScalarOrVectorConstInt(
dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
op.getLoc());
+ if (!mask)
+ return rewriter.notifyMatchFailure(op, "unsupported type for mask");
rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
adaptor.getIn(), mask);
} else {
@@ -868,6 +872,8 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
Value mask = getScalarOrVectorConstInt(
dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
+ if (!mask)
+ return rewriter.notifyMatchFailure(op, "unsupported type for mask");
rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
adaptor.getIn(), mask);
} else {
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..92b587d5ed1e4 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -122,6 +122,16 @@ func.func @unsupported_constant_tensor_2xf64_0() {
// -----
+// Regression test: arith.trunci on tensor types should not crash
+// (https://github.com/llvm/llvm-project/issues/178214).
+func.func @trunci_tensor_no_crash(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+ // expected-error @+1 {{failed to legalize operation 'arith.trunci'}}
+ %0 = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+ return %0 : tensor<1xi16>
+}
+
+// -----
+
func.func @constant_dense_resource_non_existant() {
// expected-error @+2 {{failed to legalize operation 'arith.constant'}}
// expected-error @+1 {{could not find resource blob}}
``````````
</details>
https://github.com/llvm/llvm-project/pull/183654
More information about the Mlir-commits
mailing list