[llvm] [NVPTX] Add NVPTX intrinsics for TMA copies (PR #95289)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 12 11:43:51 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir
Author: Adam Paszke (apaszke)
<details>
<summary>Changes</summary>
This is necessary to be able to pass in TMA descriptors through `byval` kernel parameters without having `NVPTXLowerArgs` insert an extra copy. While they can be passed in through global memory, this is the recommended approach that is also used by CUTLASS.
I think the code in this PR should be ready, but obviously it's missing tests. I'd welcome pointers to where I should add those. Until now I have tested the code with my own compiler and it has worked great so far.
---
Full diff: https://github.com/llvm/llvm-project/pull/95289.diff
3 Files Affected:
- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+24)
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+28)
- (modified) llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp (+78-13)
``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 0a9139e0062ba..a210a208d01c0 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1448,6 +1448,26 @@ defm int_nvvm_cp_async_ca_shared_global_8 : CP_ASYNC_SHARED_GLOBAL<"8", "ca">;
defm int_nvvm_cp_async_ca_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "ca">;
defm int_nvvm_cp_async_cg_shared_global_16 : CP_ASYNC_SHARED_GLOBAL<"16", "cg">;
+// TODO(apaszke): Multicast TMA loads
+foreach dim = [1, 2, 3, 4, 5] in {
+ def int_nvvm_cp_async_bulk_tensor_ # dim # d_shared_cluster_global_tile_mbarrier_complete_tx_bytes :
+ Intrinsic<
+ [],
+ [llvm_shared_ptr_ty, llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_anyptr_ty],
+ [IntrArgMemOnly, IntrNoCallback,
+ NoAlias<ArgIndex<0>>, NoAlias<ArgIndex<1>>, NoAlias<ArgIndex<!add(2, dim)>>,
+ WriteOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>],
+ "llvm.nvvm.cp.async.bulk.tensor." # dim # "d.shared_cluster.global.tile.mbarrier_complete_tx_bytes">;
+ def int_nvvm_cp_async_bulk_tensor_ # dim # d_global_shared_cta_tile_bulk_group :
+ Intrinsic<
+ [],
+ [llvm_anyptr_ty] # !listsplat(llvm_i32_ty, dim) # [llvm_shared_ptr_ty],
+ [IntrNoCallback,
+ NoAlias<ArgIndex<0>>, NoAlias<ArgIndex<!add(1, dim)>>,
+ ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<!add(1, dim)>>],
+ "llvm.nvvm.cp.async.bulk.tensor." # dim # "d.global.shared_cta.tile.bulk_group">;
+}
+
def int_nvvm_cp_async_commit_group :
ClangBuiltin<"__nvvm_cp_async_commit_group">,
Intrinsic<[],[],[]>;
@@ -1595,6 +1615,10 @@ def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty],
[llvm_anyptr_ty],
[IntrNoMem, IntrSpeculatable, IntrNoCallback],
"llvm.nvvm.ptr.gen.to.param">;
+def int_nvvm_ptr_param_to_gen: Intrinsic<[llvm_anyptr_ty],
+ [llvm_anyptr_ty],
+ [IntrNoMem, IntrSpeculatable, IntrNoCallback],
+ "llvm.nvvm.ptr.param.to.gen">;
// Move intrinsics, used in nvvm internally
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 440af085cb8e9..e2a565defb95b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -403,6 +403,33 @@ defm CP_ASYNC_CG_SHARED_GLOBAL_16 :
CP_ASYNC_SHARED_GLOBAL_I<"cg", "16", int_nvvm_cp_async_cg_shared_global_16,
int_nvvm_cp_async_cg_shared_global_16_s>;
+foreach dim = [1, 2, 3, 4, 5] in {
+ defvar idx_ptx = !interleave(!foreach(i, !range(dim), "$idx" # i), ", ");
+ defvar idx_dag = !dag(ins, !listsplat(Int32Regs, dim), !foreach(i, !range(dim), "idx" # i));
+ defvar intrinsic_g2s = !cast<Intrinsic>("int_nvvm_cp_async_bulk_tensor_" # dim # "d_shared_cluster_global_tile_mbarrier_complete_tx_bytes");
+ def CP_ASYNC_BULK_TENSOR_ # dim # D_SHARED_CLUSTER_GLOBAL_TILE_MBARRIER_COMPLETE_TX_BYTES_64 :
+ NVPTXInst<
+ (outs),
+ !con((ins Int64Regs:$dst, Int64Regs:$desc), idx_dag, (ins Int64Regs:$mbar)),
+ "cp.async.bulk.tensor." # dim # "d.shared::cluster.global.tile.mbarrier::complete_tx::bytes [$dst], [$desc, {{" # idx_ptx # "}}], [$mbar];",
+ [!con((intrinsic_g2s Int64Regs:$dst, Int64Regs:$desc),
+ !setdagop(idx_dag, intrinsic_g2s),
+ (intrinsic_g2s Int64Regs:$mbar))]
+ >,
+ Requires<[hasPTX<80>, hasSM<90>]>;
+ defvar intrinsic_s2g = !cast<Intrinsic>("int_nvvm_cp_async_bulk_tensor_" # dim # "d_global_shared_cta_tile_bulk_group");
+ def CP_ASYNC_BULK_TENSOR_ # dim # D_GLOBAL_SHARED_CTA_TILE_BULK_GROUP_64 :
+ NVPTXInst<
+ (outs),
+ !con((ins Int64Regs:$desc), idx_dag, (ins Int64Regs:$dst)),
+ "cp.async.bulk.tensor." # dim # "d.global.shared::cta.tile.bulk_group [$desc, {{" # idx_ptx # "}}], [$dst];",
+ [!con((intrinsic_s2g Int64Regs:$desc),
+ !setdagop(idx_dag, intrinsic_s2g),
+ (intrinsic_s2g Int64Regs:$dst))]
+ >,
+ Requires<[hasPTX<80>, hasSM<90>]>;
+}
+
def CP_ASYNC_COMMIT_GROUP :
NVPTXInst<(outs), (ins), "cp.async.commit_group;", [(int_nvvm_cp_async_commit_group)]>,
Requires<[hasPTX<70>, hasSM<80>]>;
@@ -2475,6 +2502,7 @@ defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>
defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>;
defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>;
defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>;
+defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>;
defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>;
defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index cde02c25c4834..06eb2ba848762 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -94,12 +94,17 @@
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
+#include "llvm/IR/Use.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include <cassert>
#include <numeric>
#include <queue>
@@ -146,6 +151,28 @@ INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
"Lower arguments (NVPTX)", false, false)
+static std::optional<int> tmaDescriptorOperandIndex(Instruction *I) {
+ if (auto *II = dyn_cast<IntrinsicInst>(I)) {
+ switch (II->getIntrinsicID()) {
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_shared_cluster_global_tile_mbarrier_complete_tx_bytes:
+ return 1;
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_1d_global_shared_cta_tile_bulk_group:
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_2d_global_shared_cta_tile_bulk_group:
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_3d_global_shared_cta_tile_bulk_group:
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_4d_global_shared_cta_tile_bulk_group:
+ case llvm::Intrinsic::nvvm_cp_async_bulk_tensor_5d_global_shared_cta_tile_bulk_group:
+ return 0;
+ default:
+ return std::nullopt;
+ }
+ }
+ return std::nullopt;
+}
+
// =============================================================================
// If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
// and we can't guarantee that the only accesses are loads,
@@ -166,14 +193,15 @@ INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
// Replaces the \p OldUser instruction with the same in parameter AS.
// Only Load and GEP are supported.
-static void convertToParamAS(Value *OldUser, Value *Param) {
+static void convertToParamAS(Value *OldUser, Value *OldParam, Value *NewParam) {
Instruction *I = dyn_cast<Instruction>(OldUser);
assert(I && "OldUser must be an instruction");
struct IP {
Instruction *OldInstruction;
+ Value *OldParam;
Value *NewParam;
};
- SmallVector<IP> ItemsToConvert = {{I, Param}};
+ SmallVector<IP> ItemsToConvert = {{I, OldParam, NewParam}};
SmallVector<Instruction *> InstructionsToDelete;
auto CloneInstInParamAS = [](const IP &I) -> Value * {
@@ -200,6 +228,28 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
// Just pass through the argument, the old ASC is no longer needed.
return I.NewParam;
}
+ if (auto *II = dyn_cast<IntrinsicInst>(I.OldInstruction)) {
+ // Assert that this is a TMA intrinsic.
+ assert(tmaDescriptorOperandIndex(II).has_value());
+ assert(I.OldInstruction->getOperand(*tmaDescriptorOperandIndex(II)) ==
+ I.OldParam);
+ // TMA descriptors can remain in param memory space, but need to be passed
+ // in the generic address space.
+ Type *ParamPtr = PointerType::get(II->getContext(), ADDRESS_SPACE_PARAM);
+ Type *GenericPtr =
+ PointerType::get(II->getContext(), ADDRESS_SPACE_GENERIC);
+ FunctionType *cast_func_ty =
+ FunctionType::get(GenericPtr, {ParamPtr}, false);
+ Module *M = I.OldInstruction->getModule();
+ FunctionCallee func =
+ M->getOrInsertFunction(getName(llvm::Intrinsic::nvvm_ptr_param_to_gen,
+ {GenericPtr, ParamPtr}, M),
+ cast_func_ty);
+ Instruction *NewInGeneric =
+ CallInst::Create(func, {I.NewParam}, "", II->getIterator());
+ II->replaceUsesOfWith(I.OldParam, NewInGeneric);
+ return II;
+ }
llvm_unreachable("Unsupported instruction");
};
@@ -212,7 +262,8 @@ static void convertToParamAS(Value *OldUser, Value *Param) {
// be converted and the instruction itself to be deleted. We can't delete
// the old instruction yet, because it's still in use by a load somewhere.
for (Value *V : I.OldInstruction->users())
- ItemsToConvert.push_back({cast<Instruction>(V), NewInst});
+ ItemsToConvert.push_back(
+ {cast<Instruction>(V), I.OldInstruction, NewInst});
InstructionsToDelete.push_back(I.OldInstruction);
}
@@ -300,9 +351,13 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
Worklist.push({I, Ctx.Offset + Offset});
continue;
}
+ if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) {
+ assert(tmaDescriptorOperandIndex(II).has_value());
+ continue;
+ }
llvm_unreachable("All users must be one of: load, "
- "bitcast, getelementptr.");
+ "bitcast, getelementptr, TMA intrinsic.");
}
}
@@ -321,8 +376,11 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
assert(StructType && "Missing byval type");
auto IsALoadChain = [&](Value *Start) {
- SmallVector<Value *, 16> ValuesToCheck = {Start};
- auto IsALoadChainInstr = [](Value *V) -> bool {
+ SmallVector<Use*, 16> UsesToCheck;
+ for (Use& u : Start->uses())
+ UsesToCheck.push_back(&u);
+ auto IsSupportedUse = [](Use *U) -> bool {
+ Value *V = U->get();
if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
return true;
// ASC to param space are OK, too -- we'll just strip them.
@@ -330,19 +388,26 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
return true;
}
+ // TMA descriptors passed to TMA intrinsics are OK, too.
+ if (auto *II = dyn_cast<IntrinsicInst>(V)) {
+ auto OI = tmaDescriptorOperandIndex(II);
+ return OI.has_value() && *OI == U->getOperandNo();
+ }
return false;
};
- while (!ValuesToCheck.empty()) {
- Value *V = ValuesToCheck.pop_back_val();
- if (!IsALoadChainInstr(V)) {
- LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << *V
+ while (!UsesToCheck.empty()) {
+ Use* U = UsesToCheck.pop_back_val();
+ if (!IsSupportedUse(U)) {
+ LLVM_DEBUG(dbgs() << "Need a copy of " << *Arg << " because of " << U
<< "\n");
(void)Arg;
return false;
}
- if (!isa<LoadInst>(V))
- llvm::append_range(ValuesToCheck, V->users());
+ if (!isa<LoadInst>(U)) {
+ for (Use& u : U->getUser()->uses())
+ UsesToCheck.push_back(&u);
+ }
}
return true;
};
@@ -355,7 +420,7 @@ void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
FirstInst);
for (Value *V : UsersToUpdate)
- convertToParamAS(V, ArgInParamAS);
+ convertToParamAS(V, Arg, ArgInParamAS);
LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
const auto *TLI =
``````````
</details>
https://github.com/llvm/llvm-project/pull/95289
More information about the llvm-commits
mailing list