[llvm] 2f627c1 - [NVPTX] Support for dense and sparse MMA intrinsics with block scaling. (#163561)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 21 04:13:57 PST 2025
Author: Kirill Vedernikov
Date: 2025-11-21T17:43:52+05:30
New Revision: 2f627c1878a3dba594c872773107c556992af3a1
URL: https://github.com/llvm/llvm-project/commit/2f627c1878a3dba594c872773107c556992af3a1
DIFF: https://github.com/llvm/llvm-project/commit/2f627c1878a3dba594c872773107c556992af3a1.diff
LOG: [NVPTX] Support for dense and sparse MMA intrinsics with block scaling. (#163561)
This change adds dense and sparse MMA intrinsics with block scaling. The
implementation is based on [PTX ISA version
9.0](https://docs.nvidia.com/cuda/parallel-thread-execution/). Tests for
new intrinsics are added for PTX 8.7 and SM 120a and are generated by
`llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py`. The tests have been
verified with ptxas from CUDA-13.0 release.
Dense MMA intrinsics with block scaling were supported by
@schwarzschild-radius.
Added:
llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py
Modified:
llvm/include/llvm/IR/IntrinsicsNVVM.td
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/test/CodeGen/NVPTX/wmma.py
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index be4f99aaaa241..c71f37f671539 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -178,7 +178,8 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
string gft = Geom#":"#Frag#":"#ptx_elt_type;
string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
- list<LLVMType> regs = !if(!eq(IsSparse, true),
+ bit isSparse = IsSparse;
+ list<LLVMType> regs = !if(!eq(isSparse, true),
!cond(
// mma sparse ops use other fragments for some arguments
!eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
@@ -277,6 +278,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
+ // mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
+ !eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
+
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
// All other supported geometries use the same fragment format for f32 and
// f16, so we only need to consider {fragment, type}.
@@ -520,6 +525,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
# signature;
}
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string record_name = "int_nvvm_mma_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
WMMA_REGS A, WMMA_REGS B,
WMMA_REGS C, WMMA_REGS D> {
@@ -533,6 +550,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
# signature;
}
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B,
+ WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string record_name = "int_nvvm_mma_sp_ordered_metadata_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
// Helper class that takes an intrinsic name and construct a record name.
// Additionally, sets `intr_name` to be non-empty if the default name assigned
// to this intrinsic will not match the name given.
@@ -683,6 +713,18 @@ class NVVM_MMA_OPS {
fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
int_mma_ops, subint_mma_ops, bit_mma_ops);
+ list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+ ["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
+ >.ret;
+
+ list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+ ["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
+ >.ret;
+
+ list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+ mxf4_mma_ops, mxf8f6f4_mma_ops);
+
list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
["m16n8k16", "m16n8k32"],
["bf16"], [], ["f32"], [], true>.ret;
@@ -707,6 +749,18 @@ class NVVM_MMA_OPS {
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
subint_mma_sp_ops, int_mma_sp_ops);
+ // combines available geoms and types for mxf4 and mxf4nvf4 kinds
+ list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+ ["m16n8k128"],
+ ["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+ mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -900,6 +954,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
);
}
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
+ string geom = frags[0].geom;
+
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x")),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_2x"),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_4x"),
+ !eq(stype, "ue4m3")) : true,
+ !and(!eq(geom, "m16n8k32"),
+ !eq(kind, "mxf8f6f4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x")),
+ !eq(stype, "ue8m0")) : true,
+ true: false
+ );
+}
+
// Returns true if the fragment is valid for ldmatrix ops is supported;
// false otherwise.
// E.g.
@@ -998,6 +1078,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
}
+// Returns true if this combination of kind/scale_vec_size/stype
+// for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+ string stype, string scale_vec_size> {
+ // MMA.SP ops check both layouts.
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x"))): true,
+
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue8m0"),
+ !eq(scale_vec_size, ".scale_2x")): true,
+
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue4m3"),
+ !eq(scale_vec_size, ".scale_4x")): true,
+
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf8f6f4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x"))): true,
+
+ // All other are NOT OK.
+ true: false
+ );
+}
+
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -2452,6 +2577,31 @@ foreach layout_a = ["row", "col"] in {
} // layout_b
} // layout_a
+class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs,
+ [
+ llvm_i32_ty, // scale-a-data
+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+ llvm_i32_ty, // scale-b-data,
+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+ ]),
+ [IntrNoMem, IntrNoCallback]>;
+
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+ if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+ op[0], op[1], op[2], op[3]>.record_name
+ : NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+} // kind
+
// MMA.SP
class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
@@ -2499,6 +2649,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
} // kind
} // metadata
+// MMA.SP BLOCK SCALE
+class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs,
+ [
+ llvm_i32_ty, // metadata
+ llvm_i32_ty, // sparsity selector
+ llvm_i32_ty, // scale-a-data
+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+ llvm_i32_ty, // scale-b-data
+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+ ])> {
+ int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
+
+ // The range [0;num_threads) is for the sparsity selector that indicates the threads
+ // which contribute metadata.
+ // According to PTX ISA 9.0, the sparsity selector is always 0
+ // for sparse MMA block scale instructions
+ int num_threads = 1;
+ let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
+ Range<ArgIndex<pos>, 0, num_threads>];
+}
+
+// According to PTX ISA 9.0
+// a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+ if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+ op[0], op[1], op[2], op[3]>.record_name
+ : NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+} // kind
+
// LDMATRIX
class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
: Intrinsic<Frag.regs, [llvm_anyptr_ty],
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index b54cce4781b8d..8501d4d7bb86f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4684,7 +4684,8 @@ def INT_PTX_SREG_WARPSIZE :
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
- : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> {
+ : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type,
+ !or(!eq(op, "mma.sp"), !eq(op, "mma.sp.block_scale"))> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
!eq(ptx_elt_type, "e4m3") : B32,
@@ -4724,6 +4725,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
// longer the case, we can concat all per-fragment predicates to enforce that
// all fragments of the instruction are viable.
list<Predicate> Predicates = !cond(
+ !or(!eq(op, "mma.block_scale"),
+ !eq(op, "mma.sp.block_scale")) : [hasSM120a, hasPTX<88>],
+
!or(!eq(ptx_elt_type, "e3m2"),
!eq(ptx_elt_type, "e2m3"),
!eq(ptx_elt_type, "e2m1"),
@@ -4736,9 +4740,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!or(!eq(ptx_elt_type, "e4m3"),
!eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
- !and(!eq(op, "mma.sp"),
+ !and(isSparse,
!ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>],
- !eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>],
+ isSparse : [hasSM<80>, hasPTX<71>],
// fp16 -> fp16/fp32 @ m16n16k16
!and(!eq(geom, "m16n16k16"),
@@ -5087,6 +5091,67 @@ defset list<WMMA_INSTR> MMAs = {
} // defset
}
+// MMA.block_scale
+class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Kind, string SType, string ScaleVecSize>
+ : WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+ FragA, FragB, FragC, FragD>.record_name,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$scale_a, B16:$byte_id_a,
+ B16:$thread_id_a, B32:$scale_b,
+ B16:$byte_id_b, B16:$thread_id_b)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<FragA.Predicates> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = !interleave([FragD.ptx_elt_type,
+ FragA.ptx_elt_type,
+ FragB.ptx_elt_type,
+ FragC.ptx_elt_type], ".");
+ string ScaleVecSizeStr = !cond(
+ !eq(ScaleVecSize, "") : "",
+ !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+ !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+ !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+ );
+ let AsmString = "mma.sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # ".kind::" # Kind
+ # ".block_scale"
+ # ScaleVecSizeStr
+ # "." # TypeList
+ # "." # SType # " \n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+ # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_BLOCK_SCALEs = {
+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+ if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def : MMA_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[1], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[2], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[3], "mma.block_scale", "", kind>,
+ kind, stype, scale_vec_size>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+ } // kind
+} // defset
+}
+
// MMA SP
class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
@@ -5143,6 +5208,72 @@ defset list<WMMA_INSTR> MMA_SPs = {
} // defset
}
+// MMA SP BLOCK SCALE
+class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Kind, string SType, string ScaleVecSize>
+ : WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+ FragA, FragB, FragC, FragD>.record_name,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$metadata, i32imm:$selector,
+ B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
+ B32:$scale_b, B16:$byte_id_b, B16:$thread_id_b)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<!listconcat(FragA.Predicates,
+ FragB.Predicates,
+ FragC.Predicates,
+ FragD.Predicates)> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type;
+ string ScaleVecSizeStr = !cond(
+ !eq(ScaleVecSize, "") : "",
+ !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+ !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+ !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+ );
+ let AsmString = "mma.sp::ordered_metadata.sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # ".kind::" # Kind
+ # ".block_scale"
+ # ScaleVecSizeStr
+ # TypeList
+ # "." # SType # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$metadata" # ",\n\t\t"
+ # "$selector" # ",\n\t\t"
+ # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+ # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+ if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[1], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[2], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[3], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ kind, stype, scale_vec_size>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+ } // kind
+} // defset
+}
+
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
@@ -5224,7 +5355,8 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in
+foreach mma = !listconcat(MMAs, MMA_BLOCK_SCALEs, WMMAs, MMA_LDSTs, LDMATRIXs,
+ STMATRIXs, MMA_SPs, MMA_SP_BLOCK_SCALEs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py
new file mode 100644
index 0000000000000..f1666dbc5f30f
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py
@@ -0,0 +1,12 @@
+# Check all variants of instructions supported by PTX88 on SM120a
+# RUN: %python %s --ptx=88 --gpu-arch=120 --aa > %t-ptx88-sm_120a.ll
+# RUN: llc < %t-ptx88-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx88 \
+# RUN: | FileCheck %t-ptx88-sm_120a.ll
+# RUN: %if ptxas-sm_120a && ptxas-isa-8.8 %{ \
+# RUN: llc < %t-ptx88-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx88 \
+# RUN: | %ptxas-verify -arch=sm_120a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 8427ae4ad72da..817665a68d7a5 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -131,7 +131,7 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
"m16n8k64:b:e5m2": 4,
"m16n8k64:b:e3m2": 4,
"m16n8k64:b:e2m3": 4,
- "m16n8k64:b:e2m1": 4,
+ "m16n8k64:b:e2m1": 4 if is_mma_sparse else 2,
"m16n8k64:c:f16": 2,
"m16n8k64:c:f32": 4,
"m16n8k64:d:f16": 2,
@@ -1131,6 +1131,160 @@ def gen_mma_tests():
return generated_items
+def get_mma_block_scale_ops():
+ return make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], []) + make_mma_ops(
+ ["m16n8k32"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"],
+ [],
+ )
+
+
+def is_mma_block_scale_geom_supported(geom):
+ # geometries for FP.
+ if geom in [
+ "m16n8k32",
+ "m16n8k64",
+ ]:
+ return True
+ raise ValueError(f"Unexpected MMA block scale geometry: {geom}")
+
+
+def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+ if not (
+ is_type_supported(op.a.mma_type.ptx_type)
+ and is_mma_block_scale_geom_supported(op.a.geom)
+ ):
+ return False
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::2X"]
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf4nvf4"
+ and stype == "ue8m0"
+ and scale_vec_size == ".scale_vec::2X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf4nvf4"
+ and stype == "ue4m3"
+ and scale_vec_size == ".scale_vec::4X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k32"
+ and kind == "mxf8f6f4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::1X"]
+ ):
+ return True
+
+ return False
+
+
+def common_mma_block_scale_test_gen(
+ params, op, intrinsic_template, instruction_template
+):
+ mma_block_scale_template = """
+declare ${ret_ty} @${intrinsic}(
+ ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define ${ret_ty} @test_${function}(
+ ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_scale_a_data}
+; CHECK-NEXT: ${check_byte_id_a}
+; CHECK-NEXT: ${check_thread_id_a}
+; CHECK-NEXT: ${check_scale_b_data}
+; CHECK-NEXT: ${check_byte_id_b}
+; CHECK-NEXT: ${check_thread_id_b}
+ %r = call ${ret_ty} @${intrinsic}(
+ ${args});
+ ret ${ret_ty} %r;
+}
+"""
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+ test_params["check_a"] = check_pattern(op.a)
+ test_params["check_b"] = check_pattern(op.b)
+ test_params["check_c"] = check_pattern(op.c)
+ test_params["check_d"] = check_pattern(op.d)
+ test_params["check_scale_a_data"] = "{{%r[0-9]+}}"
+ test_params["check_byte_id_a"] = "{{%r[0-9]+}}"
+ test_params["check_thread_id_a"] = "{{%r[0-9]+}}"
+ test_params["check_scale_b_data"] = "{{%r[0-9]+}}"
+ test_params["check_byte_id_b"] = "{{%r[0-9]+}}"
+ test_params["check_thread_id_b"] = "{{%r[0-9]+}}"
+ args = ",\n ".join(
+ list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
+ + ["i32 %scale_a_data", "i16 %byte_id_a, i16 %thread_id_a"]
+ + ["i32 %scale_b_data", "i16 %byte_id_b, i16 %thread_id_b"]
+ )
+ test_params["args"] = args
+ print(Template(mma_block_scale_template).substitute(test_params))
+ return (test_params["intrinsic"], test_params["instruction"])
+
+
+def gen_mma_block_scale_tests():
+ if not (ptx_version >= 88 and gpu_arch >= 120 and aa):
+ return []
+
+ mma_block_scale_intrinsic_template = "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+ mma_block_scale_instruction_template = "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
+
+ generated_items = []
+
+ for op, kind, scale_vec_size, stype in product(
+ get_mma_block_scale_ops(),
+ ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ ["", ".scale_vec::1X", ".scale_vec::2X", ".scale_vec::4X"],
+ ["ue8m0", "ue4m3"],
+ ):
+ if not is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+ continue
+
+ params = {
+ "intrinsic_signature": mma_signature(op),
+ "ptx_signature": mma_ptx_signature(op),
+ "geom": op.a.geom,
+ "kind": kind,
+ "scale_vec_size": scale_vec_size,
+ "scale": scale_vec_size.replace("_vec::", ".").lower(),
+ "stype": stype,
+ }
+
+ intrinsic_template = mma_block_scale_intrinsic_template
+ instruction_template = mma_block_scale_instruction_template
+
+ generated_items.append(
+ common_mma_block_scale_test_gen(
+ params, op, intrinsic_template, instruction_template
+ )
+ )
+
+ return generated_items
+
+
def get_mma_sp_ops():
return (
make_mma_ops(["m16n8k16", "m16n8k32"], ["bf16"], [], ["f32"], [], True)
@@ -1224,7 +1378,11 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
return True
-def sp_selector_gen(op):
+def sp_selector_gen(op, block_scale=False):
+ if block_scale:
+ # PTX ISA 9.0 has the sparsity selector equal to 0 only
+ return range(1)
+
# (geom, type) -> allowed selector range
range_01 = {
("m16n8k32", "bf16"),
@@ -1355,6 +1513,178 @@ def gen_mma_sp_tests():
return generated_items
+def get_mma_sp_block_scale_ops():
+ return make_mma_ops(["m16n8k128"], ["e2m1"], [], ["f32"], [], True) + make_mma_ops(
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"],
+ [],
+ True,
+ )
+
+
+def is_mma_sp_block_scale_geom_supported(geom):
+ # geometries for FP.
+ if geom in [
+ "m16n8k64",
+ "m16n8k128",
+ ]:
+ return True
+ raise ValueError(f"Unexpected sparse MMA block scale geometry: {geom}")
+
+
+def is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+ if not (
+ is_type_supported(op.a.mma_type.ptx_type)
+ and is_mma_sp_block_scale_geom_supported(op.a.geom)
+ ):
+ return False
+
+ if (
+ op.a.geom == "m16n8k128"
+ and kind == "mxf4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::2X"]
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k128"
+ and kind == "mxf4nvf4"
+ and stype == "ue8m0"
+ and scale_vec_size == ".scale_vec::2X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k128"
+ and kind == "mxf4nvf4"
+ and stype == "ue4m3"
+ and scale_vec_size == ".scale_vec::4X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf8f6f4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::1X"]
+ ):
+ return True
+
+ return False
+
+
+def common_mma_sp_block_scale_test_gen(
+ params, op, intrinsic_template, instruction_template
+):
+ mma_sp_block_scale_decl_template = """
+declare ${ret_ty} @${intrinsic}(
+ ${args});
+"""
+
+ mma_sp_block_scale_test_template = """
+; CHECK-LABEL: .func {{.*}}test_${function}_${selector}(
+define ${ret_ty} @test_${function}_${selector}(
+ ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_metadata}
+; CHECK-NEXT: ${check_selector}
+; CHECK-NEXT: ${check_scale_a_data}
+; CHECK-NEXT: ${check_byte_id_a}
+; CHECK-NEXT: ${check_thread_id_a}
+; CHECK-NEXT: ${check_scale_b_data}
+; CHECK-NEXT: ${check_byte_id_b}
+; CHECK-NEXT: ${check_thread_id_b}
+ %r = call ${ret_ty} @${intrinsic}(
+ ${call_args});
+ ret ${ret_ty} %r;
+}
+"""
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+ test_params["check_a"] = check_pattern(op.a)
+ test_params["check_b"] = check_pattern(op.b)
+ test_params["check_c"] = check_pattern(op.c)
+ test_params["check_d"] = check_pattern(op.d)
+ test_params["check_metadata"] = "{{%r[0-9]+}}"
+ test_params["check_scale_a_data"] = "{{%r[0-9]+}}"
+ test_params["check_byte_id_a"] = "{{%r[0-9]+}}"
+ test_params["check_thread_id_a"] = "{{%r[0-9]+}}"
+ test_params["check_scale_b_data"] = "{{%r[0-9]+}}"
+ test_params["check_byte_id_b"] = "{{%r[0-9]+}}"
+ test_params["check_thread_id_b"] = "{{%r[0-9]+}}"
+ args = ",\n ".join(
+ list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
+ + ["i32 %metadata", "i32 %selector"]
+ + ["i32 %scale_a_data", "i16 %byte_id_a, i16 %thread_id_a"]
+ + ["i32 %scale_b_data", "i16 %byte_id_b, i16 %thread_id_b"]
+ )
+ test_params["args"] = args
+
+ print(Template(mma_sp_block_scale_decl_template).substitute(test_params))
+
+ for selector in [str(r) for r in sp_selector_gen(op, True)]:
+ test_params["selector"] = selector
+ test_params["check_selector"] = "{{" + test_params["selector"] + "}}"
+ test_params["call_args"] = test_params["args"].replace(
+ "%selector", test_params["selector"]
+ )
+
+ print(Template(mma_sp_block_scale_test_template).substitute(test_params))
+
+ return (test_params["intrinsic"], test_params["instruction"])
+
+
+def gen_mma_sp_block_scale_tests():
+ if not (ptx_version >= 88 and gpu_arch >= 120 and aa):
+ return []
+
+ mma_sp_block_scale_intrinsic_template = "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+ mma_sp_block_scale_instruction_template = "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
+
+ generated_items = []
+
+ for op, kind, scale_vec_size, stype in product(
+ get_mma_sp_block_scale_ops(),
+ ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ ["", ".scale_vec::1X", ".scale_vec::2X", ".scale_vec::4X"],
+ ["ue8m0", "ue4m3"],
+ ):
+ if not is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+ continue
+
+ params = {
+ "intrinsic_signature": mma_signature(op),
+ "ptx_signature": mma_ptx_signature(op),
+ "geom": op.a.geom,
+ "kind": kind,
+ "scale_vec_size": scale_vec_size,
+ "scale": scale_vec_size.replace("_vec::", ".").lower(),
+ "stype": stype,
+ }
+
+ intrinsic_template = mma_sp_block_scale_intrinsic_template
+ instruction_template = mma_sp_block_scale_instruction_template
+
+ generated_items.append(
+ common_mma_sp_block_scale_test_gen(
+ params, op, intrinsic_template, instruction_template
+ )
+ )
+
+ return generated_items
+
+
# Append complete list of intrinsics and instructions we've generated tests for.
# Generate set of checks to verify that that we did generate sensible set of
# tests for the given combination of PTX and SM variants.
@@ -1545,7 +1875,9 @@ def gen_tests():
items += gen_stmatrix_tests()
items += gen_wmma_mma_tests()
items += gen_mma_tests()
+ items += gen_mma_block_scale_tests()
items += gen_mma_sp_tests()
+ items += gen_mma_sp_block_scale_tests()
gen_check_unsupported_ops(items)
More information about the llvm-commits
mailing list