[Mlir-commits] [mlir] [MLIR][NVVM] unify usage of tcgen05_mma_kind attr for tcgen05.mma Ops (PR #184433)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 3 12:12:44 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Kirill Vedernikov (kvederni)
<details>
<summary>Changes</summary>
This change unifies using of `tcgen05_mma_kind` attribute for tcgen05.mma Ops in MLIR.
Before this change there were two block scale attributes used for tcgen05.mma Ops. One was `MMABlockScaleKindAttr` with `mxf8f6f4`, `mxf4` and `fxf4nvf4` values used for `tcgen05.mma.block_scale` and `tcgen05.mma.sp.block_scale`. Another one was `Tcgen05MMAKindAttr` with `f16`, `tf32`, `f8f6f4` and `i8` values used for `tcgen05.mma`, `tcgen05.mma.sp`, `tcgen05.mma.ws` and `tcgen05.mma.ws.sp`.
`Tcgen05MMAKindAttr` has been extended with values from `MMABlockScaleKindAttr`. Now there is `tcgen05_mma_kind` attribute only for all `tcgen05.mma` Ops in MLIR.
Backward compatibility is not supported. Existing tests and scripts should be updated to use `tcgen05_mma_kind` attribute instead of `block_scale_kind` for all tcgen05.mma MLIR Ops.
---
Patch is 182.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184433.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+41-18)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+9-10)
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir (+49-48)
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-tensor.mlir (+49-48)
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-invalid.mlir (+64-4)
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-shared.mlir (+49-48)
- (modified) mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-sp-block-scale-tensor.mlir (+49-48)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 40f631fa0bb2c..251eb51d1eeb1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -5746,23 +5746,28 @@ def NVVM_ClusterLaunchControlQueryCancelOp
// NVVM tcgen05.mma Ops
//===----------------------------------------------------------------------===//
-def Tcgen05MMAKindF16 : I32EnumAttrCase<"F16", 0, "f16">;
-def Tcgen05MMAKindTF32 : I32EnumAttrCase<"TF32", 1, "tf32">;
-def Tcgen05MMAKindF8F6F4 : I32EnumAttrCase<"F8F6F4", 2, "f8f6f4">;
-def Tcgen05MMAKindINT8 : I32EnumAttrCase<"I8", 3, "i8">;
+def Tcgen05MMAKindF16 : I32EnumAttrCase<"F16", 0, "f16">;
+def Tcgen05MMAKindTF32 : I32EnumAttrCase<"TF32", 1, "tf32">;
+def Tcgen05MMAKindF8F6F4 : I32EnumAttrCase<"F8F6F4", 2, "f8f6f4">;
+def Tcgen05MMAKindI8 : I32EnumAttrCase<"I8", 3, "i8">;
+def Tcgen05MMAKindMXF8F6F4 : I32EnumAttrCase<"MXF8F6F4", 4, "mxf8f6f4">;
+def Tcgen05MMAKindMXF4 : I32EnumAttrCase<"MXF4", 5, "mxf4">;
+def Tcgen05MMAKindMXF4NVF4 : I32EnumAttrCase<"MXF4NVF4", 6, "mxf4nvf4">;
def Tcgen05MMAKind : I32EnumAttr<
"Tcgen05MMAKind",
"tcgen05 MMA Supported Types",
- [Tcgen05MMAKindF8F6F4, Tcgen05MMAKindINT8, Tcgen05MMAKindF16,
- Tcgen05MMAKindTF32]> {
+ [Tcgen05MMAKindF16, Tcgen05MMAKindTF32, Tcgen05MMAKindF8F6F4,
+ Tcgen05MMAKindI8, Tcgen05MMAKindMXF8F6F4, Tcgen05MMAKindMXF4,
+ Tcgen05MMAKindMXF4NVF4]> {
let cppNamespace = "::mlir::NVVM";
let genSpecializedAttr = 0;
}
def Tcgen05MMAKindAttr : EnumAttr<NVVM_Dialect, Tcgen05MMAKind, "tcgen05_mma_kind"> {
let description = [{
- The Tcgen05MMAKind attribute describes the allowed set of types for matrix A and B in the tcgen05.mma.{sp} Op. The following are supported types for each kind:
+ The Tcgen05MMAKind attribute describes the allowed set of types for matrix A and B in the tcgen05.mma Ops.
+ The following are supported types for each kind:
```
+-------------+--------------------------------------------+
@@ -5772,6 +5777,9 @@ def Tcgen05MMAKindAttr : EnumAttr<NVVM_Dialect, Tcgen05MMAKind, "tcgen05_mma_kin
| tf32 | tf32 |
| f8f6f4 | e4m3, e5m2, e2m3, e3m2, e2m1 |
| i8 | unsigned 8b, signed 8b |
+ | mxf8f6f4 | e4m3, e5m2, e2m3, e3m2, e2m1 |
+ | mxf4 | e2m1 |
+ | mxf4nvf4 | e2m1 |
+-------------+--------------------------------------------+
```
}];
@@ -5806,6 +5814,21 @@ def Tcgen05MMACollectorOpAttr : EnumAttr<NVVM_Dialect, Tcgen05MMACollectorOp, "t
let assemblyFormat = "`<` $value `>`";
}
+defvar Tcgen05MMANonBlockScaleKindCases =
+ [Tcgen05MMAKindF16, Tcgen05MMAKindTF32,
+ Tcgen05MMAKindF8F6F4, Tcgen05MMAKindI8];
+
+defvar Tcgen05MMABlockScaleKindCases =
+ [Tcgen05MMAKindMXF8F6F4, Tcgen05MMAKindMXF4, Tcgen05MMAKindMXF4NVF4];
+
+defvar Tcgen05MMANonBlockScaleKindAttr =
+ ConfinedAttr<Tcgen05MMAKindAttr,
+ [EnumAttrIsOneOf<Tcgen05MMAKindAttr, Tcgen05MMANonBlockScaleKindCases>]>;
+
+defvar Tcgen05MMABlockScaleKindAttr =
+ ConfinedAttr<Tcgen05MMAKindAttr,
+ [EnumAttrIsOneOf<Tcgen05MMAKindAttr, Tcgen05MMABlockScaleKindCases>]>;
+
def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
[AttrSizedOperandSegments,
NVVMRequiresSMa<[100, 110]>]> {
@@ -5856,7 +5879,7 @@ def NVVM_Tcgen05MMAOp : NVVM_Op<"tcgen05.mma",
}];
let arguments = (ins
- Tcgen05MMAKindAttr:$kind,
+ Tcgen05MMANonBlockScaleKindAttr:$kind,
CTAGroupKindAttr:$ctaGroup,
DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
"Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
@@ -5920,7 +5943,7 @@ def NVVM_Tcgen05MMASparseOp : NVVM_Op<"tcgen05.mma.sp",
}];
let arguments = (ins
- Tcgen05MMAKindAttr:$kind,
+ Tcgen05MMANonBlockScaleKindAttr:$kind,
CTAGroupKindAttr:$ctaGroup,
DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
"Tcgen05MMACollectorOp::DISCARD">:$collectorOp,
@@ -5998,7 +6021,7 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
- `idesc` is a 32 bit value representing the [Instruction Descriptor](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor)
Required Attributes:
- - `kind` is a MMABlockScaleKind attribute
+ - `kind` is a Tcgen05MMAKind attribute restricted to mxf8f6f4, mxf4, or mxf4nvf4
- `ctaGroup` specifies CTA group configuration
* cta_1: MMA will be performed on the current thread's CTA
@@ -6011,7 +6034,7 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
}];
let arguments = (ins
- MMABlockScaleKindAttr:$kind,
+ Tcgen05MMABlockScaleKindAttr:$kind,
CTAGroupKindAttr:$ctaGroup,
DefaultValuedAttr<Tcgen05MMABlockScaleAttr,
"Tcgen05MMABlockScale::DEFAULT">:$blockScale,
@@ -6026,8 +6049,8 @@ def NVVM_Tcgen05MMABlockScaleOp : NVVM_Op<"tcgen05.mma.block_scale",
);
let assemblyFormat = [{
- $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,` $scaleA `,` $scaleB
- attr-dict `:` `(` type(operands) `)`
+ $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,`
+ $scaleA `,` $scaleB attr-dict `:` `(` type(operands) `)`
}];
let hasVerifier = true;
@@ -6073,7 +6096,7 @@ def NVVM_Tcgen05MMASparseBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale",
}];
let arguments = (ins
- MMABlockScaleKindAttr:$kind,
+ Tcgen05MMABlockScaleKindAttr:$kind,
CTAGroupKindAttr:$ctaGroup,
DefaultValuedAttr<Tcgen05MMABlockScaleAttr,
"Tcgen05MMABlockScale::DEFAULT">:$blockScale,
@@ -6090,8 +6113,8 @@ def NVVM_Tcgen05MMASparseBlockScaleOp : NVVM_Op<"tcgen05.mma.sp.block_scale",
);
let assemblyFormat = [{
- $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,` $sparseMetadata `,` $scaleA `,` $scaleB
- attr-dict `:` `(` type(operands) `)`
+ $matrixD `,` $matrixA `,` $matrixB `,` $idesc `,` $enableInputD `,`
+ $sparseMetadata `,` $scaleA `,` $scaleB attr-dict `:` `(` type(operands) `)`
}];
let hasVerifier = true;
@@ -6166,7 +6189,7 @@ def NVVM_Tcgen05MMAWsOp : NVVM_Op<"tcgen05.mma.ws",
}];
let arguments = (ins
- Tcgen05MMAKindAttr:$kind,
+ Tcgen05MMANonBlockScaleKindAttr:$kind,
DefaultValuedAttr<Tcgen05MMACollectorBBufferAttr,
"Tcgen05MMACollectorBBuffer::B0">:$collectorBBuffer,
DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
@@ -6226,7 +6249,7 @@ def NVVM_Tcgen05MMAWsSparseOp : NVVM_Op<"tcgen05.mma.ws.sp",
}];
let arguments = (ins
- Tcgen05MMAKindAttr:$kind,
+ Tcgen05MMANonBlockScaleKindAttr:$kind,
DefaultValuedAttr<Tcgen05MMACollectorBBufferAttr,
"Tcgen05MMACollectorBBuffer::B0">:$collectorBBuffer,
DefaultValuedAttr<Tcgen05MMACollectorOpAttr,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 10e23ef21b219..f0d22d896d88a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -5240,7 +5240,7 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
auto kind = thisOp.getKind();
auto blockScale = thisOp.getBlockScale();
llvm::Intrinsic::ID ID = [&]() {
- if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
+ if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_tensor_mxf8f6f4_block_scale
@@ -5253,7 +5253,7 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
: llvm::Intrinsic::
nvvm_tcgen05_mma_shared_mxf8f6f4_block_scale_block32;
}
- } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
+ } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor
? llvm::Intrinsic::nvvm_tcgen05_mma_tensor_mxf4_block_scale
@@ -5264,7 +5264,7 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
: llvm::Intrinsic::
nvvm_tcgen05_mma_shared_mxf4_block_scale_block32;
}
- } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
+ } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor
? llvm::Intrinsic::
@@ -5287,15 +5287,14 @@ mlir::NVVM::IDArgPair Tcgen05MMABlockScaleOp::getIntrinsicIDAndArgs(
}
static LogicalResult verifyTcgen05MMABlockScaleOp(
- NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::MMABlockScaleKind kind,
+ NVVM::Tcgen05MMACollectorOp collectorOp, NVVM::Tcgen05MMAKind kind,
NVVM::Tcgen05MMABlockScale blockScale, Location loc) {
-
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT &&
- kind == MMABlockScaleKind::MXF4NVF4)
+ kind == NVVM::Tcgen05MMAKind::MXF4NVF4)
return emitError(loc, "mxf4nvf4 requires block scale attribute");
if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK16 &&
- kind != MMABlockScaleKind::MXF4NVF4)
+ kind != NVVM::Tcgen05MMAKind::MXF4NVF4)
return emitError(loc,
llvm::formatv("{} kind does not support block16 attribute",
stringifyEnum(kind)));
@@ -5338,7 +5337,7 @@ mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
auto kind = thisOp.getKind();
auto blockScale = thisOp.getBlockScale();
llvm::Intrinsic::ID ID = [&]() {
- if (kind == NVVM::MMABlockScaleKind::MXF8F6F4) {
+ if (kind == NVVM::Tcgen05MMAKind::MXF8F6F4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf8f6f4_block_scale
@@ -5351,7 +5350,7 @@ mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf8f6f4_block_scale_block32;
}
- } else if (kind == NVVM::MMABlockScaleKind::MXF4) {
+ } else if (kind == NVVM::Tcgen05MMAKind::MXF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::DEFAULT) {
return isATensor ? llvm::Intrinsic::
nvvm_tcgen05_mma_sp_tensor_mxf4_block_scale
@@ -5364,7 +5363,7 @@ mlir::NVVM::IDArgPair Tcgen05MMASparseBlockScaleOp::getIntrinsicIDAndArgs(
: llvm::Intrinsic::
nvvm_tcgen05_mma_sp_shared_mxf4_block_scale_block32;
}
- } else if (kind == NVVM::MMABlockScaleKind::MXF4NVF4) {
+ } else if (kind == NVVM::Tcgen05MMAKind::MXF4NVF4) {
if (blockScale == NVVM::Tcgen05MMABlockScale::BLOCK32) {
return isATensor
? llvm::Intrinsic::
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
index 9f7dd3ed4b6b4..99e3fe45a51b7 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-mma-block-scale-shared.mlir
@@ -1,39 +1,40 @@
// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s
+
// CHECK-LABEL: @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1
llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_1(%d_tmem : !llvm.ptr<6>, %a_desc: i64, %adesc: i64, %b_desc: i64, %idesc: i32, %enable_input_d: i1, %scale_a: !llvm.ptr<6>, %scale_b : !llvm.ptr<6>) {
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 2)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<fill>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 1, i32 3)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_1>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>, collectorOp = #nvvm.tcgen05_mma_collectorop<use>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
llvm.return
}
@@ -43,35 +44,35 @@ llvm.func @nvvm_tcgen05_mma_mxf8f6f4_block_scale_cta_2(%d_tmem : !llvm.ptr<6>, %
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale.block32(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 0)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, blockScale = #nvvm.tcgen05_mma_block_scale<block32>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) {{%[0-9]+}}, i64 {{%[0-9]+}}, i64 {{%[0-9]+}}, i32 {{%[0-9]+}}, i1 {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, ptr addrspace(6) {{%[0-9]+}}, i32 2, i32 1)
nvvm.tcgen05.mma.block_scale %d_tmem, %a_desc, %b_desc, %idesc, %enable_input_d, %scale_a, %scale_b
- {kind = #nvvm.block_scale_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
+ {kind = #nvvm.tcgen05_mma_kind<mxf8f6f4>, ctaGroup = #nvvm.cta_group<cta_2>, collectorOp = #nvvm.tcgen05_mma_collectorop<lastuse>} : (!llvm.ptr<6>, i64, i64, i32, i1, !llvm.ptr<6>, !llvm.ptr<6>)
// CHECK: call void...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/184433
More information about the Mlir-commits
mailing list