[llvm] [NVPTX][NFC] Move more TMA intrinsics lowering to tablegen (PR #147576)

Durgadoss R via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 9 05:05:45 PDT 2025


================
@@ -628,39 +649,44 @@ foreach dim = [1, 2, 3, 4, 5] in {
   }
 }
 
-// From Shared to Global memory (S2G)
-class S2G_STRINGS<int dim, string mode, bit ch,
-                  bit is_shared32 = 0, bit is_reduce = 0> {
-  string dir = "global.shared::cta";
-  string completion = "bulk_group";
-  string inst_name = !if(is_reduce, "cp.reduce", "cp")
-                     # ".async.bulk.tensor"
-                     # "." # dim # "d"
-                     # "." # dir
-                     # "." # mode
-                     # "." # completion
-                     # !if(ch, ".L2::cache_hint", "");
-  string intr_name = "CP_ASYNC_BULK_TENSOR_"
-                     # !if(is_reduce, "RED_", "S2G_")
-                     # dim # "D"
-                     # !if(is_shared32, "_SHARED32", "")
-                     # !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
-}
-
-multiclass CP_ASYNC_BULK_TENSOR_S2G_INTR<int dim, bit shared32, string mode> {
-  defvar dims_dag = !dag(ins, !listsplat(B32, dim), !foreach(i, !range(dim), "d" # i));
-  defvar dims_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
+multiclass TMA_TENSOR_S2G_INTR<int dim, string mode,
+                               list<Predicate> pred = [hasPTX<80>, hasSM<90>]> {
+  defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
+  defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
   defvar asm_str = " [$tmap, {{" # dims_str # "}}], [$src]";
-  defvar rc = !if(shared32, B32, B64);
+
+  defvar intr = !cast<Intrinsic>(
+                  "int_nvvm_cp_async_bulk_tensor_s2g_" # mode # "_" # dim # d);
+  defvar intr_dag = !con(
+                      (intr addr:$src, B64:$tmap),
+                      !setdagop(dims_dag, intr),
+                      (intr B64:$ch));
+
+  // For im2col mode, the actual asm_str is "im2col_no_offs"
+  defvar mode_asm_str = !if(!eq(mode, "im2col"),
+                            "im2col_no_offs", mode);
+  defvar prefix = "cp.async.bulk.tensor"
+                  # "." # dim # "d"
+                  # ".global.shared::cta"
+                  # "." # mode_asm_str
+                  # ".bulk_group";
 
   def "" : NVPTXInst<(outs),
-            !con((ins rc:$src, B64:$tmap), dims_dag),
-            !strconcat(S2G_STRINGS<dim, mode, 0>.inst_name, asm_str, ";"), []>,
-            Requires<[hasPTX<80>, hasSM<90>]>;
+             !con((ins ADDR:$src, B64:$tmap), dims_dag, (ins B64:$ch)),
+             prefix # asm_str # ";",
+             [!con(intr_dag, (intr 0))]>,
----------------
durga4github wrote:

Agreed. So, I created a per-variant pattern with explicit names and using them here. No inline `!con` s anymore.

Resolving this,

https://github.com/llvm/llvm-project/pull/147576


More information about the llvm-commits mailing list