[Mlir-commits] [mlir] 2b2889b - Add index::CmpOp canonicalization.
Weiwei Chen
llvmlistbot at llvm.org
Mon Aug 14 19:59:14 PDT 2023
Author: Weiwei Chen
Date: 2023-08-14T22:56:28-04:00
New Revision: 2b2889b723b5d8ebd205251900ba327cb877c88f
URL: https://github.com/llvm/llvm-project/commit/2b2889b723b5d8ebd205251900ba327cb877c88f
DIFF: https://github.com/llvm/llvm-project/commit/2b2889b723b5d8ebd205251900ba327cb877c88f.diff
LOG: Add index::CmpOp canonicalization.
Add canonicalization pattern for index::CmpOp
Differential Revision: https://reviews.llvm.org/D157903
Added:
Modified:
mlir/include/mlir/Dialect/Index/IR/IndexOps.h
mlir/include/mlir/Dialect/Index/IR/IndexOps.td
mlir/lib/Dialect/Index/IR/IndexOps.cpp
mlir/test/Dialect/Index/index-canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
index d9daf892d27e4d..8984f979c1fee9 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.h
@@ -23,10 +23,13 @@
// Forward Declarations
//===----------------------------------------------------------------------===//
-namespace mlir::index {
+namespace mlir {
+class PatternRewriter;
+namespace index {
enum class IndexCmpPredicate : uint32_t;
class IndexCmpPredicateAttr;
-} // namespace mlir::index
+} // namespace index
+} // namespace mlir
//===----------------------------------------------------------------------===//
// ODS-Generated Declarations
diff --git a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
index 5cb179dd70fd2f..e45bedc8206f80 100644
--- a/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
+++ b/mlir/include/mlir/Dialect/Index/IR/IndexOps.td
@@ -542,6 +542,7 @@ def Index_CmpOp : IndexOp<"cmp"> {
let results = (outs I1:$result);
let assemblyFormat = "`` $pred `(` $lhs `,` $rhs `)` attr-dict";
let hasFolder = 1;
+ let hasCanonicalizeMethod = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index d315d48ca2d2af..b6d802876c15ed 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -549,6 +550,37 @@ OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
return {};
}
+/// Canonicalize
+/// `x - y cmp 0` to `x cmp y`. or `x - y cmp 0` to `x cmp y`.
+/// `0 cmp x - y` to `y cmp x`. or `0 cmp x - y` to `y cmp x`.
+LogicalResult CmpOp::canonicalize(CmpOp op, PatternRewriter &rewriter) {
+ IntegerAttr cmpRhs;
+ IntegerAttr cmpLhs;
+
+ bool rhsIsZero = matchPattern(op.getRhs(), m_Constant(&cmpRhs)) &&
+ cmpRhs.getValue().isZero();
+ bool lhsIsZero = matchPattern(op.getLhs(), m_Constant(&cmpLhs)) &&
+ cmpLhs.getValue().isZero();
+ if (!rhsIsZero && !lhsIsZero)
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "cmp is not comparing something with 0");
+ SubOp subOp = rhsIsZero ? op.getLhs().getDefiningOp<index::SubOp>()
+ : op.getRhs().getDefiningOp<index::SubOp>();
+ if (!subOp)
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "non-zero operand is not a result of subtraction");
+
+ index::CmpOp newCmp;
+ if (rhsIsZero)
+ newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
+ subOp.getLhs(), subOp.getRhs());
+ else
+ newCmp = rewriter.create<index::CmpOp>(op.getLoc(), op.getPred(),
+ subOp.getRhs(), subOp.getLhs());
+ rewriter.replaceOp(op, newCmp);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index 2a1fb39f1d483b..67308ffbe55ac6 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -473,7 +473,7 @@ func.func @xor() -> index {
}
// CHECK-LABEL: @cmp
-func.func @cmp() -> (i1, i1, i1, i1) {
+func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) {
%a = index.constant 0
%b = index.constant -1
%c = index.constant -2
@@ -484,10 +484,19 @@ func.func @cmp() -> (i1, i1, i1, i1) {
%2 = index.cmp ne(%d, %a)
%3 = index.cmp sgt(%b, %a)
+ %4 = index.sub %a, %arg0
+ %5 = index.cmp sgt(%4, %a)
+
+ %6 = index.sub %a, %arg0
+ %7 = index.cmp sgt(%a, %6)
+
// CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true
// CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false
+ // CHECK-DAG: [[IDX0:%.*]] = index.constant 0
+ // CHECK-DAG: [[V4:%.*]] = index.cmp sgt([[IDX0]], %arg0)
+ // CHECK-DAG: [[V5:%.*]] = index.cmp sgt(%arg0, [[IDX0]])
// CHECK: return %[[FALSE]], %[[TRUE]], %[[TRUE]], %[[FALSE]]
- return %0, %1, %2, %3 : i1, i1, i1, i1
+ return %0, %1, %2, %3, %5, %7 : i1, i1, i1, i1, i1, i1
}
// CHECK-LABEL: @cmp_nofold
More information about the Mlir-commits
mailing list