[Mlir-commits] [mlir] 93297e4 - [MLIR] Replace a not of a comparison with appropriate comparison

William S. Moses llvmlistbot at llvm.org
Tue May 4 08:32:31 PDT 2021


Author: William S. Moses
Date: 2021-05-04T11:23:29-04:00
New Revision: 93297e4bacd99f8c6711c136a4000c8526a7ea31

URL: https://github.com/llvm/llvm-project/commit/93297e4bacd99f8c6711c136a4000c8526a7ea31
DIFF: https://github.com/llvm/llvm-project/commit/93297e4bacd99f8c6711c136a4000c8526a7ea31.diff

LOG: [MLIR] Replace a not of a comparison with appropriate comparison

Differential Revision: https://reviews.llvm.org/D101710

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 6152b6b4b41a6..8ee34ebe4c28d 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2292,6 +2292,7 @@ def XOrOp : IntBinaryOp<"xor", [Commutative]> {
     ```
   }];
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 51b832805cca8..5b10e12c248eb 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3011,6 +3011,80 @@ OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
                                         [](APInt a, APInt b) { return a ^ b; });
 }
 
+namespace {
+/// Replace a not of a comparison operation, for example: not(cmp eq A, B) =>
+/// cmp ne A, B. Note that a logical not is implemented as xor 1, val
+struct NotICmp : public OpRewritePattern<XOrOp> {
+  using OpRewritePattern<XOrOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(XOrOp op,
+                                PatternRewriter &rewriter) const override {
+
+    APInt constValue;
+    if (!matchPattern(op.getOperand(1), m_ConstantInt(&constValue)))
+      return failure();
+
+    if (constValue != 1)
+      return failure();
+
+    auto prev = op.getOperand(0).getDefiningOp<CmpIOp>();
+    if (!prev)
+      return failure();
+
+    switch (prev.predicate()) {
+    case CmpIPredicate::eq:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ne, prev.lhs(),
+                                          prev.rhs());
+      return success();
+    case CmpIPredicate::ne:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::eq, prev.lhs(),
+                                          prev.rhs());
+      return success();
+
+    case CmpIPredicate::slt:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sge, prev.lhs(),
+                                          prev.rhs());
+      return success();
+    case CmpIPredicate::sle:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sgt, prev.lhs(),
+                                          prev.rhs());
+      return success();
+    case CmpIPredicate::sgt:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::sle, prev.lhs(),
+                                          prev.rhs());
+      return success();
+    case CmpIPredicate::sge:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::slt, prev.lhs(),
+                                          prev.rhs());
+      return success();
+
+    case CmpIPredicate::ult:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::uge, prev.lhs(),
+                                          prev.rhs());
+      return success();
+    case CmpIPredicate::ule:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ugt, prev.lhs(),
+                                          prev.rhs());
+      return success();
+    case CmpIPredicate::ugt:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ule, prev.lhs(),
+                                          prev.rhs());
+      return success();
+    case CmpIPredicate::uge:
+      rewriter.replaceOpWithNewOp<CmpIOp>(op, CmpIPredicate::ult, prev.lhs(),
+                                          prev.rhs());
+      return success();
+    }
+    return failure();
+  }
+};
+} // namespace
+
+void XOrOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                        MLIRContext *context) {
+  results.insert<NotICmp>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // ZeroExtendIOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 15dbde7d2757f..f67ba447cf3d1 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -538,3 +538,113 @@ func @tripleSubSub3(%arg0: index) -> index {
   %add2 = subi %add1, %c42 : index
   return %add2 : index
 }
+
+// CHECK-LABEL: @notCmpEQ
+//       CHECK:   %[[cres:.+]] = cmpi ne, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpEQ(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "eq", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpEQ2
+//       CHECK:   %[[cres:.+]] = cmpi ne, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpEQ2(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "eq", %arg0, %arg1 : i8
+  %ncmp = xor %true, %cmp : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpNE
+//       CHECK:   %[[cres:.+]] = cmpi eq, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpNE(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "ne", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpSLT
+//       CHECK:   %[[cres:.+]] = cmpi sge, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpSLT(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "slt", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpSLE
+//       CHECK:   %[[cres:.+]] = cmpi sgt, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpSLE(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "sle", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpSGT
+//       CHECK:   %[[cres:.+]] = cmpi sle, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpSGT(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "sgt", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpSGE
+//       CHECK:   %[[cres:.+]] = cmpi slt, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpSGE(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "sge", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpULT
+//       CHECK:   %[[cres:.+]] = cmpi uge, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpULT(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "ult", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpULE
+//       CHECK:   %[[cres:.+]] = cmpi ugt, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpULE(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "ule", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpUGT
+//       CHECK:   %[[cres:.+]] = cmpi ule, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpUGT(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "ugt", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}
+
+// CHECK-LABEL: @notCmpUGE
+//       CHECK:   %[[cres:.+]] = cmpi ult, %arg0, %arg1 : i8
+//       CHECK:   return %[[cres]]
+func @notCmpUGE(%arg0: i8, %arg1: i8) -> i1 {
+  %true = constant true
+  %cmp = cmpi "uge", %arg0, %arg1 : i8
+  %ncmp = xor %cmp, %true : i1
+  return %ncmp : i1
+}


        


More information about the Mlir-commits mailing list