[clang] 250f2bb - adding bf16 support to NVPTX
Artem Belevich via cfe-commits
cfe-commits at lists.llvm.org
Wed Jun 28 11:57:30 PDT 2023
Author: root
Date: 2023-06-28T11:57:13-07:00
New Revision: 250f2bb2c6a9c288faeb821585e9394697c561d8
URL: https://github.com/llvm/llvm-project/commit/250f2bb2c6a9c288faeb821585e9394697c561d8
DIFF: https://github.com/llvm/llvm-project/commit/250f2bb2c6a9c288faeb821585e9394697c561d8.diff
LOG: adding bf16 support to NVPTX
Currently, bf16 has been scatteredly added to the PTX codegen. This patch aims to complete the set of instructions and code path required to support bf16 data type.
Reviewed By: tra
Differential Revision: https://reviews.llvm.org/D144911
Co-authored-by: Artem Belevich <tra at google.com>
Added:
llvm/test/CodeGen/NVPTX/bf16-instructions.ll
llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-autoupgrade.ll
llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72-autoupgrade.ll
Modified:
clang/include/clang/Basic/BuiltinsNVPTX.def
clang/test/CodeGen/builtins-nvptx.c
clang/test/CodeGenCUDA/bf16.cu
llvm/include/llvm/IR/IntrinsicsNVVM.td
llvm/lib/IR/AutoUpgrade.cpp
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
llvm/lib/Target/NVPTX/NVPTXMCExpr.h
llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
llvm/lib/Target/NVPTX/NVPTXSubtarget.h
llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
llvm/test/CodeGen/NVPTX/convert-sm80.ll
llvm/test/CodeGen/NVPTX/f16-instructions.ll
llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll
llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll
llvm/test/CodeGen/NVPTX/param-load-store.ll
Removed:
################################################################################
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def
index 3275d50a85a4b..f645ad25cbd86 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.def
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -173,16 +173,20 @@ TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
-TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
-TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
+TARGET_BUILTIN(__nvvm_fmin_bf16, "yyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "yyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
+TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "yyy", "",
AND(SM_86, PTX72))
-TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
+TARGET_BUILTIN(__nvvm_fmin_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
-TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
+TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
BUILTIN(__nvvm_fmin_f, "fff", "")
BUILTIN(__nvvm_fmin_ftz_f, "fff", "")
@@ -215,16 +219,20 @@ TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
-TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
-TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
+TARGET_BUILTIN(__nvvm_fmax_bf16, "yyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "yyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
+TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "yyy", "",
AND(SM_86, PTX72))
-TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
+TARGET_BUILTIN(__nvvm_fmax_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
-TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
+TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
AND(SM_86, PTX72))
BUILTIN(__nvvm_fmax_f, "fff", "")
BUILTIN(__nvvm_fmax_ftz_f, "fff", "")
@@ -352,10 +360,10 @@ TARGET_BUILTIN(__nvvm_fma_rn_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
TARGET_BUILTIN(__nvvm_fma_rn_ftz_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
TARGET_BUILTIN(__nvvm_fma_rn_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fma_rn_ftz_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fma_rn_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
-TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_bf16, "yyyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "yyyy", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
BUILTIN(__nvvm_fma_rn_ftz_f, "ffff", "")
BUILTIN(__nvvm_fma_rn_f, "ffff", "")
BUILTIN(__nvvm_fma_rz_ftz_f, "ffff", "")
@@ -543,20 +551,20 @@ BUILTIN(__nvvm_ull2d_rp, "dULLi", "")
BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "")
BUILTIN(__nvvm_f2h_rn, "Usf", "")
-TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "V2yff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "V2yff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "V2yff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "V2yff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rn, "yf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "yf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rz, "yf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
@@ -1024,10 +1032,10 @@ TARGET_BUILTIN(__nvvm_cp_async_wait_all, "v", "", AND(SM_80,PTX70))
// bf16, bf16x2 abs, neg
-TARGET_BUILTIN(__nvvm_abs_bf16, "UsUs", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_abs_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_neg_bf16, "UsUs", "", AND(SM_80,PTX70))
-TARGET_BUILTIN(__nvvm_neg_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_abs_bf16, "yy", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_abs_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_neg_bf16, "yy", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_neg_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_mapa, "v*v*i", "", AND(SM_90, PTX78))
TARGET_BUILTIN(__nvvm_mapa_shared_cluster, "v*3v*3i", "", AND(SM_90, PTX78))
diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c
index 75cb6835049c6..353f3ebb608c2 100644
--- a/clang/test/CodeGen/builtins-nvptx.c
+++ b/clang/test/CodeGen/builtins-nvptx.c
@@ -899,13 +899,13 @@ __device__ void nvvm_async_copy(__attribute__((address_space(3))) void* dst, __a
// CHECK-LABEL: nvvm_cvt_sm80
__device__ void nvvm_cvt_sm80() {
#if __CUDA_ARCH__ >= 800
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn(1, 1);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rn_relu(1, 1);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz(1, 1);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2bf16x2_rz_relu(1, 1);
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
@@ -917,13 +917,13 @@ __device__ void nvvm_cvt_sm80() {
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
__nvvm_ff2f16x2_rz_relu(1, 1);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
__nvvm_f2bf16_rn(1);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
__nvvm_f2bf16_rn_relu(1);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
__nvvm_f2bf16_rz(1);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
__nvvm_f2bf16_rz_relu(1);
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
@@ -932,32 +932,32 @@ __device__ void nvvm_cvt_sm80() {
// CHECK: ret void
}
+#define NAN32 0x7FBFFFFF
+#define NAN16 (__bf16)0x7FBF
+#define BF16 (__bf16)0.1f
+#define BF16_2 (__bf16)0.2f
+#define NANBF16 (__bf16)0xFFC1
+#define BF16X2 {(__bf16)0.1f, (__bf16)0.1f}
+#define BF16X2_2 {(__bf16)0.2f, (__bf16)0.2f}
+#define NANBF16X2 {NANBF16, NANBF16}
+
// CHECK-LABEL: nvvm_abs_neg_bf16_bf16x2_sm80
__device__ void nvvm_abs_neg_bf16_bf16x2_sm80() {
#if __CUDA_ARCH__ >= 800
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.abs.bf16(i16 -1)
- __nvvm_abs_bf16(0xFFFF);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.abs.bf16x2(i32 -1)
- __nvvm_abs_bf16x2(0xFFFFFFFF);
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.abs.bf16(bfloat 0xR3DCD)
+ __nvvm_abs_bf16(BF16);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
+ __nvvm_abs_bf16x2(BF16X2);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.neg.bf16(i16 -1)
- __nvvm_neg_bf16(0xFFFF);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.neg.bf16x2(i32 -1)
- __nvvm_neg_bf16x2(0xFFFFFFFF);
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.neg.bf16(bfloat 0xR3DCD)
+ __nvvm_neg_bf16(BF16);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
+ __nvvm_neg_bf16x2(BF16X2);
#endif
// CHECK: ret void
}
-#define NAN32 0x7FBFFFFF
-#define NAN16 0x7FBF
-#define BF16 0x1234
-#define BF16_2 0x4321
-#define NANBF16 0xFFC1
-#define BF16X2 0x12341234
-#define BF16X2_2 0x32343234
-#define NANBF16X2 0xFFC1FFC1
-
// CHECK-LABEL: nvvm_min_max_sm80
__device__ void nvvm_min_max_sm80() {
#if __CUDA_ARCH__ >= 800
@@ -967,14 +967,22 @@ __device__ void nvvm_min_max_sm80() {
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmin.ftz.nan.f
__nvvm_fmin_ftz_nan_f(0.1f, (float)NAN32);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.bf16
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.bf16
__nvvm_fmin_bf16(BF16, BF16_2);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.nan.bf16
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.bf16
+ __nvvm_fmin_ftz_bf16(BF16, BF16_2);
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.nan.bf16
__nvvm_fmin_nan_bf16(BF16, NANBF16);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.bf16x2
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.nan.bf16
+ __nvvm_fmin_ftz_nan_bf16(BF16, NANBF16);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.bf16x2
__nvvm_fmin_bf16x2(BF16X2, BF16X2_2);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.nan.bf16x2
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.bf16x2
+ __nvvm_fmin_ftz_bf16x2(BF16X2, BF16X2_2);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2
__nvvm_fmin_nan_bf16x2(BF16X2, NANBF16X2);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.nan.bf16x2
+ __nvvm_fmin_ftz_nan_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
__nvvm_fmax_nan_f(0.1f, 0.11f);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
@@ -984,14 +992,22 @@ __device__ void nvvm_min_max_sm80() {
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
__nvvm_fmax_ftz_nan_f(0.1f, (float)NAN32);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.bf16
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.bf16
__nvvm_fmax_bf16(BF16, BF16_2);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.nan.bf16
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.bf16
+ __nvvm_fmax_ftz_bf16(BF16, BF16_2);
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.nan.bf16
__nvvm_fmax_nan_bf16(BF16, NANBF16);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.bf16x2
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.nan.bf16
+ __nvvm_fmax_ftz_nan_bf16(BF16, NANBF16);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.bf16x2
__nvvm_fmax_bf16x2(BF16X2, BF16X2_2);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.nan.bf16x2
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.bf16x2
+ __nvvm_fmax_ftz_bf16x2(BF16X2, BF16X2_2);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2
__nvvm_fmax_nan_bf16x2(NANBF16X2, BF16X2);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.nan.bf16x2
+ __nvvm_fmax_ftz_nan_bf16x2(NANBF16X2, BF16X2);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
@@ -1004,14 +1020,14 @@ __device__ void nvvm_min_max_sm80() {
// CHECK-LABEL: nvvm_fma_bf16_bf16x2_sm80
__device__ void nvvm_fma_bf16_bf16x2_sm80() {
#if __CUDA_ARCH__ >= 800
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.bf16
- __nvvm_fma_rn_bf16(0x1234, 0x7FBF, 0x1234);
- // CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.relu.bf16
- __nvvm_fma_rn_relu_bf16(0x1234, 0x7FBF, 0x1234);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.bf16x2
- __nvvm_fma_rn_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
- // CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.relu.bf16x2
- __nvvm_fma_rn_relu_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.bf16
+ __nvvm_fma_rn_bf16(BF16, BF16_2, BF16_2);
+ // CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.relu.bf16
+ __nvvm_fma_rn_relu_bf16(BF16, BF16_2, BF16_2);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2
+ __nvvm_fma_rn_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
+ // CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2
+ __nvvm_fma_rn_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
#endif
// CHECK: ret void
}
@@ -1020,13 +1036,13 @@ __device__ void nvvm_fma_bf16_bf16x2_sm80() {
__device__ void nvvm_min_max_sm86() {
#if __CUDA_ARCH__ >= 860
- // CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.xorsign.abs.bf16
+ // CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.xorsign.abs.bf16
__nvvm_fmin_xorsign_abs_bf16(BF16, BF16_2);
- // CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16
+ // CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16
__nvvm_fmin_nan_xorsign_abs_bf16(BF16, NANBF16);
- // CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2
+ // CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2
__nvvm_fmin_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
- // CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
+ // CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
__nvvm_fmin_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.xorsign.abs.f
__nvvm_fmin_xorsign_abs_f(-0.1f, 0.1f);
@@ -1037,13 +1053,13 @@ __device__ void nvvm_min_max_sm86() {
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f
__nvvm_fmin_ftz_nan_xorsign_abs_f(-0.1f, (float)NAN32);
- // CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.xorsign.abs.bf16
+ // CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.xorsign.abs.bf16
__nvvm_fmax_xorsign_abs_bf16(BF16, BF16_2);
- // CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16
+ // CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16
__nvvm_fmax_nan_xorsign_abs_bf16(BF16, NANBF16);
- // CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2
+ // CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2
__nvvm_fmax_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
- // CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
+ // CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
__nvvm_fmax_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmax.xorsign.abs.f
__nvvm_fmax_xorsign_abs_f(-0.1f, 0.1f);
diff --git a/clang/test/CodeGenCUDA/bf16.cu b/clang/test/CodeGenCUDA/bf16.cu
index 32082904c4d81..3c443420dbd36 100644
--- a/clang/test/CodeGenCUDA/bf16.cu
+++ b/clang/test/CodeGenCUDA/bf16.cu
@@ -8,7 +8,7 @@
// CHECK-LABEL: .visible .func _Z8test_argPDF16bDF16b(
// CHECK: .param .b64 _Z8test_argPDF16bDF16b_param_0,
-// CHECK: .param .b16 _Z8test_argPDF16bDF16b_param_1
+// CHECK: .param .align 2 .b8 _Z8test_argPDF16bDF16b_param_1[2]
//
__device__ void test_arg(__bf16 *out, __bf16 in) {
// CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
@@ -20,8 +20,8 @@ __device__ void test_arg(__bf16 *out, __bf16 in) {
}
-// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retDF16b(
-// CHECK: .param .b16 _Z8test_retDF16b_param_0
+// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z8test_retDF16b(
+// CHECK: .param .align 2 .b8 _Z8test_retDF16b_param_0[2]
__device__ __bf16 test_ret( __bf16 in) {
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_retDF16b_param_0];
return in;
@@ -31,12 +31,12 @@ __device__ __bf16 test_ret( __bf16 in) {
__device__ __bf16 external_func( __bf16 in);
-// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callDF16b(
-// CHECK: .param .b16 _Z9test_callDF16b_param_0
+// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z9test_callDF16b(
+// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2]
__device__ __bf16 test_call( __bf16 in) {
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
// CHECK: st.param.b16 [param0+0], %[[R]];
-// CHECK: .param .b32 retval0;
+// CHECK: .param .align 2 .b8 retval0[2];
// CHECK: call.uni (retval0),
// CHECK-NEXT: _Z13external_funcDF16b,
// CHECK-NEXT: (
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7e4ad18cf5321..914f6c36a3e4a 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -595,19 +595,21 @@ let TargetPrefix = "nvvm" in {
[IntrNoMem, IntrSpeculatable, Commutative]>;
}
- foreach variant = ["_bf16", "_nan_bf16", "_xorsign_abs_bf16",
- "_nan_xorsign_abs_bf16"] in {
+ foreach variant = ["_bf16", "_ftz_bf16", "_nan_bf16", "_ftz_nan_bf16",
+ "_xorsign_abs_bf16", "_ftz_xorsign_abs_bf16", "_nan_xorsign_abs_bf16",
+ "_ftz_nan_xorsign_abs_bf16"] in {
def int_nvvm_f # operation # variant :
ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
- DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
+ DefaultAttrsIntrinsic<[llvm_bfloat_ty], [llvm_bfloat_ty, llvm_bfloat_ty],
[IntrNoMem, IntrSpeculatable, Commutative]>;
}
- foreach variant = ["_bf16x2", "_nan_bf16x2", "_xorsign_abs_bf16x2",
- "_nan_xorsign_abs_bf16x2"] in {
+ foreach variant = ["_bf16x2", "_ftz_bf16x2", "_nan_bf16x2",
+ "_ftz_nan_bf16x2", "_xorsign_abs_bf16x2", "_ftz_xorsign_abs_bf16x2",
+ "_nan_xorsign_abs_bf16x2", "_ftz_nan_xorsign_abs_bf16x2"] in {
def int_nvvm_f # operation # variant :
ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
- DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
+ DefaultAttrsIntrinsic<[llvm_v2bf16_ty], [llvm_v2bf16_ty, llvm_v2bf16_ty],
[IntrNoMem, IntrSpeculatable, Commutative]>;
}
}
@@ -774,10 +776,10 @@ let TargetPrefix = "nvvm" in {
foreach unary = ["abs", "neg"] in {
def int_nvvm_ # unary # _bf16 :
ClangBuiltin<!strconcat("__nvvm_", unary, "_bf16")>,
- DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty], [IntrNoMem]>;
+ DefaultAttrsIntrinsic<[llvm_bfloat_ty], [llvm_bfloat_ty], [IntrNoMem]>;
def int_nvvm_ # unary # _bf16x2 :
ClangBuiltin<!strconcat("__nvvm_", unary, "_bf16x2")>,
- DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem]>;
+ DefaultAttrsIntrinsic<[llvm_v2bf16_ty], [llvm_v2bf16_ty], [IntrNoMem]>;
}
//
@@ -870,17 +872,19 @@ let TargetPrefix = "nvvm" in {
[IntrNoMem, IntrSpeculatable]>;
}
- foreach variant = ["_rn_bf16", "_rn_relu_bf16"] in {
+ foreach variant = ["_rn_bf16", "_rn_ftz_bf16", "_rn_sat_bf16",
+ "_rn_ftz_sat_bf16", "_rn_relu_bf16", "_rn_ftz_relu_bf16"] in {
def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
- DefaultAttrsIntrinsic<[llvm_i16_ty],
- [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty],
+ DefaultAttrsIntrinsic<[llvm_bfloat_ty],
+ [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty],
[IntrNoMem, IntrSpeculatable]>;
}
- foreach variant = ["_rn_bf16x2", "_rn_relu_bf16x2"] in {
+ foreach variant = ["_rn_bf16x2", "_rn_ftz_bf16x2", "_rn_sat_bf16x2",
+ "_rn_ftz_sat_bf16x2", "_rn_relu_bf16x2", "_rn_ftz_relu_bf16x2"] in {
def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
- DefaultAttrsIntrinsic<[llvm_i32_ty],
- [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
+ DefaultAttrsIntrinsic<[llvm_v2bf16_ty],
+ [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty],
[IntrNoMem, IntrSpeculatable]>;
}
@@ -1232,14 +1236,19 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_f2h_rn : ClangBuiltin<"__nvvm_f2h_rn">,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
+ def int_nvvm_bf2h_rn_ftz : ClangBuiltin<"__nvvm_bf2h_rn_ftz">,
+ DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>;
+ def int_nvvm_bf2h_rn : ClangBuiltin<"__nvvm_bf2h_rn">,
+ DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>;
+
def int_nvvm_ff2bf16x2_rn : ClangBuiltin<"__nvvm_ff2bf16x2_rn">,
- Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ Intrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff2bf16x2_rn_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
- Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ Intrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff2bf16x2_rz : ClangBuiltin<"__nvvm_ff2bf16x2_rz">,
- Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ Intrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff2bf16x2_rz_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rz_relu">,
- Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+ Intrinsic<[llvm_v2bf16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
def int_nvvm_ff2f16x2_rn : ClangBuiltin<"__nvvm_ff2f16x2_rn">,
Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
@@ -1251,13 +1260,13 @@ let TargetPrefix = "nvvm" in {
Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_f2bf16_rn : ClangBuiltin<"__nvvm_f2bf16_rn">,
- Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ Intrinsic<[llvm_bfloat_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_f2bf16_rn_relu : ClangBuiltin<"__nvvm_f2bf16_rn_relu">,
- Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ Intrinsic<[llvm_bfloat_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_f2bf16_rz : ClangBuiltin<"__nvvm_f2bf16_rz">,
- Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ Intrinsic<[llvm_bfloat_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_f2bf16_rz_relu : ClangBuiltin<"__nvvm_f2bf16_rz_relu">,
- Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+ Intrinsic<[llvm_bfloat_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index f53f32f749fef..d26f39b16bb35 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -29,6 +29,7 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
#include "llvm/IR/IntrinsicsARM.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/IntrinsicsRISCV.h"
#include "llvm/IR/IntrinsicsWebAssembly.h"
#include "llvm/IR/IntrinsicsX86.h"
@@ -591,6 +592,71 @@ static bool UpgradeX86IntrinsicFunction(Function *F, StringRef Name,
return false;
}
+static Intrinsic::ID ShouldUpgradeNVPTXBF16Intrinsic(StringRef Name) {
+ return StringSwitch<Intrinsic::ID>(Name)
+ .Case("abs.bf16", Intrinsic::nvvm_abs_bf16)
+ .Case("abs.bf16x2", Intrinsic::nvvm_abs_bf16x2)
+ .Case("fma.rn.bf16", Intrinsic::nvvm_fma_rn_bf16)
+ .Case("fma.rn.bf16x2", Intrinsic::nvvm_fma_rn_bf16x2)
+ .Case("fma.rn.ftz_bf16", Intrinsic::nvvm_fma_rn_ftz_bf16)
+ .Case("fma.rn.ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2)
+ .Case("fma.rn.ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16)
+ .Case("fma.rn.ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2)
+ .Case("fma.rn.ftz_sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16)
+ .Case("fma.rn.ftz_sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2)
+ .Case("fma.rn.relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16)
+ .Case("fma.rn.relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2)
+ .Case("fma.rn.sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16)
+ .Case("fma.rn.sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2)
+ .Case("fmax.bf16", Intrinsic::nvvm_fmax_bf16)
+ .Case("fmax.bf16x2", Intrinsic::nvvm_fmax_bf16x2)
+ .Case("fmax.ftz.bf16", Intrinsic::nvvm_fmax_ftz_bf16)
+ .Case("fmax.ftz.bf16x2", Intrinsic::nvvm_fmax_ftz_bf16x2)
+ .Case("fmax.ftz.nan.bf16", Intrinsic::nvvm_fmax_ftz_nan_bf16)
+ .Case("fmax.ftz.nan.bf16x2", Intrinsic::nvvm_fmax_ftz_nan_bf16x2)
+ .Case("fmax.ftz.nan.xorsign.abs.bf16",
+ Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16)
+ .Case("fmax.ftz.nan.xorsign.abs.bf16x2",
+ Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16x2)
+ .Case("fmax.ftz.xorsign.abs.bf16",
+ Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16)
+ .Case("fmax.ftz.xorsign.abs.bf16x2",
+ Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16x2)
+ .Case("fmax.nan.bf16", Intrinsic::nvvm_fmax_nan_bf16)
+ .Case("fmax.nan.bf16x2", Intrinsic::nvvm_fmax_nan_bf16x2)
+ .Case("fmax.nan.xorsign.abs.bf16",
+ Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16)
+ .Case("fmax.nan.xorsign.abs.bf16x2",
+ Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16x2)
+ .Case("fmax.xorsign.abs.bf16", Intrinsic::nvvm_fmax_xorsign_abs_bf16)
+ .Case("fmax.xorsign.abs.bf16x2", Intrinsic::nvvm_fmax_xorsign_abs_bf16x2)
+ .Case("fmin.bf16", Intrinsic::nvvm_fmin_bf16)
+ .Case("fmin.bf16x2", Intrinsic::nvvm_fmin_bf16x2)
+ .Case("fmin.ftz.bf16", Intrinsic::nvvm_fmin_ftz_bf16)
+ .Case("fmin.ftz.bf16x2", Intrinsic::nvvm_fmin_ftz_bf16x2)
+ .Case("fmin.ftz.nan_bf16", Intrinsic::nvvm_fmin_ftz_nan_bf16)
+ .Case("fmin.ftz.nan_bf16x2", Intrinsic::nvvm_fmin_ftz_nan_bf16x2)
+ .Case("fmin.ftz.nan.xorsign.abs.bf16",
+ Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16)
+ .Case("fmin.ftz.nan.xorsign.abs.bf16x2",
+ Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16x2)
+ .Case("fmin.ftz.xorsign.abs.bf16",
+ Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16)
+ .Case("fmin.ftz.xorsign.abs.bf16x2",
+ Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16x2)
+ .Case("fmin.nan.bf16", Intrinsic::nvvm_fmin_nan_bf16)
+ .Case("fmin.nan.bf16x2", Intrinsic::nvvm_fmin_nan_bf16x2)
+ .Case("fmin.nan.xorsign.abs.bf16",
+ Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16)
+ .Case("fmin.nan.xorsign.abs.bf16x2",
+ Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16x2)
+ .Case("fmin.xorsign.abs.bf16", Intrinsic::nvvm_fmin_xorsign_abs_bf16)
+ .Case("fmin.xorsign.abs.bf16x2", Intrinsic::nvvm_fmin_xorsign_abs_bf16x2)
+ .Case("neg.bf16", Intrinsic::nvvm_neg_bf16)
+ .Case("neg.bf16x2", Intrinsic::nvvm_neg_bf16x2)
+ .Default(Intrinsic::not_intrinsic);
+}
+
static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) {
assert(F && "Illegal to upgrade a non-existent Function.");
@@ -1082,7 +1148,12 @@ static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) {
{F->getReturnType()});
return true;
}
-
+ IID = ShouldUpgradeNVPTXBF16Intrinsic(Name);
+ if (IID != Intrinsic::not_intrinsic &&
+ !F->getReturnType()->getScalarType()->isBFloatTy()) {
+ NewFn = nullptr;
+ return true;
+ }
// The following nvvm intrinsics correspond exactly to an LLVM idiom, but
// not to an intrinsic alone. We expand them in UpgradeIntrinsicCall.
//
@@ -4049,11 +4120,34 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
{Arg->getType()}),
Arg, "ctpop");
Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc");
- } else if (IsNVVM && Name == "h2f") {
- Rep = Builder.CreateCall(Intrinsic::getDeclaration(
+ } else if (IsNVVM) {
+ if (Name == "h2f") {
+ Rep =
+ Builder.CreateCall(Intrinsic::getDeclaration(
F->getParent(), Intrinsic::convert_from_fp16,
{Builder.getFloatTy()}),
CI->getArgOperand(0), "h2f");
+ } else {
+ Intrinsic::ID IID = ShouldUpgradeNVPTXBF16Intrinsic(Name);
+ if (IID != Intrinsic::not_intrinsic &&
+ !F->getReturnType()->getScalarType()->isBFloatTy()) {
+ rename(F);
+ NewFn = Intrinsic::getDeclaration(F->getParent(), IID);
+ SmallVector<Value *, 2> Args;
+ for (size_t I = 0; I < NewFn->arg_size(); ++I) {
+ Value *Arg = CI->getArgOperand(I);
+ Type *OldType = Arg->getType();
+ Type *NewType = NewFn->getArg(I)->getType();
+ Args.push_back((OldType->isIntegerTy() &&
+ NewType->getScalarType()->isBFloatTy())
+ ? Builder.CreateBitCast(Arg, NewType)
+ : Arg);
+ }
+ Rep = Builder.CreateCall(NewFn, Args);
+ if (F->getReturnType()->isIntegerTy())
+ Rep = Builder.CreateBitCast(Rep, F->getReturnType());
+ }
+ }
} else if (IsARM) {
Rep = UpgradeARMIntrinsicCall(Name, CI, F, Builder);
} else if (IsAMDGCN) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 179306b59b0ff..fd032676dcf64 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -272,6 +272,10 @@ bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
break;
+ case Type::BFloatTyID:
+ MCOp = MCOperand::createExpr(
+ NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
+ break;
case Type::FloatTyID:
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
@@ -330,6 +334,11 @@ MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
return MCOperand::createExpr(Expr);
}
+static bool ShouldPassAsArray(Type *Ty) {
+ return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
+ Ty->isHalfTy() || Ty->isBFloatTy();
+}
+
void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
const DataLayout &DL = getDataLayout();
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
@@ -341,11 +350,11 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
if (Ty->getTypeID() == Type::VoidTyID)
return;
-
O << " (";
if (isABI) {
- if (Ty->isFloatingPointTy() || (Ty->isIntegerTy() && !Ty->isIntegerTy(128))) {
+ if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
+ !ShouldPassAsArray(Ty)) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
size = ITy->getBitWidth();
@@ -353,16 +362,12 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
assert(Ty->isFloatingPointTy() && "Floating point type expected here");
size = Ty->getPrimitiveSizeInBits();
}
- // PTX ABI requires all scalar return values to be at least 32
- // bits in size. fp16 normally uses .b16 as its storage type in
- // PTX, so its size must be adjusted here, too.
size = promoteScalarArgumentSize(size);
-
O << ".param .b" << size << " func_retval0";
} else if (isa<PointerType>(Ty)) {
O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
<< " func_retval0";
- } else if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
+ } else if (ShouldPassAsArray(Ty)) {
unsigned totalsz = DL.getTypeAllocSize(Ty);
unsigned retAlignment = 0;
if (!getAlign(*F, 0, retAlignment))
@@ -1355,8 +1360,10 @@ NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
}
break;
}
+ case Type::BFloatTyID:
case Type::HalfTyID:
- // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
+ // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
+ // PTX assembly.
return "b16";
case Type::FloatTyID:
return "f32";
@@ -1510,7 +1517,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
};
if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
- if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
+ if (ShouldPassAsArray(Ty)) {
// Just print .param .align <a> .b8 .param[size];
// <a> = optimal alignment for the element type; always multiple of
// PAL.getParamAlignment
@@ -1581,12 +1588,7 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
} else if (PTy) {
assert(PTySizeInBits && "Invalid pointer size");
sz = PTySizeInBits;
- } else if (Ty->isHalfTy())
- // PTX ABI requires all scalar parameters to be at least 32
- // bits in size. fp16 normally uses .b16 as its storage type
- // in PTX, so its size must be adjusted here, too.
- sz = 32;
- else
+ } else
sz = Ty->getPrimitiveSizeInBits();
if (isABI)
O << "\t.param .b" << sz << " ";
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index cb8a1867c44f0..db69431cceefc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -500,7 +500,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
SelectAddrSpaceCast(N);
return;
case ISD::ConstantFP:
- if (tryConstantFP16(N))
+ if (tryConstantFP(N))
return;
break;
default:
@@ -524,15 +524,17 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
}
}
-// There's no way to specify FP16 immediates in .f16 ops, so we have to
-// load them into an .f16 register first.
-bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) {
- if (N->getValueType(0) != MVT::f16)
+// There's no way to specify FP16 and BF16 immediates in .(b)f16 ops, so we
+// have to load them into an .(b)f16 register first.
+bool NVPTXDAGToDAGISel::tryConstantFP(SDNode *N) {
+ if (N->getValueType(0) != MVT::f16 && N->getValueType(0) != MVT::bf16)
return false;
SDValue Val = CurDAG->getTargetConstantFP(
- cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::f16);
- SDNode *LoadConstF16 =
- CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val);
+ cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), N->getValueType(0));
+ SDNode *LoadConstF16 = CurDAG->getMachineNode(
+ (N->getValueType(0) == MVT::f16 ? NVPTX::LOAD_CONST_F16
+ : NVPTX::LOAD_CONST_BF16),
+ SDLoc(N), N->getValueType(0), Val);
ReplaceNode(N, LoadConstF16);
return true;
}
@@ -612,9 +614,9 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
// We only care about f16x2 as it's the only real vector type we
// need to deal with.
- if (Vector.getSimpleValueType() != MVT::v2f16)
+ MVT VT = Vector.getSimpleValueType();
+ if (!(VT == MVT::v2f16 || VT == MVT::v2bf16))
return false;
-
// Find and record all uses of this vector that extract element 0 or 1.
SmallVector<SDNode *, 4> E0, E1;
for (auto *U : Vector.getNode()->uses()) {
@@ -640,8 +642,9 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
// Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
// into f16,f16 SplitF16x2(V)
- SDNode *ScatterOp = CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N),
- MVT::f16, MVT::f16, Vector);
+ MVT EltVT = VT.getVectorElementType();
+ SDNode *ScatterOp =
+ CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
for (auto *Node : E0)
ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
for (auto *Node : E1)
@@ -1258,10 +1261,11 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of f16 are loaded/stored as multiples of v2f16 elements.
- if (EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) {
- assert(NumElts % 2 == 0 && "Vector must have even number of elements");
- EltVT = MVT::v2f16;
- NumElts /= 2;
+ if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
+ (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16)) {
+ assert(NumElts % 2 == 0 && "Vector must have even number of elements");
+ EltVT = N->getValueType(0);
+ NumElts /= 2;
}
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 2a8ee5089ca02..25bb73cd55361 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -71,7 +71,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryTextureIntrinsic(SDNode *N);
bool trySurfaceIntrinsic(SDNode *N);
bool tryBFE(SDNode *N);
- bool tryConstantFP16(SDNode *N);
+ bool tryConstantFP(SDNode *N);
bool SelectSETP_F16X2(SDNode *N);
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 98a877cbafec9..fa050bcdc3412 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -149,6 +149,14 @@ static bool IsPTXVectorType(MVT VT) {
}
}
+static bool Isv2f16Orv2bf16Type(EVT VT) {
+ return (VT == MVT::v2f16 || VT == MVT::v2bf16);
+}
+
+static bool Isf16Orbf16Type(MVT VT) {
+ return (VT.SimpleTy == MVT::f16 || VT.SimpleTy == MVT::bf16);
+}
+
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
/// into their primitive components.
@@ -199,7 +207,7 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
// Vectors with an even number of f16 elements will be passed to
// us as an array of v2f16/v2bf16 elements. We must match this so we
// stay in sync with Ins/Outs.
- if ((EltVT == MVT::f16 || EltVT == MVT::bf16) && NumElts % 2 == 0) {
+ if ((Isf16Orbf16Type(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
NumElts /= 2;
}
@@ -404,6 +412,21 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
};
+ auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
+ LegalizeAction NoBF16Action) {
+ bool IsOpSupported = STI.hasBF16Math();
+ // Few instructions are available on sm_90 only
+ switch(Op) {
+ case ISD::FADD:
+ case ISD::FMUL:
+ case ISD::FSUB:
+ IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
+ break;
+ }
+ setOperationAction(
+ Op, VT, IsOpSupported ? Action : NoBF16Action);
+ };
+
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
@@ -426,6 +449,16 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
+ // Conversion to/from BFP16/BFP16x2 is always legal.
+ setOperationAction(ISD::SINT_TO_FP, MVT::bf16, Legal);
+ setOperationAction(ISD::FP_TO_SINT, MVT::bf16, Legal);
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Custom);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2bf16, Custom);
+ setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2bf16, Expand);
+ setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2bf16, Expand);
+
+ setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
+ setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
// Operations not directly supported by NVPTX.
for (MVT VT : {MVT::f16, MVT::v2f16, MVT::f32, MVT::f64, MVT::i1, MVT::i8,
MVT::i16, MVT::i32, MVT::i64}) {
@@ -482,17 +515,25 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// Turn FP extload into load/fpextend
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
// Turn FP truncstore into trunc + store.
// FIXME: vector types should also be expanded
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
+ setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
+ setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
// PTX does not support load / store predicate registers
@@ -569,9 +610,9 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
ISD::SREM, ISD::UREM});
- // setcc for f16x2 needs special handling to prevent legalizer's
- // attempt to scalarize it due to v2i1 not being legal.
- if (STI.allowFP16Math())
+ // setcc for f16x2 and bf16x2 needs special handling to prevent
+ // legalizer's attempt to scalarize it due to v2i1 not being legal.
+ if (STI.allowFP16Math() || STI.hasBF16Math())
setTargetDAGCombine(ISD::SETCC);
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -583,6 +624,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
setFP16OperationAction(Op, MVT::f16, Legal, Promote);
setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
+ setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
+ setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
+ // bf16 must be promoted to f32.
+ if (getOperationAction(Op, MVT::bf16) == Promote)
+ AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
// f16/f16x2 neg was introduced in PTX 60, SM_53.
@@ -593,19 +639,25 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::FNEG, VT,
IsFP16FP16x2NegAvailable ? Legal : Expand);
+ setBF16OperationAction(ISD::FNEG, MVT::bf16, Legal, Expand);
+ setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
// (would be) Library functions.
// These map to conversion instructions for scalar FP types.
for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
ISD::FROUNDEVEN, ISD::FTRUNC}) {
+ setOperationAction(Op, MVT::bf16, Legal);
setOperationAction(Op, MVT::f16, Legal);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
+ setOperationAction(Op, MVT::v2bf16, Expand);
}
setOperationAction(ISD::FROUND, MVT::f16, Promote);
setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
+ setOperationAction(ISD::FROUND, MVT::bf16, Promote);
+ setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
setOperationAction(ISD::FROUND, MVT::f32, Custom);
setOperationAction(ISD::FROUND, MVT::f64, Custom);
@@ -613,6 +665,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// 'Expand' implements FCOPYSIGN without calling an external library.
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
@@ -622,9 +676,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
for (const auto &Op :
{ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FABS}) {
setOperationAction(Op, MVT::f16, Promote);
+ setOperationAction(Op, MVT::bf16, Promote);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
+ setOperationAction(Op, MVT::v2bf16, Expand);
}
// max.f16, max.f16x2 and max.NaN are supported on sm_80+.
auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
@@ -633,14 +689,18 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
};
for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote);
+ setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
+ setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
}
for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand);
+ setFP16OperationAction(Op, MVT::bf16, Legal, Expand);
setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
+ setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
}
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
@@ -1258,7 +1318,7 @@ NVPTXTargetLowering::getPreferredVectorAction(MVT VT) const {
if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
VT.getScalarType() == MVT::i1)
return TypeSplitVector;
- if (VT == MVT::v2f16)
+ if (Isv2f16Orv2bf16Type(VT))
return TypeLegal;
return TargetLoweringBase::getPreferredVectorAction(VT);
}
@@ -1321,6 +1381,11 @@ NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(NVPTXISD::Wrapper, dl, PtrVT, Op);
}
+static bool IsTypePassedAsArray(const Type *Ty) {
+ return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
+ Ty->isHalfTy() || Ty->isBFloatTy();
+}
+
std::string NVPTXTargetLowering::getPrototype(
const DataLayout &DL, Type *retTy, const ArgListTy &Args,
const SmallVectorImpl<ISD::OutputArg> &Outs, MaybeAlign retAlignment,
@@ -1341,7 +1406,8 @@ std::string NVPTXTargetLowering::getPrototype(
O << "()";
} else {
O << "(";
- if (retTy->isFloatingPointTy() || (retTy->isIntegerTy() && !retTy->isIntegerTy(128))) {
+ if ((retTy->isFloatingPointTy() || retTy->isIntegerTy()) &&
+ !IsTypePassedAsArray(retTy)) {
unsigned size = 0;
if (auto *ITy = dyn_cast<IntegerType>(retTy)) {
size = ITy->getBitWidth();
@@ -1358,8 +1424,7 @@ std::string NVPTXTargetLowering::getPrototype(
O << ".param .b" << size << " _";
} else if (isa<PointerType>(retTy)) {
O << ".param .b" << PtrVT.getSizeInBits() << " _";
- } else if (retTy->isAggregateType() || retTy->isVectorTy() ||
- retTy->isIntegerTy(128)) {
+ } else if (IsTypePassedAsArray(retTy)) {
O << ".param .align " << (retAlignment ? retAlignment->value() : 0)
<< " .b8 _[" << DL.getTypeAllocSize(retTy) << "]";
} else {
@@ -1381,7 +1446,7 @@ std::string NVPTXTargetLowering::getPrototype(
first = false;
if (!Outs[OIdx].Flags.isByVal()) {
- if (Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128)) {
+ if (IsTypePassedAsArray(Ty)) {
unsigned ParamAlign = 0;
const CallInst *CallI = cast<CallInst>(&CB);
// +1 because index 0 is reserved for return type alignment
@@ -1408,13 +1473,9 @@ std::string NVPTXTargetLowering::getPrototype(
sz = promoteScalarArgumentSize(sz);
} else if (isa<PointerType>(Ty)) {
sz = PtrVT.getSizeInBits();
- } else if (Ty->isHalfTy())
- // PTX ABI requires all scalar parameters to be at least 32
- // bits in size. fp16 normally uses .b16 as its storage type
- // in PTX, so its size must be adjusted here, too.
- sz = 32;
- else
+ } else {
sz = Ty->getPrimitiveSizeInBits();
+ }
O << ".param .b" << sz << " ";
O << "_";
continue;
@@ -1577,6 +1638,7 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
bool NeedAlign; // Does argument declaration specify alignment?
+ bool PassAsArray = IsByVal || IsTypePassedAsArray(Ty);
if (IsVAArg) {
if (ParamCount == FirstVAArg) {
SDValue DeclareParamOps[] = {
@@ -1586,10 +1648,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
VADeclareParam = Chain = DAG.getNode(NVPTXISD::DeclareParam, dl,
DeclareParamVTs, DeclareParamOps);
}
- NeedAlign = IsByVal || Ty->isAggregateType() || Ty->isVectorTy() ||
- Ty->isIntegerTy(128);
- } else if (IsByVal || Ty->isAggregateType() || Ty->isVectorTy() ||
- Ty->isIntegerTy(128)) {
+ NeedAlign = PassAsArray;
+ } else if (PassAsArray) {
// declare .param .align <align> .b8 .param<n>[<size>];
SDValue DeclareParamOps[] = {
Chain, DAG.getConstant(ArgAlign.value(), dl, MVT::i32),
@@ -1739,15 +1799,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
ComputeValueVTs(*this, DL, RetTy, resvtparts);
// Declare
- // .param .align 16 .b8 retval0[<size-in-bytes>], or
+ // .param .align N .b8 retval0[<size-in-bytes>], or
// .param .b<size-in-bits> retval0
unsigned resultsz = DL.getTypeAllocSizeInBits(RetTy);
- // Emit ".param .b<size-in-bits> retval0" instead of byte arrays only for
- // these three types to match the logic in
- // NVPTXAsmPrinter::printReturnValStr and NVPTXTargetLowering::getPrototype.
- // Plus, this behavior is consistent with nvcc's.
- if (RetTy->isFloatingPointTy() || RetTy->isPointerTy() ||
- (RetTy->isIntegerTy() && !RetTy->isIntegerTy(128))) {
+ if (!IsTypePassedAsArray(RetTy)) {
resultsz = promoteScalarArgumentSize(resultsz);
SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, dl, MVT::i32),
@@ -2043,7 +2098,7 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
// generates good SASS in both cases.
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
- if (!(Op->getValueType(0) == MVT::v2f16 &&
+ if (!(Isv2f16Orv2bf16Type(Op->getValueType(0)) &&
isa<ConstantFPSDNode>(Op->getOperand(0)) &&
isa<ConstantFPSDNode>(Op->getOperand(1))))
return Op;
@@ -2054,7 +2109,7 @@ SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt();
SDValue Const =
DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
- return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const);
+ return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
}
SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
@@ -2415,7 +2470,7 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
// loads and have to handle it here.
- if (Op.getValueType() == MVT::v2f16) {
+ if (Isv2f16Orv2bf16Type(Op.getValueType())) {
LoadSDNode *Load = cast<LoadSDNode>(Op);
EVT MemVT = Load->getMemoryVT();
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2460,7 +2515,7 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
// stores and have to handle it here.
- if (VT == MVT::v2f16 &&
+ if (Isv2f16Orv2bf16Type(VT) &&
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
VT, *Store->getMemOperand()))
return expandUnalignedStore(Store, DAG);
@@ -2551,7 +2606,7 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
// v8f16 is a special case. PTX doesn't have st.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// store them with st.v4.b32.
- assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+ assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
"Wrong type for the vector.");
Opcode = NVPTXISD::StoreV4;
StoreF16x2 = true;
@@ -2567,11 +2622,12 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
// Combine f16,f16 -> v2f16
NumElts /= 2;
for (unsigned i = 0; i < NumElts; ++i) {
- SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+ SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
DAG.getIntPtrConstant(i * 2, DL));
- SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+ SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
DAG.getIntPtrConstant(i * 2 + 1, DL));
- SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1);
+ EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2);
+ SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1);
Ops.push_back(V2);
}
} else {
@@ -2672,7 +2728,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (theArgs[i]->use_empty()) {
// argument is dead
- if (Ty->isAggregateType() || Ty->isIntegerTy(128)) {
+ if (IsTypePassedAsArray(Ty) && !Ty->isVectorTy()) {
SmallVector<EVT, 16> vtparts;
ComputePTXValueVTs(*this, DAG.getDataLayout(), Ty, vtparts);
@@ -2737,9 +2793,9 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
EVT LoadVT = EltVT;
if (EltVT == MVT::i1)
LoadVT = MVT::i8;
- else if (EltVT == MVT::v2f16)
+ else if (Isv2f16Orv2bf16Type(EltVT))
// getLoad needs a vector type, but it can't handle
- // vectors which contain v2f16 elements. So we must load
+ // vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
LoadVT = MVT::i32;
@@ -2763,8 +2819,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (EltVT == MVT::i1)
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
// v2f16 was loaded as an i32. Now we must bitcast it back.
- else if (EltVT == MVT::v2f16)
- Elt = DAG.getNode(ISD::BITCAST, dl, MVT::v2f16, Elt);
+ else if (Isv2f16Orv2bf16Type(EltVT))
+ Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
// If a promoted integer type is used, truncate down to the original
MVT PromotedVT;
@@ -5194,7 +5250,7 @@ static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
// v8f16 is a special case. PTX doesn't have ld.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// load them with ld.v4.b32.
- assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+ assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
"Unsupported v8 vector type.");
LoadF16x2 = true;
Opcode = NVPTXISD::LoadV4;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 54237af13d8dc..b98f76ed4b38d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -19,6 +19,8 @@ let hasSideEffects = false in {
let OperandType = "OPERAND_IMMEDIATE" in {
def f16imm : Operand<f16>;
+ def bf16imm : Operand<bf16>;
+
}
// List of vector specific properties
@@ -154,6 +156,7 @@ def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70"
def useShortPtr : Predicate<"useShortPointers()">;
def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
+def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">;
// Helper class to aid conversion between ValueType and a matching RegisterClass.
@@ -304,6 +307,31 @@ multiclass F3<string OpcStr, SDNode OpNode> {
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
[(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
Requires<[useFP16Math]>;
+ def bf16rr_ftz :
+ NVPTXInst<(outs Int16Regs:$dst),
+ (ins Int16Regs:$a, Int16Regs:$b),
+ !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
+ [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+ Requires<[hasBF16Math, doF32FTZ]>;
+ def bf16rr :
+ NVPTXInst<(outs Int16Regs:$dst),
+ (ins Int16Regs:$a, Int16Regs:$b),
+ !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
+ [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+ Requires<[hasBF16Math]>;
+
+ def bf16x2rr_ftz :
+ NVPTXInst<(outs Int32Regs:$dst),
+ (ins Int32Regs:$a, Int32Regs:$b),
+ !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
+ [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+ Requires<[hasBF16Math, doF32FTZ]>;
+ def bf16x2rr :
+ NVPTXInst<(outs Int32Regs:$dst),
+ (ins Int32Regs:$a, Int32Regs:$b),
+ !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
+ [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+ Requires<[hasBF16Math]>;
}
// Template for instructions which take three FP args. The
@@ -378,7 +406,31 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
[(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
Requires<[useFP16Math, allowFMA]>;
+ def bf16rr_ftz :
+ NVPTXInst<(outs Int16Regs:$dst),
+ (ins Int16Regs:$a, Int16Regs:$b),
+ !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
+ [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+ Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
+ def bf16rr :
+ NVPTXInst<(outs Int16Regs:$dst),
+ (ins Int16Regs:$a, Int16Regs:$b),
+ !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
+ [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+ Requires<[hasBF16Math, allowFMA]>;
+ def bf16x2rr_ftz :
+ NVPTXInst<(outs Int32Regs:$dst),
+ (ins Int32Regs:$a, Int32Regs:$b),
+ !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
+ [(set (v2bf16 Int32Regs:$dst), (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+ Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
+ def bf16x2rr :
+ NVPTXInst<(outs Int32Regs:$dst),
+ (ins Int32Regs:$a, Int32Regs:$b),
+ !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
+ [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+ Requires<[hasBF16Math, allowFMA]>;
// These have strange names so we don't perturb existing mir tests.
def _rnf64rr :
NVPTXInst<(outs Float64Regs:$dst),
@@ -440,6 +492,30 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
!strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"),
[(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
Requires<[useFP16Math, noFMA]>;
+ def _rnbf16rr_ftz :
+ NVPTXInst<(outs Int16Regs:$dst),
+ (ins Int16Regs:$a, Int16Regs:$b),
+ !strconcat(OpcStr, ".rn.ftz.bf16 \t$dst, $a, $b;"),
+ [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+ Requires<[hasBF16Math, noFMA, doF32FTZ]>;
+ def _rnbf16rr :
+ NVPTXInst<(outs Int16Regs:$dst),
+ (ins Int16Regs:$a, Int16Regs:$b),
+ !strconcat(OpcStr, ".rn.bf16 \t$dst, $a, $b;"),
+ [(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
+ Requires<[hasBF16Math, noFMA]>;
+ def _rnbf16x2rr_ftz :
+ NVPTXInst<(outs Int32Regs:$dst),
+ (ins Int32Regs:$a, Int32Regs:$b),
+ !strconcat(OpcStr, ".rn.ftz.bf16x2 \t$dst, $a, $b;"),
+ [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+ Requires<[hasBF16Math, noFMA, doF32FTZ]>;
+ def _rnbf16x2rr :
+ NVPTXInst<(outs Int32Regs:$dst),
+ (ins Int32Regs:$a, Int32Regs:$b),
+ !strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"),
+ [(set Int32Regs:$dst, (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
+ Requires<[hasBF16Math, noFMA]>;
}
// Template for operations which take two f32 or f64 operands. Provides three
@@ -470,62 +546,86 @@ let hasSideEffects = false in {
// Generate a cvt to the given type from all possible types. Each instance
// takes a CvtMode immediate that defines the conversion mode to use. It can
// be CvtNONE to omit a conversion mode.
- multiclass CVT_FROM_ALL<string FromName, RegisterClass RC> {
+ multiclass CVT_FROM_ALL<string ToType, RegisterClass RC, list<Predicate> Preds = []> {
def _s8 :
NVPTXInst<(outs RC:$dst),
(ins Int16Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".s8 \t$dst, $src;"), []>;
+ ToType, ".s8 \t$dst, $src;"), []>,
+ Requires<Preds>;
def _u8 :
NVPTXInst<(outs RC:$dst),
(ins Int16Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".u8 \t$dst, $src;"), []>;
+ ToType, ".u8 \t$dst, $src;"), []>,
+ Requires<Preds>;
def _s16 :
NVPTXInst<(outs RC:$dst),
(ins Int16Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".s16 \t$dst, $src;"), []>;
+ ToType, ".s16 \t$dst, $src;"), []>,
+ Requires<Preds>;
def _u16 :
NVPTXInst<(outs RC:$dst),
(ins Int16Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".u16 \t$dst, $src;"), []>;
+ ToType, ".u16 \t$dst, $src;"), []>,
+ Requires<Preds>;
def _s32 :
NVPTXInst<(outs RC:$dst),
(ins Int32Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".s32 \t$dst, $src;"), []>;
+ ToType, ".s32 \t$dst, $src;"), []>,
+ Requires<Preds>;
def _u32 :
NVPTXInst<(outs RC:$dst),
(ins Int32Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".u32 \t$dst, $src;"), []>;
+ ToType, ".u32 \t$dst, $src;"), []>,
+ Requires<Preds>;
def _s64 :
NVPTXInst<(outs RC:$dst),
(ins Int64Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".s64 \t$dst, $src;"), []>;
+ ToType, ".s64 \t$dst, $src;"), []>,
+ Requires<Preds>;
def _u64 :
NVPTXInst<(outs RC:$dst),
(ins Int64Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".u64 \t$dst, $src;"), []>;
+ ToType, ".u64 \t$dst, $src;"), []>,
+ Requires<Preds>;
def _f16 :
NVPTXInst<(outs RC:$dst),
(ins Int16Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".f16 \t$dst, $src;"), []>;
+ ToType, ".f16 \t$dst, $src;"), []>,
+ Requires<Preds>;
+ def _bf16 :
+ NVPTXInst<(outs RC:$dst),
+ (ins Int16Regs:$src, CvtMode:$mode),
+ !strconcat("cvt${mode:base}${mode:ftz}${mode:relu}${mode:sat}.",
+ ToType, ".bf16 \t$dst, $src;"), []>,
+ Requires<!if(!eq(ToType, "f32"),
+ // bf16->f32 was introduced early.
+ [hasPTX<71>, hasSM<80>],
+ // bf16->everything else needs sm90/ptx78
+ [hasPTX<78>, hasSM<90>])>;
def _f32 :
NVPTXInst<(outs RC:$dst),
(ins Float32Regs:$src, CvtMode:$mode),
- !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".f32 \t$dst, $src;"), []>;
+ !strconcat("cvt${mode:base}${mode:ftz}${mode:relu}${mode:sat}.",
+ ToType, ".f32 \t$dst, $src;"), []>,
+ Requires<!if(!eq(ToType, "bf16"),
+ // f32->bf16 was introduced early.
+ [hasPTX<70>, hasSM<80>],
+ Preds)>;
def _f64 :
NVPTXInst<(outs RC:$dst),
(ins Float64Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
- FromName, ".f64 \t$dst, $src;"), []>;
+ ToType, ".f64 \t$dst, $src;"), []>,
+ Requires<Preds>;
}
// Generate cvts from all types to all types.
@@ -538,6 +638,7 @@ let hasSideEffects = false in {
defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>;
defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>;
defm CVT_f16 : CVT_FROM_ALL<"f16", Int16Regs>;
+ defm CVT_bf16 : CVT_FROM_ALL<"bf16", Int16Regs, [hasPTX<78>, hasSM<90>]>;
defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>;
defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>;
@@ -556,18 +657,7 @@ let hasSideEffects = false in {
def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
"cvt.s64.s32 \t$dst, $src;", []>;
-multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
- def _f32 :
- NVPTXInst<(outs RC:$dst),
- (ins Float32Regs:$src, CvtMode:$mode),
- !strconcat("cvt${mode:base}${mode:relu}.",
- FromName, ".f32 \t$dst, $src;"), []>,
- Requires<[hasPTX<70>, hasSM<80>]>;
- }
-
- defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>;
-
- multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
+ multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
def _f32 :
NVPTXInst<(outs RC:$dst),
(ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
@@ -641,6 +731,7 @@ defm SELP_b64 : SELP_PATTERN<"b64", i64, Int64Regs, i64imm, imm>;
defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
defm SELP_f16 : SELP_PATTERN<"b16", f16, Int16Regs, f16imm, fpimm>;
+defm SELP_bf16 : SELP_PATTERN<"b16", bf16, Int16Regs, bf16imm, fpimm>;
defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
@@ -1005,7 +1096,9 @@ def DoubleConst1 : PatLeaf<(fpimm), [{
def LOAD_CONST_F16 :
NVPTXInst<(outs Int16Regs:$dst), (ins f16imm:$a),
"mov.b16 \t$dst, $a;", []>;
-
+def LOAD_CONST_BF16 :
+ NVPTXInst<(outs Int16Regs:$dst), (ins bf16imm:$a),
+ "mov.b16 \t$dst, $a;", []>;
defm FADD : F3_fma_component<"add", fadd>;
defm FSUB : F3_fma_component<"sub", fsub>;
defm FMUL : F3_fma_component<"mul", fmul>;
@@ -1033,6 +1126,20 @@ def FNEG16 : FNEG_F16_F16X2<"neg.f16", f16, Int16Regs, True>;
def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
def FNEG16x2 : FNEG_F16_F16X2<"neg.f16x2", v2f16, Int32Regs, True>;
+//
+// BF16 NEG
+//
+
+class FNEG_BF16_F16X2<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> :
+ NVPTXInst<(outs RC:$dst), (ins RC:$src),
+ !strconcat(OpcStr, " \t$dst, $src;"),
+ [(set RC:$dst, (fneg (T RC:$src)))]>,
+ Requires<[hasBF16Math, hasPTX<70>, hasSM<80>, Pred]>;
+def BFNEG16_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
+def BFNEG16 : FNEG_BF16_F16X2<"neg.bf16", bf16, Int16Regs, True>;
+def BFNEG16x2_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
+def BFNEG16x2 : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, Int32Regs, True>;
+
//
// F64 division
//
@@ -1211,13 +1318,24 @@ multiclass FMA_F16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred>
Requires<[useFP16Math, Pred]>;
}
-defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
-defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
-defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
-defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
-defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
-defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
-defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
+multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
+ def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+ !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
+ [(set RC:$dst, (fma (T RC:$a), (T RC:$b), (T RC:$c)))]>,
+ Requires<[hasBF16Math, Pred]>;
+}
+
+defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
+defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
+defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
+defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
+defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
+defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
+defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
+defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
+defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
+defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
+defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
// sin/cos
def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
@@ -1661,6 +1779,18 @@ def SETP_f16x2rr :
"setp${cmp:base}${cmp:ftz}.f16x2 \t$p|$q, $a, $b;",
[]>,
Requires<[useFP16Math]>;
+def SETP_bf16rr :
+ NVPTXInst<(outs Int1Regs:$dst),
+ (ins Int16Regs:$a, Int16Regs:$b, CmpMode:$cmp),
+ "setp${cmp:base}${cmp:ftz}.bf16 \t$dst, $a, $b;",
+ []>, Requires<[hasBF16Math]>;
+
+def SETP_bf16x2rr :
+ NVPTXInst<(outs Int1Regs:$p, Int1Regs:$q),
+ (ins Int32Regs:$a, Int32Regs:$b, CmpMode:$cmp),
+ "setp${cmp:base}${cmp:ftz}.bf16x2 \t$p|$q, $a, $b;",
+ []>,
+ Requires<[hasBF16Math]>;
// FIXME: This doesn't appear to be correct. The "set" mnemonic has the form
@@ -1691,6 +1821,7 @@ defm SET_b64 : SET<"b64", Int64Regs, i64imm>;
defm SET_s64 : SET<"s64", Int64Regs, i64imm>;
defm SET_u64 : SET<"u64", Int64Regs, i64imm>;
defm SET_f16 : SET<"f16", Int16Regs, f16imm>;
+defm SET_bf16 : SET<"bf16", Int16Regs, bf16imm>;
defm SET_f32 : SET<"f32", Float32Regs, f32imm>;
defm SET_f64 : SET<"f64", Float64Regs, f64imm>;
@@ -1959,6 +2090,26 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
(SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Int16Regs:$b, Mode)>,
Requires<[useFP16Math]>;
+ // bf16 -> pred
+ def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+ (SETP_bf16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>,
+ Requires<[hasBF16Math,doF32FTZ]>;
+ def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+ (SETP_bf16rr Int16Regs:$a, Int16Regs:$b, Mode)>,
+ Requires<[hasBF16Math]>;
+ def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), fpimm:$b)),
+ (SETP_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>,
+ Requires<[hasBF16Math,doF32FTZ]>;
+ def : Pat<(i1 (OpNode (bf16 Int16Regs:$a), fpimm:$b)),
+ (SETP_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>,
+ Requires<[hasBF16Math]>;
+ def : Pat<(i1 (OpNode fpimm:$a, (bf16 Int16Regs:$b))),
+ (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, ModeFTZ)>,
+ Requires<[hasBF16Math,doF32FTZ]>;
+ def : Pat<(i1 (OpNode fpimm:$a, (bf16 Int16Regs:$b))),
+ (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, Mode)>,
+ Requires<[hasBF16Math]>;
+
// f32 -> pred
def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)),
(SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>,
@@ -2004,6 +2155,26 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
(SET_f16ir (LOAD_CONST_F16 fpimm:$a), Int16Regs:$b, Mode)>,
Requires<[useFP16Math]>;
+ // bf16 -> i32
+ def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+ (SET_bf16rr Int16Regs:$a, Int16Regs:$b, ModeFTZ)>,
+ Requires<[hasBF16Math, doF32FTZ]>;
+ def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+ (SET_bf16rr Int16Regs:$a, Int16Regs:$b, Mode)>,
+ Requires<[hasBF16Math]>;
+ def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), fpimm:$b)),
+ (SET_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>,
+ Requires<[hasBF16Math, doF32FTZ]>;
+ def : Pat<(i32 (OpNode (bf16 Int16Regs:$a), fpimm:$b)),
+ (SET_bf16rr Int16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>,
+ Requires<[hasBF16Math]>;
+ def : Pat<(i32 (OpNode fpimm:$a, (bf16 Int16Regs:$b))),
+ (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, ModeFTZ)>,
+ Requires<[hasBF16Math, doF32FTZ]>;
+ def : Pat<(i32 (OpNode fpimm:$a, (bf16 Int16Regs:$b))),
+ (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), Int16Regs:$b, Mode)>,
+ Requires<[hasBF16Math]>;
+
// f32 -> i32
def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)),
(SET_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>,
@@ -2430,7 +2601,7 @@ def MoveParamSymbolI32 : MoveParamSymbolInst<Int32Regs, i32imm, i32, ".b32">;
def MoveParamI16 :
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src),
- "cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ?
+ "cvt.u16.u32 \t$dst, $src;", // ??? Why cvt.u16.u32 ?
[(set i16:$dst, (MoveParam i16:$src))]>;
def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
@@ -2776,7 +2947,7 @@ def: Pat<(vt (bitconvert (i16 Int16Regs:$a))),
def: Pat<(i16 (bitconvert (vt Int16Regs:$a))),
(ProxyRegI16 Int16Regs:$a)>;
}
-
+
// NOTE: pred->fp are currently sub-optimal due to an issue in TableGen where
// we cannot specify floating-point literals in isel patterns. Therefore, we
// use an integer selp to select either 1 or 0 and then cvt to floating-point.
@@ -2801,6 +2972,26 @@ def : Pat<(f16 (uint_to_fp Int32Regs:$a)),
def : Pat<(f16 (uint_to_fp Int64Regs:$a)),
(CVT_f16_u64 Int64Regs:$a, CvtRN)>;
+// sint -> bf16
+def : Pat<(bf16 (sint_to_fp Int1Regs:$a)),
+ (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int16Regs:$a)),
+ (CVT_bf16_s16 Int16Regs:$a, CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int32Regs:$a)),
+ (CVT_bf16_s32 Int32Regs:$a, CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int64Regs:$a)),
+ (CVT_bf16_s64 Int64Regs:$a, CvtRN)>;
+
+// uint -> bf16
+def : Pat<(bf16 (uint_to_fp Int1Regs:$a)),
+ (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int16Regs:$a)),
+ (CVT_bf16_u16 Int16Regs:$a, CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int32Regs:$a)),
+ (CVT_bf16_u32 Int32Regs:$a, CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int64Regs:$a)),
+ (CVT_bf16_u64 Int64Regs:$a, CvtRN)>;
+
// sint -> f32
def : Pat<(f32 (sint_to_fp Int1Regs:$a)),
(CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
@@ -2862,6 +3053,25 @@ def : Pat<(i32 (fp_to_uint (f16 Int16Regs:$a))),
def : Pat<(i64 (fp_to_uint (f16 Int16Regs:$a))),
(CVT_u64_f16 Int16Regs:$a, CvtRZI)>;
+// bf16 -> sint
+def : Pat<(i1 (fp_to_sint (bf16 Int16Regs:$a))),
+ (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>;
+def : Pat<(i16 (fp_to_sint (bf16 Int16Regs:$a))),
+ (CVT_s16_bf16 (bf16 Int16Regs:$a), CvtRZI)>;
+def : Pat<(i32 (fp_to_sint (bf16 Int16Regs:$a))),
+ (CVT_s32_bf16 (bf16 Int16Regs:$a), CvtRZI)>;
+def : Pat<(i64 (fp_to_sint (bf16 Int16Regs:$a))),
+ (CVT_s64_bf16 Int16Regs:$a, CvtRZI)>;
+
+// bf16 -> uint
+def : Pat<(i1 (fp_to_uint (bf16 Int16Regs:$a))),
+ (SETP_b16ri Int16Regs:$a, 0, CmpEQ)>;
+def : Pat<(i16 (fp_to_uint (bf16 Int16Regs:$a))),
+ (CVT_u16_bf16 Int16Regs:$a, CvtRZI)>;
+def : Pat<(i32 (fp_to_uint (bf16 Int16Regs:$a))),
+ (CVT_u32_bf16 Int16Regs:$a, CvtRZI)>;
+def : Pat<(i64 (fp_to_uint (bf16 Int16Regs:$a))),
+ (CVT_u64_bf16 Int16Regs:$a, CvtRZI)>;
// f32 -> sint
def : Pat<(i1 (fp_to_sint Float32Regs:$a)),
(SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>;
@@ -3009,6 +3219,9 @@ def : Pat<(select Int32Regs:$pred, Int64Regs:$a, Int64Regs:$b),
def : Pat<(select Int32Regs:$pred, (f16 Int16Regs:$a), (f16 Int16Regs:$b)),
(SELP_f16rr Int16Regs:$a, Int16Regs:$b,
(SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
+def : Pat<(select Int32Regs:$pred, (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)),
+ (SELP_bf16rr Int16Regs:$a, Int16Regs:$b,
+ (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b),
(SELP_f32rr Float32Regs:$a, Float32Regs:$b,
(SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
@@ -3080,6 +3293,13 @@ def : Pat<(f16 (extractelt (v2f16 Int32Regs:$src), 1)),
def : Pat<(v2f16 (build_vector (f16 Int16Regs:$a), (f16 Int16Regs:$b))),
(V2I16toI32 Int16Regs:$a, Int16Regs:$b)>;
+def : Pat<(bf16 (extractelt (v2bf16 Int32Regs:$src), 0)),
+ (I32toI16L Int32Regs:$src)>;
+def : Pat<(bf16 (extractelt (v2bf16 Int32Regs:$src), 1)),
+ (I32toI16H Int32Regs:$src)>;
+def : Pat<(v2bf16 (build_vector (bf16 Int16Regs:$a), (bf16 Int16Regs:$b))),
+ (V2I16toI32 Int16Regs:$a, Int16Regs:$b)>;
+
// Count leading zeros
let hasSideEffects = false in {
def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a),
@@ -3147,10 +3367,17 @@ def : Pat<(i32 (zext (i16 (ctpop Int16Regs:$a)))),
def : Pat<(f16 (fpround Float32Regs:$a)),
(CVT_f16_f32 Float32Regs:$a, CvtRN)>;
+// fpround f32 -> bf16
+def : Pat<(bf16 (fpround Float32Regs:$a)),
+ (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+
// fpround f64 -> f16
def : Pat<(f16 (fpround Float64Regs:$a)),
(CVT_f16_f64 Float64Regs:$a, CvtRN)>;
+// fpround f64 -> bf16
+def : Pat<(bf16 (fpround Float64Regs:$a)),
+ (CVT_bf16_f64 Float64Regs:$a, CvtRN)>;
// fpround f64 -> f32
def : Pat<(f32 (fpround Float64Regs:$a)),
(CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
@@ -3162,11 +3389,20 @@ def : Pat<(f32 (fpextend (f16 Int16Regs:$a))),
(CVT_f32_f16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
def : Pat<(f32 (fpextend (f16 Int16Regs:$a))),
(CVT_f32_f16 Int16Regs:$a, CvtNONE)>;
+// fpextend bf16 -> f32
+def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
+ (CVT_f32_bf16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
+def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
+ (CVT_f32_bf16 Int16Regs:$a, CvtNONE)>;
// fpextend f16 -> f64
def : Pat<(f64 (fpextend (f16 Int16Regs:$a))),
(CVT_f64_f16 Int16Regs:$a, CvtNONE)>;
+// fpextend bf16 -> f64
+def : Pat<(f64 (fpextend (bf16 Int16Regs:$a))),
+ (CVT_f64_bf16 Int16Regs:$a, CvtNONE)>;
+
// fpextend f32 -> f64
def : Pat<(f64 (fpextend Float32Regs:$a)),
(CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
@@ -3181,6 +3417,8 @@ def retglue : SDNode<"NVPTXISD::RET_GLUE", SDTNone,
multiclass CVT_ROUND<SDNode OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
def : Pat<(OpNode (f16 Int16Regs:$a)),
(CVT_f16_f16 Int16Regs:$a, Mode)>;
+ def : Pat<(OpNode (bf16 Int16Regs:$a)),
+ (CVT_bf16_bf16 Int16Regs:$a, Mode)>;
def : Pat<(OpNode Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, ModeFTZ)>, Requires<[doF32FTZ]>;
def : Pat<(OpNode Float32Regs:$a),
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index bfc79d383191b..f0de0144d410e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -998,6 +998,18 @@ multiclass FMA_INST {
FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, Int16Regs,
[hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, Int16Regs,
+ [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, Int16Regs,
+ [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, Int16Regs,
+ [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs,
+ [hasPTX<70>, hasSM<80>]>,
+ FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, Int16Regs,
+ [hasPTX<70>, hasSM<80>]>,
+
FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, Int32Regs,
[hasPTX<42>, hasSM<53>]>,
FMA_TUPLE<"_rn_ftz_f16x2", int_nvvm_fma_rn_ftz_f16x2, Int32Regs,
@@ -1010,11 +1022,6 @@ multiclass FMA_INST {
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
Int32Regs, [hasPTX<70>, hasSM<80>]>,
-
- FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX<70>, hasSM<80>]>,
- FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs,
- [hasPTX<70>, hasSM<80>]>,
-
FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, Int32Regs,
[hasPTX<70>, hasSM<80>]>,
FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, Int32Regs,
@@ -2207,10 +2214,6 @@ defm INT_PTX_LDU_G_v2i16_ELE
: VLDU_G_ELE_V2<"v2.u16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
defm INT_PTX_LDU_G_v2i32_ELE
: VLDU_G_ELE_V2<"v2.u32 \t{{$dst1, $dst2}}, [$src];", Int32Regs>;
-defm INT_PTX_LDU_G_v2f16_ELE
- : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Int16Regs>;
-defm INT_PTX_LDU_G_v2f16x2_ELE
- : VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", Int32Regs>;
defm INT_PTX_LDU_G_v2f32_ELE
: VLDU_G_ELE_V2<"v2.f32 \t{{$dst1, $dst2}}, [$src];", Float32Regs>;
defm INT_PTX_LDU_G_v2i64_ELE
diff --git a/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp b/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
index 5ec1b2425e68f..95125eb41bc05 100644
--- a/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
@@ -34,6 +34,11 @@ void NVPTXFloatMCExpr::printImpl(raw_ostream &OS, const MCAsmInfo *MAI) const {
NumHex = 4;
APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored);
break;
+ case VK_NVPTX_BFLOAT_PREC_FLOAT:
+ OS << "0x";
+ NumHex = 4;
+ APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &Ignored);
+ break;
case VK_NVPTX_SINGLE_PREC_FLOAT:
OS << "0f";
NumHex = 8;
diff --git a/llvm/lib/Target/NVPTX/NVPTXMCExpr.h b/llvm/lib/Target/NVPTX/NVPTXMCExpr.h
index 440fa1310003e..ef99def06c4da 100644
--- a/llvm/lib/Target/NVPTX/NVPTXMCExpr.h
+++ b/llvm/lib/Target/NVPTX/NVPTXMCExpr.h
@@ -21,6 +21,7 @@ class NVPTXFloatMCExpr : public MCTargetExpr {
public:
enum VariantKind {
VK_NVPTX_None,
+ VK_NVPTX_BFLOAT_PREC_FLOAT, // FP constant in bfloat-precision
VK_NVPTX_HALF_PREC_FLOAT, // FP constant in half-precision
VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision
VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision
@@ -40,6 +41,11 @@ class NVPTXFloatMCExpr : public MCTargetExpr {
static const NVPTXFloatMCExpr *create(VariantKind Kind, const APFloat &Flt,
MCContext &Ctx);
+ static const NVPTXFloatMCExpr *createConstantBFPHalf(const APFloat &Flt,
+ MCContext &Ctx) {
+ return create(VK_NVPTX_BFLOAT_PREC_FLOAT, Flt, Ctx);
+ }
+
static const NVPTXFloatMCExpr *createConstantFPHalf(const APFloat &Flt,
MCContext &Ctx) {
return create(VK_NVPTX_HALF_PREC_FLOAT, Flt, Ctx);
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
index 2347f46449d5f..7fa64af196b93 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
@@ -26,7 +26,6 @@ static cl::opt<bool>
NoF16Math("nvptx-no-f16-math", cl::Hidden,
cl::desc("NVPTX Specific: Disable generation of f16 math ops."),
cl::init(false));
-
// Pin the vtable to this file.
void NVPTXSubtarget::anchor() {}
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 920f5bb94689d..93af11c258b48 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -76,6 +76,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
inline bool hasHWROT32() const { return SmVersion >= 32; }
bool hasImageHandles() const;
bool hasFP16Math() const { return SmVersion >= 53; }
+ bool hasBF16Math() const { return SmVersion >= 80; }
bool allowFP16Math() const;
bool hasMaskOperator() const { return PTXVersion >= 71; }
bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index f39934ae13e80..c73721da46e35 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -204,6 +204,14 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
return {Intrinsic::fma, FTZ_MustBeOff, true};
case Intrinsic::nvvm_fma_rn_ftz_f16x2:
return {Intrinsic::fma, FTZ_MustBeOn, true};
+ case Intrinsic::nvvm_fma_rn_bf16:
+ return {Intrinsic::fma, FTZ_MustBeOff, true};
+ case Intrinsic::nvvm_fma_rn_ftz_bf16:
+ return {Intrinsic::fma, FTZ_MustBeOn, true};
+ case Intrinsic::nvvm_fma_rn_bf16x2:
+ return {Intrinsic::fma, FTZ_MustBeOff, true};
+ case Intrinsic::nvvm_fma_rn_ftz_bf16x2:
+ return {Intrinsic::fma, FTZ_MustBeOn, true};
case Intrinsic::nvvm_fmax_d:
return {Intrinsic::maxnum, FTZ_Any};
case Intrinsic::nvvm_fmax_f:
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
new file mode 100644
index 0000000000000..3373cf1401aae
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -0,0 +1,194 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck --check-prefixes=CHECK,SM80 %s
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx78 | FileCheck --check-prefixes=CHECK,SM90 %s
+; RUN: %if ptxas-11.8 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 | %ptxas-verify -arch=sm_80 %}
+; RUN: %if ptxas-11.8 %{ llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx78 | %ptxas-verify -arch=sm_90 %}
+
+; LDST: .b8 bfloat_array[8] = {1, 2, 3, 4, 5, 6, 7, 8};
+@"bfloat_array" = addrspace(1) constant [4 x bfloat]
+ [bfloat 0xR0201, bfloat 0xR0403, bfloat 0xR0605, bfloat 0xR0807]
+
+; CHECK-LABEL: test_fadd(
+; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_fadd_param_0];
+; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_fadd_param_1];
+; SM90: add.rn.bf16 [[R:%rs[0-9]+]], [[A]], [[B]];
+;
+; SM80-DAG: cvt.f32.bf16 [[FA:%f[0-9]+]], [[A]];
+; SM80-DAG: cvt.f32.bf16 [[FB:%f[0-9]+]], [[B]];
+; SM80: add.rn.f32 [[FR:%f[0-9]+]], [[FA]], [[FB]];
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[FR]];
+; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define bfloat @test_fadd(bfloat %0, bfloat %1) {
+ %3 = fadd bfloat %0, %1
+ ret bfloat %3
+}
+
+; CHECK-LABEL: test_fsub(
+; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_fsub_param_0];
+; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_fsub_param_1];
+; SM90: sub.rn.bf16 [[R:%rs[0-9]+]], [[A]], [[B]];
+;
+; SM80-DAG: cvt.f32.bf16 [[FA:%f[0-9]+]], [[A]];
+; SM80-DAG: cvt.f32.bf16 [[FB:%f[0-9]+]], [[B]];
+; SM80: sub.rn.f32 [[FR:%f[0-9]+]], [[FA]], [[FB]];
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[FR]];
+; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define bfloat @test_fsub(bfloat %0, bfloat %1) {
+ %3 = fsub bfloat %0, %1
+ ret bfloat %3
+}
+
+; CHECK-LABEL: test_faddx2(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_faddx2_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_faddx2_param_1];
+; SM90: add.rn.bf16x2 [[R:%r[0-9]+]], [[A]], [[B]];
+
+; SM80-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
+; SM80-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]];
+; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; SM80-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]];
+; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
+; SM80-DAG: add.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; SM80-DAG: add.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
+; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+
+define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fadd <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fsubx2(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fsubx2_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fsubx2_param_1];
+; SM90: sub.rn.bf16x2 [[R:%r[0-9]+]], [[A]], [[B]];
+
+; SM80-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
+; SM80-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]];
+; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; SM80-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]];
+; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
+; SM80-DAG: sub.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; SM80-DAG: sub.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
+; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+
+define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fsub <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fmulx2(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fmulx2_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fmulx2_param_1];
+; SM90: mul.rn.bf16x2 [[R:%r[0-9]+]], [[A]], [[B]];
+
+; SM80-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
+; SM80-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]];
+; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; SM80-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; SM80-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]];
+; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
+; SM80-DAG: mul.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; SM80-DAG: mul.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
+; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
+; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+
+define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fmul <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fdiv(
+; CHECK-DAG: ld.param.b32 [[A:%r[0-9]+]], [test_fdiv_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%r[0-9]+]], [test_fdiv_param_1];
+; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]]
+; CHECK-DAG: mov.b32 {[[B0:%rs[0-9]+]], [[B1:%rs[0-9]+]]}, [[B]]
+; CHECK-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; CHECK-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; CHECK-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]];
+; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
+; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
+; CHECK-NEXT: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fdiv <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_extract_0(
+; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_extract_0_param_0];
+; CHECK: st.param.b16 [func_retval0+0], [[A]];
+; CHECK: ret;
+
+define bfloat @test_extract_0(<2 x bfloat> %a) #0 {
+ %e = extractelement <2 x bfloat> %a, i32 0
+ ret bfloat %e
+}
+
+; CHECK-LABEL: test_extract_1(
+; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_extract_1_param_0+2];
+; CHECK: st.param.b16 [func_retval0+0], [[A]];
+; CHECK: ret;
+
+define bfloat @test_extract_1(<2 x bfloat> %a) #0 {
+ %e = extractelement <2 x bfloat> %a, i32 1
+ ret bfloat %e
+}
+
+; CHECK-LABEL: test_fpext_float(
+; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_fpext_float_param_0];
+; CHECK: cvt.f32.bf16 [[R:%f[0-9]+]], [[A]];
+; CHECK: st.param.f32 [func_retval0+0], [[R]];
+; CHECK: ret;
+define float @test_fpext_float(bfloat %a) #0 {
+ %r = fpext bfloat %a to float
+ ret float %r
+}
+
+; CHECK-LABEL: test_fptrunc_float(
+; CHECK: ld.param.f32 [[A:%f[0-9]+]], [test_fptrunc_float_param_0];
+; CHECK: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[A]];
+; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK: ret;
+define bfloat @test_fptrunc_float(float %a) #0 {
+ %r = fptrunc float %a to bfloat
+ ret bfloat %r
+}
+
+; CHECK-LABEL: test_fadd_imm_1(
+; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_fadd_imm_1_param_0];
+; SM90: mov.b16 [[B:%rs[0-9]+]], 0x3F80;
+; SM90: add.rn.bf16 [[R:%rs[0-9]+]], [[A]], [[B]];
+
+; SM80-DAG: cvt.f32.bf16 [[FA:%f[0-9]+]], [[A]];
+; SM80: add.rn.f32 [[FR:%f[0-9]+]], [[FA]], 0f3F800000;
+; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[FR]];
+
+; CHECK: st.param.b16 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+define bfloat @test_fadd_imm_1(bfloat %a) #0 {
+ %r = fadd bfloat %a, 1.0
+ ret bfloat %r
+}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
index 6aac2dd18775e..4e30cebfe9025 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
@@ -3,45 +3,45 @@
; CHECK-LABEL: cvt_rn_bf16x2_f32
-define i32 @cvt_rn_bf16x2_f32(float %f1, float %f2) {
+define <2 x bfloat> @cvt_rn_bf16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.bf16x2.f32
- %val = call i32 @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2);
+ %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2);
-ret i32 %val
+ret <2 x bfloat> %val
}
; CHECK-LABEL: cvt_rn_relu_bf16x2_f32
-define i32 @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) {
+define <2 x bfloat> @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rn.relu.bf16x2.f32
-%val = call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2);
+%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2);
-ret i32 %val
+ret <2 x bfloat> %val
}
; CHECK-LABEL: cvt_rz_bf16x2_f32
-define i32 @cvt_rz_bf16x2_f32(float %f1, float %f2) {
+define <2 x bfloat> @cvt_rz_bf16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rz.bf16x2.f32
- %val = call i32 @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2);
+ %val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2);
-ret i32 %val
+ret <2 x bfloat> %val
}
; CHECK-LABEL: cvt_rz_relu_bf16x2_f32
-define i32 @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) {
+define <2 x bfloat> @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) {
; CHECK: cvt.rz.relu.bf16x2.f32
-%val = call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2);
+%val = call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2);
-ret i32 %val
+ret <2 x bfloat> %val
}
-declare i32 @llvm.nvvm.ff2bf16x2.rn(float, float)
-declare i32 @llvm.nvvm.ff2bf16x2.rn.relu(float, float)
-declare i32 @llvm.nvvm.ff2bf16x2.rz(float, float)
-declare i32 @llvm.nvvm.ff2bf16x2.rz.relu(float, float)
+declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float, float)
+declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float, float)
+declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float, float)
+declare <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float, float)
; CHECK-LABEL: cvt_rn_f16x2_f32
define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) {
@@ -85,45 +85,45 @@ declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float)
declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float)
; CHECK-LABEL: cvt_rn_bf16_f32
-define i16 @cvt_rn_bf16_f32(float %f1) {
+define bfloat @cvt_rn_bf16_f32(float %f1) {
; CHECK: cvt.rn.bf16.f32
- %val = call i16 @llvm.nvvm.f2bf16.rn(float %f1);
+ %val = call bfloat @llvm.nvvm.f2bf16.rn(float %f1);
-ret i16 %val
+ret bfloat %val
}
; CHECK-LABEL: cvt_rn_relu_bf16_f32
-define i16 @cvt_rn_relu_bf16_f32(float %f1) {
+define bfloat @cvt_rn_relu_bf16_f32(float %f1) {
; CHECK: cvt.rn.relu.bf16.f32
-%val = call i16 @llvm.nvvm.f2bf16.rn.relu(float %f1);
+%val = call bfloat @llvm.nvvm.f2bf16.rn.relu(float %f1);
-ret i16 %val
+ret bfloat %val
}
; CHECK-LABEL: cvt_rz_bf16_f32
-define i16 @cvt_rz_bf16_f32(float %f1) {
+define bfloat @cvt_rz_bf16_f32(float %f1) {
; CHECK: cvt.rz.bf16.f32
- %val = call i16 @llvm.nvvm.f2bf16.rz(float %f1);
+ %val = call bfloat @llvm.nvvm.f2bf16.rz(float %f1);
-ret i16 %val
+ret bfloat %val
}
; CHECK-LABEL: cvt_rz_relu_bf16_f32
-define i16 @cvt_rz_relu_bf16_f32(float %f1) {
+define bfloat @cvt_rz_relu_bf16_f32(float %f1) {
; CHECK: cvt.rz.relu.bf16.f32
-%val = call i16 @llvm.nvvm.f2bf16.rz.relu(float %f1);
+%val = call bfloat @llvm.nvvm.f2bf16.rz.relu(float %f1);
-ret i16 %val
+ret bfloat %val
}
-declare i16 @llvm.nvvm.f2bf16.rn(float)
-declare i16 @llvm.nvvm.f2bf16.rn.relu(float)
-declare i16 @llvm.nvvm.f2bf16.rz(float)
-declare i16 @llvm.nvvm.f2bf16.rz.relu(float)
+declare bfloat @llvm.nvvm.f2bf16.rn(float)
+declare bfloat @llvm.nvvm.f2bf16.rn.relu(float)
+declare bfloat @llvm.nvvm.f2bf16.rz(float)
+declare bfloat @llvm.nvvm.f2bf16.rz.relu(float)
; CHECK-LABEL: cvt_rna_tf32_f32
define i32 @cvt_rna_tf32_f32(float %f1) {
diff --git a/llvm/test/CodeGen/NVPTX/f16-instructions.ll b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
index 55fde7837487b..deea2e3b557f1 100644
--- a/llvm/test/CodeGen/NVPTX/f16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-instructions.ll
@@ -246,11 +246,11 @@ declare half @test_callee(half %a, half %b) #0
; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_call_param_0];
; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_call_param_1];
; CHECK: {
-; CHECK-DAG: .param .b32 param0;
-; CHECK-DAG: .param .b32 param1;
+; CHECK-DAG: .param .align 2 .b8 param0[2];
+; CHECK-DAG: .param .align 2 .b8 param1[2];
; CHECK-DAG: st.param.b16 [param0+0], [[A]];
; CHECK-DAG: st.param.b16 [param1+0], [[B]];
-; CHECK-DAG: .param .b32 retval0;
+; CHECK-DAG: .param .align 2 .b8 retval0[2];
; CHECK: call.uni (retval0),
; CHECK-NEXT: test_callee,
; CHECK: );
@@ -267,11 +267,11 @@ define half @test_call(half %a, half %b) #0 {
; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_call_flipped_param_0];
; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_call_flipped_param_1];
; CHECK: {
-; CHECK-DAG: .param .b32 param0;
-; CHECK-DAG: .param .b32 param1;
+; CHECK-DAG: .param .align 2 .b8 param0[2];
+; CHECK-DAG: .param .align 2 .b8 param1[2];
; CHECK-DAG: st.param.b16 [param0+0], [[B]];
; CHECK-DAG: st.param.b16 [param1+0], [[A]];
-; CHECK-DAG: .param .b32 retval0;
+; CHECK-DAG: .param .align 2 .b8 retval0[2];
; CHECK: call.uni (retval0),
; CHECK-NEXT: test_callee,
; CHECK: );
@@ -288,11 +288,11 @@ define half @test_call_flipped(half %a, half %b) #0 {
; CHECK-DAG: ld.param.b16 [[A:%rs[0-9]+]], [test_tailcall_flipped_param_0];
; CHECK-DAG: ld.param.b16 [[B:%rs[0-9]+]], [test_tailcall_flipped_param_1];
; CHECK: {
-; CHECK-DAG: .param .b32 param0;
-; CHECK-DAG: .param .b32 param1;
+; CHECK-DAG: .param .align 2 .b8 param0[2];
+; CHECK-DAG: .param .align 2 .b8 param1[2];
; CHECK-DAG: st.param.b16 [param0+0], [[B]];
; CHECK-DAG: st.param.b16 [param1+0], [[A]];
-; CHECK-DAG: .param .b32 retval0;
+; CHECK-DAG: .param .align 2 .b8 retval0[2];
; CHECK: call.uni (retval0),
; CHECK-NEXT: test_callee,
; CHECK: );
diff --git a/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-autoupgrade.ll b/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-autoupgrade.ll
new file mode 100644
index 0000000000000..34b9c08509326
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70-autoupgrade.ll
@@ -0,0 +1,366 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s
+; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %}
+
+declare i16 @llvm.nvvm.abs.bf16(i16)
+declare i32 @llvm.nvvm.abs.bf16x2(i32)
+declare i16 @llvm.nvvm.neg.bf16(i16)
+declare i32 @llvm.nvvm.neg.bf16x2(i32)
+
+declare float @llvm.nvvm.fmin.nan.f(float, float)
+declare float @llvm.nvvm.fmin.ftz.nan.f(float, float)
+declare half @llvm.nvvm.fmin.f16(half, half)
+declare half @llvm.nvvm.fmin.ftz.f16(half, half)
+declare half @llvm.nvvm.fmin.nan.f16(half, half)
+declare half @llvm.nvvm.fmin.ftz.nan.f16(half, half)
+declare <2 x half> @llvm.nvvm.fmin.f16x2(<2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fmin.ftz.f16x2(<2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fmin.nan.f16x2(<2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fmin.ftz.nan.f16x2(<2 x half>, <2 x half>)
+declare i16 @llvm.nvvm.fmin.bf16(i16, i16)
+declare i16 @llvm.nvvm.fmin.nan.bf16(i16, i16)
+declare i32 @llvm.nvvm.fmin.bf16x2(i32, i32)
+declare i32 @llvm.nvvm.fmin.nan.bf16x2(i32, i32)
+
+declare float @llvm.nvvm.fmax.nan.f(float, float)
+declare float @llvm.nvvm.fmax.ftz.nan.f(float, float)
+declare half @llvm.nvvm.fmax.f16(half, half)
+declare half @llvm.nvvm.fmax.ftz.f16(half, half)
+declare half @llvm.nvvm.fmax.nan.f16(half, half)
+declare half @llvm.nvvm.fmax.ftz.nan.f16(half, half)
+declare <2 x half> @llvm.nvvm.fmax.f16x2(<2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fmax.ftz.f16x2(<2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fmax.nan.f16x2(<2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fmax.ftz.nan.f16x2(<2 x half>, <2 x half>)
+declare i16 @llvm.nvvm.fmax.bf16(i16, i16)
+declare i16 @llvm.nvvm.fmax.nan.bf16(i16, i16)
+declare i32 @llvm.nvvm.fmax.bf16x2(i32, i32)
+declare i32 @llvm.nvvm.fmax.nan.bf16x2(i32, i32)
+
+declare half @llvm.nvvm.fma.rn.relu.f16(half, half, half)
+declare half @llvm.nvvm.fma.rn.ftz.relu.f16(half, half, half)
+declare <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half>, <2 x half>, <2 x half>)
+declare <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half>, <2 x half>, <2 x half>)
+declare i16 @llvm.nvvm.fma.rn.bf16(i16, i16, i16)
+declare i16 @llvm.nvvm.fma.rn.relu.bf16(i16, i16, i16)
+declare i32 @llvm.nvvm.fma.rn.bf16x2(i32, i32, i32)
+declare i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32, i32, i32)
+
+; CHECK-LABEL: abs_bf16
+define i16 @abs_bf16(i16 %0) {
+ ; CHECK-NOT: call
+ ; CHECK: abs.bf16
+ %res = call i16 @llvm.nvvm.abs.bf16(i16 %0);
+ ret i16 %res
+}
+
+; CHECK-LABEL: abs_bf16x2
+define i32 @abs_bf16x2(i32 %0) {
+ ; CHECK-NOT: call
+ ; CHECK: abs.bf16x2
+ %res = call i32 @llvm.nvvm.abs.bf16x2(i32 %0);
+ ret i32 %res
+}
+
+; CHECK-LABEL: neg_bf16
+define i16 @neg_bf16(i16 %0) {
+ ; CHECK-NOT: call
+ ; CHECK: neg.bf16
+ %res = call i16 @llvm.nvvm.neg.bf16(i16 %0);
+ ret i16 %res
+}
+
+; CHECK-LABEL: neg_bf16x2
+define i32 @neg_bf16x2(i32 %0) {
+ ; CHECK-NOT: call
+ ; CHECK: neg.bf16x2
+ %res = call i32 @llvm.nvvm.neg.bf16x2(i32 %0);
+ ret i32 %res
+}
+
+; CHECK-LABEL: fmin_nan_f
+define float @fmin_nan_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.f32
+ %res = call float @llvm.nvvm.fmin.nan.f(float %0, float %1);
+ ret float %res
+}
+
+; CHECK-LABEL: fmin_ftz_nan_f
+define float @fmin_ftz_nan_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.NaN.f32
+ %res = call float @llvm.nvvm.fmin.ftz.nan.f(float %0, float %1);
+ ret float %res
+}
+
+; CHECK-LABEL: fmin_f16
+define half @fmin_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.f16
+ %res = call half @llvm.nvvm.fmin.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmin_ftz_f16
+define half @fmin_ftz_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.f16
+ %res = call half @llvm.nvvm.fmin.ftz.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmin_nan_f16
+define half @fmin_nan_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.f16
+ %res = call half @llvm.nvvm.fmin.nan.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmin_ftz_nan_f16
+define half @fmin_ftz_nan_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.NaN.f16
+ %res = call half @llvm.nvvm.fmin.ftz.nan.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmin_f16x2
+define <2 x half> @fmin_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmin.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmin_ftz_f16x2
+define <2 x half> @fmin_ftz_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmin.ftz.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmin_nan_f16x2
+define <2 x half> @fmin_nan_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmin.nan.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmin_ftz_nan_f16x2
+define <2 x half> @fmin_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.NaN.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmin.ftz.nan.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmin_bf16
+define i16 @fmin_bf16(i16 %0, i16 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.bf16
+ %res = call i16 @llvm.nvvm.fmin.bf16(i16 %0, i16 %1)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fmin_nan_bf16
+define i16 @fmin_nan_bf16(i16 %0, i16 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.bf16
+ %res = call i16 @llvm.nvvm.fmin.nan.bf16(i16 %0, i16 %1)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fmin_bf16x2
+define i32 @fmin_bf16x2(i32 %0, i32 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.bf16x2
+ %res = call i32 @llvm.nvvm.fmin.bf16x2(i32 %0, i32 %1)
+ ret i32 %res
+}
+
+; CHECK-LABEL: fmin_nan_bf16x2
+define i32 @fmin_nan_bf16x2(i32 %0, i32 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.bf16x2
+ %res = call i32 @llvm.nvvm.fmin.nan.bf16x2(i32 %0, i32 %1)
+ ret i32 %res
+}
+
+; CHECK-LABEL: fmax_nan_f
+define float @fmax_nan_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.f32
+ %res = call float @llvm.nvvm.fmax.nan.f(float %0, float %1);
+ ret float %res
+}
+
+; CHECK-LABEL: fmax_ftz_nan_f
+define float @fmax_ftz_nan_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.NaN.f32
+ %res = call float @llvm.nvvm.fmax.ftz.nan.f(float %0, float %1);
+ ret float %res
+}
+
+; CHECK-LABEL: fmax_f16
+define half @fmax_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.f16
+ %res = call half @llvm.nvvm.fmax.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmax_ftz_f16
+define half @fmax_ftz_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.f16
+ %res = call half @llvm.nvvm.fmax.ftz.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmax_nan_f16
+define half @fmax_nan_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.f16
+ %res = call half @llvm.nvvm.fmax.nan.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmax_ftz_nan_f16
+define half @fmax_ftz_nan_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.NaN.f16
+ %res = call half @llvm.nvvm.fmax.ftz.nan.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmax_f16x2
+define <2 x half> @fmax_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmax.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmax_ftz_f16x2
+define <2 x half> @fmax_ftz_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmax.ftz.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmax_nan_f16x2
+define <2 x half> @fmax_nan_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmax.nan.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmax_ftz_nan_f16x2
+define <2 x half> @fmax_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.NaN.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmax.ftz.nan.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmax_bf16
+define i16 @fmax_bf16(i16 %0, i16 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.bf16
+ %res = call i16 @llvm.nvvm.fmax.bf16(i16 %0, i16 %1)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fmax_nan_bf16
+define i16 @fmax_nan_bf16(i16 %0, i16 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.bf16
+ %res = call i16 @llvm.nvvm.fmax.nan.bf16(i16 %0, i16 %1)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fmax_bf16x2
+define i32 @fmax_bf16x2(i32 %0, i32 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.bf16x2
+ %res = call i32 @llvm.nvvm.fmax.bf16x2(i32 %0, i32 %1)
+ ret i32 %res
+}
+
+; CHECK-LABEL: fmax_nan_bf16x2
+define i32 @fmax_nan_bf16x2(i32 %0, i32 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.bf16x2
+ %res = call i32 @llvm.nvvm.fmax.nan.bf16x2(i32 %0, i32 %1)
+ ret i32 %res
+}
+
+; CHECK-LABEL: fma_rn_relu_f16
+define half @fma_rn_relu_f16(half %0, half %1, half %2) {
+ ; CHECK-NOT: call
+ ; CHECK: fma.rn.relu.f16
+ %res = call half @llvm.nvvm.fma.rn.relu.f16(half %0, half %1, half %2)
+ ret half %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_relu_f16
+define half @fma_rn_ftz_relu_f16(half %0, half %1, half %2) {
+ ; CHECK-NOT: call
+ ; CHECK: fma.rn.ftz.relu.f16
+ %res = call half @llvm.nvvm.fma.rn.ftz.relu.f16(half %0, half %1, half %2)
+ ret half %res
+}
+
+; CHECK-LABEL: fma_rn_relu_f16x2
+define <2 x half> @fma_rn_relu_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+ ; CHECK-NOT: call
+ ; CHECK: fma.rn.relu.f16x2
+ %res = call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fma_rn_ftz_relu_f16x2
+define <2 x half> @fma_rn_ftz_relu_f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2) {
+ ; CHECK-NOT: call
+ ; CHECK: fma.rn.ftz.relu.f16x2
+ %res = call <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half> %0, <2 x half> %1, <2 x half> %2)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fma_rn_bf16
+define i16 @fma_rn_bf16(i16 %0, i16 %1, i16 %2) {
+ ; CHECK-NOT: call
+ ; CHECK: fma.rn.bf16
+ %res = call i16 @llvm.nvvm.fma.rn.bf16(i16 %0, i16 %1, i16 %2)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fma_rn_relu_bf16
+define i16 @fma_rn_relu_bf16(i16 %0, i16 %1, i16 %2) {
+ ; CHECK-NOT: call
+ ; CHECK: fma.rn.relu.bf16
+ %res = call i16 @llvm.nvvm.fma.rn.relu.bf16(i16 %0, i16 %1, i16 %2)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fma_rn_bf16x2
+define i32 @fma_rn_bf16x2(i32 %0, i32 %1, i32 %2) {
+ ; CHECK-NOT: call
+ ; CHECK: fma.rn.bf16x2
+ %res = call i32 @llvm.nvvm.fma.rn.bf16x2(i32 %0, i32 %1, i32 %2)
+ ret i32 %res
+}
+
+; CHECK-LABEL: fma_rn_relu_bf16x2
+define i32 @fma_rn_relu_bf16x2(i32 %0, i32 %1, i32 %2) {
+ ; CHECK-NOT: call
+ ; CHECK: fma.rn.relu.bf16x2
+ %res = call i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32 %0, i32 %1, i32 %2)
+ ret i32 %res
+}
diff --git a/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll b/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll
index 34b9c08509326..fe05c8e5ec734 100644
--- a/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll
+++ b/llvm/test/CodeGen/NVPTX/math-intrins-sm80-ptx70.ll
@@ -1,10 +1,10 @@
; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s
; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %}
-declare i16 @llvm.nvvm.abs.bf16(i16)
-declare i32 @llvm.nvvm.abs.bf16x2(i32)
-declare i16 @llvm.nvvm.neg.bf16(i16)
-declare i32 @llvm.nvvm.neg.bf16x2(i32)
+declare bfloat @llvm.nvvm.abs.bf16(bfloat)
+declare <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat>)
+declare bfloat @llvm.nvvm.neg.bf16(bfloat)
+declare <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat>)
declare float @llvm.nvvm.fmin.nan.f(float, float)
declare float @llvm.nvvm.fmin.ftz.nan.f(float, float)
@@ -16,10 +16,10 @@ declare <2 x half> @llvm.nvvm.fmin.f16x2(<2 x half>, <2 x half>)
declare <2 x half> @llvm.nvvm.fmin.ftz.f16x2(<2 x half>, <2 x half>)
declare <2 x half> @llvm.nvvm.fmin.nan.f16x2(<2 x half>, <2 x half>)
declare <2 x half> @llvm.nvvm.fmin.ftz.nan.f16x2(<2 x half>, <2 x half>)
-declare i16 @llvm.nvvm.fmin.bf16(i16, i16)
-declare i16 @llvm.nvvm.fmin.nan.bf16(i16, i16)
-declare i32 @llvm.nvvm.fmin.bf16x2(i32, i32)
-declare i32 @llvm.nvvm.fmin.nan.bf16x2(i32, i32)
+declare bfloat @llvm.nvvm.fmin.bf16(bfloat, bfloat)
+declare bfloat @llvm.nvvm.fmin.nan.bf16(bfloat, bfloat)
+declare <2 x bfloat> @llvm.nvvm.fmin.bf16x2(<2 x bfloat>, <2 x bfloat>)
+declare <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2(<2 x bfloat>, <2 x bfloat>)
declare float @llvm.nvvm.fmax.nan.f(float, float)
declare float @llvm.nvvm.fmax.ftz.nan.f(float, float)
@@ -31,50 +31,50 @@ declare <2 x half> @llvm.nvvm.fmax.f16x2(<2 x half>, <2 x half>)
declare <2 x half> @llvm.nvvm.fmax.ftz.f16x2(<2 x half>, <2 x half>)
declare <2 x half> @llvm.nvvm.fmax.nan.f16x2(<2 x half>, <2 x half>)
declare <2 x half> @llvm.nvvm.fmax.ftz.nan.f16x2(<2 x half>, <2 x half>)
-declare i16 @llvm.nvvm.fmax.bf16(i16, i16)
-declare i16 @llvm.nvvm.fmax.nan.bf16(i16, i16)
-declare i32 @llvm.nvvm.fmax.bf16x2(i32, i32)
-declare i32 @llvm.nvvm.fmax.nan.bf16x2(i32, i32)
+declare bfloat @llvm.nvvm.fmax.bf16(bfloat, bfloat)
+declare bfloat @llvm.nvvm.fmax.nan.bf16(bfloat, bfloat)
+declare <2 x bfloat> @llvm.nvvm.fmax.bf16x2(<2 x bfloat>, <2 x bfloat>)
+declare <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2(<2 x bfloat>, <2 x bfloat>)
declare half @llvm.nvvm.fma.rn.relu.f16(half, half, half)
declare half @llvm.nvvm.fma.rn.ftz.relu.f16(half, half, half)
declare <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half>, <2 x half>, <2 x half>)
declare <2 x half> @llvm.nvvm.fma.rn.ftz.relu.f16x2(<2 x half>, <2 x half>, <2 x half>)
-declare i16 @llvm.nvvm.fma.rn.bf16(i16, i16, i16)
-declare i16 @llvm.nvvm.fma.rn.relu.bf16(i16, i16, i16)
-declare i32 @llvm.nvvm.fma.rn.bf16x2(i32, i32, i32)
-declare i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32, i32, i32)
+declare bfloat @llvm.nvvm.fma.rn.bf16(bfloat, bfloat, bfloat)
+declare bfloat @llvm.nvvm.fma.rn.relu.bf16(bfloat, bfloat, bfloat)
+declare <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2(<2 x bfloat>, <2 x bfloat>, <2 x bfloat>)
+declare <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2(<2 x bfloat>, <2 x bfloat>, <2 x bfloat>)
; CHECK-LABEL: abs_bf16
-define i16 @abs_bf16(i16 %0) {
+define bfloat @abs_bf16(bfloat %0) {
; CHECK-NOT: call
; CHECK: abs.bf16
- %res = call i16 @llvm.nvvm.abs.bf16(i16 %0);
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.abs.bf16(bfloat %0);
+ ret bfloat %res
}
; CHECK-LABEL: abs_bf16x2
-define i32 @abs_bf16x2(i32 %0) {
+define <2 x bfloat> @abs_bf16x2(<2 x bfloat> %0) {
; CHECK-NOT: call
; CHECK: abs.bf16x2
- %res = call i32 @llvm.nvvm.abs.bf16x2(i32 %0);
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat> %0);
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: neg_bf16
-define i16 @neg_bf16(i16 %0) {
+define bfloat @neg_bf16(bfloat %0) {
; CHECK-NOT: call
; CHECK: neg.bf16
- %res = call i16 @llvm.nvvm.neg.bf16(i16 %0);
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.neg.bf16(bfloat %0);
+ ret bfloat %res
}
; CHECK-LABEL: neg_bf16x2
-define i32 @neg_bf16x2(i32 %0) {
+define <2 x bfloat> @neg_bf16x2(<2 x bfloat> %0) {
; CHECK-NOT: call
; CHECK: neg.bf16x2
- %res = call i32 @llvm.nvvm.neg.bf16x2(i32 %0);
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat> %0);
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fmin_nan_f
@@ -158,35 +158,35 @@ define <2 x half> @fmin_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) {
}
; CHECK-LABEL: fmin_bf16
-define i16 @fmin_bf16(i16 %0, i16 %1) {
+define bfloat @fmin_bf16(bfloat %0, bfloat %1) {
; CHECK-NOT: call
; CHECK: min.bf16
- %res = call i16 @llvm.nvvm.fmin.bf16(i16 %0, i16 %1)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fmin.bf16(bfloat %0, bfloat %1)
+ ret bfloat %res
}
; CHECK-LABEL: fmin_nan_bf16
-define i16 @fmin_nan_bf16(i16 %0, i16 %1) {
+define bfloat @fmin_nan_bf16(bfloat %0, bfloat %1) {
; CHECK-NOT: call
; CHECK: min.NaN.bf16
- %res = call i16 @llvm.nvvm.fmin.nan.bf16(i16 %0, i16 %1)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fmin.nan.bf16(bfloat %0, bfloat %1)
+ ret bfloat %res
}
; CHECK-LABEL: fmin_bf16x2
-define i32 @fmin_bf16x2(i32 %0, i32 %1) {
+define <2 x bfloat> @fmin_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) {
; CHECK-NOT: call
; CHECK: min.bf16x2
- %res = call i32 @llvm.nvvm.fmin.bf16x2(i32 %0, i32 %1)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fmin.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1)
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fmin_nan_bf16x2
-define i32 @fmin_nan_bf16x2(i32 %0, i32 %1) {
+define <2 x bfloat> @fmin_nan_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) {
; CHECK-NOT: call
; CHECK: min.NaN.bf16x2
- %res = call i32 @llvm.nvvm.fmin.nan.bf16x2(i32 %0, i32 %1)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1)
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fmax_nan_f
@@ -270,35 +270,35 @@ define <2 x half> @fmax_ftz_nan_f16x2(<2 x half> %0, <2 x half> %1) {
}
; CHECK-LABEL: fmax_bf16
-define i16 @fmax_bf16(i16 %0, i16 %1) {
+define bfloat @fmax_bf16(bfloat %0, bfloat %1) {
; CHECK-NOT: call
; CHECK: max.bf16
- %res = call i16 @llvm.nvvm.fmax.bf16(i16 %0, i16 %1)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fmax.bf16(bfloat %0, bfloat %1)
+ ret bfloat %res
}
; CHECK-LABEL: fmax_nan_bf16
-define i16 @fmax_nan_bf16(i16 %0, i16 %1) {
+define bfloat @fmax_nan_bf16(bfloat %0, bfloat %1) {
; CHECK-NOT: call
; CHECK: max.NaN.bf16
- %res = call i16 @llvm.nvvm.fmax.nan.bf16(i16 %0, i16 %1)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fmax.nan.bf16(bfloat %0, bfloat %1)
+ ret bfloat %res
}
; CHECK-LABEL: fmax_bf16x2
-define i32 @fmax_bf16x2(i32 %0, i32 %1) {
+define <2 x bfloat> @fmax_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) {
; CHECK-NOT: call
; CHECK: max.bf16x2
- %res = call i32 @llvm.nvvm.fmax.bf16x2(i32 %0, i32 %1)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fmax.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1)
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fmax_nan_bf16x2
-define i32 @fmax_nan_bf16x2(i32 %0, i32 %1) {
+define <2 x bfloat> @fmax_nan_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) {
; CHECK-NOT: call
; CHECK: max.NaN.bf16x2
- %res = call i32 @llvm.nvvm.fmax.nan.bf16x2(i32 %0, i32 %1)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1)
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fma_rn_relu_f16
@@ -334,33 +334,33 @@ define <2 x half> @fma_rn_ftz_relu_f16x2(<2 x half> %0, <2 x half> %1, <2 x half
}
; CHECK-LABEL: fma_rn_bf16
-define i16 @fma_rn_bf16(i16 %0, i16 %1, i16 %2) {
+define bfloat @fma_rn_bf16(bfloat %0, bfloat %1, bfloat %2) {
; CHECK-NOT: call
; CHECK: fma.rn.bf16
- %res = call i16 @llvm.nvvm.fma.rn.bf16(i16 %0, i16 %1, i16 %2)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fma.rn.bf16(bfloat %0, bfloat %1, bfloat %2)
+ ret bfloat %res
}
; CHECK-LABEL: fma_rn_relu_bf16
-define i16 @fma_rn_relu_bf16(i16 %0, i16 %1, i16 %2) {
+define bfloat @fma_rn_relu_bf16(bfloat %0, bfloat %1, bfloat %2) {
; CHECK-NOT: call
; CHECK: fma.rn.relu.bf16
- %res = call i16 @llvm.nvvm.fma.rn.relu.bf16(i16 %0, i16 %1, i16 %2)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fma.rn.relu.bf16(bfloat %0, bfloat %1, bfloat %2)
+ ret bfloat %res
}
; CHECK-LABEL: fma_rn_bf16x2
-define i32 @fma_rn_bf16x2(i32 %0, i32 %1, i32 %2) {
+define <2 x bfloat> @fma_rn_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2) {
; CHECK-NOT: call
; CHECK: fma.rn.bf16x2
- %res = call i32 @llvm.nvvm.fma.rn.bf16x2(i32 %0, i32 %1, i32 %2)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2)
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fma_rn_relu_bf16x2
-define i32 @fma_rn_relu_bf16x2(i32 %0, i32 %1, i32 %2) {
+define <2 x bfloat> @fma_rn_relu_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2) {
; CHECK-NOT: call
; CHECK: fma.rn.relu.bf16x2
- %res = call i32 @llvm.nvvm.fma.rn.relu.bf16x2(i32 %0, i32 %1, i32 %2)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1, <2 x bfloat> %2)
+ ret <2 x bfloat> %res
}
diff --git a/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72-autoupgrade.ll b/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72-autoupgrade.ll
new file mode 100644
index 0000000000000..b745df484bab2
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72-autoupgrade.ll
@@ -0,0 +1,292 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_86 -mattr=+ptx72 | FileCheck %s
+; RUN: %if ptxas-11.2 %{ llc < %s -march=nvptx64 -mcpu=sm_86 -mattr=+ptx72 | %ptxas-verify -arch=sm_86 %}
+
+declare half @llvm.nvvm.fmin.xorsign.abs.f16(half, half)
+declare half @llvm.nvvm.fmin.ftz.xorsign.abs.f16(half, half)
+declare half @llvm.nvvm.fmin.nan.xorsign.abs.f16(half, half)
+declare half @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16(half, half)
+declare <2 x half> @llvm.nvvm.fmin.xorsign.abs.f16x2(<2 x half> , <2 x half>)
+declare <2 x half> @llvm.nvvm.fmin.ftz.xorsign.abs.f16x2(<2 x half> , <2 x half>)
+declare <2 x half> @llvm.nvvm.fmin.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>)
+declare <2 x half> @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>)
+declare i16 @llvm.nvvm.fmin.xorsign.abs.bf16(i16, i16)
+declare i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16(i16, i16)
+declare i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2(i32, i32)
+declare i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(i32, i32)
+declare float @llvm.nvvm.fmin.xorsign.abs.f(float, float)
+declare float @llvm.nvvm.fmin.ftz.xorsign.abs.f(float, float)
+declare float @llvm.nvvm.fmin.nan.xorsign.abs.f(float, float)
+declare float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f(float, float)
+
+declare half @llvm.nvvm.fmax.xorsign.abs.f16(half, half)
+declare half @llvm.nvvm.fmax.ftz.xorsign.abs.f16(half, half)
+declare half @llvm.nvvm.fmax.nan.xorsign.abs.f16(half, half)
+declare half @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16(half, half)
+declare <2 x half> @llvm.nvvm.fmax.xorsign.abs.f16x2(<2 x half> , <2 x half>)
+declare <2 x half> @llvm.nvvm.fmax.ftz.xorsign.abs.f16x2(<2 x half> , <2 x half>)
+declare <2 x half> @llvm.nvvm.fmax.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>)
+declare <2 x half> @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>)
+declare i16 @llvm.nvvm.fmax.xorsign.abs.bf16(i16, i16)
+declare i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16(i16, i16)
+declare i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2(i32, i32)
+declare i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(i32, i32)
+declare float @llvm.nvvm.fmax.xorsign.abs.f(float, float)
+declare float @llvm.nvvm.fmax.ftz.xorsign.abs.f(float, float)
+declare float @llvm.nvvm.fmax.nan.xorsign.abs.f(float, float)
+declare float @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f(float, float)
+
+; CHECK-LABEL: fmin_xorsign_abs_f16
+define half @fmin_xorsign_abs_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.xorsign.abs.f16
+ %res = call half @llvm.nvvm.fmin.xorsign.abs.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmin_ftz_xorsign_abs_f16
+define half @fmin_ftz_xorsign_abs_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.xorsign.abs.f16
+ %res = call half @llvm.nvvm.fmin.ftz.xorsign.abs.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmin_nan_xorsign_abs_f16
+define half @fmin_nan_xorsign_abs_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.xorsign.abs.f16
+ %res = call half @llvm.nvvm.fmin.nan.xorsign.abs.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmin_ftz_nan_xorsign_abs_f16
+define half @fmin_ftz_nan_xorsign_abs_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.NaN.xorsign.abs.f16
+ %res = call half @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmin_xorsign_abs_f16x2
+define <2 x half> @fmin_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.xorsign.abs.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmin.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmin_ftz_xorsign_abs_f16x2
+define <2 x half> @fmin_ftz_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.xorsign.abs.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmin.ftz.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmin_nan_xorsign_abs_f16x2
+define <2 x half> @fmin_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.xorsign.abs.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmin.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmin_ftz_nan_xorsign_abs_f16x2
+define <2 x half> @fmin_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.NaN.xorsign.abs.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmin_xorsign_abs_bf16
+define i16 @fmin_xorsign_abs_bf16(i16 %0, i16 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.xorsign.abs.bf16
+ %res = call i16 @llvm.nvvm.fmin.xorsign.abs.bf16(i16 %0, i16 %1)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fmin_nan_xorsign_abs_bf16
+define i16 @fmin_nan_xorsign_abs_bf16(i16 %0, i16 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.xorsign.abs.bf16
+ %res = call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16(i16 %0, i16 %1)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fmin_xorsign_abs_bf16x2
+define i32 @fmin_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.xorsign.abs.bf16x2
+ %res = call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2(i32 %0, i32 %1)
+ ret i32 %res
+}
+
+; CHECK-LABEL: fmin_nan_xorsign_abs_bf16x2
+define i32 @fmin_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.xorsign.abs.bf16x2
+ %res = call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(i32 %0, i32 %1)
+ ret i32 %res
+}
+
+; CHECK-LABEL: fmin_xorsign_abs_f
+define float @fmin_xorsign_abs_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.xorsign.abs.f
+ %res = call float @llvm.nvvm.fmin.xorsign.abs.f(float %0, float %1)
+ ret float %res
+}
+
+; CHECK-LABEL: fmin_ftz_xorsign_abs_f
+define float @fmin_ftz_xorsign_abs_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.xorsign.abs.f
+ %res = call float @llvm.nvvm.fmin.ftz.xorsign.abs.f(float %0, float %1)
+ ret float %res
+}
+
+; CHECK-LABEL: fmin_nan_xorsign_abs_f
+define float @fmin_nan_xorsign_abs_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.NaN.xorsign.abs.f
+ %res = call float @llvm.nvvm.fmin.nan.xorsign.abs.f(float %0, float %1)
+ ret float %res
+}
+
+; CHECK-LABEL: fmin_ftz_nan_xorsign_abs_f
+define float @fmin_ftz_nan_xorsign_abs_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: min.ftz.NaN.xorsign.abs.f
+ %res = call float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f(float %0, float %1)
+ ret float %res
+}
+
+; CHECK-LABEL: fmax_xorsign_abs_f16
+define half @fmax_xorsign_abs_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.xorsign.abs.f16
+ %res = call half @llvm.nvvm.fmax.xorsign.abs.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmax_ftz_xorsign_abs_f16
+define half @fmax_ftz_xorsign_abs_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.xorsign.abs.f16
+ %res = call half @llvm.nvvm.fmax.ftz.xorsign.abs.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmax_nan_xorsign_abs_f16
+define half @fmax_nan_xorsign_abs_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.xorsign.abs.f16
+ %res = call half @llvm.nvvm.fmax.nan.xorsign.abs.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmax_ftz_nan_xorsign_abs_f16
+define half @fmax_ftz_nan_xorsign_abs_f16(half %0, half %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.NaN.xorsign.abs.f16
+ %res = call half @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16(half %0, half %1)
+ ret half %res
+}
+
+; CHECK-LABEL: fmax_xorsign_abs_f16x2
+define <2 x half> @fmax_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.xorsign.abs.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmax.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmax_ftz_xorsign_abs_f16x2
+define <2 x half> @fmax_ftz_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.xorsign.abs.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmax.ftz.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmax_nan_xorsign_abs_f16x2
+define <2 x half> @fmax_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.xorsign.abs.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmax.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmax_ftz_nan_xorsign_abs_f16x2
+define <2 x half> @fmax_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.NaN.xorsign.abs.f16x2
+ %res = call <2 x half> @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16x2(<2 x half> %0, <2 x half> %1)
+ ret <2 x half> %res
+}
+
+; CHECK-LABEL: fmax_xorsign_abs_bf16
+define i16 @fmax_xorsign_abs_bf16(i16 %0, i16 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.xorsign.abs.bf16
+ %res = call i16 @llvm.nvvm.fmax.xorsign.abs.bf16(i16 %0, i16 %1)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fmax_nan_xorsign_abs_bf16
+define i16 @fmax_nan_xorsign_abs_bf16(i16 %0, i16 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.xorsign.abs.bf16
+ %res = call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16(i16 %0, i16 %1)
+ ret i16 %res
+}
+
+; CHECK-LABEL: fmax_xorsign_abs_bf16x2
+define i32 @fmax_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.xorsign.abs.bf16x2
+ %res = call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2(i32 %0, i32 %1)
+ ret i32 %res
+}
+
+; CHECK-LABEL: fmax_nan_xorsign_abs_bf16x2
+define i32 @fmax_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.xorsign.abs.bf16x2
+ %res = call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(i32 %0, i32 %1)
+ ret i32 %res
+}
+
+; CHECK-LABEL: fmax_xorsign_abs_f
+define float @fmax_xorsign_abs_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.xorsign.abs.f
+ %res = call float @llvm.nvvm.fmax.xorsign.abs.f(float %0, float %1)
+ ret float %res
+}
+
+; CHECK-LABEL: fmax_ftz_xorsign_abs_f
+define float @fmax_ftz_xorsign_abs_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.xorsign.abs.f
+ %res = call float @llvm.nvvm.fmax.ftz.xorsign.abs.f(float %0, float %1)
+ ret float %res
+}
+
+; CHECK-LABEL: fmax_nan_xorsign_abs_f
+define float @fmax_nan_xorsign_abs_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.NaN.xorsign.abs.f
+ %res = call float @llvm.nvvm.fmax.nan.xorsign.abs.f(float %0, float %1)
+ ret float %res
+}
+
+; CHECK-LABEL: fmax_ftz_nan_xorsign_abs_f
+define float @fmax_ftz_nan_xorsign_abs_f(float %0, float %1) {
+ ; CHECK-NOT: call
+ ; CHECK: max.ftz.NaN.xorsign.abs.f
+ %res = call float @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f(float %0, float %1)
+ ret float %res
+}
diff --git a/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll b/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll
index b745df484bab2..6d430b052d8fe 100644
--- a/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll
+++ b/llvm/test/CodeGen/NVPTX/math-intrins-sm86-ptx72.ll
@@ -9,10 +9,10 @@ declare <2 x half> @llvm.nvvm.fmin.xorsign.abs.f16x2(<2 x half> , <2 x half>)
declare <2 x half> @llvm.nvvm.fmin.ftz.xorsign.abs.f16x2(<2 x half> , <2 x half>)
declare <2 x half> @llvm.nvvm.fmin.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>)
declare <2 x half> @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>)
-declare i16 @llvm.nvvm.fmin.xorsign.abs.bf16(i16, i16)
-declare i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16(i16, i16)
-declare i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2(i32, i32)
-declare i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(i32, i32)
+declare bfloat @llvm.nvvm.fmin.xorsign.abs.bf16(bfloat, bfloat)
+declare bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16(bfloat, bfloat)
+declare <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2(<2 x bfloat>, <2 x bfloat>)
+declare <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(<2 x bfloat>, <2 x bfloat>)
declare float @llvm.nvvm.fmin.xorsign.abs.f(float, float)
declare float @llvm.nvvm.fmin.ftz.xorsign.abs.f(float, float)
declare float @llvm.nvvm.fmin.nan.xorsign.abs.f(float, float)
@@ -26,10 +26,10 @@ declare <2 x half> @llvm.nvvm.fmax.xorsign.abs.f16x2(<2 x half> , <2 x half>)
declare <2 x half> @llvm.nvvm.fmax.ftz.xorsign.abs.f16x2(<2 x half> , <2 x half>)
declare <2 x half> @llvm.nvvm.fmax.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>)
declare <2 x half> @llvm.nvvm.fmax.ftz.nan.xorsign.abs.f16x2(<2 x half> , <2 x half>)
-declare i16 @llvm.nvvm.fmax.xorsign.abs.bf16(i16, i16)
-declare i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16(i16, i16)
-declare i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2(i32, i32)
-declare i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(i32, i32)
+declare bfloat @llvm.nvvm.fmax.xorsign.abs.bf16(bfloat, bfloat)
+declare bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16(bfloat, bfloat)
+declare <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2(<2 x bfloat>, <2 x bfloat>)
+declare <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(<2 x bfloat>, <2 x bfloat>)
declare float @llvm.nvvm.fmax.xorsign.abs.f(float, float)
declare float @llvm.nvvm.fmax.ftz.xorsign.abs.f(float, float)
declare float @llvm.nvvm.fmax.nan.xorsign.abs.f(float, float)
@@ -100,35 +100,35 @@ define <2 x half> @fmin_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1)
}
; CHECK-LABEL: fmin_xorsign_abs_bf16
-define i16 @fmin_xorsign_abs_bf16(i16 %0, i16 %1) {
+define bfloat @fmin_xorsign_abs_bf16(bfloat %0, bfloat %1) {
; CHECK-NOT: call
; CHECK: min.xorsign.abs.bf16
- %res = call i16 @llvm.nvvm.fmin.xorsign.abs.bf16(i16 %0, i16 %1)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fmin.xorsign.abs.bf16(bfloat %0, bfloat %1)
+ ret bfloat %res
}
; CHECK-LABEL: fmin_nan_xorsign_abs_bf16
-define i16 @fmin_nan_xorsign_abs_bf16(i16 %0, i16 %1) {
+define bfloat @fmin_nan_xorsign_abs_bf16(bfloat %0, bfloat %1) {
; CHECK-NOT: call
; CHECK: min.NaN.xorsign.abs.bf16
- %res = call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16(i16 %0, i16 %1)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16(bfloat %0, bfloat %1)
+ ret bfloat %res
}
; CHECK-LABEL: fmin_xorsign_abs_bf16x2
-define i32 @fmin_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+define <2 x bfloat> @fmin_xorsign_abs_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) {
; CHECK-NOT: call
; CHECK: min.xorsign.abs.bf16x2
- %res = call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2(i32 %0, i32 %1)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1)
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fmin_nan_xorsign_abs_bf16x2
-define i32 @fmin_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+define <2 x bfloat> @fmin_nan_xorsign_abs_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) {
; CHECK-NOT: call
; CHECK: min.NaN.xorsign.abs.bf16x2
- %res = call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(i32 %0, i32 %1)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1)
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fmin_xorsign_abs_f
@@ -228,35 +228,35 @@ define <2 x half> @fmax_ftz_nan_xorsign_abs_f16x2(<2 x half> %0, <2 x half> %1)
}
; CHECK-LABEL: fmax_xorsign_abs_bf16
-define i16 @fmax_xorsign_abs_bf16(i16 %0, i16 %1) {
+define bfloat @fmax_xorsign_abs_bf16(bfloat %0, bfloat %1) {
; CHECK-NOT: call
; CHECK: max.xorsign.abs.bf16
- %res = call i16 @llvm.nvvm.fmax.xorsign.abs.bf16(i16 %0, i16 %1)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fmax.xorsign.abs.bf16(bfloat %0, bfloat %1)
+ ret bfloat %res
}
; CHECK-LABEL: fmax_nan_xorsign_abs_bf16
-define i16 @fmax_nan_xorsign_abs_bf16(i16 %0, i16 %1) {
+define bfloat @fmax_nan_xorsign_abs_bf16(bfloat %0, bfloat %1) {
; CHECK-NOT: call
; CHECK: max.NaN.xorsign.abs.bf16
- %res = call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16(i16 %0, i16 %1)
- ret i16 %res
+ %res = call bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16(bfloat %0, bfloat %1)
+ ret bfloat %res
}
; CHECK-LABEL: fmax_xorsign_abs_bf16x2
-define i32 @fmax_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+define <2 x bfloat> @fmax_xorsign_abs_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) {
; CHECK-NOT: call
; CHECK: max.xorsign.abs.bf16x2
- %res = call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2(i32 %0, i32 %1)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1)
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fmax_nan_xorsign_abs_bf16x2
-define i32 @fmax_nan_xorsign_abs_bf16x2(i32 %0, i32 %1) {
+define <2 x bfloat> @fmax_nan_xorsign_abs_bf16x2(<2 x bfloat> %0, <2 x bfloat> %1) {
; CHECK-NOT: call
; CHECK: max.NaN.xorsign.abs.bf16x2
- %res = call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(i32 %0, i32 %1)
- ret i32 %res
+ %res = call <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2(<2 x bfloat> %0, <2 x bfloat> %1)
+ ret <2 x bfloat> %res
}
; CHECK-LABEL: fmax_xorsign_abs_f
diff --git a/llvm/test/CodeGen/NVPTX/param-load-store.ll b/llvm/test/CodeGen/NVPTX/param-load-store.ll
index b05fbaea17087..313a0915d2030 100644
--- a/llvm/test/CodeGen/NVPTX/param-load-store.ll
+++ b/llvm/test/CodeGen/NVPTX/param-load-store.ll
@@ -381,13 +381,13 @@ define <5 x i16> @test_v5i16(<5 x i16> %a) {
ret <5 x i16> %r;
}
-; CHECK: .func (.param .b32 func_retval0)
+; CHECK: .func (.param .align 2 .b8 func_retval0[2])
; CHECK-LABEL: test_f16(
-; CHECK-NEXT: .param .b32 test_f16_param_0
+; CHECK-NEXT: .param .align 2 .b8 test_f16_param_0[2]
; CHECK: ld.param.b16 [[E:%rs[0-9]+]], [test_f16_param_0];
-; CHECK: .param .b32 param0;
+; CHECK: .param .align 2 .b8 param0[2];
; CHECK: st.param.b16 [param0+0], [[E]];
-; CHECK: .param .b32 retval0;
+; CHECK: .param .align 2 .b8 retval0[2];
; CHECK: call.uni (retval0),
; CHECK-NEXT: test_f16,
; CHECK: ld.param.b16 [[R:%rs[0-9]+]], [retval0+0];
@@ -415,6 +415,41 @@ define <2 x half> @test_v2f16(<2 x half> %a) {
ret <2 x half> %r;
}
+; CHECK: .func (.param .align 2 .b8 func_retval0[2])
+; CHECK-LABEL: test_bf16(
+; CHECK-NEXT: .param .align 2 .b8 test_bf16_param_0[2]
+; CHECK: ld.param.b16 [[E:%rs[0-9]+]], [test_bf16_param_0];
+; CHECK: .param .align 2 .b8 param0[2];
+; CHECK: st.param.b16 [param0+0], [[E]];
+; CHECK: .param .align 2 .b8 retval0[2];
+; CHECK: call.uni (retval0),
+; CHECK-NEXT: test_bf16,
+; CHECK: ld.param.b16 [[R:%rs[0-9]+]], [retval0+0];
+; CHECK: st.param.b16 [func_retval0+0], [[R]]
+; CHECK-NEXT: ret;
+define bfloat @test_bf16(bfloat %a) {
+ %r = tail call bfloat @test_bf16(bfloat %a);
+ ret bfloat %r;
+}
+
+; CHECK: .func (.param .align 4 .b8 func_retval0[4])
+; CHECK-LABEL: test_v2bf16(
+; CHECK-NEXT: .param .align 4 .b8 test_v2bf16_param_0[4]
+; CHECK: ld.param.b32 [[E:%r[0-9]+]], [test_v2bf16_param_0];
+; CHECK: .param .align 4 .b8 param0[4];
+; CHECK: st.param.b32 [param0+0], [[E]];
+; CHECK: .param .align 4 .b8 retval0[4];
+; CHECK: call.uni (retval0),
+; CHECK-NEXT: test_v2bf16,
+; CHECK: ld.param.b32 [[R:%r[0-9]+]], [retval0+0];
+; CHECK: st.param.b32 [func_retval0+0], [[R]]
+; CHECK-NEXT: ret;
+define <2 x bfloat> @test_v2bf16(<2 x bfloat> %a) {
+ %r = tail call <2 x bfloat> @test_v2bf16(<2 x bfloat> %a);
+ ret <2 x bfloat> %r;
+}
+
+
; CHECK:.func (.param .align 8 .b8 func_retval0[8])
; CHECK-LABEL: test_v3f16(
; CHECK: .param .align 8 .b8 test_v3f16_param_0[8]
More information about the cfe-commits
mailing list