[llvm] r359247 - PTX 6.3 extends `wmma` instruction to support s8/u8/s4/u4/b1 -> s32.

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 15:27:58 PDT 2019


Author: tra
Date: Thu Apr 25 15:27:57 2019
New Revision: 359247

URL: http://llvm.org/viewvc/llvm-project?rev=359247&view=rev
Log:
PTX 6.3 extends `wmma` instruction to support s8/u8/s4/u4/b1 -> s32.

All of the new instructions are still handled mostly by tablegen. I've slightly
refactored the code to drive intrinsic/instruction generation from a master
list of supported variants, so all irregularities have to be implemented in one place only.

The test generation script wmma.py has been refactored in a similar way.

Differential Revision: https://reviews.llvm.org/D60015

Modified:
    llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
    llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
    llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
    llvm/trunk/test/CodeGen/NVPTX/wmma.py

Modified: llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td (original)
+++ llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td Thu Apr 25 15:27:57 2019
@@ -50,6 +50,7 @@ class WMMA_REGS<string Geom, string Frag
   string geom = Geom;
   string frag = Frag;
   string ptx_elt_type = PtxEltType;
+  string gft = Geom#":"#Frag#":"#ptx_elt_type;
   string ft = frag#":"#ptx_elt_type;
   list<LLVMType> regs = !cond(
     // fp16 -> fp16/fp32 @  m16n16k16/m8n32k16/m32n8k16
@@ -60,7 +61,42 @@ class WMMA_REGS<string Geom, string Frag
     !eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
     !eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
     !eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret,
-    !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret);
+    !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret,
+
+    // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+    !eq(gft,"m16n16k16:a:u8") : RepLLVMType<2, llvm_i32_ty>.ret,
+    !eq(gft,"m16n16k16:a:s8") : RepLLVMType<2, llvm_i32_ty>.ret,
+    !eq(gft,"m16n16k16:b:u8") : RepLLVMType<2, llvm_i32_ty>.ret,
+    !eq(gft,"m16n16k16:b:s8") : RepLLVMType<2, llvm_i32_ty>.ret,
+    !eq(gft,"m16n16k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+    !eq(gft,"m16n16k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+
+    !eq(gft,"m8n32k16:a:u8") : [llvm_i32_ty],
+    !eq(gft,"m8n32k16:a:s8") : [llvm_i32_ty],
+    !eq(gft,"m8n32k16:b:u8") : RepLLVMType<4, llvm_i32_ty>.ret,
+    !eq(gft,"m8n32k16:b:s8") : RepLLVMType<4, llvm_i32_ty>.ret,
+    !eq(gft,"m8n32k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+    !eq(gft,"m8n32k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+
+    !eq(gft,"m32n8k16:a:u8") : RepLLVMType<4, llvm_i32_ty>.ret,
+    !eq(gft,"m32n8k16:a:s8") : RepLLVMType<4, llvm_i32_ty>.ret,
+    !eq(gft,"m32n8k16:b:u8") : [llvm_i32_ty],
+    !eq(gft,"m32n8k16:b:s8") : [llvm_i32_ty],
+    !eq(gft,"m32n8k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+    !eq(gft,"m32n8k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+
+    // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
+    !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
+    !eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty],
+    !eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty],
+    !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
+    !eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty],
+    !eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty],
+    !eq(gft,"m8n8k128:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
+    !eq(gft,"m8n8k128:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
+    !eq(gft,"m8n8k32:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
+    !eq(gft,"m8n8k32:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
+  );
 }
 
 class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
@@ -84,22 +120,162 @@ class WMMA_NAME_LDST<string Op, WMMA_REG
                 # !if(WithStride, "_stride", "");
 }
 
-class WMMA_NAME_MMA<string ALayout, string BLayout,
-                    WMMA_REGS C, WMMA_REGS D,
-                    int Satfinite> {
+class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  list<WMMA_REGS> id_frags = !cond(
+     // int and sub-int ops are identified by input type.
+     !eq(A.ptx_elt_type, "s8") : [A],
+     !eq(A.ptx_elt_type, "u8") : [A],
+     !eq(A.ptx_elt_type, "s4") : [A],
+     !eq(A.ptx_elt_type, "u4") : [A],
+     !eq(A.ptx_elt_type, "b1") : [A],
+     // the rest are FP ops identified by accumulator & result type.
+     1: [D, C]
+     );
+   string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type));
+}
+
+class WMMA_NAME_MMA<string ALayout, string BLayout, int Satfinite,
+                    WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
   string llvm = "llvm.nvvm.wmma."
-                # C.geom
+                # A.geom
                 # ".mma"
                 # "." # ALayout
                 # "." # BLayout
-                # "." # D.ptx_elt_type  // Intrinsic encodes 'd' first.
-                # "." # C.ptx_elt_type
+                # signature
                 # !if(Satfinite, ".satfinite", "");
 
   string record = !subst(".", "_",
                   !subst("llvm.", "int_", llvm));
 }
 
+// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
+//   Geom: list of supported geometries.
+//   TypeN: PTX type of the corresponding fragment's element.
+//   TypeB and TypeD may be empty if it must match that of TypeA or TypeC.
+class MMA_OPS<list<string> Geom, list<string> TypeA, list<string> TypeB,
+            list<string> TypeC, list<string> TypeD> {
+  list<list<WMMA_REGS>> ret =
+     !foldl([]<list<WMMA_REGS>>, Geom, t1, geom, !listconcat(t1,
+     !foldl([]<list<WMMA_REGS>>, TypeA, t2, type_a, !listconcat(t2,
+     !foldl([]<list<WMMA_REGS>>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3,
+     !foldl([]<list<WMMA_REGS>>, TypeC, t4, type_c, !listconcat(t4,
+     !foldl([]<list<WMMA_REGS>>, !if(!size(TypeC), TypeC, [type_c]), t5, type_d, !listconcat(t5,
+            [[WMMA_REGS<geom, "a", type_a>,
+              WMMA_REGS<geom, "b", type_b>,
+              WMMA_REGS<geom, "c", type_c>,
+              WMMA_REGS<geom, "d", type_d>]]))))))))));
+   // Debugging aid for readable representation of the list above.
+   list<list<string>> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]);
+}
+
+class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+  list<WMMA_REGS> ret =
+     !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+     !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+     !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+            [WMMA_REGS<geom, frag, type>]))))));
+   // Debugging aid for readable representation of the list above.
+   list<string> ops = !foreach(x, ret, x.gft);
+}
+
+
+
+// Creates list of valid combinations of fragments. This is the master list that
+// drives generation of corresponding intrinsics and instructions.
+class NVVM_MMA_OPS<int _ = 0> {
+  list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
+            ["m16n16k16", "m32n8k16", "m8n32k16"],
+            ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+  list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
+            ["m16n16k16", "m32n8k16", "m8n32k16"],
+            ["s8", "u8"], [], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
+            ["m8n8k32"],
+            ["s4", "u4"], [], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
+            ["m8n8k128"],
+            ["b1"], [], ["s32"], []>.ret;
+  list<list<WMMA_REGS>> all_mma_ops = !listconcat(fp_mma_ops, int_mma_ops,
+                                                  subint_mma_ops, bit_mma_ops);
+
+  list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
+            ["m16n16k16", "m32n8k16", "m8n32k16"],
+            ["a", "b"], ["f16", "u8", "s8"]>.ret;
+  list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
+            ["m16n16k16", "m32n8k16", "m8n32k16"],
+            ["c", "d"], ["f16", "f32", "s32"]>.ret;
+  list<WMMA_REGS> ldst_subint_ab_ops = MMA_LDST_OPS<
+            ["m8n8k32"], ["a", "b"], ["s4","u4"]>.ret;
+  list<WMMA_REGS> ldst_bit_ab_ops = MMA_LDST_OPS<
+            ["m8n8k128"], ["a", "b"], ["b1"]>.ret;
+  list<WMMA_REGS> ldst_subint_cd_ops = MMA_LDST_OPS<
+            ["m8n8k32", "m8n8k128"],  ["c", "d"], ["s32"]>.ret;
+  list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
+                                             ldst_subint_ab_ops,
+                                             ldst_bit_ab_ops,
+                                             ldst_subint_cd_ops);
+  // Separate A/B/C fragments (loads) from D (stores).
+  list<WMMA_REGS> all_ld_ops = !foldl([]<WMMA_REGS>, all_ldst_ops, a, b,
+                                      !listconcat(a, !if(!eq(b.frag,"d"), [],[b])));
+  list<WMMA_REGS> all_st_ops = !foldl([]<WMMA_REGS>, all_ldst_ops, a, b,
+                                      !listconcat(a, !if(!eq(b.frag,"d"), [b],[])));
+}
+
+def NVVM_MMA_OPS : NVVM_MMA_OPS;
+
+// Returns [1] if this combination of layout/satf is supported, [] otherwise.
+// MMA ops must provide all parameters. Loads and stores -- only frags and layout_a.
+// The class is used to prevent generation of records for the unsupported variants.
+// E.g.
+// foreach _ = NVVM_MMA_SUPPORTED<...>.ret in =
+//   def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b="-", int satf=-1> {
+  // MMA ops check both layouts.
+  string mma = frags[0].ptx_elt_type
+               # ":" # layout_a
+               # ":" # layout_b;
+  // Load ops only need type/fragment/layout.
+  string ld = frags[0].ptx_elt_type
+               # ":" # frags[0].frag
+               # ":" # layout_a
+               ;
+  string ldf = frags[0].ptx_elt_type
+               # ":" # frags[0].frag
+               ;
+  string t = frags[0].ptx_elt_type;
+  list<int> ret = !cond(
+    // Sub-int MMA only supports fixed A/B layout.
+    // b1 does not support .satf.
+    !eq(mma#":"#satf, "b1:row:col:0") : [1],
+    !eq(mma, "s4:row:col") : [1],
+    !eq(mma, "u4:row:col") : [1],
+    !eq(mma, "s4:row:col") : [1],
+    !eq(mma, "u4:row:col") : [1],
+    // Sub-int load/stores have fixed layout for A and B.
+    !and(!eq(layout_b, "-"), // It's a Load or Store op
+         !or(!eq(ld,  "b1:a:row"),
+             !eq(ld,  "b1:b:col"),
+             !eq(ldf, "b1:c"),
+             !eq(ldf, "b1:d"),
+             !eq(ld, "s4:a:row"),
+             !eq(ld, "s4:b:col"),
+             !eq(ldf, "s4:c"),
+             !eq(ldf, "s4:d"),
+             !eq(ld, "u4:a:row"),
+             !eq(ld, "u4:b:col"),
+             !eq(ldf, "u4:c"),
+             !eq(ldf, "u4:d"))) : [1],
+    // All other sub-int ops are not supported.
+    !eq(t, "b1") : [],
+    !eq(t, "s4") : [],
+    !eq(t, "u4") : [],
+    // All other (non sub-int) are OK.
+    1: [1]
+  );
+}
+
 let TargetPrefix = "nvvm" in {
   def int_nvvm_prmt : GCCBuiltin<"__nvvm_prmt">,
       Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
@@ -3970,51 +4146,41 @@ class NVVM_WMMA_ST<WMMA_REGS Frag, strin
               WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>;
 
 // Create all load/store variants 
-foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
-  foreach layout = ["row", "col"] in {
-    foreach stride = [0, 1] in {
-      foreach frag = [WMMA_REGS<geom, "a", "f16">,
-                      WMMA_REGS<geom, "b", "f16">,
-                      WMMA_REGS<geom, "c", "f16">,
-                      WMMA_REGS<geom, "c", "f32">] in {
-          def WMMA_NAME_LDST<"load", frag, layout, stride>.record
+foreach layout = ["row", "col"] in {
+  foreach stride = [0, 1] in {
+    foreach frag = NVVM_MMA_OPS.all_ld_ops in
+      foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
+        def WMMA_NAME_LDST<"load", frag, layout, stride>.record
              : NVVM_WMMA_LD<frag, layout, stride>;
-      }
-      foreach frag = [WMMA_REGS<geom, "d", "f16">,
-                      WMMA_REGS<geom, "d", "f32">] in {
-          def WMMA_NAME_LDST<"store", frag, layout, stride>.record
+    foreach frag = NVVM_MMA_OPS.all_st_ops in
+      foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
+        def WMMA_NAME_LDST<"store", frag, layout, stride>.record
              : NVVM_WMMA_ST<frag, layout, stride>;
-      }
-    }
   }
 }
 
 // WMMA.MMA
-class NVVM_WMMA_MMA<string ALayout, string BLayout,
-                    WMMA_REGS C, WMMA_REGS D, int Satfinite>
+class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite,
+                    WMMA_REGS A, WMMA_REGS B,
+                    WMMA_REGS C, WMMA_REGS D>
   : Intrinsic<D.regs,
-              !listconcat(
-                WMMA_REGS<C.geom, "a", "f16">.regs,
-                WMMA_REGS<C.geom, "b", "f16">.regs,
-                C.regs),
+              !listconcat(A.regs, B.regs, C.regs),
               [IntrNoMem],
-              WMMA_NAME_MMA<ALayout, BLayout, C, D, Satfinite>.llvm>;
+              WMMA_NAME_MMA<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
 
-foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
-  foreach layout_a = ["row", "col"] in {
-    foreach layout_b = ["row", "col"] in {
-      foreach frag_c = [WMMA_REGS<geom, "c", "f16">,
-                        WMMA_REGS<geom, "c", "f32">] in {
-        foreach frag_d = [WMMA_REGS<geom, "d", "f16">,
-                          WMMA_REGS<geom, "d", "f32">] in {
-          foreach satf = [0, 1] in {
-            def WMMA_NAME_MMA<layout_a, layout_b, frag_c, frag_d, satf>.record
-             : NVVM_WMMA_MMA<layout_a, layout_b, frag_c, frag_d, satf>;
-          }
+foreach layout_a = ["row", "col"] in {
+  foreach layout_b = ["row", "col"] in {
+    foreach satf = [0, 1] in {
+      foreach op = NVVM_MMA_OPS.all_mma_ops in {
+        foreach _ = NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret in {
+          def WMMA_NAME_MMA<layout_a, layout_b, satf,
+                            op[0], op[1], op[2], op[3]>.record
+            : NVVM_WMMA_MMA<layout_a, layout_b, satf,
+                            op[0], op[1], op[2], op[3]>;
         }
       }
-    }
-  }
-}
+    } // satf
+  } // layout_b
+} // layout_a
 
 } // let TargetPrefix = "nvvm"

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Thu Apr 25 15:27:57 2019
@@ -3500,6 +3500,94 @@ bool NVPTXTargetLowering::getTgtMemIntri
     Info.align = 16;
     return true;
   }
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v2i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOLoad;
+    Info.align = 8;
+    return true;
+  }
+
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
+
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v4i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOLoad;
+    Info.align = 16;
+    return true;
+  }
+
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
+
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
+  case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
+  case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
+  case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
+  case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOLoad;
+    Info.align = 4;
+    return true;
+  }
 
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
   case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
@@ -3543,6 +3631,44 @@ bool NVPTXTargetLowering::getTgtMemIntri
     return true;
   }
 
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v8i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOLoad;
+    Info.align = 16;
+    return true;
+  }
+
+  case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
+  case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
+  case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
+  case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: {
+    Info.opc = ISD::INTRINSIC_W_CHAIN;
+    Info.memVT = MVT::v2i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOLoad;
+    Info.align = 8;
+    return true;
+  }
+
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
   case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
@@ -3585,6 +3711,44 @@ bool NVPTXTargetLowering::getTgtMemIntri
     return true;
   }
 
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
+  case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
+  case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
+  case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::v8i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = 16;
+    return true;
+  }
+
+  case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
+  case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
+  case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
+  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
+  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
+  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
+  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::v2i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = 8;
+    return true;
+  }
+
   case Intrinsic::nvvm_atomic_load_add_f32:
   case Intrinsic::nvvm_atomic_load_add_f64:
   case Intrinsic::nvvm_atomic_load_inc_32:

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td Thu Apr 25 15:27:57 2019
@@ -142,9 +142,12 @@ def true : Predicate<"true">;
 def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">;
 def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
 def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
+def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
 
 def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
 def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">;
+def hasSM72 : Predicate<"Subtarget->getSmVersion() >= 72">;
+def hasSM75 : Predicate<"Subtarget->getSmVersion() >= 75">;
 
 def useShortPtr : Predicate<"useShortPointers()">;
 def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;

Modified: llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td Thu Apr 25 15:27:57 2019
@@ -7406,12 +7406,18 @@ def INT_PTX_SREG_WARPSIZE :
 // In addition to target-independent fields provided by WMMA_REGS, it adds
 // the fields commonly used to implement specific PTX instruction -- register
 // types and names, constraints, parts of assembly, etc.
-class WMMA_REGINFO<string Geom, string Frag, string PtxEltType>
-      : WMMA_REGS<Geom, Frag, PtxEltType> {
+class WMMA_REGINFO<WMMA_REGS r>
+      : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
   // NVPTX register types used to carry fragment data.
   NVPTXRegClass regclass = !cond(
-    !eq(PtxEltType, "f16") : Float16x2Regs,
-    !eq(PtxEltType, "f32") : Float32Regs);
+    !eq(ptx_elt_type, "f16") : Float16x2Regs,
+    !eq(ptx_elt_type, "f32") : Float32Regs,
+    !eq(ptx_elt_type, "s32") : Int32Regs,
+    !eq(ptx_elt_type, "s8") : Int32Regs,
+    !eq(ptx_elt_type, "u8") : Int32Regs,
+    !eq(ptx_elt_type, "s4") : Int32Regs,
+    !eq(ptx_elt_type, "u4") : Int32Regs,
+    !eq(ptx_elt_type, "b1") : Int32Regs);
 
   // Instruction input/output arguments for the fragment.
   list<NVPTXRegClass> ptx_regs = !foreach(tmp, regs, regclass);
@@ -7433,15 +7439,27 @@ class WMMA_REGINFO<string Geom, string F
   // all fragments of the instruction are viable.
   list<Predicate> Predicates = !cond(
     // fp16 -> fp16/fp32 @ m16n16k16
-    !and(!eq(Geom, "m16n16k16"),
-         !or(!eq(PtxEltType, "f16"),
-             !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60],
+    !and(!eq(geom, "m16n16k16"),
+         !or(!eq(ptx_elt_type, "f16"),
+             !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60],
 
     // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
-    !and(!or(!eq(Geom, "m8n32k16"),
-             !eq(Geom, "m32n8k16")),
-         !or(!eq(PtxEltType, "f16"),
-             !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]);
+    !and(!or(!eq(geom, "m8n32k16"),
+             !eq(geom, "m32n8k16")),
+         !or(!eq(ptx_elt_type, "f16"),
+             !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX61],
+
+    // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+    !and(!or(!eq(geom,"m16n16k16"),
+             !eq(geom,"m8n32k16"),
+             !eq(geom,"m32n8k16")),
+         !or(!eq(ptx_elt_type, "u8"),
+             !eq(ptx_elt_type, "s8"),
+             !eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63],
+
+    // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
+    !or(!eq(geom,"m8n8k128"),
+        !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7559,44 +7577,48 @@ class WMMA_STORE_D<WMMA_REGINFO Frag, st
 
 // Create all load/store variants
 defset list<WMMA_INSTR> MMA_LDSTs  = {
-  foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
-    foreach layout = ["row", "col"] in {
-      foreach stride = [0, 1] in {
-        foreach space = [".global", ".shared", ""] in {
-          foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
-            foreach frag = [WMMA_REGINFO<geom, "a", "f16">,
-                            WMMA_REGINFO<geom, "b", "f16">,
-                            WMMA_REGINFO<geom, "c", "f16">,
-                            WMMA_REGINFO<geom, "c", "f32">] in {
-                def : WMMA_LOAD<frag, layout, space, stride, addr>;
-            }
-            foreach frag = [WMMA_REGINFO<geom, "d", "f16">,
-                            WMMA_REGINFO<geom, "d", "f32">] in {
-                def : WMMA_STORE_D<frag, layout, space, stride, addr>;
-            }
-          } // addr
-        } // space
-      } // stride
-    } // layout
-  } // geom
+  foreach layout = ["row", "col"] in {
+    foreach stride = [0, 1] in {
+      foreach space = [".global", ".shared", ""] in {
+        foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
+          foreach frag = NVVM_MMA_OPS.all_ld_ops in
+            foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
+              def : WMMA_LOAD<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+          foreach frag = NVVM_MMA_OPS.all_st_ops in
+            foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
+              def : WMMA_STORE_D<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+        } // addr
+      } // space
+    } // stride
+  } // layout
 } // defset
 
 // WMMA.MMA
 class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                WMMA_REGINFO FragC, WMMA_REGINFO FragD,
                string ALayout, string BLayout, int Satfinite>
-  : WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, FragC, FragD, Satfinite>.record,
+  : WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
                              [FragA.Ins, FragB.Ins, FragC.Ins]>,
-    Requires<FragC.Predicates> {
+    // Requires does not seem to have effect on Instruction w/o Patterns.
+    // We set it here anyways and propagate to the Pat<> we construct below.
+    Requires<FragA.Predicates> {
   let OutOperandList = FragD.Outs;
   let InOperandList  = !con(Args, (ins MmaCode:$ptx));
-  let AsmString = "wmma.mma.sync"
+  string TypeList = !cond(
+    !eq(FragD.ptx_elt_type, "s32") : ".s32"
+                                     # "." # FragA.ptx_elt_type
+                                     # "." # FragB.ptx_elt_type
+                                     # ".s32",
+    1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type,
+  );
+  let AsmString = "wmma.mma"
+                  # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
+                  # ".sync"
                   # "${ptx:aligned}"
                   # "." # ALayout
                   # "." # BLayout
                   # "." # FragA.geom
-                  # "." # FragD.ptx_elt_type
-                  # "." # FragC.ptx_elt_type
+                  # TypeList
                   # !if(Satfinite, ".satfinite", "") # "\n\t\t"
                   # FragD.regstring # ",\n\t\t"
                   # FragA.regstring # ",\n\t\t"
@@ -7605,32 +7627,32 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_
 }
 
 defset list<WMMA_INSTR> MMAs  = {
-  foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
-    foreach layout_a = ["row", "col"] in {
-      foreach layout_b = ["row", "col"] in {
-        foreach frag_c = [WMMA_REGINFO<geom, "c", "f16">,
-                          WMMA_REGINFO<geom, "c", "f32">] in {
-          foreach frag_d = [WMMA_REGINFO<geom, "d", "f16">,
-                            WMMA_REGINFO<geom, "d", "f32">] in {
-            foreach satf = [0, 1] in {
-              def : WMMA_MMA<WMMA_REGINFO<geom, "a", "f16">,
-                             WMMA_REGINFO<geom, "b", "f16">,
-                             frag_c, frag_d, layout_a, layout_b, satf>;
-            } // satf
-          } // frag_d
-        } // frag_c
-      } // layout_b
-    } // layout_a
-  } // geom
+  foreach layout_a = ["row", "col"] in {
+    foreach layout_b = ["row", "col"] in {
+      foreach satf = [0, 1] in {
+        foreach op = NVVM_MMA_OPS.all_mma_ops in {
+          foreach _ = NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret in {
+            def : WMMA_MMA<WMMA_REGINFO<op[0]>,
+                           WMMA_REGINFO<op[1]>,
+                           WMMA_REGINFO<op[2]>,
+                           WMMA_REGINFO<op[3]>,
+                           layout_a, layout_b, satf>;
+          }
+        } // op
+      } // satf
+    } // layout_b
+  } // layout_a
 } // defset
 
+
 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
 // the instruction record.
 class WMMA_PAT<WMMA_INSTR wi>
       : Pat<wi.IntrinsicPattern,
             !con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)),
-                 (wi ptx.version))>;
+                 (wi ptx.version))>,
+        Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
 foreach mma = !listconcat(MMAs, MMA_LDSTs) in

Modified: llvm/trunk/test/CodeGen/NVPTX/wmma.py
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/wmma.py?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/wmma.py (original)
+++ llvm/trunk/test/CodeGen/NVPTX/wmma.py Thu Apr 25 15:27:57 2019
@@ -1,10 +1,42 @@
 # This test generates all variants of wmma intrinsics and verifies that LLVM
 # generates correct instructions for them.
 
-# RUN: python %s > %t.ll
-# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
-# RUN: python %s --ptx=63 > %t-ptx63.ll
-# RUN: llc < %t-ptx63.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx63 | FileCheck %t-ptx63.ll
+# Check all variants of instructions supported by PTX60 on SM70
+# RUN: python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
+# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
+# RUN:           --check-prefixes=INTRINSICS,PTX60,SM70
+# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
+# RUN:           --check-prefixes=INTRINSICS,PTX60U,SM70U
+# RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
+# RUN:           | FileCheck %t-ptx60-sm_70.ll
+
+# Check all variants of instructions supported by PTX61 on SM70
+# RUN: python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
+# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
+# RUN:           --check-prefixes=INTRINSICS,PTX60,PTX61,SM70
+# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
+# RUN:           --check-prefixes=INTRINSICS,PTX61U,SM70U
+# RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
+# RUN:           | FileCheck %t-ptx61-sm_70.ll
+
+# Check all variants of instructions supported by PTX63 on SM72
+# RUN: python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
+# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
+# RUN:           --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72
+# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
+# RUN:           --check-prefixes=INTRINSICS,PTX63U,SM72U
+# RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
+# RUN:           | FileCheck %t-ptx63-sm_72.ll
+
+# Check all variants of instructions supported by PTX63 on SM75
+# RUN: python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
+# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
+# RUN:           --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72,SM75
+# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
+# RUN:           --check-prefixes=INTRINSICS,PTX63U,SM75U
+# RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
+# RUN:           | FileCheck %t-ptx63-sm_75.ll
+
 
 from __future__ import print_function
 
@@ -12,13 +44,174 @@ import argparse
 from itertools import product
 from string import Template
 
-def make_wmma_slice_ty(abcd, itype):
-  elt_ty = "<2 x half>" if itype == "f16" else "float"
-  num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
-  return [elt_ty] * num_elts
-
-def make_wmma_ld_ret_ty(abc, itype):
-  return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
+class MMAType:
+  def __init__(self, ptx_type):
+    self.ptx_type = ptx_type
+    self.llvm_type = {
+        "f16" : "<2 x half>",
+        "f32" : "float",
+        "s32" : "i32",
+        "s8"  : "i32",
+        "u8"  : "i32",
+        "s4"  : "i32",
+        "u4"  : "i32",
+        "b1"  : "i32",
+    }[ptx_type];
+
+    self.ptx_reg_pattern = {
+        "f16" : "%hh[0-9]+",
+        "f32" : "%f[0-9]+",
+    }.get(ptx_type, "%r[0-9]+")
+
+  def __repr__(self):
+    return "%s/%s" % (self.ptx_type, self.llvm_type)
+
+class MMAFrag:
+  def __init__(self, geom, frag, ptx_elt_type):
+    self.geom = geom
+    self.frag = frag
+    self.mma_type = MMAType(ptx_elt_type);
+    self.nregs = {
+        "a:f16" : 8,
+        "b:f16" : 8,
+        "c:f16" : 4,
+        "d:f16" : 4,
+        "c:f32" : 8,
+        "d:f32" : 8,
+    }.get("%s:%s" % (frag, ptx_elt_type), {
+        # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+        "m16n16k16:a:u8" : 2,
+        "m16n16k16:a:s8" : 2,
+        "m16n16k16:b:u8" : 2,
+        "m16n16k16:b:s8" : 2,
+        "m16n16k16:c:s32" : 8,
+        "m16n16k16:d:s32" : 8,
+
+        "m8n32k16:a:u8" : 1,
+        "m8n32k16:a:s8" : 1,
+        "m8n32k16:b:u8" : 4,
+        "m8n32k16:b:s8" : 4,
+        "m8n32k16:c:s32" : 8,
+        "m8n32k16:d:s32" : 8,
+
+        "m32n8k16:a:u8" : 4,
+        "m32n8k16:a:s8" : 4,
+        "m32n8k16:b:u8" : 1,
+        "m32n8k16:b:s8" : 1,
+        "m32n8k16:c:s32" : 8,
+        "m32n8k16:d:s32" : 8,
+
+        # u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
+        "m8n8k128:a:b1" : 1,
+        "m8n8k32:a:u4" : 1,
+        "m8n8k32:a:s4" : 1,
+        "m8n8k128:b:b1" : 1,
+        "m8n8k32:b:u4" : 1,
+        "m8n8k32:b:s4" : 1,
+        "m8n8k128:c:s32" : 2,
+        "m8n8k128:d:s32" : 2,
+        "m8n8k32:c:s32" : 2,
+        "m8n8k32:d:s32" : 2,
+    }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None));
+    assert(self.nregs);
+
+  def __repr__(self):
+    return "%s:%s:%s%s" % (self.geom, self.frag, self.mma_type,
+                           "" if self.nregs == 1 else ("*%d" % self.nregs))
+
+class MMAOp:
+  def __init__(self, a, b, c, d):
+    self.a = a
+    self.b = b
+    self.c = c
+    self.d = d
+
+  def __repr__(self):
+    return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d ))
+
+def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
+  ops = []
+  for geom, type_a, type_c in product( geoms,  types_a, types_c):
+    for type_b, type_d in product(types_b if types_b else [type_a],
+                                  types_d if types_d else [type_c]):
+      ops.append(MMAOp(MMAFrag(geom, "a", type_a),
+                       MMAFrag(geom, "b", type_b),
+                       MMAFrag(geom, "c", type_c),
+                       MMAFrag(geom, "d", type_d)))
+  return ops
+
+def make_ldst_ops(geoms, frags, types):
+  return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
+          in product(geoms, frags, types)]
+
+def get_mma_ops():
+  return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+                       ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
+          make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+                       ["s8", "u8"], [], ["s32"], []) +
+          make_mma_ops(["m8n8k32"],
+                       ["s4", "u4"], [], ["s32"], []) +
+          make_mma_ops(["m8n8k128"],
+                       ["b1"], [], ["s32"], []))
+def get_ldst_ops(kind):
+  ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+                            ["a", "b"], ["f16", "u8", "s8"]) +
+              make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+                            ["c", "d"], ["f16", "f32", "s32"]) +
+              make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
+              make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
+              make_ldst_ops(["m8n8k32", "m8n8k128"],  ["c", "d"], ["s32"]))
+  return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
+
+def is_geom_supported(geom):
+  # geometries for FP and ints.
+  if geom in ["m8n32k16", "m32n8k16"]:
+    return ptx_version >= 61
+  # geometries for sub-ints.
+  if geom in ["m8n8k32", "m8n8k128"]:
+    return ptx_version >= 63 and gpu_arch >= 75
+  if geom == "m16n16k16":
+    return ptx_version >= 60
+  assert(False) # Unexpected geometry.
+
+def is_type_supported(ptx_type):
+  if ptx_type in ["s8", "u8", "s32"]:
+    return ptx_version >= 63 and gpu_arch >= 72
+  if ptx_type in ["s4", "u4", "b1"]:
+    return ptx_version >= 63 and gpu_arch >= 75
+  return ptx_version >= 60 and gpu_arch >= 70
+
+
+def is_mma_variant_supported(op, layout_a, layout_b, satf):
+  if not (is_type_supported(op.a.mma_type.ptx_type)
+          and is_geom_supported(op.a.geom)):
+    return False
+  # sub-integer require row/col layout, and no satf.
+  if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
+    if op.a.mma_type.ptx_type == "b1" and satf:
+      return False
+    return layout_a == "row" and layout_b == "col"
+  return True
+
+def is_ldst_variant_supported(frag, layout):
+  if not (is_type_supported(frag.mma_type.ptx_type)
+          and is_geom_supported(frag.geom)):
+    return False
+  if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
+    # sub-integer require sm_75 and ptx63, row/col layout for a/b.
+    return ((frag.frag == "a" and layout == "row")
+            or (frag.frag == "b" and layout == "col")
+            or frag.frag in ["c", "d"])
+  return True
+
+def make_wmma_slice_ty(frag):
+  return [frag.mma_type.llvm_type] * frag.nregs
+
+def make_wmma_ld_ret_ty(frag):
+  results = make_wmma_slice_ty(frag)
+  if len(results) == 1:
+    return "%s" % results[0]
+  return "{%s}" % ", ".join(results)
 
 # returns address space
 def get_aspace(space):
@@ -36,10 +229,8 @@ def get_aspace(space):
 def get_pspace(space):
   return "p%di8" % get_aspace(space);
 
-# Convenient test patterns.
-check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
-check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
-check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
+def check_pattern(frag):
+   return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)
 
 known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]
 
@@ -69,38 +260,35 @@ define ${ret_ty} @test_${function}_o(i8
   intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
   instruction_template = "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
 
-  for geom, abc, layout, space, stride, itype in product(
-      known_geoms,
-      "abc",
+  generated_items = []
+
+  for frag, layout, space, stride in product(
+      get_ldst_ops("load"),
       ["row","col"],
       ["",".shared",".global"],
       ["", ".stride"],
-      ["f16", "f32"]):
+      ):
+    if not is_ldst_variant_supported(frag, layout):
+      continue
 
     params = {
-        "abc" : abc,
+        "abc" : frag.frag,
         "aligned" : ".aligned" if ptx_version >= 63 else "",
         "layout" : layout,
         "space" : space,
         "stride" : stride,
-        "itype" : itype,
+        "itype" : frag.mma_type.ptx_type,
         "pspace" : get_pspace(space),
         "as"     : "addrspace(%d)" % get_aspace(space),
-        "geom"   : geom,
+        "geom"   : frag.geom,
     }
 
-    if itype == "f32" and abc != "c":
-      continue
-
     test_params = params
     test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
     test_params["function"] = test_params["intrinsic"].replace(".","_")
     test_params["instruction"] = Template(instruction_template).substitute(params)
-    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
-    if abc == "c" :
-      test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
-    else:
-      test_params["check_result"] = check_f16_8
+    test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
+    test_params["check_result"] = check_pattern(frag)
 
     if stride:
       test_params["extra_args"] = ", i32 %stride";
@@ -111,9 +299,14 @@ define ${ret_ty} @test_${function}_o(i8
 
     print(Template(load_template).substitute(test_params))
 
-def make_wmma_slice_args(itype, abcd, prefix="v"):
-  return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
-                  in enumerate(make_wmma_slice_ty(abcd, itype))])
+    generated_items.append((test_params["intrinsic"],
+                            test_params["instruction"]))
+
+  return generated_items
+
+def make_wmma_slice_args(frag):
+  return ", ".join(["%s %%%s%d" % (t, frag.frag, i) for i,t
+                  in enumerate(make_wmma_slice_ty(frag))])
 
 def gen_wmma_store_tests():
   store_template = """
@@ -141,41 +334,64 @@ define void @test_${function}_o(i8 ${as}
   intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
   instruction_template = "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
 
-  for geom, abc, layout, space, stride, itype in product(
-      known_geoms,
-      "d",
+  generated_items = []
+
+  for frag, layout, space, stride in product(
+      get_ldst_ops("store"),
       ["row","col"],
       ["",".shared",".global"],
-      ["", ".stride"],
-      ["f16", "f32"]):
+      ["", ".stride"]):
+
+    if not is_ldst_variant_supported(frag, layout):
+      continue
 
     params = {
-        "abc" : abc,
+        "abc" : frag.frag,
         "aligned" : ".aligned" if ptx_version >= 63 else "",
         "layout" : layout,
         "space" : space,
         "stride" : stride,
-        "itype" : itype,
+        "itype" : frag.mma_type.ptx_type,
         "pspace" : get_pspace(space),
         "as"     : "addrspace(%d)" % get_aspace(space),
-        "geom"   : geom,
+        "geom"   : frag.geom,
     }
 
     test_params = params
     test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
     test_params["function"] = test_params["intrinsic"].replace(".","_")
     test_params["instruction"] = Template(instruction_template).substitute(params)
-    test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
-    test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
+    test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
+    test_params["check_args"] = check_pattern(frag)
     if stride:
       test_params["extra_args"] = ", i32 %stride";
       test_params["stride_pattern"] = ", %r{{[0-9]+}};"
     else:
       test_params["extra_args"] = ""
       test_params["stride_pattern"] = ";"
-    test_params["args"] = make_wmma_slice_args(itype, "d");
+    test_params["args"] = make_wmma_slice_args(frag);
 
     print(Template(store_template).substitute(test_params))
+    generated_items.append((test_params["intrinsic"],
+                            test_params["instruction"]))
+
+  return generated_items
+
+def mma_signature(op):
+  if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
+    # int and sub-int ops are identified by input type.
+    return op.a.mma_type.ptx_type
+  else:
+    # the rest are FP ops identified by accumulator & result type.
+    return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
+
+def mma_ptx_signature(op):
+  if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
+    # int and sub-int instructions encode all four types as D.A.B.C
+    return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
+  else:
+    # the rest are FP instructions use D.C
+    return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
 
 def gen_wmma_mma_tests():
   mma_template = """
@@ -187,58 +403,129 @@ define ${ret_ty} @test_${function}(
         ${args}) {
 ; CHECK: ${instruction}
 ; CHECK-NEXT: ${check_d}
-; CHECK-NEXT: ${check_ab}
-; CHECK-NEXT: ${check_ab}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
 ; CHECK-NEXT: ${check_c}
   %r = call ${ret_ty} @${intrinsic}(
         ${args});
   ret ${ret_ty} %r;
 }
 """
-  intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
-  instruction_template = "wmma.mma.sync${aligned}.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
+  intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
+  instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}"
 
-  for geom, alayout, blayout, ctype, dtype, satf in product(
-      known_geoms,
+  generated_items=[]
+
+  for op, alayout, blayout, satf in product(
+      get_mma_ops(),
       ["row","col"],
       ["row","col"],
-      ["f16", "f32"],
-      ["f16", "f32"],
       [".satfinite", ""]):
 
+    if not is_mma_variant_supported(op, alayout, blayout, satf):
+      continue
+
     params = {
         "aligned" : ".aligned" if ptx_version >= 63 else "",
         "alayout" : alayout,
         "blayout" : blayout,
-        "ctype" : ctype,
-        "dtype" : dtype,
+        "intrinsic_signature" : mma_signature(op),
+        "ptx_signature" : mma_ptx_signature(op),
         "satf"  : satf,
-        "geom"  : geom,
+        "geom"  : op.a.geom,
+        "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
     }
 
     test_params = params
     test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
     test_params["function"] = test_params["intrinsic"].replace(".", "_")
     test_params["instruction"] = Template(instruction_template).substitute(params)
-    test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
-    test_params["check_ab"] = check_f16_8
-    test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
-    test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
-    args = ",\n        ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
-                              for abcd, t in (("a", "f16"),
-                                              ("b", "f16"),
-                                              ("c", ctype)))
+    test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+    test_params["check_a"] = check_pattern(op.a)
+    test_params["check_b"] = check_pattern(op.b)
+    test_params["check_c"] = check_pattern(op.c)
+    test_params["check_d"] = check_pattern(op.d)
+    args = ",\n        ".join(make_wmma_slice_args(frag)
+                              for frag in (op.a, op.b, op.c))
     test_params["args"] = args
     print(Template(mma_template).substitute(test_params))
+    generated_items.append((test_params["intrinsic"],
+                            test_params["instruction"]))
+
+  return generated_items
 
-def main():
-  gen_wmma_load_tests()
-  gen_wmma_store_tests()
-  gen_wmma_mma_tests()
+# Append complete list of intrinsics and instructions we've generated tests for.
+# Generate set of checks to verify that that we did generate sensible set of
+# tests for the given combination of PTX and SM variants.
+#
+# PTX<N>: verifies that we did generate tests for correct classes of intrinsics.
+# PTX<N>U: verifies that we did not generate intrinsics unsupported by
+#          the PTX version.
+# SM<N>: verifies that we did generate correct classes of instructions for the SM.
+# SM<N>U: verifies that we did not generate instructions unsupported by the SM
+#
+# Note that SM/PTX constraints overlap, but DAG checks do not allow overlapping
+# matches. We implicitly rely that we generate multiple variants of most of the
+# instructions and usually have enough input data to find more than one match of
+# the same kind, if necessary. When it's not possible (e.g. there's only one
+# m8n8k128.mma.row.col.b1), we may need to match PTX instruction instead.
+def gen_check_unsupported_ops(items):
+  print("; Complete list of intrinsics supported by PTX%d on sm_%d"
+        % (ptx_version, gpu_arch))
+  print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
+  print("""
+; PTX60-DAG: m16n16k16.load.{{[ab].*}}.f16.p
+; PTX60-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
+; PTX60U-NOT: m32n8k16
+; PTX60U-NOT: m8n32k16
+; PTX60U-NOT: .{{s32|s[48]|u[48]|b1}}
+
+; All features of PTX60, plus m32n8k16/m8n32k16 geometries.
+; PTX61-DAG: m32n8k16.load.{{[ab].*}}.f16.p
+; PTX61-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
+; PTX61-DAG: m8n32k16.load.{{[ab].*}}.f16.p
+; PTX61-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
+; PTX61U-NOT: .{{s32|s[48]|u[48]|b1}}
+
+; SM70U-NOT: .{{s32|s[48]|u[48]|b1}}
+
+; PTX63 supports all features of PTX60+PTX61, plus support for integers.
+; Alas we can"t just use PTX<N> checks for that as available instructions
+; depend on SM integers need sm72+ and subinteger ops need sm75, so we
+; transition to SM<N> checks
+; SM72-DAG: m16n16k16.load.{{[ab].*}}.s8.p
+; SM72-DAG: m8n32k16.load.{{[ab].*}}.s8.p
+; SM72-DAG: m32n8k16.load.{{[ab].*}}.s8.p
+; SM72-DAG: m16n16k16.load.{{[ab].*}}.u8.p
+; SM72-DAG: m8n32k16.load.{{[ab].*}}.u8.p
+; SM72-DAG: m32n8k16.load.{{[ab].*}}.u8.p
+; SM72-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
+; SM72U-NOT: .{{s4|u4|b1}}
+
+; SM75-DAG: m8n8k128.load.{{[ab].*}}.b1.p
+; SM75-DAG: m8n8k32.load.{{[ab].*}}.s4.p
+; SM75-DAG: m8n8k32.load.{{[ab].*}}.u4.p
+; SM75-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
+; SM75-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
+""")
+
+  print("; INTRINSICS_LIST_BEGIN")
+  for intrinsic, instruction in sorted(items):
+    print("; ", intrinsic, " -> ", instruction,"")
+  print("; INTRINSICS_LIST_END")
+  print("; INTRINSICS: ; INTRINSICS_LIST_END")
+
+def gen_tests():
+  items = gen_wmma_load_tests()
+  items += gen_wmma_store_tests()
+  items += gen_wmma_mma_tests()
+  gen_check_unsupported_ops(items)
 
 parser = argparse.ArgumentParser()
-parser.add_argument('--ptx', type=int, default=60)
+parser.add_argument("--ptx", type=int, default=60)
+parser.add_argument("--gpu-arch", type=int, default=70)
 args = parser.parse_args()
 ptx_version = args.ptx
+gpu_arch = args.gpu_arch
 
-main()
+gen_tests()




More information about the llvm-commits mailing list