[llvm] [ValueTracking] Filter out non-interesting conditions (PR #118493)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 3 06:29:13 PST 2024
https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/118493
Address issue https://github.com/llvm/llvm-project/pull/117442#discussion_r1855539750
>From 652a01a29fc19cd05708a22b2a7dfbe6f9345d04 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 3 Dec 2024 22:27:59 +0800
Subject: [PATCH] [ValueTracking] Filter out non-interesting conditions
---
.../include/llvm/Analysis/DomConditionCache.h | 23 +++++-
llvm/include/llvm/Analysis/ValueTracking.h | 8 +-
llvm/lib/Analysis/AssumptionCache.cpp | 3 +-
llvm/lib/Analysis/DomConditionCache.cpp | 25 +++++--
llvm/lib/Analysis/ValueTracking.cpp | 75 +++++++++++--------
.../InstCombine/InstCombineCompares.cpp | 4 +-
.../InstCombine/InstCombineSelect.cpp | 8 +-
7 files changed, 95 insertions(+), 51 deletions(-)
diff --git a/llvm/include/llvm/Analysis/DomConditionCache.h b/llvm/include/llvm/Analysis/DomConditionCache.h
index ac25803143f49e..4f0d2363eec71b 100644
--- a/llvm/include/llvm/Analysis/DomConditionCache.h
+++ b/llvm/include/llvm/Analysis/DomConditionCache.h
@@ -18,18 +18,34 @@
#define LLVM_ANALYSIS_DOMCONDITIONCACHE_H
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
+#include <cstdint>
namespace llvm {
class Value;
class BranchInst;
+enum class DomConditionFlag : uint8_t {
+ None = 0,
+ KnownBits = 1 << 0,
+ KnownFPClass = 1 << 1,
+ PowerOfTwo = 1 << 2,
+ ICmp = 1 << 3,
+};
+
+LLVM_DECLARE_ENUM_AS_BITMASK(
+ DomConditionFlag,
+ /*LargestValue=*/static_cast<uint8_t>(DomConditionFlag::ICmp));
+
class DomConditionCache {
private:
/// A map of values about which a branch might be providing information.
- using AffectedValuesMap = DenseMap<Value *, SmallVector<BranchInst *, 1>>;
+ using AffectedValuesMap =
+ DenseMap<Value *,
+ SmallVector<std::pair<BranchInst *, DomConditionFlag>, 1>>;
AffectedValuesMap AffectedValues;
public:
@@ -40,10 +56,11 @@ class DomConditionCache {
void removeValue(Value *V) { AffectedValues.erase(V); }
/// Access the list of branches which affect this value.
- ArrayRef<BranchInst *> conditionsFor(const Value *V) const {
+ ArrayRef<std::pair<BranchInst *, DomConditionFlag>>
+ conditionsFor(const Value *V) const {
auto AVI = AffectedValues.find_as(const_cast<Value *>(V));
if (AVI == AffectedValues.end())
- return ArrayRef<BranchInst *>();
+ return {};
return AVI->second;
}
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index bd74d27e0c49b1..c887c0b1603e4a 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -14,13 +14,14 @@
#ifndef LLVM_ANALYSIS_VALUETRACKING_H
#define LLVM_ANALYSIS_VALUETRACKING_H
+#include "DomConditionCache.h"
#include "llvm/Analysis/SimplifyQuery.h"
#include "llvm/Analysis/WithCache.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/FMF.h"
-#include "llvm/IR/Instructions.h"
#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include <cassert>
#include <cstdint>
@@ -1275,8 +1276,9 @@ std::optional<bool> isImpliedByDomCondition(CmpInst::Predicate Pred,
/// Call \p InsertAffected on all Values whose known bits / value may be
/// affected by the condition \p Cond. Used by AssumptionCache and
/// DomConditionCache.
-void findValuesAffectedByCondition(Value *Cond, bool IsAssume,
- function_ref<void(Value *)> InsertAffected);
+void findValuesAffectedByCondition(
+ Value *Cond, bool IsAssume,
+ function_ref<void(Value *, DomConditionFlag)> InsertAffected);
} // end namespace llvm
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index a0e57ab741dfa8..2a5d742df1f6eb 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -59,7 +59,8 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
// Note: This code must be kept in-sync with the code in
// computeKnownBitsFromAssume in ValueTracking.
- auto InsertAffected = [&Affected](Value *V) {
+ // TODO: Use DomConditionFlag to filter out non-interesting conditions.
+ auto InsertAffected = [&Affected](Value *V, DomConditionFlag) {
Affected.push_back({V, AssumptionCache::ExprResultIdx});
};
diff --git a/llvm/lib/Analysis/DomConditionCache.cpp b/llvm/lib/Analysis/DomConditionCache.cpp
index 66bd15b47901d7..345b2e22a687ba 100644
--- a/llvm/lib/Analysis/DomConditionCache.cpp
+++ b/llvm/lib/Analysis/DomConditionCache.cpp
@@ -10,19 +10,30 @@
#include "llvm/Analysis/ValueTracking.h"
using namespace llvm;
-static void findAffectedValues(Value *Cond,
- SmallVectorImpl<Value *> &Affected) {
- auto InsertAffected = [&Affected](Value *V) { Affected.push_back(V); };
+static void findAffectedValues(
+ Value *Cond,
+ SmallVectorImpl<std::pair<Value *, DomConditionFlag>> &Affected) {
+ auto InsertAffected = [&Affected](Value *V, DomConditionFlag Flags) {
+ Affected.push_back({V, Flags});
+ };
findValuesAffectedByCondition(Cond, /*IsAssume=*/false, InsertAffected);
}
void DomConditionCache::registerBranch(BranchInst *BI) {
assert(BI->isConditional() && "Must be conditional branch");
- SmallVector<Value *, 16> Affected;
+ SmallVector<std::pair<Value *, DomConditionFlag>, 16> Affected;
findAffectedValues(BI->getCondition(), Affected);
- for (Value *V : Affected) {
+ for (auto [V, Flags] : Affected) {
auto &AV = AffectedValues[V];
- if (!is_contained(AV, BI))
- AV.push_back(BI);
+ bool Exist = false;
+ for (auto &[OtherBI, OtherFlags] : AV) {
+ if (OtherBI == BI) {
+ OtherFlags |= Flags;
+ Exist = true;
+ break;
+ }
+ }
+ if (!Exist)
+ AV.push_back({BI, Flags});
}
}
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index d81546d0c9fedc..8d63c0d2508a9a 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -790,7 +790,9 @@ void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
if (Q.DC && Q.DT) {
// Handle dominating conditions.
- for (BranchInst *BI : Q.DC->conditionsFor(V)) {
+ for (auto [BI, Flag] : Q.DC->conditionsFor(V)) {
+ if (!any(Flag & DomConditionFlag::KnownBits))
+ continue;
BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
computeKnownBitsFromCond(V, BI->getCondition(), Known, Depth, Q,
@@ -2299,7 +2301,9 @@ bool llvm::isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth,
// Handle dominating conditions.
if (Q.DC && Q.CxtI && Q.DT) {
- for (BranchInst *BI : Q.DC->conditionsFor(V)) {
+ for (auto [BI, Flag] : Q.DC->conditionsFor(V)) {
+ if (!any(Flag & DomConditionFlag::PowerOfTwo))
+ continue;
Value *Cond = BI->getCondition();
BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
@@ -4930,7 +4934,9 @@ static KnownFPClass computeKnownFPClassFromContext(const Value *V,
if (Q.DC && Q.DT) {
// Handle dominating conditions.
- for (BranchInst *BI : Q.DC->conditionsFor(V)) {
+ for (auto [BI, Flag] : Q.DC->conditionsFor(V)) {
+ if (!any(Flag & DomConditionFlag::KnownFPClass))
+ continue;
Value *Cond = BI->getCondition();
BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
@@ -10014,36 +10020,38 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
return CR;
}
-static void
-addValueAffectedByCondition(Value *V,
- function_ref<void(Value *)> InsertAffected) {
+static void addValueAffectedByCondition(
+ Value *V, function_ref<void(Value *, DomConditionFlag)> InsertAffected,
+ DomConditionFlag Flags) {
assert(V != nullptr);
if (isa<Argument>(V) || isa<GlobalValue>(V)) {
- InsertAffected(V);
+ InsertAffected(V, Flags);
} else if (auto *I = dyn_cast<Instruction>(V)) {
- InsertAffected(V);
+ InsertAffected(V, Flags);
// Peek through unary operators to find the source of the condition.
Value *Op;
if (match(I, m_CombineOr(m_PtrToInt(m_Value(Op)), m_Trunc(m_Value(Op))))) {
if (isa<Instruction>(Op) || isa<Argument>(Op))
- InsertAffected(Op);
+ InsertAffected(Op, Flags);
}
}
}
void llvm::findValuesAffectedByCondition(
- Value *Cond, bool IsAssume, function_ref<void(Value *)> InsertAffected) {
- auto AddAffected = [&InsertAffected](Value *V) {
- addValueAffectedByCondition(V, InsertAffected);
+ Value *Cond, bool IsAssume,
+ function_ref<void(Value *, DomConditionFlag)> InsertAffected) {
+ auto AddAffected = [&InsertAffected](Value *V, DomConditionFlag Flags) {
+ addValueAffectedByCondition(V, InsertAffected, Flags);
};
- auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS) {
+ auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS,
+ DomConditionFlag Flags) {
if (IsAssume) {
- AddAffected(LHS);
- AddAffected(RHS);
+ AddAffected(LHS, Flags);
+ AddAffected(RHS, Flags);
} else if (match(RHS, m_Constant()))
- AddAffected(LHS);
+ AddAffected(LHS, Flags);
};
SmallVector<Value *, 8> Worklist;
@@ -10058,9 +10066,9 @@ void llvm::findValuesAffectedByCondition(
Value *A, *B, *X;
if (IsAssume) {
- AddAffected(V);
+ AddAffected(V, DomConditionFlag::KnownBits);
if (match(V, m_Not(m_Value(X))))
- AddAffected(X);
+ AddAffected(X, DomConditionFlag::KnownBits);
}
if (match(V, m_LogicalOp(m_Value(A), m_Value(B)))) {
@@ -10074,7 +10082,8 @@ void llvm::findValuesAffectedByCondition(
Worklist.push_back(B);
}
} else if (match(V, m_ICmp(Pred, m_Value(A), m_Value(B)))) {
- AddCmpOperands(A, B);
+ AddCmpOperands(A, B,
+ DomConditionFlag::KnownBits | DomConditionFlag::ICmp);
bool HasRHSC = match(B, m_ConstantInt());
if (ICmpInst::isEquality(Pred)) {
@@ -10084,11 +10093,11 @@ void llvm::findValuesAffectedByCondition(
// (X << C) or (X >>_s C) or (X >>_u C).
if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
match(A, m_Shift(m_Value(X), m_ConstantInt())))
- AddAffected(X);
+ AddAffected(X, DomConditionFlag::KnownBits);
else if (match(A, m_And(m_Value(X), m_Value(Y))) ||
match(A, m_Or(m_Value(X), m_Value(Y)))) {
- AddAffected(X);
- AddAffected(Y);
+ AddAffected(X, DomConditionFlag::KnownBits);
+ AddAffected(Y, DomConditionFlag::KnownBits);
}
}
} else {
@@ -10096,7 +10105,7 @@ void llvm::findValuesAffectedByCondition(
// Handle (A + C1) u< C2, which is the canonical form of
// A > C3 && A < C4.
if (match(A, m_AddLike(m_Value(X), m_ConstantInt())))
- AddAffected(X);
+ AddAffected(X, DomConditionFlag::KnownBits);
if (ICmpInst::isUnsigned(Pred)) {
Value *Y;
@@ -10106,12 +10115,12 @@ void llvm::findValuesAffectedByCondition(
if (match(A, m_And(m_Value(X), m_Value(Y))) ||
match(A, m_Or(m_Value(X), m_Value(Y))) ||
match(A, m_NUWAdd(m_Value(X), m_Value(Y)))) {
- AddAffected(X);
- AddAffected(Y);
+ AddAffected(X, DomConditionFlag::KnownBits);
+ AddAffected(Y, DomConditionFlag::KnownBits);
}
// X nuw- Y u> C -> X u> C
if (match(A, m_NUWSub(m_Value(X), m_Value())))
- AddAffected(X);
+ AddAffected(X, DomConditionFlag::KnownBits);
}
}
@@ -10119,29 +10128,29 @@ void llvm::findValuesAffectedByCondition(
// by computeKnownFPClass().
if (match(A, m_ElementWiseBitCast(m_Value(X)))) {
if (Pred == ICmpInst::ICMP_SLT && match(B, m_Zero()))
- InsertAffected(X);
+ InsertAffected(X, DomConditionFlag::KnownFPClass);
else if (Pred == ICmpInst::ICMP_SGT && match(B, m_AllOnes()))
- InsertAffected(X);
+ InsertAffected(X, DomConditionFlag::KnownFPClass);
}
}
if (HasRHSC && match(A, m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))
- AddAffected(X);
+ AddAffected(X, DomConditionFlag::PowerOfTwo);
} else if (match(V, m_FCmp(Pred, m_Value(A), m_Value(B)))) {
- AddCmpOperands(A, B);
+ AddCmpOperands(A, B, DomConditionFlag::KnownFPClass);
// fcmp fneg(x), y
// fcmp fabs(x), y
// fcmp fneg(fabs(x)), y
if (match(A, m_FNeg(m_Value(A))))
- AddAffected(A);
+ AddAffected(A, DomConditionFlag::KnownFPClass);
if (match(A, m_FAbs(m_Value(A))))
- AddAffected(A);
+ AddAffected(A, DomConditionFlag::KnownFPClass);
} else if (match(V, m_Intrinsic<Intrinsic::is_fpclass>(m_Value(A),
m_Value()))) {
// Handle patterns that computeKnownFPClass() support.
- AddAffected(A);
+ AddAffected(A, DomConditionFlag::KnownFPClass);
}
}
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index fed21db393ed22..5f635cc41a94f7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1385,7 +1385,9 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
return nullptr;
};
- for (BranchInst *BI : DC.conditionsFor(X)) {
+ for (auto [BI, Flags] : DC.conditionsFor(X)) {
+ if (!any(Flags & DomConditionFlag::ICmp))
+ continue;
ICmpInst::Predicate DomPred;
const APInt *DomC;
if (!match(BI->getCondition(),
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index c7a0c35d099cc4..e792190f95e082 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -4293,9 +4293,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
(!isa<Constant>(TrueVal) || !isa<Constant>(FalseVal))) {
// Try to simplify select arms based on KnownBits implied by the condition.
CondContext CC(CondVal);
- findValuesAffectedByCondition(CondVal, /*IsAssume=*/false, [&](Value *V) {
- CC.AffectedValues.insert(V);
- });
+ findValuesAffectedByCondition(
+ CondVal, /*IsAssume=*/false, [&](Value *V, DomConditionFlag Flags) {
+ if (any(Flags & DomConditionFlag::KnownBits))
+ CC.AffectedValues.insert(V);
+ });
SimplifyQuery Q = SQ.getWithInstruction(&SI).getWithCondContext(CC);
if (!CC.AffectedValues.empty()) {
if (!isa<Constant>(TrueVal) &&
More information about the llvm-commits
mailing list