[llvm] [LLVM][NVPTX] Add support for ldmatrix extensions introduced in PTX 8.6 (PR #124899)
Pradeep Kumar via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 29 06:53:31 PST 2025
https://github.com/schwarzschild-radius updated https://github.com/llvm/llvm-project/pull/124899
>From 52f361974bd4bc8c437a681962ca46350a039a8a Mon Sep 17 00:00:00 2001
From: pradeepku <pradeepku at nvidia.com>
Date: Tue, 28 Jan 2025 14:55:39 +0530
Subject: [PATCH] [LLVM][NVPTX] Add support for ldmatrix extensions introduced
in PTX 8.6
This commit adds support for the following ldmatrix extensions introduced in PTX 8.6
- Support for m16n16 with b8 type with mandatory transpose
- Support for m16n16 with m8n16 with source and desitination formats
The above extensions are only supported on sm_100a, sm_101a, sm_120a
Please refer the PTX ISA for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 39 ++++++++++---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 18 +++++-
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 1 +
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 27 ++++++++-
llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py | 16 ++++++
llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py | 16 ++++++
llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py | 16 ++++++
llvm/test/CodeGen/NVPTX/wmma.py | 59 +++++++++++++++++++-
8 files changed, 177 insertions(+), 15 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 9a2f38d760e659..f3aac47e4c4033 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -62,6 +62,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string frag = Frag;
string ptx_elt_type = PtxEltType;
string gft = Geom#":"#Frag#":"#ptx_elt_type;
+ string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
// mma fp ops use smaller fragments than wmma fp ops
@@ -204,9 +205,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
// ldmatrix b16 -> s32 @ m8n8
- !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
- !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
+ !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
+ !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+
);
}
@@ -411,7 +422,16 @@ class NVVM_MMA_OPS {
list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
- list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
+
+ list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
+ ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+ list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
+ ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+ list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
+ ldmatrix_geom_m16n16_ops,
+ ldmatrix_geom_m8n16_ops);
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -536,13 +556,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
// if NVVM_LDMATRIX_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
-class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
+class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
string g = frag.geom;
string t = frag.ptx_elt_type;
bit ret = !cond(
- // Only currently support m8n8 and b16
!and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+ !and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32")): true,
+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64")): true,
+ !and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
true: false
);
}
@@ -4932,7 +4957,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
foreach transposed = [0, 1] in {
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
- if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
+ if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
def LDMATRIX_NAME<frag, transposed>.record
: NVVM_LDMATRIX<frag, transposed>;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 773c97f7b4dc0f..4c1c5c10bfcc8b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3552,7 +3552,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
- case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3592,7 +3597,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
- case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3688,7 +3695,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
- case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 633a99d0fc1be3..d0a625643e2129 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -170,6 +170,7 @@ def False : Predicate<"false">;
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
+def hasAAFeatures : Predicate<"Subtarget->hasAAFeatures()">;
// Explicit records for arch-accelerated SM versions
def hasSM90a : Predicate<"Subtarget->getFullSmVersion() == 901">;
def hasSM100a : Predicate<"Subtarget->getFullSmVersion() == 1001">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 56d8b734bf01df..b2cf22b255f1d0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7107,6 +7107,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!eq(ptx_elt_type, "tf32") : Int32Regs,
!eq(ptx_elt_type, "s32") : Int32Regs,
!eq(ptx_elt_type, "b16") : Int32Regs,
+ !eq(ptx_elt_type, "b8") : Int32Regs,
+ !eq(ptx_elt_type, "b8x16.b6x16_p32") : Int32Regs,
+ !eq(ptx_elt_type, "b8x16.b4x16_p64") : Int32Regs,
!eq(ptx_elt_type, "s8") : Int32Regs,
!eq(ptx_elt_type, "u8") : Int32Regs,
!eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7194,7 +7197,27 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!and(!eq(op,"ldmatrix"),
!eq(ptx_elt_type,"b16"),
- !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>]);
+ !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8"),
+ !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+ !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+ !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+ !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+ !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7478,7 +7501,7 @@ defset list<WMMA_INSTR> LDMATRIXs = {
foreach space = [".shared", ""] in {
foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
- if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
+ if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then
def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space,
addr>;
} // addr
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
new file mode 100644
index 00000000000000..6ad0a2a5865c41
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM100a
+# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN: | FileCheck %t-ptx86-sm_100a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN: | %ptxas-verify -arch=sm_100a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
new file mode 100644
index 00000000000000..7d9953484da7d0
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM101a
+# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN: | FileCheck %t-ptx86-sm_101a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN: | %ptxas-verify -arch=sm_101a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
new file mode 100644
index 00000000000000..7bddf0b6fbb785
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM120a
+# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN: | FileCheck %t-ptx86-sm_120a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN: | %ptxas-verify -arch=sm_120a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index e1e46f0b8cab34..ce275c9b712825 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -19,6 +19,9 @@ def __init__(self, ptx_type):
"f64": "double",
"s32": "i32",
"b16": "i32",
+ "b8": "i32",
+ "b8x16.b6x16_p32": "i32",
+ "b8x16.b4x16_p64": "i32",
"s8": "i32",
"u8": "i32",
"s4": "i32",
@@ -161,6 +164,18 @@ def __init__(self, geom, frag, ptx_elt_type):
"m8n8:x1:b16": 1,
"m8n8:x2:b16": 2,
"m8n8:x4:b16": 4,
+ "m16n16:x1:b8": 2,
+ "m16n16:x2:b8": 4,
+ "m16n16:x1:b8x16.b6x16_p32": 2,
+ "m16n16:x2:b8x16.b6x16_p32": 4,
+ "m16n16:x1:b8x16.b4x16_p64": 2,
+ "m16n16:x2:b8x16.b4x16_p64": 4,
+ "m8n16:x1:b8x16.b6x16_p32": 1,
+ "m8n16:x2:b8x16.b6x16_p32": 2,
+ "m8n16:x4:b8x16.b6x16_p32": 4,
+ "m8n16:x1:b8x16.b4x16_p64": 1,
+ "m8n16:x2:b8x16.b4x16_p64": 2,
+ "m8n16:x4:b8x16.b4x16_p64": 4,
}.get(
"%s:%s:%s" % (geom, frag, ptx_elt_type),
{
@@ -289,7 +304,15 @@ def get_ldst_ops(kind):
def get_ldmatrix_ops():
- return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+ return (
+ make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+ + make_ldmatrix_ops(
+ ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]
+ )
+ + make_ldmatrix_ops(
+ ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]
+ )
+ )
def is_wmma_geom_supported(geom):
@@ -330,9 +353,22 @@ def is_mma_geom_supported(geom):
def is_ldmatrix_geom_supported(geom):
if geom in ["m8n8"]:
return ptx_version >= 65 and gpu_arch >= 75
+ elif geom in ["m16n16"]:
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
+ elif geom in ["m8n16"]:
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
assert False # Unexpected geometry.
+def is_ldmatrix_trans_supported(geom, trans):
+ if geom in ["m8n8"]:
+ return True
+ elif geom in ["m16n16"]:
+ return trans == ".trans"
+ elif geom in ["m8n16"]:
+ return trans == ""
+ assert False # Unexpected geometry.
+
def is_type_supported(ptx_type):
if ptx_type in ["s8", "u8", "s32"]:
return ptx_version >= 63 and gpu_arch >= 72
@@ -417,10 +453,11 @@ def is_ldst_variant_supported(frag, layout):
return True
-def is_ldmatrix_variant_supported(frag):
+def is_ldmatrix_variant_supported(frag, trans):
if not (
is_type_supported(frag.mma_type.ptx_type)
and is_ldmatrix_geom_supported(frag.geom)
+ and is_ldmatrix_trans_supported(frag.geom, trans)
):
return False
return frag.frag in ["x1", "x2", "x4"]
@@ -653,7 +690,7 @@ def gen_ldmatrix_tests():
["", ".shared"],
["", ".trans"],
):
- if not is_ldmatrix_variant_supported(frag):
+ if not is_ldmatrix_variant_supported(frag, trans):
continue
params = {
@@ -944,6 +981,19 @@ def gen_check_unsupported_ops(items):
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -997,13 +1047,16 @@ def gen_tests():
def main():
global ptx_version
global gpu_arch
+ global aa
parser = argparse.ArgumentParser()
parser.add_argument("--ptx", type=int, default=60)
parser.add_argument("--gpu-arch", type=int, default=70)
+ parser.add_argument("--aa", action="store_true")
args = parser.parse_args()
ptx_version = args.ptx
gpu_arch = args.gpu_arch
+ aa = args.aa
gen_tests()
More information about the llvm-commits
mailing list