[llvm] [SCEVExpander] Add SCEVUseVisitor and use it in SCEVExpander (NFC) (PR #188863)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 2 02:41:22 PDT 2026
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/188863
>From a72b9ce029cfcce98b3f80197951b58288a47727 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 24 Mar 2026 21:47:36 +0000
Subject: [PATCH 1/5] [SCEVExpander] Add SCEVUseVisitor and use it in
SCEVExpander (NFC)
Add SCEVUseVisitor, a new visitor class where all visit methods receive
a SCEVUse instead of a const SCEV*. Use it for SCEVExpander, so it can
use use-specific flags in the future.
---
.../Analysis/ScalarEvolutionExpressions.h | 49 +++++++++++++
.../Utils/ScalarEvolutionExpander.h | 51 +++++++-------
.../Utils/ScalarEvolutionExpander.cpp | 68 +++++++++++--------
3 files changed, 114 insertions(+), 54 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 5199bbede7f84..bd03bf1146a46 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -672,6 +672,55 @@ template <typename SC, typename RetVal = void> struct SCEVVisitor {
}
};
+/// A visitor class for SCEVUse.
+template <typename SC, typename RetVal = void> struct SCEVUseVisitor {
+ RetVal visit(SCEVUse S) {
+ switch (S->getSCEVType()) {
+ case scConstant:
+ return ((SC *)this)->visitConstant(S);
+ case scVScale:
+ return ((SC *)this)->visitVScale(S);
+ case scPtrToAddr:
+ return ((SC *)this)->visitPtrToAddrExpr(S);
+ case scPtrToInt:
+ return ((SC *)this)->visitPtrToIntExpr(S);
+ case scTruncate:
+ return ((SC *)this)->visitTruncateExpr(S);
+ case scZeroExtend:
+ return ((SC *)this)->visitZeroExtendExpr(S);
+ case scSignExtend:
+ return ((SC *)this)->visitSignExtendExpr(S);
+ case scAddExpr:
+ return ((SC *)this)->visitAddExpr(S);
+ case scMulExpr:
+ return ((SC *)this)->visitMulExpr(S);
+ case scUDivExpr:
+ return ((SC *)this)->visitUDivExpr(S);
+ case scAddRecExpr:
+ return ((SC *)this)->visitAddRecExpr(S);
+ case scSMaxExpr:
+ return ((SC *)this)->visitSMaxExpr(S);
+ case scUMaxExpr:
+ return ((SC *)this)->visitUMaxExpr(S);
+ case scSMinExpr:
+ return ((SC *)this)->visitSMinExpr(S);
+ case scUMinExpr:
+ return ((SC *)this)->visitUMinExpr(S);
+ case scSequentialUMinExpr:
+ return ((SC *)this)->visitSequentialUMinExpr(S);
+ case scUnknown:
+ return ((SC *)this)->visitUnknown(S);
+ case scCouldNotCompute:
+ return ((SC *)this)->visitCouldNotCompute(S);
+ }
+ llvm_unreachable("Unknown SCEV kind!");
+ }
+
+ RetVal visitCouldNotCompute(SCEVUse S) {
+ llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
+ }
+};
+
/// Visit all nodes in the expression tree using worklist traversal.
///
/// Visitor implements:
diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 2559d7b89b020..50b1ad055cf27 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -61,7 +61,7 @@ struct PoisonFlags {
/// Clients should create an instance of this class when rewriting is needed,
/// and destroy it when finished to allow the release of the associated
/// memory.
-class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
+class SCEVExpander : public SCEVUseVisitor<SCEVExpander, Value *> {
friend class SCEVExpanderCleaner;
ScalarEvolution &SE;
@@ -179,7 +179,7 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
const char *DebugType;
#endif
- friend struct SCEVVisitor<SCEVExpander, Value *>;
+ friend struct SCEVUseVisitor<SCEVExpander, Value *>;
public:
/// Construct a SCEVExpander in "canonical" mode.
@@ -313,9 +313,8 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
/// Insert code to directly compute the specified SCEV expression into the
/// program. The code is inserted into the specified block.
- LLVM_ABI Value *expandCodeFor(const SCEV *SH, Type *Ty,
- BasicBlock::iterator I);
- Value *expandCodeFor(const SCEV *SH, Type *Ty, Instruction *I) {
+ LLVM_ABI Value *expandCodeFor(SCEVUse SH, Type *Ty, BasicBlock::iterator I);
+ Value *expandCodeFor(SCEVUse SH, Type *Ty, Instruction *I) {
return expandCodeFor(SH, Ty, I->getIterator());
}
@@ -323,7 +322,7 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
/// program. The code is inserted into the SCEVExpander's current
/// insertion point. If a type is specified, the result will be expanded to
/// have that type, with a cast if necessary.
- LLVM_ABI Value *expandCodeFor(const SCEV *SH, Type *Ty = nullptr);
+ LLVM_ABI Value *expandCodeFor(SCEVUse SH, Type *Ty = nullptr);
/// Generates a code sequence that evaluates this predicate. The inserted
/// instructions will be at position \p Loc. The result will be of type i1
@@ -478,12 +477,12 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
const SCEV *S, const Instruction *InsertPt,
SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts);
- LLVM_ABI Value *expand(const SCEV *S);
- Value *expand(const SCEV *S, BasicBlock::iterator I) {
+ LLVM_ABI Value *expand(SCEVUse S);
+ Value *expand(SCEVUse S, BasicBlock::iterator I) {
setInsertPoint(I);
return expand(S);
}
- Value *expand(const SCEV *S, Instruction *I) {
+ Value *expand(SCEVUse S, Instruction *I) {
setInsertPoint(I);
return expand(S);
}
@@ -494,39 +493,39 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
Value *expandMinMaxExpr(const SCEVNAryExpr *S, Intrinsic::ID IntrinID,
Twine Name, bool IsSequential = false);
- Value *visitConstant(const SCEVConstant *S) { return S->getValue(); }
+ Value *visitConstant(SCEVUse S) { return cast<SCEVConstant>(S)->getValue(); }
- Value *visitVScale(const SCEVVScale *S);
+ Value *visitVScale(SCEVUse S);
- Value *visitPtrToAddrExpr(const SCEVPtrToAddrExpr *S);
+ Value *visitPtrToAddrExpr(SCEVUse S);
- Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S);
+ Value *visitPtrToIntExpr(SCEVUse S);
- Value *visitTruncateExpr(const SCEVTruncateExpr *S);
+ Value *visitTruncateExpr(SCEVUse S);
- Value *visitZeroExtendExpr(const SCEVZeroExtendExpr *S);
+ Value *visitZeroExtendExpr(SCEVUse S);
- Value *visitSignExtendExpr(const SCEVSignExtendExpr *S);
+ Value *visitSignExtendExpr(SCEVUse S);
- Value *visitAddExpr(const SCEVAddExpr *S);
+ Value *visitAddExpr(SCEVUse S);
- Value *visitMulExpr(const SCEVMulExpr *S);
+ Value *visitMulExpr(SCEVUse S);
- Value *visitUDivExpr(const SCEVUDivExpr *S);
+ Value *visitUDivExpr(SCEVUse S);
- Value *visitAddRecExpr(const SCEVAddRecExpr *S);
+ Value *visitAddRecExpr(SCEVUse S);
- Value *visitSMaxExpr(const SCEVSMaxExpr *S);
+ Value *visitSMaxExpr(SCEVUse S);
- Value *visitUMaxExpr(const SCEVUMaxExpr *S);
+ Value *visitUMaxExpr(SCEVUse S);
- Value *visitSMinExpr(const SCEVSMinExpr *S);
+ Value *visitSMinExpr(SCEVUse S);
- Value *visitUMinExpr(const SCEVUMinExpr *S);
+ Value *visitUMinExpr(SCEVUse S);
- Value *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S);
+ Value *visitSequentialUMinExpr(SCEVUse S);
- Value *visitUnknown(const SCEVUnknown *S) { return S->getValue(); }
+ Value *visitUnknown(SCEVUse S) { return cast<SCEVUnknown>(S)->getValue(); }
LLVM_ABI void rememberInstruction(Value *I);
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index ac60837584763..0a570da01ba15 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -523,7 +523,8 @@ class LoopCompare {
}
-Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
+Value *SCEVExpander::visitAddExpr(SCEVUse SU) {
+ const SCEVAddExpr *S = cast<SCEVAddExpr>(SU);
// Recognize the canonical representation of an unsimplifed urem.
const SCEV *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
@@ -595,7 +596,8 @@ Value *SCEVExpander::visitAddExpr(const SCEVAddExpr *S) {
return Sum;
}
-Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
+Value *SCEVExpander::visitMulExpr(SCEVUse SU) {
+ const SCEVMulExpr *S = cast<SCEVMulExpr>(SU);
Type *Ty = S->getType();
// Collect all the mul operands in a loop, along with their associated loops.
@@ -687,7 +689,8 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) {
return Prod;
}
-Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
+Value *SCEVExpander::visitUDivExpr(SCEVUse SU) {
+ const SCEVUDivExpr *S = cast<SCEVUDivExpr>(SU);
Value *LHS = expand(S->getLHS());
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) {
const APInt &RHS = SC->getAPInt();
@@ -1308,7 +1311,8 @@ Value *SCEVExpander::tryToReuseLCSSAPhi(const SCEVAddRecExpr *S) {
return nullptr;
}
-Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
+Value *SCEVExpander::visitAddRecExpr(SCEVUse SU) {
+ const SCEVAddRecExpr *S = cast<SCEVAddRecExpr>(SU);
// In canonical mode we compute the addrec as an expression of a canonical IV
// using evaluateAtIteration and expand the resulting SCEV expression. This
// way we avoid introducing new IVs to carry on the computation of the addrec
@@ -1448,7 +1452,8 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) {
return expand(T);
}
-Value *SCEVExpander::visitPtrToAddrExpr(const SCEVPtrToAddrExpr *S) {
+Value *SCEVExpander::visitPtrToAddrExpr(SCEVUse SU) {
+ const SCEVPtrToAddrExpr *S = cast<SCEVPtrToAddrExpr>(SU);
Value *V = expand(S->getOperand());
Type *Ty = S->getType();
@@ -1468,24 +1473,28 @@ Value *SCEVExpander::visitPtrToAddrExpr(const SCEVPtrToAddrExpr *S) {
GetOptimalInsertionPointForCastOf(V));
}
-Value *SCEVExpander::visitPtrToIntExpr(const SCEVPtrToIntExpr *S) {
+Value *SCEVExpander::visitPtrToIntExpr(SCEVUse SU) {
+ const SCEVPtrToIntExpr *S = cast<SCEVPtrToIntExpr>(SU);
Value *V = expand(S->getOperand());
return ReuseOrCreateCast(V, S->getType(), CastInst::PtrToInt,
GetOptimalInsertionPointForCastOf(V));
}
-Value *SCEVExpander::visitTruncateExpr(const SCEVTruncateExpr *S) {
+Value *SCEVExpander::visitTruncateExpr(SCEVUse SU) {
+ const SCEVTruncateExpr *S = cast<SCEVTruncateExpr>(SU);
Value *V = expand(S->getOperand());
return Builder.CreateTrunc(V, S->getType());
}
-Value *SCEVExpander::visitZeroExtendExpr(const SCEVZeroExtendExpr *S) {
+Value *SCEVExpander::visitZeroExtendExpr(SCEVUse SU) {
+ const SCEVZeroExtendExpr *S = cast<SCEVZeroExtendExpr>(SU);
Value *V = expand(S->getOperand());
return Builder.CreateZExt(V, S->getType(), "",
SE.isKnownNonNegative(S->getOperand()));
}
-Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
+Value *SCEVExpander::visitSignExtendExpr(SCEVUse SU) {
+ const SCEVSignExtendExpr *S = cast<SCEVSignExtendExpr>(SU);
Value *V = expand(S->getOperand());
return Builder.CreateSExt(V, S->getType());
}
@@ -1519,38 +1528,38 @@ Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
return LHS;
}
-Value *SCEVExpander::visitSMaxExpr(const SCEVSMaxExpr *S) {
- return expandMinMaxExpr(S, Intrinsic::smax, "smax");
+Value *SCEVExpander::visitSMaxExpr(SCEVUse SU) {
+ return expandMinMaxExpr(cast<SCEVSMaxExpr>(SU), Intrinsic::smax, "smax");
}
-Value *SCEVExpander::visitUMaxExpr(const SCEVUMaxExpr *S) {
- return expandMinMaxExpr(S, Intrinsic::umax, "umax");
+Value *SCEVExpander::visitUMaxExpr(SCEVUse SU) {
+ return expandMinMaxExpr(cast<SCEVUMaxExpr>(SU), Intrinsic::umax, "umax");
}
-Value *SCEVExpander::visitSMinExpr(const SCEVSMinExpr *S) {
- return expandMinMaxExpr(S, Intrinsic::smin, "smin");
+Value *SCEVExpander::visitSMinExpr(SCEVUse SU) {
+ return expandMinMaxExpr(cast<SCEVSMinExpr>(SU), Intrinsic::smin, "smin");
}
-Value *SCEVExpander::visitUMinExpr(const SCEVUMinExpr *S) {
- return expandMinMaxExpr(S, Intrinsic::umin, "umin");
+Value *SCEVExpander::visitUMinExpr(SCEVUse SU) {
+ return expandMinMaxExpr(cast<SCEVUMinExpr>(SU), Intrinsic::umin, "umin");
}
-Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
- return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true);
+Value *SCEVExpander::visitSequentialUMinExpr(SCEVUse SU) {
+ return expandMinMaxExpr(cast<SCEVSequentialUMinExpr>(SU), Intrinsic::umin,
+ "umin", /*IsSequential*/ true);
}
-Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
- return Builder.CreateVScale(S->getType());
+Value *SCEVExpander::visitVScale(SCEVUse SU) {
+ return Builder.CreateVScale(cast<SCEVVScale>(SU)->getType());
}
-Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty,
+Value *SCEVExpander::expandCodeFor(SCEVUse SH, Type *Ty,
BasicBlock::iterator IP) {
setInsertPoint(IP);
- Value *V = expandCodeFor(SH, Ty);
- return V;
+ return expandCodeFor(SH, Ty);
}
-Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) {
+Value *SCEVExpander::expandCodeFor(SCEVUse SH, Type *Ty) {
// Expand the code for this SCEV.
Value *V = expand(SH);
@@ -1602,7 +1611,7 @@ Value *SCEVExpander::FindValueInExprValueMap(
// literally, to prevent LSR's transformed SCEV from being reverted. Otherwise,
// the expansion will try to reuse Value from ExprValueMap, and only when it
// fails, expand the SCEV literally.
-Value *SCEVExpander::expand(const SCEV *S) {
+Value *SCEVExpander::expand(SCEVUse S) {
// Compute an insertion point for this SCEV object. Hoist the instructions
// as far out in the loop nest as possible.
BasicBlock::iterator InsertPt = Builder.GetInsertPoint();
@@ -1654,7 +1663,7 @@ Value *SCEVExpander::expand(const SCEV *S) {
}
// Check to see if we already expanded this here.
- auto I = InsertedExpressions.find(std::make_pair(S, &*InsertPt));
+ auto I = InsertedExpressions.find(std::make_pair(S.getPointer(), &*InsertPt));
if (I != InsertedExpressions.end())
return I->second;
@@ -1696,7 +1705,10 @@ Value *SCEVExpander::expand(const SCEV *S) {
// the expression at this insertion point. If the mapped value happened to be
// a postinc expansion, it could be reused by a non-postinc user, but only if
// its insertion point was already at the head of the loop.
- InsertedExpressions[std::make_pair(S, &*InsertPt)] = V;
+ // Only cache canonical SCEVUses (without use-specific flags) to prevent
+ // re-use of expansions with incorrect flags.
+ if (S.isCanonical())
+ InsertedExpressions[std::make_pair(S.getPointer(), &*InsertPt)] = V;
return V;
}
>From 53cb224b6bf45121ff14edca713584e978b4fdc1 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 31 Mar 2026 11:37:09 +0100
Subject: [PATCH 2/5] !fixup cache expansions for SCEVUse
---
.../llvm/Transforms/Utils/ScalarEvolutionExpander.h | 2 +-
llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp | 7 ++-----
2 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 50b1ad055cf27..b980fe95d6d8a 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -74,7 +74,7 @@ class SCEVExpander : public SCEVUseVisitor<SCEVExpander, Value *> {
bool PreserveLCSSA;
// InsertedExpressions caches Values for reuse, so must track RAUW.
- DenseMap<std::pair<const SCEV *, Instruction *>, TrackingVH<Value>>
+ DenseMap<std::pair<SCEVUse, Instruction *>, TrackingVH<Value>>
InsertedExpressions;
// InsertedValues only flags inserted instructions so needs no RAUW.
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 0a570da01ba15..71d4727e998b0 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1663,7 +1663,7 @@ Value *SCEVExpander::expand(SCEVUse S) {
}
// Check to see if we already expanded this here.
- auto I = InsertedExpressions.find(std::make_pair(S.getPointer(), &*InsertPt));
+ auto I = InsertedExpressions.find(std::make_pair(S, &*InsertPt));
if (I != InsertedExpressions.end())
return I->second;
@@ -1705,10 +1705,7 @@ Value *SCEVExpander::expand(SCEVUse S) {
// the expression at this insertion point. If the mapped value happened to be
// a postinc expansion, it could be reused by a non-postinc user, but only if
// its insertion point was already at the head of the loop.
- // Only cache canonical SCEVUses (without use-specific flags) to prevent
- // re-use of expansions with incorrect flags.
- if (S.isCanonical())
- InsertedExpressions[std::make_pair(S.getPointer(), &*InsertPt)] = V;
+ InsertedExpressions[std::make_pair(S, &*InsertPt)] = V;
return V;
}
>From 0a0a929aaf30bd5eb8f1d023d4241c1f675814dc Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 31 Mar 2026 19:10:32 +0100
Subject: [PATCH 3/5] !fixup Add templatized SCEVUseT<>
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 31 ++++++---
.../Analysis/ScalarEvolutionExpressions.h | 47 +++++++------
.../Analysis/ScalarEvolutionPatternMatch.h | 15 +++--
.../Utils/ScalarEvolutionExpander.h | 45 +++++++------
llvm/lib/Analysis/ScalarEvolution.cpp | 9 ++-
.../Utils/ScalarEvolutionExpander.cpp | 67 +++++++++----------
6 files changed, 122 insertions(+), 92 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index f20b73f9f358c..3ee8ff1f56efc 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -68,17 +68,22 @@ LLVM_ABI extern bool VerifySCEV;
class SCEV;
-struct SCEVUse : PointerIntPair<const SCEV *, 2> {
- SCEVUse() : PointerIntPair() { setFromOpaqueValue(nullptr); }
- SCEVUse(const SCEV *S) : PointerIntPair() {
- setFromOpaqueValue(reinterpret_cast<void *>(const_cast<SCEV *>(S)));
- }
- SCEVUse(const SCEV *S, unsigned Flags) : PointerIntPair(S, Flags) {}
+template <typename SCEVPtrT = const SCEV *>
+struct SCEVUseT : PointerIntPair<SCEVPtrT, 2> {
+ using Base = PointerIntPair<SCEVPtrT, 2>;
+
+ SCEVUseT() : Base() { Base::setFromOpaqueValue(nullptr); }
+ SCEVUseT(SCEVPtrT S) : SCEVUseT(S, 0) {}
+ SCEVUseT(SCEVPtrT S, unsigned Flags) : Base(S, Flags) {}
+ template <typename OtherPtrT, typename = std::enable_if_t<
+ std::is_convertible_v<OtherPtrT, SCEVPtrT>>>
+ SCEVUseT(const SCEVUseT<OtherPtrT> &Other)
+ : Base(Other.getPointer(), Other.getInt()) {}
- operator const SCEV *() const { return getPointer(); }
- const SCEV *operator->() const { return getPointer(); }
+ operator const SCEV *() const { return Base::getPointer(); }
+ SCEVPtrT operator->() const { return Base::getPointer(); }
- void *getRawPointer() const { return getOpaqueValue(); }
+ void *getRawPointer() const { return Base::getOpaqueValue(); }
/// Returns true of the SCEVUse is canonical, i.e. no SCEVUse flags set in any
/// operands.
@@ -89,7 +94,8 @@ struct SCEVUse : PointerIntPair<const SCEV *, 2> {
unsigned getFlags() const { return getInt(); }
- bool operator==(const SCEVUse &RHS) const {
+
+ bool operator==(const SCEVUseT &RHS) const {
return getRawPointer() == RHS.getRawPointer();
}
@@ -103,6 +109,11 @@ struct SCEVUse : PointerIntPair<const SCEV *, 2> {
void dump() const;
};
+/// Deduction guide for various SCEV subclass pointers.
+template <typename SCEVPtrT> SCEVUseT(SCEVPtrT) -> SCEVUseT<SCEVPtrT>;
+
+using SCEVUse = SCEVUseT<const SCEV *>;
+
/// Provide PointerLikeTypeTraits for SCEVUse, so it can be used with
/// SmallPtrSet, among others.
template <> struct PointerLikeTypeTraits<SCEVUse> {
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index bd03bf1146a46..7958fed19b00c 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -677,46 +677,55 @@ template <typename SC, typename RetVal = void> struct SCEVUseVisitor {
RetVal visit(SCEVUse S) {
switch (S->getSCEVType()) {
case scConstant:
- return ((SC *)this)->visitConstant(S);
+ return ((SC *)this)->visitConstant(SCEVUseT(cast<const SCEVConstant>(S)));
case scVScale:
- return ((SC *)this)->visitVScale(S);
+ return ((SC *)this)->visitVScale(SCEVUseT(cast<const SCEVVScale>(S)));
case scPtrToAddr:
- return ((SC *)this)->visitPtrToAddrExpr(S);
+ return ((SC *)this)
+ ->visitPtrToAddrExpr(SCEVUseT(cast<const SCEVPtrToAddrExpr>(S)));
case scPtrToInt:
- return ((SC *)this)->visitPtrToIntExpr(S);
+ return ((SC *)this)
+ ->visitPtrToIntExpr(SCEVUseT(cast<const SCEVPtrToIntExpr>(S)));
case scTruncate:
- return ((SC *)this)->visitTruncateExpr(S);
+ return ((SC *)this)
+ ->visitTruncateExpr(SCEVUseT(cast<const SCEVTruncateExpr>(S)));
case scZeroExtend:
- return ((SC *)this)->visitZeroExtendExpr(S);
+ return ((SC *)this)
+ ->visitZeroExtendExpr(SCEVUseT(cast<const SCEVZeroExtendExpr>(S)));
case scSignExtend:
- return ((SC *)this)->visitSignExtendExpr(S);
+ return ((SC *)this)
+ ->visitSignExtendExpr(SCEVUseT(cast<const SCEVSignExtendExpr>(S)));
case scAddExpr:
- return ((SC *)this)->visitAddExpr(S);
+ return ((SC *)this)->visitAddExpr(SCEVUseT(cast<const SCEVAddExpr>(S)));
case scMulExpr:
- return ((SC *)this)->visitMulExpr(S);
+ return ((SC *)this)->visitMulExpr(SCEVUseT(cast<const SCEVMulExpr>(S)));
case scUDivExpr:
- return ((SC *)this)->visitUDivExpr(S);
+ return ((SC *)this)->visitUDivExpr(SCEVUseT(cast<const SCEVUDivExpr>(S)));
case scAddRecExpr:
- return ((SC *)this)->visitAddRecExpr(S);
+ return ((SC *)this)
+ ->visitAddRecExpr(SCEVUseT(cast<const SCEVAddRecExpr>(S)));
case scSMaxExpr:
- return ((SC *)this)->visitSMaxExpr(S);
+ return ((SC *)this)->visitSMaxExpr(SCEVUseT(cast<const SCEVSMaxExpr>(S)));
case scUMaxExpr:
- return ((SC *)this)->visitUMaxExpr(S);
+ return ((SC *)this)->visitUMaxExpr(SCEVUseT(cast<const SCEVUMaxExpr>(S)));
case scSMinExpr:
- return ((SC *)this)->visitSMinExpr(S);
+ return ((SC *)this)->visitSMinExpr(SCEVUseT(cast<const SCEVSMinExpr>(S)));
case scUMinExpr:
- return ((SC *)this)->visitUMinExpr(S);
+ return ((SC *)this)->visitUMinExpr(SCEVUseT(cast<const SCEVUMinExpr>(S)));
case scSequentialUMinExpr:
- return ((SC *)this)->visitSequentialUMinExpr(S);
+ return ((SC *)this)
+ ->visitSequentialUMinExpr(
+ SCEVUseT(cast<const SCEVSequentialUMinExpr>(S)));
case scUnknown:
- return ((SC *)this)->visitUnknown(S);
+ return ((SC *)this)->visitUnknown(SCEVUseT(cast<const SCEVUnknown>(S)));
case scCouldNotCompute:
- return ((SC *)this)->visitCouldNotCompute(S);
+ return ((SC *)this)
+ ->visitCouldNotCompute(SCEVUseT(cast<const SCEVCouldNotCompute>(S)));
}
llvm_unreachable("Unknown SCEV kind!");
}
- RetVal visitCouldNotCompute(SCEVUse S) {
+ RetVal visitCouldNotCompute(SCEVUseT<const SCEVCouldNotCompute *> S) {
llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
}
};
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
index d678e23afc18d..7b045e150c76b 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h
@@ -23,7 +23,8 @@ template <typename Pattern> bool match(const SCEV *S, const Pattern &P) {
return P.match(S);
}
-template <typename Pattern> bool match(const SCEVUse U, const Pattern &P) {
+template <typename SCEVPtrT, typename Pattern>
+bool match(const SCEVUseT<SCEVPtrT> U, const Pattern &P) {
return P.match(U.getPointer());
}
@@ -87,10 +88,10 @@ template <typename Class> struct bind_ty {
}
};
-template <> struct bind_ty<SCEVUse> {
- SCEVUse &VR;
+template <typename SCEVPtrT> struct bind_ty<SCEVUseT<SCEVPtrT>> {
+ SCEVUseT<SCEVPtrT> &VR;
- bind_ty(SCEVUse &V) : VR(V) {}
+ bind_ty(SCEVUseT<SCEVPtrT> &V) : VR(V) {}
template <typename ITy> bool match(ITy *V) const {
VR = V;
@@ -100,7 +101,11 @@ template <> struct bind_ty<SCEVUse> {
/// Match a SCEV, capturing it if we match.
inline bind_ty<const SCEV> m_SCEV(const SCEV *&V) { return V; }
-inline bind_ty<SCEVUse> m_SCEV(SCEVUse &V) { return V; }
+
+template <typename SCEVPtrT>
+inline bind_ty<SCEVUseT<SCEVPtrT>> m_SCEV(SCEVUseT<SCEVPtrT> &V) {
+ return V;
+}
inline bind_ty<const SCEVConstant> m_SCEVConstant(const SCEVConstant *&V) {
return V;
}
diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index b980fe95d6d8a..f62ce0b8c008b 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -490,42 +490,45 @@ class SCEVExpander : public SCEVUseVisitor<SCEVExpander, Value *> {
/// Determine the most "relevant" loop for the given SCEV.
const Loop *getRelevantLoop(const SCEV *);
- Value *expandMinMaxExpr(const SCEVNAryExpr *S, Intrinsic::ID IntrinID,
- Twine Name, bool IsSequential = false);
+ Value *expandMinMaxExpr(SCEVUseT<const SCEVNAryExpr *> S,
+ Intrinsic::ID IntrinID, Twine Name,
+ bool IsSequential = false);
- Value *visitConstant(SCEVUse S) { return cast<SCEVConstant>(S)->getValue(); }
+ Value *visitConstant(SCEVUseT<const SCEVConstant *> S) {
+ return S->getValue();
+ }
- Value *visitVScale(SCEVUse S);
+ Value *visitVScale(SCEVUseT<const SCEVVScale *> S);
- Value *visitPtrToAddrExpr(SCEVUse S);
+ Value *visitPtrToAddrExpr(SCEVUseT<const SCEVPtrToAddrExpr *> S);
- Value *visitPtrToIntExpr(SCEVUse S);
+ Value *visitPtrToIntExpr(SCEVUseT<const SCEVPtrToIntExpr *> S);
- Value *visitTruncateExpr(SCEVUse S);
+ Value *visitTruncateExpr(SCEVUseT<const SCEVTruncateExpr *> S);
- Value *visitZeroExtendExpr(SCEVUse S);
+ Value *visitZeroExtendExpr(SCEVUseT<const SCEVZeroExtendExpr *> S);
- Value *visitSignExtendExpr(SCEVUse S);
+ Value *visitSignExtendExpr(SCEVUseT<const SCEVSignExtendExpr *> S);
- Value *visitAddExpr(SCEVUse S);
+ Value *visitAddExpr(SCEVUseT<const SCEVAddExpr *> S);
- Value *visitMulExpr(SCEVUse S);
+ Value *visitMulExpr(SCEVUseT<const SCEVMulExpr *> S);
- Value *visitUDivExpr(SCEVUse S);
+ Value *visitUDivExpr(SCEVUseT<const SCEVUDivExpr *> S);
- Value *visitAddRecExpr(SCEVUse S);
+ Value *visitAddRecExpr(SCEVUseT<const SCEVAddRecExpr *> S);
- Value *visitSMaxExpr(SCEVUse S);
+ Value *visitSMaxExpr(SCEVUseT<const SCEVSMaxExpr *> S);
- Value *visitUMaxExpr(SCEVUse S);
+ Value *visitUMaxExpr(SCEVUseT<const SCEVUMaxExpr *> S);
- Value *visitSMinExpr(SCEVUse S);
+ Value *visitSMinExpr(SCEVUseT<const SCEVSMinExpr *> S);
- Value *visitUMinExpr(SCEVUse S);
+ Value *visitUMinExpr(SCEVUseT<const SCEVUMinExpr *> S);
- Value *visitSequentialUMinExpr(SCEVUse S);
+ Value *visitSequentialUMinExpr(SCEVUseT<const SCEVSequentialUMinExpr *> S);
- Value *visitUnknown(SCEVUse S) { return cast<SCEVUnknown>(S)->getValue(); }
+ Value *visitUnknown(SCEVUseT<const SCEVUnknown *> S) { return S->getValue(); }
LLVM_ABI void rememberInstruction(Value *I);
@@ -535,8 +538,8 @@ class SCEVExpander : public SCEVUseVisitor<SCEVExpander, Value *> {
bool isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV, const Loop *L);
- Value *tryToReuseLCSSAPhi(const SCEVAddRecExpr *S);
- Value *expandAddRecExprLiterally(const SCEVAddRecExpr *);
+ Value *tryToReuseLCSSAPhi(SCEVUseT<const SCEVAddRecExpr *> S);
+ Value *expandAddRecExprLiterally(SCEVUseT<const SCEVAddRecExpr *> S);
PHINode *getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
const Loop *L, Type *&TruncTy,
bool &InvertStep);
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 73dc298e66e6f..6411552791bee 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -343,7 +343,7 @@ LLVM_DUMP_METHOD void SCEVUse::dump() const {
}
#endif
-void SCEVUse::print(raw_ostream &OS) const {
+template <> void SCEVUseT<const SCEV *>::print(raw_ostream &OS) const {
getPointer()->print(OS);
SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(getInt());
if (Flags & SCEV::FlagNUW)
@@ -352,6 +352,13 @@ void SCEVUse::print(raw_ostream &OS) const {
OS << "(u nsw)";
}
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+template <> LLVM_DUMP_METHOD void SCEVUseT<const SCEV *>::dump() const {
+ print(dbgs());
+ dbgs() << '\n';
+}
+#endif
+
//===----------------------------------------------------------------------===//
// Implementation of the SCEV class.
//
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 71d4727e998b0..138b3e7699b2e 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -523,8 +523,7 @@ class LoopCompare {
}
-Value *SCEVExpander::visitAddExpr(SCEVUse SU) {
- const SCEVAddExpr *S = cast<SCEVAddExpr>(SU);
+Value *SCEVExpander::visitAddExpr(SCEVUseT<const SCEVAddExpr *> S) {
// Recognize the canonical representation of an unsimplifed urem.
const SCEV *URemLHS = nullptr;
const SCEV *URemRHS = nullptr;
@@ -596,8 +595,7 @@ Value *SCEVExpander::visitAddExpr(SCEVUse SU) {
return Sum;
}
-Value *SCEVExpander::visitMulExpr(SCEVUse SU) {
- const SCEVMulExpr *S = cast<SCEVMulExpr>(SU);
+Value *SCEVExpander::visitMulExpr(SCEVUseT<const SCEVMulExpr *> S) {
Type *Ty = S->getType();
// Collect all the mul operands in a loop, along with their associated loops.
@@ -689,8 +687,7 @@ Value *SCEVExpander::visitMulExpr(SCEVUse SU) {
return Prod;
}
-Value *SCEVExpander::visitUDivExpr(SCEVUse SU) {
- const SCEVUDivExpr *S = cast<SCEVUDivExpr>(SU);
+Value *SCEVExpander::visitUDivExpr(SCEVUseT<const SCEVUDivExpr *> S) {
Value *LHS = expand(S->getLHS());
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getRHS())) {
const APInt &RHS = SC->getAPInt();
@@ -1159,12 +1156,13 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized,
return PN;
}
-Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
+Value *
+SCEVExpander::expandAddRecExprLiterally(SCEVUseT<const SCEVAddRecExpr *> S) {
const Loop *L = S->getLoop();
// Determine a normalized form of this expression, which is the expression
// before any post-inc adjustment is made.
- const SCEVAddRecExpr *Normalized = S;
+ const SCEVAddRecExpr *Normalized = S.getPointer();
if (PostIncLoops.count(L)) {
PostIncLoopSet Loops;
Loops.insert(L);
@@ -1249,7 +1247,7 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) {
return Result;
}
-Value *SCEVExpander::tryToReuseLCSSAPhi(const SCEVAddRecExpr *S) {
+Value *SCEVExpander::tryToReuseLCSSAPhi(SCEVUseT<const SCEVAddRecExpr *> S) {
Type *STy = S->getType();
const Loop *L = S->getLoop();
BasicBlock *EB = L->getExitBlock();
@@ -1262,7 +1260,7 @@ Value *SCEVExpander::tryToReuseLCSSAPhi(const SCEVAddRecExpr *S) {
auto CanReuse = [&](const SCEV *ExitSCEV) -> const SCEV * {
if (isa<SCEVCouldNotCompute>(ExitSCEV))
return nullptr;
- const SCEV *Diff = SE.getMinusSCEV(S, ExitSCEV);
+ const SCEV *Diff = SE.getMinusSCEV(S.getPointer(), ExitSCEV);
const SCEV *Op = Diff;
match(Op, m_scev_Add(m_SCEVConstant(), m_SCEV(Op)));
match(Op, m_scev_Mul(m_scev_AllOnes(), m_SCEV(Op)));
@@ -1311,8 +1309,7 @@ Value *SCEVExpander::tryToReuseLCSSAPhi(const SCEVAddRecExpr *S) {
return nullptr;
}
-Value *SCEVExpander::visitAddRecExpr(SCEVUse SU) {
- const SCEVAddRecExpr *S = cast<SCEVAddRecExpr>(SU);
+Value *SCEVExpander::visitAddRecExpr(SCEVUseT<const SCEVAddRecExpr *> S) {
// In canonical mode we compute the addrec as an expression of a canonical IV
// using evaluateAtIteration and expand the resulting SCEV expression. This
// way we avoid introducing new IVs to carry on the computation of the addrec
@@ -1452,8 +1449,7 @@ Value *SCEVExpander::visitAddRecExpr(SCEVUse SU) {
return expand(T);
}
-Value *SCEVExpander::visitPtrToAddrExpr(SCEVUse SU) {
- const SCEVPtrToAddrExpr *S = cast<SCEVPtrToAddrExpr>(SU);
+Value *SCEVExpander::visitPtrToAddrExpr(SCEVUseT<const SCEVPtrToAddrExpr *> S) {
Value *V = expand(S->getOperand());
Type *Ty = S->getType();
@@ -1473,33 +1469,31 @@ Value *SCEVExpander::visitPtrToAddrExpr(SCEVUse SU) {
GetOptimalInsertionPointForCastOf(V));
}
-Value *SCEVExpander::visitPtrToIntExpr(SCEVUse SU) {
- const SCEVPtrToIntExpr *S = cast<SCEVPtrToIntExpr>(SU);
+Value *SCEVExpander::visitPtrToIntExpr(SCEVUseT<const SCEVPtrToIntExpr *> S) {
Value *V = expand(S->getOperand());
return ReuseOrCreateCast(V, S->getType(), CastInst::PtrToInt,
GetOptimalInsertionPointForCastOf(V));
}
-Value *SCEVExpander::visitTruncateExpr(SCEVUse SU) {
- const SCEVTruncateExpr *S = cast<SCEVTruncateExpr>(SU);
+Value *SCEVExpander::visitTruncateExpr(SCEVUseT<const SCEVTruncateExpr *> S) {
Value *V = expand(S->getOperand());
return Builder.CreateTrunc(V, S->getType());
}
-Value *SCEVExpander::visitZeroExtendExpr(SCEVUse SU) {
- const SCEVZeroExtendExpr *S = cast<SCEVZeroExtendExpr>(SU);
+Value *
+SCEVExpander::visitZeroExtendExpr(SCEVUseT<const SCEVZeroExtendExpr *> S) {
Value *V = expand(S->getOperand());
return Builder.CreateZExt(V, S->getType(), "",
SE.isKnownNonNegative(S->getOperand()));
}
-Value *SCEVExpander::visitSignExtendExpr(SCEVUse SU) {
- const SCEVSignExtendExpr *S = cast<SCEVSignExtendExpr>(SU);
+Value *
+SCEVExpander::visitSignExtendExpr(SCEVUseT<const SCEVSignExtendExpr *> S) {
Value *V = expand(S->getOperand());
return Builder.CreateSExt(V, S->getType());
}
-Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
+Value *SCEVExpander::expandMinMaxExpr(SCEVUseT<const SCEVNAryExpr *> S,
Intrinsic::ID IntrinID, Twine Name,
bool IsSequential) {
bool PrevSafeMode = SafeUDivMode;
@@ -1528,29 +1522,30 @@ Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
return LHS;
}
-Value *SCEVExpander::visitSMaxExpr(SCEVUse SU) {
- return expandMinMaxExpr(cast<SCEVSMaxExpr>(SU), Intrinsic::smax, "smax");
+Value *SCEVExpander::visitSMaxExpr(SCEVUseT<const SCEVSMaxExpr *> S) {
+ return expandMinMaxExpr(S, Intrinsic::smax, "smax");
}
-Value *SCEVExpander::visitUMaxExpr(SCEVUse SU) {
- return expandMinMaxExpr(cast<SCEVUMaxExpr>(SU), Intrinsic::umax, "umax");
+Value *SCEVExpander::visitUMaxExpr(SCEVUseT<const SCEVUMaxExpr *> S) {
+ return expandMinMaxExpr(S, Intrinsic::umax, "umax");
}
-Value *SCEVExpander::visitSMinExpr(SCEVUse SU) {
- return expandMinMaxExpr(cast<SCEVSMinExpr>(SU), Intrinsic::smin, "smin");
+Value *SCEVExpander::visitSMinExpr(SCEVUseT<const SCEVSMinExpr *> S) {
+ return expandMinMaxExpr(S, Intrinsic::smin, "smin");
}
-Value *SCEVExpander::visitUMinExpr(SCEVUse SU) {
- return expandMinMaxExpr(cast<SCEVUMinExpr>(SU), Intrinsic::umin, "umin");
+Value *SCEVExpander::visitUMinExpr(SCEVUseT<const SCEVUMinExpr *> S) {
+ return expandMinMaxExpr(S, Intrinsic::umin, "umin");
}
-Value *SCEVExpander::visitSequentialUMinExpr(SCEVUse SU) {
- return expandMinMaxExpr(cast<SCEVSequentialUMinExpr>(SU), Intrinsic::umin,
- "umin", /*IsSequential*/ true);
+Value *SCEVExpander::visitSequentialUMinExpr(
+ SCEVUseT<const SCEVSequentialUMinExpr *> S) {
+ return expandMinMaxExpr(S, Intrinsic::umin, "umin",
+ /*IsSequential*/ true);
}
-Value *SCEVExpander::visitVScale(SCEVUse SU) {
- return Builder.CreateVScale(cast<SCEVVScale>(SU)->getType());
+Value *SCEVExpander::visitVScale(SCEVUseT<const SCEVVScale *> S) {
+ return Builder.CreateVScale(S->getType());
}
Value *SCEVExpander::expandCodeFor(SCEVUse SH, Type *Ty,
>From bd43c97bce33b5df200629655074f71ab8628be5 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 1 Apr 2026 11:16:36 +0100
Subject: [PATCH 4/5] !fixup adjust conversion return type, remove getPointre
calls
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 12 +++++++-----
llvm/lib/Analysis/ScalarEvolution.cpp | 11 ++---------
.../lib/Transforms/Utils/ScalarEvolutionExpander.cpp | 4 ++--
3 files changed, 11 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 3ee8ff1f56efc..536ecbe7442a5 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -80,7 +80,7 @@ struct SCEVUseT : PointerIntPair<SCEVPtrT, 2> {
SCEVUseT(const SCEVUseT<OtherPtrT> &Other)
: Base(Other.getPointer(), Other.getInt()) {}
- operator const SCEV *() const { return Base::getPointer(); }
+ operator SCEVPtrT() const { return Base::getPointer(); }
SCEVPtrT operator->() const { return Base::getPointer(); }
void *getRawPointer() const { return Base::getOpaqueValue(); }
@@ -92,8 +92,7 @@ struct SCEVUseT : PointerIntPair<SCEVPtrT, 2> {
/// Return the canonical SCEV for this SCEVUse.
const SCEV *getCanonical() const;
- unsigned getFlags() const { return getInt(); }
-
+ unsigned getFlags() const { return Base::getInt(); }
bool operator==(const SCEVUseT &RHS) const {
return getRawPointer() == RHS.getRawPointer();
@@ -114,6 +113,9 @@ template <typename SCEVPtrT> SCEVUseT(SCEVPtrT) -> SCEVUseT<SCEVPtrT>;
using SCEVUse = SCEVUseT<const SCEV *>;
+template <> void SCEVUseT<const SCEV *>::print(raw_ostream &OS) const;
+template <> void SCEVUseT<const SCEV *>::dump() const;
+
/// Provide PointerLikeTypeTraits for SCEVUse, so it can be used with
/// SmallPtrSet, among others.
template <> struct PointerLikeTypeTraits<SCEVUse> {
@@ -2666,8 +2668,8 @@ template <> struct DenseMapInfo<ScalarEvolution::FoldID> {
}
};
-inline const SCEV *SCEVUse::getCanonical() const {
- return getPointer()->getCanonical();
+template <> inline const SCEV *SCEVUseT<const SCEV *>::getCanonical() const {
+ return Base::getPointer()->getCanonical();
}
} // end namespace llvm
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 6411552791bee..825633c7f83c2 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -336,16 +336,9 @@ void SCEV::computeAndSetCanonical(ScalarEvolution &SE) {
}
}
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-LLVM_DUMP_METHOD void SCEVUse::dump() const {
- print(dbgs());
- dbgs() << '\n';
-}
-#endif
-
template <> void SCEVUseT<const SCEV *>::print(raw_ostream &OS) const {
- getPointer()->print(OS);
- SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(getInt());
+ Base::getPointer()->print(OS);
+ SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(Base::getInt());
if (Flags & SCEV::FlagNUW)
OS << "(u nuw)";
if (Flags & SCEV::FlagNSW)
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 138b3e7699b2e..7c762a53cebdf 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1162,7 +1162,7 @@ SCEVExpander::expandAddRecExprLiterally(SCEVUseT<const SCEVAddRecExpr *> S) {
// Determine a normalized form of this expression, which is the expression
// before any post-inc adjustment is made.
- const SCEVAddRecExpr *Normalized = S.getPointer();
+ const SCEVAddRecExpr *Normalized = S;
if (PostIncLoops.count(L)) {
PostIncLoopSet Loops;
Loops.insert(L);
@@ -1260,7 +1260,7 @@ Value *SCEVExpander::tryToReuseLCSSAPhi(SCEVUseT<const SCEVAddRecExpr *> S) {
auto CanReuse = [&](const SCEV *ExitSCEV) -> const SCEV * {
if (isa<SCEVCouldNotCompute>(ExitSCEV))
return nullptr;
- const SCEV *Diff = SE.getMinusSCEV(S.getPointer(), ExitSCEV);
+ const SCEV *Diff = SE.getMinusSCEV(S, ExitSCEV);
const SCEV *Op = Diff;
match(Op, m_scev_Add(m_SCEVConstant(), m_SCEV(Op)));
match(Op, m_scev_Mul(m_scev_AllOnes(), m_SCEV(Op)));
>From 962ab0d725871abab04082798251b4693d18b533 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 2 Apr 2026 10:10:12 +0100
Subject: [PATCH 5/5] !fixup cast between SCEVUseT.
---
llvm/include/llvm/Analysis/ScalarEvolution.h | 46 +++++++++++++++++--
.../Analysis/ScalarEvolutionExpressions.h | 43 +++++++++--------
.../Utils/ScalarEvolutionExpander.h | 2 +-
llvm/lib/Analysis/ScalarEvolution.cpp | 16 -------
.../Utils/ScalarEvolutionExpander.cpp | 2 +-
5 files changed, 70 insertions(+), 39 deletions(-)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 536ecbe7442a5..08c7e43e708b8 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -113,9 +113,6 @@ template <typename SCEVPtrT> SCEVUseT(SCEVPtrT) -> SCEVUseT<SCEVPtrT>;
using SCEVUse = SCEVUseT<const SCEV *>;
-template <> void SCEVUseT<const SCEV *>::print(raw_ostream &OS) const;
-template <> void SCEVUseT<const SCEV *>::dump() const;
-
/// Provide PointerLikeTypeTraits for SCEVUse, so it can be used with
/// SmallPtrSet, among others.
template <> struct PointerLikeTypeTraits<SCEVUse> {
@@ -158,6 +155,31 @@ template <> struct simplify_type<SCEVUse> {
}
};
+/// Provide CastInfo for SCEVUseT so that cast<SCEVUseT<const To *>>(use)
+/// returns SCEVUseT<const To *> with flags preserved.
+template <typename ToSCEVPtrT>
+struct CastInfo<SCEVUseT<ToSCEVPtrT>, SCEVUse,
+ std::enable_if_t<!is_simple_type<SCEVUse>::value>> {
+ using To = std::remove_cv_t<std::remove_pointer_t<ToSCEVPtrT>>;
+ using CastReturnType = SCEVUseT<ToSCEVPtrT>;
+
+ static bool isPossible(const SCEVUse &U) { return isa<To>(U.getPointer()); }
+ static CastReturnType doCast(const SCEVUse &U) {
+ return {cast<To>(U.getPointer()), U.getFlags()};
+ }
+ static CastReturnType castFailed() { return CastReturnType(nullptr); }
+ static CastReturnType doCastIfPossible(const SCEVUse &U) {
+ if (!isPossible(U))
+ return castFailed();
+ return doCast(U);
+ }
+};
+
+template <typename ToSCEVPtrT>
+struct CastInfo<SCEVUseT<ToSCEVPtrT>, const SCEVUse,
+ std::enable_if_t<!is_simple_type<const SCEVUse>::value>>
+ : CastInfo<SCEVUseT<ToSCEVPtrT>, SCEVUse> {};
+
/// This class represents an analyzed expression in the program. These are
/// opaque objects that the client is not allowed to do much with directly.
///
@@ -2672,6 +2694,24 @@ template <> inline const SCEV *SCEVUseT<const SCEV *>::getCanonical() const {
return Base::getPointer()->getCanonical();
}
+template <typename SCEVPtrT>
+void SCEVUseT<SCEVPtrT>::print(raw_ostream &OS) const {
+ Base::getPointer()->print(OS);
+ SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(Base::getInt());
+ if (Flags & SCEV::FlagNUW)
+ OS << "(u nuw)";
+ if (Flags & SCEV::FlagNSW)
+ OS << "(u nsw)";
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+template <typename SCEVPtrT>
+LLVM_DUMP_METHOD void SCEVUseT<SCEVPtrT>::dump() const {
+ print(dbgs());
+ dbgs() << '\n';
+}
+#endif
+
} // end namespace llvm
#endif // LLVM_ANALYSIS_SCALAREVOLUTION_H
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 7958fed19b00c..2fc928d5955d1 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -677,50 +677,57 @@ template <typename SC, typename RetVal = void> struct SCEVUseVisitor {
RetVal visit(SCEVUse S) {
switch (S->getSCEVType()) {
case scConstant:
- return ((SC *)this)->visitConstant(SCEVUseT(cast<const SCEVConstant>(S)));
+ return ((SC *)this)
+ ->visitConstant(cast<SCEVUseT<const SCEVConstant *>>(S));
case scVScale:
- return ((SC *)this)->visitVScale(SCEVUseT(cast<const SCEVVScale>(S)));
+ return ((SC *)this)->visitVScale(cast<SCEVUseT<const SCEVVScale *>>(S));
case scPtrToAddr:
return ((SC *)this)
- ->visitPtrToAddrExpr(SCEVUseT(cast<const SCEVPtrToAddrExpr>(S)));
+ ->visitPtrToAddrExpr(cast<SCEVUseT<const SCEVPtrToAddrExpr *>>(S));
case scPtrToInt:
return ((SC *)this)
- ->visitPtrToIntExpr(SCEVUseT(cast<const SCEVPtrToIntExpr>(S)));
+ ->visitPtrToIntExpr(cast<SCEVUseT<const SCEVPtrToIntExpr *>>(S));
case scTruncate:
return ((SC *)this)
- ->visitTruncateExpr(SCEVUseT(cast<const SCEVTruncateExpr>(S)));
+ ->visitTruncateExpr(cast<SCEVUseT<const SCEVTruncateExpr *>>(S));
case scZeroExtend:
return ((SC *)this)
- ->visitZeroExtendExpr(SCEVUseT(cast<const SCEVZeroExtendExpr>(S)));
+ ->visitZeroExtendExpr(cast<SCEVUseT<const SCEVZeroExtendExpr *>>(S));
case scSignExtend:
return ((SC *)this)
- ->visitSignExtendExpr(SCEVUseT(cast<const SCEVSignExtendExpr>(S)));
+ ->visitSignExtendExpr(cast<SCEVUseT<const SCEVSignExtendExpr *>>(S));
case scAddExpr:
- return ((SC *)this)->visitAddExpr(SCEVUseT(cast<const SCEVAddExpr>(S)));
+ return ((SC *)this)->visitAddExpr(cast<SCEVUseT<const SCEVAddExpr *>>(S));
case scMulExpr:
- return ((SC *)this)->visitMulExpr(SCEVUseT(cast<const SCEVMulExpr>(S)));
+ return ((SC *)this)->visitMulExpr(cast<SCEVUseT<const SCEVMulExpr *>>(S));
case scUDivExpr:
- return ((SC *)this)->visitUDivExpr(SCEVUseT(cast<const SCEVUDivExpr>(S)));
+ return ((SC *)this)
+ ->visitUDivExpr(cast<SCEVUseT<const SCEVUDivExpr *>>(S));
case scAddRecExpr:
return ((SC *)this)
- ->visitAddRecExpr(SCEVUseT(cast<const SCEVAddRecExpr>(S)));
+ ->visitAddRecExpr(cast<SCEVUseT<const SCEVAddRecExpr *>>(S));
case scSMaxExpr:
- return ((SC *)this)->visitSMaxExpr(SCEVUseT(cast<const SCEVSMaxExpr>(S)));
+ return ((SC *)this)
+ ->visitSMaxExpr(cast<SCEVUseT<const SCEVSMaxExpr *>>(S));
case scUMaxExpr:
- return ((SC *)this)->visitUMaxExpr(SCEVUseT(cast<const SCEVUMaxExpr>(S)));
+ return ((SC *)this)
+ ->visitUMaxExpr(cast<SCEVUseT<const SCEVUMaxExpr *>>(S));
case scSMinExpr:
- return ((SC *)this)->visitSMinExpr(SCEVUseT(cast<const SCEVSMinExpr>(S)));
+ return ((SC *)this)
+ ->visitSMinExpr(cast<SCEVUseT<const SCEVSMinExpr *>>(S));
case scUMinExpr:
- return ((SC *)this)->visitUMinExpr(SCEVUseT(cast<const SCEVUMinExpr>(S)));
+ return ((SC *)this)
+ ->visitUMinExpr(cast<SCEVUseT<const SCEVUMinExpr *>>(S));
case scSequentialUMinExpr:
return ((SC *)this)
->visitSequentialUMinExpr(
- SCEVUseT(cast<const SCEVSequentialUMinExpr>(S)));
+ cast<SCEVUseT<const SCEVSequentialUMinExpr *>>(S));
case scUnknown:
- return ((SC *)this)->visitUnknown(SCEVUseT(cast<const SCEVUnknown>(S)));
+ return ((SC *)this)->visitUnknown(cast<SCEVUseT<const SCEVUnknown *>>(S));
case scCouldNotCompute:
return ((SC *)this)
- ->visitCouldNotCompute(SCEVUseT(cast<const SCEVCouldNotCompute>(S)));
+ ->visitCouldNotCompute(
+ cast<SCEVUseT<const SCEVCouldNotCompute *>>(S));
}
llvm_unreachable("Unknown SCEV kind!");
}
diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index f62ce0b8c008b..42355f5841eab 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -474,7 +474,7 @@ class SCEVExpander : public SCEVUseVisitor<SCEVExpander, Value *> {
/// DropPoisonGeneratingInsts is populated with instructions for which
/// poison-generating flags must be dropped if the value is reused.
Value *FindValueInExprValueMap(
- const SCEV *S, const Instruction *InsertPt,
+ SCEVUse S, const Instruction *InsertPt,
SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts);
LLVM_ABI Value *expand(SCEVUse S);
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 825633c7f83c2..0e999e41a9e3e 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -336,22 +336,6 @@ void SCEV::computeAndSetCanonical(ScalarEvolution &SE) {
}
}
-template <> void SCEVUseT<const SCEV *>::print(raw_ostream &OS) const {
- Base::getPointer()->print(OS);
- SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(Base::getInt());
- if (Flags & SCEV::FlagNUW)
- OS << "(u nuw)";
- if (Flags & SCEV::FlagNSW)
- OS << "(u nsw)";
-}
-
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-template <> LLVM_DUMP_METHOD void SCEVUseT<const SCEV *>::dump() const {
- print(dbgs());
- dbgs() << '\n';
-}
-#endif
-
//===----------------------------------------------------------------------===//
// Implementation of the SCEV class.
//
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 7c762a53cebdf..3509a204a3d4e 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -1567,7 +1567,7 @@ Value *SCEVExpander::expandCodeFor(SCEVUse SH, Type *Ty) {
}
Value *SCEVExpander::FindValueInExprValueMap(
- const SCEV *S, const Instruction *InsertPt,
+ SCEVUse S, const Instruction *InsertPt,
SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
// If the expansion is not in CanonicalMode, and the SCEV contains any
// sub scAddRecExpr type SCEV, it is required to expand the SCEV literally.
More information about the llvm-commits
mailing list