[llvm] [polly] [SCEV] Move NoWrapFlags definition outside SCEV scope, use for SCEVUse. (PR #190199)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 3 09:16:45 PDT 2026


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/190199

>From 8b1dbadc4de43e4039c5e579e46fad7655f27425 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 25 Mar 2026 11:13:15 +0000
Subject: [PATCH 1/3] [SCEV] Move NoWrapFlags definition outside SCEV scope,
 use for SCEVUse.

The patch moves out of SCEV's scope so they can be re-used for SCEVUse.
SCEVUse gets an additional getNoWrapFlags helper that returns the union
of the expressions SCEV flags and the use-specific flags.

SCEVExpander has been updated to use this new helper.

In order to avoid other changes, the original names are exposed via
constexpr in SCEV. Not sure if there's a nicer way. One alternative
would be to define the enum in struct, and have SCEV inherit from it.

The patch also clarifies that the SCEVUse flags encode NUW/NSW, and
hides getInt, setInt and getPointer to avoid potential mis-use
---
 llvm/include/llvm/Analysis/ScalarEvolution.h  | 124 +++++++++++-------
 .../Analysis/ScalarEvolutionExpressions.h     |  14 ++
 .../Utils/ScalarEvolutionExpander.cpp         |  18 +--
 3 files changed, 99 insertions(+), 57 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 08c7e43e708b8..d529e29d2c096 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -66,6 +66,52 @@ enum SCEVTypes : unsigned short;
 
 LLVM_ABI extern bool VerifySCEV;
 
+/// NoWrapFlags are bitfield indices into SCEV's SubclassData.
+///
+/// Add and Mul expressions may have no-unsigned-wrap <NUW> or
+/// no-signed-wrap <NSW> properties, which are derived from the IR
+/// operator. NSW is a misnomer that we use to mean no signed overflow or
+/// underflow.
+///
+/// AddRec expressions may have a no-self-wraparound <NW> property if, in
+/// the integer domain, abs(step) * max-iteration(loop) <=
+/// unsigned-max(bitwidth).  This means that the recurrence will never reach
+/// its start value if the step is non-zero.  Computing the same value on
+/// each iteration is not considered wrapping, and recurrences with step = 0
+/// are trivially <NW>.  <NW> is independent of the sign of step and the
+/// value the add recurrence starts with.
+///
+/// Note that NUW and NSW are also valid properties of a recurrence, and
+/// either implies NW. For convenience, NW will be set for a recurrence
+/// whenever either NUW or NSW are set.
+///
+/// We require that the flag on a SCEV apply to the entire scope in which
+/// that SCEV is defined.  A SCEV's scope is set of locations dominated by
+/// a defining location, which is in turn described by the following rules:
+/// * A SCEVUnknown is at the point of definition of the Value.
+/// * A SCEVConstant is defined at all points.
+/// * A SCEVAddRec is defined starting with the header of the associated
+///   loop.
+/// * All other SCEVs are defined at the earlest point all operands are
+///   defined.
+///
+/// The above rules describe a maximally hoisted form (without regards to
+/// potential control dependence).  A SCEV is defined anywhere a
+/// corresponding instruction could be defined in said maximally hoisted
+/// form.  Note that SCEVUDivExpr (currently the only expression type which
+/// can trap) can be defined per these rules in regions where it would trap
+/// at runtime.  A SCEV being defined does not require the existence of any
+/// instruction within the defined scope.
+namespace SCEVWrap {
+enum NoWrapFlags {
+  FlagAnyWrap = 0,    // No guarantee.
+  FlagNW = (1 << 0),  // No self-wrap.
+  FlagNUW = (1 << 1), // No unsigned wrap.
+  FlagNSW = (1 << 2), // No signed wrap.
+  NoWrapMask = (1 << 3) - 1
+};
+} // namespace SCEVWrap
+
 class SCEV;
 
 template <typename SCEVPtrT = const SCEV *>
@@ -73,12 +119,13 @@ struct SCEVUseT : PointerIntPair<SCEVPtrT, 2> {
   using Base = PointerIntPair<SCEVPtrT, 2>;
 
   SCEVUseT() : Base() { Base::setFromOpaqueValue(nullptr); }
-  SCEVUseT(SCEVPtrT S) : SCEVUseT(S, 0) {}
-  SCEVUseT(SCEVPtrT S, unsigned Flags) : Base(S, Flags) {}
+  SCEVUseT(SCEVPtrT S) : Base(S, 0) {}
+  /// Construct with NoWrapFlags; only NUW/NSW are encoded, NW is dropped.
+  SCEVUseT(SCEVPtrT S, SCEVWrap::NoWrapFlags Flags) : Base(S, Flags >> 1) {}
   template <typename OtherPtrT, typename = std::enable_if_t<
                                     std::is_convertible_v<OtherPtrT, SCEVPtrT>>>
   SCEVUseT(const SCEVUseT<OtherPtrT> &Other)
-      : Base(Other.getPointer(), Other.getInt()) {}
+      : SCEVUseT(Other.getPointer(), Other.getUseNoWrapFlags()) {}
 
   operator SCEVPtrT() const { return Base::getPointer(); }
   SCEVPtrT operator->() const { return Base::getPointer(); }
@@ -92,7 +139,19 @@ struct SCEVUseT : PointerIntPair<SCEVPtrT, 2> {
   /// Return the canonical SCEV for this SCEVUse.
   const SCEV *getCanonical() const;
 
-  unsigned getFlags() const { return Base::getInt(); }
+  /// Return the no-wrap flags for this SCEVUse, which is the union of the
+  /// use-specific flags and the underlying SCEV's flags, masked by \p Mask.
+  inline SCEVWrap::NoWrapFlags
+  getNoWrapFlags(SCEVWrap::NoWrapFlags Mask = SCEVWrap::NoWrapMask) const;
+
+  /// Return only the use-specific no-wrap flags (NUW/NSW) without the
+  /// underlying SCEV's flags.
+  SCEVWrap::NoWrapFlags getUseNoWrapFlags() const {
+    unsigned UseFlags = Base::getInt() << 1;
+    if (UseFlags & (SCEVWrap::FlagNUW | SCEVWrap::FlagNSW))
+      UseFlags |= SCEVWrap::FlagNW;
+    return SCEVWrap::NoWrapFlags(UseFlags);
+  }
 
   bool operator==(const SCEVUseT &RHS) const {
     return getRawPointer() == RHS.getRawPointer();
@@ -106,6 +165,11 @@ struct SCEVUseT : PointerIntPair<SCEVPtrT, 2> {
 
   /// This method is used for debugging.
   void dump() const;
+
+private:
+  using Base::getInt;
+  using Base::setInt;
+  using Base::setPointer;
 };
 
 /// Deduction guide for various SCEV subclass pointers.
@@ -165,7 +229,7 @@ struct CastInfo<SCEVUseT<ToSCEVPtrT>, SCEVUse,
 
   static bool isPossible(const SCEVUse &U) { return isa<To>(U.getPointer()); }
   static CastReturnType doCast(const SCEVUse &U) {
-    return {cast<To>(U.getPointer()), U.getFlags()};
+    return CastReturnType(cast<To>(U.getPointer()), U.getUseNoWrapFlags());
   }
   static CastReturnType castFailed() { return CastReturnType(nullptr); }
   static CastReturnType doCastIfPossible(const SCEVUse &U) {
@@ -206,49 +270,13 @@ class SCEV : public FoldingSetNode {
   const SCEV *CanonicalSCEV = nullptr;
 
 public:
-  /// NoWrapFlags are bitfield indices into SubclassData.
-  ///
-  /// Add and Mul expressions may have no-unsigned-wrap <NUW> or
-  /// no-signed-wrap <NSW> properties, which are derived from the IR
-  /// operator. NSW is a misnomer that we use to mean no signed overflow or
-  /// underflow.
-  ///
-  /// AddRec expressions may have a no-self-wraparound <NW> property if, in
-  /// the integer domain, abs(step) * max-iteration(loop) <=
-  /// unsigned-max(bitwidth).  This means that the recurrence will never reach
-  /// its start value if the step is non-zero.  Computing the same value on
-  /// each iteration is not considered wrapping, and recurrences with step = 0
-  /// are trivially <NW>.  <NW> is independent of the sign of step and the
-  /// value the add recurrence starts with.
-  ///
-  /// Note that NUW and NSW are also valid properties of a recurrence, and
-  /// either implies NW. For convenience, NW will be set for a recurrence
-  /// whenever either NUW or NSW are set.
-  ///
-  /// We require that the flag on a SCEV apply to the entire scope in which
-  /// that SCEV is defined.  A SCEV's scope is set of locations dominated by
-  /// a defining location, which is in turn described by the following rules:
-  /// * A SCEVUnknown is at the point of definition of the Value.
-  /// * A SCEVConstant is defined at all points.
-  /// * A SCEVAddRec is defined starting with the header of the associated
-  ///   loop.
-  /// * All other SCEVs are defined at the earlest point all operands are
-  ///   defined.
-  ///
-  /// The above rules describe a maximally hoisted form (without regards to
-  /// potential control dependence).  A SCEV is defined anywhere a
-  /// corresponding instruction could be defined in said maximally hoisted
-  /// form.  Note that SCEVUDivExpr (currently the only expression type which
-  /// can trap) can be defined per these rules in regions where it would trap
-  /// at runtime.  A SCEV being defined does not require the existence of any
-  /// instruction within the defined scope.
-  enum NoWrapFlags {
-    FlagAnyWrap = 0,    // No guarantee.
-    FlagNW = (1 << 0),  // No self-wrap.
-    FlagNUW = (1 << 1), // No unsigned wrap.
-    FlagNSW = (1 << 2), // No signed wrap.
-    NoWrapMask = (1 << 3) - 1
-  };
+  /// Expose SCEVWrap::NoWrapFlags as SCEV::NoWrapFlags.
+  using NoWrapFlags = SCEVWrap::NoWrapFlags;
+  static constexpr auto FlagAnyWrap = SCEVWrap::FlagAnyWrap;
+  static constexpr auto FlagNW = SCEVWrap::FlagNW;
+  static constexpr auto FlagNUW = SCEVWrap::FlagNUW;
+  static constexpr auto FlagNSW = SCEVWrap::FlagNSW;
+  static constexpr auto NoWrapMask = SCEVWrap::NoWrapMask;
 
   explicit SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                 unsigned short ExpressionSize)
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 2fc928d5955d1..828e29cd66d7b 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -1046,6 +1046,20 @@ class SCEVLoopAddRecRewriter
   LoopToScevMapT ⤅
 };
 
+template <typename SCEVPtrT>
+inline SCEVWrap::NoWrapFlags
+SCEVUseT<SCEVPtrT>::getNoWrapFlags(SCEVWrap::NoWrapFlags Mask) const {
+  unsigned Flags = SCEV::FlagAnyWrap;
+  if (auto *NAry = dyn_cast<SCEVNAryExpr>(Base::getPointer()))
+    Flags = NAry->getNoWrapFlags();
+  // Use-flags only encode NUW/NSW in 2 bits; shift to align with NoWrapFlags.
+  unsigned UseFlags = Base::getInt() << 1;
+  // NUW or NSW implies NW.
+  if (UseFlags & (SCEVWrap::FlagNUW | SCEVWrap::FlagNSW))
+    UseFlags |= SCEVWrap::FlagNW;
+  return SCEVWrap::NoWrapFlags((Flags | UseFlags) & Mask);
+}
+
 } // end namespace llvm
 
 #endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 3509a204a3d4e..f83c73eec147d 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -573,7 +573,7 @@ Value *SCEVExpander::visitAddExpr(SCEVUseT<const SCEVAddExpr *> S) {
             X = SE.getSCEV(U->getValue());
         NewOps.push_back(X);
       }
-      Sum = expandAddToGEP(SE.getAddExpr(NewOps), Sum, S->getNoWrapFlags());
+      Sum = expandAddToGEP(SE.getAddExpr(NewOps), Sum, S.getNoWrapFlags());
     } else if (Op->isNonConstantNegative()) {
       // Instead of doing a negate and add, just do a subtract.
       Value *W = expand(SE.getNegativeSCEV(Op));
@@ -586,7 +586,7 @@ Value *SCEVExpander::visitAddExpr(SCEVUseT<const SCEVAddExpr *> S) {
       // Canonicalize a constant to the RHS.
       if (isa<Constant>(Sum))
         std::swap(Sum, W);
-      Sum = InsertBinop(Instruction::Add, Sum, W, S->getNoWrapFlags(),
+      Sum = InsertBinop(Instruction::Add, Sum, W, S.getNoWrapFlags(),
                         /*IsSafeToHoist*/ true);
       ++I;
     }
@@ -670,7 +670,7 @@ Value *SCEVExpander::visitMulExpr(SCEVUseT<const SCEVMulExpr *> S) {
       if (match(W, m_Power2(RHS))) {
         // Canonicalize Prod*(1<<C) to Prod<<C.
         assert(!Ty->isVectorTy() && "vector types are not SCEVable");
-        auto NWFlags = S->getNoWrapFlags();
+        auto NWFlags = S.getNoWrapFlags();
         // clear nsw flag if shl will produce poison value.
         if (RHS->logBase2() == RHS->getBitWidth() - 1)
           NWFlags = ScalarEvolution::clearFlags(NWFlags, SCEV::FlagNSW);
@@ -678,7 +678,7 @@ Value *SCEVExpander::visitMulExpr(SCEVUseT<const SCEVMulExpr *> S) {
                            ConstantInt::get(Ty, RHS->logBase2()), NWFlags,
                            /*IsSafeToHoist*/ true);
       } else {
-        Prod = InsertBinop(Instruction::Mul, Prod, W, S->getNoWrapFlags(),
+        Prod = InsertBinop(Instruction::Mul, Prod, W, S.getNoWrapFlags(),
                            /*IsSafeToHoist*/ true);
       }
     }
@@ -1340,8 +1340,8 @@ Value *SCEVExpander::visitAddRecExpr(SCEVUseT<const SCEVAddRecExpr *> S) {
     SmallVector<SCEVUse, 4> NewOps(S->getNumOperands());
     for (unsigned i = 0, e = S->getNumOperands(); i != e; ++i)
       NewOps[i] = SE.getAnyExtendExpr(S->getOperand(i), CanonicalIV->getType());
-    Value *V = expand(SE.getAddRecExpr(NewOps, S->getLoop(),
-                                       S->getNoWrapFlags(SCEV::FlagNW)));
+    Value *V = expand(
+        SE.getAddRecExpr(NewOps, S->getLoop(), S.getNoWrapFlags(SCEV::FlagNW)));
     BasicBlock::iterator NewInsertPt =
         findInsertPointAfter(cast<Instruction>(V), &*Builder.GetInsertPoint());
     V = expand(SE.getTruncateExpr(SE.getUnknown(V), Ty), NewInsertPt);
@@ -1358,13 +1358,13 @@ Value *SCEVExpander::visitAddRecExpr(SCEVUseT<const SCEVAddRecExpr *> S) {
     if (isa<PointerType>(S->getType())) {
       Value *StartV = expand(SE.getPointerBase(S));
       return expandAddToGEP(SE.removePointerBase(S), StartV,
-                            S->getNoWrapFlags(SCEV::FlagNUW));
+                            S.getNoWrapFlags(SCEV::FlagNUW));
     }
 
     SmallVector<SCEVUse, 4> NewOps(S->operands());
     NewOps[0] = SE.getConstant(Ty, 0);
-    const SCEV *Rest = SE.getAddRecExpr(NewOps, L,
-                                        S->getNoWrapFlags(SCEV::FlagNW));
+    const SCEV *Rest =
+        SE.getAddRecExpr(NewOps, L, S.getNoWrapFlags(SCEV::FlagNW));
 
     // Just do a normal add. Pre-expand the operands to suppress folding.
     //

>From 7b479015e9a966361837a6a49cd609095089faf6 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Thu, 2 Apr 2026 18:26:40 +0100
Subject: [PATCH 2/3] !fixup use enum class SCEVNoWrapFlags

---
 llvm/include/llvm/Analysis/ScalarEvolution.h  | 53 ++++++++++---------
 .../Analysis/ScalarEvolutionExpressions.h     | 35 +++++++-----
 llvm/lib/Analysis/LoopAccessAnalysis.cpp      |  2 +-
 llvm/lib/Analysis/ScalarEvolution.cpp         | 49 ++++++++---------
 .../Utils/ScalarEvolutionExpander.cpp         | 13 ++---
 5 files changed, 78 insertions(+), 74 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index d529e29d2c096..cb46e47a59d85 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -22,6 +22,7 @@
 
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/BitmaskEnum.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/DenseMapInfo.h"
 #include "llvm/ADT/FoldingSet.h"
@@ -102,15 +103,14 @@ LLVM_ABI extern bool VerifySCEV;
 /// can trap) can be defined per these rules in regions where it would trap
 /// at runtime.  A SCEV being defined does not require the existence of any
 /// instruction within the defined scope.
-namespace SCEVWrap {
-enum NoWrapFlags {
+enum class SCEVNoWrapFlags {
   FlagAnyWrap = 0,    // No guarantee.
   FlagNW = (1 << 0),  // No self-wrap.
   FlagNUW = (1 << 1), // No unsigned wrap.
   FlagNSW = (1 << 2), // No signed wrap.
-  NoWrapMask = (1 << 3) - 1
+  NoWrapMask = (1 << 3) - 1,
+  LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/NoWrapMask)
 };
-} // namespace SCEVWrap
 
 class SCEV;
 
@@ -121,7 +121,8 @@ struct SCEVUseT : PointerIntPair<SCEVPtrT, 2> {
   SCEVUseT() : Base() { Base::setFromOpaqueValue(nullptr); }
   SCEVUseT(SCEVPtrT S) : Base(S, 0) {}
   /// Construct with NoWrapFlags; only NUW/NSW are encoded, NW is dropped.
-  SCEVUseT(SCEVPtrT S, SCEVWrap::NoWrapFlags Flags) : Base(S, Flags >> 1) {}
+  SCEVUseT(SCEVPtrT S, SCEVNoWrapFlags Flags)
+      : Base(S, static_cast<unsigned>(Flags) >> 1) {}
   template <typename OtherPtrT, typename = std::enable_if_t<
                                     std::is_convertible_v<OtherPtrT, SCEVPtrT>>>
   SCEVUseT(const SCEVUseT<OtherPtrT> &Other)
@@ -141,16 +142,17 @@ struct SCEVUseT : PointerIntPair<SCEVPtrT, 2> {
 
   /// Return the no-wrap flags for this SCEVUse, which is the union of the
   /// use-specific flags and the underlying SCEV's flags, masked by \p Mask.
-  inline SCEVWrap::NoWrapFlags
-  getNoWrapFlags(SCEVWrap::NoWrapFlags Mask = SCEVWrap::NoWrapMask) const;
+  SCEVNoWrapFlags
+  getNoWrapFlags(SCEVNoWrapFlags Mask = SCEVNoWrapFlags::NoWrapMask) const;
 
   /// Return only the use-specific no-wrap flags (NUW/NSW) without the
   /// underlying SCEV's flags.
-  SCEVWrap::NoWrapFlags getUseNoWrapFlags() const {
-    unsigned UseFlags = Base::getInt() << 1;
-    if (UseFlags & (SCEVWrap::FlagNUW | SCEVWrap::FlagNSW))
-      UseFlags |= SCEVWrap::FlagNW;
-    return SCEVWrap::NoWrapFlags(UseFlags);
+  SCEVNoWrapFlags getUseNoWrapFlags() const {
+    SCEVNoWrapFlags UseFlags =
+        static_cast<SCEVNoWrapFlags>(Base::getInt() << 1);
+    if (any(UseFlags & (SCEVNoWrapFlags::FlagNUW | SCEVNoWrapFlags::FlagNSW)))
+      UseFlags |= SCEVNoWrapFlags::FlagNW;
+    return UseFlags;
   }
 
   bool operator==(const SCEVUseT &RHS) const {
@@ -270,13 +272,12 @@ class SCEV : public FoldingSetNode {
   const SCEV *CanonicalSCEV = nullptr;
 
 public:
-  /// Expose SCEVWrap::NoWrapFlags as SCEV::NoWrapFlags.
-  using NoWrapFlags = SCEVWrap::NoWrapFlags;
-  static constexpr auto FlagAnyWrap = SCEVWrap::FlagAnyWrap;
-  static constexpr auto FlagNW = SCEVWrap::FlagNW;
-  static constexpr auto FlagNUW = SCEVWrap::FlagNUW;
-  static constexpr auto FlagNSW = SCEVWrap::FlagNSW;
-  static constexpr auto NoWrapMask = SCEVWrap::NoWrapMask;
+  using NoWrapFlags = SCEVNoWrapFlags;
+  static constexpr auto FlagAnyWrap = SCEVNoWrapFlags::FlagAnyWrap;
+  static constexpr auto FlagNW = SCEVNoWrapFlags::FlagNW;
+  static constexpr auto FlagNUW = SCEVNoWrapFlags::FlagNUW;
+  static constexpr auto FlagNSW = SCEVNoWrapFlags::FlagNSW;
+  static constexpr auto NoWrapMask = SCEVNoWrapFlags::NoWrapMask;
 
   explicit SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
                 unsigned short ExpressionSize)
@@ -636,16 +637,16 @@ class ScalarEvolution {
   /// Convenient NoWrapFlags manipulation that hides enum casts and is
   /// visible in the ScalarEvolution name space.
   [[nodiscard]] static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags,
-                                                   int Mask) {
-    return (SCEV::NoWrapFlags)(Flags & Mask);
+                                                   SCEV::NoWrapFlags Mask) {
+    return Flags & Mask;
   }
   [[nodiscard]] static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags,
                                                   SCEV::NoWrapFlags OnFlags) {
-    return (SCEV::NoWrapFlags)(Flags | OnFlags);
+    return Flags | OnFlags;
   }
   [[nodiscard]] static SCEV::NoWrapFlags
   clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags) {
-    return (SCEV::NoWrapFlags)(Flags & ~OffFlags);
+    return Flags & ~OffFlags;
   }
   [[nodiscard]] static bool hasFlags(SCEV::NoWrapFlags Flags,
                                      SCEV::NoWrapFlags TestFlags) {
@@ -2725,10 +2726,10 @@ template <> inline const SCEV *SCEVUseT<const SCEV *>::getCanonical() const {
 template <typename SCEVPtrT>
 void SCEVUseT<SCEVPtrT>::print(raw_ostream &OS) const {
   Base::getPointer()->print(OS);
-  SCEV::NoWrapFlags Flags = static_cast<SCEV::NoWrapFlags>(Base::getInt());
-  if (Flags & SCEV::FlagNUW)
+  SCEV::NoWrapFlags Flags = getUseNoWrapFlags();
+  if (any(Flags & SCEV::FlagNUW))
     OS << "(u nuw)";
-  if (Flags & SCEV::FlagNSW)
+  if (any(Flags & SCEV::FlagNSW))
     OS << "(u nsw)";
 }
 
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 828e29cd66d7b..0216153fb1a21 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -234,7 +234,8 @@ class SCEVNAryExpr : public SCEV {
   ArrayRef<SCEVUse> operands() const { return ArrayRef(Operands, NumOperands); }
 
   NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
-    return (NoWrapFlags)(SubclassData & Mask);
+    return static_cast<NoWrapFlags>(static_cast<unsigned>(SubclassData) &
+                                    static_cast<unsigned>(Mask));
   }
 
   bool hasNoUnsignedWrap() const {
@@ -274,7 +275,9 @@ class SCEVCommutativeExpr : public SCEVNAryExpr {
   }
 
   /// Set flags for a non-recurrence without clearing previously set flags.
-  void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
+  void setNoWrapFlags(NoWrapFlags Flags) {
+    SubclassData |= static_cast<unsigned short>(Flags);
+  }
 };
 
 /// This node represents an addition of some number of SCEVs.
@@ -402,9 +405,9 @@ class SCEVAddRecExpr : public SCEVNAryExpr {
   /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
   /// to make it easier to propagate flags.
   void setNoWrapFlags(NoWrapFlags Flags) {
-    if (Flags & (FlagNUW | FlagNSW))
+    if (any(Flags & (FlagNUW | FlagNSW)))
       Flags = ScalarEvolution::setFlags(Flags, FlagNW);
-    SubclassData |= Flags;
+    SubclassData |= static_cast<unsigned short>(Flags);
   }
 
   /// Return the value of this chain of recurrences at the specified
@@ -453,7 +456,7 @@ class SCEVMinMaxExpr : public SCEVCommutativeExpr {
       : SCEVCommutativeExpr(ID, T, O, N) {
     assert(isMinMaxType(T));
     // Min and max never overflow
-    setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
+    setNoWrapFlags(FlagNUW | FlagNSW);
   }
 
 public:
@@ -539,7 +542,9 @@ class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
   }
 
   /// Set flags for a non-recurrence without clearing previously set flags.
-  void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
+  void setNoWrapFlags(NoWrapFlags Flags) {
+    SubclassData |= static_cast<unsigned short>(Flags);
+  }
 
 protected:
   /// Note: Constructing subclasses via this constructor is allowed
@@ -548,7 +553,7 @@ class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
       : SCEVNAryExpr(ID, T, O, N) {
     assert(isSequentialMinMaxType(T));
     // Min and max never overflow
-    setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
+    setNoWrapFlags(FlagNUW | FlagNSW);
   }
 
 public:
@@ -1047,17 +1052,19 @@ class SCEVLoopAddRecRewriter
 };
 
 template <typename SCEVPtrT>
-inline SCEVWrap::NoWrapFlags
-SCEVUseT<SCEVPtrT>::getNoWrapFlags(SCEVWrap::NoWrapFlags Mask) const {
-  unsigned Flags = SCEV::FlagAnyWrap;
+inline SCEVNoWrapFlags
+SCEVUseT<SCEVPtrT>::getNoWrapFlags(SCEVNoWrapFlags Mask) const {
+  unsigned Flags = static_cast<unsigned>(SCEV::FlagAnyWrap);
   if (auto *NAry = dyn_cast<SCEVNAryExpr>(Base::getPointer()))
-    Flags = NAry->getNoWrapFlags();
+    Flags = static_cast<unsigned>(NAry->getNoWrapFlags());
   // Use-flags only encode NUW/NSW in 2 bits; shift to align with NoWrapFlags.
   unsigned UseFlags = Base::getInt() << 1;
   // NUW or NSW implies NW.
-  if (UseFlags & (SCEVWrap::FlagNUW | SCEVWrap::FlagNSW))
-    UseFlags |= SCEVWrap::FlagNW;
-  return SCEVWrap::NoWrapFlags((Flags | UseFlags) & Mask);
+  if (UseFlags & static_cast<unsigned>(SCEVNoWrapFlags::FlagNUW |
+                                       SCEVNoWrapFlags::FlagNSW))
+    UseFlags |= static_cast<unsigned>(SCEVNoWrapFlags::FlagNW);
+  return static_cast<SCEVNoWrapFlags>((Flags | UseFlags) &
+                                      static_cast<unsigned>(Mask));
 }
 
 } // end namespace llvm
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 20f0cd404ab9c..2b9efd22131c6 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -1028,7 +1028,7 @@ static bool isNoWrap(PredicatedScalarEvolution &PSE, const SCEVAddRecExpr *AR,
                      const DominatorTree &DT,
                      std::optional<int64_t> Stride = std::nullopt) {
   // FIXME: This should probably only return true for NUW.
-  if (AR->getNoWrapFlags(SCEV::NoWrapMask))
+  if (any(AR->getNoWrapFlags(SCEV::NoWrapMask)))
     return true;
 
   if (Ptr && PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW))
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 0820440790bb4..5c8484b3d8c17 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -1485,7 +1485,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
   //
 
   const SCEV *BECount = SE->getBackedgeTakenCount(L);
-  if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
+  if (PreAR && any(PreAR->getNoWrapFlags(WrapType)) &&
       !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
     return PreStart;
 
@@ -1496,7 +1496,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
       SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
                      (SE->*GetExtendExpr)(Step, WideTy, Depth));
   if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
-    if (PreAR && AR->getNoWrapFlags(WrapType)) {
+    if (PreAR && any(AR->getNoWrapFlags(WrapType))) {
       // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
       // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
       // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`.  Cache this fact.
@@ -1595,7 +1595,7 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
 
     // Give up if we don't already have the add recurrence we need because
     // actually constructing an add recurrence is relatively expensive.
-    if (PreAR && PreAR->getNoWrapFlags(WrapType)) {  // proves (2)
+    if (PreAR && any(PreAR->getNoWrapFlags(WrapType))) { // proves (2)
       const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
       ICmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
       const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
@@ -1871,8 +1871,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
           const SCEV *SResidual =
               getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
           const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
-          return getAddExpr(SZExtD, SZExtR,
-                            (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+          return getAddExpr(SZExtD, SZExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
                             Depth + 1);
         }
       }
@@ -1926,8 +1925,7 @@ const SCEV *ScalarEvolution::getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
         const SCEV *SResidual =
             getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
         const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
-        return getAddExpr(SZExtD, SZExtR,
-                          (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+        return getAddExpr(SZExtD, SZExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
                           Depth + 1);
       }
     }
@@ -2100,8 +2098,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
         const SCEV *SResidual =
             getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
         const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
-        return getAddExpr(SSExtD, SSExtR,
-                          (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+        return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
                           Depth + 1);
       }
     }
@@ -2224,8 +2221,7 @@ const SCEV *ScalarEvolution::getSignExtendExprImpl(const SCEV *Op, Type *Ty,
           const SCEV *SResidual =
               getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
           const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
-          return getAddExpr(SSExtD, SSExtR,
-                            (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+          return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
                             Depth + 1);
         }
       }
@@ -2554,7 +2550,7 @@ static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE,
   (void)CanAnalyze;
   assert(CanAnalyze && "don't call from other places!");
 
-  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
+  SCEV::NoWrapFlags SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
   SCEV::NoWrapFlags SignOrUnsignWrap =
       ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
 
@@ -2564,8 +2560,7 @@ static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE,
   };
 
   if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
-    Flags =
-        ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
+    Flags = ScalarEvolution::setFlags(Flags, SignOrUnsignMask);
 
   SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
 
@@ -5927,8 +5922,7 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
 
   if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
     setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
-                   (SCEV::NoWrapFlags)(AR->getNoWrapFlags() |
-                                       proveNoWrapViaConstantRanges(AR)));
+                   (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
   }
 
   // We can add Flags to the post-inc expression only if we
@@ -6057,9 +6051,9 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
         insertValueToMap(PN, PHISCEV);
 
         if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
-          setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
-                         (SCEV::NoWrapFlags)(AR->getNoWrapFlags() |
-                                             proveNoWrapViaConstantRanges(AR)));
+          setNoWrapFlags(
+              const_cast<SCEVAddRecExpr *>(AR),
+              (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
         }
 
         // We can add Flags to the post-inc expression only if we
@@ -8217,11 +8211,12 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         auto Flags = SCEV::FlagAnyWrap;
         if (BO->Op) {
           auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
-          if ((MulFlags & SCEV::FlagNSW) &&
-              ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
-            Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW);
-          if (MulFlags & SCEV::FlagNUW)
-            Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW);
+          if (any(MulFlags & SCEV::FlagNSW) &&
+              (any(MulFlags & SCEV::FlagNUW) ||
+               SA->getValue().ult(BitWidth - 1)))
+            Flags = Flags | SCEV::FlagNSW;
+          if (any(MulFlags & SCEV::FlagNUW))
+            Flags = Flags | SCEV::FlagNUW;
         }
 
         ConstantInt *X = ConstantInt::get(
@@ -13362,7 +13357,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   // implicit/exceptional) which causes the loop to execute before the
   // exiting instruction we're analyzing would trigger UB.
   auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
-  bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
+  bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
   ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
 
   const SCEV *Stride = IV->getStepRecurrence(*this);
@@ -13478,7 +13473,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
   if (!isLoopInvariant(RHS, L)) {
     const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
     if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
-        RHSAddRec->getNoWrapFlags()) {
+        any(RHSAddRec->getNoWrapFlags())) {
       // The structure of loop we are trying to calculate backedge count of:
       //
       //  left = left_start
@@ -13741,7 +13736,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
     return getCouldNotCompute();
 
   auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
-  bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
+  bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
   ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
 
   const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index f83c73eec147d..a560c324b4f1e 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -302,9 +302,9 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode,
       auto canGenerateIncompatiblePoison = [&Flags](Instruction *I) {
         // Ensure that no-wrap flags match.
         if (isa<OverflowingBinaryOperator>(I)) {
-          if (I->hasNoSignedWrap() != (Flags & SCEV::FlagNSW))
+          if (I->hasNoSignedWrap() != any(Flags & SCEV::FlagNSW))
             return true;
-          if (I->hasNoUnsignedWrap() != (Flags & SCEV::FlagNUW))
+          if (I->hasNoUnsignedWrap() != any(Flags & SCEV::FlagNUW))
             return true;
         }
         // Conservatively, do not use any instruction which has any of exact
@@ -341,9 +341,9 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode,
   // InstSimplifyFolder.
   Instruction *BO = Builder.Insert(BinaryOperator::Create(Opcode, LHS, RHS));
   BO->setDebugLoc(Loc);
-  if (Flags & SCEV::FlagNUW)
+  if (any(Flags & SCEV::FlagNUW))
     BO->setHasNoUnsignedWrap();
-  if (Flags & SCEV::FlagNSW)
+  if (any(Flags & SCEV::FlagNSW))
     BO->setHasNoSignedWrap();
 
   return BO;
@@ -382,8 +382,9 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *Offset, Value *V,
          SE.DT.dominates(cast<Instruction>(V), &*Builder.GetInsertPoint()));
 
   Value *Idx = expand(Offset);
-  GEPNoWrapFlags NW = (Flags & SCEV::FlagNUW) ? GEPNoWrapFlags::noUnsignedWrap()
-                                              : GEPNoWrapFlags::none();
+  GEPNoWrapFlags NW = any(Flags & SCEV::FlagNUW)
+                          ? GEPNoWrapFlags::noUnsignedWrap()
+                          : GEPNoWrapFlags::none();
 
   // Fold a GEP with constant operands.
   if (Constant *CLHS = dyn_cast<Constant>(V))

>From 81996ce19dc02233cac15a1dbe32369604c7ce3b Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 3 Apr 2026 13:06:41 +0100
Subject: [PATCH 3/3] !fixup polly and extra test

---
 llvm/include/llvm/Analysis/ScalarEvolution.h  |  4 +-
 .../Analysis/ScalarEvolutionExpressions.h     | 16 ++----
 .../Utils/ScalarEvolutionExpanderTest.cpp     | 49 +++++++++++++++++++
 polly/lib/Support/SCEVAffinator.cpp           |  2 +-
 .../test/CodeGen/non_affine_float_compare.ll  |  3 +-
 5 files changed, 57 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index cb46e47a59d85..8ff1b8e6749f2 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -634,8 +634,8 @@ class ScalarEvolution {
     ProperlyDominatesBlock ///< The SCEV properly dominates the block.
   };
 
-  /// Convenient NoWrapFlags manipulation that hides enum casts and is
-  /// visible in the ScalarEvolution name space.
+  /// Convenient NoWrapFlags manipulation. TODO: Replace with & operator of
+  /// enum class.
   [[nodiscard]] static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags,
                                                    SCEV::NoWrapFlags Mask) {
     return Flags & Mask;
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
index 0216153fb1a21..f103a789716ca 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
@@ -234,8 +234,7 @@ class SCEVNAryExpr : public SCEV {
   ArrayRef<SCEVUse> operands() const { return ArrayRef(Operands, NumOperands); }
 
   NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
-    return static_cast<NoWrapFlags>(static_cast<unsigned>(SubclassData) &
-                                    static_cast<unsigned>(Mask));
+    return static_cast<NoWrapFlags>(SubclassData) & Mask;
   }
 
   bool hasNoUnsignedWrap() const {
@@ -1054,17 +1053,10 @@ class SCEVLoopAddRecRewriter
 template <typename SCEVPtrT>
 inline SCEVNoWrapFlags
 SCEVUseT<SCEVPtrT>::getNoWrapFlags(SCEVNoWrapFlags Mask) const {
-  unsigned Flags = static_cast<unsigned>(SCEV::FlagAnyWrap);
+  SCEVNoWrapFlags Flags = SCEVNoWrapFlags::FlagAnyWrap;
   if (auto *NAry = dyn_cast<SCEVNAryExpr>(Base::getPointer()))
-    Flags = static_cast<unsigned>(NAry->getNoWrapFlags());
-  // Use-flags only encode NUW/NSW in 2 bits; shift to align with NoWrapFlags.
-  unsigned UseFlags = Base::getInt() << 1;
-  // NUW or NSW implies NW.
-  if (UseFlags & static_cast<unsigned>(SCEVNoWrapFlags::FlagNUW |
-                                       SCEVNoWrapFlags::FlagNSW))
-    UseFlags |= static_cast<unsigned>(SCEVNoWrapFlags::FlagNW);
-  return static_cast<SCEVNoWrapFlags>((Flags | UseFlags) &
-                                      static_cast<unsigned>(Mask));
+    Flags = NAry->getNoWrapFlags();
+  return (Flags | getUseNoWrapFlags()) & Mask;
 }
 
 } // end namespace llvm
diff --git a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
index 268b9313da882..5dcc833d8893f 100644
--- a/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
+++ b/llvm/unittests/Transforms/Utils/ScalarEvolutionExpanderTest.cpp
@@ -987,4 +987,53 @@ TEST_F(ScalarEvolutionExpanderTest, GEPFlags) {
   EXPECT_EQ(GEP->getNoWrapFlags(), GEPNoWrapFlags::none());
 }
 
+// Test that InsertBinop scans existing instructions in the block.
+TEST_F(ScalarEvolutionExpanderTest, InsertBinopReuseShlWithMatchingFlags) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString("define void @f(i64 %n) { "
+                                                  "  ret void "
+                                                  "}",
+                                                  Err, C);
+
+  assert(M && "Could not parse module?");
+  assert(!verifyModule(*M) && "Must have been well formed!");
+
+  Function *F = M->getFunction("f");
+  ASSERT_NE(F, nullptr);
+
+  ScalarEvolution SE = buildSE(*F);
+
+  auto *I64Ty = Type::getInt64Ty(C);
+  const SCEV *N = SE.getSCEV(F->getArg(0));
+  const SCEV *Four = SE.getConstant(I64Ty, 4);
+  const SCEV *Mul = SE.getMulExpr(N, Four, SCEV::FlagNUW | SCEV::FlagNSW);
+  const SCEV *Expr1 = SE.getAddExpr(Mul, SE.getConstant(I64Ty, 8));
+  const SCEV *Expr2 = SE.getAddExpr(Mul, SE.getConstant(I64Ty, 16));
+
+  // Expand with separate expanders so the InsertedExpressions cache doesn't
+  // apply.
+  auto *InsertBefore = F->getEntryBlock().getTerminator();
+  SCEVExpander Exp1(SE, "expander");
+  Value *V1 = Exp1.expandCodeFor(Expr1, nullptr, InsertBefore);
+  SCEVExpander Exp2(SE, "expander");
+  Value *V2 = Exp2.expandCodeFor(Expr2, nullptr, InsertBefore);
+
+  // Both expansions produce different values (different constants added).
+  EXPECT_NE(V1, V2);
+
+  // Both should share the same shl sub-expression via pattern match.
+  Value *Shl1 = nullptr, *Shl2 = nullptr;
+  Value *Arg = F->getArg(0);
+  EXPECT_TRUE(match(V1, m_Add(m_Value(Shl1), m_SpecificInt(8))));
+  EXPECT_TRUE(match(V2, m_Add(m_Value(Shl2), m_SpecificInt(16))));
+  EXPECT_EQ(Shl1, Shl2) << "Expected the shl to be reused";
+
+  // Verify the shared shl has the expected form and flags.
+  auto *ShlInst = cast<Instruction>(Shl1);
+  EXPECT_TRUE(match(ShlInst, m_Shl(m_Specific(Arg), m_SpecificInt(2))));
+  EXPECT_TRUE(ShlInst->hasNoUnsignedWrap());
+  EXPECT_TRUE(ShlInst->hasNoSignedWrap());
+}
+
 } // end namespace llvm
diff --git a/polly/lib/Support/SCEVAffinator.cpp b/polly/lib/Support/SCEVAffinator.cpp
index b55fa62f0e187..6455f4dc1da7c 100644
--- a/polly/lib/Support/SCEVAffinator.cpp
+++ b/polly/lib/Support/SCEVAffinator.cpp
@@ -137,7 +137,7 @@ PWACtx SCEVAffinator::checkForWrapping(const SCEV *Expr, PWACtx PWAC) const {
   // whereas n is the number of bits of the Expr, hence:
   //   n = bitwidth(ExprType)
 
-  if (IgnoreIntegerWrapping || (getNoWrapFlags(Expr) & SCEV::FlagNSW))
+  if (IgnoreIntegerWrapping || any(getNoWrapFlags(Expr) & SCEV::FlagNSW))
     return PWAC;
 
   isl::pw_aff PWAMod = addModuloSemantic(PWAC.first, Expr->getType());
diff --git a/polly/test/CodeGen/non_affine_float_compare.ll b/polly/test/CodeGen/non_affine_float_compare.ll
index 9709e231a4e86..d1a38e5bd6d69 100644
--- a/polly/test/CodeGen/non_affine_float_compare.ll
+++ b/polly/test/CodeGen/non_affine_float_compare.ll
@@ -12,8 +12,7 @@
 ; CHECK:   %[[offset:.*]] = shl nuw nsw i64 %polly.indvar, 2
 ; CHECK:   %scevgep[[R0:[0-9]*]] = getelementptr i8, ptr %A, i64 %[[offset]]
 ; CHECK:   %tmp3_p_scalar_ = load float, ptr %scevgep[[R0]], align 4, !alias.scope !2, !noalias !5
-; CHECK:   %[[offset2:.*]] = shl nuw nsw i64 %polly.indvar, 2
-; CHECK:   %scevgep[[R2:[0-9]*]] = getelementptr i8, ptr %scevgep{{[0-9]*}}, i64 %[[offset2]]
+; CHECK:   %scevgep[[R2:[0-9]*]] = getelementptr i8, ptr %scevgep{{[0-9]*}}, i64 %[[offset]]
 ; CHECK:   %tmp6_p_scalar_ = load float, ptr %scevgep[[R2]], align 4, !alias.scope !2, !noalias !5
 ; CHECK:   %p_tmp7 = fcmp oeq float %tmp3_p_scalar_, %tmp6_p_scalar_
 ; CHECK:   br i1 %p_tmp7, label %polly.stmt.bb8, label %polly.stmt.bb12.[[R:[a-zA-Z_.0-9]*]]



More information about the llvm-commits mailing list