[llvm] [NVPTX][NFC] Move more TMA intrinsics lowering to tablegen (PR #147576)
Durgadoss R via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 9 05:05:45 PDT 2025
================
@@ -628,39 +649,44 @@ foreach dim = [1, 2, 3, 4, 5] in {
}
}
-// From Shared to Global memory (S2G)
-class S2G_STRINGS<int dim, string mode, bit ch,
- bit is_shared32 = 0, bit is_reduce = 0> {
- string dir = "global.shared::cta";
- string completion = "bulk_group";
- string inst_name = !if(is_reduce, "cp.reduce", "cp")
- # ".async.bulk.tensor"
- # "." # dim # "d"
- # "." # dir
- # "." # mode
- # "." # completion
- # !if(ch, ".L2::cache_hint", "");
- string intr_name = "CP_ASYNC_BULK_TENSOR_"
- # !if(is_reduce, "RED_", "S2G_")
- # dim # "D"
- # !if(is_shared32, "_SHARED32", "")
- # !if(!eq(mode, "tile"), "_TILE", "_IM2COL");
-}
-
-multiclass CP_ASYNC_BULK_TENSOR_S2G_INTR<int dim, bit shared32, string mode> {
- defvar dims_dag = !dag(ins, !listsplat(B32, dim), !foreach(i, !range(dim), "d" # i));
- defvar dims_str = !interleave(!foreach(i, !range(dim), "$d" # i), ", ");
+multiclass TMA_TENSOR_S2G_INTR<int dim, string mode,
+ list<Predicate> pred = [hasPTX<80>, hasSM<90>]> {
+ defvar dims_dag = TMA_DIMS_UTIL<dim>.ins_dag;
+ defvar dims_str = TMA_DIMS_UTIL<dim>.base_str;
defvar asm_str = " [$tmap, {{" # dims_str # "}}], [$src]";
- defvar rc = !if(shared32, B32, B64);
+
+ defvar intr = !cast<Intrinsic>(
+ "int_nvvm_cp_async_bulk_tensor_s2g_" # mode # "_" # dim # d);
+ defvar intr_dag = !con(
+ (intr addr:$src, B64:$tmap),
+ !setdagop(dims_dag, intr),
+ (intr B64:$ch));
+
+ // For im2col mode, the actual asm_str is "im2col_no_offs"
+ defvar mode_asm_str = !if(!eq(mode, "im2col"),
+ "im2col_no_offs", mode);
+ defvar prefix = "cp.async.bulk.tensor"
+ # "." # dim # "d"
+ # ".global.shared::cta"
+ # "." # mode_asm_str
+ # ".bulk_group";
def "" : NVPTXInst<(outs),
- !con((ins rc:$src, B64:$tmap), dims_dag),
- !strconcat(S2G_STRINGS<dim, mode, 0>.inst_name, asm_str, ";"), []>,
- Requires<[hasPTX<80>, hasSM<90>]>;
+ !con((ins ADDR:$src, B64:$tmap), dims_dag, (ins B64:$ch)),
+ prefix # asm_str # ";",
+ [!con(intr_dag, (intr 0))]>,
----------------
durga4github wrote:
Agreed. So, I created a per-variant pattern with explicit names and using them here. No inline `!con` s anymore.
Resolving this,
https://github.com/llvm/llvm-project/pull/147576
More information about the llvm-commits
mailing list