[llvm] [LV] Check full partial reduction chains in order. (PR #168036)
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 18 09:36:13 PST 2025
================
@@ -8016,31 +8019,41 @@ void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
// Check if each use of a chain's two extends is a partial reduction
// and only add those that don't have non-partial reduction users.
- for (auto Pair : PartialReductionChains) {
- PartialReductionChain Chain = Pair.first;
- if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
- (!Chain.ExtendB || ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
- ScaledReductionMap.try_emplace(Chain.Reduction, Pair.second);
+ for (const auto &[_, Chains] : ChainsByPhi) {
+ for (const auto &[Chain, Scale] : Chains) {
+ if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
+ (!Chain.ExtendB ||
+ ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
+ ScaledReductionMap.try_emplace(Chain.Reduction, Scale);
+ }
}
// Check that all partial reductions in a chain are only used by other
// partial reductions with the same scale factor. Otherwise we end up creating
// users of scaled reductions where the types of the other operands don't
// match.
- for (const auto &[Chain, Scale] : PartialReductionChains) {
- auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
- auto *UI = cast<Instruction>(U);
- if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
- return all_of(UI->users(), [ScaleVal, this](const User *U) {
- auto *UI = cast<Instruction>(U);
- return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
- });
+ for (const auto &[Phi, Chains] : ChainsByPhi) {
+ for (const auto &[Chain, Scale] : Chains) {
+ auto AllUsersPartialRdx = [ScaleVal = Scale, this](const User *U) {
+ auto *UI = cast<Instruction>(U);
+ if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader()) {
+ return all_of(UI->users(), [ScaleVal, this](const User *U) {
+ auto *UI = cast<Instruction>(U);
+ return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal;
+ });
+ }
+ return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
+ !OrigLoop->contains(UI->getParent());
+ };
----------------
sdesmalen-arm wrote:
This code is quite dizzying and I'm a bit puzzled why it is needed to go through the users() of `Chain.Reduction` and then each of its users(), where I would think this information is already encoded in the `ChainsByPhi` map itself.
My understanding is that `ChainsByPhi` is a map of `PHI -> [ <chain, scale>, <chain, scale>, ... ]` records, where each `chain` is an operation in that particular reduction chain (defined by its own extend instructions, extend-user (e.g. a `mul` or another operand in the chain) and the reduction op itself (e.g. `add` or `sub`)). If so, then why would it not be enough to check that all `<chain, scale>` has the same value for `scale`?
https://github.com/llvm/llvm-project/pull/168036
More information about the llvm-commits
mailing list