[llvm] [NVPTX] Support for dense and sparse MMA intrinsics with block scaling. (PR #163561)

Kirill Vedernikov via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 18 22:27:33 PST 2025


https://github.com/kvederni updated https://github.com/llvm/llvm-project/pull/163561

>From 961cc0fa37b93af3e5b0a5c23f787d728b527860 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 15 Oct 2025 15:22:10 +0200
Subject: [PATCH 1/3] [NVPTX] Support for dense and sparse MMA intrinsics with
 block scaling.

---
 llvm/include/llvm/IR/IntrinsicsNVVM.td   | 188 +++++++++++++
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 130 ++++++++-
 llvm/test/CodeGen/NVPTX/wmma.py          | 342 ++++++++++++++++++++++-
 3 files changed, 657 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 1b485dc8ccd1e..2a8d310b94065 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -277,6 +277,10 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
       !eq(gft,"m16n8k32:d:f16") : !listsplat(llvm_v2f16_ty, 2),
       !eq(gft,"m16n8k32:d:f32") : !listsplat(llvm_float_ty, 4),
 
+      // mma.block_scale e2m1 (mxf4, mxf4nvf4) -> f32 @ m16n8k64
+      !eq(gft,"m16n8k64:c:f32") : !listsplat(llvm_float_ty, 4),
+      !eq(gft,"m16n8k64:d:f32") : !listsplat(llvm_float_ty, 4),
+
       // wmma fp16 -> fp16/fp32 @  m16n16k16/m8n32k16/m32n8k16
       // All other supported geometries use the same fragment format for f32 and
       // f16, so we only need to consider {fragment, type}.
@@ -520,6 +524,18 @@ class MMA_NAME<string ALayout, string BLayout, int Satfinite, string b1op, strin
                   # signature;
 }
 
+class MMA_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+                           WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string record_name = "int_nvvm_mma_block_scale"
+                  # "_" # A.geom
+                  # "_row_col"
+                  # "_" # Kind
+                  # !subst(".", "_", ScaleVecSize)
+                  # signature
+                  # "_" # SType;
+}
+
 class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
                   WMMA_REGS A, WMMA_REGS B,
                   WMMA_REGS C, WMMA_REGS D> {
@@ -533,6 +549,19 @@ class MMA_SP_NAME<string Metadata, string Kind, int Satfinite,
                   # signature;
 }
 
+class MMA_SP_BLOCK_SCALE_NAME<string Kind, string SType, string ScaleVecSize,
+                              WMMA_REGS A, WMMA_REGS B,
+                              WMMA_REGS C, WMMA_REGS D> {
+  string signature = MMA_SIGNATURE<A, B, C, D>.ret;
+  string record_name = "int_nvvm_mma_sp_ordered_metadata_block_scale"
+                  # "_" # A.geom
+                  # "_row_col"
+                  # "_" # Kind
+                  # !subst(".", "_", ScaleVecSize)
+                  # signature
+                  # "_" # SType;
+}
+
 // Helper class that takes an intrinsic name and construct a record name.
 // Additionally, sets `intr_name` to be non-empty if the default name assigned
 // to this intrinsic will not match the name given.
@@ -683,6 +712,18 @@ class NVVM_MMA_OPS {
             fp_mma_ops, fp8_mma_ops, f8f6f4_mma_ops,
             int_mma_ops, subint_mma_ops, bit_mma_ops);
 
+  list<list<WMMA_REGS>> mxf4_mma_ops = MMA_OPS<
+    ["m16n8k64"], ["e2m1"], ["e2m1"], ["f32"], ["f32"]
+  >.ret;
+
+  list<list<WMMA_REGS>> mxf8f6f4_mma_ops = MMA_OPS<
+    ["m16n8k32"], ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"],
+                  ["e2m1", "e2m3", "e3m2", "e5m2", "e4m3"], ["f32"], ["f32"]
+  >.ret;
+
+  list<list<WMMA_REGS>> all_mma_block_scale_ops = !listconcat(
+            mxf4_mma_ops, mxf8f6f4_mma_ops);
+
   list<list<WMMA_REGS>> bf16_mma_sp_ops = MMA_OPS<
             ["m16n8k16", "m16n8k32"],
             ["bf16"], [], ["f32"], [], true>.ret;
@@ -707,6 +748,18 @@ class NVVM_MMA_OPS {
             bf16_mma_sp_ops, tf32_mma_sp_ops, fp_mma_sp_ops, fp8_mma_sp_ops,
             subint_mma_sp_ops, int_mma_sp_ops);
 
+  // combines available geoms and types for mxf4 and mxf4nvf4 kinds
+  list<list<WMMA_REGS>> mxf4xx_mma_sp_ops = MMA_OPS<
+            ["m16n8k128"],
+            ["e2m1"], ["e2m1"], ["f32"], [], true>.ret;
+  list<list<WMMA_REGS>> mxf8f6f4_mma_sp_ops = MMA_OPS<
+            ["m16n8k64"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f32"], [], true>.ret;
+  list<list<WMMA_REGS>> all_mma_sp_block_scale_ops = !listconcat(
+            mxf4xx_mma_sp_ops, mxf8f6f4_mma_sp_ops);
+
   list<WMMA_REGS> ldst_ab_ops = MMA_LDST_OPS<
             ["m16n16k16", "m32n8k16", "m8n32k16"],
             ["a", "b"], ["f16", "u8", "s8", "bf16"]>.ret;
@@ -900,6 +953,32 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
   );
 }
 
+class NVVM_MMA_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind, string stype, string scale_vec_size> {
+  string geom = frags[0].geom;
+
+  bit ret = !cond(
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_2x")),
+         !eq(stype, "ue8m0")) : true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(scale_vec_size, ".scale_2x"),
+         !eq(stype, "ue8m0")) : true,
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(scale_vec_size, ".scale_4x"),
+         !eq(stype, "ue4m3")) : true,
+    !and(!eq(geom, "m16n8k32"),
+         !eq(kind, "mxf8f6f4"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_1x")),
+         !eq(stype, "ue8m0")) : true,
+    true: false
+  );
+}
+
 // Returns true if the fragment is valid for ldmatrix ops is supported;
 // false otherwise.
 // E.g.
@@ -998,6 +1077,51 @@ class NVVM_MMA_SP_SUPPORTED<list<WMMA_REGS> frags, string metadata,
 }
 
 
+// Returns true if this combination of kind/scale_vec_size/stype
+// for MMA.SP ops is supported;
+// false otherwise.
+// E.g.
+// if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<...>.ret then
+//   def : FOO<>; // The record will only be defined for supported ops.
+//
+class NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<list<WMMA_REGS> frags, string kind,
+                                        string stype, string scale_vec_size> {
+  // MMA.SP ops check both layouts.
+  string a_type = frags[0].ptx_elt_type;
+  string b_type = frags[1].ptx_elt_type;
+  string c_type = frags[2].ptx_elt_type;
+  string d_type = frags[3].ptx_elt_type;
+  string geom = frags[0].geom;
+
+  bit ret = !cond(
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4"),
+         !eq(stype, "ue8m0"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_2x"))): true,
+
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(stype, "ue8m0"),
+         !eq(scale_vec_size, ".scale_2x")): true,
+
+    !and(!eq(geom, "m16n8k128"),
+         !eq(kind, "mxf4nvf4"),
+         !eq(stype, "ue4m3"),
+         !eq(scale_vec_size, ".scale_4x")): true,
+
+    !and(!eq(geom, "m16n8k64"),
+         !eq(kind, "mxf8f6f4"),
+         !eq(stype, "ue8m0"),
+         !or(!eq(scale_vec_size, ""),
+             !eq(scale_vec_size, ".scale_1x"))): true,
+
+    // All other are NOT OK.
+    true: false
+  );
+}
+
+
 class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
   string Suffix = !if(sync, "sync_", "")
                   # mode # "_"
@@ -2415,6 +2539,31 @@ foreach layout_a = ["row", "col"] in {
   } // layout_b
 } // layout_a
 
+class NVVM_MMA_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+  : Intrinsic<D.regs,
+              !listconcat(A.regs, B.regs, C.regs,
+              [
+                llvm_i32_ty,              // scale-a-data
+                llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+                llvm_i32_ty,              // scale-b-data,
+                llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+              ]),
+              [IntrNoMem, IntrNoCallback]>;
+
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+  foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+    foreach stype = ["ue8m0", "ue4m3"] in {
+      foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+        if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+          def MMA_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+                                   op[0], op[1], op[2], op[3]>.record_name
+            : NVVM_MMA_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+        }
+      } // op
+    } // stype
+  } // scale_vec_size
+} // kind
+
 // MMA.SP
 class NVVM_MMA_SP<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
   : Intrinsic<D.regs,
@@ -2462,6 +2611,45 @@ foreach metadata = ["sp", "sp::ordered_metadata"] in {
   } // kind
 } // metadata
 
+// MMA.SP BLOCK SCALE
+class NVVM_MMA_SP_BLOCK_SCALE<WMMA_REGS A, WMMA_REGS B, WMMA_REGS C, WMMA_REGS D>
+  : Intrinsic<D.regs,
+              !listconcat(A.regs, B.regs, C.regs,
+              [
+                llvm_i32_ty, // metadata
+                llvm_i32_ty, // sparsity selector
+                llvm_i32_ty, // scale-a-data
+                llvm_i16_ty, llvm_i16_ty, // byte-id-a, thread-id-a
+                llvm_i32_ty, // scale-b-data
+                llvm_i16_ty, llvm_i16_ty, // byte-id-b, thread-id-b
+              ])> {
+    int pos = !size(!listconcat(A.regs, B.regs, C.regs, [llvm_i32_ty]));
+
+    // The range [0;num_threads) is for the sparsity selector that indicates the threads
+    // which contribute metadata.
+    // According to PTX ISA 9.0, the sparsity selector is always 0
+    // for sparse MMA block scale instructions
+    int num_threads = 1;
+    let IntrProperties = [IntrNoMem, IntrNoCallback, ImmArg<ArgIndex<pos>>,
+                          Range<ArgIndex<pos>, 0, num_threads>];
+}
+
+// According to PTX ISA 9.0
+// a_layout = ["row"], b_layout = ["col"], spvariant = ["sp::ordered_metadata"]
+foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+  foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+    foreach stype = ["ue8m0", "ue4m3"] in {
+      foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+        if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+          def MMA_SP_BLOCK_SCALE_NAME<kind, stype, scale_vec_size,
+                                      op[0], op[1], op[2], op[3]>.record_name
+            : NVVM_MMA_SP_BLOCK_SCALE<op[0], op[1], op[2], op[3]>;
+        }
+      } // op
+    } // stype
+  } // scale_vec_size
+} // kind
+
 // LDMATRIX
 class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
   : Intrinsic<Frag.regs, [llvm_anyptr_ty],
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index ea69a54e6db37..5ef13b3be6162 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -5022,6 +5022,67 @@ defset list<WMMA_INSTR> MMAs  = {
 } // defset
 }
 
+// MMA.block_scale
+class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+                      WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+                      string Kind, string SType, string ScaleVecSize>
+  : WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+                                    FragA, FragB, FragC, FragD>.record,
+                                    [FragA.Ins, FragB.Ins, FragC.Ins,
+                                     (ins B32:$scale_a, B16:$byte_id_a,
+                                          B16:$thread_id_a, B32:$scale_b,
+                                          B16:$byte_id_b, B16:$thread_id_b)]>,
+    // Requires does not seem to have effect on Instruction w/o Patterns.
+    // We set it here anyways and propagate to the Pat<> we construct below.
+  Requires<FragA.Predicates> {
+  let OutOperandList = FragD.Outs;
+  let InOperandList  = !con(Args, (ins MmaCode:$ptx));
+  string TypeList = !interleave([FragD.ptx_elt_type,
+                                 FragA.ptx_elt_type,
+                                 FragB.ptx_elt_type,
+                                 FragC.ptx_elt_type], ".");
+  string ScaleVecSizeStr = !cond(
+    !eq(ScaleVecSize, "") : "",
+    !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+    !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+    !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+  );
+  let AsmString = "mma.sync.aligned."
+                  # FragA.geom
+                  # ".row.col"
+                  # ".kind::" # Kind
+                  # ".block_scale"
+                  # ScaleVecSizeStr
+                  # "." # TypeList
+                  # "." # SType # " \n\t\t"
+                  # FragD.regstring # ",\n\t\t"
+                  # FragA.regstring # ",\n\t\t"
+                  # FragB.regstring # ",\n\t\t"
+                  # FragC.regstring # ",\n\t\t"
+                  # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+                  # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_BLOCK_SCALEs  = {
+  foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+    foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+      foreach stype = ["ue8m0", "ue4m3"] in {
+        foreach op = NVVM_MMA_OPS.all_mma_block_scale_ops in {
+          if NVVM_MMA_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+            def : MMA_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.block_scale", "", kind>,
+                                  WMMA_REGINFO<op[1], "mma.block_scale", "", kind>,
+                                  WMMA_REGINFO<op[2], "mma.block_scale", "", kind>,
+                                  WMMA_REGINFO<op[3], "mma.block_scale", "", kind>,
+                                  kind, stype, scale_vec_size>;
+          }
+        } // op
+      } // stype
+    } // scale_vec_size
+  } // kind
+} // defset
+}
+
 // MMA SP
 class MMA_SP<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
              WMMA_REGINFO FragC, WMMA_REGINFO FragD,
@@ -5078,6 +5139,72 @@ defset list<WMMA_INSTR> MMA_SPs = {
 } // defset
 }
 
+// MMA SP BLOCK SCALE
+class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
+                         WMMA_REGINFO FragC, WMMA_REGINFO FragD,
+                         string Kind, string SType, string ScaleVecSize>
+  : WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
+                                       FragA, FragB, FragC, FragD>.record,
+               [FragA.Ins, FragB.Ins, FragC.Ins,
+                (ins B32:$metadata, i32imm:$selector,
+                     B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
+                     B32:$scale_b, B16:$byte_id_b, B16:$thread_id_b)]>,
+    // Requires does not seem to have effect on Instruction w/o Patterns.
+    // We set it here anyways and propagate to the Pat<> we construct below.
+    Requires<!listconcat(FragA.Predicates,
+                         FragB.Predicates,
+                         FragC.Predicates,
+                         FragD.Predicates)> {
+  let OutOperandList = FragD.Outs;
+  let InOperandList = !con(Args, (ins MmaCode:$ptx));
+  string TypeList = "." # FragD.ptx_elt_type
+                    # "." # FragA.ptx_elt_type
+                    # "." # FragB.ptx_elt_type
+                    # "." # FragC.ptx_elt_type;
+  string ScaleVecSizeStr = !cond(
+    !eq(ScaleVecSize, "") : "",
+    !eq(ScaleVecSize, ".scale_1x") : ".scale_vec::1X",
+    !eq(ScaleVecSize, ".scale_2x") : ".scale_vec::2X",
+    !eq(ScaleVecSize, ".scale_4x") : ".scale_vec::4X"
+  );
+  let AsmString = "mma.sp::ordered_metadata.sync.aligned."
+                  # FragA.geom
+                  # ".row.col"
+                  # ".kind::" # Kind
+                  # ".block_scale"
+                  # ScaleVecSizeStr
+                  # TypeList
+                  # "." # SType # "\n\t\t"
+                  # FragD.regstring # ",\n\t\t"
+                  # FragA.regstring # ",\n\t\t"
+                  # FragB.regstring # ",\n\t\t"
+                  # FragC.regstring # ",\n\t\t"
+                  # "$metadata" # ",\n\t\t"
+                  # "$selector" # ",\n\t\t"
+                  # "$scale_a, {{$byte_id_a, $thread_id_a}}" # ",\n\t\t"
+                  # "$scale_b, {{$byte_id_b, $thread_id_b}};";
+}
+
+let isConvergent = true in {
+defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
+  foreach kind = ["mxf4", "mxf4nvf4", "mxf8f6f4"] in {
+    foreach scale_vec_size = ["", ".scale_1x", ".scale_2x", ".scale_4x"] in {
+      foreach stype = ["ue8m0", "ue4m3"] in {
+      foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
+        if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
+          def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp", "sp::ordered_metadata", kind>,
+                       WMMA_REGINFO<op[1], "mma.sp", "sp::ordered_metadata", kind>,
+                       WMMA_REGINFO<op[2], "mma.sp", "sp::ordered_metadata", kind>,
+                       WMMA_REGINFO<op[3], "mma.sp", "sp::ordered_metadata", kind>,
+                       kind, stype, scale_vec_size>;
+        }
+      } // op
+      } // stype
+    } // scale_vec_size
+  } // kind
+} // defset
+}
+
 //
 // ldmatrix.sync.aligned.m8n8[|.trans][|.shared].b16
 //
@@ -5159,7 +5286,8 @@ class MMA_PAT<WMMA_INSTR wi>
         Requires<wi.Predicates>;
 
 // Build intrinsic->instruction patterns for all MMA instructions.
-foreach mma = !listconcat(MMAs, WMMAs, MMA_LDSTs, LDMATRIXs, STMATRIXs, MMA_SPs) in
+foreach mma = !listconcat(MMAs, MMA_BLOCK_SCALEs, WMMAs, MMA_LDSTs, LDMATRIXs,
+                          STMATRIXs, MMA_SPs, MMA_SP_BLOCK_SCALEs) in
   def : MMA_PAT<mma>;
 
 multiclass MAPA<string suffix, Intrinsic Intr> {
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 8427ae4ad72da..81c78219075f3 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -131,7 +131,7 @@ def __init__(self, geom, frag, ptx_elt_type, is_mma_sparse=False):
             "m16n8k64:b:e5m2": 4,
             "m16n8k64:b:e3m2": 4,
             "m16n8k64:b:e2m3": 4,
-            "m16n8k64:b:e2m1": 4,
+            "m16n8k64:b:e2m1": 4 if is_mma_sparse else 2,
             "m16n8k64:c:f16": 2,
             "m16n8k64:c:f32": 4,
             "m16n8k64:d:f16": 2,
@@ -1131,6 +1131,163 @@ def gen_mma_tests():
     return generated_items
 
 
+def get_mma_block_scale_ops():
+    return (
+        make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], [])
+        + make_mma_ops(
+            ["m16n8k32"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f32"],
+            [],
+        )
+    )
+
+
+def is_mma_block_scale_geom_supported(geom):
+    # geometries for FP.
+    if geom in [
+        "m16n8k32",
+        "m16n8k64",
+    ]:
+        return True
+    raise ValueError(f"Unexpected MMA block scale geometry: {geom}")
+
+
+def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+    if not (
+        is_type_supported(op.a.mma_type.ptx_type)
+        and is_mma_block_scale_geom_supported(op.a.geom)
+    ):
+        return False
+
+    if (
+        op.a.geom == "m16n8k64"
+        and kind == "mxf4"
+        and stype == "ue8m0"
+        and scale_vec_size in ["", ".scale_vec::2X"]
+    ):
+        return True
+
+    if (
+        op.a.geom == "m16n8k64"
+        and kind == "mxf4nvf4"
+        and stype == "ue8m0"
+        and scale_vec_size == ".scale_vec::2X"
+    ):
+        return True
+
+    if (
+        op.a.geom == "m16n8k64"
+        and kind == "mxf4nvf4"
+        and stype == "ue4m3"
+        and scale_vec_size == ".scale_vec::4X"
+    ):
+        return True
+
+    if (
+        op.a.geom == "m16n8k32"
+        and kind == "mxf8f6f4"
+        and stype == "ue8m0"
+        and scale_vec_size in ["", ".scale_vec::1X"]
+    ):
+        return True
+
+    return False
+
+
+def common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+    mma_block_scale_template = """
+declare ${ret_ty} @${intrinsic}(
+        ${args});
+
+; CHECK-LABEL: .func {{.*}}test_${function}(
+define ${ret_ty} @test_${function}(
+        ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_scale_a_data}
+; CHECK-NEXT: ${check_byte_id_a}
+; CHECK-NEXT: ${check_thread_id_a}
+; CHECK-NEXT: ${check_scale_b_data}
+; CHECK-NEXT: ${check_byte_id_b}
+; CHECK-NEXT: ${check_thread_id_b}
+  %r = call ${ret_ty} @${intrinsic}(
+        ${args});
+  ret ${ret_ty} %r;
+}
+"""
+
+    test_params = params
+    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["function"] = test_params["intrinsic"].replace(".", "_")
+    test_params["instruction"] = Template(instruction_template).substitute(params)
+    test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+    test_params["check_a"] = check_pattern(op.a)
+    test_params["check_b"] = check_pattern(op.b)
+    test_params["check_c"] = check_pattern(op.c)
+    test_params["check_d"] = check_pattern(op.d)
+    test_params["check_scale_a_data"] = "{{%r[0-9]+}}"
+    test_params["check_byte_id_a"] = "{{%r[0-9]+}}"
+    test_params["check_thread_id_a"] = "{{%r[0-9]+}}"
+    test_params["check_scale_b_data"] = "{{%r[0-9]+}}"
+    test_params["check_byte_id_b"] = "{{%r[0-9]+}}"
+    test_params["check_thread_id_b"] = "{{%r[0-9]+}}"
+    args = ",\n        ".join(
+        list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
+        + ["i32 %scale_a_data", "i16 %byte_id_a, i16 %thread_id_a"]
+        + ["i32 %scale_b_data", "i16 %byte_id_b, i16 %thread_id_b"]
+    )
+    test_params["args"] = args
+    print(Template(mma_block_scale_template).substitute(test_params))
+    return (test_params["intrinsic"], test_params["instruction"])
+
+
+def gen_mma_block_scale_tests():
+    if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
+        return []
+
+    mma_block_scale_intrinsic_template = (
+        "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+    )
+    mma_block_scale_instruction_template = (
+        "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
+    )
+
+    generated_items = []
+
+    for op, kind, scale_vec_size, stype in product(
+        get_mma_block_scale_ops(),
+        ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+        ["", ".scale_vec::1X", ".scale_vec::2X", ".scale_vec::4X"],
+        ["ue8m0", "ue4m3"],
+    ):
+        if not is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+            continue
+
+        params = {
+            "intrinsic_signature": mma_signature(op),
+            "ptx_signature": mma_ptx_signature(op),
+            "geom": op.a.geom,
+            "kind": kind,
+            "scale_vec_size": scale_vec_size,
+            "scale": scale_vec_size.replace("_vec::", ".").lower(),
+            "stype": stype,
+        }
+
+        intrinsic_template = mma_block_scale_intrinsic_template
+        instruction_template = mma_block_scale_instruction_template
+
+        generated_items.append(
+            common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template)
+        )
+
+    return generated_items
+
+
 def get_mma_sp_ops():
     return (
         make_mma_ops(["m16n8k16", "m16n8k32"], ["bf16"], [], ["f32"], [], True)
@@ -1224,7 +1381,11 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
     return True
 
 
-def sp_selector_gen(op):
+def sp_selector_gen(op, block_scale = False):
+    if block_scale:
+        # PTX ISA 9.0 has the sparsity selector equal to 0 only
+        return range(1)
+
     # (geom, type) -> allowed selector range
     range_01 = {
         ("m16n8k32", "bf16"),
@@ -1355,6 +1516,181 @@ def gen_mma_sp_tests():
     return generated_items
 
 
+def get_mma_sp_block_scale_ops():
+    return (
+        make_mma_ops(["m16n8k128"], ["e2m1"], [], ["f32"], [], True)
+        + make_mma_ops(
+            ["m16n8k64"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+            ["f32"],
+            [],
+            True,
+        )
+    )
+
+
+def is_mma_sp_block_scale_geom_supported(geom):
+    # geometries for FP.
+    if geom in [
+        "m16n8k64",
+        "m16n8k128",
+    ]:
+        return True
+    raise ValueError(f"Unexpected sparse MMA block scale geometry: {geom}")
+
+
+def is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+    if not (
+        is_type_supported(op.a.mma_type.ptx_type)
+        and is_mma_sp_block_scale_geom_supported(op.a.geom)
+    ):
+        return False
+
+    if (
+        op.a.geom == "m16n8k128"
+        and kind == "mxf4"
+        and stype == "ue8m0"
+        and scale_vec_size in ["", ".scale_vec::2X"]
+    ):
+        return True
+
+    if (
+        op.a.geom == "m16n8k128"
+        and kind == "mxf4nvf4"
+        and stype == "ue8m0"
+        and scale_vec_size == ".scale_vec::2X"
+    ):
+        return True
+
+    if (
+        op.a.geom == "m16n8k128"
+        and kind == "mxf4nvf4"
+        and stype == "ue4m3"
+        and scale_vec_size == ".scale_vec::4X"
+    ):
+        return True
+
+    if (
+        op.a.geom == "m16n8k64"
+        and kind == "mxf8f6f4"
+        and stype == "ue8m0"
+        and scale_vec_size in ["", ".scale_vec::1X"]
+    ):
+        return True
+
+    return False
+
+
+def common_mma_sp_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+    mma_sp_block_scale_decl_template = """
+declare ${ret_ty} @${intrinsic}(
+        ${args});
+"""
+
+    mma_sp_block_scale_test_template = """
+; CHECK-LABEL: .func {{.*}}test_${function}_${selector}(
+define ${ret_ty} @test_${function}_${selector}(
+        ${args}) {
+; CHECK: ${instruction}
+; CHECK-NEXT: ${check_d}
+; CHECK-NEXT: ${check_a}
+; CHECK-NEXT: ${check_b}
+; CHECK-NEXT: ${check_c}
+; CHECK-NEXT: ${check_metadata}
+; CHECK-NEXT: ${check_selector}
+; CHECK-NEXT: ${check_scale_a_data}
+; CHECK-NEXT: ${check_byte_id_a}
+; CHECK-NEXT: ${check_thread_id_a}
+; CHECK-NEXT: ${check_scale_b_data}
+; CHECK-NEXT: ${check_byte_id_b}
+; CHECK-NEXT: ${check_thread_id_b}
+  %r = call ${ret_ty} @${intrinsic}(
+        ${call_args});
+  ret ${ret_ty} %r;
+}
+"""
+
+    test_params = params
+    test_params["intrinsic"] = Template(intrinsic_template).substitute(params)
+    test_params["function"] = test_params["intrinsic"].replace(".", "_")
+    test_params["instruction"] = Template(instruction_template).substitute(params)
+    test_params["ret_ty"] = make_wmma_ld_ret_ty(op.d)
+    test_params["check_a"] = check_pattern(op.a)
+    test_params["check_b"] = check_pattern(op.b)
+    test_params["check_c"] = check_pattern(op.c)
+    test_params["check_d"] = check_pattern(op.d)
+    test_params["check_metadata"] = "{{%r[0-9]+}}"
+    test_params["check_scale_a_data"] = "{{%r[0-9]+}}"
+    test_params["check_byte_id_a"] = "{{%r[0-9]+}}"
+    test_params["check_thread_id_a"] = "{{%r[0-9]+}}"
+    test_params["check_scale_b_data"] = "{{%r[0-9]+}}"
+    test_params["check_byte_id_b"] = "{{%r[0-9]+}}"
+    test_params["check_thread_id_b"] = "{{%r[0-9]+}}"
+    args = ",\n        ".join(
+        list(make_wmma_slice_args(frag) for frag in (op.a, op.b, op.c))
+        + ["i32 %metadata", "i32 %selector"]
+        + ["i32 %scale_a_data", "i16 %byte_id_a, i16 %thread_id_a"]
+        + ["i32 %scale_b_data", "i16 %byte_id_b, i16 %thread_id_b"]
+    )
+    test_params["args"] = args
+
+    print(Template(mma_sp_block_scale_decl_template).substitute(test_params))
+
+    for selector in [str(r) for r in sp_selector_gen(op, True)]:
+        test_params["selector"] = selector
+        test_params["check_selector"] = "{{" + test_params["selector"] + "}}"
+        test_params["call_args"] = test_params["args"].replace(
+            "%selector", test_params["selector"]
+        )
+
+        print(Template(mma_sp_block_scale_test_template).substitute(test_params))
+
+    return (test_params["intrinsic"], test_params["instruction"])
+
+
+def gen_mma_sp_block_scale_tests():
+    if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
+        return []
+
+    mma_sp_block_scale_intrinsic_template = (
+        "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+    )
+    mma_sp_block_scale_instruction_template = (
+        "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
+    )
+
+    generated_items = []
+
+    for op, kind, scale_vec_size, stype in product(
+        get_mma_sp_block_scale_ops(),
+        ["mxf4", "mxf4nvf4", "mxf8f6f4"],
+        ["", ".scale_vec::1X", ".scale_vec::2X", ".scale_vec::4X"],
+        ["ue8m0", "ue4m3"],
+    ):
+        if not is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
+            continue
+
+        params = {
+            "intrinsic_signature": mma_signature(op),
+            "ptx_signature": mma_ptx_signature(op),
+            "geom": op.a.geom,
+            "kind": kind,
+            "scale_vec_size": scale_vec_size,
+            "scale": scale_vec_size.replace("_vec::", ".").lower(),
+            "stype": stype,
+        }
+
+        intrinsic_template = mma_sp_block_scale_intrinsic_template
+        instruction_template = mma_sp_block_scale_instruction_template
+
+        generated_items.append(
+            common_mma_sp_block_scale_test_gen(params, op, intrinsic_template, instruction_template)
+        )
+
+    return generated_items
+
+
 # Append complete list of intrinsics and instructions we've generated tests for.
 # Generate set of checks to verify that that we did generate sensible set of
 # tests for the given combination of PTX and SM variants.
@@ -1545,7 +1881,9 @@ def gen_tests():
     items += gen_stmatrix_tests()
     items += gen_wmma_mma_tests()
     items += gen_mma_tests()
+    items += gen_mma_block_scale_tests()
     items += gen_mma_sp_tests()
+    items += gen_mma_sp_block_scale_tests()
     gen_check_unsupported_ops(items)
 
 

>From 7fb0570f71c605c8a1f14aa5c2bc62b27cb7552f Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Wed, 15 Oct 2025 16:01:23 +0200
Subject: [PATCH 2/3] [NVPTX] Fixes for code formatting. PR163561.

---
 llvm/test/CodeGen/NVPTX/wmma.py | 66 +++++++++++++++------------------
 1 file changed, 30 insertions(+), 36 deletions(-)

diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index 81c78219075f3..c2e69dd48da99 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -1132,15 +1132,12 @@ def gen_mma_tests():
 
 
 def get_mma_block_scale_ops():
-    return (
-        make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], [])
-        + make_mma_ops(
-            ["m16n8k32"],
-            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
-            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
-            ["f32"],
-            [],
-        )
+    return make_mma_ops(["m16n8k64"], ["e2m1"], [], ["f32"], []) + make_mma_ops(
+        ["m16n8k32"],
+        ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+        ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+        ["f32"],
+        [],
     )
 
 
@@ -1196,7 +1193,9 @@ def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
     return False
 
 
-def common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+def common_mma_block_scale_test_gen(
+    params, op, intrinsic_template, instruction_template
+):
     mma_block_scale_template = """
 declare ${ret_ty} @${intrinsic}(
         ${args});
@@ -1250,12 +1249,8 @@ def gen_mma_block_scale_tests():
     if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
         return []
 
-    mma_block_scale_intrinsic_template = (
-        "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
-    )
-    mma_block_scale_instruction_template = (
-        "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
-    )
+    mma_block_scale_intrinsic_template = "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+    mma_block_scale_instruction_template = "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
 
     generated_items = []
 
@@ -1282,7 +1277,9 @@ def gen_mma_block_scale_tests():
         instruction_template = mma_block_scale_instruction_template
 
         generated_items.append(
-            common_mma_block_scale_test_gen(params, op, intrinsic_template, instruction_template)
+            common_mma_block_scale_test_gen(
+                params, op, intrinsic_template, instruction_template
+            )
         )
 
     return generated_items
@@ -1381,7 +1378,7 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
     return True
 
 
-def sp_selector_gen(op, block_scale = False):
+def sp_selector_gen(op, block_scale=False):
     if block_scale:
         # PTX ISA 9.0 has the sparsity selector equal to 0 only
         return range(1)
@@ -1517,16 +1514,13 @@ def gen_mma_sp_tests():
 
 
 def get_mma_sp_block_scale_ops():
-    return (
-        make_mma_ops(["m16n8k128"], ["e2m1"], [], ["f32"], [], True)
-        + make_mma_ops(
-            ["m16n8k64"],
-            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
-            ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
-            ["f32"],
-            [],
-            True,
-        )
+    return make_mma_ops(["m16n8k128"], ["e2m1"], [], ["f32"], [], True) + make_mma_ops(
+        ["m16n8k64"],
+        ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+        ["e4m3", "e5m2", "e3m2", "e2m3", "e2m1"],
+        ["f32"],
+        [],
+        True,
     )
 
 
@@ -1582,7 +1576,9 @@ def is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
     return False
 
 
-def common_mma_sp_block_scale_test_gen(params, op, intrinsic_template, instruction_template):
+def common_mma_sp_block_scale_test_gen(
+    params, op, intrinsic_template, instruction_template
+):
     mma_sp_block_scale_decl_template = """
 declare ${ret_ty} @${intrinsic}(
         ${args});
@@ -1653,12 +1649,8 @@ def gen_mma_sp_block_scale_tests():
     if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
         return []
 
-    mma_sp_block_scale_intrinsic_template = (
-        "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
-    )
-    mma_sp_block_scale_instruction_template = (
-        "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
-    )
+    mma_sp_block_scale_intrinsic_template = "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
+    mma_sp_block_scale_instruction_template = "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
 
     generated_items = []
 
@@ -1685,7 +1677,9 @@ def gen_mma_sp_block_scale_tests():
         instruction_template = mma_sp_block_scale_instruction_template
 
         generated_items.append(
-            common_mma_sp_block_scale_test_gen(params, op, intrinsic_template, instruction_template)
+            common_mma_sp_block_scale_test_gen(
+                params, op, intrinsic_template, instruction_template
+            )
         )
 
     return generated_items

>From 93994eddbcb9079a4d32a1df0cce801b0f54bb11 Mon Sep 17 00:00:00 2001
From: Kirill Vedernikov <kvedernikov at nvidia.com>
Date: Tue, 18 Nov 2025 20:04:01 +0100
Subject: [PATCH 3/3] [NVPTX] Resolved merge conflicts + updated check for PTX
 version

---
 llvm/include/llvm/IR/IntrinsicsNVVM.td       |  3 ++-
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td     | 22 ++++++++++++--------
 llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py | 12 +++++++++++
 llvm/test/CodeGen/NVPTX/wmma.py              |  4 ++--
 4 files changed, 29 insertions(+), 12 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py

diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 2a8d310b94065..ff4581b76bfcd 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -178,7 +178,8 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType, bit IsSparse = fals
   string gft = Geom#":"#Frag#":"#ptx_elt_type;
   string gf = Geom#":"#Frag;
   string ft = frag#":"#ptx_elt_type;
-  list<LLVMType> regs = !if(!eq(IsSparse, true),
+  bit isSparse = IsSparse;
+  list<LLVMType> regs = !if(!eq(isSparse, true),
     !cond(
       // mma sparse ops use other fragments for some arguments
       !eq(gft,"m16n8k16:a:bf16") : !listsplat(llvm_i32_ty, 2),
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 5ef13b3be6162..34d1e312299c8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -4619,7 +4619,8 @@ def INT_PTX_SREG_WARPSIZE :
 // the fields commonly used to implement specific PTX instruction -- register
 // types and names, constraints, parts of assembly, etc.
 class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "">
-      : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type, !eq(op, "mma.sp")> {
+      : WMMA_REGS<r.geom, r.frag, r.ptx_elt_type,
+                  !or(!eq(op, "mma.sp"), !eq(op, "mma.sp.block_scale"))> {
   // NVPTX register types used to carry fragment data.
   NVPTXRegClass regclass = !cond(
     !eq(ptx_elt_type, "e4m3") : B32,
@@ -4659,6 +4660,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
   // longer the case, we can concat all per-fragment predicates to enforce that
   // all fragments of the instruction are viable.
   list<Predicate> Predicates = !cond(
+    !or(!eq(op, "mma.block_scale"),
+        !eq(op, "mma.sp.block_scale")) : [hasSM120a, hasPTX<88>],
+
     !or(!eq(ptx_elt_type, "e3m2"),
         !eq(ptx_elt_type, "e2m3"),
         !eq(ptx_elt_type, "e2m1"),
@@ -4671,9 +4675,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op, string metadata = "", string kind = "
     !or(!eq(ptx_elt_type, "e4m3"),
         !eq(ptx_elt_type, "e5m2")) : [hasSM<89>, hasPTX<84>],
 
-    !and(!eq(op, "mma.sp"),
+    !and(isSparse,
          !ne(metadata, "sp")) : [hasSM<80>, hasPTX<85>],
-    !eq(op, "mma.sp") : [hasSM<80>, hasPTX<71>],
+    isSparse : [hasSM<80>, hasPTX<71>],
 
     // fp16 -> fp16/fp32 @ m16n16k16
     !and(!eq(geom, "m16n16k16"),
@@ -5027,7 +5031,7 @@ class MMA_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                       WMMA_REGINFO FragC, WMMA_REGINFO FragD,
                       string Kind, string SType, string ScaleVecSize>
   : WMMA_INSTR<MMA_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
-                                    FragA, FragB, FragC, FragD>.record,
+                                    FragA, FragB, FragC, FragD>.record_name,
                                     [FragA.Ins, FragB.Ins, FragC.Ins,
                                      (ins B32:$scale_a, B16:$byte_id_a,
                                           B16:$thread_id_a, B32:$scale_b,
@@ -5144,7 +5148,7 @@ class MMA_SP_BLOCK_SCALE<WMMA_REGINFO FragA, WMMA_REGINFO FragB,
                          WMMA_REGINFO FragC, WMMA_REGINFO FragD,
                          string Kind, string SType, string ScaleVecSize>
   : WMMA_INSTR<MMA_SP_BLOCK_SCALE_NAME<Kind, SType, ScaleVecSize,
-                                       FragA, FragB, FragC, FragD>.record,
+                                       FragA, FragB, FragC, FragD>.record_name,
                [FragA.Ins, FragB.Ins, FragC.Ins,
                 (ins B32:$metadata, i32imm:$selector,
                      B32:$scale_a, B16:$byte_id_a, B16:$thread_id_a,
@@ -5192,10 +5196,10 @@ defset list<WMMA_INSTR> MMA_SP_BLOCK_SCALEs = {
       foreach stype = ["ue8m0", "ue4m3"] in {
       foreach op = NVVM_MMA_OPS.all_mma_sp_block_scale_ops in {
         if NVVM_MMA_SP_BLOCK_SCALE_SUPPORTED<op, kind, stype, scale_vec_size>.ret then {
-          def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp", "sp::ordered_metadata", kind>,
-                       WMMA_REGINFO<op[1], "mma.sp", "sp::ordered_metadata", kind>,
-                       WMMA_REGINFO<op[2], "mma.sp", "sp::ordered_metadata", kind>,
-                       WMMA_REGINFO<op[3], "mma.sp", "sp::ordered_metadata", kind>,
+          def : MMA_SP_BLOCK_SCALE<WMMA_REGINFO<op[0], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+                       WMMA_REGINFO<op[1], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+                       WMMA_REGINFO<op[2], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
+                       WMMA_REGINFO<op[3], "mma.sp.block_scale", "sp::ordered_metadata", kind>,
                        kind, stype, scale_vec_size>;
         }
       } // op
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py
new file mode 100644
index 0000000000000..f1666dbc5f30f
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx88-sm120a.py
@@ -0,0 +1,12 @@
+# Check all variants of instructions supported by PTX88 on SM120a
+# RUN: %python %s --ptx=88 --gpu-arch=120 --aa > %t-ptx88-sm_120a.ll
+# RUN: llc < %t-ptx88-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx88 \
+# RUN:           | FileCheck %t-ptx88-sm_120a.ll
+# RUN: %if ptxas-sm_120a && ptxas-isa-8.8 %{                                  \
+# RUN: llc < %t-ptx88-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx88 \
+# RUN:           | %ptxas-verify -arch=sm_120a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index c2e69dd48da99..817665a68d7a5 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -1246,7 +1246,7 @@ def common_mma_block_scale_test_gen(
 
 
 def gen_mma_block_scale_tests():
-    if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
+    if not (ptx_version >= 88 and gpu_arch >= 120 and aa):
         return []
 
     mma_block_scale_intrinsic_template = "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
@@ -1646,7 +1646,7 @@ def common_mma_sp_block_scale_test_gen(
 
 
 def gen_mma_sp_block_scale_tests():
-    if not (ptx_version >= 87 and gpu_arch >= 120 and aa):
+    if not (ptx_version >= 88 and gpu_arch >= 120 and aa):
         return []
 
     mma_sp_block_scale_intrinsic_template = "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"



More information about the llvm-commits mailing list