[Mlir-commits] [mlir] [MLIR][NVVM] unify usage of tcgen05_mma_kind attr for tcgen05.mma Ops (PR #184433)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 3 12:12:44 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Kirill Vedernikov (kvederni)

<details>
<summary>Changes</summary>

This change unifies using of `tcgen05_mma_kind` attribute for tcgen05.mma Ops in MLIR.

Before this change there were two block scale attributes used for tcgen05.mma Ops. One was `MMABlockScaleKindAttr` with `mxf8f6f4`, `mxf4` and `fxf4nvf4` values used for `tcgen05.mma.block_scale` and `tcgen05.mma.sp.block_scale`. Another one was `Tcgen05MMAKindAttr` with `f16`, `tf32`, `f8f6f4` and `i8` values used for `tcgen05.mma`, `tcgen05.mma.sp`, `tcgen05.mma.ws` and `tcgen05.mma.ws.sp`.

`Tcgen05MMAKindAttr` has been extended with values from `MMABlockScaleKindAttr`. Now there is `tcgen05_mma_kind` attribute only for all `tcgen05.mma` Ops in MLIR.

Backward compatibility is not supported. Existing tests and scripts should be updated to use `tcgen05_mma_kind` attribute instead of `block_scale_kind` for all tcgen05.mma MLIR Ops.

---

Patch is 182.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184433.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+41-18) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+9-10) 
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir (+49-48) 
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir (+49-48) 
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir (+64-4) 
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir (+49-48) 
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir (+49-48) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 40f631fa0bb2c..251eb51d1eeb1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -5746,23 +5746,28 @@ def NVVM_ClusterLaunchControlQueryCancelOp
 // NVVM tcgen05.mma Ops
 //===----------------------------------------------------------------------===//
 
-def Tcgen05MMAKindF16          : I32EnumAttrCase<"F16",    0, "f16">;
-def Tcgen05MMAKindTF32         : I32EnumAttrCase<"TF32",   1, "tf32">;
-def Tcgen05MMAKindF8F6F4       : I32EnumAttrCase<"F8F6F4", 2, "f8f6f4">;
-def Tcgen05MMAKindINT8         : I32EnumAttrCase<"I8",     3, "i8">;
+def Tcgen05MMAKindF16          : I32EnumAttrCase<"F16",      0, "f16">;
+def Tcgen05MMAKindTF32         : I32EnumAttrCase<"TF32",     1, "tf32">;
+def Tcgen05MMAKindF8F6F4       : I32EnumAttrCase<"F8F6F4",   2, "f8f6f4">;
+def Tcgen05MMAKindI8           : I32EnumAttrCase<"I8",       3, "i8">;
+def Tcgen05MMAKindMXF8F6F4     : I32EnumAttrCase<"MXF8F6F4", 4, "mxf8f6f4">;
+def Tcgen05MMAKindMXF4         : I32EnumAttrCase<"MXF4",     5, "mxf4">;
+def Tcgen05MMAKindMXF4NVF4     : I32EnumAttrCase<"MXF4NVF4", 6, "mxf4nvf4">;
 
 def Tcgen05MMAKind : I32EnumAttr<
   "Tcgen05MMAKind",
   "tcgen05 MMA Supported Types",
-  [Tcgen05MMAKindF8F6F4, Tcgen05MMAKindINT8, Tcgen05MMAKindF16,
-   Tcgen05MMAKindTF32]> {
+  [Tcgen05MMAKindF16, Tcgen05MMAKindTF32, Tcgen05MMAKindF8F6F4,
+   Tcgen05MMAKindI8, Tcgen05MMAKindMXF8F6F4, Tcgen05MMAKindMXF4,
+   Tcgen05MMAKindMXF4NVF4]> {
     let cppNamespace = "::mlir::NVVM";
     let genSpecializedAttr = 0;
 }
 
 def Tcgen05MMAKindAttr : EnumAttr<NVVM_Dialect, Tcgen05MMAKind, "tcgen05_mma_kind"> {
   let description = [{
-    The Tcgen05MMAKind attribute describes the allowed set of types for matrix A and B in the tcgen05.mma.{sp} Op. The following are supported types for each kind:
+    The Tcgen05MMAKind attribute describes the allowed set of types for matrix A and B in the tcgen05.mma Ops.
+    The following are supported types for each kind:
 
     ```
     +-------------+--------------------------------------------+
@@ -5772,6 +5777,9 @@ def Tcgen05MMAKindAttr : EnumAttr<NVVM_Dialect, Tcgen05MMAKind, "tcgen05_mma_kin
     | tf32        | tf32                                       |
     | f8f6f4      | e4m3, e5m2, e2m3, e3m2, e2m1               |
     | i8          | unsigned 8b, signed 8b                     |
+    | mxf8f6f4    | e4m3, e5m2, e2m3, e3m2, e2m1               |
+    | mxf4        | e2m1                                       |
+    | mxf4nvf4    | e2m1                                       |
     +-------------+--------------------------------------------+
     ```
   }];
@@ -5806,6 +5814,21 @@ def Tcgen05MMACollectorOpAttr : EnumAttr<NVVM_Dialect, Tcgen05MMACollectorOp, "t
   let assemblyFormat = "`<` $value `>`";
 }
 
+defvar Tcgen05MMANonBlockScaleKindCases =
+  [Tcgen05MMAKindF16, Tcgen05MMAKindTF32,
+   Tcgen05MMAKindF8F6F4, Tcgen05MMAKindI8];
+
+defvar Tcgen05MMABlockScaleKindCases =
+  [Tcgen05MMAKindMXF8F6F4, Tcgen05MMAKindMXF4, Tcgen05MMAKindMXF4NVF4];
+
+defvar Tcgen05MMANonBlockScaleKindAttr =
+  ConfinedAttr<Tcgen05MMAKindAttr,
+    [EnumAttrIsOneOf<Tcgen05MMAKindAttr, Tcgen05MMANonBlockScaleKindCases>]>;
+
+defvar Tcgen05MMABlockScaleKindAttr =
+  ConfinedAttr<Tcgen05MMAKindAttr,
+    [EnumAttrIsOneOf<Tcgen05MMAKindAttr, Tcgen05MMABlockScaleKindCases>]>;
+
 def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
                           [AttrSizedOperandSegments,
                            NVVMRequiresSMa<[100, 110]>]> {
@@ -5856,7 +5879,7 @@ def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
   }];
 
   let arguments = (ins
-      Tcgen05MMAKindAttr:$kind,
+      Tcgen05MMANonBlockScaleKindAttr:$kind,
       CTAGroupKindAttr:$ctaGroup,
       DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
                         "Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
@@ -5920,7 +5943,7 @@ def NVVM_Tcgen05MMASparseOp : NVVM_Op<"tcgen05.mma.sp",
   }];
 
   let arguments = (ins
-    Tcgen05MMAKindAttr:$kind,
+    Tcgen05MMANonBlockScaleKindAttr:$kind,
     CTAGroupKindAttr:$ctaGroup,
     DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
                       "Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
@@ -5998,7 +6021,7 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
     - `idesc` is a 32 bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor)
 
     Required Attributes:
-    - `kind` is a MMABlockScaleKind attribute
+    - `kind` is a Tcgen05MMAKind attribute restricted to mxf8f6f4, mxf4, or mxf4nvf4
 
     - `ctaGroup` specifies CTA group configuration
       * cta_1: MMA will be performed on the current thread's CTA
@@ -6011,7 +6034,7 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
   }];
 
   let arguments = (ins
-      MMABlockScaleKindAttr:$kind,
+      Tcgen05MMABlockScaleKindAttr:$kind,
       CTAGroupKindAttr:$ctaGroup,
       DefaultValuedAttr<Tcgen05MMABlockScaleAttr,
                       "Tcgen05MMABlockScale::DEFAULT">:$blockScale,
@@ -6026,8 +6049,8 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
     );
 
   let assemblyFormat = [{
-    $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,` $scaleA `,` $scaleB
-    attr-dict `:` `(` type(operands) `)`
+    $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,`
+    $scaleA `,` $scaleB attr-dict `:` `(` type(operands) `)`
   }];
 
   let hasVerifier = true;
@@ -6073,7 +6096,7 @@ def NVVM_Tcgen05MMASparseBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale",
   }];
 
   let arguments = (ins
-    MMABlockScaleKindAttr:$kind,
+    Tcgen05MMABlockScaleKindAttr:$kind,
     CTAGroupKindAttr:$ctaGroup,
     DefaultValuedAttr<Tcgen05MMABlockScaleAttr,
                       "Tcgen05MMABlockScale::DEFAULT">:$blockScale,
@@ -6090,8 +6113,8 @@ def NVVM_Tcgen05MMASparseBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale",
   );
 
   let assemblyFormat = [{
-    $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,` $sparseMetadata `,`  $scaleA `,`  $scaleB
-    attr-dict `:` `(` type(operands) `)`
+    $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,`
+    $sparseMetadata `,` $scaleA `,` $scaleB attr-dict `:` `(` type(operands) `)`
   }];
 
   let hasVerifier = true;
@@ -6166,7 +6189,7 @@ def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws",
   }];
 
   let arguments = (ins
-    Tcgen05MMAKindAttr:$kind,
+    Tcgen05MMANonBlockScaleKindAttr:$kind,
     DefaultValuedAttr<Tcgen05MMACollectorBBufferAttr,
                       "Tcgen05MMACollectorBBuffer::B0">:$collectorBBuffer,
     DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
@@ -6226,7 +6249,7 @@ def NVVM_Tcgen05MMAWsSparseOp : NVVM_Op<"tcgen05.mma.ws.sp",
   }];
 
   let arguments = (ins
-    Tcgen05MMAKindAttr:$kind,
+    Tcgen05MMANonBlockScaleKindAttr:$kind,
     DefaultValuedAttr<Tcgen05MMACollectorBBufferAttr,
                       "Tcgen05MMACollectorBBuffer::B0">:$collectorBBuffer,
     DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 10e23ef21b219..f0d22d896d88a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -5240,7 +5240,7 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
   auto kind = thisOp.getKind();
   auto blockScale = thisOp.getBlockScale();
   llvm::Intrinsic::ID ID = [&]() {
-    if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
+    if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
       if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
         return isATensor ? llvm::Intrinsic::
                                nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
@@ -5253,7 +5253,7 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
                    : llvm::Intrinsic::
                          nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
       }
-    } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
+    } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
       if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
         return isATensor
                    ? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
@@ -5264,7 +5264,7 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
                          : llvm::Intrinsic::
                                nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
       }
-    } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
+    } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
       if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
         return isATensor
                    ? llvm::Intrinsic::
@@ -5287,15 +5287,14 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
 }
 
 static LogicalResult verifyTcgen05MMABlockScaleOp(
-    NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind,
+    NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
     NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
-
   if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
-      kind == MMABlockScaleKind::MXF4NVF4)
+      kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
     return emitError(loc, "mxf4nvf4 requires block scale attribute");
 
   if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
-      kind != MMABlockScaleKind::MXF4NVF4)
+      kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
     return emitError(loc,
                      llvm::formatv("{} kind does not support block16 attribute",
                                    stringifyEnum(kind)));
@@ -5338,7 +5337,7 @@ mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
   auto kind = thisOp.getKind();
   auto blockScale = thisOp.getBlockScale();
   llvm::Intrinsic::ID ID = [&]() {
-    if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
+    if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
       if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
         return isATensor ? llvm::Intrinsic::
                                nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
@@ -5351,7 +5350,7 @@ mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
                    : llvm::Intrinsic::
                          nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
       }
-    } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
+    } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
       if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
         return isATensor ? llvm::Intrinsic::
                                nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
@@ -5364,7 +5363,7 @@ mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
                    : llvm::Intrinsic::
                          nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
       }
-    } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
+    } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
       if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
         return isATensor
                    ? llvm::Intrinsic::
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
index 9f7dd3ed4b6b4..99e3fe45a51b7 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
@@ -1,39 +1,40 @@
 // RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
 
+
 // CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1
 llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   llvm.return
 }
@@ -43,35 +44,35 @@ llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
   nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
-  {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+  {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
 
   // CHECK: call void...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/184433


More information about the Mlir-commits mailing list