[llvm] r315601 - [NVPTX] Implemented wmma intrinsics and instructions.

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 12 11:27:55 PDT 2017


Author: tra
Date: Thu Oct 12 11:27:55 2017
New Revision: 315601

URL: http://llvm.org/viewvc/llvm-project?rev=315601&view=rev
Log:
[NVPTX] Implemented wmma intrinsics and instructions.

WMMA = "Warp Level Matrix Multiply-Accumulate".
These are the new instructions introduced in PTX6.0 and available
on sm_70 GPUs.

Differential Revision: https://reviews.llvm.org/D38645

Added:
    llvm/trunk/test/CodeGen/NVPTX/wmma.py
Modified:
    llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
    llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
    llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td

Modified: llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td (original)
+++ llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td Thu Oct 12 11:27:55 2017
@@ -3869,4 +3869,150 @@ def int_nvvm_match_all_sync_i64p :
   Intrinsic<[llvm_i64_ty, llvm_i1_ty], [llvm_i32_ty, llvm_i64_ty],
             [IntrNoMem, IntrConvergent], "llvm.nvvm.match.all.sync.i64p">;
 
+//
+// WMMA instructions
+//
+
+// WMMA.LOAD
+class NVVM_WMMA_LD_ALSTS<string Abc, string Layout, string Space,
+                         string Type, LLVMType regty, int WithStride>
+  : Intrinsic<!if(!eq(Abc#Type,"cf16"),
+                  [regty, regty, regty, regty],
+                  [regty, regty, regty, regty,
+                   regty, regty, regty, regty]),
+              !if(WithStride, [llvm_ptr_ty, llvm_i32_ty], [llvm_ptr_ty]),
+              [], // Properties must be set during instantiation.
+              "llvm.nvvm.wmma.load."#Abc#".sync."#Layout#".m16n16k16"
+                #Space
+                #!if(WithStride,".stride","")
+                #"."#Type>;
+
+multiclass NVVM_WMMA_LD_ALST<string Abc, string Layout, string Space,
+                           string Type, LLVMType regty> {
+  def _stride: NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 1>;
+  def NAME   : NVVM_WMMA_LD_ALSTS<Abc, Layout, Space, Type, regty, 0>;
+}
+
+multiclass NVVM_WMMA_LD_ALT<string Abc, string Layout,
+                        string Type, LLVMType regty> {
+  defm _global: NVVM_WMMA_LD_ALST<Abc, Layout, ".global", Type, regty>;
+  defm _shared: NVVM_WMMA_LD_ALST<Abc, Layout, ".shared", Type, regty>;
+  defm NAME:    NVVM_WMMA_LD_ALST<Abc, Layout,        "", Type, regty>;
+}
+
+multiclass NVVM_WMMA_LD_AT<string Abc, string Type, LLVMType regty> {
+  defm _row: NVVM_WMMA_LD_ALT<Abc, "row", Type, regty>;
+  defm _col: NVVM_WMMA_LD_ALT<Abc, "col", Type, regty>;
+}
+
+// For some reason ReadOnly<N> and NoCapture<N> confuses tblgen if they are
+// passed to Intrinsic<> form inside of a multiclass. Setting them globally
+// outside of the multiclass works.
+let IntrProperties = [IntrReadMem, IntrArgMemOnly,
+                      ReadOnly<0>, NoCapture<0>] in {
+  defm int_nvvm_wmma_load_a_f16: NVVM_WMMA_LD_AT<"a", "f16", llvm_v2f16_ty>;
+  defm int_nvvm_wmma_load_b_f16: NVVM_WMMA_LD_AT<"b", "f16", llvm_v2f16_ty>;
+  defm int_nvvm_wmma_load_c_f16: NVVM_WMMA_LD_AT<"c", "f16", llvm_v2f16_ty>;
+  defm int_nvvm_wmma_load_c_f32: NVVM_WMMA_LD_AT<"c", "f32", llvm_float_ty>;
+}
+
+// WMMA.STORE.D
+class NVVM_WMMA_STD_LSTS<string Layout, string Space,
+                         string Type, LLVMType regty, int WithStride,
+                         // This is only used to create a typed empty array we
+                         // need to pass to !if below.
+                         list<LLVMType>Empty=[]>
+  : Intrinsic<[],
+              !listconcat(
+                [llvm_ptr_ty],
+                !if(!eq(Type,"f16"),
+                    [regty, regty, regty, regty],
+                    [regty, regty, regty, regty,
+                     regty, regty, regty, regty]),
+                !if(WithStride, [llvm_i32_ty], Empty)),
+              [], // Properties must be set during instantiation.
+              "llvm.nvvm.wmma.store.d.sync."#Layout
+                   #".m16n16k16"#Space
+                   #!if(WithStride,".stride","")
+                   #"."#Type>;
+
+multiclass NVVM_WMMA_STD_LST<string Layout, string Space,
+                            string Type, LLVMType regty> {
+  def _stride: NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 1>;
+  def NAME:    NVVM_WMMA_STD_LSTS<Layout, Space, Type, regty, 0>;
+}
+
+multiclass NVVM_WMMA_STD_LT<string Layout, string Type, LLVMType regty> {
+  defm _global: NVVM_WMMA_STD_LST<Layout, ".global", Type, regty>;
+  defm _shared: NVVM_WMMA_STD_LST<Layout, ".shared", Type, regty>;
+  defm    NAME: NVVM_WMMA_STD_LST<Layout,        "", Type, regty>;
+}
+
+multiclass NVVM_WMMA_STD_T<string Type, LLVMType regty> {
+  defm _row: NVVM_WMMA_STD_LT<"row", Type, regty>;
+  defm _col: NVVM_WMMA_STD_LT<"col", Type, regty>;
+}
+
+let IntrProperties = [IntrWriteMem, IntrArgMemOnly,
+                      WriteOnly<0>, NoCapture<0>] in {
+  defm int_nvvm_wmma_store_d_f16: NVVM_WMMA_STD_T<"f16", llvm_v2f16_ty>;
+  defm int_nvvm_wmma_store_d_f32: NVVM_WMMA_STD_T<"f32", llvm_float_ty>;
+}
+
+// WMMA.MMA
+class NVVM_WMMA_MMA_ABDCS<string ALayout, string BLayout,
+                          string DType, LLVMType d_regty,
+                          string CType, LLVMType c_regty,
+                          string Satfinite = "">
+  : Intrinsic<!if(!eq(DType,"f16"),
+                      [d_regty, d_regty, d_regty, d_regty],
+                      [d_regty, d_regty, d_regty, d_regty,
+                       d_regty, d_regty, d_regty, d_regty]),
+              !listconcat(
+                [// A
+                llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
+                llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
+                // B
+                llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty,
+                llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty, llvm_v2f16_ty],
+                !if(!eq(CType,"f16"),
+                      [c_regty, c_regty, c_regty, c_regty],
+                      [c_regty, c_regty, c_regty, c_regty,
+                       c_regty, c_regty, c_regty, c_regty])),
+              [IntrNoMem],
+              "llvm.nvvm.wmma.mma.sync."#ALayout#"."#BLayout
+                 #".m16n16k16."#DType#"."#CType#Satfinite>;
+
+multiclass NVVM_WMMA_MMA_ABDC<string ALayout, string BLayout,
+                              string DType, LLVMType d_regty,
+                              string CType, LLVMType c_regty> {
+  def NAME : NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
+                                 DType, d_regty,
+                                 CType, c_regty>;
+  def _satfinite: NVVM_WMMA_MMA_ABDCS<ALayout, BLayout,
+                                      DType, d_regty,
+                                      CType, c_regty,".satfinite">;
+}
+
+multiclass NVVM_WMMA_MMA_ABD<string ALayout, string BLayout,
+                              string DType, LLVMType d_regty> {
+  defm _f16: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
+                                "f16", llvm_v2f16_ty>;
+  defm _f32: NVVM_WMMA_MMA_ABDC<ALayout, BLayout, DType, d_regty,
+                                "f32", llvm_float_ty>;
+}
+
+multiclass NVVM_WMMA_MMA_AB<string ALayout, string BLayout> {
+  defm _f16: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f16", llvm_v2f16_ty>;
+  defm _f32: NVVM_WMMA_MMA_ABD<ALayout, BLayout, "f32", llvm_float_ty>;
+}
+
+multiclass NVVM_WMMA_MMA_A<string ALayout> {
+  defm _col: NVVM_WMMA_MMA_AB<ALayout, "col">;
+  defm _row: NVVM_WMMA_MMA_AB<ALayout, "row">;
+}
+
+defm int_nvvm_wmma_mma_sync_col: NVVM_WMMA_MMA_A<"col">;
+defm int_nvvm_wmma_mma_sync_row: NVVM_WMMA_MMA_A<"row">;
+
 } // let TargetPrefix = "nvvm"

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp Thu Oct 12 11:27:55 2017
@@ -496,8 +496,318 @@ void NVPTXDAGToDAGISel::Select(SDNode *N
   SelectCode(N);
 }
 
+// Each instruction has four addressing variants. WMMA_VARIANTS() macro below
+// constructs an array indexed by WmmaVariant which getWmmaLdVariant() uses to
+// look up the intrinsic ID of particular variant.
+enum WmmaVariant {
+  WMMA_VARIANT_ARI64,
+  WMMA_VARIANT_ARI64_STRIDE,
+  WMMA_VARIANT_AVAR,
+  WMMA_VARIANT_AVAR_STRIDE,
+};
+
+// clang-format off
+#define WMMA_VARIANTS(base) \
+  {{ base##_ari64, base##_ari64_stride, base##_avar, base##_avar_stride }}
+// clang-format on
+
+static unsigned getWmmaLdVariant(WmmaVariant Variant, bool Stride,
+                                 const std::array<unsigned, 4> Variants) {
+  if (Stride) {
+    if (Variant == WMMA_VARIANT_ARI64)
+      Variant = WMMA_VARIANT_ARI64_STRIDE;
+    else if (Variant == WMMA_VARIANT_AVAR)
+      Variant = WMMA_VARIANT_AVAR_STRIDE;
+  }
+  return Variants[Variant];
+}
+
+static Optional<unsigned>
+getWmmaLdStOpcode(unsigned IntrinsicID,
+                  WmmaVariant Variant = WMMA_VARIANT_ARI64) {
+  switch (IntrinsicID) {
+  default:
+    return None;
+  //
+  // WMMA_LOAD_A f16
+  //
+  case Intrinsic::nvvm_wmma_load_a_f16_col:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
+  case Intrinsic::nvvm_wmma_load_a_f16_row:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
+  case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col));
+  case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row));
+  case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
+  case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
+  case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_shared));
+  case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_shared));
+  case Intrinsic::nvvm_wmma_load_a_f16_col_global:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
+  case Intrinsic::nvvm_wmma_load_a_f16_row_global:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
+  case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_col_global));
+  case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_A_row_global));
+
+  //
+  // WMMA_LOAD_B f16
+  //
+  case Intrinsic::nvvm_wmma_load_b_f16_col:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
+  case Intrinsic::nvvm_wmma_load_b_f16_row:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
+  case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col));
+  case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row));
+  case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
+  case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
+  case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_shared));
+  case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_shared));
+  case Intrinsic::nvvm_wmma_load_b_f16_col_global:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
+  case Intrinsic::nvvm_wmma_load_b_f16_row_global:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
+  case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_col_global));
+  case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_B_row_global));
+
+  //
+  // WMMA_LOAD_C f16
+  //
+  case Intrinsic::nvvm_wmma_load_c_f16_col:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
+  case Intrinsic::nvvm_wmma_load_c_f16_row:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
+  case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col));
+  case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row));
+  case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
+  case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
+  case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_shared));
+  case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_shared));
+  case Intrinsic::nvvm_wmma_load_c_f16_col_global:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
+  case Intrinsic::nvvm_wmma_load_c_f16_row_global:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
+  case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_col_global));
+  case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f16_row_global));
+
+  //
+  // WMMA_LOAD_C f32
+  //
+  case Intrinsic::nvvm_wmma_load_c_f32_col:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
+  case Intrinsic::nvvm_wmma_load_c_f32_row:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
+  case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col));
+  case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row));
+  case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
+  case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
+  case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_shared));
+  case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_shared));
+  case Intrinsic::nvvm_wmma_load_c_f32_col_global:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
+  case Intrinsic::nvvm_wmma_load_c_f32_row_global:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
+  case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_col_global));
+  case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_LOAD_C_f32_row_global));
+
+  //
+  // WMMA_STORE_D f16
+  //
+  case Intrinsic::nvvm_wmma_store_d_f16_col:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
+  case Intrinsic::nvvm_wmma_store_d_f16_row:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
+  case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col));
+  case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row));
+  case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
+  case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
+  case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_shared));
+  case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_shared));
+  case Intrinsic::nvvm_wmma_store_d_f16_col_global:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
+  case Intrinsic::nvvm_wmma_store_d_f16_row_global:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
+  case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_col_global));
+  case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f16_row_global));
+
+  //
+  // WMMA_STORE_D f32
+  //
+  case Intrinsic::nvvm_wmma_store_d_f32_col:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
+  case Intrinsic::nvvm_wmma_store_d_f32_row:
+    return getWmmaLdVariant(Variant, /*Stride=*/false,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
+  case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col));
+  case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
+    return getWmmaLdVariant(Variant, /*Stride=*/true,
+                            WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row));
+  case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
+  case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
+  case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_shared));
+  case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_shared));
+  case Intrinsic::nvvm_wmma_store_d_f32_col_global:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
+  case Intrinsic::nvvm_wmma_store_d_f32_row_global:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/false,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
+  case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_col_global));
+  case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride:
+    return getWmmaLdVariant(
+        Variant, /*Stride=*/true,
+        WMMA_VARIANTS(NVPTX::INT_WMMA_STORE_D_f32_row_global));
+  }
+}
+#undef WMMA_VARIANTS
+
 bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
   unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+  if (getWmmaLdStOpcode(IID))
+    return tryWMMA_LDST(N);
+
   switch (IID) {
   default:
     return false;
@@ -719,6 +1029,39 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoCh
   case Intrinsic::nvvm_match_all_sync_i64p:
     SelectMatchAll(N);
     return true;
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
+    return tryWMMA_MMA(N);
   }
 }
 
@@ -3725,3 +4068,172 @@ unsigned NVPTXDAGToDAGISel::GetConvertOp
     }
   }
 }
+
+bool NVPTXDAGToDAGISel::tryWMMA_LDST(SDNode *N) {
+  SDValue Chain = N->getOperand(0);
+  unsigned IID = cast<ConstantSDNode>(N->getOperand(1))->getZExtValue();
+  SDValue Op1 = N->getOperand(2);
+  SDValue Addr, Offset, Base;
+  Optional<unsigned> Opcode;
+  SDLoc DL(N);
+  MemSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
+  WmmaVariant Variant;
+  SmallVector<SDValue, 12> Ops;
+  bool isStore = N->getNumValues() == 1; // Store ops only return a chain.
+
+  if (SelectDirectAddr(Op1, Addr)) {
+    Variant = WMMA_VARIANT_AVAR;
+    Ops.push_back(Addr);
+  } else if (SelectADDRsi64(Op1.getNode(), Op1, Base, Offset) ||
+             SelectADDRri64(Op1.getNode(), Op1, Base, Offset)) {
+    Variant = WMMA_VARIANT_ARI64;
+    Ops.push_back(Base);
+    Ops.push_back(Offset);
+  } else {
+    Variant = WMMA_VARIANT_AVAR;
+    Ops.push_back(Op1);
+  }
+  unsigned NumOps = N->getNumOperands();
+  // Pass through the rest of the operands to the machine node.
+  for (unsigned i = 3; i < NumOps; ++i)
+    Ops.push_back(N->getOperand(i));
+  Ops.push_back(Chain);
+
+  Opcode = getWmmaLdStOpcode(IID, Variant);
+  if (!Opcode) {
+    llvm::errs() << "tryWMMALD - no Opcode.\n";
+    return false;
+  }
+
+  EVT MemVT = MemSD->getMemoryVT();
+  assert(MemVT.isVector() && "Expected vector return type.");
+
+  SDNode *MN;
+  if (isStore) {
+    MN = CurDAG->getMachineNode(Opcode.getValue(), DL, MVT::Other, Ops);
+  } else {
+    SmallVector<EVT, 9> InstVTs(MemVT.getVectorNumElements(),
+                                MemSD->getValueType(0));
+    InstVTs.push_back(MVT::Other);
+    MN = CurDAG->getMachineNode(Opcode.getValue(), DL, InstVTs, Ops);
+  }
+
+  ReplaceNode(N, MN);
+  return true;
+}
+
+bool NVPTXDAGToDAGISel::tryWMMA_MMA(SDNode *N) {
+  unsigned IID = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
+  SDLoc DL(N);
+  unsigned Opc;
+
+  switch (IID) {
+  default:
+    return false;
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16:
+    Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f16_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f16_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32:
+    Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f16_f32_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_col_col_f16_f32_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16:
+    Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f16_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f16_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32:
+    Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_col_f32_f32_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_col_col_f32_f32_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16:
+    Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f16_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f16_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32:
+    Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f16_f32_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_col_row_f16_f32_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16:
+    Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f16_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f16_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32:
+    Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_col_row_f32_f32_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_col_row_f32_f32_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16:
+    Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f16_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f16_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32:
+    Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f16_f32_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_row_col_f16_f32_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16:
+    Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f16_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f16_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32:
+    Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_col_f32_f32_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_row_col_f32_f32_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16:
+    Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f16_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f16_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32:
+    Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f16_f32_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_row_row_f16_f32_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16:
+    Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f16_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f16_satfinite;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32:
+    Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32;
+    break;
+  case Intrinsic::nvvm_wmma_mma_sync_row_row_f32_f32_satfinite:
+    Opc = NVPTX::INT_WMMA_MMA_row_row_f32_f32_satfinite;
+    break;
+  }
+
+  SmallVector<SDValue, 24> Ops;
+  // Pass through operands and return value types to the machine node.
+  for (unsigned i = 1; i < N->getNumOperands(); ++i)
+    Ops.push_back(N->getOperand(i));
+  SmallVector<EVT, 8> InstVTs(N->getNumValues(), N->getValueType(0));
+  SDNode *MN = CurDAG->getMachineNode(Opc, DL, InstVTs, Ops);
+  ReplaceNode(N, MN);
+  return true;
+}

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelDAGToDAG.h Thu Oct 12 11:27:55 2017
@@ -74,6 +74,8 @@ private:
   bool tryConstantFP16(SDNode *N);
   bool SelectSETP_F16X2(SDNode *N);
   bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
+  bool tryWMMA_LDST(SDNode *N);
+  bool tryWMMA_MMA(SDNode *N);
 
   inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) {
     return CurDAG->getTargetConstant(Imm, DL, MVT::i32);

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Thu Oct 12 11:27:55 2017
@@ -3321,6 +3321,132 @@ bool NVPTXTargetLowering::getTgtMemIntri
   switch (Intrinsic) {
   default:
     return false;
+  case Intrinsic::nvvm_wmma_load_a_f16_col:
+  case Intrinsic::nvvm_wmma_load_a_f16_row:
+  case Intrinsic::nvvm_wmma_load_a_f16_col_stride:
+  case Intrinsic::nvvm_wmma_load_a_f16_row_stride:
+  case Intrinsic::nvvm_wmma_load_a_f16_col_shared:
+  case Intrinsic::nvvm_wmma_load_a_f16_row_shared:
+  case Intrinsic::nvvm_wmma_load_a_f16_col_shared_stride:
+  case Intrinsic::nvvm_wmma_load_a_f16_row_shared_stride:
+  case Intrinsic::nvvm_wmma_load_a_f16_col_global:
+  case Intrinsic::nvvm_wmma_load_a_f16_row_global:
+  case Intrinsic::nvvm_wmma_load_a_f16_col_global_stride:
+  case Intrinsic::nvvm_wmma_load_a_f16_row_global_stride:
+  case Intrinsic::nvvm_wmma_load_b_f16_col:
+  case Intrinsic::nvvm_wmma_load_b_f16_row:
+  case Intrinsic::nvvm_wmma_load_b_f16_col_stride:
+  case Intrinsic::nvvm_wmma_load_b_f16_row_stride:
+  case Intrinsic::nvvm_wmma_load_b_f16_col_shared:
+  case Intrinsic::nvvm_wmma_load_b_f16_row_shared:
+  case Intrinsic::nvvm_wmma_load_b_f16_col_shared_stride:
+  case Intrinsic::nvvm_wmma_load_b_f16_row_shared_stride:
+  case Intrinsic::nvvm_wmma_load_b_f16_col_global:
+  case Intrinsic::nvvm_wmma_load_b_f16_row_global:
+  case Intrinsic::nvvm_wmma_load_b_f16_col_global_stride:
+  case Intrinsic::nvvm_wmma_load_b_f16_row_global_stride: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v8f16;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.vol = false;
+    Info.readMem = true;
+    Info.writeMem = false;
+    Info.align = 16;
+    return true;
+  }
+
+  case Intrinsic::nvvm_wmma_load_c_f16_col:
+  case Intrinsic::nvvm_wmma_load_c_f16_row:
+  case Intrinsic::nvvm_wmma_load_c_f16_col_stride:
+  case Intrinsic::nvvm_wmma_load_c_f16_row_stride:
+  case Intrinsic::nvvm_wmma_load_c_f16_col_shared:
+  case Intrinsic::nvvm_wmma_load_c_f16_row_shared:
+  case Intrinsic::nvvm_wmma_load_c_f16_col_shared_stride:
+  case Intrinsic::nvvm_wmma_load_c_f16_row_shared_stride:
+  case Intrinsic::nvvm_wmma_load_c_f16_col_global:
+  case Intrinsic::nvvm_wmma_load_c_f16_row_global:
+  case Intrinsic::nvvm_wmma_load_c_f16_col_global_stride:
+  case Intrinsic::nvvm_wmma_load_c_f16_row_global_stride: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v4f16;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.vol = false;
+    Info.readMem = true;
+    Info.writeMem = false;
+    Info.align = 16;
+    return true;
+  }
+
+  case Intrinsic::nvvm_wmma_load_c_f32_col:
+  case Intrinsic::nvvm_wmma_load_c_f32_row:
+  case Intrinsic::nvvm_wmma_load_c_f32_col_stride:
+  case Intrinsic::nvvm_wmma_load_c_f32_row_stride:
+  case Intrinsic::nvvm_wmma_load_c_f32_col_shared:
+  case Intrinsic::nvvm_wmma_load_c_f32_row_shared:
+  case Intrinsic::nvvm_wmma_load_c_f32_col_shared_stride:
+  case Intrinsic::nvvm_wmma_load_c_f32_row_shared_stride:
+  case Intrinsic::nvvm_wmma_load_c_f32_col_global:
+  case Intrinsic::nvvm_wmma_load_c_f32_row_global:
+  case Intrinsic::nvvm_wmma_load_c_f32_col_global_stride:
+  case Intrinsic::nvvm_wmma_load_c_f32_row_global_stride: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v8f32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.vol = false;
+    Info.readMem = true;
+    Info.writeMem = false;
+    Info.align = 16;
+    return true;
+  }
+
+  case Intrinsic::nvvm_wmma_store_d_f16_col:
+  case Intrinsic::nvvm_wmma_store_d_f16_row:
+  case Intrinsic::nvvm_wmma_store_d_f16_col_stride:
+  case Intrinsic::nvvm_wmma_store_d_f16_row_stride:
+  case Intrinsic::nvvm_wmma_store_d_f16_col_shared:
+  case Intrinsic::nvvm_wmma_store_d_f16_row_shared:
+  case Intrinsic::nvvm_wmma_store_d_f16_col_shared_stride:
+  case Intrinsic::nvvm_wmma_store_d_f16_row_shared_stride:
+  case Intrinsic::nvvm_wmma_store_d_f16_col_global:
+  case Intrinsic::nvvm_wmma_store_d_f16_row_global:
+  case Intrinsic::nvvm_wmma_store_d_f16_col_global_stride:
+  case Intrinsic::nvvm_wmma_store_d_f16_row_global_stride: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v4f16;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.vol = false;
+    Info.readMem = false;
+    Info.writeMem = true;
+    Info.align = 16;
+    return true;
+  }
+
+  case Intrinsic::nvvm_wmma_store_d_f32_col:
+  case Intrinsic::nvvm_wmma_store_d_f32_row:
+  case Intrinsic::nvvm_wmma_store_d_f32_col_stride:
+  case Intrinsic::nvvm_wmma_store_d_f32_row_stride:
+  case Intrinsic::nvvm_wmma_store_d_f32_col_shared:
+  case Intrinsic::nvvm_wmma_store_d_f32_row_shared:
+  case Intrinsic::nvvm_wmma_store_d_f32_col_shared_stride:
+  case Intrinsic::nvvm_wmma_store_d_f32_row_shared_stride:
+  case Intrinsic::nvvm_wmma_store_d_f32_col_global:
+  case Intrinsic::nvvm_wmma_store_d_f32_row_global:
+  case Intrinsic::nvvm_wmma_store_d_f32_col_global_stride:
+  case Intrinsic::nvvm_wmma_store_d_f32_row_global_stride: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v8f32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.vol = false;
+    Info.readMem = false;
+    Info.writeMem = true;
+    Info.align = 16;
+    return true;
+  }
 
   case Intrinsic::nvvm_atomic_load_add_f32:
   case Intrinsic::nvvm_atomic_load_inc_32:

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td?rev=315601&r1=315600&r2=315601&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td Thu Oct 12 11:27:55 2017
@@ -7368,3 +7368,208 @@ def INT_PTX_SREG_PM3 : PTX_READ_SREG_R32
 def INT_PTX_SREG_WARPSIZE :
     NVPTXInst<(outs Int32Regs:$dst), (ins), "mov.u32 \t$dst, WARP_SZ;",
               [(set Int32Regs:$dst, (int_nvvm_read_ptx_sreg_warpsize))]>;
+
+//
+// wmma.load.[a|b|c].sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
+//
+class WMMA_LOAD_ALSTOS<string Abc, string Layout, string Space,
+                           string Type, NVPTXRegClass regclass,
+                           Operand SrcOp, int WithOffset, int WithStride>
+  : NVPTXInst<!if(!eq(Abc#Type,"cf16"),
+                  (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
+                  (outs regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+                         regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)),
+              !if(WithStride,
+                  !if(WithOffset,
+                      (ins SrcOp:$src, i32imm:$offset, Int32Regs:$ldm),
+                      (ins SrcOp:$src, Int32Regs:$ldm)),
+                  !if(WithOffset,
+                      (ins SrcOp:$src, i32imm:$offset),
+                      (ins SrcOp:$src))),
+              "wmma.load."#Abc#".sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
+                 #!if(!eq(Abc#Type,"cf16"),
+                      "{{$r0, $r1, $r2, $r3}}",
+                      "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
+                 #", "
+                 #!if(WithOffset,"[$src+$offset]", "[$src]")
+                 #!if(WithStride, ", $ldm", "")
+                 #";",
+              []>,
+    Requires<[hasPTX60, hasSM70]>;
+
+multiclass WMMA_LOAD_ALSTO<string Abc, string Layout, string Space,
+                           string Type, NVPTXRegClass regclass,
+                           Operand SrcOp, int WithOffset = 0> {
+  def _stride: WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
+                                WithOffset, 1>;
+  def NAME:    WMMA_LOAD_ALSTOS<Abc, Layout, Space, Type, regclass, SrcOp,
+                                WithOffset, 0>;
+}
+
+multiclass WMMA_LOAD_ALST<string Abc, string Layout, string Space,
+                          string Type, NVPTXRegClass regclass> {
+  defm _avar:  WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 0>;
+  defm _ari64: WMMA_LOAD_ALSTO<Abc, Layout, Space, Type, regclass, imemAny, 1>;
+}
+
+multiclass WMMA_LOAD_ALT<string Abc, string Layout,
+                        string Type, NVPTXRegClass regclass> {
+  defm _global: WMMA_LOAD_ALST<Abc, Layout, ".global", Type, regclass>;
+  defm _shared: WMMA_LOAD_ALST<Abc, Layout, ".shared", Type, regclass>;
+  defm NAME:    WMMA_LOAD_ALST<Abc, Layout,        "", Type, regclass>;
+}
+
+multiclass WMMA_LOAD_AT<string Abc, string Type, NVPTXRegClass regclass> {
+  defm _row: WMMA_LOAD_ALT<Abc, "row", Type, regclass>;
+  defm _col: WMMA_LOAD_ALT<Abc, "col", Type, regclass>;
+}
+
+defm INT_WMMA_LOAD_A: WMMA_LOAD_AT<"a", "f16", Float16x2Regs>;
+defm INT_WMMA_LOAD_B: WMMA_LOAD_AT<"b", "f16", Float16x2Regs>;
+defm INT_WMMA_LOAD_C_f16: WMMA_LOAD_AT<"c", "f16", Float16x2Regs>;
+defm INT_WMMA_LOAD_C_f32: WMMA_LOAD_AT<"c", "f32", Float32Regs>;
+
+//
+// wmma.store.d.sync.[row|col].m16n16k16[|.global|.shared].[f16|f32]
+//
+class WMMA_STORE_D_LSTOS<string Layout, string Space,
+                         string Type, NVPTXRegClass regclass,
+                         Operand DstOp, int WithOffset, int WithStride>
+  : NVPTXInst<(outs),
+              !if(!eq(Type,"f16"),
+                !if(WithStride,
+                  !if(WithOffset,
+                      (ins DstOp:$src, i32imm:$offset,
+                           regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+                           Int32Regs:$ldm),
+                      (ins DstOp:$src,
+                           regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+                           Int32Regs:$ldm)),
+                  !if(WithOffset,
+                      (ins DstOp:$src, i32imm:$offset,
+                           regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3),
+                      (ins DstOp:$src,
+                           regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3))),
+                !if(WithStride,
+                  !if(WithOffset,
+                      (ins DstOp:$src, i32imm:$offset,
+                           regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+                           regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
+                           Int32Regs:$ldm),
+                      (ins DstOp:$src,
+                           regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+                           regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7,
+                           Int32Regs:$ldm)),
+                  !if(WithOffset,
+                      (ins DstOp:$src, i32imm:$offset,
+                           regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+                            regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7),
+                      (ins DstOp:$src,
+                           regclass:$r0, regclass:$r1, regclass:$r2, regclass:$r3,
+                           regclass:$r4, regclass:$r5, regclass:$r6, regclass:$r7)))),
+              "wmma.store.d.sync."#Layout#".m16n16k16"#Space#"." #Type# " \t"
+                 #!if(WithOffset,"[$src+$offset], ", "[$src], ")
+                 #!if(!eq(Type,"f16"),
+                      "{{$r0, $r1, $r2, $r3}}",
+                      "{{$r0, $r1, $r2, $r3, $r4, $r5, $r6, $r7}}")
+                 #!if(WithStride, ", $ldm", "")
+                 #";",
+              []>,
+    Requires<[hasPTX60, hasSM70]>;
+
+multiclass WMMA_STORE_D_LSTO<string Layout, string Space,
+                             string Type, NVPTXRegClass regclass,
+                             Operand DstOp, int WithOffset = 0> {
+  def _stride: WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
+                                  WithOffset, 1>;
+  def NAME:    WMMA_STORE_D_LSTOS<Layout, Space, Type, regclass, DstOp,
+                                  WithOffset, 0>;
+}
+
+multiclass WMMA_STORE_D_LST<string Layout, string Space,
+                            string Type, NVPTXRegClass regclass> {
+  defm _avar:  WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 0>;
+  defm _ari64: WMMA_STORE_D_LSTO<Layout, Space, Type, regclass, imemAny, 1>;
+}
+
+multiclass WMMA_STORE_D_LT<string Layout,
+                           string Type, NVPTXRegClass regclass> {
+  defm _global: WMMA_STORE_D_LST<Layout, ".global", Type, regclass>;
+  defm _shared: WMMA_STORE_D_LST<Layout, ".shared", Type, regclass>;
+  defm NAME:    WMMA_STORE_D_LST<Layout,        "", Type, regclass>;
+}
+
+multiclass WMMA_STORE_D_T<string Type, NVPTXRegClass regclass> {
+  defm _row: WMMA_STORE_D_LT<"row", Type, regclass>;
+  defm _col: WMMA_STORE_D_LT<"col", Type, regclass>;
+}
+
+defm INT_WMMA_STORE_D_f16: WMMA_STORE_D_T<"f16", Float16x2Regs>;
+defm INT_WMMA_STORE_D_f32: WMMA_STORE_D_T<"f32", Float32Regs>;
+
+// WMMA.MMA
+class WMMA_MMA_ABDCS<string ALayout, string BLayout,
+                     string DType, NVPTXRegClass d_reg,
+                     string CType, NVPTXRegClass c_reg,
+                     NVPTXRegClass ab_reg,
+                     string Satfinite = "">
+  : NVPTXInst<!if(!eq(DType,"f16"),
+                  (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3),
+                  (outs d_reg:$d0, d_reg:$d1, d_reg:$d2, d_reg:$d3,
+                        d_reg:$d4, d_reg:$d5, d_reg:$d6, d_reg:$d7)),
+              !if(!eq(CType,"f16"),
+                  (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
+                       ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
+                       ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
+                       ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
+                        c_reg:$c0,  c_reg:$c1,  c_reg:$c2,  c_reg:$c3),
+                  (ins ab_reg:$a0, ab_reg:$a1, ab_reg:$a2, ab_reg:$a3,
+                       ab_reg:$a4, ab_reg:$a5, ab_reg:$a6, ab_reg:$a7,
+                       ab_reg:$b0, ab_reg:$b1, ab_reg:$b2, ab_reg:$b3,
+                       ab_reg:$b4, ab_reg:$b5, ab_reg:$b6, ab_reg:$b7,
+                        c_reg:$c0,  c_reg:$c1,  c_reg:$c2,  c_reg:$c3,
+                        c_reg:$c4,  c_reg:$c5,  c_reg:$c6,  c_reg:$c7)),
+              "wmma.mma.sync."#ALayout#"."#BLayout#".m16n16k16."#
+                 #DType#"."#CType#Satfinite
+                 #"\n\t\t"
+                 #!if(!eq(DType,"f16"),
+                      "{{$d0, $d1, $d2, $d3}}, \n\t\t",
+                      "{{$d0, $d1, $d2, $d3, $d4, $d5, $d6, $d7}},\n\t\t")
+                 #"{{$a0, $a1, $a2, $a3, $a4, $a5, $a6, $a7}},\n\t\t"
+                 #"{{$b0, $b1, $b2, $b3, $b4, $b5, $b6, $b7}},\n\t\t"
+                 #!if(!eq(CType,"f16"),
+                      "{{$c0, $c1, $c2, $c3}};",
+                      "{{$c0, $c1, $c2, $c3, $c4, $c5, $c6, $c7}};"),
+              []>,
+    Requires<[hasPTX60, hasSM70]>;
+
+multiclass WMMA_MMA_ABDC<string ALayout, string BLayout,
+                         string DType, NVPTXRegClass d_reg,
+                         string CType, NVPTXRegClass c_reg> {
+  def _satfinite: WMMA_MMA_ABDCS<ALayout, BLayout,
+                                 DType, d_reg, CType, c_reg,
+                                 Float16x2Regs, ".satfinite">;
+  def NAME:       WMMA_MMA_ABDCS<ALayout, BLayout,
+                                 DType, d_reg, CType, c_reg,
+                                 Float16x2Regs>;
+}
+
+multiclass WMMA_MMA_ABD<string ALayout, string BLayout,
+                        string DType, NVPTXRegClass d_reg> {
+  defm _f16: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f16", Float16x2Regs>;
+  defm _f32: WMMA_MMA_ABDC<ALayout, BLayout, DType, d_reg, "f32", Float32Regs>;
+}
+
+multiclass WMMA_MMA_AB<string ALayout, string BLayout> {
+  defm _f16: WMMA_MMA_ABD<ALayout, BLayout, "f16", Float16x2Regs>;
+  defm _f32: WMMA_MMA_ABD<ALayout, BLayout, "f32", Float32Regs>;
+}
+
+multiclass WMMA_MMA_A<string ALayout> {
+  defm _col: WMMA_MMA_AB<ALayout, "col">;
+  defm _row: WMMA_MMA_AB<ALayout, "row">;
+}
+
+defm INT_WMMA_MMA_col: WMMA_MMA_A<"col">;
+defm INT_WMMA_MMA_row: WMMA_MMA_A<"row">;
+

Added: llvm/trunk/test/CodeGen/NVPTX/wmma.py
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/wmma.py?rev=315601&view=auto
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/wmma.py (added)
+++ llvm/trunk/test/CodeGen/NVPTX/wmma.py Thu Oct 12 11:27:55 2017
@@ -0,0 +1,201 @@
+# This test generates all variants of wmma intrinsics and verifies that LLVM
+# generates correct instructions for them.
+
+# RUN: python %s > %t.ll
+# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 | FileCheck %t.ll
+
+from itertools import product
+from string import Template
+
+def make_wmma_slice_ty(abcd, itype):
+  elt_ty = "<2 x half>" if itype == "f16" else "float"
+  num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
+  return [elt_ty] * num_elts
+
+def make_wmma_ld_ret_ty(abc, itype):
+  return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
+
+# Convenient test patterns.
+check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
+check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
+check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
+
+def gen_wmma_load_tests():
+  load_template = """
+declare ${ret_ty} @llvm.nvvm.wmma.load.$intrinsic_suffix(i8* %src ${extra_args});
+
+; CHECK-LABEL: .func {{.*}}test_wmma_load_${function_suffix}(
+define ${ret_ty} @test_wmma_load_${function_suffix}(i8* %src ${extra_args}) {
+; CHECK wmma.load.${intrinsic_suffix}
+; CHECK: {${check_result}}
+; CHECK: [%rd{{[0-9]+}}]${stride_pattern}
+  %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src ${extra_args});
+  ret ${ret_ty} %v0;
+}
+
+; CHECK-LABEL: .func{{.*}}test_wmma_load_${function_suffix}_o(
+define ${ret_ty} @test_wmma_load_${function_suffix}_o(i8* %src ${extra_args}) {
+; CHECK wmma.load.${intrinsic_suffix}
+; CHECK: {${check_result}}
+; CHECK: [%rd{{[0-9]+}}+128]${stride_pattern}
+  %src1 = getelementptr i8, i8* %src, i32 128;
+  %v0 = call ${ret_ty} @llvm.nvvm.wmma.load.${intrinsic_suffix}(i8* %src1 ${extra_args});
+  ret ${ret_ty} %v0;
+}
+"""
+  suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
+  instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
+
+  for abc, layout, space, stride, itype in product(
+      "abc",
+      ["row","col"],
+      ["",".shared",".global"],
+      ["", ".stride"],
+      ["f16", "f32"]):
+
+    params = {
+        "abc" : abc,
+        "layout" : layout,
+        "space" : space,
+        "stride" : stride,
+        "itype" : itype
+    }
+
+    if itype == "f32" and abc != "c":
+      continue
+
+    test_params = params
+    test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
+    test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
+    test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
+    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
+    if abc == "c" :
+      test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
+    else:
+      test_params["check_result"] = check_f16_8
+
+    if stride:
+      test_params["extra_args"] = ", i32 %stride";
+      test_params["stride_pattern"] = ", %r{{[0-9]+}}"
+    else:
+      test_params["extra_args"] = ""
+      test_params["stride_pattern"] = ""
+
+    print(Template(load_template).substitute(test_params))
+
+def make_wmma_slice_args(itype, abcd, prefix="v"):
+  return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
+                  in enumerate(make_wmma_slice_ty(abcd, itype))])
+
+def gen_wmma_store_tests():
+  store_template = """
+declare void @llvm.nvvm.wmma.store.$intrinsic_suffix(i8* %src, ${args}${extra_args});
+
+; CHECK-LABEL: .func {{.*}}test_wmma_store_${function_suffix}(
+define void @test_wmma_store_${function_suffix}(i8* %src, ${args}${extra_args}) {
+; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
+; CHECK: {${check_args}}
+; CHECK: ${stride_pattern}
+  call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src, ${args} ${extra_args});
+  ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_wmma_store_${function_suffix}_o(
+define void @test_wmma_store_${function_suffix}_o(i8* %src, ${args}${extra_args}) {
+; CHECK wmma.store.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}+128]
+; CHECK: ${check_args}
+; CHECK: ${stride_pattern}
+  %src1 = getelementptr i8, i8* %src, i32 128;
+  call void @llvm.nvvm.wmma.store.${intrinsic_suffix}(i8* %src1, ${args}${extra_args});
+  ret void
+}
+"""
+  suffix_template = "${abc}.sync.${layout}.m16n16k16${space}${stride}.${itype}"
+  instruction_template = "${abc}.sync.${layout}.m16n16k16${space}.${itype}"
+
+  for abc, layout, space, stride, itype in product(
+      "d",
+      ["row","col"],
+      ["",".shared",".global"],
+      ["", ".stride"],
+      ["f16", "f32"]):
+
+    params = {
+        "abc" : abc,
+        "layout" : layout,
+        "space" : space,
+        "stride" : stride,
+        "itype" : itype
+    }
+
+    test_params = params
+    test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
+    test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".","_")
+    test_params["instruction_suffix"] = Template(instruction_template).substitute(params)
+    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
+    test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
+    if stride:
+      test_params["extra_args"] = ", i32 %stride";
+      test_params["stride_pattern"] = ", %r{{[0-9]+}};"
+    else:
+      test_params["extra_args"] = ""
+      test_params["stride_pattern"] = ";"
+    test_params["args"] = make_wmma_slice_args(itype, "d");
+
+    print(Template(store_template).substitute(test_params))
+
+def gen_wmma_mma_tests():
+  mma_template = """
+declare ${ret_ty} @llvm.nvvm.wmma.mma.sync.$intrinsic_suffix(
+        ${args});
+
+; CHECK-LABEL: .func {{.*}}test_wmma_mma_${function_suffix}(
+define ${ret_ty} @test_wmma_mma_${function_suffix}(
+        ${args}) {
+; CHECK wmma.mma.${intrinsic_suffix} {{.*}}[%rd{{[0-9+]}}
+; CHECK ${check_d}
+; CHECK ${check_ab}
+; CHECK ${check_ab}
+; CHECK ${check_c}
+  %r = call ${ret_ty} @llvm.nvvm.wmma.mma.sync.${intrinsic_suffix}(
+        ${args});
+  ret ${ret_ty} %r;
+}
+"""
+  suffix_template = "${alayout}.${blayout}.m16n16k16.${dtype}.${ctype}${satf}"
+
+  for alayout, blayout, ctype, dtype, satf in product(
+      ["row","col"],
+      ["row","col"],
+      ["f16", "f32"],
+      ["f16", "f32"],
+      [".satfinite", ""]):
+
+    params = {
+        "alayout" : alayout,
+        "blayout" : blayout,
+        "ctype" : ctype,
+        "dtype" : dtype,
+        "satf"  : satf
+    }
+
+    test_params = params
+    test_params["intrinsic_suffix"] = Template(suffix_template).substitute(params)
+    test_params["function_suffix"] = test_params["intrinsic_suffix"].replace(".", "_")
+    test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
+    test_params["check_ab"] = check_f16_8
+    test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
+    test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
+    args = ",\n        ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
+                              for abcd, t in (("a", "f16"),
+                                              ("b", "f16"),
+                                              ("c", ctype)))
+    test_params["args"] = args
+    print(Template(mma_template).substitute(test_params))
+
+def main():
+  gen_wmma_load_tests()
+  gen_wmma_store_tests()
+  gen_wmma_mma_tests()
+
+main()




More information about the llvm-commits mailing list