[llvm] r315601 - [NVPTX] Implemented wmma intrinsics and instructions.
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 12 11:27:55 PDT 2017
Author: tra
Date: Thu Oct 12 11:27:55 2017
New Revision: 315601
URL: http://llvm.org/viewvc/llvm-project?rev=315601&view=rev
Log:
[NVPTX] Implemented wmma intrinsics and instructions.
WMMA = "Warp Level Matrix Multiply-Accumulate".
These are the new instructions introduced in PTX6.0 and available
on sm_70 GPUs.
Differential Revision: https://reviews.llvm.org/D38645
Added:
llvm/trunk/test/CodeGen/NVPTX/wmma.py
Modified:
llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
Modified: llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td (original)
+++ llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td Thu Oct 12 11:27:55 2017
@@ -3869,4 +3869,150 @@ def int_nvvm_match_all_sync_i64p :
Intrinsic<[llvm_i64_ty, llvm_i1_ty], [llvm_i32_ty, llvm_i64_ty],
[IntrNoMem, IntrConvergent], "llvm.nvvm.match.all.sync.i64p">;
+//
+// WMMA instructions
+//
+
+// WMMA.LOAD
+class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
+ string Type, LLVMType regty, int WithStride>
+ : Intrinsic<!if(!eq(Abc#Type,"cf16"),
+ [regty, regty, regty, regty],
+ [regty, regty, regty, regty,
+ regty, regty, regty, regty]),
+ !if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
+ [], // Properties must be set during instantiation.
+ "llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
+ #Space
+ #!if(WithStride,".stride","")
+ #"."#Type>;
+
+multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
+ string Type, LLVMType regty> {
+ def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
+ def NAME : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
+}
+
+multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
+ string Type, LLVMType regty> {
+ defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
+ defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
+ defm NAME: NVVM_WMMA_LD_ALST<Abc, Layout, "", Type, regty>;
+}
+
+multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
+ defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
+ defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
+}
+
+// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
+// passed to Intrinsic<> form inside of a multiclass. Setting them globally
+// outside of the multiclass works.
+let IntrProperties = [IntrReadMem, IntrArgMemOnly,
+ ReadOnly<0>, NoCapture<0>] in {
+ defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
+ defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
+ defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
+ defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
+}
+
+// WMMA.STORE.D
+class NVVM_WMMA_STD_LSTS<string Layout, string Space,
+ string Type, LLVMType regty, int WithStride,
+ // This is only used to create a typed empty array we
+ // need to pass to !if below.
+ list<LLVMType>Empty=[]>
+ : Intrinsic<[],
+ !listconcat(
+ [llvm_ptr_ty],
+ !if(!eq(Type,"f16"),
+ [regty, regty, regty, regty],
+ [regty, regty, regty, regty,
+ regty, regty, regty, regty]),
+ !if(WithStride, [llvm_i32_ty], Empty)),
+ [], // Properties must be set during instantiation.
+ "llvm.nvvm.wmma.store.d.sync."#Layout
+ #".m16n16k16"#Space
+ #!if(WithStride,".stride","")
+ #"."#Type>;
+
+multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
+ string Type, LLVMType regty> {
+ def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
+ def NAME: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
+}
+
+multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
+ defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
+ defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
+ defm NAME: NVVM_WMMA_STD_LST<Layout, "", Type, regty>;
+}
+
+multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
+ defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
+ defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
+}
+
+let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
+ WriteOnly<0>, NoCapture<0>] in {
+ defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
+ defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
+}
+
+// WMMA.MMA
+class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
+ string DType, LLVMType d_regty,
+ string CType, LLVMType c_regty,
+ string Satfinite = "">
+ : Intrinsic<!if(!eq(DType,"f16"),
+ [d_regty, d_regty, d_regty, d_regty],
+ [d_regty, d_regty, d_regty, d_regty,
+ d_regty, d_regty, d_regty, d_regty]),
+ !listconcat(
+ [// A
+ llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
+ llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
+ // B
+ llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
+ llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
+ !if(!eq(CType,"f16"),
+ [c_regty, c_regty, c_regty, c_regty],
+ [c_regty, c_regty, c_regty, c_regty,
+ c_regty, c_regty, c_regty, c_regty])),
+ [IntrNoMem],
+ "llvm.nvvm.wmma.mma.sync."#ALayout#"."#BLayout
+ #".m16n16k16."#DType#"."#CType#Satfinite>;
+
+multiclass NVVM_WMMA_MMA_ABDC<string ALayout, string BLayout,
+ string DType, LLVMType d_regty,
+ string CType, LLVMType c_regty> {
+ def NAME : NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
+ DType, d_regty,
+ CType, c_regty>;
+ def _satfinite: NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
+ DType, d_regty,
+ CType, c_regty,".satfinite">;
+}
+
+multiclass NVVM_WMMA_MMA_ABD<string ALayout, string BLayout,
+ string DType, LLVMType d_regty> {
+ defm _f16: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
+ "f16", llvm_v2f16_ty>;
+ defm _f32: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
+ "f32", llvm_float_ty>;
+}
+
+multiclass NVVM_WMMA_MMA_AB<string ALayout, string BLayout> {
+ defm _f16: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f16", llvm_v2f16_ty>;
+ defm _f32: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f32", llvm_float_ty>;
+}
+
+multiclass NVVM_WMMA_MMA_A<string ALayout> {
+ defm _col: NVVM_WMMA_MMA_AB<ALayout, "col">;
+ defm _row: NVVM_WMMA_MMA_AB<ALayout, "row">;
+}
+
+defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">;
+defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">;
+
} // let TargetPrefix = "nvvm"
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp Thu Oct 12 11:27:55 2017
@@ -496,8 +496,318 @@ void NVPTXDAGToDAGISel::Select(SDNode *N
SelectCode(N);
}
+// Each instruction has four addressing variants. WMMA_VARIANTS() macro below
+// constructs an array indexed by WmmaVariant which getWmmaLdVariant() uses to
+// look up the intrinsic ID of particular variant.
+enum WmmaVariant {
+ WMMA_VARIANT_ARI64,
+ WMMA_VARIANT_ARI64_STRIDE,
+ WMMA_VARIANT_AVAR,
+ WMMA_VARIANT_AVAR_STRIDE,
+};
+
+// clang-format off
+#define WMMA_VARIANTS(base) \
+ {{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }}
+// clang-format on
+
+static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride,
+ const std::array<unsigned, 4> Variants) {
+ if (Stride) {
+ if (Variant == WMMA_VARIANT_ARI64)
+ Variant = WMMA_VARIANT_ARI64_STRIDE;
+ else if (Variant == WMMA_VARIANT_AVAR)
+ Variant = WMMA_VARIANT_AVAR_STRIDE;
+ }
+ return Variants[Variant];
+}
+
+static Optional<unsigned>
+getWmmaLdStOpcode(unsigned IntrinsicID,
+ WmmaVariant Variant = WMMA_VARIANT_ARI64) {
+ switch (IntrinsicID) {
+ default:
+ return None;
+ //
+ // WMMA_LOAD_A f16
+ //
+ case Intrinsic::nvvm_wmma_load_a_f16_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
+ case Intrinsic::nvvm_wmma_load_a_f16_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_global:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_global:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
+ case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
+ case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
+
+ //
+ // WMMA_LOAD_B f16
+ //
+ case Intrinsic::nvvm_wmma_load_b_f16_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
+ case Intrinsic::nvvm_wmma_load_b_f16_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_global:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_global:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
+ case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
+ case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
+
+ //
+ // WMMA_LOAD_C f16
+ //
+ case Intrinsic::nvvm_wmma_load_c_f16_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
+ case Intrinsic::nvvm_wmma_load_c_f16_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
+ case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
+ case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
+
+ //
+ // WMMA_LOAD_C f32
+ //
+ case Intrinsic::nvvm_wmma_load_c_f32_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
+ case Intrinsic::nvvm_wmma_load_c_f32_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
+ case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
+ case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
+
+ //
+ // WMMA_STORE_D f16
+ //
+ case Intrinsic::nvvm_wmma_store_d_f16_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
+ case Intrinsic::nvvm_wmma_store_d_f16_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
+ case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
+ case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
+
+ //
+ // WMMA_STORE_D f32
+ //
+ case Intrinsic::nvvm_wmma_store_d_f32_col:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
+ case Intrinsic::nvvm_wmma_store_d_f32_row:
+ return getWmmaLdVariant(Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
+ return getWmmaLdVariant(Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_global:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/false,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
+ case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
+ case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride:
+ return getWmmaLdVariant(
+ Variant, /*Stride=*/true,
+ WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
+ }
+}
+#undef WMMA_VARIANTS
+
bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+ if (getWmmaLdStOpcode(IID))
+ return tryWMMA_LDST(N);
+
switch (IID) {
default:
return false;
@@ -719,6 +1029,39 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoCh
case Intrinsic::nvvm_match_all_sync_i64p:
SelectMatchAll(N);
return true;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
+ return tryWMMA_MMA(N);
}
}
@@ -3725,3 +4068,172 @@ unsigned NVPTXDAGToDAGISel::GetConvertOp
}
}
}
+
+bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) {
+ SDValue Chain = N->getOperand(0);
+ unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+ SDValue Op1 = N->getOperand(2);
+ SDValue Addr, Offset, Base;
+ Optional<unsigned> Opcode;
+ SDLoc DL(N);
+ MemSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
+ WmmaVariant Variant;
+ SmallVector<SDValue, 12> Ops;
+ bool isStore = N->getNumValues() == 1; // Store ops only return a chain.
+
+ if (SelectDirectAddr(Op1, Addr)) {
+ Variant = WMMA_VARIANT_AVAR;
+ Ops.push_back(Addr);
+ } else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) ||
+ SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) {
+ Variant = WMMA_VARIANT_ARI64;
+ Ops.push_back(Base);
+ Ops.push_back(Offset);
+ } else {
+ Variant = WMMA_VARIANT_AVAR;
+ Ops.push_back(Op1);
+ }
+ unsigned NumOps = N->getNumOperands();
+ // Pass through the rest of the operands to the machine node.
+ for (unsigned i = 3; i < NumOps; ++i)
+ Ops.push_back(N->getOperand(i));
+ Ops.push_back(Chain);
+
+ Opcode = getWmmaLdStOpcode(IID, Variant);
+ if (!Opcode) {
+ llvm::errs() << "tryWMMALD - no Opcode.\n";
+ return false;
+ }
+
+ EVT MemVT = MemSD->getMemoryVT();
+ assert(MemVT.isVector() && "Expected vector return type.");
+
+ SDNode *MN;
+ if (isStore) {
+ MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops);
+ } else {
+ SmallVector<EVT, 9> InstVTs(MemVT.getVectorNumElements(),
+ MemSD->getValueType(0));
+ InstVTs.push_back(MVT::Other);
+ MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops);
+ }
+
+ ReplaceNode(N, MN);
+ return true;
+}
+
+bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) {
+ unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
+ SDLoc DL(N);
+ unsigned Opc;
+
+ switch (IID) {
+ default:
+ return false;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32;
+ break;
+ case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
+ Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite;
+ break;
+ }
+
+ SmallVector<SDValue, 24> Ops;
+ // Pass through operands and return value types to the machine node.
+ for (unsigned i = 1; i < N->getNumOperands(); ++i)
+ Ops.push_back(N->getOperand(i));
+ SmallVector<EVT, 8> InstVTs(N->getNumValues(), N->getValueType(0));
+ SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops);
+ ReplaceNode(N, MN);
+ return true;
+}
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h Thu Oct 12 11:27:55 2017
@@ -74,6 +74,8 @@ private:
bool tryConstantFP16(SDNode *N);
bool SelectSETP_F16X2(SDNode *N);
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
+ bool tryWMMA_LDST(SDNode *N);
+ bool tryWMMA_MMA(SDNode *N);
inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
return CurDAG->getTargetConstant(Imm, DL, MVT::i32);
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Thu Oct 12 11:27:55 2017
@@ -3321,6 +3321,132 @@ bool NVPTXTargetLowering::getTgtMemIntri
switch (Intrinsic) {
default:
return false;
+ case Intrinsic::nvvm_wmma_load_a_f16_col:
+ case Intrinsic::nvvm_wmma_load_a_f16_row:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_global:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_global:
+ case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
+ case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_col:
+ case Intrinsic::nvvm_wmma_load_b_f16_row:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_global:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_global:
+ case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
+ case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v8f16;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = true;
+ Info.writeMem = false;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_load_c_f16_col:
+ case Intrinsic::nvvm_wmma_load_c_f16_row:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_global:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_global:
+ case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
+ case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v4f16;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = true;
+ Info.writeMem = false;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_load_c_f32_col:
+ case Intrinsic::nvvm_wmma_load_c_f32_row:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_global:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_global:
+ case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
+ case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v8f32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = true;
+ Info.writeMem = false;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_store_d_f16_col:
+ case Intrinsic::nvvm_wmma_store_d_f16_row:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_global:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_global:
+ case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
+ case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v4f16;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = false;
+ Info.writeMem = true;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_store_d_f32_col:
+ case Intrinsic::nvvm_wmma_store_d_f32_row:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_global:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_global:
+ case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
+ case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v8f32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.vol = false;
+ Info.readMem = false;
+ Info.writeMem = true;
+ Info.align = 16;
+ return true;
+ }
case Intrinsic::nvvm_atomic_load_add_f32:
case Intrinsic::nvvm_atomic_load_inc_32:
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td Thu Oct 12 11:27:55 2017
@@ -7368,3 +7368,208 @@ def INT_PTX_SREG_PM3 : PTX_READ_SREG_R32
def INT_PTX_SREG_WARPSIZE :
NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;",
[(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>;
+
+//
+// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
+//
+class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
+ string Type, NVPTXRegClass regclass,
+ Operand SrcOp, int WithOffset, int WithStride>
+ : NVPTXInst<!if(!eq(Abc#Type,"cf16"),
+ (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
+ (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)),
+ !if(WithStride,
+ !if(WithOffset,
+ (ins SrcOp:$src, i32imm:$offset, Int32Regs:$ldm),
+ (ins SrcOp:$src, Int32Regs:$ldm)),
+ !if(WithOffset,
+ (ins SrcOp:$src, i32imm:$offset),
+ (ins SrcOp:$src))),
+ "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
+ #!if(!eq(Abc#Type,"cf16"),
+ "{{$r0, $r1, $r2, $r3}}",
+ "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
+ #", "
+ #!if(WithOffset,"[$src+$offset]", "[$src]")
+ #!if(WithStride, ", $ldm", "")
+ #";",
+ []>,
+ Requires<[hasPTX60, hasSM70]>;
+
+multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
+ string Type, NVPTXRegClass regclass,
+ Operand SrcOp, int WithOffset = 0> {
+ def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
+ WithOffset, 1>;
+ def NAME: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
+ WithOffset, 0>;
+}
+
+multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
+ string Type, NVPTXRegClass regclass> {
+ defm _avar: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 0>;
+ defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 1>;
+}
+
+multiclass WMMA_LOAD_ALT<string Abc, string Layout,
+ string Type, NVPTXRegClass regclass> {
+ defm _global: WMMA_LOAD_ALST<Abc, Layout, ".global", Type, regclass>;
+ defm _shared: WMMA_LOAD_ALST<Abc, Layout, ".shared", Type, regclass>;
+ defm NAME: WMMA_LOAD_ALST<Abc, Layout, "", Type, regclass>;
+}
+
+multiclass WMMA_LOAD_AT<string Abc, string Type, NVPTXRegClass regclass> {
+ defm _row: WMMA_LOAD_ALT<Abc, "row", Type, regclass>;
+ defm _col: WMMA_LOAD_ALT<Abc, "col", Type, regclass>;
+}
+
+defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>;
+defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>;
+defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>;
+defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
+
+//
+// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
+//
+class WMMA_STORE_D_LSTOS<string Layout, string Space,
+ string Type, NVPTXRegClass regclass,
+ Operand DstOp, int WithOffset, int WithStride>
+ : NVPTXInst<(outs),
+ !if(!eq(Type,"f16"),
+ !if(WithStride,
+ !if(WithOffset,
+ (ins DstOp:$src, i32imm:$offset,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ Int32Regs:$ldm),
+ (ins DstOp:$src,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ Int32Regs:$ldm)),
+ !if(WithOffset,
+ (ins DstOp:$src, i32imm:$offset,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
+ (ins DstOp:$src,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3))),
+ !if(WithStride,
+ !if(WithOffset,
+ (ins DstOp:$src, i32imm:$offset,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
+ Int32Regs:$ldm),
+ (ins DstOp:$src,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
+ Int32Regs:$ldm)),
+ !if(WithOffset,
+ (ins DstOp:$src, i32imm:$offset,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7),
+ (ins DstOp:$src,
+ regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+ regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)))),
+ "wmma.store.d.sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
+ #!if(WithOffset,"[$src+$offset], ", "[$src], ")
+ #!if(!eq(Type,"f16"),
+ "{{$r0, $r1, $r2, $r3}}",
+ "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
+ #!if(WithStride, ", $ldm", "")
+ #";",
+ []>,
+ Requires<[hasPTX60, hasSM70]>;
+
+multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
+ string Type, NVPTXRegClass regclass,
+ Operand DstOp, int WithOffset = 0> {
+ def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
+ WithOffset, 1>;
+ def NAME: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
+ WithOffset, 0>;
+}
+
+multiclass WMMA_STORE_D_LST<string Layout, string Space,
+ string Type, NVPTXRegClass regclass> {
+ defm _avar: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 0>;
+ defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 1>;
+}
+
+multiclass WMMA_STORE_D_LT<string Layout,
+ string Type, NVPTXRegClass regclass> {
+ defm _global: WMMA_STORE_D_LST<Layout, ".global", Type, regclass>;
+ defm _shared: WMMA_STORE_D_LST<Layout, ".shared", Type, regclass>;
+ defm NAME: WMMA_STORE_D_LST<Layout, "", Type, regclass>;
+}
+
+multiclass WMMA_STORE_D_T<string Type, NVPTXRegClass regclass> {
+ defm _row: WMMA_STORE_D_LT<"row", Type, regclass>;
+ defm _col: WMMA_STORE_D_LT<"col", Type, regclass>;
+}
+
+defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>;
+defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>;
+
+// WMMA.MMA
+class WMMA_MMA_ABDCS<string ALayout, string BLayout,
+ string DType, NVPTXRegClass d_reg,
+ string CType, NVPTXRegClass c_reg,
+ NVPTXRegClass ab_reg,
+ string Satfinite = "">
+ : NVPTXInst<!if(!eq(DType,"f16"),
+ (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
+ (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
+ d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)),
+ !if(!eq(CType,"f16"),
+ (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
+ ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
+ ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
+ ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
+ c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3),
+ (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
+ ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
+ ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
+ ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
+ c_reg:$c0, c_reg:$c1, c_reg:$c2, c_reg:$c3,
+ c_reg:$c4, c_reg:$c5, c_reg:$c6, c_reg:$c7)),
+ "wmma.mma.sync."#ALayout#"."#BLayout#".m16n16k16."#
+ #DType#"."#CType#Satfinite
+ #"\n\t\t"
+ #!if(!eq(DType,"f16"),
+ "{{$d0, $d1, $d2, $d3}}, \n\t\t",
+ "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
+ #"{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
+ #"{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
+ #!if(!eq(CType,"f16"),
+ "{{$c0, $c1, $c2, $c3}};",
+ "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"),
+ []>,
+ Requires<[hasPTX60, hasSM70]>;
+
+multiclass WMMA_MMA_ABDC<string ALayout, string BLayout,
+ string DType, NVPTXRegClass d_reg,
+ string CType, NVPTXRegClass c_reg> {
+ def _satfinite: WMMA_MMA_ABDCS<ALayout, BLayout,
+ DType, d_reg, CType, c_reg,
+ Float16x2Regs, ".satfinite">;
+ def NAME: WMMA_MMA_ABDCS<ALayout, BLayout,
+ DType, d_reg, CType, c_reg,
+ Float16x2Regs>;
+}
+
+multiclass WMMA_MMA_ABD<string ALayout, string BLayout,
+ string DType, NVPTXRegClass d_reg> {
+ defm _f16: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f16", Float16x2Regs>;
+ defm _f32: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f32", Float32Regs>;
+}
+
+multiclass WMMA_MMA_AB<string ALayout, string BLayout> {
+ defm _f16: WMMA_MMA_ABD<ALayout, BLayout, "f16", Float16x2Regs>;
+ defm _f32: WMMA_MMA_ABD<ALayout, BLayout, "f32", Float32Regs>;
+}
+
+multiclass WMMA_MMA_A<string ALayout> {
+ defm _col: WMMA_MMA_AB<ALayout, "col">;
+ defm _row: WMMA_MMA_AB<ALayout, "row">;
+}
+
+defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">;
+defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">;
+
Added: llvm/trunk/test/CodeGen/NVPTX/wmma.py
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/wmma.py?rev=315601&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/wmma.py (added)
+++ llvm/trunk/test/CodeGen/NVPTX/wmma.py Thu Oct 12 11:27:55 2017
@@ -0,0 +1,201 @@
+# This test generates all variants of wmma intrinsics and verifies that LLVM
+# generates correct instructions for them.
+
+# RUN: python %s > %t.ll
+# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 | FileCheck %t.ll
+
+from itertools import product
+from string import Template
+
+def make_wmma_slice_ty(abcd, itype):
+ elt_ty = "<2 x half>" if itype == "f16" else "float"
+ num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
+ return [elt_ty] * num_elts
+
+def make_wmma_ld_ret_ty(abc, itype):
+ return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
+
+# Convenient test patterns.
+check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
+check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
+check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
+
+def gen_wmma_load_tests():
+ load_template = """
+declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args});
+
+; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
+define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) {
+; CHECK wmma.load.${intrinsic_suffix}
+; CHECK: {${check_result}}
+; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
+ %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args});
+ ret ${ret_ty} %v0;
+}
+
+; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
+define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
+; CHECK wmma.load.${intrinsic_suffix}
+; CHECK: {${check_result}}
+; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
+ %src1 = getelementptr i8, i8* %src, i32 128;
+ %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args});
+ ret ${ret_ty} %v0;
+}
+"""
+ suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
+ instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
+
+ for abc, layout, space, stride, itype in product(
+ "abc",
+ ["row","col"],
+ ["",".shared",".global"],
+ ["", ".stride"],
+ ["f16", "f32"]):
+
+ params = {
+ "abc" : abc,
+ "layout" : layout,
+ "space" : space,
+ "stride" : stride,
+ "itype" : itype
+ }
+
+ if itype == "f32" and abc != "c":
+ continue
+
+ test_params = params
+ test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
+ test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
+ test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
+ if abc == "c" :
+ test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
+ else:
+ test_params["check_result"] = check_f16_8
+
+ if stride:
+ test_params["extra_args"] = ", i32 %stride";
+ test_params["stride_pattern"] = ", %r{{[0-9]+}}"
+ else:
+ test_params["extra_args"] = ""
+ test_params["stride_pattern"] = ""
+
+ print(Template(load_template).substitute(test_params))
+
+def make_wmma_slice_args(itype, abcd, prefix="v"):
+ return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
+ in enumerate(make_wmma_slice_ty(abcd, itype))])
+
+def gen_wmma_store_tests():
+ store_template = """
+declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args});
+
+; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
+define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) {
+; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
+; CHECK: {${check_args}}
+; CHECK: ${stride_pattern}
+ call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args});
+ ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
+define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) {
+; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
+; CHECK: ${check_args}
+; CHECK: ${stride_pattern}
+ %src1 = getelementptr i8, i8* %src, i32 128;
+ call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args});
+ ret void
+}
+"""
+ suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
+ instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
+
+ for abc, layout, space, stride, itype in product(
+ "d",
+ ["row","col"],
+ ["",".shared",".global"],
+ ["", ".stride"],
+ ["f16", "f32"]):
+
+ params = {
+ "abc" : abc,
+ "layout" : layout,
+ "space" : space,
+ "stride" : stride,
+ "itype" : itype
+ }
+
+ test_params = params
+ test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
+ test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
+ test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
+ test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
+ if stride:
+ test_params["extra_args"] = ", i32 %stride";
+ test_params["stride_pattern"] = ", %r{{[0-9]+}};"
+ else:
+ test_params["extra_args"] = ""
+ test_params["stride_pattern"] = ";"
+ test_params["args"] = make_wmma_slice_args(itype, "d");
+
+ print(Template(store_template).substitute(test_params))
+
+def gen_wmma_mma_tests():
+ mma_template = """
+declare ${ret_ty} @llvm.nvvm.wmma.mma.sync.$intrinsic_suffix(
+ ${args});
+
+; CHECK-LABEL: .func {{.*}}test_wmma_mma_${function_suffix}(
+define ${ret_ty} @test_wmma_mma_${function_suffix}(
+ ${args}) {
+; CHECK wmma.mma.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
+; CHECK ${check_d}
+; CHECK ${check_ab}
+; CHECK ${check_ab}
+; CHECK ${check_c}
+ %r = call ${ret_ty} @llvm.nvvm.wmma.mma.sync.${intrinsic_suffix}(
+ ${args});
+ ret ${ret_ty} %r;
+}
+"""
+ suffix_template = "${alayout}.${blayout}.m16n16k16.${dtype}.${ctype}${satf}"
+
+ for alayout, blayout, ctype, dtype, satf in product(
+ ["row","col"],
+ ["row","col"],
+ ["f16", "f32"],
+ ["f16", "f32"],
+ [".satfinite", ""]):
+
+ params = {
+ "alayout" : alayout,
+ "blayout" : blayout,
+ "ctype" : ctype,
+ "dtype" : dtype,
+ "satf" : satf
+ }
+
+ test_params = params
+ test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
+ test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".", "_")
+ test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
+ test_params["check_ab"] = check_f16_8
+ test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
+ test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
+ args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
+ for abcd, t in (("a", "f16"),
+ ("b", "f16"),
+ ("c", ctype)))
+ test_params["args"] = args
+ print(Template(mma_template).substitute(test_params))
+
+def main():
+ gen_wmma_load_tests()
+ gen_wmma_store_tests()
+ gen_wmma_mma_tests()
+
+main()
More information about the llvm-commits
mailing list