[Mlir-commits] [mlir] [mlir][spirv] Fix crash in ArithToSPIRV trunc-i lowering on tensor types (PR #179009)
Samarth Narang
llvmlistbot at llvm.org
Mon Feb 2 07:37:44 PST 2026
https://github.com/snarang181 updated https://github.com/llvm/llvm-project/pull/179009
>From 9e06c65061052ab474ea0ea6a333e86cff1bb6ba Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 20:55:24 -0500
Subject: [PATCH 1/5] [mlir][spirv] Avoid crash lowering arith.trunci on tensor
types
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 28 +++++++++++++++++--
.../arith-to-spirv-unsupported.mlir | 16 +++++++++++
2 files changed, 42 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index fa8788de6c3d2..cfee6ad6493c6 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,6 +122,13 @@ static bool isBoolScalarOrVector(Type type) {
return false;
}
+/// Returns true if the given `type` is an integer scalar or vector type.
+static bool isIntScalarOrVectorType(Type type) {
+ auto eltTy = getElementTypeOrSelf(type);
+ return eltTy && isa<IntegerType>(eltTy) &&
+ (isa<IntegerType>(type) || isa<VectorType>(type));
+}
+
/// Creates a scalar/vector integer constant.
static Value getScalarOrVectorConstInt(Type type, uint64_t value,
OpBuilder &builder, Location loc) {
@@ -853,15 +860,32 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
LogicalResult
matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- Type srcType = adaptor.getIn().getType();
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
+ // Trunc-to-i1 case is handled in a separate pattern.
if (isBoolScalarOrVector(dstType))
return failure();
- if (dstType == srcType) {
+ Type srcType = op.getIn().getType();
+ Type convertedSrcType = getTypeConverter()->convertType(srcType);
+ if (!convertedSrcType)
+ return getTypeConversionFailure(rewriter, op);
+
+ // Ensure we are only lowering scalar/vector integer truncs.
+ // This prevents trying to build SPIR-V ops on tensors.
+ if (!isIntScalarOrVectorType(convertedSrcType) ||
+ !isIntScalarOrVectorType(dstType))
+ return rewriter.notifyMatchFailure(
+ op, "only int scalar or vector type for SPIR-V");
+
+ // Ensure the adaptor operand type matches the converted source type.
+ if (adaptor.getIn().getType() != convertedSrcType)
+ return rewriter.notifyMatchFailure(
+ op, "adaptor operand type does not match converted source type");
+
+ if (dstType == convertedSrcType) {
// We can have the same source and destination type due to type emulation.
// Perform bit masking to make sure we don't pollute downstream consumers
// with unwanted bits. Here we need to use the original result type's
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 9d7ab2be096ef..fa9440b3059db 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -199,3 +199,19 @@ func.func @type_conversion_failure(%arg0: i32) {
}
} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
+func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
+// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+return %t : tensor<1xi16>
+}
+
+} // end module
>From a1bffeb2c1eeb2da4cb7f497b3588e32357926f0 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Fri, 30 Jan 2026 21:47:50 -0500
Subject: [PATCH 2/5] Fix formatting
---
.../Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index fa9440b3059db..6e327fd57e78b 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -209,9 +209,9 @@ module attributes {
// Tensor truncation is not supported; ensure conversion fails gracefully (no crash).
func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
-// expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
-%t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
-return %t : tensor<1xi16>
+ // expected-error at +1 {{failed to legalize operation 'arith.trunci'}}
+ %t = arith.trunci %arg0 : tensor<1xi32> to tensor<1xi16>
+ return %t : tensor<1xi16>
}
} // end module
>From 97b91853bc746f34edbc33ec7ef279d6aec88f95 Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Sun, 1 Feb 2026 21:36:32 -0500
Subject: [PATCH 3/5] Add same check for other ops too
Signed-off-by: Samarth Narang <snarang at utexas.edu>
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 35 +++++++++++++++++++
1 file changed, 35 insertions(+)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index cfee6ad6493c6..b741b36fcc90a 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -129,6 +129,13 @@ static bool isIntScalarOrVectorType(Type type) {
(isa<IntegerType>(type) || isa<VectorType>(type));
}
+/// Returns true if the given `type` is an integer or float scalar or vector.
+static bool isIntOrFloatScalarOrVectorType(Type type) {
+ auto eltTy = getElementTypeOrSelf(type);
+ return eltTy && eltTy.isIntOrFloat() &&
+ (isa<VectorType>(type) || type == eltTy);
+}
+
/// Creates a scalar/vector integer constant.
static Value getScalarOrVectorConstInt(Type type, uint64_t value,
OpBuilder &builder, Location loc) {
@@ -718,10 +725,19 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
if (isBoolScalarOrVector(srcType))
return failure();
+ // extSI is only meaningful for integer scalar/vector types in SPIR-V.
+ if (!isIntScalarOrVectorType(srcType))
+ return rewriter.notifyMatchFailure(
+ op, "expected integer scalar or vector input type");
+
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
+ if (!isIntScalarOrVectorType(dstType))
+ return rewriter.notifyMatchFailure(
+ op, "expected integer scalar or vector result type");
+
if (dstType == srcType) {
// We can have the same source and destination type due to type emulation.
// Perform bit shifting to make sure we have the proper leading set bits.
@@ -794,10 +810,18 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
if (isBoolScalarOrVector(srcType))
return failure();
+ if (!isIntScalarOrVectorType(srcType))
+ return rewriter.notifyMatchFailure(
+ op, "expected integer scalar or vector input type");
+
Type dstType = getTypeConverter()->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);
+ if (!isIntScalarOrVectorType(dstType))
+ return rewriter.notifyMatchFailure(
+ op, "expected integer scalar or vector result type");
+
if (dstType == srcType) {
// We can have the same source and destination type due to type emulation.
// Perform bit masking to make sure we don't pollute downstream consumers
@@ -839,6 +863,10 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
Location loc = op.getLoc();
auto srcType = adaptor.getOperands().front().getType();
+ if (!isIntScalarOrVectorType(srcType))
+ return rewriter.notifyMatchFailure(
+ op, "expected integer scalar or vector source type");
+
// Check if (x & 1) == 1.
Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
Value maskedSrc = spirv::BitwiseAndOp::create(
@@ -943,6 +971,13 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
return failure();
+ if (!isIntOrFloatScalarOrVectorType(srcType) ||
+ !isIntOrFloatScalarOrVectorType(dstType))
+ return rewriter.notifyMatchFailure(
+ op, llvm::formatv(
+ "expected int/float scalar or vector types, got {0} -> {1}",
+ srcType, dstType));
+
if (dstType == srcType) {
// Due to type conversion, we are seeing the same source and target type.
// Then we can just erase this operation by forwarding its operand.
>From 6a6137fb774fa0800f42dc32eaee702e43954bef Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 10:36:34 -0500
Subject: [PATCH 4/5] Address review comments
Do not have separate handling for int/fp
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 47 +++++++++++--------
1 file changed, 27 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index b741b36fcc90a..ca5f2a6ada961 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -122,18 +122,15 @@ static bool isBoolScalarOrVector(Type type) {
return false;
}
-/// Returns true if the given `type` is an integer scalar or vector type.
-static bool isIntScalarOrVectorType(Type type) {
- auto eltTy = getElementTypeOrSelf(type);
- return eltTy && isa<IntegerType>(eltTy) &&
- (isa<IntegerType>(type) || isa<VectorType>(type));
-}
+/// Returns true if `type` is a SPIR-V scalar type or a vector of SPIR-V scalar.
+static bool isScalarOrVectorOfScalar(Type type) {
+ if (isa<spirv::ScalarType>(type))
+ return true;
+
+ if (auto vecTy = dyn_cast<VectorType>(type))
+ return isa<spirv::ScalarType>(vecTy.getElementType());
-/// Returns true if the given `type` is an integer or float scalar or vector.
-static bool isIntOrFloatScalarOrVectorType(Type type) {
- auto eltTy = getElementTypeOrSelf(type);
- return eltTy && eltTy.isIntOrFloat() &&
- (isa<VectorType>(type) || type == eltTy);
+ return false;
}
/// Creates a scalar/vector integer constant.
@@ -726,7 +723,7 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
return failure();
// extSI is only meaningful for integer scalar/vector types in SPIR-V.
- if (!isIntScalarOrVectorType(srcType))
+ if (!isScalarOrVectorOfScalar(srcType))
return rewriter.notifyMatchFailure(
op, "expected integer scalar or vector input type");
@@ -734,7 +731,7 @@ struct ExtSIPattern final : public OpConversionPattern<arith::ExtSIOp> {
if (!dstType)
return getTypeConversionFailure(rewriter, op);
- if (!isIntScalarOrVectorType(dstType))
+ if (!isScalarOrVectorOfScalar(dstType))
return rewriter.notifyMatchFailure(
op, "expected integer scalar or vector result type");
@@ -810,7 +807,7 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
if (isBoolScalarOrVector(srcType))
return failure();
- if (!isIntScalarOrVectorType(srcType))
+ if (!isScalarOrVectorOfScalar(srcType))
return rewriter.notifyMatchFailure(
op, "expected integer scalar or vector input type");
@@ -818,7 +815,7 @@ struct ExtUIPattern final : public OpConversionPattern<arith::ExtUIOp> {
if (!dstType)
return getTypeConversionFailure(rewriter, op);
- if (!isIntScalarOrVectorType(dstType))
+ if (!isScalarOrVectorOfScalar(dstType))
return rewriter.notifyMatchFailure(
op, "expected integer scalar or vector result type");
@@ -863,7 +860,7 @@ struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
Location loc = op.getLoc();
auto srcType = adaptor.getOperands().front().getType();
- if (!isIntScalarOrVectorType(srcType))
+ if (!isScalarOrVectorOfScalar(srcType))
return rewriter.notifyMatchFailure(
op, "expected integer scalar or vector source type");
@@ -903,8 +900,8 @@ struct TruncIPattern final : public OpConversionPattern<arith::TruncIOp> {
// Ensure we are only lowering scalar/vector integer truncs.
// This prevents trying to build SPIR-V ops on tensors.
- if (!isIntScalarOrVectorType(convertedSrcType) ||
- !isIntScalarOrVectorType(dstType))
+ if (!isScalarOrVectorOfScalar(convertedSrcType) ||
+ !isScalarOrVectorOfScalar(dstType))
return rewriter.notifyMatchFailure(
op, "only int scalar or vector type for SPIR-V");
@@ -963,6 +960,7 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+
Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
Type dstType = this->getTypeConverter()->convertType(op.getType());
if (!dstType)
@@ -971,8 +969,17 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
return failure();
- if (!isIntOrFloatScalarOrVectorType(srcType) ||
- !isIntOrFloatScalarOrVectorType(dstType))
+ auto isIntOrFloatScalarOrVectorNotI1 = [](Type type) {
+ Type elt = getElementTypeOrSelf(type);
+ if (!elt || (!elt.isIntOrFloat()))
+ return false;
+ if (elt.isInteger(1))
+ return false;
+ return (type == elt) || isa<VectorType>(type);
+ };
+
+ if (!isIntOrFloatScalarOrVectorNotI1(srcType) ||
+ !isIntOrFloatScalarOrVectorNotI1(dstType))
return rewriter.notifyMatchFailure(
op, llvm::formatv(
"expected int/float scalar or vector types, got {0} -> {1}",
>From 12bd36d93f255c14845b1f09f2d1314def5a0f0f Mon Sep 17 00:00:00 2001
From: Samarth Narang <snarang at utexas.edu>
Date: Mon, 2 Feb 2026 10:37:05 -0500
Subject: [PATCH 5/5] add tests
---
.../arith-to-spirv-unsupported.mlir | 30 +++++++++++++++++++
1 file changed, 30 insertions(+)
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 6e327fd57e78b..f249e3fb65bdc 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -215,3 +215,33 @@ func.func @unsupported_tensor_trunci(%arg0: tensor<1xi32>) -> tensor<1xi16> {
}
} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+ // Tensor extui is not supported; ensure conversion fails (and does not crash).
+ func.func @unsupported_tensor_extui(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+ // expected-error @+1 {{failed to legalize operation 'arith.extui'}}
+ %t = arith.extui %arg0 : tensor<1xi8> to tensor<1xi32>
+ return %t : tensor<1xi32>
+ }
+} // end module
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+ // Tensor extsi is not supported; ensure conversion fails (and does not crash).
+ func.func @unsupported_tensor_extsi(%arg0: tensor<1xi8>) -> tensor<1xi32> {
+ // expected-error @+1 {{failed to legalize operation 'arith.extsi'}}
+ %t = arith.extsi %arg0 : tensor<1xi8> to tensor<1xi32>
+ return %t : tensor<1xi32>
+ }
+} // end module
More information about the Mlir-commits
mailing list