[llvm] [LV] Check full partial reduction chains in order. (PR #168036)

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 18 09:36:13 PST 2025


================
@@ -8016,31 +8019,41 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
 
   // Check if each use of a chain's two extends is a partial reduction
   // and only add those that don't have non-partial reduction users.
-  for (auto Pair : PartialReductionChains) {
-    PartialReductionChain Chain = Pair.first;
-    if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
-        (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
-      ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second);
+  for (const auto &[_, Chains] : ChainsByPhi) {
+    for (const auto &[Chain, Scale] : Chains) {
+      if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
+          (!Chain.ExtendB ||
+           ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
+        ScaledReductionMap.try_emplace(Chain.Reduction, Scale);
+    }
   }
 
   // Check that all partial reductions in a chain are only used by other
   // partial reductions with the same scale factor. Otherwise we end up creating
   // users of scaled reductions where the types of the other operands don't
   // match.
-  for (const auto &[Chain, Scale] : PartialReductionChains) {
-    auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
-      auto *UI = cast<Instruction>(U);
-      if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
-        return all_of(UI->users(), [ScaleVal, this](const User *U) {
-          auto *UI = cast<Instruction>(U);
-          return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
-        });
+  for (const auto &[Phi, Chains] : ChainsByPhi) {
+    for (const auto &[Chain, Scale] : Chains) {
+      auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
+        auto *UI = cast<Instruction>(U);
+        if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
+          return all_of(UI->users(), [ScaleVal, this](const User *U) {
+            auto *UI = cast<Instruction>(U);
+            return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
+          });
+        }
+        return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
+               !OrigLoop->contains(UI->getParent());
+      };
----------------
sdesmalen-arm wrote:

This code is quite dizzying and I'm a bit puzzled why it is needed to go through the users() of `Chain.Reduction` and then each of its users(), where I would think this information is already encoded in the `ChainsByPhi` map itself.

My understanding is that `ChainsByPhi` is a map of `PHI -> [ <chain, scale>, <chain, scale>, ... ]` records, where each `chain` is an operation in that particular reduction chain (defined by its own extend instructions, extend-user (e.g. a `mul` or another operand in the chain) and the reduction op itself (e.g. `add` or `sub`)). If so, then why would it not be enough to check that all `<chain, scale>` has the same value for `scale`?

https://github.com/llvm/llvm-project/pull/168036


More information about the llvm-commits mailing list