[Mlir-commits] [mlir] [mlir][arith-to-spirv] Fix null dereference when converting trunci/extui with tensor types (PR #183654)
Mehdi Amini
llvmlistbot at llvm.org
Thu Feb 26 16:49:59 PST 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/183654
`getScalarOrVectorConstInt` only handles `VectorType` and `IntegerType`, returning `nullptr` for any other type (e.g., a `RankedTensorType` that slips through after type emulation maps `tensor<Nxi16>` to `tensor<Nxi32>` with the same destination type). The callers in `TruncIPattern` and `ExtUIPattern` passed this null value directly to `spirv::BitwiseAndOp::create`, causing a null-pointer dereference in `OperandStorage`.
Similarly, the signed-extension pattern passes the result of `getScalarOrVectorConstInt` as a shift amount to `ShiftLeftLogicalOp::create` without a null check.
Add `if (\!mask)` / `if (\!shiftSize)` guards that return a match failure in all three cases, converting the crash into a proper legalization failure.
Fixes #178214
>From a5562d1d60c4d23279ab2392379fcad96fc52092 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Feb 2026 16:49:10 -0800
Subject: [PATCH] [mlir][arith-to-spirv] Fix null dereference when converting
trunci/extui with tensor types
`getScalarOrVectorConstInt` only handles `VectorType` and `IntegerType`,
returning `nullptr` for any other type (e.g., a `RankedTensorType` that
slips through after type emulation maps `tensor<Nxi16>` to `tensor<Nxi32>`
with the same destination type). The callers in `TruncIPattern` and
`ExtUIPattern` passed this null value directly to `spirv::BitwiseAndOp::create`,
causing a null-pointer dereference in `OperandStorage`.
Similarly, the signed-extension pattern passes the result of
`getScalarOrVectorConstInt` as a shift amount to `ShiftLeftLogicalOp::create`
without a null check.
Add `if (\!mask)` / `if (\!shiftSize)` guards that return a match failure
in all three cases, converting the crash into a proper legalization failure.
Fixes #178214
---
mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 6 ++++++
.../ArithToSPIRV/arith-to-spirv-unsupported.mlir | 10 ++++++++++
2 files changed, 16 insertions(+)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 299c8afffb2e5..0bc001b5d576a 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -725,6 +725,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
assert(srcBW < dstBW);
Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
rewriter, op.getLoc());
+ if (!shiftSize)
+ return rewriter.notifyMatchFailure(op, "unsupported type for shift");
// First shift left to sequeeze out all leading bits beyond the original
// bitwidth. Here we need to use the original source and result type's
@@ -800,6 +802,8 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
Value mask = getScalarOrVectorConstInt(
dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
op.getLoc());
+ if (!mask)
+ return rewriter.notifyMatchFailure(op, "unsupported type for mask");
rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
adaptor.getIn(), mask);
} else {
@@ -868,6 +872,8 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
Value mask = getScalarOrVectorConstInt(
dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
+ if (!mask)
+ return rewriter.notifyMatchFailure(op, "unsupported type for mask");
rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
adaptor.getIn(), mask);
} else {
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..92b587d5ed1e4 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -122,6 +122,16 @@ func.func @unsupported_constant_tensor_2xf64_0() {
// -----
+// Regression test: arith.trunci on tensor types should not crash
+// (https://github.com/llvm/llvm-project/issues/178214).
+func.func @trunci_tensor_no_crash(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+ // expected-error @+1 {{failed to legalize operation 'arith.trunci'}}
+ %0 = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+ return %0 : tensor<1xi16>
+}
+
+// -----
+
func.func @constant_dense_resource_non_existant() {
// expected-error @+2 {{failed to legalize operation 'arith.constant'}}
// expected-error @+1 {{could not find resource blob}}
More information about the Mlir-commits
mailing list