[Mlir-commits] [mlir] [mlir][spirv] Fix crash in ArithToSPIRV trunc-i lowering on tensor types (PR #179009)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 30 17:57:37 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-spirv
Author: Samarth Narang (snarang181)
<details>
<summary>Changes</summary>
This PR fixes a crash in --convert-arith-to-spirv when lowering arith.trunci with tensor operand/result types (e.g., tensor<1xi32> -> tensor<1xi16>). The conversion patterns for arith.trunci could proceed even when the operand types were not legal SPIR-V value types and attempt to create SPIR-V ops (e.g., spirv.BitwiseAndOp / spirv.SConvertOp) with unconverted tensor operands, leading to a hard abort during op construction.
The fix adds explicit legality checks in the arith.trunci lowering patterns to ensure only supported scalar/vector integer types are lowered. Unsupported tensor cases now fail legalization cleanly (with an error diagnostic) instead of crashing.
Fixes https://github.com/llvm/llvm-project/issues/178214
---
Full diff: https://github.com/llvm/llvm-project/pull/179009.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+26-2)
- (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir (+16)
``````````diff
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index fa8788de6c3d2..cfee6ad6493c6 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,6 +122,13 @@ static bool isBoolScalarOrVector(Type type) {
return false;
}
+/// Returns true if the given `type` is an integer scalar or vector type.
+static bool isIntScalarOrVectorType(Type type) {
+ auto eltTy = getElementTypeOrSelf(type);
+ return eltTy && isa<IntegerType>(eltTy) &&
+ (isa<IntegerType>(type) || isa<VectorType>(type));
+}
+
/// Creates a scalar/vector integer constant.
static Value getScalarOrVectorConstInt(Type type, uint64_t value,
OpBuilder &builder, Location loc) {
@@ -853,15 +860,32 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
LogicalResult
matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type srcType = adaptor.getIn().getType();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
+ // Trunc-to-i1 case is handled in a separate pattern.
if (isBoolScalarOrVector(dstType))
return failure();
- if (dstType == srcType) {
+ Type srcType = op.getIn().getType();
+ Type convertedSrcType = getTypeConverter()->convertType(srcType);
+ if (!convertedSrcType)
+ return getTypeConversionFailure(rewriter, op);
+
+ // Ensure we are only lowering scalar/vector integer truncs.
+ // This prevents trying to build SPIR-V ops on tensors.
+ if (!isIntScalarOrVectorType(convertedSrcType) ||
+ !isIntScalarOrVectorType(dstType))
+ return rewriter.notifyMatchFailure(
+ op, "only int scalar or vector type for SPIR-V");
+
+ // Ensure the adaptor operand type matches the converted source type.
+ if (adaptor.getIn().getType() != convertedSrcType)
+ return rewriter.notifyMatchFailure(
+ op, "adaptor operand type does not match converted source type");
+
+ if (dstType == convertedSrcType) {
// We can have the same source and destination type due to type emulation.
// Perform bit masking to make sure we don't pollute downstream consumers
// with unwanted bits. Here we need to use the original result type's
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..fa9440b3059db 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -199,3 +199,19 @@ func.func @type_conversion_failure(%arg0: i32) {
}
} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
+func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+return %t : tensor<1xi16>
+}
+
+} // end module
``````````
</details>
https://github.com/llvm/llvm-project/pull/179009
More information about the Mlir-commits
mailing list