[llvm] 52e7ca9 - [LLVM][NVPTX] Add support for ldmatrix extensions introduced in PTX 8.6 (#124899)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 17 08:14:55 PDT 2025


Author: Pradeep Kumar
Date: 2025-03-17T20:44:52+05:30
New Revision: 52e7ca9279b4cbe30cacca67548347ef5f96b120

URL: https://github.com/llvm/llvm-project/commit/52e7ca9279b4cbe30cacca67548347ef5f96b120
DIFF: https://github.com/llvm/llvm-project/commit/52e7ca9279b4cbe30cacca67548347ef5f96b120.diff

LOG: [LLVM][NVPTX] Add support for ldmatrix extensions introduced in PTX 8.6 (#124899)

This commit adds support for the following ldmatrix extensions
introduced in PTX 8.6
- Support for m16n16 with b8 type with mandatory transpose
- Support for m16n16 with m8n16 with source and desitination formats

The above extensions are only supported on sm_100a, sm_101a, sm_120a

Please refer the PTX ISA for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix

Added: 
    llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
    llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
    llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py

Modified: 
    llvm/include/llvm/IR/IntrinsicsNVVM.td
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
    llvm/test/CodeGen/NVPTX/wmma.py

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index ea58985cbebda..665db3025903e 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -72,6 +72,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
   string frag = Frag;
   string ptx_elt_type = PtxEltType;
   string gft = Geom#":"#Frag#":"#ptx_elt_type;
+  string gf = Geom#":"#Frag;
   string ft = frag#":"#ptx_elt_type;
   list<LLVMType> regs = !cond(
     // mma fp ops use smaller fragments than wmma fp ops
@@ -214,9 +215,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
 
     // ldmatrix b16 -> s32 @ m8n8
-    !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
-    !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
-    !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
+    !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
+
+    // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
+    !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
+
+    // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
+    !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+
   );
 }
 
@@ -421,7 +432,16 @@ class NVVM_MMA_OPS {
 
   list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
     ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
-  list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
+
+  list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
+    ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+  list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
+    ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+  list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
+                                                 ldmatrix_geom_m16n16_ops,
+                                                 ldmatrix_geom_m8n16_ops);
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -546,13 +566,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
 // if NVVM_LDMATRIX_SUPPORTED<...>.ret then
 //   def : FOO<>; // The record will only be defined for supported ops.
 //
-class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
+class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
   string g = frag.geom;
   string t = frag.ptx_elt_type;
 
   bit ret = !cond(
-    // Only currently support m8n8 and b16
     !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 1)): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 1)): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
     true: false
   );
 }
@@ -4983,7 +5008,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
 
 foreach transposed = [0, 1] in {
   foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
-    if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
+    if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
       def LDMATRIX_NAME<frag, transposed>.record
         : NVVM_LDMATRIX<frag, transposed>;
     }

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b768725b04256..18ec5c5384488 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3681,7 +3681,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v4i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3721,7 +3726,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3817,7 +3824,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index f6150ee9db26e..90f56a421b19b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7052,6 +7052,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
     !eq(ptx_elt_type, "tf32") : Int32Regs,
     !eq(ptx_elt_type, "s32") : Int32Regs,
     !eq(ptx_elt_type, "b16") : Int32Regs,
+    !eq(ptx_elt_type, "b8") : Int32Regs,
+    !eq(ptx_elt_type, "b8x16.b6x16_p32") : Int32Regs,
+    !eq(ptx_elt_type, "b8x16.b4x16_p64") : Int32Regs,
     !eq(ptx_elt_type, "s8") : Int32Regs,
     !eq(ptx_elt_type, "u8") : Int32Regs,
     !eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7139,7 +7142,27 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
 
     !and(!eq(op,"ldmatrix"),
          !eq(ptx_elt_type,"b16"),
-         !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>]);
+         !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7414,7 +7437,7 @@ defset list<WMMA_INSTR> LDMATRIXs  = {
   foreach transposed = [false, true] in {
     foreach space = [".shared", ""] in {
       foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
-        if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
+        if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then
           def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space>;
     } // space
   } // transposed

diff  --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
new file mode 100644
index 0000000000000..6ad0a2a5865c4
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM100a
+# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_100a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_100a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()

diff  --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
new file mode 100644
index 0000000000000..7d9953484da7d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM101a
+# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_101a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_101a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()

diff  --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
new file mode 100644
index 0000000000000..7bddf0b6fbb78
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM120a
+# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_120a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_120a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()

diff  --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index e1e46f0b8cab3..ce275c9b71282 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -19,6 +19,9 @@ def __init__(self, ptx_type):
             "f64": "double",
             "s32": "i32",
             "b16": "i32",
+            "b8": "i32",
+            "b8x16.b6x16_p32": "i32",
+            "b8x16.b4x16_p64": "i32",
             "s8": "i32",
             "u8": "i32",
             "s4": "i32",
@@ -161,6 +164,18 @@ def __init__(self, geom, frag, ptx_elt_type):
             "m8n8:x1:b16": 1,
             "m8n8:x2:b16": 2,
             "m8n8:x4:b16": 4,
+            "m16n16:x1:b8": 2,
+            "m16n16:x2:b8": 4,
+            "m16n16:x1:b8x16.b6x16_p32": 2,
+            "m16n16:x2:b8x16.b6x16_p32": 4,
+            "m16n16:x1:b8x16.b4x16_p64": 2,
+            "m16n16:x2:b8x16.b4x16_p64": 4,
+            "m8n16:x1:b8x16.b6x16_p32": 1,
+            "m8n16:x2:b8x16.b6x16_p32": 2,
+            "m8n16:x4:b8x16.b6x16_p32": 4,
+            "m8n16:x1:b8x16.b4x16_p64": 1,
+            "m8n16:x2:b8x16.b4x16_p64": 2,
+            "m8n16:x4:b8x16.b4x16_p64": 4,
         }.get(
             "%s:%s:%s" % (geom, frag, ptx_elt_type),
             {
@@ -289,7 +304,15 @@ def get_ldst_ops(kind):
 
 
 def get_ldmatrix_ops():
-    return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+    return (
+        make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+        + make_ldmatrix_ops(
+            ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]
+        )
+        + make_ldmatrix_ops(
+            ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]
+        )
+    )
 
 
 def is_wmma_geom_supported(geom):
@@ -330,9 +353,22 @@ def is_mma_geom_supported(geom):
 def is_ldmatrix_geom_supported(geom):
     if geom in ["m8n8"]:
         return ptx_version >= 65 and gpu_arch >= 75
+    elif geom in ["m16n16"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
+    elif geom in ["m8n16"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
     assert False  # Unexpected geometry.
 
 
+def is_ldmatrix_trans_supported(geom, trans):
+    if geom in ["m8n8"]:
+        return True
+    elif geom in ["m16n16"]:
+        return trans == ".trans"
+    elif geom in ["m8n16"]:
+        return trans == ""
+    assert False  # Unexpected geometry.
+
 def is_type_supported(ptx_type):
     if ptx_type in ["s8", "u8", "s32"]:
         return ptx_version >= 63 and gpu_arch >= 72
@@ -417,10 +453,11 @@ def is_ldst_variant_supported(frag, layout):
     return True
 
 
-def is_ldmatrix_variant_supported(frag):
+def is_ldmatrix_variant_supported(frag, trans):
     if not (
         is_type_supported(frag.mma_type.ptx_type)
         and is_ldmatrix_geom_supported(frag.geom)
+        and is_ldmatrix_trans_supported(frag.geom, trans)
     ):
         return False
     return frag.frag in ["x1", "x2", "x4"]
@@ -653,7 +690,7 @@ def gen_ldmatrix_tests():
         ["", ".shared"],
         ["", ".trans"],
     ):
-        if not is_ldmatrix_variant_supported(frag):
+        if not is_ldmatrix_variant_supported(frag, trans):
             continue
 
         params = {
@@ -944,6 +981,19 @@ def gen_check_unsupported_ops(items):
 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
 
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -997,13 +1047,16 @@ def gen_tests():
 def main():
     global ptx_version
     global gpu_arch
+    global aa
     parser = argparse.ArgumentParser()
     parser.add_argument("--ptx", type=int, default=60)
     parser.add_argument("--gpu-arch", type=int, default=70)
+    parser.add_argument("--aa", action="store_true")
     args = parser.parse_args()
 
     ptx_version = args.ptx
     gpu_arch = args.gpu_arch
+    aa = args.aa
 
     gen_tests()
 


        


More information about the llvm-commits mailing list