[llvm] [NVVM][NVPTX] Add support for tcgen05.mma (PR #151949)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 15 13:29:33 PDT 2025
================
@@ -2464,4 +2504,392 @@ def int_nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_ # dim
"llvm.nvvm.clusterlaunchcontrol.query_cancel.get_first_ctaid." # dim>;
}
-} // let TargetPrefix = "nvvm"
+//
+// tcgen05.mma Intrinsics
+//
+
+foreach space = ["tensor", "shared"] in {
+ foreach ashiftid = !if(!eq(space, "tensor"), [0, 1], [0]) in {
+ defvar a_operand_type = !if(!eq(space, "tensor"), llvm_tmem_ptr_ty,
+ llvm_i64_ty);
+
+ defvar ashift = !if(!eq(ashiftid, 1), ".ashift", "");
+ defvar collector_usage_a_range = !if(!eq(ashiftid, 1), 2, 4);
+
+ def NVVM_TCGEN05_MMA_NAME<"", space, ashift, "">.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty], // 4. enable_inp_d
+ // flags
+ [llvm_i32_ty, // 5. kind
+ llvm_i32_ty, // 6. cta_group
+ llvm_i32_ty]), // 7. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<5>>, Range<ArgIndex<5>, 0, 4>,
+ ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 1, 3>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_NAME<"", space, ashift, "">.intr>;
+
+ def NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "">.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_tmem_ptr_ty], // 5. spmetadata
+ // flags
+ [llvm_i32_ty, // 6. kind
+ llvm_i32_ty, // 7. cta_group
+ llvm_i32_ty]), // 8. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 4>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 1, 3>,
+ ImmArg<ArgIndex<8>>, Range<ArgIndex<8>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "">.intr>;
+
+ // scale_d
+ foreach kind = ["f16", "tf32"] in {
+ def NVVM_TCGEN05_MMA_NAME<"", space, ashift, "." # kind # ".scale_d">.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_i64_ty], // 5. scale_d_imm
+ // flags
+ [llvm_i32_ty, // 6. cta_group
+ llvm_i32_ty]), // 7. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<5>>, Range<ArgIndex<5>, 0, 16>,
+ ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 1, 3>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_NAME<"", space, ashift, "." # kind # ".scale_d">.intr>;
+
+ def NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "." # kind # ".scale_d">.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_tmem_ptr_ty, // 5. spmetadata
+ llvm_i64_ty], // 6. scale_d_imm
+ // flags
+ [llvm_i32_ty, // 7. cta_group
+ llvm_i32_ty]), // 8. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 16>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 1, 3>,
+ ImmArg<ArgIndex<8>>, Range<ArgIndex<8>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "." # kind # ".scale_d">.intr>;
+ }
+ }
+}
+
+//
+// tcgen05.mma disable_output_lane intrinsics
+//
+foreach space = ["tensor", "shared"] in {
+ foreach ashiftid = !if(!eq(space, "tensor"), [0, 1], [0]) in {
+ defvar a_operand_type = !if(!eq(space, "tensor"),
+ llvm_tmem_ptr_ty,
+ llvm_i64_ty);
+ defvar ashift = !if(!eq(ashiftid, 1), ".ashift", "");
+ defvar collector_usage_a_range = !if(!eq(ashiftid, 1), 2, 4);
+
+ def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 1, ashift>.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_v4i32_ty], // 5. disable output lane
+ // flags
+ [llvm_i32_ty, // 6. kind
+ llvm_i32_ty]), // 7. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 4>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 1, ashift>.intr>;
+
+ def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 2, ashift>.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_v8i32_ty], // 5. disable output lane
+ // flags
+ [llvm_i32_ty, // 6. kind
+ llvm_i32_ty]), // 7. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 4>,
+ ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 2, ashift>.intr>;
+
+ def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<".sp", space, 1, ashift>.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_tmem_ptr_ty, // 5. spmetadata
+ llvm_v4i32_ty], // 6. disable output lane
+ // flags
+ [llvm_i32_ty, // 7. kind
+ llvm_i32_ty]), // 8. collector_usage_a
+ !listconcat([IntrArgMemOnly,
+ WriteOnly<ArgIndex<0>>],
+ !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+ [ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, 4>,
+ ImmArg<ArgIndex<8>>, Range<ArgIndex<8>, 0, collector_usage_a_range>]),
+ NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<".sp", space, 1, ashift>.intr>;
+
+ def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<".sp", space, 2, ashift>.record:
+ DefaultAttrsIntrinsic<[],
+ !listconcat([llvm_tmem_ptr_ty, // 0. dtmem
+ a_operand_type, // 1. a
+ llvm_i64_ty, // 2. b
+ llvm_i32_ty, // 3. idesc
+ llvm_i1_ty, // 4. enable_inp_d
+ llvm_tmem_ptr_ty, // 5. spmetadata
+ llvm_v8i32_ty], // 6. disable output lane
----------------
AlexMaclean wrote:
It seems like using illegal types in these intrinsics has forced you to add a lot of additional logic in NVPTX isel lowering. Is there a reason we need to use these types? Could these vectors be flattened into a series of arguments?
https://github.com/llvm/llvm-project/pull/151949
More information about the llvm-commits
mailing list