[llvm] Handle VECREDUCE intrinsics in NVPTX backend (PR #136253)

Princeton Ferro via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 18 17:28:59 PDT 2025


================
@@ -2128,6 +2152,194 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
+/// A generic routine for constructing a tree reduction on a vector operand.
+/// This method differs from iterative splitting in DAGTypeLegalizer by
+/// progressively grouping elements bottom-up.
+static SDValue BuildTreeReduction(
+    const SmallVector<SDValue> &Elements, EVT EltTy,
+    ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
+    const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
+  // now build the computation graph in place at each level
+  SmallVector<SDValue> Level = Elements;
+  unsigned OpIdx = 0;
+  while (Level.size() > 1) {
+    const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
+
+    // partially reduce all elements in level
+    SmallVector<SDValue> ReducedLevel;
+    unsigned I = 0, E = Level.size();
+    for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
+      // Reduce elements in groups of [DefaultGroupSize], as much as possible.
+      ReducedLevel.push_back(DAG.getNode(
+          DefaultScalarOp, DL, EltTy,
+          ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
+    }
+
+    if (I < E) {
+      if (ReducedLevel.empty()) {
+        // The current operator requires more inputs than there are operands at
+        // this level. Pick a smaller operator and retry.
+        ++OpIdx;
+        assert(OpIdx < Ops.size() && "no smaller operators for reduction");
+        continue;
+      }
+
+      // Otherwise, we just have a remainder, which we push to the next level.
+      for (; I < E; ++I)
+        ReducedLevel.push_back(Level[I]);
+    }
+    Level = ReducedLevel;
+  }
+
+  return *Level.begin();
+}
+
+/// Lower reductions to either a sequence of operations or a tree if
+/// reassociations are allowed. This method will use larger operations like
+/// max3/min3 when the target supports them.
+SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  const SDNodeFlags Flags = Op->getFlags();
+  SDValue Vector;
+  SDValue Accumulator;
+  if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
+      Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
+    // special case with accumulator as first arg
+    Accumulator = Op.getOperand(0);
+    Vector = Op.getOperand(1);
+  } else {
+    // default case
+    Vector = Op.getOperand(0);
+  }
+  EVT EltTy = Vector.getValueType().getVectorElementType();
+  const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
+                             STI.getPTXVersion() >= 88;
+
+  // A list of SDNode opcodes with equivalent semantics, sorted descending by
+  // number of inputs they take.
+  SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
+  bool IsReassociatable;
+
+  switch (Op->getOpcode()) {
+  case ISD::VECREDUCE_FADD:
+  case ISD::VECREDUCE_SEQ_FADD:
+    ScalarOps = {{ISD::FADD, 2}};
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMUL:
+  case ISD::VECREDUCE_SEQ_FMUL:
+    ScalarOps = {{ISD::FMUL, 2}};
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMAX:
----------------
Prince781 wrote:

It's because of quirks when handling NaNs. Here's a table that shows what happens according to IEEE 754 standard (the case for `min` is analogous):

Operation | `f(a, qNaN)` | `f(a, sNaN)` | `f(+0.0, -0.0)`
----------|--------------|--------------|----------------
`maxNum`  | `a`          | `qNaN`       | either (unspecified)
`maximum` | `qNaN`       | `qNaN`       | `+0.0`

- `a` is a valid number
- There are two types of NaNs: `qNaN` means "quiet NaN", while `sNaN` means "signaling NaN."
- `NaN` used here means either kind of NaN when both operands are the same.
- `maxNum` is implemented by `FMAX` / `@llvm.fmax`, defined in IEEE 754 2008[^1].
- `maximum` is implemented by `FMAXIMUM` / `@llvm.fmaximum`, defined in IEEE 754 2019[^3].

Example of non-associativity[^2]:

```
a = 1.0, b = 1.0, c = sNaN
maxNum(a, maxNum(b, c)) = maxNum(a, qNaN) = a = 1.0
maxNum(maxNum(a, b), c) = maxNum(1.0, sNaN) = qNaN
```

`maximum` can be reassociated.

[^1]: http://www.dsc.ufcg.edu.br/~cnum/modulos/Modulo2/IEEE754_2008.pdf
[^2]: https://grouper.ieee.org/groups/msc/ANSI_IEEE-Std-754-2019/background/minNum_maxNum_Removal_Demotion_v3.pdf
[^3]: https://www-users.cse.umn.edu/~vinals/tspot_files/phys4041/2020/IEEE%20Standard%20754-2019.pdf

https://github.com/llvm/llvm-project/pull/136253


More information about the llvm-commits mailing list