[llvm] [ValueTracking] Filter out non-interesting conditions (PR #118493)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 3 06:29:13 PST 2024


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/118493

Address issue https://github.com/llvm/llvm-project/pull/117442#discussion_r1855539750


>From 652a01a29fc19cd05708a22b2a7dfbe6f9345d04 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 3 Dec 2024 22:27:59 +0800
Subject: [PATCH] [ValueTracking] Filter out non-interesting conditions

---
 .../include/llvm/Analysis/DomConditionCache.h | 23 +++++-
 llvm/include/llvm/Analysis/ValueTracking.h    |  8 +-
 llvm/lib/Analysis/AssumptionCache.cpp         |  3 +-
 llvm/lib/Analysis/DomConditionCache.cpp       | 25 +++++--
 llvm/lib/Analysis/ValueTracking.cpp           | 75 +++++++++++--------
 .../InstCombine/InstCombineCompares.cpp       |  4 +-
 .../InstCombine/InstCombineSelect.cpp         |  8 +-
 7 files changed, 95 insertions(+), 51 deletions(-)

diff --git a/llvm/include/llvm/Analysis/DomConditionCache.h b/llvm/include/llvm/Analysis/DomConditionCache.h
index ac25803143f49e..4f0d2363eec71b 100644
--- a/llvm/include/llvm/Analysis/DomConditionCache.h
+++ b/llvm/include/llvm/Analysis/DomConditionCache.h
@@ -18,18 +18,34 @@
 #define LLVM_ANALYSIS_DOMCONDITIONCACHE_H
 
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
+#include <cstdint>
 
 namespace llvm {
 
 class Value;
 class BranchInst;
 
+enum class DomConditionFlag : uint8_t {
+  None = 0,
+  KnownBits = 1 << 0,
+  KnownFPClass = 1 << 1,
+  PowerOfTwo = 1 << 2,
+  ICmp = 1 << 3,
+};
+
+LLVM_DECLARE_ENUM_AS_BITMASK(
+    DomConditionFlag,
+    /*LargestValue=*/static_cast<uint8_t>(DomConditionFlag::ICmp));
+
 class DomConditionCache {
 private:
   /// A map of values about which a branch might be providing information.
-  using AffectedValuesMap = DenseMap<Value *, SmallVector<BranchInst *, 1>>;
+  using AffectedValuesMap =
+      DenseMap<Value *,
+               SmallVector<std::pair<BranchInst *, DomConditionFlag>, 1>>;
   AffectedValuesMap AffectedValues;
 
 public:
@@ -40,10 +56,11 @@ class DomConditionCache {
   void removeValue(Value *V) { AffectedValues.erase(V); }
 
   /// Access the list of branches which affect this value.
-  ArrayRef<BranchInst *> conditionsFor(const Value *V) const {
+  ArrayRef<std::pair<BranchInst *, DomConditionFlag>>
+  conditionsFor(const Value *V) const {
     auto AVI = AffectedValues.find_as(const_cast<Value *>(V));
     if (AVI == AffectedValues.end())
-      return ArrayRef<BranchInst *>();
+      return {};
 
     return AVI->second;
   }
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index bd74d27e0c49b1..c887c0b1603e4a 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -14,13 +14,14 @@
 #ifndef LLVM_ANALYSIS_VALUETRACKING_H
 #define LLVM_ANALYSIS_VALUETRACKING_H
 
+#include "DomConditionCache.h"
 #include "llvm/Analysis/SimplifyQuery.h"
 #include "llvm/Analysis/WithCache.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/FMF.h"
-#include "llvm/IR/Instructions.h"
 #include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include <cassert>
 #include <cstdint>
@@ -1275,8 +1276,9 @@ std::optional<bool> isImpliedByDomCondition(CmpInst::Predicate Pred,
 /// Call \p InsertAffected on all Values whose known bits / value may be
 /// affected by the condition \p Cond. Used by AssumptionCache and
 /// DomConditionCache.
-void findValuesAffectedByCondition(Value *Cond, bool IsAssume,
-                                   function_ref<void(Value *)> InsertAffected);
+void findValuesAffectedByCondition(
+    Value *Cond, bool IsAssume,
+    function_ref<void(Value *, DomConditionFlag)> InsertAffected);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index a0e57ab741dfa8..2a5d742df1f6eb 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -59,7 +59,8 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
   // Note: This code must be kept in-sync with the code in
   // computeKnownBitsFromAssume in ValueTracking.
 
-  auto InsertAffected = [&Affected](Value *V) {
+  // TODO: Use DomConditionFlag to filter out non-interesting conditions.
+  auto InsertAffected = [&Affected](Value *V, DomConditionFlag) {
     Affected.push_back({V, AssumptionCache::ExprResultIdx});
   };
 
diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index 66bd15b47901d7..345b2e22a687ba 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -10,19 +10,30 @@
 #include "llvm/Analysis/ValueTracking.h"
 using namespace llvm;
 
-static void findAffectedValues(Value *Cond,
-                               SmallVectorImpl<Value *> &Affected) {
-  auto InsertAffected = [&Affected](Value *V) { Affected.push_back(V); };
+static void findAffectedValues(
+    Value *Cond,
+    SmallVectorImpl<std::pair<Value *, DomConditionFlag>> &Affected) {
+  auto InsertAffected = [&Affected](Value *V, DomConditionFlag Flags) {
+    Affected.push_back({V, Flags});
+  };
   findValuesAffectedByCondition(Cond, /*IsAssume=*/false, InsertAffected);
 }
 
 void DomConditionCache::registerBranch(BranchInst *BI) {
   assert(BI->isConditional() && "Must be conditional branch");
-  SmallVector<Value *, 16> Affected;
+  SmallVector<std::pair<Value *, DomConditionFlag>, 16> Affected;
   findAffectedValues(BI->getCondition(), Affected);
-  for (Value *V : Affected) {
+  for (auto [V, Flags] : Affected) {
     auto &AV = AffectedValues[V];
-    if (!is_contained(AV, BI))
-      AV.push_back(BI);
+    bool Exist = false;
+    for (auto &[OtherBI, OtherFlags] : AV) {
+      if (OtherBI == BI) {
+        OtherFlags |= Flags;
+        Exist = true;
+        break;
+      }
+    }
+    if (!Exist)
+      AV.push_back({BI, Flags});
   }
 }
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index d81546d0c9fedc..8d63c0d2508a9a 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -790,7 +790,9 @@ void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
 
   if (Q.DC && Q.DT) {
     // Handle dominating conditions.
-    for (BranchInst *BI : Q.DC->conditionsFor(V)) {
+    for (auto [BI, Flag] : Q.DC->conditionsFor(V)) {
+      if (!any(Flag & DomConditionFlag::KnownBits))
+        continue;
       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
       if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
         computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
@@ -2299,7 +2301,9 @@ bool llvm::isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth,
 
   // Handle dominating conditions.
   if (Q.DC && Q.CxtI && Q.DT) {
-    for (BranchInst *BI : Q.DC->conditionsFor(V)) {
+    for (auto [BI, Flag] : Q.DC->conditionsFor(V)) {
+      if (!any(Flag & DomConditionFlag::PowerOfTwo))
+        continue;
       Value *Cond = BI->getCondition();
 
       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
@@ -4930,7 +4934,9 @@ static KnownFPClass computeKnownFPClassFromContext(const Value *V,
 
   if (Q.DC && Q.DT) {
     // Handle dominating conditions.
-    for (BranchInst *BI : Q.DC->conditionsFor(V)) {
+    for (auto [BI, Flag] : Q.DC->conditionsFor(V)) {
+      if (!any(Flag & DomConditionFlag::KnownFPClass))
+        continue;
       Value *Cond = BI->getCondition();
 
       BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
@@ -10014,36 +10020,38 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
   return CR;
 }
 
-static void
-addValueAffectedByCondition(Value *V,
-                            function_ref<void(Value *)> InsertAffected) {
+static void addValueAffectedByCondition(
+    Value *V, function_ref<void(Value *, DomConditionFlag)> InsertAffected,
+    DomConditionFlag Flags) {
   assert(V != nullptr);
   if (isa<Argument>(V) || isa<GlobalValue>(V)) {
-    InsertAffected(V);
+    InsertAffected(V, Flags);
   } else if (auto *I = dyn_cast<Instruction>(V)) {
-    InsertAffected(V);
+    InsertAffected(V, Flags);
 
     // Peek through unary operators to find the source of the condition.
     Value *Op;
     if (match(I, m_CombineOr(m_PtrToInt(m_Value(Op)), m_Trunc(m_Value(Op))))) {
       if (isa<Instruction>(Op) || isa<Argument>(Op))
-        InsertAffected(Op);
+        InsertAffected(Op, Flags);
     }
   }
 }
 
 void llvm::findValuesAffectedByCondition(
-    Value *Cond, bool IsAssume, function_ref<void(Value *)> InsertAffected) {
-  auto AddAffected = [&InsertAffected](Value *V) {
-    addValueAffectedByCondition(V, InsertAffected);
+    Value *Cond, bool IsAssume,
+    function_ref<void(Value *, DomConditionFlag)> InsertAffected) {
+  auto AddAffected = [&InsertAffected](Value *V, DomConditionFlag Flags) {
+    addValueAffectedByCondition(V, InsertAffected, Flags);
   };
 
-  auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS) {
+  auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS,
+                                                 DomConditionFlag Flags) {
     if (IsAssume) {
-      AddAffected(LHS);
-      AddAffected(RHS);
+      AddAffected(LHS, Flags);
+      AddAffected(RHS, Flags);
     } else if (match(RHS, m_Constant()))
-      AddAffected(LHS);
+      AddAffected(LHS, Flags);
   };
 
   SmallVector<Value *, 8> Worklist;
@@ -10058,9 +10066,9 @@ void llvm::findValuesAffectedByCondition(
     Value *A, *B, *X;
 
     if (IsAssume) {
-      AddAffected(V);
+      AddAffected(V, DomConditionFlag::KnownBits);
       if (match(V, m_Not(m_Value(X))))
-        AddAffected(X);
+        AddAffected(X, DomConditionFlag::KnownBits);
     }
 
     if (match(V, m_LogicalOp(m_Value(A), m_Value(B)))) {
@@ -10074,7 +10082,8 @@ void llvm::findValuesAffectedByCondition(
         Worklist.push_back(B);
       }
     } else if (match(V, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
-      AddCmpOperands(A, B);
+      AddCmpOperands(A, B,
+                     DomConditionFlag::KnownBits | DomConditionFlag::ICmp);
 
       bool HasRHSC = match(B, m_ConstantInt());
       if (ICmpInst::isEquality(Pred)) {
@@ -10084,11 +10093,11 @@ void llvm::findValuesAffectedByCondition(
           // (X << C) or (X >>_s C) or (X >>_u C).
           if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
               match(A, m_Shift(m_Value(X), m_ConstantInt())))
-            AddAffected(X);
+            AddAffected(X, DomConditionFlag::KnownBits);
           else if (match(A, m_And(m_Value(X), m_Value(Y))) ||
                    match(A, m_Or(m_Value(X), m_Value(Y)))) {
-            AddAffected(X);
-            AddAffected(Y);
+            AddAffected(X, DomConditionFlag::KnownBits);
+            AddAffected(Y, DomConditionFlag::KnownBits);
           }
         }
       } else {
@@ -10096,7 +10105,7 @@ void llvm::findValuesAffectedByCondition(
           // Handle (A + C1) u< C2, which is the canonical form of
           // A > C3 && A < C4.
           if (match(A, m_AddLike(m_Value(X), m_ConstantInt())))
-            AddAffected(X);
+            AddAffected(X, DomConditionFlag::KnownBits);
 
           if (ICmpInst::isUnsigned(Pred)) {
             Value *Y;
@@ -10106,12 +10115,12 @@ void llvm::findValuesAffectedByCondition(
             if (match(A, m_And(m_Value(X), m_Value(Y))) ||
                 match(A, m_Or(m_Value(X), m_Value(Y))) ||
                 match(A, m_NUWAdd(m_Value(X), m_Value(Y)))) {
-              AddAffected(X);
-              AddAffected(Y);
+              AddAffected(X, DomConditionFlag::KnownBits);
+              AddAffected(Y, DomConditionFlag::KnownBits);
             }
             // X nuw- Y u> C -> X u> C
             if (match(A, m_NUWSub(m_Value(X), m_Value())))
-              AddAffected(X);
+              AddAffected(X, DomConditionFlag::KnownBits);
           }
         }
 
@@ -10119,29 +10128,29 @@ void llvm::findValuesAffectedByCondition(
         // by computeKnownFPClass().
         if (match(A, m_ElementWiseBitCast(m_Value(X)))) {
           if (Pred == ICmpInst::ICMP_SLT && match(B, m_Zero()))
-            InsertAffected(X);
+            InsertAffected(X, DomConditionFlag::KnownFPClass);
           else if (Pred == ICmpInst::ICMP_SGT && match(B, m_AllOnes()))
-            InsertAffected(X);
+            InsertAffected(X, DomConditionFlag::KnownFPClass);
         }
       }
 
       if (HasRHSC && match(A, m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))
-        AddAffected(X);
+        AddAffected(X, DomConditionFlag::PowerOfTwo);
     } else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) {
-      AddCmpOperands(A, B);
+      AddCmpOperands(A, B, DomConditionFlag::KnownFPClass);
 
       // fcmp fneg(x), y
       // fcmp fabs(x), y
       // fcmp fneg(fabs(x)), y
       if (match(A, m_FNeg(m_Value(A))))
-        AddAffected(A);
+        AddAffected(A, DomConditionFlag::KnownFPClass);
       if (match(A, m_FAbs(m_Value(A))))
-        AddAffected(A);
+        AddAffected(A, DomConditionFlag::KnownFPClass);
 
     } else if (match(V, m_Intrinsic<Intrinsic::is_fpclass>(m_Value(A),
                                                            m_Value()))) {
       // Handle patterns that computeKnownFPClass() support.
-      AddAffected(A);
+      AddAffected(A, DomConditionFlag::KnownFPClass);
     }
   }
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fed21db393ed22..5f635cc41a94f7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1385,7 +1385,9 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
     return nullptr;
   };
 
-  for (BranchInst *BI : DC.conditionsFor(X)) {
+  for (auto [BI, Flags] : DC.conditionsFor(X)) {
+    if (!any(Flags & DomConditionFlag::ICmp))
+      continue;
     ICmpInst::Predicate DomPred;
     const APInt *DomC;
     if (!match(BI->getCondition(),
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index c7a0c35d099cc4..e792190f95e082 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4293,9 +4293,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
       (!isa<Constant>(TrueVal) || !isa<Constant>(FalseVal))) {
     // Try to simplify select arms based on KnownBits implied by the condition.
     CondContext CC(CondVal);
-    findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
-      CC.AffectedValues.insert(V);
-    });
+    findValuesAffectedByCondition(
+        CondVal, /*IsAssume=*/false, [&](Value *V, DomConditionFlag Flags) {
+          if (any(Flags & DomConditionFlag::KnownBits))
+            CC.AffectedValues.insert(V);
+        });
     SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC);
     if (!CC.AffectedValues.empty()) {
       if (!isa<Constant>(TrueVal) &&



More information about the llvm-commits mailing list