[llvm] [NVPTX] Add NVPTX intrinsics for TMA copies (PR #95289)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 12 11:43:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-ir

Author: Adam Paszke (apaszke)

<details>
<summary>Changes</summary>

This is necessary to be able to pass in TMA descriptors through `byval` kernel parameters without having `NVPTXLowerArgs` insert an extra copy. While they can be passed in through global memory, this is the recommended approach that is also used by CUTLASS.

I think the code in this PR should be ready, but obviously it's missing tests. I'd welcome pointers to where I should add those. Until now I have tested the code with my own compiler and it has worked great so far.

---
Full diff: https://github.com/llvm/llvm-project/pull/95289.diff


3 Files Affected:

- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+24) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+28) 
- (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+78-13) 


``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0a9139e0062ba..a210a208d01c0 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1448,6 +1448,26 @@ defm int_nvvm_cp_async_ca_shared_global_8 : CP_ASYNC_SHARED_GLOBAL<"8", "ca">;
 defm int_nvvm_cp_async_ca_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "ca">;
 defm int_nvvm_cp_async_cg_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "cg">;
 
+// TODO(apaszke): Multicast TMA loads
+foreach dim = [1, 2, 3, 4, 5] in {
+  def int_nvvm_cp_async_bulk_tensor_ # dim # d_shared_cluster_global_tile_mbarrier_complete_tx_bytes :
+    Intrinsic<
+      [],
+      [llvm_shared_ptr_ty, llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_anyptr_ty],
+      [IntrArgMemOnly, IntrNoCallback,
+       NoAlias<ArgIndex<0>>, NoAlias<ArgIndex<1>>, NoAlias<ArgIndex<!add(2,  dim)>>,
+       WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>],
+      "llvm.nvvm.cp.async.bulk.tensor." # dim # "d.shared_cluster.global.tile.mbarrier_complete_tx_bytes">;
+  def int_nvvm_cp_async_bulk_tensor_ # dim # d_global_shared_cta_tile_bulk_group :
+    Intrinsic<
+      [],
+      [llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_shared_ptr_ty],
+      [IntrNoCallback,
+       NoAlias<ArgIndex<0>>, NoAlias<ArgIndex<!add(1, dim)>>,
+       ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<!add(1, dim)>>],
+      "llvm.nvvm.cp.async.bulk.tensor." # dim # "d.global.shared_cta.tile.bulk_group">;
+}
+
 def int_nvvm_cp_async_commit_group :
     ClangBuiltin<"__nvvm_cp_async_commit_group">,
     Intrinsic<[],[],[]>;
@@ -1595,6 +1615,10 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
                                      [llvm_anyptr_ty],
                                    [IntrNoMem, IntrSpeculatable, IntrNoCallback],
                                    "llvm.nvvm.ptr.gen.to.param">;
+def int_nvvm_ptr_param_to_gen: Intrinsic<[llvm_anyptr_ty],
+                                     [llvm_anyptr_ty],
+                                   [IntrNoMem, IntrSpeculatable, IntrNoCallback],
+                                   "llvm.nvvm.ptr.param.to.gen">;
 
 // Move intrinsics, used in nvvm internally
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 440af085cb8e9..e2a565defb95b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -403,6 +403,33 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 :
   CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16,
                                        int_nvvm_cp_async_cg_shared_global_16_s>;
 
+foreach dim = [1, 2, 3, 4, 5] in {
+  defvar idx_ptx = !interleave(!foreach(i, !range(dim), "$idx" # i), ", ");
+  defvar idx_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "idx" # i));
+  defvar intrinsic_g2s = !cast<Intrinsic>("int_nvvm_cp_async_bulk_tensor_" # dim # "d_shared_cluster_global_tile_mbarrier_complete_tx_bytes");
+  def CP_ASYNC_BULK_TENSOR_ # dim # D_SHARED_CLUSTER_GLOBAL_TILE_MBARRIER_COMPLETE_TX_BYTES_64 :
+    NVPTXInst<
+      (outs),
+      !con((ins Int64Regs:$dst, Int64Regs:$desc), idx_dag, (ins Int64Regs:$mbar)),
+      "cp.async.bulk.tensor." # dim # "d.shared::cluster.global.tile.mbarrier::complete_tx::bytes [$dst], [$desc, {{" # idx_ptx # "}}], [$mbar];",
+      [!con((intrinsic_g2s Int64Regs:$dst, Int64Regs:$desc),
+            !setdagop(idx_dag, intrinsic_g2s),
+            (intrinsic_g2s Int64Regs:$mbar))]
+    >,
+    Requires<[hasPTX<80>, hasSM<90>]>;
+  defvar intrinsic_s2g = !cast<Intrinsic>("int_nvvm_cp_async_bulk_tensor_" # dim # "d_global_shared_cta_tile_bulk_group");
+  def CP_ASYNC_BULK_TENSOR_ # dim # D_GLOBAL_SHARED_CTA_TILE_BULK_GROUP_64 :
+    NVPTXInst<
+      (outs),
+      !con((ins Int64Regs:$desc), idx_dag, (ins Int64Regs:$dst)),
+      "cp.async.bulk.tensor." # dim # "d.global.shared::cta.tile.bulk_group [$desc, {{" # idx_ptx # "}}], [$dst];",
+      [!con((intrinsic_s2g Int64Regs:$desc),
+            !setdagop(idx_dag, intrinsic_s2g),
+            (intrinsic_s2g Int64Regs:$dst))]
+    >,
+    Requires<[hasPTX<80>, hasSM<90>]>;
+}
+
 def CP_ASYNC_COMMIT_GROUP :
   NVPTXInst<(outs), (ins), "cp.async.commit_group;", [(int_nvvm_cp_async_commit_group)]>,
   Requires<[hasPTX<70>, hasSM<80>]>;
@@ -2475,6 +2502,7 @@ defm cvta_local  : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>
 defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
 defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
 defm cvta_const  : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;
+defm cvta_param  : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>;
 
 defm cvta_to_local  : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
 defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index cde02c25c4834..06eb2ba848762 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -94,12 +94,17 @@
 #include "NVPTXUtilities.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
+#include "llvm/IR/Use.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include <cassert>
 #include <numeric>
 #include <queue>
 
@@ -146,6 +151,28 @@ INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
 INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
                     "Lower arguments (NVPTX)", false, false)
 
+static std::optional<int> tmaDescriptorOperandIndex(Instruction *I) {
+  if (auto *II = dyn_cast<IntrinsicInst>(I)) {
+    switch (II->getIntrinsicID()) {
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+      return 1;
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_global_shared_cta_tile_bulk_group:
+    case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_global_shared_cta_tile_bulk_group:
+      return 0;
+    default:
+      return std::nullopt;
+    }
+  }
+  return std::nullopt;
+}
+
 // =============================================================================
 // If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
 // and we can't guarantee that the only accesses are loads,
@@ -166,14 +193,15 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
 
 // Replaces the \p OldUser instruction with the same in parameter AS.
 // Only Load and GEP are supported.
-static void convertToParamAS(Value *OldUser, Value *Param) {
+static void convertToParamAS(Value *OldUser, Value *OldParam, Value *NewParam) {
   Instruction *I = dyn_cast<Instruction>(OldUser);
   assert(I && "OldUser must be an instruction");
   struct IP {
     Instruction *OldInstruction;
+    Value *OldParam;
     Value *NewParam;
   };
-  SmallVector<IP> ItemsToConvert = {{I, Param}};
+  SmallVector<IP> ItemsToConvert = {{I, OldParam, NewParam}};
   SmallVector<Instruction *> InstructionsToDelete;
 
   auto CloneInstInParamAS = [](const IP &I) -> Value * {
@@ -200,6 +228,28 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
       // Just pass through the argument, the old ASC is no longer needed.
       return I.NewParam;
     }
+    if (auto *II = dyn_cast<IntrinsicInst>(I.OldInstruction)) {
+      // Assert that this is a TMA intrinsic.
+      assert(tmaDescriptorOperandIndex(II).has_value());
+      assert(I.OldInstruction->getOperand(*tmaDescriptorOperandIndex(II)) ==
+             I.OldParam);
+      // TMA descriptors can remain in param memory space, but need to be passed
+      // in the generic address space.
+      Type *ParamPtr = PointerType::get(II->getContext(), ADDRESS_SPACE_PARAM);
+      Type *GenericPtr =
+          PointerType::get(II->getContext(), ADDRESS_SPACE_GENERIC);
+      FunctionType *cast_func_ty =
+          FunctionType::get(GenericPtr, {ParamPtr}, false);
+      Module *M = I.OldInstruction->getModule();
+      FunctionCallee func =
+          M->getOrInsertFunction(getName(llvm::Intrinsic::nvvm_ptr_param_to_gen,
+                                         {GenericPtr, ParamPtr}, M),
+                                 cast_func_ty);
+      Instruction *NewInGeneric =
+          CallInst::Create(func, {I.NewParam}, "", II->getIterator());
+      II->replaceUsesOfWith(I.OldParam, NewInGeneric);
+      return II;
+    }
     llvm_unreachable("Unsupported instruction");
   };
 
@@ -212,7 +262,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
       // be converted and the instruction itself to be deleted. We can't delete
       // the old instruction yet, because it's still in use by a load somewhere.
       for (Value *V : I.OldInstruction->users())
-        ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
+        ItemsToConvert.push_back(
+            {cast<Instruction>(V), I.OldInstruction, NewInst});
 
       InstructionsToDelete.push_back(I.OldInstruction);
     }
@@ -300,9 +351,13 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
         Worklist.push({I, Ctx.Offset + Offset});
         continue;
       }
+      if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
+        assert(tmaDescriptorOperandIndex(II).has_value());
+        continue;
+      }
 
       llvm_unreachable("All users must be one of: load, "
-                       "bitcast, getelementptr.");
+                       "bitcast, getelementptr, TMA intrinsic.");
     }
   }
 
@@ -321,8 +376,11 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
   assert(StructType && "Missing byval type");
 
   auto IsALoadChain = [&](Value *Start) {
-    SmallVector<Value *, 16> ValuesToCheck = {Start};
-    auto IsALoadChainInstr = [](Value *V) -> bool {
+    SmallVector<Use*, 16> UsesToCheck;
+    for (Use& u : Start->uses())
+      UsesToCheck.push_back(&u);
+    auto IsSupportedUse = [](Use *U) -> bool {
+      Value *V = U->get();
       if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
         return true;
       // ASC to param space are OK, too -- we'll just strip them.
@@ -330,19 +388,26 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
         if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
           return true;
       }
+      // TMA descriptors passed to TMA intrinsics are OK, too.
+      if (auto *II = dyn_cast<IntrinsicInst>(V)) {
+        auto OI = tmaDescriptorOperandIndex(II);
+        return OI.has_value() && *OI == U->getOperandNo();
+      }
       return false;
     };
 
-    while (!ValuesToCheck.empty()) {
-      Value *V = ValuesToCheck.pop_back_val();
-      if (!IsALoadChainInstr(V)) {
-        LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
+    while (!UsesToCheck.empty()) {
+      Use* U = UsesToCheck.pop_back_val();
+      if (!IsSupportedUse(U)) {
+        LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << U
                           << "\n");
         (void)Arg;
         return false;
       }
-      if (!isa<LoadInst>(V))
-        llvm::append_range(ValuesToCheck, V->users());
+      if (!isa<LoadInst>(U)) {
+        for (Use& u : U->getUser()->uses())
+          UsesToCheck.push_back(&u);
+      }
     }
     return true;
   };
@@ -355,7 +420,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
         Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
         FirstInst);
     for (Value *V : UsersToUpdate)
-      convertToParamAS(V, ArgInParamAS);
+      convertToParamAS(V, Arg, ArgInParamAS);
     LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
 
     const auto *TLI =

``````````

</details>


https://github.com/llvm/llvm-project/pull/95289


More information about the llvm-commits mailing list