[llvm] [NVPTX] Add intrinsics and codegen for tensormap.replace (PR #172458)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 16 03:17:37 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This change adds NVVM intrinsics and NVPTX codegen for the
`tensormap.replace` PTX instructions.
Tests are added in `tensormap_replace.ll`, `tensormap_replace_sm_100a.ll`,
and `tensormap_replace_sm_103a.ll` and tested through `ptxas-13.0`.

PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-tensormap-replace

---

Patch is 40.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/172458.diff


9 Files Affected:

- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+52) 
- (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+51) 
- (modified) llvm/lib/IR/NVVMIntrinsicUtils.cpp (+103) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+91) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+64) 
- (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+29) 
- (added) llvm/test/CodeGen/NVPTX/tensormap_replace.ll (+263) 
- (added) llvm/test/CodeGen/NVPTX/tensormap_replace_sm_100a.ll (+60) 
- (added) llvm/test/CodeGen/NVPTX/tensormap_replace_sm_103a.ll (+19) 


``````````diff
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index aab85c2a86373..a82af450b35b3 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -3312,4 +3312,56 @@ foreach sp = [0, 1] in {
   }
 }
 
+//
+// tensormap.replace intrinsics
+//
+
+let IntrProperties = [IntrArgMemOnly, IntrWriteMem, NoCapture<ArgIndex<0>>] in {
+  def int_nvvm_tensormap_replace_global_address :
+    DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i64_ty], []>;
+  def int_nvvm_tensormap_replace_rank :
+    DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty], []>;
+}
+
+let IntrProperties = [IntrArgMemOnly, ImmArg<ArgIndex<1>>, IntrWriteMem,
+                      NoCapture<ArgIndex<0>>] in {
+  def int_nvvm_tensormap_replace_global_stride :
+    DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty, llvm_i64_ty], []>;
+  foreach tmap_field = ["box_dim", "global_dim", "element_stride"] in {
+    def int_nvvm_tensormap_replace_ # tmap_field :
+      DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty, llvm_i32_ty], []>;
+  }
+}
+
+def int_nvvm_tensormap_replace_elemtype : 
+  DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty], 
+    [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+     Range<ArgIndex<1>, 0, 16>,
+     ArgInfo<ArgIndex<1>, [ArgName<"elemtype">, 
+                           ImmArgPrinter<"printTensormapElemType">]>]>;
+def int_nvvm_tensormap_replace_interleave_layout :
+  DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty], 
+    [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+     Range<ArgIndex<1>, 0, 3>,
+     ArgInfo<ArgIndex<1>, [ArgName<"interleave_layout">, 
+                           ImmArgPrinter<"printTensormapInterleaveLayout">]>]>;
+def int_nvvm_tensormap_replace_swizzle_mode :
+  DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty], 
+    [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+     Range<ArgIndex<1>, 0, 5>,
+     ArgInfo<ArgIndex<1>, [ArgName<"swizzle_mode">, 
+                           ImmArgPrinter<"printTensormapSwizzleMode">]>]>;
+def int_nvvm_tensormap_replace_swizzle_atomicity :
+  DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty], 
+    [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+     Range<ArgIndex<1>, 0, 4>,
+     ArgInfo<ArgIndex<1>, [ArgName<"swizzle_atomicity">, 
+                           ImmArgPrinter<"printTensormapSwizzleAtomicity">]>]>;
+def int_nvvm_tensormap_replace_fill_mode :
+  DefaultAttrsIntrinsic<[], [llvm_anyptr_ty, llvm_i32_ty], 
+    [IntrArgMemOnly, IntrWriteMem, ImmArg<ArgIndex<1>>, NoCapture<ArgIndex<0>>,
+     Range<ArgIndex<1>, 0, 2>,
+     ArgInfo<ArgIndex<1>, [ArgName<"fill_mode">, 
+                           ImmArgPrinter<"printTensormapFillMode">]>]>;
+
 } // let TargetPrefix = "nvvm"
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index 62f2a249b1357..067290e57245a 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -59,10 +59,61 @@ enum class Tcgen05CollectorUsageOp : uint8_t {
   USE = 3,
 };
 
+enum class TensormapElemType : uint8_t {
+  U8 = 0,
+  U16 = 1,
+  U32 = 2,
+  S32 = 3,
+  U64 = 4,
+  S64 = 5,
+  F16 = 6,
+  F32 = 7,
+  F32_FTZ = 8,
+  F64 = 9,
+  BF16 = 10,
+  TF32 = 11,
+  TF32_FTZ = 12,
+  B4x16 = 13,
+  B4x16_p64 = 14,
+  B6x16_p32 = 15,
+};
+
+enum class TensormapInterleaveLayout : uint8_t {
+  NO_INTERLEAVE = 0,
+  INTERLEAVE_16B = 1,
+  INTERLEAVE_32B = 2,
+};
+
+enum class TensormapSwizzleMode : uint8_t {
+  NO_SWIZZLE = 0,
+  SWIZZLE_32B = 1,
+  SWIZZLE_64B = 2,
+  SWIZZLE_128B = 3,
+  SWIZZLE_96B = 4,
+};
+
+enum class TensormapSwizzleAtomicity : uint8_t {
+  SWIZZLE_ATOMICITY_16B = 0,
+  SWIZZLE_ATOMICITY_32B = 1,
+  SWIZZLE_ATOMICITY_32B_FLIP_8B = 2,
+  SWIZZLE_ATOMICITY_64B = 3,
+};
+
+enum class TensormapFillMode : uint8_t {
+  ZERO_FILL = 0,
+  OOB_NAN_FILL = 1,
+};
+
 void printTcgen05MMAKind(raw_ostream &OS, const Constant *ImmArgVal);
 
 void printTcgen05CollectorUsageOp(raw_ostream &OS, const Constant *ImmArgVal);
 
+void printTensormapElemType(raw_ostream &OS, const Constant *ImmArgVal);
+void printTensormapInterleaveLayout(raw_ostream &OS, const Constant *ImmArgVal);
+void printTensormapSwizzleMode(raw_ostream &OS, const Constant *ImmArgVal);
+void printTensormapSwizzleAtomicity(raw_ostream &OS, const Constant *ImmArgVal);
+void printTensormapFillMode(raw_ostream &OS, const Constant *ImmArgVal);
+
 inline bool FPToIntegerIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
   switch (IntrinsicID) {
   case Intrinsic::nvvm_f2i_rm_ftz:
diff --git a/llvm/lib/IR/NVVMIntrinsicUtils.cpp b/llvm/lib/IR/NVVMIntrinsicUtils.cpp
index 4389fa38ad3af..2c939ff0ca08a 100644
--- a/llvm/lib/IR/NVVMIntrinsicUtils.cpp
+++ b/llvm/lib/IR/NVVMIntrinsicUtils.cpp
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/StringRef.h"
 #include "llvm/IR/NVVMIntrinsicUtils.h"
 
 using namespace llvm;
@@ -59,3 +60,105 @@ void nvvm::printTcgen05CollectorUsageOp(raw_ostream &OS,
   llvm_unreachable("printTcgen05CollectorUsageOp called with invalid value for "
                    "immediate argument");
 }
+
+void nvvm::printTensormapElemType(raw_ostream &OS, const Constant *ImmArgVal) {
+  static constexpr StringRef TensormapElemTypes[] = {
+      "u8",       "u16",   "u32",       "s32",      "u64",  "s64",
+      "f16",      "f32",   "f32.ftz",   "f64",      "bf16", "tf32",
+      "tf32.ftz", "b4x16", "b4x16_p64", "b6x16_p32"};
+  if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+    uint64_t Val = CI->getZExtValue();
+    if (Val <= static_cast<uint64_t>(nvvm::TensormapElemType::B6x16_p32)) {
+      OS << TensormapElemTypes[Val];
+      return;
+    }
+  }
+  llvm_unreachable("printTensormapElemType called with invalid value for "
+                   "immediate argument");
+}
+
+void nvvm::printTensormapInterleaveLayout(raw_ostream &OS,
+                                          const Constant *ImmArgVal) {
+  if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+    uint64_t Val = CI->getZExtValue();
+    switch (static_cast<TensormapInterleaveLayout>(Val)) {
+    case TensormapInterleaveLayout::NO_INTERLEAVE:
+      OS << "No interleave";
+      return;
+    case TensormapInterleaveLayout::INTERLEAVE_16B:
+      OS << "16B interleave";
+      return;
+    case TensormapInterleaveLayout::INTERLEAVE_32B:
+      OS << "32B interleave";
+      return;
+    }
+  }
+  llvm_unreachable(
+      "printTensormapInterleaveLayout called with invalid value for "
+      "immediate argument");
+}
+
+void nvvm::printTensormapSwizzleMode(raw_ostream &OS,
+                                     const Constant *ImmArgVal) {
+  if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+    uint64_t Val = CI->getZExtValue();
+    switch (static_cast<TensormapSwizzleMode>(Val)) {
+    case TensormapSwizzleMode::NO_SWIZZLE:
+      OS << "No swizzling";
+      return;
+    case TensormapSwizzleMode::SWIZZLE_32B:
+      OS << "32B swizzling";
+      return;
+    case TensormapSwizzleMode::SWIZZLE_64B:
+      OS << "64B swizzling";
+      return;
+    case TensormapSwizzleMode::SWIZZLE_128B:
+      OS << "128B swizzling";
+      return;
+    case TensormapSwizzleMode::SWIZZLE_96B:
+      OS << "96B swizzling";
+      return;
+    }
+  }
+  llvm_unreachable("printTensormapSwizzleMode called with invalid value for "
+                   "immediate argument");
+}
+
+void nvvm::printTensormapSwizzleAtomicity(raw_ostream &OS,
+                                          const Constant *ImmArgVal) {
+  if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+    uint64_t Val = CI->getZExtValue();
+    switch (static_cast<TensormapSwizzleAtomicity>(Val)) {
+    case TensormapSwizzleAtomicity::SWIZZLE_ATOMICITY_16B:
+      OS << "16B";
+      return;
+    case TensormapSwizzleAtomicity::SWIZZLE_ATOMICITY_32B:
+      OS << "32B";
+      return;
+    case TensormapSwizzleAtomicity::SWIZZLE_ATOMICITY_32B_FLIP_8B:
+      OS << "32B + 8B flip";
+      return;
+    case TensormapSwizzleAtomicity::SWIZZLE_ATOMICITY_64B:
+      OS << "64B";
+      return;
+    }
+  }
+  llvm_unreachable(
+      "printTensormapSwizzleAtomicity called with invalid value for "
+      "immediate argument");
+}
+
+void nvvm::printTensormapFillMode(raw_ostream &OS, const Constant *ImmArgVal) {
+  if (const auto *CI = dyn_cast<ConstantInt>(ImmArgVal)) {
+    uint64_t Val = CI->getZExtValue();
+    if (Val == static_cast<uint64_t>(TensormapFillMode::ZERO_FILL)) {
+      OS << "Zero fill";
+      return;
+    } else if (Val == static_cast<uint64_t>(TensormapFillMode::OOB_NAN_FILL)) {
+      OS << "OOB-NaN fill";
+      return;
+    }
+  }
+  llvm_unreachable("printTensormapFillMode called with invalid value for "
+                   "immediate argument");
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 92f3865818530..a8d5da3407a67 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2746,6 +2746,64 @@ lowerTcgen05Ld(SDNode *N, SelectionDAG &DAG, bool HasOffset = false) {
   return {{BuildVector, Chain}};
 }
 
+static SDValue lowerTensormapReplaceElemtype(SDValue Op, SelectionDAG &DAG) {
+  SDNode *N = Op.getNode();
+  SDLoc DL(N);
+  unsigned Val = N->getConstantOperandVal(3);
+
+  if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceElemtypeSupport(
+          Val)) {
+    const Function &Fn = DAG.getMachineFunction().getFunction();
+
+    unsigned AS = 0;
+    if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(N)) {
+      AS = MemN->getAddressSpace();
+    }
+    Type *PtrTy = PointerType::get(*DAG.getContext(), AS);
+    Module *M = DAG.getMachineFunction().getFunction().getParent();
+
+    DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
+        Fn,
+        "Intrinsic " +
+            Intrinsic::getName(N->getConstantOperandVal(1), {PtrTy}, M) +
+            " with elemtype " + Twine(Val) +
+            " is not supported on the given target.",
+        DL.getDebugLoc()));
+    return Op.getOperand(0);
+  }
+
+  return Op;
+}
+
+static SDValue lowerTensormapReplaceSwizzleMode(SDValue Op, SelectionDAG &DAG) {
+  SDNode *N = Op.getNode();
+  SDLoc DL(N);
+  unsigned Val = N->getConstantOperandVal(3);
+
+  if (!DAG.getSubtarget<NVPTXSubtarget>().hasTensormapReplaceSwizzleModeSupport(
+          Val)) {
+    const Function &Fn = DAG.getMachineFunction().getFunction();
+
+    unsigned AS = 0;
+    if (auto *MemN = dyn_cast<MemIntrinsicSDNode>(N)) {
+      AS = MemN->getAddressSpace();
+    }
+    Type *PtrTy = PointerType::get(*DAG.getContext(), AS);
+    Module *M = DAG.getMachineFunction().getFunction().getParent();
+
+    DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
+        Fn,
+        "Intrinsic " +
+            Intrinsic::getName(N->getConstantOperandVal(1), {PtrTy}, M) +
+            " with swizzle mode " + Twine(Val) +
+            " is not supported on the given target.",
+        DL.getDebugLoc()));
+    return Op.getOperand(0);
+  }
+
+  return Op;
+}
+
 static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
   SDNode *N = Op.getNode();
   SDValue Intrin = N->getOperand(1);
@@ -2822,6 +2880,10 @@ static SDValue lowerIntrinsicVoid(SDValue Op, SelectionDAG &DAG) {
   case Intrinsic::
       nvvm_tcgen05_mma_sp_tensor_scale_d_disable_output_lane_cg2_ashift:
     return LowerTcgen05MMADisableOutputLane(Op, DAG);
+  case Intrinsic::nvvm_tensormap_replace_elemtype:
+    return lowerTensormapReplaceElemtype(Op, DAG);
+  case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
+    return lowerTensormapReplaceSwizzleMode(Op, DAG);
   }
   return Op;
 }
@@ -4526,6 +4588,35 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
     return true;
   }
 
+  case Intrinsic::nvvm_tensormap_replace_global_address:
+  case Intrinsic::nvvm_tensormap_replace_global_stride:{
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::i64;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align.reset();
+    return true;
+  }
+
+  case Intrinsic::nvvm_tensormap_replace_rank:
+  case Intrinsic::nvvm_tensormap_replace_box_dim:
+  case Intrinsic::nvvm_tensormap_replace_global_dim:
+  case Intrinsic::nvvm_tensormap_replace_element_stride:
+  case Intrinsic::nvvm_tensormap_replace_elemtype:
+  case Intrinsic::nvvm_tensormap_replace_interleave_layout:
+  case Intrinsic::nvvm_tensormap_replace_swizzle_mode:
+  case Intrinsic::nvvm_tensormap_replace_swizzle_atomicity:
+  case Intrinsic::nvvm_tensormap_replace_fill_mode: {
+    Info.opc = ISD::INTRINSIC_VOID;
+    Info.memVT = MVT::i32;
+    Info.ptrVal = I.getArgOperand(0);
+    Info.offset = 0;
+    Info.flags = MachineMemOperand::MOStore;
+    Info.align.reset();
+    return true;
+  }
+
   case Intrinsic::nvvm_ldu_global_i:
   case Intrinsic::nvvm_ldu_global_f:
   case Intrinsic::nvvm_ldu_global_p: {
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 817006c367379..b15d1210ded32 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -6233,3 +6233,67 @@ foreach sp = [0, 1] in {
   }
 }
 
+//
+// tensormap.replace Instructions
+//
+
+class TensormapReplaceInst_2<string state_space, string field_name, 
+  string regclass_name, NVPTXRegClass val_RC, ValueType ValTy, Intrinsic Intrin,
+  code predicate> :
+  BasicNVPTXInst<(outs), 
+    (ins ADDR:$addr, val_RC:$val), 
+    "tensormap.replace.tile." # field_name # "." # state_space # ".b1024." # regclass_name,
+    [(PatFrag<(ops node:$addr, node:$val),
+       (Intrin node:$addr, node:$val), predicate>
+      addr:$addr, ValTy:$val)]>;
+
+class TensormapReplaceInst_3<string state_space, string field_name, 
+  string regclass_name, NVPTXRegClass val_RC, ValueType ValTy, Intrinsic Intrin,
+  code predicate> :
+  BasicNVPTXInst<(outs), 
+    (ins ADDR:$addr, B32:$ord, val_RC:$val), 
+    "tensormap.replace.tile." # field_name # "." # state_space # ".b1024." # regclass_name,
+    [(PatFrag<(ops node:$addr, node:$ord, node:$val),
+       (Intrin node:$addr, node:$ord, node:$val), predicate>
+      addr:$addr, i32:$ord, ValTy:$val)]>;
+
+foreach state_space = ["GLOBAL", "SHARED_CTA"] in {
+  defvar pred = !if(!eq(state_space, "GLOBAL"), AS_match.global, AS_match.shared);
+  defvar ss_ptx = !tolower(!subst("_", "::", state_space));
+  let Predicates = [callSubtarget<"hasTensormapReplaceSupport">] in {
+    def TMAP_REPLACE_TILE_GLOBAL_ADDRESS_ # state_space : 
+      TensormapReplaceInst_2<ss_ptx, "global_address", "b64", B64, i64,
+        int_nvvm_tensormap_replace_global_address, pred>;
+
+    foreach field_name = ["INTERLEAVE_LAYOUT", "FILL_MODE", "RANK"] in {
+      defvar intrin = !cast<Intrinsic>("int_nvvm_tensormap_replace_" # !tolower(field_name));
+      def TMAP_REPLACE_TILE_ # field_name # _ # state_space : 
+        TensormapReplaceInst_2<ss_ptx, !tolower(field_name), "b32", B32, i32,
+          intrin, pred>;
+    } // field_name
+
+    def TMAP_REPLACE_TILE_GLOBAL_STRIDE_ # state_space : 
+      TensormapReplaceInst_3<ss_ptx, "global_stride", "b64", B64, i64, 
+        int_nvvm_tensormap_replace_global_stride, pred>;
+
+    foreach field_name = ["BOX_DIM", "GLOBAL_DIM", "ELEMENT_STRIDE"] in {
+      defvar intrin = !cast<Intrinsic>("int_nvvm_tensormap_replace_" # !tolower(field_name));
+      def TMAP_REPLACE_TILE_ # field_name # _ # state_space : 
+        TensormapReplaceInst_3<ss_ptx, !tolower(field_name), "b32", B32, i32, 
+          intrin, pred>;
+    } // field_name
+  } // hasTensormapReplaceSupport
+
+  def TMAP_REPLACE_TILE_ELEMTYPE_ # state_space : 
+    TensormapReplaceInst_2<ss_ptx, "elemtype", "b32", B32, i32, 
+      int_nvvm_tensormap_replace_elemtype, pred>;
+
+  def TMAP_REPLACE_SWIZZLE_ATOMICITY_ # state_space : 
+    TensormapReplaceInst_2<ss_ptx, "swizzle_atomicity", "b32", B32, i32, 
+      int_nvvm_tensormap_replace_swizzle_atomicity, pred>,
+    Requires<[callSubtarget<"hasTensormapReplaceSwizzleAtomicitySupport">]>;
+
+  def TMAP_REPLACE_SWIZZLE_MODE_ # state_space : 
+    TensormapReplaceInst_2<ss_ptx, "swizzle_mode", "b32", B32, i32, 
+      int_nvvm_tensormap_replace_swizzle_mode, pred>;
+} // state_space
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 6f6057b3689e6..ccf2be1835722 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -20,6 +20,7 @@
 #include "NVPTXRegisterInfo.h"
 #include "llvm/CodeGen/TargetSubtargetInfo.h"
 #include "llvm/IR/DataLayout.h"
+#include "llvm/IR/NVVMIntrinsicUtils.h"
 #include "llvm/Support/NVPTXAddrSpace.h"
 #include <string>
 
@@ -202,6 +203,34 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
            hasPTXWithAccelSMs(86, {100, 101, 120});
   }
 
+  bool hasTensormapReplaceSupport() const {
+    return hasPTXWithFamilySMs(90, {90, 100, 110, 120}) ||
+           hasPTXWithFamilySMs(88, {90, 100, 101, 120}) ||
+           hasPTXWithAccelSMs(83, {90, 100, 101, 120});
+  }
+
+  bool hasTensormapReplaceElemtypeSupport(unsigned value) const {
+    if (value >= static_cast<unsigned>(nvvm::TensormapElemType::B4x16))
+      return hasPTXWithFamilySMs(90, {100, 110, 120}) ||
+             hasPTXWithFamilySMs(88, {100, 101, 120}) ||
+             hasPTXWithAccelSMs(87, {100, 101, 120});
+
+    return hasTensormapReplaceSupport();
+  }
+
+  bool hasTensormapReplaceSwizzleAtomicitySupport() const {
+    return hasPTXWithFamilySMs(90, {100, 110, 120}) ||
+           hasPTXWithFamilySMs(88, {100, 101, 120}) ||
+           hasPTXWithAccelSMs(87, {100, 101, 120});
+  }
+
+  bool hasTensormapReplaceSwizzleModeSupport(unsigned value) const {
+    if (value == static_cast<unsigned>(nvvm::TensormapSwizzleMode::SWIZZLE_96B))
+      return hasPTXWithAccelSMs(88, {103});
+
+    return hasTensormapReplaceSupport();
+  }
+
   // Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
   // terminates a basic block. Instead, it would assume that control flow
   // continued to the next instruction. The next instruction could be in the
diff --git a/llvm/test/CodeGen/NVPTX/tensormap_replace.ll b/llvm/test/CodeGen/NVPTX/tensormap_replace.ll
new file mode 100644
index 0000000000000..e1be5f9adbce7
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/tensormap_replace.ll
@@ -0,0 +1,263 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90a -mattr=+ptx83 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90a -mattr=+ptx83 | %ptxas-verify -arch=sm_90a %}
+
+define void @tensormap_replace_global_address(ptr addrspace(1) %global_addr, ptr addrspace(3) %shared_addr, i64 %value) {
+; CHECK-LABEL: tensormap_replace_global_address(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b64 %rd<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [tensormap_replace_global_address_param_0];
+; CHECK-NEXT:    ld.param.b64 %rd2, [tensormap_replace_global_address_param_2];
+; CHECK-NEXT:    tensormap.replace.tile.global_address.global.b1024.b64 [%rd1], %rd2;
+; CHECK-NEXT:    ld.param.b64 %rd3, [tensormap_replace_global_address_param_1];
+; CHECK-NEXT:    tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%rd3], %rd2;
+; CHECK-NEXT:    ret;
+  call void @llvm.nvvm.tensormap.replace.global.address.p1(ptr addrspace(1) %global_addr, i64 %value)
+  call void @llvm.nvvm.tensormap.replace.global.address.p3(ptr addrspace(3) %shared_addr, i64 %value)
+  ret void
+}
+
+define void @tensormap_replace_rank(ptr addrspace(1) %global_addr, ptr addrspace(3) %shared_addr, i32 %value) {
+; CHECK-LABEL: tensormap_replace_rank(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [tensormap_replace_rank_param_0];
+; CHECK-NEXT:    ld.param.b32 %r1, [tensormap_replace_rank_param_2];
+; CHECK-NEXT:    tensormap.replace.tile.rank.global.b1024.b32 [%rd1], %r1;
+; CHECK-NEXT:    ld.param.b64 %rd2, [tensormap_replace_rank_para...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/172458


More information about the llvm-commits mailing list