[Mlir-commits] [mlir] [mlir][ArithToSPIRV] Fix invalid SPIRV and crashes when lowering integer ops on i1 (PR #189239)
Mehdi Amini
llvmlistbot at llvm.org
Mon Mar 30 16:46:11 PDT 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/189239
>From 55c28b68c170fe8154712bbb29d5ffe5752918b4 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 15:54:39 -0700
Subject: [PATCH 1/2] [mlir][ArithToSPIRV] Fix crash converting
arith.addi/subi/muli on i1 types
arith.addi, arith.subi, and arith.muli on i1 (boolean) operands were
incorrectly lowered to spirv.IAdd, spirv.ISub, and spirv.IMul, which
require 8/16/32/64-bit integer types and reject i1, causing SPIRV
verification to fail.
Fix by adding three new boolean-specialized conversion patterns
(AddIOpBooleanPattern, SubIOpBooleanPattern, MulIOpBooleanPattern)
modeled after the existing XOrIOpBooleanPattern:
- addi on i1: lowers to spirv.LogicalNotEqual (addition mod 2 = XOR)
- subi on i1: lowers to spirv.LogicalNotEqual (subtraction mod 2 = XOR)
- muli on i1: lowers to spirv.LogicalAnd (multiplication mod 2 = AND)
ElementwiseArithOpPattern is updated to reject boolean types so the
specialized patterns take priority.
Fixes #61162
Assisted-by: Claude Code
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 29 +++++++++++++++++++
.../ArithToSPIRV/arith-to-spirv.mlir | 22 ++++++++++++++
2 files changed, 51 insertions(+)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d6b1e9552fbc5..265e3d9fc0bc8 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -182,6 +182,11 @@ struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() <= 3);
+ // Reject boolean types to allow specialized boolean patterns to handle
+ // them (e.g., addi/subi on i1 should use LogicalNotEqual, not IAdd/ISub).
+ if (!adaptor.getOperands().empty() &&
+ isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
Type dstType = converter->convertType(op.getType());
if (!dstType) {
@@ -572,6 +577,27 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
}
};
+/// Converts an arith integer op to the given SPIR-V boolean op if the type is
+/// i1 or vector of i1.
+template <typename ArithOp, typename SPIRVOp>
+struct BoolIOpPattern final : public OpConversionPattern<ArithOp> {
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
+
+ Type dstType = this->getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//
@@ -1410,8 +1436,11 @@ void mlir::arith::populateArithToSPIRVPatterns(
patterns.add<
ConstantCompositeOpPattern,
ConstantScalarOpPattern,
+ BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>,
ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
+ BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>,
ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
+ BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>,
ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 9c726b8643a46..cf8579ad882b8 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -272,6 +272,28 @@ func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
return
}
+// CHECK-LABEL: @bool_arith_scalar
+func.func @bool_arith_scalar(%arg0 : i1, %arg1 : i1) {
+ // CHECK: spirv.LogicalNotEqual
+ %0 = arith.addi %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.subi %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalAnd
+ %2 = arith.muli %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @bool_arith_vector
+func.func @bool_arith_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+ // CHECK: spirv.LogicalNotEqual
+ %0 = arith.addi %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.subi %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalAnd
+ %2 = arith.muli %arg0, %arg1 : vector<4xi1>
+ return
+}
+
// CHECK-LABEL: @shift_scalar
func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.ShiftLeftLogical
>From d9dc180a8ec856d924455600a69ebf38f0ea8c67 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 30 Mar 2026 04:38:21 -0700
Subject: [PATCH 2/2] [mlir][ArithToSPIRV] Fix i1 lowerings for div, rem,
shift, and min/max ops
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Extend the i1 boolean pattern fixes to cover all integer ops that were
previously generating invalid SPIRV (e.g. spirv.UDiv on i1) or crashing:
- divui/divsi(a,b) on i1 → spirv.LogicalAnd (identity: only b=1 is valid)
- remui/remsi(a,b) on i1 → spirv.LogicalAnd(a, spirv.LogicalNot(b))
- shli/shrui(a,b) on i1 → spirv.LogicalAnd(a, spirv.LogicalNot(b))
- shrsi(a,b) on i1 → identity (arithmetic right shift of i1 = input)
- maxui/minsi(a,b) on i1 → spirv.LogicalOr
- maxsi/minui(a,b) on i1 → spirv.LogicalAnd
All new boolean patterns are given benefit=2 so they take priority over
the generic spirv::ElementwiseOpPattern (benefit=1), which has no i1 guard
and would otherwise generate invalid SPIRV for 1-bit integer types.
Also add test coverage for all of the above in arith-to-spirv.mlir.
Assisted-by: Claude Code
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 92 ++++++++++++++++++-
.../ArithToSPIRV/arith-to-spirv.mlir | 57 ++++++++++++
2 files changed, 144 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265e3d9fc0bc8..660f450ff9f27 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -578,10 +578,26 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
};
/// Converts an arith integer op to the given SPIR-V boolean op if the type is
-/// i1 or vector of i1.
+/// i1 or vector of i1. Each mapping follows from the boolean truth table of
+/// the operation:
+/// addi(a, b) = a ^ b (add mod 2 = XOR = LogicalNotEqual)
+/// subi(a, b) = a ^ b (sub mod 2 = XOR = LogicalNotEqual)
+/// muli(a, b) = a & b (1*1=1, else 0 = LogicalAnd)
+/// divui(a, b) = a & b (a/1=a, a/0=UB; truth table matches AND)
+/// divsi(a, b) = a & b (same as divui on i1)
+/// maxsi(a, b) = a & b (signed i1: 1 represents -1, so max is 0 unless both
+/// are 1)
+/// maxui(a, b) = a | b (unsigned max on i1: 1 when either operand is 1)
+/// minsi(a, b) = a | b (signed i1: -1 < 0, so min is 1 when either operand
+/// is 1)
+/// minui(a, b) = a & b (unsigned min on i1: 1 only when both operands are
+/// 1)
template <typename ArithOp, typename SPIRVOp>
struct BoolIOpPattern final : public OpConversionPattern<ArithOp> {
- using OpConversionPattern<ArithOp>::OpConversionPattern;
+ BoolIOpPattern(const TypeConverter &converter, MLIRContext *context)
+ // benefit=2: takes priority over the generic ElementwiseArithOpPattern
+ // (benefit=1) when the operand type is i1.
+ : OpConversionPattern<ArithOp>(converter, context, /*benefit=*/2) {}
LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
@@ -598,6 +614,61 @@ struct BoolIOpPattern final : public OpConversionPattern<ArithOp> {
}
};
+/// Converts an arith binary op on i1 to spirv.LogicalAnd(lhs,
+/// spirv.LogicalNot(rhs)). This covers shift-left, shift-right-unsigned, and
+/// unsigned remainder on i1:
+/// shli(a, b) = a & ~b (shift left clears the bit when b=1)
+/// shrui(a, b) = a & ~b (shift right unsigned clears the bit when b=1)
+/// remui(a, b) = a & ~b (only defined when b=1; a%1=0, and ~b=~1=0, so AND
+/// gives 0)
+/// remsi(a, b) = a & ~b (only defined when b=1; a%1=0, and ~b=~1=0, so AND
+/// gives 0)
+template <typename ArithOp>
+struct BoolIOpAndNotPattern final : public OpConversionPattern<ArithOp> {
+ BoolIOpAndNotPattern(const TypeConverter &converter, MLIRContext *context)
+ // benefit=2: takes priority over the generic ElementwiseArithOpPattern
+ // (benefit=1) when the operand type is i1.
+ : OpConversionPattern<ArithOp>(converter, context, /*benefit=*/2) {}
+
+ LogicalResult
+ matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
+
+ Type dstType = this->getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ Location loc = op.getLoc();
+ Value notRhs = spirv::LogicalNotOp::create(rewriter, loc, dstType,
+ adaptor.getOperands()[1]);
+ rewriter.replaceOpWithNewOp<spirv::LogicalAndOp>(
+ op, dstType, adaptor.getOperands()[0], notRhs);
+ return success();
+ }
+};
+
+/// Converts arith.shrsi on i1 to identity: arithmetic right shift of a 1-bit
+/// signed value always yields the original value (0 >> n = 0, -1 >> n = -1).
+struct ShRSIBoolPattern final : public OpConversionPattern<arith::ShRSIOp> {
+ ShRSIBoolPattern(const TypeConverter &converter, MLIRContext *context)
+ // benefit=2: takes priority over the generic spirv::ElementwiseOpPattern
+ // (benefit=1) when the operand type is i1.
+ : OpConversionPattern<arith::ShRSIOp>(converter, context,
+ /*benefit=*/2) {}
+
+ LogicalResult
+ matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
+
+ rewriter.replaceOp(op, adaptor.getOperands().front());
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//
@@ -1436,21 +1507,28 @@ void mlir::arith::populateArithToSPIRVPatterns(
patterns.add<
ConstantCompositeOpPattern,
ConstantScalarOpPattern,
- BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>,
+ BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>, // add mod 2 = XOR = not-equal
ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
- BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>,
+ BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>, // sub mod 2 = XOR = not-equal
ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
- BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>,
+ BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>, // 1*1=1, else 0 = AND
ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
+ BoolIOpPattern<arith::DivUIOp, spirv::LogicalAndOp>, // a/1=a, a/0=UB; truth table = AND
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
+ BoolIOpPattern<arith::DivSIOp, spirv::LogicalAndOp>, // same as divui on i1
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
+ BoolIOpAndNotPattern<arith::RemUIOp>, // remui(a,b) = a & ~b (see pattern comment)
spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
+ BoolIOpAndNotPattern<arith::RemSIOp>, // remsi(a,b) = a & ~b (see pattern comment)
RemSIOpGLPattern, RemSIOpCLPattern,
BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
XOrIOpLogicalPattern, XOrIOpBooleanPattern,
+ BoolIOpAndNotPattern<arith::ShLIOp>, // shli(a,b) = a & ~b (see pattern comment)
ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
+ BoolIOpAndNotPattern<arith::ShRUIOp>, // shrui(a,b) = a & ~b (see pattern comment)
spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
+ ShRSIBoolPattern,
spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
@@ -1483,6 +1561,10 @@ void mlir::arith::populateArithToSPIRVPatterns(
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
+ BoolIOpPattern<arith::MaxSIOp, spirv::LogicalAndOp>, // signed i1: 1=-1, so max=0 unless both are 1
+ BoolIOpPattern<arith::MaxUIOp, spirv::LogicalOrOp>, // unsigned max on i1: 1 when either is 1
+ BoolIOpPattern<arith::MinSIOp, spirv::LogicalOrOp>, // signed i1: -1<0, so min=1 when either is 1
+ BoolIOpPattern<arith::MinUIOp, spirv::LogicalAndOp>, // unsigned min on i1: 1 only when both are 1
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index cf8579ad882b8..31b70177a0d19 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -320,6 +320,63 @@ func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
// -----
+// Test i1 lowerings for shift, div, rem, and min/max ops.
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @bool_shift_div_rem_scalar
+func.func @bool_shift_div_rem_scalar(%arg0 : i1, %arg1 : i1) {
+ // shli(a,b) = a & ~b
+ // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1
+ // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]]
+ %0 = arith.shli %arg0, %arg1 : i1
+ // shrui(a,b) = a & ~b
+ // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1
+ // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]]
+ %1 = arith.shrui %arg0, %arg1 : i1
+ // shrsi(a,b) = a (arithmetic right shift of i1 is identity)
+ // CHECK-NOT: spirv.ShiftRightArithmetic
+ %2 = arith.shrsi %arg0, %arg1 : i1
+ // divui(a,b) = a & b (only valid for b=1, same as muli)
+ // CHECK: spirv.LogicalAnd %arg0, %arg1
+ %3 = arith.divui %arg0, %arg1 : i1
+ // divsi(a,b) = a & b (only non-UB/non-overflow case: 0/-1 = 0)
+ // CHECK: spirv.LogicalAnd %arg0, %arg1
+ %4 = arith.divsi %arg0, %arg1 : i1
+ // remui(a,b) = a & ~b (a % 1 = 0 for valid b=1)
+ // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1
+ // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]]
+ %5 = arith.remui %arg0, %arg1 : i1
+ // remsi(a,b) = a & ~b
+ // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1
+ // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]]
+ %6 = arith.remsi %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @bool_minmax_scalar
+func.func @bool_minmax_scalar(%arg0 : i1, %arg1 : i1) {
+ // maxui(a,b) = a | b (unsigned max of two booleans is OR)
+ // CHECK: spirv.LogicalOr %arg0, %arg1
+ %0 = arith.maxui %arg0, %arg1 : i1
+ // maxsi(a,b) = a & b (signed max: max(-1,0)=0, so max(true,false)=false → AND)
+ // CHECK: spirv.LogicalAnd %arg0, %arg1
+ %1 = arith.maxsi %arg0, %arg1 : i1
+ // minui(a,b) = a & b (unsigned min of two booleans is AND)
+ // CHECK: spirv.LogicalAnd %arg0, %arg1
+ %2 = arith.minui %arg0, %arg1 : i1
+ // minsi(a,b) = a | b (signed min: min(-1,0)=-1, so min(true,false)=true → OR)
+ // CHECK: spirv.LogicalOr %arg0, %arg1
+ %3 = arith.minsi %arg0, %arg1 : i1
+ return
+}
+
+} // end module
+
+// -----
+
//===----------------------------------------------------------------------===//
// arith.cmpf
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list