[llvm] [NVVM][NVPTX] Add support for tcgen05.mma (PR #151949)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 15 13:29:33 PDT 2025


================
@@ -2464,4 +2504,392 @@ def int_nvvm_clusterlaunchcontrol_query_cancel_get_first_ctaid_ # dim
                             "llvm.nvvm.clusterlaunchcontrol.query_cancel.get_first_ctaid." # dim>;
 }
 
-} // let TargetPrefix = "nvvm"
+//
+// tcgen05.mma Intrinsics
+//
+
+foreach space = ["tensor", "shared"] in {
+  foreach ashiftid = !if(!eq(space, "tensor"), [0, 1], [0]) in {
+    defvar a_operand_type = !if(!eq(space, "tensor"), llvm_tmem_ptr_ty,
+                                                      llvm_i64_ty);
+
+    defvar ashift = !if(!eq(ashiftid, 1), ".ashift", "");
+    defvar collector_usage_a_range = !if(!eq(ashiftid, 1), 2, 4);
+
+    def NVVM_TCGEN05_MMA_NAME<"", space, ashift, "">.record:
+          DefaultAttrsIntrinsic<[],
+            !listconcat([llvm_tmem_ptr_ty,   // 0. dtmem
+                         a_operand_type,     // 1. a
+                         llvm_i64_ty,        // 2. b
+                         llvm_i32_ty,        // 3. idesc
+                         llvm_i1_ty],        // 4. enable_inp_d
+                        // flags
+                        [llvm_i32_ty,        // 5. kind
+                         llvm_i32_ty,        // 6. cta_group
+                         llvm_i32_ty]),      // 7. collector_usage_a
+            !listconcat([IntrArgMemOnly,
+                         WriteOnly<ArgIndex<0>>],
+                        !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+                        [ImmArg<ArgIndex<5>>, Range<ArgIndex<5>, 0, 4>,
+                         ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 1, 3>,
+                         ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+            NVVM_TCGEN05_MMA_NAME<"", space, ashift, "">.intr>;
+
+    def NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "">.record:
+          DefaultAttrsIntrinsic<[],
+            !listconcat([llvm_tmem_ptr_ty,   // 0. dtmem
+                         a_operand_type,     // 1. a
+                         llvm_i64_ty,        // 2. b
+                         llvm_i32_ty,        // 3. idesc
+                         llvm_i1_ty,         // 4. enable_inp_d
+                         llvm_tmem_ptr_ty],  // 5. spmetadata
+                        // flags
+                        [llvm_i32_ty,        // 6. kind
+                         llvm_i32_ty,        // 7. cta_group
+                         llvm_i32_ty]),      // 8. collector_usage_a
+            !listconcat([IntrArgMemOnly,
+                         WriteOnly<ArgIndex<0>>],
+                        !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+                        [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 4>,
+                         ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 1, 3>,
+                         ImmArg<ArgIndex<8>>, Range<ArgIndex<8>, 0, collector_usage_a_range>]),
+            NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "">.intr>;
+
+    // scale_d
+    foreach kind = ["f16", "tf32"] in {
+      def NVVM_TCGEN05_MMA_NAME<"", space, ashift, "." # kind # ".scale_d">.record: 
+            DefaultAttrsIntrinsic<[],
+              !listconcat([llvm_tmem_ptr_ty,   // 0. dtmem
+                           a_operand_type,     // 1. a
+                           llvm_i64_ty,        // 2. b
+                           llvm_i32_ty,        // 3. idesc
+                           llvm_i1_ty,         // 4. enable_inp_d
+                           llvm_i64_ty],       // 5. scale_d_imm
+                          // flags
+                          [llvm_i32_ty,        // 6. cta_group
+                           llvm_i32_ty]),      // 7. collector_usage_a
+              !listconcat([IntrArgMemOnly,
+                           WriteOnly<ArgIndex<0>>],
+                          !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+                          [ImmArg<ArgIndex<5>>, Range<ArgIndex<5>, 0, 16>,
+                           ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 1, 3>,
+                           ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+              NVVM_TCGEN05_MMA_NAME<"", space, ashift, "." # kind # ".scale_d">.intr>;
+
+      def NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "." # kind # ".scale_d">.record: 
+            DefaultAttrsIntrinsic<[],
+              !listconcat([llvm_tmem_ptr_ty,   // 0. dtmem
+                           a_operand_type,     // 1. a
+                           llvm_i64_ty,        // 2. b
+                           llvm_i32_ty,        // 3. idesc
+                           llvm_i1_ty,         // 4. enable_inp_d
+                           llvm_tmem_ptr_ty,   // 5. spmetadata
+                           llvm_i64_ty],       // 6. scale_d_imm
+                            // flags
+                          [llvm_i32_ty,        // 7. cta_group
+                           llvm_i32_ty]),      // 8. collector_usage_a
+              !listconcat([IntrArgMemOnly,
+                           WriteOnly<ArgIndex<0>>],
+                          !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+                          [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 16>,
+                           ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 1, 3>,
+                           ImmArg<ArgIndex<8>>, Range<ArgIndex<8>, 0, collector_usage_a_range>]),
+                NVVM_TCGEN05_MMA_NAME<".sp", space, ashift, "." # kind # ".scale_d">.intr>;
+    }
+  }
+}
+
+//
+// tcgen05.mma disable_output_lane intrinsics
+//
+foreach space = ["tensor", "shared"] in {
+  foreach ashiftid = !if(!eq(space, "tensor"), [0, 1], [0]) in {
+  defvar a_operand_type = !if(!eq(space, "tensor"),
+                                llvm_tmem_ptr_ty,
+                                llvm_i64_ty);
+  defvar ashift = !if(!eq(ashiftid, 1), ".ashift", "");
+  defvar collector_usage_a_range = !if(!eq(ashiftid, 1), 2, 4);
+
+  def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 1, ashift>.record:
+        DefaultAttrsIntrinsic<[],
+          !listconcat([llvm_tmem_ptr_ty,   // 0. dtmem
+                       a_operand_type,     // 1. a
+                       llvm_i64_ty,        // 2. b
+                       llvm_i32_ty,        // 3. idesc
+                       llvm_i1_ty,         // 4. enable_inp_d
+                       llvm_v4i32_ty],     // 5. disable output lane
+                      // flags
+                      [llvm_i32_ty,        // 6. kind
+                       llvm_i32_ty]),      // 7. collector_usage_a
+          !listconcat([IntrArgMemOnly,
+                       WriteOnly<ArgIndex<0>>],
+                      !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+                      [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 4>,
+                       ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+          NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 1, ashift>.intr>;
+
+  def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 2, ashift>.record:
+        DefaultAttrsIntrinsic<[],
+          !listconcat([llvm_tmem_ptr_ty,    // 0. dtmem
+                       a_operand_type,      // 1. a
+                       llvm_i64_ty,         // 2. b
+                       llvm_i32_ty,         // 3. idesc
+                       llvm_i1_ty,          // 4. enable_inp_d
+                       llvm_v8i32_ty],      // 5. disable output lane
+                        // flags
+                      [llvm_i32_ty,         // 6. kind
+                       llvm_i32_ty]),       // 7. collector_usage_a
+          !listconcat([IntrArgMemOnly,
+                       WriteOnly<ArgIndex<0>>],
+                       !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+                       [ImmArg<ArgIndex<6>>, Range<ArgIndex<6>, 0, 4>,
+                        ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, collector_usage_a_range>]),
+          NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<"", space, 2, ashift>.intr>;
+
+    def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<".sp", space, 1, ashift>.record:
+          DefaultAttrsIntrinsic<[],
+            !listconcat([llvm_tmem_ptr_ty,   // 0. dtmem
+                         a_operand_type,     // 1. a
+                         llvm_i64_ty,        // 2. b
+                         llvm_i32_ty,        // 3. idesc
+                         llvm_i1_ty,         // 4. enable_inp_d
+                         llvm_tmem_ptr_ty,   // 5. spmetadata
+                         llvm_v4i32_ty],     // 6. disable output lane
+                        // flags
+                        [llvm_i32_ty,        // 7. kind
+                         llvm_i32_ty]),      // 8. collector_usage_a
+            !listconcat([IntrArgMemOnly,
+                         WriteOnly<ArgIndex<0>>],
+                        !if(!eq(space, "tensor"), [ReadOnly<ArgIndex<1>>], []),
+                        [ImmArg<ArgIndex<7>>, Range<ArgIndex<7>, 0, 4>,
+                         ImmArg<ArgIndex<8>>, Range<ArgIndex<8>, 0, collector_usage_a_range>]),
+            NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<".sp", space, 1, ashift>.intr>;
+
+  def NVVM_TCGEN05_MMA_DISABLE_OUTPUT_LANE_NAME<".sp", space, 2, ashift>.record:
+        DefaultAttrsIntrinsic<[],
+          !listconcat([llvm_tmem_ptr_ty,   // 0. dtmem
+                       a_operand_type,     // 1. a
+                       llvm_i64_ty,        // 2. b
+                       llvm_i32_ty,        // 3. idesc
+                       llvm_i1_ty,         // 4. enable_inp_d
+                       llvm_tmem_ptr_ty,   // 5. spmetadata
+                       llvm_v8i32_ty],     // 6. disable output lane
----------------
AlexMaclean wrote:

It seems like using illegal types in these intrinsics has forced you to add a lot of additional logic in NVPTX isel lowering. Is there a reason we need to use these types? Could these vectors be flattened into a series of arguments?

https://github.com/llvm/llvm-project/pull/151949


More information about the llvm-commits mailing list