[Mlir-commits] [llvm] [mlir] [NVPTX] Added more MMA intrinsics for F8F6F4 and FP64 types. (PR #156040)

Kirill Vedernikov llvmlistbot at llvm.org
Fri Sep 26 05:01:50 PDT 2025


https://github.com/kvederni updated https://github.com/llvm/llvm-project/pull/156040

>From 2ef12326f55dee4d283277b8655a6057329ef0ab Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Fri, 29 Aug 2025 16:44:59 +0200
Subject: [PATCH 1/5] [NVPTX] Added more MMA intrinsics for F8F6F4 and FP64
 types. [NVPTX] Added restrictions for dtype/ctype combinations. [MLIR]
 Aligned MMA restrictions with NVVM IR.

MMA description in PTX ISA 9.0 is at https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-mma
---
 llvm/include/llvm/IR/IntrinsicsNVVM.td      | 102 +++++++++++++++--
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td    |  30 +++--
 llvm/test/CodeGen/NVPTX/wmma.py             | 115 ++++++++++++++++++--
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td |   5 +-
 mlir/test/Target/LLVMIR/nvvmir.mlir         |  26 -----
 5 files changed, 218 insertions(+), 60 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 7b40841e45d0d..9015245f99983 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -272,6 +272,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m16n8k16:d:f32") : !listsplat(llvm_float_ty, 4),
       !eq(gft,"m16n8k4:c:f32") : !listsplat(llvm_float_ty, 4),
       !eq(gft,"m16n8k4:d:f32") : !listsplat(llvm_float_ty, 4),
+      !eq(gft,"m16n8k32:c:f16") : !listsplat(llvm_v2f16_ty, 2),
+      !eq(gft,"m16n8k32:c:f32") : !listsplat(llvm_float_ty, 4),
+      !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
+      !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
 
       // wmma fp16 -> fp16/fp32 @  m16n16k16/m8n32k16/m32n8k16
       // All other supported geometries use the same fragment format for f32 and
@@ -298,6 +302,21 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m8n8k4:c:f64") : !listsplat(llvm_double_ty, 2),
       !eq(gft,"m8n8k4:d:f64") : !listsplat(llvm_double_ty, 2),
 
+      !eq(gft,"m16n8k4:a:f64") : !listsplat(llvm_double_ty, 2),
+      !eq(gft,"m16n8k4:b:f64") : [llvm_double_ty],
+      !eq(gft,"m16n8k4:c:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k4:d:f64") : !listsplat(llvm_double_ty, 4),
+
+      !eq(gft,"m16n8k8:a:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k8:b:f64") : !listsplat(llvm_double_ty, 2),
+      !eq(gft,"m16n8k8:c:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k8:d:f64") : !listsplat(llvm_double_ty, 4),
+
+      !eq(gft,"m16n8k16:a:f64") : !listsplat(llvm_double_ty, 8),
+      !eq(gft,"m16n8k16:b:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k16:c:f64") : !listsplat(llvm_double_ty, 4),
+      !eq(gft,"m16n8k16:d:f64") : !listsplat(llvm_double_ty, 4),
+
       // wmma bf16 -> s32 @ m16n16k16/m8n32k16/m32n8k16
       !eq(gft,"m16n16k16:a:bf16") : !listsplat(llvm_i32_ty, 4),
       !eq(gft,"m16n16k16:b:bf16") : !listsplat(llvm_i32_ty, 4),
@@ -378,6 +397,26 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m16n8k64:c:s32") : !listsplat(llvm_i32_ty, 4),
       !eq(gft,"m16n8k64:d:s32") : !listsplat(llvm_i32_ty, 4),
 
+      // mma e4m3/e5m2 -> f16/f32 @ m16n8k16
+      !eq(gft,"m16n8k16:a:e4m3") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k16:a:e5m2") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k16:b:e4m3") : [llvm_i32_ty],
+      !eq(gft,"m16n8k16:b:e5m2") : [llvm_i32_ty],
+      // mma e4m3/e5m2/e3m2/e2m3/e2m1 -> f32 @ m16n8k32
+      !eq(gft,"m16n8k32:a:e4m3") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e5m2") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e3m2") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e2m3") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k32:b:e4m3") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e5m2") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e3m2") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e2m3") : !listsplat(llvm_i32_ty, 2),
+      !eq(gft,"m16n8k32:b:e2m1") : !listsplat(llvm_i32_ty, 2),
+      // mma e2m1 -> f32 @m16n8k64
+      !eq(gft,"m16n8k64:a:e2m1") : !listsplat(llvm_i32_ty, 4),
+      !eq(gft,"m16n8k64:b:e2m1") : !listsplat(llvm_i32_ty, 2),
+
       // wmma/mma b1 -> s32 @ m8n8k128(b1)
       !eq(gft,"m8n8k128:a:b1") : [llvm_i32_ty],
       !eq(gft,"m8n8k128:b:b1") : [llvm_i32_ty],
@@ -468,7 +507,7 @@ class WMMA_NAME<string ALayout, string BLayout, int Satfinite, string Rnd, strin
                   # !if(Satfinite, "_satfinite", "");
 }
 
-class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
+class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, string Kind,
                WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
   string signature = MMA_SIGNATURE<A, B, C, D>.ret;
   string record = "int_nvvm_mma"
@@ -476,6 +515,7 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op,
                   # "_" # A.geom
                   # "_" # ALayout
                   # "_" # BLayout
+                  # !if(!ne(Kind, ""), !strconcat("_", !subst("::", "_", Kind)), "")
                   # !if(Satfinite, "_satfinite", "")
                   # signature;
 }
@@ -601,7 +641,7 @@ class NVVM_MMA_OPS {
             ["m16n8k16", "m16n8k8"],
             ["bf16"], [], ["f32"], []>.ret;
   list<list<WMMA_REGS>> f64_mma_ops = MMA_OPS<
-            ["m8n8k4"],
+            ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"],
             ["f64"], [], ["f64"], []>.ret;
   list<list<WMMA_REGS>> fp_mma_ops = MMA_OPS<
             ["m8n8k4", "m16n8k8", "m16n8k16"],
@@ -609,6 +649,18 @@ class NVVM_MMA_OPS {
   list<list<WMMA_REGS>> int_mma_ops = MMA_OPS<
             ["m8n8k16", "m16n8k16", "m16n8k32"],
             ["s8", "u8"], ["s8", "u8"], ["s32"], []>.ret;
+  // m16n8k32 fp8 variants are intersected with f8f6f4 variants
+  // and processed there
+  list<list<WMMA_REGS>> fp8_mma_ops = MMA_OPS<
+            ["m16n8k16"],
+            ["e4m3", "e5m2"], ["e4m3", "e5m2"],
+            ["f16", "f32"], ["f16", "f32"]>.ret;
+  // it also contains e4m3/e5m2 from fp8 variants
+  list<list<WMMA_REGS>> f8f6f4_mma_ops = MMA_OPS<
+            ["m16n8k32"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f16", "f32"], ["f16", "f32"]>.ret;
   list<list<WMMA_REGS>> subint_mma_ops = MMA_OPS<
             ["m8n8k32", "m16n8k32", "m16n8k64"],
             ["s4", "u4"], ["s4", "u4"], ["s32"], []>.ret;
@@ -617,7 +669,8 @@ class NVVM_MMA_OPS {
             ["b1"], [], ["s32"], []>.ret;
   list<list<WMMA_REGS>> all_mma_ops = !listconcat(
             tf32_mma_ops, bf16_mma_ops, f64_mma_ops,
-            fp_mma_ops, int_mma_ops, subint_mma_ops, bit_mma_ops);
+            fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
+            int_mma_ops, subint_mma_ops, bit_mma_ops);
 
   list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
             ["m16n8k16", "m16n8k32"],
@@ -770,7 +823,8 @@ class NVVM_MMA_B1OPS<list<WMMA_REGS> frags> {
 // if NVVM_MMA_SUPPORTED<...>.ret then
 //   def : FOO<>; // The record will only be defined for supported ops.
 //
-class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b, int satf> {
+class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b,
+                         string kind, int satf> {
   // MMA ops check both layouts.
   string layout = layout_a # ":" # layout_b;
   string a_type = frags[0].ptx_elt_type;
@@ -805,10 +859,31 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
          !or(!ne(a_type, b_type),
              !ne(c_type, d_type))): false,
 
-    // m16n8k8 requires C and D to be the same type.
-    !and(!eq(geom, "m16n8k8"),
+    // m16n8k16/m16n8k32 requires C and D to be the same type
+    !and(!or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k32")),
          !ne(c_type, d_type)): false,
 
+    // Limit kind to valid types and geometries
+    !and(!ne(kind, ""),
+         !or(!ne(geom, "m16n8k32"),
+             !and(!ne(a_type, "e4m3"),
+                  !ne(a_type, "e5m2"),
+                  !ne(a_type, "e3m2"),
+                  !ne(a_type, "e2m3"),
+                  !ne(a_type, "e2m1")))): false,
+
+    // Limit m16n8k16/m16n8k32 with no kind to valid types
+    !and(!eq(kind, ""),
+         !or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k32")),
+             !or(!eq(a_type, "e3m2"),
+                 !eq(a_type, "e2m3"),
+                 !eq(a_type, "e2m1"),
+                 !eq(b_type, "e3m2"),
+                 !eq(b_type, "e2m3"),
+                 !eq(b_type, "e2m1"))): false,
+
     // All other are OK.
     true: true
   );
@@ -882,9 +957,10 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
              !eq(a_type, "tf32")),
          !ne(a_type, b_type)): false,
 
-    // m16n8k16 and m16n8k32 requires C and D to be the same type.
+    // m16n8k16, m16n8k32 and m16n8k64 requires C and D to be the same type.
     !and(!or(!eq(geom, "m16n8k16"),
-             !eq(geom, "m16n8k32")),
+             !eq(geom, "m16n8k32"),
+             !eq(geom, "m16n8k64")),
          !ne(c_type, d_type)): false,
 
     !and(!eq(kind, ""),
@@ -2143,10 +2219,12 @@ foreach layout_a = ["row", "col"] in {
     foreach satf = [0, 1] in {
       foreach op = NVVM_MMA_OPS.all_mma_ops in {
         foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
-          if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-            def MMA_NAME<layout_a, layout_b, satf, b1op, op[0], op[1], op[2], op[3]>.record
-              : NVVM_MMA<op[0], op[1], op[2], op[3]>;
-          }
+          foreach kind = ["", "kind::f8f6f4"] in {
+            if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+                def MMA_NAME<layout_a, layout_b, satf, b1op, kind, op[0], op[1], op[2], op[3]>.record
+                : NVVM_MMA<op[0], op[1], op[2], op[3]>;
+            }
+          } // kind
         } // b1op
       } // op
     } // satf
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c544911bdf1e3..8f58c31d7e1c7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4461,6 +4461,10 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
         !eq(ptx_elt_type, "e2m1"),
         !ne(kind, "")) : [hasSM120a, hasPTX<87>],
 
+    !and(!or(!eq(ptx_elt_type,"e4m3"),
+             !eq(ptx_elt_type,"e5m2")),
+         !eq(geom, "m16n8k16")) : [hasSM<89>, hasPTX<87>],
+
     !or(!eq(ptx_elt_type, "e4m3"),
         !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
 
@@ -4476,6 +4480,11 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
     !and(!eq(geom, "m8n8k4"),
          !eq(ptx_elt_type, "f64")) : [hasSM<80>, hasPTX<70>],
 
+    !and(!or(!eq(geom, "m16n8k4"),
+             !eq(geom, "m16n8k8"),
+             !eq(geom, "m16n8k16")),
+         !eq(ptx_elt_type, "f64")) : [hasSM<90>, hasPTX<78>],
+
     // fp16 -> fp16/fp32 @ m8n32k16/m32n8k16
     !and(!or(!eq(geom, "m8n32k16"),
              !eq(geom, "m32n8k16")),
@@ -4760,8 +4769,8 @@ defset list<WMMA_INSTR> WMMAs  = {
 // MMA
 class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                WMMA_REGINFO FragC, WMMA_REGINFO FragD,
-               string ALayout, string BLayout, int Satfinite, string b1op>
-  : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, FragA, FragB, FragC, FragD>.record,
+               string ALayout, string BLayout, int Satfinite, string b1op, string Kind>
+  : WMMA_INSTR<MMA_NAME<ALayout, BLayout, Satfinite, b1op, Kind, FragA, FragB, FragC, FragD>.record,
                         [FragA.Ins, FragB.Ins, FragC.Ins]>,
     // Requires does not seem to have effect on Instruction w/o Patterns.
     // We set it here anyways and propagate to the Pat<> we construct below.
@@ -4776,6 +4785,7 @@ class MMA<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                   # FragA.geom
                   # "." # ALayout
                   # "." # BLayout
+                  # !if(!ne(Kind, ""), "." # Kind, "")
                   # !if(Satfinite, ".satfinite", "")
                   # TypeList
                   # b1op # "\n\t\t"
@@ -4792,13 +4802,15 @@ defset list<WMMA_INSTR> MMAs  = {
       foreach satf = [0, 1] in {
         foreach op = NVVM_MMA_OPS.all_mma_ops in {
           foreach b1op = NVVM_MMA_B1OPS<op>.ret in {
-            if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, satf>.ret then {
-              def : MMA<WMMA_REGINFO<op[0], "mma">,
-                        WMMA_REGINFO<op[1], "mma">,
-                        WMMA_REGINFO<op[2], "mma">,
-                        WMMA_REGINFO<op[3], "mma">,
-                        layout_a, layout_b, satf, b1op>;
-            }
+            foreach kind = ["", "kind::f8f6f4"] in {
+              if NVVM_MMA_SUPPORTED<op, layout_a, layout_b, kind, satf>.ret then {
+                def : MMA<WMMA_REGINFO<op[0], "mma", "", kind>,
+                          WMMA_REGINFO<op[1], "mma", "", kind>,
+                          WMMA_REGINFO<op[2], "mma", "", kind>,
+                          WMMA_REGINFO<op[3], "mma", "", kind>,
+                          layout_a, layout_b, satf, b1op, kind>;
+              }
+            } // kind
           } // b1op
         } // op
       } // satf
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 6d73bce46da7c..1c32856c1ce20 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -90,6 +90,21 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
             "m16n8k32:b:s8": 2,
             "m16n8k32:c:s32": 4,
             "m16n8k32:d:s32": 4,
+            # e4m3/e5m2/e3m2/e2m3/e2m1 -> f16/f32 @ m16n8k16/m16n8k32
+            "m16n8k16:a:e4m3": 2,
+            "m16n8k16:a:e5m2": 2,
+            "m16n8k32:a:e4m3": 4,
+            "m16n8k32:a:e5m2": 4,
+            "m16n8k32:a:e3m2": 4,
+            "m16n8k32:a:e2m3": 4,
+            "m16n8k32:a:e2m1": 4,
+            "m16n8k16:b:e4m3": 1,
+            "m16n8k16:b:e5m2": 1,
+            "m16n8k32:b:e4m3": 2,
+            "m16n8k32:b:e5m2": 2,
+            "m16n8k32:b:e3m2": 2,
+            "m16n8k32:b:e2m3": 2,
+            "m16n8k32:b:e2m1": 2,
             # mma sp
             "m16n8k32:a:bf16": 4,
             "m16n8k32:a:f16": 4,
@@ -182,6 +197,18 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
             "m8n8k4:b:f64": 1,
             "m8n8k4:c:f64": 2,
             "m8n8k4:d:f64": 2,
+            "m16n8k4:a:f64": 2,
+            "m16n8k4:b:f64": 1,
+            "m16n8k4:c:f64": 4,
+            "m16n8k4:d:f64": 4,
+            "m16n8k8:a:f64": 4,
+            "m16n8k8:b:f64": 2,
+            "m16n8k8:c:f64": 4,
+            "m16n8k8:d:f64": 4,
+            "m16n8k16:a:f64": 8,
+            "m16n8k16:b:f64": 4,
+            "m16n8k16:c:f64": 4,
+            "m16n8k16:d:f64": 4,
             # tf32 -> s32 @ m16n16k8
             "m16n16k8:a:tf32": 4,
             "m16n16k8:b:tf32": 4,
@@ -324,7 +351,9 @@ def get_wmma_ops():
 
 def get_mma_ops():
     return (
-        make_mma_ops(["m8n8k4"], ["f64"], [], ["f64"], [])
+        make_mma_ops(
+            ["m8n8k4", "m16n8k4", "m16n8k8", "m16n8k16"], ["f64"], [], ["f64"], []
+        )
         + make_mma_ops(["m16n8k4", "m16n8k8"], ["tf32"], [], ["f32"], [])
         + make_mma_ops(["m16n8k16", "m16n8k8"], ["bf16"], [], ["f32"], [])
         + make_mma_ops(
@@ -341,6 +370,20 @@ def get_mma_ops():
             ["m8n8k32", "m16n8k32", "m16n8k64"], ["s4", "u4"], ["s4", "u4"], ["s32"], []
         )
         + make_mma_ops(["m8n8k128", "m16n8k128", "m16n8k256"], ["b1"], [], ["s32"], [])
+        + make_mma_ops(
+            ["m16n8k16"],
+            ["e4m3", "e5m2"],
+            ["e4m3", "e5m2"],
+            ["f16", "f32"],
+            ["f16", "f32"],
+        )
+        + make_mma_ops(
+            ["m16n8k32"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f16", "f32"],
+            ["f16", "f32"],
+        )
     )
 
 
@@ -492,7 +535,7 @@ def is_wmma_variant_supported(op, layout_a, layout_b, rnd, satf):
     return True
 
 
-def is_mma_variant_supported(op, layout_a, layout_b, satf):
+def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
     if not (
         is_type_supported(op.a.mma_type.ptx_type) and is_mma_geom_supported(op.a.geom)
     ):
@@ -516,13 +559,49 @@ def is_mma_variant_supported(op, layout_a, layout_b, satf):
     ):
         return False
 
+    if (
+        op.a.geom != "m8n8k4"
+        and op.a.mma_type.ptx_type == "f64"
+        and (ptx_version < 78 or gpu_arch < 90)
+    ):
+        return False
+
     # C and D type must be the same
-    if op.a.geom == "m16n8k16" and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type:
+    if (
+        op.a.geom in ["m16n8k16", "m16n8k32"]
+        and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
+    ):
+        return False
+
+    if (
+        op.a.geom in ["m16n8k16", "m16n8k32"]
+        and any(x in ["e4m3", "e5m2"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+        and ptx_version < 87
+    ):
+        return False
+
+    if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
+        return False
+
+    if (
+        kind != ""
+        and (
+            op.a.geom != "m16n8k32"
+            or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
+        )
+    ):
+        return False
+
+    if (kind == ""
+        and op.a.geom in ["m16n8k16", "m16n8k32"]
+        and any(x in ["e3m2", "e2m3", "e2m1"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+    ):
         return False
 
     # Require row/col layout for all MMA except m8n8k4 on FP16
     if not (op.a.geom == "m8n8k4" and op.a.mma_type.ptx_type == "f16"):
         return layout_a == "row" and layout_b == "col"
+
     return True
 
 
@@ -937,7 +1016,12 @@ def common_mma_test_gen(params, op, intrinsic_template, instruction_template):
 """
 
     test_params = params
-    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["intrinsic"] = (
+        Template(intrinsic_template)
+        .substitute(params)
+        .replace("::", ".")
+        .replace("_", ".")
+    )
     test_params["function"] = test_params["intrinsic"].replace(".", "_")
     test_params["instruction"] = Template(instruction_template).substitute(params)
     test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
@@ -1002,16 +1086,24 @@ def gen_wmma_mma_tests():
 
 
 def gen_mma_tests():
-    mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${satf}.${intrinsic_signature}"
-    mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${satf}.${ptx_signature}${b1op}"
+    mma_intrinsic_template = (
+        "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
+    )
+    mma_instruction_template = (
+        "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
+    )
 
     generated_items = []
 
-    for op, alayout, blayout, satf in product(
-        get_mma_ops(), ["row", "col"], ["row", "col"], [".satfinite", ""]
+    for op, alayout, blayout, kind, satf in product(
+        get_mma_ops(),
+        ["row", "col"],
+        ["row", "col"],
+        ["", ".kind::f8f6f4"],
+        [".satfinite", ""],
     ):
 
-        if not is_mma_variant_supported(op, alayout, blayout, satf):
+        if not is_mma_variant_supported(op, alayout, blayout, kind, satf):
             continue
 
         for b1op in get_b1_ops(op.a.mma_type.ptx_type):
@@ -1024,6 +1116,7 @@ def gen_mma_tests():
                 "satf": satf,
                 "geom": op.a.geom,
                 "b1op": b1op,
+                "kind": kind,
             }
 
             intrinsic_template = mma_intrinsic_template
@@ -1105,9 +1198,9 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
     ):
         return False
 
-    # C and D type must be the same for m16n8k16/m16n8k32
+    # C and D type must be the same for m16n8k16/m16n8k32/m16n8k64
     if (
-        op.a.geom in ["m16n8k16", "m16n8k32"]
+        op.a.geom in ["m16n8k16", "m16n8k32", "m16n8k64"]
         and op.c.mma_type.ptx_type != op.d.mma_type.ptx_type
     ):
         return False
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9528da05c9fd6..c1da1cf5d0c28 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1763,8 +1763,9 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
          !or(!ne(a_type, b_type),
              !ne(c_type, d_type))): false,
 
-    // m16n8k8 requires C and D to be the same type.
-    !and(!eq(geom, "m16n8k8"),
+    // m16n8k16/m16n8k32 requires C and D to be the same type
+    !and(!or(!eq(geom, "m16n8k16"),
+             !eq(geom, "m16n8k32")),
          !ne(c_type, d_type)): false,
 
     // All other are OK.
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 62aeb071c5786..00a479d1f877d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -302,32 +302,6 @@ llvm.func @nvvm_mma_m16n8k16_bf16_bf16(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i3
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
-// f32 return type, f16 accumulate type
-// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
-llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
-                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
-                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
-                                %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {
-  // CHECK: call { float, float, float, float } @llvm.nvvm.mma.m16n8k16.row.col.f32.f16
-  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
-    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
-  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
-}
-
-// f16 return type, f32 accumulate type
-// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f32
-llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
-                                %a2 : vector<2xf16>, %a3 : vector<2xf16>,
-                                %b0 : vector<2xf16>, %b1 : vector<2xf16>,
-                                %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
-  // CHECK: call { <2 x half>, <2 x half> } @llvm.nvvm.mma.m16n8k16.row.col.f16.f32
-  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
-    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
-     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
-}
-
 // f32 return type, f32 accumulate type
 // CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f32
 llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,

>From 34dfa53d8faf76f2c0a5da67a217668f7c5ba2dc Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Mon, 1 Sep 2025 10:59:06 +0200
Subject: [PATCH 2/5] [NVPTX] Code formatting issues were fixed for PR156040.

---
 llvm/test/CodeGen/NVPTX/wmma.py | 30 +++++++++++++++---------------
 1 file changed, 15 insertions(+), 15 deletions(-)

diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 1c32856c1ce20..aeddda812432d 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -575,7 +575,10 @@ def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
 
     if (
         op.a.geom in ["m16n8k16", "m16n8k32"]
-        and any(x in ["e4m3", "e5m2"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+        and any(
+            x in ["e4m3", "e5m2"]
+            for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+        )
         and ptx_version < 87
     ):
         return False
@@ -583,18 +586,19 @@ def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
     if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
         return False
 
-    if (
-        kind != ""
-        and (
-            op.a.geom != "m16n8k32"
-            or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
-        )
+    if kind != "" and (
+        op.a.geom != "m16n8k32"
+        or op.a.mma_type.ptx_type not in ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"]
     ):
         return False
 
-    if (kind == ""
+    if (
+        kind == ""
         and op.a.geom in ["m16n8k16", "m16n8k32"]
-        and any(x in ["e3m2", "e2m3", "e2m1"] for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type))
+        and any(
+            x in ["e3m2", "e2m3", "e2m1"]
+            for x in (op.a.mma_type.ptx_type, op.b.mma_type.ptx_type)
+        )
     ):
         return False
 
@@ -1086,12 +1090,8 @@ def gen_wmma_mma_tests():
 
 
 def gen_mma_tests():
-    mma_intrinsic_template = (
-        "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
-    )
-    mma_instruction_template = (
-        "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
-    )
+    mma_intrinsic_template = "llvm.nvvm.mma${b1op}.${geom}.${alayout}.${blayout}${kind}${satf}.${intrinsic_signature}"
+    mma_instruction_template = "mma.sync${aligned}.${geom}.${alayout}.${blayout}${kind}${satf}.${ptx_signature}${b1op}"
 
     generated_items = []
 

>From 2f65fef05ad58777d79d8e355fe794c3b4fe390b Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Thu, 4 Sep 2025 15:37:45 +0200
Subject: [PATCH 3/5] [NVPTX] Updated a check for ptx and sm versions.
 PR156040.

---
 llvm/test/CodeGen/NVPTX/wmma.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index aeddda812432d..8427ae4ad72da 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -583,7 +583,7 @@ def is_mma_variant_supported(op, layout_a, layout_b, kind, satf):
     ):
         return False
 
-    if kind != "" and (ptx_version < 87 or gpu_arch < 120 or not aa):
+    if kind != "" and not (ptx_version >= 87 and gpu_arch >= 120 and aa):
         return False
 
     if kind != "" and (

>From 2b34832920b6cdf36479f5b2a021450445f6e528 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Tue, 16 Sep 2025 13:05:40 +0200
Subject: [PATCH 4/5] [NVPTX] ptxas features have been aligned with the latest
 ones. PR156040.

---
 llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
index ae781df0116fd..40055ae519fc4 100644
--- a/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx87-sm120a.py
@@ -2,7 +2,7 @@
 # RUN: %python %s --ptx=87 --gpu-arch=120 --aa > %t-ptx87-sm_120a.ll
 # RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
 # RUN:           | FileCheck %t-ptx87-sm_120a.ll
-# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: %if ptxas-sm_120a && ptxas-isa-8.7 %{                                  \
 # RUN: llc < %t-ptx87-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx87 \
 # RUN:           | %ptxas-verify -arch=sm_120a                              \
 # RUN: %}

>From 3e6c7f8d777994141f64954a4169b04aede3dc67 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Fri, 26 Sep 2025 13:55:20 +0200
Subject: [PATCH 5/5] [NVPTX] Moved unsupported MLIR MMA tests to invalid.mlir.
 PR156040.

---
 mlir/test/Dialect/LLVMIR/invalid.mlir | 28 +++++++++++++++++++++++++++
 1 file changed, 28 insertions(+)

diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index 4394786db5a5d..5f741ed775891 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -743,6 +743,34 @@ func.func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
 
 // -----
 
+// f32 return type, f16 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                     %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                     %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                     %c0 : vector<2xf16>, %c1 : vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)> {
+  // C and D should have the same type according to PTX ISA
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(f32, f32, f32, f32)>
+  llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
+}
+
+// -----
+
+// f16 return type, f32 accumulate type
+llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
+                                     %a2 : vector<2xf16>, %a3 : vector<2xf16>,
+                                     %b0 : vector<2xf16>, %b1 : vector<2xf16>,
+                                     %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> {
+  // C and D should have the same type according to PTX ISA
+  %0 = nvvm.mma.sync A[%a0, %a1, %a2, %a3] B[%b0, %b1] C[%c0, %c1, %c2, %c3]
+    {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
+     shape = #nvvm.shape<m = 16, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+  llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
+}
+
+// -----
+
 func.func @atomicrmw_mismatched_operands(%f32_ptr : !llvm.ptr, %f32 : f32) {
   // expected-error at +1 {{op failed to verify that result #0 and operand #1 have the same type}}
   %0 = "llvm.atomicrmw"(%f32_ptr, %f32) {bin_op=11, ordering=1} : (!llvm.ptr, f32) -> i32



More information about the Mlir-commits mailing list