[llvm] [NVVM][NVPTX] Add support for tcgen05.mma (PR #151949)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 4 05:16:33 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Pradeep Kumar (schwarzschild-radius)
<details>
<summary>Changes</summary>
This commit adds support for tcgen05.mma instructions in NVPTX with tests under CodeGen/NVPTX/tcgen05-mma*. This tcgen05.mma instructions are modeled as intrinsics with multiple flag arguments to model cta_group, mma kind, collector usage etc. The rationale for the design is documented in NVPTXUsage.rst file. For more details, please refer the [PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/?a#tcgen05-mma-instructions)
---
Patch is 386.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151949.diff
13 Files Affected:
- (modified) llvm/docs/NVPTXUsage.rst (+385-3)
- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+429-1)
- (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+14)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+291)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+38-1)
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+478-4)
- (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+1-1)
- (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-block-scale-ptx88.ll (+526)
- (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-block-scale.ll (+291)
- (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-disable-output-lane.ll (+677)
- (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-scale-d.ll (+412)
- (added) llvm/test/CodeGen/NVPTX/tcgen05-mma-ws.ll (+569)
- (added) llvm/test/CodeGen/NVPTX/tcgen05-mma.ll (+601)
``````````diff
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index d28eb6860c33a..1b61df2cf5254 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -1945,6 +1945,388 @@ The last argument `i1 %unpack` is a compile-time constant which when set, indica
For more information, refer to the
`PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
+tcgen05.mma Intrinsics
+----------------------
+
+One of the key instructions introduced in the Blackwell architecture is the tcgen05.mma family, which carries out matrix multiply-accumulate operations using the 5th generation Tensor Core unit. The `tcgen05.mma` instruction supports a broad range of capabilities, including sparsity, block scaling, and weight-stationary convolutions. Accurately modeling these through intrinsics is highly complex, and the following table outlines the large number of intrinsics required to fully support the tcgen05.mma instruction set.
+
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| variant | Configuration | Total Variants |
++====================================+===================================================================================================+================+
+| tcgen05.mma.shared | 2 (space) x 2 (sp) x 4 (kind) x 2 (cta_group) x 4 (collector_usage) | 128 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.tensor.ashift | 2 (sp) x 4 (kind) x 2 (cta_group) x 2 (collector_usage) | 32 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.scale_d | 2 (space) x 2 (sp) x 2 (kind) x 2 (cta_group) x 4 (collector_usage) | 128 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.scale_d.tensor.ashift | 2 (sp) x 2 (kind) x 2 (cta_group) x 2 (collector_usage) | 16 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.disable_output_lane | 2 (space) x 2 (sp) x 4 (kind) x 2 (cta_group) x 4 (collector_usage) | 128 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.disable_output_lane... | 2 (sp) x 4 (kind) x 2 (cta_group) x 2 (collector_usage) | 32 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale | 2 (space) x 1 (mxf4nvf4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage) | 32 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale | 2 (space) x 1 (mxf4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage) | 32 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale | 2 (space) x 1 (mxf8f6f4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage) | 32 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.ws | 2 (space) x 2 (sp) x 4 (kind) x 2 (zero_col_mask) x 4 (collector_usage_op) x 4 (collector_buffer) | 256 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| Total | | 816 |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+
+To reduce the number of possible intrinsic variations, we've modeled the tcgen05.mma instructions using flag operands. We've added range checks to these flags to prevent invalid values. We also expanded some flags back into intrinsic modifiers to avoid supporting invalid combinations of features.
+
+'``llvm.nvvm.tcgen05.mma.*``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+
+ ; .sp variants
+ declare void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group, i32 %collector_usage_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+
+ ; .scale_d variants
+ declare void @llvm.nvvm.tcgen05.mma.shared.f16.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.shared.tf32.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.f16.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.f16.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.tf32.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.tf32.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+
+ ; sp.scale_d variants
+ declare void @llvm.nvvm.tcgen05.mma.sp.shared.f16.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.shared.tf32.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.f16.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.f16.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.tf32.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.tf32.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+
+Overview:
+"""""""""
+
+`nvvm.tcgen05.mma` is an asynchronous intrinsic which initiates an `MxNxK` matrix multiply and accumulate operation, `D = A * B + D` where the `A` matrix is `M x K`, the `B` matrix is `K x N`, and the `D` matrix is `M x N`. The operation of the form `D = A*B` is issued when the input predicate argument `%enable_inp_d` is false. The optional immediate argument `%scale_d_imm` can be specified to scale the input matrix `D` as follows: `D = A * B + D * (2 ^ - %scale_d_imm)`. The valid range of values for argument `%scale_d_imm` is `[0, 15]`. The 32-bit register operand idesc is the instruction descriptor as described in `Instruction descriptor <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor>`__
+
+`nvvm.tcgen05.mma` has single thread semantics, unlike the collective instructions `nvvm.mma.sync` or the PTX `wgmma.mma_async` instruction. So, a single thread issuing the `nvvm.tcgen05.mma` will result in the initiation of the whole matrix multiply and accumulate operation
+
+When `.sp` is specifed, the dimension of A matrix is `M x (K/2)` and requires specifiying an additional `%spmetadata` argument
+
+`.ashift` shifts the rows of the A matrix down by one row, except for the last row in the Tensor Memory. `.ashift` is only allowed with M = 128 or M = 256.
+
+The `%collector_usage_a_op_flag` flag specifies the usage of collector buffer for matrix `A`. It is illegal to specify either of `USE` or `FILL` for `%collector_usage_a_op_flag` along with `.ashift`
+
+For more information, refer to the
+`PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__
+
+The following tables describes the possible values of the flag arguments
+
+`%kind_flag` flag:
+
+============= ==========
+ `kind_flag` value
+============= ==========
+ F16 0
+ TF32 1
+ F8F6F4 2
+ I8 3
+============= ==========
+
+`%cta_group` flag:
+
+============= ==========
+ `cta_group` value
+============= ==========
+ CG1 1
+ CG2 2
+============= ==========
+
+`%collector_usage_a_op_flag` flag:
+
+============================= ==========
+ `collector_usage_a_op_flag` value
+============================= ==========
+ DISCARD 0
+ LASTUSE 1
+ USE 2
+ FILL 3
+============================= ==========
+
+'``llvm.nvvm.tcgen05.mma.block_scale*``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ ; mxf8f6f4
+ declare void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block_scale(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.shared.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf8f6f4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+
+ ; mxf4
+ declare void @llvm.nvvm.tcgen05.mma.shared.mxf4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.shared.mxf4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+
+ ; mxf4nvf4
+ declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block16.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+ declare void @llvm.nvvm.tcgen05.mma.sp.tensor.mxf4nvf4.block32.block_scale(ptr addrspace(6) %d, addrspace(6) %a, i64 %b, i32 %idesc, i1 %enable_inp_d, ptr addrspace(6) %spmetadata, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
+
+Overview:
+"""""""""
+`nvvm.tcgen05.mma.block_scale` is an asynchronous intrinsic which initiates an `MxNxK` matrix multiply and accumulate operation, `D = (A * scale_a) * (B * scale_a) + D` where the `A` matrix is `M x K`, the `B` matrix is `K x N`, and the `D` matrix is `M x N`. The matrices `A` and `B` are scaled with `%scale_A` and `%scale_B` matrices respectively before performing the matrix multiply and accumulate operation. The operation of the form `D = A*B` is issued when the input predicate argument `%enable_inp_d` is false. The 32-bit register operand idesc is the instruction descriptor as described in `Instruction descriptor <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor>`__
+
+`nvvm.tcgen05.mma.block_scale` has single thread semantics, unlike the collective instructions `nvvm.mma.sync` or the PTX `wgmma.mma_async` instruction. So, a single thread issuing the `nvvm.tcgen05.mma.block_scale` will result in the initiation of the whole matrix multiply and accumulate operation
+
+When `.sp` is specifed, the dimension of A matrix is `Mx(K/2)` and requires specifiying an additional `%spmetadata` argument
+
+The `%collector_usage_a_op_flag` flag specifies the usage of collector buffer for matrix `A`
+
+For more information, refer to the
+`PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__
+
+The following tables describes the possible values of the flag arguments
+
+`%kind_flag` flag:
+
+============= ==========
+ `kind_flag` value
+============= ==========
+ MXF8F6F4 0
+ MXF4 1
+ MXF4NVF4 2
+============= ==========
+
+`%cta_group` flag:
+
+============= ==========
+ `cta_group` value
+============= ==========
+ CG1 1
+ CG2 2
+============= ==========
+
+`%collector_usage_a_op_flag` flag:
+
+========================...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/151949
More information about the llvm-commits
mailing list