[Mlir-commits] [llvm] [mlir] [MLIR][NVVM][NVGPU] Support intrinsics about stmatrix (PR #148377)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jul 12 08:21:48 PDT 2025


https://github.com/Pecco-314 created https://github.com/llvm/llvm-project/pull/148377

Add support for the `@llvm.nvvm.stmatrix` intrinsic series. These correspond to PTX stmatrix operations, as documented in the [PTX ISA reference](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix).

>From 5aed821f5dab515a59905342ec09b83bc6df336d Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Sat, 12 Jul 2025 23:17:41 +0800
Subject: [PATCH] [MLIR][NVVM][NVGPU] Support intrinsics about stmatrix

---
 llvm/include/llvm/IR/IntrinsicsNVVM.td        |  65 +++++++++
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  29 +++-
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td      |  45 ++++++-
 llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py    |  14 ++
 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py  |   4 +-
 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py  |   4 +-
 llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py  |   4 +-
 llvm/test/CodeGen/NVPTX/wmma.py               | 125 ++++++++++++++++++
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  39 +++---
 .../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp  |  43 ++++++
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   |  24 ----
 mlir/test/Target/LLVMIR/nvvmir.mlir           |  23 ++++
 12 files changed, 365 insertions(+), 54 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py

diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0375f29ad8906..aad21fd4cba1c 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
     !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
 
+    // stmatrix b8 -> s32 @ m16n8
+    !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
+
   );
 }
 
@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
                   !subst("llvm.", "int_", intr));
 }
 
+class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
+  string intr = "llvm.nvvm.stmatrix.sync.aligned"
+                # "." # Frag.geom
+                # "." # Frag.frag
+                # !if(Trans, ".trans", "")
+                # "." # Frag.ptx_elt_type
+                ;
+  string record = !subst(".", "_",
+                  !subst("llvm.", "int_", intr));
+}
+
 // Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
 //   Geom: list of supported geometries.
 //   TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
    list<string> ops = !foreach(x, ret, x.gft);
 }
 
+class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+  list<WMMA_REGS> ret =
+     !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+     !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+     !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+            [WMMA_REGS<geom, frag, type>]))))));
+   // Debugging aid for readable representation of the list above.
+   list<string> ops = !foreach(x, ret, x.gft);
+}
+
 // Creates list of valid combinations of fragments. This is the main list that
 // drives generation of corresponding intrinsics and instructions.
 class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
   list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
     ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
 
+  list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
+    ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
+
+  list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
+    ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
+
   list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
                                                  ldmatrix_geom_m16n16_ops,
                                                  ldmatrix_geom_m8n16_ops);
+
+  list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
+                                                 stmatrix_b8_ops);
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
   );
 }
 
+// Returns true if the fragment is valid for stmatrix ops is supported;
+// false otherwise.
+class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
+  string g = frag.geom;
+  string t = frag.ptx_elt_type;
+
+  bit ret = !cond(
+    !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+    !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
+    true: false
+  );
+}
+
 class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
   string Suffix = !if(sync, "sync_", "")
                   # mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
   }
 }
 
+// STMATRIX
+class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
+  : Intrinsic<[],
+          !listconcat([llvm_anyptr_ty], Frag.regs),
+          [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
+           WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
+          STMATRIX_NAME<Frag, Transposed>.intr>;
+
+foreach transposed = [0, 1] in {
+  foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
+    if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
+      def STMATRIX_NAME<frag, transposed>.record
+        : NVVM_STMATRIX<frag, transposed>;
+    }
+  }
+}
+
 // MAPA
 let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
   def int_nvvm_mapa
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3d010e04824c5..d94be492b0c02 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
-  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+  case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
     Info.opc = ISD::INTRINSIC_VOID;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
     return true;
   }
 
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(4);
+    return true;
+  }
+
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
+  case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::v4i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align = Align(16);
+    return true;
+  }
+
   case Intrinsic::nvvm_atomic_add_gen_f_cta:
   case Intrinsic::nvvm_atomic_add_gen_f_sys:
   case Intrinsic::nvvm_atomic_add_gen_i_cta:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 93827be5c2811..1e24bf8ab99e1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4597,7 +4597,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
 
     !and(!eq(op, "ldmatrix"),
          !eq(ptx_elt_type, "b8x16.b4x16_p64"),
-         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
+         !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+    !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
+         !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
+
+    !and(!eq(op, "stmatrix"),
+         !eq(ptx_elt_type, "b8"),
+         !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -4878,6 +4885,40 @@ defset list<WMMA_INSTR> LDMATRIXs  = {
   } // transposed
 } // defset
 
+//
+// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
+  : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
+    Requires<Frag.Predicates> {
+  // Build PatFrag that only matches particular address space.
+  dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
+  PatFrag IntrFrag = PatFrag<PFOperands, !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)), 
+                             !cond(!eq(Space, ".shared"): AS_match.shared,
+                                   true: AS_match.generic)>;
+  // Build AS-constrained pattern.
+  let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+  let OutOperandList = (outs);
+  let InOperandList = !con(Args, (ins MmaCode:$ptx));
+  let AsmString = "stmatrix.sync.aligned."
+                  # Frag.geom
+                  # "." # Frag.frag
+                  # !if(Transposed, ".trans", "")
+                  # Space
+                  # "." # Frag.ptx_elt_type
+                  # " [$dst], " # Frag.regstring # ";";
+}
+
+// Create all stmatrix variants
+defset list<WMMA_INSTR> STMATRIXs = {
+  foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
+      foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
+        if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
+          def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
+    } // space
+  } // transposed
+} // defset
+
 // Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
 // dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
 // the instruction record.
@@ -4888,7 +4929,7 @@ class MMA_PAT<WMMA_INSTR wi>
         Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
   def : MMA_PAT<mma>;
 
 multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
new file mode 100644
index 0000000000000..8f502065345c1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
@@ -0,0 +1,14 @@
+# Check all variants of instructions supported by PTX78 on SM90
+# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
+# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
+# RUN:           --check-prefixes=PTX78STMATRIX-DAG
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | FileCheck %t-ptx78-sm_90.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN:           | %ptxas-verify -arch=sm_90                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
index 6ad0a2a5865c4..5c14a54601ed9 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM100a
 # RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
 # RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_100a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
index 7d9953484da7d..a77f9adddff9c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM101a
 # RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
 # RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_101a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
index 7bddf0b6fbb78..8126e64d6cc85 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -1,9 +1,7 @@
 # Check all variants of instructions supported by PTX86 on SM120a
 # RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
 # RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
 # RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
 # RUN:           | FileCheck %t-ptx86-sm_120a.ll
 # RUN: %if ptxas-12.7 %{                                                  \
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2ee489670e9e4..3888e9b6b1b8d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -10,6 +10,7 @@
 from itertools import product
 from string import Template
 
+
 class MMAType:
     def __init__(self, ptx_type):
         self.ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
             "m8n16:x1:b8x16.b4x16_p64": 1,
             "m8n16:x2:b8x16.b4x16_p64": 2,
             "m8n16:x4:b8x16.b4x16_p64": 4,
+            # stmatrix
+            "m8n8:x1:b16": 1,
+            "m8n8:x2:b16": 2,
+            "m8n8:x4:b16": 4,
+            "m16n8:x1:b8": 1,
+            "m16n8:x2:b8": 2,
+            "m16n8:x4:b8": 4,
         }.get(
             "%s:%s:%s" % (geom, frag, ptx_elt_type),
             {
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
     ]
 
 
+def make_stmatrix_ops(geoms, frags, types):
+    return [
+        MMAFrag(geom, frag, ptx_type)
+        for (geom, frag, ptx_type) in product(geoms, frags, types)
+    ]
+
+
 def get_wmma_ops():
     return (
         make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
     )
 
 
+def get_stmatrix_ops():
+    return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops(
+        ["m16n8"], ["x1", "x2", "x4"], ["b8"]
+    )
+
+
 def is_wmma_geom_supported(geom):
     # geometries for FP and ints.
     if geom in ["m8n32k16", "m32n8k16"]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
     assert False  # Unexpected geometry.
 
 
+def is_stmatrix_geom_supported(geom):
+    if geom in ["m8n8"]:
+        return ptx_version >= 78 and gpu_arch >= 90
+    elif geom in ["m16n8"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
+    assert False  # Unexpected geometry.
+
+
 def is_ldmatrix_trans_supported(geom, trans):
     if geom in ["m8n8"]:
         return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
         return trans == ""
     assert False  # Unexpected geometry.
 
+
+def is_stmatrix_trans_supported(geom, trans):
+    if geom in ["m8n8"]:
+        return True
+    elif geom in ["m16n8"]:
+        return trans == ".trans"
+    assert False  # Unexpected geometry.
+
+
 def is_type_supported(ptx_type):
     if ptx_type in ["s8", "u8", "s32"]:
         return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
     return frag.frag in ["x1", "x2", "x4"]
 
 
+def is_stmatrix_variant_supported(frag, trans):
+    if not (
+        is_type_supported(frag.mma_type.ptx_type)
+        and is_stmatrix_geom_supported(frag.geom)
+        and is_stmatrix_trans_supported(frag.geom, trans)
+    ):
+        return False
+    return frag.frag in ["x1", "x2", "x4"]
+
+
 def make_wmma_slice_ty(frag):
     return [frag.mma_type.llvm_type] * frag.nregs
 
@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
 
     return generated_items
 
+def gen_stmatrix_tests():
+    stmatrix_template = """
+declare void @${intrinsic}(i8 ${as}* %dst, ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define void @test_${function}(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
+; CHECK: {${check_args}}
+  call void @${intrinsic}(i8${as}* %dst, ${args});
+  ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
+; CHECK: {${check_args}}
+  %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
+  call void @${intrinsic}(i8 ${as}* %dst1, ${args});
+  ret void
+}
+"""
+    intrinsic_template = (
+        "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+    )
+    instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+    )
+    generated_items = []
+
+    for frag, space, trans in product(get_stmatrix_ops(),
+        ["", ".shared"],
+        ["", ".trans"],
+    ):
+        if not is_stmatrix_variant_supported(frag, trans):
+            continue
+
+        params = {
+            "frag": frag.frag,
+            "space": space,"trans": trans,
+            "itype": frag.mma_type.ptx_type,
+            "pspace": get_pspace(space),
+            "as": "addrspace(%d)" % get_aspace(space),
+            "geom": frag.geom,
+        }
+
+        test_params = params
+        test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+        test_params["function"] = test_params["intrinsic"].replace(".", "_")
+        test_params["instruction"] = Template(instruction_template).substitute(params)
+        test_params["args"] = make_wmma_slice_args(frag)
+        test_params["check_args"] = check_pattern(frag)
+
+        print(Template(stmatrix_template).substitute(test_params))
+        generated_items.append((test_params["intrinsic"], test_params["instruction"]))
+
+    return generated_items
 
 def mma_signature(op):
     if op.a.mma_type.ptx_type == "f16":
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
 ; NOALTFLOAT-NOT: .{{bf16|tf32}}
 ; NODOUBLE-NOT: .f64
 ; NOLDMATRIX-NOT: ldmatrix.sync.aligned
+; NOSTMATRIX-NOT: stmatrix.sync.aligned
 
 ; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
 ; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
 ; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
 
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
     items = gen_wmma_load_tests()
     items += gen_wmma_store_tests()
     items += gen_ldmatrix_tests()
+    items += gen_stmatrix_tests()
     items += gen_wmma_mma_tests()
     items += gen_mma_tests()
     gen_check_unsupported_ops(items)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..8de5932aaf2e3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,10 +1990,22 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
-def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
-  Arguments<(ins LLVM_PointerShared:$ptr, 
-                 Variadic<I32>:$sources, 
-                 MMALayoutAttr:$layout)> {
+def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">;
+def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">;
+def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">;
+def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">;
+
+def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix",
+  [LdStMatrixShapeM8N8, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixShapeAttr : EnumAttr<NVVM_Dialect, LdStMatrixShape, "ld_st_matrix_shape"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
+  Arguments<(ins LLVM_AnyPointer: $ptr, Variadic<I32>:$sources, MMALayoutAttr:$layout, LdStMatrixShapeAttr:$shape)> {
   let summary = "cooperative matrix store";
   let description = [{
     Collectively store one or more matrices across all threads in a warp to the
@@ -2001,21 +2013,12 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
     
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
   }];
-  
-  let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
-  let extraClassDefinition = [{
-    std::string $cppClass::getPtx() {
-      int d = getSources().size();
-      std::string ptx = "stmatrix.sync.aligned";
-      ptx += ".x" + std::to_string(d);
-      if (getLayout() == NVVM::MMALayout::col)
-        ptx += ".trans";
-      if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};";
-      if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};";
-      if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
-      return ptx;
-    }
+  string llvmBuilder = [{
+      auto operands = moduleTranslation.lookupValues(opInst.getOperands());
+      auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape);
+      createIntrinsicCall(builder, intId, operands, operands[0]->getType());
   }];
+  let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index eecca64c4bf81..d03242f402ec5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -163,6 +163,49 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
   }
 }
 
+/// Return the intrinsic ID associated with stmatrix for the given paramters.
+static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout,
+                                                  int32_t num,
+                                                  NVVM::LdStMatrixShape shape) {
+  if (shape == NVVM::LdStMatrixShape::M8N8) {
+    if (layout == NVVM::MMALayout::row) {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16;
+      case 2:
+        return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16;
+      case 4:
+        return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16;
+      default:
+        llvm_unreachable("unsupported number of matrix");
+      }
+    } else {
+      switch (num) {
+      case 1:
+        return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
+      case 2:
+        return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
+      case 4:
+        return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
+      default:
+        llvm_unreachable("unsupported number of matrix");
+      }
+    }
+  } else {
+    // for 16x8 matrices, .trans is mandatory
+    switch (num) {
+    case 1:
+      return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
+    case 2:
+      return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
+    case 4:
+      return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
+    default:
+      llvm_unreachable("unsupported number of matrix");
+    }
+  }
+}
+
 /// Return the intrinsic ID associated with st.bulk for the given address type.
 static llvm::Intrinsic::ID
 getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 8d720ce62a91b..580b09d70c480 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() {
 
 // -----
 
-// CHECK-LABEL: @stmatrix(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>, 
-// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
-llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
-  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
-  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
-  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
-  nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
-  nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
-  nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
-  llvm.return 
-}
-
-// -----
-
 // CHECK-LABEL: @init_mbarrier_arrive_expect_tx
 llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
   //CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a04186f512..3be35faf091e2 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
   llvm.return
 }
 
+// CHECK-LABEL: @st_matrix
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+  // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+  nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32
+  // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+  nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32
+  // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+  nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n8>} : !llvm.ptr<3>, i32
+  // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32, i32
+  // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32, i32  
+  // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n8>} : !llvm.ptr<3>, i32, i32
+  // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32, i32, i32, i32
+  // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32, i32, i32, i32
+  // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+  nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n8>} : !llvm.ptr<3>, i32, i32, i32, i32
+  llvm.return
+}
+
 // This function has the "kernel" attribute attached and should appear in the
 // NVVM annotations after conversion.
 llvm.func @kernel_func() attributes {nvvm.kernel} {



More information about the Mlir-commits mailing list