[llvm] [mlir] [MLIR][NVVM][NVGPU] Support intrinsics about stmatrix (PR #148377)

via llvm-commits llvm-commits at lists.llvm.org
Sat Jul 12 08:22:32 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Pecco (Pecco-314)

<details>
<summary>Changes</summary>

Add support for the `@<!-- -->llvm.nvvm.stmatrix` intrinsic series. These correspond to PTX stmatrix operations, as documented in the [PTX ISA reference](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix).

---

Patch is 29.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148377.diff


12 Files Affected:

- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+65) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+28-1) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+43-2) 
- (added) llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py (+14) 
- (modified) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py (+1-3) 
- (modified) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py (+1-3) 
- (modified) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py (+1-3) 
- (modified) llvm/test/CodeGen/NVPTX/wmma.py (+125) 
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+21-18) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp (+43) 
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (-24) 
- (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+23) 


``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0375f29ad8906..aad21fd4cba1c 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
     !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
 
+    // stmatrix b8 -> s32 @ m16n8
+    !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
+
   );
 }
 
@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
                   !subst("llvm.", "int_", intr));
 }
 
+class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
+  string intr = "llvm.nvvm.stmatrix.sync.aligned"
+                # "." # Frag.geom
+                # "." # Frag.frag
+                # !if(Trans, ".trans", "")
+                # "." # Frag.ptx_elt_type
+                ;
+  string record = !subst(".", "_",
+                  !subst("llvm.", "int_", intr));
+}
+
 // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
 //   Geom: list of supported geometries.
 //   TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
    list<string> ops = !foreach(x, ret, x.gft);
 }
 
+class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+  list<WMMA_REGS> ret =
+     !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+     !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+     !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+            [WMMA_REGS<geom, frag, type>]))))));
+   // Debugging aid for readable representation of the list above.
+   list<string> ops = !foreach(x, ret, x.gft);
+}
+
 // Creates list of valid combinations of fragments. This is the main list that
 // drives generation of corresponding intrinsics and instructions.
 class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
   list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
     ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
 
+  list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
+    ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
+
+  list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
+    ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
+
   list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
                                                  ldmatrix_geom_m16n16_ops,
                                                  ldmatrix_geom_m8n16_ops);
+
+  list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
+                                                 stmatrix_b8_ops);
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
   );
 }
 
+// Returns true if the fragment is valid for stmatrix ops is supported;
+// false otherwise.
+class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
+  string g = frag.geom;
+  string t = frag.ptx_elt_type;
+
+  bit ret = !cond(
+    !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+    !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
+    true: false
+  );
+}
+
 class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
   string Suffix = !if(sync, "sync_", "")
                   # mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
   }
 }
 
+// STMATRIX
+class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
+  : Intrinsic<[],
+          !listconcat([llvm_anyptr_ty], Frag.regs),
+          [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
+           WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
+          STMATRIX_NAME<Frag, Transposed>.intr>;
+
+foreach transposed = [0, 1] in {
+  foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
+    if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
+      def STMATRIX_NAME<frag, transposed>.record
+        : NVVM_STMATRIX<frag, transposed>;
+    }
+  }
+}
+
 // MAPA
 let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
   def int_nvvm_mapa
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3d010e04824c5..d94be492b0c02 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
-  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(4);
+    return true;
+  }
+
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::v4i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(16);
+    return true;
+  }
+
   case Intrinsic::nvvm_atomic_add_gen_f_cta:
   case Intrinsic::nvvm_atomic_add_gen_f_sys:
   case Intrinsic::nvvm_atomic_add_gen_i_cta:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 93827be5c2811..1e24bf8ab99e1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4597,7 +4597,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
 
     !and(!eq(op, "ldmatrix"),
          !eq(ptx_elt_type, "b8x16.b4x16_p64"),
-         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
+         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+    !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
+         !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
+
+    !and(!eq(op, "stmatrix"),
+         !eq(ptx_elt_type, "b8"),
+         !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -4878,6 +4885,40 @@ defset list<WMMA_INSTR> LDMATRIXs  = {
   } // transposed
 } // defset
 
+//
+// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
+  : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
+    Requires<Frag.Predicates> {
+  // Build PatFrag that only matches particular address space.
+  dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
+  PatFrag IntrFrag = PatFrag<PFOperands, !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)), 
+                             !cond(!eq(Space, ".shared"): AS_match.shared,
+                                   true: AS_match.generic)>;
+  // Build AS-constrained pattern.
+  let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+  let OutOperandList = (outs);
+  let InOperandList = !con(Args, (ins MmaCode:$ptx));
+  let AsmString = "stmatrix.sync.aligned."
+                  # Frag.geom
+                  # "." # Frag.frag
+                  # !if(Transposed, ".trans", "")
+                  # Space
+                  # "." # Frag.ptx_elt_type
+                  # " [$dst], " # Frag.regstring # ";";
+}
+
+// Create all stmatrix variants
+defset list<WMMA_INSTR> STMATRIXs = {
+  foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
+      foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
+        if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
+          def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
+    } // space
+  } // transposed
+} // defset
+
 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
 // the instruction record.
@@ -4888,7 +4929,7 @@ class MMA_PAT<WMMA_INSTR wi>
         Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
   def : MMA_PAT<mma>;
 
 multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
new file mode 100644
index 0000000000000..8f502065345c1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
@@ -0,0 +1,14 @@
+# Check all variants of instructions supported by PTX78 on SM90
+# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
+# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
+# RUN:           --check-prefixes=PTX78STMATRIX-DAG
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | FileCheck %t-ptx78-sm_90.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | %ptxas-verify -arch=sm_90                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
index 6ad0a2a5865c4..5c14a54601ed9 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM100a
 # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
 # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_100a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
index 7d9953484da7d..a77f9adddff9c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM101a
 # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
 # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_101a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
index 7bddf0b6fbb78..8126e64d6cc85 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM120a
 # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
 # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_120a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2ee489670e9e4..3888e9b6b1b8d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -10,6 +10,7 @@
 from itertools import product
 from string import Template
 
+
 class MMAType:
     def __init__(self, ptx_type):
         self.ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
             "m8n16:x1:b8x16.b4x16_p64": 1,
             "m8n16:x2:b8x16.b4x16_p64": 2,
             "m8n16:x4:b8x16.b4x16_p64": 4,
+            # stmatrix
+            "m8n8:x1:b16": 1,
+            "m8n8:x2:b16": 2,
+            "m8n8:x4:b16": 4,
+            "m16n8:x1:b8": 1,
+            "m16n8:x2:b8": 2,
+            "m16n8:x4:b8": 4,
         }.get(
             "%s:%s:%s" % (geom, frag, ptx_elt_type),
             {
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
     ]
 
 
+def make_stmatrix_ops(geoms, frags, types):
+    return [
+        MMAFrag(geom, frag, ptx_type)
+        for (geom, frag, ptx_type) in product(geoms, frags, types)
+    ]
+
+
 def get_wmma_ops():
     return (
         make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
     )
 
 
+def get_stmatrix_ops():
+    return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops(
+        ["m16n8"], ["x1", "x2", "x4"], ["b8"]
+    )
+
+
 def is_wmma_geom_supported(geom):
     # geometries for FP and ints.
     if geom in ["m8n32k16", "m32n8k16"]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
     assert False  # Unexpected geometry.
 
 
+def is_stmatrix_geom_supported(geom):
+    if geom in ["m8n8"]:
+        return ptx_version >= 78 and gpu_arch >= 90
+    elif geom in ["m16n8"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
+    assert False  # Unexpected geometry.
+
+
 def is_ldmatrix_trans_supported(geom, trans):
     if geom in ["m8n8"]:
         return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
         return trans == ""
     assert False  # Unexpected geometry.
 
+
+def is_stmatrix_trans_supported(geom, trans):
+    if geom in ["m8n8"]:
+        return True
+    elif geom in ["m16n8"]:
+        return trans == ".trans"
+    assert False  # Unexpected geometry.
+
+
 def is_type_supported(ptx_type):
     if ptx_type in ["s8", "u8", "s32"]:
         return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
     return frag.frag in ["x1", "x2", "x4"]
 
 
+def is_stmatrix_variant_supported(frag, trans):
+    if not (
+        is_type_supported(frag.mma_type.ptx_type)
+        and is_stmatrix_geom_supported(frag.geom)
+        and is_stmatrix_trans_supported(frag.geom, trans)
+    ):
+        return False
+    return frag.frag in ["x1", "x2", "x4"]
+
+
 def make_wmma_slice_ty(frag):
     return [frag.mma_type.llvm_type] * frag.nregs
 
@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
 
     return generated_items
 
+def gen_stmatrix_tests():
+    stmatrix_template = """
+declare void @${intrinsic}(i8 ${as}* %dst, ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define void @test_${function}(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
+; CHECK: {${check_args}}
+  call void @${intrinsic}(i8${as}* %dst, ${args});
+  ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
+; CHECK: {${check_args}}
+  %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
+  call void @${intrinsic}(i8 ${as}* %dst1, ${args});
+  ret void
+}
+"""
+    intrinsic_template = (
+        "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+    )
+    instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+    )
+    generated_items = []
+
+    for frag, space, trans in product(get_stmatrix_ops(),
+        ["", ".shared"],
+        ["", ".trans"],
+    ):
+        if not is_stmatrix_variant_supported(frag, trans):
+            continue
+
+        params = {
+            "frag": frag.frag,
+            "space": space,"trans": trans,
+            "itype": frag.mma_type.ptx_type,
+            "pspace": get_pspace(space),
+            "as": "addrspace(%d)" % get_aspace(space),
+            "geom": frag.geom,
+        }
+
+        test_params = params
+        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+        test_params["function"] = test_params["intrinsic"].replace(".", "_")
+        test_params["instruction"] = Template(instruction_template).substitute(params)
+        test_params["args"] = make_wmma_slice_args(frag)
+        test_params["check_args"] = check_pattern(frag)
+
+        print(Template(stmatrix_template).substitute(test_params))
+        generated_items.append((test_params["intrinsic"], test_params["instruction"]))
+
+    return generated_items
 
 def mma_signature(op):
     if op.a.mma_type.ptx_type == "f16":
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
 ; NODOUBLE-NOT: .f64
 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
+; NOSTMATRIX-NOT: stmatrix.sync.aligned
 
 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
 
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
     items = gen_wmma_load_tests()
     items += gen_wmma_store_tests()
     items += gen_ldmatrix_tests()
+    items += gen_stmatrix_tests()
     items += gen_wmma_mma_tests()
     items += gen_mma_tests()
     gen_check_unsupported_ops(items)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..8de5932aaf2e3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,10 +1990,22 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
-def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
-  Ar...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/148377


More information about the llvm-commits mailing list