[llvm] [DAG] Fold nested add(add(reduce(a), b), add(reduce(c), d)) (PR #115150)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 25 08:57:05 PST 2024
================
@@ -1329,6 +1329,38 @@ SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
N0.getOperand(0), N1.getOperand(0)));
}
+
+ // Reassociate op(op(vecreduce(a), b), op(vecreduce(c), d)) into
+ // op(vecreduce(op(a, c)), op(b, d)), to combine the reductions into a
+ // single node.
+ SDValue A, B, C, D, RedA, RedB;
+ if (sd_match(N0, m_OneUse(m_c_BinOp(
+ Opc,
+ m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(A))),
+ m_Value(RedA)),
+ m_Value(B)))) &&
+ sd_match(N1, m_OneUse(m_c_BinOp(
+ Opc,
+ m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(C))),
+ m_Value(RedB)),
+ m_Value(D)))) &&
+ !sd_match(B, m_UnaryOp(RedOpc, m_Value())) &&
+ !sd_match(D, m_UnaryOp(RedOpc, m_Value())) &&
+ A.getValueType() == C.getValueType() &&
+ hasOperation(Opc, A.getValueType()) &&
+ TLI.shouldReassociateReduction(RedOpc, VT)) {
+ if ((Opc == ISD::FADD || Opc == ISD::FMUL) &&
+ (!N0->getFlags().hasAllowReassociation() ||
+ !N1->getFlags().hasAllowReassociation() ||
+ !RedA->getFlags().hasAllowReassociation() ||
+ !RedB->getFlags().hasAllowReassociation()))
+ return SDValue();
+ SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
+ SDValue Op = DAG.getNode(Opc, DL, A.getValueType(), A, C);
+ SDValue Red = DAG.getNode(RedOpc, DL, VT, Op);
----------------
david-arm wrote:
Sorry to be picky, but is there still a problem here with flags for FP operations? Perhaps I'm just being overly cautious, but I'm thinking of an example where the flags on the top level `fadd` may be different to those on the two individual reductions, etc. So effectively we could end up promoting the reduction to have flags that it didn't have previously. I must admit I'm a bit unsure about what happens in general for DAG combines in cases like this:
```
%r1 = call fast float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %a)
%a1 = fadd fast float %r1, %c
%r2 = call reassoc float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %b)
%a2 = fadd reassoc float %r2, %d
%r = fadd fast float %a1, %a2
```
Might it be simpler to require all nested operations to have the same flags as the `Flags` passed in if that solves the problem you care about?
https://github.com/llvm/llvm-project/pull/115150
More information about the llvm-commits
mailing list