[llvm] [NVVM][NVPTX] Add support for tcgen05.mma (PR #151949)
Pradeep Kumar via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 8 08:59:03 PDT 2025
================
@@ -2464,4 +2504,392 @@ def int_nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_ # dim
"llvm.nvvm.clusterlaunchcontrol.query_cancel.get_first_ctaid." # dim>;
}
-} // let TargetPrefix = "nvvm"
+//
+// tcgen05.mma Intrinsics
+//
+
+foreach space = ["tensor", "shared"] in {
+ foreach ashiftid = !if(!eq(space, "tensor"), [0, 1], [0]) in {
+ defvar a_operand_type = !if(!eq(space, "tensor"), llvm_tmem_ptr_ty,
+ llvm_i64_ty);
+
+ defvar ashift = !if(!eq(ashiftid, 1), ".ashift", "");
+ defvar collector_usage_a_range = !if(!eq(ashiftid, 1), 2, 4);
+
+ def NVVM_TCGEN05_MMA_NAME<"", space, ashift, "">.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty], // 4. enable_inp_d
+ // flags
+ [llvm_i32_ty, // 5. kind
+ llvm_i32_ty, // 6. cta_group
+ llvm_i32_ty]), // 7. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<5>>, Range<ArgIndex<5>, 0, 4>,
+ ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 1, 3>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_NAME<"", space, ashift, "">.intr>;
+
+ def NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "">.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_tmem_ptr_ty], // 5. spmetadata
+ // flags
+ [llvm_i32_ty, // 6. kind
+ llvm_i32_ty, // 7. cta_group
+ llvm_i32_ty]), // 8. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 4>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 1, 3>,
+ ImmArg<ArgIndex<8>>, Range<ArgIndex<8>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "">.intr>;
+
+ // scale_d
+ foreach kind = ["f16", "tf32"] in {
+ def NVVM_TCGEN05_MMA_NAME<"", space, ashift, "." # kind # ".scale_d">.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_i64_ty], // 5. scale_d_imm
+ // flags
+ [llvm_i32_ty, // 6. cta_group
+ llvm_i32_ty]), // 7. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<5>>, Range<ArgIndex<5>, 0, 16>,
+ ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 1, 3>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_NAME<"", space, ashift, "." # kind # ".scale_d">.intr>;
+
+ def NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "." # kind # ".scale_d">.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_tmem_ptr_ty, // 5. spmetadata
+ llvm_i64_ty], // 6. scale_d_imm
+ // flags
+ [llvm_i32_ty, // 7. cta_group
+ llvm_i32_ty]), // 8. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 16>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 1, 3>,
+ ImmArg<ArgIndex<8>>, Range<ArgIndex<8>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "." # kind # ".scale_d">.intr>;
+ }
+ }
+}
+
+//
+// tcgen05.mma disable_output_lane intrinsics
+//
+foreach space = ["tensor", "shared"] in {
+ foreach ashiftid = !if(!eq(space, "tensor"), [0, 1], [0]) in {
+ defvar a_operand_type = !if(!eq(space, "tensor"),
+ llvm_tmem_ptr_ty,
+ llvm_i64_ty);
+ defvar ashift = !if(!eq(ashiftid, 1), ".ashift", "");
+ defvar collector_usage_a_range = !if(!eq(ashiftid, 1), 2, 4);
+
+ def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 1, ashift>.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_v4i32_ty], // 5. disable output lane
+ // flags
+ [llvm_i32_ty, // 6. kind
+ llvm_i32_ty]), // 7. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 4>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 1, ashift>.intr>;
+
+ def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 2, ashift>.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_v8i32_ty], // 5. disable output lane
+ // flags
+ [llvm_i32_ty, // 6. kind
+ llvm_i32_ty]), // 7. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 4>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 2, ashift>.intr>;
+
+ def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<".sp", space, 1, ashift>.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_tmem_ptr_ty, // 5. spmetadata
+ llvm_v4i32_ty], // 6. disable output lane
+ // flags
+ [llvm_i32_ty, // 7. kind
+ llvm_i32_ty]), // 8. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, 4>,
+ ImmArg<ArgIndex<8>>, Range<ArgIndex<8>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<".sp", space, 1, ashift>.intr>;
+
+ def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<".sp", space, 2, ashift>.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_tmem_ptr_ty, // 5. spmetadata
+ llvm_v8i32_ty], // 6. disable output lane
----------------
schwarzschild-radius wrote:
The PTX spec mentions disable_output_lane as a vector and CUTLASS (https://github.com/NVIDIA/cutlass/blob/76c96b0be35cb263debe3e3d8418b80911a544ab/include/cute/arch/mma_sm100_umma.hpp#L149) also uses a vector to model it. The problem with scalarizing is that if the user chooses to use a vector, then the user will have to insert a series of extractelement and insertelement in and around the tcgen05.mma code to conform with the interface which can lead to increase in compile. As discussed offline, we should enable support for vector type natively in the backend to workaround the boilerplatecode
https://github.com/llvm/llvm-project/pull/151949
More information about the llvm-commits
mailing list