[Mlir-commits] [mlir] [mlir][ArithToSPIRV] Fix crash converting arith.addi/subi/muli on i1 types (PR #189239)

Mehdi Amini llvmlistbot at llvm.org
Mon Mar 30 03:40:56 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/189239

>From 55c28b68c170fe8154712bbb29d5ffe5752918b4 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 15:54:39 -0700
Subject: [PATCH] [mlir][ArithToSPIRV] Fix crash converting
 arith.addi/subi/muli on i1 types

arith.addi, arith.subi, and arith.muli on i1 (boolean) operands were
incorrectly lowered to spirv.IAdd, spirv.ISub, and spirv.IMul, which
require 8/16/32/64-bit integer types and reject i1, causing SPIRV
verification to fail.

Fix by adding three new boolean-specialized conversion patterns
(AddIOpBooleanPattern, SubIOpBooleanPattern, MulIOpBooleanPattern)
modeled after the existing XOrIOpBooleanPattern:
- addi on i1: lowers to spirv.LogicalNotEqual (addition mod 2 = XOR)
- subi on i1: lowers to spirv.LogicalNotEqual (subtraction mod 2 = XOR)
- muli on i1: lowers to spirv.LogicalAnd (multiplication mod 2 = AND)

ElementwiseArithOpPattern is updated to reject boolean types so the
specialized patterns take priority.

Fixes #61162

Assisted-by: Claude Code
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 29 +++++++++++++++++++
 .../ArithToSPIRV/arith-to-spirv.mlir          | 22 ++++++++++++++
 2 files changed, 51 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d6b1e9552fbc5..265e3d9fc0bc8 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -182,6 +182,11 @@ struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     assert(adaptor.getOperands().size() <= 3);
+    // Reject boolean types to allow specialized boolean patterns to handle
+    // them (e.g., addi/subi on i1 should use LogicalNotEqual, not IAdd/ISub).
+    if (!adaptor.getOperands().empty() &&
+        isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
     auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
     Type dstType = converter->convertType(op.getType());
     if (!dstType) {
@@ -572,6 +577,27 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
   }
 };
 
+/// Converts an arith integer op to the given SPIR-V boolean op if the type is
+/// i1 or vector of i1.
+template <typename ArithOp, typename SPIRVOp>
+struct BoolIOpPattern final : public OpConversionPattern<ArithOp> {
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
+
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // UIToFPOp
 //===----------------------------------------------------------------------===//
@@ -1410,8 +1436,11 @@ void mlir::arith::populateArithToSPIRVPatterns(
   patterns.add<
     ConstantCompositeOpPattern,
     ConstantScalarOpPattern,
+    BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>,
     ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
+    BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>,
     ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
+    BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>,
     ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 9c726b8643a46..cf8579ad882b8 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -272,6 +272,28 @@ func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
   return
 }
 
+// CHECK-LABEL: @bool_arith_scalar
+func.func @bool_arith_scalar(%arg0 : i1, %arg1 : i1) {
+  // CHECK: spirv.LogicalNotEqual
+  %0 = arith.addi %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalNotEqual
+  %1 = arith.subi %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalAnd
+  %2 = arith.muli %arg0, %arg1 : i1
+  return
+}
+
+// CHECK-LABEL: @bool_arith_vector
+func.func @bool_arith_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+  // CHECK: spirv.LogicalNotEqual
+  %0 = arith.addi %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalNotEqual
+  %1 = arith.subi %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalAnd
+  %2 = arith.muli %arg0, %arg1 : vector<4xi1>
+  return
+}
+
 // CHECK-LABEL: @shift_scalar
 func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
   // CHECK: spirv.ShiftLeftLogical



More information about the Mlir-commits mailing list