[llvm] cc1b9ac - [NVPTX] Lower fp16 fminnum, fmaxnum to native on sm_80.
Christian Sigg via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 12 23:52:39 PST 2022
Author: Christian Sigg
Date: 2022-01-13T08:52:31+01:00
New Revision: cc1b9acf550d13702a20cf7e5bac8a94bf2202ba
URL: https://github.com/llvm/llvm-project/commit/cc1b9acf550d13702a20cf7e5bac8a94bf2202ba
DIFF: https://github.com/llvm/llvm-project/commit/cc1b9acf550d13702a20cf7e5bac8a94bf2202ba.diff
LOG: [NVPTX] Lower fp16 fminnum, fmaxnum to native on sm_80.
Reviewed By: bkramer, tra
Differential Revision: https://reviews.llvm.org/D117122
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/test/CodeGen/NVPTX/math-intrins.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index e2f6b69fc530a..faa873f07cadf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -560,10 +560,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
}
- setOperationAction(ISD::FMINNUM, MVT::f16, Promote);
- setOperationAction(ISD::FMAXNUM, MVT::f16, Promote);
- setOperationAction(ISD::FMINIMUM, MVT::f16, Promote);
- setOperationAction(ISD::FMAXIMUM, MVT::f16, Promote);
+ // max.f16 is supported on sm_80+.
+ if (STI.allowFP16Math() && STI.getSmVersion() >= 80 &&
+ STI.getPTXVersion() >= 70) {
+ setOperationAction(ISD::FMINNUM, MVT::f16, Legal);
+ setOperationAction(ISD::FMAXNUM, MVT::f16, Legal);
+ setOperationAction(ISD::FMINNUM, MVT::v2f16, Legal);
+ setOperationAction(ISD::FMAXNUM, MVT::v2f16, Legal);
+ }
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
// No FPOW or FREM in PTX.
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 96386af569de6..03f7cd0c71524 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -249,6 +249,32 @@ multiclass F3<string OpcStr, SDNode OpNode> {
(ins Float32Regs:$a, f32imm:$b),
!strconcat(OpcStr, ".f32 \t$dst, $a, $b;"),
[(set Float32Regs:$dst, (OpNode Float32Regs:$a, fpimm:$b))]>;
+
+ def f16rr_ftz :
+ NVPTXInst<(outs Float16Regs:$dst),
+ (ins Float16Regs:$a, Float16Regs:$b),
+ !strconcat(OpcStr, ".ftz.f16 \t$dst, $a, $b;"),
+ [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
+ Requires<[useFP16Math, doF32FTZ]>;
+ def f16rr :
+ NVPTXInst<(outs Float16Regs:$dst),
+ (ins Float16Regs:$a, Float16Regs:$b),
+ !strconcat(OpcStr, ".f16 \t$dst, $a, $b;"),
+ [(set Float16Regs:$dst, (OpNode Float16Regs:$a, Float16Regs:$b))]>,
+ Requires<[useFP16Math]>;
+
+ def f16x2rr_ftz :
+ NVPTXInst<(outs Float16x2Regs:$dst),
+ (ins Float16x2Regs:$a, Float16x2Regs:$b),
+ !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a, $b;"),
+ [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
+ Requires<[useFP16Math, doF32FTZ]>;
+ def f16x2rr :
+ NVPTXInst<(outs Float16x2Regs:$dst),
+ (ins Float16x2Regs:$a, Float16x2Regs:$b),
+ !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
+ [(set Float16x2Regs:$dst, (OpNode Float16x2Regs:$a, Float16x2Regs:$b))]>,
+ Requires<[useFP16Math]>;
}
// Template for instructions which take three FP args. The
diff --git a/llvm/test/CodeGen/NVPTX/math-intrins.ll b/llvm/test/CodeGen/NVPTX/math-intrins.ll
index 9c0c4ae85e1f1..3ab63a04a4674 100644
--- a/llvm/test/CodeGen/NVPTX/math-intrins.ll
+++ b/llvm/test/CodeGen/NVPTX/math-intrins.ll
@@ -1,4 +1,6 @@
-; RUN: llc < %s | FileCheck %s
+; RUN: llc < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
+; RUN: llc < %s -mcpu=sm_80 | FileCheck %s --check-prefixes=CHECK,CHECK-F16
+; RUN: llc < %s -mcpu=sm_80 --nvptx-no-f16-math | FileCheck %s --check-prefixes=CHECK,CHECK-NOF16
target triple = "nvptx64-nvidia-cuda"
; Checks that llvm intrinsics for math functions are correctly lowered to PTX.
@@ -17,10 +19,14 @@ declare float @llvm.trunc.f32(float) #0
declare double @llvm.trunc.f64(double) #0
declare float @llvm.fabs.f32(float) #0
declare double @llvm.fabs.f64(double) #0
+declare half @llvm.minnum.f16(half, half) #0
declare float @llvm.minnum.f32(float, float) #0
declare double @llvm.minnum.f64(double, double) #0
+declare <2 x half> @llvm.minnum.v2f16(<2 x half>, <2 x half>) #0
+declare half @llvm.maxnum.f16(half, half) #0
declare float @llvm.maxnum.f32(float, float) #0
declare double @llvm.maxnum.f64(double, double) #0
+declare <2 x half> @llvm.maxnum.v2f16(<2 x half>, <2 x half>) #0
declare float @llvm.fma.f32(float, float, float) #0
declare double @llvm.fma.f64(double, double, double) #0
@@ -193,6 +199,14 @@ define double @abs_double(double %a) {
; ---- min ----
+; CHECK-LABEL: min_half
+define half @min_half(half %a, half %b) {
+ ; CHECK-NOF16: min.f32
+ ; CHECK-F16: min.f16
+ %x = call half @llvm.minnum.f16(half %a, half %b)
+ ret half %x
+}
+
; CHECK-LABEL: min_float
define float @min_float(float %a, float %b) {
; CHECK: min.f32
@@ -228,8 +242,25 @@ define double @min_double(double %a, double %b) {
ret double %x
}
+; CHECK-LABEL: min_v2half
+define <2 x half> @min_v2half(<2 x half> %a, <2 x half> %b) {
+ ; CHECK-NOF16: min.f32
+ ; CHECK-NOF16: min.f32
+ ; CHECK-F16: min.f16x2
+ %x = call <2 x half> @llvm.minnum.v2f16(<2 x half> %a, <2 x half> %b)
+ ret <2 x half> %x
+}
+
; ---- max ----
+; CHECK-LABEL: max_half
+define half @max_half(half %a, half %b) {
+ ; CHECK-NOF16: max.f32
+ ; CHECK-F16: max.f16
+ %x = call half @llvm.maxnum.f16(half %a, half %b)
+ ret half %x
+}
+
; CHECK-LABEL: max_imm1
define float @max_imm1(float %a) {
; CHECK: max.f32
@@ -265,6 +296,15 @@ define double @max_double(double %a, double %b) {
ret double %x
}
+; CHECK-LABEL: max_v2half
+define <2 x half> @max_v2half(<2 x half> %a, <2 x half> %b) {
+ ; CHECK-NOF16: max.f32
+ ; CHECK-NOF16: max.f32
+ ; CHECK-F16: max.f16x2
+ %x = call <2 x half> @llvm.maxnum.v2f16(<2 x half> %a, <2 x half> %b)
+ ret <2 x half> %x
+}
+
; ---- fma ----
; CHECK-LABEL: @fma_float
More information about the llvm-commits
mailing list