[Mlir-commits] [mlir] [MLIR][Arith] Canonicalize and / or / xor const tripples (PR #195588)
Max Graey
llvmlistbot at llvm.org
Sun May 3 23:02:21 PDT 2026
https://github.com/MaxGraey created https://github.com/llvm/llvm-project/pull/195588
```
(x & c1) & c2 --> x & (c1 & c2)
(x | c1) | c2 --> x | (c1 | c2)
(x ^ c1) ^ c2 --> x ^ (c1 ^ c2)
```
Proofs:
- [(x & c1) & c2 --> x & (c1 & c2)](https://alive2.llvm.org/ce/z/VdSbpp)
- [(x | c1) | c2 --> x | (c1 | c2)](https://alive2.llvm.org/ce/z/vbA5NZ)
- [(x ^ c1) ^ c2 --> x ^ (c1 ^ c2)](https://alive2.llvm.org/ce/z/t_7v_h)
Co-authored-by: Claude Opus 4.6 (1M context) [noreply at anthropic.com](mailto:noreply at anthropic.com)
>From a4b6d12ade27e60bd546df86d29d6219bfe90283 Mon Sep 17 00:00:00 2001
From: MaxGraey <maxgraey at gmail.com>
Date: Mon, 4 May 2026 08:54:56 +0300
Subject: [PATCH] [MLIR][Arth] Canonicalize and/or/xor const tripples
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 36 +++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 15 ++
mlir/test/Dialect/Arith/canonicalize.mlir | 144 ++++++++++++++++++
3 files changed, 195 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 66f4ace265201..af33e9ae4f3a7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,15 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
+// Bitwise-and two integer attributes and create a new one with the result.
+def AndIntAttrs : NativeCodeCall<"andIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Bitwise-or two integer attributes and create a new one with the result.
+def OrIntAttrs : NativeCodeCall<"orIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Bitwise-xor two integer attributes and create a new one with the result.
+def XorIntAttrs : NativeCodeCall<"xorIntegerAttrs($_builder, $0, $1, $2)">;
+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
@@ -209,6 +218,15 @@ def MulUIExtendedToMulI :
// xori is commutative and will be canonicalized to have its constants appear
// as the second operand.
+// xori(xori(x, c0), c1) -> xori(x, c0 ^ c1)
+def XOrIXOrIConstant :
+ Pat<(Arith_XOrIOp:$res
+ (Arith_XOrIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_XOrIOp $x, (Arith_ConstantOp (XorIntAttrs $res, $c0, $c1))),
+ [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+ (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
+
// not(cmpi(pred, a, b)) -> cmpi(~pred, a, b), where not(x) is xori(x, 1)
def InvertPredicate : NativeCodeCall<"invertPredicate($0)">;
def XOrINotCmpI :
@@ -323,6 +341,15 @@ def ExtSIOfExtUI :
// AndIOp
//===----------------------------------------------------------------------===//
+// andi(andi(x, c0), c1) -> andi(x, c0 & c1)
+def AndIAndIConstant :
+ Pat<(Arith_AndIOp:$res
+ (Arith_AndIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_AndIOp $x, (Arith_ConstantOp (AndIntAttrs $res, $c0, $c1))),
+ [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+ (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
+
// and extui(x), extui(y) -> extui(and(x,y))
// AND can only clear bits, so if either operand is non-negative (sign bit 0),
// the result also has sign bit 0. Preserves nneg if either operand has it.
@@ -342,6 +369,15 @@ def AndOfExtSI :
// OrIOp
//===----------------------------------------------------------------------===//
+// ori(ori(x, c0), c1) -> ori(x, c0 | c1)
+def OrIOrIConstant :
+ Pat<(Arith_OrIOp:$res
+ (Arith_OrIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_OrIOp $x, (Arith_ConstantOp (OrIntAttrs $res, $c0, $c1))),
+ [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+ (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
+
// or extui(x), extui(y) -> extui(or(x,y))
// OR preserves nneg only if both operands are non-negative: or of two values
// with sign bit 0 has sign bit 0.
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index bcdab9ee8e978..87b29a3f3beb7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -67,6 +67,21 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}
+static IntegerAttr andIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::bit_and<APInt>());
+}
+
+static IntegerAttr orIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::bit_or<APInt>());
+}
+
+static IntegerAttr xorIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::bit_xor<APInt>());
+}
+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
static IntegerOverflowFlagsAttr
mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 02626bef856d3..287f903ab7192 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -579,6 +579,102 @@ func.func @orOfExtUI_nneg_mixed(%arg0: i8, %arg1: i8) -> i64 {
// -----
+// CHECK-LABEL: @andOfAndConstant
+// CHECK: %[[cres:.+]] = arith.constant 8 : i32
+// CHECK: %[[and:.+]] = arith.andi %arg0, %[[cres]] : i32
+// CHECK: return %[[and]]
+func.func @andOfAndConstant(%arg0: i32) -> i32 {
+ %c12 = arith.constant 12 : i32
+ %c10 = arith.constant 10 : i32
+ %a1 = arith.andi %arg0, %c12 : i32
+ %a2 = arith.andi %a1, %c10 : i32
+ return %a2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @andOfAndConstantIndex
+// CHECK: %[[cres:.+]] = arith.constant 8 : index
+// CHECK: %[[and:.+]] = arith.andi %arg0, %[[cres]] : index
+// CHECK: return %[[and]]
+func.func @andOfAndConstantIndex(%arg0: index) -> index {
+ %c12 = arith.constant 12 : index
+ %c10 = arith.constant 10 : index
+ %a1 = arith.andi %arg0, %c12 : index
+ %a2 = arith.andi %a1, %c10 : index
+ return %a2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @orOfOrConstant
+// CHECK: %[[cres:.+]] = arith.constant 14 : i32
+// CHECK: %[[or:.+]] = arith.ori %arg0, %[[cres]] : i32
+// CHECK: return %[[or]]
+func.func @orOfOrConstant(%arg0: i32) -> i32 {
+ %c12 = arith.constant 12 : i32
+ %c10 = arith.constant 10 : i32
+ %o1 = arith.ori %arg0, %c12 : i32
+ %o2 = arith.ori %o1, %c10 : i32
+ return %o2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @orOfOrConstantIndex
+// CHECK: %[[cres:.+]] = arith.constant 14 : index
+// CHECK: %[[or:.+]] = arith.ori %arg0, %[[cres]] : index
+// CHECK: return %[[or]]
+func.func @orOfOrConstantIndex(%arg0: index) -> index {
+ %c12 = arith.constant 12 : index
+ %c10 = arith.constant 10 : index
+ %o1 = arith.ori %arg0, %c12 : index
+ %o2 = arith.ori %o1, %c10 : index
+ return %o2 : index
+}
+
+// -----
+
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL: func.func @nested_andi() -> i32 {
+// CHECK: %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK: %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_3:.*]] = arith.andi %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK: return %[[VAL_4]] : i32
+// CHECK: }
+func.func @nested_andi() -> (i32) {
+ %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+ %1 = "test.constant"() {value = -2147483648} : () -> i32
+ %2 = "test.constant"() {value = 0x80000000} : () -> i32
+ %4 = arith.andi %0, %1 : i32
+ %5 = arith.andi %4, %2 : i32
+ return %5 : i32
+}
+
+// -----
+
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL: func.func @nested_ori() -> i32 {
+// CHECK: %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK: %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_3:.*]] = arith.ori %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK: %[[VAL_4:.*]] = arith.ori %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK: return %[[VAL_4]] : i32
+// CHECK: }
+func.func @nested_ori() -> (i32) {
+ %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+ %1 = "test.constant"() {value = -2147483648} : () -> i32
+ %2 = "test.constant"() {value = 0x80000000} : () -> i32
+ %4 = arith.ori %0, %1 : i32
+ %5 = arith.ori %4, %2 : i32
+ return %5 : i32
+}
+
+// -----
+
// CHECK-LABEL: @indexCastOfSignExtend
// CHECK: %[[res:.+]] = arith.index_cast %arg0 : i8 to index
// CHECK: return %[[res]]
@@ -2130,6 +2226,54 @@ func.func @xorOfExtUI_nneg_mixed(%arg0: i8, %arg1: i8) -> i64 {
// -----
+// CHECK-LABEL: @xorOfXorConstant
+// CHECK: %[[cres:.+]] = arith.constant 6 : i32
+// CHECK: %[[xor:.+]] = arith.xori %arg0, %[[cres]] : i32
+// CHECK: return %[[xor]]
+func.func @xorOfXorConstant(%arg0: i32) -> i32 {
+ %c12 = arith.constant 12 : i32
+ %c10 = arith.constant 10 : i32
+ %x1 = arith.xori %arg0, %c12 : i32
+ %x2 = arith.xori %x1, %c10 : i32
+ return %x2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @xorOfXorConstantIndex
+// CHECK: %[[cres:.+]] = arith.constant 6 : index
+// CHECK: %[[xor:.+]] = arith.xori %arg0, %[[cres]] : index
+// CHECK: return %[[xor]]
+func.func @xorOfXorConstantIndex(%arg0: index) -> index {
+ %c12 = arith.constant 12 : index
+ %c10 = arith.constant 10 : index
+ %x1 = arith.xori %arg0, %c12 : index
+ %x2 = arith.xori %x1, %c10 : index
+ return %x2 : index
+}
+
+// -----
+
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL: func.func @nested_xori() -> i32 {
+// CHECK: %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK: %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_3:.*]] = arith.xori %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK: %[[VAL_4:.*]] = arith.xori %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK: return %[[VAL_4]] : i32
+// CHECK: }
+func.func @nested_xori() -> (i32) {
+ %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+ %1 = "test.constant"() {value = -2147483648} : () -> i32
+ %2 = "test.constant"() {value = 0x80000000} : () -> i32
+ %4 = arith.xori %0, %1 : i32
+ %5 = arith.xori %4, %2 : i32
+ return %5 : i32
+}
+
+// -----
+
// CHECK-LABEL: @bitcastSameType(
// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]]
func.func @bitcastSameType(%arg : f32) -> f32 {
More information about the Mlir-commits
mailing list