[llvm] [NVPTX] Move TMA G2S lowering to Tablegen (PR #165710)

Durgadoss R via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 30 07:00:21 PDT 2025


================
@@ -599,75 +599,15 @@ class TMA_IM2COL_UTIL<int dim, string mode> {
   string base_str = !interleave(!foreach(i, !range(offsets), "$im2col" # i), ", ");
 }
 
-// From Global to Shared memory (G2S)
-class G2S_STRINGS<int dim, string mode, bit mc, bit ch, bit is_shared32 = 0> {
-  string prefix = "cp.async.bulk.tensor";
-  string dir = "shared::cluster.global";
-  string completion = "mbarrier::complete_tx::bytes";
-  string inst_name = prefix
-                     # "." # dim # "d"
-                     # "." # dir
-                     # "." # mode
-                     # "." # completion
-                     # !if(mc, ".multicast::cluster", "")
-                     # !if(ch, ".L2::cache_hint", "");
-  string intr_name = "CP_ASYNC_BULK_TENSOR_G2S_"
-                     # dim # "D"
-                     # !if(is_shared32, "_SHARED32", "")
-                     # !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
-}
-
 def CTAGroupFlags : Operand<i32> {
   let PrintMethod = "printCTAGroup";
 }
 
-multiclass CP_ASYNC_BULK_TENSOR_G2S_INTR<int dim, bit is_shared32, string mode> {
-  defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
-  defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
-  defvar asm_str_default = "$cg [$dst], [$tmap, {{" # dims_str # "}}], [$mbar]";
-  defvar rc = !if(is_shared32, B32, B64);
-
-  defvar num_im2col = !if(!ge(dim, 3), !add(dim, -2), 0);
-  defvar im2col_dag = !if(!eq(mode, "im2col"),
-    !dag(ins, !listsplat(B16, num_im2col), !foreach(i, !range(num_im2col), "im2col" # i)),
-    (ins));
-  defvar im2col_str = !interleave(!foreach(i, !range(num_im2col), "$im2col" # i), ", ");
-  defvar im2col_asm_str = ", {{" # im2col_str # "}}";
-
-  defvar asm_str = !if(!eq(mode, "im2col"),
-    !strconcat(asm_str_default, im2col_asm_str), asm_str_default);
+def tma_cta_group_imm0 : TImmLeaf<i32, [{return Imm == 0;}]>;
+def tma_cta_group_imm_any : TImmLeaf<i32, [{return Imm >= 0;}]>;
 
-  def "" : NVPTXInst<(outs),
-            !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag, (ins CTAGroupFlags:$cg)),
-            !strconcat(G2S_STRINGS<dim, mode, 0, 0>.inst_name, asm_str, ";")>,
-            Requires<[hasPTX<80>, hasSM<90>]>;
-  def _MC : NVPTXInst<(outs),
-                  !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag,
-                       (ins B16:$mc, CTAGroupFlags:$cg)),
-                  !strconcat(G2S_STRINGS<dim, mode, 1, 0>.inst_name, asm_str, ", $mc;")>,
-                  Requires<[hasPTX<80>, hasSM<90>]>;
-  def _CH : NVPTXInst<(outs),
-                  !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag,
-                       (ins B64:$ch, CTAGroupFlags:$cg)),
-                  !strconcat(G2S_STRINGS<dim, mode, 0, 1>.inst_name, asm_str, ", $ch;")>,
-                  Requires<[hasPTX<80>, hasSM<90>]>;
-  def _MC_CH : NVPTXInst<(outs),
-                     !con((ins rc:$dst, rc:$mbar, B64:$tmap), dims_dag, im2col_dag,
-                          (ins B16:$mc, B64:$ch, CTAGroupFlags:$cg)),
-                     !strconcat(G2S_STRINGS<dim, mode, 1, 1>.inst_name, asm_str, ", $mc, $ch;")>,
-                     Requires<[hasPTX<80>, hasSM<90>]>;
-}
-
-foreach dim = [1, 2, 3, 4, 5] in {
-  foreach shared32 = [true, false] in {
-    foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
-      defm G2S_STRINGS<dim, mode, 0, 0, shared32>.intr_name :
-        CP_ASYNC_BULK_TENSOR_G2S_INTR<dim, shared32, mode>;
-    }
-  }
-}
-
-multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = []> {
+multiclass TMA_TENSOR_G2S_INTR<int dim, string mode, list<Predicate> pred = [],
----------------
durga4github wrote:

It appears that we are explicitly passing a `pred` in all the uses below. So, we can remove the default value (empty list) here.

https://github.com/llvm/llvm-project/pull/165710


More information about the llvm-commits mailing list