[llvm] [NVPTX] Add intrinsics and codegen for tensormap.replace (PR #172458)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 16 03:17:37 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Srinivasa Ravi (Wolfram70)
<details>
<summary>Changes</summary>
This change adds NVVM intrinsics and NVPTX codegen for the
`tensormap.replace` PTX instructions.
Tests are added in `tensormap_replace.ll`, `tensormap_replace_sm_100a.ll`,
and `tensormap_replace_sm_103a.ll` and tested through `ptxas-13.0`.
PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-tensormap-replace
---
Patch is 40.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172458.diff
9 Files Affected:
- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+52)
- (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+51)
- (modified) llvm/lib/IR/NVVMIntrinsicUtils.cpp (+103)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+91)
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+64)
- (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+29)
- (added) llvm/test/CodeGen/NVPTX/tensormap_replace.ll (+263)
- (added) llvm/test/CodeGen/NVPTX/tensormap_replace_sm_100a.ll (+60)
- (added) llvm/test/CodeGen/NVPTX/tensormap_replace_sm_103a.ll (+19)
``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index aab85c2a86373..a82af450b35b3 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -3312,4 +3312,56 @@ foreach sp = [0, 1] in {
}
}
+//
+// tensormap.replace intrinsics
+//
+
+let IntrProperties = [IntrArgMemOnly, IntrWriteMem, NoCapture<ArgIndex<0>>] in {
+ def int_nvvm_tensormap_replace_global_address :
+ DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i64_ty], []>;
+ def int_nvvm_tensormap_replace_rank :
+ DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty], []>;
+}
+
+let IntrProperties = [IntrArgMemOnly, ImmArg<ArgIndex<1>>, IntrWriteMem,
+ NoCapture<ArgIndex<0>>] in {
+ def int_nvvm_tensormap_replace_global_stride :
+ DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty, llvm_i64_ty], []>;
+ foreach tmap_field = ["box_dim", "global_dim", "element_stride"] in {
+ def int_nvvm_tensormap_replace_ # tmap_field :
+ DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty, llvm_i32_ty], []>;
+ }
+}
+
+def int_nvvm_tensormap_replace_elemtype :
+ DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty],
+ [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+ Range<ArgIndex<1>, 0, 16>,
+ ArgInfo<ArgIndex<1>, [ArgName<"elemtype">,
+ ImmArgPrinter<"printTensormapElemType">]>]>;
+def int_nvvm_tensormap_replace_interleave_layout :
+ DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty],
+ [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+ Range<ArgIndex<1>, 0, 3>,
+ ArgInfo<ArgIndex<1>, [ArgName<"interleave_layout">,
+ ImmArgPrinter<"printTensormapInterleaveLayout">]>]>;
+def int_nvvm_tensormap_replace_swizzle_mode :
+ DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty],
+ [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+ Range<ArgIndex<1>, 0, 5>,
+ ArgInfo<ArgIndex<1>, [ArgName<"swizzle_mode">,
+ ImmArgPrinter<"printTensormapSwizzleMode">]>]>;
+def int_nvvm_tensormap_replace_swizzle_atomicity :
+ DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty],
+ [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+ Range<ArgIndex<1>, 0, 4>,
+ ArgInfo<ArgIndex<1>, [ArgName<"swizzle_atomicity">,
+ ImmArgPrinter<"printTensormapSwizzleAtomicity">]>]>;
+def int_nvvm_tensormap_replace_fill_mode :
+ DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty],
+ [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+ Range<ArgIndex<1>, 0, 2>,
+ ArgInfo<ArgIndex<1>, [ArgName<"fill_mode">,
+ ImmArgPrinter<"printTensormapFillMode">]>]>;
+
} // let TargetPrefix = "nvvm"
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index 62f2a249b1357..067290e57245a 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -59,10 +59,61 @@ enum class Tcgen05CollectorUsageOp : uint8_t {
USE = 3,
};
+enum class TensormapElemType : uint8_t {
+ U8 = 0,
+ U16 = 1,
+ U32 = 2,
+ S32 = 3,
+ U64 = 4,
+ S64 = 5,
+ F16 = 6,
+ F32 = 7,
+ F32_FTZ = 8,
+ F64 = 9,
+ BF16 = 10,
+ TF32 = 11,
+ TF32_FTZ = 12,
+ B4x16 = 13,
+ B4x16_p64 = 14,
+ B6x16_p32 = 15,
+};
+
+enum class TensormapInterleaveLayout : uint8_t {
+ NO_INTERLEAVE = 0,
+ INTERLEAVE_16B = 1,
+ INTERLEAVE_32B = 2,
+};
+
+enum class TensormapSwizzleMode : uint8_t {
+ NO_SWIZZLE = 0,
+ SWIZZLE_32B = 1,
+ SWIZZLE_64B = 2,
+ SWIZZLE_128B = 3,
+ SWIZZLE_96B = 4,
+};
+
+enum class TensormapSwizzleAtomicity : uint8_t {
+ SWIZZLE_ATOMICITY_16B = 0,
+ SWIZZLE_ATOMICITY_32B = 1,
+ SWIZZLE_ATOMICITY_32B_FLIP_8B = 2,
+ SWIZZLE_ATOMICITY_64B = 3,
+};
+
+enum class TensormapFillMode : uint8_t {
+ ZERO_FILL = 0,
+ OOB_NAN_FILL = 1,
+};
+
void printTcgen05MMAKind(raw_ostream &OS, const Constant *ImmArgVal);
void printTcgen05CollectorUsageOp(raw_ostream &OS, const Constant *ImmArgVal);
+void printTensormapElemType(raw_ostream &OS, const Constant *ImmArgVal);
+void printTensormapInterleaveLayout(raw_ostream &OS, const Constant *ImmArgVal);
+void printTensormapSwizzleMode(raw_ostream &OS, const Constant *ImmArgVal);
+void printTensormapSwizzleAtomicity(raw_ostream &OS, const Constant *ImmArgVal);
+void printTensormapFillMode(raw_ostream &OS, const Constant *ImmArgVal);
+
inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
switch (IntrinsicID) {
case Intrinsic::nvvm_f2i_rm_ftz:
diff --git a/llvm/lib/IR/NVVMIntrinsicUtils.cpp b/llvm/lib/IR/NVVMIntrinsicUtils.cpp
index 4389fa38ad3af..2c939ff0ca08a 100644
--- a/llvm/lib/IR/NVVMIntrinsicUtils.cpp
+++ b/llvm/lib/IR/NVVMIntrinsicUtils.cpp
@@ -10,6 +10,7 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/ADT/StringRef.h"
#include "llvm/IR/NVVMIntrinsicUtils.h"
using namespace llvm;
@@ -59,3 +60,105 @@ void nvvm::printTcgen05CollectorUsageOp(raw_ostream &OS,
llvm_unreachable("printTcgen05CollectorUsageOp called with invalid value for "
"immediate argument");
}
+
+void nvvm::printTensormapElemType(raw_ostream &OS, const Constant *ImmArgVal) {
+ static constexpr StringRef TensormapElemTypes[] = {
+ "u8", "u16", "u32", "s32", "u64", "s64",
+ "f16", "f32", "f32.ftz", "f64", "bf16", "tf32",
+ "tf32.ftz", "b4x16", "b4x16_p64", "b6x16_p32"};
+ if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+ uint64_t Val = CI->getZExtValue();
+ if (Val <= static_cast<uint64_t>(nvvm::TensormapElemType::B6x16_p32)) {
+ OS << TensormapElemTypes[Val];
+ return;
+ }
+ }
+ llvm_unreachable("printTensormapElemType called with invalid value for "
+ "immediate argument");
+}
+
+void nvvm::printTensormapInterleaveLayout(raw_ostream &OS,
+ const Constant *ImmArgVal) {
+ if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+ uint64_t Val = CI->getZExtValue();
+ switch (static_cast<TensormapInterleaveLayout>(Val)) {
+ case TensormapInterleaveLayout::NO_INTERLEAVE:
+ OS << "No interleave";
+ return;
+ case TensormapInterleaveLayout::INTERLEAVE_16B:
+ OS << "16B interleave";
+ return;
+ case TensormapInterleaveLayout::INTERLEAVE_32B:
+ OS << "32B interleave";
+ return;
+ }
+ }
+ llvm_unreachable(
+ "printTensormapInterleaveLayout called with invalid value for "
+ "immediate argument");
+}
+
+void nvvm::printTensormapSwizzleMode(raw_ostream &OS,
+ const Constant *ImmArgVal) {
+ if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+ uint64_t Val = CI->getZExtValue();
+ switch (static_cast<TensormapSwizzleMode>(Val)) {
+ case TensormapSwizzleMode::NO_SWIZZLE:
+ OS << "No swizzling";
+ return;
+ case TensormapSwizzleMode::SWIZZLE_32B:
+ OS << "32B swizzling";
+ return;
+ case TensormapSwizzleMode::SWIZZLE_64B:
+ OS << "64B swizzling";
+ return;
+ case TensormapSwizzleMode::SWIZZLE_128B:
+ OS << "128B swizzling";
+ return;
+ case TensormapSwizzleMode::SWIZZLE_96B:
+ OS << "96B swizzling";
+ return;
+ }
+ }
+ llvm_unreachable("printTensormapSwizzleMode called with invalid value for "
+ "immediate argument");
+}
+
+void nvvm::printTensormapSwizzleAtomicity(raw_ostream &OS,
+ const Constant *ImmArgVal) {
+ if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+ uint64_t Val = CI->getZExtValue();
+ switch (static_cast<TensormapSwizzleAtomicity>(Val)) {
+ case TensormapSwizzleAtomicity::SWIZZLE_ATOMICITY_16B:
+ OS << "16B";
+ return;
+ case TensormapSwizzleAtomicity::SWIZZLE_ATOMICITY_32B:
+ OS << "32B";
+ return;
+ case TensormapSwizzleAtomicity::SWIZZLE_ATOMICITY_32B_FLIP_8B:
+ OS << "32B + 8B flip";
+ return;
+ case TensormapSwizzleAtomicity::SWIZZLE_ATOMICITY_64B:
+ OS << "64B";
+ return;
+ }
+ }
+ llvm_unreachable(
+ "printTensormapSwizzleAtomicity called with invalid value for "
+ "immediate argument");
+}
+
+void nvvm::printTensormapFillMode(raw_ostream &OS, const Constant *ImmArgVal) {
+ if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+ uint64_t Val = CI->getZExtValue();
+ if (Val == static_cast<uint64_t>(TensormapFillMode::ZERO_FILL)) {
+ OS << "Zero fill";
+ return;
+ } else if (Val == static_cast<uint64_t>(TensormapFillMode::OOB_NAN_FILL)) {
+ OS << "OOB-NaN fill";
+ return;
+ }
+ }
+ llvm_unreachable("printTensormapFillMode called with invalid value for "
+ "immediate argument");
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 92f3865818530..a8d5da3407a67 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2746,6 +2746,64 @@ lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
return {{BuildVector, Chain}};
}
+static SDValue lowerTensormapReplaceElemtype(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+ SDLoc DL(N);
+ unsigned Val = N->getConstantOperandVal(3);
+
+ if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceElemtypeSupport(
+ Val)) {
+ const Function &Fn = DAG.getMachineFunction().getFunction();
+
+ unsigned AS = 0;
+ if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(N)) {
+ AS = MemN->getAddressSpace();
+ }
+ Type *PtrTy = PointerType::get(*DAG.getContext(), AS);
+ Module *M = DAG.getMachineFunction().getFunction().getParent();
+
+ DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
+ Fn,
+ "Intrinsic " +
+ Intrinsic::getName(N->getConstantOperandVal(1), {PtrTy}, M) +
+ " with elemtype " + Twine(Val) +
+ " is not supported on the given target.",
+ DL.getDebugLoc()));
+ return Op.getOperand(0);
+ }
+
+ return Op;
+}
+
+static SDValue lowerTensormapReplaceSwizzleMode(SDValue Op, SelectionDAG &DAG) {
+ SDNode *N = Op.getNode();
+ SDLoc DL(N);
+ unsigned Val = N->getConstantOperandVal(3);
+
+ if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceSwizzleModeSupport(
+ Val)) {
+ const Function &Fn = DAG.getMachineFunction().getFunction();
+
+ unsigned AS = 0;
+ if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(N)) {
+ AS = MemN->getAddressSpace();
+ }
+ Type *PtrTy = PointerType::get(*DAG.getContext(), AS);
+ Module *M = DAG.getMachineFunction().getFunction().getParent();
+
+ DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
+ Fn,
+ "Intrinsic " +
+ Intrinsic::getName(N->getConstantOperandVal(1), {PtrTy}, M) +
+ " with swizzle mode " + Twine(Val) +
+ " is not supported on the given target.",
+ DL.getDebugLoc()));
+ return Op.getOperand(0);
+ }
+
+ return Op;
+}
+
static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
SDNode *N = Op.getNode();
SDValue Intrin = N->getOperand(1);
@@ -2822,6 +2880,10 @@ static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
case Intrinsic::
nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
return LowerTcgen05MMADisableOutputLane(Op, DAG);
+ case Intrinsic::nvvm_tensormap_replace_elemtype:
+ return lowerTensormapReplaceElemtype(Op, DAG);
+ case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
+ return lowerTensormapReplaceSwizzleMode(Op, DAG);
}
return Op;
}
@@ -4526,6 +4588,35 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
return true;
}
+ case Intrinsic::nvvm_tensormap_replace_global_address:
+ case Intrinsic::nvvm_tensormap_replace_global_stride:{
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::i64;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align.reset();
+ return true;
+ }
+
+ case Intrinsic::nvvm_tensormap_replace_rank:
+ case Intrinsic::nvvm_tensormap_replace_box_dim:
+ case Intrinsic::nvvm_tensormap_replace_global_dim:
+ case Intrinsic::nvvm_tensormap_replace_element_stride:
+ case Intrinsic::nvvm_tensormap_replace_elemtype:
+ case Intrinsic::nvvm_tensormap_replace_interleave_layout:
+ case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
+ case Intrinsic::nvvm_tensormap_replace_swizzle_atomicity:
+ case Intrinsic::nvvm_tensormap_replace_fill_mode: {
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = MVT::i32;
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags = MachineMemOperand::MOStore;
+ Info.align.reset();
+ return true;
+ }
+
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldu_global_p: {
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 817006c367379..b15d1210ded32 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6233,3 +6233,67 @@ foreach sp = [0, 1] in {
}
}
+//
+// tensormap.replace Instructions
+//
+
+class TensormapReplaceInst_2<string state_space, string field_name,
+ string regclass_name, NVPTXRegClass val_RC, ValueType ValTy, Intrinsic Intrin,
+ code predicate> :
+ BasicNVPTXInst<(outs),
+ (ins ADDR:$addr, val_RC:$val),
+ "tensormap.replace.tile." # field_name # "." # state_space # ".b1024." # regclass_name,
+ [(PatFrag<(ops node:$addr, node:$val),
+ (Intrin node:$addr, node:$val), predicate>
+ addr:$addr, ValTy:$val)]>;
+
+class TensormapReplaceInst_3<string state_space, string field_name,
+ string regclass_name, NVPTXRegClass val_RC, ValueType ValTy, Intrinsic Intrin,
+ code predicate> :
+ BasicNVPTXInst<(outs),
+ (ins ADDR:$addr, B32:$ord, val_RC:$val),
+ "tensormap.replace.tile." # field_name # "." # state_space # ".b1024." # regclass_name,
+ [(PatFrag<(ops node:$addr, node:$ord, node:$val),
+ (Intrin node:$addr, node:$ord, node:$val), predicate>
+ addr:$addr, i32:$ord, ValTy:$val)]>;
+
+foreach state_space = ["GLOBAL", "SHARED_CTA"] in {
+ defvar pred = !if(!eq(state_space, "GLOBAL"), AS_match.global, AS_match.shared);
+ defvar ss_ptx = !tolower(!subst("_", "::", state_space));
+ let Predicates = [callSubtarget<"hasTensormapReplaceSupport">] in {
+ def TMAP_REPLACE_TILE_GLOBAL_ADDRESS_ # state_space :
+ TensormapReplaceInst_2<ss_ptx, "global_address", "b64", B64, i64,
+ int_nvvm_tensormap_replace_global_address, pred>;
+
+ foreach field_name = ["INTERLEAVE_LAYOUT", "FILL_MODE", "RANK"] in {
+ defvar intrin = !cast<Intrinsic>("int_nvvm_tensormap_replace_" # !tolower(field_name));
+ def TMAP_REPLACE_TILE_ # field_name # _ # state_space :
+ TensormapReplaceInst_2<ss_ptx, !tolower(field_name), "b32", B32, i32,
+ intrin, pred>;
+ } // field_name
+
+ def TMAP_REPLACE_TILE_GLOBAL_STRIDE_ # state_space :
+ TensormapReplaceInst_3<ss_ptx, "global_stride", "b64", B64, i64,
+ int_nvvm_tensormap_replace_global_stride, pred>;
+
+ foreach field_name = ["BOX_DIM", "GLOBAL_DIM", "ELEMENT_STRIDE"] in {
+ defvar intrin = !cast<Intrinsic>("int_nvvm_tensormap_replace_" # !tolower(field_name));
+ def TMAP_REPLACE_TILE_ # field_name # _ # state_space :
+ TensormapReplaceInst_3<ss_ptx, !tolower(field_name), "b32", B32, i32,
+ intrin, pred>;
+ } // field_name
+ } // hasTensormapReplaceSupport
+
+ def TMAP_REPLACE_TILE_ELEMTYPE_ # state_space :
+ TensormapReplaceInst_2<ss_ptx, "elemtype", "b32", B32, i32,
+ int_nvvm_tensormap_replace_elemtype, pred>;
+
+ def TMAP_REPLACE_SWIZZLE_ATOMICITY_ # state_space :
+ TensormapReplaceInst_2<ss_ptx, "swizzle_atomicity", "b32", B32, i32,
+ int_nvvm_tensormap_replace_swizzle_atomicity, pred>,
+ Requires<[callSubtarget<"hasTensormapReplaceSwizzleAtomicitySupport">]>;
+
+ def TMAP_REPLACE_SWIZZLE_MODE_ # state_space :
+ TensormapReplaceInst_2<ss_ptx, "swizzle_mode", "b32", B32, i32,
+ int_nvvm_tensormap_replace_swizzle_mode, pred>;
+} // state_space
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 6f6057b3689e6..ccf2be1835722 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -20,6 +20,7 @@
#include "NVPTXRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/NVVMIntrinsicUtils.h"
#include "llvm/Support/NVPTXAddrSpace.h"
#include <string>
@@ -202,6 +203,34 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
hasPTXWithAccelSMs(86, {100, 101, 120});
}
+ bool hasTensormapReplaceSupport() const {
+ return hasPTXWithFamilySMs(90, {90, 100, 110, 120}) ||
+ hasPTXWithFamilySMs(88, {90, 100, 101, 120}) ||
+ hasPTXWithAccelSMs(83, {90, 100, 101, 120});
+ }
+
+ bool hasTensormapReplaceElemtypeSupport(unsigned value) const {
+ if (value >= static_cast<unsigned>(nvvm::TensormapElemType::B4x16))
+ return hasPTXWithFamilySMs(90, {100, 110, 120}) ||
+ hasPTXWithFamilySMs(88, {100, 101, 120}) ||
+ hasPTXWithAccelSMs(87, {100, 101, 120});
+
+ return hasTensormapReplaceSupport();
+ }
+
+ bool hasTensormapReplaceSwizzleAtomicitySupport() const {
+ return hasPTXWithFamilySMs(90, {100, 110, 120}) ||
+ hasPTXWithFamilySMs(88, {100, 101, 120}) ||
+ hasPTXWithAccelSMs(87, {100, 101, 120});
+ }
+
+ bool hasTensormapReplaceSwizzleModeSupport(unsigned value) const {
+ if (value == static_cast<unsigned>(nvvm::TensormapSwizzleMode::SWIZZLE_96B))
+ return hasPTXWithAccelSMs(88, {103});
+
+ return hasTensormapReplaceSupport();
+ }
+
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
// terminates a basic block. Instead, it would assume that control flow
// continued to the next instruction. The next instruction could be in the
diff --git a/llvm/test/CodeGen/NVPTX/tensormap_replace.ll b/llvm/test/CodeGen/NVPTX/tensormap_replace.ll
new file mode 100644
index 0000000000000..e1be5f9adbce7
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/tensormap_replace.ll
@@ -0,0 +1,263 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90a -mattr=+ptx83 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90a -mattr=+ptx83 | %ptxas-verify -arch=sm_90a %}
+
+define void @tensormap_replace_global_address(ptr addrspace(1) %global_addr, ptr addrspace(3) %shared_addr, i64 %value) {
+; CHECK-LABEL: tensormap_replace_global_address(
+; CHECK: {
+; CHECK-NEXT: .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [tensormap_replace_global_address_param_0];
+; CHECK-NEXT: ld.param.b64 %rd2, [tensormap_replace_global_address_param_2];
+; CHECK-NEXT: tensormap.replace.tile.global_address.global.b1024.b64 [%rd1], %rd2;
+; CHECK-NEXT: ld.param.b64 %rd3, [tensormap_replace_global_address_param_1];
+; CHECK-NEXT: tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%rd3], %rd2;
+; CHECK-NEXT: ret;
+ call void @llvm.nvvm.tensormap.replace.global.address.p1(ptr addrspace(1) %global_addr, i64 %value)
+ call void @llvm.nvvm.tensormap.replace.global.address.p3(ptr addrspace(3) %shared_addr, i64 %value)
+ ret void
+}
+
+define void @tensormap_replace_rank(ptr addrspace(1) %global_addr, ptr addrspace(3) %shared_addr, i32 %value) {
+; CHECK-LABEL: tensormap_replace_rank(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.b64 %rd1, [tensormap_replace_rank_param_0];
+; CHECK-NEXT: ld.param.b32 %r1, [tensormap_replace_rank_param_2];
+; CHECK-NEXT: tensormap.replace.tile.rank.global.b1024.b32 [%rd1], %r1;
+; CHECK-NEXT: ld.param.b64 %rd2, [tensormap_replace_rank_para...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/172458
More information about the llvm-commits
mailing list