[llvm] Handle VECREDUCE intrinsics in NVPTX backend (PR #136253)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 17 21:36:48 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Princeton Ferro (Prince781)

<details>
<summary>Changes</summary>

Lower VECREDUCE intrinsics to tree reductions when reassociations are allowed. And add support for `min3`/`max3` from PTX ISA 8.8 (to be released next week).

---

Patch is 94.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136253.diff


6 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTX.td (+8-2) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+230) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+6) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+54) 
- (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+2) 
- (added) llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll (+1908) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTX.td b/llvm/lib/Target/NVPTX/NVPTX.td
index 5467ae011a208..d4dc278cfa648 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.td
+++ b/llvm/lib/Target/NVPTX/NVPTX.td
@@ -36,17 +36,19 @@ class FeaturePTX<int version>:
 
 foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53,
               60, 61, 62, 70, 72, 75, 80, 86, 87,
-              89, 90, 100, 101, 120] in
+              89, 90, 100, 101, 103, 120, 121] in
   def SM#sm: FeatureSM<""#sm, !mul(sm, 10)>;
 
 def SM90a: FeatureSM<"90a", 901>;
 def SM100a: FeatureSM<"100a", 1001>;
 def SM101a: FeatureSM<"101a", 1011>;
+def SM103a: FeatureSM<"103a", 1031>;
 def SM120a: FeatureSM<"120a", 1201>;
+def SM121a: FeatureSM<"121a", 1211>;
 
 foreach version = [32, 40, 41, 42, 43, 50, 60, 61, 62, 63, 64, 65,
                    70, 71, 72, 73, 74, 75, 76, 77, 78,
-                   80, 81, 82, 83, 84, 85, 86, 87] in
+                   80, 81, 82, 83, 84, 85, 86, 87, 88] in
   def PTX#version: FeaturePTX<version>;
 
 //===----------------------------------------------------------------------===//
@@ -81,8 +83,12 @@ def : Proc<"sm_100", [SM100, PTX86]>;
 def : Proc<"sm_100a", [SM100a, PTX86]>;
 def : Proc<"sm_101", [SM101, PTX86]>;
 def : Proc<"sm_101a", [SM101a, PTX86]>;
+def : Proc<"sm_103", [SM103, PTX88]>;
+def : Proc<"sm_103a", [SM103a, PTX88]>;
 def : Proc<"sm_120", [SM120, PTX87]>;
 def : Proc<"sm_120a", [SM120a, PTX87]>;
+def : Proc<"sm_121", [SM121, PTX88]>;
+def : Proc<"sm_121a", [SM121a, PTX88]>;
 
 def NVPTXInstrInfo : InstrInfo {
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9bde2a976e164..3a14aa47e5065 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -831,6 +831,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   if (STI.allowFP16Math() || STI.hasBF16Math())
     setTargetDAGCombine(ISD::SETCC);
 
+  // Vector reduction operations. These are transformed into a tree evaluation
+  // of nodes which may initially be illegal.
+  for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
+    MVT EltVT = VT.getVectorElementType();
+    if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
+        EltVT == MVT::f64) {
+      setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
+                          ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
+                          ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
+                          ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
+                         VT, Custom);
+    } else if (EltVT.isScalarInteger()) {
+      setOperationAction(
+          {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
+           ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
+           ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
+          VT, Custom);
+    }
+  }
+
   // Promote fp16 arithmetic if fp16 hardware isn't available or the
   // user passed --nvptx-no-fp16-math. The flag is useful because,
   // although sm_53+ GPUs have some sort of FP16 support in
@@ -1082,6 +1102,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::BFI)
     MAKE_CASE(NVPTXISD::PRMT)
     MAKE_CASE(NVPTXISD::FCOPYSIGN)
+    MAKE_CASE(NVPTXISD::FMAXNUM3)
+    MAKE_CASE(NVPTXISD::FMINNUM3)
+    MAKE_CASE(NVPTXISD::FMAXIMUM3)
+    MAKE_CASE(NVPTXISD::FMINIMUM3)
     MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
     MAKE_CASE(NVPTXISD::STACKRESTORE)
     MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2136,6 +2160,194 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
+/// A generic routine for constructing a tree reduction on a vector operand.
+/// This method differs from iterative splitting in DAGTypeLegalizer by
+/// progressively grouping elements bottom-up.
+static SDValue BuildTreeReduction(
+    const SmallVector<SDValue> &Elements, EVT EltTy,
+    ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
+    const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
+  // now build the computation graph in place at each level
+  SmallVector<SDValue> Level = Elements;
+  unsigned OpIdx = 0;
+  while (Level.size() > 1) {
+    const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
+
+    // partially reduce all elements in level
+    SmallVector<SDValue> ReducedLevel;
+    unsigned I = 0, E = Level.size();
+    for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
+      // Reduce elements in groups of [DefaultGroupSize], as much as possible.
+      ReducedLevel.push_back(DAG.getNode(
+          DefaultScalarOp, DL, EltTy,
+          ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
+    }
+
+    if (I < E) {
+      if (ReducedLevel.empty()) {
+        // The current operator requires more inputs than there are operands at
+        // this level. Pick a smaller operator and retry.
+        ++OpIdx;
+        assert(OpIdx < Ops.size() && "no smaller operators for reduction");
+        continue;
+      }
+
+      // Otherwise, we just have a remainder, which we push to the next level.
+      for (; I < E; ++I)
+        ReducedLevel.push_back(Level[I]);
+    }
+    Level = ReducedLevel;
+  }
+
+  return *Level.begin();
+}
+
+/// Lower reductions to either a sequence of operations or a tree if
+/// reassociations are allowed. This method will use larger operations like
+/// max3/min3 when the target supports them.
+SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  const SDNodeFlags Flags = Op->getFlags();
+  SDValue Vector;
+  SDValue Accumulator;
+  if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
+      Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
+    // special case with accumulator as first arg
+    Accumulator = Op.getOperand(0);
+    Vector = Op.getOperand(1);
+  } else {
+    // default case
+    Vector = Op.getOperand(0);
+  }
+  EVT EltTy = Vector.getValueType().getVectorElementType();
+  const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
+                             STI.getPTXVersion() >= 88;
+
+  // A list of SDNode opcodes with equivalent semantics, sorted descending by
+  // number of inputs they take.
+  SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
+  bool IsReassociatable;
+
+  switch (Op->getOpcode()) {
+  case ISD::VECREDUCE_FADD:
+  case ISD::VECREDUCE_SEQ_FADD:
+    ScalarOps = {{ISD::FADD, 2}};
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMUL:
+  case ISD::VECREDUCE_SEQ_FMUL:
+    ScalarOps = {{ISD::FMUL, 2}};
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMAX:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
+    ScalarOps.push_back({ISD::FMAXNUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMIN:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
+    ScalarOps.push_back({ISD::FMINNUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMAXIMUM:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
+    ScalarOps.push_back({ISD::FMAXIMUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMINIMUM:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
+    ScalarOps.push_back({ISD::FMINIMUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_ADD:
+    ScalarOps = {{ISD::ADD, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_MUL:
+    ScalarOps = {{ISD::MUL, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_UMAX:
+    ScalarOps = {{ISD::UMAX, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_UMIN:
+    ScalarOps = {{ISD::UMIN, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_SMAX:
+    ScalarOps = {{ISD::SMAX, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_SMIN:
+    ScalarOps = {{ISD::SMIN, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_AND:
+    ScalarOps = {{ISD::AND, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_OR:
+    ScalarOps = {{ISD::OR, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_XOR:
+    ScalarOps = {{ISD::XOR, 2}};
+    IsReassociatable = true;
+    break;
+  default:
+    llvm_unreachable("unhandled vecreduce operation");
+  }
+
+  EVT VectorTy = Vector.getValueType();
+  const unsigned NumElts = VectorTy.getVectorNumElements();
+
+  // scalarize vector
+  SmallVector<SDValue> Elements(NumElts);
+  for (unsigned I = 0, E = NumElts; I != E; ++I) {
+    Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
+                              DAG.getConstant(I, DL, MVT::i64));
+  }
+
+  // Lower to tree reduction.
+  if (IsReassociatable || Flags.hasAllowReassociation()) {
+    // we don't expect an accumulator for reassociatable vector reduction ops
+    assert(!Accumulator && "unexpected accumulator");
+    return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
+  }
+
+  // Lower to sequential reduction.
+  for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
+    assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
+    const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
+
+    if (!Accumulator) {
+      if (I + DefaultGroupSize <= NumElts) {
+        Accumulator = DAG.getNode(
+            DefaultScalarOp, DL, EltTy,
+            ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
+        I += DefaultGroupSize;
+      }
+    }
+
+    if (Accumulator) {
+      for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
+        SmallVector<SDValue> Operands = {Accumulator};
+        for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
+          Operands.push_back(Elements[I + K]);
+        Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
+      }
+    }
+  }
+
+  return Accumulator;
+}
+
 SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
   // Handle bitcasting from v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
@@ -2879,6 +3091,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVECTOR_SHUFFLE(Op, DAG);
   case ISD::CONCAT_VECTORS:
     return LowerCONCAT_VECTORS(Op, DAG);
+  case ISD::VECREDUCE_FADD:
+  case ISD::VECREDUCE_FMUL:
+  case ISD::VECREDUCE_SEQ_FADD:
+  case ISD::VECREDUCE_SEQ_FMUL:
+  case ISD::VECREDUCE_FMAX:
+  case ISD::VECREDUCE_FMIN:
+  case ISD::VECREDUCE_FMAXIMUM:
+  case ISD::VECREDUCE_FMINIMUM:
+  case ISD::VECREDUCE_ADD:
+  case ISD::VECREDUCE_MUL:
+  case ISD::VECREDUCE_UMAX:
+  case ISD::VECREDUCE_UMIN:
+  case ISD::VECREDUCE_SMAX:
+  case ISD::VECREDUCE_SMIN:
+  case ISD::VECREDUCE_AND:
+  case ISD::VECREDUCE_OR:
+  case ISD::VECREDUCE_XOR:
+    return LowerVECREDUCE(Op, DAG);
   case ISD::STORE:
     return LowerSTORE(Op, DAG);
   case ISD::LOAD:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index dd90746f6d9d6..0880f3e8edfe8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
   UNPACK_VECTOR,
 
   FCOPYSIGN,
+  FMAXNUM3,
+  FMINNUM3,
+  FMAXIMUM3,
+  FMINIMUM3,
+
   DYNAMIC_STACKALLOC,
   STACKRESTORE,
   STACKSAVE,
@@ -296,6 +301,7 @@ class NVPTXTargetLowering : public TargetLowering {
 
   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 16b489afddf5c..721b578ba72ee 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -368,6 +368,46 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
 }
 
+// 3-input min/max (sm_100+) for f32 only
+multiclass FMINIMUMMAXIMUM3<string OpcStr, SDNode OpNode> {
+   def f32rrr_ftz :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
+               !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
+               Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
+   def f32rri_ftz :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
+               !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
+               Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
+   def f32rii_ftz :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, f32imm:$b, f32imm:$c),
+               !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
+               Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
+   def f32rrr :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
+               !strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
+               Requires<[hasPTX<88>, hasSM<100>]>;
+   def f32rri :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
+               !strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, Float32Regs:$b, fpimm:$c))]>,
+               Requires<[hasPTX<88>, hasSM<100>]>;
+   def f32rii :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, f32imm:$b, f32imm:$c),
+               !strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
+               Requires<[hasPTX<88>, hasSM<100>]>;
+}
+
 // Template for instructions which take three FP args.  The
 // instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
 //
@@ -1101,6 +1141,20 @@ defm FMAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
 defm FMINNAN : FMINIMUMMAXIMUM<"min.NaN", /* NaN */ true, fminimum>;
 defm FMAXNAN : FMINIMUMMAXIMUM<"max.NaN", /* NaN */ true, fmaximum>;
 
+def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
+                            [SDNPCommutative, SDNPAssociative]>;
+def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
+                             [SDNPCommutative, SDNPAssociative]>;
+def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
+                             [SDNPCommutative, SDNPAssociative]>;
+def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
+                             [SDNPCommutative, SDNPAssociative]>;
+
+defm FMIN3 : FMINIMUMMAXIMUM3<"min", nvptx_fminnum3>;
+defm FMAX3 : FMINIMUMMAXIMUM3<"max", nvptx_fmaxnum3>;
+defm FMINNAN3 : FMINIMUMMAXIMUM3<"min.NaN", nvptx_fminimum3>;
+defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max.NaN", nvptx_fmaximum3>;
+
 defm FABS  : F2<"abs", fabs>;
 defm FNEG  : F2<"neg", fneg>;
 defm FABS_H: F2_Support_Half<"abs", fabs>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 9e77f628da7a7..2cc81a064152f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -83,6 +83,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
   }
   unsigned getMinVectorRegisterBitWidth() const { return 32; }
 
+  bool shouldExpandReduction(const IntrinsicInst *II) const { return false; }
+
   // We don't want to prevent inlining because of target-cpu and -features
   // attributes that were added to newer versions of LLVM/Clang: There are
   // no incompatible functions in PTX, ptxas will throw errors in such cases.
diff --git a/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll b/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll
new file mode 100644
index 0000000000000..a9101ba3ca651
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll
@@ -0,0 +1,1908 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --extra_scrub --version 5
+; RUN: llc < %s -mcpu=sm_80 -mattr=+ptx70 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | FileCheck -check-prefixes CHECK,CHECK-SM80 %s
+; RUN: %if ptxas-12.9 %{ llc < %s -mcpu=sm_80 -mattr=+ptx70 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | %ptxas-verify -arch=sm_80 %}
+; RUN: llc < %s -mcpu=sm_100 -mattr=+ptx88 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | FileCheck -check-prefixes CHECK,CHECK-SM100 %s
+; RUN: %if ptxas-12.9 %{ llc < %s -mcpu=sm_100 -mattr=+ptx88 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | %ptxas-verify -arch=sm_100 %}
+target triple = "nvptx64-nvidia-cuda"
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+; Check straight line reduction.
+define half @reduce_fadd_half(<8 x half> %in) {
+; CHECK-LABEL: reduce_fadd_half(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<18>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_param_0];
+; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r1;
+; CHECK-NEXT:    mov.b16 %rs3, 0x0000;
+; CHECK-NEXT:    add.rn.f16 %rs4, %rs1, %rs3;
+; CHECK-NEXT:    add.rn.f16 %rs5, %rs4, %rs2;
+; CHECK-NEXT:    mov.b32 {%rs6, %rs7}, %r2;
+; CHECK-NEXT:    add.rn.f16 %rs8, %rs5, %rs6;
+; CHECK-NEXT:    add.rn.f16 %rs9, %rs8, %rs7;
+; CHECK-NEXT:    mov.b32 {%rs10, %rs11}, %r3;
+; CHECK-NEXT:    add.rn.f16 %rs12, %rs9, %rs10;
+; CHECK-NEXT:    add.rn.f16 %rs13, %rs12, %rs11;
+; CHECK-NEXT:    mov.b32 {%rs14, %rs15}, %r4;
+; CHECK-NEXT:    add.rn.f16 %rs16, %rs13, %rs14;
+; CHECK-NEXT:    add.rn.f16 %rs17, %rs16, %rs15;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs17;
+; CHECK-NEXT:    ret;
+  %res = call half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
+  ret half %res
+}
+
+; Check tree reduction.
+define half @reduce_fadd_half_reassoc(<8 x half> %in) {
+; CHECK-LABEL: reduce_fadd_half_reassoc(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<18>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_reassoc_param_0];
+; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r4;
+; CHECK-NEXT:    add.rn.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT:    mov.b32 {%rs4, %rs5}, %r3;
+; CHECK-NEXT:    add.rn.f16 %rs6, %rs4, %rs5;
+; CHECK-NEXT:    add.rn.f16 %rs7, %rs6, %rs3;
+; CHECK-NEXT:    mov.b32 {%rs8, %rs9}, %r2;
+; CHECK-NEXT:    add.rn.f16 %rs10, %rs8, %rs9;
+; CHECK-NEXT:    mov.b32 {%rs11, %rs12}, %r1;
+; CHECK-NEXT:    add.rn.f16 %rs13, %rs11, %rs12;
+; CHECK-NEXT:    add.rn.f16 %rs14, %rs13, %rs10;
+; CHECK-NEXT:    add.rn.f16 %rs15, %rs14, %rs7;
+; CHECK-NEXT:    mov.b16 %rs16, 0x0000;
+; CHECK-NEXT:    add.rn.f16 %rs17, %rs15, %rs16;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs17;
+; CHECK-NEXT:    ret;
+  %res = call reassoc half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
+  ret half %res
+}
+
+; Check tree reduction with non-power of 2 size.
+define half @reduce_fadd_half_reassoc_nonpow2(<7 x half> %in) {
+; CHECK-LABEL: reduce_fadd_half_reassoc_nonpow2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<16>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [reduce_fadd_half_reassoc_nonpow2_param_0+8];
+; CHECK-NEXT:    mov.b32 {%rs5, %rs6}, %r1;
+; CHECK-NEXT:    ld.param.b16 %rs7, [reduce_fadd_half_reassoc_nonpow2_param_0+12];
+; CHECK-NEXT:    ld.param....
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/136253


More information about the llvm-commits mailing list