[clang] 3644726 - [Clang][NVPTX] Add NVPTX intrinsics and builtins for CUDA PTX 6.5 and 7.0 WMMA and MMA instructions
Artem Belevich via cfe-commits
cfe-commits at lists.llvm.org
Tue Jun 29 15:44:58 PDT 2021
Author: Steffen Larsen
Date: 2021-06-29T15:44:07-07:00
New Revision: 3644726a78e37823b1687a7aa8d186e91570ffe2
URL: https://github.com/llvm/llvm-project/commit/3644726a78e37823b1687a7aa8d186e91570ffe2
DIFF: https://github.com/llvm/llvm-project/commit/3644726a78e37823b1687a7aa8d186e91570ffe2.diff
LOG: [Clang][NVPTX] Add NVPTX intrinsics and builtins for CUDA PTX 6.5 and 7.0 WMMA and MMA instructions
Adds NVPTX builtins and intrinsics for the CUDA PTX `wmma.load`, `wmma.store`, `wmma.mma`, and `mma` instructions added in PTX 6.5 and 7.0.
PTX ISA description of
- `wmma.load`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-ld
- `wmma.store`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-st
- `wmma.mma`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-wmma-mma
- `mma`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma
Overview of `wmma.mma` and `mma` matrix shape/type combinations added with specific PTX versions: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape
Authored-by: Steffen Larsen <steffen.larsen at codeplay.com>
Co-Authored-by: Stuart Adams <stuart.adams at codeplay.com>
Reviewed By: tra
Differential Revision: https://reviews.llvm.org/D104847
Added:
Modified:
clang/include/clang/Basic/BuiltinsNVPTX.def
clang/lib/CodeGen/CGBuiltin.cpp
clang/test/CodeGen/builtins-nvptx-mma.cu
clang/test/CodeGen/builtins-nvptx-mma.py
llvm/include/llvm/IR/IntrinsicsNVVM.td
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/test/CodeGen/NVPTX/lit.local.cfg
llvm/test/CodeGen/NVPTX/wmma.py
Removed:
################################################################################
diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def
index 98f3c659b7cec..e815138a15c15 100644
--- a/clang/include/clang/Basic/BuiltinsNVPTX.def
+++ b/clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -759,6 +759,29 @@ TARGET_BUILTIN(__imma_m8n8k32_mma_s4, "vi*iC*iC*iC*IiIi", "", AND(SM_75,PTX63))
TARGET_BUILTIN(__imma_m8n8k32_mma_u4, "vi*iC*iC*iC*IiIi", "", AND(SM_75,PTX63))
TARGET_BUILTIN(__imma_m8n8k32_st_c_i32, "vi*iC*UiIi", "", AND(SM_75,PTX63))
+// Builtins to support double and alternate float WMMA instructions on sm_80
+TARGET_BUILTIN(__dmma_m8n8k4_ld_a, "vd*dC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__dmma_m8n8k4_ld_b, "vd*dC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__dmma_m8n8k4_ld_c, "vd*dC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__dmma_m8n8k4_st_c_f64, "vd*dC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__dmma_m8n8k4_mma_f64, "vd*dC*dC*dC*IiIi", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__mma_bf16_m16n16k16_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m16n16k16_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m16n16k16_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m8n32k16_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m8n32k16_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m8n32k16_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m32n8k16_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m32n8k16_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_bf16_m32n8k16_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__mma_tf32_m16n16k8_ld_a, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_tf32_m16n16k8_ld_b, "vi*iC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_tf32_m16n16k8_ld_c, "vf*fC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_m16n16k8_st_c_f32, "vf*fC*UiIi", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__mma_tf32_m16n16k8_mma_f32, "vf*iC*iC*fC*IiIi", "", AND(SM_80,PTX70))
+
// Async Copy
TARGET_BUILTIN(__nvvm_cp_async_mbarrier_arrive, "vWi*", "", AND(SM_80,PTX70))
TARGET_BUILTIN(__nvvm_cp_async_mbarrier_arrive_shared, "vWi*3", "", AND(SM_80,PTX70))
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 2f2d5e6c83d77..3fc9ba414397e 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -16402,6 +16402,34 @@ static NVPTXMmaLdstInfo getNVPTXMmaLdstInfo(unsigned BuiltinID) {
case NVPTX::BI__bmma_m8n8k128_ld_c:
return MMA_LDST(2, m8n8k128_load_c_s32);
+ // Double MMA loads
+ case NVPTX::BI__dmma_m8n8k4_ld_a:
+ return MMA_LDST(1, m8n8k4_load_a_f64);
+ case NVPTX::BI__dmma_m8n8k4_ld_b:
+ return MMA_LDST(1, m8n8k4_load_b_f64);
+ case NVPTX::BI__dmma_m8n8k4_ld_c:
+ return MMA_LDST(2, m8n8k4_load_c_f64);
+
+ // Alternate float MMA loads
+ case NVPTX::BI__mma_bf16_m16n16k16_ld_a:
+ return MMA_LDST(4, m16n16k16_load_a_bf16);
+ case NVPTX::BI__mma_bf16_m16n16k16_ld_b:
+ return MMA_LDST(4, m16n16k16_load_b_bf16);
+ case NVPTX::BI__mma_bf16_m8n32k16_ld_a:
+ return MMA_LDST(2, m8n32k16_load_a_bf16);
+ case NVPTX::BI__mma_bf16_m8n32k16_ld_b:
+ return MMA_LDST(8, m8n32k16_load_b_bf16);
+ case NVPTX::BI__mma_bf16_m32n8k16_ld_a:
+ return MMA_LDST(8, m32n8k16_load_a_bf16);
+ case NVPTX::BI__mma_bf16_m32n8k16_ld_b:
+ return MMA_LDST(2, m32n8k16_load_b_bf16);
+ case NVPTX::BI__mma_tf32_m16n16k8_ld_a:
+ return MMA_LDST(4, m16n16k8_load_a_tf32);
+ case NVPTX::BI__mma_tf32_m16n16k8_ld_b:
+ return MMA_LDST(2, m16n16k8_load_b_tf32);
+ case NVPTX::BI__mma_tf32_m16n16k8_ld_c:
+ return MMA_LDST(8, m16n16k8_load_c_f32);
+
// NOTE: We need to follow inconsitent naming scheme used by NVCC. Unlike
// PTX and LLVM IR where stores always use fragment D, NVCC builtins always
// use fragment C for both loads and stores.
@@ -16433,6 +16461,14 @@ static NVPTXMmaLdstInfo getNVPTXMmaLdstInfo(unsigned BuiltinID) {
case NVPTX::BI__bmma_m8n8k128_st_c_i32:
return MMA_LDST(2, m8n8k128_store_d_s32);
+ // Double MMA store
+ case NVPTX::BI__dmma_m8n8k4_st_c_f64:
+ return MMA_LDST(2, m8n8k4_store_d_f64);
+
+ // Alternate float MMA store
+ case NVPTX::BI__mma_m16n16k8_st_c_f32:
+ return MMA_LDST(8, m16n16k8_store_d_f32);
+
default:
llvm_unreachable("Unknown MMA builtin");
}
@@ -16446,10 +16482,14 @@ struct NVPTXMmaInfo {
unsigned NumEltsB;
unsigned NumEltsC;
unsigned NumEltsD;
+
+ // Variants are ordered by layout-A/layout-B/satf, where 'row' has priority
+ // over 'col' for layout. The index of non-satf variants is expected to match
+ // the undocumented layout constants used by CUDA's mma.hpp.
std::array<unsigned, 8> Variants;
unsigned getMMAIntrinsic(int Layout, bool Satf) {
- unsigned Index = Layout * 2 + Satf;
+ unsigned Index = Layout + 4 * Satf;
if (Index >= Variants.size())
return 0;
return Variants[Index];
@@ -16460,93 +16500,107 @@ struct NVPTXMmaInfo {
// Layout and Satf, 0 otherwise.
static NVPTXMmaInfo getNVPTXMmaInfo(unsigned BuiltinID) {
// clang-format off
-#define MMA_VARIANTS(geom, type) {{ \
+#define MMA_VARIANTS(geom, type) \
Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \
- Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \
Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
- Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type
+#define MMA_SATF_VARIANTS(geom, type) \
+ MMA_VARIANTS(geom, type), \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \
- Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \
- Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \
- }}
+ Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite
// Sub-integer MMA only supports row.col layout.
-#define MMA_VARIANTS_I4(geom, type) {{ \
- 0, \
+#define MMA_VARIANTS_I4(geom, type) \
0, \
Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
- Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
0, \
0, \
0, \
- 0 \
- }}
-// b1 MMA does not support .satfinite.
-#define MMA_VARIANTS_B1(geom, type) {{ \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
0, \
+ 0
+// b1 MMA does not support .satfinite.
+#define MMA_VARIANTS_B1(geom, type) \
0, \
Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
0, \
0, \
0, \
0, \
- 0 \
- }}
- // clang-format on
- switch (BuiltinID) {
- // FP MMA
- // Note that 'type' argument of MMA_VARIANT uses D_C notation, while
- // NumEltsN of return value are ordered as A,B,C,D.
- case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
- return {8, 8, 4, 4, MMA_VARIANTS(m16n16k16, f16_f16)};
- case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
- return {8, 8, 4, 8, MMA_VARIANTS(m16n16k16, f32_f16)};
- case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
- return {8, 8, 8, 4, MMA_VARIANTS(m16n16k16, f16_f32)};
- case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
- return {8, 8, 8, 8, MMA_VARIANTS(m16n16k16, f32_f32)};
- case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
- return {8, 8, 4, 4, MMA_VARIANTS(m32n8k16, f16_f16)};
- case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
- return {8, 8, 4, 8, MMA_VARIANTS(m32n8k16, f32_f16)};
- case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
- return {8, 8, 8, 4, MMA_VARIANTS(m32n8k16, f16_f32)};
- case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
- return {8, 8, 8, 8, MMA_VARIANTS(m32n8k16, f32_f32)};
- case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
- return {8, 8, 4, 4, MMA_VARIANTS(m8n32k16, f16_f16)};
- case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
- return {8, 8, 4, 8, MMA_VARIANTS(m8n32k16, f32_f16)};
- case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
- return {8, 8, 8, 4, MMA_VARIANTS(m8n32k16, f16_f32)};
- case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
- return {8, 8, 8, 8, MMA_VARIANTS(m8n32k16, f32_f32)};
-
- // Integer MMA
- case NVPTX::BI__imma_m16n16k16_mma_s8:
- return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, s8)};
- case NVPTX::BI__imma_m16n16k16_mma_u8:
- return {2, 2, 8, 8, MMA_VARIANTS(m16n16k16, u8)};
- case NVPTX::BI__imma_m32n8k16_mma_s8:
- return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, s8)};
- case NVPTX::BI__imma_m32n8k16_mma_u8:
- return {4, 1, 8, 8, MMA_VARIANTS(m32n8k16, u8)};
- case NVPTX::BI__imma_m8n32k16_mma_s8:
- return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, s8)};
- case NVPTX::BI__imma_m8n32k16_mma_u8:
- return {1, 4, 8, 8, MMA_VARIANTS(m8n32k16, u8)};
-
- // Sub-integer MMA
- case NVPTX::BI__imma_m8n8k32_mma_s4:
- return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, s4)};
- case NVPTX::BI__imma_m8n8k32_mma_u4:
- return {1, 1, 2, 2, MMA_VARIANTS_I4(m8n8k32, u4)};
- case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
- return {1, 1, 2, 2, MMA_VARIANTS_B1(m8n8k128, b1)};
- default:
- llvm_unreachable("Unexpected builtin ID.");
- }
+ 0, \
+ 0
+ // clang-format on
+ switch (BuiltinID) {
+ // FP MMA
+ // Note that 'type' argument of MMA_SATF_VARIANTS uses D_C notation, while
+ // NumEltsN of return value are ordered as A,B,C,D.
+ case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
+ return {8, 8, 4, 4, {{MMA_SATF_VARIANTS(m16n16k16, f16_f16)}}};
+ case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
+ return {8, 8, 4, 8, {{MMA_SATF_VARIANTS(m16n16k16, f32_f16)}}};
+ case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
+ return {8, 8, 8, 4, {{MMA_SATF_VARIANTS(m16n16k16, f16_f32)}}};
+ case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
+ return {8, 8, 8, 8, {{MMA_SATF_VARIANTS(m16n16k16, f32_f32)}}};
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
+ return {8, 8, 4, 4, {{MMA_SATF_VARIANTS(m32n8k16, f16_f16)}}};
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
+ return {8, 8, 4, 8, {{MMA_SATF_VARIANTS(m32n8k16, f32_f16)}}};
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
+ return {8, 8, 8, 4, {{MMA_SATF_VARIANTS(m32n8k16, f16_f32)}}};
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
+ return {8, 8, 8, 8, {{MMA_SATF_VARIANTS(m32n8k16, f32_f32)}}};
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
+ return {8, 8, 4, 4, {{MMA_SATF_VARIANTS(m8n32k16, f16_f16)}}};
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
+ return {8, 8, 4, 8, {{MMA_SATF_VARIANTS(m8n32k16, f32_f16)}}};
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
+ return {8, 8, 8, 4, {{MMA_SATF_VARIANTS(m8n32k16, f16_f32)}}};
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
+ return {8, 8, 8, 8, {{MMA_SATF_VARIANTS(m8n32k16, f32_f32)}}};
+
+ // Integer MMA
+ case NVPTX::BI__imma_m16n16k16_mma_s8:
+ return {2, 2, 8, 8, {{MMA_SATF_VARIANTS(m16n16k16, s8)}}};
+ case NVPTX::BI__imma_m16n16k16_mma_u8:
+ return {2, 2, 8, 8, {{MMA_SATF_VARIANTS(m16n16k16, u8)}}};
+ case NVPTX::BI__imma_m32n8k16_mma_s8:
+ return {4, 1, 8, 8, {{MMA_SATF_VARIANTS(m32n8k16, s8)}}};
+ case NVPTX::BI__imma_m32n8k16_mma_u8:
+ return {4, 1, 8, 8, {{MMA_SATF_VARIANTS(m32n8k16, u8)}}};
+ case NVPTX::BI__imma_m8n32k16_mma_s8:
+ return {1, 4, 8, 8, {{MMA_SATF_VARIANTS(m8n32k16, s8)}}};
+ case NVPTX::BI__imma_m8n32k16_mma_u8:
+ return {1, 4, 8, 8, {{MMA_SATF_VARIANTS(m8n32k16, u8)}}};
+
+ // Sub-integer MMA
+ case NVPTX::BI__imma_m8n8k32_mma_s4:
+ return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, s4)}}};
+ case NVPTX::BI__imma_m8n8k32_mma_u4:
+ return {1, 1, 2, 2, {{MMA_VARIANTS_I4(m8n8k32, u4)}}};
+ case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
+ return {1, 1, 2, 2, {{MMA_VARIANTS_B1(m8n8k128, b1)}}};
+
+ // Double MMA
+ case NVPTX::BI__dmma_m8n8k4_mma_f64:
+ return {1, 1, 2, 2, {{MMA_VARIANTS(m8n8k4, f64)}}};
+
+ // Alternate FP MMA
+ case NVPTX::BI__mma_bf16_m16n16k16_mma_f32:
+ return {4, 4, 8, 8, {{MMA_VARIANTS(m16n16k16, bf16)}}};
+ case NVPTX::BI__mma_bf16_m8n32k16_mma_f32:
+ return {2, 8, 8, 8, {{MMA_VARIANTS(m8n32k16, bf16)}}};
+ case NVPTX::BI__mma_bf16_m32n8k16_mma_f32:
+ return {8, 2, 8, 8, {{MMA_VARIANTS(m32n8k16, bf16)}}};
+ case NVPTX::BI__mma_tf32_m16n16k8_mma_f32:
+ return {4, 4, 8, 8, {{MMA_VARIANTS(m16n16k8, tf32)}}};
+ default:
+ llvm_unreachable("Unexpected builtin ID.");
+ }
#undef MMA_VARIANTS
+#undef MMA_SATF_VARIANTS
#undef MMA_VARIANTS_I4
#undef MMA_VARIANTS_B1
}
@@ -16844,7 +16898,20 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
case NVPTX::BI__bmma_m8n8k128_ld_a_b1:
case NVPTX::BI__bmma_m8n8k128_ld_b_b1:
case NVPTX::BI__bmma_m8n8k128_ld_c:
- {
+ // Double MMA loads.
+ case NVPTX::BI__dmma_m8n8k4_ld_a:
+ case NVPTX::BI__dmma_m8n8k4_ld_b:
+ case NVPTX::BI__dmma_m8n8k4_ld_c:
+ // Alternate float MMA loads.
+ case NVPTX::BI__mma_bf16_m16n16k16_ld_a:
+ case NVPTX::BI__mma_bf16_m16n16k16_ld_b:
+ case NVPTX::BI__mma_bf16_m8n32k16_ld_a:
+ case NVPTX::BI__mma_bf16_m8n32k16_ld_b:
+ case NVPTX::BI__mma_bf16_m32n8k16_ld_a:
+ case NVPTX::BI__mma_bf16_m32n8k16_ld_b:
+ case NVPTX::BI__mma_tf32_m16n16k8_ld_a:
+ case NVPTX::BI__mma_tf32_m16n16k8_ld_b:
+ case NVPTX::BI__mma_tf32_m16n16k8_ld_c: {
Address Dst = EmitPointerWithAlignment(E->getArg(0));
Value *Src = EmitScalarExpr(E->getArg(1));
Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -16889,7 +16956,9 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
case NVPTX::BI__imma_m32n8k16_st_c_i32:
case NVPTX::BI__imma_m8n32k16_st_c_i32:
case NVPTX::BI__imma_m8n8k32_st_c_i32:
- case NVPTX::BI__bmma_m8n8k128_st_c_i32: {
+ case NVPTX::BI__bmma_m8n8k128_st_c_i32:
+ case NVPTX::BI__dmma_m8n8k4_st_c_f64:
+ case NVPTX::BI__mma_m16n16k8_st_c_f32: {
Value *Dst = EmitScalarExpr(E->getArg(0));
Address Src = EmitPointerWithAlignment(E->getArg(1));
Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -16941,7 +17010,12 @@ CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, const CallExpr *E) {
case NVPTX::BI__imma_m8n32k16_mma_u8:
case NVPTX::BI__imma_m8n8k32_mma_s4:
case NVPTX::BI__imma_m8n8k32_mma_u4:
- case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1: {
+ case NVPTX::BI__bmma_m8n8k128_mma_xor_popc_b1:
+ case NVPTX::BI__dmma_m8n8k4_mma_f64:
+ case NVPTX::BI__mma_bf16_m16n16k16_mma_f32:
+ case NVPTX::BI__mma_bf16_m8n32k16_mma_f32:
+ case NVPTX::BI__mma_bf16_m32n8k16_mma_f32:
+ case NVPTX::BI__mma_tf32_m16n16k8_mma_f32: {
Address Dst = EmitPointerWithAlignment(E->getArg(0));
Address SrcA = EmitPointerWithAlignment(E->getArg(1));
Address SrcB = EmitPointerWithAlignment(E->getArg(2));
diff --git a/clang/test/CodeGen/builtins-nvptx-mma.cu b/clang/test/CodeGen/builtins-nvptx-mma.cu
index cc31f6f4779a5..7e9bac86792d2 100644
--- a/clang/test/CodeGen/builtins-nvptx-mma.cu
+++ b/clang/test/CodeGen/builtins-nvptx-mma.cu
@@ -3,21 +3,20 @@
// *** DO NOT EDIT ***
//
// This test has been automatically generated by
-// builtins-nvtx-mma.py --ptx=63 --gpu-arch=75
+// builtins-nvtx-mma.py --ptx=70 --gpu-arch=80
//
-// Make sure we can handle all builtins available on sm_75 with PTX63
-// RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_75 \
-// RUN: -fcuda-is-device -target-feature +ptx63 \
-// RUN: -DPTX=63 -DSM=75 \
+// Make sure we can handle all builtins available on sm_80 with PTX70
+// RUN: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_80 \
+// RUN: -fcuda-is-device -target-feature +ptx70 \
+// RUN: -DPTX=70 -DSM=80 \
// RUN: -S -emit-llvm -o - -x cuda %s \
-// RUN: | FileCheck -check-prefixes=CHECK_PTX61_SM70,CHECK_PTX63_SM75,CHECK_PTX63_SM72,CHECK_PTX60_SM70 %s
+// RUN: | FileCheck -check-prefixes=CHECK_PTX70_SM80,CHECK_PTX60_SM70,CHECK_PTX63_SM72,CHECK_PTX61_SM70,CHECK_PTX63_SM75 %s
// Verify that all builtins have correct constraints.
// RUN: %clang_cc1 -triple nvptx-unknown-unknown \
// RUN: -target-cpu sm_60 -target-feature +ptx42 \
-// RUN: -DPTX=63 -DSM=75 -fcuda-is-device -S -o /dev/null -x cuda \
+// RUN: -DPTX=70 -DSM=80 -fcuda-is-device -S -o /dev/null -x cuda \
// RUN: -verify %s
-
#if !defined(CUDA_VERSION)
#define __device__ __attribute__((device))
#define __global__ __attribute__((global))
@@ -29,8 +28,8 @@ typedef unsigned long long uint64_t;
// CHECK-LABEL: test_wmma_buitins
__device__ void test_wmma_buitins(int *src, int *dst,
- float *fsrc, float *fdst, int ldm) {
-
+ float *fsrc, float *fdst,
+ double *dsrc, double *ddst, int ldm) {
#if (PTX >= 60) && (SM >= 70)
@@ -751,5 +750,153 @@ __device__ void test_wmma_buitins(int *src, int *dst,
// CHECK_PTX63_SM75: call {{.*}} @llvm.nvvm.wmma.m8n8k32.mma.row.col.u4.satfinite
// expected-error-re at +1 {{'__imma_m8n8k32_mma_u4' needs target feature (sm_75{{.*}},(ptx63{{.*}}}}
__imma_m8n8k32_mma_u4(dst, src, src, src, 1, 1);
-#endif // (PTX >= 63) && (SM >= 75)
+#endif // (PTX >= 63) && (SM >= 75)
+
+#if (PTX >= 70) && (SM >= 80)
+
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m16n16k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m16n16k16_ld_a(dst, src, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m16n16k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m16n16k16_ld_a(dst, src, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m16n16k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m16n16k16_ld_b(dst, src, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m16n16k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m16n16k16_ld_b(dst, src, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_ld_a(dst, src, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_ld_a(dst, src, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_ld_b(dst, src, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.b.row.stride.tf32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_ld_b(dst, src, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.c.col.stride.f32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_ld_c(fdst, fsrc, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.load.c.row.stride.f32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_ld_c(fdst, fsrc, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.store.d.col.stride.f32
+ // expected-error-re at +1 {{'__mma_m16n16k8_st_c_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_m16n16k8_st_c_f32(fdst, fsrc, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.store.d.row.stride.f32
+ // expected-error-re at +1 {{'__mma_m16n16k8_st_c_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_m16n16k8_st_c_f32(fdst, fsrc, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m32n8k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m32n8k16_ld_a(dst, src, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m32n8k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m32n8k16_ld_a(dst, src, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m32n8k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m32n8k16_ld_b(dst, src, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m32n8k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m32n8k16_ld_b(dst, src, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m8n32k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m8n32k16_ld_a(dst, src, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m8n32k16_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m8n32k16_ld_a(dst, src, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m8n32k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m8n32k16_ld_b(dst, src, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m8n32k16_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m8n32k16_ld_b(dst, src, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.a.col.stride.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_ld_a(ddst, dsrc, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.a.row.stride.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_ld_a' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_ld_a(ddst, dsrc, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.b.col.stride.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_ld_b(ddst, dsrc, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.b.row.stride.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_ld_b' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_ld_b(ddst, dsrc, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.c.col.stride.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_ld_c(ddst, dsrc, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.load.c.row.stride.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_ld_c' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_ld_c(ddst, dsrc, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.store.d.col.stride.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_st_c_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_st_c_f64(ddst, dsrc, ldm, 1);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.store.d.row.stride.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_st_c_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_st_c_f64(ddst, dsrc, ldm, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 3, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.col.row.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 2, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.col.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 1, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m16n16k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m16n16k16_mma_f32(fdst, src, src, fsrc, 0, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 3, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.col.row.tf32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 2, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.row.col.tf32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 1, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m16n16k8.mma.row.row.tf32
+ // expected-error-re at +1 {{'__mma_tf32_m16n16k8_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_tf32_m16n16k8_mma_f32(fdst, src, src, fsrc, 0, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 3, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.col.row.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 2, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.row.col.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 1, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m32n8k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m32n8k16_mma_f32(fdst, src, src, fsrc, 0, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 3, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.col.row.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 2, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.row.col.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 1, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16
+ // expected-error-re at +1 {{'__mma_bf16_m8n32k16_mma_f32' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __mma_bf16_m8n32k16_mma_f32(fdst, src, src, fsrc, 0, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.col.col.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 3, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.col.row.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 2, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.col.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 1, 0);
+ // CHECK_PTX70_SM80: call {{.*}} @llvm.nvvm.wmma.m8n8k4.mma.row.row.f64
+ // expected-error-re at +1 {{'__dmma_m8n8k4_mma_f64' needs target feature (sm_80{{.*}},(ptx70{{.*}}}}
+ __dmma_m8n8k4_mma_f64(ddst, dsrc, dsrc, dsrc, 0, 0);
+#endif // (PTX >= 70) && (SM >= 80)
}
diff --git a/clang/test/CodeGen/builtins-nvptx-mma.py b/clang/test/CodeGen/builtins-nvptx-mma.py
index 1b395fc4f33b1..2ffc21b12fb06 100644
--- a/clang/test/CodeGen/builtins-nvptx-mma.py
+++ b/clang/test/CodeGen/builtins-nvptx-mma.py
@@ -47,7 +47,13 @@ def make_ldst_ops(geoms, frags, types):
in product(geoms, frags, types)]
def get_mma_ops():
- return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+ return (make_mma_ops(["m16n16k8"],
+ ["tf32"], [], ["f32"], []) +
+ make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["bf16"], [], ["f32"], []) +
+ make_mma_ops(["m8n8k4"],
+ ["f64"], [], ["f64"], []) +
+ make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["s8", "u8"], [], ["s32"], []) +
@@ -55,14 +61,18 @@ def get_mma_ops():
["s4", "u4"], [], ["s32"], []) +
make_mma_ops(["m8n8k128"],
["b1"], [], ["s32"], []))
+
def get_ldst_ops():
return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
- ["a", "b"], ["f16", "u8", "s8"]) +
+ ["a", "b"], ["f16", "u8", "s8", "bf16"]) +
make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["c", "d"], ["f16", "f32", "s32"]) +
make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
- make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]))
+ make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) +
+ make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
+ make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
+ make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
def is_geom_supported(geom):
# geometries for FP and ints.
@@ -73,6 +83,8 @@ def is_geom_supported(geom):
return ptx_version >= 63 and gpu_arch >= 75
if geom == "m16n16k16":
return ptx_version >= 60
+ if geom in ["m16n16k8", "m8n8k4"]:
+ return ptx_version >= 70 and gpu_arch >= 80
assert(False) # Unexpected geometry.
def is_type_supported(ptx_type):
@@ -80,16 +92,24 @@ def is_type_supported(ptx_type):
return ptx_version >= 63 and gpu_arch >= 72
if ptx_type in ["s4", "u4", "b1"]:
return ptx_version >= 63 and gpu_arch >= 75
+ if ptx_type in ["bf16", "tf32", "f64"]:
+ return ptx_version >= 70 and gpu_arch >= 80
return ptx_version >= 60 and gpu_arch >= 70
+def is_rnd_supported(op):
+ # rnd is only supported for FP64 WMMA
+ return op.a.ptx_type == "f64"
+
def is_mma_variant_supported(op, layout_a, layout_b, satf):
if not (is_type_supported(op.a.ptx_type)
and is_geom_supported(op.a.geom)):
return False
- # sub-integer require row/col layout, and no satf.
+
+ if satf and not op.a.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
+ return False
+
+ # sub-integer types require row/col layout.
if op.a.ptx_type in ["s4", "u4", "b1"]:
- if op.a.ptx_type == "b1" and satf:
- return False
return layout_a == "row" and layout_b == "col"
return True
@@ -98,7 +118,7 @@ def is_ldst_variant_supported(frag, layout):
and is_geom_supported(frag.geom)):
return False
if frag.ptx_type in ["s4", "u4", "b1"]:
- # sub-integer require sm_75 and ptx63, row/col layout for a/b.
+ # sub-integer types require sm_75 and ptx63, row/col layout for a/b.
return ((frag.frag == "a" and layout == "row")
or (frag.frag == "b" and layout == "col")
or frag.frag in ["c", "d"])
@@ -109,12 +129,21 @@ def get_builtin_prefix(frag):
if frag.geom in ["m16n16k16", "m32n8k16", "m8n32k16"]:
if frag.ptx_type in ["f16", "f32"]:
prefix = "__hmma"
+ elif frag.ptx_type == "bf16":
+ prefix = "__mma_bf16"
else:
prefix = "__imma"
elif frag.geom == "m8n8k32":
prefix = "__imma" # sub-integers
elif frag.geom == "m8n8k128":
prefix = "__bmma"
+ elif frag.geom == "m8n8k4":
+ prefix = "__dmma"
+ elif frag.geom == "m16n16k8":
+ if frag.ptx_type == "f32":
+ prefix = "__mma"
+ else:
+ prefix = "__mma_tf32"
assert prefix
return prefix
@@ -123,10 +152,13 @@ def get_ldst_builtin_name(frag):
if prefix == "__hmma":
suffix = "" if frag.frag in ["a","b"] else frag.ptx_type
- elif prefix in ["__imma", "__bmma"]:
- suffix = "" if frag.frag in ["c"] else frag.ptx_type
+ elif prefix in ["__dmma", "__mma_bf16", "__mma_tf32"]:
+ suffix = "" if frag.frag in ["a","b","c"] else frag.ptx_type
+ else:
+ suffix = "" if frag.frag == "c" else frag.ptx_type
if suffix == "s32":
suffix = "i32"
+
if frag.frag == "d":
ifrag = "c"
op = "st"
@@ -143,6 +175,8 @@ def get_mma_builtin_name(op):
if prefix == "__hmma":
suffix = op.d.ptx_type + op.c.ptx_type
+ elif prefix in ["__mma_bf16", "__mma_tf32"]:
+ suffix = op.d.ptx_type
else:
suffix = op.a.ptx_type
@@ -151,8 +185,9 @@ def get_mma_builtin_name(op):
suffix)
return name
-
def get_required_sm(frag):
+ if frag.ptx_type in ["f64", "bf16", "tf32"]:
+ return 80
if frag.ptx_type in ["u4", "s4", "b1"]:
return 75
if frag.ptx_type in ["s8", "u8"]:
@@ -163,18 +198,34 @@ def get_required_sm(frag):
else: # s8/u8
return 72
if frag.ptx_type in ["f16", "f32"]:
- return 70
+ if frag.geom == "m16n16k8":
+ return 80
+ else:
+ return 70
assert(False)
def get_required_ptx(frag):
+ if frag.ptx_type in ["f64", "bf16", "tf32"]:
+ return 70
if frag.ptx_type in ["f16", "f32"]:
- return 60 if frag.geom == "m16n16k16" else 61
+ if frag.geom == "m16n16k16":
+ return 60
+ if frag.geom == "m16n16k8":
+ return 70
+ return 61
return 63
+def get_src_dst_prefix(ptx_type):
+ if ptx_type == "f32":
+ return "f"
+ if ptx_type == "f64":
+ return "d"
+ return ""
+
def gen_wmma_ldst_tests(results):
load_template = """
// CHECK${check_suffix}: call {{.*}} @${intrinsic}
- // expected-error-re at +1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}}
+ // expected-error-re at +1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
${builtin}(${dst}, ${src}, ldm, ${blayout});
""".rstrip()
intrinsic_template = "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}"
@@ -184,7 +235,7 @@ def gen_wmma_ldst_tests(results):
if not is_ldst_variant_supported(frag, layout):
continue
- is_fp = frag.ptx_type == "f32"
+ src_dst_prefix = get_src_dst_prefix(frag.ptx_type)
min_sm = get_required_sm(frag)
min_ptx = get_required_ptx(frag)
params = {
@@ -192,8 +243,8 @@ def gen_wmma_ldst_tests(results):
"builtin" : get_ldst_builtin_name(frag),
"min_ptx" : min_ptx,
"min_sm" : min_sm,
- "dst": "fdst" if is_fp else "dst",
- "src": "fsrc" if is_fp else "src",
+ "dst": src_dst_prefix + "dst",
+ "src": src_dst_prefix + "src",
"blayout" : 0 if layout == "row" else 1,
"intrinsic" : Template(intrinsic_template).substitute({
"frag" : frag.frag,
@@ -208,12 +259,12 @@ def gen_wmma_ldst_tests(results):
return results
def mma_signature(op):
- if op.a.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
- # int and sub-int ops are identified by input type.
- return op.a.ptx_type
- else:
- # the rest are FP ops identified by accumulator & result type.
+ if op.a.ptx_type == "f16":
+ # FP16 ops identified by accumulator & result type.
return "%s.%s" % (op.d.ptx_type, op.c.ptx_type)
+ else:
+ # other ops are identified by input type.
+ return op.a.ptx_type
# Get numeric value for rowcol parameter of the builtin
# AFAICT it uses the encoding accepted by NVVM intrinsics:
@@ -229,8 +280,8 @@ def get_ilayout(a, b):
def gen_wmma_mma_tests(results):
mma_template = """
// CHECK${check_suffix}: call {{.*}} @${intrinsic}
- // expected-error-re at +1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}}
- ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_isatf});
+ // expected-error-re at +1 {{'${builtin}' needs target feature (sm_${min_sm}{{.*}},(ptx${min_ptx}{{.*}}}}
+ ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_satf});
""".rstrip()
intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
@@ -243,9 +294,9 @@ def gen_wmma_mma_tests(results):
if not is_mma_variant_supported(op, alayout, blayout, satf):
continue
- a_is_fp = op.a.ptx_type == "f32"
- c_is_fp = op.c.ptx_type == "f32"
- d_is_fp = op.d.ptx_type == "f32"
+ asrc_prefix = get_src_dst_prefix(op.a.ptx_type)
+ csrc_prefix = get_src_dst_prefix(op.c.ptx_type)
+ ddst_prefix = get_src_dst_prefix(op.d.ptx_type)
min_sm = get_required_sm(op.a)
min_ptx = get_required_ptx(op.a)
if op.a.ptx_type == "b1": # .b1 MMA has no satf argument.
@@ -257,11 +308,11 @@ def gen_wmma_mma_tests(results):
"builtin" : get_mma_builtin_name(op),
"min_ptx" : min_ptx,
"min_sm" : min_sm,
- "dst": "fdst" if d_is_fp else "dst",
- "asrc": "fsrc" if a_is_fp else "src",
- "csrc": "fsrc" if c_is_fp else "src",
+ "dst": ddst_prefix + "dst",
+ "asrc": asrc_prefix + "src",
+ "csrc": csrc_prefix + "src",
"ilayout" : get_ilayout(alayout, blayout),
- "maybe_isatf" : isatf_arg,
+ "maybe_satf" : isatf_arg,
"intrinsic" : Template(intrinsic_template).substitute({
"geom" : op.a.geom,
"alayout" : alayout,
@@ -322,7 +373,8 @@ def supported_variants(ptx, sm, results):
// CHECK-LABEL: test_wmma_buitins
__device__ void test_wmma_buitins(int *src, int *dst,
- float *fsrc, float *fdst, int ldm) {
+ float *fsrc, float *fdst,
+ double *dsrc, double *ddst, int ldm) {
""");
for (ptx, sm), tests in sorted(results.items()):
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 71e31b14f4c3b..3ce9dfb1bb807 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -52,13 +52,27 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string gft = Geom#":"#Frag#":"#ptx_elt_type;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
- // mma.sync.m8n8k4 uses smaller a/b fragments than wmma fp ops
+ // mma fp ops use smaller fragments than wmma fp ops
!eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2),
-
- // fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
- // All currently supported geometries use the same fragment format,
- // so we only need to consider {fragment, type}.
+ !eq(gft,"m16n8k8:a:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k8:b:f16") : [llvm_v2f16_ty],
+ !eq(gft,"m16n8k8:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k8:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k8:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k8:d:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k16:a:f16") : !listsplat(llvm_v2f16_ty, 4),
+ !eq(gft,"m16n8k16:b:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
+
+ // wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
+ // All other supported geometries use the same fragment format for f32 and
+ // f16, so we only need to consider {fragment, type}.
!eq(ft,"a:f16") : !listsplat(llvm_v2f16_ty, 8),
!eq(ft,"b:f16") : !listsplat(llvm_v2f16_ty, 8),
!eq(ft,"c:f16") : !listsplat(llvm_v2f16_ty, 4),
@@ -66,7 +80,36 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(ft,"c:f32") : !listsplat(llvm_float_ty, 8),
!eq(ft,"d:f32") : !listsplat(llvm_float_ty, 8),
- // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+ // wmma tf32 -> s32 @ m16n16k8
+ !eq(gft,"m16n16k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n16k8:b:tf32") : !listsplat(llvm_i32_ty, 4),
+
+ // mma tf32 -> s32 @ m16n16k8/m16n8k8
+ !eq(gft,"m16n8k4:a:tf32") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k4:b:tf32") : [llvm_i32_ty],
+ !eq(gft,"m16n8k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k8:b:tf32") : !listsplat(llvm_i32_ty, 2),
+
+ !eq(gft,"m8n8k4:a:f64") : [llvm_double_ty],
+ !eq(gft,"m8n8k4:b:f64") : [llvm_double_ty],
+ !eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
+ !eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
+
+ // wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+ !eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m8n32k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m8n32k16:b:bf16") : !listsplat(llvm_i32_ty, 8),
+ !eq(gft,"m32n8k16:a:bf16") : !listsplat(llvm_i32_ty, 8),
+ !eq(gft,"m32n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
+
+ // mma bf16 -> s32 @ m16n8k16/m16n8k8
+ !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
+
+ // wmma u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
!eq(gft,"m16n16k16:a:u8") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n16k16:a:s8") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n16k16:b:u8") : !listsplat(llvm_i32_ty, 2),
@@ -88,17 +131,65 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gft,"m32n8k16:c:s32") : !listsplat(llvm_i32_ty, 8),
!eq(gft,"m32n8k16:d:s32") : !listsplat(llvm_i32_ty, 8),
- // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
- !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
+ // mma u8/s8 -> s32 @ m8n8k16/m16n8k16/m16n8k32
+ !eq(gft,"m8n8k16:a:u8") : [llvm_i32_ty],
+ !eq(gft,"m8n8k16:a:s8") : [llvm_i32_ty],
+ !eq(gft,"m8n8k16:b:u8") : [llvm_i32_ty],
+ !eq(gft,"m8n8k16:b:s8") : [llvm_i32_ty],
+ !eq(gft,"m8n8k16:c:s32") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m8n8k16:d:s32") : !listsplat(llvm_i32_ty, 2),
+
+ !eq(gft,"m16n8k16:a:u8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k16:a:s8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k16:b:u8") : [llvm_i32_ty],
+ !eq(gft,"m16n8k16:b:s8") : [llvm_i32_ty],
+ !eq(gft,"m16n8k16:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k16:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ !eq(gft,"m16n8k32:a:u8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:a:s8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:b:u8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:s8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ // wmma/mma u4/s4 -> s32 @ m8n8k32 (u4/s4)
!eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty],
!eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty],
- !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
!eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty],
!eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty],
- !eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n8k32:c:s32") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m8n8k32:d:s32") : !listsplat(llvm_i32_ty, 2),
+
+ !eq(gft,"m16n8k32:a:u4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:a:s4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:u4") : [llvm_i32_ty],
+ !eq(gft,"m16n8k32:b:s4") : [llvm_i32_ty],
+ !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ !eq(gft,"m16n8k64:a:u4") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:a:s4") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:u4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k64:b:s4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ // wmma/mma b1 -> s32 @ m8n8k128(b1)
+ !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
+ !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
+ !eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
+
+ !eq(gft,"m16n8k128:a:b1") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k128:b:b1") : [llvm_i32_ty],
+ !eq(gft,"m16n8k128:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ !eq(gft,"m16n8k256:a:b1") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
);
}
@@ -125,35 +216,40 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
list<WMMA_REGS> id_frags = !cond(
- // int and sub-int ops are identified by input type.
- !eq(A.ptx_elt_type, "s8") : [A],
- !eq(A.ptx_elt_type, "u8") : [A],
- !eq(A.ptx_elt_type, "s4") : [A],
- !eq(A.ptx_elt_type, "u4") : [A],
- !eq(A.ptx_elt_type, "b1") : [A],
- // the rest are FP ops identified by accumulator & result type.
- true: [D, C]
+ // FP16 ops are identified by accumulator & result type.
+ !eq(A.ptx_elt_type, "f16") : [D, C],
+ // other ops are identified by input types.
+ !ne(A.ptx_elt_type, B.ptx_elt_type): [A, B],
+ true: [A]
);
string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type));
}
-class WMMA_NAME_MMA<string ALayout, string BLayout, int Satfinite,
- WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
- string llvm = !if(
- !eq(A.geom, "m8n8k4"),
- "llvm.nvvm.mma.m8n8k4"
- # "." # ALayout
- # "." # BLayout
- # signature,
- "llvm.nvvm.wmma."
- # A.geom
- # ".mma"
- # "." # ALayout
- # "." # BLayout
- # signature
- # !if(Satfinite, ".satfinite", ""));
+ string llvm = "llvm.nvvm.wmma."
+ # A.geom
+ # ".mma"
+ # "." # ALayout
+ # "." # BLayout
+ # !if(!ne(Rnd, ""), !strconcat(".", Rnd), "")
+ # signature
+ # !if(Satfinite, ".satfinite", "");
+
+ string record = !subst(".", "_",
+ !subst("llvm.", "int_", llvm));
+}
+class MMA_NAME<string ALayout, string BLayout, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string llvm = "llvm.nvvm.mma."
+ # A.geom
+ # "." # ALayout
+ # "." # BLayout
+ # !if(Satfinite, ".satfinite", "")
+ # signature;
string record = !subst(".", "_",
!subst("llvm.", "int_", llvm));
}
@@ -188,14 +284,18 @@ class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<string> ops = !foreach(x, ret, x.gft);
}
-
-
// Creates list of valid combinations of fragments. This is the master list that
// drives generation of corresponding intrinsics and instructions.
class NVVM_MMA_OPS<int _ = 0> {
- list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
+ list<list<WMMA_REGS>> tf32_wmma_ops = MMA_OPS<
+ ["m16n16k8"],
+ ["tf32"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> bf16_wmma_ops = MMA_OPS<
+ ["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["bf16"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> f64_wmma_ops = MMA_OPS<
["m8n8k4"],
- ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+ ["f64"], [], ["f64"], []>.ret;
list<list<WMMA_REGS>> fp_wmma_ops = MMA_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
@@ -208,16 +308,50 @@ class NVVM_MMA_OPS<int _ = 0> {
list<list<WMMA_REGS>> bit_wmma_ops = MMA_OPS<
["m8n8k128"],
["b1"], [], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> all_wmma_ops = !listconcat(
+ tf32_wmma_ops, bf16_wmma_ops, f64_wmma_ops,
+ fp_wmma_ops, int_wmma_ops, subint_wmma_ops, bit_wmma_ops);
+
+ list<list<WMMA_REGS>> tf32_mma_ops = MMA_OPS<
+ ["m16n8k4", "m16n8k8"],
+ ["tf32"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> bf16_mma_ops = MMA_OPS<
+ ["m16n8k16", "m16n8k8"],
+ ["bf16"], [], ["f32"], []>.ret;
+ list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
+ ["m8n8k4"],
+ ["f64"], [], ["f64"], []>.ret;
+ list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
+ ["m8n8k4", "m16n8k8", "m16n8k16"],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+ list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
+ ["m8n8k16", "m16n8k16", "m16n8k32"],
+ ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
+ ["m8n8k32", "m16n8k32", "m16n8k64"],
+ ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
+ ["m8n8k128", "m16n8k128", "m16n8k256"],
+ ["b1"], [], ["s32"], []>.ret;
list<list<WMMA_REGS>> all_mma_ops = !listconcat(
- fp_mma_ops, fp_wmma_ops, int_wmma_ops,
- subint_wmma_ops, bit_wmma_ops);
+ tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
+ fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
- ["a", "b"], ["f16", "u8", "s8"]>.ret;
+ ["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["c", "d"], ["f16", "f32", "s32"]>.ret;
+ list<WMMA_REGS> ldst_tf32_ab_ops = MMA_LDST_OPS<
+ ["m16n16k8"],
+ ["a", "b"], ["tf32"]>.ret;
+ list<WMMA_REGS> ldst_tf32_cd_ops = MMA_LDST_OPS<
+ ["m16n16k8"],
+ ["c", "d"], ["f32"]>.ret;
+ list<WMMA_REGS> ldst_f64_abcd_ops = MMA_LDST_OPS<
+ ["m8n8k4"],
+ ["a", "b", "c", "d"], ["f64"]>.ret;
list<WMMA_REGS> ldst_subint_ab_ops = MMA_LDST_OPS<
["m8n8k32"], ["a", "b"], ["s4","u4"]>.ret;
list<WMMA_REGS> ldst_bit_ab_ops = MMA_LDST_OPS<
@@ -225,6 +359,9 @@ class NVVM_MMA_OPS<int _ = 0> {
list<WMMA_REGS> ldst_subint_cd_ops = MMA_LDST_OPS<
["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]>.ret;
list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
+ ldst_tf32_ab_ops,
+ ldst_tf32_cd_ops,
+ ldst_f64_abcd_ops,
ldst_subint_ab_ops,
ldst_bit_ab_ops,
ldst_subint_cd_ops);
@@ -235,69 +372,110 @@ class NVVM_MMA_OPS<int _ = 0> {
def NVVM_MMA_OPS : NVVM_MMA_OPS;
-// Returns true if this combination of layout/satf is supported; false otherwise.
-// MMA ops must provide all parameters. Loads and stores -- only frags and layout_a.
-// The class is used to prevent generation of records for the unsupported variants.
+
+// Returns true if this combination of fragment and layout for WMMA load/store
+// ops is supported; false otherwise.
+// E.g.
+// if NVVM_WMMA_LDST_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_WMMA_LDST_SUPPORTED<WMMA_REGS frag, string layout> {
+ string f = frag.frag;
+ string t = frag.ptx_elt_type;
+
+ bit ret = !cond(
+ // Sub-int load and store requires A fragment to be of row layout and B
+ // fragments to be of column layout.
+ !and(!or(!eq(t, "b1"),
+ !eq(t, "u4"),
+ !eq(t, "s4")),
+ !or(!and(!eq(f, "a"),
+ !ne(layout, "row")),
+ !and(!eq(f, "b"),
+ !ne(layout, "col")))) : false,
+ true: true
+ );
+}
+
+// Returns true if this combination of layout/satf/rnd for WMMA ops is
+// supported; false otherwise.
+// E.g.
+// if NVVM_WMMA_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_WMMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf, string rnd> {
+ // WMMA ops check both layouts.
+ string layout = layout_a # ":" # layout_b;
+ string t = frags[0].ptx_elt_type;
+
+ bit ret = !cond(
+ // only f64 wmma functions support rnd options
+ // any non f64 type that uses a rnd value is invalid
+ !and(!ne(t, "f64"), !ne(rnd, "")) : false,
+
+ // satf is only valid for select types
+ !and(!eq(satf, 1),
+ !ne(t, "s8"),
+ !ne(t, "u8"),
+ !ne(t, "s4"),
+ !ne(t, "u4"),
+ !ne(t, "f16")): false,
+
+ // Sub-int wmma requires row/column layout
+ !and(!or(!eq(t, "s4"),
+ !eq(t, "u4"),
+ !eq(t, "b1")),
+ !ne(layout, "row:col")) : false,
+ true: true
+ );
+}
+
+// Returns true if this combination of layout/satf for MMA ops is supported;
+// false otherwise.
// E.g.
// if NVVM_MMA_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
-class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b="-", int satf=-1> {
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
// MMA ops check both layouts.
- string mma = frags[0].ptx_elt_type
- # ":" # layout_a
- # ":" # layout_b;
- // Load ops only need type/fragment/layout.
- string ld = frags[0].ptx_elt_type
- # ":" # frags[0].frag
- # ":" # layout_a
- ;
- string ldf = frags[0].ptx_elt_type
- # ":" # frags[0].frag
- ;
- string t = frags[0].ptx_elt_type;
+ string layout = layout_a # ":" # layout_b;
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
// gcd is a shortcut used to identify instructions that depend on
- // geom+frag_c+frag_d. Not all instances of this class have all fragments
- // specified. If there are not enough fragments, the tail evaluates to '?'.
- string gcd = frags[0].geom
- # ":"
- # !if(!eq(!size(frags), 4),
- frags[2].ptx_elt_type # frags[3].ptx_elt_type,
- "?");
+ // geom+frag_c+frag_d.
+ string gcd = geom # ":" # c_type # d_type;
bit ret = !cond(
- // Sub-int MMA only supports fixed A/B layout.
- // b1 does not support .satf.
- !eq(mma#":"#satf, "b1:row:col:0") : true,
- // mma.m8n8k4 has no .satf modifier.
- !and(!eq(frags[0].geom, "m8n8k4"),
- !ne(satf, 0)): false,
-
- // mma.m8n8k4 has no C=f32 D=f16 variant.
+
+ // Limit satf to valid types
+ !and(!eq(satf, 1),
+ !ne(a_type, "s8"),
+ !ne(a_type, "u8"),
+ !ne(a_type, "s4"),
+ !ne(a_type, "u4")): false,
+
+ // m8n8k4 has no C=f32 D=f16 variant.
!eq(gcd, "m8n8k4:f32f16"): false,
- !eq(mma, "s4:row:col") : true,
- !eq(mma, "u4:row:col") : true,
- !eq(mma, "s4:row:col") : true,
- !eq(mma, "u4:row:col") : true,
- // Sub-int load/stores have fixed layout for A and B.
- !and(!eq(layout_b, "-"), // It's a Load or Store op
- !or(!eq(ld, "b1:a:row"),
- !eq(ld, "b1:b:col"),
- !eq(ldf, "b1:c"),
- !eq(ldf, "b1:d"),
- !eq(ld, "s4:a:row"),
- !eq(ld, "s4:b:col"),
- !eq(ldf, "s4:c"),
- !eq(ldf, "s4:d"),
- !eq(ld, "u4:a:row"),
- !eq(ld, "u4:b:col"),
- !eq(ldf, "u4:c"),
- !eq(ldf, "u4:d"))) : true,
- // All other sub-int ops are not supported.
- !eq(t, "b1") : false,
- !eq(t, "s4") : false,
- !eq(t, "u4") : false,
- // All other (non sub-int) are OK.
+
+ // only m8n8k4 for f16 does not require row:col layout
+ !and(!ne(layout, "row:col"),
+ !or(!ne(geom, "m8n8k4"),
+ !ne(a_type, "f16"))) : false,
+
+ // m16n8k8 requires A and B to be the same type and C and D to be the same
+ // type.
+ !and(!eq(geom, "m16n8k8"),
+ !or(!ne(a_type, b_type),
+ !ne(c_type, d_type))): false,
+
+ // m16n8k8 requires C and D to be the same type.
+ !and(!eq(geom, "m16n8k8"),
+ !ne(c_type, d_type)): false,
+
+ // All other are OK.
true: true
);
}
@@ -4271,36 +4449,59 @@ class NVVM_WMMA_ST<WMMA_REGS Frag, string Layout, int WithStride>
foreach layout = ["row", "col"] in {
foreach stride = [0, 1] in {
foreach frag = NVVM_MMA_OPS.all_ld_ops in
- if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
+ if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
def WMMA_NAME_LDST<"load", frag, layout, stride>.record
: NVVM_WMMA_LD<frag, layout, stride>;
foreach frag = NVVM_MMA_OPS.all_st_ops in
- if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
+ if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
def WMMA_NAME_LDST<"store", frag, layout, stride>.record
: NVVM_WMMA_ST<frag, layout, stride>;
}
}
// WMMA.MMA
-class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite,
+class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite, string rnd,
WMMA_REGS A, WMMA_REGS B,
WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
!listconcat(A.regs, B.regs, C.regs),
[IntrNoMem],
- WMMA_NAME_MMA<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
+ WMMA_NAME<ALayout, BLayout, Satfinite, rnd, A, B, C, D>.llvm>;
+
+foreach layout_a = ["row", "col"] in {
+ foreach layout_b = ["row", "col"] in {
+ foreach satf = [0, 1] in {
+ foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
+ foreach op = NVVM_MMA_OPS.all_wmma_ops in {
+ if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
+ def WMMA_NAME<layout_a, layout_b, satf, rnd,
+ op[0], op[1], op[2], op[3]>.record
+ : NVVM_WMMA_MMA<layout_a, layout_b, satf, rnd,
+ op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // rnd
+ } // satf
+ } // layout_b
+} // layout_a
+
+// MMA
+class NVVM_MMA<string ALayout, string BLayout, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs),
+ [IntrNoMem],
+ MMA_NAME<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
foreach layout_a = ["row", "col"] in {
foreach layout_b = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
- def WMMA_NAME_MMA<layout_a, layout_b, satf,
- op[0], op[1], op[2], op[3]>.record
- : NVVM_WMMA_MMA<layout_a, layout_b, satf,
- op[0], op[1], op[2], op[3]>;
+ def MMA_NAME<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>.record
+ : NVVM_MMA<layout_a, layout_b, satf, op[0], op[1], op[2], op[3]>;
}
- }
+ } // op
} // satf
} // layout_b
} // layout_a
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d431f20d066f6..d4842c953ce7a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3490,6 +3490,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_bf16_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
@@ -3497,7 +3501,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
- case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: {
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_bf16_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3515,6 +3523,14 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_bf16_row_stride:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_a_tf32_row_stride:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
@@ -3523,7 +3539,15 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
- case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: {
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_bf16_row_stride:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3603,7 +3627,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col:
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row:
case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride:
- case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride: {
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k8_load_c_f32_row_stride: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
@@ -3613,6 +3641,16 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_bf16_row_stride:
+
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_bf16_row_stride:
+
case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
@@ -3651,6 +3689,37 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
+ case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col:
+ case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row:
+ case Intrinsic::nvvm_wmma_m8n8k4_load_a_f64_row_stride:
+
+ case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col:
+ case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row:
+ case Intrinsic::nvvm_wmma_m8n8k4_load_b_f64_row_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::f64;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOLoad;
+ Info.align = Align(8);
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col:
+ case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row:
+ case Intrinsic::nvvm_wmma_m8n8k4_load_c_f64_row_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v2f64;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOLoad;
+ Info.align = Align(16);
+ return true;
+ }
+
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
@@ -3683,7 +3752,11 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col:
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row:
case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride:
- case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride: {
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride:
+ case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col:
+ case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row:
+ case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k8_store_d_f32_row_stride: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v8f32;
Info.ptrVal = I.getArgOperand(0);
@@ -3731,6 +3804,19 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
+ case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col:
+ case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row:
+ case Intrinsic::nvvm_wmma_m8n8k4_store_d_f64_row_stride: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::v2f64;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align = Align(16);
+ return true;
+ }
+
case Intrinsic::nvvm_atomic_load_inc_32:
case Intrinsic::nvvm_atomic_load_dec_32:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 5622d5a6fdac5..ab93bf16d4919 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -144,6 +144,7 @@ def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
def hasPTX64 : Predicate<"Subtarget->getPTXVersion() >= 64">;
+def hasPTX65 : Predicate<"Subtarget->getPTXVersion() >= 65">;
def hasPTX70 : Predicate<"Subtarget->getPTXVersion() >= 70">;
def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 1aaa9f0dd127d..798538410b104 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1943,21 +1943,21 @@ multiclass VLDU_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
!strconcat("ldu.global.", TyStr), []>;
}
-multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
+multiclass VLDU_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins Int32Regs:$src),
+ regclass:$dst4), (ins Int32Regs:$src),
!strconcat("ldu.global.", TyStr), []>;
def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins Int64Regs:$src),
+ regclass:$dst4), (ins Int64Regs:$src),
!strconcat("ldu.global.", TyStr), []>;
def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins MEMri:$src),
+ regclass:$dst4), (ins MEMri:$src),
!strconcat("ldu.global.", TyStr), []>;
def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins MEMri64:$src),
+ regclass:$dst4), (ins MEMri64:$src),
!strconcat("ldu.global.", TyStr), []>;
def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins imemAny:$src),
+ regclass:$dst4), (ins imemAny:$src),
!strconcat("ldu.global.", TyStr), []>;
}
@@ -1997,7 +1997,7 @@ defm INT_PTX_LDU_G_v4f32_ELE
//-----------------------------------
-// Support for ldg on sm_35 or later
+// Support for ldg on sm_35 or later
//-----------------------------------
// Don't annotate ld.global.nc as mayLoad, because these loads go through the
@@ -2045,7 +2045,7 @@ defm INT_PTX_LDG_GLOBAL_p64
// vector
-// Elementized vector ldg
+// Elementized vector ldg
multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2),
(ins Int32Regs:$src),
@@ -2064,21 +2064,21 @@ multiclass VLDG_G_ELE_V2<string TyStr, NVPTXRegClass regclass> {
!strconcat("ld.global.nc.", TyStr), []>;
}
-multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
+multiclass VLDG_G_ELE_V4<string TyStr, NVPTXRegClass regclass> {
def _areg32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins Int32Regs:$src),
+ regclass:$dst4), (ins Int32Regs:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _areg64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins Int64Regs:$src),
+ regclass:$dst4), (ins Int64Regs:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _ari32: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins MEMri:$src),
+ regclass:$dst4), (ins MEMri:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _ari64: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins MEMri64:$src),
+ regclass:$dst4), (ins MEMri64:$src),
!strconcat("ld.global.nc.", TyStr), []>;
def _avar: NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3,
- regclass:$dst4), (ins imemAny:$src),
+ regclass:$dst4), (ins imemAny:$src),
!strconcat("ld.global.nc.", TyStr), []>;
}
@@ -7568,12 +7568,15 @@ def INT_PTX_SREG_WARPSIZE :
// In addition to target-independent fields provided by WMMA_REGS, it adds
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
-class WMMA_REGINFO<WMMA_REGS r>
+class WMMA_REGINFO<WMMA_REGS r, string op>
: WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
!eq(ptx_elt_type, "f16") : Float16x2Regs,
!eq(ptx_elt_type, "f32") : Float32Regs,
+ !eq(ptx_elt_type, "f64") : Float64Regs,
+ !eq(ptx_elt_type, "bf16") : Int32Regs,
+ !eq(ptx_elt_type, "tf32") : Int32Regs,
!eq(ptx_elt_type, "s32") : Int32Regs,
!eq(ptx_elt_type, "s8") : Int32Regs,
!eq(ptx_elt_type, "u8") : Int32Regs,
@@ -7602,6 +7605,9 @@ class WMMA_REGINFO<WMMA_REGS r>
!or(!eq(ptx_elt_type, "f16"),
!eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60],
+ !and(!eq(geom,"m8n8k4"),
+ !eq(ptx_elt_type, "f64")) : [hasSM80, hasPTX70],
+
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
!and(!or(!eq(geom, "m8n32k16"),
!eq(geom, "m32n8k16")),
@@ -7616,11 +7622,46 @@ class WMMA_REGINFO<WMMA_REGS r>
!eq(ptx_elt_type, "s8"),
!eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63],
- // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
- !or(!eq(geom,"m8n8k128"),
- !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
+ !and(!or(!eq(geom,"m16n16k16"),
+ !eq(geom,"m8n32k16"),
+ !eq(geom,"m32n8k16")),
+ !eq(ptx_elt_type, "bf16")) : [hasSM80, hasPTX70],
+
+ !and(!eq(geom,"m16n16k8"),
+ !eq(ptx_elt_type, "tf32")) : [hasSM80, hasPTX70],
+
+ !and(!eq(geom,"m16n16k8"),
+ !eq(ptx_elt_type, "f32")) : [hasSM80, hasPTX70],
+
+ // b1 -> s32 @ m8n8k128(b1)
+ !and(!ne(op,"mma"),
+ !eq(geom,"m8n8k128")) : [hasSM75, hasPTX63],
+
+ // u4/s4 -> s32 @ m8n8k32 (u4/s4)
+ !and(!ne(op,"mma"),
+ !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63],
+
+ !or(!eq(geom,"m16n8k8"),
+ !eq(geom,"m8n8k16")) : [hasSM75, hasPTX65],
- !eq(geom, "m8n8k4") : [hasSM70, hasPTX64]);
+ !and(!ne(ptx_elt_type,"f64"),
+ !eq(geom, "m8n8k4")) : [hasSM70, hasPTX64],
+
+ // mma m8n8k32 requires higher PTX version
+ !and(!eq(op,"mma"),
+ !eq(geom,"m8n8k32")) : [hasSM75, hasPTX65],
+
+ !and(!eq(ptx_elt_type,"f64"),
+ !eq(geom, "m8n8k4")) : [hasSM80, hasPTX70],
+
+ !and(!eq(op,"mma"),
+ !or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k4"),
+ !eq(geom, "m16n8k32"),
+ !eq(geom, "m16n8k64"),
+ !eq(geom, "m8n8k128"),
+ !eq(geom, "m16n8k128"),
+ !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7744,11 +7785,11 @@ defset list<WMMA_INSTR> MMA_LDSTs = {
foreach space = [".global", ".shared", ""] in {
foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
foreach frag = NVVM_MMA_OPS.all_ld_ops in
- if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
- def : WMMA_LOAD<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+ if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
+ def : WMMA_LOAD<WMMA_REGINFO<frag, "load">, layout, space, stride, addr>;
foreach frag = NVVM_MMA_OPS.all_st_ops in
- if NVVM_MMA_SUPPORTED<[frag], layout>.ret then
- def : WMMA_STORE_D<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+ if NVVM_WMMA_LDST_SUPPORTED<frag, layout>.ret then
+ def : WMMA_STORE_D<WMMA_REGINFO<frag, "store">, layout, space, stride, addr>;
} // addr
} // space
} // stride
@@ -7758,46 +7799,84 @@ defset list<WMMA_INSTR> MMA_LDSTs = {
// WMMA.MMA
class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
- string ALayout, string BLayout, int Satfinite>
- : WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
- [FragA.Ins, FragB.Ins, FragC.Ins]>,
+ string ALayout, string BLayout, int Satfinite, string rnd>
+ : WMMA_INSTR<WMMA_NAME<ALayout, BLayout, Satfinite, rnd, FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
Requires<FragA.Predicates> {
let OutOperandList = FragD.Outs;
let InOperandList = !con(Args, (ins MmaCode:$ptx));
string TypeList = !cond(
- !eq(FragD.geom, "m8n8k4") : "." # FragD.ptx_elt_type
- # ".f16.f16."
- # FragC.ptx_elt_type,
- !eq(FragD.ptx_elt_type, "s32") : ".s32"
- # "." # FragA.ptx_elt_type
- # "." # FragB.ptx_elt_type
- # ".s32",
- 1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type,
+ !eq(FragA.ptx_elt_type, "f16") : "." # FragD.ptx_elt_type
+ # "." # FragC.ptx_elt_type,
+ 1: "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type,
);
- let AsmString = !if(!eq(FragA.geom, "m8n8k4"),
- "mma.sync.aligned.m8n8k4"
- # "." # ALayout
- # "." # BLayout
- # TypeList # "\n\t\t"
- # FragD.regstring # ",\n\t\t"
- # FragA.regstring # ",\n\t\t"
- # FragB.regstring # ",\n\t\t"
- # FragC.regstring # ";",
- "wmma.mma"
- # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
- # ".sync"
- # "${ptx:aligned}"
- # "." # ALayout
- # "." # BLayout
- # "." # FragA.geom
- # TypeList
- # !if(Satfinite, ".satfinite", "") # "\n\t\t"
- # FragD.regstring # ",\n\t\t"
- # FragA.regstring # ",\n\t\t"
- # FragB.regstring # ",\n\t\t"
- # FragC.regstring # ";");
+ let AsmString = "wmma.mma"
+ # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
+ # ".sync"
+ # "${ptx:aligned}"
+ # "." # ALayout
+ # "." # BLayout
+ # "." # FragA.geom
+ # !if(!ne(rnd, ""), !strconcat(".", rnd), "")
+ # TypeList
+ # !if(Satfinite, ".satfinite", "") # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ";";
+}
+
+defset list<WMMA_INSTR> WMMAs = {
+ foreach layout_a = ["row", "col"] in {
+ foreach layout_b = ["row", "col"] in {
+ foreach satf = [0, 1] in {
+ foreach rnd = ["", "rn", "rz", "rm", "rp"] in {
+ foreach op = NVVM_MMA_OPS.all_wmma_ops in {
+ if NVVM_WMMA_SUPPORTED<op, layout_a, layout_b, satf, rnd>.ret then {
+ def : WMMA_MMA<WMMA_REGINFO<op[0], "wmma.mma">,
+ WMMA_REGINFO<op[1], "wmma.mma">,
+ WMMA_REGINFO<op[2], "wmma.mma">,
+ WMMA_REGINFO<op[3], "wmma.mma">,
+ layout_a, layout_b, satf, rnd>;
+ }
+ } // op
+ } // rnd
+ } // satf
+ } // layout_b
+ } // layout_a
+} // defset
+
+// MMA
+class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string ALayout, string BLayout, int Satfinite>
+ : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<FragA.Predicates> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type;
+ let AsmString = "mma.sync.aligned."
+ # FragA.geom
+ # "." # ALayout
+ # "." # BLayout
+ # !if(Satfinite, ".satfinite", "")
+ # TypeList
+ # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "") # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ";";
}
defset list<WMMA_INSTR> MMAs = {
@@ -7806,11 +7885,11 @@ defset list<WMMA_INSTR> MMAs = {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
- def : WMMA_MMA<WMMA_REGINFO<op[0]>,
- WMMA_REGINFO<op[1]>,
- WMMA_REGINFO<op[2]>,
- WMMA_REGINFO<op[3]>,
- layout_a, layout_b, satf>;
+ def : MMA<WMMA_REGINFO<op[0], "mma">,
+ WMMA_REGINFO<op[1], "mma">,
+ WMMA_REGINFO<op[2], "mma">,
+ WMMA_REGINFO<op[3], "mma">,
+ layout_a, layout_b, satf>;
}
} // op
} // satf
@@ -7822,12 +7901,12 @@ defset list<WMMA_INSTR> MMAs = {
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
// the instruction record.
-class WMMA_PAT<WMMA_INSTR wi>
+class MMA_PAT<WMMA_INSTR wi>
: Pat<wi.IntrinsicPattern,
!con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)),
(wi ptx.version))>,
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, MMA_LDSTs) in
- def : WMMA_PAT<mma>;
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in
+ def : MMA_PAT<mma>;
diff --git a/llvm/test/CodeGen/NVPTX/lit.local.cfg b/llvm/test/CodeGen/NVPTX/lit.local.cfg
index 2cb98eb371b21..8354800109ebd 100644
--- a/llvm/test/CodeGen/NVPTX/lit.local.cfg
+++ b/llvm/test/CodeGen/NVPTX/lit.local.cfg
@@ -1,2 +1,3 @@
if not 'NVPTX' in config.root.targets:
config.unsupported = True
+config.suffixes.add('.py')
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 8c140c4d93108..8a808bd377eb9 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -6,7 +6,7 @@
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
-# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA
+# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
# RUN: | FileCheck %t-ptx60-sm_70.ll
@@ -15,7 +15,7 @@
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
-# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA
+# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
# RUN: | FileCheck %t-ptx61-sm_70.ll
@@ -24,7 +24,7 @@
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
-# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA
+# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
# RUN: | FileCheck %t-ptx63-sm_72.ll
@@ -33,7 +33,7 @@
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
-# RUN: --check-prefixes=INTRINSICS,NOMMA
+# RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
# RUN: | FileCheck %t-ptx63-sm_75.ll
@@ -42,10 +42,28 @@
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
-# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT
+# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT
# RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
# RUN: | FileCheck %t-ptx64-sm_70.ll
+# Check all variants of instructions supported by PTX65 on SM75+
+# RUN: python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
+# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
+# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA
+# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
+# RUN: --check-prefixes=INTRINSICS
+# RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
+# RUN: | FileCheck %t-ptx65-sm_75.ll
+
+# Check all variants of instructions supported by PTX70 on SM80+
+# RUN: python %s --ptx=70 --gpu-arch=80 > %t-ptx70-sm_80.ll
+# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
+# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX70MMA
+# RUN: FileCheck %t-ptx70-sm_80.ll < %t-ptx70-sm_80.ll \
+# RUN: --check-prefixes=INTRINSICS
+# RUN: llc < %t-ptx70-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 \
+# RUN: | FileCheck %t-ptx70-sm_80.ll
+
from __future__ import print_function
import argparse
@@ -56,19 +74,23 @@ class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
self.llvm_type = {
- "f16" : "<2 x half>",
- "f32" : "float",
- "s32" : "i32",
- "s8" : "i32",
- "u8" : "i32",
- "s4" : "i32",
- "u4" : "i32",
- "b1" : "i32",
+ "f16" : "<2 x half>",
+ "f32" : "float",
+ "f64" : "double",
+ "s32" : "i32",
+ "s8" : "i32",
+ "u8" : "i32",
+ "s4" : "i32",
+ "u4" : "i32",
+ "b1" : "i32",
+ "bf16" : "i32",
+ "tf32" : "i32",
}[ptx_type];
self.ptx_reg_pattern = {
"f16" : "%hh[0-9]+",
"f32" : "%f[0-9]+",
+ "f64" : "%fd[0-9]+",
}.get(ptx_type, "%r[0-9]+")
def __repr__(self):
@@ -78,16 +100,8 @@ class MMAFrag:
def __init__(self, geom, frag, ptx_elt_type):
self.geom = geom
self.frag = frag
- self.is_mma = True if geom == "m8n8k4" else False;
self.mma_type = MMAType(ptx_elt_type);
self.nregs = {
- "a:f16" : 2 if self.is_mma else 8,
- "b:f16" : 2 if self.is_mma else 8,
- "c:f16" : 4,
- "d:f16" : 4,
- "c:f32" : 8,
- "d:f32" : 8,
- }.get("%s:%s" % (frag, ptx_elt_type), {
# u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
"m16n16k16:a:u8" : 2,
"m16n16k16:a:s8" : 2,
@@ -110,18 +124,123 @@ def __init__(self, geom, frag, ptx_elt_type):
"m32n8k16:c:s32" : 8,
"m32n8k16:d:s32" : 8,
- # u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
- "m8n8k128:a:b1" : 1,
+ "m8n8k16:a:u8": 1,
+ "m8n8k16:a:s8": 1,
+ "m8n8k16:b:u8": 1,
+ "m8n8k16:b:s8": 1,
+ "m8n8k16:c:s32": 2,
+ "m8n8k16:d:s32": 2,
+
+ "m16n8k16:a:u8": 2,
+ "m16n8k16:a:s8": 2,
+ "m16n8k16:b:u8": 1,
+ "m16n8k16:b:s8": 1,
+ "m16n8k16:c:s32": 4,
+ "m16n8k16:d:s32": 4,
+
+ "m16n8k32:a:u8": 4,
+ "m16n8k32:a:s8": 4,
+ "m16n8k32:b:u8": 2,
+ "m16n8k32:b:s8": 2,
+ "m16n8k32:c:s32": 4,
+ "m16n8k32:d:s32": 4,
+
+ # u4/s4 -> s32 @ m8n8k32 (u4/s4)
"m8n8k32:a:u4" : 1,
"m8n8k32:a:s4" : 1,
- "m8n8k128:b:b1" : 1,
"m8n8k32:b:u4" : 1,
"m8n8k32:b:s4" : 1,
- "m8n8k128:c:s32" : 2,
- "m8n8k128:d:s32" : 2,
"m8n8k32:c:s32" : 2,
"m8n8k32:d:s32" : 2,
- }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None));
+
+ "m16n8k32:a:u4" : 2,
+ "m16n8k32:a:s4" : 2,
+ "m16n8k32:b:u4" : 1,
+ "m16n8k32:b:s4" : 1,
+ "m16n8k32:c:s32" : 4,
+ "m16n8k32:d:s32" : 4,
+
+ "m16n8k64:a:u4" : 4,
+ "m16n8k64:a:s4" : 4,
+ "m16n8k64:b:u4" : 2,
+ "m16n8k64:b:s4" : 2,
+ "m16n8k64:c:s32" : 4,
+ "m16n8k64:d:s32" : 4,
+
+ # b1 -> s32 @ m8n8k128(b1)
+ "m8n8k128:a:b1" : 1,
+ "m8n8k128:b:b1" : 1,
+ "m8n8k128:c:s32" : 2,
+ "m8n8k128:d:s32" : 2,
+
+ "m16n8k128:a:b1" : 2,
+ "m16n8k128:b:b1" : 1,
+ "m16n8k128:c:s32" : 4,
+ "m16n8k128:d:s32" : 4,
+
+ "m16n8k256:a:b1" : 4,
+ "m16n8k256:b:b1" : 2,
+ "m16n8k256:c:s32" : 4,
+ "m16n8k256:d:s32" : 4,
+
+ # bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+ "m16n16k16:a:bf16" : 4,
+ "m16n16k16:b:bf16" : 4,
+ "m8n32k16:a:bf16" : 2,
+ "m8n32k16:b:bf16" : 8,
+ "m32n8k16:a:bf16" : 8,
+ "m32n8k16:b:bf16" : 2,
+
+ "m16n8k16:a:bf16" : 4,
+ "m16n8k16:b:bf16" : 2,
+ "m16n8k16:c:f32" : 4,
+ "m16n8k16:d:f32" : 4,
+ "m16n8k8:a:bf16" : 2,
+ "m16n8k8:b:bf16" : 1,
+ "m16n8k8:c:f32" : 4,
+ "m16n8k8:d:f32" : 4,
+
+ "m8n8k4:a:f64" : 1,
+ "m8n8k4:b:f64" : 1,
+ "m8n8k4:c:f64" : 2,
+ "m8n8k4:d:f64" : 2,
+
+ # tf32 -> s32 @ m16n16k8
+ "m16n16k8:a:tf32" : 4,
+ "m16n16k8:b:tf32" : 4,
+
+ "m16n8k4:a:tf32" : 2,
+ "m16n8k4:b:tf32" : 1,
+ "m16n8k4:c:f32" : 4,
+ "m16n8k4:d:f32" : 4,
+ "m16n8k8:a:tf32" : 4,
+ "m16n8k8:b:tf32" : 2,
+ "m16n8k8:c:f32" : 4,
+ "m16n8k8:d:f32" : 4,
+
+ "m8n8k4:a:f16": 2,
+ "m8n8k4:b:f16": 2,
+ "m16n8k8:a:f16": 2,
+ "m16n8k8:b:f16": 1,
+ "m16n8k8:c:f16": 2,
+ "m16n8k8:d:f16": 2,
+ "m16n8k8:c:f32": 4,
+ "m16n8k8:d:f32": 4,
+ "m16n8k16:a:f16": 4,
+ "m16n8k16:b:f16": 2,
+ "m16n8k16:c:f16": 2,
+ "m16n8k16:d:f16": 2,
+ "m16n8k16:c:f32": 4,
+ "m16n8k16:d:f32": 4,
+ }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), {
+ # All other FP shape/fragment/type combinations have the same size
+ "a:f16" : 8,
+ "b:f16" : 8,
+ "c:f16" : 4,
+ "d:f16" : 4,
+ "c:f32" : 8,
+ "d:f32" : 8,
+ }.get("%s:%s" % (frag, ptx_elt_type), None))
assert(self.nregs);
def __repr__(self):
@@ -153,9 +272,13 @@ def make_ldst_ops(geoms, frags, types):
return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
in product(geoms, frags, types)]
-def get_mma_ops():
- return (make_mma_ops(["m8n8k4"],
- ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
+def get_wmma_ops():
+ return (make_mma_ops(["m16n16k8"],
+ ["tf32"], [], ["f32"], []) +
+ make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["bf16"], [], ["f32"], []) +
+ make_mma_ops(["m8n8k4"],
+ ["f64"], [], ["f64"], []) +
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
@@ -164,20 +287,38 @@ def get_mma_ops():
["s4", "u4"], [], ["s32"], []) +
make_mma_ops(["m8n8k128"],
["b1"], [], ["s32"], []))
+
+def get_mma_ops():
+ return (make_mma_ops(["m8n8k4"],
+ ["f64"], [], ["f64"], []) +
+ make_mma_ops(["m16n8k4", "m16n8k8"],
+ ["tf32"], [], ["f32"], []) +
+ make_mma_ops(["m16n8k16", "m16n8k8"],
+ ["bf16"], [], ["f32"], []) +
+ make_mma_ops(["m8n8k4", "m16n8k8", "m16n8k16"],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
+ make_mma_ops(["m8n8k16", "m16n8k16", "m16n8k32"],
+ ["s8", "u8"], ["s8", "u8"], ["s32"], []) +
+ make_mma_ops(["m8n8k32", "m16n8k32", "m16n8k64"],
+ ["s4", "u4"], ["s4", "u4"], ["s32"], []) +
+ make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"],
+ ["b1"], [], ["s32"], []))
+
def get_ldst_ops(kind):
ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
- ["a", "b"], ["f16", "u8", "s8"]) +
+ ["a", "b"], ["f16", "u8", "s8", "bf16"]) +
make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
["c", "d"], ["f16", "f32", "s32"]) +
make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
- make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]))
+ make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]) +
+ make_ldst_ops(["m8n8k4"], ["a", "b", "c", "d"], ["f64"]) +
+ make_ldst_ops(["m16n16k8"], ["a", "b"], ["tf32"]) +
+ make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
-def is_geom_supported(geom):
+def is_wmma_geom_supported(geom):
# geometries for FP and ints.
- if geom == "m8n8k4":
- return ptx_version >= 64
if geom in ["m8n32k16", "m32n8k16"]:
return ptx_version >= 61
# geometries for sub-ints.
@@ -185,6 +326,21 @@ def is_geom_supported(geom):
return ptx_version >= 63 and gpu_arch >= 75
if geom == "m16n16k16":
return ptx_version >= 60
+ if geom == "m16n8k8":
+ return ptx_version >= 65
+ if geom in ["m16n16k8", "m8n8k4"]:
+ return ptx_version >= 70
+ assert(False) # Unexpected geometry.
+
+def is_mma_geom_supported(geom):
+ # geometries for FP and ints.
+ if geom == "m8n8k4":
+ return ptx_version >= 64
+ if geom in ["m16n8k8", "m8n8k16", "m8n8k32"]:
+ return ptx_version >= 65
+ if geom in ["m16n8k16", "m16n8k4", "m16n8k32", "m16n8k64", "m8n8k128",
+ "m16n8k128", "m16n8k256"]:
+ return ptx_version >= 70
assert(False) # Unexpected geometry.
def is_type_supported(ptx_type):
@@ -192,30 +348,63 @@ def is_type_supported(ptx_type):
return ptx_version >= 63 and gpu_arch >= 72
if ptx_type in ["s4", "u4", "b1"]:
return ptx_version >= 63 and gpu_arch >= 75
+ if ptx_type in ["bf16", "tf32", "f64"]:
+ return ptx_version >= 70
return ptx_version >= 60 and gpu_arch >= 70
+def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
+ if not (is_type_supported(op.a.mma_type.ptx_type)
+ and is_wmma_geom_supported(op.a.geom)):
+ return False
+
+ # rnd is only supported for FP64 WMMA
+ if rnd and op.a.mma_type.ptx_type != "f64":
+ return False
+
+ if satf:
+ # satfinite for floating points was removed in PTX 6.5
+ if op.a.mma_type.ptx_type == "f16" and ptx_version >= 65:
+ return False
+ if not op.a.mma_type.ptx_type in ["f16", "s8", "u8", "s4", "u4"]:
+ return False
+
+ # sub-integer require row/col layout.
+ if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
+ return layout_a == "row" and layout_b == "col"
+ return True
def is_mma_variant_supported(op, layout_a, layout_b, satf):
if not (is_type_supported(op.a.mma_type.ptx_type)
- and is_geom_supported(op.a.geom)):
+ and is_mma_geom_supported(op.a.geom)):
+ return False
+
+ if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
+ return False
+
+ # If the type of C is f32 then so must the type of D
+ if (op.a.geom == "m8n8k4" and op.c.mma_type.ptx_type == "f32"
+ and op.d.mma_type.ptx_type != "f32"):
return False
- if op.a.geom == "m8n8k4":
- if satf:
+
+ # A and B type must be the same. C and D type must be the same
+ if (op.a.geom == "m16n8k8"
+ and (op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
+ or op.c.mma_type.ptx_type != op.d.mma_type.ptx_type)):
return False
- if op.c.mma_type.ptx_type == "f32":
- # If C is f32, D must be, too.
- return op.d.mma_type.ptx_type == "f32"
- # sub-integer require row/col layout, and no satf.
- if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
- if op.a.mma_type.ptx_type == "b1" and satf:
+ # C and D type must be the same
+ if (op.a.geom == "m16n8k16"
+ and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type):
return False
+
+ # Require row/col layout for all MMA except m8n8k4 on FP16
+ if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
return layout_a == "row" and layout_b == "col"
return True
def is_ldst_variant_supported(frag, layout):
if not (is_type_supported(frag.mma_type.ptx_type)
- and is_geom_supported(frag.geom)):
+ and is_wmma_geom_supported(frag.geom)):
return False
if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
# sub-integer require sm_75 and ptx63, row/col layout for a/b.
@@ -396,24 +585,37 @@ def gen_wmma_store_tests():
return generated_items
def mma_signature(op):
- if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
- # int and sub-int ops are identified by input type.
- return op.a.mma_type.ptx_type
- else:
- # the rest are FP ops identified by accumulator & result type.
+ if op.a.mma_type.ptx_type == "f16":
+ # FP16 ops identified by accumulator & result type.
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
+ elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
+ # other ops are identified by input types.
+ return "%s.%s" % (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+ else:
+ # if input types are the same, it only appears once.
+ return op.a.mma_type.ptx_type
def mma_ptx_signature(op):
- if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
- # int and sub-int instructions encode all four types as D.A.B.C
- return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
- if op.a.geom == "m8n8k4":
- return "%s.f16.f16.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
+ # Encode all four types as D.A.B.C
+ return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
+
+def wmma_signature(op):
+ if op.a.mma_type.ptx_type == "f16":
+ # FP16 ops identified by accumulator & result type.
+ return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
else:
- # the rest are FP instructions use D.C
+ # other ops are identified by input type.
+ return op.a.mma_type.ptx_type
+
+def wmma_ptx_signature(op):
+ if op.a.mma_type.ptx_type == "f16":
+ # FP16 instructions use D.C
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
+ else:
+ # other instructions encode all four types as D.A.B.C
+ return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
-def gen_wmma_mma_tests():
+def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
mma_template = """
declare ${ret_ty} @${intrinsic}(
${args});
@@ -431,10 +633,61 @@ def gen_wmma_mma_tests():
ret ${ret_ty} %r;
}
"""
- wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
- wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}"
- mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}.${intrinsic_signature}"
- mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}.${ptx_signature}"
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+ test_params["check_a"] = check_pattern(op.a)
+ test_params["check_b"] = check_pattern(op.b)
+ test_params["check_c"] = check_pattern(op.c)
+ test_params["check_d"] = check_pattern(op.d)
+ args = ",\n ".join(make_wmma_slice_args(frag)
+ for frag in (op.a, op.b, op.c))
+ test_params["args"] = args
+ print(Template(mma_template).substitute(test_params))
+ return (test_params["intrinsic"], test_params["instruction"])
+
+def gen_wmma_mma_tests():
+ wmma_intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}${rnd}.${intrinsic_signature}${satf}"
+ wmma_instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}${rnd}.${ptx_signature}${satf}"
+
+ generated_items=[]
+
+ for op, alayout, blayout, rnd, satf in product(
+ get_wmma_ops(),
+ ["row","col"],
+ ["row","col"],
+ [".rn", ".rz", ".rm", ".rp", ""],
+ [".satfinite", ""]):
+
+ if not is_wmma_variant_supported(op, alayout, blayout, rnd, satf):
+ continue
+
+ params = {
+ "aligned" : ".aligned" if ptx_version >= 63 else "",
+ "alayout" : alayout,
+ "blayout" : blayout,
+ "intrinsic_signature" : wmma_signature(op),
+ "ptx_signature" : wmma_ptx_signature(op),
+ "satf" : satf,
+ "rnd" : rnd,
+ "geom" : op.a.geom,
+ "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
+ }
+
+ intrinsic_template = wmma_intrinsic_template
+ instruction_template = wmma_instruction_template
+
+ generated_items.append(common_mma_test_gen(params, op,
+ intrinsic_template, instruction_template))
+
+ return generated_items
+
+def gen_mma_tests():
+ mma_intrinsic_template = "llvm.nvvm.mma.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
+ mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${mma_variant}"
generated_items=[]
@@ -458,28 +711,11 @@ def gen_wmma_mma_tests():
"mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
}
- if op.a.geom == "m8n8k4":
- intrinsic_template = mma_intrinsic_template
- instruction_template = mma_instruction_template
- else:
- intrinsic_template = wmma_intrinsic_template
- instruction_template = wmma_instruction_template
+ intrinsic_template = mma_intrinsic_template
+ instruction_template = mma_instruction_template
- test_params = params
- test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
- test_params["function"] = test_params["intrinsic"].replace(".", "_")
- test_params["instruction"] = Template(instruction_template).substitute(params)
- test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
- test_params["check_a"] = check_pattern(op.a)
- test_params["check_b"] = check_pattern(op.b)
- test_params["check_c"] = check_pattern(op.c)
- test_params["check_d"] = check_pattern(op.d)
- args = ",\n ".join(make_wmma_slice_args(frag)
- for frag in (op.a, op.b, op.c))
- test_params["args"] = args
- print(Template(mma_template).substitute(test_params))
- generated_items.append((test_params["intrinsic"],
- test_params["instruction"]))
+ generated_items.append(common_mma_test_gen(params, op,
+ intrinsic_template, instruction_template))
return generated_items
@@ -497,6 +733,8 @@ def gen_check_unsupported_ops(items):
; NOINT-NOT: .{{s32|s8}}
; NOSUBINT-NOT: {{s4|u4|b1}}
; NOMMA-NOT: .m8n8k4.
+; NOALTFLOAT-NOT: .{{bf16|tf32}}
+; NODOUBLE-NOT: .f64
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -543,10 +781,61 @@ def gen_check_unsupported_ops(items):
; SUBINT-DAG: m8n8k32.mma.{{.*}}.s4
; SUBINT-DAG: m8n8k128.mma.{{.*}}.b1
+; ALTFLOAT-DAG: m16n16k16.load.{{[ab].*}}.bf16.p
+; ALTFLOAT-DAG: m8n32k16.load.{{[ab].*}}.bf16.p
+; ALTFLOAT-DAG: m32n8k16.load.{{[ab].*}}.bf16.p
+; ALTFLOAT-DAG: m16n16k8.load.{{[ab].*}}.tf32.p
+; ALTFLOAT-DAG: m16n16k16.mma.{{.*}}.bf16
+; ALTFLOAT-DAG: m8n32k16.mma.{{.*}}.bf16
+; ALTFLOAT-DAG: m32n8k16.mma.{{.*}}.bf16
+; ALTFLOAT-DAG: m16n16k8.mma.{{.*}}.tf32
+
+; DOUBLE-DAG: m8n8k4.load.{{[abc].*}}.f64.p
+; DOUBLE-DAG: m8n8k4.store.d.{{.*}}.f64.p
+; DOUBLE-DAG: m8n8k4.mma.{{.*}}.f64
+
; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f32
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f16
; MMA-DAG: mma.m8n8k4.{{.*}}.f16.f16
; MMA-DAG: mma.m8n8k4.{{.*}}.f32.f32
+
+; PTX65MMA-DAG: mma.m16n8k8.row.col.f16.f16
+; PTX65MMA-DAG: mma.m16n8k8.row.col.f32.f32
+; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.u8
+; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.s8
+; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.s8.u8
+; PTX65MMA-DAG: mma.m8n8k16.row.col{{.*}}.u8.s8
+; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.u4
+; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.s4
+; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
+; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
+
+; PTX70MMA-DAG: mma.m8n8k4.row.col.f64
+; PTX70MMA-DAG: mma.m16n8k4.row.col.tf32
+; PTX70MMA-DAG: mma.m16n8k8.row.col.tf32
+; PTX70MMA-DAG: mma.m16n8k16.row.col.bf16
+; PTX70MMA-DAG: mma.m16n8k8.row.col.bf16
+; PTX70MMA-DAG: mma.m16n8k16.row.col.f16.f16
+; PTX70MMA-DAG: mma.m16n8k16.row.col.f32.f32
+; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.u8
+; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.s8
+; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.s8.u8
+; PTX70MMA-DAG: mma.m16n8k16.row.col{{.*}}.u8.s8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.u8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.s8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s8.u8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u8.s8
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.u4
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.s4
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.s4.u4
+; PTX70MMA-DAG: mma.m16n8k32.row.col{{.*}}.u4.s4
+; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.u4
+; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.s4
+; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.s4.u4
+; PTX70MMA-DAG: mma.m16n8k64.row.col{{.*}}.u4.s4
+; PTX70MMA-DAG: mma.m8n8k128.row.col.b1
+; PTX70MMA-DAG: mma.m16n8k128.row.col.b1
+; PTX70MMA-DAG: mma.m16n8k256.row.col.b1
;
""")
@@ -561,6 +850,7 @@ def gen_tests():
items = gen_wmma_load_tests()
items += gen_wmma_store_tests()
items += gen_wmma_mma_tests()
+ items += gen_mma_tests()
gen_check_unsupported_ops(items)
parser = argparse.ArgumentParser()
More information about the cfe-commits
mailing list