[Mlir-commits] [llvm] [mlir] [NVPTX] Added more MMA intrinsics for F8F6F4 and FP64 types. (PR #156040)
Kirill Vedernikov
llvmlistbot at llvm.org
Fri Sep 26 05:01:50 PDT 2025
https://github.com/kvederni updated https://github.com/llvm/llvm-project/pull/156040
>From 2ef12326f55dee4d283277b8655a6057329ef0ab Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Fri, 29 Aug 2025 16:44:59 +0200
Subject: [PATCH 1/5] [NVPTX] Added more MMA intrinsics for F8F6F4 and FP64
types. [NVPTX] Added restrictions for dtype/ctype combinations. [MLIR]
Aligned MMA restrictions with NVVM IR.
MMA description in PTX ISA 9.0 is at https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
---
llvm/include/llvm/IR/IntrinsicsNVVM.td | 102 +++++++++++++++--
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 30 +++--
llvm/test/CodeGen/NVPTX/wmma.py | 115 ++++++++++++++++++--
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 5 +-
mlir/test/Target/LLVMIR/nvvmir.mlir | 26 -----
5 files changed, 218 insertions(+), 60 deletions(-)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7b40841e45d0d..9015245f99983 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -272,6 +272,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
!eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k32:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k32:c:f32") : !listsplat(llvm_float_ty, 4),
+ !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+ !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
// wmma fp16 -> fp16/fp32 @ m16n16k16/m8n32k16/m32n8k16
// All other supported geometries use the same fragment format for f32 and
@@ -298,6 +302,21 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
!eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
+ !eq(gft,"m16n8k4:a:f64") : !listsplat(llvm_double_ty, 2),
+ !eq(gft,"m16n8k4:b:f64") : [llvm_double_ty],
+ !eq(gft,"m16n8k4:c:f64") : !listsplat(llvm_double_ty, 4),
+ !eq(gft,"m16n8k4:d:f64") : !listsplat(llvm_double_ty, 4),
+
+ !eq(gft,"m16n8k8:a:f64") : !listsplat(llvm_double_ty, 4),
+ !eq(gft,"m16n8k8:b:f64") : !listsplat(llvm_double_ty, 2),
+ !eq(gft,"m16n8k8:c:f64") : !listsplat(llvm_double_ty, 4),
+ !eq(gft,"m16n8k8:d:f64") : !listsplat(llvm_double_ty, 4),
+
+ !eq(gft,"m16n8k16:a:f64") : !listsplat(llvm_double_ty, 8),
+ !eq(gft,"m16n8k16:b:f64") : !listsplat(llvm_double_ty, 4),
+ !eq(gft,"m16n8k16:c:f64") : !listsplat(llvm_double_ty, 4),
+ !eq(gft,"m16n8k16:d:f64") : !listsplat(llvm_double_ty, 4),
+
// wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
!eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
@@ -378,6 +397,26 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
!eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
!eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
+ // mma e4m3/e5m2 -> f16/f32 @ m16n8k16
+ !eq(gft,"m16n8k16:a:e4m3") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k16:a:e5m2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k16:b:e4m3") : [llvm_i32_ty],
+ !eq(gft,"m16n8k16:b:e5m2") : [llvm_i32_ty],
+ // mma e4m3/e5m2/e3m2/e2m3/e2m1 -> f32 @ m16n8k32
+ !eq(gft,"m16n8k32:a:e4m3") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:a:e5m2") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:a:e3m2") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:a:e2m3") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k32:b:e4m3") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:e5m2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:e3m2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:e2m3") : !listsplat(llvm_i32_ty, 2),
+ !eq(gft,"m16n8k32:b:e2m1") : !listsplat(llvm_i32_ty, 2),
+ // mma e2m1 -> f32 @m16n8k64
+ !eq(gft,"m16n8k64:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+ !eq(gft,"m16n8k64:b:e2m1") : !listsplat(llvm_i32_ty, 2),
+
// wmma/mma b1 -> s32 @ m8n8k128(b1)
!eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
!eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
@@ -468,7 +507,7 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, strin
# !if(Satfinite, "_satfinite", "");
}
-class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
+class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, string Kind,
WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
string signature = MMA_SIGNATURE<A, B, C, D>.ret;
string record = "int_nvvm_mma"
@@ -476,6 +515,7 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
# "_" # A.geom
# "_" # ALayout
# "_" # BLayout
+ # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
# !if(Satfinite, "_satfinite", "")
# signature;
}
@@ -601,7 +641,7 @@ class NVVM_MMA_OPS {
["m16n8k16", "m16n8k8"],
["bf16"], [], ["f32"], []>.ret;
list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
- ["m8n8k4"],
+ ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"],
["f64"], [], ["f64"], []>.ret;
list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
["m8n8k4", "m16n8k8", "m16n8k16"],
@@ -609,6 +649,18 @@ class NVVM_MMA_OPS {
list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
["m8n8k16", "m16n8k16", "m16n8k32"],
["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+ // m16n8k32 fp8 variants are intersected with f8f6f4 variants
+ // and processed there
+ list<list<WMMA_REGS>> fp8_mma_ops = MMA_OPS<
+ ["m16n8k16"],
+ ["e4m3", "e5m2"], ["e4m3", "e5m2"],
+ ["f16", "f32"], ["f16", "f32"]>.ret;
+ // it also contains e4m3/e5m2 from fp8 variants
+ list<list<WMMA_REGS>> f8f6f4_mma_ops = MMA_OPS<
+ ["m16n8k32"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f16", "f32"], ["f16", "f32"]>.ret;
list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
["m8n8k32", "m16n8k32", "m16n8k64"],
["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
@@ -617,7 +669,8 @@ class NVVM_MMA_OPS {
["b1"], [], ["s32"], []>.ret;
list<list<WMMA_REGS>> all_mma_ops = !listconcat(
tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
- fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+ fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
+ int_mma_ops, subint_mma_ops, bit_mma_ops);
list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
["m16n8k16", "m16n8k32"],
@@ -770,7 +823,8 @@ class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
// if NVVM_MMA_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
-class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b,
+ string kind, int satf> {
// MMA ops check both layouts.
string layout = layout_a # ":" # layout_b;
string a_type = frags[0].ptx_elt_type;
@@ -805,10 +859,31 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
!or(!ne(a_type, b_type),
!ne(c_type, d_type))): false,
- // m16n8k8 requires C and D to be the same type.
- !and(!eq(geom, "m16n8k8"),
+ // m16n8k16/m16n8k32 requires C and D to be the same type
+ !and(!or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k32")),
!ne(c_type, d_type)): false,
+ // Limit kind to valid types and geometries
+ !and(!ne(kind, ""),
+ !or(!ne(geom, "m16n8k32"),
+ !and(!ne(a_type, "e4m3"),
+ !ne(a_type, "e5m2"),
+ !ne(a_type, "e3m2"),
+ !ne(a_type, "e2m3"),
+ !ne(a_type, "e2m1")))): false,
+
+ // Limit m16n8k16/m16n8k32 with no kind to valid types
+ !and(!eq(kind, ""),
+ !or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k32")),
+ !or(!eq(a_type, "e3m2"),
+ !eq(a_type, "e2m3"),
+ !eq(a_type, "e2m1"),
+ !eq(b_type, "e3m2"),
+ !eq(b_type, "e2m3"),
+ !eq(b_type, "e2m1"))): false,
+
// All other are OK.
true: true
);
@@ -882,9 +957,10 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
!eq(a_type, "tf32")),
!ne(a_type, b_type)): false,
- // m16n8k16 and m16n8k32 requires C and D to be the same type.
+ // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
!and(!or(!eq(geom, "m16n8k16"),
- !eq(geom, "m16n8k32")),
+ !eq(geom, "m16n8k32"),
+ !eq(geom, "m16n8k64")),
!ne(c_type, d_type)): false,
!and(!eq(kind, ""),
@@ -2143,10 +2219,12 @@ foreach layout_a = ["row", "col"] in {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
- if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
- def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record
- : NVVM_MMA<op[0], op[1], op[2], op[3]>;
- }
+ foreach kind = ["", "kind::f8f6f4"] in {
+ if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+ def MMA_NAME<layout_a, layout_b, satf, b1op, kind, op[0], op[1], op[2], op[3]>.record
+ : NVVM_MMA<op[0], op[1], op[2], op[3]>;
+ }
+ } // kind
} // b1op
} // op
} // satf
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c544911bdf1e3..8f58c31d7e1c7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4461,6 +4461,10 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!eq(ptx_elt_type, "e2m1"),
!ne(kind, "")) : [hasSM120a, hasPTX<87>],
+ !and(!or(!eq(ptx_elt_type,"e4m3"),
+ !eq(ptx_elt_type,"e5m2")),
+ !eq(geom, "m16n8k16")) : [hasSM<89>, hasPTX<87>],
+
!or(!eq(ptx_elt_type, "e4m3"),
!eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
@@ -4476,6 +4480,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
!and(!eq(geom, "m8n8k4"),
!eq(ptx_elt_type, "f64")) : [hasSM<80>, hasPTX<70>],
+ !and(!or(!eq(geom, "m16n8k4"),
+ !eq(geom, "m16n8k8"),
+ !eq(geom, "m16n8k16")),
+ !eq(ptx_elt_type, "f64")) : [hasSM<90>, hasPTX<78>],
+
// fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
!and(!or(!eq(geom, "m8n32k16"),
!eq(geom, "m32n8k16")),
@@ -4760,8 +4769,8 @@ defset list<WMMA_INSTR> WMMAs = {
// MMA
class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
WMMA_REGINFO FragC, WMMA_REGINFO FragD,
- string ALayout, string BLayout, int Satfinite, string b1op>
- : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
+ string ALayout, string BLayout, int Satfinite, string b1op, string Kind>
+ : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record,
[FragA.Ins, FragB.Ins, FragC.Ins]>,
// Requires does not seem to have effect on Instruction w/o Patterns.
// We set it here anyways and propagate to the Pat<> we construct below.
@@ -4776,6 +4785,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
# FragA.geom
# "." # ALayout
# "." # BLayout
+ # !if(!ne(Kind, ""), "." # Kind, "")
# !if(Satfinite, ".satfinite", "")
# TypeList
# b1op # "\n\t\t"
@@ -4792,13 +4802,15 @@ defset list<WMMA_INSTR> MMAs = {
foreach satf = [0, 1] in {
foreach op = NVVM_MMA_OPS.all_mma_ops in {
foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
- if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
- def : MMA<WMMA_REGINFO<op[0], "mma">,
- WMMA_REGINFO<op[1], "mma">,
- WMMA_REGINFO<op[2], "mma">,
- WMMA_REGINFO<op[3], "mma">,
- layout_a, layout_b, satf, b1op>;
- }
+ foreach kind = ["", "kind::f8f6f4"] in {
+ if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+ def : MMA<WMMA_REGINFO<op[0], "mma", "", kind>,
+ WMMA_REGINFO<op[1], "mma", "", kind>,
+ WMMA_REGINFO<op[2], "mma", "", kind>,
+ WMMA_REGINFO<op[3], "mma", "", kind>,
+ layout_a, layout_b, satf, b1op, kind>;
+ }
+ } // kind
} // b1op
} // op
} // satf
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 6d73bce46da7c..1c32856c1ce20 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -90,6 +90,21 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
"m16n8k32:b:s8": 2,
"m16n8k32:c:s32": 4,
"m16n8k32:d:s32": 4,
+ # e4m3/e5m2/e3m2/e2m3/e2m1 -> f16/f32 @ m16n8k16/m16n8k32
+ "m16n8k16:a:e4m3": 2,
+ "m16n8k16:a:e5m2": 2,
+ "m16n8k32:a:e4m3": 4,
+ "m16n8k32:a:e5m2": 4,
+ "m16n8k32:a:e3m2": 4,
+ "m16n8k32:a:e2m3": 4,
+ "m16n8k32:a:e2m1": 4,
+ "m16n8k16:b:e4m3": 1,
+ "m16n8k16:b:e5m2": 1,
+ "m16n8k32:b:e4m3": 2,
+ "m16n8k32:b:e5m2": 2,
+ "m16n8k32:b:e3m2": 2,
+ "m16n8k32:b:e2m3": 2,
+ "m16n8k32:b:e2m1": 2,
# mma sp
"m16n8k32:a:bf16": 4,
"m16n8k32:a:f16": 4,
@@ -182,6 +197,18 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
"m8n8k4:b:f64": 1,
"m8n8k4:c:f64": 2,
"m8n8k4:d:f64": 2,
+ "m16n8k4:a:f64": 2,
+ "m16n8k4:b:f64": 1,
+ "m16n8k4:c:f64": 4,
+ "m16n8k4:d:f64": 4,
+ "m16n8k8:a:f64": 4,
+ "m16n8k8:b:f64": 2,
+ "m16n8k8:c:f64": 4,
+ "m16n8k8:d:f64": 4,
+ "m16n8k16:a:f64": 8,
+ "m16n8k16:b:f64": 4,
+ "m16n8k16:c:f64": 4,
+ "m16n8k16:d:f64": 4,
# tf32 -> s32 @ m16n16k8
"m16n16k8:a:tf32": 4,
"m16n16k8:b:tf32": 4,
@@ -324,7 +351,9 @@ def get_wmma_ops():
def get_mma_ops():
return (
- make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
+ make_mma_ops(
+ ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"], ["f64"], [], ["f64"], []
+ )
+ make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
+ make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
+ make_mma_ops(
@@ -341,6 +370,20 @@ def get_mma_ops():
["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
)
+ make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
+ + make_mma_ops(
+ ["m16n8k16"],
+ ["e4m3", "e5m2"],
+ ["e4m3", "e5m2"],
+ ["f16", "f32"],
+ ["f16", "f32"],
+ )
+ + make_mma_ops(
+ ["m16n8k32"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+ ["f16", "f32"],
+ ["f16", "f32"],
+ )
)
@@ -492,7 +535,7 @@ def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
return True
-def is_mma_variant_supported(op, layout_a, layout_b, satf):
+def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
if not (
is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom)
):
@@ -516,13 +559,49 @@ def is_mma_variant_supported(op, layout_a, layout_b, satf):
):
return False
+ if (
+ op.a.geom != "m8n8k4"
+ and op.a.mma_type.ptx_type == "f64"
+ and (ptx_version < 78 or gpu_arch < 90)
+ ):
+ return False
+
# C and D type must be the same
- if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type:
+ if (
+ op.a.geom in ["m16n8k16", "m16n8k32"]
+ and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
+ ):
+ return False
+
+ if (
+ op.a.geom in ["m16n8k16", "m16n8k32"]
+ and any(x in ["e4m3", "e5m2"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+ and ptx_version < 87
+ ):
+ return False
+
+ if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
+ return False
+
+ if (
+ kind != ""
+ and (
+ op.a.geom != "m16n8k32"
+ or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
+ )
+ ):
+ return False
+
+ if (kind == ""
+ and op.a.geom in ["m16n8k16", "m16n8k32"]
+ and any(x in ["e3m2", "e2m3", "e2m1"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+ ):
return False
# Require row/col layout for all MMA except m8n8k4 on FP16
if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
return layout_a == "row" and layout_b == "col"
+
return True
@@ -937,7 +1016,12 @@ def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
"""
test_params = params
- test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+ test_params["intrinsic"] = (
+ Template(intrinsic_template)
+ .substitute(params)
+ .replace("::", ".")
+ .replace("_", ".")
+ )
test_params["function"] = test_params["intrinsic"].replace(".", "_")
test_params["instruction"] = Template(instruction_template).substitute(params)
test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
@@ -1002,16 +1086,24 @@ def gen_wmma_mma_tests():
def gen_mma_tests():
- mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
- mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
+ mma_intrinsic_template = (
+ "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
+ )
+ mma_instruction_template = (
+ "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
+ )
generated_items = []
- for op, alayout, blayout, satf in product(
- get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
+ for op, alayout, blayout, kind, satf in product(
+ get_mma_ops(),
+ ["row", "col"],
+ ["row", "col"],
+ ["", ".kind::f8f6f4"],
+ [".satfinite", ""],
):
- if not is_mma_variant_supported(op, alayout, blayout, satf):
+ if not is_mma_variant_supported(op, alayout, blayout, kind, satf):
continue
for b1op in get_b1_ops(op.a.mma_type.ptx_type):
@@ -1024,6 +1116,7 @@ def gen_mma_tests():
"satf": satf,
"geom": op.a.geom,
"b1op": b1op,
+ "kind": kind,
}
intrinsic_template = mma_intrinsic_template
@@ -1105,9 +1198,9 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
):
return False
- # C and D type must be the same for m16n8k16/m16n8k32
+ # C and D type must be the same for m16n8k16/m16n8k32/m16n8k64
if (
- op.a.geom in ["m16n8k16", "m16n8k32"]
+ op.a.geom in ["m16n8k16", "m16n8k32", "m16n8k64"]
and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
):
return False
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9528da05c9fd6..c1da1cf5d0c28 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1763,8 +1763,9 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
!or(!ne(a_type, b_type),
!ne(c_type, d_type))): false,
- // m16n8k8 requires C and D to be the same type.
- !and(!eq(geom, "m16n8k8"),
+ // m16n8k16/m16n8k32 requires C and D to be the same type
+ !and(!or(!eq(geom, "m16n8k16"),
+ !eq(geom, "m16n8k32")),
!ne(c_type, d_type)): false,
// All other are OK.
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 62aeb071c5786..00a479d1f877d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -302,32 +302,6 @@ llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i3
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
}
-// f32 return type, f16 accumulate type
-// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
-llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
- %a2 : vector<2xf16>, %a3 : vector<2xf16>,
- %b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {
- // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16
- %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
- {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
- shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
- llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
-}
-
-// f16 return type, f32 accumulate type
-// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f32
-llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
- %a2 : vector<2xf16>, %a3 : vector<2xf16>,
- %b0 : vector<2xf16>, %b1 : vector<2xf16>,
- %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
- // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32
- %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
- {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
- shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
- llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-}
-
// f32 return type, f32 accumulate type
// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f32
llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
>From 34dfa53d8faf76f2c0a5da67a217668f7c5ba2dc Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Mon, 1 Sep 2025 10:59:06 +0200
Subject: [PATCH 2/5] [NVPTX] Code formatting issues were fixed for PR156040.
---
llvm/test/CodeGen/NVPTX/wmma.py | 30 +++++++++++++++---------------
1 file changed, 15 insertions(+), 15 deletions(-)
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 1c32856c1ce20..aeddda812432d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -575,7 +575,10 @@ def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
if (
op.a.geom in ["m16n8k16", "m16n8k32"]
- and any(x in ["e4m3", "e5m2"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+ and any(
+ x in ["e4m3", "e5m2"]
+ for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+ )
and ptx_version < 87
):
return False
@@ -583,18 +586,19 @@ def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
return False
- if (
- kind != ""
- and (
- op.a.geom != "m16n8k32"
- or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
- )
+ if kind != "" and (
+ op.a.geom != "m16n8k32"
+ or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
):
return False
- if (kind == ""
+ if (
+ kind == ""
and op.a.geom in ["m16n8k16", "m16n8k32"]
- and any(x in ["e3m2", "e2m3", "e2m1"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+ and any(
+ x in ["e3m2", "e2m3", "e2m1"]
+ for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+ )
):
return False
@@ -1086,12 +1090,8 @@ def gen_wmma_mma_tests():
def gen_mma_tests():
- mma_intrinsic_template = (
- "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
- )
- mma_instruction_template = (
- "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
- )
+ mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
+ mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
generated_items = []
>From 2f65fef05ad58777d79d8e355fe794c3b4fe390b Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Thu, 4 Sep 2025 15:37:45 +0200
Subject: [PATCH 3/5] [NVPTX] Updated a check for ptx and sm versions.
PR156040.
---
llvm/test/CodeGen/NVPTX/wmma.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index aeddda812432d..8427ae4ad72da 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -583,7 +583,7 @@ def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
):
return False
- if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
+ if kind != "" and not (ptx_version >= 87 and gpu_arch >= 120 and aa):
return False
if kind != "" and (
>From 2b34832920b6cdf36479f5b2a021450445f6e528 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Tue, 16 Sep 2025 13:05:40 +0200
Subject: [PATCH 4/5] [NVPTX] ptxas features have been aligned with the latest
ones. PR156040.
---
llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
index ae781df0116fd..40055ae519fc4 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
@@ -2,7 +2,7 @@
# RUN: %python %s --ptx=87 --gpu-arch=120 --aa > %t-ptx87-sm_120a.ll
# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
# RUN: | FileCheck %t-ptx87-sm_120a.ll
-# RUN: %if ptxas-12.7 %{ \
+# RUN: %if ptxas-sm_120a && ptxas-isa-8.7 %{ \
# RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
# RUN: | %ptxas-verify -arch=sm_120a \
# RUN: %}
>From 3e6c7f8d777994141f64954a4169b04aede3dc67 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Fri, 26 Sep 2025 13:55:20 +0200
Subject: [PATCH 5/5] [NVPTX] Moved unsupported MLIR MMA tests to invalid.mlir.
PR156040.
---
mlir/test/Dialect/LLVMIR/invalid.mlir | 28 +++++++++++++++++++++++++++
1 file changed, 28 insertions(+)
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 4394786db5a5d..5f741ed775891 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -743,6 +743,34 @@ func.func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
// -----
+// f32 return type, f16 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {
+ // C and D should have the same type according to PTX ISA
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// -----
+
+// f16 return type, f32 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+ %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+ %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+ %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+ // C and D should have the same type according to PTX ISA
+ %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+ {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+ shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// -----
+
func.func @atomicrmw_mismatched_operands(%f32_ptr : !llvm.ptr, %f32 : f32) {
// expected-error at +1 {{op failed to verify that result #0 and operand #1 have the same type}}
%0 = "llvm.atomicrmw"(%f32_ptr, %f32) {bin_op=11, ordering=1} : (!llvm.ptr, f32) -> i32
More information about the Mlir-commits
mailing list