[llvm] [SDISel] Teach the type legalizer about ADDRSPACECAST (PR #90969)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 3 06:41:51 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Quentin Colombet (qcolombet)

<details>
<summary>Changes</summary>

Vectorized ADDRSPACECASTs were not supported by the type legalizer.

This patch adds the support for:
- splitting the vector result: <2 x ptr> => 2 x <1 x ptr>
- scalarization: <1 x ptr> => ptr
- widening: <3 x ptr> => <4 x ptr>

This is all exercised by the added NVPTX tests.

@<!-- -->jholewinski to double check the NVPTX test results.

---
Full diff: https://github.com/llvm/llvm-project/pull/90969.diff


3 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+3) 
- (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp (+66) 
- (modified) llvm/test/CodeGen/NVPTX/addrspacecast.ll (+92) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 0252e3d6febca9..dcd547d231c70b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -785,6 +785,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue ScalarizeVecRes_InregOp(SDNode *N);
   SDValue ScalarizeVecRes_VecInregOp(SDNode *N);
 
+  SDValue ScalarizeVecRes_ADDRSPACECAST(SDNode *N);
   SDValue ScalarizeVecRes_BITCAST(SDNode *N);
   SDValue ScalarizeVecRes_BUILD_VECTOR(SDNode *N);
   SDValue ScalarizeVecRes_EXTRACT_SUBVECTOR(SDNode *N);
@@ -852,6 +853,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVecRes_BinOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_TernaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_ADDRSPACECAST(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_FFREXP(SDNode *N, unsigned ResNo, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_ExtendOp(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_InregOp(SDNode *N, SDValue &Lo, SDValue &Hi);
@@ -955,6 +957,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   // Widen Vector Result Promotion.
   void WidenVectorResult(SDNode *N, unsigned ResNo);
   SDValue WidenVecRes_MERGE_VALUES(SDNode* N, unsigned ResNo);
+  SDValue WidenVecRes_ADDRSPACECAST(SDNode *N);
   SDValue WidenVecRes_AssertZext(SDNode* N);
   SDValue WidenVecRes_BITCAST(SDNode* N);
   SDValue WidenVecRes_BUILD_VECTOR(SDNode* N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index cab4dc5f3c1565..14501e5d01d568 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -23,6 +23,7 @@
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Analysis/MemoryLocation.h"
 #include "llvm/Analysis/VectorUtils.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/TypeSize.h"
@@ -116,6 +117,9 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::FCANONICALIZE:
     R = ScalarizeVecRes_UnaryOp(N);
     break;
+  case ISD::ADDRSPACECAST:
+    R = ScalarizeVecRes_ADDRSPACECAST(N);
+    break;
   case ISD::FFREXP:
     R = ScalarizeVecRes_FFREXP(N, ResNo);
     break;
@@ -475,6 +479,31 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_VecInregOp(SDNode *N) {
   llvm_unreachable("Illegal extend_vector_inreg opcode");
 }
 
+SDValue DAGTypeLegalizer::ScalarizeVecRes_ADDRSPACECAST(SDNode *N) {
+  EVT DestVT = N->getValueType(0).getVectorElementType();
+  SDValue Op = N->getOperand(0);
+  EVT OpVT = Op.getValueType();
+  SDLoc DL(N);
+  // The result needs scalarizing, but it's not a given that the source does.
+  // This is a workaround for targets where it's impossible to scalarize the
+  // result of a conversion, because the source type is legal.
+  // For instance, this happens on AArch64: v1i1 is illegal but v1i{8,16,32}
+  // are widened to v8i8, v4i16, and v2i32, which is legal, because v1i64 is
+  // legal and was not scalarized.
+  // See the similar logic in ScalarizeVecRes_SETCC
+  if (getTypeAction(OpVT) == TargetLowering::TypeScalarizeVector) {
+    Op = GetScalarizedVector(Op);
+  } else {
+    EVT VT = OpVT.getVectorElementType();
+    Op = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op,
+                     DAG.getVectorIdxConstant(0, DL));
+  }
+  auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
+  unsigned SrcAS = AddrSpaceCastN->getSrcAddressSpace();
+  unsigned DestAS = AddrSpaceCastN->getDestAddressSpace();
+  return DAG.getAddrSpaceCast(SDLoc(N), DestVT, Op, SrcAS, DestAS);
+}
+
 SDValue DAGTypeLegalizer::ScalarizeVecRes_SCALAR_TO_VECTOR(SDNode *N) {
   // If the operand is wider than the vector element type then it is implicitly
   // truncated.  Make that explicit here.
@@ -1122,6 +1151,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::FCANONICALIZE:
     SplitVecRes_UnaryOp(N, Lo, Hi);
     break;
+  case ISD::ADDRSPACECAST:
+    SplitVecRes_ADDRSPACECAST(N, Lo, Hi);
+    break;
   case ISD::FFREXP:
     SplitVecRes_FFREXP(N, ResNo, Lo, Hi);
     break;
@@ -2353,6 +2385,27 @@ void DAGTypeLegalizer::SplitVecRes_UnaryOp(SDNode *N, SDValue &Lo,
   Hi = DAG.getNode(Opcode, dl, HiVT, {Hi, MaskHi, EVLHi}, Flags);
 }
 
+void DAGTypeLegalizer::SplitVecRes_ADDRSPACECAST(SDNode *N, SDValue &Lo,
+                                                 SDValue &Hi) {
+  EVT LoVT, HiVT;
+  SDLoc dl(N);
+  std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0));
+
+  // If the input also splits, handle it directly for a compile time speedup.
+  // Otherwise split it by hand.
+  EVT InVT = N->getOperand(0).getValueType();
+  if (getTypeAction(InVT) == TargetLowering::TypeSplitVector)
+    GetSplitVector(N->getOperand(0), Lo, Hi);
+  else
+    std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
+
+  auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
+  unsigned SrcAS = AddrSpaceCastN->getSrcAddressSpace();
+  unsigned DestAS = AddrSpaceCastN->getDestAddressSpace();
+  Lo = DAG.getAddrSpaceCast(dl, LoVT, Lo, SrcAS, DestAS);
+  Hi = DAG.getAddrSpaceCast(dl, HiVT, Hi, SrcAS, DestAS);
+}
+
 void DAGTypeLegalizer::SplitVecRes_FFREXP(SDNode *N, unsigned ResNo,
                                           SDValue &Lo, SDValue &Hi) {
   SDLoc dl(N);
@@ -4121,6 +4174,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
     report_fatal_error("Do not know how to widen the result of this operator!");
 
   case ISD::MERGE_VALUES:      Res = WidenVecRes_MERGE_VALUES(N, ResNo); break;
+  case ISD::ADDRSPACECAST:
+    Res = WidenVecRes_ADDRSPACECAST(N);
+    break;
   case ISD::AssertZext:        Res = WidenVecRes_AssertZext(N); break;
   case ISD::BITCAST:           Res = WidenVecRes_BITCAST(N); break;
   case ISD::BUILD_VECTOR:      Res = WidenVecRes_BUILD_VECTOR(N); break;
@@ -5086,6 +5142,16 @@ SDValue DAGTypeLegalizer::WidenVecRes_MERGE_VALUES(SDNode *N, unsigned ResNo) {
   return GetWidenedVector(WidenVec);
 }
 
+SDValue DAGTypeLegalizer::WidenVecRes_ADDRSPACECAST(SDNode *N) {
+  EVT WidenVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
+  SDValue InOp = GetWidenedVector(N->getOperand(0));
+  auto *AddrSpaceCastN = cast<AddrSpaceCastSDNode>(N);
+
+  return DAG.getAddrSpaceCast(SDLoc(N), WidenVT, InOp,
+                              AddrSpaceCastN->getSrcAddressSpace(),
+                              AddrSpaceCastN->getDestAddressSpace());
+}
+
 SDValue DAGTypeLegalizer::WidenVecRes_BITCAST(SDNode *N) {
   SDValue InOp = N->getOperand(0);
   EVT InVT = InOp.getValueType();
diff --git a/llvm/test/CodeGen/NVPTX/addrspacecast.ll b/llvm/test/CodeGen/NVPTX/addrspacecast.ll
index b680490ac5b124..85752bb95eb31f 100644
--- a/llvm/test/CodeGen/NVPTX/addrspacecast.ll
+++ b/llvm/test/CodeGen/NVPTX/addrspacecast.ll
@@ -98,3 +98,95 @@ define i32 @conv8(ptr %ptr) {
   %val = load i32, ptr addrspace(5) %specptr
   ret i32 %val
 }
+
+; Check that we support addrspacecast when splitting the vector
+; result (<2 x ptr> => 2 x <1 x ptr>).
+; This also checks that scalarization works for addrspacecast
+; (when going from <1 x ptr> to ptr.)
+; ALL-LABEL: split1To0
+define void @split1To0(ptr nocapture noundef readonly %xs) {
+; CLS32: cvta.global.u32
+; CLS32: cvta.global.u32
+; CLS64: cvta.global.u64
+; CLS64: cvta.global.u64
+; ALL: st.u32
+; ALL: st.u32
+  %vec_addr = load <2 x ptr addrspace(1)>, ptr %xs, align 16
+  %addrspacecast = addrspacecast <2 x ptr addrspace(1)> %vec_addr to <2 x ptr>
+  %extractelement0 = extractelement <2 x ptr> %addrspacecast, i64 0
+  store float 0.5, ptr %extractelement0, align 4
+  %extractelement1 = extractelement <2 x ptr> %addrspacecast, i64 1
+  store float 1.0, ptr %extractelement1, align 4
+  ret void
+}
+
+; Same as split1To0 but from 0 to 1, to make sure the addrspacecast preserve
+; the source and destination addrspaces properly.
+; ALL-LABEL: split0To1
+define void @split0To1(ptr nocapture noundef readonly %xs) {
+; CLS32: cvta.to.global.u32
+; CLS32: cvta.to.global.u32
+; CLS64: cvta.to.global.u64
+; CLS64: cvta.to.global.u64
+; ALL: st.global.u32
+; ALL: st.global.u32
+  %vec_addr = load <2 x ptr>, ptr %xs, align 16
+  %addrspacecast = addrspacecast <2 x ptr> %vec_addr to <2 x ptr addrspace(1)>
+  %extractelement0 = extractelement <2 x ptr addrspace(1)> %addrspacecast, i64 0
+  store float 0.5, ptr addrspace(1) %extractelement0, align 4
+  %extractelement1 = extractelement <2 x ptr addrspace(1)> %addrspacecast, i64 1
+  store float 1.0, ptr addrspace(1) %extractelement1, align 4
+  ret void
+}
+
+; Check that we support addrspacecast when a widening is required
+; (3 x ptr => 4 x ptr).
+; ALL-LABEL: widen1To0
+define void @widen1To0(ptr nocapture noundef readonly %xs) {
+; CLS32: cvta.global.u32
+; CLS32: cvta.global.u32
+; CLS32: cvta.global.u32
+
+; CLS64: cvta.global.u64
+; CLS64: cvta.global.u64
+; CLS64: cvta.global.u64
+
+; ALL: st.u32
+; ALL: st.u32
+; ALL: st.u32
+  %vec_addr = load <3 x ptr addrspace(1)>, ptr %xs, align 16
+  %addrspacecast = addrspacecast <3 x ptr addrspace(1)> %vec_addr to <3 x ptr>
+  %extractelement0 = extractelement <3 x ptr> %addrspacecast, i64 0
+  store float 0.5, ptr %extractelement0, align 4
+  %extractelement1 = extractelement <3 x ptr> %addrspacecast, i64 1
+  store float 1.0, ptr %extractelement1, align 4
+  %extractelement2 = extractelement <3 x ptr> %addrspacecast, i64 2
+  store float 1.5, ptr %extractelement2, align 4
+  ret void
+}
+
+; Same as widen1To0 but from 0 to 1, to make sure the addrspacecast preserve
+; the source and destination addrspaces properly.
+; ALL-LABEL: widen0To1
+define void @widen0To1(ptr nocapture noundef readonly %xs) {
+; CLS32: cvta.to.global.u32
+; CLS32: cvta.to.global.u32
+; CLS32: cvta.to.global.u32
+
+; CLS64: cvta.to.global.u64
+; CLS64: cvta.to.global.u64
+; CLS64: cvta.to.global.u64
+
+; ALL: st.global.u32
+; ALL: st.global.u32
+; ALL: st.global.u32
+  %vec_addr = load <3 x ptr>, ptr %xs, align 16
+  %addrspacecast = addrspacecast <3 x ptr> %vec_addr to <3 x ptr addrspace(1)>
+  %extractelement0 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 0
+  store float 0.5, ptr addrspace(1) %extractelement0, align 4
+  %extractelement1 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 1
+  store float 1.0, ptr addrspace(1) %extractelement1, align 4
+  %extractelement2 = extractelement <3 x ptr addrspace(1)> %addrspacecast, i64 2
+  store float 1.5, ptr addrspace(1) %extractelement2, align 4
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/90969


More information about the llvm-commits mailing list