[llvm] [LLVM][NVPTX] Add support for ldmatrix extensions introduced in PTX 8.6 (PR #124899)

Pradeep Kumar via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 29 06:45:54 PST 2025


https://github.com/schwarzschild-radius updated https://github.com/llvm/llvm-project/pull/124899

>From cfabf9c3b5bbbf082334d393e733837a3e8007df Mon Sep 17 00:00:00 2001
From: pradeepku <pradeepku at nvidia.com>
Date: Tue, 28 Jan 2025 14:55:39 +0530
Subject: [PATCH] [LLVM][NVPTX] Add support for ldmatrix extensions introduced
 in PTX 8.6

This commit adds support for the following ldmatrix extensions introduced in PTX 8.6
- Support for m16n16 with b8 type with mandatory transpose
- Support for m16n16 with m8n16 with source and desitination formats

The above extensions are only supported on sm_100a, sm_101a, sm_120a

Please refer the PTX ISA for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
---
 llvm/include/llvm/IR/IntrinsicsNVVM.td       | 39 ++++++++++---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp  | 18 +++++-
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td      |  1 +
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td     | 27 ++++++++-
 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py | 16 ++++++
 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py | 16 ++++++
 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py | 16 ++++++
 llvm/test/CodeGen/NVPTX/wmma.py              | 58 +++++++++++++++++++-
 8 files changed, 176 insertions(+), 15 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
 create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
 create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py

diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 9a2f38d760e659..f3aac47e4c4033 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -62,6 +62,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
   string frag = Frag;
   string ptx_elt_type = PtxEltType;
   string gft = Geom#":"#Frag#":"#ptx_elt_type;
+  string gf = Geom#":"#Frag;
   string ft = frag#":"#ptx_elt_type;
   list<LLVMType> regs = !cond(
     // mma fp ops use smaller fragments than wmma fp ops
@@ -204,9 +205,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
 
     // ldmatrix b16 -> s32 @ m8n8
-    !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
-    !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
-    !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
+    !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
+
+    // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
+    !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
+
+    // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
+    !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+
   );
 }
 
@@ -411,7 +422,16 @@ class NVVM_MMA_OPS {
 
   list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
     ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
-  list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
+
+  list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
+    ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+  list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
+    ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+  list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
+                                                 ldmatrix_geom_m16n16_ops,
+                                                 ldmatrix_geom_m8n16_ops);
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -536,13 +556,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
 // if NVVM_LDMATRIX_SUPPORTED<...>.ret then
 //   def : FOO<>; // The record will only be defined for supported ops.
 //
-class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
+class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
   string g = frag.geom;
   string t = frag.ptx_elt_type;
 
   bit ret = !cond(
-    // Only currently support m8n8 and b16
     !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32")): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64")): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
     true: false
   );
 }
@@ -4932,7 +4957,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
 
 foreach transposed = [0, 1] in {
   foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
-    if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
+    if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
       def LDMATRIX_NAME<frag, transposed>.record
         : NVVM_LDMATRIX<frag, transposed>;
     }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 773c97f7b4dc0f..4c1c5c10bfcc8b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3552,7 +3552,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v4i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3592,7 +3597,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3688,7 +3695,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 633a99d0fc1be3..d0a625643e2129 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -170,6 +170,7 @@ def False : Predicate<"false">;
 class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
 class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
 
+def hasAAFeatures : Predicate<"Subtarget->hasAAFeatures()">;
 // Explicit records for arch-accelerated SM versions
 def hasSM90a : Predicate<"Subtarget->getFullSmVersion() == 901">;
 def hasSM100a : Predicate<"Subtarget->getFullSmVersion() == 1001">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 56d8b734bf01df..b2cf22b255f1d0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7107,6 +7107,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
     !eq(ptx_elt_type, "tf32") : Int32Regs,
     !eq(ptx_elt_type, "s32") : Int32Regs,
     !eq(ptx_elt_type, "b16") : Int32Regs,
+    !eq(ptx_elt_type, "b8") : Int32Regs,
+    !eq(ptx_elt_type, "b8x16.b6x16_p32") : Int32Regs,
+    !eq(ptx_elt_type, "b8x16.b4x16_p64") : Int32Regs,
     !eq(ptx_elt_type, "s8") : Int32Regs,
     !eq(ptx_elt_type, "u8") : Int32Regs,
     !eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7194,7 +7197,27 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
 
     !and(!eq(op,"ldmatrix"),
          !eq(ptx_elt_type,"b16"),
-         !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>]);
+         !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+         !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+         !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7478,7 +7501,7 @@ defset list<WMMA_INSTR> LDMATRIXs  = {
     foreach space = [".shared", ""] in {
       foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
         foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
-          if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
+          if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then
             def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space,
                             addr>;
       } // addr
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
new file mode 100644
index 00000000000000..6ad0a2a5865c41
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM100a
+# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_100a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_100a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
new file mode 100644
index 00000000000000..7d9953484da7d0
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM101a
+# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_101a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_101a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
new file mode 100644
index 00000000000000..7bddf0b6fbb785
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM120a
+# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_120a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_120a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index e1e46f0b8cab34..8f4dde747eaec7 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -19,6 +19,9 @@ def __init__(self, ptx_type):
             "f64": "double",
             "s32": "i32",
             "b16": "i32",
+            "b8": "i32",
+            "b8x16.b6x16_p32": "i32",
+            "b8x16.b4x16_p64": "i32",
             "s8": "i32",
             "u8": "i32",
             "s4": "i32",
@@ -161,6 +164,18 @@ def __init__(self, geom, frag, ptx_elt_type):
             "m8n8:x1:b16": 1,
             "m8n8:x2:b16": 2,
             "m8n8:x4:b16": 4,
+            "m16n16:x1:b8": 2,
+            "m16n16:x2:b8": 4,
+            "m16n16:x1:b8x16.b6x16_p32": 2,
+            "m16n16:x2:b8x16.b6x16_p32": 4,
+            "m16n16:x1:b8x16.b4x16_p64": 2,
+            "m16n16:x2:b8x16.b4x16_p64": 4,
+            "m8n16:x1:b8x16.b6x16_p32": 1,
+            "m8n16:x2:b8x16.b6x16_p32": 2,
+            "m8n16:x4:b8x16.b6x16_p32": 4,
+            "m8n16:x1:b8x16.b4x16_p64": 1,
+            "m8n16:x2:b8x16.b4x16_p64": 2,
+            "m8n16:x4:b8x16.b4x16_p64": 4,
         }.get(
             "%s:%s:%s" % (geom, frag, ptx_elt_type),
             {
@@ -289,7 +304,15 @@ def get_ldst_ops(kind):
 
 
 def get_ldmatrix_ops():
-    return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+    return (
+        make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+        + make_ldmatrix_ops(
+            ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]
+        )
+        + make_ldmatrix_ops(
+            ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]
+        )
+    )
 
 
 def is_wmma_geom_supported(geom):
@@ -330,8 +353,20 @@ def is_mma_geom_supported(geom):
 def is_ldmatrix_geom_supported(geom):
     if geom in ["m8n8"]:
         return ptx_version >= 65 and gpu_arch >= 75
+    elif geom in ["m16n16"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
+    elif geom in ["m8n16"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
     assert False  # Unexpected geometry.
 
+def is_ldmatrix_trans_supported(geom, trans):
+    if geom in ["m8n8"]:
+        return True
+    elif geom in ["m16n16"]:
+        return trans == ".trans"
+    elif geom in ["m8n16"]:
+        return trans == ""
+    assert False  # Unexpected geometry.
 
 def is_type_supported(ptx_type):
     if ptx_type in ["s8", "u8", "s32"]:
@@ -417,10 +452,11 @@ def is_ldst_variant_supported(frag, layout):
     return True
 
 
-def is_ldmatrix_variant_supported(frag):
+def is_ldmatrix_variant_supported(frag, trans):
     if not (
         is_type_supported(frag.mma_type.ptx_type)
         and is_ldmatrix_geom_supported(frag.geom)
+        and is_ldmatrix_trans_supported(frag.geom, trans)
     ):
         return False
     return frag.frag in ["x1", "x2", "x4"]
@@ -653,7 +689,7 @@ def gen_ldmatrix_tests():
         ["", ".shared"],
         ["", ".trans"],
     ):
-        if not is_ldmatrix_variant_supported(frag):
+        if not is_ldmatrix_variant_supported(frag, trans):
             continue
 
         params = {
@@ -944,6 +980,19 @@ def gen_check_unsupported_ops(items):
 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
 
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -997,13 +1046,16 @@ def gen_tests():
 def main():
     global ptx_version
     global gpu_arch
+    global aa
     parser = argparse.ArgumentParser()
     parser.add_argument("--ptx", type=int, default=60)
     parser.add_argument("--gpu-arch", type=int, default=70)
+    parser.add_argument("--aa", action="store_true")
     args = parser.parse_args()
 
     ptx_version = args.ptx
     gpu_arch = args.gpu_arch
+    aa = args.aa
 
     gen_tests()
 



More information about the llvm-commits mailing list