[llvm] [NVPTX] Use fast-math flags when lowering sin, cos, frem (PR #133121)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Wed Mar 26 16:26:27 PDT 2025
https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/133121
>From 0019ac76fa2babc9ebdbf3824d1696e0a91f34cd Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 26 Mar 2025 16:09:50 +0000
Subject: [PATCH 1/3] pre-commit tests
---
llvm/test/CodeGen/NVPTX/fast-math.ll | 14 ++
llvm/test/CodeGen/NVPTX/frem.ll | 300 +++++++++++++++++++++++++++
2 files changed, 314 insertions(+)
create mode 100644 llvm/test/CodeGen/NVPTX/frem.ll
diff --git a/llvm/test/CodeGen/NVPTX/fast-math.ll b/llvm/test/CodeGen/NVPTX/fast-math.ll
index d45ce15298f9d..4cb6a35e796fb 100644
--- a/llvm/test/CodeGen/NVPTX/fast-math.ll
+++ b/llvm/test/CodeGen/NVPTX/fast-math.ll
@@ -131,6 +131,20 @@ define float @fadd_ftz(float %a, float %b) #1 {
declare float @llvm.sin.f32(float)
declare float @llvm.cos.f32(float)
+; CHECK-LABEL: fsin_approx_afn
+; CHECK: sin.approx.f32
+define float @fsin_approx_afn(float %a) {
+ %r = tail call afn float @llvm.sin.f32(float %a)
+ ret float %r
+}
+
+; CHECK-LABEL: fcos_approx_afn
+; CHECK: cos.approx.f32
+define float @fcos_approx_afn(float %a) {
+ %r = tail call afn float @llvm.cos.f32(float %a)
+ ret float %r
+}
+
; CHECK-LABEL: fsin_approx
; CHECK: sin.approx.f32
define float @fsin_approx(float %a) #0 {
diff --git a/llvm/test/CodeGen/NVPTX/frem.ll b/llvm/test/CodeGen/NVPTX/frem.ll
new file mode 100644
index 0000000000000..f40bb1303820d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/frem.ll
@@ -0,0 +1,300 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s --enable-unsafe-fp-math | FileCheck %s --check-prefixes=FAST
+; RUN: llc < %s | FileCheck %s --check-prefixes=NORMAL
+
+
+target triple = "nvptx64-unknown-cuda"
+
+define half @frem_f16(half %a, half %b) {
+; FAST-LABEL: frem_f16(
+; FAST: {
+; FAST-NEXT: .reg .b16 %rs<4>;
+; FAST-NEXT: .reg .f32 %f<7>;
+; FAST-EMPTY:
+; FAST-NEXT: // %bb.0:
+; FAST-NEXT: ld.param.b16 %rs1, [frem_f16_param_0];
+; FAST-NEXT: ld.param.b16 %rs2, [frem_f16_param_1];
+; FAST-NEXT: cvt.f32.f16 %f1, %rs2;
+; FAST-NEXT: cvt.f32.f16 %f2, %rs1;
+; FAST-NEXT: div.rn.f32 %f3, %f2, %f1;
+; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; FAST-NEXT: mul.f32 %f5, %f4, %f1;
+; FAST-NEXT: sub.f32 %f6, %f2, %f5;
+; FAST-NEXT: cvt.rn.f16.f32 %rs3, %f6;
+; FAST-NEXT: st.param.b16 [func_retval0], %rs3;
+; FAST-NEXT: ret;
+;
+; NORMAL-LABEL: frem_f16(
+; NORMAL: {
+; NORMAL-NEXT: .reg .pred %p<2>;
+; NORMAL-NEXT: .reg .b16 %rs<4>;
+; NORMAL-NEXT: .reg .f32 %f<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT: // %bb.0:
+; NORMAL-NEXT: ld.param.b16 %rs1, [frem_f16_param_0];
+; NORMAL-NEXT: ld.param.b16 %rs2, [frem_f16_param_1];
+; NORMAL-NEXT: cvt.f32.f16 %f1, %rs2;
+; NORMAL-NEXT: cvt.f32.f16 %f2, %rs1;
+; NORMAL-NEXT: div.rn.f32 %f3, %f2, %f1;
+; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; NORMAL-NEXT: mul.f32 %f5, %f4, %f1;
+; NORMAL-NEXT: sub.f32 %f6, %f2, %f5;
+; NORMAL-NEXT: testp.infinite.f32 %p1, %f1;
+; NORMAL-NEXT: selp.f32 %f7, %f2, %f6, %p1;
+; NORMAL-NEXT: cvt.rn.f16.f32 %rs3, %f7;
+; NORMAL-NEXT: st.param.b16 [func_retval0], %rs3;
+; NORMAL-NEXT: ret;
+ %r = frem half %a, %b
+ ret half %r
+}
+
+define float @frem_f32(float %a, float %b) {
+; FAST-LABEL: frem_f32(
+; FAST: {
+; FAST-NEXT: .reg .f32 %f<7>;
+; FAST-EMPTY:
+; FAST-NEXT: // %bb.0:
+; FAST-NEXT: ld.param.f32 %f1, [frem_f32_param_0];
+; FAST-NEXT: ld.param.f32 %f2, [frem_f32_param_1];
+; FAST-NEXT: div.rn.f32 %f3, %f1, %f2;
+; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; FAST-NEXT: mul.f32 %f5, %f4, %f2;
+; FAST-NEXT: sub.f32 %f6, %f1, %f5;
+; FAST-NEXT: st.param.f32 [func_retval0], %f6;
+; FAST-NEXT: ret;
+;
+; NORMAL-LABEL: frem_f32(
+; NORMAL: {
+; NORMAL-NEXT: .reg .pred %p<2>;
+; NORMAL-NEXT: .reg .f32 %f<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT: // %bb.0:
+; NORMAL-NEXT: ld.param.f32 %f1, [frem_f32_param_0];
+; NORMAL-NEXT: ld.param.f32 %f2, [frem_f32_param_1];
+; NORMAL-NEXT: div.rn.f32 %f3, %f1, %f2;
+; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; NORMAL-NEXT: mul.f32 %f5, %f4, %f2;
+; NORMAL-NEXT: sub.f32 %f6, %f1, %f5;
+; NORMAL-NEXT: testp.infinite.f32 %p1, %f2;
+; NORMAL-NEXT: selp.f32 %f7, %f1, %f6, %p1;
+; NORMAL-NEXT: st.param.f32 [func_retval0], %f7;
+; NORMAL-NEXT: ret;
+ %r = frem float %a, %b
+ ret float %r
+}
+
+define double @frem_f64(double %a, double %b) {
+; FAST-LABEL: frem_f64(
+; FAST: {
+; FAST-NEXT: .reg .f64 %fd<7>;
+; FAST-EMPTY:
+; FAST-NEXT: // %bb.0:
+; FAST-NEXT: ld.param.f64 %fd1, [frem_f64_param_0];
+; FAST-NEXT: ld.param.f64 %fd2, [frem_f64_param_1];
+; FAST-NEXT: div.rn.f64 %fd3, %fd1, %fd2;
+; FAST-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
+; FAST-NEXT: mul.f64 %fd5, %fd4, %fd2;
+; FAST-NEXT: sub.f64 %fd6, %fd1, %fd5;
+; FAST-NEXT: st.param.f64 [func_retval0], %fd6;
+; FAST-NEXT: ret;
+;
+; NORMAL-LABEL: frem_f64(
+; NORMAL: {
+; NORMAL-NEXT: .reg .pred %p<2>;
+; NORMAL-NEXT: .reg .f64 %fd<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT: // %bb.0:
+; NORMAL-NEXT: ld.param.f64 %fd1, [frem_f64_param_0];
+; NORMAL-NEXT: ld.param.f64 %fd2, [frem_f64_param_1];
+; NORMAL-NEXT: div.rn.f64 %fd3, %fd1, %fd2;
+; NORMAL-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
+; NORMAL-NEXT: mul.f64 %fd5, %fd4, %fd2;
+; NORMAL-NEXT: sub.f64 %fd6, %fd1, %fd5;
+; NORMAL-NEXT: testp.infinite.f64 %p1, %fd2;
+; NORMAL-NEXT: selp.f64 %fd7, %fd1, %fd6, %p1;
+; NORMAL-NEXT: st.param.f64 [func_retval0], %fd7;
+; NORMAL-NEXT: ret;
+ %r = frem double %a, %b
+ ret double %r
+}
+
+define half @frem_f16_ninf(half %a, half %b) {
+; FAST-LABEL: frem_f16_ninf(
+; FAST: {
+; FAST-NEXT: .reg .b16 %rs<4>;
+; FAST-NEXT: .reg .f32 %f<7>;
+; FAST-EMPTY:
+; FAST-NEXT: // %bb.0:
+; FAST-NEXT: ld.param.b16 %rs1, [frem_f16_ninf_param_0];
+; FAST-NEXT: ld.param.b16 %rs2, [frem_f16_ninf_param_1];
+; FAST-NEXT: cvt.f32.f16 %f1, %rs2;
+; FAST-NEXT: cvt.f32.f16 %f2, %rs1;
+; FAST-NEXT: div.rn.f32 %f3, %f2, %f1;
+; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; FAST-NEXT: mul.f32 %f5, %f4, %f1;
+; FAST-NEXT: sub.f32 %f6, %f2, %f5;
+; FAST-NEXT: cvt.rn.f16.f32 %rs3, %f6;
+; FAST-NEXT: st.param.b16 [func_retval0], %rs3;
+; FAST-NEXT: ret;
+;
+; NORMAL-LABEL: frem_f16_ninf(
+; NORMAL: {
+; NORMAL-NEXT: .reg .pred %p<2>;
+; NORMAL-NEXT: .reg .b16 %rs<4>;
+; NORMAL-NEXT: .reg .f32 %f<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT: // %bb.0:
+; NORMAL-NEXT: ld.param.b16 %rs1, [frem_f16_ninf_param_0];
+; NORMAL-NEXT: ld.param.b16 %rs2, [frem_f16_ninf_param_1];
+; NORMAL-NEXT: cvt.f32.f16 %f1, %rs2;
+; NORMAL-NEXT: cvt.f32.f16 %f2, %rs1;
+; NORMAL-NEXT: div.rn.f32 %f3, %f2, %f1;
+; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; NORMAL-NEXT: mul.f32 %f5, %f4, %f1;
+; NORMAL-NEXT: sub.f32 %f6, %f2, %f5;
+; NORMAL-NEXT: testp.infinite.f32 %p1, %f1;
+; NORMAL-NEXT: selp.f32 %f7, %f2, %f6, %p1;
+; NORMAL-NEXT: cvt.rn.f16.f32 %rs3, %f7;
+; NORMAL-NEXT: st.param.b16 [func_retval0], %rs3;
+; NORMAL-NEXT: ret;
+ %r = frem ninf half %a, %b
+ ret half %r
+}
+
+define float @frem_f32_ninf(float %a, float %b) {
+; FAST-LABEL: frem_f32_ninf(
+; FAST: {
+; FAST-NEXT: .reg .f32 %f<7>;
+; FAST-EMPTY:
+; FAST-NEXT: // %bb.0:
+; FAST-NEXT: ld.param.f32 %f1, [frem_f32_ninf_param_0];
+; FAST-NEXT: ld.param.f32 %f2, [frem_f32_ninf_param_1];
+; FAST-NEXT: div.rn.f32 %f3, %f1, %f2;
+; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; FAST-NEXT: mul.f32 %f5, %f4, %f2;
+; FAST-NEXT: sub.f32 %f6, %f1, %f5;
+; FAST-NEXT: st.param.f32 [func_retval0], %f6;
+; FAST-NEXT: ret;
+;
+; NORMAL-LABEL: frem_f32_ninf(
+; NORMAL: {
+; NORMAL-NEXT: .reg .pred %p<2>;
+; NORMAL-NEXT: .reg .f32 %f<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT: // %bb.0:
+; NORMAL-NEXT: ld.param.f32 %f1, [frem_f32_ninf_param_0];
+; NORMAL-NEXT: ld.param.f32 %f2, [frem_f32_ninf_param_1];
+; NORMAL-NEXT: div.rn.f32 %f3, %f1, %f2;
+; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; NORMAL-NEXT: mul.f32 %f5, %f4, %f2;
+; NORMAL-NEXT: sub.f32 %f6, %f1, %f5;
+; NORMAL-NEXT: testp.infinite.f32 %p1, %f2;
+; NORMAL-NEXT: selp.f32 %f7, %f1, %f6, %p1;
+; NORMAL-NEXT: st.param.f32 [func_retval0], %f7;
+; NORMAL-NEXT: ret;
+ %r = frem ninf float %a, %b
+ ret float %r
+}
+
+define double @frem_f64_ninf(double %a, double %b) {
+; FAST-LABEL: frem_f64_ninf(
+; FAST: {
+; FAST-NEXT: .reg .f64 %fd<7>;
+; FAST-EMPTY:
+; FAST-NEXT: // %bb.0:
+; FAST-NEXT: ld.param.f64 %fd1, [frem_f64_ninf_param_0];
+; FAST-NEXT: ld.param.f64 %fd2, [frem_f64_ninf_param_1];
+; FAST-NEXT: div.rn.f64 %fd3, %fd1, %fd2;
+; FAST-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
+; FAST-NEXT: mul.f64 %fd5, %fd4, %fd2;
+; FAST-NEXT: sub.f64 %fd6, %fd1, %fd5;
+; FAST-NEXT: st.param.f64 [func_retval0], %fd6;
+; FAST-NEXT: ret;
+;
+; NORMAL-LABEL: frem_f64_ninf(
+; NORMAL: {
+; NORMAL-NEXT: .reg .pred %p<2>;
+; NORMAL-NEXT: .reg .f64 %fd<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT: // %bb.0:
+; NORMAL-NEXT: ld.param.f64 %fd1, [frem_f64_ninf_param_0];
+; NORMAL-NEXT: ld.param.f64 %fd2, [frem_f64_ninf_param_1];
+; NORMAL-NEXT: div.rn.f64 %fd3, %fd1, %fd2;
+; NORMAL-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
+; NORMAL-NEXT: mul.f64 %fd5, %fd4, %fd2;
+; NORMAL-NEXT: sub.f64 %fd6, %fd1, %fd5;
+; NORMAL-NEXT: testp.infinite.f64 %p1, %fd2;
+; NORMAL-NEXT: selp.f64 %fd7, %fd1, %fd6, %p1;
+; NORMAL-NEXT: st.param.f64 [func_retval0], %fd7;
+; NORMAL-NEXT: ret;
+ %r = frem ninf double %a, %b
+ ret double %r
+}
+
+define float @frem_f32_imm1(float %a) {
+; FAST-LABEL: frem_f32_imm1(
+; FAST: {
+; FAST-NEXT: .reg .f32 %f<6>;
+; FAST-EMPTY:
+; FAST-NEXT: // %bb.0:
+; FAST-NEXT: ld.param.f32 %f1, [frem_f32_imm1_param_0];
+; FAST-NEXT: div.rn.f32 %f2, %f1, 0f40E00000;
+; FAST-NEXT: cvt.rzi.f32.f32 %f3, %f2;
+; FAST-NEXT: mul.f32 %f4, %f3, 0f40E00000;
+; FAST-NEXT: sub.f32 %f5, %f1, %f4;
+; FAST-NEXT: st.param.f32 [func_retval0], %f5;
+; FAST-NEXT: ret;
+;
+; NORMAL-LABEL: frem_f32_imm1(
+; NORMAL: {
+; NORMAL-NEXT: .reg .pred %p<2>;
+; NORMAL-NEXT: .reg .f32 %f<7>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT: // %bb.0:
+; NORMAL-NEXT: ld.param.f32 %f1, [frem_f32_imm1_param_0];
+; NORMAL-NEXT: div.rn.f32 %f2, %f1, 0f40E00000;
+; NORMAL-NEXT: cvt.rzi.f32.f32 %f3, %f2;
+; NORMAL-NEXT: mul.f32 %f4, %f3, 0f40E00000;
+; NORMAL-NEXT: sub.f32 %f5, %f1, %f4;
+; NORMAL-NEXT: testp.infinite.f32 %p1, 0f40E00000;
+; NORMAL-NEXT: selp.f32 %f6, %f1, %f5, %p1;
+; NORMAL-NEXT: st.param.f32 [func_retval0], %f6;
+; NORMAL-NEXT: ret;
+ %r = frem float %a, 7.0
+ ret float %r
+}
+
+define float @frem_f32_imm2(float %a) {
+; FAST-LABEL: frem_f32_imm2(
+; FAST: {
+; FAST-NEXT: .reg .f32 %f<7>;
+; FAST-EMPTY:
+; FAST-NEXT: // %bb.0:
+; FAST-NEXT: ld.param.f32 %f1, [frem_f32_imm2_param_0];
+; FAST-NEXT: mov.b32 %f2, 0f40E00000;
+; FAST-NEXT: div.rn.f32 %f3, %f2, %f1;
+; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; FAST-NEXT: mul.f32 %f5, %f4, %f1;
+; FAST-NEXT: sub.f32 %f6, %f2, %f5;
+; FAST-NEXT: st.param.f32 [func_retval0], %f6;
+; FAST-NEXT: ret;
+;
+; NORMAL-LABEL: frem_f32_imm2(
+; NORMAL: {
+; NORMAL-NEXT: .reg .pred %p<2>;
+; NORMAL-NEXT: .reg .f32 %f<8>;
+; NORMAL-EMPTY:
+; NORMAL-NEXT: // %bb.0:
+; NORMAL-NEXT: ld.param.f32 %f1, [frem_f32_imm2_param_0];
+; NORMAL-NEXT: mov.b32 %f2, 0f40E00000;
+; NORMAL-NEXT: div.rn.f32 %f3, %f2, %f1;
+; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
+; NORMAL-NEXT: mul.f32 %f5, %f4, %f1;
+; NORMAL-NEXT: sub.f32 %f6, %f2, %f5;
+; NORMAL-NEXT: testp.infinite.f32 %p1, %f1;
+; NORMAL-NEXT: selp.f32 %f7, %f2, %f6, %p1;
+; NORMAL-NEXT: st.param.f32 [func_retval0], %f7;
+; NORMAL-NEXT: ret;
+ %r = frem float 7.0, %a
+ ret float %r
+}
>From 44e112a2cc6d2b256f3c5f7307798bf3967d3a58 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 26 Mar 2025 16:29:44 +0000
Subject: [PATCH 2/3] [NVPTX] Use fast-math flags when lowering sin, cos, frem
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 32 +++++
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 126 +++---------------
llvm/test/CodeGen/NVPTX/f16-instructions.ll | 8 +-
llvm/test/CodeGen/NVPTX/f16x2-instructions.ll | 8 +-
llvm/test/CodeGen/NVPTX/frem.ll | 108 +++++++--------
5 files changed, 108 insertions(+), 174 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 06e221777b7ea..8a4b83365ae84 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -18,6 +18,7 @@
#include "NVPTXTargetMachine.h"
#include "NVPTXTargetObjectFile.h"
#include "NVPTXUtilities.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
@@ -932,6 +933,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(Op, MVT::bf16, Promote);
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
+ setOperationAction(ISD::FREM, {MVT::f32, MVT::f64}, Custom);
setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
if (STI.getPTXVersion() >= 65) {
@@ -2819,6 +2821,34 @@ static SDValue lowerROT(SDValue Op, SelectionDAG &DAG) {
SDLoc(Op), Opcode, DAG);
}
+static SDValue lowerFREM(SDValue Op, SelectionDAG &DAG,
+ bool AllowUnsafeFPMath) {
+ // Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
+ // i.e. "poor man's fmod()". When y is infinite, x is returned. This matches
+ // the semantics of LLVM's frem.
+ SDLoc DL(Op);
+ SDValue X = Op->getOperand(0);
+ SDValue Y = Op->getOperand(1);
+ EVT Ty = Op.getValueType();
+
+ SDValue Div = DAG.getNode(ISD::FDIV, DL, Ty, X, Y);
+ SDValue Trunc = DAG.getNode(ISD::FTRUNC, DL, Ty, Div);
+ SDValue Mul =
+ DAG.getNode(ISD::FMUL, DL, Ty, Trunc, Y, SDNodeFlags::AllowContract);
+ SDValue Sub =
+ DAG.getNode(ISD::FSUB, DL, Ty, X, Mul, SDNodeFlags::AllowContract);
+
+ if (AllowUnsafeFPMath || Op->getFlags().hasNoInfs())
+ return Sub;
+
+ // If Y is infinite, return X
+ SDValue AbsY = DAG.getNode(ISD::FABS, DL, Ty, Y);
+ SDValue Inf =
+ DAG.getConstantFP(APFloat::getInf(Ty.getFltSemantics()), DL, Ty);
+ SDValue IsInf = DAG.getSetCC(DL, MVT::i1, AbsY, Inf, ISD::SETEQ);
+ return DAG.getSelect(DL, Ty, IsInf, X, Sub);
+}
+
SDValue
NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
switch (Op.getOpcode()) {
@@ -2913,6 +2943,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::CTPOP:
case ISD::CTLZ:
return lowerCTLZCTPOP(Op, DAG);
+ case ISD::FREM:
+ return lowerFREM(Op, DAG, allowUnsafeFPMath(DAG.getMachineFunction()));
default:
llvm_unreachable("Custom lowering not defined for operation");
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 1786503a6dd4e..fe9bb621b481c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -150,9 +150,6 @@ def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
def doMulWide : Predicate<"doMulWide">;
-def allowUnsafeFPMath : Predicate<"allowUnsafeFPMath()">;
-def noUnsafeFPMath : Predicate<"!allowUnsafeFPMath()">;
-
def do_DIVF32_APPROX : Predicate<"getDivF32Level()==0">;
def do_DIVF32_FULL : Predicate<"getDivF32Level()==1">;
@@ -211,6 +208,12 @@ class ValueToRegClass<ValueType T> {
// Some Common Instruction Class Templates
//===----------------------------------------------------------------------===//
+class OneUse1<SDPatternOperator operator>
+ : PatFrag<(ops node:$A), (operator node:$A), [{ return N->hasOneUse(); }]>;
+
+class fpimm_pos_inf<ValueType vt>
+ : FPImmLeaf<vt, [{ return Imm.isPosInfinity(); }]>;
+
// Utility class to wrap up information about a register and DAG type for more
// convenient iteration and parameterization
class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm> {
@@ -442,7 +445,7 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
class BinOpAllowsFMA<SDPatternOperator operator>
: PatFrag<(ops node:$A, node:$B),
(operator node:$A, node:$B), [{
- return allowFMA() || N->getFlags().hasAllowContract();;
+ return allowFMA() || N->getFlags().hasAllowContract();
}]>;
multiclass F3_fma_component<string op_str, SDNode op_node> {
@@ -693,10 +696,7 @@ let hasSideEffects = false in {
defm CVT_to_tf32_rz_relu_satf : CVT_TO_TF32<"rz.relu.satfinite", [hasPTX<86>, hasSM<100>]>;
}
-def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
- return N->hasOneUse();
-}]>;
-
+def fpround_oneuse : OneUse1<fpround>;
def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse f32:$lo)),
(bf16 (fpround_oneuse f32:$hi)))),
(CVT_bf16x2_f32 $hi, $lo, CvtRN)>,
@@ -786,18 +786,14 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
// Test Instructions
//-----------------------------------
+def fabs_oneuse : OneUse1<fabs>;
+
def TESTINF_f32r : NVPTXInst<(outs Int1Regs:$p), (ins Float32Regs:$a),
"testp.infinite.f32 \t$p, $a;",
- []>;
-def TESTINF_f32i : NVPTXInst<(outs Int1Regs:$p), (ins f32imm:$a),
- "testp.infinite.f32 \t$p, $a;",
- []>;
+ [(set i1:$p, (seteq (fabs_oneuse f32:$a), fpimm_pos_inf<f32>))]>;
def TESTINF_f64r : NVPTXInst<(outs Int1Regs:$p), (ins Float64Regs:$a),
"testp.infinite.f64 \t$p, $a;",
- []>;
-def TESTINF_f64i : NVPTXInst<(outs Int1Regs:$p), (ins f64imm:$a),
- "testp.infinite.f64 \t$p, $a;",
- []>;
+ [(set i1:$p, (seteq (fabs_oneuse f64:$a), fpimm_pos_inf<f64>))]>;
//-----------------------------------
// Integer Arithmetic
@@ -1362,99 +1358,19 @@ defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
// sin/cos
+
+class UnaryOpAllowsApproxFn<SDPatternOperator operator>
+ : PatFrag<(ops node:$A),
+ (operator node:$A), [{
+ return allowUnsafeFPMath() || N->getFlags().hasApproximateFuncs();
+}]>;
+
def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
"sin.approx.f32 \t$dst, $src;",
- [(set f32:$dst, (fsin f32:$src))]>,
- Requires<[allowUnsafeFPMath]>;
+ [(set f32:$dst, (UnaryOpAllowsApproxFn<fsin> f32:$src))]>;
def COSF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
"cos.approx.f32 \t$dst, $src;",
- [(set f32:$dst, (fcos f32:$src))]>,
- Requires<[allowUnsafeFPMath]>;
-
-// Lower (frem x, y) into (sub x, (mul (ftrunc (div x, y)) y)),
-// i.e. "poor man's fmod()". When y is infinite, x is returned. This matches the
-// semantics of LLVM's frem.
-
-// frem - f32 FTZ
-def : Pat<(frem f32:$x, f32:$y),
- (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
- (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
- $y))>,
- Requires<[doF32FTZ, allowUnsafeFPMath]>;
-def : Pat<(frem f32:$x, fpimm:$y),
- (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
- (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
- fpimm:$y))>,
- Requires<[doF32FTZ, allowUnsafeFPMath]>;
-
-def : Pat<(frem f32:$x, f32:$y),
- (SELP_f32rr $x,
- (FSUBf32rr_ftz $x, (FMULf32rr_ftz (CVT_f32_f32
- (FDIV32rr_prec_ftz $x, $y), CvtRZI_FTZ),
- $y)),
- (TESTINF_f32r $y))>,
- Requires<[doF32FTZ, noUnsafeFPMath]>;
-def : Pat<(frem f32:$x, fpimm:$y),
- (SELP_f32rr $x,
- (FSUBf32rr_ftz $x, (FMULf32ri_ftz (CVT_f32_f32
- (FDIV32ri_prec_ftz $x, fpimm:$y), CvtRZI_FTZ),
- fpimm:$y)),
- (TESTINF_f32i fpimm:$y))>,
- Requires<[doF32FTZ, noUnsafeFPMath]>;
-
-// frem - f32
-def : Pat<(frem f32:$x, f32:$y),
- (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
- (FDIV32rr_prec $x, $y), CvtRZI),
- $y))>,
- Requires<[allowUnsafeFPMath]>;
-def : Pat<(frem f32:$x, fpimm:$y),
- (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
- (FDIV32ri_prec $x, fpimm:$y), CvtRZI),
- fpimm:$y))>,
- Requires<[allowUnsafeFPMath]>;
-
-def : Pat<(frem f32:$x, f32:$y),
- (SELP_f32rr $x,
- (FSUBf32rr $x, (FMULf32rr (CVT_f32_f32
- (FDIV32rr_prec $x, $y), CvtRZI),
- $y)),
- (TESTINF_f32r Float32Regs:$y))>,
- Requires<[noUnsafeFPMath]>;
-def : Pat<(frem f32:$x, fpimm:$y),
- (SELP_f32rr $x,
- (FSUBf32rr $x, (FMULf32ri (CVT_f32_f32
- (FDIV32ri_prec $x, fpimm:$y), CvtRZI),
- fpimm:$y)),
- (TESTINF_f32i fpimm:$y))>,
- Requires<[noUnsafeFPMath]>;
-
-// frem - f64
-def : Pat<(frem f64:$x, f64:$y),
- (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
- (FDIV64rr $x, $y), CvtRZI),
- $y))>,
- Requires<[allowUnsafeFPMath]>;
-def : Pat<(frem f64:$x, fpimm:$y),
- (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
- (FDIV64ri $x, fpimm:$y), CvtRZI),
- fpimm:$y))>,
- Requires<[allowUnsafeFPMath]>;
-
-def : Pat<(frem f64:$x, f64:$y),
- (SELP_f64rr $x,
- (FSUBf64rr $x, (FMULf64rr (CVT_f64_f64
- (FDIV64rr $x, $y), CvtRZI),
- $y)),
- (TESTINF_f64r Float64Regs:$y))>,
- Requires<[noUnsafeFPMath]>;
-def : Pat<(frem f64:$x, fpimm:$y),
- (SELP_f64rr $x,
- (FSUBf64rr $x, (FMULf64ri (CVT_f64_f64
- (FDIV64ri $x, fpimm:$y), CvtRZI),
- fpimm:$y)),
- (TESTINF_f64r $y))>,
- Requires<[noUnsafeFPMath]>;
+ [(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
//-----------------------------------
// Bitwise operations
diff --git a/llvm/test/CodeGen/NVPTX/f16-instructions.ll b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
index 70d1167bbb6e2..b34dfc4e19766 100644
--- a/llvm/test/CodeGen/NVPTX/f16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
@@ -200,14 +200,14 @@ define half @test_fdiv(half %a, half %b) #0 {
; CHECK-NOFTZ-DAG: cvt.f32.f16 [[FB:%f[0-9]+]], [[B]];
; CHECK-NOFTZ-NEXT: div.rn.f32 [[D:%f[0-9]+]], [[FA]], [[FB]];
; CHECK-NOFTZ-NEXT: cvt.rzi.f32.f32 [[DI:%f[0-9]+]], [[D]];
-; CHECK-NOFTZ-NEXT: mul.f32 [[RI:%f[0-9]+]], [[DI]], [[FB]];
-; CHECK-NOFTZ-NEXT: sub.f32 [[RF:%f[0-9]+]], [[FA]], [[RI]];
+; CHECK-NOFTZ-NEXT: neg.f32 [[DNEG:%f[0-9]+]], [[DI]];
+; CHECK-NOFTZ-NEXT: fma.rn.f32 [[RF:%f[0-9]+]], [[DNEG]], [[FB]], [[FA]];
; CHECK-F16-FTZ-DAG: cvt.ftz.f32.f16 [[FA:%f[0-9]+]], [[A]];
; CHECK-F16-FTZ-DAG: cvt.ftz.f32.f16 [[FB:%f[0-9]+]], [[B]];
; CHECK-F16-FTZ-NEXT: div.rn.ftz.f32 [[D:%f[0-9]+]], [[FA]], [[FB]];
; CHECK-F16-FTZ-NEXT: cvt.rzi.ftz.f32.f32 [[DI:%f[0-9]+]], [[D]];
-; CHECK-F16-FTZ-NEXT: mul.ftz.f32 [[RI:%f[0-9]+]], [[DI]], [[FB]];
-; CHECK-F16-FTZ-NEXT: sub.ftz.f32 [[RF:%f[0-9]+]], [[FA]], [[RI]];
+; CHECK-F16-FTZ-NEXT: neg.ftz.f32 [[DNEG:%f[0-9]+]], [[DI]];
+; CHECK-F16-FTZ-NEXT: fma.rn.ftz.f32 [[RF:%f[0-9]+]], [[DNEG]], [[FB]], [[FA]];
; CHECK-NEXT: testp.infinite.f32 [[ISBINF:%p[0-9]+]], [[FB]];
; CHECK-NEXT: selp.f32 [[RESULT:%f[0-9]+]], [[FA]], [[RF]], [[ISBINF]];
; CHECK-NEXT: cvt.rn.f16.f32 [[R:%rs[0-9]+]], [[RESULT]];
diff --git a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
index 539e810c83cbd..d78b68dc501da 100644
--- a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll
@@ -362,8 +362,8 @@ define <2 x half> @test_frem(<2 x half> %a, <2 x half> %b) #0 {
; CHECK-NEXT: cvt.f32.f16 %f2, %rs4;
; CHECK-NEXT: div.rn.f32 %f3, %f2, %f1;
; CHECK-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; CHECK-NEXT: mul.f32 %f5, %f4, %f1;
-; CHECK-NEXT: sub.f32 %f6, %f2, %f5;
+; CHECK-NEXT: neg.f32 %f5, %f4;
+; CHECK-NEXT: fma.rn.f32 %f6, %f5, %f1, %f2;
; CHECK-NEXT: testp.infinite.f32 %p1, %f1;
; CHECK-NEXT: selp.f32 %f7, %f2, %f6, %p1;
; CHECK-NEXT: cvt.rn.f16.f32 %rs5, %f7;
@@ -371,8 +371,8 @@ define <2 x half> @test_frem(<2 x half> %a, <2 x half> %b) #0 {
; CHECK-NEXT: cvt.f32.f16 %f9, %rs3;
; CHECK-NEXT: div.rn.f32 %f10, %f9, %f8;
; CHECK-NEXT: cvt.rzi.f32.f32 %f11, %f10;
-; CHECK-NEXT: mul.f32 %f12, %f11, %f8;
-; CHECK-NEXT: sub.f32 %f13, %f9, %f12;
+; CHECK-NEXT: neg.f32 %f12, %f11;
+; CHECK-NEXT: fma.rn.f32 %f13, %f12, %f8, %f9;
; CHECK-NEXT: testp.infinite.f32 %p2, %f8;
; CHECK-NEXT: selp.f32 %f14, %f9, %f13, %p2;
; CHECK-NEXT: cvt.rn.f16.f32 %rs6, %f14;
diff --git a/llvm/test/CodeGen/NVPTX/frem.ll b/llvm/test/CodeGen/NVPTX/frem.ll
index f40bb1303820d..89e1f2e4c0055 100644
--- a/llvm/test/CodeGen/NVPTX/frem.ll
+++ b/llvm/test/CodeGen/NVPTX/frem.ll
@@ -16,10 +16,10 @@ define half @frem_f16(half %a, half %b) {
; FAST-NEXT: ld.param.b16 %rs2, [frem_f16_param_1];
; FAST-NEXT: cvt.f32.f16 %f1, %rs2;
; FAST-NEXT: cvt.f32.f16 %f2, %rs1;
-; FAST-NEXT: div.rn.f32 %f3, %f2, %f1;
+; FAST-NEXT: div.approx.f32 %f3, %f2, %f1;
; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; FAST-NEXT: mul.f32 %f5, %f4, %f1;
-; FAST-NEXT: sub.f32 %f6, %f2, %f5;
+; FAST-NEXT: neg.f32 %f5, %f4;
+; FAST-NEXT: fma.rn.f32 %f6, %f5, %f1, %f2;
; FAST-NEXT: cvt.rn.f16.f32 %rs3, %f6;
; FAST-NEXT: st.param.b16 [func_retval0], %rs3;
; FAST-NEXT: ret;
@@ -37,8 +37,8 @@ define half @frem_f16(half %a, half %b) {
; NORMAL-NEXT: cvt.f32.f16 %f2, %rs1;
; NORMAL-NEXT: div.rn.f32 %f3, %f2, %f1;
; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; NORMAL-NEXT: mul.f32 %f5, %f4, %f1;
-; NORMAL-NEXT: sub.f32 %f6, %f2, %f5;
+; NORMAL-NEXT: neg.f32 %f5, %f4;
+; NORMAL-NEXT: fma.rn.f32 %f6, %f5, %f1, %f2;
; NORMAL-NEXT: testp.infinite.f32 %p1, %f1;
; NORMAL-NEXT: selp.f32 %f7, %f2, %f6, %p1;
; NORMAL-NEXT: cvt.rn.f16.f32 %rs3, %f7;
@@ -56,10 +56,10 @@ define float @frem_f32(float %a, float %b) {
; FAST-NEXT: // %bb.0:
; FAST-NEXT: ld.param.f32 %f1, [frem_f32_param_0];
; FAST-NEXT: ld.param.f32 %f2, [frem_f32_param_1];
-; FAST-NEXT: div.rn.f32 %f3, %f1, %f2;
+; FAST-NEXT: div.approx.f32 %f3, %f1, %f2;
; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; FAST-NEXT: mul.f32 %f5, %f4, %f2;
-; FAST-NEXT: sub.f32 %f6, %f1, %f5;
+; FAST-NEXT: neg.f32 %f5, %f4;
+; FAST-NEXT: fma.rn.f32 %f6, %f5, %f2, %f1;
; FAST-NEXT: st.param.f32 [func_retval0], %f6;
; FAST-NEXT: ret;
;
@@ -73,8 +73,8 @@ define float @frem_f32(float %a, float %b) {
; NORMAL-NEXT: ld.param.f32 %f2, [frem_f32_param_1];
; NORMAL-NEXT: div.rn.f32 %f3, %f1, %f2;
; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; NORMAL-NEXT: mul.f32 %f5, %f4, %f2;
-; NORMAL-NEXT: sub.f32 %f6, %f1, %f5;
+; NORMAL-NEXT: neg.f32 %f5, %f4;
+; NORMAL-NEXT: fma.rn.f32 %f6, %f5, %f2, %f1;
; NORMAL-NEXT: testp.infinite.f32 %p1, %f2;
; NORMAL-NEXT: selp.f32 %f7, %f1, %f6, %p1;
; NORMAL-NEXT: st.param.f32 [func_retval0], %f7;
@@ -93,8 +93,8 @@ define double @frem_f64(double %a, double %b) {
; FAST-NEXT: ld.param.f64 %fd2, [frem_f64_param_1];
; FAST-NEXT: div.rn.f64 %fd3, %fd1, %fd2;
; FAST-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
-; FAST-NEXT: mul.f64 %fd5, %fd4, %fd2;
-; FAST-NEXT: sub.f64 %fd6, %fd1, %fd5;
+; FAST-NEXT: neg.f64 %fd5, %fd4;
+; FAST-NEXT: fma.rn.f64 %fd6, %fd5, %fd2, %fd1;
; FAST-NEXT: st.param.f64 [func_retval0], %fd6;
; FAST-NEXT: ret;
;
@@ -108,8 +108,8 @@ define double @frem_f64(double %a, double %b) {
; NORMAL-NEXT: ld.param.f64 %fd2, [frem_f64_param_1];
; NORMAL-NEXT: div.rn.f64 %fd3, %fd1, %fd2;
; NORMAL-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
-; NORMAL-NEXT: mul.f64 %fd5, %fd4, %fd2;
-; NORMAL-NEXT: sub.f64 %fd6, %fd1, %fd5;
+; NORMAL-NEXT: neg.f64 %fd5, %fd4;
+; NORMAL-NEXT: fma.rn.f64 %fd6, %fd5, %fd2, %fd1;
; NORMAL-NEXT: testp.infinite.f64 %p1, %fd2;
; NORMAL-NEXT: selp.f64 %fd7, %fd1, %fd6, %p1;
; NORMAL-NEXT: st.param.f64 [func_retval0], %fd7;
@@ -129,19 +129,18 @@ define half @frem_f16_ninf(half %a, half %b) {
; FAST-NEXT: ld.param.b16 %rs2, [frem_f16_ninf_param_1];
; FAST-NEXT: cvt.f32.f16 %f1, %rs2;
; FAST-NEXT: cvt.f32.f16 %f2, %rs1;
-; FAST-NEXT: div.rn.f32 %f3, %f2, %f1;
+; FAST-NEXT: div.approx.f32 %f3, %f2, %f1;
; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; FAST-NEXT: mul.f32 %f5, %f4, %f1;
-; FAST-NEXT: sub.f32 %f6, %f2, %f5;
+; FAST-NEXT: neg.f32 %f5, %f4;
+; FAST-NEXT: fma.rn.f32 %f6, %f5, %f1, %f2;
; FAST-NEXT: cvt.rn.f16.f32 %rs3, %f6;
; FAST-NEXT: st.param.b16 [func_retval0], %rs3;
; FAST-NEXT: ret;
;
; NORMAL-LABEL: frem_f16_ninf(
; NORMAL: {
-; NORMAL-NEXT: .reg .pred %p<2>;
; NORMAL-NEXT: .reg .b16 %rs<4>;
-; NORMAL-NEXT: .reg .f32 %f<8>;
+; NORMAL-NEXT: .reg .f32 %f<7>;
; NORMAL-EMPTY:
; NORMAL-NEXT: // %bb.0:
; NORMAL-NEXT: ld.param.b16 %rs1, [frem_f16_ninf_param_0];
@@ -150,11 +149,9 @@ define half @frem_f16_ninf(half %a, half %b) {
; NORMAL-NEXT: cvt.f32.f16 %f2, %rs1;
; NORMAL-NEXT: div.rn.f32 %f3, %f2, %f1;
; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; NORMAL-NEXT: mul.f32 %f5, %f4, %f1;
-; NORMAL-NEXT: sub.f32 %f6, %f2, %f5;
-; NORMAL-NEXT: testp.infinite.f32 %p1, %f1;
-; NORMAL-NEXT: selp.f32 %f7, %f2, %f6, %p1;
-; NORMAL-NEXT: cvt.rn.f16.f32 %rs3, %f7;
+; NORMAL-NEXT: neg.f32 %f5, %f4;
+; NORMAL-NEXT: fma.rn.f32 %f6, %f5, %f1, %f2;
+; NORMAL-NEXT: cvt.rn.f16.f32 %rs3, %f6;
; NORMAL-NEXT: st.param.b16 [func_retval0], %rs3;
; NORMAL-NEXT: ret;
%r = frem ninf half %a, %b
@@ -169,28 +166,25 @@ define float @frem_f32_ninf(float %a, float %b) {
; FAST-NEXT: // %bb.0:
; FAST-NEXT: ld.param.f32 %f1, [frem_f32_ninf_param_0];
; FAST-NEXT: ld.param.f32 %f2, [frem_f32_ninf_param_1];
-; FAST-NEXT: div.rn.f32 %f3, %f1, %f2;
+; FAST-NEXT: div.approx.f32 %f3, %f1, %f2;
; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; FAST-NEXT: mul.f32 %f5, %f4, %f2;
-; FAST-NEXT: sub.f32 %f6, %f1, %f5;
+; FAST-NEXT: neg.f32 %f5, %f4;
+; FAST-NEXT: fma.rn.f32 %f6, %f5, %f2, %f1;
; FAST-NEXT: st.param.f32 [func_retval0], %f6;
; FAST-NEXT: ret;
;
; NORMAL-LABEL: frem_f32_ninf(
; NORMAL: {
-; NORMAL-NEXT: .reg .pred %p<2>;
-; NORMAL-NEXT: .reg .f32 %f<8>;
+; NORMAL-NEXT: .reg .f32 %f<7>;
; NORMAL-EMPTY:
; NORMAL-NEXT: // %bb.0:
; NORMAL-NEXT: ld.param.f32 %f1, [frem_f32_ninf_param_0];
; NORMAL-NEXT: ld.param.f32 %f2, [frem_f32_ninf_param_1];
; NORMAL-NEXT: div.rn.f32 %f3, %f1, %f2;
; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; NORMAL-NEXT: mul.f32 %f5, %f4, %f2;
-; NORMAL-NEXT: sub.f32 %f6, %f1, %f5;
-; NORMAL-NEXT: testp.infinite.f32 %p1, %f2;
-; NORMAL-NEXT: selp.f32 %f7, %f1, %f6, %p1;
-; NORMAL-NEXT: st.param.f32 [func_retval0], %f7;
+; NORMAL-NEXT: neg.f32 %f5, %f4;
+; NORMAL-NEXT: fma.rn.f32 %f6, %f5, %f2, %f1;
+; NORMAL-NEXT: st.param.f32 [func_retval0], %f6;
; NORMAL-NEXT: ret;
%r = frem ninf float %a, %b
ret float %r
@@ -206,26 +200,23 @@ define double @frem_f64_ninf(double %a, double %b) {
; FAST-NEXT: ld.param.f64 %fd2, [frem_f64_ninf_param_1];
; FAST-NEXT: div.rn.f64 %fd3, %fd1, %fd2;
; FAST-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
-; FAST-NEXT: mul.f64 %fd5, %fd4, %fd2;
-; FAST-NEXT: sub.f64 %fd6, %fd1, %fd5;
+; FAST-NEXT: neg.f64 %fd5, %fd4;
+; FAST-NEXT: fma.rn.f64 %fd6, %fd5, %fd2, %fd1;
; FAST-NEXT: st.param.f64 [func_retval0], %fd6;
; FAST-NEXT: ret;
;
; NORMAL-LABEL: frem_f64_ninf(
; NORMAL: {
-; NORMAL-NEXT: .reg .pred %p<2>;
-; NORMAL-NEXT: .reg .f64 %fd<8>;
+; NORMAL-NEXT: .reg .f64 %fd<7>;
; NORMAL-EMPTY:
; NORMAL-NEXT: // %bb.0:
; NORMAL-NEXT: ld.param.f64 %fd1, [frem_f64_ninf_param_0];
; NORMAL-NEXT: ld.param.f64 %fd2, [frem_f64_ninf_param_1];
; NORMAL-NEXT: div.rn.f64 %fd3, %fd1, %fd2;
; NORMAL-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
-; NORMAL-NEXT: mul.f64 %fd5, %fd4, %fd2;
-; NORMAL-NEXT: sub.f64 %fd6, %fd1, %fd5;
-; NORMAL-NEXT: testp.infinite.f64 %p1, %fd2;
-; NORMAL-NEXT: selp.f64 %fd7, %fd1, %fd6, %p1;
-; NORMAL-NEXT: st.param.f64 [func_retval0], %fd7;
+; NORMAL-NEXT: neg.f64 %fd5, %fd4;
+; NORMAL-NEXT: fma.rn.f64 %fd6, %fd5, %fd2, %fd1;
+; NORMAL-NEXT: st.param.f64 [func_retval0], %fd6;
; NORMAL-NEXT: ret;
%r = frem ninf double %a, %b
ret double %r
@@ -234,31 +225,26 @@ define double @frem_f64_ninf(double %a, double %b) {
define float @frem_f32_imm1(float %a) {
; FAST-LABEL: frem_f32_imm1(
; FAST: {
-; FAST-NEXT: .reg .f32 %f<6>;
+; FAST-NEXT: .reg .f32 %f<5>;
; FAST-EMPTY:
; FAST-NEXT: // %bb.0:
; FAST-NEXT: ld.param.f32 %f1, [frem_f32_imm1_param_0];
-; FAST-NEXT: div.rn.f32 %f2, %f1, 0f40E00000;
+; FAST-NEXT: mul.f32 %f2, %f1, 0f3E124925;
; FAST-NEXT: cvt.rzi.f32.f32 %f3, %f2;
-; FAST-NEXT: mul.f32 %f4, %f3, 0f40E00000;
-; FAST-NEXT: sub.f32 %f5, %f1, %f4;
-; FAST-NEXT: st.param.f32 [func_retval0], %f5;
+; FAST-NEXT: fma.rn.f32 %f4, %f3, 0fC0E00000, %f1;
+; FAST-NEXT: st.param.f32 [func_retval0], %f4;
; FAST-NEXT: ret;
;
; NORMAL-LABEL: frem_f32_imm1(
; NORMAL: {
-; NORMAL-NEXT: .reg .pred %p<2>;
-; NORMAL-NEXT: .reg .f32 %f<7>;
+; NORMAL-NEXT: .reg .f32 %f<5>;
; NORMAL-EMPTY:
; NORMAL-NEXT: // %bb.0:
; NORMAL-NEXT: ld.param.f32 %f1, [frem_f32_imm1_param_0];
; NORMAL-NEXT: div.rn.f32 %f2, %f1, 0f40E00000;
; NORMAL-NEXT: cvt.rzi.f32.f32 %f3, %f2;
-; NORMAL-NEXT: mul.f32 %f4, %f3, 0f40E00000;
-; NORMAL-NEXT: sub.f32 %f5, %f1, %f4;
-; NORMAL-NEXT: testp.infinite.f32 %p1, 0f40E00000;
-; NORMAL-NEXT: selp.f32 %f6, %f1, %f5, %p1;
-; NORMAL-NEXT: st.param.f32 [func_retval0], %f6;
+; NORMAL-NEXT: fma.rn.f32 %f4, %f3, 0fC0E00000, %f1;
+; NORMAL-NEXT: st.param.f32 [func_retval0], %f4;
; NORMAL-NEXT: ret;
%r = frem float %a, 7.0
ret float %r
@@ -272,10 +258,10 @@ define float @frem_f32_imm2(float %a) {
; FAST-NEXT: // %bb.0:
; FAST-NEXT: ld.param.f32 %f1, [frem_f32_imm2_param_0];
; FAST-NEXT: mov.b32 %f2, 0f40E00000;
-; FAST-NEXT: div.rn.f32 %f3, %f2, %f1;
+; FAST-NEXT: div.approx.f32 %f3, %f2, %f1;
; FAST-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; FAST-NEXT: mul.f32 %f5, %f4, %f1;
-; FAST-NEXT: sub.f32 %f6, %f2, %f5;
+; FAST-NEXT: neg.f32 %f5, %f4;
+; FAST-NEXT: fma.rn.f32 %f6, %f5, %f1, 0f40E00000;
; FAST-NEXT: st.param.f32 [func_retval0], %f6;
; FAST-NEXT: ret;
;
@@ -289,10 +275,10 @@ define float @frem_f32_imm2(float %a) {
; NORMAL-NEXT: mov.b32 %f2, 0f40E00000;
; NORMAL-NEXT: div.rn.f32 %f3, %f2, %f1;
; NORMAL-NEXT: cvt.rzi.f32.f32 %f4, %f3;
-; NORMAL-NEXT: mul.f32 %f5, %f4, %f1;
-; NORMAL-NEXT: sub.f32 %f6, %f2, %f5;
+; NORMAL-NEXT: neg.f32 %f5, %f4;
+; NORMAL-NEXT: fma.rn.f32 %f6, %f5, %f1, 0f40E00000;
; NORMAL-NEXT: testp.infinite.f32 %p1, %f1;
-; NORMAL-NEXT: selp.f32 %f7, %f2, %f6, %p1;
+; NORMAL-NEXT: selp.f32 %f7, 0f40E00000, %f6, %p1;
; NORMAL-NEXT: st.param.f32 [func_retval0], %f7;
; NORMAL-NEXT: ret;
%r = frem float 7.0, %a
>From f4307e0794e04577ca65ede63e77576763ef2df1 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 26 Mar 2025 23:26:15 +0000
Subject: [PATCH 3/3] address comments
---
llvm/test/CodeGen/NVPTX/frem.ll | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/llvm/test/CodeGen/NVPTX/frem.ll b/llvm/test/CodeGen/NVPTX/frem.ll
index 89e1f2e4c0055..73debfbfcdf49 100644
--- a/llvm/test/CodeGen/NVPTX/frem.ll
+++ b/llvm/test/CodeGen/NVPTX/frem.ll
@@ -1,6 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc < %s --enable-unsafe-fp-math | FileCheck %s --check-prefixes=FAST
-; RUN: llc < %s | FileCheck %s --check-prefixes=NORMAL
+; RUN: llc < %s --enable-unsafe-fp-math -mcpu=sm_60 | FileCheck %s --check-prefixes=FAST
+; RUN: llc < %s -mcpu=sm_60 | FileCheck %s --check-prefixes=NORMAL
target triple = "nvptx64-unknown-cuda"
More information about the llvm-commits
mailing list