[llvm] [ValueTracking] NFC: Allow tracking values through AddrSpaceCasts (PR #70483)

Jeffrey Byrnes via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 27 10:28:58 PDT 2023


https://github.com/jrbyrnes created https://github.com/llvm/llvm-project/pull/70483

Provide capability to compute known bits through AddrSpaceCasts via target queries. This is mostly useful when trying to determine alignments.

>From f0bd69f50ddbccb871871fe1655ece1d1702b3f9 Mon Sep 17 00:00:00 2001
From: Jeffrey Byrnes <Jeffrey.Byrnes at amd.com>
Date: Tue, 24 Oct 2023 13:18:17 -0700
Subject: [PATCH] [ValueTracking] NFC: Allow tracking values through
 AddrSpaceCasts

Change-Id: I7b26e4e90dad483086ff170b5454e1ce69afe7d8
---
 llvm/include/llvm/Analysis/SimplifyQuery.h    |  9 ++-
 .../llvm/Analysis/TargetTransformInfo.h       | 19 ++++++
 .../llvm/Analysis/TargetTransformInfoImpl.h   |  9 +++
 llvm/include/llvm/Analysis/ValueTracking.h    |  1 +
 llvm/include/llvm/Transforms/Utils/Local.h    |  3 +-
 llvm/lib/Analysis/TargetTransformInfo.cpp     |  7 +++
 llvm/lib/Analysis/ValueTracking.cpp           | 58 ++++++++++++++-----
 llvm/lib/Transforms/IPO/GlobalOpt.cpp         | 15 +++--
 llvm/lib/Transforms/Scalar/InferAlignment.cpp | 10 +++-
 llvm/lib/Transforms/Utils/Local.cpp           |  5 +-
 .../Vectorize/LoadStoreVectorizer.cpp         | 10 ++--
 11 files changed, 114 insertions(+), 32 deletions(-)

diff --git a/llvm/include/llvm/Analysis/SimplifyQuery.h b/llvm/include/llvm/Analysis/SimplifyQuery.h
index f9cc3029221d679..0ff8a6059802cb5 100644
--- a/llvm/include/llvm/Analysis/SimplifyQuery.h
+++ b/llvm/include/llvm/Analysis/SimplifyQuery.h
@@ -16,6 +16,7 @@ namespace llvm {
 class AssumptionCache;
 class DominatorTree;
 class TargetLibraryInfo;
+class TargetTransformInfo;
 
 /// InstrInfoQuery provides an interface to query additional information for
 /// instructions like metadata or keywords like nsw, which provides conservative
@@ -62,6 +63,7 @@ struct SimplifyQuery {
   const DominatorTree *DT = nullptr;
   AssumptionCache *AC = nullptr;
   const Instruction *CxtI = nullptr;
+  const TargetTransformInfo *TTI = nullptr;
 
   // Wrapper to query additional information for instructions like metadata or
   // keywords like nsw, which provides conservative results if those cannot
@@ -86,9 +88,10 @@ struct SimplifyQuery {
 
   SimplifyQuery(const DataLayout &DL, const DominatorTree *DT,
                 AssumptionCache *AC = nullptr,
-                const Instruction *CXTI = nullptr, bool UseInstrInfo = true,
-                bool CanUseUndef = true)
-      : DL(DL), DT(DT), AC(AC), CxtI(CXTI), IIQ(UseInstrInfo),
+                const Instruction *CXTI = nullptr,
+                const TargetTransformInfo *TTI = nullptr,
+                bool UseInstrInfo = true, bool CanUseUndef = true)
+      : DL(DL), DT(DT), AC(AC), CxtI(CXTI), TTI(TTI), IIQ(UseInstrInfo),
         CanUseUndef(CanUseUndef) {}
 
   SimplifyQuery getWithInstruction(const Instruction *I) const {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 5234ef8788d9e96..e4cc3a618612e05 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -29,6 +29,7 @@
 #include "llvm/Support/AtomicOrdering.h"
 #include "llvm/Support/BranchProbability.h"
 #include "llvm/Support/InstructionCost.h"
+#include "llvm/Support/KnownBits.h"
 #include <functional>
 #include <optional>
 #include <utility>
@@ -67,6 +68,7 @@ class User;
 class Value;
 class VPIntrinsic;
 struct KnownBits;
+struct SimplifyQuery;
 
 /// Information about a load/store intrinsic defined by the target.
 struct MemIntrinsicInfo {
@@ -1674,6 +1676,11 @@ class TargetTransformInfo {
 
   /// @}
 
+  std::optional<KnownBits>
+  computeKnownBitsAddrSpaceCast(unsigned DestAS, unsigned SrcAS,
+                                const APInt &DemandedElts, KnownBits &Known,
+                                const SimplifyQuery &Q) const;
+
 private:
   /// The abstract base class used to type erase specific TTI
   /// implementations.
@@ -2041,6 +2048,10 @@ class TargetTransformInfo::Concept {
   getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual bool hasArmWideBranch(bool Thumb) const = 0;
   virtual unsigned getMaxNumArgs() const = 0;
+  virtual std::optional<KnownBits>
+  computeKnownBitsAddrSpaceCast(unsigned DestAS, unsigned SrcAS,
+                                const APInt &DemandedElts, KnownBits &Known,
+                                const SimplifyQuery &Q) const = 0;
 };
 
 template <typename T>
@@ -2757,6 +2768,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   unsigned getMaxNumArgs() const override {
     return Impl.getMaxNumArgs();
   }
+
+  std::optional<KnownBits>
+  computeKnownBitsAddrSpaceCast(unsigned DestAS, unsigned SrcAS,
+                                const APInt &DemandedElts, KnownBits &Known,
+                                const SimplifyQuery &Q) const override {
+    return Impl.computeKnownBitsAddrSpaceCast(DestAS, SrcAS, DemandedElts,
+                                              Known, Q);
+  }
 };
 
 template <typename T>
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index e14915443513990..9add41d3b6aa72e 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -28,6 +28,8 @@
 namespace llvm {
 
 class Function;
+struct KnownBits;
+struct SimplifyQuery;
 
 /// Base class for use as a mix-in that aids implementing
 /// a TargetTransformInfo-compatible class.
@@ -895,6 +897,13 @@ class TargetTransformInfoImplBase {
 
   unsigned getMaxNumArgs() const { return UINT_MAX; }
 
+  std::optional<KnownBits>
+  computeKnownBitsAddrSpaceCast(unsigned DestAS, unsigned SrcAS,
+                                const APInt &DemandedElts, KnownBits &Known,
+                                const SimplifyQuery &Q) const {
+    return std::nullopt;
+  }
+
 protected:
   // Obtain the minimum required size to hold the value (without the sign)
   // In case of a vector it returns the min required size for one element.
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 0e02d0d5b4865da..141b1fa2ea3de1b 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -81,6 +81,7 @@ KnownBits computeKnownBits(const Value *V, const DataLayout &DL,
                            unsigned Depth = 0, AssumptionCache *AC = nullptr,
                            const Instruction *CxtI = nullptr,
                            const DominatorTree *DT = nullptr,
+                           const TargetTransformInfo *TTI = nullptr,
                            bool UseInstrInfo = true);
 
 /// Returns the known bits rather than passing by reference.
diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h
index fa8405a6191eba8..c533af76b84740f 100644
--- a/llvm/include/llvm/Transforms/Utils/Local.h
+++ b/llvm/include/llvm/Transforms/Utils/Local.h
@@ -235,7 +235,8 @@ Align getOrEnforceKnownAlignment(Value *V, MaybeAlign PrefAlign,
                                  const DataLayout &DL,
                                  const Instruction *CxtI = nullptr,
                                  AssumptionCache *AC = nullptr,
-                                 const DominatorTree *DT = nullptr);
+                                 const DominatorTree *DT = nullptr,
+                                 const TargetTransformInfo *TTI = nullptr);
 
 /// Try to infer an alignment for the specified pointer.
 inline Align getKnownAlignment(Value *V, const DataLayout &DL,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index aad14f21d114619..5aa824d7dd6e128 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1248,6 +1248,13 @@ bool TargetTransformInfo::hasActiveVectorLength(unsigned Opcode, Type *DataType,
   return TTIImpl->hasActiveVectorLength(Opcode, DataType, Alignment);
 }
 
+std::optional<KnownBits> TargetTransformInfo::computeKnownBitsAddrSpaceCast(
+    unsigned DestAS, unsigned SrcAS, const APInt &DemandedElts,
+    KnownBits &Known, const SimplifyQuery &Q) const {
+  return TTIImpl->computeKnownBitsAddrSpaceCast(DestAS, SrcAS, DemandedElts,
+                                                Known, Q);
+}
+
 TargetTransformInfo::Concept::~Concept() = default;
 
 TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {}
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index c303d261107eb19..a5b22b74428fc63 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -32,6 +32,7 @@
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/Analysis/WithCache.h"
 #include "llvm/IR/Argument.h"
@@ -166,7 +167,7 @@ void llvm::computeKnownBits(const Value *V, KnownBits &Known,
                             const DominatorTree *DT, bool UseInstrInfo) {
   ::computeKnownBits(
       V, Known, Depth,
-      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
+      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), nullptr, UseInstrInfo));
 }
 
 void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
@@ -176,15 +177,18 @@ void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
                             bool UseInstrInfo) {
   ::computeKnownBits(
       V, DemandedElts, Known, Depth,
-      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
+      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), nullptr, UseInstrInfo));
 }
 
 KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
                                  unsigned Depth, AssumptionCache *AC,
                                  const Instruction *CxtI,
-                                 const DominatorTree *DT, bool UseInstrInfo) {
+                                 const DominatorTree *DT,
+                                 const TargetTransformInfo *TTI,
+                                 bool UseInstrInfo) {
   return computeKnownBits(
-      V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
+      V, Depth,
+      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), TTI, UseInstrInfo, true));
 }
 
 KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
@@ -193,7 +197,7 @@ KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
                                  const DominatorTree *DT, bool UseInstrInfo) {
   return computeKnownBits(
       V, DemandedElts, Depth,
-      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
+      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), nullptr, UseInstrInfo));
 }
 
 bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
@@ -270,7 +274,7 @@ bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
                                   const DominatorTree *DT, bool UseInstrInfo) {
   return ::isKnownToBeAPowerOfTwo(
       V, OrZero, Depth,
-      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
+      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), nullptr, UseInstrInfo));
 }
 
 static bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
@@ -283,14 +287,16 @@ bool llvm::isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth,
                           AssumptionCache *AC, const Instruction *CxtI,
                           const DominatorTree *DT, bool UseInstrInfo) {
   return ::isKnownNonZero(
-      V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
+      V, Depth,
+      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), nullptr, UseInstrInfo));
 }
 
 bool llvm::isKnownNonNegative(const Value *V, const DataLayout &DL,
                               unsigned Depth, AssumptionCache *AC,
                               const Instruction *CxtI, const DominatorTree *DT,
                               bool UseInstrInfo) {
-  KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT, UseInstrInfo);
+  KnownBits Known =
+      computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo);
   return Known.isNonNegative();
 }
 
@@ -309,7 +315,8 @@ bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth,
 bool llvm::isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth,
                            AssumptionCache *AC, const Instruction *CxtI,
                            const DominatorTree *DT, bool UseInstrInfo) {
-  KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT, UseInstrInfo);
+  KnownBits Known =
+      computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo);
   return Known.isNegative();
 }
 
@@ -322,7 +329,7 @@ bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
                            bool UseInstrInfo) {
   return ::isKnownNonEqual(
       V1, V2, 0,
-      SimplifyQuery(DL, DT, AC, safeCxtI(V2, V1, CxtI), UseInstrInfo));
+      SimplifyQuery(DL, DT, AC, safeCxtI(V2, V1, CxtI), nullptr, UseInstrInfo));
 }
 
 static bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth,
@@ -334,7 +341,7 @@ bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask,
                              const DominatorTree *DT, bool UseInstrInfo) {
   return ::MaskedValueIsZero(
       V, Mask, Depth,
-      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
+      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), nullptr, UseInstrInfo));
 }
 
 static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
@@ -353,7 +360,8 @@ unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL,
                                   const Instruction *CxtI,
                                   const DominatorTree *DT, bool UseInstrInfo) {
   return ::ComputeNumSignBits(
-      V, Depth, SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo));
+      V, Depth,
+      SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), nullptr, UseInstrInfo));
 }
 
 unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
@@ -961,7 +969,7 @@ KnownBits llvm::analyzeKnownBitsFromAndXorOr(
 
   return getKnownBitsFromAndXorOr(
       I, DemandedElts, KnownLHS, KnownRHS, Depth,
-      SimplifyQuery(DL, DT, AC, safeCxtI(I, CxtI), UseInstrInfo));
+      SimplifyQuery(DL, DT, AC, safeCxtI(I, CxtI), nullptr, UseInstrInfo));
 }
 
 ConstantRange llvm::getVScaleRange(const Function *F, unsigned BitWidth) {
@@ -988,7 +996,6 @@ static void computeKnownBitsFromOperator(const Operator *I,
                                          KnownBits &Known, unsigned Depth,
                                          const SimplifyQuery &Q) {
   unsigned BitWidth = Known.getBitWidth();
-
   KnownBits Known2(BitWidth);
   switch (I->getOpcode()) {
   default: break;
@@ -1775,6 +1782,28 @@ static void computeKnownBitsFromOperator(const Operator *I,
                                   Depth + 1))
       computeKnownBits(I->getOperand(0), Known, Depth + 1, Q);
     break;
+  case Instruction::AddrSpaceCast: {
+    auto ASC = dyn_cast<AddrSpaceCastOperator>(I);
+    assert(ASC);
+    unsigned SrcAS = ASC->getSrcAddressSpace();
+    unsigned DestAS = ASC->getDestAddressSpace();
+    if (Q.TTI) {
+      // If we cannot look through the AddrSpaceCast and convert back, then the
+      // whole process fails and reverts to using the previous state of
+      // KnownBits.
+      auto Known3 = Q.TTI->computeKnownBitsAddrSpaceCast(
+          DestAS, SrcAS, DemandedElts, Known, Q);
+      if (!Known3)
+        break;
+      computeKnownBits(I->getOperand(0), DemandedElts, *Known3, Depth + 1, Q);
+      Known3 = Q.TTI->computeKnownBitsAddrSpaceCast(SrcAS, DestAS, DemandedElts,
+                                                    *Known3, Q);
+      if (!Known3)
+        break;
+      Known = *Known3;
+    }
+    break;
+  }
   }
 }
 
@@ -1840,6 +1869,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
   }
 
   Type *ScalarTy = Ty->getScalarType();
+
   if (ScalarTy->isPointerTy()) {
     assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) &&
            "V and Known should have same BitWidth");
diff --git a/llvm/lib/Transforms/IPO/GlobalOpt.cpp b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
index 56a52d13c20d0cc..8aa97391e641213 100644
--- a/llvm/lib/Transforms/IPO/GlobalOpt.cpp
+++ b/llvm/lib/Transforms/IPO/GlobalOpt.cpp
@@ -506,7 +506,9 @@ static void transferSRADebugInfo(GlobalVariable *GV, GlobalVariable *NGV,
 /// program in a more fine-grained way.  We have determined that this
 /// transformation is safe already.  We return the first global variable we
 /// insert so that the caller can reprocess it.
-static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) {
+static GlobalVariable *
+SRAGlobal(GlobalVariable *GV, const DataLayout &DL,
+          function_ref<TargetTransformInfo &(Function &)> GetTTI) {
   assert(GV->hasLocalLinkage());
 
   // Collect types to split into.
@@ -608,9 +610,14 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) {
       assert(NGV && "Must have replacement global for this offset");
 
       // Update the pointer operand and recalculate alignment.
+
       Align PrefAlign = DL.getPrefTypeAlign(getLoadStoreType(V));
-      Align NewAlign =
-          getOrEnforceKnownAlignment(NGV, PrefAlign, DL, cast<Instruction>(V));
+      TargetTransformInfo *FTTI = nullptr;
+      if (auto I = dyn_cast<Instruction>(V)) {
+        FTTI = &GetTTI(*I->getFunction());
+      }
+      Align NewAlign = getOrEnforceKnownAlignment(
+          NGV, PrefAlign, DL, cast<Instruction>(V), nullptr, nullptr, FTTI);
 
       if (auto *LI = dyn_cast<LoadInst>(V)) {
         LI->setOperand(0, NGV);
@@ -1530,7 +1537,7 @@ processInternalGlobal(GlobalVariable *GV, const GlobalStatus &GS,
   }
   if (!GV->getInitializer()->getType()->isSingleValueType()) {
     const DataLayout &DL = GV->getParent()->getDataLayout();
-    if (SRAGlobal(GV, DL))
+    if (SRAGlobal(GV, DL, GetTTI))
       return true;
   }
   Value *StoredOnceValue = GS.getStoredOnceValue();
diff --git a/llvm/lib/Transforms/Scalar/InferAlignment.cpp b/llvm/lib/Transforms/Scalar/InferAlignment.cpp
index b75b8d486fbbe8b..63144e3dd98573c 100644
--- a/llvm/lib/Transforms/Scalar/InferAlignment.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAlignment.cpp
@@ -13,6 +13,7 @@
 
 #include "llvm/Transforms/Scalar/InferAlignment.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/InitializePasses.h"
@@ -47,7 +48,8 @@ static bool tryToImproveAlign(
   return false;
 }
 
-bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
+bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT,
+                    TargetTransformInfo &TTI) {
   const DataLayout &DL = F.getParent()->getDataLayout();
   bool Changed = false;
 
@@ -70,7 +72,8 @@ bool inferAlignment(Function &F, AssumptionCache &AC, DominatorTree &DT) {
     for (Instruction &I : BB) {
       Changed |= tryToImproveAlign(
           DL, &I, [&](Value *PtrOp, Align OldAlign, Align PrefAlign) {
-            KnownBits Known = computeKnownBits(PtrOp, DL, 0, &AC, &I, &DT);
+            KnownBits Known =
+                computeKnownBits(PtrOp, DL, 0, &AC, &I, &DT, &TTI);
             unsigned TrailZ = std::min(Known.countMinTrailingZeros(),
                                        +Value::MaxAlignmentExponent);
             return Align(1ull << std::min(Known.getBitWidth() - 1, TrailZ));
@@ -85,7 +88,8 @@ PreservedAnalyses InferAlignmentPass::run(Function &F,
                                           FunctionAnalysisManager &AM) {
   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
-  inferAlignment(F, AC, DT);
+  TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
+  inferAlignment(F, AC, DT, TTI);
   // Changes to alignment shouldn't invalidated analyses.
   return PreservedAnalyses::all();
 }
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index a255db7dafa79d0..31b0a1e78301b7c 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -1443,11 +1443,12 @@ Align llvm::getOrEnforceKnownAlignment(Value *V, MaybeAlign PrefAlign,
                                        const DataLayout &DL,
                                        const Instruction *CxtI,
                                        AssumptionCache *AC,
-                                       const DominatorTree *DT) {
+                                       const DominatorTree *DT,
+                                       const TargetTransformInfo *TTI) {
   assert(V->getType()->isPointerTy() &&
          "getOrEnforceKnownAlignment expects a pointer!");
 
-  KnownBits Known = computeKnownBits(V, DL, 0, AC, CxtI, DT);
+  KnownBits Known = computeKnownBits(V, DL, 0, AC, CxtI, DT, TTI);
   unsigned TrailZ = Known.countMinTrailingZeros();
 
   // Avoid trouble with ridiculously large TrailZ values, such as
diff --git a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
index 73a80702671922b..054120abe0744c5 100644
--- a/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp
@@ -809,7 +809,7 @@ std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
       if (IsAllocaAccess && Alignment.value() % SizeBytes != 0 &&
           IsAllowedAndFast(PrefAlign)) {
         Align NewAlign = getOrEnforceKnownAlignment(
-            PtrOperand, PrefAlign, DL, C[CBegin].Inst, nullptr, &DT);
+            PtrOperand, PrefAlign, DL, C[CBegin].Inst, nullptr, &DT, &TTI);
         if (NewAlign >= Alignment) {
           LLVM_DEBUG(dbgs()
                      << "LSV: splitByChain upgrading alloca alignment from "
@@ -876,10 +876,10 @@ bool Vectorizer::vectorizeChain(Chain &C) {
   // If this is a load/store of an alloca, we might have upgraded the alloca's
   // alignment earlier.  Get the new alignment.
   if (AS == DL.getAllocaAddrSpace()) {
-    Alignment = std::max(
-        Alignment,
-        getOrEnforceKnownAlignment(getLoadStorePointerOperand(C[0].Inst),
-                                   MaybeAlign(), DL, C[0].Inst, nullptr, &DT));
+    Alignment = std::max(Alignment,
+                         getOrEnforceKnownAlignment(
+                             getLoadStorePointerOperand(C[0].Inst),
+                             MaybeAlign(), DL, C[0].Inst, nullptr, &DT, &TTI));
   }
 
   // All elements of the chain must have the same scalar-type size.



More information about the llvm-commits mailing list