[llvm] [DAG] Fold nested add(add(reduce(a), b), add(reduce(c), d)) (PR #115150)

David Sherwood via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 25 08:57:05 PST 2024


================
@@ -1329,6 +1329,38 @@ SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
                        DAG.getNode(Opc, DL, N0.getOperand(0).getValueType(),
                                    N0.getOperand(0), N1.getOperand(0)));
   }
+
+  // Reassociate op(op(vecreduce(a), b), op(vecreduce(c), d)) into
+  // op(vecreduce(op(a, c)), op(b, d)), to combine the reductions into a
+  // single node.
+  SDValue A, B, C, D, RedA, RedB;
+  if (sd_match(N0, m_OneUse(m_c_BinOp(
+                       Opc,
+                       m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(A))),
+                               m_Value(RedA)),
+                       m_Value(B)))) &&
+      sd_match(N1, m_OneUse(m_c_BinOp(
+                       Opc,
+                       m_AllOf(m_OneUse(m_UnaryOp(RedOpc, m_Value(C))),
+                               m_Value(RedB)),
+                       m_Value(D)))) &&
+      !sd_match(B, m_UnaryOp(RedOpc, m_Value())) &&
+      !sd_match(D, m_UnaryOp(RedOpc, m_Value())) &&
+      A.getValueType() == C.getValueType() &&
+      hasOperation(Opc, A.getValueType()) &&
+      TLI.shouldReassociateReduction(RedOpc, VT)) {
+    if ((Opc == ISD::FADD || Opc == ISD::FMUL) &&
+        (!N0->getFlags().hasAllowReassociation() ||
+         !N1->getFlags().hasAllowReassociation() ||
+         !RedA->getFlags().hasAllowReassociation() ||
+         !RedB->getFlags().hasAllowReassociation()))
+      return SDValue();
+    SelectionDAG::FlagInserter FlagsInserter(DAG, Flags);
+    SDValue Op = DAG.getNode(Opc, DL, A.getValueType(), A, C);
+    SDValue Red = DAG.getNode(RedOpc, DL, VT, Op);
----------------
david-arm wrote:

Sorry to be picky, but is there still a problem here with flags for FP operations? Perhaps I'm just being overly cautious, but I'm thinking of an example where the flags on the top level `fadd` may be different to those on the two individual reductions, etc. So effectively we could end up promoting the reduction to have flags that it didn't have previously. I must admit I'm a bit unsure about what happens in general for DAG combines in cases like this:

```
  %r1 = call fast float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %a)
  %a1 = fadd fast float %r1, %c
  %r2 = call reassoc float @llvm.vector.reduce.fadd.f32.v4f32(float -0.0, <4 x float> %b)
  %a2 = fadd reassoc float %r2, %d
  %r = fadd fast float %a1, %a2
```

Might it be simpler to require all nested operations to have the same flags as the `Flags` passed in if that solves the problem you care about?

https://github.com/llvm/llvm-project/pull/115150


More information about the llvm-commits mailing list