[llvm] [Semilattice] Introduce for dataflow analysis with KnownBits (PR #177616)
Ramkumar Ramachandra via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 9 08:21:22 PDT 2026
https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/177616
>From e67faca01738944716f112c293e0485c5744c73b Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <artagnon at tenstorrent.com>
Date: Tue, 20 Jan 2026 02:57:33 +0000
Subject: [PATCH 1/4] [KnownBitsAnalysis] Introduce KnownBitsDataflow
Introduce the first component of KnownBitsAnalysis, KnownBitsDataflow,
that caches KnownBits for an entire function, with an invalidation API
that invalidates in a dataflow-dependent manner. There are no APIs to
modify the created dataflow graph yet.
---
.../include/llvm/Analysis/KnownBitsAnalysis.h | 100 +++++
llvm/include/llvm/Support/KnownBits.h | 8 +
llvm/lib/Analysis/CMakeLists.txt | 1 +
llvm/lib/Analysis/KnownBitsAnalysis.cpp | 150 +++++++
llvm/unittests/Analysis/CMakeLists.txt | 1 +
.../Analysis/KnownBitsAnalysisTest.cpp | 371 ++++++++++++++++++
6 files changed, 631 insertions(+)
create mode 100644 llvm/include/llvm/Analysis/KnownBitsAnalysis.h
create mode 100644 llvm/lib/Analysis/KnownBitsAnalysis.cpp
create mode 100644 llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp
diff --git a/llvm/include/llvm/Analysis/KnownBitsAnalysis.h b/llvm/include/llvm/Analysis/KnownBitsAnalysis.h
new file mode 100644
index 0000000000000..f9a0baaf2dd1b
--- /dev/null
+++ b/llvm/include/llvm/Analysis/KnownBitsAnalysis.h
@@ -0,0 +1,100 @@
+//===- KnownBitsAnalysis.h - An analysis that caches KnownBits ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Used to cache KnownBits for the entire function, with dataflow-dependent
+// invalidation.
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_KNOWNBITSANALYSIS_H
+#define LLVM_ANALYSIS_KNOWNBITSANALYSIS_H
+
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/KnownBits.h"
+
+namespace llvm {
+class Value;
+class User;
+class Function;
+class DataLayout;
+class raw_ostream;
+
+class KnownBitsDataflow : protected DenseMap<const Value *, KnownBits> {
+ /// The roots are the arguments of the function, and PHI nodes and
+ /// Instructions like fptosi in each Basic Block, filtered for integer,
+ /// pointer, or vector thereof, types.
+ SmallVector<const Value *> Roots;
+
+ SmallVector<const Value *> getLeaves(ArrayRef<const Value *> Roots) const;
+ void emplace_all_conflict(const Value *V); // NOLINT
+
+ template <typename RangeT>
+ SmallVector<Value *> insert_range(RangeT R); // NOLINT
+
+ void recurseInsertChildren(ArrayRef<Value *> R);
+ template <typename RangeT> void initialize(RangeT R);
+ void initialize(Function &F);
+
+protected:
+ const DataLayout &DL;
+ LLVM_ABI_FOR_TEST KnownBits &getKnownBits(const Value *V) {
+ return operator[](V);
+ }
+ LLVM_ABI_FOR_TEST void setAllConflict(const Value *V) {
+ assert(contains(V) && "Expected Value in map");
+ getKnownBits(V).setAllConflict();
+ }
+ LLVM_ABI_FOR_TEST void setAllZero(const Value *V) {
+ assert(contains(V) && "Expected Value in map");
+ getKnownBits(V).setAllZero();
+ }
+ LLVM_ABI_FOR_TEST void setAllOnes(const Value *V) {
+ assert(contains(V) && "Expected Value in map");
+ getKnownBits(V).setAllOnes();
+ }
+ LLVM_ABI_FOR_TEST bool isAllConflict(const Value *V) const {
+ return at(V).isAllConflict();
+ }
+ LLVM_ABI_FOR_TEST ArrayRef<const Value *> getRoots() const;
+ LLVM_ABI_FOR_TEST SmallVector<const Value *> getLeaves() const;
+ LLVM_ABI_FOR_TEST void intersectWith(const Value *V, KnownBits Known) {
+ assert(contains(V) && "Expected Value in map");
+ KnownBits &K = getKnownBits(V);
+ K = K.intersectWith(Known);
+ }
+
+public:
+ LLVM_ABI KnownBitsDataflow(Function &F);
+ LLVM_ABI KnownBitsDataflow(const KnownBitsDataflow &) = delete;
+ LLVM_ABI KnownBitsDataflow &operator=(const KnownBitsDataflow &) = delete;
+
+ LLVM_ABI bool empty() const {
+ return DenseMap<const Value *, KnownBits>::empty();
+ }
+ LLVM_ABI size_t size() const {
+ return DenseMap<const Value *, KnownBits>::size();
+ }
+ LLVM_ABI bool contains(const Value *V) const {
+ return DenseMap<const Value *, KnownBits>::contains(V);
+ }
+ LLVM_ABI KnownBits at(const Value *V) const {
+ return DenseMap<const Value *, KnownBits>::at(V);
+ }
+
+ /// Invalidates KnownBits corresponding to \p V, and all dependent values in
+ /// dataflow, and returns the invalidated leaves.
+ LLVM_ABI SmallVector<const Value *> invalidate(const Value *V);
+
+ LLVM_ABI void print(raw_ostream &OS) const;
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+} // end namespace llvm
+
+#endif // LLVM_ANALYSIS_KNOWNBITSANALYSIS_H
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 2ac4d330714a1..647d45eec3001 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -40,6 +40,11 @@ struct KnownBits {
/// Create a known bits object of BitWidth bits initialized to unknown.
KnownBits(unsigned BitWidth) : Zero(BitWidth, 0), One(BitWidth, 0) {}
+ /// Create a known bits object of BitWidth bits initialized to AllConflict.
+ static KnownBits getAllConflict(unsigned BitWidth) {
+ return {APInt::getAllOnes(BitWidth), APInt::getAllOnes(BitWidth)};
+ }
+
/// Get the bit width of this value.
unsigned getBitWidth() const {
assert(Zero.getBitWidth() == One.getBitWidth() &&
@@ -65,6 +70,9 @@ struct KnownBits {
/// Returns true if we don't know any bits.
bool isUnknown() const { return Zero.isZero() && One.isZero(); }
+ /// Returns true if all bits conflict.
+ bool isAllConflict() const { return Zero.isAllOnes() && One.isAllOnes(); }
+
/// Returns true if we don't know the sign bit.
bool isSignUnknown() const {
return !Zero.isSignBitSet() && !One.isSignBitSet();
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index f3586c66cb056..75def0dd0e323 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -98,6 +98,7 @@ add_llvm_component_library(LLVMAnalysis
InstructionSimplify.cpp
InteractiveModelRunner.cpp
KernelInfo.cpp
+ KnownBitsAnalysis.cpp
LastRunTrackingAnalysis.cpp
LazyBranchProbabilityInfo.cpp
LazyBlockFrequencyInfo.cpp
diff --git a/llvm/lib/Analysis/KnownBitsAnalysis.cpp b/llvm/lib/Analysis/KnownBitsAnalysis.cpp
new file mode 100644
index 0000000000000..ba6246f2a9b19
--- /dev/null
+++ b/llvm/lib/Analysis/KnownBitsAnalysis.cpp
@@ -0,0 +1,150 @@
+//===- KnownBitsAnalysis.cpp - An analysis that caches KnownBits ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/KnownBitsAnalysis.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/GraphTraits.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/KnownBits.h"
+
+using namespace llvm;
+
+template <typename NodeRef, typename ChildIteratorType>
+struct NodeGraphTraitsBase {
+ static NodeRef getEntryNode(NodeRef N) { return N; }
+ static ChildIteratorType child_begin(NodeRef N) { // NOLINT
+ return N->user_begin();
+ }
+ static ChildIteratorType child_end(NodeRef N) { // NOLINT
+ return N->user_end();
+ }
+};
+
+template <>
+struct GraphTraits<Value *>
+ : public NodeGraphTraitsBase<Value *, Value::user_iterator> {
+ using NodeRef = Value *;
+ using ChildIteratorType = Value::user_iterator;
+};
+
+template <>
+struct GraphTraits<const Value *>
+ : public NodeGraphTraitsBase<const Value *, Value::const_user_iterator> {
+ using NodeRef = const Value *;
+ using ChildIteratorType = Value::const_user_iterator;
+};
+
+/// A wrapper around make_filter_range, that filters \p R for integer, pointer,
+/// or vector thereof, types, excluding values that match \p ExcludeFn.
+template <typename RangeT>
+static auto filter_range( // NOLINT
+ RangeT R, std::function<bool(const Value *)> ExcludeFn = [](const Value *) {
+ return false;
+ }) {
+ return make_filter_range(R, [&](const Value *V) {
+ return !ExcludeFn(V) && V->getType()->getScalarType()->isIntOrPtrTy();
+ });
+}
+
+static bool isLeaf(const Value *V) { return filter_range(V->users()).empty(); }
+
+ArrayRef<const Value *> KnownBitsDataflow::getRoots() const { return Roots; }
+
+SmallVector<const Value *>
+KnownBitsDataflow::getLeaves(ArrayRef<const Value *> Roots) const {
+ SetVector<const Value *> Leaves;
+ for (const Value *R : Roots)
+ for (const Value *N : filter_range(depth_first(R)))
+ if (isLeaf(N))
+ Leaves.insert(N);
+ return Leaves.takeVector();
+}
+
+SmallVector<const Value *> KnownBitsDataflow::getLeaves() const {
+ return getLeaves(getRoots());
+}
+
+void KnownBitsDataflow::emplace_all_conflict(const Value *V) {
+ emplace_or_assign(V, KnownBits::getAllConflict(DL.getTypeSizeInBits(
+ V->getType()->getScalarType())));
+}
+
+template <typename RangeT>
+SmallVector<Value *> KnownBitsDataflow::insert_range(RangeT R) {
+ SmallVector<Value *> Filtered(
+ filter_range(R, bind_front(&KnownBitsDataflow::contains, this)));
+ for (Value *V : Filtered)
+ emplace_all_conflict(V);
+ return Filtered;
+}
+
+void KnownBitsDataflow::recurseInsertChildren(ArrayRef<Value *> R) {
+ for (Value *V : R)
+ recurseInsertChildren(insert_range(
+ map_range(V->users(), [](User *U) -> Value * { return U; })));
+}
+
+template <typename RangeT> void KnownBitsDataflow::initialize(RangeT R) {
+ for (auto *V : R) {
+ emplace_all_conflict(V);
+ recurseInsertChildren(V);
+ }
+ append_range(Roots, R);
+}
+
+void KnownBitsDataflow::initialize(Function &F) {
+ // First, initialize with function arguments.
+ initialize(filter_range(llvm::make_pointer_range(F.args())));
+ for (BasicBlock &BB : F) {
+ // Now initialize with all Instructions in the BB that weren't seen.
+ initialize(filter_range(llvm::make_pointer_range(BB),
+ bind_front(&KnownBitsDataflow::contains, this)));
+ }
+}
+
+KnownBitsDataflow::KnownBitsDataflow(Function &F) : DL(F.getDataLayout()) {
+ initialize(F);
+}
+
+SmallVector<const Value *> KnownBitsDataflow::invalidate(const Value *V) {
+ SmallVector<const Value *> Leaves;
+ for (const Value *N : filter_range(depth_first(V))) {
+ setAllConflict(N);
+ if (isLeaf(N))
+ Leaves.push_back(N);
+ }
+ return Leaves;
+}
+
+void KnownBitsDataflow::print(raw_ostream &OS) const {
+ for (const Value *R : getRoots()) {
+ OS << "^ ";
+ R->print(OS);
+ OS << " | ";
+ at(R).print(OS);
+ OS << "\n";
+ for (const Value *V : filter_range(drop_begin(depth_first(R)))) {
+ if (isLeaf(V))
+ OS << "$ ";
+ else
+ OS << " ";
+ V->print(OS);
+ OS << " | ";
+ at(V).print(OS);
+ OS << "\n";
+ }
+ }
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+LLVM_DUMP_METHOD void KnownBitsDataflow::dump() const { print(dbgs()); }
+#endif
diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 50bf4539e7984..51147f83da1b8 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -36,6 +36,7 @@ set(ANALYSIS_TEST_SOURCES
IR2VecTest.cpp
IRSimilarityIdentifierTest.cpp
IVDescriptorsTest.cpp
+ KnownBitsAnalysisTest.cpp
LastRunTrackingAnalysisTest.cpp
LazyCallGraphTest.cpp
LoadsTest.cpp
diff --git a/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp b/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp
new file mode 100644
index 0000000000000..ea1539f05b538
--- /dev/null
+++ b/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp
@@ -0,0 +1,371 @@
+//===- KnownBitsAnalysisTest.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/KnownBitsAnalysis.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/SourceMgr.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+static Instruction *findInstructionByName(Function *F, StringRef Name) {
+ for (Instruction &I : instructions(F))
+ if (I.getName() == Name)
+ return &I;
+ return nullptr;
+}
+
+std::unique_ptr<Module> parseIR(LLVMContext &Ctx, StringRef Assembly) {
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M = parseAssemblyString(Assembly, Err, Ctx);
+ if (!M)
+ Err.print(__FILE__, errs());
+ return M;
+}
+
+/// KnownBitsDataflow tests follow.
+struct KnownBitsDataflowForTest : public KnownBitsDataflow {
+ KnownBitsDataflowForTest(Function &F) : KnownBitsDataflow(F) {}
+ ArrayRef<const Value *> getRoots() const {
+ return KnownBitsDataflow::getRoots();
+ }
+ SmallVector<const Value *> getLeaves() const {
+ return KnownBitsDataflow::getLeaves();
+ }
+ void intersectWith(const Value *V, KnownBits Known) {
+ return KnownBitsDataflow::intersectWith(V, Known);
+ }
+ void setAllZero(const Value *V) { return KnownBitsDataflow::setAllZero(V); }
+ void setAllOnes(const Value *V) { return KnownBitsDataflow::setAllOnes(V); }
+ bool isAllConflict(const Value *V) const {
+ return KnownBitsDataflow::isAllConflict(V);
+ }
+ bool isAllOnes(const Value *V) const {
+ KnownBits Known = at(V);
+ return Known.Zero.isZero() && Known.One.isAllOnes();
+ }
+};
+
+TEST(KnownBitsDataflow, BasicConstruction) {
+ LLVMContext Ctx;
+ std::unique_ptr<Module> M = parseIR(Ctx, R"(
+define void @test(i32 %n) {
+entry:
+ br label %loop
+loop:
+ %phi_counter = phi i32 [ 0, %entry ], [ %next_counter, %loop ]
+ %phi_result = phi i32 [ 1, %entry ], [ %result, %loop ]
+ %counter = add i32 %phi_counter, 1
+ %result = mul i32 %phi_result, 2
+ %next_counter = add i32 %counter, 1
+ %cond = icmp slt i32 %next_counter, %n
+ br i1 %cond, label %loop, label %exit
+exit:
+ store i32 %result, ptr poison
+ ret void
+})");
+ Function *F = M->getFunction("test");
+ KnownBitsDataflowForTest Lat(*F);
+ Argument *ArgN = &*F->arg_begin();
+ Instruction *Counter = findInstructionByName(F, "counter");
+ Instruction *NextCounter = findInstructionByName(F, "next_counter");
+ Instruction *Result = findInstructionByName(F, "result");
+ Instruction *PhiCounter = findInstructionByName(F, "phi_counter");
+ Instruction *PhiResult = findInstructionByName(F, "phi_result");
+ Instruction *Cond = findInstructionByName(F, "cond");
+
+ EXPECT_TRUE(Lat.isAllConflict(ArgN));
+ EXPECT_TRUE(Lat.isAllConflict(PhiCounter));
+ EXPECT_TRUE(Lat.isAllConflict(PhiResult));
+ EXPECT_TRUE(Lat.isAllConflict(Counter));
+ EXPECT_TRUE(Lat.isAllConflict(Result));
+ EXPECT_TRUE(Lat.isAllConflict(NextCounter));
+ EXPECT_TRUE(Lat.isAllConflict(Cond));
+ EXPECT_EQ(Lat.size(), 7u);
+}
+
+TEST(KnownBitsDataflow, ConstructionWithIntAndPtr) {
+ LLVMContext Ctx;
+ std::unique_ptr<Module> M = parseIR(Ctx, R"(
+define void @test(i32 %int_arg, float %float_arg, ptr %ptr_arg, <2 x i32> %vec_int_arg, <2 x ptr> %vec_ptr_arg) {
+entry:
+ br i1 poison, label %then, label %else
+then:
+ %int_val = add i32 %int_arg, 1
+ %float_val = fadd float %float_arg, 1.0
+ %vec_val = add <2 x i32> %vec_int_arg, <i32 1, i32 2>
+ br label %merge
+else:
+ %fpconv = fptoui float %float_arg to i32
+ %int_val2 = mul i32 %int_arg, %fpconv
+ %ptr_val = getelementptr i8, ptr %ptr_arg, i32 4
+ %vec_val2 = mul <2 x i32> %vec_int_arg, <i32 3, i32 4>
+ br label %merge
+merge:
+ %phi_int = phi i32 [ %int_val, %then ], [ %int_val2, %else ]
+ %phi_float = phi float [ %float_val, %then ], [ %float_arg, %else ]
+ %phi_ptr = phi ptr [ %ptr_arg, %then ], [ %ptr_val, %else ]
+ %phi_vec = phi <2 x i32> [ %vec_val, %then ], [ %vec_val2, %else ]
+ %final_int = add i32 %phi_int, 5
+ %vec_ptr_conv = ptrtoint <2 x ptr> %vec_ptr_arg to <2 x i32>
+ %final_vec = add <2 x i32> %phi_vec, %vec_ptr_conv
+ store float %phi_float, ptr %phi_ptr
+ ret void
+})");
+ Function *F = M->getFunction("test");
+ KnownBitsDataflowForTest Lat(*F);
+ auto *ArgIt = F->arg_begin();
+ Argument *IntArg = &*ArgIt++;
+ Argument *FloatArg = &*ArgIt++;
+ Argument *PtrArg = &*ArgIt++;
+ Argument *VecIntArg = &*ArgIt++;
+ Argument *VecPtrArg = &*ArgIt++;
+ Instruction *IntVal = findInstructionByName(F, "int_val");
+ Instruction *FloatVal = findInstructionByName(F, "float_val");
+ Instruction *VecVal = findInstructionByName(F, "vec_val");
+ Instruction *IntVal2 = findInstructionByName(F, "int_val2");
+ Instruction *PtrVal = findInstructionByName(F, "ptr_val");
+ Instruction *VecVal2 = findInstructionByName(F, "vec_val2");
+ Instruction *PhiInt = findInstructionByName(F, "phi_int");
+ Instruction *PhiFloat = findInstructionByName(F, "phi_float");
+ Instruction *PhiPtr = findInstructionByName(F, "phi_ptr");
+ Instruction *PhiVec = findInstructionByName(F, "phi_vec");
+ Instruction *FinalInt = findInstructionByName(F, "final_int");
+ Instruction *FPConv = findInstructionByName(F, "fpconv");
+ Instruction *VecPtrConv = findInstructionByName(F, "vec_ptr_conv");
+ Instruction *FinalVec = findInstructionByName(F, "final_vec");
+
+ EXPECT_THAT(Lat.getRoots(), ::testing::ElementsAre(IntArg, PtrArg, VecIntArg,
+ VecPtrArg, FPConv));
+ EXPECT_THAT(Lat.getLeaves(),
+ ::testing::ElementsAre(FinalInt, PhiPtr, FinalVec));
+
+ EXPECT_TRUE(Lat.isAllConflict(IntArg));
+ EXPECT_FALSE(Lat.contains(FloatArg));
+ EXPECT_TRUE(Lat.isAllConflict(PtrArg));
+ EXPECT_TRUE(Lat.isAllConflict(VecIntArg));
+ EXPECT_TRUE(Lat.isAllConflict(VecPtrArg));
+
+ EXPECT_TRUE(Lat.isAllConflict(IntVal));
+ EXPECT_TRUE(Lat.isAllConflict(IntVal2));
+ EXPECT_TRUE(Lat.isAllConflict(PhiInt));
+ EXPECT_TRUE(Lat.isAllConflict(FinalInt));
+ EXPECT_TRUE(Lat.isAllConflict(VecVal));
+ EXPECT_TRUE(Lat.isAllConflict(VecVal2));
+ EXPECT_TRUE(Lat.isAllConflict(PhiVec));
+ EXPECT_TRUE(Lat.isAllConflict(FPConv));
+ EXPECT_TRUE(Lat.isAllConflict(VecPtrConv));
+ EXPECT_TRUE(Lat.isAllConflict(FinalVec));
+ EXPECT_FALSE(Lat.contains(FloatVal));
+ EXPECT_TRUE(Lat.isAllConflict(PtrVal));
+ EXPECT_FALSE(Lat.contains(PhiFloat));
+ EXPECT_TRUE(Lat.isAllConflict(PhiPtr));
+
+ EXPECT_EQ(Lat.size(), 16u);
+}
+
+TEST(KnownBitsDataflow, ConstructionWithNestedLoop) {
+ LLVMContext Ctx;
+ std::unique_ptr<Module> M = parseIR(Ctx, R"(
+define void @test(i32 %n, i32 %m) {
+entry:
+ br label %outer_loop
+outer_loop:
+ %outer_phi = phi i32 [ 0, %entry ], [ %outer_next, %outer_latch ]
+ br label %inner_loop
+inner_loop:
+ %inner_phi = phi i32 [ 0, %outer_loop ], [ %inner_next, %inner_loop ]
+ %inner_next = add i32 %inner_phi, 1
+ %inner_cond = icmp slt i32 %inner_next, %m
+ br i1 %inner_cond, label %inner_loop, label %outer_latch
+outer_latch:
+ %outer_next = add i32 %outer_phi, 1
+ %outer_cond = icmp slt i32 %outer_next, %n
+ br i1 %outer_cond, label %outer_loop, label %exit
+exit:
+ ret void
+})");
+ Function *F = M->getFunction("test");
+ KnownBitsDataflowForTest Lat(*F);
+ auto *ArgIt = F->arg_begin();
+ Argument *ArgN = &*ArgIt++;
+ Argument *ArgM = &*ArgIt;
+ Instruction *OuterPHI = findInstructionByName(F, "outer_phi");
+ Instruction *InnerPHI = findInstructionByName(F, "inner_phi");
+ Instruction *InnerNext = findInstructionByName(F, "inner_next");
+ Instruction *OuterNext = findInstructionByName(F, "outer_next");
+ Instruction *InnerCond = findInstructionByName(F, "inner_cond");
+ Instruction *OuterCond = findInstructionByName(F, "outer_cond");
+
+ EXPECT_THAT(Lat.getRoots(),
+ ::testing::ElementsAre(ArgN, ArgM, OuterPHI, InnerPHI));
+ EXPECT_THAT(Lat.getLeaves(), ::testing::ElementsAre(OuterCond, InnerCond));
+
+ EXPECT_TRUE(Lat.isAllConflict(ArgN));
+ EXPECT_TRUE(Lat.isAllConflict(ArgM));
+ EXPECT_TRUE(Lat.isAllConflict(OuterPHI));
+ EXPECT_TRUE(Lat.isAllConflict(InnerPHI));
+ EXPECT_TRUE(Lat.isAllConflict(InnerNext));
+ EXPECT_TRUE(Lat.isAllConflict(OuterNext));
+ EXPECT_TRUE(Lat.isAllConflict(InnerCond));
+ EXPECT_TRUE(Lat.isAllConflict(OuterCond));
+ EXPECT_EQ(Lat.size(), 8u);
+}
+
+TEST(KnownBitsDataflow, InvalidateKnownBitsSingleBB) {
+ LLVMContext Ctx;
+ std::unique_ptr<Module> M = parseIR(Ctx, R"(
+define void @test(i32 %arg, <2 x i32> %vec_arg) {
+ %counter = add i32 %arg, 1
+ %result = mul i32 %counter, 2
+ %next_counter = add i32 %result, 3
+ %branch_val = sub i32 %next_counter, 1
+ %merge_val = add i32 %branch_val, 5
+ store i32 %merge_val, ptr poison
+ ret void
+})");
+ Function *F = M->getFunction("test");
+ KnownBitsDataflowForTest Lat(*F);
+ Argument *Arg = &*F->arg_begin();
+ Argument *VecArg = &*F->arg_begin();
+ Instruction *Counter = findInstructionByName(F, "counter");
+ Instruction *NextCounter = findInstructionByName(F, "next_counter");
+ Instruction *Result = findInstructionByName(F, "result");
+ Instruction *BranchVal = findInstructionByName(F, "branch_val");
+ Instruction *MergeVal = findInstructionByName(F, "merge_val");
+ KnownBits Known32(32);
+ Known32.setAllOnes();
+
+ Lat.intersectWith(Arg, Known32);
+ Lat.intersectWith(Counter, Known32);
+ Lat.intersectWith(Result, Known32);
+ Lat.intersectWith(NextCounter, Known32);
+ Lat.intersectWith(BranchVal, Known32);
+ Lat.intersectWith(MergeVal, Known32);
+ Lat.setAllOnes(VecArg);
+
+ EXPECT_TRUE(Lat.isAllOnes(Arg));
+ EXPECT_TRUE(Lat.isAllOnes(Counter));
+ EXPECT_TRUE(Lat.isAllOnes(Result));
+ EXPECT_TRUE(Lat.isAllOnes(NextCounter));
+ EXPECT_TRUE(Lat.isAllOnes(BranchVal));
+ EXPECT_TRUE(Lat.isAllOnes(MergeVal));
+ EXPECT_TRUE(Lat.isAllOnes(VecArg));
+
+ auto InvalidatedLeaves = Lat.invalidate(Counter);
+ EXPECT_THAT(InvalidatedLeaves, ::testing::ElementsAre(MergeVal));
+
+ EXPECT_TRUE(Lat.isAllOnes(Arg));
+ EXPECT_TRUE(Lat.isAllOnes(VecArg));
+ EXPECT_TRUE(Lat.isAllConflict(Counter));
+ EXPECT_TRUE(Lat.isAllConflict(Result));
+ EXPECT_TRUE(Lat.isAllConflict(NextCounter));
+ EXPECT_TRUE(Lat.isAllConflict(BranchVal));
+ EXPECT_TRUE(Lat.isAllConflict(MergeVal));
+}
+
+TEST(KnownBitsDataflow, InvalidateKnownBitsMultipleBBs) {
+ LLVMContext Ctx;
+ std::unique_ptr<Module> M = parseIR(Ctx, R"(
+define void @test(i32 %n, i1 %cond) {
+entry:
+ %counter = add i32 %n, 1
+ br i1 %cond, label %then, label %else
+then:
+ %branch_val = mul i32 %counter, 2
+ br label %merge
+else:
+ %result = add i32 %counter, 3
+ br label %merge
+merge:
+ %merge_val = phi i32 [ %branch_val, %then ], [ %result, %else ]
+ %next_counter = add i32 %merge_val, 1
+ store i32 %next_counter, ptr poison
+ ret void
+})");
+ Function *F = M->getFunction("test");
+ KnownBitsDataflowForTest Lat(*F);
+ auto *ArgIt = F->arg_begin();
+ Argument *ArgN = &*ArgIt++;
+ Argument *ArgCond = &*ArgIt;
+ Instruction *Counter = findInstructionByName(F, "counter");
+ Instruction *NextCounter = findInstructionByName(F, "next_counter");
+ Instruction *Result = findInstructionByName(F, "result");
+ Instruction *BranchVal = findInstructionByName(F, "branch_val");
+ Instruction *MergeVal = findInstructionByName(F, "merge_val");
+ KnownBits Known32(32);
+ Known32.setAllOnes();
+
+ Lat.intersectWith(ArgN, Known32);
+ Lat.intersectWith(Counter, Known32);
+ Lat.intersectWith(BranchVal, Known32);
+ Lat.intersectWith(Result, Known32);
+ Lat.intersectWith(MergeVal, Known32);
+ Lat.intersectWith(NextCounter, Known32);
+ Lat.setAllOnes(ArgCond);
+
+ EXPECT_TRUE(Lat.isAllOnes(Counter));
+ EXPECT_TRUE(Lat.isAllOnes(BranchVal));
+ EXPECT_TRUE(Lat.isAllOnes(Result));
+ EXPECT_TRUE(Lat.isAllOnes(MergeVal));
+ EXPECT_TRUE(Lat.isAllOnes(NextCounter));
+
+ auto InvalidatedLeaves = Lat.invalidate(Result);
+ EXPECT_THAT(InvalidatedLeaves, ::testing::ElementsAre(NextCounter));
+
+ EXPECT_TRUE(Lat.isAllOnes(ArgN));
+ EXPECT_TRUE(Lat.isAllOnes(ArgCond));
+ EXPECT_TRUE(Lat.isAllOnes(Counter));
+ EXPECT_TRUE(Lat.isAllOnes(BranchVal));
+ EXPECT_TRUE(Lat.isAllConflict(Result));
+ EXPECT_TRUE(Lat.isAllConflict(MergeVal));
+ EXPECT_TRUE(Lat.isAllConflict(NextCounter));
+}
+
+TEST(KnownBitsDataflow, Print) {
+ LLVMContext Ctx;
+ std::unique_ptr<Module> M = parseIR(Ctx, R"(
+define void @test(i32 %n) {
+entry:
+ br label %loop
+loop:
+ %phi_counter = phi i32 [ 0, %entry ], [ %next_counter, %loop ]
+ %counter = add i32 %phi_counter, 1
+ %result = mul i32 %counter, 2
+ %next_counter = add i32 %result, 1
+ %cond = icmp slt i32 %next_counter, %n
+ br i1 %cond, label %loop, label %exit
+exit:
+ ret void
+})");
+ Function *F = M->getFunction("test");
+ KnownBitsDataflowForTest Lat(*F);
+ Instruction *Result = findInstructionByName(F, "result");
+ Lat.setAllZero(Result);
+ std::string ActualOutput;
+ raw_string_ostream OS(ActualOutput);
+ Lat.print(OS);
+ std::string ExpectedOutput =
+ R"(^ i32 %n | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+$ %cond = icmp slt i32 %next_counter, %n | !
+^ %phi_counter = phi i32 [ 0, %entry ], [ %next_counter, %loop ] | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ %counter = add i32 %phi_counter, 1 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ %result = mul i32 %counter, 2 | 00000000000000000000000000000000
+ %next_counter = add i32 %result, 1 | !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+$ %cond = icmp slt i32 %next_counter, %n | !
+)";
+ EXPECT_EQ(ActualOutput, ExpectedOutput);
+}
+} // namespace
>From 035bdb887bce103efe10f9761960c5a34bcc474d Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <artagnon at tenstorrent.com>
Date: Thu, 19 Feb 2026 11:08:48 +0000
Subject: [PATCH 2/4] [KnownBitsCache] Copy functions from VT verbatim
---
.../include/llvm/Analysis/KnownBitsAnalysis.h | 8 +
llvm/include/llvm/Analysis/ValueTracking.h | 92 +-
.../llvm/Analysis/ValueTrackingHelper.h | 150 ++
llvm/lib/Analysis/KnownBitsAnalysis.cpp | 1472 +++++++++++++++++
llvm/lib/Analysis/ValueTracking.cpp | 30 +-
.../Analysis/KnownBitsAnalysisTest.cpp | 70 +
6 files changed, 1715 insertions(+), 107 deletions(-)
create mode 100644 llvm/include/llvm/Analysis/ValueTrackingHelper.h
diff --git a/llvm/include/llvm/Analysis/KnownBitsAnalysis.h b/llvm/include/llvm/Analysis/KnownBitsAnalysis.h
index f9a0baaf2dd1b..8eceef2ed3c9a 100644
--- a/llvm/include/llvm/Analysis/KnownBitsAnalysis.h
+++ b/llvm/include/llvm/Analysis/KnownBitsAnalysis.h
@@ -95,6 +95,14 @@ class KnownBitsDataflow : protected DenseMap<const Value *, KnownBits> {
LLVM_DUMP_METHOD void dump() const;
#endif
};
+
+class KnownBitsCache : protected KnownBitsDataflow {
+ void compute(ArrayRef<const Value *> Leaves);
+
+public:
+ KnownBitsCache(Function &F);
+ LLVM_ABI KnownBits getOrCompute(const Value *V);
+};
} // end namespace llvm
#endif // LLVM_ANALYSIS_KNOWNBITSANALYSIS_H
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 5f1e7773be8d3..59520dc9cd344 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -15,6 +15,7 @@
#define LLVM_ANALYSIS_VALUETRACKING_H
#include "llvm/Analysis/SimplifyQuery.h"
+#include "llvm/Analysis/ValueTrackingHelper.h"
#include "llvm/Analysis/WithCache.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
@@ -91,17 +92,6 @@ LLVM_ABI KnownBits computeKnownBits(const Value *V, const SimplifyQuery &Q,
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known,
const SimplifyQuery &Q, unsigned Depth = 0);
-/// Compute known bits from the range metadata.
-/// \p KnownZero the set of bits that are known to be zero
-/// \p KnownOne the set of bits that are known to be one
-LLVM_ABI void computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
- KnownBits &Known);
-
-/// Merge bits known from context-dependent facts into Known.
-LLVM_ABI void computeKnownBitsFromContext(const Value *V, KnownBits &Known,
- const SimplifyQuery &Q,
- unsigned Depth = 0);
-
/// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or).
LLVM_ABI KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I,
const KnownBits &KnownLHS,
@@ -109,13 +99,6 @@ LLVM_ABI KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I,
const SimplifyQuery &SQ,
unsigned Depth = 0);
-/// Adjust \p Known for the given select \p Arm to include information from the
-/// select \p Cond.
-LLVM_ABI void adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond,
- Value *Arm, bool Invert,
- const SimplifyQuery &Q,
- unsigned Depth = 0);
-
/// Adjust \p Known for the given select \p Arm to include information from the
/// select \p Cond.
LLVM_ABI void adjustKnownFPClassForSelectArm(KnownFPClass &Known, Value *Cond,
@@ -675,10 +658,6 @@ LLVM_ABI OverflowResult computeOverflowForSignedSub(const Value *LHS,
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
const DominatorTree &DT);
-/// Determine the possible constant range of vscale with the given bit width,
-/// based on the vscale_range function attribute.
-LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth);
-
/// Determine the possible constant range of an integer or vector of integer
/// value. This is intended as a cheap, non-recursive check.
LLVM_ABI ConstantRange computeConstantRange(const Value *V, bool ForSigned,
@@ -790,44 +769,6 @@ LLVM_ABI bool canCreatePoison(const Operator *Op,
/// impliesPoison returns true.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V);
-/// Return true if this function can prove that V does not have undef bits
-/// and is never poison. If V is an aggregate value or vector, check whether
-/// all elements (except padding) are not undef or poison.
-/// Note that this is different from canCreateUndefOrPoison because the
-/// function assumes Op's operands are not poison/undef.
-///
-/// If CtxI and DT are specified this method performs flow-sensitive analysis
-/// and returns true if it is guaranteed to be never undef or poison
-/// immediately before the CtxI.
-LLVM_ABI bool
-isGuaranteedNotToBeUndefOrPoison(const Value *V, AssumptionCache *AC = nullptr,
- const Instruction *CtxI = nullptr,
- const DominatorTree *DT = nullptr,
- unsigned Depth = 0);
-
-/// Returns true if V cannot be poison, but may be undef.
-LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V,
- AssumptionCache *AC = nullptr,
- const Instruction *CtxI = nullptr,
- const DominatorTree *DT = nullptr,
- unsigned Depth = 0);
-
-inline bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC,
- BasicBlock::iterator CtxI,
- const DominatorTree *DT = nullptr,
- unsigned Depth = 0) {
- // Takes an iterator as a position, passes down to Instruction *
- // implementation.
- return isGuaranteedNotToBePoison(V, AC, &*CtxI, DT, Depth);
-}
-
-/// Returns true if V cannot be undef, but may be poison.
-LLVM_ABI bool isGuaranteedNotToBeUndef(const Value *V,
- AssumptionCache *AC = nullptr,
- const Instruction *CtxI = nullptr,
- const DominatorTree *DT = nullptr,
- unsigned Depth = 0);
-
/// Return true if undefined behavior would provable be executed on the path to
/// OnPathTo if Root produced a posion result. Note that this doesn't say
/// anything about whether OnPathTo is actually executed or whether Root is
@@ -955,37 +896,6 @@ LLVM_ABI APInt getMinMaxLimit(SelectPatternFlavor SPF, unsigned BitWidth);
LLVM_ABI std::pair<Intrinsic::ID, bool>
canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL);
-/// Attempt to match a simple first order recurrence cycle of the form:
-/// %iv = phi Ty [%Start, %Entry], [%Inc, %backedge]
-/// %inc = binop %iv, %step
-/// OR
-/// %iv = phi Ty [%Start, %Entry], [%Inc, %backedge]
-/// %inc = binop %step, %iv
-///
-/// A first order recurrence is a formula with the form: X_n = f(X_(n-1))
-///
-/// A couple of notes on subtleties in that definition:
-/// * The Step does not have to be loop invariant. In math terms, it can
-/// be a free variable. We allow recurrences with both constant and
-/// variable coefficients. Callers may wish to filter cases where Step
-/// does not dominate P.
-/// * For non-commutative operators, we will match both forms. This
-/// results in some odd recurrence structures. Callers may wish to filter
-/// out recurrences where the phi is not the LHS of the returned operator.
-/// * Because of the structure matched, the caller can assume as a post
-/// condition of the match the presence of a Loop with P's parent as it's
-/// header *except* in unreachable code. (Dominance decays in unreachable
-/// code.)
-///
-/// NOTE: This is intentional simple. If you want the ability to analyze
-/// non-trivial loop conditons, see ScalarEvolution instead.
-LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
- Value *&Start, Value *&Step);
-
-/// Analogous to the above, but starting from the binary operator
-LLVM_ABI bool matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
- Value *&Start, Value *&Step);
-
/// Attempt to match a simple value-accumulating recurrence of the form:
/// %llvm.intrinsic.acc = phi Ty [%Init, %Entry], [%llvm.intrinsic, %backedge]
/// %llvm.intrinsic = call Ty @llvm.intrinsic(%OtherOp, %llvm.intrinsic.acc)
diff --git a/llvm/include/llvm/Analysis/ValueTrackingHelper.h b/llvm/include/llvm/Analysis/ValueTrackingHelper.h
new file mode 100644
index 0000000000000..a934837fd962c
--- /dev/null
+++ b/llvm/include/llvm/Analysis/ValueTrackingHelper.h
@@ -0,0 +1,150 @@
+//===- ValueTrackingHelper.h - Helper functions for ValueTracking ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_VALUETRACKINGHELPER_H
+#define LLVM_ANALYSIS_VALUETRACKINGHELPER_H
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/Support/Compiler.h"
+
+namespace llvm {
+class Type;
+class Use;
+class Value;
+class Instruction;
+class PHINode;
+class BinaryOperator;
+class IntrinsicInst;
+class ShuffleVectorInst;
+class Function;
+class DataLayout;
+class MDNode;
+class ConstantRange;
+class AssumptionCache;
+class DominatorTree;
+class CmpPredicate;
+struct KnownBits;
+struct SimplifyQuery;
+
+/// Attempt to match a simple first order recurrence cycle of the form:
+/// %iv = phi Ty [%Start, %Entry], [%Inc, %backedge]
+/// %inc = binop %iv, %step
+/// OR
+/// %iv = phi Ty [%Start, %Entry], [%Inc, %backedge]
+/// %inc = binop %step, %iv
+///
+/// A first order recurrence is a formula with the form: X_n = f(X_(n-1))
+///
+/// A couple of notes on subtleties in that definition:
+/// * The Step does not have to be loop invariant. In math terms, it can
+/// be a free variable. We allow recurrences with both constant and
+/// variable coefficients. Callers may wish to filter cases where Step
+/// does not dominate P.
+/// * For non-commutative operators, we will match both forms. This
+/// results in some odd recurrence structures. Callers may wish to filter
+/// out recurrences where the phi is not the LHS of the returned operator.
+/// * Because of the structure matched, the caller can assume as a post
+/// condition of the match the presence of a Loop with P's parent as it's
+/// header *except* in unreachable code. (Dominance decays in unreachable
+/// code.)
+///
+/// NOTE: This is intentional simple. If you want the ability to analyze
+/// non-trivial loop conditons, see ScalarEvolution instead.
+LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
+ Value *&Start, Value *&Step);
+
+/// Analogous to the above, but starting from the binary operator
+LLVM_ABI bool matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
+ Value *&Start, Value *&Step);
+
+/// Determine the possible constant range of vscale with the given bit width,
+/// based on the vscale_range function attribute.
+LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth);
+
+/// Return true if this function can prove that V does not have undef bits
+/// and is never poison. If V is an aggregate value or vector, check whether
+/// all elements (except padding) are not undef or poison.
+/// Note that this is different from canCreateUndefOrPoison because the
+/// function assumes Op's operands are not poison/undef.
+///
+/// If CtxI and DT are specified this method performs flow-sensitive analysis
+/// and returns true if it is guaranteed to be never undef or poison
+/// immediately before the CtxI.
+LLVM_ABI bool
+isGuaranteedNotToBeUndefOrPoison(const Value *V, AssumptionCache *AC = nullptr,
+ const Instruction *CtxI = nullptr,
+ const DominatorTree *DT = nullptr,
+ unsigned Depth = 0);
+
+/// Returns true if V cannot be poison, but may be undef.
+LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V,
+ AssumptionCache *AC = nullptr,
+ const Instruction *CtxI = nullptr,
+ const DominatorTree *DT = nullptr,
+ unsigned Depth = 0);
+
+inline bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC,
+ BasicBlock::iterator CtxI,
+ const DominatorTree *DT = nullptr,
+ unsigned Depth = 0) {
+ // Takes an iterator as a position, passes down to Instruction *
+ // implementation.
+ return isGuaranteedNotToBePoison(V, AC, &*CtxI, DT, Depth);
+}
+
+/// Returns true if V cannot be undef, but may be poison.
+LLVM_ABI bool isGuaranteedNotToBeUndef(const Value *V,
+ AssumptionCache *AC = nullptr,
+ const Instruction *CtxI = nullptr,
+ const DominatorTree *DT = nullptr,
+ unsigned Depth = 0);
+
+/// Return the boolean condition value in the context of the given instruction
+/// if it is known based on dominating conditions.
+LLVM_ABI std::optional<bool>
+isImpliedByDomCondition(const Value *Cond, const Instruction *ContextI,
+ const DataLayout &DL);
+LLVM_ABI std::optional<bool>
+isImpliedByDomCondition(CmpPredicate Pred, const Value *LHS, const Value *RHS,
+ const Instruction *ContextI, const DataLayout &DL);
+
+/// Adjust \p Known for the given select \p Arm to include information from the
+/// select \p Cond.
+LLVM_ABI void adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond,
+ Value *Arm, bool Invert,
+ const SimplifyQuery &Q,
+ unsigned Depth = 0);
+
+/// Compute known bits from the range metadata.
+/// \p KnownZero the set of bits that are known to be zero
+/// \p KnownOne the set of bits that are known to be one
+LLVM_ABI void computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
+ KnownBits &Known);
+
+/// Merge bits known from context-dependent facts into Known.
+LLVM_ABI void computeKnownBitsFromContext(const Value *V, KnownBits &Known,
+ const SimplifyQuery &Q,
+ unsigned Depth = 0);
+
+namespace vthelper {
+LLVM_ABI unsigned getBitWidth(Type *Ty, const DataLayout &DL);
+LLVM_ABI const Instruction *safeCxtI(const Value *V, const Instruction *CxtI);
+LLVM_ABI void breakSelfRecursivePHI(const Use *U, const PHINode *PHI,
+ Value *&ValOut, Instruction *&CtxIOut,
+ const PHINode **PhiOut = nullptr);
+LLVM_ABI void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
+ KnownBits &Known);
+LLVM_ABI bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
+ const SimplifyQuery &Q, unsigned Depth = 0);
+LLVM_ABI bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
+ const APInt &DemandedElts,
+ APInt &DemandedLHS, APInt &DemandedRHS);
+} // namespace vthelper
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/Analysis/KnownBitsAnalysis.cpp b/llvm/lib/Analysis/KnownBitsAnalysis.cpp
index ba6246f2a9b19..a60a6e4f63833 100644
--- a/llvm/lib/Analysis/KnownBitsAnalysis.cpp
+++ b/llvm/lib/Analysis/KnownBitsAnalysis.cpp
@@ -11,12 +11,25 @@
#include "llvm/ADT/GraphTraits.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/Analysis/SimplifyQuery.h"
+#include "llvm/Analysis/ValueTrackingHelper.h"
+#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
+#include "llvm/IR/GlobalAlias.h"
+#include "llvm/IR/IntrinsicsRISCV.h"
+#include "llvm/IR/IntrinsicsX86.h"
+#include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/KnownBits.h"
+#include "llvm/TargetParser/RISCVTargetParser.h"
using namespace llvm;
+using namespace vthelper;
+using namespace PatternMatch;
template <typename NodeRef, typename ChildIteratorType>
struct NodeGraphTraitsBase {
@@ -148,3 +161,1462 @@ void KnownBitsDataflow::print(raw_ostream &OS) const {
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD void KnownBitsDataflow::dump() const { print(dbgs()); }
#endif
+
+constexpr unsigned MaxAnalysisRecursionDepth = 6;
+
+static void computeKnownBits(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known, const SimplifyQuery &Q,
+ unsigned Depth);
+
+static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
+ const SimplifyQuery &Q, unsigned Depth) {
+ KnownBits Known(getBitWidth(V->getType(), Q.DL));
+ ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
+ return Known;
+}
+
+static void computeKnownBits(const Value *V, KnownBits &Known,
+ const SimplifyQuery &Q, unsigned Depth) {
+ // Since the number of lanes in a scalable vector is unknown at compile time,
+ // we track one bit which is implicitly broadcast to all lanes. This means
+ // that all lanes in a scalable vector are considered demanded.
+ auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+ APInt DemandedElts =
+ FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+ ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
+}
+
+static KnownBits computeKnownBits(const Value *V, const SimplifyQuery &Q,
+ unsigned Depth) {
+ KnownBits Known(getBitWidth(V->getType(), Q.DL));
+ computeKnownBits(V, Known, Q, Depth);
+ return Known;
+}
+
+static KnownBits computeKnownBitsForHorizontalOperation(
+ const Operator *I, const APInt &DemandedElts, const SimplifyQuery &Q,
+ unsigned Depth,
+ const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
+ KnownBitsFunc) {
+ APInt DemandedEltsLHS, DemandedEltsRHS;
+ getHorizDemandedEltsForFirstOperand(Q.DL.getTypeSizeInBits(I->getType()),
+ DemandedElts, DemandedEltsLHS,
+ DemandedEltsRHS);
+
+ const auto ComputeForSingleOpFunc =
+ [Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
+ return KnownBitsFunc(
+ computeKnownBits(Op, DemandedEltsOp, Q, Depth + 1),
+ computeKnownBits(Op, DemandedEltsOp << 1, Q, Depth + 1));
+ };
+
+ if (DemandedEltsRHS.isZero())
+ return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS);
+ if (DemandedEltsLHS.isZero())
+ return ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS);
+
+ return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS)
+ .intersectWith(ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS));
+}
+
+static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
+ const APInt &DemandedElts,
+ KnownBits &KnownOut,
+ const SimplifyQuery &Q,
+ unsigned Depth) {
+
+ Type *Ty = Op0->getType();
+ const unsigned BitWidth = Ty->getScalarSizeInBits();
+
+ // Only handle scalar types for now
+ if (Ty->isVectorTy())
+ return;
+
+ // Try to match: a * (b - c) + c * d.
+ // When a == 1 => A == nullptr, the same applies to d/D as well.
+ const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
+ const Instruction *SubBC = nullptr;
+
+ const auto MatchSubBC = [&]() {
+ // (b - c) can have two forms that interest us:
+ //
+ // 1. sub nuw %b, %c
+ // 2. xor %c, %b
+ //
+ // For the first case, nuw flag guarantees our requirement b >= c.
+ //
+ // The second case might happen when the analysis can infer that b is a mask
+ // for c and we can transform sub operation into xor (that is usually true
+ // for constant b's). Even though xor is symmetrical, canonicalization
+ // ensures that the constant will be the RHS. We have additional checks
+ // later on to ensure that this xor operation is equivalent to subtraction.
+ return m_Instruction(SubBC, m_CombineOr(m_NUWSub(m_Value(B), m_Value(C)),
+ m_Xor(m_Value(C), m_Value(B))));
+ };
+
+ const auto MatchASubBC = [&]() {
+ // Cases:
+ // - a * (b - c)
+ // - (b - c) * a
+ // - (b - c) <- a implicitly equals 1
+ return m_CombineOr(m_c_Mul(m_Value(A), MatchSubBC()), MatchSubBC());
+ };
+
+ const auto MatchCD = [&]() {
+ // Cases:
+ // - d * c
+ // - c * d
+ // - c <- d implicitly equals 1
+ return m_CombineOr(m_c_Mul(m_Value(D), m_Specific(C)), m_Specific(C));
+ };
+
+ const auto Match = [&](const Value *LHS, const Value *RHS) {
+ // We do use m_Specific(C) in MatchCD, so we have to make sure that
+ // it's bound to anything and match(LHS, MatchASubBC()) absolutely
+ // has to evaluate first and return true.
+ //
+ // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
+ return match(LHS, MatchASubBC()) && match(RHS, MatchCD());
+ };
+
+ if (!Match(Op0, Op1) && !Match(Op1, Op0))
+ return;
+
+ const auto ComputeKnownBitsOrOne = [&](const Value *V) {
+ // For some of the values we use the convention of leaving
+ // it nullptr to signify an implicit constant 1.
+ return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
+ : KnownBits::makeConstant(APInt(BitWidth, 1));
+ };
+
+ // Check that all operands are non-negative
+ const KnownBits KnownA = ComputeKnownBitsOrOne(A);
+ if (!KnownA.isNonNegative())
+ return;
+
+ const KnownBits KnownD = ComputeKnownBitsOrOne(D);
+ if (!KnownD.isNonNegative())
+ return;
+
+ const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
+ if (!KnownB.isNonNegative())
+ return;
+
+ const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
+ if (!KnownC.isNonNegative())
+ return;
+
+ // If we matched subtraction as xor, we need to actually check that xor
+ // is semantically equivalent to subtraction.
+ //
+ // For that to be true, b has to be a mask for c or that b's known
+ // ones cover all known and possible ones of c.
+ if (SubBC->getOpcode() == Instruction::Xor &&
+ !KnownC.getMaxValue().isSubsetOf(KnownB.getMinValue()))
+ return;
+
+ const APInt MaxA = KnownA.getMaxValue();
+ const APInt MaxD = KnownD.getMaxValue();
+ const APInt MaxAD = APIntOps::umax(MaxA, MaxD);
+ const APInt MaxB = KnownB.getMaxValue();
+
+ // We can't infer leading zeros info if the upper-bound estimate wraps.
+ bool Overflow;
+ const APInt UpperBound = MaxAD.umul_ov(MaxB, Overflow);
+
+ if (Overflow)
+ return;
+
+ // If we know that x <= y and both are positive than x has at least the same
+ // number of leading zeros as y.
+ const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
+ KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros);
+}
+
+static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
+ bool NSW, bool NUW,
+ const APInt &DemandedElts,
+ KnownBits &KnownOut, KnownBits &Known2,
+ const SimplifyQuery &Q, unsigned Depth) {
+ computeKnownBits(Op1, DemandedElts, KnownOut, Q, Depth + 1);
+
+ // If one operand is unknown and we have no nowrap information,
+ // the result will be unknown independently of the second operand.
+ if (KnownOut.isUnknown() && !NSW && !NUW)
+ return;
+
+ computeKnownBits(Op0, DemandedElts, Known2, Q, Depth + 1);
+ KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, Known2, KnownOut);
+
+ if (!Add && NSW && !KnownOut.isNonNegative() &&
+ isImpliedByDomCondition(ICmpInst::ICMP_SLE, Op1, Op0, Q.CxtI, Q.DL)
+ .value_or(false))
+ KnownOut.makeNonNegative();
+
+ if (Add)
+ // Try to match lerp pattern and combine results
+ computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth);
+}
+
+static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
+ bool NUW, const APInt &DemandedElts,
+ KnownBits &Known, KnownBits &Known2,
+ const SimplifyQuery &Q, unsigned Depth) {
+ computeKnownBits(Op1, DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(Op0, DemandedElts, Known2, Q, Depth + 1);
+
+ bool isKnownNegative = false;
+ bool isKnownNonNegative = false;
+ // If the multiplication is known not to overflow, compute the sign bit.
+ if (NSW) {
+ if (Op0 == Op1) {
+ // The product of a number with itself is non-negative.
+ isKnownNonNegative = true;
+ } else {
+ bool isKnownNonNegativeOp1 = Known.isNonNegative();
+ bool isKnownNonNegativeOp0 = Known2.isNonNegative();
+ bool isKnownNegativeOp1 = Known.isNegative();
+ bool isKnownNegativeOp0 = Known2.isNegative();
+ // The product of two numbers with the same sign is non-negative.
+ isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) ||
+ (isKnownNonNegativeOp1 && isKnownNonNegativeOp0);
+ if (!isKnownNonNegative && NUW) {
+ // mul nuw nsw with a factor > 1 is non-negative.
+ KnownBits One = KnownBits::makeConstant(APInt(Known.getBitWidth(), 1));
+ isKnownNonNegative = KnownBits::sgt(Known, One).value_or(false) ||
+ KnownBits::sgt(Known2, One).value_or(false);
+ }
+
+ // The product of a negative number and a non-negative number is either
+ // negative or zero.
+ if (!isKnownNonNegative)
+ isKnownNegative =
+ (isKnownNegativeOp1 && isKnownNonNegativeOp0 &&
+ Known2.isNonZero()) ||
+ (isKnownNegativeOp0 && isKnownNonNegativeOp1 && Known.isNonZero());
+ }
+ }
+
+ bool SelfMultiply = Op0 == Op1;
+ if (SelfMultiply)
+ SelfMultiply &=
+ isGuaranteedNotToBeUndef(Op0, Q.AC, Q.CxtI, Q.DT, Depth + 1);
+
+ // MISSING: computeNumSignBits in case of SelfMultiply.
+
+ Known = KnownBits::mul(Known, Known2, SelfMultiply);
+
+ // Only make use of no-wrap flags if we failed to compute the sign bit
+ // directly. This matters if the multiplication always overflows, in
+ // which case we prefer to follow the result of the direct computation,
+ // though as the program is invoking undefined behaviour we can choose
+ // whatever we like here.
+ if (isKnownNonNegative && !Known.isNegative())
+ Known.makeNonNegative();
+ else if (isKnownNegative && !Known.isNonNegative())
+ Known.makeNegative();
+}
+
+static void computeKnownBitsFromShiftOperator(
+ const Operator *I, const APInt &DemandedElts, KnownBits &Known,
+ KnownBits &Known2, const SimplifyQuery &Q, unsigned Depth,
+ function_ref<KnownBits(const KnownBits &, const KnownBits &, bool)> KF) {
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
+ // To limit compile-time impact, only query isKnownNonZero() if we know at
+ // least something about the shift amount.
+ bool ShAmtNonZero =
+ Known.isNonZero() ||
+ (Known.getMaxValue().ult(Known.getBitWidth()) &&
+ isKnownNonZero(I->getOperand(1), DemandedElts, Q, Depth + 1));
+ Known = KF(Known2, Known, ShAmtNonZero);
+}
+
+static KnownBits
+getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
+ const KnownBits &KnownLHS, const KnownBits &KnownRHS,
+ const SimplifyQuery &Q, unsigned Depth) {
+ unsigned BitWidth = KnownLHS.getBitWidth();
+ KnownBits KnownOut(BitWidth);
+ bool IsAnd = false;
+ bool HasKnownOne = !KnownLHS.One.isZero() || !KnownRHS.One.isZero();
+ Value *X = nullptr, *Y = nullptr;
+
+ switch (I->getOpcode()) {
+ case Instruction::And:
+ KnownOut = KnownLHS & KnownRHS;
+ IsAnd = true;
+ // and(x, -x) is common idioms that will clear all but lowest set
+ // bit. If we have a single known bit in x, we can clear all bits
+ // above it.
+ // TODO: instcombine often reassociates independent `and` which can hide
+ // this pattern. Try to match and(x, and(-x, y)) / and(and(x, y), -x).
+ if (HasKnownOne && match(I, m_c_And(m_Value(X), m_Neg(m_Deferred(X))))) {
+ // -(-x) == x so using whichever (LHS/RHS) gets us a better result.
+ if (KnownLHS.countMaxTrailingZeros() <= KnownRHS.countMaxTrailingZeros())
+ KnownOut = KnownLHS.blsi();
+ else
+ KnownOut = KnownRHS.blsi();
+ }
+ break;
+ case Instruction::Or:
+ KnownOut = KnownLHS | KnownRHS;
+ break;
+ case Instruction::Xor:
+ KnownOut = KnownLHS ^ KnownRHS;
+ // xor(x, x-1) is common idioms that will clear all but lowest set
+ // bit. If we have a single known bit in x, we can clear all bits
+ // above it.
+ // TODO: xor(x, x-1) is often rewritting as xor(x, x-C) where C !=
+ // -1 but for the purpose of demanded bits (xor(x, x-C) &
+ // Demanded) == (xor(x, x-1) & Demanded). Extend the xor pattern
+ // to use arbitrary C if xor(x, x-C) as the same as xor(x, x-1).
+ if (HasKnownOne &&
+ match(I, m_c_Xor(m_Value(X), m_Add(m_Deferred(X), m_AllOnes())))) {
+ const KnownBits &XBits = I->getOperand(0) == X ? KnownLHS : KnownRHS;
+ KnownOut = XBits.blsmsk();
+ }
+ break;
+ default:
+ llvm_unreachable("Invalid Op used in 'analyzeKnownBitsFromAndXorOr'");
+ }
+
+ // and(x, add (x, -1)) is a common idiom that always clears the low bit;
+ // xor/or(x, add (x, -1)) is an idiom that will always set the low bit.
+ // here we handle the more general case of adding any odd number by
+ // matching the form and/xor/or(x, add(x, y)) where y is odd.
+ // TODO: This could be generalized to clearing any bit set in y where the
+ // following bit is known to be unset in y.
+ if (!KnownOut.Zero[0] && !KnownOut.One[0] &&
+ (match(I, m_c_BinOp(m_Value(X), m_c_Add(m_Deferred(X), m_Value(Y)))) ||
+ match(I, m_c_BinOp(m_Value(X), m_Sub(m_Deferred(X), m_Value(Y)))) ||
+ match(I, m_c_BinOp(m_Value(X), m_Sub(m_Value(Y), m_Deferred(X)))))) {
+ KnownBits KnownY(BitWidth);
+ computeKnownBits(Y, DemandedElts, KnownY, Q, Depth + 1);
+ if (KnownY.countMinTrailingOnes() > 0) {
+ if (IsAnd)
+ KnownOut.Zero.setBit(0);
+ else
+ KnownOut.One.setBit(0);
+ }
+ }
+ return KnownOut;
+}
+
+static void computeKnownBitsFromOperator(const Operator *I,
+ const APInt &DemandedElts,
+ KnownBits &Known,
+ const SimplifyQuery &Q,
+ unsigned Depth) {
+ unsigned BitWidth = Known.getBitWidth();
+
+ KnownBits Known2(BitWidth);
+ switch (I->getOpcode()) {
+ default:
+ break;
+ case Instruction::Load:
+ if (MDNode *MD =
+ Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_range))
+ computeKnownBitsFromRangeMetadata(*MD, Known);
+ break;
+ case Instruction::And:
+ computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+
+ Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q, Depth);
+ break;
+ case Instruction::Or:
+ computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+
+ Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q, Depth);
+ break;
+ case Instruction::Xor:
+ computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+
+ Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q, Depth);
+ break;
+ case Instruction::Mul: {
+ bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
+ bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+ computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, NUW,
+ DemandedElts, Known, Known2, Q, Depth);
+ break;
+ }
+ case Instruction::UDiv: {
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known =
+ KnownBits::udiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
+ break;
+ }
+ case Instruction::SDiv: {
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known =
+ KnownBits::sdiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
+ break;
+ }
+ case Instruction::Select: {
+ auto ComputeForArm = [&](Value *Arm, bool Invert) {
+ KnownBits Res(Known.getBitWidth());
+ computeKnownBits(Arm, DemandedElts, Res, Q, Depth + 1);
+ adjustKnownBitsForSelectArm(Res, I->getOperand(0), Arm, Invert, Q, Depth);
+ return Res;
+ };
+ // Only known if known in both the LHS and RHS.
+ Known =
+ ComputeForArm(I->getOperand(1), /*Invert=*/false)
+ .intersectWith(ComputeForArm(I->getOperand(2), /*Invert=*/true));
+ break;
+ }
+ case Instruction::FPTrunc:
+ case Instruction::FPExt:
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ case Instruction::SIToFP:
+ case Instruction::UIToFP:
+ break; // Can't work with floating point.
+ case Instruction::PtrToInt:
+ case Instruction::PtrToAddr:
+ case Instruction::IntToPtr:
+ // Fall through and handle them the same as zext/trunc.
+ [[fallthrough]];
+ case Instruction::ZExt:
+ case Instruction::Trunc: {
+ Type *SrcTy = I->getOperand(0)->getType();
+
+ unsigned SrcBitWidth;
+ // Note that we handle pointer operands here because of inttoptr/ptrtoint
+ // which fall through here.
+ Type *ScalarTy = SrcTy->getScalarType();
+ SrcBitWidth = ScalarTy->isPointerTy()
+ ? Q.DL.getPointerTypeSizeInBits(ScalarTy)
+ : Q.DL.getTypeSizeInBits(ScalarTy);
+
+ assert(SrcBitWidth && "SrcBitWidth can't be zero");
+ Known = Known.anyextOrTrunc(SrcBitWidth);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ if (auto *Inst = dyn_cast<PossiblyNonNegInst>(I);
+ Inst && Inst->hasNonNeg() && !Known.isNegative())
+ Known.makeNonNegative();
+ Known = Known.zextOrTrunc(BitWidth);
+ break;
+ }
+ case Instruction::BitCast: {
+ Type *SrcTy = I->getOperand(0)->getType();
+ if (SrcTy->isIntOrPtrTy() &&
+ // TODO: For now, not handling conversions like:
+ // (bitcast i64 %x to <2 x i32>)
+ !I->getType()->isVectorTy()) {
+ computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ break;
+ }
+
+ // MISSING: computeKnownFPClass to handle bitcast from floating-point to
+ // integer.
+
+ // Handle cast from vector integer type to scalar or vector integer.
+ auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcTy);
+ if (!SrcVecTy || !SrcVecTy->getElementType()->isIntegerTy() ||
+ !I->getType()->isIntOrIntVectorTy() ||
+ isa<ScalableVectorType>(I->getType()))
+ break;
+
+ unsigned NumElts = DemandedElts.getBitWidth();
+ bool IsLE = Q.DL.isLittleEndian();
+ // Look through a cast from narrow vector elements to wider type.
+ // Examples: v4i32 -> v2i64, v3i8 -> v24
+ unsigned SubBitWidth = SrcVecTy->getScalarSizeInBits();
+ if (BitWidth % SubBitWidth == 0) {
+ // Known bits are automatically intersected across demanded elements of a
+ // vector. So for example, if a bit is computed as known zero, it must be
+ // zero across all demanded elements of the vector.
+ //
+ // For this bitcast, each demanded element of the output is sub-divided
+ // across a set of smaller vector elements in the source vector. To get
+ // the known bits for an entire element of the output, compute the known
+ // bits for each sub-element sequentially. This is done by shifting the
+ // one-set-bit demanded elements parameter across the sub-elements for
+ // consecutive calls to computeKnownBits. We are using the demanded
+ // elements parameter as a mask operator.
+ //
+ // The known bits of each sub-element are then inserted into place
+ // (dependent on endian) to form the full result of known bits.
+ unsigned SubScale = BitWidth / SubBitWidth;
+ APInt SubDemandedElts = APInt::getZero(NumElts * SubScale);
+ for (unsigned i = 0; i != NumElts; ++i) {
+ if (DemandedElts[i])
+ SubDemandedElts.setBit(i * SubScale);
+ }
+
+ KnownBits KnownSrc(SubBitWidth);
+ for (unsigned i = 0; i != SubScale; ++i) {
+ computeKnownBits(I->getOperand(0), SubDemandedElts.shl(i), KnownSrc, Q,
+ Depth + 1);
+ unsigned ShiftElt = IsLE ? i : SubScale - 1 - i;
+ Known.insertBits(KnownSrc, ShiftElt * SubBitWidth);
+ }
+ }
+ // Look through a cast from wider vector elements to narrow type.
+ // Examples: v2i64 -> v4i32
+ if (SubBitWidth % BitWidth == 0) {
+ unsigned SubScale = SubBitWidth / BitWidth;
+ KnownBits KnownSrc(SubBitWidth);
+ APInt SubDemandedElts =
+ APIntOps::ScaleBitMask(DemandedElts, NumElts / SubScale);
+ computeKnownBits(I->getOperand(0), SubDemandedElts, KnownSrc, Q,
+ Depth + 1);
+
+ Known.setAllConflict();
+ for (unsigned i = 0; i != NumElts; ++i) {
+ if (DemandedElts[i]) {
+ unsigned Shifts = IsLE ? i : NumElts - 1 - i;
+ unsigned Offset = (Shifts % SubScale) * BitWidth;
+ Known = Known.intersectWith(KnownSrc.extractBits(BitWidth, Offset));
+ if (Known.isUnknown())
+ break;
+ }
+ }
+ }
+ break;
+ }
+ case Instruction::SExt: {
+ // Compute the bits in the result that are not present in the input.
+ unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
+
+ Known = Known.trunc(SrcBitWidth);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ // If the sign bit of the input is known set or clear, then we know the
+ // top bits of the result.
+ Known = Known.sext(BitWidth);
+ break;
+ }
+ case Instruction::Shl: {
+ bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+ bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
+ auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+ bool ShAmtNonZero) {
+ return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW, ShAmtNonZero);
+ };
+ computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
+ KF);
+ // Trailing zeros of a right-shifted constant never decrease.
+ const APInt *C;
+ if (match(I->getOperand(0), m_APInt(C)))
+ Known.Zero.setLowBits(C->countr_zero());
+ break;
+ }
+ case Instruction::LShr: {
+ bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+ auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+ bool ShAmtNonZero) {
+ return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
+ };
+ computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
+ KF);
+ // Leading zeros of a left-shifted constant never decrease.
+ const APInt *C;
+ if (match(I->getOperand(0), m_APInt(C)))
+ Known.Zero.setHighBits(C->countl_zero());
+ break;
+ }
+ case Instruction::AShr: {
+ bool Exact = Q.IIQ.isExact(cast<BinaryOperator>(I));
+ auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
+ bool ShAmtNonZero) {
+ return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
+ };
+ computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
+ KF);
+ break;
+ }
+ case Instruction::Sub: {
+ bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
+ bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+ computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, NUW,
+ DemandedElts, Known, Known2, Q, Depth);
+ break;
+ }
+ case Instruction::Add: {
+ bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
+ bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
+ computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, NUW,
+ DemandedElts, Known, Known2, Q, Depth);
+ break;
+ }
+ case Instruction::SRem:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::srem(Known, Known2);
+ break;
+
+ case Instruction::URem:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::urem(Known, Known2);
+ break;
+ case Instruction::Alloca:
+ Known.Zero.setLowBits(Log2(cast<AllocaInst>(I)->getAlign()));
+ break;
+ case Instruction::GetElementPtr: {
+ // Analyze all of the subscripts of this getelementptr instruction
+ // to determine if we can prove known low zero bits.
+ computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ // Accumulate the constant indices in a separate variable
+ // to minimize the number of calls to computeForAddSub.
+ unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(I->getType());
+ APInt AccConstIndices(IndexWidth, 0);
+
+ auto AddIndexToKnown = [&](KnownBits IndexBits) {
+ if (IndexWidth == BitWidth) {
+ // Note that inbounds does *not* guarantee nsw for the addition, as only
+ // the offset is signed, while the base address is unsigned.
+ Known = KnownBits::add(Known, IndexBits);
+ } else {
+ // If the index width is smaller than the pointer width, only add the
+ // value to the low bits.
+ assert(IndexWidth < BitWidth &&
+ "Index width can't be larger than pointer width");
+ Known.insertBits(KnownBits::add(Known.trunc(IndexWidth), IndexBits), 0);
+ }
+ };
+
+ gep_type_iterator GTI = gep_type_begin(I);
+ for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) {
+ // TrailZ can only become smaller, short-circuit if we hit zero.
+ if (Known.isUnknown())
+ break;
+
+ Value *Index = I->getOperand(i);
+
+ // Handle case when index is zero.
+ Constant *CIndex = dyn_cast<Constant>(Index);
+ if (CIndex && CIndex->isNullValue())
+ continue;
+
+ if (StructType *STy = GTI.getStructTypeOrNull()) {
+ // Handle struct member offset arithmetic.
+
+ assert(CIndex &&
+ "Access to structure field must be known at compile time");
+
+ if (CIndex->getType()->isVectorTy())
+ Index = CIndex->getSplatValue();
+
+ unsigned Idx = cast<ConstantInt>(Index)->getZExtValue();
+ const StructLayout *SL = Q.DL.getStructLayout(STy);
+ uint64_t Offset = SL->getElementOffset(Idx);
+ AccConstIndices += Offset;
+ continue;
+ }
+
+ // Handle array index arithmetic.
+ Type *IndexedTy = GTI.getIndexedType();
+ if (!IndexedTy->isSized()) {
+ Known.resetAll();
+ break;
+ }
+
+ TypeSize Stride = GTI.getSequentialElementStride(Q.DL);
+ uint64_t StrideInBytes = Stride.getKnownMinValue();
+ if (!Stride.isScalable()) {
+ // Fast path for constant offset.
+ if (auto *CI = dyn_cast<ConstantInt>(Index)) {
+ AccConstIndices +=
+ CI->getValue().sextOrTrunc(IndexWidth) * StrideInBytes;
+ continue;
+ }
+ }
+
+ KnownBits IndexBits =
+ computeKnownBits(Index, Q, Depth + 1).sextOrTrunc(IndexWidth);
+ KnownBits ScalingFactor(IndexWidth);
+ // Multiply by current sizeof type.
+ // &A[i] == A + i * sizeof(*A[i]).
+ if (Stride.isScalable()) {
+ // For scalable types the only thing we know about sizeof is
+ // that this is a multiple of the minimum size.
+ ScalingFactor.Zero.setLowBits(llvm::countr_zero(StrideInBytes));
+ } else {
+ ScalingFactor =
+ KnownBits::makeConstant(APInt(IndexWidth, StrideInBytes));
+ }
+ AddIndexToKnown(KnownBits::mul(IndexBits, ScalingFactor));
+ }
+ if (!Known.isUnknown() && !AccConstIndices.isZero())
+ AddIndexToKnown(KnownBits::makeConstant(AccConstIndices));
+ break;
+ }
+ case Instruction::PHI: {
+ const PHINode *P = cast<PHINode>(I);
+ BinaryOperator *BO = nullptr;
+ Value *R = nullptr, *L = nullptr;
+ if (matchSimpleRecurrence(P, BO, R, L)) {
+ // Handle the case of a simple two-predecessor recurrence PHI.
+ // There's a lot more that could theoretically be done here, but
+ // this is sufficient to catch some interesting cases.
+ unsigned Opcode = BO->getOpcode();
+
+ switch (Opcode) {
+ // If this is a shift recurrence, we know the bits being shifted in. We
+ // can combine that with information about the start value of the
+ // recurrence to conclude facts about the result. If this is a udiv
+ // recurrence, we know that the result can never exceed either the
+ // numerator or the start value, whichever is greater.
+ case Instruction::LShr:
+ case Instruction::AShr:
+ case Instruction::Shl:
+ case Instruction::UDiv:
+ if (BO->getOperand(0) != I)
+ break;
+ [[fallthrough]];
+
+ // For a urem recurrence, the result can never exceed the start value. The
+ // phi could either be the numerator or the denominator.
+ case Instruction::URem: {
+ // We have matched a recurrence of the form:
+ // %iv = [R, %entry], [%iv.next, %backedge]
+ // %iv.next = shift_op %iv, L
+
+ // Recurse with the phi context to avoid concern about whether facts
+ // inferred hold at original context instruction. TODO: It may be
+ // correct to use the original context. IF warranted, explore and
+ // add sufficient tests to cover.
+ SimplifyQuery RecQ = Q.getWithoutCondContext();
+ RecQ.CxtI = P;
+ computeKnownBits(R, DemandedElts, Known2, RecQ, Depth + 1);
+ switch (Opcode) {
+ case Instruction::Shl:
+ // A shl recurrence will only increase the tailing zeros
+ Known.Zero.setLowBits(Known2.countMinTrailingZeros());
+ break;
+ case Instruction::LShr:
+ case Instruction::UDiv:
+ case Instruction::URem:
+ // lshr, udiv, and urem recurrences will preserve the leading zeros of
+ // the start value.
+ Known.Zero.setHighBits(Known2.countMinLeadingZeros());
+ break;
+ case Instruction::AShr:
+ // An ashr recurrence will extend the initial sign bit
+ Known.Zero.setHighBits(Known2.countMinLeadingZeros());
+ Known.One.setHighBits(Known2.countMinLeadingOnes());
+ break;
+ }
+ break;
+ }
+
+ // Check for operations that have the property that if
+ // both their operands have low zero bits, the result
+ // will have low zero bits.
+ case Instruction::Add:
+ case Instruction::Sub:
+ case Instruction::And:
+ case Instruction::Or:
+ case Instruction::Mul: {
+ // Change the context instruction to the "edge" that flows into the
+ // phi. This is important because that is where the value is actually
+ // "evaluated" even though it is used later somewhere else. (see also
+ // D69571).
+ SimplifyQuery RecQ = Q.getWithoutCondContext();
+
+ unsigned OpNum = P->getOperand(0) == R ? 0 : 1;
+ Instruction *RInst = P->getIncomingBlock(OpNum)->getTerminator();
+ Instruction *LInst = P->getIncomingBlock(1 - OpNum)->getTerminator();
+
+ // Ok, we have a PHI of the form L op= R. Check for low
+ // zero bits.
+ RecQ.CxtI = RInst;
+ computeKnownBits(R, DemandedElts, Known2, RecQ, Depth + 1);
+
+ // We need to take the minimum number of known bits
+ KnownBits Known3(BitWidth);
+ RecQ.CxtI = LInst;
+ computeKnownBits(L, DemandedElts, Known3, RecQ, Depth + 1);
+
+ Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(),
+ Known3.countMinTrailingZeros()));
+
+ auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(BO);
+ if (!OverflowOp || !Q.IIQ.hasNoSignedWrap(OverflowOp))
+ break;
+
+ switch (Opcode) {
+ // If initial value of recurrence is nonnegative, and we are adding
+ // a nonnegative number with nsw, the result can only be nonnegative
+ // or poison value regardless of the number of times we execute the
+ // add in phi recurrence. If initial value is negative and we are
+ // adding a negative number with nsw, the result can only be
+ // negative or poison value. Similar arguments apply to sub and mul.
+ //
+ // (add non-negative, non-negative) --> non-negative
+ // (add negative, negative) --> negative
+ case Instruction::Add: {
+ if (Known2.isNonNegative() && Known3.isNonNegative())
+ Known.makeNonNegative();
+ else if (Known2.isNegative() && Known3.isNegative())
+ Known.makeNegative();
+ break;
+ }
+
+ // (sub nsw non-negative, negative) --> non-negative
+ // (sub nsw negative, non-negative) --> negative
+ case Instruction::Sub: {
+ if (BO->getOperand(0) != I)
+ break;
+ if (Known2.isNonNegative() && Known3.isNegative())
+ Known.makeNonNegative();
+ else if (Known2.isNegative() && Known3.isNonNegative())
+ Known.makeNegative();
+ break;
+ }
+
+ // (mul nsw non-negative, non-negative) --> non-negative
+ case Instruction::Mul:
+ if (Known2.isNonNegative() && Known3.isNonNegative())
+ Known.makeNonNegative();
+ break;
+
+ default:
+ break;
+ }
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+
+ // Unreachable blocks may have zero-operand PHI nodes.
+ if (P->getNumIncomingValues() == 0)
+ break;
+
+ // Otherwise take the unions of the known bit sets of the operands,
+ // taking conservative care to avoid excessive recursion.
+ if (Depth < MaxAnalysisRecursionDepth - 1 && Known.isUnknown()) {
+ // Skip if every incoming value references to ourself.
+ if (isa_and_nonnull<UndefValue>(P->hasConstantValue()))
+ break;
+
+ Known.setAllConflict();
+ for (const Use &U : P->operands()) {
+ Value *IncValue;
+ const PHINode *CxtPhi;
+ Instruction *CxtI;
+ breakSelfRecursivePHI(&U, P, IncValue, CxtI, &CxtPhi);
+ // Skip direct self references.
+ if (IncValue == P)
+ continue;
+
+ // Change the context instruction to the "edge" that flows into the
+ // phi. This is important because that is where the value is actually
+ // "evaluated" even though it is used later somewhere else. (see also
+ // D69571).
+ SimplifyQuery RecQ = Q.getWithoutCondContext().getWithInstruction(CxtI);
+
+ Known2 = KnownBits(BitWidth);
+
+ // Recurse, but cap the recursion to one level, because we don't
+ // want to waste time spinning around in loops.
+ // TODO: See if we can base recursion limiter on number of incoming phi
+ // edges so we don't overly clamp analysis.
+ computeKnownBits(IncValue, DemandedElts, Known2, RecQ,
+ MaxAnalysisRecursionDepth - 1);
+
+ // See if we can further use a conditional branch into the phi
+ // to help us determine the range of the value.
+ if (!Known2.isConstant()) {
+ CmpPredicate Pred;
+ const APInt *RHSC;
+ BasicBlock *TrueSucc, *FalseSucc;
+ // TODO: Use RHS Value and compute range from its known bits.
+ if (match(RecQ.CxtI,
+ m_Br(m_c_ICmp(Pred, m_Specific(IncValue), m_APInt(RHSC)),
+ m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) {
+ // Check for cases of duplicate successors.
+ if ((TrueSucc == CxtPhi->getParent()) !=
+ (FalseSucc == CxtPhi->getParent())) {
+ // If we're using the false successor, invert the predicate.
+ if (FalseSucc == CxtPhi->getParent())
+ Pred = CmpInst::getInversePredicate(Pred);
+ // Get the knownbits implied by the incoming phi condition.
+ auto CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
+ KnownBits KnownUnion = Known2.unionWith(CR.toKnownBits());
+ // We can have conflicts here if we are analyzing deadcode (its
+ // impossible for us reach this BB based the icmp).
+ if (KnownUnion.hasConflict()) {
+ // No reason to continue analyzing in a known dead region, so
+ // just resetAll and break. This will cause us to also exit the
+ // outer loop.
+ Known.resetAll();
+ break;
+ }
+ Known2 = KnownUnion;
+ }
+ }
+ }
+
+ Known = Known.intersectWith(Known2);
+ // If all bits have been ruled out, there's no need to check
+ // more operands.
+ if (Known.isUnknown())
+ break;
+ }
+ }
+ break;
+ }
+ case Instruction::Call:
+ case Instruction::Invoke: {
+ // If range metadata is attached to this call, set known bits from that,
+ // and then intersect with known bits based on other properties of the
+ // function.
+ if (MDNode *MD =
+ Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range))
+ computeKnownBitsFromRangeMetadata(*MD, Known);
+
+ const auto *CB = cast<CallBase>(I);
+
+ if (std::optional<ConstantRange> Range = CB->getRange())
+ Known = Known.unionWith(Range->toKnownBits());
+
+ if (const Value *RV = CB->getReturnedArgOperand()) {
+ if (RV->getType() == I->getType()) {
+ computeKnownBits(RV, Known2, Q, Depth + 1);
+ Known = Known.unionWith(Known2);
+ // If the function doesn't return properly for all input values
+ // (e.g. unreachable exits) then there might be conflicts between the
+ // argument value and the range metadata. Simply discard the known bits
+ // in case of conflicts.
+ if (Known.hasConflict())
+ Known.resetAll();
+ }
+ }
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ switch (II->getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::abs: {
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ bool IntMinIsPoison = match(II->getArgOperand(1), m_One());
+ Known = Known.unionWith(Known2.abs(IntMinIsPoison));
+ break;
+ }
+ case Intrinsic::bitreverse:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ Known = Known.unionWith(Known2.reverseBits());
+ break;
+ case Intrinsic::bswap:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ Known = Known.unionWith(Known2.byteSwap());
+ break;
+ case Intrinsic::ctlz: {
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ // If we have a known 1, its position is our upper bound.
+ unsigned PossibleLZ = Known2.countMaxLeadingZeros();
+ // If this call is poison for 0 input, the result will be less than 2^n.
+ if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext()))
+ PossibleLZ = std::min(PossibleLZ, BitWidth - 1);
+ unsigned LowBits = llvm::bit_width(PossibleLZ);
+ Known.Zero.setBitsFrom(LowBits);
+ break;
+ }
+ case Intrinsic::cttz: {
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ // If we have a known 1, its position is our upper bound.
+ unsigned PossibleTZ = Known2.countMaxTrailingZeros();
+ // If this call is poison for 0 input, the result will be less than 2^n.
+ if (II->getArgOperand(1) == ConstantInt::getTrue(II->getContext()))
+ PossibleTZ = std::min(PossibleTZ, BitWidth - 1);
+ unsigned LowBits = llvm::bit_width(PossibleTZ);
+ Known.Zero.setBitsFrom(LowBits);
+ break;
+ }
+ case Intrinsic::ctpop: {
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ // We can bound the space the count needs. Also, bits known to be zero
+ // can't contribute to the population.
+ unsigned BitsPossiblySet = Known2.countMaxPopulation();
+ unsigned LowBits = llvm::bit_width(BitsPossiblySet);
+ Known.Zero.setBitsFrom(LowBits);
+ // TODO: we could bound KnownOne using the lower bound on the number
+ // of bits which might be set provided by popcnt KnownOne2.
+ break;
+ }
+ case Intrinsic::fshr:
+ case Intrinsic::fshl: {
+ const APInt *SA;
+ if (!match(I->getOperand(2), m_APInt(SA)))
+ break;
+
+ // Normalize to funnel shift left.
+ uint64_t ShiftAmt = SA->urem(BitWidth);
+ if (II->getIntrinsicID() == Intrinsic::fshr)
+ ShiftAmt = BitWidth - ShiftAmt;
+
+ KnownBits Known3(BitWidth);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known3, Q, Depth + 1);
+
+ Known2 <<= ShiftAmt;
+ Known3 >>= BitWidth - ShiftAmt;
+ Known = Known2.unionWith(Known3);
+ break;
+ }
+ case Intrinsic::clmul:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::clmul(Known, Known2);
+ break;
+ case Intrinsic::uadd_sat:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::uadd_sat(Known, Known2);
+ break;
+ case Intrinsic::usub_sat:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::usub_sat(Known, Known2);
+ break;
+ case Intrinsic::sadd_sat:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::sadd_sat(Known, Known2);
+ break;
+ case Intrinsic::ssub_sat:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::ssub_sat(Known, Known2);
+ break;
+ // Vec reverse preserves bits from input vec.
+ case Intrinsic::vector_reverse:
+ computeKnownBits(I->getOperand(0), DemandedElts.reverseBits(), Known, Q,
+ Depth + 1);
+ break;
+ // for min/max/and/or reduce, any bit common to each element in the
+ // input vec is set in the output.
+ case Intrinsic::vector_reduce_and:
+ case Intrinsic::vector_reduce_or:
+ case Intrinsic::vector_reduce_umax:
+ case Intrinsic::vector_reduce_umin:
+ case Intrinsic::vector_reduce_smax:
+ case Intrinsic::vector_reduce_smin:
+ computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ break;
+ case Intrinsic::vector_reduce_xor: {
+ computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ // The zeros common to all vecs are zero in the output.
+ // If the number of elements is odd, then the common ones remain. If the
+ // number of elements is even, then the common ones becomes zeros.
+ auto *VecTy = cast<VectorType>(I->getOperand(0)->getType());
+ // Even, so the ones become zeros.
+ bool EvenCnt = VecTy->getElementCount().isKnownEven();
+ if (EvenCnt)
+ Known.Zero |= Known.One;
+ // Maybe even element count so need to clear ones.
+ if (VecTy->isScalableTy() || EvenCnt)
+ Known.One.clearAllBits();
+ break;
+ }
+ case Intrinsic::vector_reduce_add: {
+ auto *VecTy = dyn_cast<FixedVectorType>(I->getOperand(0)->getType());
+ if (!VecTy)
+ break;
+ computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ Known = Known.reduceAdd(VecTy->getNumElements());
+ break;
+ }
+ case Intrinsic::umin:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::umin(Known, Known2);
+ break;
+ case Intrinsic::umax:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::umax(Known, Known2);
+ break;
+ case Intrinsic::smin:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::smin(Known, Known2);
+ unionWithMinMaxIntrinsicClamp(II, Known);
+ break;
+ case Intrinsic::smax:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::smax(Known, Known2);
+ unionWithMinMaxIntrinsicClamp(II, Known);
+ break;
+ case Intrinsic::ptrmask: {
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+
+ const Value *Mask = I->getOperand(1);
+ Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
+ computeKnownBits(Mask, DemandedElts, Known2, Q, Depth + 1);
+ // TODO: 1-extend would be more precise.
+ Known &= Known2.anyextOrTrunc(BitWidth);
+ break;
+ }
+ case Intrinsic::x86_sse2_pmulh_w:
+ case Intrinsic::x86_avx2_pmulh_w:
+ case Intrinsic::x86_avx512_pmulh_w_512:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::mulhs(Known, Known2);
+ break;
+ case Intrinsic::x86_sse2_pmulhu_w:
+ case Intrinsic::x86_avx2_pmulhu_w:
+ case Intrinsic::x86_avx512_pmulhu_w_512:
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ Known = KnownBits::mulhu(Known, Known2);
+ break;
+ case Intrinsic::x86_sse42_crc32_64_64:
+ Known.Zero.setBitsFrom(32);
+ break;
+ case Intrinsic::x86_ssse3_phadd_d_128:
+ case Intrinsic::x86_ssse3_phadd_w_128:
+ case Intrinsic::x86_avx2_phadd_d:
+ case Intrinsic::x86_avx2_phadd_w: {
+ Known = computeKnownBitsForHorizontalOperation(
+ I, DemandedElts, Q, Depth,
+ [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
+ return KnownBits::add(KnownLHS, KnownRHS);
+ });
+ break;
+ }
+ case Intrinsic::x86_ssse3_phadd_sw_128:
+ case Intrinsic::x86_avx2_phadd_sw: {
+ Known = computeKnownBitsForHorizontalOperation(
+ I, DemandedElts, Q, Depth, KnownBits::sadd_sat);
+ break;
+ }
+ case Intrinsic::x86_ssse3_phsub_d_128:
+ case Intrinsic::x86_ssse3_phsub_w_128:
+ case Intrinsic::x86_avx2_phsub_d:
+ case Intrinsic::x86_avx2_phsub_w: {
+ Known = computeKnownBitsForHorizontalOperation(
+ I, DemandedElts, Q, Depth,
+ [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
+ return KnownBits::sub(KnownLHS, KnownRHS);
+ });
+ break;
+ }
+ case Intrinsic::x86_ssse3_phsub_sw_128:
+ case Intrinsic::x86_avx2_phsub_sw: {
+ Known = computeKnownBitsForHorizontalOperation(
+ I, DemandedElts, Q, Depth, KnownBits::ssub_sat);
+ break;
+ }
+ case Intrinsic::riscv_vsetvli:
+ case Intrinsic::riscv_vsetvlimax: {
+ bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;
+ const ConstantRange Range = getVScaleRange(II->getFunction(), BitWidth);
+ uint64_t SEW = RISCVVType::decodeVSEW(
+ cast<ConstantInt>(II->getArgOperand(HasAVL))->getZExtValue());
+ RISCVVType::VLMUL VLMUL = static_cast<RISCVVType::VLMUL>(
+ cast<ConstantInt>(II->getArgOperand(1 + HasAVL))->getZExtValue());
+ uint64_t MaxVLEN =
+ Range.getUnsignedMax().getZExtValue() * RISCV::RVVBitsPerBlock;
+ uint64_t MaxVL = MaxVLEN / RISCVVType::getSEWLMULRatio(SEW, VLMUL);
+
+ // Result of vsetvli must be not larger than AVL.
+ if (HasAVL)
+ if (auto *CI = dyn_cast<ConstantInt>(II->getArgOperand(0)))
+ MaxVL = std::min(MaxVL, CI->getZExtValue());
+
+ unsigned KnownZeroFirstBit = Log2_32(MaxVL) + 1;
+ if (BitWidth > KnownZeroFirstBit)
+ Known.Zero.setBitsFrom(KnownZeroFirstBit);
+ break;
+ }
+ case Intrinsic::vscale: {
+ if (!II->getParent() || !II->getFunction())
+ break;
+
+ Known = getVScaleRange(II->getFunction(), BitWidth).toKnownBits();
+ break;
+ }
+ }
+ }
+ break;
+ }
+ case Instruction::ShuffleVector: {
+ if (auto *Splat = getSplatValue(I)) {
+ computeKnownBits(Splat, Known, Q, Depth + 1);
+ break;
+ }
+
+ auto *Shuf = dyn_cast<ShuffleVectorInst>(I);
+ // FIXME: Do we need to handle ConstantExpr involving shufflevectors?
+ if (!Shuf) {
+ Known.resetAll();
+ return;
+ }
+ // For undef elements, we don't know anything about the common state of
+ // the shuffle result.
+ APInt DemandedLHS, DemandedRHS;
+ if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) {
+ Known.resetAll();
+ return;
+ }
+ Known.setAllConflict();
+ if (!!DemandedLHS) {
+ const Value *LHS = Shuf->getOperand(0);
+ computeKnownBits(LHS, DemandedLHS, Known, Q, Depth + 1);
+ // If we don't know any bits, early out.
+ if (Known.isUnknown())
+ break;
+ }
+ if (!!DemandedRHS) {
+ const Value *RHS = Shuf->getOperand(1);
+ computeKnownBits(RHS, DemandedRHS, Known2, Q, Depth + 1);
+ Known = Known.intersectWith(Known2);
+ }
+ break;
+ }
+ case Instruction::InsertElement: {
+ if (isa<ScalableVectorType>(I->getType())) {
+ Known.resetAll();
+ return;
+ }
+ const Value *Vec = I->getOperand(0);
+ const Value *Elt = I->getOperand(1);
+ auto *CIdx = dyn_cast<ConstantInt>(I->getOperand(2));
+ unsigned NumElts = DemandedElts.getBitWidth();
+ APInt DemandedVecElts = DemandedElts;
+ bool NeedsElt = true;
+ // If we know the index we are inserting too, clear it from Vec check.
+ if (CIdx && CIdx->getValue().ult(NumElts)) {
+ DemandedVecElts.clearBit(CIdx->getZExtValue());
+ NeedsElt = DemandedElts[CIdx->getZExtValue()];
+ }
+
+ Known.setAllConflict();
+ if (NeedsElt) {
+ computeKnownBits(Elt, Known, Q, Depth + 1);
+ // If we don't know any bits, early out.
+ if (Known.isUnknown())
+ break;
+ }
+
+ if (!DemandedVecElts.isZero()) {
+ computeKnownBits(Vec, DemandedVecElts, Known2, Q, Depth + 1);
+ Known = Known.intersectWith(Known2);
+ }
+ break;
+ }
+ case Instruction::ExtractElement: {
+ // Look through extract element. If the index is non-constant or
+ // out-of-range demand all elements, otherwise just the extracted element.
+ const Value *Vec = I->getOperand(0);
+ const Value *Idx = I->getOperand(1);
+ auto *CIdx = dyn_cast<ConstantInt>(Idx);
+ if (isa<ScalableVectorType>(Vec->getType())) {
+ // FIXME: there's probably *something* we can do with scalable vectors
+ Known.resetAll();
+ break;
+ }
+ unsigned NumElts = cast<FixedVectorType>(Vec->getType())->getNumElements();
+ APInt DemandedVecElts = APInt::getAllOnes(NumElts);
+ if (CIdx && CIdx->getValue().ult(NumElts))
+ DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
+ computeKnownBits(Vec, DemandedVecElts, Known, Q, Depth + 1);
+ break;
+ }
+ case Instruction::ExtractValue:
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I->getOperand(0))) {
+ const ExtractValueInst *EVI = cast<ExtractValueInst>(I);
+ if (EVI->getNumIndices() != 1)
+ break;
+ if (EVI->getIndices()[0] == 0) {
+ switch (II->getIntrinsicID()) {
+ default:
+ break;
+ case Intrinsic::uadd_with_overflow:
+ case Intrinsic::sadd_with_overflow:
+ computeKnownBitsAddSub(
+ true, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
+ /* NUW=*/false, DemandedElts, Known, Known2, Q, Depth);
+ break;
+ case Intrinsic::usub_with_overflow:
+ case Intrinsic::ssub_with_overflow:
+ computeKnownBitsAddSub(
+ false, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
+ /* NUW=*/false, DemandedElts, Known, Known2, Q, Depth);
+ break;
+ case Intrinsic::umul_with_overflow:
+ case Intrinsic::smul_with_overflow:
+ computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false,
+ false, DemandedElts, Known, Known2, Q, Depth);
+ break;
+ }
+ }
+ }
+ break;
+ case Instruction::Freeze:
+ if (isGuaranteedNotToBePoison(I->getOperand(0), Q.AC, Q.CxtI, Q.DT,
+ Depth + 1))
+ computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ break;
+ }
+}
+
+static void computeKnownBits(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known, const SimplifyQuery &Q,
+ unsigned Depth) {
+ if (!DemandedElts) {
+ // No demanded elts, better to assume we don't know anything.
+ Known.resetAll();
+ return;
+ }
+
+ assert(V && "No Value?");
+ assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
+
+#ifndef NDEBUG
+ Type *Ty = V->getType();
+ unsigned BitWidth = Known.getBitWidth();
+
+ assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
+ "Not integer or pointer type!");
+
+ if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
+ assert(
+ FVTy->getNumElements() == DemandedElts.getBitWidth() &&
+ "DemandedElt width should equal the fixed vector number of elements");
+ } else {
+ assert(DemandedElts == APInt(1, 1) &&
+ "DemandedElt width should be 1 for scalars or scalable vectors");
+ }
+
+ Type *ScalarTy = Ty->getScalarType();
+ if (ScalarTy->isPointerTy()) {
+ assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) &&
+ "V and Known should have same BitWidth");
+ } else {
+ assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) &&
+ "V and Known should have same BitWidth");
+ }
+#endif
+
+ const APInt *C;
+ if (match(V, m_APInt(C))) {
+ // We know all of the bits for a scalar constant or a splat vector constant!
+ Known = KnownBits::makeConstant(*C);
+ return;
+ }
+ // Null and aggregate-zero are all-zeros.
+ if (isa<ConstantPointerNull>(V) || isa<ConstantAggregateZero>(V)) {
+ Known.setAllZero();
+ return;
+ }
+ // Handle a constant vector by taking the intersection of the known bits of
+ // each element.
+ if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(V)) {
+ assert(!isa<ScalableVectorType>(V->getType()));
+ // We know that CDV must be a vector of integers. Take the intersection of
+ // each element.
+ Known.setAllConflict();
+ for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ APInt Elt = CDV->getElementAsAPInt(i);
+ Known.Zero &= ~Elt;
+ Known.One &= Elt;
+ }
+ if (Known.hasConflict())
+ Known.resetAll();
+ return;
+ }
+
+ if (const auto *CV = dyn_cast<ConstantVector>(V)) {
+ assert(!isa<ScalableVectorType>(V->getType()));
+ // We know that CV must be a vector of integers. Take the intersection of
+ // each element.
+ Known.setAllConflict();
+ for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
+ if (!DemandedElts[i])
+ continue;
+ Constant *Element = CV->getAggregateElement(i);
+ if (isa<PoisonValue>(Element))
+ continue;
+ auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element);
+ if (!ElementCI) {
+ Known.resetAll();
+ return;
+ }
+ const APInt &Elt = ElementCI->getValue();
+ Known.Zero &= ~Elt;
+ Known.One &= Elt;
+ }
+ if (Known.hasConflict())
+ Known.resetAll();
+ return;
+ }
+
+ // Start out not knowing anything.
+ Known.resetAll();
+
+ // We can't imply anything about undefs.
+ if (isa<UndefValue>(V))
+ return;
+
+ // There's no point in looking through other users of ConstantData for
+ // assumptions. Confirm that we've handled them all.
+ assert(!isa<ConstantData>(V) && "Unhandled constant data!");
+
+ if (const auto *A = dyn_cast<Argument>(V))
+ if (std::optional<ConstantRange> Range = A->getRange())
+ Known = Range->toKnownBits();
+
+ // All recursive calls that increase depth must come after this.
+ if (Depth == MaxAnalysisRecursionDepth)
+ return;
+
+ // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has
+ // the bits of its aliasee.
+ if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
+ if (!GA->isInterposable())
+ computeKnownBits(GA->getAliasee(), Known, Q, Depth + 1);
+ return;
+ }
+
+ if (const Operator *I = dyn_cast<Operator>(V))
+ computeKnownBitsFromOperator(I, DemandedElts, Known, Q, Depth);
+ else if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
+ if (std::optional<ConstantRange> CR = GV->getAbsoluteSymbolRange())
+ Known = CR->toKnownBits();
+ }
+
+ // Aligned pointers have trailing zeros - refine Known.Zero set
+ if (isa<PointerType>(V->getType())) {
+ Align Alignment = V->getPointerAlignment(Q.DL);
+ Known.Zero.setLowBits(Log2(Alignment));
+ }
+
+ // computeKnownBitsFromContext strictly refines Known.
+ // Therefore, we run them after computeKnownBitsFromOperator.
+
+ // Check whether we can determine known bits from context such as assumes.
+ computeKnownBitsFromContext(V, Known, Q, Depth);
+}
+
+void KnownBitsCache::compute(ArrayRef<const Value *> Leaves) {
+ for (const Value *V : Leaves)
+ computeKnownBits(V, getKnownBits(V), DL, 0);
+}
+
+KnownBits KnownBitsCache::getOrCompute(const Value *V) {
+ KnownBits &Known = getKnownBits(V);
+ if (Known.isAllConflict())
+ compute(V);
+ return Known;
+}
+
+KnownBitsCache::KnownBitsCache(Function &F) : KnownBitsDataflow(F) {
+ compute(getLeaves());
+}
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 3c499862642db..2109cdfbedc42 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -83,6 +83,7 @@
#include <utility>
using namespace llvm;
+using namespace vthelper;
using namespace llvm::PatternMatch;
// Controls the number of uses of the value searched for possible
@@ -96,7 +97,7 @@ static constexpr unsigned MaxInstrsToCheckForFree = 16;
/// Returns the bitwidth of the given scalar or pointer type. For vector types,
/// returns the element type's bitwidth.
-static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
+unsigned vthelper::getBitWidth(Type *Ty, const DataLayout &DL) {
if (unsigned BitWidth = Ty->getScalarSizeInBits())
return BitWidth;
@@ -105,7 +106,7 @@ static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
// Given the provided Value and, potentially, a context instruction, return
// the preferred context instruction (if any).
-static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
+const Instruction *vthelper::safeCxtI(const Value *V, const Instruction *CxtI) {
// If we've been provided with a context instruction, then use that (provided
// it has been inserted).
if (CxtI && CxtI->getParent())
@@ -119,9 +120,9 @@ static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
return nullptr;
}
-static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
- const APInt &DemandedElts,
- APInt &DemandedLHS, APInt &DemandedRHS) {
+bool vthelper::getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
+ const APInt &DemandedElts,
+ APInt &DemandedLHS, APInt &DemandedRHS) {
if (isa<ScalableVectorType>(Shuf->getType())) {
assert(DemandedElts == APInt(1,1));
DemandedLHS = DemandedRHS = DemandedElts;
@@ -275,9 +276,6 @@ bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
Depth);
}
-static bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
- const SimplifyQuery &Q, unsigned Depth);
-
bool llvm::isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
unsigned Depth) {
return computeKnownBits(V, SQ, Depth).isNonNegative();
@@ -780,9 +778,9 @@ static bool cmpExcludesZero(CmpInst::Predicate Pred, const Value *RHS) {
return true;
}
-static void breakSelfRecursivePHI(const Use *U, const PHINode *PHI,
- Value *&ValOut, Instruction *&CtxIOut,
- const PHINode **PhiOut = nullptr) {
+void vthelper::breakSelfRecursivePHI(const Use *U, const PHINode *PHI,
+ Value *&ValOut, Instruction *&CtxIOut,
+ const PHINode **PhiOut) {
ValOut = U->get();
if (ValOut == PHI)
return;
@@ -1393,8 +1391,8 @@ static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
return CLow->sle(*CHigh);
}
-static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
- KnownBits &Known) {
+void vthelper::unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
+ KnownBits &Known) {
const APInt *CLow, *CHigh;
if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
Known = Known.unionWith(
@@ -3689,8 +3687,8 @@ static bool isKnownNonZeroFromOperator(const Operator *I,
/// specified, perform context-sensitive analysis and return true if the
/// pointer couldn't possibly be null at the specified instruction.
/// Supports values with integer or pointer type and vectors of integers.
-bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
- const SimplifyQuery &Q, unsigned Depth) {
+bool vthelper::isKnownNonZero(const Value *V, const APInt &DemandedElts,
+ const SimplifyQuery &Q, unsigned Depth) {
Type *Ty = V->getType();
#ifndef NDEBUG
@@ -3792,7 +3790,7 @@ bool llvm::isKnownNonZero(const Value *V, const SimplifyQuery &Q,
auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
APInt DemandedElts =
FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
- return ::isKnownNonZero(V, DemandedElts, Q, Depth);
+ return vthelper::isKnownNonZero(V, DemandedElts, Q, Depth);
}
/// If the pair of operators are the same invertible function, return the
diff --git a/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp b/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp
index ea1539f05b538..09f4fc66b11a5 100644
--- a/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/KnownBitsAnalysis.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
@@ -368,4 +369,73 @@ define void @test(i32 %n) {
)";
EXPECT_EQ(ActualOutput, ExpectedOutput);
}
+
+/// KnownBitsCache tests follow.
+TEST(KnownBitsCache, KnownBitsComputationParity) {
+ LLVMContext Ctx;
+ std::unique_ptr<Module> M = parseIR(Ctx, R"(
+define void @test(i32 %int_arg, float %float_arg, ptr %ptr_arg, <2 x i32> %vec_int_arg, <2 x ptr> %vec_ptr_arg) {
+entry:
+ br i1 poison, label %then, label %else
+then:
+ %int_val = add i32 %int_arg, 1
+ %float_val = fadd float %float_arg, 1.0
+ %vec_val = add <2 x i32> %vec_int_arg, <i32 1, i32 2>
+ br label %merge
+else:
+ %fpconv = fptoui float %float_arg to i32
+ %int_val2 = mul i32 %int_arg, %fpconv
+ %ptr_val = getelementptr i8, ptr %ptr_arg, i32 4
+ %vec_val2 = mul <2 x i32> %vec_int_arg, <i32 3, i32 4>
+ br label %merge
+merge:
+ %phi_int = phi i32 [ %int_val, %then ], [ %int_val2, %else ]
+ %phi_float = phi float [ %float_val, %then ], [ %float_arg, %else ]
+ %phi_ptr = phi ptr [ %ptr_arg, %then ], [ %ptr_val, %else ]
+ %phi_vec = phi <2 x i32> [ %vec_val, %then ], [ %vec_val2, %else ]
+ %final_int = add i32 %phi_int, 5
+ %vec_ptr_conv = ptrtoint <2 x ptr> %vec_ptr_arg to <2 x i32>
+ %final_vec = add <2 x i32> %phi_vec, %vec_ptr_conv
+ store float %phi_float, ptr %phi_ptr
+ ret void
+})");
+ Function *F = M->getFunction("test");
+ const DataLayout &DL = F->getDataLayout();
+ KnownBitsCache Lat(*F);
+ auto *ArgIt = F->arg_begin();
+ Argument *IntArg = &*ArgIt++;
+ ArgIt++;
+ Argument *PtrArg = &*ArgIt++;
+ Argument *VecIntArg = &*ArgIt++;
+ Argument *VecPtrArg = &*ArgIt++;
+ Instruction *IntVal = findInstructionByName(F, "int_val");
+ Instruction *VecVal = findInstructionByName(F, "vec_val");
+ Instruction *IntVal2 = findInstructionByName(F, "int_val2");
+ Instruction *PtrVal = findInstructionByName(F, "ptr_val");
+ Instruction *VecVal2 = findInstructionByName(F, "vec_val2");
+ Instruction *PhiInt = findInstructionByName(F, "phi_int");
+ Instruction *PhiPtr = findInstructionByName(F, "phi_ptr");
+ Instruction *PhiVec = findInstructionByName(F, "phi_vec");
+ Instruction *FinalInt = findInstructionByName(F, "final_int");
+ Instruction *FPConv = findInstructionByName(F, "fpconv");
+ Instruction *VecPtrConv = findInstructionByName(F, "vec_ptr_conv");
+ Instruction *FinalVec = findInstructionByName(F, "final_vec");
+
+ EXPECT_EQ(Lat.getOrCompute(IntArg), computeKnownBits(IntArg, DL));
+ EXPECT_EQ(Lat.getOrCompute(PtrArg), computeKnownBits(PtrArg, DL));
+ EXPECT_EQ(Lat.getOrCompute(VecIntArg), computeKnownBits(VecIntArg, DL));
+ EXPECT_EQ(Lat.getOrCompute(VecPtrArg), computeKnownBits(VecPtrArg, DL));
+ EXPECT_EQ(Lat.getOrCompute(IntVal), computeKnownBits(IntVal, DL));
+ EXPECT_EQ(Lat.getOrCompute(VecVal), computeKnownBits(VecVal, DL));
+ EXPECT_EQ(Lat.getOrCompute(IntVal2), computeKnownBits(IntVal2, DL));
+ EXPECT_EQ(Lat.getOrCompute(PtrVal), computeKnownBits(PtrVal, DL));
+ EXPECT_EQ(Lat.getOrCompute(VecVal2), computeKnownBits(VecVal2, DL));
+ EXPECT_EQ(Lat.getOrCompute(PhiInt), computeKnownBits(PhiInt, DL));
+ EXPECT_EQ(Lat.getOrCompute(PhiPtr), computeKnownBits(PhiPtr, DL));
+ EXPECT_EQ(Lat.getOrCompute(PhiVec), computeKnownBits(PhiVec, DL));
+ EXPECT_EQ(Lat.getOrCompute(FinalInt), computeKnownBits(FinalInt, DL));
+ EXPECT_EQ(Lat.getOrCompute(FPConv), computeKnownBits(FPConv, DL));
+ EXPECT_EQ(Lat.getOrCompute(VecPtrConv), computeKnownBits(VecPtrConv, DL));
+ EXPECT_EQ(Lat.getOrCompute(FinalVec), computeKnownBits(FinalVec, DL));
+}
} // namespace
>From 7d787c6bec410a6b34e45733926852daf4c86b84 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <artagnon at tenstorrent.com>
Date: Fri, 27 Feb 2026 11:39:08 +0000
Subject: [PATCH 3/4] [KnownBitsCache] Avoid redundant KnownBits computations
---
.../include/llvm/Analysis/KnownBitsAnalysis.h | 37 ++
.../llvm/Analysis/ValueTrackingHelper.h | 3 +
llvm/lib/Analysis/KnownBitsAnalysis.cpp | 508 +++++++++---------
llvm/lib/Analysis/ValueTracking.cpp | 12 +-
.../Analysis/KnownBitsAnalysisTest.cpp | 41 +-
5 files changed, 321 insertions(+), 280 deletions(-)
diff --git a/llvm/include/llvm/Analysis/KnownBitsAnalysis.h b/llvm/include/llvm/Analysis/KnownBitsAnalysis.h
index 8eceef2ed3c9a..d77a3c505eacd 100644
--- a/llvm/include/llvm/Analysis/KnownBitsAnalysis.h
+++ b/llvm/include/llvm/Analysis/KnownBitsAnalysis.h
@@ -21,8 +21,10 @@ namespace llvm {
class Value;
class User;
class Function;
+class Operator;
class DataLayout;
class raw_ostream;
+struct SimplifyQuery;
class KnownBitsDataflow : protected DenseMap<const Value *, KnownBits> {
/// The roots are the arguments of the function, and PHI nodes and
@@ -97,6 +99,41 @@ class KnownBitsDataflow : protected DenseMap<const Value *, KnownBits> {
};
class KnownBitsCache : protected KnownBitsDataflow {
+ KnownBits computeKnownBitsForHorizontalOperation(
+ const Operator *I, const APInt &DemandedElts, const SimplifyQuery &Q,
+ const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
+ KnownBitsFunc);
+ void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
+ const APInt &DemandedElts,
+ KnownBits &KnownOut,
+ const SimplifyQuery &Q);
+ void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
+ bool NSW, bool NUW, const APInt &DemandedElts,
+ KnownBits &KnownOut, KnownBits &Known2,
+ const SimplifyQuery &Q);
+ void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
+ bool NUW, const APInt &DemandedElts,
+ KnownBits &Known, KnownBits &Known2,
+ const SimplifyQuery &Q);
+ void computeKnownBitsFromShiftOperator(
+ const Operator *I, const APInt &DemandedElts, KnownBits &Known,
+ KnownBits &Known2, const SimplifyQuery &Q,
+ function_ref<KnownBits(const KnownBits &, const KnownBits &, bool)> KF);
+ KnownBits getKnownBitsFromAndXorOr(const Operator *I,
+ const APInt &DemandedElts,
+ const KnownBits &KnownLHS,
+ const KnownBits &KnownRHS,
+ const SimplifyQuery &Q);
+ void computeKnownBitsFromOperator(const Operator *I,
+ const APInt &DemandedElts, KnownBits &Known,
+ const SimplifyQuery &Q);
+ KnownBits computeKnownBits(const Value *V, const SimplifyQuery &Q);
+ KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
+ const SimplifyQuery &Q);
+ void computeKnownBits(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known, const SimplifyQuery &Q);
+ void computeKnownBits(const Value *V, KnownBits &Known,
+ const SimplifyQuery &Q);
void compute(ArrayRef<const Value *> Leaves);
public:
diff --git a/llvm/include/llvm/Analysis/ValueTrackingHelper.h b/llvm/include/llvm/Analysis/ValueTrackingHelper.h
index a934837fd962c..13bf9b7f9d2db 100644
--- a/llvm/include/llvm/Analysis/ValueTrackingHelper.h
+++ b/llvm/include/llvm/Analysis/ValueTrackingHelper.h
@@ -144,6 +144,9 @@ LLVM_ABI bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
LLVM_ABI bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
const APInt &DemandedElts,
APInt &DemandedLHS, APInt &DemandedRHS);
+LLVM_ABI void computeKnownBits(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known, const SimplifyQuery &Q,
+ unsigned Depth);
} // namespace vthelper
} // namespace llvm
diff --git a/llvm/lib/Analysis/KnownBitsAnalysis.cpp b/llvm/lib/Analysis/KnownBitsAnalysis.cpp
index a60a6e4f63833..df85178aa7333 100644
--- a/llvm/lib/Analysis/KnownBitsAnalysis.cpp
+++ b/llvm/lib/Analysis/KnownBitsAnalysis.cpp
@@ -164,38 +164,8 @@ LLVM_DUMP_METHOD void KnownBitsDataflow::dump() const { print(dbgs()); }
constexpr unsigned MaxAnalysisRecursionDepth = 6;
-static void computeKnownBits(const Value *V, const APInt &DemandedElts,
- KnownBits &Known, const SimplifyQuery &Q,
- unsigned Depth);
-
-static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
- const SimplifyQuery &Q, unsigned Depth) {
- KnownBits Known(getBitWidth(V->getType(), Q.DL));
- ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
- return Known;
-}
-
-static void computeKnownBits(const Value *V, KnownBits &Known,
- const SimplifyQuery &Q, unsigned Depth) {
- // Since the number of lanes in a scalable vector is unknown at compile time,
- // we track one bit which is implicitly broadcast to all lanes. This means
- // that all lanes in a scalable vector are considered demanded.
- auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
- APInt DemandedElts =
- FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
- ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
-}
-
-static KnownBits computeKnownBits(const Value *V, const SimplifyQuery &Q,
- unsigned Depth) {
- KnownBits Known(getBitWidth(V->getType(), Q.DL));
- computeKnownBits(V, Known, Q, Depth);
- return Known;
-}
-
-static KnownBits computeKnownBitsForHorizontalOperation(
+KnownBits KnownBitsCache::computeKnownBitsForHorizontalOperation(
const Operator *I, const APInt &DemandedElts, const SimplifyQuery &Q,
- unsigned Depth,
const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
KnownBitsFunc) {
APInt DemandedEltsLHS, DemandedEltsRHS;
@@ -203,12 +173,11 @@ static KnownBits computeKnownBitsForHorizontalOperation(
DemandedElts, DemandedEltsLHS,
DemandedEltsRHS);
- const auto ComputeForSingleOpFunc =
- [Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
- return KnownBitsFunc(
- computeKnownBits(Op, DemandedEltsOp, Q, Depth + 1),
- computeKnownBits(Op, DemandedEltsOp << 1, Q, Depth + 1));
- };
+ const auto ComputeForSingleOpFunc = [&](const Value *Op,
+ APInt &DemandedEltsOp) {
+ return KnownBitsFunc(computeKnownBits(Op, DemandedEltsOp, Q),
+ computeKnownBits(Op, DemandedEltsOp << 1, Q));
+ };
if (DemandedEltsRHS.isZero())
return ComputeForSingleOpFunc(I->getOperand(0), DemandedEltsLHS);
@@ -219,11 +188,11 @@ static KnownBits computeKnownBitsForHorizontalOperation(
.intersectWith(ComputeForSingleOpFunc(I->getOperand(1), DemandedEltsRHS));
}
-static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
- const APInt &DemandedElts,
- KnownBits &KnownOut,
- const SimplifyQuery &Q,
- unsigned Depth) {
+void KnownBitsCache::computeKnownBitsFromLerpPattern(const Value *Op0,
+ const Value *Op1,
+ const APInt &DemandedElts,
+ KnownBits &KnownOut,
+ const SimplifyQuery &Q) {
Type *Ty = Op0->getType();
const unsigned BitWidth = Ty->getScalarSizeInBits();
@@ -285,7 +254,7 @@ static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
const auto ComputeKnownBitsOrOne = [&](const Value *V) {
// For some of the values we use the convention of leaving
// it nullptr to signify an implicit constant 1.
- return V ? computeKnownBits(V, DemandedElts, Q, Depth + 1)
+ return V ? computeKnownBits(V, DemandedElts, Q)
: KnownBits::makeConstant(APInt(BitWidth, 1));
};
@@ -298,11 +267,11 @@ static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
if (!KnownD.isNonNegative())
return;
- const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q, Depth + 1);
+ const KnownBits KnownB = computeKnownBits(B, DemandedElts, Q);
if (!KnownB.isNonNegative())
return;
- const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q, Depth + 1);
+ const KnownBits KnownC = computeKnownBits(C, DemandedElts, Q);
if (!KnownC.isNonNegative())
return;
@@ -333,19 +302,20 @@ static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros);
}
-static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
- bool NSW, bool NUW,
- const APInt &DemandedElts,
- KnownBits &KnownOut, KnownBits &Known2,
- const SimplifyQuery &Q, unsigned Depth) {
- computeKnownBits(Op1, DemandedElts, KnownOut, Q, Depth + 1);
+void KnownBitsCache::computeKnownBitsAddSub(bool Add, const Value *Op0,
+ const Value *Op1, bool NSW,
+ bool NUW, const APInt &DemandedElts,
+ KnownBits &KnownOut,
+ KnownBits &Known2,
+ const SimplifyQuery &Q) {
+ computeKnownBits(Op1, DemandedElts, KnownOut, Q);
// If one operand is unknown and we have no nowrap information,
// the result will be unknown independently of the second operand.
if (KnownOut.isUnknown() && !NSW && !NUW)
return;
- computeKnownBits(Op0, DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(Op0, DemandedElts, Known2, Q);
KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, Known2, KnownOut);
if (!Add && NSW && !KnownOut.isNonNegative() &&
@@ -355,15 +325,16 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
if (Add)
// Try to match lerp pattern and combine results
- computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth);
+ computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q);
}
-static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
- bool NUW, const APInt &DemandedElts,
- KnownBits &Known, KnownBits &Known2,
- const SimplifyQuery &Q, unsigned Depth) {
- computeKnownBits(Op1, DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(Op0, DemandedElts, Known2, Q, Depth + 1);
+void KnownBitsCache::computeKnownBitsMul(const Value *Op0, const Value *Op1,
+ bool NSW, bool NUW,
+ const APInt &DemandedElts,
+ KnownBits &Known, KnownBits &Known2,
+ const SimplifyQuery &Q) {
+ computeKnownBits(Op1, DemandedElts, Known, Q);
+ computeKnownBits(Op0, DemandedElts, Known2, Q);
bool isKnownNegative = false;
bool isKnownNonNegative = false;
@@ -399,8 +370,7 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
bool SelfMultiply = Op0 == Op1;
if (SelfMultiply)
- SelfMultiply &=
- isGuaranteedNotToBeUndef(Op0, Q.AC, Q.CxtI, Q.DT, Depth + 1);
+ SelfMultiply &= isGuaranteedNotToBeUndef(Op0, Q.AC, Q.CxtI, Q.DT);
// MISSING: computeNumSignBits in case of SelfMultiply.
@@ -417,25 +387,25 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
Known.makeNegative();
}
-static void computeKnownBitsFromShiftOperator(
+void KnownBitsCache::computeKnownBitsFromShiftOperator(
const Operator *I, const APInt &DemandedElts, KnownBits &Known,
- KnownBits &Known2, const SimplifyQuery &Q, unsigned Depth,
+ KnownBits &Known2, const SimplifyQuery &Q,
function_ref<KnownBits(const KnownBits &, const KnownBits &, bool)> KF) {
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known, Q);
// To limit compile-time impact, only query isKnownNonZero() if we know at
// least something about the shift amount.
bool ShAmtNonZero =
- Known.isNonZero() ||
- (Known.getMaxValue().ult(Known.getBitWidth()) &&
- isKnownNonZero(I->getOperand(1), DemandedElts, Q, Depth + 1));
+ Known.isNonZero() || (Known.getMaxValue().ult(Known.getBitWidth()) &&
+ isKnownNonZero(I->getOperand(1), DemandedElts, Q));
Known = KF(Known2, Known, ShAmtNonZero);
}
-static KnownBits
-getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
- const KnownBits &KnownLHS, const KnownBits &KnownRHS,
- const SimplifyQuery &Q, unsigned Depth) {
+KnownBits KnownBitsCache::getKnownBitsFromAndXorOr(const Operator *I,
+ const APInt &DemandedElts,
+ const KnownBits &KnownLHS,
+ const KnownBits &KnownRHS,
+ const SimplifyQuery &Q) {
unsigned BitWidth = KnownLHS.getBitWidth();
KnownBits KnownOut(BitWidth);
bool IsAnd = false;
@@ -492,7 +462,7 @@ getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
match(I, m_c_BinOp(m_Value(X), m_Sub(m_Deferred(X), m_Value(Y)))) ||
match(I, m_c_BinOp(m_Value(X), m_Sub(m_Value(Y), m_Deferred(X)))))) {
KnownBits KnownY(BitWidth);
- computeKnownBits(Y, DemandedElts, KnownY, Q, Depth + 1);
+ computeKnownBits(Y, DemandedElts, KnownY, Q);
if (KnownY.countMinTrailingOnes() > 0) {
if (IsAnd)
KnownOut.Zero.setBit(0);
@@ -503,11 +473,10 @@ getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
return KnownOut;
}
-static void computeKnownBitsFromOperator(const Operator *I,
- const APInt &DemandedElts,
- KnownBits &Known,
- const SimplifyQuery &Q,
- unsigned Depth) {
+void KnownBitsCache::computeKnownBitsFromOperator(const Operator *I,
+ const APInt &DemandedElts,
+ KnownBits &Known,
+ const SimplifyQuery &Q) {
unsigned BitWidth = Known.getBitWidth();
KnownBits Known2(BitWidth);
@@ -520,40 +489,40 @@ static void computeKnownBitsFromOperator(const Operator *I,
computeKnownBitsFromRangeMetadata(*MD, Known);
break;
case Instruction::And:
- computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
- Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q, Depth);
+ Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q);
break;
case Instruction::Or:
- computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
- Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q, Depth);
+ Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q);
break;
case Instruction::Xor:
- computeKnownBits(I->getOperand(1), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
- Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q, Depth);
+ Known = getKnownBitsFromAndXorOr(I, DemandedElts, Known2, Known, Q);
break;
case Instruction::Mul: {
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, NUW,
- DemandedElts, Known, Known2, Q, Depth);
+ DemandedElts, Known, Known2, Q);
break;
}
case Instruction::UDiv: {
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known =
KnownBits::udiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
break;
}
case Instruction::SDiv: {
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known =
KnownBits::sdiv(Known, Known2, Q.IIQ.isExact(cast<BinaryOperator>(I)));
break;
@@ -561,8 +530,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Instruction::Select: {
auto ComputeForArm = [&](Value *Arm, bool Invert) {
KnownBits Res(Known.getBitWidth());
- computeKnownBits(Arm, DemandedElts, Res, Q, Depth + 1);
- adjustKnownBitsForSelectArm(Res, I->getOperand(0), Arm, Invert, Q, Depth);
+ computeKnownBits(Arm, DemandedElts, Res, Q);
+ adjustKnownBitsForSelectArm(Res, I->getOperand(0), Arm, Invert, Q);
return Res;
};
// Only known if known in both the LHS and RHS.
@@ -597,7 +566,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
assert(SrcBitWidth && "SrcBitWidth can't be zero");
Known = Known.anyextOrTrunc(SrcBitWidth);
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
if (auto *Inst = dyn_cast<PossiblyNonNegInst>(I);
Inst && Inst->hasNonNeg() && !Known.isNegative())
Known.makeNonNegative();
@@ -610,7 +579,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
// TODO: For now, not handling conversions like:
// (bitcast i64 %x to <2 x i32>)
!I->getType()->isVectorTy()) {
- computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), Known, Q);
break;
}
@@ -653,8 +622,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
KnownBits KnownSrc(SubBitWidth);
for (unsigned i = 0; i != SubScale; ++i) {
- computeKnownBits(I->getOperand(0), SubDemandedElts.shl(i), KnownSrc, Q,
- Depth + 1);
+ computeKnownBits(I->getOperand(0), SubDemandedElts.shl(i), KnownSrc, Q);
unsigned ShiftElt = IsLE ? i : SubScale - 1 - i;
Known.insertBits(KnownSrc, ShiftElt * SubBitWidth);
}
@@ -666,8 +634,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
KnownBits KnownSrc(SubBitWidth);
APInt SubDemandedElts =
APIntOps::ScaleBitMask(DemandedElts, NumElts / SubScale);
- computeKnownBits(I->getOperand(0), SubDemandedElts, KnownSrc, Q,
- Depth + 1);
+ computeKnownBits(I->getOperand(0), SubDemandedElts, KnownSrc, Q);
Known.setAllConflict();
for (unsigned i = 0; i != NumElts; ++i) {
@@ -687,7 +654,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
Known = Known.trunc(SrcBitWidth);
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
// If the sign bit of the input is known set or clear, then we know the
// top bits of the result.
Known = Known.sext(BitWidth);
@@ -700,8 +667,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
bool ShAmtNonZero) {
return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW, ShAmtNonZero);
};
- computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
- KF);
+ computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, KF);
// Trailing zeros of a right-shifted constant never decrease.
const APInt *C;
if (match(I->getOperand(0), m_APInt(C)))
@@ -714,8 +680,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
bool ShAmtNonZero) {
return KnownBits::lshr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
- computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
- KF);
+ computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, KF);
// Leading zeros of a left-shifted constant never decrease.
const APInt *C;
if (match(I->getOperand(0), m_APInt(C)))
@@ -728,33 +693,32 @@ static void computeKnownBitsFromOperator(const Operator *I,
bool ShAmtNonZero) {
return KnownBits::ashr(KnownVal, KnownAmt, ShAmtNonZero, Exact);
};
- computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
- KF);
+ computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, KF);
break;
}
case Instruction::Sub: {
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, NUW,
- DemandedElts, Known, Known2, Q, Depth);
+ DemandedElts, Known, Known2, Q);
break;
}
case Instruction::Add: {
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, NUW,
- DemandedElts, Known, Known2, Q, Depth);
+ DemandedElts, Known, Known2, Q);
break;
}
case Instruction::SRem:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::srem(Known, Known2);
break;
case Instruction::URem:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::urem(Known, Known2);
break;
case Instruction::Alloca:
@@ -763,7 +727,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Instruction::GetElementPtr: {
// Analyze all of the subscripts of this getelementptr instruction
// to determine if we can prove known low zero bits.
- computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), Known, Q);
// Accumulate the constant indices in a separate variable
// to minimize the number of calls to computeForAddSub.
unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(I->getType());
@@ -830,8 +794,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
}
- KnownBits IndexBits =
- computeKnownBits(Index, Q, Depth + 1).sextOrTrunc(IndexWidth);
+ KnownBits IndexBits = computeKnownBits(Index, Q).sextOrTrunc(IndexWidth);
KnownBits ScalingFactor(IndexWidth);
// Multiply by current sizeof type.
// &A[i] == A + i * sizeof(*A[i]).
@@ -886,7 +849,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
// add sufficient tests to cover.
SimplifyQuery RecQ = Q.getWithoutCondContext();
RecQ.CxtI = P;
- computeKnownBits(R, DemandedElts, Known2, RecQ, Depth + 1);
+ computeKnownBits(R, DemandedElts, Known2, RecQ);
switch (Opcode) {
case Instruction::Shl:
// A shl recurrence will only increase the tailing zeros
@@ -929,12 +892,12 @@ static void computeKnownBitsFromOperator(const Operator *I,
// Ok, we have a PHI of the form L op= R. Check for low
// zero bits.
RecQ.CxtI = RInst;
- computeKnownBits(R, DemandedElts, Known2, RecQ, Depth + 1);
+ computeKnownBits(R, DemandedElts, Known2, RecQ);
// We need to take the minimum number of known bits
KnownBits Known3(BitWidth);
RecQ.CxtI = LInst;
- computeKnownBits(L, DemandedElts, Known3, RecQ, Depth + 1);
+ computeKnownBits(L, DemandedElts, Known3, RecQ);
Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(),
Known3.countMinTrailingZeros()));
@@ -996,78 +959,79 @@ static void computeKnownBitsFromOperator(const Operator *I,
// Otherwise take the unions of the known bit sets of the operands,
// taking conservative care to avoid excessive recursion.
- if (Depth < MaxAnalysisRecursionDepth - 1 && Known.isUnknown()) {
- // Skip if every incoming value references to ourself.
- if (isa_and_nonnull<UndefValue>(P->hasConstantValue()))
- break;
+ // Skip if every incoming value references to ourself.
+ if (isa_and_nonnull<UndefValue>(P->hasConstantValue()))
+ break;
- Known.setAllConflict();
- for (const Use &U : P->operands()) {
- Value *IncValue;
- const PHINode *CxtPhi;
- Instruction *CxtI;
- breakSelfRecursivePHI(&U, P, IncValue, CxtI, &CxtPhi);
- // Skip direct self references.
- if (IncValue == P)
- continue;
+ Known.setAllConflict();
+ for (const Use &U : P->operands()) {
+ Value *IncValue;
+ const PHINode *CxtPhi;
+ Instruction *CxtI;
+ breakSelfRecursivePHI(&U, P, IncValue, CxtI, &CxtPhi);
+ // Skip direct self references.
+ if (IncValue == P)
+ continue;
- // Change the context instruction to the "edge" that flows into the
- // phi. This is important because that is where the value is actually
- // "evaluated" even though it is used later somewhere else. (see also
- // D69571).
- SimplifyQuery RecQ = Q.getWithoutCondContext().getWithInstruction(CxtI);
-
- Known2 = KnownBits(BitWidth);
-
- // Recurse, but cap the recursion to one level, because we don't
- // want to waste time spinning around in loops.
- // TODO: See if we can base recursion limiter on number of incoming phi
- // edges so we don't overly clamp analysis.
- computeKnownBits(IncValue, DemandedElts, Known2, RecQ,
- MaxAnalysisRecursionDepth - 1);
-
- // See if we can further use a conditional branch into the phi
- // to help us determine the range of the value.
- if (!Known2.isConstant()) {
- CmpPredicate Pred;
- const APInt *RHSC;
- BasicBlock *TrueSucc, *FalseSucc;
- // TODO: Use RHS Value and compute range from its known bits.
- if (match(RecQ.CxtI,
- m_Br(m_c_ICmp(Pred, m_Specific(IncValue), m_APInt(RHSC)),
- m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) {
- // Check for cases of duplicate successors.
- if ((TrueSucc == CxtPhi->getParent()) !=
- (FalseSucc == CxtPhi->getParent())) {
- // If we're using the false successor, invert the predicate.
- if (FalseSucc == CxtPhi->getParent())
- Pred = CmpInst::getInversePredicate(Pred);
- // Get the knownbits implied by the incoming phi condition.
- auto CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
- KnownBits KnownUnion = Known2.unionWith(CR.toKnownBits());
- // We can have conflicts here if we are analyzing deadcode (its
- // impossible for us reach this BB based the icmp).
- if (KnownUnion.hasConflict()) {
- // No reason to continue analyzing in a known dead region, so
- // just resetAll and break. This will cause us to also exit the
- // outer loop.
- Known.resetAll();
- break;
- }
- Known2 = KnownUnion;
+ // Change the context instruction to the "edge" that flows into the
+ // phi. This is important because that is where the value is actually
+ // "evaluated" even though it is used later somewhere else. (see also
+ // D69571).
+ SimplifyQuery RecQ = Q.getWithoutCondContext().getWithInstruction(CxtI);
+
+ Known2 = KnownBits(BitWidth);
+
+ // Recurse, but cap the recursion to one level, because we don't
+ // want to waste time spinning around in loops.
+ // TODO: See if we can base recursion limiter on number of incoming phi
+ // edges so we don't overly clamp analysis.
+ vthelper::computeKnownBits(IncValue, DemandedElts, Known2, RecQ,
+ MaxAnalysisRecursionDepth - 1);
+
+ // See if we can further use a conditional branch into the phi
+ // to help us determine the range of the value.
+ if (!Known2.isConstant()) {
+ CmpPredicate Pred;
+ const APInt *RHSC;
+ BasicBlock *TrueSucc, *FalseSucc;
+ // TODO: Use RHS Value and compute range from its known bits.
+ if (match(RecQ.CxtI,
+ m_Br(m_c_ICmp(Pred, m_Specific(IncValue), m_APInt(RHSC)),
+ m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) {
+ // Check for cases of duplicate successors.
+ if ((TrueSucc == CxtPhi->getParent()) !=
+ (FalseSucc == CxtPhi->getParent())) {
+ // If we're using the false successor, invert the predicate.
+ if (FalseSucc == CxtPhi->getParent())
+ Pred = CmpInst::getInversePredicate(Pred);
+ // Get the knownbits implied by the incoming phi condition.
+ auto CR = ConstantRange::makeExactICmpRegion(Pred, *RHSC);
+ KnownBits KnownUnion = Known2.unionWith(CR.toKnownBits());
+ // We can have conflicts here if we are analyzing deadcode (its
+ // impossible for us reach this BB based the icmp).
+ if (KnownUnion.hasConflict()) {
+ // No reason to continue analyzing in a known dead region, so
+ // just resetAll and break. This will cause us to also exit the
+ // outer loop.
+ Known.resetAll();
+ break;
}
+ Known2 = KnownUnion;
}
}
-
- Known = Known.intersectWith(Known2);
- // If all bits have been ruled out, there's no need to check
- // more operands.
- if (Known.isUnknown())
- break;
}
+
+ // Update KnownBits of IncValue, since we're calling VT's
+ // computeKnownBits.
+ emplace_or_assign(IncValue, Known2);
+
+ Known = Known.intersectWith(Known2);
+ // If all bits have been ruled out, there's no need to check
+ // more operands.
+ if (Known.isUnknown())
+ break;
}
- break;
- }
+ } break;
case Instruction::Call:
case Instruction::Invoke: {
// If range metadata is attached to this call, set known bits from that,
@@ -1084,7 +1048,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
if (const Value *RV = CB->getReturnedArgOperand()) {
if (RV->getType() == I->getType()) {
- computeKnownBits(RV, Known2, Q, Depth + 1);
+ computeKnownBits(RV, Known2, Q);
Known = Known.unionWith(Known2);
// If the function doesn't return properly for all input values
// (e.g. unreachable exits) then there might be conflicts between the
@@ -1099,21 +1063,21 @@ static void computeKnownBitsFromOperator(const Operator *I,
default:
break;
case Intrinsic::abs: {
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
bool IntMinIsPoison = match(II->getArgOperand(1), m_One());
Known = Known.unionWith(Known2.abs(IntMinIsPoison));
break;
}
case Intrinsic::bitreverse:
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
Known = Known.unionWith(Known2.reverseBits());
break;
case Intrinsic::bswap:
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
Known = Known.unionWith(Known2.byteSwap());
break;
case Intrinsic::ctlz: {
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
// If we have a known 1, its position is our upper bound.
unsigned PossibleLZ = Known2.countMaxLeadingZeros();
// If this call is poison for 0 input, the result will be less than 2^n.
@@ -1124,7 +1088,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::cttz: {
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
// If we have a known 1, its position is our upper bound.
unsigned PossibleTZ = Known2.countMaxTrailingZeros();
// If this call is poison for 0 input, the result will be less than 2^n.
@@ -1135,7 +1099,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::ctpop: {
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
// We can bound the space the count needs. Also, bits known to be zero
// can't contribute to the population.
unsigned BitsPossiblySet = Known2.countMaxPopulation();
@@ -1157,8 +1121,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
ShiftAmt = BitWidth - ShiftAmt;
KnownBits Known3(BitWidth);
- computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known3, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known2, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known3, Q);
Known2 <<= ShiftAmt;
Known3 >>= BitWidth - ShiftAmt;
@@ -1166,34 +1130,34 @@ static void computeKnownBitsFromOperator(const Operator *I,
break;
}
case Intrinsic::clmul:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::clmul(Known, Known2);
break;
case Intrinsic::uadd_sat:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::uadd_sat(Known, Known2);
break;
case Intrinsic::usub_sat:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::usub_sat(Known, Known2);
break;
case Intrinsic::sadd_sat:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::sadd_sat(Known, Known2);
break;
case Intrinsic::ssub_sat:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::ssub_sat(Known, Known2);
break;
// Vec reverse preserves bits from input vec.
case Intrinsic::vector_reverse:
- computeKnownBits(I->getOperand(0), DemandedElts.reverseBits(), Known, Q,
- Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts.reverseBits(), Known,
+ Q);
break;
// for min/max/and/or reduce, any bit common to each element in the
// input vec is set in the output.
@@ -1203,10 +1167,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Intrinsic::vector_reduce_umin:
case Intrinsic::vector_reduce_smax:
case Intrinsic::vector_reduce_smin:
- computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), Known, Q);
break;
case Intrinsic::vector_reduce_xor: {
- computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), Known, Q);
// The zeros common to all vecs are zero in the output.
// If the number of elements is odd, then the common ones remain. If the
// number of elements is even, then the common ones becomes zeros.
@@ -1224,38 +1188,38 @@ static void computeKnownBitsFromOperator(const Operator *I,
auto *VecTy = dyn_cast<FixedVectorType>(I->getOperand(0)->getType());
if (!VecTy)
break;
- computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), Known, Q);
Known = Known.reduceAdd(VecTy->getNumElements());
break;
}
case Intrinsic::umin:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::umin(Known, Known2);
break;
case Intrinsic::umax:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::umax(Known, Known2);
break;
case Intrinsic::smin:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::smin(Known, Known2);
unionWithMinMaxIntrinsicClamp(II, Known);
break;
case Intrinsic::smax:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::smax(Known, Known2);
unionWithMinMaxIntrinsicClamp(II, Known);
break;
case Intrinsic::ptrmask: {
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
const Value *Mask = I->getOperand(1);
Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
- computeKnownBits(Mask, DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(Mask, DemandedElts, Known2, Q);
// TODO: 1-extend would be more precise.
Known &= Known2.anyextOrTrunc(BitWidth);
break;
@@ -1263,15 +1227,15 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Intrinsic::x86_sse2_pmulh_w:
case Intrinsic::x86_avx2_pmulh_w:
case Intrinsic::x86_avx512_pmulh_w_512:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::mulhs(Known, Known2);
break;
case Intrinsic::x86_sse2_pmulhu_w:
case Intrinsic::x86_avx2_pmulhu_w:
case Intrinsic::x86_avx512_pmulhu_w_512:
- computeKnownBits(I->getOperand(0), DemandedElts, Known, Q, Depth + 1);
- computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q, Depth + 1);
+ computeKnownBits(I->getOperand(0), DemandedElts, Known, Q);
+ computeKnownBits(I->getOperand(1), DemandedElts, Known2, Q);
Known = KnownBits::mulhu(Known, Known2);
break;
case Intrinsic::x86_sse42_crc32_64_64:
@@ -1282,7 +1246,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Intrinsic::x86_avx2_phadd_d:
case Intrinsic::x86_avx2_phadd_w: {
Known = computeKnownBitsForHorizontalOperation(
- I, DemandedElts, Q, Depth,
+ I, DemandedElts, Q,
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
return KnownBits::add(KnownLHS, KnownRHS);
});
@@ -1290,8 +1254,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
case Intrinsic::x86_ssse3_phadd_sw_128:
case Intrinsic::x86_avx2_phadd_sw: {
- Known = computeKnownBitsForHorizontalOperation(
- I, DemandedElts, Q, Depth, KnownBits::sadd_sat);
+ Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Q,
+ KnownBits::sadd_sat);
break;
}
case Intrinsic::x86_ssse3_phsub_d_128:
@@ -1299,7 +1263,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Intrinsic::x86_avx2_phsub_d:
case Intrinsic::x86_avx2_phsub_w: {
Known = computeKnownBitsForHorizontalOperation(
- I, DemandedElts, Q, Depth,
+ I, DemandedElts, Q,
[](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
return KnownBits::sub(KnownLHS, KnownRHS);
});
@@ -1307,8 +1271,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
case Intrinsic::x86_ssse3_phsub_sw_128:
case Intrinsic::x86_avx2_phsub_sw: {
- Known = computeKnownBitsForHorizontalOperation(
- I, DemandedElts, Q, Depth, KnownBits::ssub_sat);
+ Known = computeKnownBitsForHorizontalOperation(I, DemandedElts, Q,
+ KnownBits::ssub_sat);
break;
}
case Intrinsic::riscv_vsetvli:
@@ -1346,7 +1310,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
case Instruction::ShuffleVector: {
if (auto *Splat = getSplatValue(I)) {
- computeKnownBits(Splat, Known, Q, Depth + 1);
+ computeKnownBits(Splat, Known, Q);
break;
}
@@ -1359,21 +1323,22 @@ static void computeKnownBitsFromOperator(const Operator *I,
// For undef elements, we don't know anything about the common state of
// the shuffle result.
APInt DemandedLHS, DemandedRHS;
- if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) {
+ if (!vthelper::getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS,
+ DemandedRHS)) {
Known.resetAll();
return;
}
Known.setAllConflict();
if (!!DemandedLHS) {
const Value *LHS = Shuf->getOperand(0);
- computeKnownBits(LHS, DemandedLHS, Known, Q, Depth + 1);
+ computeKnownBits(LHS, DemandedLHS, Known, Q);
// If we don't know any bits, early out.
if (Known.isUnknown())
break;
}
if (!!DemandedRHS) {
const Value *RHS = Shuf->getOperand(1);
- computeKnownBits(RHS, DemandedRHS, Known2, Q, Depth + 1);
+ computeKnownBits(RHS, DemandedRHS, Known2, Q);
Known = Known.intersectWith(Known2);
}
break;
@@ -1397,14 +1362,14 @@ static void computeKnownBitsFromOperator(const Operator *I,
Known.setAllConflict();
if (NeedsElt) {
- computeKnownBits(Elt, Known, Q, Depth + 1);
+ computeKnownBits(Elt, Known, Q);
// If we don't know any bits, early out.
if (Known.isUnknown())
break;
}
if (!DemandedVecElts.isZero()) {
- computeKnownBits(Vec, DemandedVecElts, Known2, Q, Depth + 1);
+ computeKnownBits(Vec, DemandedVecElts, Known2, Q);
Known = Known.intersectWith(Known2);
}
break;
@@ -1424,7 +1389,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
APInt DemandedVecElts = APInt::getAllOnes(NumElts);
if (CIdx && CIdx->getValue().ult(NumElts))
DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue());
- computeKnownBits(Vec, DemandedVecElts, Known, Q, Depth + 1);
+ computeKnownBits(Vec, DemandedVecElts, Known, Q);
break;
}
case Instruction::ExtractValue:
@@ -1440,34 +1405,36 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Intrinsic::sadd_with_overflow:
computeKnownBitsAddSub(
true, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
- /* NUW=*/false, DemandedElts, Known, Known2, Q, Depth);
+ /* NUW=*/false, DemandedElts, Known, Known2, Q);
break;
case Intrinsic::usub_with_overflow:
case Intrinsic::ssub_with_overflow:
computeKnownBitsAddSub(
false, II->getArgOperand(0), II->getArgOperand(1), /*NSW=*/false,
- /* NUW=*/false, DemandedElts, Known, Known2, Q, Depth);
+ /* NUW=*/false, DemandedElts, Known, Known2, Q);
break;
case Intrinsic::umul_with_overflow:
case Intrinsic::smul_with_overflow:
computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false,
- false, DemandedElts, Known, Known2, Q, Depth);
+ false, DemandedElts, Known, Known2, Q);
break;
}
}
}
break;
case Instruction::Freeze:
- if (isGuaranteedNotToBePoison(I->getOperand(0), Q.AC, Q.CxtI, Q.DT,
- Depth + 1))
- computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
+ if (isGuaranteedNotToBePoison(I->getOperand(0), Q.AC, Q.CxtI, Q.DT))
+ computeKnownBits(I->getOperand(0), Known, Q);
break;
}
}
-static void computeKnownBits(const Value *V, const APInt &DemandedElts,
- KnownBits &Known, const SimplifyQuery &Q,
- unsigned Depth) {
+void KnownBitsCache::computeKnownBits(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known,
+ const SimplifyQuery &Q) {
+ if (!Known.isAllConflict())
+ return;
+
if (!DemandedElts) {
// No demanded elts, better to assume we don't know anything.
Known.resetAll();
@@ -1475,7 +1442,6 @@ static void computeKnownBits(const Value *V, const APInt &DemandedElts,
}
assert(V && "No Value?");
- assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
#ifndef NDEBUG
Type *Ty = V->getType();
@@ -1573,20 +1539,16 @@ static void computeKnownBits(const Value *V, const APInt &DemandedElts,
if (std::optional<ConstantRange> Range = A->getRange())
Known = Range->toKnownBits();
- // All recursive calls that increase depth must come after this.
- if (Depth == MaxAnalysisRecursionDepth)
- return;
-
// A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has
// the bits of its aliasee.
if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) {
if (!GA->isInterposable())
- computeKnownBits(GA->getAliasee(), Known, Q, Depth + 1);
+ computeKnownBits(GA->getAliasee(), Known, Q);
return;
}
if (const Operator *I = dyn_cast<Operator>(V))
- computeKnownBitsFromOperator(I, DemandedElts, Known, Q, Depth);
+ computeKnownBitsFromOperator(I, DemandedElts, Known, Q);
else if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
if (std::optional<ConstantRange> CR = GV->getAbsoluteSymbolRange())
Known = CR->toKnownBits();
@@ -1602,15 +1564,53 @@ static void computeKnownBits(const Value *V, const APInt &DemandedElts,
// Therefore, we run them after computeKnownBitsFromOperator.
// Check whether we can determine known bits from context such as assumes.
- computeKnownBitsFromContext(V, Known, Q, Depth);
+ computeKnownBitsFromContext(V, Known, Q);
+
+ // Make sure we recursively compute KnownBits of all operands, for cases not
+ // already handled.
+ if (auto *I = dyn_cast<Instruction>(V))
+ for (Value *Op : I->operands())
+ if (contains(Op))
+ computeKnownBits(Op, Q);
+}
+
+KnownBits KnownBitsCache::computeKnownBits(const Value *V,
+ const SimplifyQuery &Q) {
+ assert(contains(V) && "KnownBits information for Value not tracked");
+ KnownBits &Known = getKnownBits(V);
+ computeKnownBits(V, Known, Q);
+ return Known;
+}
+
+KnownBits KnownBitsCache::computeKnownBits(const Value *V,
+ const APInt &DemandedElts,
+ const SimplifyQuery &Q) {
+ assert(contains(V) && "KnownBits information for Value not tracked");
+ KnownBits &Known = getKnownBits(V);
+ computeKnownBits(V, DemandedElts, Known, Q);
+ return Known;
+}
+
+void KnownBitsCache::computeKnownBits(const Value *V, KnownBits &Known,
+ const SimplifyQuery &Q) {
+ // Since the number of lanes in a scalable vector is unknown at compile time,
+ // we track one bit which is implicitly broadcast to all lanes. This means
+ // that all lanes in a scalable vector are considered demanded.
+ auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
+ APInt DemandedElts =
+ FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
+ computeKnownBits(V, DemandedElts, Known, Q);
}
void KnownBitsCache::compute(ArrayRef<const Value *> Leaves) {
- for (const Value *V : Leaves)
- computeKnownBits(V, getKnownBits(V), DL, 0);
+ for (const Value *V : Leaves) {
+ assert(contains(V) && "KnownBits information for Value not tracked");
+ computeKnownBits(V, getKnownBits(V), DL);
+ }
}
KnownBits KnownBitsCache::getOrCompute(const Value *V) {
+ assert(contains(V) && "KnownBits information for Value not tracked");
KnownBits &Known = getKnownBits(V);
if (Known.isAllConflict())
compute(V);
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 2109cdfbedc42..2ea202eab1fe0 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -135,10 +135,6 @@ bool vthelper::getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
DemandedElts, DemandedLHS, DemandedRHS);
}
-static void computeKnownBits(const Value *V, const APInt &DemandedElts,
- KnownBits &Known, const SimplifyQuery &Q,
- unsigned Depth);
-
void llvm::computeKnownBits(const Value *V, KnownBits &Known,
const SimplifyQuery &Q, unsigned Depth) {
// Since the number of lanes in a scalable vector is unknown at compile time,
@@ -147,7 +143,7 @@ void llvm::computeKnownBits(const Value *V, KnownBits &Known,
auto *FVTy = dyn_cast<FixedVectorType>(V->getType());
APInt DemandedElts =
FVTy ? APInt::getAllOnes(FVTy->getNumElements()) : APInt(1, 1);
- ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
+ vthelper::computeKnownBits(V, DemandedElts, Known, Q, Depth);
}
void llvm::computeKnownBits(const Value *V, KnownBits &Known,
@@ -2436,9 +2432,9 @@ KnownBits llvm::computeKnownBits(const Value *V, const SimplifyQuery &Q,
/// where V is a vector, known zero, and known one values are the
/// same width as the vector element, and the bit is set only if it is true
/// for all of the demanded elements in the vector specified by DemandedElts.
-void computeKnownBits(const Value *V, const APInt &DemandedElts,
- KnownBits &Known, const SimplifyQuery &Q,
- unsigned Depth) {
+void vthelper::computeKnownBits(const Value *V, const APInt &DemandedElts,
+ KnownBits &Known, const SimplifyQuery &Q,
+ unsigned Depth) {
if (!DemandedElts) {
// No demanded elts, better to assume we don't know anything.
Known.resetAll();
diff --git a/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp b/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp
index 09f4fc66b11a5..bc9c2f5e0bed7 100644
--- a/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/KnownBitsAnalysisTest.cpp
@@ -371,7 +371,12 @@ define void @test(i32 %n) {
}
/// KnownBitsCache tests follow.
-TEST(KnownBitsCache, KnownBitsComputationParity) {
+struct KnownBitsCacheForTest : public KnownBitsCache {
+ KnownBitsCacheForTest(Function &F) : KnownBitsCache(F) {}
+ KnownBits at(const Value *V) const { return KnownBitsCache::at(V); }
+};
+
+TEST(KnownBitsCache, Initialization) {
LLVMContext Ctx;
std::unique_ptr<Module> M = parseIR(Ctx, R"(
define void @test(i32 %int_arg, float %float_arg, ptr %ptr_arg, <2 x i32> %vec_int_arg, <2 x ptr> %vec_ptr_arg) {
@@ -401,7 +406,7 @@ define void @test(i32 %int_arg, float %float_arg, ptr %ptr_arg, <2 x i32> %vec_i
})");
Function *F = M->getFunction("test");
const DataLayout &DL = F->getDataLayout();
- KnownBitsCache Lat(*F);
+ KnownBitsCacheForTest Lat(*F);
auto *ArgIt = F->arg_begin();
Argument *IntArg = &*ArgIt++;
ArgIt++;
@@ -421,21 +426,21 @@ define void @test(i32 %int_arg, float %float_arg, ptr %ptr_arg, <2 x i32> %vec_i
Instruction *VecPtrConv = findInstructionByName(F, "vec_ptr_conv");
Instruction *FinalVec = findInstructionByName(F, "final_vec");
- EXPECT_EQ(Lat.getOrCompute(IntArg), computeKnownBits(IntArg, DL));
- EXPECT_EQ(Lat.getOrCompute(PtrArg), computeKnownBits(PtrArg, DL));
- EXPECT_EQ(Lat.getOrCompute(VecIntArg), computeKnownBits(VecIntArg, DL));
- EXPECT_EQ(Lat.getOrCompute(VecPtrArg), computeKnownBits(VecPtrArg, DL));
- EXPECT_EQ(Lat.getOrCompute(IntVal), computeKnownBits(IntVal, DL));
- EXPECT_EQ(Lat.getOrCompute(VecVal), computeKnownBits(VecVal, DL));
- EXPECT_EQ(Lat.getOrCompute(IntVal2), computeKnownBits(IntVal2, DL));
- EXPECT_EQ(Lat.getOrCompute(PtrVal), computeKnownBits(PtrVal, DL));
- EXPECT_EQ(Lat.getOrCompute(VecVal2), computeKnownBits(VecVal2, DL));
- EXPECT_EQ(Lat.getOrCompute(PhiInt), computeKnownBits(PhiInt, DL));
- EXPECT_EQ(Lat.getOrCompute(PhiPtr), computeKnownBits(PhiPtr, DL));
- EXPECT_EQ(Lat.getOrCompute(PhiVec), computeKnownBits(PhiVec, DL));
- EXPECT_EQ(Lat.getOrCompute(FinalInt), computeKnownBits(FinalInt, DL));
- EXPECT_EQ(Lat.getOrCompute(FPConv), computeKnownBits(FPConv, DL));
- EXPECT_EQ(Lat.getOrCompute(VecPtrConv), computeKnownBits(VecPtrConv, DL));
- EXPECT_EQ(Lat.getOrCompute(FinalVec), computeKnownBits(FinalVec, DL));
+ EXPECT_EQ(Lat.at(IntArg), computeKnownBits(IntArg, DL));
+ EXPECT_EQ(Lat.at(PtrArg), computeKnownBits(PtrArg, DL));
+ EXPECT_EQ(Lat.at(VecIntArg), computeKnownBits(VecIntArg, DL));
+ EXPECT_EQ(Lat.at(VecPtrArg), computeKnownBits(VecPtrArg, DL));
+ EXPECT_EQ(Lat.at(IntVal), computeKnownBits(IntVal, DL));
+ EXPECT_EQ(Lat.at(VecVal), computeKnownBits(VecVal, DL));
+ EXPECT_EQ(Lat.at(IntVal2), computeKnownBits(IntVal2, DL));
+ EXPECT_EQ(Lat.at(PtrVal), computeKnownBits(PtrVal, DL));
+ EXPECT_EQ(Lat.at(VecVal2), computeKnownBits(VecVal2, DL));
+ EXPECT_EQ(Lat.at(PhiInt), computeKnownBits(PhiInt, DL));
+ EXPECT_EQ(Lat.at(PhiPtr), computeKnownBits(PhiPtr, DL));
+ EXPECT_EQ(Lat.at(PhiVec), computeKnownBits(PhiVec, DL));
+ EXPECT_EQ(Lat.at(FinalInt), computeKnownBits(FinalInt, DL));
+ EXPECT_EQ(Lat.at(FPConv), computeKnownBits(FPConv, DL));
+ EXPECT_EQ(Lat.at(VecPtrConv), computeKnownBits(VecPtrConv, DL));
+ EXPECT_EQ(Lat.at(FinalVec), computeKnownBits(FinalVec, DL));
}
} // namespace
>From 1b11d8026f8bd06f4f8cd19909d59d3b4c9b8532 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <artagnon at tenstorrent.com>
Date: Mon, 9 Mar 2026 15:18:41 +0000
Subject: [PATCH 4/4] [KnownBitsAnalysis] Attempt to fix win, per Simon's
diagnosis
---
llvm/lib/Analysis/KnownBitsAnalysis.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Analysis/KnownBitsAnalysis.cpp b/llvm/lib/Analysis/KnownBitsAnalysis.cpp
index df85178aa7333..0abc08a8f470f 100644
--- a/llvm/lib/Analysis/KnownBitsAnalysis.cpp
+++ b/llvm/lib/Analysis/KnownBitsAnalysis.cpp
@@ -63,7 +63,7 @@ static auto filter_range( // NOLINT
RangeT R, std::function<bool(const Value *)> ExcludeFn = [](const Value *) {
return false;
}) {
- return make_filter_range(R, [&](const Value *V) {
+ return make_filter_range(R, [=](const Value *V) {
return !ExcludeFn(V) && V->getType()->getScalarType()->isIntOrPtrTy();
});
}
More information about the llvm-commits
mailing list