[llvm] 2c3f82b - [NVPTX] Fix NVPTX lowering of frem when denominator is infinite.

Christian Sigg via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 5 00:28:02 PST 2023


Author: Benjamin Chetioui
Date: 2023-01-05T09:27:54+01:00
New Revision: 2c3f82b7759691f3b67f7e5940e95ac3434b1a9c

URL: https://github.com/llvm/llvm-project/commit/2c3f82b7759691f3b67f7e5940e95ac3434b1a9c
DIFF: https://github.com/llvm/llvm-project/commit/2c3f82b7759691f3b67f7e5940e95ac3434b1a9c.diff

LOG: [NVPTX] Fix NVPTX lowering of frem when denominator is infinite.

`frem x, {+,-}inf` must return x to match the specification of LLVM's frem.

Reviewed By: tra

Differential Revision: https://reviews.llvm.org/D140846

Added: 
    

Modified: 
    llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/test/CodeGen/NVPTX/f16-instructions.ll
    llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
    llvm/test/CodeGen/NVPTX/fast-math.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a114d92397c91..b6a1394119805 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -132,6 +132,7 @@ def doMulWide      : Predicate<"doMulWide">;
 def allowFMA : Predicate<"allowFMA()">;
 def noFMA : Predicate<"!allowFMA()">;
 def allowUnsafeFPMath : Predicate<"allowUnsafeFPMath()">;
+def noUnsafeFPMath : Predicate<"!allowUnsafeFPMath()">;
 
 def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
 def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;
@@ -166,7 +167,7 @@ def hasSM80 : Predicate<"Subtarget->getSmVersion() >= 80">;
 def hasSM86 : Predicate<"Subtarget->getSmVersion() >= 86">;
 
 // non-sync shfl instructions are not available on sm_70+ in PTX6.4+
-def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" 
+def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
                           "&& Subtarget->getPTXVersion() >= 64)">;
 
 def useShortPtr : Predicate<"useShortPointers()">;
@@ -192,7 +193,7 @@ class ValueToRegClass<ValueType T> {
      !eq(name, "af32"): Float32ArgRegs,
      !eq(name, "if64"): Float64ArgRegs,
     );
-} 
+}
 
 
 //===----------------------------------------------------------------------===//
@@ -597,6 +598,99 @@ multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
   defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
 }
 
+//-----------------------------------
+// Selection instructions (selp)
+//-----------------------------------
+
+// TODO: Missing slct
+
+// selp instructions that don't have any pattern matches; we explicitly use
+// them within this file.
+let hasSideEffects = false in {
+  multiclass SELP<string TypeStr, RegisterClass RC, Operand ImmCls> {
+    def rr : NVPTXInst<(outs RC:$dst),
+                       (ins RC:$a, RC:$b, Int1Regs:$p),
+                       !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
+    def ri : NVPTXInst<(outs RC:$dst),
+                       (ins RC:$a, ImmCls:$b, Int1Regs:$p),
+                       !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
+    def ir : NVPTXInst<(outs RC:$dst),
+                       (ins ImmCls:$a, RC:$b, Int1Regs:$p),
+                       !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
+    def ii : NVPTXInst<(outs RC:$dst),
+                       (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
+                       !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
+  }
+
+  multiclass SELP_PATTERN<string TypeStr, ValueType T, RegisterClass RC,
+                          Operand ImmCls, SDNode ImmNode> {
+    def rr :
+      NVPTXInst<(outs RC:$dst),
+                (ins RC:$a, RC:$b, Int1Regs:$p),
+                !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
+                [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T RC:$b)))]>;
+    def ri :
+      NVPTXInst<(outs RC:$dst),
+                (ins RC:$a, ImmCls:$b, Int1Regs:$p),
+                !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
+                [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T ImmNode:$b)))]>;
+    def ir :
+      NVPTXInst<(outs RC:$dst),
+                (ins ImmCls:$a, RC:$b, Int1Regs:$p),
+                !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
+                [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, (T RC:$b)))]>;
+    def ii :
+      NVPTXInst<(outs RC:$dst),
+                (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
+                !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
+                [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, ImmNode:$b))]>;
+  }
+}
+
+// Don't pattern match on selp.{s,u}{16,32,64} -- selp.b{16,32,64} is just as
+// good.
+defm SELP_b16 : SELP_PATTERN<"b16", i16, Int16Regs, i16imm, imm>;
+defm SELP_s16 : SELP<"s16", Int16Regs, i16imm>;
+defm SELP_u16 : SELP<"u16", Int16Regs, i16imm>;
+defm SELP_b32 : SELP_PATTERN<"b32", i32, Int32Regs, i32imm, imm>;
+defm SELP_s32 : SELP<"s32", Int32Regs, i32imm>;
+defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>;
+defm SELP_b64 : SELP_PATTERN<"b64", i64, Int64Regs, i64imm, imm>;
+defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
+defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
+defm SELP_f16 : SELP_PATTERN<"b16", f16, Float16Regs, f16imm, fpimm>;
+
+defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
+defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
+
+// This does not work as tablegen fails to infer the type of 'imm'.
+// def v2f16imm : Operand<v2f16>;
+// defm SELP_f16x2 : SELP_PATTERN<"b32", v2f16, Float16x2Regs, v2f16imm, imm>;
+
+def SELP_f16x2rr :
+    NVPTXInst<(outs Float16x2Regs:$dst),
+              (ins Float16x2Regs:$a, Float16x2Regs:$b, Int1Regs:$p),
+              "selp.b32 \t$dst, $a, $b, $p;",
+              [(set Float16x2Regs:$dst,
+                    (select Int1Regs:$p, (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>;
+
+//-----------------------------------
+// Test Instructions
+//-----------------------------------
+
+def TESTINF_f32r : NVPTXInst<(outs Int1Regs:$p), (ins Float32Regs:$a),
+                             "testp.infinite.f32 \t$p, $a;",
+                             []>;
+def TESTINF_f32i : NVPTXInst<(outs Int1Regs:$p), (ins f32imm:$a),
+                             "testp.infinite.f32 \t$p, $a;",
+                             []>;
+def TESTINF_f64r : NVPTXInst<(outs Int1Regs:$p), (ins Float64Regs:$a),
+                             "testp.infinite.f64 \t$p, $a;",
+                             []>;
+def TESTINF_f64i : NVPTXInst<(outs Int1Regs:$p), (ins f64imm:$a),
+                             "testp.infinite.f64 \t$p, $a;",
+                             []>;
+
 //-----------------------------------
 // Integer Arithmetic
 //-----------------------------------
@@ -1154,39 +1248,89 @@ def COSF:  NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
                       Requires<[allowUnsafeFPMath]>;
 
 // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
-// i.e. "poor man's fmod()"
+// i.e. "poor man's fmod()". When y is infinite, x is returned. This matches the
+// semantics of LLVM's frem.
 
 // frem - f32 FTZ
 def : Pat<(frem Float32Regs:$x, Float32Regs:$y),
           (FSUBf32rr_ftz Float32Regs:$x, (FMULf32rr_ftz (CVT_f32_f32
             (FDIV32rr_prec_ftz Float32Regs:$x, Float32Regs:$y), CvtRZI_FTZ),
              Float32Regs:$y))>,
-          Requires<[doF32FTZ]>;
+          Requires<[doF32FTZ, allowUnsafeFPMath]>;
 def : Pat<(frem Float32Regs:$x, fpimm:$y),
           (FSUBf32rr_ftz Float32Regs:$x, (FMULf32ri_ftz (CVT_f32_f32
             (FDIV32ri_prec_ftz Float32Regs:$x, fpimm:$y), CvtRZI_FTZ),
              fpimm:$y))>,
-          Requires<[doF32FTZ]>;
+          Requires<[doF32FTZ, allowUnsafeFPMath]>;
+
+def : Pat<(frem Float32Regs:$x, Float32Regs:$y),
+          (SELP_f32rr Float32Regs:$x,
+            (FSUBf32rr_ftz Float32Regs:$x, (FMULf32rr_ftz (CVT_f32_f32
+              (FDIV32rr_prec_ftz Float32Regs:$x, Float32Regs:$y), CvtRZI_FTZ),
+              Float32Regs:$y)),
+            (TESTINF_f32r Float32Regs:$y))>,
+          Requires<[doF32FTZ, noUnsafeFPMath]>;
+def : Pat<(frem Float32Regs:$x, fpimm:$y),
+          (SELP_f32rr Float32Regs:$x,
+            (FSUBf32rr_ftz Float32Regs:$x, (FMULf32ri_ftz (CVT_f32_f32
+              (FDIV32ri_prec_ftz Float32Regs:$x, fpimm:$y), CvtRZI_FTZ),
+              fpimm:$y)),
+            (TESTINF_f32i fpimm:$y))>,
+          Requires<[doF32FTZ, noUnsafeFPMath]>;
 
 // frem - f32
 def : Pat<(frem Float32Regs:$x, Float32Regs:$y),
           (FSUBf32rr Float32Regs:$x, (FMULf32rr (CVT_f32_f32
             (FDIV32rr_prec Float32Regs:$x, Float32Regs:$y), CvtRZI),
-             Float32Regs:$y))>;
+             Float32Regs:$y))>,
+          Requires<[allowUnsafeFPMath]>;
 def : Pat<(frem Float32Regs:$x, fpimm:$y),
           (FSUBf32rr Float32Regs:$x, (FMULf32ri (CVT_f32_f32
             (FDIV32ri_prec Float32Regs:$x, fpimm:$y), CvtRZI),
-             fpimm:$y))>;
+             fpimm:$y))>,
+          Requires<[allowUnsafeFPMath]>;
+
+def : Pat<(frem Float32Regs:$x, Float32Regs:$y),
+          (SELP_f32rr Float32Regs:$x,
+            (FSUBf32rr Float32Regs:$x, (FMULf32rr (CVT_f32_f32
+              (FDIV32rr_prec Float32Regs:$x, Float32Regs:$y), CvtRZI),
+              Float32Regs:$y)),
+            (TESTINF_f32r Float32Regs:$y))>,
+          Requires<[noUnsafeFPMath]>;
+def : Pat<(frem Float32Regs:$x, fpimm:$y),
+          (SELP_f32rr Float32Regs:$x,
+            (FSUBf32rr Float32Regs:$x, (FMULf32ri (CVT_f32_f32
+              (FDIV32ri_prec Float32Regs:$x, fpimm:$y), CvtRZI),
+              fpimm:$y)),
+            (TESTINF_f32i fpimm:$y))>,
+          Requires<[noUnsafeFPMath]>;
 
 // frem - f64
 def : Pat<(frem Float64Regs:$x, Float64Regs:$y),
           (FSUBf64rr Float64Regs:$x, (FMULf64rr (CVT_f64_f64
             (FDIV64rr Float64Regs:$x, Float64Regs:$y), CvtRZI),
-             Float64Regs:$y))>;
+             Float64Regs:$y))>,
+          Requires<[allowUnsafeFPMath]>;
 def : Pat<(frem Float64Regs:$x, fpimm:$y),
           (FSUBf64rr Float64Regs:$x, (FMULf64ri (CVT_f64_f64
             (FDIV64ri Float64Regs:$x, fpimm:$y), CvtRZI),
-             fpimm:$y))>;
+             fpimm:$y))>,
+          Requires<[allowUnsafeFPMath]>;
+
+def : Pat<(frem Float64Regs:$x, Float64Regs:$y),
+          (SELP_f64rr Float64Regs:$x,
+            (FSUBf64rr Float64Regs:$x, (FMULf64rr (CVT_f64_f64
+              (FDIV64rr Float64Regs:$x, Float64Regs:$y), CvtRZI),
+               Float64Regs:$y)),
+            (TESTINF_f64r Float64Regs:$y))>,
+          Requires<[noUnsafeFPMath]>;
+def : Pat<(frem Float64Regs:$x, fpimm:$y),
+          (SELP_f64rr Float64Regs:$x,
+            (FSUBf64rr Float64Regs:$x, (FMULf64ri (CVT_f64_f64
+              (FDIV64ri Float64Regs:$x, fpimm:$y), CvtRZI),
+              fpimm:$y)),
+            (TESTINF_f64r Float64Regs:$y))>,
+          Requires<[noUnsafeFPMath]>;
 
 //-----------------------------------
 // Bitwise operations
@@ -1568,82 +1712,6 @@ defm SET_f16 : SET<"f16", Float16Regs, f16imm>;
 defm SET_f32 : SET<"f32", Float32Regs, f32imm>;
 defm SET_f64 : SET<"f64", Float64Regs, f64imm>;
 
-//-----------------------------------
-// Selection instructions (selp)
-//-----------------------------------
-
-// FIXME: Missing slct
-
-// selp instructions that don't have any pattern matches; we explicitly use
-// them within this file.
-let hasSideEffects = false in {
-  multiclass SELP<string TypeStr, RegisterClass RC, Operand ImmCls> {
-    def rr : NVPTXInst<(outs RC:$dst),
-                       (ins RC:$a, RC:$b, Int1Regs:$p),
-                       !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
-    def ri : NVPTXInst<(outs RC:$dst),
-                       (ins RC:$a, ImmCls:$b, Int1Regs:$p),
-                       !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
-    def ir : NVPTXInst<(outs RC:$dst),
-                       (ins ImmCls:$a, RC:$b, Int1Regs:$p),
-                       !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
-    def ii : NVPTXInst<(outs RC:$dst),
-                       (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
-                       !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"), []>;
-  }
-
-  multiclass SELP_PATTERN<string TypeStr, ValueType T, RegisterClass RC,
-                          Operand ImmCls, SDNode ImmNode> {
-    def rr :
-      NVPTXInst<(outs RC:$dst),
-                (ins RC:$a, RC:$b, Int1Regs:$p),
-                !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
-                [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T RC:$b)))]>;
-    def ri :
-      NVPTXInst<(outs RC:$dst),
-                (ins RC:$a, ImmCls:$b, Int1Regs:$p),
-                !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
-                [(set (T RC:$dst), (select Int1Regs:$p, (T RC:$a), (T ImmNode:$b)))]>;
-    def ir :
-      NVPTXInst<(outs RC:$dst),
-                (ins ImmCls:$a, RC:$b, Int1Regs:$p),
-                !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
-                [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, (T RC:$b)))]>;
-    def ii :
-      NVPTXInst<(outs RC:$dst),
-                (ins ImmCls:$a, ImmCls:$b, Int1Regs:$p),
-                !strconcat("selp.", TypeStr, " \t$dst, $a, $b, $p;"),
-                [(set (T RC:$dst), (select Int1Regs:$p, ImmNode:$a, ImmNode:$b))]>;
-  }
-}
-
-// Don't pattern match on selp.{s,u}{16,32,64} -- selp.b{16,32,64} is just as
-// good.
-defm SELP_b16 : SELP_PATTERN<"b16", i16, Int16Regs, i16imm, imm>;
-defm SELP_s16 : SELP<"s16", Int16Regs, i16imm>;
-defm SELP_u16 : SELP<"u16", Int16Regs, i16imm>;
-defm SELP_b32 : SELP_PATTERN<"b32", i32, Int32Regs, i32imm, imm>;
-defm SELP_s32 : SELP<"s32", Int32Regs, i32imm>;
-defm SELP_u32 : SELP<"u32", Int32Regs, i32imm>;
-defm SELP_b64 : SELP_PATTERN<"b64", i64, Int64Regs, i64imm, imm>;
-defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
-defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
-defm SELP_f16 : SELP_PATTERN<"b16", f16, Float16Regs, f16imm, fpimm>;
-
-defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
-defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
-
-// This does not work as tablegen fails to infer the type of 'imm'.
-//def v2f16imm : Operand<v2f16>;
-//defm SELP_f16x2 : SELP_PATTERN<"b32", v2f16, Float16x2Regs, v2f16imm, imm>;
-
-def SELP_f16x2rr :
-    NVPTXInst<(outs Float16x2Regs:$dst),
-              (ins Float16x2Regs:$a, Float16x2Regs:$b, Int1Regs:$p),
-              "selp.b32 \t$dst, $a, $b, $p;",
-              [(set Float16x2Regs:$dst,
-                    (select Int1Regs:$p, (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>;
-
 //-----------------------------------
 // Data Movement (Load / Store, Move)
 //-----------------------------------
@@ -1879,7 +1947,7 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
   def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))),
             (SETP_f16rr Float16Regs:$a, Float16Regs:$b, ModeFTZ)>,
         Requires<[useFP16Math,doF32FTZ]>;
-  def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))), 
+  def : Pat<(i1 (OpNode (f16 Float16Regs:$a), (f16 Float16Regs:$b))),
             (SETP_f16rr Float16Regs:$a, Float16Regs:$b, Mode)>,
         Requires<[useFP16Math]>;
   def : Pat<(i1 (OpNode (f16 Float16Regs:$a), fpimm:$b)),
@@ -2700,7 +2768,7 @@ let mayStore=1, hasSideEffects=0 in {
 
 //---- Conversion ----
 
-class F_BITCONVERT<string SzStr, ValueType TIn, ValueType TOut, 
+class F_BITCONVERT<string SzStr, ValueType TIn, ValueType TOut,
   NVPTXRegClass regclassIn = ValueToRegClass<TIn>.ret,
   NVPTXRegClass regclassOut = ValueToRegClass<TOut>.ret> :
            NVPTXInst<(outs regclassOut:$d), (ins regclassIn:$a),

diff  --git a/llvm/test/CodeGen/NVPTX/f16-instructions.ll b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
index ca432fe1715e1..2ed795de28ff1 100644
--- a/llvm/test/CodeGen/NVPTX/f16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
@@ -196,8 +196,10 @@ define half @test_fdiv(half %a, half %b) #0 {
 ; CHECK-F16-FTZ-NEXT: cvt.rzi.ftz.f32.f32 [[DI:%f[0-9]+]], [[D]];
 ; CHECK-F16-FTZ-NEXT: mul.ftz.f32         [[RI:%f[0-9]+]], [[DI]], [[FB]];
 ; CHECK-F16-FTZ-NEXT: sub.ftz.f32         [[RF:%f[0-9]+]], [[FA]], [[RI]];
-; CHECK-NEXT: cvt.rn.f16.f32  [[R:%h[0-9]+]], [[RF]];
-; CHECK-NEXT: st.param.b16    [func_retval0+0], [[R]];
+; CHECK-NEXT: testp.infinite.f32 [[ISBINF:%p[0-9]+]], [[FB]];
+; CHECK-NEXT: selp.f32           [[RESULT:%f[0-9]+]], [[FA]], [[RF]], [[ISBINF]];
+; CHECK-NEXT: cvt.rn.f16.f32     [[R:%h[0-9]+]], [[RESULT]];
+; CHECK-NEXT: st.param.b16       [func_retval0+0], [[R]];
 ; CHECK-NEXT: ret;
 define half @test_frem(half %a, half %b) #0 {
   %r = frem half %a, %b

diff  --git a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
index c83e370af1fbf..4cbe46b633ac8 100644
--- a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
@@ -240,26 +240,30 @@ define <2 x half> @test_fdiv(<2 x half> %a, <2 x half> %b) #0 {
 
 ; CHECK-LABEL: test_frem(
 ; -- Load two 16x2 inputs and split them into f16 elements
-; CHECK-DAG:  ld.param.b32    [[A:%hh[0-9]+]], [test_frem_param_0];
-; CHECK-DAG:  ld.param.b32    [[B:%hh[0-9]+]], [test_frem_param_1];
+; CHECK-DAG:  ld.param.b32       [[A:%hh[0-9]+]], [test_frem_param_0];
+; CHECK-DAG:  ld.param.b32       [[B:%hh[0-9]+]], [test_frem_param_1];
 ; -- Split into elements
-; CHECK-DAG:  mov.b32         {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]]
-; CHECK-DAG:  mov.b32         {[[B0:%h[0-9]+]], [[B1:%h[0-9]+]]}, [[B]]
+; CHECK-DAG:  mov.b32            {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]]
+; CHECK-DAG:  mov.b32            {[[B0:%h[0-9]+]], [[B1:%h[0-9]+]]}, [[B]]
 ; -- promote to f32.
-; CHECK-DAG:  cvt.f32.f16     [[FA0:%f[0-9]+]], [[A0]];
-; CHECK-DAG:  cvt.f32.f16     [[FB0:%f[0-9]+]], [[B0]];
-; CHECK-DAG:  cvt.f32.f16     [[FA1:%f[0-9]+]], [[A1]];
-; CHECK-DAG:  cvt.f32.f16     [[FB1:%f[0-9]+]], [[B1]];
+; CHECK-DAG:  cvt.f32.f16        [[FA0:%f[0-9]+]], [[A0]];
+; CHECK-DAG:  cvt.f32.f16        [[FB0:%f[0-9]+]], [[B0]];
+; CHECK-DAG:  cvt.f32.f16        [[FA1:%f[0-9]+]], [[A1]];
+; CHECK-DAG:  cvt.f32.f16        [[FB1:%f[0-9]+]], [[B1]];
 ; -- frem(a[0],b[0]).
-; CHECK-DAG:  div.rn.f32      [[FD0:%f[0-9]+]], [[FA0]], [[FB0]];
-; CHECK-DAG:  cvt.rzi.f32.f32 [[DI0:%f[0-9]+]], [[FD0]];
-; CHECK-DAG:  mul.f32         [[RI0:%f[0-9]+]], [[DI0]], [[FB0]];
-; CHECK-DAG:  sub.f32         [[RF0:%f[0-9]+]], [[FA0]], [[RI0]];
+; CHECK-DAG:  div.rn.f32         [[FD0:%f[0-9]+]], [[FA0]], [[FB0]];
+; CHECK-DAG:  cvt.rzi.f32.f32    [[DI0:%f[0-9]+]], [[FD0]];
+; CHECK-DAG:  mul.f32            [[RI0:%f[0-9]+]], [[DI0]], [[FB0]];
+; CHECK-DAG:  sub.f32            [[RFNINF0:%f[0-9]+]], [[FA0]], [[RI0]];
+; CHECK-DAG:  testp.infinite.f32 [[ISB0INF:%p[0-9]+]], [[FB0]];
+; CHECK-DAG:  selp.f32           [[RF0:%f[0-9]+]], [[FA0]], [[RFNINF0]], [[ISB0INF]];
 ; -- frem(a[1],b[1]).
-; CHECK-DAG:  div.rn.f32      [[FD1:%f[0-9]+]], [[FA1]], [[FB1]];
-; CHECK-DAG:  cvt.rzi.f32.f32 [[DI1:%f[0-9]+]], [[FD1]];
-; CHECK-DAG:  mul.f32         [[RI1:%f[0-9]+]], [[DI1]], [[FB1]];
-; CHECK-DAG:  sub.f32         [[RF1:%f[0-9]+]], [[FA1]], [[RI1]];
+; CHECK-DAG:  div.rn.f32         [[FD1:%f[0-9]+]], [[FA1]], [[FB1]];
+; CHECK-DAG:  cvt.rzi.f32.f32    [[DI1:%f[0-9]+]], [[FD1]];
+; CHECK-DAG:  mul.f32            [[RI1:%f[0-9]+]], [[DI1]], [[FB1]];
+; CHECK-DAG:  sub.f32            [[RFNINF1:%f[0-9]+]], [[FA1]], [[RI1]];
+; CHECK-DAG:  testp.infinite.f32 [[ISB1INF:%p[0-9]+]], [[FB1]];
+; CHECK-DAG:  selp.f32           [[RF1:%f[0-9]+]], [[FA1]], [[RFNINF1]], [[ISB1INF]];
 ; -- convert back to f16.
 ; CHECK-DAG:  cvt.rn.f16.f32  [[R0:%h[0-9]+]], [[RF0]];
 ; CHECK-DAG:  cvt.rn.f16.f32  [[R1:%h[0-9]+]], [[RF1]];

diff  --git a/llvm/test/CodeGen/NVPTX/fast-math.ll b/llvm/test/CodeGen/NVPTX/fast-math.ll
index ceeef54a9295f..cf86de2d54938 100644
--- a/llvm/test/CodeGen/NVPTX/fast-math.ll
+++ b/llvm/test/CodeGen/NVPTX/fast-math.ll
@@ -241,5 +241,26 @@ define float @repeated_div_fast_ftz_sel(i1 %pred, float %a, float %b, float %div
   ret float %w
 }
 
+; CHECK-LABEL: frem
+define float @frem(float %a, float %b) #0 {
+  ; CHECK-NOT: testp.infinite
+  %rem = frem float %a, %b
+  ret float %rem
+}
+
+; CHECK-LABEL: frem_ftz
+define float @frem_ftz(float %a, float %b) #0 #1 {
+  ; CHECK-NOT: testp.infinite
+  %rem = frem float %a, %b
+  ret float %rem
+}
+
+; CHECK-LABEL: frem_f64
+define double @frem_f64(double %a, double %b) #0 {
+  ; CHECK-NOT: testp.infinite
+  %rem = frem double %a, %b
+  ret double %rem
+}
+
 attributes #0 = { "unsafe-fp-math" = "true" }
 attributes #1 = { "denormal-fp-math-f32" = "preserve-sign" }


        


More information about the llvm-commits mailing list