[Mlir-commits] [mlir] [mlir][ArithToSPIRV] Fix crash converting arith.addi/subi/muli on i1 types (PR #189239)
Mehdi Amini
llvmlistbot at llvm.org
Sun Mar 29 06:13:10 PDT 2026
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189239
arith.addi, arith.subi, and arith.muli on i1 (boolean) operands were incorrectly lowered to spirv.IAdd, spirv.ISub, and spirv.IMul, which require 8/16/32/64-bit integer types and reject i1, causing SPIRV verification to fail.
Fix by adding three new boolean-specialized conversion patterns (AddIOpBooleanPattern, SubIOpBooleanPattern, MulIOpBooleanPattern) modeled after the existing XOrIOpBooleanPattern:
- addi on i1: lowers to spirv.LogicalNotEqual (addition mod 2 = XOR)
- subi on i1: lowers to spirv.LogicalNotEqual (subtraction mod 2 = XOR)
- muli on i1: lowers to spirv.LogicalAnd (multiplication mod 2 = AND)
ElementwiseArithOpPattern is updated to reject boolean types so the specialized patterns take priority.
Fixes #61162
Assisted-by: Claude Code
>From d255518ca1a37ccc3ce62d7db7c60b502ab721b1 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 15:54:39 -0700
Subject: [PATCH] [mlir][ArithToSPIRV] Fix crash converting
arith.addi/subi/muli on i1 types
arith.addi, arith.subi, and arith.muli on i1 (boolean) operands were
incorrectly lowered to spirv.IAdd, spirv.ISub, and spirv.IMul, which
require 8/16/32/64-bit integer types and reject i1, causing SPIRV
verification to fail.
Fix by adding three new boolean-specialized conversion patterns
(AddIOpBooleanPattern, SubIOpBooleanPattern, MulIOpBooleanPattern)
modeled after the existing XOrIOpBooleanPattern:
- addi on i1: lowers to spirv.LogicalNotEqual (addition mod 2 = XOR)
- subi on i1: lowers to spirv.LogicalNotEqual (subtraction mod 2 = XOR)
- muli on i1: lowers to spirv.LogicalAnd (multiplication mod 2 = AND)
ElementwiseArithOpPattern is updated to reject boolean types so the
specialized patterns take priority.
Fixes #61162
Assisted-by: Claude Code
---
.../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 80 ++++++++++++++++++-
.../ArithToSPIRV/arith-to-spirv.mlir | 22 +++++
2 files changed, 99 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d6b1e9552fbc5..59501697a0f20 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -182,6 +182,11 @@ struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(adaptor.getOperands().size() <= 3);
+ // Reject boolean types to allow specialized boolean patterns to handle
+ // them (e.g., addi/subi on i1 should use LogicalNotEqual, not IAdd/ISub).
+ if (!adaptor.getOperands().empty() &&
+ isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
Type dstType = converter->convertType(op.getType());
if (!dstType) {
@@ -572,6 +577,75 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
}
};
+/// Converts arith.addi to spirv.LogicalNotEqual if the type is i1 or vector of
+/// i1. For booleans, addition mod 2 is equivalent to XOR / not-equal.
+struct AddIOpBooleanPattern final : public OpConversionPattern<arith::AddIOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 2);
+
+ if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
+
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
+ op, dstType, adaptor.getOperands());
+ return success();
+ }
+};
+
+/// Converts arith.subi to spirv.LogicalNotEqual if the type is i1 or vector of
+/// i1. For booleans, subtraction mod 2 is equivalent to XOR / not-equal.
+struct SubIOpBooleanPattern final : public OpConversionPattern<arith::SubIOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 2);
+
+ if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
+
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
+ op, dstType, adaptor.getOperands());
+ return success();
+ }
+};
+
+/// Converts arith.muli to spirv.LogicalAnd if the type is i1 or vector of i1.
+/// For booleans, multiplication mod 2 is equivalent to AND.
+struct MulIOpBooleanPattern final : public OpConversionPattern<arith::MulIOp> {
+ using Base::Base;
+
+ LogicalResult
+ matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ assert(adaptor.getOperands().size() == 2);
+
+ if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+ return failure();
+
+ Type dstType = getTypeConverter()->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ rewriter.replaceOpWithNewOp<spirv::LogicalAndOp>(op, dstType,
+ adaptor.getOperands());
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//
@@ -1410,9 +1484,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
patterns.add<
ConstantCompositeOpPattern,
ConstantScalarOpPattern,
- ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
- ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
- ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
+ AddIOpBooleanPattern, ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
+ SubIOpBooleanPattern, ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
+ MulIOpBooleanPattern, ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 9c726b8643a46..cf8579ad882b8 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -272,6 +272,28 @@ func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
return
}
+// CHECK-LABEL: @bool_arith_scalar
+func.func @bool_arith_scalar(%arg0 : i1, %arg1 : i1) {
+ // CHECK: spirv.LogicalNotEqual
+ %0 = arith.addi %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.subi %arg0, %arg1 : i1
+ // CHECK: spirv.LogicalAnd
+ %2 = arith.muli %arg0, %arg1 : i1
+ return
+}
+
+// CHECK-LABEL: @bool_arith_vector
+func.func @bool_arith_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+ // CHECK: spirv.LogicalNotEqual
+ %0 = arith.addi %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalNotEqual
+ %1 = arith.subi %arg0, %arg1 : vector<4xi1>
+ // CHECK: spirv.LogicalAnd
+ %2 = arith.muli %arg0, %arg1 : vector<4xi1>
+ return
+}
+
// CHECK-LABEL: @shift_scalar
func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
// CHECK: spirv.ShiftLeftLogical
More information about the Mlir-commits
mailing list