[llvm] [NFC] Add CmpIntrinsic class to represent calls to UCMP/SCMP intrinsics (PR #98177)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 9 08:49:06 PDT 2024


https://github.com/Poseydon42 created https://github.com/llvm/llvm-project/pull/98177

None

>From 9cc2c2795d1eaf23f7fc5a503a65ac4a4f7a60b4 Mon Sep 17 00:00:00 2001
From: Poseydon42 <vvmposeydon at gmail.com>
Date: Tue, 9 Jul 2024 16:48:16 +0100
Subject: [PATCH] [NFC] Add CmpIntrinsic class to represent calls to UCMP/SCMP
 intrinsics

---
 llvm/include/llvm/IR/IntrinsicInst.h          | 37 +++++++++++++++++++
 .../Scalar/CorrelatedValuePropagation.cpp     | 32 ++++++++--------
 2 files changed, 52 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h
index 3963a5c8ab8f9..2a37c06dd2c3c 100644
--- a/llvm/include/llvm/IR/IntrinsicInst.h
+++ b/llvm/include/llvm/IR/IntrinsicInst.h
@@ -834,6 +834,43 @@ class MinMaxIntrinsic : public IntrinsicInst {
   }
 };
 
+/// This class represents a ucmp/scmp intrinsic
+class CmpIntrinsic : public IntrinsicInst {
+public:
+  static bool classof(const IntrinsicInst *I) {
+    switch (I->getIntrinsicID()) {
+    case Intrinsic::scmp:
+    case Intrinsic::ucmp:
+      return true;
+    default:
+      return false;
+    }
+  }
+  static bool classof(const Value *V) {
+    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
+  }
+
+  Value *getLHS() const { return const_cast<Value *>(getArgOperand(0)); }
+  Value *getRHS() const { return const_cast<Value *>(getArgOperand(1)); }
+
+  static bool isSigned(Intrinsic::ID ID) { return ID == Intrinsic::scmp; }
+  bool isSigned() const { return isSigned(getIntrinsicID()); }
+
+  static CmpInst::Predicate getGTPredicate(Intrinsic::ID ID) {
+    return isSigned(ID) ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
+  }
+  CmpInst::Predicate getGTPredicate() const {
+    return getGTPredicate(getIntrinsicID());
+  }
+
+  static CmpInst::Predicate getLTPredicate(Intrinsic::ID ID) {
+    return isSigned(ID) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
+  }
+  CmpInst::Predicate getLTPredicate() const {
+    return getLTPredicate(getIntrinsicID());
+  }
+};
+
 /// This class represents an intrinsic that is based on a binary operation.
 /// This includes op.with.overflow and saturating add/sub intrinsics.
 class BinaryOpIntrinsic : public IntrinsicInst {
diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
index 20f5dba413212..596078407edd1 100644
--- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
+++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp
@@ -538,29 +538,28 @@ static bool processAbsIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) {
   return false;
 }
 
-static bool processCmpIntrinsic(IntrinsicInst *II, LazyValueInfo *LVI) {
-  bool IsSigned = II->getIntrinsicID() == Intrinsic::scmp;
-  ConstantRange LHS_CR = LVI->getConstantRangeAtUse(II->getOperandUse(0),
-                                                    /*UndefAllowed*/ false);
-  ConstantRange RHS_CR = LVI->getConstantRangeAtUse(II->getOperandUse(1),
-                                                    /*UndefAllowed*/ false);
+static bool processCmpIntrinsic(CmpIntrinsic *CI, LazyValueInfo *LVI) {
+  ConstantRange LHS_CR =
+      LVI->getConstantRangeAtUse(CI->getOperandUse(0), /*UndefAllowed*/ false);
+  ConstantRange RHS_CR =
+      LVI->getConstantRangeAtUse(CI->getOperandUse(1), /*UndefAllowed*/ false);
 
-  if (LHS_CR.icmp(IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, RHS_CR)) {
+  if (LHS_CR.icmp(CI->getGTPredicate(), RHS_CR)) {
     ++NumCmpIntr;
-    II->replaceAllUsesWith(ConstantInt::get(II->getType(), 1));
-    II->eraseFromParent();
+    CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 1));
+    CI->eraseFromParent();
     return true;
   }
-  if (LHS_CR.icmp(IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, RHS_CR)) {
+  if (LHS_CR.icmp(CI->getLTPredicate(), RHS_CR)) {
     ++NumCmpIntr;
-    II->replaceAllUsesWith(ConstantInt::getSigned(II->getType(), -1));
-    II->eraseFromParent();
+    CI->replaceAllUsesWith(ConstantInt::getSigned(CI->getType(), -1));
+    CI->eraseFromParent();
     return true;
   }
   if (LHS_CR.icmp(ICmpInst::ICMP_EQ, RHS_CR)) {
     ++NumCmpIntr;
-    II->replaceAllUsesWith(ConstantInt::get(II->getType(), 0));
-    II->eraseFromParent();
+    CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0));
+    CI->eraseFromParent();
     return true;
   }
 
@@ -658,9 +657,8 @@ static bool processCallSite(CallBase &CB, LazyValueInfo *LVI) {
     return processAbsIntrinsic(&cast<IntrinsicInst>(CB), LVI);
   }
 
-  if (CB.getIntrinsicID() == Intrinsic::scmp ||
-      CB.getIntrinsicID() == Intrinsic::ucmp) {
-    return processCmpIntrinsic(&cast<IntrinsicInst>(CB), LVI);
+  if (auto* CI = dyn_cast<CmpIntrinsic>(&CB)) {
+    return processCmpIntrinsic(CI, LVI);
   }
 
   if (auto *MM = dyn_cast<MinMaxIntrinsic>(&CB)) {



More information about the llvm-commits mailing list