[llvm] [SCEV] Collect and merge loop guards through PHI nodes with multiple incoming values (PR #113915)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 07:52:37 PDT 2024


================
@@ -15596,6 +15603,66 @@ ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
     Terms.emplace_back(LoopEntryPredicate->getCondition(),
                        LoopEntryPredicate->getSuccessor(0) == Pair.second);
   }
+  // Finally, if we stopped climbing the predecessor chain because
+  // there wasn't a unique one to continue, try to collect conditions
+  // for PHINodes by recursively following all of their incoming
+  // blocks and try to merge the found conditions to build a new one
+  // for the Phi.
+  if (Pair.second->hasNPredecessorsOrMore(2)) {
+    for (auto &Phi : Pair.second->phis()) {
+      if (!SE.isSCEVable(Phi.getType()))
+        continue;
+
+      using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
+      auto GetMinMaxConst = [&SE, &VisitedBlocks, &Pair,
+                             &Phi](unsigned int In) -> MinMaxPattern {
+        LoopGuards G(SE);
+        if (VisitedBlocks.insert(Phi.getIncomingBlock(In)).second)
+          collectFromBlock(SE, G, Pair.second, Phi.getIncomingBlock(In),
+                           VisitedBlocks);
+        const SCEV *S = G.RewriteMap[SE.getSCEV(Phi.getIncomingValue(In))];
+        auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S);
+        if (!SM)
+          return {nullptr, scCouldNotCompute};
+        if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
+          return {C0, SM->getSCEVType()};
+        if (const SCEVConstant *C1 = dyn_cast<SCEVConstant>(SM->getOperand(1)))
+          return {C1, SM->getSCEVType()};
+        return {nullptr, scCouldNotCompute};
+      };
+      auto MergeMinMaxConst = [](MinMaxPattern P1,
+                                 MinMaxPattern P2) -> MinMaxPattern {
+        auto [C1, T1] = P1;
+        auto [C2, T2] = P2;
+        if (!C1 || !C2 || T1 != T2)
+          return {nullptr, scCouldNotCompute};
+        switch (T1) {
+        case scUMaxExpr:
+          return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
+        case scSMaxExpr:
+          return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
+        case scUMinExpr:
+          return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
+        case scSMinExpr:
+          return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
----------------
fhahn wrote:

We should have tests for all min/max expressions

Could you also adds some negative tests?

https://github.com/llvm/llvm-project/pull/113915


More information about the llvm-commits mailing list