[Mlir-commits] [mlir] 9c8abbf - [mlir][Arithmetic] Support commutative canonicalization for continuous XOrIOp.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Sep 26 00:01:36 PDT 2022
Author: jacquesguan
Date: 2022-09-26T15:01:24+08:00
New Revision: 9c8abbfa0aabb390e7ae79c5309e499fa43899e4
URL: https://github.com/llvm/llvm-project/commit/9c8abbfa0aabb390e7ae79c5309e499fa43899e4
DIFF: https://github.com/llvm/llvm-project/commit/9c8abbfa0aabb390e7ae79c5309e499fa43899e4.diff
LOG: [mlir][Arithmetic] Support commutative canonicalization for continuous XOrIOp.
This patch adds commutative canonicalization support for D116383.
Reviewed By: Mogball
Differential Revision: https://reviews.llvm.org/D134258
Added:
Modified:
mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
mlir/test/Dialect/Arithmetic/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index bd693478a3533..1891ce813919a 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -627,9 +627,21 @@ OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
if (getLhs() == getRhs())
return Builder(getContext()).getZeroAttr(getType());
/// xor(xor(x, a), a) -> x
- if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
+ /// xor(xor(a, x), a) -> x
+ if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
if (prev.getRhs() == getRhs())
return prev.getLhs();
+ if (prev.getLhs() == getRhs())
+ return prev.getRhs();
+ }
+ /// xor(a, xor(x, a)) -> x
+ /// xor(a, xor(a, x)) -> x
+ if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
+ if (prev.getRhs() == getLhs())
+ return prev.getLhs();
+ if (prev.getLhs() == getLhs())
+ return prev.getRhs();
+ }
return constFoldBinaryOp<IntegerAttr>(
operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 649da010cd359..632e7af4a26a3 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -1585,3 +1585,51 @@ func.func @test_andi_not_fold_lhs(%arg0 : index) -> index {
%2 = arith.andi %1, %arg0 : index
return %2 : index
}
+
+// -----
+/// xor(xor(x, a), a) -> x
+
+// CHECK-LABEL: @xorxor0(
+// CHECK-NOT: xori
+// CHECK: return %arg0
+func.func @xorxor0(%a : i32, %b : i32) -> i32 {
+ %c = arith.xori %a, %b : i32
+ %res = arith.xori %c, %b : i32
+ return %res : i32
+}
+
+// -----
+/// xor(xor(a, x), a) -> x
+
+// CHECK-LABEL: @xorxor1(
+// CHECK-NOT: xori
+// CHECK: return %arg0
+func.func @xorxor1(%a : i32, %b : i32) -> i32 {
+ %c = arith.xori %b, %a : i32
+ %res = arith.xori %c, %b : i32
+ return %res : i32
+}
+
+// -----
+/// xor(a, xor(x, a)) -> x
+
+// CHECK-LABEL: @xorxor2(
+// CHECK-NOT: xori
+// CHECK: return %arg0
+func.func @xorxor2(%a : i32, %b : i32) -> i32 {
+ %c = arith.xori %a, %b : i32
+ %res = arith.xori %b, %c : i32
+ return %res : i32
+}
+
+// -----
+/// xor(a, xor(a, x)) -> x
+
+// CHECK-LABEL: @xorxor3(
+// CHECK-NOT: xori
+// CHECK: return %arg0
+func.func @xorxor3(%a : i32, %b : i32) -> i32 {
+ %c = arith.xori %b, %a : i32
+ %res = arith.xori %b, %c : i32
+ return %res : i32
+}
More information about the Mlir-commits
mailing list