[llvm] [NVVM][NVPTX] Add support for tcgen05.mma (PR #151949)

Artem Belevich via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 15 13:05:51 PDT 2025


================
@@ -1945,6 +1945,388 @@ The last argument `i1 %unpack` is a compile-time constant which when set, indica
 For more information, refer to the
 `PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-st>`__.
 
+tcgen05.mma Intrinsics
+----------------------
+
+One of the key instructions introduced in the Blackwell architecture is the tcgen05.mma family, which carries out matrix multiply-accumulate operations using the 5th generation Tensor Core unit. The `tcgen05.mma` instruction supports a broad range of capabilities, including sparsity, block scaling, and weight-stationary convolutions. Accurately modeling these through intrinsics is highly complex, and the following table outlines the large number of intrinsics required to fully support the tcgen05.mma instruction set.
+
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| variant                            | Configuration                                                                                     | Total Variants |
++====================================+===================================================================================================+================+
+| tcgen05.mma.shared                 | 2 (space) x 2 (sp) x 4 (kind) x 2 (cta_group) x 4 (collector_usage)                               | 128            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.tensor.ashift          | 2 (sp) x 4 (kind) x 2 (cta_group) x 2 (collector_usage)                                           | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.scale_d                | 2 (space) x 2 (sp) x 2 (kind) x 2 (cta_group) x 4 (collector_usage)                               | 128            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.scale_d.tensor.ashift  | 2 (sp) x 2 (kind) x 2 (cta_group) x 2 (collector_usage)                                           | 16             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.disable_output_lane    | 2 (space) x 2 (sp) x 4 (kind) x 2 (cta_group) x 4 (collector_usage)                               | 128            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.disable_output_lane... | 2 (sp) x 4 (kind) x 2 (cta_group) x 2 (collector_usage)                                           | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale            | 2 (space) x 1 (mxf4nvf4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage)               | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale            | 2 (space) x 1 (mxf4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage)                   | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.block_scale            | 2 (space) x 1 (mxf8f6f4) x 2 (cta_group) x 2 (scale_vec_size) x 4 (collector_usage)               | 32             |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| tcgen05.mma.ws                     | 2 (space) x 2 (sp) x 4 (kind) x 2 (zero_col_mask) x 4 (collector_usage_op) x 4 (collector_buffer) | 256            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+| Total                              |                                                                                                   | 816            |
++------------------------------------+---------------------------------------------------------------------------------------------------+----------------+
+
+To reduce the number of possible intrinsic variations, we've modeled the tcgen05.mma instructions using flag operands. We've added range checks to these flags to prevent invalid values. We also expanded some flags back into intrinsic modifiers to avoid supporting invalid combinations of features.
+
+'``llvm.nvvm.tcgen05.mma.*``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+  declare void @llvm.nvvm.tcgen05.mma.shared(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; .sp variants
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group, i32 %collector_usage_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i32 %kind_flag, i32 %cta_group_flag, i32 %collector_usage_a_op_flag)
+
+  ; .scale_d variants
+  declare void @llvm.nvvm.tcgen05.mma.shared.f16.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.shared.tf32.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.f16.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.f16.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.tf32.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.tensor.tf32.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+
+  ; sp.scale_d variants
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.f16.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.shared.tf32.scale_d(ptr addrspace(6) %d, i64 %a, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.f16.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.f16.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.tf32.scale_d(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+  declare void @llvm.nvvm.tcgen05.mma.sp.tensor.tf32.scale_d.ashift(ptr addrspace(6) %d, addrspace(6) %atensor, i64 %b, i32 %idesc, addrspace(6) %spmetadata, i1 %enable_inp_d, i64 %scale_d_imm, i32 %cta_group, i32 %collector_usage_a_op_flag)
+
+Overview:
+"""""""""
+
+`nvvm.tcgen05.mma` is an asynchronous intrinsic which initiates an `MxNxK` matrix multiply and accumulate operation, `D = A * B + D` where the `A` matrix is `M x K`, the `B` matrix is `K x N`, and the `D` matrix is `M x N`. The operation of the form `D = A*B` is issued when the input predicate argument `%enable_inp_d` is false. The optional immediate argument `%scale_d_imm` can be specified to scale the input matrix `D` as follows: `D = A * B + D * (2 ^ - %scale_d_imm)`. The valid range of values for argument `%scale_d_imm` is `[0, 15]`. The 32-bit register operand idesc is the instruction descriptor as described in `Instruction descriptor <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instruction-descriptor>`__
+
+`nvvm.tcgen05.mma` has single thread semantics, unlike the collective instructions `nvvm.mma.sync` or the PTX `wgmma.mma_async` instruction. So, a single thread issuing the `nvvm.tcgen05.mma` will result in the initiation of the whole matrix multiply and accumulate operation
+
+When `.sp` is specifed, the dimension of A matrix is `M x (K/2)` and requires specifiying an additional `%spmetadata` argument
+
+`.ashift` shifts the rows of the A matrix down by one row, except for the last row in the Tensor Memory. `.ashift` is only allowed with M = 128 or M = 256.
+
+The `%collector_usage_a_op_flag` flag specifies the usage of collector buffer for matrix `A`. It is illegal to specify either of `USE` or `FILL` for `%collector_usage_a_op_flag` along with `.ashift`
+
+For more information, refer to the
+`PTX ISA <https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-instructions-mma>`__
+
+The following tables describes the possible values of the flag arguments
+
+`%kind_flag` flag:
+
+============= ==========
+  `kind_flag`   value
+============= ==========
+     F16          0
+     TF32         1
+     F8F6F4       2
+     I8           3
+============= ==========
+
+`%cta_group` flag:
+
+============= ==========
+ `cta_group`    value
+============= ==========
+     CG1          1
+     CG2          2
+============= ==========
+
+`%collector_usage_a_op_flag` flag:
+
+============================= ==========
+ `collector_usage_a_op_flag`    value
+============================= ==========
+     DISCARD                      0
+     LASTUSE                      1
+     USE                          2
+     FILL                         3
+============================= ==========
+
+'``llvm.nvvm.tcgen05.mma.block_scale*``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+  ; mxf8f6f4
+  declare void @llvm.nvvm.tcgen05.mma.shared.mxf8f6f4.block_scale(ptr addrspace(6) %d, addrspace(3) %a, addrspace(3) %b, i32 %idesc, i1 %enable_inp_d, addrspace(6) %scale_a, addrspace(6) %scale_b, i32 %kind_flag, i32 cta_group_flag, i32 %collector_usage_a_op_flag)
----------------
Artem-B wrote:

Isn't `%kind_flag` redundant heres, considering that the intrinsic already has `mxf8f6f4` in its name?

Applies to other intrinsics below.

https://github.com/llvm/llvm-project/pull/151949


More information about the llvm-commits mailing list