[llvm] [NVPTX] Add sparse MMA intrinsics (PR #150950)
Kirill Vedernikov via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 28 06:30:43 PDT 2025
https://github.com/kvederni updated https://github.com/llvm/llvm-project/pull/150950
>From 7489d1b1af2a52968bfb5cb891e870a933ef02dc Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Mon, 28 Jul 2025 15:09:58 +0200
Subject: [PATCH] [NVPTX] Add sparse MMA intrinsics
This change adds intrinsics for MMA sparse. The implementation is based
on PTX ISA version 8.8.
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 189 ++++++++++++-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 82 +++++-
llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py | 12 +
llvm/test/CodeGen/NVPTX/wmma.py | 267 +++++++++++++++++--
4 files changed, 525 insertions(+), 25 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 967d1663f237b..c4f3e1b394c8e 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -170,7 +170,7 @@ class StrJoin<string sep, list<string> str_list> {
// Geom: m<M>n<N>k<K>. E.g. m8n32k16
// Frag: [a|b|c|d] ([x1|x2|x4] for ldmatrix)
// PtxEltType: PTX type for the element.
-class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
+class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = false> {
string geom = Geom;
string frag = Frag;
string ptx_elt_type = PtxEltType;
@@ -178,6 +178,54 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
+ // mma sparse ops use other fragments for some arguments
+ !and(!eq(gft, "m16n8k16:a:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k16:a:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k32:a:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k32:a:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 4),
+ !and(!eq(gft, "m16n8k32:b:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k32:b:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 4),
+ !and(!eq(gft, "m16n8k32:c:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k32:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k32:d:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k32:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k16:a:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k16:b:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k16:c:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k16:d:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k8:a:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k32:a:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k32:a:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k64:a:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e4m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e5m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e3m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e2m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e4m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e5m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e3m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e2m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:c:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k64:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k64:d:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k64:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k64:a:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k128:a:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:a:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:a:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:b:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:b:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:b:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:c:s32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k128:d:s32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
// mma fp ops use smaller fragments than wmma fp ops
!eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2),
@@ -362,6 +410,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
list<WMMA_REGS> id_frags = !cond(
+ // FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type.
+ !or(!eq(A.ptx_elt_type, "e4m3"),
+ !eq(A.ptx_elt_type, "e5m2"),
+ !eq(A.ptx_elt_type, "e3m2"),
+ !eq(A.ptx_elt_type, "e2m3"),
+ !eq(A.ptx_elt_type, "e2m1")): [D, A, B, C],
// FP16 ops are identified by accumulator & result type.
!eq(A.ptx_elt_type, "f16") : [D, C],
// other ops are identified by input types.
@@ -397,6 +451,19 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
# signature;
}
+class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B,
+ WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string record = "int_nvvm_mma"
+ # "_" # !subst("::", "_", Metadata)
+ # "_" # A.geom
+ # "_row_col"
+ # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
+ # !if(Satfinite, "_satfinite", "")
+ # signature;
+}
+
class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
string intr = "llvm.nvvm.ldmatrix.sync.aligned"
# "." # Frag.geom
@@ -424,21 +491,22 @@ class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
// TypeN: PTX type of the corresponding fragment's element.
// TypeB and TypeD may be empty if it must match that of TypeA or TypeC.
class MMA_OPS<list<string> Geom, list<string> TypeA, list<string> TypeB,
- list<string> TypeC, list<string> TypeD> {
+ list<string> TypeC, list<string> TypeD, bit IsSparse = false> {
list<list<WMMA_REGS>> ret =
!foldl([]<list<WMMA_REGS>>, Geom, t1, geom, !listconcat(t1,
!foldl([]<list<WMMA_REGS>>, TypeA, t2, type_a, !listconcat(t2,
!foldl([]<list<WMMA_REGS>>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3,
!foldl([]<list<WMMA_REGS>>, TypeC, t4, type_c, !listconcat(t4,
!foldl([]<list<WMMA_REGS>>, !if(!size(TypeD), TypeD, [type_c]), t5, type_d, !listconcat(t5,
- [[WMMA_REGS<geom, "a", type_a>,
- WMMA_REGS<geom, "b", type_b>,
- WMMA_REGS<geom, "c", type_c>,
- WMMA_REGS<geom, "d", type_d>]]))))))))));
+ [[WMMA_REGS<geom, "a", type_a, IsSparse>,
+ WMMA_REGS<geom, "b", type_b, IsSparse>,
+ WMMA_REGS<geom, "c", type_c, IsSparse>,
+ WMMA_REGS<geom, "d", type_d, IsSparse>]]))))))))));
// Debugging aid for readable representation of the list above.
list<list<string>> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]);
}
+
class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<WMMA_REGS> ret =
!foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
@@ -522,6 +590,30 @@ class NVVM_MMA_OPS {
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+ list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
+ ["m16n8k16", "m16n8k32"],
+ ["bf16"], [], ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS<
+ ["m16n8k8", "m16n8k16"],
+ ["tf32"], [], ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS<
+ ["m16n8k16", "m16n8k32"],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"], true>.ret;
+ list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS<
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f16", "f32"], ["f16", "f32"], true>.ret;
+ list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS<
+ ["m16n8k64", "m16n8k128"],
+ ["s4", "u4"], ["s4", "u4"], ["s32"], [], true>.ret;
+ list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS<
+ ["m16n8k32", "m16n8k64"],
+ ["s8", "u8"], ["s8", "u8"], ["s32"], [], true>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_ops = !listconcat(
+ bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
+ subint_mma_sp_ops, int_mma_sp_ops);
+
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -728,6 +820,68 @@ class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
);
}
+
+// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
+ string kind, int satf> {
+ // MMA.SP ops check both layouts.
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ bit is_int = !or(!eq(a_type, "s8"),
+ !eq(a_type, "u8"),
+ !eq(a_type, "s4"),
+ !eq(a_type, "u4"));
+
+ bit ret = !cond(
+
+ // Limit satf to valid types
+ !and(!eq(satf, 1),
+ !eq(is_int, 0)): false,
+
+ // f16/bf16/tf32 requires A and B to be the same type.
+ !and(!or(!eq(a_type, "f16"),
+ !eq(a_type, "bf16"),
+ !eq(a_type, "tf32")),
+ !ne(a_type, b_type)): false,
+
+ // m16n8k16 and m16n8k32 requires C and D to be the same type.
+ !and(!or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k32")),
+ !ne(c_type, d_type)): false,
+
+ !and(!eq(kind, ""),
+ !or(!eq(a_type, "e3m2"),
+ !eq(a_type, "e2m3"),
+ !eq(a_type, "e2m1"),
+ !eq(b_type, "e3m2"),
+ !eq(b_type, "e2m3"),
+ !eq(b_type, "e2m1"))): false,
+
+ !and(!eq(kind, ""),
+ !eq(geom, "m16n8k64"),
+ !or(!eq(c_type, "f16"),
+ !eq(d_type, "f16"))): false,
+
+ !and(!ne(kind, ""),
+ !or(!eq(metadata, "sp"),
+ !ne(geom, "m16n8k64"),
+ !eq(is_int, 1))): false,
+
+ // All other are OK.
+ true: true
+ );
+}
+
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -2001,6 +2155,29 @@ foreach layout_a = ["row", "col"] in {
} // layout_b
} // layout_a
+// MMA.SP
+class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty], [llvm_i32_ty])> {
+ int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
+ let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
+ Range<ArgIndex<pos>, 0, 4>];
+}
+
+foreach metadata = ["sp", "sp::ordered_metadata"] in {
+ foreach kind = ["", "kind::f8f6f4"] in {
+ foreach satf = [0, 1] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_ops in {
+ if NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret then {
+ def MMA_SP_NAME<metadata, kind, satf,
+ op[0], op[1], op[2], op[3]>.record
+ : NVVM_MMA_SP<op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // satf
+ } // kind
+} // metadata
+
// LDMATRIX
class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
: Intrinsic<Frag.regs, [llvm_anyptr_ty],
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 0a00220d94289..a2b29a17537e9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4637,10 +4637,15 @@ def INT_PTX_SREG_WARPSIZE :
// In addition to target-independent fields provided by WMMA_REGS, it adds
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
-class WMMA_REGINFO<WMMA_REGS r, string op>
- : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
+class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
+ : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
+ !eq(ptx_elt_type, "e4m3") : B32,
+ !eq(ptx_elt_type, "e5m2") : B32,
+ !eq(ptx_elt_type, "e3m2") : B32,
+ !eq(ptx_elt_type, "e2m3") : B32,
+ !eq(ptx_elt_type, "e2m1") : B32,
!eq(ptx_elt_type, "f16") : B32,
!eq(ptx_elt_type, "f32") : B32,
!eq(ptx_elt_type, "f64") : B64,
@@ -4673,6 +4678,18 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
// longer the case, we can concat all per-fragment predicates to enforce that
// all fragments of the instruction are viable.
list<Predicate> Predicates = !cond(
+ !or(!eq(ptx_elt_type, "e3m2"),
+ !eq(ptx_elt_type, "e2m3"),
+ !eq(ptx_elt_type, "e2m1"),
+ !ne(kind, "")) : [hasSM120a, hasPTX<87>],
+
+ !or(!eq(ptx_elt_type, "e4m3"),
+ !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
+
+ !and(!eq(op, "mma.sp"),
+ !ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>],
+ !eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>],
+
// fp16 -> fp16/fp32 @ m16n16k16
!and(!eq(geom, "m16n16k16"),
!or(!eq(ptx_elt_type, "f16"),
@@ -4777,7 +4794,8 @@ class BuildPatternI<Intrinsic Intr, dag Ins> {
// Build a dag pattern that matches the intrinsic call.
dag ret = !foreach(tmp, Ins,
!subst(ADDR, addr,
- !subst(ins, Intr, tmp)));
+ !subst(ins, Intr,
+ !subst(i32imm, timm, tmp))));
}
// Same as above, but uses PatFrag instead of an Intrinsic.
@@ -5011,6 +5029,62 @@ defset list<WMMA_INSTR> MMAs = {
} // defset
}
+// MMA SP
+class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Metadata, string Kind, int Satfinite>
+ : WMMA_INSTR<MMA_SP_NAME<Metadata, Kind, Satfinite,
+ FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$metadata, i32imm:$selector)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<!listconcat(FragA.Predicates,
+ FragB.Predicates,
+ FragC.Predicates,
+ FragD.Predicates)> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type;
+ let AsmString = "mma"
+ # "." # Metadata
+ # ".sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # !if(!ne(Kind, ""), "." # Kind, "")
+ # !if(Satfinite, ".satfinite", "")
+ # TypeList # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$metadata" # ",\n\t\t"
+ # "$selector" # ";";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_SPs = {
+ foreach metadata = ["sp", "sp::ordered_metadata"] in {
+ foreach kind = ["", "kind::f8f6f4"] in {
+ foreach satf = [0, 1] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_ops in {
+ if NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret then {
+ def : MMA_SP<WMMA_REGINFO<op[0], "mma.sp", metadata, kind>,
+ WMMA_REGINFO<op[1], "mma.sp", metadata, kind>,
+ WMMA_REGINFO<op[2], "mma.sp", metadata, kind>,
+ WMMA_REGINFO<op[3], "mma.sp", metadata, kind>,
+ metadata, kind, satf>;
+ }
+ } // op
+ } // satf
+ } // kind
+ } // metadata
+} // defset
+}
+
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
@@ -5092,7 +5166,7 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
new file mode 100644
index 0000000000000..ae781df0116fd
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
@@ -0,0 +1,12 @@
+# Check all variants of instructions supported by PTX87 on SM120a
+# RUN: %python %s --ptx=87 --gpu-arch=120 --aa > %t-ptx87-sm_120a.ll
+# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
+# RUN: | FileCheck %t-ptx87-sm_120a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
+# RUN: | %ptxas-verify -arch=sm_120a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2eb3c3dbb4c39..283c94714282b 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -15,6 +15,11 @@ class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
self.llvm_type = {
+ "e4m3" : "i32",
+ "e5m2" : "i32",
+ "e3m2" : "i32",
+ "e2m3" : "i32",
+ "e2m1" : "i32",
"f16": "<2 x half>",
"f32": "float",
"f64": "double",
@@ -43,7 +48,7 @@ def __repr__(self):
class MMAFrag:
- def __init__(self, geom, frag, ptx_elt_type):
+ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse = False):
self.geom = geom
self.frag = frag
self.mma_type = MMAType(ptx_elt_type)
@@ -79,12 +84,53 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k16:b:s8": 1,
"m16n8k16:c:s32": 4,
"m16n8k16:d:s32": 4,
- "m16n8k32:a:u8": 4,
- "m16n8k32:a:s8": 4,
+ "m16n8k32:a:u8": 2 if is_mma_sparse else 4,
+ "m16n8k32:a:s8": 2 if is_mma_sparse else 4,
"m16n8k32:b:u8": 2,
"m16n8k32:b:s8": 2,
"m16n8k32:c:s32": 4,
"m16n8k32:d:s32": 4,
+ # mma sp
+ "m16n8k32:a:bf16": 4,
+ "m16n8k32:a:f16": 4,
+ "m16n8k32:b:bf16": 4,
+ "m16n8k32:b:f16": 4,
+ "m16n8k32:c:f16": 2,
+ "m16n8k32:c:f32": 4 if is_mma_sparse else 8,
+ "m16n8k32:d:f16": 2,
+ "m16n8k32:d:f32": 4 if is_mma_sparse else 8,
+ "m16n8k16:a:tf32": 4,
+ "m16n8k16:b:tf32": 4,
+ "m16n8k16:c:tf32": 4,
+ "m16n8k16:d:tf32": 4,
+ "m16n8k64:a:u8": 4,
+ "m16n8k64:a:s8": 4,
+ "m16n8k64:a:e4m3": 4,
+ "m16n8k64:a:e5m2": 4,
+ "m16n8k64:a:e3m2": 4,
+ "m16n8k64:a:e2m3": 4,
+ "m16n8k64:a:e2m1": 4,
+ "m16n8k64:b:u8": 4,
+ "m16n8k64:b:s8": 4,
+ "m16n8k64:b:e4m3": 4,
+ "m16n8k64:b:e5m2": 4,
+ "m16n8k64:b:e3m2": 4,
+ "m16n8k64:b:e2m3": 4,
+ "m16n8k64:b:e2m1": 4,
+ "m16n8k64:c:f16": 2,
+ "m16n8k64:c:f32": 4,
+ "m16n8k64:d:f16": 2,
+ "m16n8k64:d:f32": 4,
+ "m16n8k128:a:u4": 4,
+ "m16n8k128:a:s4": 4,
+ "m16n8k128:a:e2m1": 4,
+ "m16n8k128:b:u4": 4,
+ "m16n8k128:b:s4": 4,
+ "m16n8k128:b:e2m1": 4,
+ "m16n8k128:c:s32": 4,
+ "m16n8k128:c:f32": 4,
+ "m16n8k128:d:s32": 4,
+ "m16n8k128:d:f32": 4,
# u4/s4 -> s32 @ m8n8k32 (u4/s4)
"m8n8k32:a:u4": 1,
"m8n8k32:a:s4": 1,
@@ -98,8 +144,8 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k32:b:s4": 1,
"m16n8k32:c:s32": 4,
"m16n8k32:d:s32": 4,
- "m16n8k64:a:u4": 4,
- "m16n8k64:a:s4": 4,
+ "m16n8k64:a:u4": 2 if is_mma_sparse else 4,
+ "m16n8k64:a:s4": 2 if is_mma_sparse else 4,
"m16n8k64:b:u4": 2,
"m16n8k64:b:s4": 2,
"m16n8k64:c:s32": 4,
@@ -124,7 +170,7 @@ def __init__(self, geom, frag, ptx_elt_type):
"m8n32k16:b:bf16": 8,
"m32n8k16:a:bf16": 8,
"m32n8k16:b:bf16": 2,
- "m16n8k16:a:bf16": 4,
+ "m16n8k16:a:bf16": 2 if is_mma_sparse else 4,
"m16n8k16:b:bf16": 2,
"m16n8k16:c:f32": 4,
"m16n8k16:d:f32": 4,
@@ -143,7 +189,7 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k4:b:tf32": 1,
"m16n8k4:c:f32": 4,
"m16n8k4:d:f32": 4,
- "m16n8k8:a:tf32": 4,
+ "m16n8k8:a:tf32": 2 if is_mma_sparse else 4,
"m16n8k8:b:tf32": 2,
"m16n8k8:c:f32": 4,
"m16n8k8:d:f32": 4,
@@ -155,7 +201,7 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k8:d:f16": 2,
"m16n8k8:c:f32": 4,
"m16n8k8:d:f32": 4,
- "m16n8k16:a:f16": 4,
+ "m16n8k16:a:f16": 2 if is_mma_sparse else 4,
"m16n8k16:b:f16": 2,
"m16n8k16:c:f16": 2,
"m16n8k16:d:f16": 2,
@@ -218,7 +264,7 @@ def __repr__(self):
return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d)
-def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
+def make_mma_ops(geoms, types_a, types_b, types_c, types_d, is_mma_sparse = False):
ops = []
for geom, type_a, type_c in product(geoms, types_a, types_c):
for type_b, type_d in product(
@@ -226,10 +272,10 @@ def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
):
ops.append(
MMAOp(
- MMAFrag(geom, "a", type_a),
- MMAFrag(geom, "b", type_b),
- MMAFrag(geom, "c", type_c),
- MMAFrag(geom, "d", type_d),
+ MMAFrag(geom, "a", type_a, is_mma_sparse),
+ MMAFrag(geom, "b", type_b, is_mma_sparse),
+ MMAFrag(geom, "c", type_c, is_mma_sparse),
+ MMAFrag(geom, "d", type_d, is_mma_sparse),
)
)
return ops
@@ -416,6 +462,10 @@ def is_type_supported(ptx_type):
return ptx_version >= 65 and gpu_arch >= 75
if ptx_type in ["bf16", "tf32", "f64"]:
return ptx_version >= 70
+ if ptx_type in ["e4m3", "e5m2"]:
+ return ptx_version >= 84 and gpu_arch >= 89
+ if ptx_type in ["e3m2", "e2m3", "e2m1"]:
+ return ptx_version >= 87 and gpu_arch >= 120 and aa
return ptx_version >= 60 and gpu_arch >= 70
@@ -448,7 +498,7 @@ def is_mma_variant_supported(op, layout_a, layout_b, satf):
):
return False
- if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
+ if satf and op.a.mma_type.ptx_type not in ["s8", "u8", "s4", "u4"]:
return False
# If the type of C is f32 then so must the type of D
@@ -825,7 +875,11 @@ def gen_stmatrix_tests():
return generated_items
def mma_signature(op):
- if op.a.mma_type.ptx_type == "f16":
+ if op.a.mma_type.ptx_type in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]:
+ # FP8/F8F6F4 ops identified by inputs, accumulator & result types.
+ return "%s.%s.%s.%s" % (op.d.mma_type.ptx_type, op.a.mma_type.ptx_type, \
+ op.b.mma_type.ptx_type, op.c.mma_type.ptx_type)
+ elif op.a.mma_type.ptx_type == "f16":
# FP16 ops identified by accumulator & result type.
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
@@ -980,6 +1034,188 @@ def gen_mma_tests():
return generated_items
+def get_mma_sp_ops():
+ return (
+ make_mma_ops(["m16n8k16", "m16n8k32"], ["bf16"], [], ["f32"], [], True)
+ + make_mma_ops(["m16n8k8", "m16n8k16"], ["tf32"], [], ["f32"], [], True)
+ + make_mma_ops(["m16n8k16", "m16n8k32"], ["f16"], [], ["f16", "f32"], ["f16", "f32"], True)
+ + make_mma_ops(["m16n8k64", "m16n8k128"], ["s4", "u4"], ["s4", "u4"], ["s32"], [], True)
+ + make_mma_ops(["m16n8k32", "m16n8k64"], ["s8", "u8"], ["s8", "u8"], ["s32"], [], True)
+ + make_mma_ops(["m16n8k64"], ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], ["f16", "f32"], ["f16", "f32"], True)
+ )
+
+
+def is_mma_sp_geom_supported(geom):
+ # geometries for FP and ints.
+ if geom in [
+ "m16n8k16",
+ "m16n8k32",
+ "m16n8k8",
+ "m16n8k64",
+ "m16n8k128",
+ ]:
+ return ptx_version >= 71
+ raise ValueError(f"Unexpected sparse MMA geometry: {geom}")
+
+
+def is_mma_sp_variant_supported(op, metadata, kind, satf):
+ if metadata != "sp" and (ptx_version < 85 or gpu_arch < 80):
+ return False
+
+ if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
+ return False
+
+ if not (
+ is_type_supported(op.a.mma_type.ptx_type) and is_mma_sp_geom_supported(op.a.geom)
+ ):
+ return False
+
+ is_int = op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]
+
+ if satf and not is_int:
+ return False
+
+ # A and B type must be the same
+ if (
+ op.a.mma_type.ptx_type in ["f16", "bf16", "tf32"]
+ and op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
+ ):
+ return False
+
+ # C and D type must be the same for m16n8k16/m16n8k32
+ if (
+ op.a.geom in ["m16n8k16", "m16n8k32"]
+ and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
+ ):
+ return False
+
+ if kind == "" and (op.a.mma_type.ptx_type in ["e3m2", "e2m3", "e2m1"] or \
+ op.b.mma_type.ptx_type in ["e3m2", "e2m3", "e2m1"]):
+ return False
+
+ if kind == "" and op.a.geom == "m16n8k64" and \
+ (op.c.mma_type.ptx_type == "f16" or op.d.mma_type.ptx_type == "f16"):
+ return False
+
+ if kind != "" and (metadata == "sp" or op.a.geom != "m16n8k64" or is_int):
+ return False
+
+ return True
+
+
+def sp_selector_gen(op):
+ # (geom, type) -> allowed selector range
+ range_01 = {
+ ("m16n8k32", "bf16"),
+ ("m16n8k16", "tf32"),
+ ("m16n8k32", "u8"),
+ ("m16n8k32", "s8"),
+ ("m16n8k64", "u4"),
+ ("m16n8k64", "s4"),
+ }
+
+ if (op.a.geom, op.a.mma_type.ptx_type) in range_01:
+ return range(2)
+ if (
+ op.a.geom == "m16n8k64"
+ and op.a.mma_type.ptx_type in ["u8", "s8", "e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
+ ):
+ return range(1)
+ return range(4)
+
+
+def common_mma_sp_test_gen(params, op, intrinsic_template, instruction_template):
+ mma_sp_decl_template = """
+declare ${ret_ty} @${intrinsic}(
+ ${args});
+"""
+
+ mma_sp_test_template = """
+; CHECK-LABEL: .func {{.*}}test_${function}_${selector}(
+define ${ret_ty} @test_${function}_${selector}(
+ ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_metadata}
+; CHECK-NEXT: ${check_selector}
+ %r = call ${ret_ty} @${intrinsic}(
+ ${call_args});
+ ret ${ret_ty} %r;
+}
+"""
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params).replace("::", ".").replace("_", ".")
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+ test_params["check_a"] = check_pattern(op.a)
+ test_params["check_b"] = check_pattern(op.b)
+ test_params["check_c"] = check_pattern(op.c)
+ test_params["check_d"] = check_pattern(op.d)
+ test_params["check_metadata"] = "{{%r[0-9]+}}"
+ args = ",\n ".join(
+ list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
+ + ["i32 %metadata", "i32 %selector"]
+ )
+ test_params["args"] = args
+
+ print(Template(mma_sp_decl_template).substitute(test_params))
+
+ for selector in [str(r) for r in sp_selector_gen(op)]:
+ test_params["selector"] = selector
+ test_params["check_selector"] = "{{" + test_params["selector"] + "}}"
+ test_params["call_args"] = test_params["args"].replace("%selector", test_params["selector"])
+
+ print(Template(mma_sp_test_template).substitute(test_params))
+
+ return (test_params["intrinsic"], test_params["instruction"])
+
+
+def gen_mma_sp_tests():
+ if ptx_version < 71 or gpu_arch < 80:
+ return []
+
+ mma_sp_intrinsic_template = "llvm.nvvm.mma.${metadata}.${geom}.row.col${kind}${satf}.${intrinsic_signature}"
+ mma_sp_instruction_template = "mma.${metadata}.sync.aligned.${geom}.row.col${kind}${satf}.${ptx_signature}"
+
+ generated_items = []
+
+ for op, metadata, kind, satf in product(
+ get_mma_sp_ops(),
+ ["sp::ordered_metadata", "sp"],
+ ["", ".kind::f8f6f4"],
+ [".satfinite", ""]
+ ):
+
+ if not is_mma_sp_variant_supported(op, metadata, kind, satf):
+ continue
+
+ params = {
+ "intrinsic_signature": mma_signature(op),
+ "ptx_signature": mma_ptx_signature(op),
+ "satf": satf,
+ "geom": op.a.geom,
+ "metadata": metadata,
+ "kind": kind,
+ }
+
+ intrinsic_template = mma_sp_intrinsic_template
+ instruction_template = mma_sp_instruction_template
+
+ generated_items.append(
+ common_mma_sp_test_gen(
+ params, op, intrinsic_template, instruction_template
+ )
+ )
+
+ return generated_items
+
+
# Append complete list of intrinsics and instructions we've generated tests for.
# Generate set of checks to verify that that we did generate sensible set of
# tests for the given combination of PTX and SM variants.
@@ -1170,6 +1406,7 @@ def gen_tests():
items += gen_stmatrix_tests()
items += gen_wmma_mma_tests()
items += gen_mma_tests()
+ items += gen_mma_sp_tests()
gen_check_unsupported_ops(items)
More information about the llvm-commits
mailing list