[llvm-branch-commits] [llvm] [SCEV] Rewrite to always create canonical SCEV. (PR #185042)

Florian Hahn via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Mar 6 08:50:23 PST 2026


https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/185042

Compute the canonical SCEV on construction, by getting the canonical
SCEVs for all operands and using that to construct the canonical SCEV.

Depends on https://github.com/llvm/llvm-project/pull/185040

>From a62672ccab0e0614c323c303e133e182ef21c1d1 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 19 Feb 2026 18:02:05 +0000
Subject: [PATCH 1/2] [SCEV] Add canonical SCEV pointer and use for SCEVUse
 cmps (NFC)

Add canonical SCEV pointer, to be used in combination with SCEVUse.
For now, SCEVUse with different flags (or their wrapped SCEVs are
different due to different flags) are considered equal if they have the
same canonical SCEV.
---
 llvm/include/llvm/Analysis/ScalarEvolution.h | 39 ++++++++++++++++++--
 1 file changed, 35 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index c59a652509e27..77db49d44583a 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -78,10 +78,14 @@ struct SCEVUse : PointerIntPair<const SCEV *, 2> {
 
   void *getRawPointer() const { return getOpaqueValue(); }
 
+  bool isCanonical() const;
+
+  const SCEV *getCanonical() const;
+
   unsigned getFlags() const { return getInt(); }
 
   bool operator==(const SCEVUse &RHS) const {
-    return getRawPointer() == RHS.getRawPointer();
+    return getCanonical() == RHS.getCanonical();
   }
 
   bool operator==(const SCEV *RHS) const { return getRawPointer() == RHS; }
@@ -120,14 +124,24 @@ template <> struct DenseMapInfo<SCEVUse> {
   }
 
   static unsigned getHashValue(SCEVUse U) {
-    return hash_value(U.getRawPointer());
+    return hash_value(U.getCanonical());
   }
 
   static bool isEqual(const SCEVUse LHS, const SCEVUse RHS) {
-    return LHS.getRawPointer() == RHS.getRawPointer();
+    void *L = LHS.getRawPointer();
+    void *R = RHS.getRawPointer();
+    if (L == reinterpret_cast<void *>(-1) ||
+        L == reinterpret_cast<void *>(-2) ||
+        R == reinterpret_cast<void *>(-1) || R == reinterpret_cast<void *>(-2))
+      return L == R;
+    return LHS.getCanonical() == RHS.getCanonical();
   }
 };
 
+inline bool SCEVUse::isCanonical() const {
+  return getCanonical() == getPointer();
+}
+
 template <> struct simplify_type<SCEVUse> {
   using SimpleType = const SCEV *;
 
@@ -157,6 +171,11 @@ class SCEV : public FoldingSetNode {
   /// miscellaneous information.
   unsigned short SubclassData = 0;
 
+private:
+  /// Pointer to the canonical SCEV for this node. SCEVs that differ only in
+  /// no-wrap flags share the same canonical SCEV.
+  const SCEV *CanonicalSCEV;
+
 public:
   /// NoWrapFlags are bitfield indices into SubclassData.
   ///
@@ -204,7 +223,8 @@ class SCEV : public FoldingSetNode {
 
   explicit SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                 unsigned short ExpressionSize)
-      : FastID(ID), SCEVType(SCEVTy), ExpressionSize(ExpressionSize) {}
+      : FastID(ID), SCEVType(SCEVTy), ExpressionSize(ExpressionSize),
+        CanonicalSCEV(this) {}
   SCEV(const SCEV &) = delete;
   SCEV &operator=(const SCEV &) = delete;
 
@@ -247,6 +267,8 @@ class SCEV : public FoldingSetNode {
 
   /// This method is used for debugging.
   LLVM_ABI void dump() const;
+
+  const SCEV *getCanonical() const { return CanonicalSCEV; }
 };
 
 // Specialize FoldingSetTrait for SCEV to avoid needing to compute
@@ -269,6 +291,11 @@ inline raw_ostream &operator<<(raw_ostream &OS, const SCEV &S) {
   return OS;
 }
 
+inline raw_ostream &operator<<(raw_ostream &OS, SCEVUse U) {
+  U.print(OS);
+  return OS;
+}
+
 /// An object of this class is returned by queries that could not be answered.
 /// For example, if you ask for the number of iterations of a linked-list
 /// traversal loop, you will get one of these.  None of the standard SCEV
@@ -2636,6 +2663,10 @@ template <> struct DenseMapInfo<ScalarEvolution::FoldID> {
   }
 };
 
+inline const SCEV *SCEVUse::getCanonical() const {
+  return getPointer()->getCanonical();
+}
+
 } // end namespace llvm
 
 #endif // LLVM_ANALYSIS_SCALAREVOLUTION_H

>From 41750b850b0de073886cc8a8eb02b18469eceb06 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 19 Feb 2026 19:17:47 +0000
Subject: [PATCH 2/2] [SCEV] Rewrite to always create canonical SCEV.

Compute the canonical SCEV on construction, by getting the canonical
SCEVs for all operands and using that to construct the canonical SCEV.
---
 llvm/include/llvm/Analysis/ScalarEvolution.h |  4 +-
 llvm/lib/Analysis/ScalarEvolution.cpp        | 91 ++++++++++++++++++++
 2 files changed, 94 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 77db49d44583a..297fcf03d600d 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -224,7 +224,7 @@ class SCEV : public FoldingSetNode {
   explicit SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                 unsigned short ExpressionSize)
       : FastID(ID), SCEVType(SCEVTy), ExpressionSize(ExpressionSize),
-        CanonicalSCEV(this) {}
+        CanonicalSCEV(nullptr) {}
   SCEV(const SCEV &) = delete;
   SCEV &operator=(const SCEV &) = delete;
 
@@ -269,6 +269,8 @@ class SCEV : public FoldingSetNode {
   LLVM_ABI void dump() const;
 
   const SCEV *getCanonical() const { return CanonicalSCEV; }
+
+  LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE);
 };
 
 // Specialize FoldingSetTrait for SCEV to avoid needing to compute
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 89dfb20bade47..43aca59938b9f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -258,6 +258,80 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
 //                           SCEV class definitions
 //===----------------------------------------------------------------------===//
 
+void SCEV::computeAndSetCanonical(ScalarEvolution &SE) {
+  switch (getSCEVType()) {
+  case scConstant:
+  case scVScale:
+  case scUnknown:
+    CanonicalSCEV = this;
+    return;
+  default:
+    break;
+  }
+
+  bool Changed = false;
+  SmallVector<SCEVUse, 4> CanonOps;
+  for (SCEVUse Op : operands()) {
+    CanonOps.push_back(Op->getCanonical());
+    Changed |= CanonOps.back() != Op.getPointer();
+  }
+
+  if (!Changed) {
+    CanonicalSCEV = this;
+    return;
+  }
+
+  auto *NAry = dyn_cast<SCEVNAryExpr>(this);
+  SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
+  switch (getSCEVType()) {
+  case scPtrToAddr:
+    CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
+    return;
+  case scPtrToInt:
+    CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
+    return;
+  case scTruncate:
+    CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
+    return;
+  case scZeroExtend:
+    CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
+    return;
+  case scSignExtend:
+    CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
+    return;
+  case scUDivExpr:
+    CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
+    return;
+  case scAddExpr:
+    CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
+    return;
+  case scMulExpr:
+    CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
+    return;
+  case scAddRecExpr:
+    CanonicalSCEV = SE.getAddRecExpr(
+        CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
+    return;
+  case scSMaxExpr:
+    CanonicalSCEV = SE.getSMaxExpr(CanonOps);
+    return;
+  case scUMaxExpr:
+    CanonicalSCEV = SE.getUMaxExpr(CanonOps);
+    return;
+  case scSMinExpr:
+    CanonicalSCEV = SE.getSMinExpr(CanonOps);
+    return;
+  case scUMinExpr:
+    CanonicalSCEV = SE.getUMinExpr(CanonOps);
+    return;
+  case scSequentialUMinExpr:
+    CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
+    return;
+  default:
+    llvm_unreachable("Unknown SCEV type");
+  }
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 LLVM_DUMP_METHOD void SCEVUse::dump() const {
   print(dbgs());
@@ -494,6 +568,7 @@ const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
   void *IP = nullptr;
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
   SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
+  S->computeAndSetCanonical(*this);
   UniqueSCEVs.InsertNode(S, IP);
   return S;
 }
@@ -519,6 +594,7 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) {
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
     return S;
   SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
+  S->computeAndSetCanonical(*this);
   UniqueSCEVs.InsertNode(S, IP);
   return S;
 }
@@ -1133,6 +1209,7 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op) {
         SCEV *S = new (SCEVAllocator)
             SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
         UniqueSCEVs.InsertNode(S, IP);
+        S->computeAndSetCanonical(*this);
         registerUser(S, U);
         return static_cast<const SCEV *>(S);
       });
@@ -1166,6 +1243,7 @@ const SCEV *ScalarEvolution::getPtrToAddrExpr(const SCEV *Op) {
         SCEV *S = new (SCEVAllocator)
             SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
         UniqueSCEVs.InsertNode(S, IP);
+        S->computeAndSetCanonical(*this);
         registerUser(S, U);
         return static_cast<const SCEV *>(S);
       });
@@ -1222,6 +1300,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
     SCEV *S =
         new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Op);
     return S;
   }
@@ -1275,6 +1354,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
   SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
                                                  Op, Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Op);
   return S;
 }
@@ -1645,6 +1725,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
     SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
                                                      Op, Ty);
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Op);
     return S;
   }
@@ -1929,6 +2010,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
   SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
                                                    Op, Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Op);
   return S;
 }
@@ -1985,6 +2067,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
     SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
                                                      Op, Ty);
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Op);
     return S;
   }
@@ -2191,6 +2274,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
   SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
                                                    Op, Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Op);
   return S;
 }
@@ -3035,6 +3119,7 @@ const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
     S = new (SCEVAllocator)
         SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Ops);
   }
   S->setNoWrapFlags(Flags);
@@ -3058,6 +3143,7 @@ const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
     S = new (SCEVAllocator)
         SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     LoopUsers[L].push_back(S);
     registerUser(S, Ops);
   }
@@ -3080,6 +3166,7 @@ const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
     S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
                                         O, Ops.size());
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Ops);
   }
   S->setNoWrapFlags(Flags);
@@ -3647,6 +3734,7 @@ const SCEV *ScalarEvolution::getUDivExpr(SCEVUse LHS, SCEVUse RHS) {
   SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
                                              LHS, RHS);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
   return S;
 }
@@ -4047,6 +4135,7 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
       SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
 
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Ops);
   return S;
 }
@@ -4437,6 +4526,7 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
       SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
 
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Ops);
   return S;
 }
@@ -4527,6 +4617,7 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) {
                                             FirstUnknown);
   FirstUnknown = cast<SCEVUnknown>(S);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   return S;
 }
 



More information about the llvm-branch-commits mailing list