[Mlir-commits] [mlir] [mlir][index] Fold `cmp(x, x)` when `x` isn't a constant (PR #78812)

Xiangxi Guo llvmlistbot at llvm.org
Fri Jan 19 15:53:56 PST 2024


https://github.com/StrongerXi created https://github.com/llvm/llvm-project/pull/78812

Such cases show up in the middle of optimizations passes, e.g., after some rewrites and then CSE. The current folder can fold such cases when the inputs are constant; this patch improves it to fold even if the inputs are non-constant.

>From d9396f746650e8ca745ddc0b6741b5a79d80c67e Mon Sep 17 00:00:00 2001
From: Ryan Guo <ryanguo at modular.com>
Date: Fri, 19 Jan 2024 15:29:54 -0800
Subject: [PATCH] [mlir][index] Fold `cmp(x, x)` when `x` isn't a constant

Such cases show up in the middle of optimizations passes, e.g., after
some rewrites and then CSE. The current folder can fold such cases when
the inputs are constant; this patch improves it to fold even if the
inputs are non-constant.
---
 mlir/lib/Dialect/Index/IR/IndexOps.cpp        | 22 +++++++++++++++++++
 .../Dialect/Index/index-canonicalize.mlir     | 20 +++++++++++++++++
 2 files changed, 42 insertions(+)

diff --git a/mlir/lib/Dialect/Index/IR/IndexOps.cpp b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
index b506397742772a..42401dae217ce1 100644
--- a/mlir/lib/Dialect/Index/IR/IndexOps.cpp
+++ b/mlir/lib/Dialect/Index/IR/IndexOps.cpp
@@ -578,6 +578,24 @@ static std::optional<bool> foldCmpOfMaxOrMin(Operation *lhsOp,
                                 lhsRange, ConstantIntRanges::constant(cstB));
 }
 
+/// Return the result of `cmp(pred, x, x)`
+static bool compareSameArgs(IndexCmpPredicate pred) {
+  switch (pred) {
+  case IndexCmpPredicate::EQ:
+  case IndexCmpPredicate::SGE:
+  case IndexCmpPredicate::SLE:
+  case IndexCmpPredicate::UGE:
+  case IndexCmpPredicate::ULE:
+    return true;
+  case IndexCmpPredicate::NE:
+  case IndexCmpPredicate::SGT:
+  case IndexCmpPredicate::SLT:
+  case IndexCmpPredicate::UGT:
+  case IndexCmpPredicate::ULT:
+    return false;
+  }
+}
+
 OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
   // Attempt to fold if both inputs are constant.
   auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
@@ -606,6 +624,10 @@ OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
       return BoolAttr::get(getContext(), *result64);
   }
 
+  // Fold `cmp(x, x)`
+  if (getLhs() == getRhs())
+    return BoolAttr::get(getContext(), compareSameArgs(getPred()));
+
   return {};
 }
 
diff --git a/mlir/test/Dialect/Index/index-canonicalize.mlir b/mlir/test/Dialect/Index/index-canonicalize.mlir
index db03505350b77e..37aa33bfde952e 100644
--- a/mlir/test/Dialect/Index/index-canonicalize.mlir
+++ b/mlir/test/Dialect/Index/index-canonicalize.mlir
@@ -499,6 +499,26 @@ func.func @cmp(%arg0: index) -> (i1, i1, i1, i1, i1, i1) {
   return %0, %1, %2, %3, %5, %7 : i1, i1, i1, i1, i1, i1
 }
 
+// CHECK-LABEL: @cmp_same_args
+func.func @cmp_same_args(%a: index) -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
+  %0 = index.cmp eq(%a, %a)
+  %1 = index.cmp sge(%a, %a)
+  %2 = index.cmp sle(%a, %a)
+  %3 = index.cmp uge(%a, %a)
+  %4 = index.cmp ule(%a, %a)
+  %5 = index.cmp ne(%a, %a)
+  %6 = index.cmp sgt(%a, %a)
+  %7 = index.cmp slt(%a, %a)
+  %8 = index.cmp ugt(%a, %a)
+  %9 = index.cmp ult(%a, %a)
+
+  // CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true
+  // CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false
+  // CHECK-NEXT: return %[[TRUE]], %[[TRUE]], %[[TRUE]], %[[TRUE]], %[[TRUE]],
+  // CHECK-SAME: %[[FALSE]], %[[FALSE]], %[[FALSE]], %[[FALSE]], %[[FALSE]]
+  return %0, %1, %2, %3, %4, %5, %6, %7, %8, %9 : i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
+}
+
 // CHECK-LABEL: @cmp_nofold
 func.func @cmp_nofold() -> i1 {
   %lhs = index.constant 1



More information about the Mlir-commits mailing list