[Mlir-commits] [mlir] cb778c3 - [mlir][std] Fold comparisons when the operands are equal

Stephan Herhut llvmlistbot at llvm.org
Fri Nov 20 04:26:56 PST 2020


Author: Stephan Herhut
Date: 2020-11-20T13:26:41+01:00
New Revision: cb778c34237c384821b7bc961f15d139a10b0ca7

URL: https://github.com/llvm/llvm-project/commit/cb778c34237c384821b7bc961f15d139a10b0ca7
DIFF: https://github.com/llvm/llvm-project/commit/cb778c34237c384821b7bc961f15d139a10b0ca7.diff

LOG: [mlir][std] Fold comparisons when the operands are equal

For equal operands, comparisons can be decided statically.

Differential Revision: https://reviews.llvm.org/D91856

Added: 
    

Modified: 
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/test/Dialect/Standard/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 342d73273dd6..6e755daa2669 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -916,17 +916,41 @@ bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
   llvm_unreachable("unknown comparison predicate");
 }
 
+// Returns true if the predicate is true for two equal operands.
+static bool applyCmpPredicateToEqualOperands(CmpIPredicate predicate) {
+  switch (predicate) {
+  case CmpIPredicate::eq:
+  case CmpIPredicate::sle:
+  case CmpIPredicate::sge:
+  case CmpIPredicate::ule:
+  case CmpIPredicate::uge:
+    return true;
+  case CmpIPredicate::ne:
+  case CmpIPredicate::slt:
+  case CmpIPredicate::sgt:
+  case CmpIPredicate::ult:
+  case CmpIPredicate::ugt:
+    return false;
+  }
+  llvm_unreachable("unknown comparison predicate");
+}
+
 // Constant folding hook for comparisons.
 OpFoldResult CmpIOp::fold(ArrayRef<Attribute> operands) {
   assert(operands.size() == 2 && "cmpi takes two arguments");
 
+  if (lhs() == rhs()) {
+    auto val = applyCmpPredicateToEqualOperands(getPredicate());
+    return BoolAttr::get(val, getContext());
+  }
+
   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
   if (!lhs || !rhs)
     return {};
 
   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
-  return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val));
+  return BoolAttr::get(val, getContext());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
index 1e2e4a5bf116..51475371244b 100644
--- a/mlir/test/Dialect/Standard/canonicalize.mlir
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -59,3 +59,25 @@ func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index {
   %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex>
   return %1 : index
 }
+
+// Test case: Folding of comparisons with equal operands.
+// CHECK-LABEL: @cmpi_equal_operands
+//   CHECK-DAG:   %[[T:.*]] = constant true
+//   CHECK-DAG:   %[[F:.*]] = constant false
+//       CHECK:   return %[[T]], %[[T]], %[[T]], %[[T]], %[[T]],
+//  CHECK-SAME:          %[[F]], %[[F]], %[[F]], %[[F]], %[[F]]
+func @cmpi_equal_operands(%arg0: i64)
+    -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
+  %0 = cmpi "eq", %arg0, %arg0 : i64
+  %1 = cmpi "sle", %arg0, %arg0 : i64
+  %2 = cmpi "sge", %arg0, %arg0 : i64
+  %3 = cmpi "ule", %arg0, %arg0 : i64
+  %4 = cmpi "uge", %arg0, %arg0 : i64
+  %5 = cmpi "ne", %arg0, %arg0 : i64
+  %6 = cmpi "slt", %arg0, %arg0 : i64
+  %7 = cmpi "sgt", %arg0, %arg0 : i64
+  %8 = cmpi "ult", %arg0, %arg0 : i64
+  %9 = cmpi "ugt", %arg0, %arg0 : i64
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9
+      : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
+}


        


More information about the Mlir-commits mailing list