[Mlir-commits] [mlir] [mlir][spirv] Fix crash in ArithToSPIRV trunc-i lowering on tensor types (PR #179009)
Samarth Narang
llvmlistbot at llvm.org
Fri Jan 30 18:48:10 PST 2026
https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/179009
>From 17c2ab6403de417804a927cc44235639e362ddd5 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 20:55:24 -0500
Subject: [PATCH 1/2] [mlir][spirv] Avoid crash lowering arith.trunci on tensor
types
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 28 +++++++++++++++++--
.../arith-to-spirv-unsupported.mlir | 16 +++++++++++
2 files changed, 42 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index fa8788de6c3d2..cfee6ad6493c6 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,6 +122,13 @@ static bool isBoolScalarOrVector(Type type) {
return false;
}
+/// Returns true if the given `type` is an integer scalar or vector type.
+static bool isIntScalarOrVectorType(Type type) {
+ auto eltTy = getElementTypeOrSelf(type);
+ return eltTy && isa<IntegerType>(eltTy) &&
+ (isa<IntegerType>(type) || isa<VectorType>(type));
+}
+
/// Creates a scalar/vector integer constant.
static Value getScalarOrVectorConstInt(Type type, uint64_t value,
OpBuilder &builder, Location loc) {
@@ -853,15 +860,32 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
LogicalResult
matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type srcType = adaptor.getIn().getType();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
+ // Trunc-to-i1 case is handled in a separate pattern.
if (isBoolScalarOrVector(dstType))
return failure();
- if (dstType == srcType) {
+ Type srcType = op.getIn().getType();
+ Type convertedSrcType = getTypeConverter()->convertType(srcType);
+ if (!convertedSrcType)
+ return getTypeConversionFailure(rewriter, op);
+
+ // Ensure we are only lowering scalar/vector integer truncs.
+ // This prevents trying to build SPIR-V ops on tensors.
+ if (!isIntScalarOrVectorType(convertedSrcType) ||
+ !isIntScalarOrVectorType(dstType))
+ return rewriter.notifyMatchFailure(
+ op, "only int scalar or vector type for SPIR-V");
+
+ // Ensure the adaptor operand type matches the converted source type.
+ if (adaptor.getIn().getType() != convertedSrcType)
+ return rewriter.notifyMatchFailure(
+ op, "adaptor operand type does not match converted source type");
+
+ if (dstType == convertedSrcType) {
// We can have the same source and destination type due to type emulation.
// Perform bit masking to make sure we don't pollute downstream consumers
// with unwanted bits. Here we need to use the original result type's
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..fa9440b3059db 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -199,3 +199,19 @@ func.func @type_conversion_failure(%arg0: i32) {
}
} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
+func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+return %t : tensor<1xi16>
+}
+
+} // end module
>From a5d294ab39ed3f0ac204c8d63e0c373b072dc360 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 21:47:50 -0500
Subject: [PATCH 2/2] Fix formatting
---
.../Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index fa9440b3059db..6e327fd57e78b 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -209,9 +209,9 @@ module attributes {
// Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
-// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
-%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
-return %t : tensor<1xi16>
+ // expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+ %t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+ return %t : tensor<1xi16>
}
} // end module
More information about the Mlir-commits
mailing list