[llvm] [mlir] [NVPTX] Split Param address space into EntryParam and DeviceParam (NFC) (PR #186636)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Sat Mar 14 20:13:46 PDT 2026
https://github.com/AlexMaclean created https://github.com/llvm/llvm-project/pull/186636
This change begins clarifying and cleaning up some oddities around the param address-space in NVPTX. PTX supports ".param" loads and stores referring to both entry (kernel) and device parameters, however these spaces are actually quite different. Entry param space supports pointers, and addrspace-casting to generic while device parameter space can only be refrenced by a parameter plus an immediate offset. This change accounts for this fact with the following refactors:
- Rename `ADDRESS_SPACE_PARAM` -> `ADDRESS_SPACE_ENTRY_PARAM`. This reflects the fact that only entry parameter space can be meaningfully modeled in LLVM IR and that pointers with this AS in llvm IR are always referring to entry parameters.
- Add `NVPTX::AddressSpace::DeviceParam` for NVPTX MIR instructions. This is used in NVPTX MIR instructions to signify that they load/store device parameters. It has a distinct value from `NVPTX::AddressSpace::EntryParam` so that in the future we can print these differently on supported PTX versions.
>From 909cc13eabc9229ed09f9d18d06c8e92401aed01 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Sat, 14 Mar 2026 20:22:41 +0000
Subject: [PATCH] [NVPTX] Split Param address space into EntryParam and
DeviceParam (NFC)
---
llvm/include/llvm/Support/NVPTXAddrSpace.h | 2 +-
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp | 3 +-
llvm/lib/Target/NVPTX/NVPTX.h | 25 ++++++-----
llvm/lib/Target/NVPTX/NVPTXAliasAnalysis.cpp | 2 +-
llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp | 2 +-
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 42 ++++++++-----------
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 35 +++++++++-------
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 2 +-
llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp | 17 ++++----
.../Target/NVPTX/NVPTXTargetTransformInfo.cpp | 4 +-
.../Target/NVPTX/NVPTXTargetTransformInfo.h | 3 +-
llvm/lib/Target/NVPTX/NVPTXUtilities.h | 3 +-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 4 +-
13 files changed, 73 insertions(+), 71 deletions(-)
diff --git a/llvm/include/llvm/Support/NVPTXAddrSpace.h b/llvm/include/llvm/Support/NVPTXAddrSpace.h
index 04f74c34787cc..12c493b6568fb 100644
--- a/llvm/include/llvm/Support/NVPTXAddrSpace.h
+++ b/llvm/include/llvm/Support/NVPTXAddrSpace.h
@@ -27,7 +27,7 @@ enum AddressSpace : unsigned {
ADDRESS_SPACE_TENSOR = 6,
ADDRESS_SPACE_SHARED_CLUSTER = 7,
- ADDRESS_SPACE_PARAM = 101,
+ ADDRESS_SPACE_ENTRY_PARAM = 101,
};
// According to official PTX Writer's Guide, DWARF debug information should
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 5a5793bc7bc13..2abff556e11bb 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -334,7 +334,8 @@ void NVPTXInstPrinter::printAtomicCode(const MCInst *MI, int OpNum,
case NVPTX::AddressSpace::Const:
case NVPTX::AddressSpace::Shared:
case NVPTX::AddressSpace::SharedCluster:
- case NVPTX::AddressSpace::Param:
+ case NVPTX::AddressSpace::EntryParam:
+ case NVPTX::AddressSpace::DeviceParam:
case NVPTX::AddressSpace::Local:
O << "." << A;
return;
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 883efadc27963..f524b7373c2ea 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -14,12 +14,13 @@
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTX_H
#define LLVM_LIB_TARGET_NVPTX_NVPTX_H
-#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/CodeGen.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
#include "llvm/Target/TargetMachine.h"
+
namespace llvm {
class FunctionPass;
class MachineFunctionPass;
@@ -191,15 +192,19 @@ enum Scope : ScopeUnderlyingType {
using AddressSpaceUnderlyingType = unsigned int;
enum AddressSpace : AddressSpaceUnderlyingType {
- Generic = 0,
- Global = 1,
- Shared = 3,
- Const = 4,
- Local = 5,
- SharedCluster = 7,
-
- // NVPTX Backend Private:
- Param = 101
+ Generic = NVPTXAS::ADDRESS_SPACE_GENERIC,
+ Global = NVPTXAS::ADDRESS_SPACE_GLOBAL,
+ Shared = NVPTXAS::ADDRESS_SPACE_SHARED,
+ Const = NVPTXAS::ADDRESS_SPACE_CONST,
+ Local = NVPTXAS::ADDRESS_SPACE_LOCAL,
+ SharedCluster = NVPTXAS::ADDRESS_SPACE_SHARED_CLUSTER,
+ EntryParam = NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM,
+
+ // DeviceParam is not a real address space, as it does not support pointers
+ // and instead can only be referenced by param+offset. For this reason it is
+ // only used in MIR as an instruction modifier and should not be used in LLVM
+ // IR.
+ DeviceParam
};
namespace PTXLdStInstCode {
diff --git a/llvm/lib/Target/NVPTX/NVPTXAliasAnalysis.cpp b/llvm/lib/Target/NVPTX/NVPTXAliasAnalysis.cpp
index a579783802aa2..f86b5975add74 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAliasAnalysis.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAliasAnalysis.cpp
@@ -109,7 +109,7 @@ AliasResult NVPTXAAResult::alias(const MemoryLocation &Loc1,
// allow any writes to .param pointers.
static bool isConstOrParam(unsigned AS) {
return AS == AddressSpace::ADDRESS_SPACE_CONST ||
- AS == AddressSpace::ADDRESS_SPACE_PARAM;
+ AS == AddressSpace::ADDRESS_SPACE_ENTRY_PARAM;
}
ModRefInfo NVPTXAAResult::getModRefInfoMask(const MemoryLocation &Loc,
diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
index c8b53571c1e59..2b59286295d2b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp
@@ -102,7 +102,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
(LI->uses().begin() + LDInstBasePtrOpIdx)
->ChangeToES(ParamSymbol->getSymbolName());
(LI->uses().begin() + LDInstAddrSpaceOpIdx)
- ->ChangeToImmediate(NVPTX::AddressSpace::Param);
+ ->ChangeToImmediate(NVPTX::AddressSpace::DeviceParam);
}
return true;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 99982ff3181b3..de3a968ddc2cb 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -496,30 +496,21 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
return true;
}
-static std::optional<NVPTX::AddressSpace> convertAS(unsigned AS) {
+NVPTX::AddressSpace NVPTXDAGToDAGISel::getAddrSpace(const MemSDNode *N) {
+ auto AS =
+ static_cast<NVPTX::AddressSpace>(N->getMemOperand()->getAddrSpace());
switch (AS) {
- case llvm::ADDRESS_SPACE_LOCAL:
- return NVPTX::AddressSpace::Local;
- case llvm::ADDRESS_SPACE_GLOBAL:
- return NVPTX::AddressSpace::Global;
- case llvm::ADDRESS_SPACE_SHARED:
- return NVPTX::AddressSpace::Shared;
- case llvm::ADDRESS_SPACE_SHARED_CLUSTER:
- return NVPTX::AddressSpace::SharedCluster;
- case llvm::ADDRESS_SPACE_GENERIC:
- return NVPTX::AddressSpace::Generic;
- case llvm::ADDRESS_SPACE_PARAM:
- return NVPTX::AddressSpace::Param;
- case llvm::ADDRESS_SPACE_CONST:
- return NVPTX::AddressSpace::Const;
- default:
- return std::nullopt;
+ case NVPTX::AddressSpace::Generic:
+ case NVPTX::AddressSpace::Global:
+ case NVPTX::AddressSpace::Shared:
+ case NVPTX::AddressSpace::Const:
+ case NVPTX::AddressSpace::Local:
+ case NVPTX::AddressSpace::SharedCluster:
+ case NVPTX::AddressSpace::EntryParam:
+ case NVPTX::AddressSpace::DeviceParam:
+ return AS;
}
-}
-
-NVPTX::AddressSpace NVPTXDAGToDAGISel::getAddrSpace(const MemSDNode *N) {
- return convertAS(N->getMemOperand()->getAddrSpace())
- .value_or(NVPTX::AddressSpace::Generic);
+ llvm_unreachable("Unexpected address space");
}
NVPTX::Ordering NVPTXDAGToDAGISel::getMemOrder(const MemSDNode *N) const {
@@ -655,7 +646,8 @@ getOperationOrderings(MemSDNode *N, const NVPTXSubtarget *Subtarget) {
// a dead dummy volatile load.
if (CodeAddrSpace == NVPTX::AddressSpace::Local ||
CodeAddrSpace == NVPTX::AddressSpace::Const ||
- CodeAddrSpace == NVPTX::AddressSpace::Param) {
+ CodeAddrSpace == NVPTX::AddressSpace::EntryParam ||
+ CodeAddrSpace == NVPTX::AddressSpace::DeviceParam) {
return NVPTX::Ordering::NotAtomic;
}
@@ -967,7 +959,7 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
case ADDRESS_SPACE_LOCAL:
Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local;
break;
- case ADDRESS_SPACE_PARAM:
+ case ADDRESS_SPACE_ENTRY_PARAM:
Opc = TM.is64Bit() ? NVPTX::cvta_param_64 : NVPTX::cvta_param;
break;
}
@@ -998,7 +990,7 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
case ADDRESS_SPACE_LOCAL:
Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local;
break;
- case ADDRESS_SPACE_PARAM:
+ case ADDRESS_SPACE_ENTRY_PARAM:
Opc = TM.is64Bit() ? NVPTX::cvta_to_param_64 : NVPTX::cvta_to_param;
break;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index e807e8d55f6dc..2a7518a61b3d4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1548,9 +1548,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Align ParamAlign = commonAlignment(ArgAlign, ParamOffset);
SDValue ParamAddr =
DAG.getObjectPtrOffset(dl, ParamSymbol, ParamOffset);
- SDValue StoreParam =
- DAG.getStore(ArgDeclare, dl, SrcLoad, ParamAddr,
- MachinePointerInfo(ADDRESS_SPACE_PARAM), ParamAlign);
+ SDValue StoreParam = DAG.getStore(
+ ArgDeclare, dl, SrcLoad, ParamAddr,
+ MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), ParamAlign);
CallPrereqs.push_back(StoreParam);
J += NumElts;
@@ -1620,9 +1620,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
return GetStoredValue(J + K);
});
- SDValue StoreParam =
- DAG.getStore(ArgDeclare, dl, Val, Ptr,
- MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+ SDValue StoreParam = DAG.getStore(
+ ArgDeclare, dl, Val, Ptr,
+ MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), CurrentAlign);
CallPrereqs.push_back(StoreParam);
J += NumElts;
@@ -1737,9 +1737,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDValue Ptr =
DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
- SDValue R =
- DAG.getLoad(VecVT, dl, Call, Ptr,
- MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+ SDValue R = DAG.getLoad(
+ VecVT, dl, Call, Ptr,
+ MachinePointerInfo(NVPTX::AddressSpace::DeviceParam), CurrentAlign);
LoadChains.push_back(R.getValue(1));
for (const unsigned J : llvm::seq(NumElts))
@@ -4057,6 +4057,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
auto PtrVT = getPointerTy(DAG.getDataLayout());
const Function &F = DAG.getMachineFunction().getFunction();
+ const bool IsKernel = isKernelFunction(F);
SDValue Root = DAG.getRoot();
SmallVector<SDValue, 16> OutChains;
@@ -4112,7 +4113,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
assert(ByvalIn.VT == PtrVT && "ByVal argument must be a pointer");
SDValue P;
- if (isKernelFunction(F)) {
+ if (IsKernel) {
assert(isParamGridConstant(Arg) && "ByVal argument must be lowered to "
"grid_constant by NVPTXLowerArgs");
P = ArgSymbol;
@@ -4145,11 +4146,12 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
dl, ArgSymbol, TypeSize::getFixed(Offsets[I]));
const Align PartAlign = commonAlignment(ArgAlign, Offsets[I]);
- SDValue P =
- DAG.getLoad(VecVT, dl, Root, VecAddr,
- MachinePointerInfo(ADDRESS_SPACE_PARAM), PartAlign,
- MachineMemOperand::MODereferenceable |
- MachineMemOperand::MOInvariant);
+ const unsigned AS = IsKernel ? NVPTX::AddressSpace::EntryParam
+ : NVPTX::AddressSpace::DeviceParam;
+ SDValue P = DAG.getLoad(VecVT, dl, Root, VecAddr,
+ MachinePointerInfo(AS), PartAlign,
+ MachineMemOperand::MODereferenceable |
+ MachineMemOperand::MOInvariant);
P.getNode()->setIROrder(Arg.getArgNo() + 1);
for (const unsigned J : llvm::seq(NumElts)) {
SDValue Elt = getExtractVectorizedValue(P, J, LoadVT, dl, DAG);
@@ -4226,7 +4228,8 @@ NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
DAG.getObjectPtrOffset(dl, RetSymbol, TypeSize::getFixed(Offsets[I]));
Chain = DAG.getStore(Chain, dl, Val, Ptr,
- MachinePointerInfo(ADDRESS_SPACE_PARAM), CurrentAlign);
+ MachinePointerInfo(NVPTX::AddressSpace::DeviceParam),
+ CurrentAlign);
I += NumElts;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 26f5f3f5160f5..84550330cd75d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -23,7 +23,7 @@ def AS_match {
return cast<MemSDNode>(N)->getAddressSpace() == llvm::ADDRESS_SPACE_CONST;
}];
code param = [{
- return cast<MemSDNode>(N)->getAddressSpace() == llvm::ADDRESS_SPACE_PARAM;
+ return cast<MemSDNode>(N)->getAddressSpace() == llvm::ADDRESS_SPACE_ENTRY_PARAM;
}];
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
index 04c2b86de5a0a..835d0ca9d2a4d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
@@ -225,12 +225,13 @@ static void convertToParamAS(ArrayRef<Use *> OldUses, Value *Param) {
return NewGEP;
}
if (auto *BC = dyn_cast<BitCastInst>(OldInst)) {
- auto *NewBCType = PointerType::get(BC->getContext(), ADDRESS_SPACE_PARAM);
+ auto *NewBCType =
+ PointerType::get(BC->getContext(), ADDRESS_SPACE_ENTRY_PARAM);
return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
BC->getName(), BC->getIterator());
}
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(OldInst)) {
- assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
+ assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_ENTRY_PARAM);
(void)ASC;
// Just pass through the argument, the old ASC is no longer needed.
return I.NewParam;
@@ -339,7 +340,7 @@ static void propagateAlignmentToLoads(Value *Val, Align NewAlign,
Worklist.push({cast<Instruction>(CurUser), Ctx.Offset});
else if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
APInt OffsetAccumulated =
- APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
+ APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_ENTRY_PARAM));
if (!I->accumulateConstantOffset(DL, OffsetAccumulated))
continue;
@@ -364,10 +365,10 @@ static void propagateAlignmentToLoads(Value *Val, Align NewAlign,
// alignment of the return value based on the alignment of the argument.
static CallInst *createNVVMInternalAddrspaceWrap(IRBuilder<> &IRB,
Argument &Arg) {
- CallInst *ArgInParam =
- IRB.CreateIntrinsic(Intrinsic::nvvm_internal_addrspace_wrap,
- {IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg.getType()},
- &Arg, {}, Arg.getName() + ".param");
+ CallInst *ArgInParam = IRB.CreateIntrinsic(
+ Intrinsic::nvvm_internal_addrspace_wrap,
+ {IRB.getPtrTy(ADDRESS_SPACE_ENTRY_PARAM), Arg.getType()}, &Arg, {},
+ Arg.getName() + ".param");
if (MaybeAlign ParamAlign = Arg.getParamAlign())
ArgInParam->addRetAttr(
@@ -429,7 +430,7 @@ struct ArgUseChecker : PtrUseVisitor<ArgUseChecker> {
void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
// ASC to param space are no-ops and do not need a copy
- if (ASC.getDestAddressSpace() != ADDRESS_SPACE_PARAM)
+ if (ASC.getDestAddressSpace() != ADDRESS_SPACE_ENTRY_PARAM)
return PI.setEscapedAndAborted(&ASC);
Base::visitAddrSpaceCastInst(ASC);
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index c1fe9300785a3..65b6069bb7495 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -403,7 +403,7 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
// Returns true/false when we know the answer, nullopt otherwise.
static std::optional<bool> evaluateIsSpace(Intrinsic::ID IID, unsigned AS) {
if (AS == NVPTXAS::ADDRESS_SPACE_GENERIC ||
- AS == NVPTXAS::ADDRESS_SPACE_PARAM)
+ AS == NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM)
return std::nullopt; // Got to check at run-time.
switch (IID) {
case Intrinsic::nvvm_isspacep_global:
@@ -579,7 +579,7 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
IRBuilder<> Builder(II);
const unsigned NewAS = NewV->getType()->getPointerAddressSpace();
if (NewAS == NVPTXAS::ADDRESS_SPACE_CONST ||
- NewAS == NVPTXAS::ADDRESS_SPACE_PARAM)
+ NewAS == NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM)
return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_prefetch_tensormap,
NewV);
return nullptr;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index cbb73511c5a08..899249db54574 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -57,7 +57,8 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
bool
canHaveNonUndefGlobalInitializerInAddressSpace(unsigned AS) const override {
return AS != AddressSpace::ADDRESS_SPACE_SHARED &&
- AS != AddressSpace::ADDRESS_SPACE_LOCAL && AS != ADDRESS_SPACE_PARAM;
+ AS != AddressSpace::ADDRESS_SPACE_LOCAL &&
+ AS != AddressSpace::ADDRESS_SPACE_ENTRY_PARAM;
}
std::optional<Instruction *>
diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
index f0428e281b081..b96a83236ac23 100644
--- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h
+++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h
@@ -207,7 +207,8 @@ inline std::string AddressSpaceToString(AddressSpace A) {
return "shared";
case AddressSpace::SharedCluster:
return "shared::cluster";
- case AddressSpace::Param:
+ case AddressSpace::EntryParam:
+ case AddressSpace::DeviceParam:
return "param";
case AddressSpace::Local:
return "local";
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7d49aa3878ebe..528e709629ebf 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -4796,9 +4796,7 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
static llvm::Value *getParamCastedAddr(llvm::Value *addr,
llvm::IRBuilderBase &builder) {
return builder.CreateAddrSpaceCast(
- addr,
- llvm::PointerType::get(builder.getContext(),
- llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
+ addr, builder.getPtrTy(llvm::NVPTXAS::ADDRESS_SPACE_ENTRY_PARAM));
}
NVVM::IDArgPair
More information about the llvm-commits
mailing list