[llvm] 1b4c85f - [NVPTX] Add NVPTX intrinsics for CUDA PTX 6.5 ldmatrix instructions
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 6 16:14:17 PDT 2021
Author: Steffen Larsen
Date: 2021-08-06T16:13:35-07:00
New Revision: 1b4c85fc02cc87b4abcd794c98e6ff91a3d3766b
URL: https://github.com/llvm/llvm-project/commit/1b4c85fc02cc87b4abcd794c98e6ff91a3d3766b
DIFF: https://github.com/llvm/llvm-project/commit/1b4c85fc02cc87b4abcd794c98e6ff91a3d3766b.diff
LOG: [NVPTX] Add NVPTX intrinsics for CUDA PTX 6.5 ldmatrix instructions
Adds NVPTX intrinsics for the CUDA PTX `ldmatrix.sync.aligned` instructions added in PTX 6.5.
PTX ISA description of `ldmatrix.sync.aligned`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
Authored-by: Steffen Larsen <steffen.larsen at codeplay.com>
Reviewed By: tra
Differential Revision: https://reviews.llvm.org/D107046
Added:
Modified:
llvm/include/llvm/IR/IntrinsicsNVVM.td
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/test/CodeGen/NVPTX/wmma.py
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index cc43d23bec1c..6676303c2fef 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -43,7 +43,7 @@ def llvm_shared_i64ptr_ty : LLVMQualPointerType<llvm_i64_ty, 3>; // (shared)i64*
// Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
// Geom: m<M>n<N>k<K>. E.g. m8n32k16
-// Frag: [abcd]
+// Frag: [a|b|c|d] ([x1|x2|x4] for ldmatrix)
// PtxEltType: PTX type for the element.
class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string geom = Geom;
@@ -190,6 +190,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2),
!eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b16 -> s32 @ m8n8
+ !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
+ !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
);
}
@@ -256,6 +261,17 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
!subst("llvm.", "int_", llvm));
}
+class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
+ string intr = "llvm.nvvm.ldmatrix.sync.aligned"
+ # "." # Frag.geom
+ # "." # Frag.frag
+ # !if(Trans, ".trans", "")
+ # "." # Frag.ptx_elt_type
+ ;
+ string record = !subst(".", "_",
+ !subst("llvm.", "int_", intr));
+}
+
// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
// Geom: list of supported geometries.
// TypeN: PTX type of the corresponding fragment's element.
@@ -286,6 +302,16 @@ class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<string> ops = !foreach(x, ret, x.gft);
}
+class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+ list<WMMA_REGS> ret =
+ !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+ !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+ !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+ [WMMA_REGS<geom, frag, type>]))))));
+ // Debugging aid for readable representation of the list above.
+ list<string> ops = !foreach(x, ret, x.gft);
+}
+
// Creates list of valid combinations of fragments. This is the master list that
// drives generation of corresponding intrinsics and instructions.
class NVVM_MMA_OPS<int _ = 0> {
@@ -370,11 +396,14 @@ class NVVM_MMA_OPS<int _ = 0> {
// Separate A/B/C fragments (loads) from D (stores).
list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
+
+ list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
+ ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
+ list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
-
// Returns true if this combination of fragment and layout for WMMA load/store
// ops is supported; false otherwise.
// E.g.
@@ -489,6 +518,23 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
);
}
+// Returns true if the fragment is valid for ldmatrix ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_LDMATRIX_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
+ string g = frag.geom;
+ string t = frag.ptx_elt_type;
+
+ bit ret = !cond(
+ // Only currently support m8n8 and b16
+ !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+ true: false
+ );
+}
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -4519,4 +4565,20 @@ foreach layout_a = ["row", "col"] in {
} // layout_b
} // layout_a
+// LDMATRIX
+class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
+ : Intrinsic<Frag.regs, [llvm_anyptr_ty],
+ [IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>,
+ NoCapture<ArgIndex<0>>],
+ LDMATRIX_NAME<Frag, Transposed>.intr>;
+
+foreach transposed = [0, 1] in {
+ foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
+ if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
+ def LDMATRIX_NAME<frag, transposed>.record
+ : NVVM_LDMATRIX<frag, transposed>;
+ }
+ }
+}
+
} // let TargetPrefix = "nvvm"
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d4842c953ce7..0922eac58fe9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3547,7 +3547,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
- case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: {
+ case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3585,7 +3587,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
- case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: {
+ case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3679,7 +3683,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
- case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: {
+ case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index de4bf2ef3055..138f32bd2bd2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7578,6 +7578,7 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!eq(ptx_elt_type, "bf16") : Int32Regs,
!eq(ptx_elt_type, "tf32") : Int32Regs,
!eq(ptx_elt_type, "s32") : Int32Regs,
+ !eq(ptx_elt_type, "b16") : Int32Regs,
!eq(ptx_elt_type, "s8") : Int32Regs,
!eq(ptx_elt_type, "u8") : Int32Regs,
!eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7661,7 +7662,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!eq(geom, "m16n8k64"),
!eq(geom, "m8n8k128"),
!eq(geom, "m16n8k128"),
- !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]);
+ !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b16"),
+ !eq(geom, "m8n8")) : [hasSM75, hasPTX65]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7910,6 +7915,44 @@ defset list<WMMA_INSTR> MMAs = {
} // layout_a
} // defset
+//
+// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class LDMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space,
+ DAGOperand SrcOp>
+ : WMMA_INSTR<LDMATRIX_NAME<Frag, Transposed>.record, [(ins SrcOp:$src)]>,
+ Requires<Frag.Predicates> {
+ // Build PatFrag that only matches particular address space.
+ PatFrag IntrFrag = PatFrag<(ops node:$src), (Intr node:$src),
+ !cond(!eq(Space, ".shared"): AS_match.shared,
+ true: AS_match.generic)>;
+ // Build AS-constrained pattern.
+ let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+
+ let OutOperandList = Frag.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ let AsmString = "ldmatrix.sync.aligned."
+ # Frag.geom
+ # "." # Frag.frag
+ # !if(Transposed, ".trans", "")
+ # Space
+ # "." # Frag.ptx_elt_type
+ # " " # Frag.regstring # ", [$src];";
+}
+
+// Create all ldmatrix variants
+defset list<WMMA_INSTR> LDMATRIXs = {
+ foreach transposed = [false, true] in {
+ foreach space = [".shared", ""] in {
+ foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
+ foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
+ if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
+ def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space,
+ addr>;
+ } // addr
+ } // space
+ } // transposed
+} // defset
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
@@ -7921,5 +7964,5 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
def : MMA_PAT<mma>;
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 785e48ce75a2..3b3d10947cac 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -6,7 +6,7 @@
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16
# RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
-# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
+# RUN: --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
# RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
# RUN: | FileCheck %t-ptx60-sm_70.ll
@@ -15,7 +15,7 @@
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM
# RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
-# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
+# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
# RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
# RUN: | FileCheck %t-ptx61-sm_70.ll
@@ -24,7 +24,7 @@
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
# RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
-# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
+# RUN: --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
# RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
# RUN: | FileCheck %t-ptx63-sm_72.ll
@@ -33,7 +33,7 @@
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
# RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
-# RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT
+# RUN: --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
# RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
# RUN: | FileCheck %t-ptx63-sm_75.ll
@@ -42,14 +42,14 @@
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
# RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
-# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT
+# RUN: --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT,NOLDMATRIX
# RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
# RUN: | FileCheck %t-ptx64-sm_70.ll
# Check all variants of instructions supported by PTX65 on SM75+
# RUN: %python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
-# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA
+# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA,PTX65LDMATRIX
# RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
# RUN: --check-prefixes=INTRINSICS
# RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
@@ -58,7 +58,7 @@
# Check all variants of instructions supported by PTX71 on SM80+
# RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll
# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
-# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX71MMA
+# RUN: --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX65LDMATRIX,PTX71MMA
# RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
# RUN: --check-prefixes=INTRINSICS
# RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
@@ -78,6 +78,7 @@ def __init__(self, ptx_type):
"f32" : "float",
"f64" : "double",
"s32" : "i32",
+ "b16" : "i32",
"s8" : "i32",
"u8" : "i32",
"s4" : "i32",
@@ -232,6 +233,11 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k16:d:f16": 2,
"m16n8k16:c:f32": 4,
"m16n8k16:d:f32": 4,
+
+ # ldmatrix
+ "m8n8:x1:b16": 1,
+ "m8n8:x2:b16": 2,
+ "m8n8:x4:b16": 4,
}.get("%s:%s:%s" % (geom, frag, ptx_elt_type), {
# All other FP shape/fragment/type combinations have the same size
"a:f16" : 8,
@@ -272,6 +278,10 @@ def make_ldst_ops(geoms, frags, types):
return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
in product(geoms, frags, types)]
+def make_ldmatrix_ops(geoms, frags, types):
+ return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
+ in product(geoms, frags, types)]
+
def get_wmma_ops():
return (make_mma_ops(["m16n16k8"],
["tf32"], [], ["f32"], []) +
@@ -317,6 +327,9 @@ def get_ldst_ops(kind):
make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
+def get_ldmatrix_ops():
+ return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+
def is_wmma_geom_supported(geom):
# geometries for FP and ints.
if geom in ["m8n32k16", "m32n8k16"]:
@@ -343,11 +356,18 @@ def is_mma_geom_supported(geom):
return ptx_version >= 70
assert(False) # Unexpected geometry.
+def is_ldmatrix_geom_supported(geom):
+ if geom in ["m8n8"]:
+ return ptx_version >= 65 and gpu_arch >= 75
+ assert(False) # Unexpected geometry.
+
def is_type_supported(ptx_type):
if ptx_type in ["s8", "u8", "s32"]:
return ptx_version >= 63 and gpu_arch >= 72
if ptx_type in ["s4", "u4", "b1"]:
return ptx_version >= 63 and gpu_arch >= 75
+ if ptx_type == "b16":
+ return ptx_version >= 65 and gpu_arch >= 75
if ptx_type in ["bf16", "tf32", "f64"]:
return ptx_version >= 70
return ptx_version >= 60 and gpu_arch >= 70
@@ -413,6 +433,12 @@ def is_ldst_variant_supported(frag, layout):
or frag.frag in ["c", "d"])
return True
+def is_ldmatrix_variant_supported(frag):
+ if not (is_type_supported(frag.mma_type.ptx_type)
+ and is_ldmatrix_geom_supported(frag.geom)):
+ return False
+ return frag.frag in ["x1", "x2", "x4"]
+
def make_wmma_slice_ty(frag):
return [frag.mma_type.llvm_type] * frag.nregs
@@ -584,6 +610,66 @@ def gen_wmma_store_tests():
return generated_items
+def gen_ldmatrix_tests():
+ ldmatrix_template = """
+declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define ${ret_ty} @test_${function}(i8 ${as}* %src) {
+; CHECK: ${instruction}
+; CHECK: {${check_result}}
+; CHECK: [%rd{{[0-9]+}}]
+ %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
+ ret ${ret_ty} %v0;
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
+; CHECK: ${instruction}
+; CHECK: {${check_result}}
+; CHECK: [%rd{{[0-9]+}}+128]
+ %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
+ %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
+ ret ${ret_ty} %v0;
+}
+"""
+ intrinsic_template = "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+ instruction_template = "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+
+ generated_items = []
+
+ for frag, space, trans in product(
+ get_ldmatrix_ops(),
+ ["",".shared"],
+ ["",".trans"],
+ ):
+ if not is_ldmatrix_variant_supported(frag):
+ continue
+
+ params = {
+ "frag" : frag.frag,
+ "space" : space,
+ "trans" : trans,
+ "itype" : frag.mma_type.ptx_type,
+ "pspace" : get_pspace(space),
+ "as" : "addrspace(%d)" % get_aspace(space),
+ "geom" : frag.geom,
+ }
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".","_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
+ test_params["check_result"] = check_pattern(frag)
+
+ print(Template(ldmatrix_template).substitute(test_params))
+
+ generated_items.append((test_params["intrinsic"],
+ test_params["instruction"]))
+
+ return generated_items
+
def mma_signature(op):
if op.a.mma_type.ptx_type == "f16":
# FP16 ops identified by accumulator & result type.
@@ -744,6 +830,7 @@ def gen_check_unsupported_ops(items):
; NOMMA-NOT: .m8n8k4.
; NOALTFLOAT-NOT: .{{bf16|tf32}}
; NODOUBLE-NOT: .f64
+; NOLDMATRIX-NOT: ldmatrix.sync.aligned
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -819,6 +906,19 @@ def gen_check_unsupported_ops(items):
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -861,6 +961,7 @@ def gen_check_unsupported_ops(items):
def gen_tests():
items = gen_wmma_load_tests()
items += gen_wmma_store_tests()
+ items += gen_ldmatrix_tests()
items += gen_wmma_mma_tests()
items += gen_mma_tests()
gen_check_unsupported_ops(items)
More information about the llvm-commits
mailing list