[llvm] [LLVM][NVPTX]Add BF16 vector instruction and fix lowering rules (PR #69415)
Han Shen via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 20 01:20:41 PDT 2023
https://github.com/shenh10 updated https://github.com/llvm/llvm-project/pull/69415
>From b7585d8f33fe1385f25ff1b9b8d5285458d522ce Mon Sep 17 00:00:00 2001
From: shenhan03 <shenhan03 at kuaishou.com>
Date: Wed, 18 Oct 2023 11:40:12 +0800
Subject: [PATCH] Fix BF16 instruction lowering rules and provide most test
---
llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp | 4 +-
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 15 +-
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h | 1 +
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 63 ++-
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 1 +
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 37 +-
.../NVPTX/bf16x2-instructions-approx.ll | 42 ++
.../test/CodeGen/NVPTX/bf16x2-instructions.ll | 516 ++++++++++++++++++
8 files changed, 660 insertions(+), 19 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
create mode 100644 llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index f19beea3a3ed8b7..ae7e3bc4384bdc0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -1623,8 +1623,10 @@ SDValue SelectionDAGLegalize::ExpandFCOPYSIGN(SDNode *Node) const {
SignMask);
// If FABS is legal transform FCOPYSIGN(x, y) => sign(x) ? -FABS(x) : FABS(X)
+ // We don't do it in bf16 since the other path has less number of instructions
EVT FloatVT = Mag.getValueType();
- if (TLI.isOperationLegalOrCustom(ISD::FABS, FloatVT) &&
+ if (FloatVT != MVT::bf16 &&
+ TLI.isOperationLegalOrCustom(ISD::FABS, FloatVT) &&
TLI.isOperationLegalOrCustom(ISD::FNEG, FloatVT)) {
SDValue AbsValue = DAG.getNode(ISD::FABS, DL, FloatVT, Mag);
SDValue NegValue = DAG.getNode(ISD::FNEG, DL, FloatVT, AbsValue);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 68391cdb6ff172b..894a8636f45851e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -105,7 +105,9 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
case NVPTXISD::SETP_F16X2:
SelectSETP_F16X2(N);
return;
-
+ case NVPTXISD::SETP_BF16X2:
+ SelectSETP_BF16X2(N);
+ return;
case NVPTXISD::LoadV2:
case NVPTXISD::LoadV4:
if (tryLoadVector(N))
@@ -608,6 +610,17 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
return true;
}
+bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
+ unsigned PTXCmpMode =
+ getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
+ SDLoc DL(N);
+ SDNode *SetP = CurDAG->getMachineNode(
+ NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
+ N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+ ReplaceNode(N, SetP);
+ return true;
+}
+
// Find all instances of extract_vector_elt that use this v2f16 vector
// and coalesce them into a scattering move instruction.
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 06922331f5e2059..84c8432047ca314 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -74,6 +74,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryBFE(SDNode *N);
bool tryConstantFP(SDNode *N);
bool SelectSETP_F16X2(SDNode *N);
+ bool SelectSETP_BF16X2(SDNode *N);
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 617009334dd201c..0d6096f0c587b59 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -436,6 +436,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
case ISD::FADD:
case ISD::FMUL:
case ISD::FSUB:
+ case ISD::SELECT:
+ case ISD::SELECT_CC:
+ case ISD::SETCC:
+ case ISD::FEXP2:
+ case ISD::FCEIL:
+ case ISD::FFLOOR:
+ case ISD::FNEARBYINT:
+ case ISD::FRINT:
+ case ISD::FTRUNC:
IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
break;
}
@@ -488,8 +497,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2bf16, Expand);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2bf16, Expand);
- setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
+ setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
+ if (getOperationAction(ISD::SETCC, MVT::bf16) == Promote)
+ AddPromotedToType(ISD::SETCC, MVT::bf16, MVT::f32);
// Conversion to/from i16/i16x2 is always legal.
setOperationAction(ISD::BUILD_VECTOR, MVT::v2i16, Custom);
@@ -719,9 +730,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
setFP16OperationAction(Op, MVT::f16, Legal, Promote);
setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
- setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
// bf16 must be promoted to f32.
+ setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
if (getOperationAction(Op, MVT::bf16) == Promote)
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
@@ -741,21 +752,23 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// These map to conversion instructions for scalar FP types.
for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
ISD::FROUNDEVEN, ISD::FTRUNC}) {
- setOperationAction(Op, MVT::bf16, Legal);
setOperationAction(Op, MVT::f16, Legal);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
setOperationAction(Op, MVT::v2bf16, Expand);
+ setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
+ if (getOperationAction(Op, MVT::bf16) == Promote)
+ AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
setOperationAction(ISD::FROUND, MVT::f16, Promote);
setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
- setOperationAction(ISD::FROUND, MVT::bf16, Promote);
setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
setOperationAction(ISD::FROUND, MVT::f32, Custom);
setOperationAction(ISD::FROUND, MVT::f64, Custom);
-
+ setOperationAction(ISD::FROUND, MVT::bf16, Promote);
+ AddPromotedToType(ISD::FROUND, MVT::bf16, MVT::f32);
// 'Expand' implements FCOPYSIGN without calling an external library.
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
@@ -769,14 +782,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// promoted to f32. v2f16 is expanded to f16, which is then promoted
// to f32.
for (const auto &Op :
- {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FABS}) {
+ {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) {
setOperationAction(Op, MVT::f16, Promote);
- setOperationAction(Op, MVT::bf16, Promote);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
setOperationAction(Op, MVT::v2bf16, Expand);
+ setOperationAction(Op, MVT::bf16, Promote);
+ AddPromotedToType(Op, MVT::bf16, MVT::f32);
+ }
+ for (const auto &Op : {ISD::FABS}) {
+ setOperationAction(Op, MVT::f16, Promote);
+ setOperationAction(Op, MVT::f32, Legal);
+ setOperationAction(Op, MVT::f64, Legal);
+ setOperationAction(Op, MVT::v2f16, Expand);
+ setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
+ setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
+ if (getOperationAction(Op, MVT::bf16) == Promote)
+ AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
+
// max.f16, max.f16x2 and max.NaN are supported on sm_80+.
auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
bool IsAtLeastSm80 = STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
@@ -784,11 +809,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
};
for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote);
- setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
+ setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
+ if (getOperationAction(Op, MVT::bf16) == Promote)
+ AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand);
@@ -920,6 +947,8 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "NVPTXISD::PRMT";
case NVPTXISD::SETP_F16X2:
return "NVPTXISD::SETP_F16X2";
+ case NVPTXISD::SETP_BF16X2:
+ return "NVPTXISD::SETP_BF16X2";
case NVPTXISD::Dummy:
return "NVPTXISD::Dummy";
case NVPTXISD::MUL_WIDE_SIGNED:
@@ -5382,12 +5411,17 @@ static SDValue PerformSHLCombine(SDNode *N,
}
static SDValue PerformSETCCCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI) {
+ TargetLowering::DAGCombinerInfo &DCI,
+ unsigned int SmVersion) {
EVT CCType = N->getValueType(0);
SDValue A = N->getOperand(0);
SDValue B = N->getOperand(1);
- if (CCType != MVT::v2i1 || A.getValueType() != MVT::v2f16)
+ EVT AType = A.getValueType();
+ if (!(CCType == MVT::v2i1 && (AType == MVT::v2f16 || AType == MVT::v2bf16)))
+ return SDValue();
+
+ if (A.getValueType() == MVT::v2bf16 && SmVersion < 90)
return SDValue();
SDLoc DL(N);
@@ -5395,9 +5429,10 @@ static SDValue PerformSETCCCombine(SDNode *N,
// convert back to v2i1. The returned result will be scalarized by
// the legalizer, but the comparison will remain a single vector
// instruction.
- SDValue CCNode = DCI.DAG.getNode(NVPTXISD::SETP_F16X2, DL,
- DCI.DAG.getVTList(MVT::i1, MVT::i1),
- {A, B, N->getOperand(2)});
+ SDValue CCNode = DCI.DAG.getNode(
+ A.getValueType() == MVT::v2f16 ? NVPTXISD::SETP_F16X2
+ : NVPTXISD::SETP_BF16X2,
+ DL, DCI.DAG.getVTList(MVT::i1, MVT::i1), {A, B, N->getOperand(2)});
return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, CCType, CCNode.getValue(0),
CCNode.getValue(1));
}
@@ -5536,7 +5571,7 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SREM:
return PerformREMCombine(N, DCI, OptLevel);
case ISD::SETCC:
- return PerformSETCCCombine(N, DCI);
+ return PerformSETCCCombine(N, DCI, STI.getSmVersion());
case ISD::LOAD:
return PerformLOADCombine(N, DCI);
case NVPTXISD::StoreRetval:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index f6932db2aeb0b9e..54e34dedc6675e8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -57,6 +57,7 @@ enum NodeType : unsigned {
MUL_WIDE_UNSIGNED,
IMAD,
SETP_F16X2,
+ SETP_BF16X2,
BFE,
BFI,
PRMT,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b0b96b94a125752..7c04d0451f7a185 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -555,6 +555,34 @@ multiclass F2<string OpcStr, SDNode OpNode> {
[(set Float32Regs:$dst, (OpNode Float32Regs:$a))]>;
}
+multiclass F2_Support_Half<string OpcStr, SDNode OpNode> {
+ def bf16 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
+ !strconcat(OpcStr, ".bf16 \t$dst, $a;"),
+ [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a)))]>,
+ Requires<[hasSM<80>, hasPTX<70>]>;
+ def bf16x2 : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
+ !strconcat(OpcStr, ".v2bf16 \t$dst, $a;"),
+ [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a)))]>,
+ Requires<[hasSM<80>, hasPTX<70>]>;
+ def f16_ftz : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
+ !strconcat(OpcStr, ".ftz.f16 \t$dst, $a;"),
+ [(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a)))]>,
+ Requires<[hasSM<53>, hasPTX<65>, doF32FTZ]>;
+ def f16x2_ftz : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
+ !strconcat(OpcStr, ".ftz.f16x2 \t$dst, $a;"),
+ [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a)))]>,
+ Requires<[hasSM<53>, hasPTX<65>, doF32FTZ]>;
+ def f16 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a),
+ !strconcat(OpcStr, ".f16 \t$dst, $a;"),
+ [(set Int16Regs:$dst, (OpNode (f16 Int16Regs:$a)))]>,
+ Requires<[hasSM<53>, hasPTX<65>]>;
+ def f16x2 : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a),
+ !strconcat(OpcStr, ".f16x2 \t$dst, $a;"),
+ [(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a)))]>,
+ Requires<[hasSM<53>, hasPTX<65>]>;
+
+}
+
//===----------------------------------------------------------------------===//
// NVPTX Instructions.
//===----------------------------------------------------------------------===//
@@ -1139,6 +1167,9 @@ defm FMAXNAN : F3<"max.NaN", fmaximum>;
defm FABS : F2<"abs", fabs>;
defm FNEG : F2<"neg", fneg>;
+defm FABS_H: F2_Support_Half<"abs", fabs>;
+defm FNEG_H: F2_Support_Half<"neg", fneg>;
+
defm FSQRT : F2<"sqrt.rn", fsqrt>;
//
@@ -1936,14 +1967,14 @@ def SETP_bf16rr :
NVPTXInst<(outs Int1Regs:$dst),
(ins Int16Regs:$a, Int16Regs:$b, CmpMode:$cmp),
"setp${cmp:base}${cmp:ftz}.bf16 \t$dst, $a, $b;",
- []>, Requires<[hasBF16Math]>;
+ []>, Requires<[hasBF16Math, hasPTX<78>, hasSM<90>]>;
def SETP_bf16x2rr :
NVPTXInst<(outs Int1Regs:$p, Int1Regs:$q),
(ins Int32Regs:$a, Int32Regs:$b, CmpMode:$cmp),
"setp${cmp:base}${cmp:ftz}.bf16x2 \t$p|$q, $a, $b;",
[]>,
- Requires<[hasBF16Math]>;
+ Requires<[hasBF16Math, hasPTX<78>, hasSM<90>]>;
// FIXME: This doesn't appear to be correct. The "set" mnemonic has the form
@@ -1974,7 +2005,7 @@ defm SET_b64 : SET<"b64", Int64Regs, i64imm>;
defm SET_s64 : SET<"s64", Int64Regs, i64imm>;
defm SET_u64 : SET<"u64", Int64Regs, i64imm>;
defm SET_f16 : SET<"f16", Int16Regs, f16imm>;
-defm SET_bf16 : SET<"bf16", Int16Regs, bf16imm>;
+defm SET_bf16 : SET<"bf16", Int16Regs, bf16imm>, Requires<[hasPTX<78>, hasSM<90>]>;
defm SET_f32 : SET<"f32", Float32Regs, f32imm>;
defm SET_f64 : SET<"f64", Float64Regs, f64imm>;
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
new file mode 100644
index 000000000000000..dde8555c35af4c8
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
@@ -0,0 +1,42 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 --enable-unsafe-fp-math | FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-11.8 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 --enable-unsafe-fp-math | %ptxas-verify -arch=sm_80 %}
+
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+declare <2 x bfloat> @llvm.sin.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a) #0
+
+; CHECK-LABEL: test_sin(
+; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_sin_param_0];
+; CHECK: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; CHECK-DAG: cvt.f32.bf16 [[AF0:%f[0-9]+]], [[A0]];
+; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]];
+; CHECK-DAG: sin.approx.f32 [[RF0:%f[0-9]+]], [[AF0]];
+; CHECK-DAG: sin.approx.f32 [[RF1:%f[0-9]+]], [[AF1]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
+ %r = call <2 x bfloat> @llvm.sin.f16(<2 x bfloat> %a)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_cos(
+; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_cos_param_0];
+; CHECK: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; CHECK-DAG: cvt.f32.bf16 [[AF0:%f[0-9]+]], [[A0]];
+; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]];
+; CHECK-DAG: cos.approx.f32 [[RF0:%f[0-9]+]], [[AF0]];
+; CHECK-DAG: cos.approx.f32 [[RF1:%f[0-9]+]], [[AF1]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_cos(<2 x bfloat> %a) #0 #1 {
+ %r = call <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a)
+ ret <2 x bfloat> %r
+}
+
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
new file mode 100644
index 000000000000000..38a5cf281da9dff
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -0,0 +1,516 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 | FileCheck --check-prefixes=CHECK,SM80 %s
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx78 | FileCheck --check-prefixes=CHECK,SM90 %s
+; RUN: %if ptxas-11.8 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 | %ptxas-verify -arch=sm_80 %}
+; RUN: %if ptxas-11.8 %{ llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx78 | %ptxas-verify -arch=sm_90 %}
+
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+; CHECK-LABEL: test_ret_const(
+; CHECK: mov.b32 [[T:%r[0-9+]]], 1073758080;
+; CHECK: st.param.b32 [func_retval0+0], [[T]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_ret_const() #0 {
+ ret <2 x bfloat> <bfloat 1.0, bfloat 2.0>
+}
+
+; Check that we can lower fadd with immediate arguments.
+; CHECK-LABEL: test_fadd_imm_0(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fadd_imm_0_param_0];
+;
+; SM90-DAG: mov.b32 [[I:%r[0-9+]]], 1073758080;
+; SM90-DAG: add.rn.bf16x2 [[R:%r[0-9]+]], [[A]], [[I]];
+;
+; SM80-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]]
+; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]
+; SM80-DAG: add.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], 0f3F800000;
+; SM80-DAG: add.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], 0f40000000;
+; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]
+; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]
+; SM80-DAG: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_fadd_imm_0(<2 x bfloat> %a) #0 {
+ %r = fadd <2 x bfloat> <bfloat 1.0, bfloat 2.0>, %a
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fadd_imm_1(
+; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_fadd_imm_1_param_0];
+; SM90: mov.b16 [[B:%rs[0-9]+]], 0x3F80;
+; SM90: add.rn.bf16 [[R:%rs[0-9]+]], [[A]], [[B]];
+
+; SM80-DAG: cvt.f32.bf16 [[FA:%f[0-9]+]], [[A]];
+; SM80: add.rn.f32 [[FR:%f[0-9]+]], [[FA]], 0f3F800000;
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[FR]];
+
+; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define bfloat @test_fadd_imm_1(bfloat %a) #0 {
+ %r = fadd bfloat %a, 1.0
+ ret bfloat %r
+}
+
+; CHECK-LABEL: test_fsubx2(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fsubx2_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fsubx2_param_1];
+; SM90: sub.rn.bf16x2 [[R:%r[0-9]+]], [[A]], [[B]];
+
+; SM80-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
+; SM80-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]];
+; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; SM80-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]];
+; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
+; SM80-DAG: sub.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; SM80-DAG: sub.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
+; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+
+define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fsub <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fmulx2(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fmulx2_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fmulx2_param_1];
+; SM90: mul.rn.bf16x2 [[R:%r[0-9]+]], [[A]], [[B]];
+
+; SM80-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
+; SM80-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]];
+; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; SM80-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]];
+; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
+; SM80-DAG: mul.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; SM80-DAG: mul.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
+; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+
+define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fmul <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fdiv(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fdiv_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fdiv_param_1];
+; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; CHECK-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]]
+; CHECK-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; CHECK-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; CHECK-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]];
+; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
+; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
+; CHECK-NEXT: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fdiv <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fneg(
+; CHECK-DAG: ld.param.u32 [[A:%r[0-9]+]], [test_fneg_param_0];
+
+; CHECK-DAG: xor.b32 [[IHH0:%r[0-9]+]], [[A]], -2147450880;
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[IHH0]];
+; CHECK-NEXT: ret;
+define <2 x bfloat> @test_fneg(<2 x bfloat> %a) #0 {
+ %r = fneg <2 x bfloat> %a
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: .func test_ldst_v2bf16(
+; CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [test_ldst_v2bf16_param_0];
+; CHECK-DAG: ld.param.u64 %[[B:rd[0-9]+]], [test_ldst_v2bf16_param_1];
+; CHECK-DAG: ld.b32 [[E:%r[0-9]+]], [%[[A]]]
+; CHECK-DAG: st.b32 [%[[B]]], [[E]];
+; CHECK: ret;
+define void @test_ldst_v2bf16(<2 x bfloat>* %a, <2 x bfloat>* %b) {
+ %t1 = load <2 x bfloat>, <2 x bfloat>* %a
+ store <2 x bfloat> %t1, <2 x bfloat>* %b, align 16
+ ret void
+}
+
+; CHECK-LABEL: .func test_ldst_v3bf16(
+; CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [test_ldst_v3bf16_param_0];
+; CHECK-DAG: ld.param.u64 %[[B:rd[0-9]+]], [test_ldst_v3bf16_param_1];
+; -- v3 is inconvenient to capture as it's lowered as ld.b64 + fair
+; number of bitshifting instructions that may change at llvm's whim.
+; So we only verify that we only issue correct number of writes using
+; correct offset, but not the values we write.
+; CHECK-DAG: ld.u64
+; CHECK-DAG: st.u32 [%[[B]]],
+; CHECK-DAG: st.b16 [%[[B]]+4],
+; CHECK: ret;
+define void @test_ldst_v3bf16(<3 x bfloat>* %a, <3 x bfloat>* %b) {
+ %t1 = load <3 x bfloat>, <3 x bfloat>* %a
+ store <3 x bfloat> %t1, <3 x bfloat>* %b, align 16
+ ret void
+}
+
+declare <2 x bfloat> @test_callee(<2 x bfloat> %a, <2 x bfloat> %b) #0
+
+; CHECK-LABEL: test_call(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_call_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_call_param_1];
+; CHECK: {
+; CHECK-DAG: .param .align 4 .b8 param0[4];
+; CHECK-DAG: .param .align 4 .b8 param1[4];
+; CHECK-DAG: st.param.b32 [param0+0], [[A]];
+; CHECK-DAG: st.param.b32 [param1+0], [[B]];
+; CHECK-DAG: .param .align 4 .b8 retval0[4];
+; CHECK: call.uni (retval0),
+; CHECK-NEXT: test_callee,
+; CHECK: );
+; CHECK-NEXT: ld.param.b32 [[R:%r[0-9]+]], [retval0+0];
+; CHECK-NEXT: }
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_call(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = call <2 x bfloat> @test_callee(<2 x bfloat> %a, <2 x bfloat> %b)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_select(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_select_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_select_param_1];
+; CHECK-DAG: ld.param.u8 [[C:%rs[0-9]+]], [test_select_param_2]
+; CHECK-DAG: setp.eq.b16 [[PRED:%p[0-9]+]], %rs{{.*}}, 1;
+; CHECK-NEXT: selp.b32 [[R:%r[0-9]+]], [[A]], [[B]], [[PRED]];
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_select(<2 x bfloat> %a, <2 x bfloat> %b, i1 zeroext %c) #0 {
+ %r = select i1 %c, <2 x bfloat> %a, <2 x bfloat> %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_select_cc(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_select_cc_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_select_cc_param_1];
+; CHECK-DAG: ld.param.b32 [[C:%r[0-9]+]], [test_select_cc_param_2];
+; CHECK-DAG: ld.param.b32 [[D:%r[0-9]+]], [test_select_cc_param_3];
+;
+; SM90: setp.neu.bf16x2 [[P0:%p[0-9]+]]|[[P1:%p[0-9]+]], [[C]], [[D]]
+;
+; SM80-DAG: mov.b32 {[[C0:%rs[0-9]+]], [[C1:%rs[0-9]+]]}, [[C]]
+; SM80-DAG: mov.b32 {[[D0:%rs[0-9]+]], [[D1:%rs[0-9]+]]}, [[D]]
+; SM80-DAG: cvt.f32.bf16 [[DF0:%f[0-9]+]], [[D0]];
+; SM80-DAG: cvt.f32.bf16 [[CF0:%f[0-9]+]], [[C0]];
+; SM80-DAG: cvt.f32.bf16 [[DF1:%f[0-9]+]], [[D1]];
+; SM80-DAG: cvt.f32.bf16 [[CF1:%f[0-9]+]], [[C1]];
+; SM80-DAG: setp.neu.f32 [[P0:%p[0-9]+]], [[CF0]], [[DF0]]
+; SM80-DAG: setp.neu.f32 [[P1:%p[0-9]+]], [[CF1]], [[DF1]]
+;
+; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; CHECK-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]]
+; CHECK-DAG: selp.b16 [[R0:%rs[0-9]+]], [[A0]], [[B0]], [[P0]];
+; CHECK-DAG: selp.b16 [[R1:%rs[0-9]+]], [[A1]], [[B1]], [[P1]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_select_cc(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c, <2 x bfloat> %d) #0 {
+ %cc = fcmp une <2 x bfloat> %c, %d
+ %r = select <2 x i1> %cc, <2 x bfloat> %a, <2 x bfloat> %b
+ ret <2 x bfloat> %r
+}
+
+
+; CHECK-LABEL: test_select_cc_f32_bf16(
+; CHECK-DAG: ld.param.v2.f32 {[[A0:%f[0-9]+]], [[A1:%f[0-9]+]]}, [test_select_cc_f32_bf16_param_0];
+; CHECK-DAG: ld.param.b32 [[C:%r[0-9]+]], [test_select_cc_f32_bf16_param_2];
+; CHECK-DAG: ld.param.b32 [[D:%r[0-9]+]], [test_select_cc_f32_bf16_param_3];
+; SM90: setp.neu.bf16x2 [[P0:%p[0-9]+]]|[[P1:%p[0-9]+]], [[C]], [[D]]
+; CHECK-DAG: ld.param.v2.f32 {[[B0:%f[0-9]+]], [[B1:%f[0-9]+]]}, [test_select_cc_f32_bf16_param_1];
+
+; SM80-DAG: mov.b32 {[[C0:%rs[0-9]+]], [[C1:%rs[0-9]+]]}, [[C]]
+; SM80-DAG: mov.b32 {[[D0:%rs[0-9]+]], [[D1:%rs[0-9]+]]}, [[D]]
+; SM80-DAG: cvt.f32.bf16 [[DF0:%f[0-9]+]], [[D0]];
+; SM80-DAG: cvt.f32.bf16 [[CF0:%f[0-9]+]], [[C0]];
+; SM80-DAG: cvt.f32.bf16 [[DF1:%f[0-9]+]], [[D1]];
+; SM80-DAG: cvt.f32.bf16 [[CF1:%f[0-9]+]], [[C1]];
+; SM80-DAG: setp.neu.f32 [[P0:%p[0-9]+]], [[CF0]], [[DF0]]
+; SM80-DAG: setp.neu.f32 [[P1:%p[0-9]+]], [[CF1]], [[DF1]]
+;
+; CHECK-DAG: selp.f32 [[R0:%f[0-9]+]], [[A0]], [[B0]], [[P0]];
+; CHECK-DAG: selp.f32 [[R1:%f[0-9]+]], [[A1]], [[B1]], [[P1]];
+; CHECK-NEXT: st.param.v2.f32 [func_retval0+0], {[[R0]], [[R1]]};
+; CHECK-NEXT: ret;
+define <2 x float> @test_select_cc_f32_bf16(<2 x float> %a, <2 x float> %b,
+ <2 x bfloat> %c, <2 x bfloat> %d) #0 {
+ %cc = fcmp une <2 x bfloat> %c, %d
+ %r = select <2 x i1> %cc, <2 x float> %a, <2 x float> %b
+ ret <2 x float> %r
+}
+
+; CHECK-LABEL: test_select_cc_bf16_f32(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_select_cc_bf16_f32_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_select_cc_bf16_f32_param_1];
+; CHECK-DAG: ld.param.v2.f32 {[[C0:%f[0-9]+]], [[C1:%f[0-9]+]]}, [test_select_cc_bf16_f32_param_2];
+; CHECK-DAG: ld.param.v2.f32 {[[D0:%f[0-9]+]], [[D1:%f[0-9]+]]}, [test_select_cc_bf16_f32_param_3];
+; CHECK-DAG: setp.neu.f32 [[P0:%p[0-9]+]], [[C0]], [[D0]]
+; CHECK-DAG: setp.neu.f32 [[P1:%p[0-9]+]], [[C1]], [[D1]]
+; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; CHECK-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]]
+; CHECK-DAG: selp.b16 [[R0:%rs[0-9]+]], [[A0]], [[B0]], [[P0]];
+; CHECK-DAG: selp.b16 [[R1:%rs[0-9]+]], [[A1]], [[B1]], [[P1]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+define <2 x bfloat> @test_select_cc_bf16_f32(<2 x bfloat> %a, <2 x bfloat> %b,
+ <2 x float> %c, <2 x float> %d) #0 {
+ %cc = fcmp une <2 x float> %c, %d
+ %r = select <2 x i1> %cc, <2 x bfloat> %a, <2 x bfloat> %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fptrunc_2xfloat(
+; CHECK: ld.param.v2.f32 {[[A0:%f[0-9]+]], [[A1:%f[0-9]+]]}, [test_fptrunc_2xfloat_param_0];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[A0]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[A1]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_fptrunc_2xfloat(<2 x float> %a) #0 {
+ %r = fptrunc <2 x float> %a to <2 x bfloat>
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fpext_2xfloat(
+; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_fpext_2xfloat_param_0];
+; CHECK: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; CHECK-DAG: cvt.f32.bf16 [[R0:%f[0-9]+]], [[A0]];
+; CHECK-DAG: cvt.f32.bf16 [[R1:%f[0-9]+]], [[A1]];
+; CHECK-NEXT: st.param.v2.f32 [func_retval0+0], {[[R0]], [[R1]]};
+; CHECK: ret;
+define <2 x float> @test_fpext_2xfloat(<2 x bfloat> %a) #0 {
+ %r = fpext <2 x bfloat> %a to <2 x float>
+ ret <2 x float> %r
+}
+
+; CHECK-LABEL: test_bitcast_2xbf16_to_2xi16(
+; CHECK: ld.param.u32 [[A:%r[0-9]+]], [test_bitcast_2xbf16_to_2xi16_param_0];
+; CHECK: st.param.b32 [func_retval0+0], [[A]]
+; CHECK: ret;
+define <2 x i16> @test_bitcast_2xbf16_to_2xi16(<2 x bfloat> %a) #0 {
+ %r = bitcast <2 x bfloat> %a to <2 x i16>
+ ret <2 x i16> %r
+}
+
+
+; CHECK-LABEL: test_bitcast_2xi16_to_2xbf16(
+; CHECK: ld.param.b32 [[R]], [test_bitcast_2xi16_to_2xbf16_param_0];
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_bitcast_2xi16_to_2xbf16(<2 x i16> %a) #0 {
+ %r = bitcast <2 x i16> %a to <2 x bfloat>
+ ret <2 x bfloat> %r
+}
+
+declare <2 x bfloat> @llvm.sqrt.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.powi.f16(<2 x bfloat> %a, <2 x i32> %b) #0
+declare <2 x bfloat> @llvm.sin.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.pow.f16(<2 x bfloat> %a, <2 x bfloat> %b) #0
+declare <2 x bfloat> @llvm.exp.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.exp2.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.log.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.log10.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.log2.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.fma.f16(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) #0
+declare <2 x bfloat> @llvm.fabs.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.minnum.f16(<2 x bfloat> %a, <2 x bfloat> %b) #0
+declare <2 x bfloat> @llvm.maxnum.f16(<2 x bfloat> %a, <2 x bfloat> %b) #0
+declare <2 x bfloat> @llvm.copysign.f16(<2 x bfloat> %a, <2 x bfloat> %b) #0
+declare <2 x bfloat> @llvm.floor.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.ceil.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.trunc.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.rint.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.nearbyint.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.round.f16(<2 x bfloat> %a) #0
+declare <2 x bfloat> @llvm.fmuladd.f16(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) #0
+
+
+; CHECK-LABEL: test_sqrt(
+; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_sqrt_param_0];
+; CHECK: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; CHECK-DAG: cvt.f32.bf16 [[AF0:%f[0-9]+]], [[A0]];
+; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]];
+; CHECK-DAG: sqrt.rn.f32 [[RF0:%f[0-9]+]], [[AF0]];
+; CHECK-DAG: sqrt.rn.f32 [[RF1:%f[0-9]+]], [[AF1]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_sqrt(<2 x bfloat> %a) #0 {
+ %r = call <2 x bfloat> @llvm.sqrt.f16(<2 x bfloat> %a)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fmuladd(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fmuladd_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fmuladd_param_1];
+; CHECK-DAG: ld.param.b32 [[C:%r[0-9]+]], [test_fmuladd_param_2];
+;
+; CHECK: fma.rn.bf16x2 [[RA:%r[0-9]+]], [[A]], [[B]], [[C]];
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[RA]];
+; CHECK: ret;
+define <2 x bfloat> @test_fmuladd(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c) #0 {
+ %r = call <2 x bfloat> @llvm.fmuladd.f16(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bfloat> %c)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fabs(
+; CHECK: ld.param.u32 [[A:%r[0-9]+]], [test_fabs_param_0];
+; CHECK: and.b32 [[R:%r[0-9]+]], [[A]], 2147450879;
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_fabs(<2 x bfloat> %a) #0 {
+ %r = call <2 x bfloat> @llvm.fabs.f16(<2 x bfloat> %a)
+ ret <2 x bfloat> %r
+}
+
+
+; CHECK-LABEL: test_minnum(
+; CHECK-DAG: ld.param.b32 [[AF0:%r[0-9]+]], [test_minnum_param_0];
+; CHECK-DAG: ld.param.b32 [[BF0:%r[0-9]+]], [test_minnum_param_1];
+; CHECK-DAG: min.bf16x2 [[RF0:%r[0-9]+]], [[AF0]], [[BF0]];
+; CHECK: st.param.b32 [func_retval0+0], [[RF0]];
+; CHECK: ret;
+define <2 x bfloat> @test_minnum(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = call <2 x bfloat> @llvm.minnum.f16(<2 x bfloat> %a, <2 x bfloat> %b)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_maxnum(
+; CHECK-DAG: ld.param.b32 [[AF0:%r[0-9]+]], [test_maxnum_param_0];
+; CHECK-DAG: ld.param.b32 [[BF0:%r[0-9]+]], [test_maxnum_param_1];
+; CHECK-DAG: max.bf16x2 [[RF0:%r[0-9]+]], [[AF0]], [[BF0]];
+; CHECK: st.param.b32 [func_retval0+0], [[RF0]];
+; CHECK: ret;
+define <2 x bfloat> @test_maxnum(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = call <2 x bfloat> @llvm.maxnum.f16(<2 x bfloat> %a, <2 x bfloat> %b)
+ ret <2 x bfloat> %r
+}
+
+
+
+; CHECK-LABEL: test_floor(
+; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_floor_param_0];
+; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
+; SM90: cvt.rmi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
+; SM90: cvt.rmi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
+; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; SM80-DAG: cvt.rmi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
+; SM80-DAG: cvt.rmi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 {
+ %r = call <2 x bfloat> @llvm.floor.f16(<2 x bfloat> %a)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_ceil(
+; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_ceil_param_0];
+; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
+; SM90: cvt.rpi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
+; SM90: cvt.rpi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
+; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; SM80-DAG: cvt.rpi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
+; SM80-DAG: cvt.rpi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 {
+ %r = call <2 x bfloat> @llvm.ceil.f16(<2 x bfloat> %a)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_trunc(
+; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_trunc_param_0];
+; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
+; SM90: cvt.rzi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
+; SM90: cvt.rzi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 {
+ %r = call <2 x bfloat> @llvm.trunc.f16(<2 x bfloat> %a)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_rint(
+; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_rint_param_0];
+; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
+; SM90: cvt.rni.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
+; SM90: cvt.rni.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
+; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_rint(<2 x bfloat> %a) #0 {
+ %r = call <2 x bfloat> @llvm.rint.f16(<2 x bfloat> %a)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_round(
+; CHECK: ld.param.b32 {{.*}}, [test_round_param_0];
+; check the use of sign mask and 0.5 to implement round
+; CHECK: and.b32 [[R1:%r[0-9]+]], {{.*}}, -2147483648;
+; CHECK: or.b32 {{.*}}, [[R1]], 1056964608;
+; CHECK: and.b32 [[R2:%r[0-9]+]], {{.*}}, -2147483648;
+; CHECK: or.b32 {{.*}}, [[R2]], 1056964608;
+; CHECK: st.param.b32 [func_retval0+0], {{.*}};
+; CHECK: ret;
+define <2 x bfloat> @test_round(<2 x bfloat> %a) #0 {
+ %r = call <2 x bfloat> @llvm.round.f16(<2 x bfloat> %a)
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_copysign(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_copysign_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_copysign_param_1];
+; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; CHECK-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]]
+; CHECK-DAG: and.b16 [[AX0:%rs[0-9]+]], [[A0]], 32767;
+; CHECK-DAG: and.b16 [[AX1:%rs[0-9]+]], [[A1]], 32767;
+; CHECK-DAG: and.b16 [[BX0:%rs[0-9]+]], [[B0]], -32768;
+; CHECK-DAG: and.b16 [[BX1:%rs[0-9]+]], [[B1]], -32768;
+; CHECK-DAG: or.b16 [[RS0:%rs[0-9]+]], [[AX0]], [[BX0]];
+; CHECK-DAG: or.b16 [[RS1:%rs[0-9]+]], [[AX1]], [[BX1]];
+; CHECK-DAG: mov.b32 [[R:%r[0-9]+]], {[[RS0]], [[RS1]]}
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define <2 x bfloat> @test_copysign(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = call <2 x bfloat> @llvm.copysign.f16(<2 x bfloat> %a, <2 x bfloat> %b)
+ ret <2 x bfloat> %r
+}
+
More information about the llvm-commits
mailing list