[Mlir-commits] [mlir] cb778c3 - [mlir][std] Fold comparisons when the operands are equal
Stephan Herhut
llvmlistbot at llvm.org
Fri Nov 20 04:26:56 PST 2020
Author: Stephan Herhut
Date: 2020-11-20T13:26:41+01:00
New Revision: cb778c34237c384821b7bc961f15d139a10b0ca7
URL: https://github.com/llvm/llvm-project/commit/cb778c34237c384821b7bc961f15d139a10b0ca7
DIFF: https://github.com/llvm/llvm-project/commit/cb778c34237c384821b7bc961f15d139a10b0ca7.diff
LOG: [mlir][std] Fold comparisons when the operands are equal
For equal operands, comparisons can be decided statically.
Differential Revision: https://reviews.llvm.org/D91856
Added:
Modified:
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/test/Dialect/Standard/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 342d73273dd6..6e755daa2669 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -916,17 +916,41 @@ bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
llvm_unreachable("unknown comparison predicate");
}
+// Returns true if the predicate is true for two equal operands.
+static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
+ switch (predicate) {
+ case CmpIPredicate::eq:
+ case CmpIPredicate::sle:
+ case CmpIPredicate::sge:
+ case CmpIPredicate::ule:
+ case CmpIPredicate::uge:
+ return true;
+ case CmpIPredicate::ne:
+ case CmpIPredicate::slt:
+ case CmpIPredicate::sgt:
+ case CmpIPredicate::ult:
+ case CmpIPredicate::ugt:
+ return false;
+ }
+ llvm_unreachable("unknown comparison predicate");
+}
+
// Constant folding hook for comparisons.
OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "cmpi takes two arguments");
+ if (lhs() == rhs()) {
+ auto val = applyCmpPredicateToEqualOperands(getPredicate());
+ return BoolAttr::get(val, getContext());
+ }
+
auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
if (!lhs || !rhs)
return {};
auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
- return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
+ return BoolAttr::get(val, getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 1e2e4a5bf116..51475371244b 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -59,3 +59,25 @@ func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index {
%1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
return %1 : index
}
+
+// Test case: Folding of comparisons with equal operands.
+// CHECK-LABEL: @cmpi_equal_operands
+// CHECK-DAG: %[[T:.*]] = constant true
+// CHECK-DAG: %[[F:.*]] = constant false
+// CHECK: return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
+// CHECK-SAME: %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
+func @cmpi_equal_operands(%arg0: i64)
+ -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
+ %0 = cmpi "eq", %arg0, %arg0 : i64
+ %1 = cmpi "sle", %arg0, %arg0 : i64
+ %2 = cmpi "sge", %arg0, %arg0 : i64
+ %3 = cmpi "ule", %arg0, %arg0 : i64
+ %4 = cmpi "uge", %arg0, %arg0 : i64
+ %5 = cmpi "ne", %arg0, %arg0 : i64
+ %6 = cmpi "slt", %arg0, %arg0 : i64
+ %7 = cmpi "sgt", %arg0, %arg0 : i64
+ %8 = cmpi "ult", %arg0, %arg0 : i64
+ %9 = cmpi "ugt", %arg0, %arg0 : i64
+ return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
+ : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
+}
More information about the Mlir-commits
mailing list