[llvm] [IR] Add helper for comparing KnownBits with IR predicate (NFC) (PR #115878)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 07:02:26 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-globalisel

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

Add `ICmpInst::compare()` overload accepting `KnownBits`, similar to the existing one accepting `APInt`. This is not directly part of KnownBits (or APInt) for layering reasons.

---
Full diff: https://github.com/llvm/llvm-project/pull/115878.diff


3 Files Affected:

- (modified) llvm/include/llvm/IR/Instructions.h (+6) 
- (modified) llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp (+1-34) 
- (modified) llvm/lib/IR/Instructions.cpp (+30) 


``````````diff
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index b6575d4c85724c..bc29a4801e4ff4 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -48,6 +48,7 @@ class APInt;
 class BasicBlock;
 class ConstantInt;
 class DataLayout;
+struct KnownBits;
 class StringRef;
 class Type;
 class Value;
@@ -1305,6 +1306,11 @@ class ICmpInst: public CmpInst {
   static bool compare(const APInt &LHS, const APInt &RHS,
                       ICmpInst::Predicate Pred);
 
+  /// Return result of `LHS Pred RHS`, if it can be determined from the
+  /// KnownBits. Otherwise return nullopt.
+  static std::optional<bool> compare(const KnownBits &LHS, const KnownBits &RHS,
+                                     ICmpInst::Predicate Pred);
+
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Instruction *I) {
     return I->getOpcode() == Instruction::ICmp;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 0945e7334ac9d4..1a1a1c28ef1500 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -4442,40 +4442,7 @@ bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
 
   if (!KnownVal) {
     auto KnownLHS = KB->getKnownBits(MI.getOperand(2).getReg());
-    switch (Pred) {
-    default:
-      llvm_unreachable("Unexpected G_ICMP predicate?");
-    case CmpInst::ICMP_EQ:
-      KnownVal = KnownBits::eq(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_NE:
-      KnownVal = KnownBits::ne(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_SGE:
-      KnownVal = KnownBits::sge(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_SGT:
-      KnownVal = KnownBits::sgt(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_SLE:
-      KnownVal = KnownBits::sle(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_SLT:
-      KnownVal = KnownBits::slt(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_UGE:
-      KnownVal = KnownBits::uge(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_UGT:
-      KnownVal = KnownBits::ugt(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_ULE:
-      KnownVal = KnownBits::ule(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_ULT:
-      KnownVal = KnownBits::ult(KnownLHS, KnownRHS);
-      break;
-    }
+    KnownVal = ICmpInst::compare(KnownLHS, KnownRHS, Pred);
   }
 
   if (!KnownVal)
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 05e340ffa20a07..5b89a27126150a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -40,6 +40,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CheckedArithmetic.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/KnownBits.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/ModRef.h"
 #include "llvm/Support/TypeSize.h"
@@ -3837,6 +3838,35 @@ bool FCmpInst::compare(const APFloat &LHS, const APFloat &RHS,
   }
 }
 
+std::optional<bool> ICmpInst::compare(const KnownBits &LHS,
+                                      const KnownBits &RHS,
+                                      ICmpInst::Predicate Pred) {
+  switch (Pred) {
+  case ICmpInst::ICMP_EQ:
+    return KnownBits::eq(LHS, RHS);
+  case ICmpInst::ICMP_NE:
+    return KnownBits::ne(LHS, RHS);
+  case ICmpInst::ICMP_UGE:
+    return KnownBits::uge(LHS, RHS);
+  case ICmpInst::ICMP_UGT:
+    return KnownBits::ugt(LHS, RHS);
+  case ICmpInst::ICMP_ULE:
+    return KnownBits::ule(LHS, RHS);
+  case ICmpInst::ICMP_ULT:
+    return KnownBits::ult(LHS, RHS);
+  case ICmpInst::ICMP_SGE:
+    return KnownBits::sge(LHS, RHS);
+  case ICmpInst::ICMP_SGT:
+    return KnownBits::sgt(LHS, RHS);
+  case ICmpInst::ICMP_SLE:
+    return KnownBits::sle(LHS, RHS);
+  case ICmpInst::ICMP_SLT:
+    return KnownBits::slt(LHS, RHS);
+  default:
+    llvm_unreachable("Unexpected non-integer predicate.");
+  }
+}
+
 CmpInst::Predicate CmpInst::getFlippedSignednessPredicate(Predicate pred) {
   assert(CmpInst::isRelational(pred) &&
          "Call only with non-equality predicates!");

``````````

</details>


https://github.com/llvm/llvm-project/pull/115878


More information about the llvm-commits mailing list