[llvm] [NVPTX] Add sparse MMA intrinsics (PR #150950)
Kirill Vedernikov via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 30 09:36:13 PDT 2025
https://github.com/kvederni updated https://github.com/llvm/llvm-project/pull/150950
>From 7489d1b1af2a52968bfb5cb891e870a933ef02dc Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Mon, 28 Jul 2025 15:09:58 +0200
Subject: [PATCH 1/5] [NVPTX] Add sparse MMA intrinsics
This change adds intrinsics for MMA sparse. The implementation is based
on PTX ISA version 8.8.
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 189 ++++++++++++-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 82 +++++-
llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py | 12 +
llvm/test/CodeGen/NVPTX/wmma.py | 267 +++++++++++++++++--
4 files changed, 525 insertions(+), 25 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 967d1663f237b..c4f3e1b394c8e 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -170,7 +170,7 @@ class StrJoin<string sep, list<string> str_list> {
// Geom: m<M>n<N>k<K>. E.g. m8n32k16
// Frag: [a|b|c|d] ([x1|x2|x4] for ldmatrix)
// PtxEltType: PTX type for the element.
-class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
+class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = false> {
string geom = Geom;
string frag = Frag;
string ptx_elt_type = PtxEltType;
@@ -178,6 +178,54 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
+ // mma sparse ops use other fragments for some arguments
+ !and(!eq(gft, "m16n8k16:a:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k16:a:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k32:a:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k32:a:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 4),
+ !and(!eq(gft, "m16n8k32:b:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k32:b:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 4),
+ !and(!eq(gft, "m16n8k32:c:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k32:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k32:d:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k32:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k16:a:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k16:b:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k16:c:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k16:d:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k8:a:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k32:a:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k32:a:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k64:a:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e4m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e5m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e3m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e2m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e4m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e5m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e3m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e2m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:b:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k64:c:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k64:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k64:d:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
+ !and(!eq(gft, "m16n8k64:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k64:a:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k64:a:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
+ !and(!eq(gft, "m16n8k128:a:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:a:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:a:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:b:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:b:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:b:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:c:s32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
+ !and(!eq(gft, "m16n8k128:d:s32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
+ !and(!eq(gft, "m16n8k128:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
// mma fp ops use smaller fragments than wmma fp ops
!eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2),
@@ -362,6 +410,12 @@ class WMMA_NAME_LDST<string Op, WMMA_REGS Frag, string Layout, int WithStride> {
class MMA_SIGNATURE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
list<WMMA_REGS> id_frags = !cond(
+ // FP8/F8F6F4 ops are identified by A,B inputs & accomulator & result type.
+ !or(!eq(A.ptx_elt_type, "e4m3"),
+ !eq(A.ptx_elt_type, "e5m2"),
+ !eq(A.ptx_elt_type, "e3m2"),
+ !eq(A.ptx_elt_type, "e2m3"),
+ !eq(A.ptx_elt_type, "e2m1")): [D, A, B, C],
// FP16 ops are identified by accumulator & result type.
!eq(A.ptx_elt_type, "f16") : [D, C],
// other ops are identified by input types.
@@ -397,6 +451,19 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
# signature;
}
+class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
+ WMMA_REGS A, WMMA_REGS B,
+ WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string record = "int_nvvm_mma"
+ # "_" # !subst("::", "_", Metadata)
+ # "_" # A.geom
+ # "_row_col"
+ # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
+ # !if(Satfinite, "_satfinite", "")
+ # signature;
+}
+
class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
string intr = "llvm.nvvm.ldmatrix.sync.aligned"
# "." # Frag.geom
@@ -424,21 +491,22 @@ class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
// TypeN: PTX type of the corresponding fragment's element.
// TypeB and TypeD may be empty if it must match that of TypeA or TypeC.
class MMA_OPS<list<string> Geom, list<string> TypeA, list<string> TypeB,
- list<string> TypeC, list<string> TypeD> {
+ list<string> TypeC, list<string> TypeD, bit IsSparse = false> {
list<list<WMMA_REGS>> ret =
!foldl([]<list<WMMA_REGS>>, Geom, t1, geom, !listconcat(t1,
!foldl([]<list<WMMA_REGS>>, TypeA, t2, type_a, !listconcat(t2,
!foldl([]<list<WMMA_REGS>>, !if(!size(TypeB), TypeB, [type_a]), t3, type_b, !listconcat(t3,
!foldl([]<list<WMMA_REGS>>, TypeC, t4, type_c, !listconcat(t4,
!foldl([]<list<WMMA_REGS>>, !if(!size(TypeD), TypeD, [type_c]), t5, type_d, !listconcat(t5,
- [[WMMA_REGS<geom, "a", type_a>,
- WMMA_REGS<geom, "b", type_b>,
- WMMA_REGS<geom, "c", type_c>,
- WMMA_REGS<geom, "d", type_d>]]))))))))));
+ [[WMMA_REGS<geom, "a", type_a, IsSparse>,
+ WMMA_REGS<geom, "b", type_b, IsSparse>,
+ WMMA_REGS<geom, "c", type_c, IsSparse>,
+ WMMA_REGS<geom, "d", type_d, IsSparse>]]))))))))));
// Debugging aid for readable representation of the list above.
list<list<string>> ops = !foreach(x, ret, [x[0].gft, x[1].gft, x[2].gft, x[3].gft]);
}
+
class MMA_LDST_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
list<WMMA_REGS> ret =
!foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
@@ -522,6 +590,30 @@ class NVVM_MMA_OPS {
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+ list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
+ ["m16n8k16", "m16n8k32"],
+ ["bf16"], [], ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> tf32_mma_sp_ops = MMA_OPS<
+ ["m16n8k8", "m16n8k16"],
+ ["tf32"], [], ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> fp_mma_sp_ops = MMA_OPS<
+ ["m16n8k16", "m16n8k32"],
+ ["f16"], [], ["f16", "f32"], ["f16", "f32"], true>.ret;
+ list<list<WMMA_REGS>> fp8_mma_sp_ops = MMA_OPS<
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f16", "f32"], ["f16", "f32"], true>.ret;
+ list<list<WMMA_REGS>> subint_mma_sp_ops = MMA_OPS<
+ ["m16n8k64", "m16n8k128"],
+ ["s4", "u4"], ["s4", "u4"], ["s32"], [], true>.ret;
+ list<list<WMMA_REGS>> int_mma_sp_ops = MMA_OPS<
+ ["m16n8k32", "m16n8k64"],
+ ["s8", "u8"], ["s8", "u8"], ["s32"], [], true>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_ops = !listconcat(
+ bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
+ subint_mma_sp_ops, int_mma_sp_ops);
+
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -728,6 +820,68 @@ class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
);
}
+
+// Returns true if this combination of layout/kind/satf for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
+ string kind, int satf> {
+ // MMA.SP ops check both layouts.
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ bit is_int = !or(!eq(a_type, "s8"),
+ !eq(a_type, "u8"),
+ !eq(a_type, "s4"),
+ !eq(a_type, "u4"));
+
+ bit ret = !cond(
+
+ // Limit satf to valid types
+ !and(!eq(satf, 1),
+ !eq(is_int, 0)): false,
+
+ // f16/bf16/tf32 requires A and B to be the same type.
+ !and(!or(!eq(a_type, "f16"),
+ !eq(a_type, "bf16"),
+ !eq(a_type, "tf32")),
+ !ne(a_type, b_type)): false,
+
+ // m16n8k16 and m16n8k32 requires C and D to be the same type.
+ !and(!or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k32")),
+ !ne(c_type, d_type)): false,
+
+ !and(!eq(kind, ""),
+ !or(!eq(a_type, "e3m2"),
+ !eq(a_type, "e2m3"),
+ !eq(a_type, "e2m1"),
+ !eq(b_type, "e3m2"),
+ !eq(b_type, "e2m3"),
+ !eq(b_type, "e2m1"))): false,
+
+ !and(!eq(kind, ""),
+ !eq(geom, "m16n8k64"),
+ !or(!eq(c_type, "f16"),
+ !eq(d_type, "f16"))): false,
+
+ !and(!ne(kind, ""),
+ !or(!eq(metadata, "sp"),
+ !ne(geom, "m16n8k64"),
+ !eq(is_int, 1))): false,
+
+ // All other are OK.
+ true: true
+ );
+}
+
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -2001,6 +2155,29 @@ foreach layout_a = ["row", "col"] in {
} // layout_b
} // layout_a
+// MMA.SP
+class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty], [llvm_i32_ty])> {
+ int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
+ let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
+ Range<ArgIndex<pos>, 0, 4>];
+}
+
+foreach metadata = ["sp", "sp::ordered_metadata"] in {
+ foreach kind = ["", "kind::f8f6f4"] in {
+ foreach satf = [0, 1] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_ops in {
+ if NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret then {
+ def MMA_SP_NAME<metadata, kind, satf,
+ op[0], op[1], op[2], op[3]>.record
+ : NVVM_MMA_SP<op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // satf
+ } // kind
+} // metadata
+
// LDMATRIX
class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
: Intrinsic<Frag.regs, [llvm_anyptr_ty],
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 0a00220d94289..a2b29a17537e9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4637,10 +4637,15 @@ def INT_PTX_SREG_WARPSIZE :
// In addition to target-independent fields provided by WMMA_REGS, it adds
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
-class WMMA_REGINFO<WMMA_REGS r, string op>
- : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type> {
+class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
+ : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
+ !eq(ptx_elt_type, "e4m3") : B32,
+ !eq(ptx_elt_type, "e5m2") : B32,
+ !eq(ptx_elt_type, "e3m2") : B32,
+ !eq(ptx_elt_type, "e2m3") : B32,
+ !eq(ptx_elt_type, "e2m1") : B32,
!eq(ptx_elt_type, "f16") : B32,
!eq(ptx_elt_type, "f32") : B32,
!eq(ptx_elt_type, "f64") : B64,
@@ -4673,6 +4678,18 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
// longer the case, we can concat all per-fragment predicates to enforce that
// all fragments of the instruction are viable.
list<Predicate> Predicates = !cond(
+ !or(!eq(ptx_elt_type, "e3m2"),
+ !eq(ptx_elt_type, "e2m3"),
+ !eq(ptx_elt_type, "e2m1"),
+ !ne(kind, "")) : [hasSM120a, hasPTX<87>],
+
+ !or(!eq(ptx_elt_type, "e4m3"),
+ !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
+
+ !and(!eq(op, "mma.sp"),
+ !ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>],
+ !eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>],
+
// fp16 -> fp16/fp32 @ m16n16k16
!and(!eq(geom, "m16n16k16"),
!or(!eq(ptx_elt_type, "f16"),
@@ -4777,7 +4794,8 @@ class BuildPatternI<Intrinsic Intr, dag Ins> {
// Build a dag pattern that matches the intrinsic call.
dag ret = !foreach(tmp, Ins,
!subst(ADDR, addr,
- !subst(ins, Intr, tmp)));
+ !subst(ins, Intr,
+ !subst(i32imm, timm, tmp))));
}
// Same as above, but uses PatFrag instead of an Intrinsic.
@@ -5011,6 +5029,62 @@ defset list<WMMA_INSTR> MMAs = {
} // defset
}
+// MMA SP
+class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Metadata, string Kind, int Satfinite>
+ : WMMA_INSTR<MMA_SP_NAME<Metadata, Kind, Satfinite,
+ FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$metadata, i32imm:$selector)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<!listconcat(FragA.Predicates,
+ FragB.Predicates,
+ FragC.Predicates,
+ FragD.Predicates)> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type;
+ let AsmString = "mma"
+ # "." # Metadata
+ # ".sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # !if(!ne(Kind, ""), "." # Kind, "")
+ # !if(Satfinite, ".satfinite", "")
+ # TypeList # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$metadata" # ",\n\t\t"
+ # "$selector" # ";";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_SPs = {
+ foreach metadata = ["sp", "sp::ordered_metadata"] in {
+ foreach kind = ["", "kind::f8f6f4"] in {
+ foreach satf = [0, 1] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_ops in {
+ if NVVM_MMA_SP_SUPPORTED<op, metadata, kind, satf>.ret then {
+ def : MMA_SP<WMMA_REGINFO<op[0], "mma.sp", metadata, kind>,
+ WMMA_REGINFO<op[1], "mma.sp", metadata, kind>,
+ WMMA_REGINFO<op[2], "mma.sp", metadata, kind>,
+ WMMA_REGINFO<op[3], "mma.sp", metadata, kind>,
+ metadata, kind, satf>;
+ }
+ } // op
+ } // satf
+ } // kind
+ } // metadata
+} // defset
+}
+
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
@@ -5092,7 +5166,7 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs) in
+foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
new file mode 100644
index 0000000000000..ae781df0116fd
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
@@ -0,0 +1,12 @@
+# Check all variants of instructions supported by PTX87 on SM120a
+# RUN: %python %s --ptx=87 --gpu-arch=120 --aa > %t-ptx87-sm_120a.ll
+# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
+# RUN: | FileCheck %t-ptx87-sm_120a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
+# RUN: | %ptxas-verify -arch=sm_120a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 2eb3c3dbb4c39..283c94714282b 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -15,6 +15,11 @@ class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
self.llvm_type = {
+ "e4m3" : "i32",
+ "e5m2" : "i32",
+ "e3m2" : "i32",
+ "e2m3" : "i32",
+ "e2m1" : "i32",
"f16": "<2 x half>",
"f32": "float",
"f64": "double",
@@ -43,7 +48,7 @@ def __repr__(self):
class MMAFrag:
- def __init__(self, geom, frag, ptx_elt_type):
+ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse = False):
self.geom = geom
self.frag = frag
self.mma_type = MMAType(ptx_elt_type)
@@ -79,12 +84,53 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k16:b:s8": 1,
"m16n8k16:c:s32": 4,
"m16n8k16:d:s32": 4,
- "m16n8k32:a:u8": 4,
- "m16n8k32:a:s8": 4,
+ "m16n8k32:a:u8": 2 if is_mma_sparse else 4,
+ "m16n8k32:a:s8": 2 if is_mma_sparse else 4,
"m16n8k32:b:u8": 2,
"m16n8k32:b:s8": 2,
"m16n8k32:c:s32": 4,
"m16n8k32:d:s32": 4,
+ # mma sp
+ "m16n8k32:a:bf16": 4,
+ "m16n8k32:a:f16": 4,
+ "m16n8k32:b:bf16": 4,
+ "m16n8k32:b:f16": 4,
+ "m16n8k32:c:f16": 2,
+ "m16n8k32:c:f32": 4 if is_mma_sparse else 8,
+ "m16n8k32:d:f16": 2,
+ "m16n8k32:d:f32": 4 if is_mma_sparse else 8,
+ "m16n8k16:a:tf32": 4,
+ "m16n8k16:b:tf32": 4,
+ "m16n8k16:c:tf32": 4,
+ "m16n8k16:d:tf32": 4,
+ "m16n8k64:a:u8": 4,
+ "m16n8k64:a:s8": 4,
+ "m16n8k64:a:e4m3": 4,
+ "m16n8k64:a:e5m2": 4,
+ "m16n8k64:a:e3m2": 4,
+ "m16n8k64:a:e2m3": 4,
+ "m16n8k64:a:e2m1": 4,
+ "m16n8k64:b:u8": 4,
+ "m16n8k64:b:s8": 4,
+ "m16n8k64:b:e4m3": 4,
+ "m16n8k64:b:e5m2": 4,
+ "m16n8k64:b:e3m2": 4,
+ "m16n8k64:b:e2m3": 4,
+ "m16n8k64:b:e2m1": 4,
+ "m16n8k64:c:f16": 2,
+ "m16n8k64:c:f32": 4,
+ "m16n8k64:d:f16": 2,
+ "m16n8k64:d:f32": 4,
+ "m16n8k128:a:u4": 4,
+ "m16n8k128:a:s4": 4,
+ "m16n8k128:a:e2m1": 4,
+ "m16n8k128:b:u4": 4,
+ "m16n8k128:b:s4": 4,
+ "m16n8k128:b:e2m1": 4,
+ "m16n8k128:c:s32": 4,
+ "m16n8k128:c:f32": 4,
+ "m16n8k128:d:s32": 4,
+ "m16n8k128:d:f32": 4,
# u4/s4 -> s32 @ m8n8k32 (u4/s4)
"m8n8k32:a:u4": 1,
"m8n8k32:a:s4": 1,
@@ -98,8 +144,8 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k32:b:s4": 1,
"m16n8k32:c:s32": 4,
"m16n8k32:d:s32": 4,
- "m16n8k64:a:u4": 4,
- "m16n8k64:a:s4": 4,
+ "m16n8k64:a:u4": 2 if is_mma_sparse else 4,
+ "m16n8k64:a:s4": 2 if is_mma_sparse else 4,
"m16n8k64:b:u4": 2,
"m16n8k64:b:s4": 2,
"m16n8k64:c:s32": 4,
@@ -124,7 +170,7 @@ def __init__(self, geom, frag, ptx_elt_type):
"m8n32k16:b:bf16": 8,
"m32n8k16:a:bf16": 8,
"m32n8k16:b:bf16": 2,
- "m16n8k16:a:bf16": 4,
+ "m16n8k16:a:bf16": 2 if is_mma_sparse else 4,
"m16n8k16:b:bf16": 2,
"m16n8k16:c:f32": 4,
"m16n8k16:d:f32": 4,
@@ -143,7 +189,7 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k4:b:tf32": 1,
"m16n8k4:c:f32": 4,
"m16n8k4:d:f32": 4,
- "m16n8k8:a:tf32": 4,
+ "m16n8k8:a:tf32": 2 if is_mma_sparse else 4,
"m16n8k8:b:tf32": 2,
"m16n8k8:c:f32": 4,
"m16n8k8:d:f32": 4,
@@ -155,7 +201,7 @@ def __init__(self, geom, frag, ptx_elt_type):
"m16n8k8:d:f16": 2,
"m16n8k8:c:f32": 4,
"m16n8k8:d:f32": 4,
- "m16n8k16:a:f16": 4,
+ "m16n8k16:a:f16": 2 if is_mma_sparse else 4,
"m16n8k16:b:f16": 2,
"m16n8k16:c:f16": 2,
"m16n8k16:d:f16": 2,
@@ -218,7 +264,7 @@ def __repr__(self):
return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d)
-def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
+def make_mma_ops(geoms, types_a, types_b, types_c, types_d, is_mma_sparse = False):
ops = []
for geom, type_a, type_c in product(geoms, types_a, types_c):
for type_b, type_d in product(
@@ -226,10 +272,10 @@ def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
):
ops.append(
MMAOp(
- MMAFrag(geom, "a", type_a),
- MMAFrag(geom, "b", type_b),
- MMAFrag(geom, "c", type_c),
- MMAFrag(geom, "d", type_d),
+ MMAFrag(geom, "a", type_a, is_mma_sparse),
+ MMAFrag(geom, "b", type_b, is_mma_sparse),
+ MMAFrag(geom, "c", type_c, is_mma_sparse),
+ MMAFrag(geom, "d", type_d, is_mma_sparse),
)
)
return ops
@@ -416,6 +462,10 @@ def is_type_supported(ptx_type):
return ptx_version >= 65 and gpu_arch >= 75
if ptx_type in ["bf16", "tf32", "f64"]:
return ptx_version >= 70
+ if ptx_type in ["e4m3", "e5m2"]:
+ return ptx_version >= 84 and gpu_arch >= 89
+ if ptx_type in ["e3m2", "e2m3", "e2m1"]:
+ return ptx_version >= 87 and gpu_arch >= 120 and aa
return ptx_version >= 60 and gpu_arch >= 70
@@ -448,7 +498,7 @@ def is_mma_variant_supported(op, layout_a, layout_b, satf):
):
return False
- if satf and not op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]:
+ if satf and op.a.mma_type.ptx_type not in ["s8", "u8", "s4", "u4"]:
return False
# If the type of C is f32 then so must the type of D
@@ -825,7 +875,11 @@ def gen_stmatrix_tests():
return generated_items
def mma_signature(op):
- if op.a.mma_type.ptx_type == "f16":
+ if op.a.mma_type.ptx_type in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]:
+ # FP8/F8F6F4 ops identified by inputs, accumulator & result types.
+ return "%s.%s.%s.%s" % (op.d.mma_type.ptx_type, op.a.mma_type.ptx_type, \
+ op.b.mma_type.ptx_type, op.c.mma_type.ptx_type)
+ elif op.a.mma_type.ptx_type == "f16":
# FP16 ops identified by accumulator & result type.
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
elif op.a.mma_type.ptx_type != op.b.mma_type.ptx_type:
@@ -980,6 +1034,188 @@ def gen_mma_tests():
return generated_items
+def get_mma_sp_ops():
+ return (
+ make_mma_ops(["m16n8k16", "m16n8k32"], ["bf16"], [], ["f32"], [], True)
+ + make_mma_ops(["m16n8k8", "m16n8k16"], ["tf32"], [], ["f32"], [], True)
+ + make_mma_ops(["m16n8k16", "m16n8k32"], ["f16"], [], ["f16", "f32"], ["f16", "f32"], True)
+ + make_mma_ops(["m16n8k64", "m16n8k128"], ["s4", "u4"], ["s4", "u4"], ["s32"], [], True)
+ + make_mma_ops(["m16n8k32", "m16n8k64"], ["s8", "u8"], ["s8", "u8"], ["s32"], [], True)
+ + make_mma_ops(["m16n8k64"], ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], ["f16", "f32"], ["f16", "f32"], True)
+ )
+
+
+def is_mma_sp_geom_supported(geom):
+ # geometries for FP and ints.
+ if geom in [
+ "m16n8k16",
+ "m16n8k32",
+ "m16n8k8",
+ "m16n8k64",
+ "m16n8k128",
+ ]:
+ return ptx_version >= 71
+ raise ValueError(f"Unexpected sparse MMA geometry: {geom}")
+
+
+def is_mma_sp_variant_supported(op, metadata, kind, satf):
+ if metadata != "sp" and (ptx_version < 85 or gpu_arch < 80):
+ return False
+
+ if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
+ return False
+
+ if not (
+ is_type_supported(op.a.mma_type.ptx_type) and is_mma_sp_geom_supported(op.a.geom)
+ ):
+ return False
+
+ is_int = op.a.mma_type.ptx_type in ["s8", "u8", "s4", "u4"]
+
+ if satf and not is_int:
+ return False
+
+ # A and B type must be the same
+ if (
+ op.a.mma_type.ptx_type in ["f16", "bf16", "tf32"]
+ and op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
+ ):
+ return False
+
+ # C and D type must be the same for m16n8k16/m16n8k32
+ if (
+ op.a.geom in ["m16n8k16", "m16n8k32"]
+ and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
+ ):
+ return False
+
+ if kind == "" and (op.a.mma_type.ptx_type in ["e3m2", "e2m3", "e2m1"] or \
+ op.b.mma_type.ptx_type in ["e3m2", "e2m3", "e2m1"]):
+ return False
+
+ if kind == "" and op.a.geom == "m16n8k64" and \
+ (op.c.mma_type.ptx_type == "f16" or op.d.mma_type.ptx_type == "f16"):
+ return False
+
+ if kind != "" and (metadata == "sp" or op.a.geom != "m16n8k64" or is_int):
+ return False
+
+ return True
+
+
+def sp_selector_gen(op):
+ # (geom, type) -> allowed selector range
+ range_01 = {
+ ("m16n8k32", "bf16"),
+ ("m16n8k16", "tf32"),
+ ("m16n8k32", "u8"),
+ ("m16n8k32", "s8"),
+ ("m16n8k64", "u4"),
+ ("m16n8k64", "s4"),
+ }
+
+ if (op.a.geom, op.a.mma_type.ptx_type) in range_01:
+ return range(2)
+ if (
+ op.a.geom == "m16n8k64"
+ and op.a.mma_type.ptx_type in ["u8", "s8", "e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
+ ):
+ return range(1)
+ return range(4)
+
+
+def common_mma_sp_test_gen(params, op, intrinsic_template, instruction_template):
+ mma_sp_decl_template = """
+declare ${ret_ty} @${intrinsic}(
+ ${args});
+"""
+
+ mma_sp_test_template = """
+; CHECK-LABEL: .func {{.*}}test_${function}_${selector}(
+define ${ret_ty} @test_${function}_${selector}(
+ ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_metadata}
+; CHECK-NEXT: ${check_selector}
+ %r = call ${ret_ty} @${intrinsic}(
+ ${call_args});
+ ret ${ret_ty} %r;
+}
+"""
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params).replace("::", ".").replace("_", ".")
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+ test_params["check_a"] = check_pattern(op.a)
+ test_params["check_b"] = check_pattern(op.b)
+ test_params["check_c"] = check_pattern(op.c)
+ test_params["check_d"] = check_pattern(op.d)
+ test_params["check_metadata"] = "{{%r[0-9]+}}"
+ args = ",\n ".join(
+ list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
+ + ["i32 %metadata", "i32 %selector"]
+ )
+ test_params["args"] = args
+
+ print(Template(mma_sp_decl_template).substitute(test_params))
+
+ for selector in [str(r) for r in sp_selector_gen(op)]:
+ test_params["selector"] = selector
+ test_params["check_selector"] = "{{" + test_params["selector"] + "}}"
+ test_params["call_args"] = test_params["args"].replace("%selector", test_params["selector"])
+
+ print(Template(mma_sp_test_template).substitute(test_params))
+
+ return (test_params["intrinsic"], test_params["instruction"])
+
+
+def gen_mma_sp_tests():
+ if ptx_version < 71 or gpu_arch < 80:
+ return []
+
+ mma_sp_intrinsic_template = "llvm.nvvm.mma.${metadata}.${geom}.row.col${kind}${satf}.${intrinsic_signature}"
+ mma_sp_instruction_template = "mma.${metadata}.sync.aligned.${geom}.row.col${kind}${satf}.${ptx_signature}"
+
+ generated_items = []
+
+ for op, metadata, kind, satf in product(
+ get_mma_sp_ops(),
+ ["sp::ordered_metadata", "sp"],
+ ["", ".kind::f8f6f4"],
+ [".satfinite", ""]
+ ):
+
+ if not is_mma_sp_variant_supported(op, metadata, kind, satf):
+ continue
+
+ params = {
+ "intrinsic_signature": mma_signature(op),
+ "ptx_signature": mma_ptx_signature(op),
+ "satf": satf,
+ "geom": op.a.geom,
+ "metadata": metadata,
+ "kind": kind,
+ }
+
+ intrinsic_template = mma_sp_intrinsic_template
+ instruction_template = mma_sp_instruction_template
+
+ generated_items.append(
+ common_mma_sp_test_gen(
+ params, op, intrinsic_template, instruction_template
+ )
+ )
+
+ return generated_items
+
+
# Append complete list of intrinsics and instructions we've generated tests for.
# Generate set of checks to verify that that we did generate sensible set of
# tests for the given combination of PTX and SM variants.
@@ -1170,6 +1406,7 @@ def gen_tests():
items += gen_stmatrix_tests()
items += gen_wmma_mma_tests()
items += gen_mma_tests()
+ items += gen_mma_sp_tests()
gen_check_unsupported_ops(items)
>From 7c705a909d23145334f3303e1f596628ac4a7d94 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Tue, 29 Jul 2025 17:51:35 +0200
Subject: [PATCH 2/5] [NVPTX] Code formatting issues were fixed for PR150950.
---
llvm/test/CodeGen/NVPTX/wmma.py | 102 +++++++++++++++++++++-----------
1 file changed, 69 insertions(+), 33 deletions(-)
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 283c94714282b..3b3f70a31beff 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -15,11 +15,11 @@ class MMAType:
def __init__(self, ptx_type):
self.ptx_type = ptx_type
self.llvm_type = {
- "e4m3" : "i32",
- "e5m2" : "i32",
- "e3m2" : "i32",
- "e2m3" : "i32",
- "e2m1" : "i32",
+ "e4m3": "i32",
+ "e5m2": "i32",
+ "e3m2": "i32",
+ "e2m3": "i32",
+ "e2m1": "i32",
"f16": "<2 x half>",
"f32": "float",
"f64": "double",
@@ -48,7 +48,7 @@ def __repr__(self):
class MMAFrag:
- def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse = False):
+ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
self.geom = geom
self.frag = frag
self.mma_type = MMAType(ptx_elt_type)
@@ -264,7 +264,7 @@ def __repr__(self):
return "{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d)
-def make_mma_ops(geoms, types_a, types_b, types_c, types_d, is_mma_sparse = False):
+def make_mma_ops(geoms, types_a, types_b, types_c, types_d, is_mma_sparse=False):
ops = []
for geom, type_a, type_c in product(geoms, types_a, types_c):
for type_b, type_d in product(
@@ -877,8 +877,12 @@ def gen_stmatrix_tests():
def mma_signature(op):
if op.a.mma_type.ptx_type in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]:
# FP8/F8F6F4 ops identified by inputs, accumulator & result types.
- return "%s.%s.%s.%s" % (op.d.mma_type.ptx_type, op.a.mma_type.ptx_type, \
- op.b.mma_type.ptx_type, op.c.mma_type.ptx_type)
+ return "%s.%s.%s.%s" % (
+ op.d.mma_type.ptx_type,
+ op.a.mma_type.ptx_type,
+ op.b.mma_type.ptx_type,
+ op.c.mma_type.ptx_type,
+ )
elif op.a.mma_type.ptx_type == "f16":
# FP16 ops identified by accumulator & result type.
return "%s.%s" % (op.d.mma_type.ptx_type, op.c.mma_type.ptx_type)
@@ -1038,11 +1042,23 @@ def get_mma_sp_ops():
return (
make_mma_ops(["m16n8k16", "m16n8k32"], ["bf16"], [], ["f32"], [], True)
+ make_mma_ops(["m16n8k8", "m16n8k16"], ["tf32"], [], ["f32"], [], True)
- + make_mma_ops(["m16n8k16", "m16n8k32"], ["f16"], [], ["f16", "f32"], ["f16", "f32"], True)
- + make_mma_ops(["m16n8k64", "m16n8k128"], ["s4", "u4"], ["s4", "u4"], ["s32"], [], True)
- + make_mma_ops(["m16n8k32", "m16n8k64"], ["s8", "u8"], ["s8", "u8"], ["s32"], [], True)
- + make_mma_ops(["m16n8k64"], ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
- ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"], ["f16", "f32"], ["f16", "f32"], True)
+ + make_mma_ops(
+ ["m16n8k16", "m16n8k32"], ["f16"], [], ["f16", "f32"], ["f16", "f32"], True
+ )
+ + make_mma_ops(
+ ["m16n8k64", "m16n8k128"], ["s4", "u4"], ["s4", "u4"], ["s32"], [], True
+ )
+ + make_mma_ops(
+ ["m16n8k32", "m16n8k64"], ["s8", "u8"], ["s8", "u8"], ["s32"], [], True
+ )
+ + make_mma_ops(
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f16", "f32"],
+ ["f16", "f32"],
+ True,
+ )
)
@@ -1067,7 +1083,8 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
return False
if not (
- is_type_supported(op.a.mma_type.ptx_type) and is_mma_sp_geom_supported(op.a.geom)
+ is_type_supported(op.a.mma_type.ptx_type)
+ and is_mma_sp_geom_supported(op.a.geom)
):
return False
@@ -1080,22 +1097,27 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
if (
op.a.mma_type.ptx_type in ["f16", "bf16", "tf32"]
and op.a.mma_type.ptx_type != op.b.mma_type.ptx_type
- ):
+ ):
return False
# C and D type must be the same for m16n8k16/m16n8k32
if (
op.a.geom in ["m16n8k16", "m16n8k32"]
and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
- ):
+ ):
return False
- if kind == "" and (op.a.mma_type.ptx_type in ["e3m2", "e2m3", "e2m1"] or \
- op.b.mma_type.ptx_type in ["e3m2", "e2m3", "e2m1"]):
+ if kind == "" and (
+ op.a.mma_type.ptx_type in ["e3m2", "e2m3", "e2m1"]
+ or op.b.mma_type.ptx_type in ["e3m2", "e2m3", "e2m1"]
+ ):
return False
- if kind == "" and op.a.geom == "m16n8k64" and \
- (op.c.mma_type.ptx_type == "f16" or op.d.mma_type.ptx_type == "f16"):
+ if (
+ kind == ""
+ and op.a.geom == "m16n8k64"
+ and (op.c.mma_type.ptx_type == "f16" or op.d.mma_type.ptx_type == "f16")
+ ):
return False
if kind != "" and (metadata == "sp" or op.a.geom != "m16n8k64" or is_int):
@@ -1117,10 +1139,15 @@ def sp_selector_gen(op):
if (op.a.geom, op.a.mma_type.ptx_type) in range_01:
return range(2)
- if (
- op.a.geom == "m16n8k64"
- and op.a.mma_type.ptx_type in ["u8", "s8", "e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
- ):
+ if op.a.geom == "m16n8k64" and op.a.mma_type.ptx_type in [
+ "u8",
+ "s8",
+ "e4m3",
+ "e5m2",
+ "e3m2",
+ "e2m3",
+ "e2m1",
+ ]:
return range(1)
return range(4)
@@ -1149,7 +1176,12 @@ def common_mma_sp_test_gen(params, op, intrinsic_template, instruction_template)
"""
test_params = params
- test_params["intrinsic"] = Template(intrinsic_template).substitute(params).replace("::", ".").replace("_", ".")
+ test_params["intrinsic"] = (
+ Template(intrinsic_template)
+ .substitute(params)
+ .replace("::", ".")
+ .replace("_", ".")
+ )
test_params["function"] = test_params["intrinsic"].replace(".", "_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
@@ -1169,7 +1201,9 @@ def common_mma_sp_test_gen(params, op, intrinsic_template, instruction_template)
for selector in [str(r) for r in sp_selector_gen(op)]:
test_params["selector"] = selector
test_params["check_selector"] = "{{" + test_params["selector"] + "}}"
- test_params["call_args"] = test_params["args"].replace("%selector", test_params["selector"])
+ test_params["call_args"] = test_params["args"].replace(
+ "%selector", test_params["selector"]
+ )
print(Template(mma_sp_test_template).substitute(test_params))
@@ -1180,8 +1214,12 @@ def gen_mma_sp_tests():
if ptx_version < 71 or gpu_arch < 80:
return []
- mma_sp_intrinsic_template = "llvm.nvvm.mma.${metadata}.${geom}.row.col${kind}${satf}.${intrinsic_signature}"
- mma_sp_instruction_template = "mma.${metadata}.sync.aligned.${geom}.row.col${kind}${satf}.${ptx_signature}"
+ mma_sp_intrinsic_template = (
+ "llvm.nvvm.mma.${metadata}.${geom}.row.col${kind}${satf}.${intrinsic_signature}"
+ )
+ mma_sp_instruction_template = (
+ "mma.${metadata}.sync.aligned.${geom}.row.col${kind}${satf}.${ptx_signature}"
+ )
generated_items = []
@@ -1189,7 +1227,7 @@ def gen_mma_sp_tests():
get_mma_sp_ops(),
["sp::ordered_metadata", "sp"],
["", ".kind::f8f6f4"],
- [".satfinite", ""]
+ [".satfinite", ""],
):
if not is_mma_sp_variant_supported(op, metadata, kind, satf):
@@ -1208,9 +1246,7 @@ def gen_mma_sp_tests():
instruction_template = mma_sp_instruction_template
generated_items.append(
- common_mma_sp_test_gen(
- params, op, intrinsic_template, instruction_template
- )
+ common_mma_sp_test_gen(params, op, intrinsic_template, instruction_template)
)
return generated_items
>From ef4afb27b2a111d331db01cc38a1f97248235e3b Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 30 Jul 2025 15:27:55 +0200
Subject: [PATCH 3/5] [NVPTX] Check IsSparse for regs once only in WMMA_REGS.
PR150950.
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 442 +++++++++++++------------
1 file changed, 235 insertions(+), 207 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index c4f3e1b394c8e..1d4edd078e01f 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -177,213 +177,241 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
string gft = Geom#":"#Frag#":"#ptx_elt_type;
string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
- list<LLVMType> regs = !cond(
- // mma sparse ops use other fragments for some arguments
- !and(!eq(gft, "m16n8k16:a:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
- !and(!eq(gft, "m16n8k16:a:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
- !and(!eq(gft, "m16n8k32:a:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k32:a:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 4),
- !and(!eq(gft, "m16n8k32:b:bf16"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k32:b:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 4),
- !and(!eq(gft, "m16n8k32:c:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
- !and(!eq(gft, "m16n8k32:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
- !and(!eq(gft, "m16n8k32:d:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
- !and(!eq(gft, "m16n8k32:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
- !and(!eq(gft, "m16n8k16:a:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k16:b:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k16:c:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
- !and(!eq(gft, "m16n8k16:d:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
- !and(!eq(gft, "m16n8k8:a:tf32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
- !and(!eq(gft, "m16n8k32:a:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
- !and(!eq(gft, "m16n8k32:a:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
- !and(!eq(gft, "m16n8k64:a:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:a:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:a:e4m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:a:e5m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:a:e3m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:a:e2m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:a:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:b:u8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:b:s8"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:b:e4m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:b:e5m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:b:e3m2"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:b:e2m3"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:b:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k64:c:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
- !and(!eq(gft, "m16n8k64:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
- !and(!eq(gft, "m16n8k64:d:f16"), !eq(IsSparse, true)) : !listsplat(llvm_v2f16_ty, 2),
- !and(!eq(gft, "m16n8k64:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
- !and(!eq(gft, "m16n8k64:a:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
- !and(!eq(gft, "m16n8k64:a:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 2),
- !and(!eq(gft, "m16n8k128:a:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k128:a:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k128:a:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k128:b:u4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k128:b:s4"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k128:b:e2m1"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k128:c:s32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k128:c:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
- !and(!eq(gft, "m16n8k128:d:s32"), !eq(IsSparse, true)) : !listsplat(llvm_i32_ty, 4),
- !and(!eq(gft, "m16n8k128:d:f32"), !eq(IsSparse, true)) : !listsplat(llvm_float_ty, 4),
- // mma fp ops use smaller fragments than wmma fp ops
- !eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2),
- !eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2),
- !eq(gft,"m16n8k8:a:f16") : !listsplat(llvm_v2f16_ty, 2),
- !eq(gft,"m16n8k8:b:f16") : [llvm_v2f16_ty],
- !eq(gft,"m16n8k8:c:f16") : !listsplat(llvm_v2f16_ty, 2),
- !eq(gft,"m16n8k8:d:f16") : !listsplat(llvm_v2f16_ty, 2),
- !eq(gft,"m16n8k8:c:f32") : !listsplat(llvm_float_ty, 4),
- !eq(gft,"m16n8k8:d:f32") : !listsplat(llvm_float_ty, 4),
- !eq(gft,"m16n8k16:a:f16") : !listsplat(llvm_v2f16_ty, 4),
- !eq(gft,"m16n8k16:b:f16") : !listsplat(llvm_v2f16_ty, 2),
- !eq(gft,"m16n8k16:c:f16") : !listsplat(llvm_v2f16_ty, 2),
- !eq(gft,"m16n8k16:d:f16") : !listsplat(llvm_v2f16_ty, 2),
- !eq(gft,"m16n8k16:c:f32") : !listsplat(llvm_float_ty, 4),
- !eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
- !eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
- !eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
-
- // wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
- // All other supported geometries use the same fragment format for f32 and
- // f16, so we only need to consider {fragment, type}.
- !eq(ft,"a:f16") : !listsplat(llvm_v2f16_ty, 8),
- !eq(ft,"b:f16") : !listsplat(llvm_v2f16_ty, 8),
- !eq(ft,"c:f16") : !listsplat(llvm_v2f16_ty, 4),
- !eq(ft,"d:f16") : !listsplat(llvm_v2f16_ty, 4),
- !eq(ft,"c:f32") : !listsplat(llvm_float_ty, 8),
- !eq(ft,"d:f32") : !listsplat(llvm_float_ty, 8),
-
- // wmma tf32 -> s32 @ m16n16k8
- !eq(gft,"m16n16k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n16k8:b:tf32") : !listsplat(llvm_i32_ty, 4),
-
- // mma tf32 -> s32 @ m16n16k8/m16n8k8
- !eq(gft,"m16n8k4:a:tf32") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k4:b:tf32") : [llvm_i32_ty],
- !eq(gft,"m16n8k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k8:b:tf32") : !listsplat(llvm_i32_ty, 2),
-
- !eq(gft,"m8n8k4:a:f64") : [llvm_double_ty],
- !eq(gft,"m8n8k4:b:f64") : [llvm_double_ty],
- !eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
- !eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
-
- // wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
- !eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m8n32k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m8n32k16:b:bf16") : !listsplat(llvm_i32_ty, 8),
- !eq(gft,"m32n8k16:a:bf16") : !listsplat(llvm_i32_ty, 8),
- !eq(gft,"m32n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
-
- // mma bf16 -> s32 @ m16n8k16/m16n8k8
- !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
-
- // wmma u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
- !eq(gft,"m16n16k16:a:u8") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n16k16:a:s8") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n16k16:b:u8") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n16k16:b:s8") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n16k16:c:s32") : !listsplat(llvm_i32_ty, 8),
- !eq(gft,"m16n16k16:d:s32") : !listsplat(llvm_i32_ty, 8),
-
- !eq(gft,"m8n32k16:a:u8") : [llvm_i32_ty],
- !eq(gft,"m8n32k16:a:s8") : [llvm_i32_ty],
- !eq(gft,"m8n32k16:b:u8") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m8n32k16:b:s8") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m8n32k16:c:s32") : !listsplat(llvm_i32_ty, 8),
- !eq(gft,"m8n32k16:d:s32") : !listsplat(llvm_i32_ty, 8),
-
- !eq(gft,"m32n8k16:a:u8") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m32n8k16:a:s8") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m32n8k16:b:u8") : [llvm_i32_ty],
- !eq(gft,"m32n8k16:b:s8") : [llvm_i32_ty],
- !eq(gft,"m32n8k16:c:s32") : !listsplat(llvm_i32_ty, 8),
- !eq(gft,"m32n8k16:d:s32") : !listsplat(llvm_i32_ty, 8),
-
- // mma u8/s8 -> s32 @ m8n8k16/m16n8k16/m16n8k32
- !eq(gft,"m8n8k16:a:u8") : [llvm_i32_ty],
- !eq(gft,"m8n8k16:a:s8") : [llvm_i32_ty],
- !eq(gft,"m8n8k16:b:u8") : [llvm_i32_ty],
- !eq(gft,"m8n8k16:b:s8") : [llvm_i32_ty],
- !eq(gft,"m8n8k16:c:s32") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m8n8k16:d:s32") : !listsplat(llvm_i32_ty, 2),
-
- !eq(gft,"m16n8k16:a:u8") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k16:a:s8") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k16:b:u8") : [llvm_i32_ty],
- !eq(gft,"m16n8k16:b:s8") : [llvm_i32_ty],
- !eq(gft,"m16n8k16:c:s32") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k16:d:s32") : !listsplat(llvm_i32_ty, 4),
-
- !eq(gft,"m16n8k32:a:u8") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k32:a:s8") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k32:b:u8") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k32:b:s8") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
-
- // wmma/mma u4/s4 -> s32 @ m8n8k32 (u4/s4)
- !eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty],
- !eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty],
- !eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty],
- !eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty],
- !eq(gft,"m8n8k32:c:s32") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m8n8k32:d:s32") : !listsplat(llvm_i32_ty, 2),
-
- !eq(gft,"m16n8k32:a:u4") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k32:a:s4") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k32:b:u4") : [llvm_i32_ty],
- !eq(gft,"m16n8k32:b:s4") : [llvm_i32_ty],
- !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
-
- !eq(gft,"m16n8k64:a:u4") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k64:a:s4") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k64:b:u4") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k64:b:s4") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
-
- // wmma/mma b1 -> s32 @ m8n8k128(b1)
- !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
- !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
- !eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
-
- !eq(gft,"m16n8k128:a:b1") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k128:b:b1") : [llvm_i32_ty],
- !eq(gft,"m16n8k128:c:s32") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k128:d:s32") : !listsplat(llvm_i32_ty, 4),
-
- !eq(gft,"m16n8k256:a:b1") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4),
- !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
-
- // ldmatrix b16 -> s32 @ m8n8
- !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
- !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
- !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
-
- // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
- !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
- !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
-
- // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
- !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
- !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
- !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
-
- // stmatrix b8 -> s32 @ m16n8
- !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
- !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
- !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
-
+ list<LLVMType> regs = !if(!eq(IsSparse, true),
+ !cond(
+ // mma sparse ops use other fragments for some arguments
+ !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k16:a:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k16:b:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k16:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
+
+ !eq(gft,"m16n8k32:a:bf16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:a:f16") : !listsplat(llvm_v2f16_ty, 4),
+ !eq(gft,"m16n8k32:b:bf16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:b:f16") : !listsplat(llvm_v2f16_ty, 4),
+ !eq(gft,"m16n8k32:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k32:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
+
+ !eq(gft,"m16n8k16:a:tf32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k16:b:tf32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k16:c:tf32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k16:d:tf32") : !listsplat(llvm_float_ty, 4),
+
+ !eq(gft,"m16n8k8:a:tf32") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k8:b:tf32") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k8:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k8:d:f32") : !listsplat(llvm_float_ty, 4),
+
+ !eq(gft,"m16n8k32:a:u8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:a:s8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:u8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:s8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ !eq(gft,"m16n8k64:a:u8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:a:s8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:a:e4m3") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:a:e5m2") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:a:e3m2") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:a:e2m3") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:u8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:s8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:e4m3") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:e5m2") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:e3m2") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:e2m3") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:e2m1") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k64:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
+
+ !eq(gft,"m16n8k64:a:u4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k64:a:s4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k64:b:u4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k64:b:s4") : !listsplat(llvm_i32_ty, 2),
+
+ !eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ !eq(gft,"m16n8k128:a:u4") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:a:s4") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:b:u4") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:b:s4") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:b:e2m1") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k128:d:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:d:f32") : !listsplat(llvm_float_ty, 4),
+ ),
+ !cond(
+ // mma fp ops use smaller fragments than wmma fp ops
+ !eq(gft,"m8n8k4:a:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m8n8k4:b:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k8:a:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k8:b:f16") : [llvm_v2f16_ty],
+ !eq(gft,"m16n8k8:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k8:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k8:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k8:d:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k16:a:f16") : !listsplat(llvm_v2f16_ty, 4),
+ !eq(gft,"m16n8k16:b:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k16:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
+
+ // wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
+ // All other supported geometries use the same fragment format for f32 and
+ // f16, so we only need to consider {fragment, type}.
+ !eq(ft,"a:f16") : !listsplat(llvm_v2f16_ty, 8),
+ !eq(ft,"b:f16") : !listsplat(llvm_v2f16_ty, 8),
+ !eq(ft,"c:f16") : !listsplat(llvm_v2f16_ty, 4),
+ !eq(ft,"d:f16") : !listsplat(llvm_v2f16_ty, 4),
+ !eq(ft,"c:f32") : !listsplat(llvm_float_ty, 8),
+ !eq(ft,"d:f32") : !listsplat(llvm_float_ty, 8),
+
+ // wmma tf32 -> s32 @ m16n16k8
+ !eq(gft,"m16n16k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n16k8:b:tf32") : !listsplat(llvm_i32_ty, 4),
+
+ // mma tf32 -> s32 @ m16n16k8/m16n8k8
+ !eq(gft,"m16n8k4:a:tf32") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k4:b:tf32") : [llvm_i32_ty],
+ !eq(gft,"m16n8k8:a:tf32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k8:b:tf32") : !listsplat(llvm_i32_ty, 2),
+
+ !eq(gft,"m8n8k4:a:f64") : [llvm_double_ty],
+ !eq(gft,"m8n8k4:b:f64") : [llvm_double_ty],
+ !eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
+ !eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
+
+ // wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+ !eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m8n32k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m8n32k16:b:bf16") : !listsplat(llvm_i32_ty, 8),
+ !eq(gft,"m32n8k16:a:bf16") : !listsplat(llvm_i32_ty, 8),
+ !eq(gft,"m32n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
+
+ // mma bf16 -> s32 @ m16n8k16/m16n8k8
+ !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k16:b:bf16") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k8:a:bf16") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k8:b:bf16") : [llvm_i32_ty],
+
+ // wmma u8/s8 -> s32 @ m16n16k16/m8n32k16/m32n8k16
+ !eq(gft,"m16n16k16:a:u8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n16k16:a:s8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n16k16:b:u8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n16k16:b:s8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n16k16:c:s32") : !listsplat(llvm_i32_ty, 8),
+ !eq(gft,"m16n16k16:d:s32") : !listsplat(llvm_i32_ty, 8),
+
+ !eq(gft,"m8n32k16:a:u8") : [llvm_i32_ty],
+ !eq(gft,"m8n32k16:a:s8") : [llvm_i32_ty],
+ !eq(gft,"m8n32k16:b:u8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m8n32k16:b:s8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m8n32k16:c:s32") : !listsplat(llvm_i32_ty, 8),
+ !eq(gft,"m8n32k16:d:s32") : !listsplat(llvm_i32_ty, 8),
+
+ !eq(gft,"m32n8k16:a:u8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m32n8k16:a:s8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m32n8k16:b:u8") : [llvm_i32_ty],
+ !eq(gft,"m32n8k16:b:s8") : [llvm_i32_ty],
+ !eq(gft,"m32n8k16:c:s32") : !listsplat(llvm_i32_ty, 8),
+ !eq(gft,"m32n8k16:d:s32") : !listsplat(llvm_i32_ty, 8),
+
+ // mma u8/s8 -> s32 @ m8n8k16/m16n8k16/m16n8k32
+ !eq(gft,"m8n8k16:a:u8") : [llvm_i32_ty],
+ !eq(gft,"m8n8k16:a:s8") : [llvm_i32_ty],
+ !eq(gft,"m8n8k16:b:u8") : [llvm_i32_ty],
+ !eq(gft,"m8n8k16:b:s8") : [llvm_i32_ty],
+ !eq(gft,"m8n8k16:c:s32") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m8n8k16:d:s32") : !listsplat(llvm_i32_ty, 2),
+
+ !eq(gft,"m16n8k16:a:u8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k16:a:s8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k16:b:u8") : [llvm_i32_ty],
+ !eq(gft,"m16n8k16:b:s8") : [llvm_i32_ty],
+ !eq(gft,"m16n8k16:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k16:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ !eq(gft,"m16n8k32:a:u8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:a:s8") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:b:u8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:s8") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ // wmma/mma u4/s4 -> s32 @ m8n8k32 (u4/s4)
+ !eq(gft,"m8n8k32:a:u4") : [llvm_i32_ty],
+ !eq(gft,"m8n8k32:a:s4") : [llvm_i32_ty],
+ !eq(gft,"m8n8k32:b:u4") : [llvm_i32_ty],
+ !eq(gft,"m8n8k32:b:s4") : [llvm_i32_ty],
+ !eq(gft,"m8n8k32:c:s32") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m8n8k32:d:s32") : !listsplat(llvm_i32_ty, 2),
+
+ !eq(gft,"m16n8k32:a:u4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:a:s4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:u4") : [llvm_i32_ty],
+ !eq(gft,"m16n8k32:b:s4") : [llvm_i32_ty],
+ !eq(gft,"m16n8k32:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ !eq(gft,"m16n8k64:a:u4") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:a:s4") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:u4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k64:b:s4") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ // wmma/mma b1 -> s32 @ m8n8k128(b1)
+ !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
+ !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
+ !eq(gft,"m8n8k128:c:s32") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m8n8k128:d:s32") : !listsplat(llvm_i32_ty, 2),
+
+ !eq(gft,"m16n8k128:a:b1") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k128:b:b1") : [llvm_i32_ty],
+ !eq(gft,"m16n8k128:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k128:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ !eq(gft,"m16n8k256:a:b1") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k256:b:b1") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k256:c:s32") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b16 -> s32 @ m8n8
+ !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
+ !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
+ !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+
+ // stmatrix b8 -> s32 @ m16n8
+ !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
+ )
);
}
>From e9c32c01c8f192f40d383019da15d558a049204a Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 30 Jul 2025 16:59:19 +0200
Subject: [PATCH 4/5] [NVPTX] The sparsity selector range depends on the shape
and the type. PR150950.
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 21 ++++++++++++++++++++-
1 file changed, 20 insertions(+), 1 deletion(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 1d4edd078e01f..953de3de2ba70 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -2188,8 +2188,27 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty], [llvm_i32_ty])> {
int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
+
+ // The range [0;range) is for the sparsity selector that indicates the threads
+ // which contribute metadata.
+ int range = !if(!or(!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "bf16")),
+ !and(!eq(A.geom, "m16n8k16"), !eq(A.ptx_elt_type, "tf32")),
+ !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "u8")),
+ !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "s8")),
+ !and(!eq(A.geom, "m16n8k64"), !eq(A.ptx_elt_type, "u4")),
+ !and(!eq(A.geom, "m16n8k64"), !eq(A.ptx_elt_type, "s4"))),
+ 2,
+ !if(!and(!eq(A.geom, "m16n8k64"),
+ !or(!eq(A.ptx_elt_type, "u8"),
+ !eq(A.ptx_elt_type, "s8"),
+ !eq(A.ptx_elt_type, "e4m3"),
+ !eq(A.ptx_elt_type, "e5m2"),
+ !eq(A.ptx_elt_type, "e3m2"),
+ !eq(A.ptx_elt_type, "e2m3"),
+ !eq(A.ptx_elt_type, "e2m1"))),
+ 1, 4));
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
- Range<ArgIndex<pos>, 0, 4>];
+ Range<ArgIndex<pos>, 0, range>];
}
foreach metadata = ["sp", "sp::ordered_metadata"] in {
>From eb8e14bb6c2508cd8019ef217894b99eb77304a4 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 30 Jul 2025 18:32:25 +0200
Subject: [PATCH 5/5] [NVPTX] Renamed a variable for the sparsity selector.
PR150950.
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 36 +++++++++++++-------------
1 file changed, 18 insertions(+), 18 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 953de3de2ba70..9f56c3ab436e0 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -2189,26 +2189,26 @@ class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty], [llvm_i32_ty])> {
int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
- // The range [0;range) is for the sparsity selector that indicates the threads
+ // The range [0;num_threads) is for the sparsity selector that indicates the threads
// which contribute metadata.
- int range = !if(!or(!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "bf16")),
- !and(!eq(A.geom, "m16n8k16"), !eq(A.ptx_elt_type, "tf32")),
- !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "u8")),
- !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "s8")),
- !and(!eq(A.geom, "m16n8k64"), !eq(A.ptx_elt_type, "u4")),
- !and(!eq(A.geom, "m16n8k64"), !eq(A.ptx_elt_type, "s4"))),
- 2,
- !if(!and(!eq(A.geom, "m16n8k64"),
- !or(!eq(A.ptx_elt_type, "u8"),
- !eq(A.ptx_elt_type, "s8"),
- !eq(A.ptx_elt_type, "e4m3"),
- !eq(A.ptx_elt_type, "e5m2"),
- !eq(A.ptx_elt_type, "e3m2"),
- !eq(A.ptx_elt_type, "e2m3"),
- !eq(A.ptx_elt_type, "e2m1"))),
- 1, 4));
+ int num_threads = !if(!or(!and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "bf16")),
+ !and(!eq(A.geom, "m16n8k16"), !eq(A.ptx_elt_type, "tf32")),
+ !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "u8")),
+ !and(!eq(A.geom, "m16n8k32"), !eq(A.ptx_elt_type, "s8")),
+ !and(!eq(A.geom, "m16n8k64"), !eq(A.ptx_elt_type, "u4")),
+ !and(!eq(A.geom, "m16n8k64"), !eq(A.ptx_elt_type, "s4"))),
+ 2,
+ !if(!and(!eq(A.geom, "m16n8k64"),
+ !or(!eq(A.ptx_elt_type, "u8"),
+ !eq(A.ptx_elt_type, "s8"),
+ !eq(A.ptx_elt_type, "e4m3"),
+ !eq(A.ptx_elt_type, "e5m2"),
+ !eq(A.ptx_elt_type, "e3m2"),
+ !eq(A.ptx_elt_type, "e2m3"),
+ !eq(A.ptx_elt_type, "e2m1"))),
+ 1, 4));
let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
- Range<ArgIndex<pos>, 0, range>];
+ Range<ArgIndex<pos>, 0, num_threads>];
}
foreach metadata = ["sp", "sp::ordered_metadata"] in {
More information about the llvm-commits
mailing list