[Mlir-commits] [mlir] [MLIR][Arith] Canonicalize and / or / xor const tripples (PR #195588)
Max Graey
llvmlistbot at llvm.org
Sun May 3 23:11:39 PDT 2026
https://github.com/MaxGraey updated https://github.com/llvm/llvm-project/pull/195588
>From a4b6d12ade27e60bd546df86d29d6219bfe90283 Mon Sep 17 00:00:00 2001
From: MaxGraey <maxgraey at gmail.com>
Date: Mon, 4 May 2026 08:54:56 +0300
Subject: [PATCH 1/2] [MLIR][Arth] Canonicalize and/or/xor const tripples
---
.../Dialect/Arith/IR/ArithCanonicalization.td | 36 +++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 15 ++
mlir/test/Dialect/Arith/canonicalize.mlir | 144 ++++++++++++++++++
3 files changed, 195 insertions(+)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
index 66f4ace265201..af33e9ae4f3a7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
+++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
@@ -24,6 +24,15 @@ def SubIntAttrs : NativeCodeCall<"subIntegerAttrs($_builder, $0, $1, $2)">;
// Multiply two integer attributes and create a new one with the result.
def MulIntAttrs : NativeCodeCall<"mulIntegerAttrs($_builder, $0, $1, $2)">;
+// Bitwise-and two integer attributes and create a new one with the result.
+def AndIntAttrs : NativeCodeCall<"andIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Bitwise-or two integer attributes and create a new one with the result.
+def OrIntAttrs : NativeCodeCall<"orIntegerAttrs($_builder, $0, $1, $2)">;
+
+// Bitwise-xor two integer attributes and create a new one with the result.
+def XorIntAttrs : NativeCodeCall<"xorIntegerAttrs($_builder, $0, $1, $2)">;
+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
def MergeOverflow : NativeCodeCall<"mergeOverflowFlags($0, $1)">;
@@ -209,6 +218,15 @@ def MulUIExtendedToMulI :
// xori is commutative and will be canonicalized to have its constants appear
// as the second operand.
+// xori(xori(x, c0), c1) -> xori(x, c0 ^ c1)
+def XOrIXOrIConstant :
+ Pat<(Arith_XOrIOp:$res
+ (Arith_XOrIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_XOrIOp $x, (Arith_ConstantOp (XorIntAttrs $res, $c0, $c1))),
+ [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+ (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
+
// not(cmpi(pred, a, b)) -> cmpi(~pred, a, b), where not(x) is xori(x, 1)
def InvertPredicate : NativeCodeCall<"invertPredicate($0)">;
def XOrINotCmpI :
@@ -323,6 +341,15 @@ def ExtSIOfExtUI :
// AndIOp
//===----------------------------------------------------------------------===//
+// andi(andi(x, c0), c1) -> andi(x, c0 & c1)
+def AndIAndIConstant :
+ Pat<(Arith_AndIOp:$res
+ (Arith_AndIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_AndIOp $x, (Arith_ConstantOp (AndIntAttrs $res, $c0, $c1))),
+ [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+ (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
+
// and extui(x), extui(y) -> extui(and(x,y))
// AND can only clear bits, so if either operand is non-negative (sign bit 0),
// the result also has sign bit 0. Preserves nneg if either operand has it.
@@ -342,6 +369,15 @@ def AndOfExtSI :
// OrIOp
//===----------------------------------------------------------------------===//
+// ori(ori(x, c0), c1) -> ori(x, c0 | c1)
+def OrIOrIConstant :
+ Pat<(Arith_OrIOp:$res
+ (Arith_OrIOp $x, (ConstantLikeMatcher APIntAttr:$c0)),
+ (ConstantLikeMatcher APIntAttr:$c1)),
+ (Arith_OrIOp $x, (Arith_ConstantOp (OrIntAttrs $res, $c0, $c1))),
+ [(Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c0),
+ (Constraint<CPred<"$0.getType() == cast<IntegerAttr>($1).getType()">> $res, $c1)]>;
+
// or extui(x), extui(y) -> extui(or(x,y))
// OR preserves nneg only if both operands are non-negative: or of two values
// with sign bit 0 has sign bit 0.
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index bcdab9ee8e978..87b29a3f3beb7 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -67,6 +67,21 @@ static IntegerAttr mulIntegerAttrs(PatternRewriter &builder, Value res,
return applyToIntegerAttrs(builder, res, lhs, rhs, std::multiplies<APInt>());
}
+static IntegerAttr andIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::bit_and<APInt>());
+}
+
+static IntegerAttr orIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::bit_or<APInt>());
+}
+
+static IntegerAttr xorIntegerAttrs(PatternRewriter &builder, Value res,
+ Attribute lhs, Attribute rhs) {
+ return applyToIntegerAttrs(builder, res, lhs, rhs, std::bit_xor<APInt>());
+}
+
// Merge overflow flags from 2 ops, selecting the most conservative combination.
static IntegerOverflowFlagsAttr
mergeOverflowFlags(IntegerOverflowFlagsAttr val1,
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 02626bef856d3..287f903ab7192 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -579,6 +579,102 @@ func.func @orOfExtUI_nneg_mixed(%arg0: i8, %arg1: i8) -> i64 {
// -----
+// CHECK-LABEL: @andOfAndConstant
+// CHECK: %[[cres:.+]] = arith.constant 8 : i32
+// CHECK: %[[and:.+]] = arith.andi %arg0, %[[cres]] : i32
+// CHECK: return %[[and]]
+func.func @andOfAndConstant(%arg0: i32) -> i32 {
+ %c12 = arith.constant 12 : i32
+ %c10 = arith.constant 10 : i32
+ %a1 = arith.andi %arg0, %c12 : i32
+ %a2 = arith.andi %a1, %c10 : i32
+ return %a2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @andOfAndConstantIndex
+// CHECK: %[[cres:.+]] = arith.constant 8 : index
+// CHECK: %[[and:.+]] = arith.andi %arg0, %[[cres]] : index
+// CHECK: return %[[and]]
+func.func @andOfAndConstantIndex(%arg0: index) -> index {
+ %c12 = arith.constant 12 : index
+ %c10 = arith.constant 10 : index
+ %a1 = arith.andi %arg0, %c12 : index
+ %a2 = arith.andi %a1, %c10 : index
+ return %a2 : index
+}
+
+// -----
+
+// CHECK-LABEL: @orOfOrConstant
+// CHECK: %[[cres:.+]] = arith.constant 14 : i32
+// CHECK: %[[or:.+]] = arith.ori %arg0, %[[cres]] : i32
+// CHECK: return %[[or]]
+func.func @orOfOrConstant(%arg0: i32) -> i32 {
+ %c12 = arith.constant 12 : i32
+ %c10 = arith.constant 10 : i32
+ %o1 = arith.ori %arg0, %c12 : i32
+ %o2 = arith.ori %o1, %c10 : i32
+ return %o2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @orOfOrConstantIndex
+// CHECK: %[[cres:.+]] = arith.constant 14 : index
+// CHECK: %[[or:.+]] = arith.ori %arg0, %[[cres]] : index
+// CHECK: return %[[or]]
+func.func @orOfOrConstantIndex(%arg0: index) -> index {
+ %c12 = arith.constant 12 : index
+ %c10 = arith.constant 10 : index
+ %o1 = arith.ori %arg0, %c12 : index
+ %o2 = arith.ori %o1, %c10 : index
+ return %o2 : index
+}
+
+// -----
+
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL: func.func @nested_andi() -> i32 {
+// CHECK: %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK: %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_3:.*]] = arith.andi %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK: %[[VAL_4:.*]] = arith.andi %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK: return %[[VAL_4]] : i32
+// CHECK: }
+func.func @nested_andi() -> (i32) {
+ %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+ %1 = "test.constant"() {value = -2147483648} : () -> i32
+ %2 = "test.constant"() {value = 0x80000000} : () -> i32
+ %4 = arith.andi %0, %1 : i32
+ %5 = arith.andi %4, %2 : i32
+ return %5 : i32
+}
+
+// -----
+
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL: func.func @nested_ori() -> i32 {
+// CHECK: %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK: %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_3:.*]] = arith.ori %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK: %[[VAL_4:.*]] = arith.ori %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK: return %[[VAL_4]] : i32
+// CHECK: }
+func.func @nested_ori() -> (i32) {
+ %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+ %1 = "test.constant"() {value = -2147483648} : () -> i32
+ %2 = "test.constant"() {value = 0x80000000} : () -> i32
+ %4 = arith.ori %0, %1 : i32
+ %5 = arith.ori %4, %2 : i32
+ return %5 : i32
+}
+
+// -----
+
// CHECK-LABEL: @indexCastOfSignExtend
// CHECK: %[[res:.+]] = arith.index_cast %arg0 : i8 to index
// CHECK: return %[[res]]
@@ -2130,6 +2226,54 @@ func.func @xorOfExtUI_nneg_mixed(%arg0: i8, %arg1: i8) -> i64 {
// -----
+// CHECK-LABEL: @xorOfXorConstant
+// CHECK: %[[cres:.+]] = arith.constant 6 : i32
+// CHECK: %[[xor:.+]] = arith.xori %arg0, %[[cres]] : i32
+// CHECK: return %[[xor]]
+func.func @xorOfXorConstant(%arg0: i32) -> i32 {
+ %c12 = arith.constant 12 : i32
+ %c10 = arith.constant 10 : i32
+ %x1 = arith.xori %arg0, %c12 : i32
+ %x2 = arith.xori %x1, %c10 : i32
+ return %x2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @xorOfXorConstantIndex
+// CHECK: %[[cres:.+]] = arith.constant 6 : index
+// CHECK: %[[xor:.+]] = arith.xori %arg0, %[[cres]] : index
+// CHECK: return %[[xor]]
+func.func @xorOfXorConstantIndex(%arg0: index) -> index {
+ %c12 = arith.constant 12 : index
+ %c10 = arith.constant 10 : index
+ %x1 = arith.xori %arg0, %c12 : index
+ %x2 = arith.xori %x1, %c10 : index
+ return %x2 : index
+}
+
+// -----
+
+// Negative test case to ensure no further folding is performed when there's a type mismatch between the values and the result.
+// CHECK-LABEL: func.func @nested_xori() -> i32 {
+// CHECK: %[[VAL_0:.*]] = "test.constant"() <{value = 2147483647 : i64}> : () -> i32
+// CHECK: %[[VAL_1:.*]] = "test.constant"() <{value = -2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_2:.*]] = "test.constant"() <{value = 2147483648 : i64}> : () -> i32
+// CHECK: %[[VAL_3:.*]] = arith.xori %[[VAL_0]], %[[VAL_1]] : i32
+// CHECK: %[[VAL_4:.*]] = arith.xori %[[VAL_3]], %[[VAL_2]] : i32
+// CHECK: return %[[VAL_4]] : i32
+// CHECK: }
+func.func @nested_xori() -> (i32) {
+ %0 = "test.constant"() {value = 0x7fffffff} : () -> i32
+ %1 = "test.constant"() {value = -2147483648} : () -> i32
+ %2 = "test.constant"() {value = 0x80000000} : () -> i32
+ %4 = arith.xori %0, %1 : i32
+ %5 = arith.xori %4, %2 : i32
+ return %5 : i32
+}
+
+// -----
+
// CHECK-LABEL: @bitcastSameType(
// CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]]
func.func @bitcastSameType(%arg : f32) -> f32 {
>From 8dc3f7088a45ac2a6446991efc000c071cda0110 Mon Sep 17 00:00:00 2001
From: MaxGraey <maxgraey at gmail.com>
Date: Mon, 4 May 2026 09:11:25 +0300
Subject: [PATCH 2/2] fix build: register new patterns
---
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 87b29a3f3beb7..87c877de24981 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1105,7 +1105,7 @@ OpFoldResult arith::XOrIOp::fold(FoldAdaptor adaptor) {
void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
+ patterns.add<XOrIXOrIConstant, XOrINotCmpI, XOrIOfExtUI, XOrIOfExtSI>(context);
}
//===----------------------------------------------------------------------===//
@@ -1828,7 +1828,7 @@ LogicalResult arith::ScalingTruncFOp::verify() {
void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<AndOfExtUI, AndOfExtSI>(context);
+ patterns.add<AndIAndIConstant, AndOfExtUI, AndOfExtSI>(context);
}
//===----------------------------------------------------------------------===//
@@ -1837,7 +1837,7 @@ void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
- patterns.add<OrOfExtUI, OrOfExtSI>(context);
+ patterns.add<OrIOrIConstant, OrOfExtUI, OrOfExtSI>(context);
}
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list