[llvm-branch-commits] [llvm] [SCEV] Rewrite to always create canonical SCEV. (PR #185042)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Mar 6 08:51:00 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: Florian Hahn (fhahn)
<details>
<summary>Changes</summary>
Compute the canonical SCEV on construction, by getting the canonical
SCEVs for all operands and using that to construct the canonical SCEV.
Depends on https://github.com/llvm/llvm-project/pull/185040 (included in PR)
---
Full diff: https://github.com/llvm/llvm-project/pull/185042.diff
2 Files Affected:
- (modified) llvm/include/llvm/Analysis/ScalarEvolution.h (+37-4)
- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+91)
``````````diff
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index c59a652509e27..297fcf03d600d 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -78,10 +78,14 @@ struct SCEVUse : PointerIntPair<const SCEV *, 2> {
void *getRawPointer() const { return getOpaqueValue(); }
+ bool isCanonical() const;
+
+ const SCEV *getCanonical() const;
+
unsigned getFlags() const { return getInt(); }
bool operator==(const SCEVUse &RHS) const {
- return getRawPointer() == RHS.getRawPointer();
+ return getCanonical() == RHS.getCanonical();
}
bool operator==(const SCEV *RHS) const { return getRawPointer() == RHS; }
@@ -120,14 +124,24 @@ template <> struct DenseMapInfo<SCEVUse> {
}
static unsigned getHashValue(SCEVUse U) {
- return hash_value(U.getRawPointer());
+ return hash_value(U.getCanonical());
}
static bool isEqual(const SCEVUse LHS, const SCEVUse RHS) {
- return LHS.getRawPointer() == RHS.getRawPointer();
+ void *L = LHS.getRawPointer();
+ void *R = RHS.getRawPointer();
+ if (L == reinterpret_cast<void *>(-1) ||
+ L == reinterpret_cast<void *>(-2) ||
+ R == reinterpret_cast<void *>(-1) || R == reinterpret_cast<void *>(-2))
+ return L == R;
+ return LHS.getCanonical() == RHS.getCanonical();
}
};
+inline bool SCEVUse::isCanonical() const {
+ return getCanonical() == getPointer();
+}
+
template <> struct simplify_type<SCEVUse> {
using SimpleType = const SCEV *;
@@ -157,6 +171,11 @@ class SCEV : public FoldingSetNode {
/// miscellaneous information.
unsigned short SubclassData = 0;
+private:
+ /// Pointer to the canonical SCEV for this node. SCEVs that differ only in
+ /// no-wrap flags share the same canonical SCEV.
+ const SCEV *CanonicalSCEV;
+
public:
/// NoWrapFlags are bitfield indices into SubclassData.
///
@@ -204,7 +223,8 @@ class SCEV : public FoldingSetNode {
explicit SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
unsigned short ExpressionSize)
- : FastID(ID), SCEVType(SCEVTy), ExpressionSize(ExpressionSize) {}
+ : FastID(ID), SCEVType(SCEVTy), ExpressionSize(ExpressionSize),
+ CanonicalSCEV(nullptr) {}
SCEV(const SCEV &) = delete;
SCEV &operator=(const SCEV &) = delete;
@@ -247,6 +267,10 @@ class SCEV : public FoldingSetNode {
/// This method is used for debugging.
LLVM_ABI void dump() const;
+
+ const SCEV *getCanonical() const { return CanonicalSCEV; }
+
+ LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE);
};
// Specialize FoldingSetTrait for SCEV to avoid needing to compute
@@ -269,6 +293,11 @@ inline raw_ostream &operator<<(raw_ostream &OS, const SCEV &S) {
return OS;
}
+inline raw_ostream &operator<<(raw_ostream &OS, SCEVUse U) {
+ U.print(OS);
+ return OS;
+}
+
/// An object of this class is returned by queries that could not be answered.
/// For example, if you ask for the number of iterations of a linked-list
/// traversal loop, you will get one of these. None of the standard SCEV
@@ -2636,6 +2665,10 @@ template <> struct DenseMapInfo<ScalarEvolution::FoldID> {
}
};
+inline const SCEV *SCEVUse::getCanonical() const {
+ return getPointer()->getCanonical();
+}
+
} // end namespace llvm
#endif // LLVM_ANALYSIS_SCALAREVOLUTION_H
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 89dfb20bade47..43aca59938b9f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -258,6 +258,80 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
// SCEV class definitions
//===----------------------------------------------------------------------===//
+void SCEV::computeAndSetCanonical(ScalarEvolution &SE) {
+ switch (getSCEVType()) {
+ case scConstant:
+ case scVScale:
+ case scUnknown:
+ CanonicalSCEV = this;
+ return;
+ default:
+ break;
+ }
+
+ bool Changed = false;
+ SmallVector<SCEVUse, 4> CanonOps;
+ for (SCEVUse Op : operands()) {
+ CanonOps.push_back(Op->getCanonical());
+ Changed |= CanonOps.back() != Op.getPointer();
+ }
+
+ if (!Changed) {
+ CanonicalSCEV = this;
+ return;
+ }
+
+ auto *NAry = dyn_cast<SCEVNAryExpr>(this);
+ SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
+ switch (getSCEVType()) {
+ case scPtrToAddr:
+ CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
+ return;
+ case scPtrToInt:
+ CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
+ return;
+ case scTruncate:
+ CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
+ return;
+ case scZeroExtend:
+ CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
+ return;
+ case scSignExtend:
+ CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
+ return;
+ case scUDivExpr:
+ CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
+ return;
+ case scAddExpr:
+ CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
+ return;
+ case scMulExpr:
+ CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
+ return;
+ case scAddRecExpr:
+ CanonicalSCEV = SE.getAddRecExpr(
+ CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
+ return;
+ case scSMaxExpr:
+ CanonicalSCEV = SE.getSMaxExpr(CanonOps);
+ return;
+ case scUMaxExpr:
+ CanonicalSCEV = SE.getUMaxExpr(CanonOps);
+ return;
+ case scSMinExpr:
+ CanonicalSCEV = SE.getSMinExpr(CanonOps);
+ return;
+ case scUMinExpr:
+ CanonicalSCEV = SE.getUMinExpr(CanonOps);
+ return;
+ case scSequentialUMinExpr:
+ CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
+ return;
+ default:
+ llvm_unreachable("Unknown SCEV type");
+ }
+}
+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
LLVM_DUMP_METHOD void SCEVUse::dump() const {
print(dbgs());
@@ -494,6 +568,7 @@ const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
+ S->computeAndSetCanonical(*this);
UniqueSCEVs.InsertNode(S, IP);
return S;
}
@@ -519,6 +594,7 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) {
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
+ S->computeAndSetCanonical(*this);
UniqueSCEVs.InsertNode(S, IP);
return S;
}
@@ -1133,6 +1209,7 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op) {
SCEV *S = new (SCEVAllocator)
SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, U);
return static_cast<const SCEV *>(S);
});
@@ -1166,6 +1243,7 @@ const SCEV *ScalarEvolution::getPtrToAddrExpr(const SCEV *Op) {
SCEV *S = new (SCEVAllocator)
SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, U);
return static_cast<const SCEV *>(S);
});
@@ -1222,6 +1300,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
SCEV *S =
new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Op);
return S;
}
@@ -1275,6 +1354,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Op);
return S;
}
@@ -1645,6 +1725,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Op);
return S;
}
@@ -1929,6 +2010,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Op);
return S;
}
@@ -1985,6 +2067,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Op);
return S;
}
@@ -2191,6 +2274,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
Op, Ty);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Op);
return S;
}
@@ -3035,6 +3119,7 @@ const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
S = new (SCEVAllocator)
SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Ops);
}
S->setNoWrapFlags(Flags);
@@ -3058,6 +3143,7 @@ const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
S = new (SCEVAllocator)
SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
LoopUsers[L].push_back(S);
registerUser(S, Ops);
}
@@ -3080,6 +3166,7 @@ const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Ops);
}
S->setNoWrapFlags(Flags);
@@ -3647,6 +3734,7 @@ const SCEV *ScalarEvolution::getUDivExpr(SCEVUse LHS, SCEVUse RHS) {
SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
LHS, RHS);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
return S;
}
@@ -4047,6 +4135,7 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Ops);
return S;
}
@@ -4437,6 +4526,7 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
registerUser(S, Ops);
return S;
}
@@ -4527,6 +4617,7 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) {
FirstUnknown);
FirstUnknown = cast<SCEVUnknown>(S);
UniqueSCEVs.InsertNode(S, IP);
+ S->computeAndSetCanonical(*this);
return S;
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/185042
More information about the llvm-branch-commits
mailing list