[Mlir-commits] [mlir] 2b2889b - Add index::CmpOp canonicalization.

Weiwei Chen llvmlistbot at llvm.org
Mon Aug 14 19:59:14 PDT 2023


Author: Weiwei Chen
Date: 2023-08-14T22:56:28-04:00
New Revision: 2b2889b723b5d8ebd205251900ba327cb877c88f

URL: https://github.com/llvm/llvm-project/commit/2b2889b723b5d8ebd205251900ba327cb877c88f
DIFF: https://github.com/llvm/llvm-project/commit/2b2889b723b5d8ebd205251900ba327cb877c88f.diff

LOG: Add index::CmpOp canonicalization.
Add canonicalization pattern for index::CmpOp

Differential Revision: https://reviews.llvm.org/D157903

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Index/IR/IndexOps.h
    mlir/include/mlir/Dialect/Index/IR/IndexOps.td
    mlir/lib/Dialect/Index/IR/IndexOps.cpp
    mlir/test/Dialect/Index/index-canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
index d9daf892d27e4d..8984f979c1fee9 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
@@ -23,10 +23,13 @@
 // Forward Declarations
 //===----------------------------------------------------------------------===//
 
-namespace mlir::index {
+namespace mlir {
+class PatternRewriter;
+namespace index {
 enum class IndexCmpPredicate : uint32_t;
 class IndexCmpPredicateAttr;
-} // namespace mlir::index
+} // namespace index
+} // namespace mlir
 
 //===----------------------------------------------------------------------===//
 // ODS-Generated Declarations

diff  --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 5cb179dd70fd2f..e45bedc8206f80 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -542,6 +542,7 @@ def Index_CmpOp : IndexOp<"cmp"> {
   let results = (outs I1:$result);
   let assemblyFormat = "`` $pred `(` $lhs `,` $rhs `)` attr-dict";
   let hasFolder = 1;
+  let hasCanonicalizeMethod = 1;
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index d315d48ca2d2af..b6d802876c15ed 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -12,6 +12,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -549,6 +550,37 @@ OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
   return {};
 }
 
+/// Canonicalize
+/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
+/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
+LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
+  IntegerAttr cmpRhs;
+  IntegerAttr cmpLhs;
+
+  bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
+                   cmpRhs.getValue().isZero();
+  bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
+                   cmpLhs.getValue().isZero();
+  if (!rhsIsZero && !lhsIsZero)
+    return rewriter.notifyMatchFailure(op.getLoc(),
+                                       "cmp is not comparing something with 0");
+  SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
+                          : op.getRhs().getDefiningOp<index::SubOp>();
+  if (!subOp)
+    return rewriter.notifyMatchFailure(
+        op.getLoc(), "non-zero operand is not a result of subtraction");
+
+  index::CmpOp newCmp;
+  if (rhsIsZero)
+    newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
+                                           subOp.getLhs(), subOp.getRhs());
+  else
+    newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
+                                           subOp.getRhs(), subOp.getLhs());
+  rewriter.replaceOp(op, newCmp);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // ConstantOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 2a1fb39f1d483b..67308ffbe55ac6 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -473,7 +473,7 @@ func.func @xor() -> index {
 }
 
 // CHECK-LABEL: @cmp
-func.func @cmp() -> (i1, i1, i1, i1) {
+func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) {
   %a = index.constant 0
   %b = index.constant -1
   %c = index.constant -2
@@ -484,10 +484,19 @@ func.func @cmp() -> (i1, i1, i1, i1) {
   %2 = index.cmp ne(%d, %a)
   %3 = index.cmp sgt(%b, %a)
 
+  %4 = index.sub %a, %arg0
+  %5 = index.cmp sgt(%4, %a)
+
+  %6 = index.sub %a, %arg0
+  %7 = index.cmp sgt(%a, %6)
+
   // CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true
   // CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false
+  // CHECK-DAG: [[IDX0:%.*]] = index.constant 0
+  // CHECK-DAG: [[V4:%.*]] = index.cmp sgt([[IDX0]], %arg0)
+  // CHECK-DAG: [[V5:%.*]] = index.cmp sgt(%arg0, [[IDX0]])
   // CHECK: return %[[FALSE]], %[[TRUE]], %[[TRUE]], %[[FALSE]]
-  return %0, %1, %2, %3 : i1, i1, i1, i1
+  return %0, %1, %2, %3, %5, %7 : i1, i1, i1, i1, i1, i1
 }
 
 // CHECK-LABEL: @cmp_nofold


        


More information about the Mlir-commits mailing list