[llvm] [SCEV] Add canonical SCEV pointer and construct canonical SCEVs (NFC) (PR #188858)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 31 09:49:33 PDT 2026


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/188858

>From 36abe53f5e88224f2969093d4085832cb6325bec Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 19 Feb 2026 18:02:05 +0000
Subject: [PATCH 1/2] [SCEV] Add canonical SCEV pointer and construct canonical
 SCEVs (NFC)

Add a canonical SCEV pointer to SCEV, to support comparing SCEVs for
equality, even with different use-specific flags.

Currently this should be NFC, as nothing yet sets use flags.

Compile-time impact: https://llvm-compile-time-tracker.com/compare.php?from=13f1fd006243f756417c3ae992342c0674e3f04e&to=e7f46bcf8bef62c619380fbcccbe6073300b69fe&stat=instructions:u

+0.03% - +0.05% on stage1 configs
-0.03% surprisingly for stage2-O3
---
 llvm/include/llvm/Analysis/ScalarEvolution.h | 52 ++++++++++-
 llvm/lib/Analysis/ScalarEvolution.cpp        | 95 ++++++++++++++++++++
 2 files changed, 143 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 9a97ca218b3d3..69babad03a575 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -80,13 +80,20 @@ struct SCEVUse : PointerIntPair<const SCEV *, 2> {
 
   void *getRawPointer() const { return getOpaqueValue(); }
 
+  /// Returns true of the SCEVUse is canonical, i.e. no SCEVUse flags set in any
+  /// operands.
+  bool isCanonical() const;
+
+  /// Return the canonical SCEV for this SCEVUse.
+  const SCEV *getCanonical() const;
+
   unsigned getFlags() const { return getInt(); }
 
   bool operator==(const SCEVUse &RHS) const {
-    return getRawPointer() == RHS.getRawPointer();
+    return getCanonical() == RHS.getCanonical();
   }
 
-  bool operator==(const SCEV *RHS) const { return getRawPointer() == RHS; }
+  inline bool operator==(const SCEV *RHS) const;
 
   /// Print out the internal representation of this scalar to the specified
   /// stream.  This should really only be used for debugging purposes.
@@ -122,14 +129,24 @@ template <> struct DenseMapInfo<SCEVUse> {
   }
 
   static unsigned getHashValue(SCEVUse U) {
-    return hash_value(U.getRawPointer());
+    return hash_value(U.getCanonical());
   }
 
   static bool isEqual(const SCEVUse LHS, const SCEVUse RHS) {
-    return LHS.getRawPointer() == RHS.getRawPointer();
+    void *L = LHS.getRawPointer();
+    void *R = RHS.getRawPointer();
+    void *Empty = getEmptyKey().getRawPointer();
+    void *Tombstone = getTombstoneKey().getRawPointer();
+    if (L == Empty || L == Tombstone || R == Empty || R == Tombstone)
+      return L == R;
+    return LHS.getCanonical() == RHS.getCanonical();
   }
 };
 
+inline bool SCEVUse::isCanonical() const {
+  return getCanonical() == getPointer() && getFlags() == 0;
+}
+
 template <> struct simplify_type<SCEVUse> {
   using SimpleType = const SCEV *;
 
@@ -159,6 +176,10 @@ class SCEV : public FoldingSetNode {
   /// miscellaneous information.
   unsigned short SubclassData = 0;
 
+  /// Pointer to the canonical version of the SCEV, i.e. one where all operands
+  /// have no SCEVUse flags.
+  const SCEV *CanonicalSCEV = nullptr;
+
 public:
   /// NoWrapFlags are bitfield indices into SubclassData.
   ///
@@ -249,6 +270,16 @@ class SCEV : public FoldingSetNode {
 
   /// This method is used for debugging.
   LLVM_ABI void dump() const;
+
+  /// Compute and set the canonical SCEV, by constructing a SCEV with the same
+  /// operands, but all SCEVUse flags dropped.
+  LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE);
+
+  /// Return the canonical SCEV.
+  LLVM_ABI const SCEV *getCanonical() const {
+    assert(CanonicalSCEV && "canonical SCEV not yet computed");
+    return CanonicalSCEV;
+  }
 };
 
 // Specialize FoldingSetTrait for SCEV to avoid needing to compute
@@ -271,6 +302,11 @@ inline raw_ostream &operator<<(raw_ostream &OS, const SCEV &S) {
   return OS;
 }
 
+inline raw_ostream &operator<<(raw_ostream &OS, SCEVUse U) {
+  U.print(OS);
+  return OS;
+}
+
 /// An object of this class is returned by queries that could not be answered.
 /// For example, if you ask for the number of iterations of a linked-list
 /// traversal loop, you will get one of these.  None of the standard SCEV
@@ -2629,6 +2665,14 @@ template <> struct DenseMapInfo<ScalarEvolution::FoldID> {
   }
 };
 
+inline const SCEV *SCEVUse::getCanonical() const {
+  return getPointer()->getCanonical();
+}
+
+inline bool SCEVUse::operator==(const SCEV *RHS) const {
+  return getCanonical() == RHS->getCanonical();
+}
+
 } // end namespace llvm
 
 #endif // LLVM_ANALYSIS_SCALAREVOLUTION_H
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index d99488121baba..73dc298e66e6f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -258,6 +258,84 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
 //                           SCEV class definitions
 //===----------------------------------------------------------------------===//
 
+void SCEV::computeAndSetCanonical(ScalarEvolution &SE) {
+  // Leaf nodes are always their own canonical.
+  switch (getSCEVType()) {
+  case scConstant:
+  case scVScale:
+  case scUnknown:
+    CanonicalSCEV = this;
+    return;
+  default:
+    break;
+  }
+
+  // For all other expressions, check whether any immediate operand has a
+  // different canonical. Since operands are always created before their parent,
+  // their canonical pointers are already set — no recursion needed.
+  bool Changed = false;
+  SmallVector<SCEVUse, 4> CanonOps;
+  for (SCEVUse Op : operands()) {
+    CanonOps.push_back(Op->getCanonical());
+    Changed |= CanonOps.back() != Op.getPointer();
+  }
+
+  if (!Changed) {
+    CanonicalSCEV = this;
+    return;
+  }
+
+  auto *NAry = dyn_cast<SCEVNAryExpr>(this);
+  SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
+  switch (getSCEVType()) {
+  case scPtrToAddr:
+    CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
+    return;
+  case scPtrToInt:
+    CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
+    return;
+  case scTruncate:
+    CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
+    return;
+  case scZeroExtend:
+    CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
+    return;
+  case scSignExtend:
+    CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
+    return;
+  case scUDivExpr:
+    CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
+    return;
+  case scAddExpr:
+    CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
+    return;
+  case scMulExpr:
+    CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
+    return;
+  case scAddRecExpr:
+    CanonicalSCEV = SE.getAddRecExpr(
+        CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
+    return;
+  case scSMaxExpr:
+    CanonicalSCEV = SE.getSMaxExpr(CanonOps);
+    return;
+  case scUMaxExpr:
+    CanonicalSCEV = SE.getUMaxExpr(CanonOps);
+    return;
+  case scSMinExpr:
+    CanonicalSCEV = SE.getSMinExpr(CanonOps);
+    return;
+  case scUMinExpr:
+    CanonicalSCEV = SE.getUMinExpr(CanonOps);
+    return;
+  case scSequentialUMinExpr:
+    CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
+    return;
+  default:
+    llvm_unreachable("Unknown SCEV type");
+  }
+}
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
 LLVM_DUMP_METHOD void SCEVUse::dump() const {
   print(dbgs());
@@ -495,6 +573,7 @@ const SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
   if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
   SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   return S;
 }
 
@@ -520,6 +599,7 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) {
     return S;
   SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   return S;
 }
 
@@ -1133,6 +1213,7 @@ const SCEV *ScalarEvolution::getLosslessPtrToIntExpr(const SCEV *Op) {
         SCEV *S = new (SCEVAllocator)
             SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
         UniqueSCEVs.InsertNode(S, IP);
+        S->computeAndSetCanonical(*this);
         registerUser(S, U);
         return static_cast<const SCEV *>(S);
       });
@@ -1166,6 +1247,7 @@ const SCEV *ScalarEvolution::getPtrToAddrExpr(const SCEV *Op) {
         SCEV *S = new (SCEVAllocator)
             SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
         UniqueSCEVs.InsertNode(S, IP);
+        S->computeAndSetCanonical(*this);
         registerUser(S, U);
         return static_cast<const SCEV *>(S);
       });
@@ -1222,6 +1304,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
     SCEV *S =
         new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Op);
     return S;
   }
@@ -1275,6 +1358,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
   SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
                                                  Op, Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Op);
   return S;
 }
@@ -1645,6 +1729,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
     SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
                                                      Op, Ty);
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Op);
     return S;
   }
@@ -1929,6 +2014,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
   SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
                                                    Op, Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Op);
   return S;
 }
@@ -1985,6 +2071,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
     SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
                                                      Op, Ty);
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Op);
     return S;
   }
@@ -2191,6 +2278,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
   SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
                                                    Op, Ty);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Op);
   return S;
 }
@@ -3035,6 +3123,7 @@ const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
     S = new (SCEVAllocator)
         SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Ops);
   }
   S->setNoWrapFlags(Flags);
@@ -3058,6 +3147,7 @@ const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
     S = new (SCEVAllocator)
         SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     LoopUsers[L].push_back(S);
     registerUser(S, Ops);
   }
@@ -3080,6 +3170,7 @@ const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
     S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
                                         O, Ops.size());
     UniqueSCEVs.InsertNode(S, IP);
+    S->computeAndSetCanonical(*this);
     registerUser(S, Ops);
   }
   S->setNoWrapFlags(Flags);
@@ -3647,6 +3738,7 @@ const SCEV *ScalarEvolution::getUDivExpr(SCEVUse LHS, SCEVUse RHS) {
   SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
                                              LHS, RHS);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
   return S;
 }
@@ -4047,6 +4139,7 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
       SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
 
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Ops);
   return S;
 }
@@ -4437,6 +4530,7 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
       SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
 
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   registerUser(S, Ops);
   return S;
 }
@@ -4527,6 +4621,7 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) {
                                             FirstUnknown);
   FirstUnknown = cast<SCEVUnknown>(S);
   UniqueSCEVs.InsertNode(S, IP);
+  S->computeAndSetCanonical(*this);
   return S;
 }
 

>From b62ffc9a8eeb17c9d7dd52c18de1df6016ddd5e3 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 31 Mar 2026 11:35:53 +0100
Subject: [PATCH 2/2] !fixup restore raw pointer compares

---
 llvm/include/llvm/Analysis/ScalarEvolution.h | 18 ++++--------------
 1 file changed, 4 insertions(+), 14 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 69babad03a575..30d02d901f723 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -90,10 +90,10 @@ struct SCEVUse : PointerIntPair<const SCEV *, 2> {
   unsigned getFlags() const { return getInt(); }
 
   bool operator==(const SCEVUse &RHS) const {
-    return getCanonical() == RHS.getCanonical();
+    return getRawPointer() == RHS.getRawPointer();
   }
 
-  inline bool operator==(const SCEV *RHS) const;
+  bool operator==(const SCEV *RHS) const { return getRawPointer() == RHS; }
 
   /// Print out the internal representation of this scalar to the specified
   /// stream.  This should really only be used for debugging purposes.
@@ -129,17 +129,11 @@ template <> struct DenseMapInfo<SCEVUse> {
   }
 
   static unsigned getHashValue(SCEVUse U) {
-    return hash_value(U.getCanonical());
+    return hash_value(U.getRawPointer());
   }
 
   static bool isEqual(const SCEVUse LHS, const SCEVUse RHS) {
-    void *L = LHS.getRawPointer();
-    void *R = RHS.getRawPointer();
-    void *Empty = getEmptyKey().getRawPointer();
-    void *Tombstone = getTombstoneKey().getRawPointer();
-    if (L == Empty || L == Tombstone || R == Empty || R == Tombstone)
-      return L == R;
-    return LHS.getCanonical() == RHS.getCanonical();
+    return LHS.getRawPointer() == RHS.getRawPointer();
   }
 };
 
@@ -2669,10 +2663,6 @@ inline const SCEV *SCEVUse::getCanonical() const {
   return getPointer()->getCanonical();
 }
 
-inline bool SCEVUse::operator==(const SCEV *RHS) const {
-  return getCanonical() == RHS->getCanonical();
-}
-
 } // end namespace llvm
 
 #endif // LLVM_ANALYSIS_SCALAREVOLUTION_H



More information about the llvm-commits mailing list