[llvm] [SCEV] Preserve divisibility info when creating UMax/SMax expressions. (PR #160012)

via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 21 13:41:43 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: Florian Hahn (fhahn)

<details>
<summary>Changes</summary>

Currently we generate (S|U)Max(1, Op) for Op >= 1. This may discard divisibility info of Op. This patch rewrites such SMax/UMax expressions to use the lowest common multiplier for all non-constant operands.

---
Full diff: https://github.com/llvm/llvm-project/pull/160012.diff


2 Files Affected:

- (modified) llvm/lib/Analysis/ScalarEvolution.cpp (+22-2) 
- (modified) llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll (+2-2) 


``````````diff
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index b08399b381f34..ee1f92a4197e8 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15850,12 +15850,17 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
         To = SE.getUMaxExpr(FromRewritten, RHS);
         if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
           EnqueueOperands(UMin);
+        if (RHS->isOne())
+          ExprsToRewrite.push_back(From);
         break;
       case CmpInst::ICMP_SGT:
       case CmpInst::ICMP_SGE:
         To = SE.getSMaxExpr(FromRewritten, RHS);
-        if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
+        if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten)) {
           EnqueueOperands(SMin);
+        }
+        if (RHS->isOne())
+          ExprsToRewrite.push_back(From);
         break;
       case CmpInst::ICMP_EQ:
         if (isa<SCEVConstant>(RHS))
@@ -15986,7 +15991,22 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
     for (const SCEV *Expr : ExprsToRewrite) {
       const SCEV *RewriteTo = Guards.RewriteMap[Expr];
       Guards.RewriteMap.erase(Expr);
-      Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
+      const SCEV *Rewritten = Guards.rewrite(RewriteTo);
+
+      // Try to strengthen divisibility of SMax/UMax expressions coming from >=
+      // 1 conditions.
+      if (auto *SMax = dyn_cast<SCEVSMaxExpr>(Rewritten)) {
+        unsigned MinTrailingZeros = SE.getMinTrailingZeros(SMax->getOperand(1));
+        for (const SCEV *Op : drop_begin(SMax->operands(), 2))
+          MinTrailingZeros =
+              std::min(MinTrailingZeros, SE.getMinTrailingZeros(Op));
+        if (MinTrailingZeros != 0)
+          Rewritten = SE.getSMaxExpr(
+              SE.getConstant(APInt(SMax->getType()->getScalarSizeInBits(), 1)
+                                 .shl(MinTrailingZeros)),
+              SMax);
+      }
+      Guards.RewriteMap.insert({Expr, Rewritten});
     }
   }
 }
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
index 8d091a00ed4b9..d38010403dad7 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
@@ -61,7 +61,7 @@ define void @umin(i32 noundef %a, i32 noundef %b) {
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
 ; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i32 2147483646
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
-; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 2
 ;
 ; void umin(unsigned a, unsigned b) {
 ;   a *= 2;
@@ -157,7 +157,7 @@ define void @smin(i32 noundef %a, i32 noundef %b) {
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
 ; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i32 2147483646
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
-; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
+; CHECK-NEXT:  Loop %for.body: Trip multiple is 2
 ;
 ; void smin(signed a, signed b) {
 ;   a *= 2;

``````````

</details>


https://github.com/llvm/llvm-project/pull/160012


More information about the llvm-commits mailing list