[llvm] 1b4c85f - [NVPTX] Add NVPTX intrinsics for CUDA PTX 6.5 ldmatrix instructions

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 6 16:14:17 PDT 2021


Author: Steffen Larsen
Date: 2021-08-06T16:13:35-07:00
New Revision: 1b4c85fc02cc87b4abcd794c98e6ff91a3d3766b

URL: https://github.com/llvm/llvm-project/commit/1b4c85fc02cc87b4abcd794c98e6ff91a3d3766b
DIFF: https://github.com/llvm/llvm-project/commit/1b4c85fc02cc87b4abcd794c98e6ff91a3d3766b.diff

LOG: [NVPTX] Add NVPTX intrinsics for CUDA PTX 6.5 ldmatrix instructions

Adds NVPTX intrinsics for the CUDA PTX `ldmatrix.sync.aligned` instructions added in PTX 6.5.

PTX ISA description of `ldmatrix.sync.aligned`: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix

Authored-by: Steffen Larsen <steffen.larsen at codeplay.com>

Reviewed By: tra

Differential Revision: https://reviews.llvm.org/D107046

Added: 
    

Modified: 
    llvm/include/llvm/IR/IntrinsicsNVVM.td
    llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
    llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
    llvm/test/CodeGen/NVPTX/wmma.py

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index cc43d23bec1c..6676303c2fef 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -43,7 +43,7 @@ def llvm_shared_i64ptr_ty : LLVMQualPointerType<llvm_i64_ty, 3>; // (shared)i64*
 
 // Helper class that represents a 'fragment' of an NVPTX *MMA instruction.
 // Geom: m<M>n<N>k<K>. E.g. m8n32k16
-// Frag: [abcd]
+// Frag: [a|b|c|d] ([x1|x2|x4] for ldmatrix)
 // PtxEltType: PTX type for the element.
 class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
   string geom = Geom;
@@ -190,6 +190,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2),
     !eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4),
     !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+    // ldmatrix b16 -> s32 @ m8n8
+    !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
+    !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
+    !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
   );
 }
 
@@ -256,6 +261,17 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
                   !subst("llvm.", "int_", llvm));
 }
 
+class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
+  string intr = "llvm.nvvm.ldmatrix.sync.aligned"
+                # "." # Frag.geom
+                # "." # Frag.frag
+                # !if(Trans, ".trans", "")
+                # "." # Frag.ptx_elt_type
+                ;
+  string record = !subst(".", "_",
+                  !subst("llvm.", "int_", intr));
+}
+
 // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
 //   Geom: list of supported geometries.
 //   TypeN: PTX type of the corresponding fragment's element.
@@ -286,6 +302,16 @@ class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
    list<string> ops = !foreach(x, ret, x.gft);
 }
 
+class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+  list<WMMA_REGS> ret =
+     !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+     !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+     !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+            [WMMA_REGS<geom, frag, type>]))))));
+   // Debugging aid for readable representation of the list above.
+   list<string> ops = !foreach(x, ret, x.gft);
+}
+
 // Creates list of valid combinations of fragments. This is the master list that
 // drives generation of corresponding intrinsics and instructions.
 class NVVM_MMA_OPS<int _ = 0> {
@@ -370,11 +396,14 @@ class NVVM_MMA_OPS<int _ = 0> {
   // Separate A/B/C fragments (loads) from D (stores).
   list<WMMA_REGS> all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d"));
   list<WMMA_REGS> all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d"));
+
+  list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
+    ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
+  list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
 
-
 // Returns true if this combination of fragment and layout for WMMA load/store
 // ops is supported; false otherwise.
 // E.g.
@@ -489,6 +518,23 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
   );
 }
 
+// Returns true if the fragment is valid for ldmatrix ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_LDMATRIX_SUPPORTED<...>.ret then
+//   def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
+  string g = frag.geom;
+  string t = frag.ptx_elt_type;
+
+  bit ret = !cond(
+    // Only currently support m8n8 and b16
+    !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+    true: false
+  );
+}
+
 class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
   string Suffix = !if(sync, "sync_", "")
                   # mode # "_"
@@ -4519,4 +4565,20 @@ foreach layout_a = ["row", "col"] in {
   } // layout_b
 } // layout_a
 
+// LDMATRIX
+class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
+  : Intrinsic<Frag.regs, [llvm_anyptr_ty],
+              [IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>,
+               NoCapture<ArgIndex<0>>],
+              LDMATRIX_NAME<Frag, Transposed>.intr>;
+
+foreach transposed = [0, 1] in {
+  foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
+    if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
+      def LDMATRIX_NAME<frag, transposed>.record
+        : NVVM_LDMATRIX<frag, transposed>;
+    }
+  }
+}
+
 } // let TargetPrefix = "nvvm"

diff  --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d4842c953ce7..0922eac58fe9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3547,7 +3547,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col:
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_col_stride:
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
-  case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride: {
+  case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v4i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3585,7 +3587,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col:
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_s4_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
-  case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col: {
+  case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3679,7 +3683,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col:
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
-  case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride: {
+  case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);

diff  --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index de4bf2ef3055..138f32bd2bd2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7578,6 +7578,7 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
     !eq(ptx_elt_type, "bf16") : Int32Regs,
     !eq(ptx_elt_type, "tf32") : Int32Regs,
     !eq(ptx_elt_type, "s32") : Int32Regs,
+    !eq(ptx_elt_type, "b16") : Int32Regs,
     !eq(ptx_elt_type, "s8") : Int32Regs,
     !eq(ptx_elt_type, "u8") : Int32Regs,
     !eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7661,7 +7662,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
              !eq(geom, "m16n8k64"),
              !eq(geom, "m8n8k128"),
              !eq(geom, "m16n8k128"),
-             !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70]);
+             !eq(geom, "m16n8k256"))) : [hasSM80, hasPTX70],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b16"),
+         !eq(geom, "m8n8")) : [hasSM75, hasPTX65]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7910,6 +7915,44 @@ defset list<WMMA_INSTR> MMAs  = {
   } // layout_a
 } // defset
 
+//
+// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class LDMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space,
+               DAGOperand SrcOp>
+  : WMMA_INSTR<LDMATRIX_NAME<Frag, Transposed>.record, [(ins SrcOp:$src)]>,
+    Requires<Frag.Predicates> {
+  // Build PatFrag that only matches particular address space.
+  PatFrag IntrFrag = PatFrag<(ops node:$src), (Intr node:$src),
+                             !cond(!eq(Space, ".shared"): AS_match.shared,
+                                   true: AS_match.generic)>;
+  // Build AS-constrained pattern.
+  let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+
+  let OutOperandList = Frag.Outs;
+  let InOperandList = !con(Args, (ins MmaCode:$ptx));
+  let AsmString = "ldmatrix.sync.aligned."
+                  # Frag.geom
+                  # "." # Frag.frag
+                  # !if(Transposed, ".trans", "")
+                  # Space
+                  # "." # Frag.ptx_elt_type
+                  # " " # Frag.regstring # ", [$src];";
+}
+
+// Create all ldmatrix variants
+defset list<WMMA_INSTR> LDMATRIXs  = {
+  foreach transposed = [false, true] in {
+    foreach space = [".shared", ""] in {
+      foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
+        foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
+          if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
+            def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space,
+                            addr>;
+      } // addr
+    } // space
+  } // transposed
+} // defset
 
 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
@@ -7921,5 +7964,5 @@ class MMA_PAT<WMMA_INSTR wi>
         Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
   def : MMA_PAT<mma>;

diff  --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 785e48ce75a2..3b3d10947cac 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -6,7 +6,7 @@
 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16
 # RUN: FileCheck %t-ptx60-sm_70.ll < %t-ptx60-sm_70.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
+# RUN:           --check-prefixes=INTRINSICS,NOEXTGEOM,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
 # RUN: llc < %t-ptx60-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx60 \
 # RUN:           | FileCheck %t-ptx60-sm_70.ll
 
@@ -15,7 +15,7 @@
 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM
 # RUN: FileCheck %t-ptx61-sm_70.ll < %t-ptx61-sm_70.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
+# RUN:           --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
 # RUN: llc < %t-ptx61-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx61 \
 # RUN:           | FileCheck %t-ptx61-sm_70.ll
 
@@ -24,7 +24,7 @@
 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT
 # RUN: FileCheck %t-ptx63-sm_72.ll < %t-ptx63-sm_72.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT
+# RUN:           --check-prefixes=INTRINSICS,NOSUBINT,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
 # RUN: llc < %t-ptx63-sm_72.ll -march=nvptx64 -mcpu=sm_72 -mattr=+ptx63 \
 # RUN:           | FileCheck %t-ptx63-sm_72.ll
 
@@ -33,7 +33,7 @@
 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT
 # RUN: FileCheck %t-ptx63-sm_75.ll < %t-ptx63-sm_75.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT
+# RUN:           --check-prefixes=INTRINSICS,NOMMA,NODOUBLE,NOALTFLOAT,NOLDMATRIX
 # RUN: llc < %t-ptx63-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx63 \
 # RUN:           | FileCheck %t-ptx63-sm_75.ll
 
@@ -42,14 +42,14 @@
 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
 # RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,MMA
 # RUN: FileCheck %t-ptx64-sm_70.ll < %t-ptx64-sm_70.ll \
-# RUN:           --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT
+# RUN:           --check-prefixes=INTRINSICS,NOINT,NOSUBINT,NODOUBLE,NOALTFLOAT,NOLDMATRIX
 # RUN: llc < %t-ptx64-sm_70.ll -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 \
 # RUN:           | FileCheck %t-ptx64-sm_70.ll
 
 # Check all variants of instructions supported by PTX65 on SM75+
 # RUN: %python %s --ptx=65 --gpu-arch=75 > %t-ptx65-sm_75.ll
 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
-# RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA
+# RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,PTX65MMA,PTX65LDMATRIX
 # RUN: FileCheck %t-ptx65-sm_75.ll < %t-ptx65-sm_75.ll \
 # RUN:           --check-prefixes=INTRINSICS
 # RUN: llc < %t-ptx65-sm_75.ll -march=nvptx64 -mcpu=sm_75 -mattr=+ptx65 \
@@ -58,7 +58,7 @@
 # Check all variants of instructions supported by PTX71 on SM80+
 # RUN: %python %s --ptx=71 --gpu-arch=80 > %t-ptx71-sm_80.ll
 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
-# RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX71MMA
+# RUN:           --check-prefixes=INTRINSICS,M16N16,EXTGEOM,INT,SUBINT,MMA,ALTFLOAT,DOUBLE,PTX65MMA,PTX65LDMATRIX,PTX71MMA
 # RUN: FileCheck %t-ptx71-sm_80.ll < %t-ptx71-sm_80.ll \
 # RUN:           --check-prefixes=INTRINSICS
 # RUN: llc < %t-ptx71-sm_80.ll -march=nvptx64 -mcpu=sm_80 -mattr=+ptx71 \
@@ -78,6 +78,7 @@ def __init__(self, ptx_type):
         "f32"  : "float",
         "f64"  : "double",
         "s32"  : "i32",
+        "b16"  : "i32",
         "s8"   : "i32",
         "u8"   : "i32",
         "s4"   : "i32",
@@ -232,6 +233,11 @@ def __init__(self, geom, frag, ptx_elt_type):
         "m16n8k16:d:f16": 2,
         "m16n8k16:c:f32": 4,
         "m16n8k16:d:f32": 4,
+
+        # ldmatrix
+        "m8n8:x1:b16": 1,
+        "m8n8:x2:b16": 2,
+        "m8n8:x4:b16": 4,
     }.get("%s:%s:%s" % (geom, frag, ptx_elt_type), {
         # All other FP shape/fragment/type combinations have the same size
         "a:f16" : 8,
@@ -272,6 +278,10 @@ def make_ldst_ops(geoms, frags, types):
   return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
           in product(geoms, frags, types)]
 
+def make_ldmatrix_ops(geoms, frags, types):
+  return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
+          in product(geoms, frags, types)]
+
 def get_wmma_ops():
   return (make_mma_ops(["m16n16k8"],
                        ["tf32"], [], ["f32"], []) +
@@ -317,6 +327,9 @@ def get_ldst_ops(kind):
               make_ldst_ops(["m16n16k8"], ["c", "d"], ["f32"]))
   return [ x for x in ldst_ops if (x.frag == "d") == (kind == "store")]
 
+def get_ldmatrix_ops():
+  return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+
 def is_wmma_geom_supported(geom):
   # geometries for FP and ints.
   if geom in ["m8n32k16", "m32n8k16"]:
@@ -343,11 +356,18 @@ def is_mma_geom_supported(geom):
     return ptx_version >= 70
   assert(False) # Unexpected geometry.
 
+def is_ldmatrix_geom_supported(geom):
+  if geom in ["m8n8"]:
+    return ptx_version >= 65 and gpu_arch >= 75
+  assert(False) # Unexpected geometry.
+
 def is_type_supported(ptx_type):
   if ptx_type in ["s8", "u8", "s32"]:
     return ptx_version >= 63 and gpu_arch >= 72
   if ptx_type in ["s4", "u4", "b1"]:
     return ptx_version >= 63 and gpu_arch >= 75
+  if ptx_type == "b16":
+    return ptx_version >= 65 and gpu_arch >= 75
   if ptx_type in ["bf16", "tf32", "f64"]:
     return ptx_version >= 70
   return ptx_version >= 60 and gpu_arch >= 70
@@ -413,6 +433,12 @@ def is_ldst_variant_supported(frag, layout):
             or frag.frag in ["c", "d"])
   return True
 
+def is_ldmatrix_variant_supported(frag):
+  if not (is_type_supported(frag.mma_type.ptx_type)
+          and is_ldmatrix_geom_supported(frag.geom)):
+    return False
+  return frag.frag in ["x1", "x2", "x4"]
+
 def make_wmma_slice_ty(frag):
   return [frag.mma_type.llvm_type] * frag.nregs
 
@@ -584,6 +610,66 @@ def gen_wmma_store_tests():
 
   return generated_items
 
+def gen_ldmatrix_tests():
+  ldmatrix_template = """
+declare ${ret_ty} @${intrinsic}(i8 ${as}* %src);
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define ${ret_ty} @test_${function}(i8 ${as}* %src) {
+; CHECK: ${instruction}
+; CHECK: {${check_result}}
+; CHECK: [%rd{{[0-9]+}}]
+  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src);
+  ret ${ret_ty} %v0;
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define ${ret_ty} @test_${function}_o(i8 ${as}* %src) {
+; CHECK: ${instruction}
+; CHECK: {${check_result}}
+; CHECK: [%rd{{[0-9]+}}+128]
+  %src1 = getelementptr i8, i8 ${as}* %src, i32 128;
+  %v0 = call ${ret_ty} @${intrinsic}(i8 ${as}* %src1);
+  ret ${ret_ty} %v0;
+}
+"""
+  intrinsic_template = "llvm.nvvm.ldmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+  instruction_template = "ldmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+
+  generated_items = []
+
+  for frag, space, trans in product(
+      get_ldmatrix_ops(),
+      ["",".shared"],
+      ["",".trans"],
+      ):
+    if not is_ldmatrix_variant_supported(frag):
+      continue
+
+    params = {
+        "frag" : frag.frag,
+        "space" : space,
+        "trans" : trans,
+        "itype" : frag.mma_type.ptx_type,
+        "pspace" : get_pspace(space),
+        "as"     : "addrspace(%d)" % get_aspace(space),
+        "geom"   : frag.geom,
+    }
+
+    test_params = params
+    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["function"] = test_params["intrinsic"].replace(".","_")
+    test_params["instruction"] = Template(instruction_template).substitute(params)
+    test_params["ret_ty"] = make_wmma_ld_ret_ty(frag)
+    test_params["check_result"] = check_pattern(frag)
+
+    print(Template(ldmatrix_template).substitute(test_params))
+
+    generated_items.append((test_params["intrinsic"],
+                            test_params["instruction"]))
+
+  return generated_items
+
 def mma_signature(op):
   if op.a.mma_type.ptx_type == "f16":
     # FP16 ops identified by accumulator & result type.
@@ -744,6 +830,7 @@ def gen_check_unsupported_ops(items):
 ; NOMMA-NOT: .m8n8k4.
 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
 ; NODOUBLE-NOT: .f64
+; NOLDMATRIX-NOT: ldmatrix.sync.aligned
 
 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -819,6 +906,19 @@ def gen_check_unsupported_ops(items):
 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.s4.u4
 ; PTX65MMA-DAG: mma.m8n8k32.row.col{{.*}}.u4.s4
 
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -861,6 +961,7 @@ def gen_check_unsupported_ops(items):
 def gen_tests():
   items = gen_wmma_load_tests()
   items += gen_wmma_store_tests()
+  items += gen_ldmatrix_tests()
   items += gen_wmma_mma_tests()
   items += gen_mma_tests()
   gen_check_unsupported_ops(items)


        


More information about the llvm-commits mailing list