[llvm] [NVPTX] Add 3-operand fmin/fmax DAGCombines (PR #159729)
Lewis Crawford via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 19 06:38:43 PDT 2025
https://github.com/LewisCrawford updated https://github.com/llvm/llvm-project/pull/159729
>From d2a3b9f6e6da9a4ac91213d79ae4b6093f6b0907 Mon Sep 17 00:00:00 2001
From: Lewis Crawford <lcrawford at nvidia.com>
Date: Fri, 19 Sep 2025 09:07:11 +0000
Subject: [PATCH 1/2] [NVPTX] Add 3-operand fmin/fmax DAGCombines
Add DAGCombiner patterns for pairs of 2-operand min/max instructions
to be fused into a single 3-operand min/max instruction for f32s
(only for PTX 8.8+ and sm100+).
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 70 +++++-
llvm/test/CodeGen/NVPTX/fmax3.ll | 260 ++++++++++++++++++++
2 files changed, 326 insertions(+), 4 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/fmax3.ll
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d3fb657851fe2..307e1c6f7c227 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -841,10 +841,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
// We have some custom DAG combine patterns for these nodes
- setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
- ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
- ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
+ setTargetDAGCombine(
+ {ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT,
+ ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM,
+ ISD::FMAXIMUM, ISD::FMINIMUM, ISD::FMAXIMUMNUM,
+ ISD::FMINIMUMNUM, ISD::MUL, ISD::SHL,
+ ISD::SREM, ISD::UREM, ISD::VSELECT,
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
+ ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5316,6 +5320,56 @@ static SDValue PerformFADDCombine(SDNode *N,
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}
+/// Get 3-input version of a 2-input min/max opcode
+static NVPTXISD::NodeType getMinMax3Opcode(unsigned MinMax2Opcode) {
+ switch (MinMax2Opcode) {
+ case ISD::FMAXNUM:
+ case ISD::FMAXIMUMNUM:
+ return NVPTXISD::FMAXNUM3;
+ case ISD::FMINNUM:
+ case ISD::FMINIMUMNUM:
+ return NVPTXISD::FMINNUM3;
+ case ISD::FMAXIMUM:
+ return NVPTXISD::FMAXIMUM3;
+ case ISD::FMINIMUM:
+ return NVPTXISD::FMINIMUM3;
+ default:
+ llvm_unreachable("Invalid 2-input min/max opcode");
+ }
+}
+
+/// PerformFMinMaxCombine - Combine (fmaxnum (fmaxnum a, b), c) into
+/// (fmaxnum3 a, b, c). Also covers other llvm min/max intrinsics.
+static SDValue PerformFMinMaxCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ unsigned PTXVersion, unsigned SmVersion) {
+
+ // 3-input min/max requires PTX 8.8+ and SM_100+, and only supports f32s
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::f32 || PTXVersion < 88 || SmVersion < 100)
+ return SDValue();
+
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ unsigned MinMaxOp2 = N->getOpcode();
+ NVPTXISD::NodeType MinMaxOp3 = getMinMax3Opcode(MinMaxOp2);
+
+ if (Op0.getOpcode() == MinMaxOp2 && Op0.hasOneUse()) {
+ // (maxnum (maxnum a, b), c) -> (maxnum3 a, b, c)
+ SDValue A = Op0.getOperand(0);
+ SDValue B = Op0.getOperand(1);
+ SDValue C = Op1;
+ return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
+ } else if (Op1->getOpcode() == MinMaxOp2 && Op1->hasOneUse()) {
+ // (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
+ SDValue A = Op0;
+ SDValue B = Op1.getOperand(0);
+ SDValue C = Op1.getOperand(1);
+ return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
+ }
+ return SDValue();
+}
+
static SDValue PerformREMCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
@@ -5996,6 +6050,14 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformEXTRACTCombine(N, DCI);
case ISD::FADD:
return PerformFADDCombine(N, DCI, OptLevel);
+ case ISD::FMAXNUM:
+ case ISD::FMINNUM:
+ case ISD::FMAXIMUM:
+ case ISD::FMINIMUM:
+ case ISD::FMAXIMUMNUM:
+ case ISD::FMINIMUMNUM:
+ return PerformFMinMaxCombine(N, DCI, STI.getPTXVersion(),
+ STI.getSmVersion());
case ISD::LOAD:
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
diff --git a/llvm/test/CodeGen/NVPTX/fmax3.ll b/llvm/test/CodeGen/NVPTX/fmax3.ll
new file mode 100644
index 0000000000000..9339b2e247af4
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/fmax3.ll
@@ -0,0 +1,260 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -march=nvptx64 -mcpu=sm_100f -o - %s | FileCheck %s
+
+target triple = "nvptx64-nvidia-cuda"
+target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
+
+define void @test_fmaxnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaxnum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_param_2];
+; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call float @llvm.maxnum.f32(float %a, float %b)
+ %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
+ store float %max_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fminnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminnum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fminnum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fminnum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fminnum3_param_2];
+; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fminnum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %min_ab = call float @llvm.minnum.f32(float %a, float %b)
+ %min_abc = call float @llvm.minnum.f32(float %min_ab, float %c)
+ store float %min_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fmaximum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaximum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximum3_param_2];
+; CHECK-NEXT: max.NaN.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call float @llvm.maximum.f32(float %a, float %b)
+ %max_abc = call float @llvm.maximum.f32(float %max_ab, float %c)
+ store float %max_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fminimum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminimum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fminimum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fminimum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fminimum3_param_2];
+; CHECK-NEXT: min.NaN.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %min_ab = call float @llvm.minimum.f32(float %a, float %b)
+ %min_abc = call float @llvm.minimum.f32(float %min_ab, float %c)
+ store float %min_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fmaximumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaximumnum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fmaximumnum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fmaximumnum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fmaximumnum3_param_2];
+; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaximumnum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call float @llvm.maximumnum.f32(float %a, float %b)
+ %max_abc = call float @llvm.maximumnum.f32(float %max_ab, float %c)
+ store float %max_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+define void @test_fminimumnum3(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fminimumnum3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fminimumnum3_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fminimumnum3_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fminimumnum3_param_2];
+; CHECK-NEXT: min.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fminimumnum3_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %min_ab = call float @llvm.minimumnum.f32(float %a, float %b)
+ %min_abc = call float @llvm.minimumnum.f32(float %min_ab, float %c)
+ store float %min_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+; Test commuted operands (second operand is the nested operation)
+define void @test_fmaxnum3_commuted(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_fmaxnum3_commuted(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_fmaxnum3_commuted_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_fmaxnum3_commuted_param_1];
+; CHECK-NEXT: ld.param.b32 %r3, [test_fmaxnum3_commuted_param_2];
+; CHECK-NEXT: max.f32 %r4, %r1, %r2, %r3;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_fmaxnum3_commuted_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %max_bc = call float @llvm.maxnum.f32(float %b, float %c)
+ %max_abc = call float @llvm.maxnum.f32(float %a, float %max_bc)
+ store float %max_abc, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+; NEGATIVE TEST: Mixed min/max operations should not combine
+define void @test_mixed_minmax_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_mixed_minmax_no_combine(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_minmax_no_combine_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_minmax_no_combine_param_1];
+; CHECK-NEXT: min.f32 %r3, %r1, %r2;
+; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_minmax_no_combine_param_2];
+; CHECK-NEXT: max.f32 %r5, %r3, %r4;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_minmax_no_combine_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r5;
+; CHECK-NEXT: ret;
+entry:
+ %min_ab = call float @llvm.minnum.f32(float %a, float %b)
+ %max_result = call float @llvm.maxnum.f32(float %min_ab, float %c)
+ store float %max_result, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+; NEGATIVE TEST: Mixed maxnum/maximum operations should not combine
+define void @test_mixed_maxnum_maximum_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_mixed_maxnum_maximum_no_combine(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_mixed_maxnum_maximum_no_combine_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_mixed_maxnum_maximum_no_combine_param_1];
+; CHECK-NEXT: max.f32 %r3, %r1, %r2;
+; CHECK-NEXT: ld.param.b32 %r4, [test_mixed_maxnum_maximum_no_combine_param_2];
+; CHECK-NEXT: max.NaN.f32 %r5, %r3, %r4;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_mixed_maxnum_maximum_no_combine_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r5;
+; CHECK-NEXT: ret;
+entry:
+ %maxnum_ab = call float @llvm.maxnum.f32(float %a, float %b)
+ %maximum_result = call float @llvm.maximum.f32(float %maxnum_ab, float %c)
+ store float %maximum_result, ptr addrspace(1) %output, align 4
+ ret void
+}
+
+; NEGATIVE TEST: f16 should not be combined (only f32 supported)
+define void @test_f16_no_combine(half %a, half %b, half %c, ptr addrspace(1) %output) {
+; CHECK-LABEL: test_f16_no_combine(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<6>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b16 %rs1, [test_f16_no_combine_param_0];
+; CHECK-NEXT: ld.param.b16 %rs2, [test_f16_no_combine_param_1];
+; CHECK-NEXT: max.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT: ld.param.b16 %rs4, [test_f16_no_combine_param_2];
+; CHECK-NEXT: max.f16 %rs5, %rs3, %rs4;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_f16_no_combine_param_3];
+; CHECK-NEXT: st.global.b16 [%rd1], %rs5;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call half @llvm.maxnum.f16(half %a, half %b)
+ %max_abc = call half @llvm.maxnum.f16(half %max_ab, half %c)
+ store half %max_abc, ptr addrspace(1) %output, align 2
+ ret void
+}
+
+; NEGATIVE TEST: Multiple uses of intermediate result should not combine
+define void @test_multiple_uses_no_combine(float %a, float %b, float %c, ptr addrspace(1) %output1, ptr addrspace(1) %output2) {
+; CHECK-LABEL: test_multiple_uses_no_combine(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<6>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b32 %r1, [test_multiple_uses_no_combine_param_0];
+; CHECK-NEXT: ld.param.b32 %r2, [test_multiple_uses_no_combine_param_1];
+; CHECK-NEXT: max.f32 %r3, %r1, %r2;
+; CHECK-NEXT: ld.param.b32 %r4, [test_multiple_uses_no_combine_param_2];
+; CHECK-NEXT: max.f32 %r5, %r3, %r4;
+; CHECK-NEXT: ld.param.b64 %rd1, [test_multiple_uses_no_combine_param_3];
+; CHECK-NEXT: st.global.b32 [%rd1], %r3;
+; CHECK-NEXT: ld.param.b64 %rd2, [test_multiple_uses_no_combine_param_4];
+; CHECK-NEXT: st.global.b32 [%rd2], %r5;
+; CHECK-NEXT: ret;
+entry:
+ %max_ab = call float @llvm.maxnum.f32(float %a, float %b)
+ %max_abc = call float @llvm.maxnum.f32(float %max_ab, float %c)
+ ; Multiple uses of %max_ab should prevent combining
+ store float %max_ab, ptr addrspace(1) %output1, align 4
+ store float %max_abc, ptr addrspace(1) %output2, align 4
+ ret void
+}
+
+; Declare all the intrinsics we need
+declare float @llvm.maxnum.f32(float, float) #0
+declare float @llvm.minnum.f32(float, float) #0
+declare float @llvm.maximum.f32(float, float) #0
+declare float @llvm.minimum.f32(float, float) #0
+declare float @llvm.maximumnum.f32(float, float) #0
+declare float @llvm.minimumnum.f32(float, float) #0
+declare half @llvm.maxnum.f16(half, half) #0
+
+attributes #0 = { nounwind readnone speculatable willreturn }
>From 3f162aace550ff9fe2eb37d63ec5c472a5552442 Mon Sep 17 00:00:00 2001
From: Lewis Crawford <lcrawford at nvidia.com>
Date: Fri, 19 Sep 2025 13:36:57 +0000
Subject: [PATCH 2/2] Change Op1->getOperand to Op1.getOperand
Use . instead of -> operators on the SDValue Op1 for
consistency. Both are equivalent, as the * operator on
an SDValue returns the inner SDNode, and the getOperand
function on an SDValue just calls getOperand on the inner
node anyway, but we should use the same approach consistently.
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 307e1c6f7c227..ca8a3f69f991d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5360,7 +5360,7 @@ static SDValue PerformFMinMaxCombine(SDNode *N,
SDValue B = Op0.getOperand(1);
SDValue C = Op1;
return DCI.DAG.getNode(MinMaxOp3, SDLoc(N), VT, A, B, C, N->getFlags());
- } else if (Op1->getOpcode() == MinMaxOp2 && Op1->hasOneUse()) {
+ } else if (Op1.getOpcode() == MinMaxOp2 && Op1.hasOneUse()) {
// (maxnum a, (maxnum b, c)) -> (maxnum3 a, b, c)
SDValue A = Op0;
SDValue B = Op1.getOperand(0);
More information about the llvm-commits
mailing list