[llvm] r359247 - PTX 6.3 extends `wmma` instruction to support s8/u8/s4/u4/b1 -> s32.
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 25 15:27:58 PDT 2019
Author: tra
Date: Thu Apr 25 15:27:57 2019
New Revision: 359247
URL: http://llvm.org/viewvc/llvm-project?rev=359247&view=rev
Log:
PTX 6.3 extends `wmma` instruction to support s8/u8/s4/u4/b1 -> s32.
All of the new instructions are still handled mostly by tablegen. I've slightly
refactored the code to drive intrinsic/instruction generation from a master
list of supported variants, so all irregularities have to be implemented in one place only.
The test generation script wmma.py has been refactored in a similar way.
Differential Revision: https://reviews.llvm.org/D60015
Modified:
llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/trunk/test/CodeGen/NVPTX/wmma.py
Modified: llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td (original)
+++ llvm/trunk/include/llvm/IR/IntrinsicsNVVM.td Thu Apr 25 15:27:57 2019
@@ -50,6 +50,7 @@ class WMMA_REGS<string Geom, string Frag
string geom = Geom;
string frag = Frag;
string ptx_elt_type = PtxEltType;
+ string gft = Geom#":"#Frag#":"#ptx_elt_type;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
// fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
@@ -60,7 +61,42 @@ class WMMA_REGS<string Geom, string Frag
!eq(ft,"c:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
!eq(ft,"d:f16") : RepLLVMType<4, llvm_v2f16_ty>.ret,
!eq(ft,"c:f32") : RepLLVMType<8, llvm_float_ty>.ret,
- !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret);
+ !eq(ft,"d:f32") : RepLLVMType<8, llvm_float_ty>.ret,
+
+ // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+ !eq(gft,"m16n16k16:a:u8") : RepLLVMType<2, llvm_i32_ty>.ret,
+ !eq(gft,"m16n16k16:a:s8") : RepLLVMType<2, llvm_i32_ty>.ret,
+ !eq(gft,"m16n16k16:b:u8") : RepLLVMType<2, llvm_i32_ty>.ret,
+ !eq(gft,"m16n16k16:b:s8") : RepLLVMType<2, llvm_i32_ty>.ret,
+ !eq(gft,"m16n16k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+ !eq(gft,"m16n16k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+
+ !eq(gft,"m8n32k16:a:u8") : [llvm_i32_ty],
+ !eq(gft,"m8n32k16:a:s8") : [llvm_i32_ty],
+ !eq(gft,"m8n32k16:b:u8") : RepLLVMType<4, llvm_i32_ty>.ret,
+ !eq(gft,"m8n32k16:b:s8") : RepLLVMType<4, llvm_i32_ty>.ret,
+ !eq(gft,"m8n32k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+ !eq(gft,"m8n32k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+
+ !eq(gft,"m32n8k16:a:u8") : RepLLVMType<4, llvm_i32_ty>.ret,
+ !eq(gft,"m32n8k16:a:s8") : RepLLVMType<4, llvm_i32_ty>.ret,
+ !eq(gft,"m32n8k16:b:u8") : [llvm_i32_ty],
+ !eq(gft,"m32n8k16:b:s8") : [llvm_i32_ty],
+ !eq(gft,"m32n8k16:c:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+ !eq(gft,"m32n8k16:d:s32") : RepLLVMType<8, llvm_i32_ty>.ret,
+
+ // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
+ !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
+ !eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty],
+ !eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty],
+ !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
+ !eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty],
+ !eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty],
+ !eq(gft,"m8n8k128:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
+ !eq(gft,"m8n8k128:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
+ !eq(gft,"m8n8k32:c:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
+ !eq(gft,"m8n8k32:d:s32") : RepLLVMType<2, llvm_i32_ty>.ret,
+ );
}
class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
@@ -84,22 +120,162 @@ class WMMA_NAME_LDST<string Op, WMMA_REG
# !if(WithStride, "_stride", "");
}
-class WMMA_NAME_MMA<string ALayout, string BLayout,
- WMMA_REGS C, WMMA_REGS D,
- int Satfinite> {
+class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ list<WMMA_REGS> id_frags = !cond(
+ // int and sub-int ops are identified by input type.
+ !eq(A.ptx_elt_type, "s8") : [A],
+ !eq(A.ptx_elt_type, "u8") : [A],
+ !eq(A.ptx_elt_type, "s4") : [A],
+ !eq(A.ptx_elt_type, "u4") : [A],
+ !eq(A.ptx_elt_type, "b1") : [A],
+ // the rest are FP ops identified by accumulator & result type.
+ 1: [D, C]
+ );
+ string ret = !foldl("", id_frags, a, b, !strconcat(a, ".", b.ptx_elt_type));
+}
+
+class WMMA_NAME_MMA<string ALayout, string BLayout, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string llvm = "llvm.nvvm.wmma."
- # C.geom
+ # A.geom
# ".mma"
# "." # ALayout
# "." # BLayout
- # "." # D.ptx_elt_type // Intrinsic encodes 'd' first.
- # "." # C.ptx_elt_type
+ # signature
# !if(Satfinite, ".satfinite", "");
string record = !subst(".", "_",
!subst("llvm.", "int_", llvm));
}
+// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
+// Geom: list of supported geometries.
+// TypeN: PTX type of the corresponding fragment's element.
+// TypeB and TypeD may be empty if it must match that of TypeA or TypeC.
+class MMA_OPS<list<string> Geom, list<string> TypeA, list<string> TypeB,
+ list<string> TypeC, list<string> TypeD> {
+ list<list<WMMA_REGS>> ret =
+ !foldl([]<list<WMMA_REGS>>, Geom, t1, geom, !listconcat(t1,
+ !foldl([]<list<WMMA_REGS>>, TypeA, t2, type_a, !listconcat(t2,
+ !foldl([]<list<WMMA_REGS>>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3,
+ !foldl([]<list<WMMA_REGS>>, TypeC, t4, type_c, !listconcat(t4,
+ !foldl([]<list<WMMA_REGS>>, !if(!size(TypeC), TypeC, [type_c]), t5, type_d, !listconcat(t5,
+ [[WMMA_REGS<geom, "a", type_a>,
+ WMMA_REGS<geom, "b", type_b>,
+ WMMA_REGS<geom, "c", type_c>,
+ WMMA_REGS<geom, "d", type_d>]]))))))))));
+ // Debugging aid for readable representation of the list above.
+ list<list<string>> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]);
+}
+
+class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+ list<WMMA_REGS> ret =
+ !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+ !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+ !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+ [WMMA_REGS<geom, frag, type>]))))));
+ // Debugging aid for readable representation of the list above.
+ list<string> ops = !foreach(x, ret, x.gft);
+}
+
+
+
+// Creates list of valid combinations of fragments. This is the master list that
+// drives generation of corresponding intrinsics and instructions.
+class NVVM_MMA_OPS<int _ = 0> {
+ list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
+ ["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"]>.ret;
+ list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
+ ["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["s8", "u8"], [], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
+ ["m8n8k32"],
+ ["s4", "u4"], [], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> bit_mma_ops = MMA_OPS<
+ ["m8n8k128"],
+ ["b1"], [], ["s32"], []>.ret;
+ list<list<WMMA_REGS>> all_mma_ops = !listconcat(fp_mma_ops, int_mma_ops,
+ subint_mma_ops, bit_mma_ops);
+
+ list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
+ ["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["a", "b"], ["f16", "u8", "s8"]>.ret;
+ list<WMMA_REGS> ldst_cd_ops = MMA_LDST_OPS<
+ ["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["c", "d"], ["f16", "f32", "s32"]>.ret;
+ list<WMMA_REGS> ldst_subint_ab_ops = MMA_LDST_OPS<
+ ["m8n8k32"], ["a", "b"], ["s4","u4"]>.ret;
+ list<WMMA_REGS> ldst_bit_ab_ops = MMA_LDST_OPS<
+ ["m8n8k128"], ["a", "b"], ["b1"]>.ret;
+ list<WMMA_REGS> ldst_subint_cd_ops = MMA_LDST_OPS<
+ ["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]>.ret;
+ list<WMMA_REGS> all_ldst_ops = !listconcat(ldst_ab_ops, ldst_cd_ops,
+ ldst_subint_ab_ops,
+ ldst_bit_ab_ops,
+ ldst_subint_cd_ops);
+ // Separate A/B/C fragments (loads) from D (stores).
+ list<WMMA_REGS> all_ld_ops = !foldl([]<WMMA_REGS>, all_ldst_ops, a, b,
+ !listconcat(a, !if(!eq(b.frag,"d"), [],[b])));
+ list<WMMA_REGS> all_st_ops = !foldl([]<WMMA_REGS>, all_ldst_ops, a, b,
+ !listconcat(a, !if(!eq(b.frag,"d"), [b],[])));
+}
+
+def NVVM_MMA_OPS : NVVM_MMA_OPS;
+
+// Returns [1] if this combination of layout/satf is supported, [] otherwise.
+// MMA ops must provide all parameters. Loads and stores -- only frags and layout_a.
+// The class is used to prevent generation of records for the unsupported variants.
+// E.g.
+// foreach _ = NVVM_MMA_SUPPORTED<...>.ret in =
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b="-", int satf=-1> {
+ // MMA ops check both layouts.
+ string mma = frags[0].ptx_elt_type
+ # ":" # layout_a
+ # ":" # layout_b;
+ // Load ops only need type/fragment/layout.
+ string ld = frags[0].ptx_elt_type
+ # ":" # frags[0].frag
+ # ":" # layout_a
+ ;
+ string ldf = frags[0].ptx_elt_type
+ # ":" # frags[0].frag
+ ;
+ string t = frags[0].ptx_elt_type;
+ list<int> ret = !cond(
+ // Sub-int MMA only supports fixed A/B layout.
+ // b1 does not support .satf.
+ !eq(mma#":"#satf, "b1:row:col:0") : [1],
+ !eq(mma, "s4:row:col") : [1],
+ !eq(mma, "u4:row:col") : [1],
+ !eq(mma, "s4:row:col") : [1],
+ !eq(mma, "u4:row:col") : [1],
+ // Sub-int load/stores have fixed layout for A and B.
+ !and(!eq(layout_b, "-"), // It's a Load or Store op
+ !or(!eq(ld, "b1:a:row"),
+ !eq(ld, "b1:b:col"),
+ !eq(ldf, "b1:c"),
+ !eq(ldf, "b1:d"),
+ !eq(ld, "s4:a:row"),
+ !eq(ld, "s4:b:col"),
+ !eq(ldf, "s4:c"),
+ !eq(ldf, "s4:d"),
+ !eq(ld, "u4:a:row"),
+ !eq(ld, "u4:b:col"),
+ !eq(ldf, "u4:c"),
+ !eq(ldf, "u4:d"))) : [1],
+ // All other sub-int ops are not supported.
+ !eq(t, "b1") : [],
+ !eq(t, "s4") : [],
+ !eq(t, "u4") : [],
+ // All other (non sub-int) are OK.
+ 1: [1]
+ );
+}
+
let TargetPrefix = "nvvm" in {
def int_nvvm_prmt : GCCBuiltin<"__nvvm_prmt">,
Intrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
@@ -3970,51 +4146,41 @@ class NVVM_WMMA_ST<WMMA_REGS Frag, strin
WMMA_NAME_LDST<"store", Frag, Layout, WithStride>.intr>;
// Create all load/store variants
-foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
- foreach layout = ["row", "col"] in {
- foreach stride = [0, 1] in {
- foreach frag = [WMMA_REGS<geom, "a", "f16">,
- WMMA_REGS<geom, "b", "f16">,
- WMMA_REGS<geom, "c", "f16">,
- WMMA_REGS<geom, "c", "f32">] in {
- def WMMA_NAME_LDST<"load", frag, layout, stride>.record
+foreach layout = ["row", "col"] in {
+ foreach stride = [0, 1] in {
+ foreach frag = NVVM_MMA_OPS.all_ld_ops in
+ foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
+ def WMMA_NAME_LDST<"load", frag, layout, stride>.record
: NVVM_WMMA_LD<frag, layout, stride>;
- }
- foreach frag = [WMMA_REGS<geom, "d", "f16">,
- WMMA_REGS<geom, "d", "f32">] in {
- def WMMA_NAME_LDST<"store", frag, layout, stride>.record
+ foreach frag = NVVM_MMA_OPS.all_st_ops in
+ foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
+ def WMMA_NAME_LDST<"store", frag, layout, stride>.record
: NVVM_WMMA_ST<frag, layout, stride>;
- }
- }
}
}
// WMMA.MMA
-class NVVM_WMMA_MMA<string ALayout, string BLayout,
- WMMA_REGS C, WMMA_REGS D, int Satfinite>
+class NVVM_WMMA_MMA<string ALayout, string BLayout, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B,
+ WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
- !listconcat(
- WMMA_REGS<C.geom, "a", "f16">.regs,
- WMMA_REGS<C.geom, "b", "f16">.regs,
- C.regs),
+ !listconcat(A.regs, B.regs, C.regs),
[IntrNoMem],
- WMMA_NAME_MMA<ALayout, BLayout, C, D, Satfinite>.llvm>;
+ WMMA_NAME_MMA<ALayout, BLayout, Satfinite, A, B, C, D>.llvm>;
-foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
- foreach layout_a = ["row", "col"] in {
- foreach layout_b = ["row", "col"] in {
- foreach frag_c = [WMMA_REGS<geom, "c", "f16">,
- WMMA_REGS<geom, "c", "f32">] in {
- foreach frag_d = [WMMA_REGS<geom, "d", "f16">,
- WMMA_REGS<geom, "d", "f32">] in {
- foreach satf = [0, 1] in {
- def WMMA_NAME_MMA<layout_a, layout_b, frag_c, frag_d, satf>.record
- : NVVM_WMMA_MMA<layout_a, layout_b, frag_c, frag_d, satf>;
- }
+foreach layout_a = ["row", "col"] in {
+ foreach layout_b = ["row", "col"] in {
+ foreach satf = [0, 1] in {
+ foreach op = NVVM_MMA_OPS.all_mma_ops in {
+ foreach _ = NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret in {
+ def WMMA_NAME_MMA<layout_a, layout_b, satf,
+ op[0], op[1], op[2], op[3]>.record
+ : NVVM_WMMA_MMA<layout_a, layout_b, satf,
+ op[0], op[1], op[2], op[3]>;
}
}
- }
- }
-}
+ } // satf
+ } // layout_b
+} // layout_a
} // let TargetPrefix = "nvvm"
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXISelLowering.cpp Thu Apr 25 15:27:57 2019
@@ -3500,6 +3500,94 @@ bool NVPTXTargetLowering::getTgtMemIntri
Info.align = 16;
return true;
}
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_col:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_s8_row_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_a_u8_row:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_col:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_s8_row_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_b_u8_row: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v2i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOLoad;
+ Info.align = 8;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_s8_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_a_u8_row:
+
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_s8_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_b_u8_row: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v4i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOLoad;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_s8_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_b_u8_row:
+
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_s8_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_a_u8_row:
+ case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row:
+ case Intrinsic::nvvm_wmma_m8n8k128_load_a_b1_row_stride:
+ case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col:
+ case Intrinsic::nvvm_wmma_m8n8k128_load_b_b1_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_a_s4_row_stride:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row_stride:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_a_u4_row:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOLoad;
+ Info.align = 4;
+ return true;
+ }
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row:
@@ -3543,6 +3631,44 @@ bool NVPTXTargetLowering::getTgtMemIntri
return true;
}
+ case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row:
+ case Intrinsic::nvvm_wmma_m16n16k16_load_c_s32_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_load_c_s32_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_load_c_s32_row_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v8i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOLoad;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col:
+ case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row:
+ case Intrinsic::nvvm_wmma_m8n8k128_load_c_s32_row_stride:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
+ case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: {
+ Info.opc = ISD::INTRINSIC_W_CHAIN;
+ Info.memVT = MVT::v2i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOLoad;
+ Info.align = 8;
+ return true;
+ }
+
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row:
case Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride:
@@ -3585,6 +3711,44 @@ bool NVPTXTargetLowering::getTgtMemIntri
return true;
}
+ case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col:
+ case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row:
+ case Intrinsic::nvvm_wmma_m16n16k16_store_d_s32_row_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row:
+ case Intrinsic::nvvm_wmma_m32n8k16_store_d_s32_row_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row:
+ case Intrinsic::nvvm_wmma_m8n32k16_store_d_s32_row_stride: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::v8i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align = 16;
+ return true;
+ }
+
+ case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col:
+ case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row:
+ case Intrinsic::nvvm_wmma_m8n8k128_store_d_s32_row_stride:
+ case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
+ case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
+ case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
+ case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::v2i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align = 8;
+ return true;
+ }
+
case Intrinsic::nvvm_atomic_load_add_f32:
case Intrinsic::nvvm_atomic_load_add_f64:
case Intrinsic::nvvm_atomic_load_inc_32:
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXInstrInfo.td Thu Apr 25 15:27:57 2019
@@ -142,9 +142,12 @@ def true : Predicate<"true">;
def hasPTX31 : Predicate<"Subtarget->getPTXVersion() >= 31">;
def hasPTX60 : Predicate<"Subtarget->getPTXVersion() >= 60">;
def hasPTX61 : Predicate<"Subtarget->getPTXVersion() >= 61">;
+def hasPTX63 : Predicate<"Subtarget->getPTXVersion() >= 63">;
def hasSM30 : Predicate<"Subtarget->getSmVersion() >= 30">;
def hasSM70 : Predicate<"Subtarget->getSmVersion() >= 70">;
+def hasSM72 : Predicate<"Subtarget->getSmVersion() >= 72">;
+def hasSM75 : Predicate<"Subtarget->getSmVersion() >= 75">;
def useShortPtr : Predicate<"useShortPointers()">;
def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
Modified: llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td (original)
+++ llvm/trunk/lib/Target/NVPTX/NVPTXIntrinsics.td Thu Apr 25 15:27:57 2019
@@ -7406,12 +7406,18 @@ def INT_PTX_SREG_WARPSIZE :
// In addition to target-independent fields provided by WMMA_REGS, it adds
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
-class WMMA_REGINFO<string Geom, string Frag, string PtxEltType>
- : WMMA_REGS<Geom, Frag, PtxEltType> {
+class WMMA_REGINFO<WMMA_REGS r>
+ : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
- !eq(PtxEltType, "f16") : Float16x2Regs,
- !eq(PtxEltType, "f32") : Float32Regs);
+ !eq(ptx_elt_type, "f16") : Float16x2Regs,
+ !eq(ptx_elt_type, "f32") : Float32Regs,
+ !eq(ptx_elt_type, "s32") : Int32Regs,
+ !eq(ptx_elt_type, "s8") : Int32Regs,
+ !eq(ptx_elt_type, "u8") : Int32Regs,
+ !eq(ptx_elt_type, "s4") : Int32Regs,
+ !eq(ptx_elt_type, "u4") : Int32Regs,
+ !eq(ptx_elt_type, "b1") : Int32Regs);
// Instruction input/output arguments for the fragment.
list<NVPTXRegClass> ptx_regs = !foreach(tmp, regs, regclass);
@@ -7433,15 +7439,27 @@ class WMMA_REGINFO<string Geom, string F
// all fragments of the instruction are viable.
list<Predicate> Predicates = !cond(
// fp16 -> fp16/fp32 @ m16n16k16
- !and(!eq(Geom, "m16n16k16"),
- !or(!eq(PtxEltType, "f16"),
- !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX60],
+ !and(!eq(geom, "m16n16k16"),
+ !or(!eq(ptx_elt_type, "f16"),
+ !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX60],
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
- !and(!or(!eq(Geom, "m8n32k16"),
- !eq(Geom, "m32n8k16")),
- !or(!eq(PtxEltType, "f16"),
- !eq(PtxEltType, "f32"))) : [hasSM70, hasPTX61]);
+ !and(!or(!eq(geom, "m8n32k16"),
+ !eq(geom, "m32n8k16")),
+ !or(!eq(ptx_elt_type, "f16"),
+ !eq(ptx_elt_type, "f32"))) : [hasSM70, hasPTX61],
+
+ // u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+ !and(!or(!eq(geom,"m16n16k16"),
+ !eq(geom,"m8n32k16"),
+ !eq(geom,"m32n8k16")),
+ !or(!eq(ptx_elt_type, "u8"),
+ !eq(ptx_elt_type, "s8"),
+ !eq(ptx_elt_type, "s32"))) : [hasSM72, hasPTX63],
+
+ // u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
+ !or(!eq(geom,"m8n8k128"),
+ !eq(geom,"m8n8k32")) : [hasSM75, hasPTX63]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7559,44 +7577,48 @@ class WMMA_STORE_D<WMMA_REGINFO Frag, st
// Create all load/store variants
defset list<WMMA_INSTR> MMA_LDSTs = {
- foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
- foreach layout = ["row", "col"] in {
- foreach stride = [0, 1] in {
- foreach space = [".global", ".shared", ""] in {
- foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
- foreach frag = [WMMA_REGINFO<geom, "a", "f16">,
- WMMA_REGINFO<geom, "b", "f16">,
- WMMA_REGINFO<geom, "c", "f16">,
- WMMA_REGINFO<geom, "c", "f32">] in {
- def : WMMA_LOAD<frag, layout, space, stride, addr>;
- }
- foreach frag = [WMMA_REGINFO<geom, "d", "f16">,
- WMMA_REGINFO<geom, "d", "f32">] in {
- def : WMMA_STORE_D<frag, layout, space, stride, addr>;
- }
- } // addr
- } // space
- } // stride
- } // layout
- } // geom
+ foreach layout = ["row", "col"] in {
+ foreach stride = [0, 1] in {
+ foreach space = [".global", ".shared", ""] in {
+ foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
+ foreach frag = NVVM_MMA_OPS.all_ld_ops in
+ foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
+ def : WMMA_LOAD<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+ foreach frag = NVVM_MMA_OPS.all_st_ops in
+ foreach _ = NVVM_MMA_SUPPORTED<[frag], layout>.ret in
+ def : WMMA_STORE_D<WMMA_REGINFO<frag>, layout, space, stride, addr>;
+ } // addr
+ } // space
+ } // stride
+ } // layout
} // defset
// WMMA.MMA
class WMMA_MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
string ALayout, string BLayout, int Satfinite>
- : WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, FragC, FragD, Satfinite>.record,
+ : WMMA_INSTR<WMMA_NAME_MMA<ALayout, BLayout, Satfinite, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
- Requires<FragC.Predicates> {
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<FragA.Predicates> {
let OutOperandList = FragD.Outs;
let InOperandList = !con(Args, (ins MmaCode:$ptx));
- let AsmString = "wmma.mma.sync"
+ string TypeList = !cond(
+ !eq(FragD.ptx_elt_type, "s32") : ".s32"
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # ".s32",
+ 1: "." # FragD.ptx_elt_type # "." # FragC.ptx_elt_type,
+ );
+ let AsmString = "wmma.mma"
+ # !if(!eq(FragA.ptx_elt_type, "b1"), ".xor.popc", "")
+ # ".sync"
# "${ptx:aligned}"
# "." # ALayout
# "." # BLayout
# "." # FragA.geom
- # "." # FragD.ptx_elt_type
- # "." # FragC.ptx_elt_type
+ # TypeList
# !if(Satfinite, ".satfinite", "") # "\n\t\t"
# FragD.regstring # ",\n\t\t"
# FragA.regstring # ",\n\t\t"
@@ -7605,32 +7627,32 @@ class WMMA_MMA<WMMA_REGINFO FragA, WMMA_
}
defset list<WMMA_INSTR> MMAs = {
- foreach geom = ["m16n16k16", "m32n8k16", "m8n32k16" ] in {
- foreach layout_a = ["row", "col"] in {
- foreach layout_b = ["row", "col"] in {
- foreach frag_c = [WMMA_REGINFO<geom, "c", "f16">,
- WMMA_REGINFO<geom, "c", "f32">] in {
- foreach frag_d = [WMMA_REGINFO<geom, "d", "f16">,
- WMMA_REGINFO<geom, "d", "f32">] in {
- foreach satf = [0, 1] in {
- def : WMMA_MMA<WMMA_REGINFO<geom, "a", "f16">,
- WMMA_REGINFO<geom, "b", "f16">,
- frag_c, frag_d, layout_a, layout_b, satf>;
- } // satf
- } // frag_d
- } // frag_c
- } // layout_b
- } // layout_a
- } // geom
+ foreach layout_a = ["row", "col"] in {
+ foreach layout_b = ["row", "col"] in {
+ foreach satf = [0, 1] in {
+ foreach op = NVVM_MMA_OPS.all_mma_ops in {
+ foreach _ = NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret in {
+ def : WMMA_MMA<WMMA_REGINFO<op[0]>,
+ WMMA_REGINFO<op[1]>,
+ WMMA_REGINFO<op[2]>,
+ WMMA_REGINFO<op[3]>,
+ layout_a, layout_b, satf>;
+ }
+ } // op
+ } // satf
+ } // layout_b
+ } // layout_a
} // defset
+
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
// the instruction record.
class WMMA_PAT<WMMA_INSTR wi>
: Pat<wi.IntrinsicPattern,
!con(!foreach(tmp, wi.Args, !subst(ins, wi, tmp)),
- (wi ptx.version))>;
+ (wi ptx.version))>,
+ Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
foreach mma = !listconcat(MMAs, MMA_LDSTs) in
Modified: llvm/trunk/test/CodeGen/NVPTX/wmma.py
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/CodeGen/NVPTX/wmma.py?rev=359247&r1=359246&r2=359247&view=diff
==============================================================================
--- llvm/trunk/test/CodeGen/NVPTX/wmma.py (original)
+++ llvm/trunk/test/CodeGen/NVPTX/wmma.py Thu Apr 25 15:27:57 2019
@@ -1,10 +1,42 @@
# This test generates all variants of wmma intrinsics and verifies that LLVM
# generates correct instructions for them.
-# RUN: python %s > %t.ll
-# RUN: llc < %t.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 | FileCheck %t.ll
-# RUN: python %s --ptx=63 > %t-ptx63.ll
-# RUN: llc < %t-ptx63.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx63 | FileCheck %t-ptx63.ll
+# Check all variants of instructions supported by PTX60 on SM70
+# RUN: python %s --ptx=60 --gpu-arch=70 > %t-ptx60-sm_70.ll
+# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
+# RUN: --check-prefixes=INTRINSICS,PTX60,SM70
+# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
+# RUN: --check-prefixes=INTRINSICS,PTX60U,SM70U
+# RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
+# RUN: | FileCheck %t-ptx60-sm_70.ll
+
+# Check all variants of instructions supported by PTX61 on SM70
+# RUN: python %s --ptx=61 --gpu-arch=70 > %t-ptx61-sm_70.ll
+# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
+# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,SM70
+# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
+# RUN: --check-prefixes=INTRINSICS,PTX61U,SM70U
+# RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
+# RUN: | FileCheck %t-ptx61-sm_70.ll
+
+# Check all variants of instructions supported by PTX63 on SM72
+# RUN: python %s --ptx=63 --gpu-arch=72 > %t-ptx63-sm_72.ll
+# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
+# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72
+# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
+# RUN: --check-prefixes=INTRINSICS,PTX63U,SM72U
+# RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
+# RUN: | FileCheck %t-ptx63-sm_72.ll
+
+# Check all variants of instructions supported by PTX63 on SM75
+# RUN: python %s --ptx=63 --gpu-arch=75 > %t-ptx63-sm_75.ll
+# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
+# RUN: --check-prefixes=INTRINSICS,PTX60,PTX61,PTX63,SM70,SM72,SM75
+# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
+# RUN: --check-prefixes=INTRINSICS,PTX63U,SM75U
+# RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
+# RUN: | FileCheck %t-ptx63-sm_75.ll
+
from __future__ import print_function
@@ -12,13 +44,174 @@ import argparse
from itertools import product
from string import Template
-def make_wmma_slice_ty(abcd, itype):
- elt_ty = "<2 x half>" if itype == "f16" else "float"
- num_elts = 4 if abcd in "cd" and itype == "f16" else 8;
- return [elt_ty] * num_elts
-
-def make_wmma_ld_ret_ty(abc, itype):
- return "{%s}" % ", ".join(make_wmma_slice_ty(abc, itype))
+class MMAType:
+ def __init__(self, ptx_type):
+ self.ptx_type = ptx_type
+ self.llvm_type = {
+ "f16" : "<2 x half>",
+ "f32" : "float",
+ "s32" : "i32",
+ "s8" : "i32",
+ "u8" : "i32",
+ "s4" : "i32",
+ "u4" : "i32",
+ "b1" : "i32",
+ }[ptx_type];
+
+ self.ptx_reg_pattern = {
+ "f16" : "%hh[0-9]+",
+ "f32" : "%f[0-9]+",
+ }.get(ptx_type, "%r[0-9]+")
+
+ def __repr__(self):
+ return "%s/%s" % (self.ptx_type, self.llvm_type)
+
+class MMAFrag:
+ def __init__(self, geom, frag, ptx_elt_type):
+ self.geom = geom
+ self.frag = frag
+ self.mma_type = MMAType(ptx_elt_type);
+ self.nregs = {
+ "a:f16" : 8,
+ "b:f16" : 8,
+ "c:f16" : 4,
+ "d:f16" : 4,
+ "c:f32" : 8,
+ "d:f32" : 8,
+ }.get("%s:%s" % (frag, ptx_elt_type), {
+ # u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+ "m16n16k16:a:u8" : 2,
+ "m16n16k16:a:s8" : 2,
+ "m16n16k16:b:u8" : 2,
+ "m16n16k16:b:s8" : 2,
+ "m16n16k16:c:s32" : 8,
+ "m16n16k16:d:s32" : 8,
+
+ "m8n32k16:a:u8" : 1,
+ "m8n32k16:a:s8" : 1,
+ "m8n32k16:b:u8" : 4,
+ "m8n32k16:b:s8" : 4,
+ "m8n32k16:c:s32" : 8,
+ "m8n32k16:d:s32" : 8,
+
+ "m32n8k16:a:u8" : 4,
+ "m32n8k16:a:s8" : 4,
+ "m32n8k16:b:u8" : 1,
+ "m32n8k16:b:s8" : 1,
+ "m32n8k16:c:s32" : 8,
+ "m32n8k16:d:s32" : 8,
+
+ # u4/s4/b1 -> s32 @ m8n8k32 (u4/s4), m8n8k128(b1)
+ "m8n8k128:a:b1" : 1,
+ "m8n8k32:a:u4" : 1,
+ "m8n8k32:a:s4" : 1,
+ "m8n8k128:b:b1" : 1,
+ "m8n8k32:b:u4" : 1,
+ "m8n8k32:b:s4" : 1,
+ "m8n8k128:c:s32" : 2,
+ "m8n8k128:d:s32" : 2,
+ "m8n8k32:c:s32" : 2,
+ "m8n8k32:d:s32" : 2,
+ }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), None));
+ assert(self.nregs);
+
+ def __repr__(self):
+ return "%s:%s:%s%s" % (self.geom, self.frag, self.mma_type,
+ "" if self.nregs == 1 else ("*%d" % self.nregs))
+
+class MMAOp:
+ def __init__(self, a, b, c, d):
+ self.a = a
+ self.b = b
+ self.c = c
+ self.d = d
+
+ def __repr__(self):
+ return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d ))
+
+def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
+ ops = []
+ for geom, type_a, type_c in product( geoms, types_a, types_c):
+ for type_b, type_d in product(types_b if types_b else [type_a],
+ types_d if types_d else [type_c]):
+ ops.append(MMAOp(MMAFrag(geom, "a", type_a),
+ MMAFrag(geom, "b", type_b),
+ MMAFrag(geom, "c", type_c),
+ MMAFrag(geom, "d", type_d)))
+ return ops
+
+def make_ldst_ops(geoms, frags, types):
+ return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
+ in product(geoms, frags, types)]
+
+def get_mma_ops():
+ return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
+ make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["s8", "u8"], [], ["s32"], []) +
+ make_mma_ops(["m8n8k32"],
+ ["s4", "u4"], [], ["s32"], []) +
+ make_mma_ops(["m8n8k128"],
+ ["b1"], [], ["s32"], []))
+def get_ldst_ops(kind):
+ ldst_ops = (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["a", "b"], ["f16", "u8", "s8"]) +
+ make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
+ ["c", "d"], ["f16", "f32", "s32"]) +
+ make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
+ make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
+ make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]))
+ return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
+
+def is_geom_supported(geom):
+ # geometries for FP and ints.
+ if geom in ["m8n32k16", "m32n8k16"]:
+ return ptx_version >= 61
+ # geometries for sub-ints.
+ if geom in ["m8n8k32", "m8n8k128"]:
+ return ptx_version >= 63 and gpu_arch >= 75
+ if geom == "m16n16k16":
+ return ptx_version >= 60
+ assert(False) # Unexpected geometry.
+
+def is_type_supported(ptx_type):
+ if ptx_type in ["s8", "u8", "s32"]:
+ return ptx_version >= 63 and gpu_arch >= 72
+ if ptx_type in ["s4", "u4", "b1"]:
+ return ptx_version >= 63 and gpu_arch >= 75
+ return ptx_version >= 60 and gpu_arch >= 70
+
+
+def is_mma_variant_supported(op, layout_a, layout_b, satf):
+ if not (is_type_supported(op.a.mma_type.ptx_type)
+ and is_geom_supported(op.a.geom)):
+ return False
+ # sub-integer require row/col layout, and no satf.
+ if op.a.mma_type.ptx_type in ["s4", "u4", "b1"]:
+ if op.a.mma_type.ptx_type == "b1" and satf:
+ return False
+ return layout_a == "row" and layout_b == "col"
+ return True
+
+def is_ldst_variant_supported(frag, layout):
+ if not (is_type_supported(frag.mma_type.ptx_type)
+ and is_geom_supported(frag.geom)):
+ return False
+ if frag.mma_type.ptx_type in ["s4", "u4", "b1"]:
+ # sub-integer require sm_75 and ptx63, row/col layout for a/b.
+ return ((frag.frag == "a" and layout == "row")
+ or (frag.frag == "b" and layout == "col")
+ or frag.frag in ["c", "d"])
+ return True
+
+def make_wmma_slice_ty(frag):
+ return [frag.mma_type.llvm_type] * frag.nregs
+
+def make_wmma_ld_ret_ty(frag):
+ results = make_wmma_slice_ty(frag)
+ if len(results) == 1:
+ return "%s" % results[0]
+ return "{%s}" % ", ".join(results)
# returns address space
def get_aspace(space):
@@ -36,10 +229,8 @@ def get_aspace(space):
def get_pspace(space):
return "p%di8" % get_aspace(space);
-# Convenient test patterns.
-check_f16_8 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 8)
-check_f16_4 = "{{%s}}" % ", *".join(["%hh[0-9]+"] * 4)
-check_f32_8 = "{{%s}}" % ", *".join(["%f[0-9]+"] * 8)
+def check_pattern(frag):
+ return "{{%s}}" % ", *".join([frag.mma_type.ptx_reg_pattern] * frag.nregs)
known_geoms = ["m16n16k16", "m8n32k16", "m32n8k16"]
@@ -69,38 +260,35 @@ define ${ret_ty} @test_${function}_o(i8
intrinsic_template = "llvm.nvvm.wmma.${geom}.load.${abc}.${layout}${stride}.${itype}.${pspace}"
instruction_template = "wmma.load.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
- for geom, abc, layout, space, stride, itype in product(
- known_geoms,
- "abc",
+ generated_items = []
+
+ for frag, layout, space, stride in product(
+ get_ldst_ops("load"),
["row","col"],
["",".shared",".global"],
["", ".stride"],
- ["f16", "f32"]):
+ ):
+ if not is_ldst_variant_supported(frag, layout):
+ continue
params = {
- "abc" : abc,
+ "abc" : frag.frag,
"aligned" : ".aligned" if ptx_version >= 63 else "",
"layout" : layout,
"space" : space,
"stride" : stride,
- "itype" : itype,
+ "itype" : frag.mma_type.ptx_type,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space),
- "geom" : geom,
+ "geom" : frag.geom,
}
- if itype == "f32" and abc != "c":
- continue
-
test_params = params
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".","_")
test_params["instruction"] = Template(instruction_template).substitute(params)
- test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
- if abc == "c" :
- test_params["check_result"] = check_f16_4 if itype == "f16" else check_f32_8
- else:
- test_params["check_result"] = check_f16_8
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
+ test_params["check_result"] = check_pattern(frag)
if stride:
test_params["extra_args"] = ", i32 %stride";
@@ -111,9 +299,14 @@ define ${ret_ty} @test_${function}_o(i8
print(Template(load_template).substitute(test_params))
-def make_wmma_slice_args(itype, abcd, prefix="v"):
- return ", ".join(["%s %%%s%d" % (t, prefix, i) for i,t
- in enumerate(make_wmma_slice_ty(abcd, itype))])
+ generated_items.append((test_params["intrinsic"],
+ test_params["instruction"]))
+
+ return generated_items
+
+def make_wmma_slice_args(frag):
+ return ", ".join(["%s %%%s%d" % (t, frag.frag, i) for i,t
+ in enumerate(make_wmma_slice_ty(frag))])
def gen_wmma_store_tests():
store_template = """
@@ -141,41 +334,64 @@ define void @test_${function}_o(i8 ${as}
intrinsic_template = "llvm.nvvm.wmma.${geom}.store.${abc}.${layout}${stride}.${itype}.${pspace}"
instruction_template = "wmma.store.${abc}.sync${aligned}.${layout}.${geom}${space}.${itype}"
- for geom, abc, layout, space, stride, itype in product(
- known_geoms,
- "d",
+ generated_items = []
+
+ for frag, layout, space, stride in product(
+ get_ldst_ops("store"),
["row","col"],
["",".shared",".global"],
- ["", ".stride"],
- ["f16", "f32"]):
+ ["", ".stride"]):
+
+ if not is_ldst_variant_supported(frag, layout):
+ continue
params = {
- "abc" : abc,
+ "abc" : frag.frag,
"aligned" : ".aligned" if ptx_version >= 63 else "",
"layout" : layout,
"space" : space,
"stride" : stride,
- "itype" : itype,
+ "itype" : frag.mma_type.ptx_type,
"pspace" : get_pspace(space),
"as" : "addrspace(%d)" % get_aspace(space),
- "geom" : geom,
+ "geom" : frag.geom,
}
test_params = params
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".","_")
test_params["instruction"] = Template(instruction_template).substitute(params)
- test_params["ret_ty"] = make_wmma_ld_ret_ty(abc, itype)
- test_params["check_args"] = check_f16_4 if itype == "f16" else check_f32_8
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
+ test_params["check_args"] = check_pattern(frag)
if stride:
test_params["extra_args"] = ", i32 %stride";
test_params["stride_pattern"] = ", %r{{[0-9]+}};"
else:
test_params["extra_args"] = ""
test_params["stride_pattern"] = ";"
- test_params["args"] = make_wmma_slice_args(itype, "d");
+ test_params["args"] = make_wmma_slice_args(frag);
print(Template(store_template).substitute(test_params))
+ generated_items.append((test_params["intrinsic"],
+ test_params["instruction"]))
+
+ return generated_items
+
+def mma_signature(op):
+ if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
+ # int and sub-int ops are identified by input type.
+ return op.a.mma_type.ptx_type
+ else:
+ # the rest are FP ops identified by accumulator & result type.
+ return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
+
+def mma_ptx_signature(op):
+ if op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
+ # int and sub-int instructions encode all four types as D.A.B.C
+ return ".".join(x.mma_type.ptx_type for x in (op.d, op.a, op.b, op.c))
+ else:
+ # the rest are FP instructions use D.C
+ return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
def gen_wmma_mma_tests():
mma_template = """
@@ -187,58 +403,129 @@ define ${ret_ty} @test_${function}(
${args}) {
; CHECK: ${instruction}
; CHECK-NEXT: ${check_d}
-; CHECK-NEXT: ${check_ab}
-; CHECK-NEXT: ${check_ab}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
; CHECK-NEXT: ${check_c}
%r = call ${ret_ty} @${intrinsic}(
${args});
ret ${ret_ty} %r;
}
"""
- intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${dtype}.${ctype}${satf}"
- instruction_template = "wmma.mma.sync${aligned}.${alayout}.${blayout}.${geom}.${dtype}.${ctype}${satf}"
+ intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
+ instruction_template = "wmma.mma${mma_variant}.sync${aligned}.${alayout}.${blayout}.${geom}.${ptx_signature}${satf}"
- for geom, alayout, blayout, ctype, dtype, satf in product(
- known_geoms,
+ generated_items=[]
+
+ for op, alayout, blayout, satf in product(
+ get_mma_ops(),
["row","col"],
["row","col"],
- ["f16", "f32"],
- ["f16", "f32"],
[".satfinite", ""]):
+ if not is_mma_variant_supported(op, alayout, blayout, satf):
+ continue
+
params = {
"aligned" : ".aligned" if ptx_version >= 63 else "",
"alayout" : alayout,
"blayout" : blayout,
- "ctype" : ctype,
- "dtype" : dtype,
+ "intrinsic_signature" : mma_signature(op),
+ "ptx_signature" : mma_ptx_signature(op),
"satf" : satf,
- "geom" : geom,
+ "geom" : op.a.geom,
+ "mma_variant" : ".xor.popc" if op.a.mma_type.ptx_type == "b1" else "",
}
test_params = params
test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
test_params["function"] = test_params["intrinsic"].replace(".", "_")
test_params["instruction"] = Template(instruction_template).substitute(params)
- test_params["ret_ty"] = make_wmma_ld_ret_ty("d", dtype)
- test_params["check_ab"] = check_f16_8
- test_params["check_c"] = check_f16_4 if ctype == "f16" else check_f32_8
- test_params["check_d"] = check_f16_4 if dtype == "f16" else check_f32_8
- args = ",\n ".join(make_wmma_slice_args(t, abcd, prefix=abcd)
- for abcd, t in (("a", "f16"),
- ("b", "f16"),
- ("c", ctype)))
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+ test_params["check_a"] = check_pattern(op.a)
+ test_params["check_b"] = check_pattern(op.b)
+ test_params["check_c"] = check_pattern(op.c)
+ test_params["check_d"] = check_pattern(op.d)
+ args = ",\n ".join(make_wmma_slice_args(frag)
+ for frag in (op.a, op.b, op.c))
test_params["args"] = args
print(Template(mma_template).substitute(test_params))
+ generated_items.append((test_params["intrinsic"],
+ test_params["instruction"]))
+
+ return generated_items
-def main():
- gen_wmma_load_tests()
- gen_wmma_store_tests()
- gen_wmma_mma_tests()
+# Append complete list of intrinsics and instructions we've generated tests for.
+# Generate set of checks to verify that that we did generate sensible set of
+# tests for the given combination of PTX and SM variants.
+#
+# PTX<N>: verifies that we did generate tests for correct classes of intrinsics.
+# PTX<N>U: verifies that we did not generate intrinsics unsupported by
+# the PTX version.
+# SM<N>: verifies that we did generate correct classes of instructions for the SM.
+# SM<N>U: verifies that we did not generate instructions unsupported by the SM
+#
+# Note that SM/PTX constraints overlap, but DAG checks do not allow overlapping
+# matches. We implicitly rely that we generate multiple variants of most of the
+# instructions and usually have enough input data to find more than one match of
+# the same kind, if necessary. When it's not possible (e.g. there's only one
+# m8n8k128.mma.row.col.b1), we may need to match PTX instruction instead.
+def gen_check_unsupported_ops(items):
+ print("; Complete list of intrinsics supported by PTX%d on sm_%d"
+ % (ptx_version, gpu_arch))
+ print("; INTRINSICS: {{^; INTRINSICS_LIST_BEGIN}}")
+ print("""
+; PTX60-DAG: m16n16k16.load.{{[ab].*}}.f16.p
+; PTX60-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
+; PTX60U-NOT: m32n8k16
+; PTX60U-NOT: m8n32k16
+; PTX60U-NOT: .{{s32|s[48]|u[48]|b1}}
+
+; All features of PTX60, plus m32n8k16/m8n32k16 geometries.
+; PTX61-DAG: m32n8k16.load.{{[ab].*}}.f16.p
+; PTX61-DAG: m32n8k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
+; PTX61-DAG: m8n32k16.load.{{[ab].*}}.f16.p
+; PTX61-DAG: m8n32k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
+; PTX61U-NOT: .{{s32|s[48]|u[48]|b1}}
+
+; SM70U-NOT: .{{s32|s[48]|u[48]|b1}}
+
+; PTX63 supports all features of PTX60+PTX61, plus support for integers.
+; Alas we can"t just use PTX<N> checks for that as available instructions
+; depend on SM integers need sm72+ and subinteger ops need sm75, so we
+; transition to SM<N> checks
+; SM72-DAG: m16n16k16.load.{{[ab].*}}.s8.p
+; SM72-DAG: m8n32k16.load.{{[ab].*}}.s8.p
+; SM72-DAG: m32n8k16.load.{{[ab].*}}.s8.p
+; SM72-DAG: m16n16k16.load.{{[ab].*}}.u8.p
+; SM72-DAG: m8n32k16.load.{{[ab].*}}.u8.p
+; SM72-DAG: m32n8k16.load.{{[ab].*}}.u8.p
+; SM72-DAG: m32n8k16.{{load|store}}.{{[cd].*\.s32}}.p
+; SM72U-NOT: .{{s4|u4|b1}}
+
+; SM75-DAG: m8n8k128.load.{{[ab].*}}.b1.p
+; SM75-DAG: m8n8k32.load.{{[ab].*}}.s4.p
+; SM75-DAG: m8n8k32.load.{{[ab].*}}.u4.p
+; SM75-DAG: m8n8k128.{{load|store}}.{{[cd].*\.s32}}.p
+; SM75-DAG: m8n8k32.{{load|store}}.{{[cd].*\.s32}}.p
+""")
+
+ print("; INTRINSICS_LIST_BEGIN")
+ for intrinsic, instruction in sorted(items):
+ print("; ", intrinsic, " -> ", instruction,"")
+ print("; INTRINSICS_LIST_END")
+ print("; INTRINSICS: ; INTRINSICS_LIST_END")
+
+def gen_tests():
+ items = gen_wmma_load_tests()
+ items += gen_wmma_store_tests()
+ items += gen_wmma_mma_tests()
+ gen_check_unsupported_ops(items)
parser = argparse.ArgumentParser()
-parser.add_argument('--ptx', type=int, default=60)
+parser.add_argument("--ptx", type=int, default=60)
+parser.add_argument("--gpu-arch", type=int, default=70)
args = parser.parse_args()
ptx_version = args.ptx
+gpu_arch = args.gpu_arch
-main()
+gen_tests()
More information about the llvm-commits
mailing list