[llvm] [NVPTX] Add 3-operand fmin/fmax DAGCombines (PR #159729)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 19 02:13:20 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Lewis Crawford (LewisCrawford)

<details>
<summary>Changes</summary>

Add DAGCombiner patterns for pairs of 2-operand min/max instructions to be fused into a single 3-operand min/max instruction for f32s (only for PTX 8.8+ and sm100+).

---
Full diff: https://github.com/llvm/llvm-project/pull/159729.diff


2 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+66-4) 
- (added) llvm/test/CodeGen/NVPTX/fmax3.ll (+260) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..307e1c6f7c227 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -841,10 +841,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
 
   // We have some custom DAG combine patterns for these nodes
-  setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
-                       ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
-                       ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
-                       ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
+  setTargetDAGCombine(
+      {ISD::ADD,          ISD::AND,           ISD::EXTRACT_VECTOR_ELT,
+       ISD::FADD,         ISD::FMAXNUM,       ISD::FMINNUM,
+       ISD::FMAXIMUM,     ISD::FMINIMUM,      ISD::FMAXIMUMNUM,
+       ISD::FMINIMUMNUM,  ISD::MUL,           ISD::SHL,
+       ISD::SREM,         ISD::UREM,          ISD::VSELECT,
+       ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
+       ISD::STORE,        ISD::ZERO_EXTEND,   ISD::SIGN_EXTEND});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5316,6 +5320,56 @@ static SDValue PerformFADDCombine(SDNode *N,
   return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
 }
 
+/// Get 3-input version of a 2-input min/max opcode
+static NVPTXISD::NodeType getMinMax3Opcode(unsigned MinMax2Opcode) {
+  switch (MinMax2Opcode) {
+  case ISD::FMAXNUM:
+  case ISD::FMAXIMUMNUM:
+    return NVPTXISD::FMAXNUM3;
+  case ISD::FMINNUM:
+  case ISD::FMINIMUMNUM:
+    return NVPTXISD::FMINNUM3;
+  case ISD::FMAXIMUM:
+    return NVPTXISD::FMAXIMUM3;
+  case ISD::FMINIMUM:
+    return NVPTXISD::FMINIMUM3;
+  default:
+    llvm_unreachable("Invalid 2-input min/max opcode");
+  }
+}
+
+/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
+/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
+static SDValue PerformFMinMaxCombine(SDNode *N,
+                                     TargetLowering::DAGCombinerInfo &DCI,
+                                     unsigned PTXVersion, unsigned SmVersion) {
+
+  // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
+  EVT VT = N->getValueType(0);
+  if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
+    return SDValue();
+
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  unsigned MinMaxOp2 = N->getOpcode();
+  NVPTXISD::NodeType MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);
+
+  if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
+    // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
+    SDValue A = Op0.getOperand(0);
+    SDValue B = Op0.getOperand(1);
+    SDValue C = Op1;
+    return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
+  } else if (Op1->getOpcode() == MinMaxOp2 && Op1->hasOneUse()) {
+    // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
+    SDValue A = Op0;
+    SDValue B = Op1.getOperand(0);
+    SDValue C = Op1.getOperand(1);
+    return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
+  }
+  return SDValue();
+}
+
 static SDValue PerformREMCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  CodeGenOptLevel OptLevel) {
@@ -5996,6 +6050,14 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     return PerformEXTRACTCombine(N, DCI);
   case ISD::FADD:
     return PerformFADDCombine(N, DCI, OptLevel);
+  case ISD::FMAXNUM:
+  case ISD::FMINNUM:
+  case ISD::FMAXIMUM:
+  case ISD::FMINIMUM:
+  case ISD::FMAXIMUMNUM:
+  case ISD::FMINIMUMNUM:
+    return PerformFMinMaxCombine(N, DCI, STI.getPTXVersion(),
+                                 STI.getSmVersion());
   case ISD::LOAD:
   case NVPTXISD::LoadV2:
   case NVPTXISD::LoadV4:
diff --git a/llvm/test/CodeGen/NVPTX/fmax3.ll b/llvm/test/CodeGen/NVPTX/fmax3.ll
new file mode 100644
index 0000000000000..9339b2e247af4
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fmax3.ll
@@ -0,0 +1,260 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -march=nvptx64 -mcpu=sm_100f -o - %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
+
+define void @test_fmaxnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaxnum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fmaxnum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fmaxnum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fmaxnum3_param_2];
+; CHECK-NEXT:    max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fmaxnum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call float @llvm.maxnum.f32(float %a, float %b)
+  %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
+  store float %max_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fminnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminnum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fminnum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fminnum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fminnum3_param_2];
+; CHECK-NEXT:    min.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fminnum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %min_ab = call float @llvm.minnum.f32(float %a, float %b)
+  %min_abc = call float @llvm.minnum.f32(float %min_ab, float %c)
+  store float %min_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fmaximum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaximum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fmaximum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fmaximum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fmaximum3_param_2];
+; CHECK-NEXT:    max.NaN.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fmaximum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call float @llvm.maximum.f32(float %a, float %b)
+  %max_abc = call float @llvm.maximum.f32(float %max_ab, float %c)
+  store float %max_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fminimum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminimum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fminimum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fminimum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fminimum3_param_2];
+; CHECK-NEXT:    min.NaN.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fminimum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %min_ab = call float @llvm.minimum.f32(float %a, float %b)
+  %min_abc = call float @llvm.minimum.f32(float %min_ab, float %c)
+  store float %min_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fmaximumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaximumnum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fmaximumnum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fmaximumnum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fmaximumnum3_param_2];
+; CHECK-NEXT:    max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fmaximumnum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call float @llvm.maximumnum.f32(float %a, float %b)
+  %max_abc = call float @llvm.maximumnum.f32(float %max_ab, float %c)
+  store float %max_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+define void @test_fminimumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminimumnum3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fminimumnum3_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fminimumnum3_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fminimumnum3_param_2];
+; CHECK-NEXT:    min.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fminimumnum3_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %min_ab = call float @llvm.minimumnum.f32(float %a, float %b)
+  %min_abc = call float @llvm.minimumnum.f32(float %min_ab, float %c)
+  store float %min_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+; Test commuted operands (second operand is the nested operation)
+define void @test_fmaxnum3_commuted(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaxnum3_commuted(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_fmaxnum3_commuted_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_fmaxnum3_commuted_param_1];
+; CHECK-NEXT:    ld.param.b32 %r3, [test_fmaxnum3_commuted_param_2];
+; CHECK-NEXT:    max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_fmaxnum3_commuted_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %max_bc = call float @llvm.maxnum.f32(float %b, float %c)
+  %max_abc = call float @llvm.maxnum.f32(float %a, float %max_bc)
+  store float %max_abc, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+; NEGATIVE TEST: Mixed min/max operations should not combine
+define void @test_mixed_minmax_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_mixed_minmax_no_combine(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_mixed_minmax_no_combine_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_mixed_minmax_no_combine_param_1];
+; CHECK-NEXT:    min.f32 %r3, %r1, %r2;
+; CHECK-NEXT:    ld.param.b32 %r4, [test_mixed_minmax_no_combine_param_2];
+; CHECK-NEXT:    max.f32 %r5, %r3, %r4;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_mixed_minmax_no_combine_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r5;
+; CHECK-NEXT:    ret;
+entry:
+  %min_ab = call float @llvm.minnum.f32(float %a, float %b)
+  %max_result = call float @llvm.maxnum.f32(float %min_ab, float %c)
+  store float %max_result, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+; NEGATIVE TEST: Mixed maxnum/maximum operations should not combine
+define void @test_mixed_maxnum_maximum_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_mixed_maxnum_maximum_no_combine(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_mixed_maxnum_maximum_no_combine_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_mixed_maxnum_maximum_no_combine_param_1];
+; CHECK-NEXT:    max.f32 %r3, %r1, %r2;
+; CHECK-NEXT:    ld.param.b32 %r4, [test_mixed_maxnum_maximum_no_combine_param_2];
+; CHECK-NEXT:    max.NaN.f32 %r5, %r3, %r4;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_mixed_maxnum_maximum_no_combine_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r5;
+; CHECK-NEXT:    ret;
+entry:
+  %maxnum_ab = call float @llvm.maxnum.f32(float %a, float %b)
+  %maximum_result = call float @llvm.maximum.f32(float %maxnum_ab, float %c)
+  store float %maximum_result, ptr addrspace(1) %output, align 4
+  ret void
+}
+
+; NEGATIVE TEST: f16 should not be combined (only f32 supported)
+define void @test_f16_no_combine(half %a, half %b, half %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_f16_no_combine(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<6>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b16 %rs1, [test_f16_no_combine_param_0];
+; CHECK-NEXT:    ld.param.b16 %rs2, [test_f16_no_combine_param_1];
+; CHECK-NEXT:    max.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT:    ld.param.b16 %rs4, [test_f16_no_combine_param_2];
+; CHECK-NEXT:    max.f16 %rs5, %rs3, %rs4;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_f16_no_combine_param_3];
+; CHECK-NEXT:    st.global.b16 [%rd1], %rs5;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call half @llvm.maxnum.f16(half %a, half %b)
+  %max_abc = call half @llvm.maxnum.f16(half %max_ab, half %c)
+  store half %max_abc, ptr addrspace(1) %output, align 2
+  ret void
+}
+
+; NEGATIVE TEST: Multiple uses of intermediate result should not combine
+define void @test_multiple_uses_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output1, ptr addrspace(1) %output2) {
+; CHECK-LABEL: test_multiple_uses_no_combine(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b32 %r1, [test_multiple_uses_no_combine_param_0];
+; CHECK-NEXT:    ld.param.b32 %r2, [test_multiple_uses_no_combine_param_1];
+; CHECK-NEXT:    max.f32 %r3, %r1, %r2;
+; CHECK-NEXT:    ld.param.b32 %r4, [test_multiple_uses_no_combine_param_2];
+; CHECK-NEXT:    max.f32 %r5, %r3, %r4;
+; CHECK-NEXT:    ld.param.b64 %rd1, [test_multiple_uses_no_combine_param_3];
+; CHECK-NEXT:    st.global.b32 [%rd1], %r3;
+; CHECK-NEXT:    ld.param.b64 %rd2, [test_multiple_uses_no_combine_param_4];
+; CHECK-NEXT:    st.global.b32 [%rd2], %r5;
+; CHECK-NEXT:    ret;
+entry:
+  %max_ab = call float @llvm.maxnum.f32(float %a, float %b)
+  %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
+  ; Multiple uses of %max_ab should prevent combining
+  store float %max_ab, ptr addrspace(1) %output1, align 4
+  store float %max_abc, ptr addrspace(1) %output2, align 4
+  ret void
+}
+
+; Declare all the intrinsics we need
+declare float @llvm.maxnum.f32(float, float) #0
+declare float @llvm.minnum.f32(float, float) #0
+declare float @llvm.maximum.f32(float, float) #0
+declare float @llvm.minimum.f32(float, float) #0
+declare float @llvm.maximumnum.f32(float, float) #0
+declare float @llvm.minimumnum.f32(float, float) #0
+declare half @llvm.maxnum.f16(half, half) #0
+
+attributes #0 = { nounwind readnone speculatable willreturn }

``````````

</details>


https://github.com/llvm/llvm-project/pull/159729


More information about the llvm-commits mailing list