[llvm] 8f3e646 - AMDGPU: Fold fmed3 of fpext sources to f16 fmed3

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu May 18 00:45:00 PDT 2023


Author: Matt Arsenault
Date: 2023-05-18T08:34:46+01:00
New Revision: 8f3e64624c2e49b61ee578aec493260a59a35e80

URL: https://github.com/llvm/llvm-project/commit/8f3e64624c2e49b61ee578aec493260a59a35e80
DIFF: https://github.com/llvm/llvm-project/commit/8f3e64624c2e49b61ee578aec493260a59a35e80.diff

LOG: AMDGPU: Fold fmed3 of fpext sources to f16 fmed3

InstCombine already does this for minnum/maxnum. If we
also apply this to fmed3, we don't need to explicitly
use 16-bit fmed3 if we're not sure the target
supports 16-bit instructions yet.

Added: 
    

Modified: 
    llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
    llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 0fac9c8a6df6..6e8878e7aa2a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -23,6 +23,7 @@
 #include <optional>
 
 using namespace llvm;
+using namespace llvm::PatternMatch;
 
 #define DEBUG_TYPE "AMDGPUtti"
 
@@ -355,6 +356,26 @@ bool GCNTTIImpl::canSimplifyLegacyMulToMul(const Instruction &I,
   return false;
 }
 
+/// Match an fpext from half to float, or a constant we can convert.
+static bool matchFPExtFromF16(Value *Arg, Value *&FPExtSrc) {
+  if (match(Arg, m_OneUse(m_FPExt(m_Value(FPExtSrc)))))
+    return FPExtSrc->getType()->isHalfTy();
+
+  ConstantFP *CFP;
+  if (match(Arg, m_ConstantFP(CFP))) {
+    bool LosesInfo;
+    APFloat Val(CFP->getValueAPF());
+    Val.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &LosesInfo);
+    if (LosesInfo)
+      return false;
+
+    FPExtSrc = ConstantFP::get(Type::getHalfTy(Arg->getContext()), Val);
+    return true;
+  }
+
+  return false;
+}
+
 // Trim all zero components from the end of the vector \p UseV and return
 // an appropriate bitset with known elements.
 static APInt trimTrailingZerosInVector(InstCombiner &IC, Value *UseV,
@@ -732,6 +753,20 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
       }
     }
 
+    if (!ST->hasMed3_16())
+      break;
+
+    Value *X, *Y, *Z;
+
+    // Repeat floating-point width reduction done for minnum/maxnum.
+    // fmed3((fpext X), (fpext Y), (fpext Z)) -> fpext (fmed3(X, Y, Z))
+    if (matchFPExtFromF16(Src0, X) && matchFPExtFromF16(Src1, Y) &&
+        matchFPExtFromF16(Src2, Z)) {
+      Value *NewCall = IC.Builder.CreateIntrinsic(IID, {X->getType()},
+                                                  {X, Y, Z}, &II, II.getName());
+      return new FPExtInst(NewCall, II.getType());
+    }
+
     break;
   }
   case Intrinsic::amdgcn_icmp:

diff  --git a/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll b/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll
index 2370eaf844dc..a31b47b2ca6e 100644
--- a/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll
@@ -24,10 +24,8 @@ define float @fmed3_f32_fpext_f16(half %arg0, half %arg1, half %arg2) #1 {
 ;
 ; GFX9-LABEL: define float @fmed3_f32_fpext_f16
 ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG1:%.*]], half [[ARG2:%.*]]) #[[ATTR1:[0-9]+]] {
-; GFX9-NEXT:    [[ARG0_EXT:%.*]] = fpext half [[ARG0]] to float
-; GFX9-NEXT:    [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float
-; GFX9-NEXT:    [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]])
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG0]], half [[ARG1]], half [[ARG2]])
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %arg0.ext = fpext half %arg0 to float
@@ -48,10 +46,8 @@ define float @fmed3_f32_fpext_f16_flags(half %arg0, half %arg1, half %arg2) #1 {
 ;
 ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_flags
 ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG1:%.*]], half [[ARG2:%.*]]) #[[ATTR1]] {
-; GFX9-NEXT:    [[ARG0_EXT:%.*]] = fpext half [[ARG0]] to float
-; GFX9-NEXT:    [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float
-; GFX9-NEXT:    [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call nsz float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]])
+; GFX9-NEXT:    [[MED31:%.*]] = call nsz half @llvm.amdgcn.fmed3.f16(half [[ARG0]], half [[ARG1]], half [[ARG2]])
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %arg0.ext = fpext half %arg0 to float
@@ -71,9 +67,8 @@ define float @fmed3_f32_fpext_f16_k0(half %arg1, half %arg2) #1 {
 ;
 ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k0
 ; GFX9-SAME: (half [[ARG1:%.*]], half [[ARG2:%.*]]) #[[ATTR1]] {
-; GFX9-NEXT:    [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float
-; GFX9-NEXT:    [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG1_EXT]], float [[ARG2_EXT]], float 2.000000e+00)
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG1]], half [[ARG2]], half 0xH4000)
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %arg1.ext = fpext half %arg1 to float
@@ -92,9 +87,8 @@ define float @fmed3_f32_fpext_f16_k1(half %arg0, half %arg2) #1 {
 ;
 ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k1
 ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG2:%.*]]) #[[ATTR1]] {
-; GFX9-NEXT:    [[ARG0_EXT:%.*]] = fpext half [[ARG0]] to float
-; GFX9-NEXT:    [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG2_EXT]], float 2.000000e+00)
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG0]], half [[ARG2]], half 0xH4000)
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %arg0.ext = fpext half %arg0 to float
@@ -113,9 +107,8 @@ define float @fmed3_f32_fpext_f16_k2(half %arg0, half %arg1) #1 {
 ;
 ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k2
 ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG1:%.*]]) #[[ATTR1]] {
-; GFX9-NEXT:    [[ARG0_EXT:%.*]] = fpext half [[ARG0]] to float
-; GFX9-NEXT:    [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float 2.000000e+00)
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG0]], half [[ARG1]], half 0xH4000)
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %arg0.ext = fpext half %arg0 to float
@@ -133,8 +126,8 @@ define float @fmed3_f32_fpext_f16_k0_k1(half %arg2) #1 {
 ;
 ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k0_k1
 ; GFX9-SAME: (half [[ARG2:%.*]]) #[[ATTR1]] {
-; GFX9-NEXT:    [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG2_EXT]], float 0.000000e+00, float 1.600000e+01)
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG2]], half 0xH0000, half 0xH4C00)
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %arg2.ext = fpext half %arg2 to float
@@ -151,8 +144,8 @@ define float @fmed3_f32_fpext_f16_k0_k2(half %arg1) #1 {
 ;
 ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k0_k2
 ; GFX9-SAME: (half [[ARG1:%.*]]) #[[ATTR1]] {
-; GFX9-NEXT:    [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG1_EXT]], float 0.000000e+00, float 2.000000e+00)
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG1]], half 0xH0000, half 0xH4000)
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %arg1.ext = fpext half %arg1 to float
@@ -177,10 +170,8 @@ define float @fmed3_f32_fpext_f16_fabs(half %arg0, half %arg1, half %arg2) #1 {
 ; GFX9-NEXT:    [[FABS_ARG0:%.*]] = call half @llvm.fabs.f16(half [[ARG0]])
 ; GFX9-NEXT:    [[FABS_ARG1:%.*]] = call half @llvm.fabs.f16(half [[ARG1]])
 ; GFX9-NEXT:    [[FABS_ARG2:%.*]] = call half @llvm.fabs.f16(half [[ARG2]])
-; GFX9-NEXT:    [[ARG0_EXT:%.*]] = fpext half [[FABS_ARG0]] to float
-; GFX9-NEXT:    [[ARG1_EXT:%.*]] = fpext half [[FABS_ARG1]] to float
-; GFX9-NEXT:    [[ARG2_EXT:%.*]] = fpext half [[FABS_ARG2]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]])
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[FABS_ARG0]], half [[FABS_ARG1]], half [[FABS_ARG2]])
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %fabs.arg0 = call half @llvm.fabs.f16(half %arg0)
@@ -208,12 +199,10 @@ define float @fmed3_fabs_f32_fpext_f16(half %arg0, half %arg1, half %arg2) #1 {
 ; GFX9-LABEL: define float @fmed3_fabs_f32_fpext_f16
 ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG1:%.*]], half [[ARG2:%.*]]) #[[ATTR1]] {
 ; GFX9-NEXT:    [[TMP1:%.*]] = call half @llvm.fabs.f16(half [[ARG0]])
-; GFX9-NEXT:    [[FABS_EXT_ARG0:%.*]] = fpext half [[TMP1]] to float
 ; GFX9-NEXT:    [[TMP2:%.*]] = call half @llvm.fabs.f16(half [[ARG1]])
-; GFX9-NEXT:    [[FABS_EXT_ARG1:%.*]] = fpext half [[TMP2]] to float
 ; GFX9-NEXT:    [[TMP3:%.*]] = call half @llvm.fabs.f16(half [[ARG2]])
-; GFX9-NEXT:    [[FABS_EXT_ARG2:%.*]] = fpext half [[TMP3]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[FABS_EXT_ARG0]], float [[FABS_EXT_ARG1]], float [[FABS_EXT_ARG2]])
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[TMP1]], half [[TMP2]], half [[TMP3]])
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %arg0.ext = fpext half %arg0 to float
@@ -243,10 +232,8 @@ define float @fmed3_f32_fpext_f16_fneg(half %arg0, half %arg1, half %arg2) #1 {
 ; GFX9-NEXT:    [[FNEG_ARG0:%.*]] = fneg half [[ARG0]]
 ; GFX9-NEXT:    [[FNEG_ARG1:%.*]] = fneg half [[ARG1]]
 ; GFX9-NEXT:    [[FNEG_ARG2:%.*]] = fneg half [[ARG2]]
-; GFX9-NEXT:    [[ARG0_EXT:%.*]] = fpext half [[FNEG_ARG0]] to float
-; GFX9-NEXT:    [[ARG1_EXT:%.*]] = fpext half [[FNEG_ARG1]] to float
-; GFX9-NEXT:    [[ARG2_EXT:%.*]] = fpext half [[FNEG_ARG2]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]])
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[FNEG_ARG0]], half [[FNEG_ARG1]], half [[FNEG_ARG2]])
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %fneg.arg0 = fneg half %arg0
@@ -315,10 +302,8 @@ define float @fmed3_f32_fpext_f16_fneg_fabs(half %arg0, half %arg1, half %arg2)
 ; GFX9-NEXT:    [[FNEG_FABS_ARG0:%.*]] = fneg half [[FABS_ARG0]]
 ; GFX9-NEXT:    [[FNEG_FABS_ARG1:%.*]] = fneg half [[FABS_ARG1]]
 ; GFX9-NEXT:    [[FNEG_FABS_ARG2:%.*]] = fneg half [[FABS_ARG2]]
-; GFX9-NEXT:    [[ARG0_EXT:%.*]] = fpext half [[FNEG_FABS_ARG0]] to float
-; GFX9-NEXT:    [[ARG1_EXT:%.*]] = fpext half [[FNEG_FABS_ARG1]] to float
-; GFX9-NEXT:    [[ARG2_EXT:%.*]] = fpext half [[FNEG_FABS_ARG2]] to float
-; GFX9-NEXT:    [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]])
+; GFX9-NEXT:    [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[FNEG_FABS_ARG0]], half [[FNEG_FABS_ARG1]], half [[FNEG_FABS_ARG2]])
+; GFX9-NEXT:    [[MED3:%.*]] = fpext half [[MED31]] to float
 ; GFX9-NEXT:    ret float [[MED3]]
 ;
   %fabs.arg0 = call half @llvm.fabs.f16(half %arg0)


        


More information about the llvm-commits mailing list