[Mlir-commits] [mlir] 9c8abbf - [mlir][Arithmetic] Support commutative canonicalization for continuous XOrIOp.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 26 00:01:36 PDT 2022


Author: jacquesguan
Date: 2022-09-26T15:01:24+08:00
New Revision: 9c8abbfa0aabb390e7ae79c5309e499fa43899e4

URL: https://github.com/llvm/llvm-project/commit/9c8abbfa0aabb390e7ae79c5309e499fa43899e4
DIFF: https://github.com/llvm/llvm-project/commit/9c8abbfa0aabb390e7ae79c5309e499fa43899e4.diff

LOG: [mlir][Arithmetic] Support commutative canonicalization for continuous XOrIOp.

This patch adds commutative canonicalization support for D116383.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D134258

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index bd693478a3533..1891ce813919a 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -627,9 +627,21 @@ OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
   if (getLhs() == getRhs())
     return Builder(getContext()).getZeroAttr(getType());
   /// xor(xor(x, a), a) -> x
-  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
+  /// xor(xor(a, x), a) -> x
+  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>()) {
     if (prev.getRhs() == getRhs())
       return prev.getLhs();
+    if (prev.getLhs() == getRhs())
+      return prev.getRhs();
+  }
+  /// xor(a, xor(x, a)) -> x
+  /// xor(a, xor(a, x)) -> x
+  if (arith::XOrIOp prev = getRhs().getDefiningOp<arith::XOrIOp>()) {
+    if (prev.getRhs() == getLhs())
+      return prev.getLhs();
+    if (prev.getLhs() == getLhs())
+      return prev.getRhs();
+  }
 
   return constFoldBinaryOp<IntegerAttr>(
       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index 649da010cd359..632e7af4a26a3 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -1585,3 +1585,51 @@ func.func @test_andi_not_fold_lhs(%arg0 : index) -> index {
     %2 = arith.andi %1, %arg0 : index
     return %2 : index
 }
+
+// -----
+/// xor(xor(x, a), a) -> x
+
+// CHECK-LABEL: @xorxor0(
+//       CHECK-NOT: xori
+//       CHECK:   return %arg0
+func.func @xorxor0(%a : i32, %b : i32) -> i32 {
+  %c = arith.xori %a, %b : i32
+  %res = arith.xori %c, %b : i32
+  return %res : i32
+}
+
+// -----
+/// xor(xor(a, x), a) -> x
+
+// CHECK-LABEL: @xorxor1(
+//       CHECK-NOT: xori
+//       CHECK:   return %arg0
+func.func @xorxor1(%a : i32, %b : i32) -> i32 {
+  %c = arith.xori %b, %a : i32
+  %res = arith.xori %c, %b : i32
+  return %res : i32
+}
+
+// -----
+/// xor(a, xor(x, a)) -> x
+
+// CHECK-LABEL: @xorxor2(
+//       CHECK-NOT: xori
+//       CHECK:   return %arg0
+func.func @xorxor2(%a : i32, %b : i32) -> i32 {
+  %c = arith.xori %a, %b : i32
+  %res = arith.xori %b, %c : i32
+  return %res : i32
+}
+
+// -----
+/// xor(a, xor(a, x)) -> x
+
+// CHECK-LABEL: @xorxor3(
+//       CHECK-NOT: xori
+//       CHECK:   return %arg0
+func.func @xorxor3(%a : i32, %b : i32) -> i32 {
+  %c = arith.xori %b, %a : i32
+  %res = arith.xori %b, %c : i32
+  return %res : i32
+}


        


More information about the Mlir-commits mailing list