[Mlir-commits] [mlir] [mlir][spirv] Fix crash in ArithToSPIRV trunc-i lowering on tensor types (PR #179009)

Samarth Narang llvmlistbot at llvm.org
Mon Feb 2 07:37:44 PST 2026


https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/179009

>From 9e06c65061052ab474ea0ea6a333e86cff1bb6ba Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 20:55:24 -0500
Subject: [PATCH 1/5] [mlir][spirv] Avoid crash lowering arith.trunci on tensor
 types

---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 28 +++++++++++++++++--
 .../arith-to-spirv-unsupported.mlir           | 16 +++++++++++
 2 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index fa8788de6c3d2..cfee6ad6493c6 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,6 +122,13 @@ static bool isBoolScalarOrVector(Type type) {
   return false;
 }
 
+/// Returns true if the given `type` is an integer scalar or vector type.
+static bool isIntScalarOrVectorType(Type type) {
+  auto eltTy = getElementTypeOrSelf(type);
+  return eltTy && isa<IntegerType>(eltTy) &&
+         (isa<IntegerType>(type) || isa<VectorType>(type));
+}
+
 /// Creates a scalar/vector integer constant.
 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
                                        OpBuilder &builder, Location loc) {
@@ -853,15 +860,32 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
   LogicalResult
   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type srcType = adaptor.getIn().getType();
     Type dstType = getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
 
+    // Trunc-to-i1 case is handled in a separate pattern.
     if (isBoolScalarOrVector(dstType))
       return failure();
 
-    if (dstType == srcType) {
+    Type srcType = op.getIn().getType();
+    Type convertedSrcType = getTypeConverter()->convertType(srcType);
+    if (!convertedSrcType)
+      return getTypeConversionFailure(rewriter, op);
+
+    // Ensure we are only lowering scalar/vector integer truncs.
+    // This prevents trying to build SPIR-V ops on tensors.
+    if (!isIntScalarOrVectorType(convertedSrcType) ||
+        !isIntScalarOrVectorType(dstType))
+      return rewriter.notifyMatchFailure(
+          op, "only int scalar or vector type for SPIR-V");
+
+    // Ensure the adaptor operand type matches the converted source type.
+    if (adaptor.getIn().getType() != convertedSrcType)
+      return rewriter.notifyMatchFailure(
+          op, "adaptor operand type does not match converted source type");
+
+    if (dstType == convertedSrcType) {
       // We can have the same source and destination type due to type emulation.
       // Perform bit masking to make sure we don't pollute downstream consumers
       // with unwanted bits. Here we need to use the original result type's
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..fa9440b3059db 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -199,3 +199,19 @@ func.func @type_conversion_failure(%arg0: i32) {
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
+func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+return %t : tensor<1xi16>
+}
+
+} // end module

>From a1bffeb2c1eeb2da4cb7f497b3588e32357926f0 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 21:47:50 -0500
Subject: [PATCH 2/5] Fix formatting

---
 .../Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index fa9440b3059db..6e327fd57e78b 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -209,9 +209,9 @@ module attributes {
 
 // Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
 func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
-// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
-%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
-return %t : tensor<1xi16>
+  // expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+  %t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+  return %t : tensor<1xi16>
 }
 
 } // end module

>From 97b91853bc746f34edbc33ec7ef279d6aec88f95 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Sun, 1 Feb 2026 21:36:32 -0500
Subject: [PATCH 3/5] Add same check for other ops too

Signed-off-by: Samarth Narang <snarang at utexas.edu>
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 35 +++++++++++++++++++
 1 file changed, 35 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index cfee6ad6493c6..b741b36fcc90a 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -129,6 +129,13 @@ static bool isIntScalarOrVectorType(Type type) {
          (isa<IntegerType>(type) || isa<VectorType>(type));
 }
 
+/// Returns true if the given `type` is an integer or float scalar or vector.
+static bool isIntOrFloatScalarOrVectorType(Type type) {
+  auto eltTy = getElementTypeOrSelf(type);
+  return eltTy && eltTy.isIntOrFloat() &&
+         (isa<VectorType>(type) || type == eltTy);
+}
+
 /// Creates a scalar/vector integer constant.
 static Value getScalarOrVectorConstInt(Type type, uint64_t value,
                                        OpBuilder &builder, Location loc) {
@@ -718,10 +725,19 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
     if (isBoolScalarOrVector(srcType))
       return failure();
 
+    // extSI is only meaningful for integer scalar/vector types in SPIR-V.
+    if (!isIntScalarOrVectorType(srcType))
+      return rewriter.notifyMatchFailure(
+          op, "expected integer scalar or vector input type");
+
     Type dstType = getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
 
+    if (!isIntScalarOrVectorType(dstType))
+      return rewriter.notifyMatchFailure(
+          op, "expected integer scalar or vector result type");
+
     if (dstType == srcType) {
       // We can have the same source and destination type due to type emulation.
       // Perform bit shifting to make sure we have the proper leading set bits.
@@ -794,10 +810,18 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
     if (isBoolScalarOrVector(srcType))
       return failure();
 
+    if (!isIntScalarOrVectorType(srcType))
+      return rewriter.notifyMatchFailure(
+          op, "expected integer scalar or vector input type");
+
     Type dstType = getTypeConverter()->convertType(op.getType());
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
 
+    if (!isIntScalarOrVectorType(dstType))
+      return rewriter.notifyMatchFailure(
+          op, "expected integer scalar or vector result type");
+
     if (dstType == srcType) {
       // We can have the same source and destination type due to type emulation.
       // Perform bit masking to make sure we don't pollute downstream consumers
@@ -839,6 +863,10 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
 
     Location loc = op.getLoc();
     auto srcType = adaptor.getOperands().front().getType();
+    if (!isIntScalarOrVectorType(srcType))
+      return rewriter.notifyMatchFailure(
+          op, "expected integer scalar or vector source type");
+
     // Check if (x & 1) == 1.
     Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
     Value maskedSrc = spirv::BitwiseAndOp::create(
@@ -943,6 +971,13 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
     if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
       return failure();
 
+    if (!isIntOrFloatScalarOrVectorType(srcType) ||
+        !isIntOrFloatScalarOrVectorType(dstType))
+      return rewriter.notifyMatchFailure(
+          op, llvm::formatv(
+                  "expected int/float scalar or vector types, got {0} -> {1}",
+                  srcType, dstType));
+
     if (dstType == srcType) {
       // Due to type conversion, we are seeing the same source and target type.
       // Then we can just erase this operation by forwarding its operand.

>From 6a6137fb774fa0800f42dc32eaee702e43954bef Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 10:36:34 -0500
Subject: [PATCH 4/5] Address review comments

Do not have separate handling for int/fp
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 47 +++++++++++--------
 1 file changed, 27 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index b741b36fcc90a..ca5f2a6ada961 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,18 +122,15 @@ static bool isBoolScalarOrVector(Type type) {
   return false;
 }
 
-/// Returns true if the given `type` is an integer scalar or vector type.
-static bool isIntScalarOrVectorType(Type type) {
-  auto eltTy = getElementTypeOrSelf(type);
-  return eltTy && isa<IntegerType>(eltTy) &&
-         (isa<IntegerType>(type) || isa<VectorType>(type));
-}
+/// Returns true if `type` is a SPIR-V scalar type or a vector of SPIR-V scalar.
+static bool isScalarOrVectorOfScalar(Type type) {
+  if (isa<spirv::ScalarType>(type))
+    return true;
+
+  if (auto vecTy = dyn_cast<VectorType>(type))
+    return isa<spirv::ScalarType>(vecTy.getElementType());
 
-/// Returns true if the given `type` is an integer or float scalar or vector.
-static bool isIntOrFloatScalarOrVectorType(Type type) {
-  auto eltTy = getElementTypeOrSelf(type);
-  return eltTy && eltTy.isIntOrFloat() &&
-         (isa<VectorType>(type) || type == eltTy);
+  return false;
 }
 
 /// Creates a scalar/vector integer constant.
@@ -726,7 +723,7 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
       return failure();
 
     // extSI is only meaningful for integer scalar/vector types in SPIR-V.
-    if (!isIntScalarOrVectorType(srcType))
+    if (!isScalarOrVectorOfScalar(srcType))
       return rewriter.notifyMatchFailure(
           op, "expected integer scalar or vector input type");
 
@@ -734,7 +731,7 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
 
-    if (!isIntScalarOrVectorType(dstType))
+    if (!isScalarOrVectorOfScalar(dstType))
       return rewriter.notifyMatchFailure(
           op, "expected integer scalar or vector result type");
 
@@ -810,7 +807,7 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
     if (isBoolScalarOrVector(srcType))
       return failure();
 
-    if (!isIntScalarOrVectorType(srcType))
+    if (!isScalarOrVectorOfScalar(srcType))
       return rewriter.notifyMatchFailure(
           op, "expected integer scalar or vector input type");
 
@@ -818,7 +815,7 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
     if (!dstType)
       return getTypeConversionFailure(rewriter, op);
 
-    if (!isIntScalarOrVectorType(dstType))
+    if (!isScalarOrVectorOfScalar(dstType))
       return rewriter.notifyMatchFailure(
           op, "expected integer scalar or vector result type");
 
@@ -863,7 +860,7 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
 
     Location loc = op.getLoc();
     auto srcType = adaptor.getOperands().front().getType();
-    if (!isIntScalarOrVectorType(srcType))
+    if (!isScalarOrVectorOfScalar(srcType))
       return rewriter.notifyMatchFailure(
           op, "expected integer scalar or vector source type");
 
@@ -903,8 +900,8 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
 
     // Ensure we are only lowering scalar/vector integer truncs.
     // This prevents trying to build SPIR-V ops on tensors.
-    if (!isIntScalarOrVectorType(convertedSrcType) ||
-        !isIntScalarOrVectorType(dstType))
+    if (!isScalarOrVectorOfScalar(convertedSrcType) ||
+        !isScalarOrVectorOfScalar(dstType))
       return rewriter.notifyMatchFailure(
           op, "only int scalar or vector type for SPIR-V");
 
@@ -963,6 +960,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
   LogicalResult
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+
     Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
     Type dstType = this->getTypeConverter()->convertType(op.getType());
     if (!dstType)
@@ -971,8 +969,17 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
     if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
       return failure();
 
-    if (!isIntOrFloatScalarOrVectorType(srcType) ||
-        !isIntOrFloatScalarOrVectorType(dstType))
+    auto isIntOrFloatScalarOrVectorNotI1 = [](Type type) {
+      Type elt = getElementTypeOrSelf(type);
+      if (!elt || (!elt.isIntOrFloat()))
+        return false;
+      if (elt.isInteger(1))
+        return false;
+      return (type == elt) || isa<VectorType>(type);
+    };
+
+    if (!isIntOrFloatScalarOrVectorNotI1(srcType) ||
+        !isIntOrFloatScalarOrVectorNotI1(dstType))
       return rewriter.notifyMatchFailure(
           op, llvm::formatv(
                   "expected int/float scalar or vector types, got {0} -> {1}",

>From 12bd36d93f255c14845b1f09f2d1314def5a0f0f Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 10:37:05 -0500
Subject: [PATCH 5/5] add tests

---
 .../arith-to-spirv-unsupported.mlir           | 30 +++++++++++++++++++
 1 file changed, 30 insertions(+)

diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 6e327fd57e78b..f249e3fb65bdc 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -215,3 +215,33 @@ func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
 }
 
 } // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+  // Tensor extui is not supported; ensure conversion fails (and does not crash).
+  func.func @unsupported_tensor_extui(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+    // expected-error @+1 {{failed to legalize operation 'arith.extui'}}
+    %t = arith.extui %arg0 : tensor<1xi8> to tensor<1xi32>
+    return %t : tensor<1xi32>
+  }
+} // end module
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+  // Tensor extsi is not supported; ensure conversion fails (and does not crash).
+  func.func @unsupported_tensor_extsi(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+    // expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
+    %t = arith.extsi %arg0 : tensor<1xi8> to tensor<1xi32>
+    return %t : tensor<1xi32>
+  }
+} // end module



More information about the Mlir-commits mailing list