[llvm] [NVPTX] Support for dense and sparse MMA intrinsics with block scaling. (PR #163561)
Kirill Vedernikov via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 18 22:27:33 PST 2025
https://github.com/kvederni updated https://github.com/llvm/llvm-project/pull/163561
>From 961cc0fa37b93af3e5b0a5c23f787d728b527860 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 15 Oct 2025 15:22:10 +0200
Subject: [PATCH 1/3] [NVPTX] Support for dense and sparse MMA intrinsics with
block scaling.
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 188 +++++++++++++
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 130 ++++++++-
llvm/test/CodeGen/NVPTX/wmma.py | 342 ++++++++++++++++++++++-
3 files changed, 657 insertions(+), 3 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 1b485dc8ccd1e..2a8d310b94065 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -277,6 +277,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
!eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
+ // mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
+ !eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
+
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
// All other supported geometries use the same fragment format for f32 and
// f16, so we only need to consider {fragment, type}.
@@ -520,6 +524,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
# signature;
}
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string record_name = "int_nvvm_mma_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
WMMA_REGS A, WMMA_REGS B,
WMMA_REGS C, WMMA_REGS D> {
@@ -533,6 +549,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
# signature;
}
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+ WMMA_REGS A, WMMA_REGS B,
+ WMMA_REGS C, WMMA_REGS D> {
+ string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+ string record_name = "int_nvvm_mma_sp_ordered_metadata_block_scale"
+ # "_" # A.geom
+ # "_row_col"
+ # "_" # Kind
+ # !subst(".", "_", ScaleVecSize)
+ # signature
+ # "_" # SType;
+}
+
// Helper class that takes an intrinsic name and construct a record name.
// Additionally, sets `intr_name` to be non-empty if the default name assigned
// to this intrinsic will not match the name given.
@@ -683,6 +712,18 @@ class NVVM_MMA_OPS {
fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
int_mma_ops, subint_mma_ops, bit_mma_ops);
+ list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+ ["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
+ >.ret;
+
+ list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+ ["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+ ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
+ >.ret;
+
+ list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+ mxf4_mma_ops, mxf8f6f4_mma_ops);
+
list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
["m16n8k16", "m16n8k32"],
["bf16"], [], ["f32"], [], true>.ret;
@@ -707,6 +748,18 @@ class NVVM_MMA_OPS {
bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
subint_mma_sp_ops, int_mma_sp_ops);
+ // combines available geoms and types for mxf4 and mxf4nvf4 kinds
+ list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+ ["m16n8k128"],
+ ["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"], [], true>.ret;
+ list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+ mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
["m16n16k16", "m32n8k16", "m8n32k16"],
["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -900,6 +953,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
);
}
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
+ string geom = frags[0].geom;
+
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x")),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_2x"),
+ !eq(stype, "ue8m0")) : true,
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(scale_vec_size, ".scale_4x"),
+ !eq(stype, "ue4m3")) : true,
+ !and(!eq(geom, "m16n8k32"),
+ !eq(kind, "mxf8f6f4"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x")),
+ !eq(stype, "ue8m0")) : true,
+ true: false
+ );
+}
+
// Returns true if the fragment is valid for ldmatrix ops is supported;
// false otherwise.
// E.g.
@@ -998,6 +1077,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
}
+// Returns true if this combination of kind/scale_vec_size/stype
+// for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
+// def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+ string stype, string scale_vec_size> {
+ // MMA.SP ops check both layouts.
+ string a_type = frags[0].ptx_elt_type;
+ string b_type = frags[1].ptx_elt_type;
+ string c_type = frags[2].ptx_elt_type;
+ string d_type = frags[3].ptx_elt_type;
+ string geom = frags[0].geom;
+
+ bit ret = !cond(
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_2x"))): true,
+
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue8m0"),
+ !eq(scale_vec_size, ".scale_2x")): true,
+
+ !and(!eq(geom, "m16n8k128"),
+ !eq(kind, "mxf4nvf4"),
+ !eq(stype, "ue4m3"),
+ !eq(scale_vec_size, ".scale_4x")): true,
+
+ !and(!eq(geom, "m16n8k64"),
+ !eq(kind, "mxf8f6f4"),
+ !eq(stype, "ue8m0"),
+ !or(!eq(scale_vec_size, ""),
+ !eq(scale_vec_size, ".scale_1x"))): true,
+
+ // All other are NOT OK.
+ true: false
+ );
+}
+
+
class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
string Suffix = !if(sync, "sync_", "")
# mode # "_"
@@ -2415,6 +2539,31 @@ foreach layout_a = ["row", "col"] in {
} // layout_b
} // layout_a
+class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs,
+ [
+ llvm_i32_ty, // scale-a-data
+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+ llvm_i32_ty, // scale-b-data,
+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+ ]),
+ [IntrNoMem, IntrNoCallback]>;
+
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+ if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+ op[0], op[1], op[2], op[3]>.record_name
+ : NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+} // kind
+
// MMA.SP
class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
: Intrinsic<D.regs,
@@ -2462,6 +2611,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
} // kind
} // metadata
+// MMA.SP BLOCK SCALE
+class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+ : Intrinsic<D.regs,
+ !listconcat(A.regs, B.regs, C.regs,
+ [
+ llvm_i32_ty, // metadata
+ llvm_i32_ty, // sparsity selector
+ llvm_i32_ty, // scale-a-data
+ llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+ llvm_i32_ty, // scale-b-data
+ llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+ ])> {
+ int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
+
+ // The range [0;num_threads) is for the sparsity selector that indicates the threads
+ // which contribute metadata.
+ // According to PTX ISA 9.0, the sparsity selector is always 0
+ // for sparse MMA block scale instructions
+ int num_threads = 1;
+ let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
+ Range<ArgIndex<pos>, 0, num_threads>];
+}
+
+// According to PTX ISA 9.0
+// a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+ if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+ op[0], op[1], op[2], op[3]>.record_name
+ : NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+} // kind
+
// LDMATRIX
class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
: Intrinsic<Frag.regs, [llvm_anyptr_ty],
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index ea69a54e6db37..5ef13b3be6162 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -5022,6 +5022,67 @@ defset list<WMMA_INSTR> MMAs = {
} // defset
}
+// MMA.block_scale
+class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Kind, string SType, string ScaleVecSize>
+ : WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+ FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$scale_a, B16:$byte_id_a,
+ B16:$thread_id_a, B32:$scale_b,
+ B16:$byte_id_b, B16:$thread_id_b)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<FragA.Predicates> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = !interleave([FragD.ptx_elt_type,
+ FragA.ptx_elt_type,
+ FragB.ptx_elt_type,
+ FragC.ptx_elt_type], ".");
+ string ScaleVecSizeStr = !cond(
+ !eq(ScaleVecSize, "") : "",
+ !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+ !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+ !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+ );
+ let AsmString = "mma.sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # ".kind::" # Kind
+ # ".block_scale"
+ # ScaleVecSizeStr
+ # "." # TypeList
+ # "." # SType # " \n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+ # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_BLOCK_SCALEs = {
+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+ if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def : MMA_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[1], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[2], "mma.block_scale", "", kind>,
+ WMMA_REGINFO<op[3], "mma.block_scale", "", kind>,
+ kind, stype, scale_vec_size>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+ } // kind
+} // defset
+}
+
// MMA SP
class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
@@ -5078,6 +5139,72 @@ defset list<WMMA_INSTR> MMA_SPs = {
} // defset
}
+// MMA SP BLOCK SCALE
+class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+ WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+ string Kind, string SType, string ScaleVecSize>
+ : WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+ FragA, FragB, FragC, FragD>.record,
+ [FragA.Ins, FragB.Ins, FragC.Ins,
+ (ins B32:$metadata, i32imm:$selector,
+ B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
+ B32:$scale_b, B16:$byte_id_b, B16:$thread_id_b)]>,
+ // Requires does not seem to have effect on Instruction w/o Patterns.
+ // We set it here anyways and propagate to the Pat<> we construct below.
+ Requires<!listconcat(FragA.Predicates,
+ FragB.Predicates,
+ FragC.Predicates,
+ FragD.Predicates)> {
+ let OutOperandList = FragD.Outs;
+ let InOperandList = !con(Args, (ins MmaCode:$ptx));
+ string TypeList = "." # FragD.ptx_elt_type
+ # "." # FragA.ptx_elt_type
+ # "." # FragB.ptx_elt_type
+ # "." # FragC.ptx_elt_type;
+ string ScaleVecSizeStr = !cond(
+ !eq(ScaleVecSize, "") : "",
+ !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+ !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+ !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+ );
+ let AsmString = "mma.sp::ordered_metadata.sync.aligned."
+ # FragA.geom
+ # ".row.col"
+ # ".kind::" # Kind
+ # ".block_scale"
+ # ScaleVecSizeStr
+ # TypeList
+ # "." # SType # "\n\t\t"
+ # FragD.regstring # ",\n\t\t"
+ # FragA.regstring # ",\n\t\t"
+ # FragB.regstring # ",\n\t\t"
+ # FragC.regstring # ",\n\t\t"
+ # "$metadata" # ",\n\t\t"
+ # "$selector" # ",\n\t\t"
+ # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+ # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
+ foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+ foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+ foreach stype = ["ue8m0", "ue4m3"] in {
+ foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+ if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+ def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[1], "mma.sp", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[2], "mma.sp", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[3], "mma.sp", "sp::ordered_metadata", kind>,
+ kind, stype, scale_vec_size>;
+ }
+ } // op
+ } // stype
+ } // scale_vec_size
+ } // kind
+} // defset
+}
+
//
// ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
//
@@ -5159,7 +5286,8 @@ class MMA_PAT<WMMA_INSTR wi>
Requires<wi.Predicates>;
// Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in
+foreach mma = !listconcat(MMAs, MMA_BLOCK_SCALEs, WMMAs, MMA_LDSTs, LDMATRIXs,
+ STMATRIXs, MMA_SPs, MMA_SP_BLOCK_SCALEs) in
def : MMA_PAT<mma>;
multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 8427ae4ad72da..81c78219075f3 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -131,7 +131,7 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
"m16n8k64:b:e5m2": 4,
"m16n8k64:b:e3m2": 4,
"m16n8k64:b:e2m3": 4,
- "m16n8k64:b:e2m1": 4,
+ "m16n8k64:b:e2m1": 4 if is_mma_sparse else 2,
"m16n8k64:c:f16": 2,
"m16n8k64:c:f32": 4,
"m16n8k64:d:f16": 2,
@@ -1131,6 +1131,163 @@ def gen_mma_tests():
return generated_items
+def get_mma_block_scale_ops():
+ return (
+ make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], [])
+ + make_mma_ops(
+ ["m16n8k32"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"],
+ [],
+ )
+ )
+
+
+def is_mma_block_scale_geom_supported(geom):
+ # geometries for FP.
+ if geom in [
+ "m16n8k32",
+ "m16n8k64",
+ ]:
+ return True
+ raise ValueError(f"Unexpected MMA block scale geometry: {geom}")
+
+
+def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+ if not (
+ is_type_supported(op.a.mma_type.ptx_type)
+ and is_mma_block_scale_geom_supported(op.a.geom)
+ ):
+ return False
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::2X"]
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf4nvf4"
+ and stype == "ue8m0"
+ and scale_vec_size == ".scale_vec::2X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf4nvf4"
+ and stype == "ue4m3"
+ and scale_vec_size == ".scale_vec::4X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k32"
+ and kind == "mxf8f6f4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::1X"]
+ ):
+ return True
+
+ return False
+
+
+def common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+ mma_block_scale_template = """
+declare ${ret_ty} @${intrinsic}(
+ ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define ${ret_ty} @test_${function}(
+ ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_scale_a_data}
+; CHECK-NEXT: ${check_byte_id_a}
+; CHECK-NEXT: ${check_thread_id_a}
+; CHECK-NEXT: ${check_scale_b_data}
+; CHECK-NEXT: ${check_byte_id_b}
+; CHECK-NEXT: ${check_thread_id_b}
+ %r = call ${ret_ty} @${intrinsic}(
+ ${args});
+ ret ${ret_ty} %r;
+}
+"""
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+ test_params["check_a"] = check_pattern(op.a)
+ test_params["check_b"] = check_pattern(op.b)
+ test_params["check_c"] = check_pattern(op.c)
+ test_params["check_d"] = check_pattern(op.d)
+ test_params["check_scale_a_data"] = "{{%r[0-9]+}}"
+ test_params["check_byte_id_a"] = "{{%r[0-9]+}}"
+ test_params["check_thread_id_a"] = "{{%r[0-9]+}}"
+ test_params["check_scale_b_data"] = "{{%r[0-9]+}}"
+ test_params["check_byte_id_b"] = "{{%r[0-9]+}}"
+ test_params["check_thread_id_b"] = "{{%r[0-9]+}}"
+ args = ",\n ".join(
+ list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
+ + ["i32 %scale_a_data", "i16 %byte_id_a, i16 %thread_id_a"]
+ + ["i32 %scale_b_data", "i16 %byte_id_b, i16 %thread_id_b"]
+ )
+ test_params["args"] = args
+ print(Template(mma_block_scale_template).substitute(test_params))
+ return (test_params["intrinsic"], test_params["instruction"])
+
+
+def gen_mma_block_scale_tests():
+ if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
+ return []
+
+ mma_block_scale_intrinsic_template = (
+ "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+ )
+ mma_block_scale_instruction_template = (
+ "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
+ )
+
+ generated_items = []
+
+ for op, kind, scale_vec_size, stype in product(
+ get_mma_block_scale_ops(),
+ ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ ["", ".scale_vec::1X", ".scale_vec::2X", ".scale_vec::4X"],
+ ["ue8m0", "ue4m3"],
+ ):
+ if not is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+ continue
+
+ params = {
+ "intrinsic_signature": mma_signature(op),
+ "ptx_signature": mma_ptx_signature(op),
+ "geom": op.a.geom,
+ "kind": kind,
+ "scale_vec_size": scale_vec_size,
+ "scale": scale_vec_size.replace("_vec::", ".").lower(),
+ "stype": stype,
+ }
+
+ intrinsic_template = mma_block_scale_intrinsic_template
+ instruction_template = mma_block_scale_instruction_template
+
+ generated_items.append(
+ common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template)
+ )
+
+ return generated_items
+
+
def get_mma_sp_ops():
return (
make_mma_ops(["m16n8k16", "m16n8k32"], ["bf16"], [], ["f32"], [], True)
@@ -1224,7 +1381,11 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
return True
-def sp_selector_gen(op):
+def sp_selector_gen(op, block_scale = False):
+ if block_scale:
+ # PTX ISA 9.0 has the sparsity selector equal to 0 only
+ return range(1)
+
# (geom, type) -> allowed selector range
range_01 = {
("m16n8k32", "bf16"),
@@ -1355,6 +1516,181 @@ def gen_mma_sp_tests():
return generated_items
+def get_mma_sp_block_scale_ops():
+ return (
+ make_mma_ops(["m16n8k128"], ["e2m1"], [], ["f32"], [], True)
+ + make_mma_ops(
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"],
+ [],
+ True,
+ )
+ )
+
+
+def is_mma_sp_block_scale_geom_supported(geom):
+ # geometries for FP.
+ if geom in [
+ "m16n8k64",
+ "m16n8k128",
+ ]:
+ return True
+ raise ValueError(f"Unexpected sparse MMA block scale geometry: {geom}")
+
+
+def is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+ if not (
+ is_type_supported(op.a.mma_type.ptx_type)
+ and is_mma_sp_block_scale_geom_supported(op.a.geom)
+ ):
+ return False
+
+ if (
+ op.a.geom == "m16n8k128"
+ and kind == "mxf4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::2X"]
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k128"
+ and kind == "mxf4nvf4"
+ and stype == "ue8m0"
+ and scale_vec_size == ".scale_vec::2X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k128"
+ and kind == "mxf4nvf4"
+ and stype == "ue4m3"
+ and scale_vec_size == ".scale_vec::4X"
+ ):
+ return True
+
+ if (
+ op.a.geom == "m16n8k64"
+ and kind == "mxf8f6f4"
+ and stype == "ue8m0"
+ and scale_vec_size in ["", ".scale_vec::1X"]
+ ):
+ return True
+
+ return False
+
+
+def common_mma_sp_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+ mma_sp_block_scale_decl_template = """
+declare ${ret_ty} @${intrinsic}(
+ ${args});
+"""
+
+ mma_sp_block_scale_test_template = """
+; CHECK-LABEL: .func {{.*}}test_${function}_${selector}(
+define ${ret_ty} @test_${function}_${selector}(
+ ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_metadata}
+; CHECK-NEXT: ${check_selector}
+; CHECK-NEXT: ${check_scale_a_data}
+; CHECK-NEXT: ${check_byte_id_a}
+; CHECK-NEXT: ${check_thread_id_a}
+; CHECK-NEXT: ${check_scale_b_data}
+; CHECK-NEXT: ${check_byte_id_b}
+; CHECK-NEXT: ${check_thread_id_b}
+ %r = call ${ret_ty} @${intrinsic}(
+ ${call_args});
+ ret ${ret_ty} %r;
+}
+"""
+
+ test_params = params
+ test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["function"] = test_params["intrinsic"].replace(".", "_")
+ test_params["instruction"] = Template(instruction_template).substitute(params)
+ test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+ test_params["check_a"] = check_pattern(op.a)
+ test_params["check_b"] = check_pattern(op.b)
+ test_params["check_c"] = check_pattern(op.c)
+ test_params["check_d"] = check_pattern(op.d)
+ test_params["check_metadata"] = "{{%r[0-9]+}}"
+ test_params["check_scale_a_data"] = "{{%r[0-9]+}}"
+ test_params["check_byte_id_a"] = "{{%r[0-9]+}}"
+ test_params["check_thread_id_a"] = "{{%r[0-9]+}}"
+ test_params["check_scale_b_data"] = "{{%r[0-9]+}}"
+ test_params["check_byte_id_b"] = "{{%r[0-9]+}}"
+ test_params["check_thread_id_b"] = "{{%r[0-9]+}}"
+ args = ",\n ".join(
+ list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
+ + ["i32 %metadata", "i32 %selector"]
+ + ["i32 %scale_a_data", "i16 %byte_id_a, i16 %thread_id_a"]
+ + ["i32 %scale_b_data", "i16 %byte_id_b, i16 %thread_id_b"]
+ )
+ test_params["args"] = args
+
+ print(Template(mma_sp_block_scale_decl_template).substitute(test_params))
+
+ for selector in [str(r) for r in sp_selector_gen(op, True)]:
+ test_params["selector"] = selector
+ test_params["check_selector"] = "{{" + test_params["selector"] + "}}"
+ test_params["call_args"] = test_params["args"].replace(
+ "%selector", test_params["selector"]
+ )
+
+ print(Template(mma_sp_block_scale_test_template).substitute(test_params))
+
+ return (test_params["intrinsic"], test_params["instruction"])
+
+
+def gen_mma_sp_block_scale_tests():
+ if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
+ return []
+
+ mma_sp_block_scale_intrinsic_template = (
+ "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+ )
+ mma_sp_block_scale_instruction_template = (
+ "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
+ )
+
+ generated_items = []
+
+ for op, kind, scale_vec_size, stype in product(
+ get_mma_sp_block_scale_ops(),
+ ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+ ["", ".scale_vec::1X", ".scale_vec::2X", ".scale_vec::4X"],
+ ["ue8m0", "ue4m3"],
+ ):
+ if not is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+ continue
+
+ params = {
+ "intrinsic_signature": mma_signature(op),
+ "ptx_signature": mma_ptx_signature(op),
+ "geom": op.a.geom,
+ "kind": kind,
+ "scale_vec_size": scale_vec_size,
+ "scale": scale_vec_size.replace("_vec::", ".").lower(),
+ "stype": stype,
+ }
+
+ intrinsic_template = mma_sp_block_scale_intrinsic_template
+ instruction_template = mma_sp_block_scale_instruction_template
+
+ generated_items.append(
+ common_mma_sp_block_scale_test_gen(params, op, intrinsic_template, instruction_template)
+ )
+
+ return generated_items
+
+
# Append complete list of intrinsics and instructions we've generated tests for.
# Generate set of checks to verify that that we did generate sensible set of
# tests for the given combination of PTX and SM variants.
@@ -1545,7 +1881,9 @@ def gen_tests():
items += gen_stmatrix_tests()
items += gen_wmma_mma_tests()
items += gen_mma_tests()
+ items += gen_mma_block_scale_tests()
items += gen_mma_sp_tests()
+ items += gen_mma_sp_block_scale_tests()
gen_check_unsupported_ops(items)
>From 7fb0570f71c605c8a1f14aa5c2bc62b27cb7552f Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 15 Oct 2025 16:01:23 +0200
Subject: [PATCH 2/3] [NVPTX] Fixes for code formatting. PR163561.
---
llvm/test/CodeGen/NVPTX/wmma.py | 66 +++++++++++++++------------------
1 file changed, 30 insertions(+), 36 deletions(-)
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 81c78219075f3..c2e69dd48da99 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -1132,15 +1132,12 @@ def gen_mma_tests():
def get_mma_block_scale_ops():
- return (
- make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], [])
- + make_mma_ops(
- ["m16n8k32"],
- ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
- ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
- ["f32"],
- [],
- )
+ return make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], []) + make_mma_ops(
+ ["m16n8k32"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"],
+ [],
)
@@ -1196,7 +1193,9 @@ def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
return False
-def common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+def common_mma_block_scale_test_gen(
+ params, op, intrinsic_template, instruction_template
+):
mma_block_scale_template = """
declare ${ret_ty} @${intrinsic}(
${args});
@@ -1250,12 +1249,8 @@ def gen_mma_block_scale_tests():
if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
return []
- mma_block_scale_intrinsic_template = (
- "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
- )
- mma_block_scale_instruction_template = (
- "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
- )
+ mma_block_scale_intrinsic_template = "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+ mma_block_scale_instruction_template = "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
generated_items = []
@@ -1282,7 +1277,9 @@ def gen_mma_block_scale_tests():
instruction_template = mma_block_scale_instruction_template
generated_items.append(
- common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template)
+ common_mma_block_scale_test_gen(
+ params, op, intrinsic_template, instruction_template
+ )
)
return generated_items
@@ -1381,7 +1378,7 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
return True
-def sp_selector_gen(op, block_scale = False):
+def sp_selector_gen(op, block_scale=False):
if block_scale:
# PTX ISA 9.0 has the sparsity selector equal to 0 only
return range(1)
@@ -1517,16 +1514,13 @@ def gen_mma_sp_tests():
def get_mma_sp_block_scale_ops():
- return (
- make_mma_ops(["m16n8k128"], ["e2m1"], [], ["f32"], [], True)
- + make_mma_ops(
- ["m16n8k64"],
- ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
- ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
- ["f32"],
- [],
- True,
- )
+ return make_mma_ops(["m16n8k128"], ["e2m1"], [], ["f32"], [], True) + make_mma_ops(
+ ["m16n8k64"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f32"],
+ [],
+ True,
)
@@ -1582,7 +1576,9 @@ def is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
return False
-def common_mma_sp_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+def common_mma_sp_block_scale_test_gen(
+ params, op, intrinsic_template, instruction_template
+):
mma_sp_block_scale_decl_template = """
declare ${ret_ty} @${intrinsic}(
${args});
@@ -1653,12 +1649,8 @@ def gen_mma_sp_block_scale_tests():
if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
return []
- mma_sp_block_scale_intrinsic_template = (
- "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
- )
- mma_sp_block_scale_instruction_template = (
- "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
- )
+ mma_sp_block_scale_intrinsic_template = "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+ mma_sp_block_scale_instruction_template = "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
generated_items = []
@@ -1685,7 +1677,9 @@ def gen_mma_sp_block_scale_tests():
instruction_template = mma_sp_block_scale_instruction_template
generated_items.append(
- common_mma_sp_block_scale_test_gen(params, op, intrinsic_template, instruction_template)
+ common_mma_sp_block_scale_test_gen(
+ params, op, intrinsic_template, instruction_template
+ )
)
return generated_items
>From 93994eddbcb9079a4d32a1df0cce801b0f54bb11 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Tue, 18 Nov 2025 20:04:01 +0100
Subject: [PATCH 3/3] [NVPTX] Resolved merge conflicts + updated check for PTX
version
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 3 ++-
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 22 ++++++++++++--------
llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py | 12 +++++++++++
llvm/test/CodeGen/NVPTX/wmma.py | 4 ++--
4 files changed, 29 insertions(+), 12 deletions(-)
create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 2a8d310b94065..ff4581b76bfcd 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -178,7 +178,8 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
string gft = Geom#":"#Frag#":"#ptx_elt_type;
string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
- list<LLVMType> regs = !if(!eq(IsSparse, true),
+ bit isSparse = IsSparse;
+ list<LLVMType> regs = !if(!eq(isSparse, true),
!cond(
// mma sparse ops use other fragments for some arguments
!eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 5ef13b3be6162..34d1e312299c8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4619,7 +4619,8 @@ def INT_PTX_SREG_WARPSIZE :
// the fields commonly used to implement specific PTX instruction -- register
// types and names, constraints, parts of assembly, etc.
class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
- : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> {
+ : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type,
+ !or(!eq(op, "mma.sp"), !eq(op, "mma.sp.block_scale"))> {
// NVPTX register types used to carry fragment data.
NVPTXRegClass regclass = !cond(
!eq(ptx_elt_type, "e4m3") : B32,
@@ -4659,6 +4660,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
// longer the case, we can concat all per-fragment predicates to enforce that
// all fragments of the instruction are viable.
list<Predicate> Predicates = !cond(
+ !or(!eq(op, "mma.block_scale"),
+ !eq(op, "mma.sp.block_scale")) : [hasSM120a, hasPTX<88>],
+
!or(!eq(ptx_elt_type, "e3m2"),
!eq(ptx_elt_type, "e2m3"),
!eq(ptx_elt_type, "e2m1"),
@@ -4671,9 +4675,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!or(!eq(ptx_elt_type, "e4m3"),
!eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
- !and(!eq(op, "mma.sp"),
+ !and(isSparse,
!ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>],
- !eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>],
+ isSparse : [hasSM<80>, hasPTX<71>],
// fp16 -> fp16/fp32 @ m16n16k16
!and(!eq(geom, "m16n16k16"),
@@ -5027,7 +5031,7 @@ class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
string Kind, string SType, string ScaleVecSize>
: WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
- FragA, FragB, FragC, FragD>.record,
+ FragA, FragB, FragC, FragD>.record_name,
[FragA.Ins, FragB.Ins, FragC.Ins,
(ins B32:$scale_a, B16:$byte_id_a,
B16:$thread_id_a, B32:$scale_b,
@@ -5144,7 +5148,7 @@ class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
string Kind, string SType, string ScaleVecSize>
: WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
- FragA, FragB, FragC, FragD>.record,
+ FragA, FragB, FragC, FragD>.record_name,
[FragA.Ins, FragB.Ins, FragC.Ins,
(ins B32:$metadata, i32imm:$selector,
B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
@@ -5192,10 +5196,10 @@ defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
foreach stype = ["ue8m0", "ue4m3"] in {
foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
- def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp", "sp::ordered_metadata", kind>,
- WMMA_REGINFO<op[1], "mma.sp", "sp::ordered_metadata", kind>,
- WMMA_REGINFO<op[2], "mma.sp", "sp::ordered_metadata", kind>,
- WMMA_REGINFO<op[3], "mma.sp", "sp::ordered_metadata", kind>,
+ def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[1], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[2], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+ WMMA_REGINFO<op[3], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
kind, stype, scale_vec_size>;
}
} // op
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py
new file mode 100644
index 0000000000000..f1666dbc5f30f
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py
@@ -0,0 +1,12 @@
+# Check all variants of instructions supported by PTX88 on SM120a
+# RUN: %python %s --ptx=88 --gpu-arch=120 --aa > %t-ptx88-sm_120a.ll
+# RUN: llc < %t-ptx88-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx88 \
+# RUN: | FileCheck %t-ptx88-sm_120a.ll
+# RUN: %if ptxas-sm_120a && ptxas-isa-8.8 %{ \
+# RUN: llc < %t-ptx88-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx88 \
+# RUN: | %ptxas-verify -arch=sm_120a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index c2e69dd48da99..817665a68d7a5 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -1246,7 +1246,7 @@ def common_mma_block_scale_test_gen(
def gen_mma_block_scale_tests():
- if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
+ if not (ptx_version >= 88 and gpu_arch >= 120 and aa):
return []
mma_block_scale_intrinsic_template = "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
@@ -1646,7 +1646,7 @@ def common_mma_sp_block_scale_test_gen(
def gen_mma_sp_block_scale_tests():
- if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
+ if not (ptx_version >= 88 and gpu_arch >= 120 and aa):
return []
mma_sp_block_scale_intrinsic_template = "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
More information about the llvm-commits
mailing list