[llvm] 3ccaabe - [NVPTX] Lower llvm.roundeven to cvt.rni

Benjamin Kramer via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 25 04:37:11 PDT 2022


Author: Benjamin Kramer
Date: 2022-08-25T13:36:22+02:00
New Revision: 3ccaabe0517d70ce0bd47305e0d2477df3b0dff9

URL: https://github.com/llvm/llvm-project/commit/3ccaabe0517d70ce0bd47305e0d2477df3b0dff9
DIFF: https://github.com/llvm/llvm-project/commit/3ccaabe0517d70ce0bd47305e0d2477df3b0dff9.diff

LOG: [NVPTX] Lower llvm.roundeven to cvt.rni

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/test/CodeGen/NVPTX/f16-instructions.ll
    llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
    llvm/test/CodeGen/NVPTX/math-intrins.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8264032b765a9..ab2340f4a7af8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -571,7 +571,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
 
   // These map to conversion instructions for scalar FP types.
   for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
-                         ISD::FTRUNC}) {
+                         ISD::FROUNDEVEN, ISD::FTRUNC}) {
     setOperationAction(Op, MVT::f16, Legal);
     setOperationAction(Op, MVT::f32, Legal);
     setOperationAction(Op, MVT::f64, Legal);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 7b66d6280e964..a4c67cd4564c3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3084,7 +3084,7 @@ def : Pat<(f64 (fpextend Float32Regs:$a)),
 def retflag : SDNode<"NVPTXISD::RET_FLAG", SDTNone,
                      [SDNPHasChain, SDNPOptInGlue]>;
 
-// fceil, ffloor, fround, ftrunc.
+// fceil, ffloor, froundeven, ftrunc.
 
 multiclass CVT_ROUND<SDNode OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
   def : Pat<(OpNode Float16Regs:$a),
@@ -3099,6 +3099,7 @@ multiclass CVT_ROUND<SDNode OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
 
 defm : CVT_ROUND<fceil, CvtRPI, CvtRPI_FTZ>;
 defm : CVT_ROUND<ffloor, CvtRMI, CvtRMI_FTZ>;
+defm : CVT_ROUND<froundeven, CvtRNI, CvtRNI_FTZ>;
 defm : CVT_ROUND<ftrunc, CvtRZI, CvtRZI_FTZ>;
 
 // nearbyint and rint are implemented as rounding to nearest even.  This isn't

diff  --git a/llvm/test/CodeGen/NVPTX/f16-instructions.ll b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
index d623cda75ec40..4f59c44216a2d 100644
--- a/llvm/test/CodeGen/NVPTX/f16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
@@ -847,6 +847,7 @@ declare half @llvm.trunc.f16(half %a) #0
 declare half @llvm.rint.f16(half %a) #0
 declare half @llvm.nearbyint.f16(half %a) #0
 declare half @llvm.round.f16(half %a) #0
+declare half @llvm.roundeven.f16(half %a) #0
 declare half @llvm.fmuladd.f16(half %a, half %b, half %c) #0
 
 ; CHECK-LABEL: test_sqrt(
@@ -1127,6 +1128,16 @@ define half @test_nearbyint(half %a) #0 {
   ret half %r
 }
 
+; CHECK-LABEL: test_roundeven(
+; CHECK:      ld.param.b16    [[A:%h[0-9]+]], [test_roundeven_param_0];
+; CHECK:      cvt.rni.f16.f16 [[R:%h[0-9]+]], [[A]];
+; CHECK:      st.param.b16    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define half @test_roundeven(half %a) #0 {
+  %r = call half @llvm.roundeven.f16(half %a)
+  ret half %r
+}
+
 ; CHECK-LABEL: test_round(
 ; CHECK:      ld.param.b16    {{.*}}, [test_round_param_0];
 ; check the use of sign mask and 0.5 to implement round

diff  --git a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
index 44d83b967054e..f45093ea74d33 100644
--- a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
@@ -1059,6 +1059,7 @@ declare <2 x half> @llvm.trunc.f16(<2 x half> %a) #0
 declare <2 x half> @llvm.rint.f16(<2 x half> %a) #0
 declare <2 x half> @llvm.nearbyint.f16(<2 x half> %a) #0
 declare <2 x half> @llvm.round.f16(<2 x half> %a) #0
+declare <2 x half> @llvm.roundeven.f16(<2 x half> %a) #0
 declare <2 x half> @llvm.fmuladd.f16(<2 x half> %a, <2 x half> %b, <2 x half> %c) #0
 
 ; CHECK-LABEL: test_sqrt(
@@ -1426,6 +1427,19 @@ define <2 x half> @test_nearbyint(<2 x half> %a) #0 {
   ret <2 x half> %r
 }
 
+; CHECK-LABEL: test_roundeven(
+; CHECK:      ld.param.b32    [[A:%hh[0-9]+]], [test_roundeven_param_0];
+; CHECK-DAG:  mov.b32         {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]];
+; CHECK-DAG:  cvt.rni.f16.f16 [[R1:%h[0-9]+]], [[A1]];
+; CHECK-DAG:  cvt.rni.f16.f16 [[R0:%h[0-9]+]], [[A0]];
+; CHECK:      mov.b32         [[R:%hh[0-9]+]], {[[R0]], [[R1]]}
+; CHECK:      st.param.b32    [func_retval0+0], [[R]];
+; CHECK:      ret;
+define <2 x half> @test_roundeven(<2 x half> %a) #0 {
+  %r = call <2 x half> @llvm.roundeven.f16(<2 x half> %a)
+  ret <2 x half> %r
+}
+
 ; CHECK-LABEL: test_round(
 ; CHECK:      ld.param.b32    {{.*}}, [test_round_param_0];
 ; check the use of sign mask and 0.5 to implement round

diff  --git a/llvm/test/CodeGen/NVPTX/math-intrins.ll b/llvm/test/CodeGen/NVPTX/math-intrins.ll
index efd8ed0a37b7a..d31844549a322 100644
--- a/llvm/test/CodeGen/NVPTX/math-intrins.ll
+++ b/llvm/test/CodeGen/NVPTX/math-intrins.ll
@@ -19,6 +19,8 @@ declare float @llvm.nearbyint.f32(float) #0
 declare double @llvm.nearbyint.f64(double) #0
 declare float @llvm.rint.f32(float) #0
 declare double @llvm.rint.f64(double) #0
+declare float @llvm.roundeven.f32(float) #0
+declare double @llvm.roundeven.f64(double) #0
 declare float @llvm.trunc.f32(float) #0
 declare double @llvm.trunc.f64(double) #0
 declare float @llvm.fabs.f32(float) #0
@@ -155,6 +157,29 @@ define double @rint_double(double %a) {
   ret double %b
 }
 
+; ---- roundeven ----
+
+; CHECK-LABEL: roundeven_float
+define float @roundeven_float(float %a) {
+  ; CHECK: cvt.rni.f32.f32
+  %b = call float @llvm.roundeven.f32(float %a)
+  ret float %b
+}
+
+; CHECK-LABEL: roundeven_float_ftz
+define float @roundeven_float_ftz(float %a) #1 {
+  ; CHECK: cvt.rni.ftz.f32.f32
+  %b = call float @llvm.roundeven.f32(float %a)
+  ret float %b
+}
+
+; CHECK-LABEL: roundeven_double
+define double @roundeven_double(double %a) {
+  ; CHECK: cvt.rni.f64.f64
+  %b = call double @llvm.roundeven.f64(double %a)
+  ret double %b
+}
+
 ; ---- trunc ----
 
 ; CHECK-LABEL: trunc_float


        


More information about the llvm-commits mailing list