[Mlir-commits] [mlir] [mlir][spirv] Fix crash in ArithToSPIRV trunc-i lowering on tensor types (PR #179009)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 30 17:57:37 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-spirv

Author: Samarth Narang (snarang181)

<details>
<summary>Changes</summary>

This PR fixes a crash in --convert-arith-to-spirv when lowering arith.trunci with tensor operand/result types (e.g., tensor<1xi32> -> tensor<1xi16>). The conversion patterns for arith.trunci could proceed even when the operand types were not legal SPIR-V value types and attempt to create SPIR-V ops (e.g., spirv.BitwiseAndOp / spirv.SConvertOp) with unconverted tensor operands, leading to a hard abort during op construction.

The fix adds explicit legality checks in the arith.trunci lowering patterns to ensure only supported scalar/vector integer types are lowered. Unsupported tensor cases now fail legalization cleanly (with an error diagnostic) instead of crashing.

Fixes https://github.com/llvm/llvm-project/issues/178214

---
Full diff: https://github.com/llvm/llvm-project/pull/179009.diff


2 Files Affected:

- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+26-2) 
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir (+16) 


``````````diff
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index fa8788de6c3d2..cfee6ad6493c6 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,6 +122,13 @@ static bool isBoolScalarOrVector(Type type) {
   return false;
 }
 
+/// Returns true if the given `type` is an integer scalar or vector type.
+static bool isIntScalarOrVectorType(Type type) {
+  auto eltTy = getElementTypeOrSelf(type);
+  return eltTy && isa<IntegerType>(eltTy) &&
+         (isa<IntegerType>(type) || isa<VectorType>(type));
+}
+
 /// Creates a scalar/vector integer constant.
 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
                                        OpBuilder &builder, Location loc) {
@@ -853,15 +860,32 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
   LogicalResult
   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type srcType = adaptor.getIn().getType();
     Type dstType = getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
 
+    // Trunc-to-i1 case is handled in a separate pattern.
     if (isBoolScalarOrVector(dstType))
       return failure();
 
-    if (dstType == srcType) {
+    Type srcType = op.getIn().getType();
+    Type convertedSrcType = getTypeConverter()->convertType(srcType);
+    if (!convertedSrcType)
+      return getTypeConversionFailure(rewriter, op);
+
+    // Ensure we are only lowering scalar/vector integer truncs.
+    // This prevents trying to build SPIR-V ops on tensors.
+    if (!isIntScalarOrVectorType(convertedSrcType) ||
+        !isIntScalarOrVectorType(dstType))
+      return rewriter.notifyMatchFailure(
+          op, "only int scalar or vector type for SPIR-V");
+
+    // Ensure the adaptor operand type matches the converted source type.
+    if (adaptor.getIn().getType() != convertedSrcType)
+      return rewriter.notifyMatchFailure(
+          op, "adaptor operand type does not match converted source type");
+
+    if (dstType == convertedSrcType) {
       // We can have the same source and destination type due to type emulation.
       // Perform bit masking to make sure we don't pollute downstream consumers
       // with unwanted bits. Here we need to use the original result type's
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..fa9440b3059db 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -199,3 +199,19 @@ func.func @type_conversion_failure(%arg0: i32) {
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
+func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+return %t : tensor<1xi16>
+}
+
+} // end module

``````````

</details>


https://github.com/llvm/llvm-project/pull/179009


More information about the Mlir-commits mailing list