[llvm] 97dbf38 - [SCEVExpander] Add SCEVUseVisitor and use it in SCEVExpander (NFC) (#188863)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 2 08:08:06 PDT 2026


Author: Florian Hahn
Date: 2026-04-02T15:08:01Z
New Revision: 97dbf38c9c495ce9fb958137957cb7794ef3285b

URL: https://github.com/llvm/llvm-project/commit/97dbf38c9c495ce9fb958137957cb7794ef3285b
DIFF: https://github.com/llvm/llvm-project/commit/97dbf38c9c495ce9fb958137957cb7794ef3285b.diff

LOG: [SCEVExpander] Add SCEVUseVisitor and use it in SCEVExpander (NFC) (#188863)

Add SCEVUseVisitor, a new visitor class where all visit methods receive
a SCEVUse instead of a const SCEV*. Use it for SCEVExpander, so it can
use use-specific flags in the future.

PR: https://github.com/llvm/llvm-project/pull/188863

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
    llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
    llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index f20b73f9f358c..08c7e43e708b8 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -68,17 +68,22 @@ LLVM_ABI extern bool VerifySCEV;
 
 class SCEV;
 
-struct SCEVUse : PointerIntPair<const SCEV *, 2> {
-  SCEVUse() : PointerIntPair() { setFromOpaqueValue(nullptr); }
-  SCEVUse(const SCEV *S) : PointerIntPair() {
-    setFromOpaqueValue(reinterpret_cast<void *>(const_cast<SCEV *>(S)));
-  }
-  SCEVUse(const SCEV *S, unsigned Flags) : PointerIntPair(S, Flags) {}
+template <typename SCEVPtrT = const SCEV *>
+struct SCEVUseT : PointerIntPair<SCEVPtrT, 2> {
+  using Base = PointerIntPair<SCEVPtrT, 2>;
+
+  SCEVUseT() : Base() { Base::setFromOpaqueValue(nullptr); }
+  SCEVUseT(SCEVPtrT S) : SCEVUseT(S, 0) {}
+  SCEVUseT(SCEVPtrT S, unsigned Flags) : Base(S, Flags) {}
+  template <typename OtherPtrT, typename = std::enable_if_t<
+                                    std::is_convertible_v<OtherPtrT, SCEVPtrT>>>
+  SCEVUseT(const SCEVUseT<OtherPtrT> &Other)
+      : Base(Other.getPointer(), Other.getInt()) {}
 
-  operator const SCEV *() const { return getPointer(); }
-  const SCEV *operator->() const { return getPointer(); }
+  operator SCEVPtrT() const { return Base::getPointer(); }
+  SCEVPtrT operator->() const { return Base::getPointer(); }
 
-  void *getRawPointer() const { return getOpaqueValue(); }
+  void *getRawPointer() const { return Base::getOpaqueValue(); }
 
   /// Returns true of the SCEVUse is canonical, i.e. no SCEVUse flags set in any
   /// operands.
@@ -87,9 +92,9 @@ struct SCEVUse : PointerIntPair<const SCEV *, 2> {
   /// Return the canonical SCEV for this SCEVUse.
   const SCEV *getCanonical() const;
 
-  unsigned getFlags() const { return getInt(); }
+  unsigned getFlags() const { return Base::getInt(); }
 
-  bool operator==(const SCEVUse &RHS) const {
+  bool operator==(const SCEVUseT &RHS) const {
     return getRawPointer() == RHS.getRawPointer();
   }
 
@@ -103,6 +108,11 @@ struct SCEVUse : PointerIntPair<const SCEV *, 2> {
   void dump() const;
 };
 
+/// Deduction guide for various SCEV subclass pointers.
+template <typename SCEVPtrT> SCEVUseT(SCEVPtrT) -> SCEVUseT<SCEVPtrT>;
+
+using SCEVUse = SCEVUseT<const SCEV *>;
+
 /// Provide PointerLikeTypeTraits for SCEVUse, so it can be used with
 /// SmallPtrSet, among others.
 template <> struct PointerLikeTypeTraits<SCEVUse> {
@@ -145,6 +155,31 @@ template <> struct simplify_type<SCEVUse> {
   }
 };
 
+/// Provide CastInfo for SCEVUseT so that cast<SCEVUseT<const To *>>(use)
+/// returns SCEVUseT<const To *> with flags preserved.
+template <typename ToSCEVPtrT>
+struct CastInfo<SCEVUseT<ToSCEVPtrT>, SCEVUse,
+                std::enable_if_t<!is_simple_type<SCEVUse>::value>> {
+  using To = std::remove_cv_t<std::remove_pointer_t<ToSCEVPtrT>>;
+  using CastReturnType = SCEVUseT<ToSCEVPtrT>;
+
+  static bool isPossible(const SCEVUse &U) { return isa<To>(U.getPointer()); }
+  static CastReturnType doCast(const SCEVUse &U) {
+    return {cast<To>(U.getPointer()), U.getFlags()};
+  }
+  static CastReturnType castFailed() { return CastReturnType(nullptr); }
+  static CastReturnType doCastIfPossible(const SCEVUse &U) {
+    if (!isPossible(U))
+      return castFailed();
+    return doCast(U);
+  }
+};
+
+template <typename ToSCEVPtrT>
+struct CastInfo<SCEVUseT<ToSCEVPtrT>, const SCEVUse,
+                std::enable_if_t<!is_simple_type<const SCEVUse>::value>>
+    : CastInfo<SCEVUseT<ToSCEVPtrT>, SCEVUse> {};
+
 /// This class represents an analyzed expression in the program.  These are
 /// opaque objects that the client is not allowed to do much with directly.
 ///
@@ -2655,9 +2690,27 @@ template <> struct DenseMapInfo<ScalarEvolution::FoldID> {
   }
 };
 
-inline const SCEV *SCEVUse::getCanonical() const {
-  return getPointer()->getCanonical();
+template <> inline const SCEV *SCEVUseT<const SCEV *>::getCanonical() const {
+  return Base::getPointer()->getCanonical();
+}
+
+template <typename SCEVPtrT>
+void SCEVUseT<SCEVPtrT>::print(raw_ostream &OS) const {
+  Base::getPointer()->print(OS);
+  SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(Base::getInt());
+  if (Flags & SCEV::FlagNUW)
+    OS << "(u nuw)";
+  if (Flags & SCEV::FlagNSW)
+    OS << "(u nsw)";
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+template <typename SCEVPtrT>
+LLVM_DUMP_METHOD void SCEVUseT<SCEVPtrT>::dump() const {
+  print(dbgs());
+  dbgs() << '\n';
 }
+#endif
 
 } // end namespace llvm
 

diff  --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 5199bbede7f84..2fc928d5955d1 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -672,6 +672,71 @@ template <typename SC, typename RetVal = void> struct SCEVVisitor {
   }
 };
 
+/// A visitor class for SCEVUse.
+template <typename SC, typename RetVal = void> struct SCEVUseVisitor {
+  RetVal visit(SCEVUse S) {
+    switch (S->getSCEVType()) {
+    case scConstant:
+      return ((SC *)this)
+          ->visitConstant(cast<SCEVUseT<const SCEVConstant *>>(S));
+    case scVScale:
+      return ((SC *)this)->visitVScale(cast<SCEVUseT<const SCEVVScale *>>(S));
+    case scPtrToAddr:
+      return ((SC *)this)
+          ->visitPtrToAddrExpr(cast<SCEVUseT<const SCEVPtrToAddrExpr *>>(S));
+    case scPtrToInt:
+      return ((SC *)this)
+          ->visitPtrToIntExpr(cast<SCEVUseT<const SCEVPtrToIntExpr *>>(S));
+    case scTruncate:
+      return ((SC *)this)
+          ->visitTruncateExpr(cast<SCEVUseT<const SCEVTruncateExpr *>>(S));
+    case scZeroExtend:
+      return ((SC *)this)
+          ->visitZeroExtendExpr(cast<SCEVUseT<const SCEVZeroExtendExpr *>>(S));
+    case scSignExtend:
+      return ((SC *)this)
+          ->visitSignExtendExpr(cast<SCEVUseT<const SCEVSignExtendExpr *>>(S));
+    case scAddExpr:
+      return ((SC *)this)->visitAddExpr(cast<SCEVUseT<const SCEVAddExpr *>>(S));
+    case scMulExpr:
+      return ((SC *)this)->visitMulExpr(cast<SCEVUseT<const SCEVMulExpr *>>(S));
+    case scUDivExpr:
+      return ((SC *)this)
+          ->visitUDivExpr(cast<SCEVUseT<const SCEVUDivExpr *>>(S));
+    case scAddRecExpr:
+      return ((SC *)this)
+          ->visitAddRecExpr(cast<SCEVUseT<const SCEVAddRecExpr *>>(S));
+    case scSMaxExpr:
+      return ((SC *)this)
+          ->visitSMaxExpr(cast<SCEVUseT<const SCEVSMaxExpr *>>(S));
+    case scUMaxExpr:
+      return ((SC *)this)
+          ->visitUMaxExpr(cast<SCEVUseT<const SCEVUMaxExpr *>>(S));
+    case scSMinExpr:
+      return ((SC *)this)
+          ->visitSMinExpr(cast<SCEVUseT<const SCEVSMinExpr *>>(S));
+    case scUMinExpr:
+      return ((SC *)this)
+          ->visitUMinExpr(cast<SCEVUseT<const SCEVUMinExpr *>>(S));
+    case scSequentialUMinExpr:
+      return ((SC *)this)
+          ->visitSequentialUMinExpr(
+              cast<SCEVUseT<const SCEVSequentialUMinExpr *>>(S));
+    case scUnknown:
+      return ((SC *)this)->visitUnknown(cast<SCEVUseT<const SCEVUnknown *>>(S));
+    case scCouldNotCompute:
+      return ((SC *)this)
+          ->visitCouldNotCompute(
+              cast<SCEVUseT<const SCEVCouldNotCompute *>>(S));
+    }
+    llvm_unreachable("Unknown SCEV kind!");
+  }
+
+  RetVal visitCouldNotCompute(SCEVUseT<const SCEVCouldNotCompute *> S) {
+    llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
+  }
+};
+
 /// Visit all nodes in the expression tree using worklist traversal.
 ///
 /// Visitor implements:

diff  --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index d678e23afc18d..7b045e150c76b 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -23,7 +23,8 @@ template <typename Pattern> bool match(const SCEV *S, const Pattern &P) {
   return P.match(S);
 }
 
-template <typename Pattern> bool match(const SCEVUse U, const Pattern &P) {
+template <typename SCEVPtrT, typename Pattern>
+bool match(const SCEVUseT<SCEVPtrT> U, const Pattern &P) {
   return P.match(U.getPointer());
 }
 
@@ -87,10 +88,10 @@ template <typename Class> struct bind_ty {
   }
 };
 
-template <> struct bind_ty<SCEVUse> {
-  SCEVUse &VR;
+template <typename SCEVPtrT> struct bind_ty<SCEVUseT<SCEVPtrT>> {
+  SCEVUseT<SCEVPtrT> &VR;
 
-  bind_ty(SCEVUse &V) : VR(V) {}
+  bind_ty(SCEVUseT<SCEVPtrT> &V) : VR(V) {}
 
   template <typename ITy> bool match(ITy *V) const {
     VR = V;
@@ -100,7 +101,11 @@ template <> struct bind_ty<SCEVUse> {
 
 /// Match a SCEV, capturing it if we match.
 inline bind_ty<const SCEV> m_SCEV(const SCEV *&V) { return V; }
-inline bind_ty<SCEVUse> m_SCEV(SCEVUse &V) { return V; }
+
+template <typename SCEVPtrT>
+inline bind_ty<SCEVUseT<SCEVPtrT>> m_SCEV(SCEVUseT<SCEVPtrT> &V) {
+  return V;
+}
 inline bind_ty<const SCEVConstant> m_SCEVConstant(const SCEVConstant *&V) {
   return V;
 }

diff  --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 2559d7b89b020..42355f5841eab 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -61,7 +61,7 @@ struct PoisonFlags {
 /// Clients should create an instance of this class when rewriting is needed,
 /// and destroy it when finished to allow the release of the associated
 /// memory.
-class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
+class SCEVExpander : public SCEVUseVisitor<SCEVExpander, Value *> {
   friend class SCEVExpanderCleaner;
 
   ScalarEvolution &SE;
@@ -74,7 +74,7 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   bool PreserveLCSSA;
 
   // InsertedExpressions caches Values for reuse, so must track RAUW.
-  DenseMap<std::pair<const SCEV *, Instruction *>, TrackingVH<Value>>
+  DenseMap<std::pair<SCEVUse, Instruction *>, TrackingVH<Value>>
       InsertedExpressions;
 
   // InsertedValues only flags inserted instructions so needs no RAUW.
@@ -179,7 +179,7 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   const char *DebugType;
 #endif
 
-  friend struct SCEVVisitor<SCEVExpander, Value *>;
+  friend struct SCEVUseVisitor<SCEVExpander, Value *>;
 
 public:
   /// Construct a SCEVExpander in "canonical" mode.
@@ -313,9 +313,8 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
 
   /// Insert code to directly compute the specified SCEV expression into the
   /// program.  The code is inserted into the specified block.
-  LLVM_ABI Value *expandCodeFor(const SCEV *SH, Type *Ty,
-                                BasicBlock::iterator I);
-  Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I) {
+  LLVM_ABI Value *expandCodeFor(SCEVUse SH, Type *Ty, BasicBlock::iterator I);
+  Value *expandCodeFor(SCEVUse SH, Type *Ty, Instruction *I) {
     return expandCodeFor(SH, Ty, I->getIterator());
   }
 
@@ -323,7 +322,7 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   /// program.  The code is inserted into the SCEVExpander's current
   /// insertion point. If a type is specified, the result will be expanded to
   /// have that type, with a cast if necessary.
-  LLVM_ABI Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr);
+  LLVM_ABI Value *expandCodeFor(SCEVUse SH, Type *Ty = nullptr);
 
   /// Generates a code sequence that evaluates this predicate.  The inserted
   /// instructions will be at position \p Loc.  The result will be of type i1
@@ -475,15 +474,15 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   /// DropPoisonGeneratingInsts is populated with instructions for which
   /// poison-generating flags must be dropped if the value is reused.
   Value *FindValueInExprValueMap(
-      const SCEV *S, const Instruction *InsertPt,
+      SCEVUse S, const Instruction *InsertPt,
       SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts);
 
-  LLVM_ABI Value *expand(const SCEV *S);
-  Value *expand(const SCEV *S, BasicBlock::iterator I) {
+  LLVM_ABI Value *expand(SCEVUse S);
+  Value *expand(SCEVUse S, BasicBlock::iterator I) {
     setInsertPoint(I);
     return expand(S);
   }
-  Value *expand(const SCEV *S, Instruction *I) {
+  Value *expand(SCEVUse S, Instruction *I) {
     setInsertPoint(I);
     return expand(S);
   }
@@ -491,42 +490,45 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   /// Determine the most "relevant" loop for the given SCEV.
   const Loop *getRelevantLoop(const SCEV *);
 
-  Value *expandMinMaxExpr(const SCEVNAryExpr *S, Intrinsic::ID IntrinID,
-                          Twine Name, bool IsSequential = false);
+  Value *expandMinMaxExpr(SCEVUseT<const SCEVNAryExpr *> S,
+                          Intrinsic::ID IntrinID, Twine Name,
+                          bool IsSequential = false);
 
-  Value *visitConstant(const SCEVConstant *S) { return S->getValue(); }
+  Value *visitConstant(SCEVUseT<const SCEVConstant *> S) {
+    return S->getValue();
+  }
 
-  Value *visitVScale(const SCEVVScale *S);
+  Value *visitVScale(SCEVUseT<const SCEVVScale *> S);
 
-  Value *visitPtrToAddrExpr(const SCEVPtrToAddrExpr *S);
+  Value *visitPtrToAddrExpr(SCEVUseT<const SCEVPtrToAddrExpr *> S);
 
-  Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S);
+  Value *visitPtrToIntExpr(SCEVUseT<const SCEVPtrToIntExpr *> S);
 
-  Value *visitTruncateExpr(const SCEVTruncateExpr *S);
+  Value *visitTruncateExpr(SCEVUseT<const SCEVTruncateExpr *> S);
 
-  Value *visitZeroExtendExpr(const SCEVZeroExtendExpr *S);
+  Value *visitZeroExtendExpr(SCEVUseT<const SCEVZeroExtendExpr *> S);
 
-  Value *visitSignExtendExpr(const SCEVSignExtendExpr *S);
+  Value *visitSignExtendExpr(SCEVUseT<const SCEVSignExtendExpr *> S);
 
-  Value *visitAddExpr(const SCEVAddExpr *S);
+  Value *visitAddExpr(SCEVUseT<const SCEVAddExpr *> S);
 
-  Value *visitMulExpr(const SCEVMulExpr *S);
+  Value *visitMulExpr(SCEVUseT<const SCEVMulExpr *> S);
 
-  Value *visitUDivExpr(const SCEVUDivExpr *S);
+  Value *visitUDivExpr(SCEVUseT<const SCEVUDivExpr *> S);
 
-  Value *visitAddRecExpr(const SCEVAddRecExpr *S);
+  Value *visitAddRecExpr(SCEVUseT<const SCEVAddRecExpr *> S);
 
-  Value *visitSMaxExpr(const SCEVSMaxExpr *S);
+  Value *visitSMaxExpr(SCEVUseT<const SCEVSMaxExpr *> S);
 
-  Value *visitUMaxExpr(const SCEVUMaxExpr *S);
+  Value *visitUMaxExpr(SCEVUseT<const SCEVUMaxExpr *> S);
 
-  Value *visitSMinExpr(const SCEVSMinExpr *S);
+  Value *visitSMinExpr(SCEVUseT<const SCEVSMinExpr *> S);
 
-  Value *visitUMinExpr(const SCEVUMinExpr *S);
+  Value *visitUMinExpr(SCEVUseT<const SCEVUMinExpr *> S);
 
-  Value *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S);
+  Value *visitSequentialUMinExpr(SCEVUseT<const SCEVSequentialUMinExpr *> S);
 
-  Value *visitUnknown(const SCEVUnknown *S) { return S->getValue(); }
+  Value *visitUnknown(SCEVUseT<const SCEVUnknown *> S) { return S->getValue(); }
 
   LLVM_ABI void rememberInstruction(Value *I);
 
@@ -536,8 +538,8 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
 
   bool isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV, const Loop *L);
 
-  Value *tryToReuseLCSSAPhi(const SCEVAddRecExpr *S);
-  Value *expandAddRecExprLiterally(const SCEVAddRecExpr *);
+  Value *tryToReuseLCSSAPhi(SCEVUseT<const SCEVAddRecExpr *> S);
+  Value *expandAddRecExprLiterally(SCEVUseT<const SCEVAddRecExpr *> S);
   PHINode *getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
                                      const Loop *L, Type *&TruncTy,
                                      bool &InvertStep);

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 73dc298e66e6f..0e999e41a9e3e 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -336,22 +336,6 @@ void SCEV::computeAndSetCanonical(ScalarEvolution &SE) {
   }
 }
 
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-LLVM_DUMP_METHOD void SCEVUse::dump() const {
-  print(dbgs());
-  dbgs() << '\n';
-}
-#endif
-
-void SCEVUse::print(raw_ostream &OS) const {
-  getPointer()->print(OS);
-  SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(getInt());
-  if (Flags & SCEV::FlagNUW)
-    OS << "(u nuw)";
-  if (Flags & SCEV::FlagNSW)
-    OS << "(u nsw)";
-}
-
 //===----------------------------------------------------------------------===//
 // Implementation of the SCEV class.
 //

diff  --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index ac60837584763..3509a204a3d4e 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -523,7 +523,7 @@ class LoopCompare {
 
 }
 
-Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
+Value *SCEVExpander::visitAddExpr(SCEVUseT<const SCEVAddExpr *> S) {
   // Recognize the canonical representation of an unsimplifed urem.
   const SCEV *URemLHS = nullptr;
   const SCEV *URemRHS = nullptr;
@@ -595,7 +595,7 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
   return Sum;
 }
 
-Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
+Value *SCEVExpander::visitMulExpr(SCEVUseT<const SCEVMulExpr *> S) {
   Type *Ty = S->getType();
 
   // Collect all the mul operands in a loop, along with their associated loops.
@@ -687,7 +687,7 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
   return Prod;
 }
 
-Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
+Value *SCEVExpander::visitUDivExpr(SCEVUseT<const SCEVUDivExpr *> S) {
   Value *LHS = expand(S->getLHS());
   if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) {
     const APInt &RHS = SC->getAPInt();
@@ -1156,7 +1156,8 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
   return PN;
 }
 
-Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
+Value *
+SCEVExpander::expandAddRecExprLiterally(SCEVUseT<const SCEVAddRecExpr *> S) {
   const Loop *L = S->getLoop();
 
   // Determine a normalized form of this expression, which is the expression
@@ -1246,7 +1247,7 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
   return Result;
 }
 
-Value *SCEVExpander::tryToReuseLCSSAPhi(const SCEVAddRecExpr *S) {
+Value *SCEVExpander::tryToReuseLCSSAPhi(SCEVUseT<const SCEVAddRecExpr *> S) {
   Type *STy = S->getType();
   const Loop *L = S->getLoop();
   BasicBlock *EB = L->getExitBlock();
@@ -1308,7 +1309,7 @@ Value *SCEVExpander::tryToReuseLCSSAPhi(const SCEVAddRecExpr *S) {
   return nullptr;
 }
 
-Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
+Value *SCEVExpander::visitAddRecExpr(SCEVUseT<const SCEVAddRecExpr *> S) {
   // In canonical mode we compute the addrec as an expression of a canonical IV
   // using evaluateAtIteration and expand the resulting SCEV expression. This
   // way we avoid introducing new IVs to carry on the computation of the addrec
@@ -1448,7 +1449,7 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
   return expand(T);
 }
 
-Value *SCEVExpander::visitPtrToAddrExpr(const SCEVPtrToAddrExpr *S) {
+Value *SCEVExpander::visitPtrToAddrExpr(SCEVUseT<const SCEVPtrToAddrExpr *> S) {
   Value *V = expand(S->getOperand());
   Type *Ty = S->getType();
 
@@ -1468,29 +1469,31 @@ Value *SCEVExpander::visitPtrToAddrExpr(const SCEVPtrToAddrExpr *S) {
                            GetOptimalInsertionPointForCastOf(V));
 }
 
-Value *SCEVExpander::visitPtrToIntExpr(const SCEVPtrToIntExpr *S) {
+Value *SCEVExpander::visitPtrToIntExpr(SCEVUseT<const SCEVPtrToIntExpr *> S) {
   Value *V = expand(S->getOperand());
   return ReuseOrCreateCast(V, S->getType(), CastInst::PtrToInt,
                            GetOptimalInsertionPointForCastOf(V));
 }
 
-Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) {
+Value *SCEVExpander::visitTruncateExpr(SCEVUseT<const SCEVTruncateExpr *> S) {
   Value *V = expand(S->getOperand());
   return Builder.CreateTrunc(V, S->getType());
 }
 
-Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) {
+Value *
+SCEVExpander::visitZeroExtendExpr(SCEVUseT<const SCEVZeroExtendExpr *> S) {
   Value *V = expand(S->getOperand());
   return Builder.CreateZExt(V, S->getType(), "",
                             SE.isKnownNonNegative(S->getOperand()));
 }
 
-Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
+Value *
+SCEVExpander::visitSignExtendExpr(SCEVUseT<const SCEVSignExtendExpr *> S) {
   Value *V = expand(S->getOperand());
   return Builder.CreateSExt(V, S->getType());
 }
 
-Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
+Value *SCEVExpander::expandMinMaxExpr(SCEVUseT<const SCEVNAryExpr *> S,
                                       Intrinsic::ID IntrinID, Twine Name,
                                       bool IsSequential) {
   bool PrevSafeMode = SafeUDivMode;
@@ -1519,38 +1522,39 @@ Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
   return LHS;
 }
 
-Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) {
+Value *SCEVExpander::visitSMaxExpr(SCEVUseT<const SCEVSMaxExpr *> S) {
   return expandMinMaxExpr(S, Intrinsic::smax, "smax");
 }
 
-Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) {
+Value *SCEVExpander::visitUMaxExpr(SCEVUseT<const SCEVUMaxExpr *> S) {
   return expandMinMaxExpr(S, Intrinsic::umax, "umax");
 }
 
-Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) {
+Value *SCEVExpander::visitSMinExpr(SCEVUseT<const SCEVSMinExpr *> S) {
   return expandMinMaxExpr(S, Intrinsic::smin, "smin");
 }
 
-Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
+Value *SCEVExpander::visitUMinExpr(SCEVUseT<const SCEVUMinExpr *> S) {
   return expandMinMaxExpr(S, Intrinsic::umin, "umin");
 }
 
-Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
-  return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true);
+Value *SCEVExpander::visitSequentialUMinExpr(
+    SCEVUseT<const SCEVSequentialUMinExpr *> S) {
+  return expandMinMaxExpr(S, Intrinsic::umin, "umin",
+                          /*IsSequential*/ true);
 }
 
-Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
+Value *SCEVExpander::visitVScale(SCEVUseT<const SCEVVScale *> S) {
   return Builder.CreateVScale(S->getType());
 }
 
-Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
+Value *SCEVExpander::expandCodeFor(SCEVUse SH, Type *Ty,
                                    BasicBlock::iterator IP) {
   setInsertPoint(IP);
-  Value *V = expandCodeFor(SH, Ty);
-  return V;
+  return expandCodeFor(SH, Ty);
 }
 
-Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) {
+Value *SCEVExpander::expandCodeFor(SCEVUse SH, Type *Ty) {
   // Expand the code for this SCEV.
   Value *V = expand(SH);
 
@@ -1563,7 +1567,7 @@ Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) {
 }
 
 Value *SCEVExpander::FindValueInExprValueMap(
-    const SCEV *S, const Instruction *InsertPt,
+    SCEVUse S, const Instruction *InsertPt,
     SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
   // If the expansion is not in CanonicalMode, and the SCEV contains any
   // sub scAddRecExpr type SCEV, it is required to expand the SCEV literally.
@@ -1602,7 +1606,7 @@ Value *SCEVExpander::FindValueInExprValueMap(
 // literally, to prevent LSR's transformed SCEV from being reverted. Otherwise,
 // the expansion will try to reuse Value from ExprValueMap, and only when it
 // fails, expand the SCEV literally.
-Value *SCEVExpander::expand(const SCEV *S) {
+Value *SCEVExpander::expand(SCEVUse S) {
   // Compute an insertion point for this SCEV object. Hoist the instructions
   // as far out in the loop nest as possible.
   BasicBlock::iterator InsertPt = Builder.GetInsertPoint();


        


More information about the llvm-commits mailing list