[Mlir-commits] [mlir] [mlir][arith-to-spirv] Fix null dereference when converting trunci/extui with tensor types (PR #183654)

Mehdi Amini llvmlistbot at llvm.org
Thu Feb 26 16:49:59 PST 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/183654

`getScalarOrVectorConstInt` only handles `VectorType` and `IntegerType`, returning `nullptr` for any other type (e.g., a `RankedTensorType` that slips through after type emulation maps `tensor<Nxi16>` to `tensor<Nxi32>` with the same destination type). The callers in `TruncIPattern` and `ExtUIPattern` passed this null value directly to `spirv::BitwiseAndOp::create`, causing a null-pointer dereference in `OperandStorage`.

Similarly, the signed-extension pattern passes the result of `getScalarOrVectorConstInt` as a shift amount to `ShiftLeftLogicalOp::create` without a null check.

Add `if (\!mask)` / `if (\!shiftSize)` guards that return a match failure in all three cases, converting the crash into a proper legalization failure.

Fixes #178214

>From a5562d1d60c4d23279ab2392379fcad96fc52092 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 26 Feb 2026 16:49:10 -0800
Subject: [PATCH] [mlir][arith-to-spirv] Fix null dereference when converting
 trunci/extui with tensor types

`getScalarOrVectorConstInt` only handles `VectorType` and `IntegerType`,
returning `nullptr` for any other type (e.g., a `RankedTensorType` that
slips through after type emulation maps `tensor<Nxi16>` to `tensor<Nxi32>`
with the same destination type). The callers in `TruncIPattern` and
`ExtUIPattern` passed this null value directly to `spirv::BitwiseAndOp::create`,
causing a null-pointer dereference in `OperandStorage`.

Similarly, the signed-extension pattern passes the result of
`getScalarOrVectorConstInt` as a shift amount to `ShiftLeftLogicalOp::create`
without a null check.

Add `if (\!mask)` / `if (\!shiftSize)` guards that return a match failure
in all three cases, converting the crash into a proper legalization failure.

Fixes #178214
---
 mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp      |  6 ++++++
 .../ArithToSPIRV/arith-to-spirv-unsupported.mlir       | 10 ++++++++++
 2 files changed, 16 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 299c8afffb2e5..0bc001b5d576a 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -725,6 +725,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
       assert(srcBW < dstBW);
       Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
                                                   rewriter, op.getLoc());
+      if (!shiftSize)
+        return rewriter.notifyMatchFailure(op, "unsupported type for shift");
 
       // First shift left to sequeeze out all leading bits beyond the original
       // bitwidth. Here we need to use the original source and result type's
@@ -800,6 +802,8 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
       Value mask = getScalarOrVectorConstInt(
           dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
           op.getLoc());
+      if (!mask)
+        return rewriter.notifyMatchFailure(op, "unsupported type for mask");
       rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
                                                        adaptor.getIn(), mask);
     } else {
@@ -868,6 +872,8 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
       unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
       Value mask = getScalarOrVectorConstInt(
           dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
+      if (!mask)
+        return rewriter.notifyMatchFailure(op, "unsupported type for mask");
       rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
                                                        adaptor.getIn(), mask);
     } else {
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..92b587d5ed1e4 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -122,6 +122,16 @@ func.func @unsupported_constant_tensor_2xf64_0() {
 
 // -----
 
+// Regression test: arith.trunci on tensor types should not crash
+// (https://github.com/llvm/llvm-project/issues/178214).
+func.func @trunci_tensor_no_crash(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+  // expected-error @+1 {{failed to legalize operation 'arith.trunci'}}
+  %0 = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
+// -----
+
 func.func @constant_dense_resource_non_existant() {
   // expected-error @+2 {{failed to legalize operation 'arith.constant'}}
   // expected-error @+1 {{could not find resource blob}}



More information about the Mlir-commits mailing list