[Mlir-commits] [mlir] [mlir][arith-to-spirv] Fix null dereference when converting trunci/extui with tensor types (PR #183654)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 26 16:50:34 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

<details>
<summary>Changes</summary>

`getScalarOrVectorConstInt` only handles `VectorType` and `IntegerType`, returning `nullptr` for any other type (e.g., a `RankedTensorType` that slips through after type emulation maps `tensor<Nxi16>` to `tensor<Nxi32>` with the same destination type). The callers in `TruncIPattern` and `ExtUIPattern` passed this null value directly to `spirv::BitwiseAndOp::create`, causing a null-pointer dereference in `OperandStorage`.

Similarly, the signed-extension pattern passes the result of `getScalarOrVectorConstInt` as a shift amount to `ShiftLeftLogicalOp::create` without a null check.

Add `if (\!mask)` / `if (\!shiftSize)` guards that return a match failure in all three cases, converting the crash into a proper legalization failure.

Fixes #<!-- -->178214

---
Full diff: https://github.com/llvm/llvm-project/pull/183654.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+6) 
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir (+10) 


``````````diff
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 299c8afffb2e5..0bc001b5d576a 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -725,6 +725,8 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
       assert(srcBW < dstBW);
       Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW,
                                                   rewriter, op.getLoc());
+      if (!shiftSize)
+        return rewriter.notifyMatchFailure(op, "unsupported type for shift");
 
       // First shift left to sequeeze out all leading bits beyond the original
       // bitwidth. Here we need to use the original source and result type's
@@ -800,6 +802,8 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
       Value mask = getScalarOrVectorConstInt(
           dstType, llvm::maskTrailingOnes<uint64_t>(bitwidth), rewriter,
           op.getLoc());
+      if (!mask)
+        return rewriter.notifyMatchFailure(op, "unsupported type for mask");
       rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
                                                        adaptor.getIn(), mask);
     } else {
@@ -868,6 +872,8 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
       unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth();
       Value mask = getScalarOrVectorConstInt(
           dstType, llvm::maskTrailingOnes<uint64_t>(bw), rewriter, op.getLoc());
+      if (!mask)
+        return rewriter.notifyMatchFailure(op, "unsupported type for mask");
       rewriter.replaceOpWithNewOp<spirv::BitwiseAndOp>(op, dstType,
                                                        adaptor.getIn(), mask);
     } else {
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..92b587d5ed1e4 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -122,6 +122,16 @@ func.func @unsupported_constant_tensor_2xf64_0() {
 
 // -----
 
+// Regression test: arith.trunci on tensor types should not crash
+// (https://github.com/llvm/llvm-project/issues/178214).
+func.func @trunci_tensor_no_crash(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+  // expected-error @+1 {{failed to legalize operation 'arith.trunci'}}
+  %0 = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+  return %0 : tensor<1xi16>
+}
+
+// -----
+
 func.func @constant_dense_resource_non_existant() {
   // expected-error @+2 {{failed to legalize operation 'arith.constant'}}
   // expected-error @+1 {{could not find resource blob}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/183654


More information about the Mlir-commits mailing list