[Mlir-commits] [mlir] [mlir][ArithToSPIRV] Fix invalid SPIRV and crashes when lowering integer ops on i1 (PR #189239)

Mehdi Amini llvmlistbot at llvm.org
Mon Mar 30 16:46:11 PDT 2026


https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/189239

>From 55c28b68c170fe8154712bbb29d5ffe5752918b4 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 28 Mar 2026 15:54:39 -0700
Subject: [PATCH 1/2] [mlir][ArithToSPIRV] Fix crash converting
 arith.addi/subi/muli on i1 types

arith.addi, arith.subi, and arith.muli on i1 (boolean) operands were
incorrectly lowered to spirv.IAdd, spirv.ISub, and spirv.IMul, which
require 8/16/32/64-bit integer types and reject i1, causing SPIRV
verification to fail.

Fix by adding three new boolean-specialized conversion patterns
(AddIOpBooleanPattern, SubIOpBooleanPattern, MulIOpBooleanPattern)
modeled after the existing XOrIOpBooleanPattern:
- addi on i1: lowers to spirv.LogicalNotEqual (addition mod 2 = XOR)
- subi on i1: lowers to spirv.LogicalNotEqual (subtraction mod 2 = XOR)
- muli on i1: lowers to spirv.LogicalAnd (multiplication mod 2 = AND)

ElementwiseArithOpPattern is updated to reject boolean types so the
specialized patterns take priority.

Fixes #61162

Assisted-by: Claude Code
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 29 +++++++++++++++++++
 .../ArithToSPIRV/arith-to-spirv.mlir          | 22 ++++++++++++++
 2 files changed, 51 insertions(+)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d6b1e9552fbc5..265e3d9fc0bc8 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -182,6 +182,11 @@ struct ElementwiseArithOpPattern final : OpConversionPattern<Op> {
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     assert(adaptor.getOperands().size() <= 3);
+    // Reject boolean types to allow specialized boolean patterns to handle
+    // them (e.g., addi/subi on i1 should use LogicalNotEqual, not IAdd/ISub).
+    if (!adaptor.getOperands().empty() &&
+        isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
     auto converter = this->template getTypeConverter<SPIRVTypeConverter>();
     Type dstType = converter->convertType(op.getType());
     if (!dstType) {
@@ -572,6 +577,27 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
   }
 };
 
+/// Converts an arith integer op to the given SPIR-V boolean op if the type is
+/// i1 or vector of i1.
+template <typename ArithOp, typename SPIRVOp>
+struct BoolIOpPattern final : public OpConversionPattern<ArithOp> {
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
+
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    rewriter.replaceOpWithNewOp<SPIRVOp>(op, dstType, adaptor.getOperands());
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // UIToFPOp
 //===----------------------------------------------------------------------===//
@@ -1410,8 +1436,11 @@ void mlir::arith::populateArithToSPIRVPatterns(
   patterns.add<
     ConstantCompositeOpPattern,
     ConstantScalarOpPattern,
+    BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>,
     ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
+    BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>,
     ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
+    BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>,
     ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 9c726b8643a46..cf8579ad882b8 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -272,6 +272,28 @@ func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
   return
 }
 
+// CHECK-LABEL: @bool_arith_scalar
+func.func @bool_arith_scalar(%arg0 : i1, %arg1 : i1) {
+  // CHECK: spirv.LogicalNotEqual
+  %0 = arith.addi %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalNotEqual
+  %1 = arith.subi %arg0, %arg1 : i1
+  // CHECK: spirv.LogicalAnd
+  %2 = arith.muli %arg0, %arg1 : i1
+  return
+}
+
+// CHECK-LABEL: @bool_arith_vector
+func.func @bool_arith_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) {
+  // CHECK: spirv.LogicalNotEqual
+  %0 = arith.addi %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalNotEqual
+  %1 = arith.subi %arg0, %arg1 : vector<4xi1>
+  // CHECK: spirv.LogicalAnd
+  %2 = arith.muli %arg0, %arg1 : vector<4xi1>
+  return
+}
+
 // CHECK-LABEL: @shift_scalar
 func.func @shift_scalar(%arg0 : i32, %arg1 : i32) {
   // CHECK: spirv.ShiftLeftLogical

>From d9dc180a8ec856d924455600a69ebf38f0ea8c67 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Mon, 30 Mar 2026 04:38:21 -0700
Subject: [PATCH 2/2] [mlir][ArithToSPIRV] Fix i1 lowerings for div, rem,
 shift, and min/max ops
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Extend the i1 boolean pattern fixes to cover all integer ops that were
previously generating invalid SPIRV (e.g. spirv.UDiv on i1) or crashing:

- divui/divsi(a,b) on i1 → spirv.LogicalAnd (identity: only b=1 is valid)
- remui/remsi(a,b) on i1 → spirv.LogicalAnd(a, spirv.LogicalNot(b))
- shli/shrui(a,b) on i1  → spirv.LogicalAnd(a, spirv.LogicalNot(b))
- shrsi(a,b) on i1       → identity (arithmetic right shift of i1 = input)
- maxui/minsi(a,b) on i1 → spirv.LogicalOr
- maxsi/minui(a,b) on i1 → spirv.LogicalAnd

All new boolean patterns are given benefit=2 so they take priority over
the generic spirv::ElementwiseOpPattern (benefit=1), which has no i1 guard
and would otherwise generate invalid SPIRV for 1-bit integer types.

Also add test coverage for all of the above in arith-to-spirv.mlir.

Assisted-by: Claude Code
---
 .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp  | 92 ++++++++++++++++++-
 .../ArithToSPIRV/arith-to-spirv.mlir          | 57 ++++++++++++
 2 files changed, 144 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265e3d9fc0bc8..660f450ff9f27 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -578,10 +578,26 @@ struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
 };
 
 /// Converts an arith integer op to the given SPIR-V boolean op if the type is
-/// i1 or vector of i1.
+/// i1 or vector of i1. Each mapping follows from the boolean truth table of
+/// the operation:
+///   addi(a, b)  = a ^ b  (add mod 2 = XOR = LogicalNotEqual)
+///   subi(a, b)  = a ^ b  (sub mod 2 = XOR = LogicalNotEqual)
+///   muli(a, b)  = a & b  (1*1=1, else 0 = LogicalAnd)
+///   divui(a, b) = a & b  (a/1=a, a/0=UB; truth table matches AND)
+///   divsi(a, b) = a & b  (same as divui on i1)
+///   maxsi(a, b) = a & b  (signed i1: 1 represents -1, so max is 0 unless both
+///                        are 1)
+///   maxui(a, b) = a | b  (unsigned max on i1: 1 when either operand is 1)
+///   minsi(a, b) = a | b  (signed i1: -1 < 0, so min is 1 when either operand
+///                        is 1)
+///   minui(a, b) = a & b  (unsigned min on i1: 1 only when both operands are
+///                        1)
 template <typename ArithOp, typename SPIRVOp>
 struct BoolIOpPattern final : public OpConversionPattern<ArithOp> {
-  using OpConversionPattern<ArithOp>::OpConversionPattern;
+  BoolIOpPattern(const TypeConverter &converter, MLIRContext *context)
+      // benefit=2: takes priority over the generic ElementwiseArithOpPattern
+      // (benefit=1) when the operand type is i1.
+      : OpConversionPattern<ArithOp>(converter, context, /*benefit=*/2) {}
 
   LogicalResult
   matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
@@ -598,6 +614,61 @@ struct BoolIOpPattern final : public OpConversionPattern<ArithOp> {
   }
 };
 
+/// Converts an arith binary op on i1 to spirv.LogicalAnd(lhs,
+/// spirv.LogicalNot(rhs)). This covers shift-left, shift-right-unsigned, and
+/// unsigned remainder on i1:
+///   shli(a, b)  = a & ~b  (shift left clears the bit when b=1)
+///   shrui(a, b) = a & ~b  (shift right unsigned clears the bit when b=1)
+///   remui(a, b) = a & ~b  (only defined when b=1; a%1=0, and ~b=~1=0, so AND
+///                         gives 0)
+///   remsi(a, b) = a & ~b  (only defined when b=1; a%1=0, and ~b=~1=0, so AND
+///                         gives 0)
+template <typename ArithOp>
+struct BoolIOpAndNotPattern final : public OpConversionPattern<ArithOp> {
+  BoolIOpAndNotPattern(const TypeConverter &converter, MLIRContext *context)
+      // benefit=2: takes priority over the generic ElementwiseArithOpPattern
+      // (benefit=1) when the operand type is i1.
+      : OpConversionPattern<ArithOp>(converter, context, /*benefit=*/2) {}
+
+  LogicalResult
+  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
+
+    Type dstType = this->getTypeConverter()->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    Location loc = op.getLoc();
+    Value notRhs = spirv::LogicalNotOp::create(rewriter, loc, dstType,
+                                               adaptor.getOperands()[1]);
+    rewriter.replaceOpWithNewOp<spirv::LogicalAndOp>(
+        op, dstType, adaptor.getOperands()[0], notRhs);
+    return success();
+  }
+};
+
+/// Converts arith.shrsi on i1 to identity: arithmetic right shift of a 1-bit
+/// signed value always yields the original value (0 >> n = 0, -1 >> n = -1).
+struct ShRSIBoolPattern final : public OpConversionPattern<arith::ShRSIOp> {
+  ShRSIBoolPattern(const TypeConverter &converter, MLIRContext *context)
+      // benefit=2: takes priority over the generic spirv::ElementwiseOpPattern
+      // (benefit=1) when the operand type is i1.
+      : OpConversionPattern<arith::ShRSIOp>(converter, context,
+                                            /*benefit=*/2) {}
+
+  LogicalResult
+  matchAndRewrite(arith::ShRSIOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
+      return failure();
+
+    rewriter.replaceOp(op, adaptor.getOperands().front());
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // UIToFPOp
 //===----------------------------------------------------------------------===//
@@ -1436,21 +1507,28 @@ void mlir::arith::populateArithToSPIRVPatterns(
   patterns.add<
     ConstantCompositeOpPattern,
     ConstantScalarOpPattern,
-    BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>,
+    BoolIOpPattern<arith::AddIOp, spirv::LogicalNotEqualOp>, // add mod 2 = XOR = not-equal
     ElementwiseArithOpPattern<arith::AddIOp, spirv::IAddOp>,
-    BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>,
+    BoolIOpPattern<arith::SubIOp, spirv::LogicalNotEqualOp>, // sub mod 2 = XOR = not-equal
     ElementwiseArithOpPattern<arith::SubIOp, spirv::ISubOp>,
-    BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>,
+    BoolIOpPattern<arith::MulIOp, spirv::LogicalAndOp>,      // 1*1=1, else 0 = AND
     ElementwiseArithOpPattern<arith::MulIOp, spirv::IMulOp>,
+    BoolIOpPattern<arith::DivUIOp, spirv::LogicalAndOp>,     // a/1=a, a/0=UB; truth table = AND
     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
+    BoolIOpPattern<arith::DivSIOp, spirv::LogicalAndOp>,     // same as divui on i1
     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
+    BoolIOpAndNotPattern<arith::RemUIOp>,  // remui(a,b) = a & ~b (see pattern comment)
     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
+    BoolIOpAndNotPattern<arith::RemSIOp>,  // remsi(a,b) = a & ~b (see pattern comment)
     RemSIOpGLPattern, RemSIOpCLPattern,
     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
+    BoolIOpAndNotPattern<arith::ShLIOp>,   // shli(a,b)  = a & ~b (see pattern comment)
     ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
+    BoolIOpAndNotPattern<arith::ShRUIOp>,  // shrui(a,b) = a & ~b (see pattern comment)
     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
+    ShRSIBoolPattern,
     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
     spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
     spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
@@ -1483,6 +1561,10 @@ void mlir::arith::populateArithToSPIRVPatterns(
     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
     MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
     MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
+    BoolIOpPattern<arith::MaxSIOp, spirv::LogicalAndOp>, // signed i1: 1=-1, so max=0 unless both are 1
+    BoolIOpPattern<arith::MaxUIOp, spirv::LogicalOrOp>,  // unsigned max on i1: 1 when either is 1
+    BoolIOpPattern<arith::MinSIOp, spirv::LogicalOrOp>,  // signed i1: -1<0, so min=1 when either is 1
+    BoolIOpPattern<arith::MinUIOp, spirv::LogicalAndOp>, // unsigned min on i1: 1 only when both are 1
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index cf8579ad882b8..31b70177a0d19 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -320,6 +320,63 @@ func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) {
 
 // -----
 
+// Test i1 lowerings for shift, div, rem, and min/max ops.
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @bool_shift_div_rem_scalar
+func.func @bool_shift_div_rem_scalar(%arg0 : i1, %arg1 : i1) {
+  // shli(a,b) = a & ~b
+  // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1
+  // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]]
+  %0 = arith.shli %arg0, %arg1 : i1
+  // shrui(a,b) = a & ~b
+  // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1
+  // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]]
+  %1 = arith.shrui %arg0, %arg1 : i1
+  // shrsi(a,b) = a  (arithmetic right shift of i1 is identity)
+  // CHECK-NOT: spirv.ShiftRightArithmetic
+  %2 = arith.shrsi %arg0, %arg1 : i1
+  // divui(a,b) = a & b  (only valid for b=1, same as muli)
+  // CHECK: spirv.LogicalAnd %arg0, %arg1
+  %3 = arith.divui %arg0, %arg1 : i1
+  // divsi(a,b) = a & b  (only non-UB/non-overflow case: 0/-1 = 0)
+  // CHECK: spirv.LogicalAnd %arg0, %arg1
+  %4 = arith.divsi %arg0, %arg1 : i1
+  // remui(a,b) = a & ~b  (a % 1 = 0 for valid b=1)
+  // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1
+  // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]]
+  %5 = arith.remui %arg0, %arg1 : i1
+  // remsi(a,b) = a & ~b
+  // CHECK: %[[NOTB:.+]] = spirv.LogicalNot %arg1
+  // CHECK: spirv.LogicalAnd %arg0, %[[NOTB]]
+  %6 = arith.remsi %arg0, %arg1 : i1
+  return
+}
+
+// CHECK-LABEL: @bool_minmax_scalar
+func.func @bool_minmax_scalar(%arg0 : i1, %arg1 : i1) {
+  // maxui(a,b) = a | b  (unsigned max of two booleans is OR)
+  // CHECK: spirv.LogicalOr %arg0, %arg1
+  %0 = arith.maxui %arg0, %arg1 : i1
+  // maxsi(a,b) = a & b  (signed max: max(-1,0)=0, so max(true,false)=false → AND)
+  // CHECK: spirv.LogicalAnd %arg0, %arg1
+  %1 = arith.maxsi %arg0, %arg1 : i1
+  // minui(a,b) = a & b  (unsigned min of two booleans is AND)
+  // CHECK: spirv.LogicalAnd %arg0, %arg1
+  %2 = arith.minui %arg0, %arg1 : i1
+  // minsi(a,b) = a | b  (signed min: min(-1,0)=-1, so min(true,false)=true → OR)
+  // CHECK: spirv.LogicalOr %arg0, %arg1
+  %3 = arith.minsi %arg0, %arg1 : i1
+  return
+}
+
+} // end module
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // arith.cmpf
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list