[llvm] [NVPTX] Add sparse MMA intrinsics (PR #150950)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 28 06:25:57 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Kirill Vedernikov (kvederni)
<details>
<summary>Changes</summary>
This change adds intrinsics for MMA sparse. The implementation is based on PTX ISA version 8.8.
---
Patch is 31.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150950.diff
4 Files Affected:
- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+183-6)
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+78-4)
- (added) llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py (+12)
- (modified) llvm/test/CodeGen/NVPTX/wmma.py (+252-15)
``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 967d1663f237b..c4f3e1b394c8e 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -170,7 +170,7 @@ class StrJoin<string sep, list<string> str_list> {
// Geom: m<M>n<N>k<K>. E.g. m8n32k16
// Frag: [a|b|c|d] ([x1|x2|x4] for ldmatrix)
// PtxEltType: PTX type for the element.
-class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
+class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = false> {
string geom = Geom;
string frag = Frag;
string ptx_elt_type = PtxEltType;
@@ -178,6 +178,54 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
+ // mma sparse ops use other fragments for some arguments
+ !and(!eq(gft, "m16n8k16:a:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k16:a:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k32:a:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k32:a:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 4),
+ !and(!eq(gft, "m16n8k32:b:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k32:b:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 4),
+ !and(!eq(gft, "m16n8k32:c:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k32:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k32:d:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k32:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k16:a:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k16:b:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k16:c:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k16:d:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k8:a:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k32:a:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k32:a:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k64:a:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e4m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e5m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e3m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e2m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e4m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e5m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e3m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e2m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:c:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k64:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k64:d:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k64:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k64:a:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k128:a:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:a:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:a:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:b:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:b:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:b:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:c:s32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k128:d:s32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
// mma fp ops use smaller fragments than wmma fp ops
!eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2),
@@ -362,6 +410,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
list<WMMA_REGS> id_frags = !cond(
+ // FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type.
+ !or(!eq(A.ptx_elt_type, "e4m3"),
+ !eq(A.ptx_elt_type, "e5m2"),
+ !eq(A.ptx_elt_type, "e3m2"),
+ !eq(A.ptx_elt_type, "e2m3"),
+ !eq(A.ptx_elt_type, "e2m1")): [D, A, B, C],
// FP16 ops are identified by accumulator & result type.
!eq(A.ptx_elt_type, "f16") : [D, C],
// other ops are identified by input types.
@@ -397,6 +451,19 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
# signature;
}
+class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B,
+ WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string record = "int_nvvm_mma"
+ # "_" # !subst("::", "_", Metadata)
+ # "_" # A.geom
+ # "_row_col"
+ # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
+ # !if(Satfinite, "_satfinite", "")
+ # signature;
+}
+
class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
string intr = "llvm.nvvm.ldmatrix.sync.aligned"
# "." # Frag.geom
@@ -424,21 +491,22 @@ class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
// TypeN: PTX type of the corresponding fragment's element.
// TypeB and TypeD may be empty if it must match that of TypeA or TypeC.
class MMA_OPS<list<string> Geom, list<string> TypeA, list<string> TypeB,
- list<string> TypeC, list<string> TypeD> {
+ list<string> TypeC, list<string> TypeD, bit IsSparse = false> {
list<list<WMMA_REGS>> ret =
!foldl([]<list<WMMA_REGS>>, Geom, t1, geom, !listconcat(t1,
!foldl([]<list<WMMA_REGS>>, TypeA, t2, type_a, !listconcat(t2,
!foldl([]<list<WMMA_REGS>>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3,
!foldl([]<list<WMMA_REGS>>, TypeC, t4, type_c, !listconcat(t4,
!foldl([]<list<WMMA_REGS>>, !if(!size(TypeD), TypeD, [type_c]), t5, type_d, !listconcat(t5,
- [[WMMA_REGS<geom, "a", type_a>,
- WMMA_REGS<geom, "b", type_b>,
- WMMA_REGS<geom, "c", type_c>,
- WMMA_REGS<geom, "d", type_d>]]))))))))));
+ [[WMMA_REGS<geom, "a", type_a, IsSparse>,
+ WMMA_REGS<geom, "b", type_b, IsSparse>,
+ WMMA_REGS<geom, "c", type_c, IsSparse>,
+ WMMA_REGS<geom, "d", type_d, IsSparse>]]))))))))));
// Debugging aid for readable representation of the list above.
list<list<string>> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]);
}
+
class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<WMMA_REGS> ret =
!foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
@@ -522,6 +590,30 @@ class NVVM_MMA_OPS {
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+ list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
+ ["m16n8k16", "m16n8k32"],
+ ["bf16"], [], ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS<
+ ["m16n8k8", "m16n8k16"],
+ ["tf32"], [], ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS<
+ ["m16n8k16", "m16n8k32"],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"], true>.ret;
+ list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS<
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f16", "f32"], ["f16", "f32"], true>.ret;
+ list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS<
+ ["m16n8k64", "m16n8k128"],
+ ["s4", "u4"], ["s4", "u4"], ["s32"], [], true>.ret;
+ list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS<
+ ["m16n8k32", "m16n8k64"],
+ ["s8", "u8"], ["s8", "u8"], ["s32"], [], true>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_ops = !listconcat(
+ bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
+ subint_mma_sp_ops, int_mma_sp_ops);
+
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -728,6 +820,68 @@ class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
);
}
+
+// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
+ string kind, int satf> {
+ // MMA.SP ops check both layouts.
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ bit is_int = !or(!eq(a_type, "s8"),
+ !eq(a_type, "u8"),
+ !eq(a_type, "s4"),
+ !eq(a_type, "u4"));
+
+ bit ret = !cond(
+
+ // Limit satf to valid types
+ !and(!eq(satf, 1),
+ !eq(is_int, 0)): false,
+
+ // f16/bf16/tf32 requires A and B to be the same type.
+ !and(!or(!eq(a_type, "f16"),
+ !eq(a_type, "bf16"),
+ !eq(a_type, "tf32")),
+ !ne(a_type, b_type)): false,
+
+ // m16n8k16 and m16n8k32 requires C and D to be the same type.
+ !and(!or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k32")),
+ !ne(c_type, d_type)): false,
+
+ !and(!eq(kind, ""),
+ !or(!eq(a_type, "e3m2"),
+ !eq(a_type, "e2m3"),
+ !eq(a_type, "e2m1"),
+ !eq(b_type, "e3m2"),
+ !eq(b_type, "e2m3"),
+ !eq(b_type, "e2m1"))): false,
+
+ !and(!eq(kind, ""),
+ !eq(geom, "m16n8k64"),
+ !or(!eq(c_type, "f16"),
+ !eq(d_type, "f16"))): false,
+
+ !and(!ne(kind, ""),
+ !or(!eq(metadata, "sp"),
+ !ne(geom, "m16n8k64"),
+ !eq(is_int, 1))): false,
+
+ // All other are OK.
+ true: true
+ );
+}
+
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -2001,6 +2155,29 @@ foreach layout_a = ["row", "col"] in {
} // layout_b
} // layout_a
+// MMA.SP
+class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty], [llvm_i32_ty])> {
+ int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
+ let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
+ Range<ArgIndex<pos>, 0, 4>];
+}
+
+foreach metadata = ["sp", "sp::ordered_metadata"] in {
+ foreach kind = ["", "kind::f8f6f4"] in {
+ foreach satf = [0, 1] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_ops in {
+ if NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret then {
+ def MMA_SP_NAME<metadata, kind, satf,
+ op[0], op[1], op[2], op[3]>.record
+ : NVVM_MMA_SP<op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // satf
+ } // kind
+} // metadata
+
// LDMATRIX
class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
: Intrinsic<Frag.regs, [llvm_anyptr_ty],
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 0a00220d94289..a2b29a17537e9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4637,10 +4637,15 @@ def INT_PTX_SREG_WARPSIZE :
// In addition to target-independent fields provided by WMMA_REGS, it adds
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
-class WMMA_REGINFO<WMMA_REGS r, string op>
- : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
+class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
+ : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
+ !eq(ptx_elt_type, "e4m3") : B32,
+ !eq(ptx_elt_type, "e5m2") : B32,
+ !eq(ptx_elt_type, "e3m2") : B32,
+ !eq(ptx_elt_type, "e2m3") : B32,
+ !eq(ptx_elt_type, "e2m1") : B32,
!eq(ptx_elt_type, "f16") : B32,
!eq(ptx_elt_type, "f32") : B32,
!eq(ptx_elt_type, "f64") : B64,
@@ -4673,6 +4678,18 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
// longer the case, we can concat all per-fragment predicates to enforce that
// all fragments of the instruction are viable.
list<Predicate> Predicates = !cond(
+ !or(!eq(ptx_elt_type, "e3m2"),
+ !eq(ptx_elt_type, "e2m3"),
+ !eq(ptx_elt_type, "e2m1"),
+ !ne(kind, "")) : [hasSM120a, hasPTX<87>],
+
+ !or(!eq(ptx_elt_type, "e4m3"),
+ !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
+
+ !and(!eq(op, "mma.sp"),
+ !ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>],
+ !eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>],
+
// fp16 -> fp16/fp32 @ m16n16k16
!and(!eq(geom, "m16n16k16"),
!or(!eq(ptx_elt_type, "f16"),
@@ -4777,7 +4794,8 @@ class BuildPatternI<Intrinsic Intr, dag Ins> {
// Build a dag pattern that matches the intrinsic call.
dag ret = !foreach(tmp, Ins,
!subst(ADDR, addr,
- !subst(ins, Intr, tmp)));
+ !subst(ins, Intr,
+ !subst(i32imm, timm, tmp))));
}
// Same as above, but uses PatFrag instead of an Intrinsic.
@@ -5011,6 +5029,62 @@ defset list<WMMA_INSTR> MMAs = {
} // defset
}
+// MMA SP
+class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Metadata, string Kind, int Satfinite>
+ : WMMA_INSTR<MMA_SP_NAME<Metadata, Kind, Satfinite,
+ FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$metadata, i32imm:$selector)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<!listconcat(FragA.Predicates,
+ FragB.Predicates,
+ FragC.Predicates,
+ FragD.Predicates)> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type;
+ let AsmString = "mma"
+ # "." # Metadata
+ # ".sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # !if(!ne(Kind, ""), "." # Kind, "")
+ # !if(Satfinite, ".satfinite", "")
+ # TypeList # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$metadata" # ",\n\t\t"
+ # "$selector" # ";";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_SPs = {
+ foreach metadata = ["sp", "sp::ordered_metadata"] in {
+ foreach kind = ["", "kind::f8f6f4"] in {
+ foreach satf = [0, 1] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_ops in {
+ if NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret then {
+ def : MMA_SP<WMMA_REGINFO<op[0], "mma.sp", metadata, kind>,
+ WMMA_REGINFO<op[1], "mma.sp", metadata, kind>,
+ WMMA_REGINFO<op[2], "mma.sp", metadata, kind>,
+ WMMA_REGINFO<op[3], "mma.sp", metadata, kind>,
+ metadata, kind, satf>;
+ }
+ } // op
+ } // satf
+ } // kind
+ } // metadata
+} // defset
+}
+
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
@@ -5092,7 +5166,7 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
new file mode 100644
index 0000000000000..ae781df0116fd
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
@@ -0,0 +1,12 @@
+# Check all variants of instructions supported by PTX87 on SM120a
+# RUN: %python %s --ptx=87 --gpu-arch=120 --aa > %t-ptx87-sm_120a.ll
+# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
+# RUN: | FileCheck %t-ptx87-sm_120a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
+# RUN: | %ptxas-verify -arch=sm_120a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2eb3c3dbb4c39..283c94714282b 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -15,6 +15,11 @@ class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
self.llvm_type = {
+ "e4m3" : "i32",
+ "e5m2" : "i32",
+ "e3m2" : "i32",
+ "e2m3" : "i32",
+ "e2m1" : "i32",
"f16": "<2 x half>",
"f32": "float",
"f64": "double",
@@ -43,7 +48,7 @@ def __repr__(self):
class MMAFrag:
- def __init__(self, geom, frag, ptx_elt_type):
+ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse = False):
self.geom = geom
self.frag = frag
self.mma_type = MMAType(ptx_elt_type)
@@ -79,12 +84,53 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k16:b:s8": 1,
"m16n8k16:c:s32": 4,
"m16n8k16:d:s32": 4,
- "m16n8k32:a:u8": 4,
- "m16n8k32:a:s8": 4,
+ "m16n8k32:a:u8": 2 if is_mma_sparse else 4,
+ "m16n8k32:a:s8": 2 if is_mma_sparse else 4,
"m16n8k32:b:u8": 2,
"m16n8k32:b:s8": 2,
"m16n8k32:c:s32": 4,
"m16n8k32:d:s32": 4,
+ # mma sp
+ "m16n8k32:a:bf16"...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/150950
More information about the llvm-commits
mailing list