[llvm] cc1b9ac - [NVPTX] Lower fp16 fminnum, fmaxnum to native on sm_80.

Christian Sigg via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 12 23:52:39 PST 2022


Author: Christian Sigg
Date: 2022-01-13T08:52:31+01:00
New Revision: cc1b9acf550d13702a20cf7e5bac8a94bf2202ba

URL: https://github.com/llvm/llvm-project/commit/cc1b9acf550d13702a20cf7e5bac8a94bf2202ba
DIFF: https://github.com/llvm/llvm-project/commit/cc1b9acf550d13702a20cf7e5bac8a94bf2202ba.diff

LOG: [NVPTX] Lower fp16 fminnum, fmaxnum to native on sm_80.

Reviewed By: bkramer, tra

Differential Revision: https://reviews.llvm.org/D117122

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/test/CodeGen/NVPTX/math-intrins.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index e2f6b69fc530a..faa873f07cadf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -560,10 +560,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setOperationAction(Op, MVT::f64, Legal);
     setOperationAction(Op, MVT::v2f16, Expand);
   }
-  setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
-  setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
-  setOperationAction(ISD::FMINIMUM, MVT::f16, Promote);
-  setOperationAction(ISD::FMAXIMUM, MVT::f16, Promote);
+  // max.f16 is supported on sm_80+.
+  if (STI.allowFP16Math() && STI.getSmVersion() >= 80 &&
+      STI.getPTXVersion() >= 70) {
+    setOperationAction(ISD::FMINNUM, MVT::f16, Legal);
+    setOperationAction(ISD::FMAXNUM, MVT::f16, Legal);
+    setOperationAction(ISD::FMINNUM, MVT::v2f16, Legal);
+    setOperationAction(ISD::FMAXNUM, MVT::v2f16, Legal);
+  }
 
   // No FEXP2, FLOG2.  The PTX ex2 and log2 functions are always approximate.
   // No FPOW or FREM in PTX.

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 96386af569de6..03f7cd0c71524 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -249,6 +249,32 @@ multiclass F3<string OpcStr, SDNode OpNode> {
                (ins Float32Regs:$a, f32imm:$b),
                !strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
                [(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>;
+
+   def f16rr_ftz :
+     NVPTXInst<(outs Float16Regs:$dst),
+               (ins Float16Regs:$a, Float16Regs:$b),
+               !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
+               [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
+               Requires<[useFP16Math, doF32FTZ]>;
+   def f16rr :
+     NVPTXInst<(outs Float16Regs:$dst),
+               (ins Float16Regs:$a, Float16Regs:$b),
+               !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
+               [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
+               Requires<[useFP16Math]>;
+
+   def f16x2rr_ftz :
+     NVPTXInst<(outs Float16x2Regs:$dst),
+               (ins Float16x2Regs:$a, Float16x2Regs:$b),
+               !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
+               [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
+               Requires<[useFP16Math, doF32FTZ]>;
+   def f16x2rr :
+     NVPTXInst<(outs Float16x2Regs:$dst),
+               (ins Float16x2Regs:$a, Float16x2Regs:$b),
+               !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
+               [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
+               Requires<[useFP16Math]>;
 }
 
 // Template for instructions which take three FP args.  The

diff  --git a/llvm/test/CodeGen/NVPTX/math-intrins.ll b/llvm/test/CodeGen/NVPTX/math-intrins.ll
index 9c0c4ae85e1f1..3ab63a04a4674 100644
--- a/llvm/test/CodeGen/NVPTX/math-intrins.ll
+++ b/llvm/test/CodeGen/NVPTX/math-intrins.ll
@@ -1,4 +1,6 @@
-; RUN: llc < %s | FileCheck %s
+; RUN: llc < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
+; RUN: llc < %s -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK-F16
+; RUN: llc < %s -mcpu=sm_80 --nvptx-no-f16-math | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
 target triple = "nvptx64-nvidia-cuda"
 
 ; Checks that llvm intrinsics for math functions are correctly lowered to PTX.
@@ -17,10 +19,14 @@ declare float @llvm.trunc.f32(float) #0
 declare double @llvm.trunc.f64(double) #0
 declare float @llvm.fabs.f32(float) #0
 declare double @llvm.fabs.f64(double) #0
+declare half @llvm.minnum.f16(half, half) #0
 declare float @llvm.minnum.f32(float, float) #0
 declare double @llvm.minnum.f64(double, double) #0
+declare <2 x half> @llvm.minnum.v2f16(<2 x half>, <2 x half>) #0
+declare half @llvm.maxnum.f16(half, half) #0
 declare float @llvm.maxnum.f32(float, float) #0
 declare double @llvm.maxnum.f64(double, double) #0
+declare <2 x half> @llvm.maxnum.v2f16(<2 x half>, <2 x half>) #0
 declare float @llvm.fma.f32(float, float, float) #0
 declare double @llvm.fma.f64(double, double, double) #0
 
@@ -193,6 +199,14 @@ define double @abs_double(double %a) {
 
 ; ---- min ----
 
+; CHECK-LABEL: min_half
+define half @min_half(half %a, half %b) {
+  ; CHECK-NOF16: min.f32
+  ; CHECK-F16: min.f16
+  %x = call half @llvm.minnum.f16(half %a, half %b)
+  ret half %x
+}
+
 ; CHECK-LABEL: min_float
 define float @min_float(float %a, float %b) {
   ; CHECK: min.f32
@@ -228,8 +242,25 @@ define double @min_double(double %a, double %b) {
   ret double %x
 }
 
+; CHECK-LABEL: min_v2half
+define <2 x half> @min_v2half(<2 x half> %a, <2 x half> %b) {
+  ; CHECK-NOF16: min.f32
+  ; CHECK-NOF16: min.f32
+  ; CHECK-F16: min.f16x2
+  %x = call <2 x half> @llvm.minnum.v2f16(<2 x half> %a, <2 x half> %b)
+  ret <2 x half> %x
+}
+
 ; ---- max ----
 
+; CHECK-LABEL: max_half
+define half @max_half(half %a, half %b) {
+  ; CHECK-NOF16: max.f32
+  ; CHECK-F16: max.f16
+  %x = call half @llvm.maxnum.f16(half %a, half %b)
+  ret half %x
+}
+
 ; CHECK-LABEL: max_imm1
 define float @max_imm1(float %a) {
   ; CHECK: max.f32
@@ -265,6 +296,15 @@ define double @max_double(double %a, double %b) {
   ret double %x
 }
 
+; CHECK-LABEL: max_v2half
+define <2 x half> @max_v2half(<2 x half> %a, <2 x half> %b) {
+  ; CHECK-NOF16: max.f32
+  ; CHECK-NOF16: max.f32
+  ; CHECK-F16: max.f16x2
+  %x = call <2 x half> @llvm.maxnum.v2f16(<2 x half> %a, <2 x half> %b)
+  ret <2 x half> %x
+}
+
 ; ---- fma ----
 
 ; CHECK-LABEL: @fma_float


        


More information about the llvm-commits mailing list