[llvm] be57381 - [InstCombine] Create a class to lazily track computed known bits (#66611)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 17 09:10:23 PDT 2023


Author: Dhruv Chawla
Date: 2023-10-17T21:40:18+05:30
New Revision: be57381a4a08b0b6a89d5b5fdec0880b202e99f4

URL: https://github.com/llvm/llvm-project/commit/be57381a4a08b0b6a89d5b5fdec0880b202e99f4
DIFF: https://github.com/llvm/llvm-project/commit/be57381a4a08b0b6a89d5b5fdec0880b202e99f4.diff

LOG: [InstCombine] Create a class to lazily track computed known bits (#66611)

This patch adds a new class "WithCache" which stores a pointer to
any type passable to computeKnownBits along with KnownBits
information which is computed on-demand when getKnownBits()
is called. This allows reusing the known bits information when it is
passed as an argument to multiple functions.

It also changes a few functions to accept WithCache(s) so that
known bits information computed in some callees can be propagated to
others from the top level visitAddSub caller.

This gives a speedup of 0.14%:
https://llvm-compile-time-tracker.com/compare.php?from=499d41cef2e7bbb65804f6a815b9fa8b27efce0f&to=fbea87f1f1e6d5552e2bc309f8e201a3af6d28ec&stat=instructions:u

Added: 
    llvm/include/llvm/Analysis/WithCache.h

Modified: 
    llvm/include/llvm/Analysis/ValueTracking.h
    llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 25272e0581c9385..0e02d0d5b4865da 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -17,6 +17,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/Analysis/SimplifyQuery.h"
+#include "llvm/Analysis/WithCache.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/FMF.h"
@@ -90,6 +91,12 @@ KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
                            const DominatorTree *DT = nullptr,
                            bool UseInstrInfo = true);
 
+KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
+                           unsigned Depth, const SimplifyQuery &Q);
+
+KnownBits computeKnownBits(const Value *V, unsigned Depth,
+                           const SimplifyQuery &Q);
+
 /// Compute known bits from the range metadata.
 /// \p KnownZero the set of bits that are known to be zero
 /// \p KnownOne the set of bits that are known to be one
@@ -107,7 +114,8 @@ KnownBits analyzeKnownBitsFromAndXorOr(
     bool UseInstrInfo = true);
 
 /// Return true if LHS and RHS have no common bits set.
-bool haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
+bool haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
+                         const WithCache<const Value *> &RHSCache,
                          const SimplifyQuery &SQ);
 
 /// Return true if the given value is known to have exactly one bit set when
@@ -847,9 +855,12 @@ OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS,
                                              const SimplifyQuery &SQ);
 OverflowResult computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
                                            const SimplifyQuery &SQ);
-OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS,
-                                             const SimplifyQuery &SQ);
-OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
+OverflowResult
+computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
+                              const WithCache<const Value *> &RHS,
+                              const SimplifyQuery &SQ);
+OverflowResult computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
+                                           const WithCache<const Value *> &RHS,
                                            const SimplifyQuery &SQ);
 /// This version also leverages the sign bit of Add if known.
 OverflowResult computeOverflowForSignedAdd(const AddOperator *Add,

diff  --git a/llvm/include/llvm/Analysis/WithCache.h b/llvm/include/llvm/Analysis/WithCache.h
new file mode 100644
index 000000000000000..8065c45738f840b
--- /dev/null
+++ b/llvm/include/llvm/Analysis/WithCache.h
@@ -0,0 +1,71 @@
+//===- llvm/Analysis/WithCache.h - KnownBits cache for pointers -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Store a pointer to any type along with the KnownBits information for it
+// that is computed lazily (if required).
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_WITHCACHE_H
+#define LLVM_ANALYSIS_WITHCACHE_H
+
+#include "llvm/IR/Value.h"
+#include "llvm/Support/KnownBits.h"
+#include <type_traits>
+
+namespace llvm {
+struct SimplifyQuery;
+KnownBits computeKnownBits(const Value *V, unsigned Depth,
+                           const SimplifyQuery &Q);
+
+template <typename Arg> class WithCache {
+  static_assert(std::is_pointer_v<Arg>, "WithCache requires a pointer type!");
+
+  using UnderlyingType = std::remove_pointer_t<Arg>;
+  constexpr static bool IsConst = std::is_const_v<Arg>;
+
+  template <typename T, bool Const>
+  using conditionally_const_t = std::conditional_t<Const, const T, T>;
+
+  using PointerType = conditionally_const_t<UnderlyingType *, IsConst>;
+  using ReferenceType = conditionally_const_t<UnderlyingType &, IsConst>;
+
+  // Store the presence of the KnownBits information in one of the bits of
+  // Pointer.
+  // true  -> present
+  // false -> absent
+  mutable PointerIntPair<PointerType, 1, bool> Pointer;
+  mutable KnownBits Known;
+
+  void calculateKnownBits(const SimplifyQuery &Q) const {
+    Known = computeKnownBits(Pointer.getPointer(), 0, Q);
+    Pointer.setInt(true);
+  }
+
+public:
+  WithCache(PointerType Pointer) : Pointer(Pointer, false) {}
+  WithCache(PointerType Pointer, const KnownBits &Known)
+      : Pointer(Pointer, true), Known(Known) {}
+
+  [[nodiscard]] PointerType getValue() const { return Pointer.getPointer(); }
+
+  [[nodiscard]] const KnownBits &getKnownBits(const SimplifyQuery &Q) const {
+    if (!hasKnownBits())
+      calculateKnownBits(Q);
+    return Known;
+  }
+
+  [[nodiscard]] bool hasKnownBits() const { return Pointer.getInt(); }
+
+  operator PointerType() const { return Pointer.getPointer(); }
+  PointerType operator->() const { return Pointer.getPointer(); }
+  ReferenceType operator*() const { return *Pointer.getPointer(); }
+};
+} // namespace llvm
+
+#endif

diff  --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index dcfcc8f41dd58d0..f8b3874267ded3b 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -510,15 +510,18 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
                                              SQ.getWithInstruction(CxtI));
   }
 
-  OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
-                                               const Value *RHS,
-                                               const Instruction *CxtI) const {
+  OverflowResult
+  computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
+                                const WithCache<const Value *> &RHS,
+                                const Instruction *CxtI) const {
     return llvm::computeOverflowForUnsignedAdd(LHS, RHS,
                                                SQ.getWithInstruction(CxtI));
   }
 
-  OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS,
-                                             const Instruction *CxtI) const {
+  OverflowResult
+  computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
+                              const WithCache<const Value *> &RHS,
+                              const Instruction *CxtI) const {
     return llvm::computeOverflowForSignedAdd(LHS, RHS,
                                              SQ.getWithInstruction(CxtI));
   }

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 82310444326d6bb..1e0281b3f1bd79e 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -33,6 +33,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/Analysis/WithCache.h"
 #include "llvm/IR/Argument.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
@@ -178,17 +179,11 @@ void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
       SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
 }
 
-static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
-                                  unsigned Depth, const SimplifyQuery &Q);
-
-static KnownBits computeKnownBits(const Value *V, unsigned Depth,
-                                  const SimplifyQuery &Q);
-
 KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
                                  unsigned Depth, AssumptionCache *AC,
                                  const Instruction *CxtI,
                                  const DominatorTree *DT, bool UseInstrInfo) {
-  return ::computeKnownBits(
+  return computeKnownBits(
       V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
 }
 
@@ -196,13 +191,17 @@ KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
                                  const DataLayout &DL, unsigned Depth,
                                  AssumptionCache *AC, const Instruction *CxtI,
                                  const DominatorTree *DT, bool UseInstrInfo) {
-  return ::computeKnownBits(
+  return computeKnownBits(
       V, DemandedElts, Depth,
       SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
 }
 
-bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
+bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
+                               const WithCache<const Value *> &RHSCache,
                                const SimplifyQuery &SQ) {
+  const Value *LHS = LHSCache.getValue();
+  const Value *RHS = RHSCache.getValue();
+
   assert(LHS->getType() == RHS->getType() &&
          "LHS and RHS should have the same type");
   assert(LHS->getType()->isIntOrIntVectorTy() &&
@@ -250,12 +249,9 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
         match(LHS, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
       return true;
   }
-  IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType());
-  KnownBits LHSKnown(IT->getBitWidth());
-  KnownBits RHSKnown(IT->getBitWidth());
-  ::computeKnownBits(LHS, LHSKnown, 0, SQ);
-  ::computeKnownBits(RHS, RHSKnown, 0, SQ);
-  return KnownBits::haveNoCommonBitsSet(LHSKnown, RHSKnown);
+
+  return KnownBits::haveNoCommonBitsSet(LHSCache.getKnownBits(SQ),
+                                        RHSCache.getKnownBits(SQ));
 }
 
 bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
@@ -1784,19 +1780,19 @@ static void computeKnownBitsFromOperator(const Operator *I,
 
 /// Determine which bits of V are known to be either zero or one and return
 /// them.
-KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
-                           unsigned Depth, const SimplifyQuery &Q) {
+KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
+                                 unsigned Depth, const SimplifyQuery &Q) {
   KnownBits Known(getBitWidth(V->getType(), Q.DL));
-  computeKnownBits(V, DemandedElts, Known, Depth, Q);
+  ::computeKnownBits(V, DemandedElts, Known, Depth, Q);
   return Known;
 }
 
 /// Determine which bits of V are known to be either zero or one and return
 /// them.
-KnownBits computeKnownBits(const Value *V, unsigned Depth,
-                           const SimplifyQuery &Q) {
+KnownBits llvm::computeKnownBits(const Value *V, unsigned Depth,
+                                 const SimplifyQuery &Q) {
   KnownBits Known(getBitWidth(V->getType(), Q.DL));
-  computeKnownBits(V, Known, Depth, Q);
+  ::computeKnownBits(V, Known, Depth, Q);
   return Known;
 }
 
@@ -6256,10 +6252,11 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
 
 /// Combine constant ranges from computeConstantRange() and computeKnownBits().
 static ConstantRange
-computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
+computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
+                                       bool ForSigned,
                                        const SimplifyQuery &SQ) {
-  KnownBits Known = ::computeKnownBits(V, /*Depth=*/0, SQ);
-  ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned);
+  ConstantRange CR1 =
+      ConstantRange::fromKnownBits(V.getKnownBits(SQ), ForSigned);
   ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
   ConstantRange::PreferredRangeType RangeType =
       ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
@@ -6269,8 +6266,8 @@ computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
 OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
                                                    const Value *RHS,
                                                    const SimplifyQuery &SQ) {
-  KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
-  KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
+  KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
+  KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
   ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
   ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
   return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
@@ -6307,17 +6304,18 @@ OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
     // product is exactly the minimum negative number.
     // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
     // For simplicity we just check if at least one side is not negative.
-    KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
-    KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
+    KnownBits LHSKnown = computeKnownBits(LHS, /*Depth=*/0, SQ);
+    KnownBits RHSKnown = computeKnownBits(RHS, /*Depth=*/0, SQ);
     if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
       return OverflowResult::NeverOverflows;
   }
   return OverflowResult::MayOverflow;
 }
 
-OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
-                                                   const Value *RHS,
-                                                   const SimplifyQuery &SQ) {
+OverflowResult
+llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
+                                    const WithCache<const Value *> &RHS,
+                                    const SimplifyQuery &SQ) {
   ConstantRange LHSRange =
       computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
   ConstantRange RHSRange =
@@ -6325,10 +6323,10 @@ OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
   return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
 }
 
-static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
-                                                  const Value *RHS,
-                                                  const AddOperator *Add,
-                                                  const SimplifyQuery &SQ) {
+static OverflowResult
+computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
+                            const WithCache<const Value *> &RHS,
+                            const AddOperator *Add, const SimplifyQuery &SQ) {
   if (Add && Add->hasNoSignedWrap()) {
     return OverflowResult::NeverOverflows;
   }
@@ -6944,9 +6942,10 @@ OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
                                        Add, SQ);
 }
 
-OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS,
-                                                 const Value *RHS,
-                                                 const SimplifyQuery &SQ) {
+OverflowResult
+llvm::computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
+                                  const WithCache<const Value *> &RHS,
+                                  const SimplifyQuery &SQ) {
   return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ);
 }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 44f6e37cb3b4417..87181650e758729 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1566,7 +1566,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
     return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));
 
   // A+B --> A|B iff A and B have no bits set in common.
-  if (haveNoCommonBitsSet(LHS, RHS, SQ.getWithInstruction(&I)))
+  WithCache<const Value *> LHSCache(LHS), RHSCache(RHS);
+  if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(&I)))
     return BinaryOperator::CreateOr(LHS, RHS);
 
   if (Instruction *Ext = narrowMathIfNoOverflow(I))
@@ -1661,11 +1662,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   // willNotOverflowUnsignedAdd to reduce the number of invocations of
   // computeKnownBits.
   bool Changed = false;
-  if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) {
+  if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHSCache, RHSCache, I)) {
     Changed = true;
     I.setHasNoSignedWrap(true);
   }
-  if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) {
+  if (!I.hasNoUnsignedWrap() &&
+      willNotOverflowUnsignedAdd(LHSCache, RHSCache, I)) {
     Changed = true;
     I.setHasNoUnsignedWrap(true);
   }

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 83c127a0ef012ad..a53d67b2899b700 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -295,13 +295,15 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
 
   Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext);
 
-  bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS,
+  bool willNotOverflowSignedAdd(const WithCache<const Value *> &LHS,
+                                const WithCache<const Value *> &RHS,
                                 const Instruction &CxtI) const {
     return computeOverflowForSignedAdd(LHS, RHS, &CxtI) ==
            OverflowResult::NeverOverflows;
   }
 
-  bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS,
+  bool willNotOverflowUnsignedAdd(const WithCache<const Value *> &LHS,
+                                  const WithCache<const Value *> &RHS,
                                   const Instruction &CxtI) const {
     return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) ==
            OverflowResult::NeverOverflows;


        


More information about the llvm-commits mailing list