[llvm] [AMDGPU] Fold fmed3 when inputs include infinity (PR #144824)
Darren Wihandi via llvm-commits
llvm-commits at lists.llvm.org
Thu Jun 19 11:58:51 PDT 2025
https://github.com/fairywreath updated https://github.com/llvm/llvm-project/pull/144824
>From 2956b11dc09fe5efd669c534b24c8ab43d1a3198 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Wed, 18 Jun 2025 22:10:05 -0400
Subject: [PATCH 1/2] [AMDGPU] Fold fmed3 when inputs include infinity
---
.../AMDGPU/AMDGPUInstCombineIntrinsic.cpp | 40 ++++++++-
.../Transforms/InstCombine/AMDGPU/fmed3.ll | 90 +++++++++++++++++++
2 files changed, 129 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 5477c5eae9392..7554c6953d76f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1039,7 +1039,6 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
const APFloat *ConstSrc1 = nullptr;
const APFloat *ConstSrc2 = nullptr;
- // TODO: Also can fold to 2 operands with infinities.
if ((match(Src0, m_APFloat(ConstSrc0)) && ConstSrc0->isNaN()) ||
isa<UndefValue>(Src0)) {
switch (fpenvIEEEMode(II)) {
@@ -1088,6 +1087,45 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
case KnownIEEEMode::Unknown:
break;
}
+ } else if (match(Src0, m_APFloat(ConstSrc0)) && ConstSrc0->isInfinity()) {
+ switch (fpenvIEEEMode(II)) {
+ case KnownIEEEMode::On:
+ V = ConstSrc0->isNegative() ? IC.Builder.CreateMinNum(Src1, Src2)
+ : IC.Builder.CreateMaxNum(Src1, Src2);
+ break;
+ case KnownIEEEMode::Off:
+ V = ConstSrc0->isNegative() ? IC.Builder.CreateMinimumNum(Src1, Src2)
+ : IC.Builder.CreateMaximumNum(Src1, Src2);
+ break;
+ case KnownIEEEMode::Unknown:
+ break;
+ }
+ } else if (match(Src1, m_APFloat(ConstSrc1)) && ConstSrc1->isInfinity()) {
+ switch (fpenvIEEEMode(II)) {
+ case KnownIEEEMode::On:
+ V = ConstSrc1->isNegative() ? IC.Builder.CreateMinNum(Src0, Src2)
+ : IC.Builder.CreateMaxNum(Src0, Src2);
+ break;
+ case KnownIEEEMode::Off:
+ V = ConstSrc1->isNegative() ? IC.Builder.CreateMinimumNum(Src0, Src2)
+ : IC.Builder.CreateMaximumNum(Src0, Src2);
+ break;
+ case KnownIEEEMode::Unknown:
+ break;
+ }
+ } else if (match(Src2, m_APFloat(ConstSrc2)) && ConstSrc2->isInfinity()) {
+ switch (fpenvIEEEMode(II)) {
+ case KnownIEEEMode::On:
+ V = ConstSrc2->isNegative() ? IC.Builder.CreateMinNum(Src0, Src1)
+ : IC.Builder.CreateMaxNum(Src0, Src1);
+ break;
+ case KnownIEEEMode::Off:
+ V = ConstSrc2->isNegative() ? IC.Builder.CreateMinimumNum(Src0, Src1)
+ : IC.Builder.CreateMaximumNum(Src0, Src1);
+ break;
+ case KnownIEEEMode::Unknown:
+ break;
+ }
}
if (V) {
diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll b/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll
index d9311008bd680..361a2b8280910 100644
--- a/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll
@@ -521,6 +521,96 @@ define float @fmed3_neg2_3_snan1_f32(float %x, float %y) #1 {
ret float %med3
}
+define float @fmed3_inf_x_y_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_inf_x_y_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT: ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_inf_x_y_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.maximumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT: ret float [[MED3]]
+;
+ %med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF0000000000000, float %x, float %y)
+ ret float %med3
+}
+
+define float @fmed3_x_inf_y_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_x_inf_y_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT: ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_x_inf_y_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.maximumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT: ret float [[MED3]]
+;
+ %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 0x7FF0000000000000, float %y)
+ ret float %med3
+}
+
+define float @fmed3_x_y_inf_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_x_y_inf_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT: ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_x_y_inf_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.maximumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT: ret float [[MED3]]
+;
+ %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 0x7FF0000000000000)
+ ret float %med3
+}
+
+define float @fmed3_ninf_x_y_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_ninf_x_y_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT: ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_ninf_x_y_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.minimumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT: ret float [[MED3]]
+;
+ %med3 = call float @llvm.amdgcn.fmed3.f32(float 0xFFF0000000000000, float %x, float %y)
+ ret float %med3
+}
+
+define float @fmed3_x_ninf_y_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_x_ninf_y_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT: ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_x_ninf_y_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.minimumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT: ret float [[MED3]]
+;
+ %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 0xFFF0000000000000, float %y)
+ ret float %med3
+}
+
+define float @fmed3_x_y_ninf_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_x_y_ninf_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT: [[MED3:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT: ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_x_y_ninf_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT: [[MED3:%.*]] = call float @llvm.minimumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT: ret float [[MED3]]
+;
+ %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 0xFFF0000000000000)
+ ret float %med3
+}
+
; --------------------------------------------------------------------
; llvm.amdgcn.fmed3 with default mode implied by shader CC
; --------------------------------------------------------------------
>From 706242313821a8b41c2d002ee6a24b690a192a38 Mon Sep 17 00:00:00 2001
From: fairywreath <nerradfour at gmail.com>
Date: Thu, 19 Jun 2025 12:44:28 -0600
Subject: [PATCH 2/2] Merge with nan handling and add comment
---
.../AMDGPU/AMDGPUInstCombineIntrinsic.cpp | 85 ++++++++-----------
1 file changed, 34 insertions(+), 51 deletions(-)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 7554c6953d76f..b8996fb97f1cb 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1031,6 +1031,14 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
// s1 _nan: min(s0, s2)
// s2 _nan: min(s0, s1)
+ // med3 behavior with infinity
+ // s0 +inf: max(s1, s2)
+ // s1 +inf: max(s0, s2)
+ // s2 +inf: max(s0, s1)
+ // s0 -inf: min(s1, s2)
+ // s1 -inf: min(s0, s2)
+ // s2 -inf: min(s0, s1)
+
// Checking for NaN before canonicalization provides better fidelity when
// mapping other operations onto fmed3 since the order of operands is
// unchanged.
@@ -1039,89 +1047,64 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
const APFloat *ConstSrc1 = nullptr;
const APFloat *ConstSrc2 = nullptr;
- if ((match(Src0, m_APFloat(ConstSrc0)) && ConstSrc0->isNaN()) ||
+ if ((match(Src0, m_APFloat(ConstSrc0)) &&
+ (ConstSrc0->isNaN() || ConstSrc0->isInfinity())) ||
isa<UndefValue>(Src0)) {
+ const bool IsPosInfinity = ConstSrc0 && ConstSrc0->isPosInfinity();
switch (fpenvIEEEMode(II)) {
case KnownIEEEMode::On:
// TODO: If Src2 is snan, does it need quieting?
- if (ConstSrc0 && ConstSrc0->isSignaling())
+ if (ConstSrc0 && ConstSrc0->isNaN() && ConstSrc0->isSignaling())
return IC.replaceInstUsesWith(II, Src2);
- V = IC.Builder.CreateMinNum(Src1, Src2);
+
+ V = IsPosInfinity ? IC.Builder.CreateMaxNum(Src1, Src2)
+ : IC.Builder.CreateMinNum(Src1, Src2);
break;
case KnownIEEEMode::Off:
- V = IC.Builder.CreateMinimumNum(Src1, Src2);
+ V = IsPosInfinity ? IC.Builder.CreateMaximumNum(Src1, Src2)
+ : IC.Builder.CreateMinimumNum(Src1, Src2);
break;
case KnownIEEEMode::Unknown:
break;
}
- } else if ((match(Src1, m_APFloat(ConstSrc1)) && ConstSrc1->isNaN()) ||
+ } else if ((match(Src1, m_APFloat(ConstSrc1)) &&
+ (ConstSrc1->isNaN() || ConstSrc1->isInfinity())) ||
isa<UndefValue>(Src1)) {
+ const bool IsPosInfinity = ConstSrc1 && ConstSrc1->isPosInfinity();
switch (fpenvIEEEMode(II)) {
case KnownIEEEMode::On:
// TODO: If Src2 is snan, does it need quieting?
- if (ConstSrc1 && ConstSrc1->isSignaling())
+ if (ConstSrc1 && ConstSrc1->isNaN() && ConstSrc1->isSignaling())
return IC.replaceInstUsesWith(II, Src2);
- V = IC.Builder.CreateMinNum(Src0, Src2);
+ V = IsPosInfinity ? IC.Builder.CreateMaxNum(Src0, Src2)
+ : IC.Builder.CreateMinNum(Src0, Src2);
break;
case KnownIEEEMode::Off:
- V = IC.Builder.CreateMinimumNum(Src0, Src2);
+ V = IsPosInfinity ? IC.Builder.CreateMaximumNum(Src0, Src2)
+ : IC.Builder.CreateMinimumNum(Src0, Src2);
break;
case KnownIEEEMode::Unknown:
break;
}
- } else if ((match(Src2, m_APFloat(ConstSrc2)) && ConstSrc2->isNaN()) ||
+ } else if ((match(Src2, m_APFloat(ConstSrc2)) &&
+ (ConstSrc2->isNaN() || ConstSrc2->isInfinity())) ||
isa<UndefValue>(Src2)) {
switch (fpenvIEEEMode(II)) {
case KnownIEEEMode::On:
- if (ConstSrc2 && ConstSrc2->isSignaling()) {
+ if (ConstSrc2 && ConstSrc2->isNaN() && ConstSrc2->isSignaling()) {
auto *Quieted = ConstantFP::get(II.getType(), ConstSrc2->makeQuiet());
return IC.replaceInstUsesWith(II, Quieted);
}
- V = IC.Builder.CreateMinNum(Src0, Src1);
- break;
- case KnownIEEEMode::Off:
- V = IC.Builder.CreateMaximumNum(Src0, Src1);
- break;
- case KnownIEEEMode::Unknown:
- break;
- }
- } else if (match(Src0, m_APFloat(ConstSrc0)) && ConstSrc0->isInfinity()) {
- switch (fpenvIEEEMode(II)) {
- case KnownIEEEMode::On:
- V = ConstSrc0->isNegative() ? IC.Builder.CreateMinNum(Src1, Src2)
- : IC.Builder.CreateMaxNum(Src1, Src2);
- break;
- case KnownIEEEMode::Off:
- V = ConstSrc0->isNegative() ? IC.Builder.CreateMinimumNum(Src1, Src2)
- : IC.Builder.CreateMaximumNum(Src1, Src2);
- break;
- case KnownIEEEMode::Unknown:
- break;
- }
- } else if (match(Src1, m_APFloat(ConstSrc1)) && ConstSrc1->isInfinity()) {
- switch (fpenvIEEEMode(II)) {
- case KnownIEEEMode::On:
- V = ConstSrc1->isNegative() ? IC.Builder.CreateMinNum(Src0, Src2)
- : IC.Builder.CreateMaxNum(Src0, Src2);
- break;
- case KnownIEEEMode::Off:
- V = ConstSrc1->isNegative() ? IC.Builder.CreateMinimumNum(Src0, Src2)
- : IC.Builder.CreateMaximumNum(Src0, Src2);
- break;
- case KnownIEEEMode::Unknown:
- break;
- }
- } else if (match(Src2, m_APFloat(ConstSrc2)) && ConstSrc2->isInfinity()) {
- switch (fpenvIEEEMode(II)) {
- case KnownIEEEMode::On:
- V = ConstSrc2->isNegative() ? IC.Builder.CreateMinNum(Src0, Src1)
- : IC.Builder.CreateMaxNum(Src0, Src1);
+ V = (ConstSrc2 && ConstSrc2->isPosInfinity())
+ ? IC.Builder.CreateMaxNum(Src0, Src1)
+ : IC.Builder.CreateMinNum(Src0, Src1);
break;
case KnownIEEEMode::Off:
- V = ConstSrc2->isNegative() ? IC.Builder.CreateMinimumNum(Src0, Src1)
- : IC.Builder.CreateMaximumNum(Src0, Src1);
+ V = (ConstSrc2 && ConstSrc2->isNegInfinity())
+ ? IC.Builder.CreateMinimumNum(Src0, Src1)
+ : IC.Builder.CreateMaximumNum(Src0, Src1);
break;
case KnownIEEEMode::Unknown:
break;
More information about the llvm-commits
mailing list