[Mlir-commits] [llvm] [mlir] [MLIR][NVVM][NVGPU] Support intrinsics about stmatrix (PR #148377)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jul 12 08:21:48 PDT 2025
https://github.com/Pecco-314 created https://github.com/llvm/llvm-project/pull/148377
Add support for the `@llvm.nvvm.stmatrix` intrinsic series. These correspond to PTX stmatrix operations, as documented in the [PTX ISA reference](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix).
>From 5aed821f5dab515a59905342ec09b83bc6df336d Mon Sep 17 00:00:00 2001
From: Gao Yanfeng <gaoyanfeng at linux.alibaba.com>
Date: Sat, 12 Jul 2025 23:17:41 +0800
Subject: [PATCH] [MLIR][NVVM][NVGPU] Support intrinsics about stmatrix
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 65 +++++++++
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 29 +++-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 45 ++++++-
llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py | 14 ++
llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py | 4 +-
llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py | 4 +-
llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py | 4 +-
llvm/test/CodeGen/NVPTX/wmma.py | 125 ++++++++++++++++++
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 39 +++---
.../Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 43 ++++++
.../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 24 ----
mlir/test/Target/LLVMIR/nvvmir.mlir | 23 ++++
12 files changed, 365 insertions(+), 54 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0375f29ad8906..aad21fd4cba1c 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
!eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+ // stmatrix b8 -> s32 @ m16n8
+ !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
+
);
}
@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
!subst("llvm.", "int_", intr));
}
+class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
+ string intr = "llvm.nvvm.stmatrix.sync.aligned"
+ # "." # Frag.geom
+ # "." # Frag.frag
+ # !if(Trans, ".trans", "")
+ # "." # Frag.ptx_elt_type
+ ;
+ string record = !subst(".", "_",
+ !subst("llvm.", "int_", intr));
+}
+
// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
// Geom: list of supported geometries.
// TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<string> ops = !foreach(x, ret, x.gft);
}
+class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
+ list<WMMA_REGS> ret =
+ !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
+ !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
+ !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
+ [WMMA_REGS<geom, frag, type>]))))));
+ // Debugging aid for readable representation of the list above.
+ list<string> ops = !foreach(x, ret, x.gft);
+}
+
// Creates list of valid combinations of fragments. This is the main list that
// drives generation of corresponding intrinsics and instructions.
class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+ list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
+ ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
+
+ list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
+ ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
+
list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
ldmatrix_geom_m16n16_ops,
ldmatrix_geom_m8n16_ops);
+
+ list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
+ stmatrix_b8_ops);
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
);
}
+// Returns true if the fragment is valid for stmatrix ops is supported;
+// false otherwise.
+class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
+ string g = frag.geom;
+ string t = frag.ptx_elt_type;
+
+ bit ret = !cond(
+ !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+ !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
+ true: false
+ );
+}
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
}
}
+// STMATRIX
+class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
+ : Intrinsic<[],
+ !listconcat([llvm_anyptr_ty], Frag.regs),
+ [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
+ WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
+ STMATRIX_NAME<Frag, Transposed>.intr>;
+
+foreach transposed = [0, 1] in {
+ foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
+ if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
+ def STMATRIX_NAME<frag, transposed>.record
+ : NVVM_STMATRIX<frag, transposed>;
+ }
+ }
+}
+
// MAPA
let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
def int_nvvm_mapa
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 3d010e04824c5..d94be492b0c02 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3952,7 +3952,10 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col:
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row:
- case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride: {
+ case Intrinsic::nvvm_wmma_m8n8k32_store_d_s32_row_stride:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8: {
Info.opc = ISD::INTRINSIC_VOID;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3975,6 +3978,30 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align = Align(4);
+ return true;
+ }
+
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16:
+ case Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::v4i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align = Align(16);
+ return true;
+ }
+
case Intrinsic::nvvm_atomic_add_gen_f_cta:
case Intrinsic::nvvm_atomic_add_gen_f_sys:
case Intrinsic::nvvm_atomic_add_gen_i_cta:
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 93827be5c2811..1e24bf8ab99e1 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4597,7 +4597,14 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!and(!eq(op, "ldmatrix"),
!eq(ptx_elt_type, "b8x16.b4x16_p64"),
- !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
+ !eq(geom, "m8n16")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>],
+
+ !and(!eq(op, "stmatrix"),!eq(ptx_elt_type, "b16"),
+ !eq(geom, "m8n8")) : [hasSM<90>, hasPTX<78>],
+
+ !and(!eq(op, "stmatrix"),
+ !eq(ptx_elt_type, "b8"),
+ !eq(geom, "m16n8")) : [hasSM<100>, hasArchAccelFeatures, hasPTX<86>]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -4878,6 +4885,40 @@ defset list<WMMA_INSTR> LDMATRIXs = {
} // transposed
} // defset
+//
+// stmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
+//
+class STMATRIX<WMMA_REGINFO Frag, bit Transposed, string Space>
+ : WMMA_INSTR<STMATRIX_NAME<Frag, Transposed>.record, [!con((ins ADDR:$dst), Frag.Ins)]>,
+ Requires<Frag.Predicates> {
+ // Build PatFrag that only matches particular address space.
+ dag PFOperands = !con((ops node:$dst), !dag(ops, !listsplat(node, !size(Frag.regs)), Frag.reg_names));
+ PatFrag IntrFrag = PatFrag<PFOperands, !foreach(tmp, PFOperands, !subst(ops, Intr, tmp)),
+ !cond(!eq(Space, ".shared"): AS_match.shared,
+ true: AS_match.generic)>;
+ // Build AS-constrained pattern.
+ let IntrinsicPattern = BuildPatternPF<IntrFrag, Args>.ret;
+ let OutOperandList = (outs);
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ let AsmString = "stmatrix.sync.aligned."
+ # Frag.geom
+ # "." # Frag.frag
+ # !if(Transposed, ".trans", "")
+ # Space
+ # "." # Frag.ptx_elt_type
+ # " [$dst], " # Frag.regstring # ";";
+}
+
+// Create all stmatrix variants
+defset list<WMMA_INSTR> STMATRIXs = {
+ foreach transposed = [false, true] in {foreach space = [".shared", ""] in {
+ foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in
+ if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then
+ def : STMATRIX<WMMA_REGINFO<frag, "stmatrix">, transposed, space>;
+ } // space
+ } // transposed
+} // defset
+
// Constructing non-flat DAGs is still a pain. I can't !subst a dag node with a
// dag, so the ptx.version must be appended *after* foreach replaces 'ins' with
// the instruction record.
@@ -4888,7 +4929,7 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
new file mode 100644
index 0000000000000..8f502065345c1
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx78-sm90.py
@@ -0,0 +1,14 @@
+# Check all variants of instructions supported by PTX78 on SM90
+# RUN: %python %s --ptx=78 --gpu-arch=90 --aa > %t-ptx78-sm_90.ll
+# RUN: FileCheck %t-ptx78-sm_90.ll < %t-ptx78-sm_90.ll \
+# RUN: --check-prefixes=PTX78STMATRIX-DAG
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN: | FileCheck %t-ptx78-sm_90.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx78-sm_90.ll -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78 \
+# RUN: | %ptxas-verify -arch=sm_90 \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
index 6ad0a2a5865c4..5c14a54601ed9 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -1,9 +1,7 @@
# Check all variants of instructions supported by PTX86 on SM100a
# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_100a.ll
# RUN: %if ptxas-12.7 %{ \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
index 7d9953484da7d..a77f9adddff9c 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -1,9 +1,7 @@
# Check all variants of instructions supported by PTX86 on SM101a
# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_101a.ll
# RUN: %if ptxas-12.7 %{ \
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
index 7bddf0b6fbb78..8126e64d6cc85 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -1,9 +1,7 @@
# Check all variants of instructions supported by PTX86 on SM120a
# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
-# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
-# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG,PTX86STMATRIX-DAG
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
# RUN: | FileCheck %t-ptx86-sm_120a.ll
# RUN: %if ptxas-12.7 %{ \
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2ee489670e9e4..3888e9b6b1b8d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -10,6 +10,7 @@
from itertools import product
from string import Template
+
class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
@@ -176,6 +177,13 @@ def __init__(self, geom, frag, ptx_elt_type):
"m8n16:x1:b8x16.b4x16_p64": 1,
"m8n16:x2:b8x16.b4x16_p64": 2,
"m8n16:x4:b8x16.b4x16_p64": 4,
+ # stmatrix
+ "m8n8:x1:b16": 1,
+ "m8n8:x2:b16": 2,
+ "m8n8:x4:b16": 4,
+ "m16n8:x1:b8": 1,
+ "m16n8:x2:b8": 2,
+ "m16n8:x4:b8": 4,
}.get(
"%s:%s:%s" % (geom, frag, ptx_elt_type),
{
@@ -241,6 +249,13 @@ def make_ldmatrix_ops(geoms, frags, types):
]
+def make_stmatrix_ops(geoms, frags, types):
+ return [
+ MMAFrag(geom, frag, ptx_type)
+ for (geom, frag, ptx_type) in product(geoms, frags, types)
+ ]
+
+
def get_wmma_ops():
return (
make_mma_ops(["m16n16k8"], ["tf32"], [], ["f32"], [])
@@ -315,6 +330,12 @@ def get_ldmatrix_ops():
)
+def get_stmatrix_ops():
+ return make_stmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"]) + make_stmatrix_ops(
+ ["m16n8"], ["x1", "x2", "x4"], ["b8"]
+ )
+
+
def is_wmma_geom_supported(geom):
# geometries for FP and ints.
if geom in ["m8n32k16", "m32n8k16"]:
@@ -360,6 +381,14 @@ def is_ldmatrix_geom_supported(geom):
assert False # Unexpected geometry.
+def is_stmatrix_geom_supported(geom):
+ if geom in ["m8n8"]:
+ return ptx_version >= 78 and gpu_arch >= 90
+ elif geom in ["m16n8"]:
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
+ assert False # Unexpected geometry.
+
+
def is_ldmatrix_trans_supported(geom, trans):
if geom in ["m8n8"]:
return True
@@ -369,6 +398,15 @@ def is_ldmatrix_trans_supported(geom, trans):
return trans == ""
assert False # Unexpected geometry.
+
+def is_stmatrix_trans_supported(geom, trans):
+ if geom in ["m8n8"]:
+ return True
+ elif geom in ["m16n8"]:
+ return trans == ".trans"
+ assert False # Unexpected geometry.
+
+
def is_type_supported(ptx_type):
if ptx_type in ["s8", "u8", "s32"]:
return ptx_version >= 63 and gpu_arch >= 72
@@ -463,6 +501,16 @@ def is_ldmatrix_variant_supported(frag, trans):
return frag.frag in ["x1", "x2", "x4"]
+def is_stmatrix_variant_supported(frag, trans):
+ if not (
+ is_type_supported(frag.mma_type.ptx_type)
+ and is_stmatrix_geom_supported(frag.geom)
+ and is_stmatrix_trans_supported(frag.geom, trans)
+ ):
+ return False
+ return frag.frag in ["x1", "x2", "x4"]
+
+
def make_wmma_slice_ty(frag):
return [frag.mma_type.llvm_type] * frag.nregs
@@ -716,6 +764,61 @@ def gen_ldmatrix_tests():
return generated_items
+def gen_stmatrix_tests():
+ stmatrix_template = """
+declare void @${intrinsic}(i8 ${as}* %dst, ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define void @test_${function}(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}]
+; CHECK: {${check_args}}
+ call void @${intrinsic}(i8${as}* %dst, ${args});
+ ret void
+}
+
+; CHECK-LABEL: .func{{.*}}test_${function}_o(
+define void @test_${function}_o(i8 ${as}* %dst, ${args}) {
+; CHECK: ${instruction} {{.*}}[%rd{{[0-9+]}}+128],
+; CHECK: {${check_args}}
+ %dst1 = getelementptr i8, i8 ${as}* %dst, i32 128;
+ call void @${intrinsic}(i8 ${as}* %dst1, ${args});
+ ret void
+}
+"""
+ intrinsic_template = (
+ "llvm.nvvm.stmatrix.sync.aligned.${geom}.${frag}${trans}.${itype}.${pspace}"
+ )
+ instruction_template = ("stmatrix.sync.aligned.${geom}.${frag}${trans}${space}.${itype}"
+ )
+ generated_items = []
+
+ for frag, space, trans in product(get_stmatrix_ops(),
+ ["", ".shared"],
+ ["", ".trans"],
+ ):
+ if not is_stmatrix_variant_supported(frag, trans):
+ continue
+
+ params = {
+ "frag": frag.frag,
+ "space": space,"trans": trans,
+ "itype": frag.mma_type.ptx_type,
+ "pspace": get_pspace(space),
+ "as": "addrspace(%d)" % get_aspace(space),
+ "geom": frag.geom,
+ }
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["args"] = make_wmma_slice_args(frag)
+ test_params["check_args"] = check_pattern(frag)
+
+ print(Template(stmatrix_template).substitute(test_params))
+ generated_items.append((test_params["intrinsic"], test_params["instruction"]))
+
+ return generated_items
def mma_signature(op):
if op.a.mma_type.ptx_type == "f16":
@@ -893,6 +996,7 @@ def gen_check_unsupported_ops(items):
; NOALTFLOAT-NOT: .{{bf16|tf32}}
; NODOUBLE-NOT: .f64
; NOLDMATRIX-NOT: ldmatrix.sync.aligned
+; NOSTMATRIX-NOT: stmatrix.sync.aligned
; M16N16-DAG: m16n16k16.load.{{[ab].*}}.f16.p
; M16N16-DAG: m16n16k16.{{load|store}}.{{[cd].*\.(f16|f32)}}.p
@@ -994,6 +1098,26 @@ def gen_check_unsupported_ops(items):
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x1.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x2.trans.shared.b16
+; PTX78STMATRIX-DAG: stmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x1.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x2.trans.shared.b8
+; PTX86STMATRIX-DAG: stmatrix.sync.aligned.m16n8.x4.trans.shared.b8
+
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -1039,6 +1163,7 @@ def gen_tests():
items = gen_wmma_load_tests()
items += gen_wmma_store_tests()
items += gen_ldmatrix_tests()
+ items += gen_stmatrix_tests()
items += gen_wmma_mma_tests()
items += gen_mma_tests()
gen_check_unsupported_ops(items)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 45a8904375e2b..8de5932aaf2e3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1990,10 +1990,22 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
let hasVerifier = 1;
}
-def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
- Arguments<(ins LLVM_PointerShared:$ptr,
- Variadic<I32>:$sources,
- MMALayoutAttr:$layout)> {
+def LdStMatrixShapeM8N8 : I32EnumAttrCase<"M8N8", 0, "m8n8">;
+def LdStMatrixShapeM8N16 : I32EnumAttrCase<"M8N16", 1, "m8n16">;
+def LdStMatrixShapeM16N8 : I32EnumAttrCase<"M16N8", 2, "m16n8">;
+def LdStMatrixShapeM16N16 : I32EnumAttrCase<"M16N16", 3, "m16n16">;
+
+def LdStMatrixShape : I32EnumAttr<"LdStMatrixShape", "Matrix shape for ldmatrix and stmatrix",
+ [LdStMatrixShapeM8N8, LdStMatrixShapeM16N8, LdStMatrixShapeM16N16]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def LdStMatrixShapeAttr : EnumAttr<NVVM_Dialect, LdStMatrixShape, "ld_st_matrix_shape"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
+ Arguments<(ins LLVM_AnyPointer: $ptr, Variadic<I32>:$sources, MMALayoutAttr:$layout, LdStMatrixShapeAttr:$shape)> {
let summary = "cooperative matrix store";
let description = [{
Collectively store one or more matrices across all threads in a warp to the
@@ -2001,21 +2013,12 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
}];
-
- let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- int d = getSources().size();
- std::string ptx = "stmatrix.sync.aligned";
- ptx += ".x" + std::to_string(d);
- if (getLayout() == NVVM::MMALayout::col)
- ptx += ".trans";
- if(d == 1) ptx += ".m8n8.shared.b16 [%0], {%1};";
- if(d == 2) ptx += ".m8n8.shared.b16 [%0], {%1, %2};";
- if(d == 4) ptx += ".m8n8.shared.b16 [%0], {%1, %2, %3, %4};";
- return ptx;
- }
+ string llvmBuilder = [{
+ auto operands = moduleTranslation.lookupValues(opInst.getOperands());
+ auto intId = getStMatrixIntrinsicId($layout, $sources.size(), $shape);
+ createIntrinsicCall(builder, intId, operands, operands[0]->getType());
}];
+ let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index eecca64c4bf81..d03242f402ec5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -163,6 +163,49 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
}
}
+/// Return the intrinsic ID associated with stmatrix for the given paramters.
+static llvm::Intrinsic::ID getStMatrixIntrinsicId(NVVM::MMALayout layout,
+ int32_t num,
+ NVVM::LdStMatrixShape shape) {
+ if (shape == NVVM::LdStMatrixShape::M8N8) {
+ if (layout == NVVM::MMALayout::row) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16;
+ case 2:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16;
+ case 4:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16;
+ default:
+ llvm_unreachable("unsupported number of matrix");
+ }
+ } else {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
+ case 2:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
+ case 4:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
+ default:
+ llvm_unreachable("unsupported number of matrix");
+ }
+ }
+ } else {
+ // for 16x8 matrices, .trans is mandatory
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
+ case 2:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
+ case 4:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
+ default:
+ llvm_unreachable("unsupported number of matrix");
+ }
+ }
+}
+
/// Return the intrinsic ID associated with st.bulk for the given address type.
static llvm::Intrinsic::ID
getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 8d720ce62a91b..580b09d70c480 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -580,30 +580,6 @@ func.func @elect_one_leader_sync() {
// -----
-// CHECK-LABEL: @stmatrix(
-// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !llvm.ptr<3>,
-// CHECK-SAME: %[[arg1:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg2:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg3:[a-zA-Z0-9_]+]]: i32,
-// CHECK-SAME: %[[arg4:[a-zA-Z0-9_]+]]: i32)
-llvm.func @stmatrix(%arg0 : !llvm.ptr<3>, %m1 : i32, %m2 : i32, %m3 : i32, %m4 : i32) {
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [$0], {$1};", "r,r" %[[arg0]], %[[arg1]] : (!llvm.ptr<3>, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [$0], {$1, $2};", "r,r,r" %[[arg0]], %[[arg1]], %[[arg2]] : (!llvm.ptr<3>, i32, i32) -> ()
-// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [$0], {$1, $2, $3, $4};", "r,r,r,r,r" %[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]], %[[arg4]] : (!llvm.ptr<3>, i32, i32, i32, i32) -> ()
- nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32
- nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32
- nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<row>} : !llvm.ptr<3>, i32, i32, i32, i32
- nvvm.stmatrix %arg0, %m1 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32
- nvvm.stmatrix %arg0, %m1, %m2 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32
- nvvm.stmatrix %arg0, %m1, %m2, %m3, %m4 {layout = #nvvm.mma_layout<col>} : !llvm.ptr<3>, i32, i32, i32, i32
- llvm.return
-}
-
-// -----
-
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index f86a04186f512..3be35faf091e2 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -573,6 +573,29 @@ llvm.func @ld_matrix(%arg0: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: @st_matrix
+llvm.func @st_matrix(%arg0: !llvm.ptr<3>, %r1: i32, %r2: i32, %r3: i32, %r4: i32) {
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x1.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x1.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n8>} : !llvm.ptr<3>, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x2.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x2.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n8>} : !llvm.ptr<3>, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<row>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32, i32, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m8n8.x4.trans.b16.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m8n8>} : !llvm.ptr<3>, i32, i32, i32, i32
+ // CHECK: call void @llvm.nvvm.stmatrix.sync.aligned.m16n8.x4.trans.b8.p3(ptr addrspace(3) %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ nvvm.stmatrix %arg0, %r1, %r2, %r3, %r4 {layout = #nvvm.mma_layout<col>, shape = #nvvm.ld_st_matrix_shape<m16n8>} : !llvm.ptr<3>, i32, i32, i32, i32
+ llvm.return
+}
+
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {
More information about the Mlir-commits
mailing list