[llvm] [IR] Add helper for comparing KnownBits with IR predicate (NFC) (PR #115878)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 12 07:01:05 PST 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/115878

Add `ICmpInst::compare()` overload accepting `KnownBits`, similar to the existing one accepting `APInt`. This is not directly part of KnownBits (or APInt) for layering reasons.

>From 463997518fd14fb9df3dea18ee7aadbb0e85b846 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 12 Nov 2024 15:59:08 +0100
Subject: [PATCH] [IR] Add helper for comparing KnownBits with IR predicate

Add `ICmpInst::compare()` overload accepting `KnownBits`, similar
to the existing one accepting `APInt`. This is not directly part
of KnownBits (or APInt) for layering reasons.
---
 llvm/include/llvm/IR/Instructions.h           |  6 ++++
 .../lib/CodeGen/GlobalISel/CombinerHelper.cpp | 35 +------------------
 llvm/lib/IR/Instructions.cpp                  | 30 ++++++++++++++++
 3 files changed, 37 insertions(+), 34 deletions(-)

diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index b6575d4c85724c..bc29a4801e4ff4 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -48,6 +48,7 @@ class APInt;
 class BasicBlock;
 class ConstantInt;
 class DataLayout;
+struct KnownBits;
 class StringRef;
 class Type;
 class Value;
@@ -1305,6 +1306,11 @@ class ICmpInst: public CmpInst {
   static bool compare(const APInt &LHS, const APInt &RHS,
                       ICmpInst::Predicate Pred);
 
+  /// Return result of `LHS Pred RHS`, if it can be determined from the
+  /// KnownBits. Otherwise return nullopt.
+  static std::optional<bool> compare(const KnownBits &LHS, const KnownBits &RHS,
+                                     ICmpInst::Predicate Pred);
+
   // Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Instruction *I) {
     return I->getOpcode() == Instruction::ICmp;
diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
index 0945e7334ac9d4..1a1a1c28ef1500 100644
--- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
+++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
@@ -4442,40 +4442,7 @@ bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI,
 
   if (!KnownVal) {
     auto KnownLHS = KB->getKnownBits(MI.getOperand(2).getReg());
-    switch (Pred) {
-    default:
-      llvm_unreachable("Unexpected G_ICMP predicate?");
-    case CmpInst::ICMP_EQ:
-      KnownVal = KnownBits::eq(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_NE:
-      KnownVal = KnownBits::ne(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_SGE:
-      KnownVal = KnownBits::sge(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_SGT:
-      KnownVal = KnownBits::sgt(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_SLE:
-      KnownVal = KnownBits::sle(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_SLT:
-      KnownVal = KnownBits::slt(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_UGE:
-      KnownVal = KnownBits::uge(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_UGT:
-      KnownVal = KnownBits::ugt(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_ULE:
-      KnownVal = KnownBits::ule(KnownLHS, KnownRHS);
-      break;
-    case CmpInst::ICMP_ULT:
-      KnownVal = KnownBits::ult(KnownLHS, KnownRHS);
-      break;
-    }
+    KnownVal = ICmpInst::compare(KnownLHS, KnownRHS, Pred);
   }
 
   if (!KnownVal)
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 05e340ffa20a07..5b89a27126150a 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -40,6 +40,7 @@
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CheckedArithmetic.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/KnownBits.h"
 #include "llvm/Support/MathExtras.h"
 #include "llvm/Support/ModRef.h"
 #include "llvm/Support/TypeSize.h"
@@ -3837,6 +3838,35 @@ bool FCmpInst::compare(const APFloat &LHS, const APFloat &RHS,
   }
 }
 
+std::optional<bool> ICmpInst::compare(const KnownBits &LHS,
+                                      const KnownBits &RHS,
+                                      ICmpInst::Predicate Pred) {
+  switch (Pred) {
+  case ICmpInst::ICMP_EQ:
+    return KnownBits::eq(LHS, RHS);
+  case ICmpInst::ICMP_NE:
+    return KnownBits::ne(LHS, RHS);
+  case ICmpInst::ICMP_UGE:
+    return KnownBits::uge(LHS, RHS);
+  case ICmpInst::ICMP_UGT:
+    return KnownBits::ugt(LHS, RHS);
+  case ICmpInst::ICMP_ULE:
+    return KnownBits::ule(LHS, RHS);
+  case ICmpInst::ICMP_ULT:
+    return KnownBits::ult(LHS, RHS);
+  case ICmpInst::ICMP_SGE:
+    return KnownBits::sge(LHS, RHS);
+  case ICmpInst::ICMP_SGT:
+    return KnownBits::sgt(LHS, RHS);
+  case ICmpInst::ICMP_SLE:
+    return KnownBits::sle(LHS, RHS);
+  case ICmpInst::ICMP_SLT:
+    return KnownBits::slt(LHS, RHS);
+  default:
+    llvm_unreachable("Unexpected non-integer predicate.");
+  }
+}
+
 CmpInst::Predicate CmpInst::getFlippedSignednessPredicate(Predicate pred) {
   assert(CmpInst::isRelational(pred) &&
          "Call only with non-equality predicates!");



More information about the llvm-commits mailing list