[llvm] [NVPTX] Custom lower ADDRSPACECAST (PR #125607)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 3 16:25:27 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Justin Fargnoli (justinfargnoli)

<details>
<summary>Changes</summary>

Avoid [crashing](https://godbolt.org/z/8T58vcM68) when lowering `addrspacecast ptr addrspace(<non-zero>) %ptr to ptr addrspace(<non-zero>)`. 

---
Full diff: https://github.com/llvm/llvm-project/pull/125607.diff


3 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+24) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+1) 
- (modified) llvm/test/CodeGen/NVPTX/addrspacecast.ll (+14) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 773c97f7b4dc0f..962c971e6970cd 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -989,6 +989,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
     setOperationAction(ISD::FLOG2, {MVT::v2f16, MVT::v2bf16}, Expand);
   }
 
+  setOperationAction(ISD::ADDRSPACECAST, {MVT::i32, MVT::i64}, Custom);
+
   // No FPOW or FREM in PTX.
 
   // Now deduce the information based on the above mentioned
@@ -2652,6 +2654,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return SDValue();
   case ISD::FRAMEADDR:
     return SDValue();
+  case ISD::ADDRSPACECAST:
+    return LowerADDRSPACECAST(Op, DAG);
   case ISD::GlobalAddress:
     return LowerGlobalAddress(Op, DAG);
   case ISD::INTRINSIC_W_CHAIN:
@@ -2767,6 +2771,26 @@ unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
   return MachineJumpTableInfo::EK_Inline;
 }
 
+SDValue NVPTXTargetLowering::LowerADDRSPACECAST(SDValue Op,
+                                                SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  AddrSpaceCastSDNode *N = cast<AddrSpaceCastSDNode>(Op.getNode());
+
+  EVT OperandVT = Op.getOperand(0).getValueType();
+  unsigned SrcAS = N->getSrcAddressSpace();
+  EVT ResultVT = Op.getValueType();
+  unsigned DestAS = N->getDestAddressSpace();
+
+  if (SrcAS == llvm::ADDRESS_SPACE_GENERIC ||
+      DestAS == llvm::ADDRESS_SPACE_GENERIC)
+    return Op;
+
+  SDValue ToGeneric = DAG.getAddrSpaceCast(DL, OperandVT, Op.getOperand(0),
+                                           SrcAS, llvm::ADDRESS_SPACE_GENERIC);
+  return DAG.getAddrSpaceCast(DL, ResultVT, ToGeneric,
+                              llvm::ADDRESS_SPACE_GENERIC, DestAS);
+}
+
 // This function is almost a copy of SelectionDAG::expandVAArg().
 // The only diff is that this one produces loads from local address space.
 SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 5adf69d621552f..74ec14ba5f8e32 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -264,6 +264,7 @@ class NVPTXTargetLowering : public TargetLowering {
   const NVPTXSubtarget &STI; // cache the subtarget here
   SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
 
+  SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
 
   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/addrspacecast.ll b/llvm/test/CodeGen/NVPTX/addrspacecast.ll
index 23428b3728674e..e3ebb2f458d46a 100644
--- a/llvm/test/CodeGen/NVPTX/addrspacecast.ll
+++ b/llvm/test/CodeGen/NVPTX/addrspacecast.ll
@@ -99,6 +99,20 @@ define i32 @conv8(ptr %ptr) {
   ret i32 %val
 }
 
+; ALL-LABEL: conv9
+define i32 @conv9(ptr addrspace(1) %ptr) {
+; CLS32: cvta.global.u32
+; CLS32: cvta.to.shared.u32
+; CLS64: cvta.global.u64
+; CLS64: cvta.to.shared.u64
+; PTRCONV: cvt.u32.u64
+; NOPTRCONV-NOT: cvt.u32.u64
+; ALL: ld.shared.u32
+  %specptr = addrspacecast ptr addrspace(1) %ptr to ptr addrspace(3)
+  %val = load i32, ptr addrspace(3) %specptr
+  ret i32 %val
+}
+
 ; Check that we support addrspacecast when splitting the vector
 ; result (<2 x ptr> => 2 x <1 x ptr>).
 ; This also checks that scalarization works for addrspacecast

``````````

</details>


https://github.com/llvm/llvm-project/pull/125607


More information about the llvm-commits mailing list