[Mlir-commits] [mlir] [mlir][ArithToSPIRV] Fix crash converting arith.addi/subi/muli on i1 types (PR #189239)

Mehdi Amini llvmlistbot at llvm.org
Sun Mar 29 06:13:10 PDT 2026


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/189239

arith.addi, arith.subi, and arith.muli on i1 (boolean) operands were incorrectly lowered to spirv.IAdd, spirv.ISub, and spirv.IMul, which require 8/16/32/64-bit integer types and reject i1, causing SPIRV verification to fail.

Fix by adding three new boolean-specialized conversion patterns (AddIOpBooleanPattern, SubIOpBooleanPattern, MulIOpBooleanPattern) modeled after the existing XOrIOpBooleanPattern:
- addi on i1: lowers to spirv.LogicalNotEqual (addition mod 2 = XOR)
- subi on i1: lowers to spirv.LogicalNotEqual (subtraction mod 2 = XOR)
- muli on i1: lowers to spirv.LogicalAnd (multiplication mod 2 = AND)

ElementwiseArithOpPattern is updated to reject boolean types so the specialized patterns take priority.

Fixes #61162

Assisted-by: Claude Code

>From d255518ca1a37ccc3ce62d7db7c60b502ab721b1 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 15:54:39 -0700
Subject: [PATCH] [mlir][ArithToSPIRV] Fix crash converting
 arith.addi/subi/muli on i1 types

arith.addi, arith.subi, and arith.muli on i1 (boolean) operands were
incorrectly lowered to spirv.IAdd, spirv.ISub, and spirv.IMul, which
require 8/16/32/64-bit integer types and reject i1, causing SPIRV
verification to fail.

Fix by adding three new boolean-specialized conversion patterns
(AddIOpBooleanPattern, SubIOpBooleanPattern, MulIOpBooleanPattern)
modeled after the existing XOrIOpBooleanPattern:
- addi on i1: lowers to spirv.LogicalNotEqual (addition mod 2 = XOR)
- subi on i1: lowers to spirv.LogicalNotEqual (subtraction mod 2 = XOR)
- muli on i1: lowers to spirv.LogicalAnd (multiplication mod 2 = AND)

ElementwiseArithOpPattern is updated to reject boolean types so the
specialized patterns take priority.

Fixes #61162

Assisted-by: Claude Code
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 80 ++++++++++++++++++-
 .../ArithToSPIRV/arith-to-spirv.mlir          | 22 +++++
 2 files changed, 99 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d6b1e9552fbc5..59501697a0f20 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -182,6 +182,11 @@ struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     assert(adaptor.getOperands().size() <= 3);
+    // Reject boolean types to allow specialized boolean patterns to handle
+    // them (e.g., addi/subi on i1 should use LogicalNotEqual, not IAdd/ISub).
+    if (!adaptor.getOperands().empty() &&
+        isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
     auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
     Type dstType = converter->convertType(op.getType());
     if (!dstType) {
@@ -572,6 +577,75 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
   }
 };
 
+/// Converts arith.addi to spirv.LogicalNotEqual if the type is i1 or vector of
+/// i1. For booleans, addition mod 2 is equivalent to XOR / not-equal.
+struct AddIOpBooleanPattern final : public OpConversionPattern<arith::AddIOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(arith::AddIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 2);
+
+    if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
+
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
+        op, dstType, adaptor.getOperands());
+    return success();
+  }
+};
+
+/// Converts arith.subi to spirv.LogicalNotEqual if the type is i1 or vector of
+/// i1. For booleans, subtraction mod 2 is equivalent to XOR / not-equal.
+struct SubIOpBooleanPattern final : public OpConversionPattern<arith::SubIOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 2);
+
+    if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
+
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(
+        op, dstType, adaptor.getOperands());
+    return success();
+  }
+};
+
+/// Converts arith.muli to spirv.LogicalAnd if the type is i1 or vector of i1.
+/// For booleans, multiplication mod 2 is equivalent to AND.
+struct MulIOpBooleanPattern final : public OpConversionPattern<arith::MulIOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(arith::MulIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    assert(adaptor.getOperands().size() == 2);
+
+    if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
+
+    Type dstType = getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    rewriter.replaceOpWithNewOp<spirv::LogicalAndOp>(op, dstType,
+                                                     adaptor.getOperands());
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // UIToFPOp
 //===----------------------------------------------------------------------===//
@@ -1410,9 +1484,9 @@ void mlir::arith::populateArithToSPIRVPatterns(
   patterns.add<
     ConstantCompositeOpPattern,
     ConstantScalarOpPattern,
-    ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
-    ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
-    ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
+    AddIOpBooleanPattern, ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
+    SubIOpBooleanPattern, ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
+    MulIOpBooleanPattern, ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 9c726b8643a46..cf8579ad882b8 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -272,6 +272,28 @@ func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
   return
 }
 
+// CHECK-LABEL: @bool_arith_scalar
+func.func @bool_arith_scalar(%arg0 : i1, %arg1 : i1) {
+  // CHECK: spirv.LogicalNotEqual
+  %0 = arith.addi %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalNotEqual
+  %1 = arith.subi %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalAnd
+  %2 = arith.muli %arg0, %arg1 : i1
+  return
+}
+
+// CHECK-LABEL: @bool_arith_vector
+func.func @bool_arith_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+  // CHECK: spirv.LogicalNotEqual
+  %0 = arith.addi %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalNotEqual
+  %1 = arith.subi %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalAnd
+  %2 = arith.muli %arg0, %arg1 : vector<4xi1>
+  return
+}
+
 // CHECK-LABEL: @shift_scalar
 func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
   // CHECK: spirv.ShiftLeftLogical



More information about the Mlir-commits mailing list