[Mlir-commits] [mlir] [mlir][spirv] Fix crash in ArithToSPIRV trunc-i lowering on tensor types (PR #179009)

Samarth Narang llvmlistbot at llvm.org
Fri Jan 30 18:48:10 PST 2026


https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/179009

>From 17c2ab6403de417804a927cc44235639e362ddd5 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 20:55:24 -0500
Subject: [PATCH 1/2] [mlir][spirv] Avoid crash lowering arith.trunci on tensor
 types

---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 28 +++++++++++++++++--
 .../arith-to-spirv-unsupported.mlir           | 16 +++++++++++
 2 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index fa8788de6c3d2..cfee6ad6493c6 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,6 +122,13 @@ static bool isBoolScalarOrVector(Type type) {
   return false;
 }
 
+/// Returns true if the given `type` is an integer scalar or vector type.
+static bool isIntScalarOrVectorType(Type type) {
+  auto eltTy = getElementTypeOrSelf(type);
+  return eltTy && isa<IntegerType>(eltTy) &&
+         (isa<IntegerType>(type) || isa<VectorType>(type));
+}
+
 /// Creates a scalar/vector integer constant.
 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
                                        OpBuilder &builder, Location loc) {
@@ -853,15 +860,32 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
   LogicalResult
   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type srcType = adaptor.getIn().getType();
     Type dstType = getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
 
+    // Trunc-to-i1 case is handled in a separate pattern.
     if (isBoolScalarOrVector(dstType))
       return failure();
 
-    if (dstType == srcType) {
+    Type srcType = op.getIn().getType();
+    Type convertedSrcType = getTypeConverter()->convertType(srcType);
+    if (!convertedSrcType)
+      return getTypeConversionFailure(rewriter, op);
+
+    // Ensure we are only lowering scalar/vector integer truncs.
+    // This prevents trying to build SPIR-V ops on tensors.
+    if (!isIntScalarOrVectorType(convertedSrcType) ||
+        !isIntScalarOrVectorType(dstType))
+      return rewriter.notifyMatchFailure(
+          op, "only int scalar or vector type for SPIR-V");
+
+    // Ensure the adaptor operand type matches the converted source type.
+    if (adaptor.getIn().getType() != convertedSrcType)
+      return rewriter.notifyMatchFailure(
+          op, "adaptor operand type does not match converted source type");
+
+    if (dstType == convertedSrcType) {
       // We can have the same source and destination type due to type emulation.
       // Perform bit masking to make sure we don't pollute downstream consumers
       // with unwanted bits. Here we need to use the original result type's
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..fa9440b3059db 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -199,3 +199,19 @@ func.func @type_conversion_failure(%arg0: i32) {
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
+func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+return %t : tensor<1xi16>
+}
+
+} // end module

>From a5d294ab39ed3f0ac204c8d63e0c373b072dc360 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 21:47:50 -0500
Subject: [PATCH 2/2] Fix formatting

---
 .../Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index fa9440b3059db..6e327fd57e78b 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -209,9 +209,9 @@ module attributes {
 
 // Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
 func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
-// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
-%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
-return %t : tensor<1xi16>
+  // expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+  %t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+  return %t : tensor<1xi16>
 }
 
 } // end module



More information about the Mlir-commits mailing list