[Mlir-commits] [mlir] 34646a2 - [MLIR][Arith] Fold repeated xor and trunc

William S. Moses llvmlistbot at llvm.org
Fri Jan 7 00:36:16 PST 2022


Author: William S. Moses
Date: 2022-01-07T03:36:10-05:00
New Revision: 34646a2f7ee1564b0f6ff706b32d7206e39aac9f

URL: https://github.com/llvm/llvm-project/commit/34646a2f7ee1564b0f6ff706b32d7206e39aac9f
DIFF: https://github.com/llvm/llvm-project/commit/34646a2f7ee1564b0f6ff706b32d7206e39aac9f.diff

LOG: [MLIR][Arith] Fold repeated xor and trunc

This patch adds two folds. One for a repeated xor (e.g. xor(xor(x, a), a)) and one for a repeated trunc (e.g. trunc(trunc(x))).

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D116383

Added: 
    

Modified: 
    mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
    mlir/test/Dialect/Arithmetic/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
index 1536eeaf48af8..69acc19fb9e4f 100644
--- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
+++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp
@@ -557,6 +557,10 @@ OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
   /// xor(x, x) -> 0
   if (getLhs() == getRhs())
     return Builder(getContext()).getZeroAttr(getType());
+  /// xor(xor(x, a), a) -> x
+  if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
+    if (prev.getRhs() == getRhs())
+      return prev.getLhs();
 
   return constFoldBinaryOp<IntegerAttr>(
       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
@@ -859,13 +863,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
 //===----------------------------------------------------------------------===//
 
 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
+  assert(operands.size() == 1 && "unary operation takes one operand");
+
   // trunci(zexti(a)) -> a
   // trunci(sexti(a)) -> a
   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
     return getOperand().getDefiningOp()->getOperand(0);
 
-  assert(operands.size() == 1 && "unary operation takes one operand");
+  // trunci(trunci(a)) -> trunci(a))
+  if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
+    setOperand(getOperand().getDefiningOp()->getOperand(0));
+    return getResult();
+  }
 
   if (!operands[0])
     return {};

diff  --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
index e4cbe21710ba8..8a17fa9d7f661 100644
--- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir
+++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir
@@ -172,6 +172,15 @@ func @truncConstant(%arg0: i8) -> i16 {
   return %tr : i16
 }
 
+// CHECK-LABEL: @truncTrunc
+//       CHECK:   %[[cres:.+]] = arith.trunci %arg0 : i64 to i8
+//       CHECK:   return %[[cres]]
+func @truncTrunc(%arg0: i64) -> i8 {
+  %tr1 = arith.trunci %arg0 : i64 to i32
+  %tr2 = arith.trunci %tr1 : i32 to i8
+  return %tr2 : i8
+}
+
 // CHECK-LABEL: @truncFPConstant
 //       CHECK:   %[[cres:.+]] = arith.constant 1.000000e+00 : bf16
 //       CHECK:   return %[[cres]]
@@ -427,6 +436,18 @@ func @notCmpUGE(%arg0: i8, %arg1: i8) -> i1 {
 
 // -----
 
+// CHECK-LABEL: @xorxor(
+//       CHECK-NOT: xori
+//       CHECK:   return %arg0
+func @xorxor(%cmp : i1) -> i1 {
+  %true = arith.constant true
+  %ncmp = arith.xori %cmp, %true : i1
+  %nncmp = arith.xori %ncmp, %true : i1
+  return %nncmp : i1
+}
+
+// -----
+
 // CHECK-LABEL: @bitcastSameType(
 // CHECK-SAME: %[[ARG:[a-zA-Z0-9_]*]]
 func @bitcastSameType(%arg : f32) -> f32 {


        


More information about the Mlir-commits mailing list