[llvm] [SCEV] Use Step and Start to check if SCEVWrapPredicate is implied. (PR #118184)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 16 06:03:25 PST 2024


================
@@ -14934,10 +14936,52 @@ SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID,
 
 const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
 
-bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const {
+bool SCEVWrapPredicate::implies(const SCEVPredicate *N,
+                                ScalarEvolution &SE) const {
   const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
+  if (!Op)
+    return false;
+
+  if (setFlags(Flags, Op->Flags) != Flags)
+    return false;
+
+  if (Op->AR == AR)
+    return true;
+
+  if (Flags != SCEVWrapPredicate::IncrementNSSW &&
+      Flags != SCEVWrapPredicate::IncrementNUSW)
+    return false;
 
-  return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
+  bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
+  const SCEV *Step = AR->getStepRecurrence(SE);
+  const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
+
+  // If both steps are positive, this implies N, if N's start and step are
+  // ULE/SLE (for NSUW/NSSW) than this'.
+  if (SE.isKnownPositive(Step) && SE.isKnownPositive(OpStep)) {
+    const SCEV *OpStart = Op->AR->getStart();
+    const SCEV *Start = AR->getStart();
+    if (SE.getTypeSizeInBits(Step->getType()) >
+        SE.getTypeSizeInBits(OpStep->getType())) {
+      OpStep = SE.getZeroExtendExpr(OpStep, Step->getType());
+    } else {
+      Step = IsNUW ? SE.getNoopOrZeroExtend(Step, OpStep->getType())
----------------
fhahn wrote:

Adjusted, thanks!

https://github.com/llvm/llvm-project/pull/118184


More information about the llvm-commits mailing list