[llvm] Nvptx port LowerBITCAST to SelectionDAG (PR #120903)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 23 09:32:24 PST 2024


https://github.com/GrumpyPigSkin updated https://github.com/llvm/llvm-project/pull/120903

>From e1b6fce91b52484f6cf72690acfe38c20a5de5ef Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Sun, 22 Dec 2024 15:08:25 +0000
Subject: [PATCH 1/6] Ported LowerBITCAST from NVPTXISelLowering.cpp to
 SelectionDAG/LegalizeTypes.cpp.

---
 .../SelectionDAG/LegalizeIntegerTypes.cpp     |  9 +++--
 .../CodeGen/SelectionDAG/LegalizeTypes.cpp    | 33 +++++++++++++++++++
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |  1 +
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 26 +--------------
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h     |  2 --
 5 files changed, 42 insertions(+), 29 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index be7521f3416850..8a6bfc0c66cd82 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2174,8 +2174,13 @@ SDValue DAGTypeLegalizer::PromoteIntOp_ATOMIC_STORE(AtomicSDNode *N) {
 }
 
 SDValue DAGTypeLegalizer::PromoteIntOp_BITCAST(SDNode *N) {
-  // This should only occur in unusual situations like bitcasting to an
-  // x86_fp80, so just turn it into a store+load
+
+  // Use the custom lowering.
+  if (const auto Res = LowerBitcast(N)) {
+    return Res;
+  }
+
+  // If it fails fall back to the default method
   return CreateStackStoreLoad(N->getOperand(0), N->getValueType(0));
 }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index b6abad830c371e..8df3e5ec163e8f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -910,6 +910,39 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
   return DAG.getLoad(DestVT, dl, Store, StackPtr, MachinePointerInfo(), Align);
 }
 
+static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
+                            SDValue Value) {
+  if (Value->getValueType(0) == VT)
+    return Value;
+  return DAG.getNode(ISD::BITCAST, DL, VT, Value);
+}
+
+SDValue DAGTypeLegalizer::LowerBitcast(SDNode *Node) const {
+  assert(Node->getOpcode() == ISD::BITCAST ||
+         Node->getOpcode() == ISD::FP_ROUND && "Unexpected opcode!");
+  // Handle bitcasting from v2i8 without hitting the default promotion
+  // strategy which goes through stack memory.
+  EVT FromVT = Node->getOperand(0)->getValueType(0);
+  if (FromVT != MVT::v2i8) {
+    return SDValue();
+  }
+
+  // Pack vector elements into i16 and bitcast to final type
+  SDLoc DL(Node);
+  SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
+                             Node->getOperand(0), DAG.getIntPtrConstant(0, DL));
+  SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
+                             Node->getOperand(0), DAG.getIntPtrConstant(1, DL));
+  SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
+  SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
+  SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
+  SDValue AsInt = DAG.getNode(
+      ISD::OR, DL, MVT::i16,
+      {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
+  EVT ToVT = Node->getValueType(0);
+  return MaybeBitcast(DAG, DL, ToVT, AsInt);
+}
+
 /// Replace the node's results with custom code provided by the target and
 /// return "true", or do nothing and return "false".
 /// The last parameter is FALSE if we are dealing with a node with legal
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 571a710cc92a34..30951112069ed5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -216,6 +216,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue BitConvertToInteger(SDValue Op);
   SDValue BitConvertVectorToIntegerVector(SDValue Op);
   SDValue CreateStackStoreLoad(SDValue Op, EVT DestVT);
+  SDValue LowerBitcast(SDNode *N) const;
   bool CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult);
   bool CustomWidenLowerNode(SDNode *N, EVT VT);
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 5c1f717694a4c7..2eaeb624004730 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2086,30 +2086,6 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
-SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
-  // Handle bitcasting from v2i8 without hitting the default promotion
-  // strategy which goes through stack memory.
-  EVT FromVT = Op->getOperand(0)->getValueType(0);
-  if (FromVT != MVT::v2i8) {
-    return Op;
-  }
-
-  // Pack vector elements into i16 and bitcast to final type
-  SDLoc DL(Op);
-  SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
-                             Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
-  SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
-                             Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
-  SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
-  SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
-  SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
-  SDValue AsInt = DAG.getNode(
-      ISD::OR, DL, MVT::i16,
-      {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
-  EVT ToVT = Op->getValueType(0);
-  return MaybeBitcast(DAG, DL, ToVT, AsInt);
-}
-
 // We can init constant f16x2/v2i16/v4i8 with a single .b32 move.  Normally it
 // would get lowered as two constant loads and vector-packing move.
 // Instead we want just a constant move:
@@ -2619,7 +2595,7 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   case ISD::BUILD_VECTOR:
     return LowerBUILD_VECTOR(Op, DAG);
   case ISD::BITCAST:
-    return LowerBITCAST(Op, DAG);
+    return SDValue();
   case ISD::EXTRACT_SUBVECTOR:
     return Op;
   case ISD::EXTRACT_VECTOR_ELT:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 4a98fe21b81dc6..446ff1536d36cf 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -265,8 +265,6 @@ class NVPTXTargetLowering : public TargetLowering {
   const NVPTXSubtarget &STI; // cache the subtarget here
   SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
 
-  SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
-
   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;

>From 3320d5585b2b37df05f4dcd54cf9ae11aba42e00 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Sun, 22 Dec 2024 15:12:47 +0000
Subject: [PATCH 2/6] Removed redundant assert check

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp | 1 -
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp        | 3 +--
 2 files changed, 1 insertion(+), 3 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 8a6bfc0c66cd82..bcb59e3c2aef3e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2174,7 +2174,6 @@ SDValue DAGTypeLegalizer::PromoteIntOp_ATOMIC_STORE(AtomicSDNode *N) {
 }
 
 SDValue DAGTypeLegalizer::PromoteIntOp_BITCAST(SDNode *N) {
-
   // Use the custom lowering.
   if (const auto Res = LowerBitcast(N)) {
     return Res;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 8df3e5ec163e8f..4aecf667b2cee1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -918,8 +918,7 @@ static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
 }
 
 SDValue DAGTypeLegalizer::LowerBitcast(SDNode *Node) const {
-  assert(Node->getOpcode() == ISD::BITCAST ||
-         Node->getOpcode() == ISD::FP_ROUND && "Unexpected opcode!");
+  assert(Node->getOpcode() == ISD::BITCAST && "Unexpected opcode!");
   // Handle bitcasting from v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
   EVT FromVT = Node->getOperand(0)->getValueType(0);

>From d7cb1339321d41b6f489450f16ea529eac194889 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Sun, 22 Dec 2024 16:54:58 +0000
Subject: [PATCH 3/6] Addressed Most Code Review Comments

---
 .../SelectionDAG/LegalizeIntegerTypes.cpp     |  3 +--
 .../CodeGen/SelectionDAG/LegalizeTypes.cpp    | 22 ++++++++-----------
 2 files changed, 10 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index bcb59e3c2aef3e..05cbcf3297ac3d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2175,9 +2175,8 @@ SDValue DAGTypeLegalizer::PromoteIntOp_ATOMIC_STORE(AtomicSDNode *N) {
 
 SDValue DAGTypeLegalizer::PromoteIntOp_BITCAST(SDNode *N) {
   // Use the custom lowering.
-  if (const auto Res = LowerBitcast(N)) {
+  if (SDValue Res = LowerBitcast(N))
     return Res;
-  }
 
   // If it fails fall back to the default method
   return CreateStackStoreLoad(N->getOperand(0), N->getValueType(0));
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 4aecf667b2cee1..b91530d6f0bb69 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -910,21 +910,13 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
   return DAG.getLoad(DestVT, dl, Store, StackPtr, MachinePointerInfo(), Align);
 }
 
-static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
-                            SDValue Value) {
-  if (Value->getValueType(0) == VT)
-    return Value;
-  return DAG.getNode(ISD::BITCAST, DL, VT, Value);
-}
-
 SDValue DAGTypeLegalizer::LowerBitcast(SDNode *Node) const {
   assert(Node->getOpcode() == ISD::BITCAST && "Unexpected opcode!");
   // Handle bitcasting from v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
   EVT FromVT = Node->getOperand(0)->getValueType(0);
-  if (FromVT != MVT::v2i8) {
+  if (FromVT != MVT::v2i8)
     return SDValue();
-  }
 
   // Pack vector elements into i16 and bitcast to final type
   SDLoc DL(Node);
@@ -932,14 +924,18 @@ SDValue DAGTypeLegalizer::LowerBitcast(SDNode *Node) const {
                              Node->getOperand(0), DAG.getIntPtrConstant(0, DL));
   SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
                              Node->getOperand(0), DAG.getIntPtrConstant(1, DL));
+  
   SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
   SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
-  SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
+  
+  EVT ShiftAmtTy = TLI.getShiftAmountTy(Extend1.getValueType(), DAG.getDataLayout());
+  SDValue ShiftConst = DAG.getShiftAmountConstant(8, ShiftAmtTy, DL);
   SDValue AsInt = DAG.getNode(
-      ISD::OR, DL, MVT::i16,
-      {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
+      ISD::OR, DL, MVT::i16, Extend0,
+      DAG.getNode(ISD::SHL, DL, Extend1.getValueType(), Extend1, ShiftConst));
   EVT ToVT = Node->getValueType(0);
-  return MaybeBitcast(DAG, DL, ToVT, AsInt);
+  
+  return DAG.getBitcast( ToVT, AsInt);
 }
 
 /// Replace the node's results with custom code provided by the target and

>From b67448702e3fde5b94ceaeed6d3a78b75cf248da Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Sun, 22 Dec 2024 17:00:01 +0000
Subject: [PATCH 4/6] Applied code formatting

---
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index b91530d6f0bb69..8f42877bcb8b66 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -924,18 +924,19 @@ SDValue DAGTypeLegalizer::LowerBitcast(SDNode *Node) const {
                              Node->getOperand(0), DAG.getIntPtrConstant(0, DL));
   SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
                              Node->getOperand(0), DAG.getIntPtrConstant(1, DL));
-  
+
   SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
   SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
-  
-  EVT ShiftAmtTy = TLI.getShiftAmountTy(Extend1.getValueType(), DAG.getDataLayout());
+
+  EVT ShiftAmtTy =
+      TLI.getShiftAmountTy(Extend1.getValueType(), DAG.getDataLayout());
   SDValue ShiftConst = DAG.getShiftAmountConstant(8, ShiftAmtTy, DL);
   SDValue AsInt = DAG.getNode(
       ISD::OR, DL, MVT::i16, Extend0,
       DAG.getNode(ISD::SHL, DL, Extend1.getValueType(), Extend1, ShiftConst));
   EVT ToVT = Node->getValueType(0);
-  
-  return DAG.getBitcast( ToVT, AsInt);
+
+  return DAG.getBitcast(ToVT, AsInt);
 }
 
 /// Replace the node's results with custom code provided by the target and

>From f8dadde2fa135391b5bd60f8629546f3a639a228 Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Mon, 23 Dec 2024 17:23:33 +0000
Subject: [PATCH 5/6] Generalised bit packing and unpacking

---
 .../SelectionDAG/LegalizeIntegerTypes.cpp     |  14 ++-
 .../CodeGen/SelectionDAG/LegalizeTypes.cpp    | 112 ++++++++++++++----
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h |   4 +-
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  37 ------
 4 files changed, 100 insertions(+), 67 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 05cbcf3297ac3d..0eaf2a5dc44f1c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -474,7 +474,13 @@ SDValue DAGTypeLegalizer::PromoteIntRes_BITCAST(SDNode *N) {
 
   switch (getTypeAction(InVT)) {
   case TargetLowering::TypeLegal:
-    break;
+    // Try and use in-register bitcast
+     if (SDValue Res = LowerBitcastInRegister(N))
+       return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT,
+                          Res);
+     // Fallback to stack load store
+     break;
+    
   case TargetLowering::TypePromoteInteger:
     if (NOutVT.bitsEq(NInVT) && !NOutVT.isVector() && !NInVT.isVector())
       // The input promotes to the same size.  Convert the promoted value.
@@ -2174,11 +2180,11 @@ SDValue DAGTypeLegalizer::PromoteIntOp_ATOMIC_STORE(AtomicSDNode *N) {
 }
 
 SDValue DAGTypeLegalizer::PromoteIntOp_BITCAST(SDNode *N) {
-  // Use the custom lowering.
-  if (SDValue Res = LowerBitcast(N))
+  // Try and use in register bitcast
+  if (SDValue Res = LowerBitcastInRegister(N))
     return Res;
 
-  // If it fails fall back to the default method
+  // Fallback
   return CreateStackStoreLoad(N->getOperand(0), N->getValueType(0));
 }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 8f42877bcb8b66..8220b9a9ffc9fd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -910,33 +910,95 @@ SDValue DAGTypeLegalizer::CreateStackStoreLoad(SDValue Op,
   return DAG.getLoad(DestVT, dl, Store, StackPtr, MachinePointerInfo(), Align);
 }
 
-SDValue DAGTypeLegalizer::LowerBitcast(SDNode *Node) const {
-  assert(Node->getOpcode() == ISD::BITCAST && "Unexpected opcode!");
-  // Handle bitcasting from v2i8 without hitting the default promotion
-  // strategy which goes through stack memory.
-  EVT FromVT = Node->getOperand(0)->getValueType(0);
-  if (FromVT != MVT::v2i8)
+SDValue DAGTypeLegalizer::PackBitcastInRegister(SDNode *N) const {
+  assert(N->getOpcode() == ISD::BITCAST && "Unexpected opcode!");
+
+  EVT FromVT = N->getOperand(0)->getValueType(0);
+  EVT ToVT = N->getValueType(0);
+
+  if (!FromVT.isVector() || !ToVT.isInteger())
+    return SDValue();
+
+  SDLoc DL(N);
+
+  // Get the number of elements we need to pack into the integer
+  unsigned NumElems = FromVT.getVectorNumElements();
+  EVT ElemVT = FromVT.getVectorElementType();
+  unsigned ElemBits = ElemVT.getSizeInBits();
+
+  EVT PackVT = EVT::getIntegerVT(*DAG.getContext(), ElemBits * NumElems);
+  SDValue Packed = DAG.getConstant(0, DL, PackVT);
+
+  // Determine endianness
+  bool IsBigEndian = DAG.getDataLayout().isBigEndian();
+
+  for (unsigned I = 0; I < NumElems; ++I) {
+    unsigned ElementIndex = IsBigEndian ? (NumElems - 1 - I) : I;
+    SDValue Elem =
+        DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ElemVT, N->getOperand(0),
+                    DAG.getIntPtrConstant(ElementIndex, DL));
+    SDValue ExtElem = DAG.getNode(ISD::ZERO_EXTEND, DL, PackVT, Elem);
+    SDValue ShiftAmount = DAG.getShiftAmountConstant(ElemBits * I, PackVT, DL);
+    SDValue ShiftedElem =
+        DAG.getNode(ISD::SHL, DL, PackVT, ExtElem, ShiftAmount);
+
+    Packed = DAG.getNode(ISD::OR, DL, PackVT, Packed, ShiftedElem);
+  }
+
+  return DAG.getBitcast(ToVT, Packed);
+}
+
+
+SDValue DAGTypeLegalizer::UnpackBitcastInRegister(SDNode *N) const {
+  assert(N->getOpcode() == ISD::BITCAST && "Unexpected opcode!");
+  EVT FromVT = N->getOperand(0)->getValueType(0);
+  EVT ToVT = N->getValueType(0);
+
+  if (!FromVT.isInteger() || !ToVT.isVector())
     return SDValue();
 
-  // Pack vector elements into i16 and bitcast to final type
-  SDLoc DL(Node);
-  SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
-                             Node->getOperand(0), DAG.getIntPtrConstant(0, DL));
-  SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
-                             Node->getOperand(0), DAG.getIntPtrConstant(1, DL));
-
-  SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
-  SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
-
-  EVT ShiftAmtTy =
-      TLI.getShiftAmountTy(Extend1.getValueType(), DAG.getDataLayout());
-  SDValue ShiftConst = DAG.getShiftAmountConstant(8, ShiftAmtTy, DL);
-  SDValue AsInt = DAG.getNode(
-      ISD::OR, DL, MVT::i16, Extend0,
-      DAG.getNode(ISD::SHL, DL, Extend1.getValueType(), Extend1, ShiftConst));
-  EVT ToVT = Node->getValueType(0);
-
-  return DAG.getBitcast(ToVT, AsInt);
+  SDLoc DL(N);
+
+  unsigned NumElems = ToVT.getVectorNumElements();
+  EVT ElemVT = ToVT.getVectorElementType();
+  unsigned ElemBits = ElemVT.getSizeInBits();
+
+  // Ensure the integer has enough bits
+  unsigned PackedBits = FromVT.getSizeInBits();
+  assert(PackedBits >= ElemBits * NumElems &&
+         "Packed type does not have enough bits to represent the vector!");
+
+  // Determine endianness
+  bool IsBigEndian = DAG.getDataLayout().isBigEndian();
+
+  // Hold all the vector elements
+  SmallVector<SDValue, 8> Elements;
+  Elements.reserve(NumElems);
+
+  for (unsigned I = 0; I < NumElems; ++I) {
+    unsigned ElementIndex = IsBigEndian ? (NumElems - 1 - I) : I;
+    unsigned ShiftAmountVal = ElemBits * ElementIndex;
+
+    SDValue ShiftAmount =
+        DAG.getShiftAmountConstant(ShiftAmountVal, FromVT, DL);
+    SDValue Shifted =
+        DAG.getNode(ISD::SRL, DL, FromVT, N->getOperand(0), ShiftAmount);
+    SDValue Element = DAG.getNode(ISD::TRUNCATE, DL, ElemVT, Shifted);
+    Elements.push_back(Element);
+  }
+
+  return DAG.getBuildVector(ToVT, DL, Elements);
+}
+
+
+SDValue DAGTypeLegalizer::LowerBitcastInRegister(SDNode *N) const {
+  // Try the pack, if we aren't going from vector -> scalar it will backout immediately.
+  if (SDValue Res = PackBitcastInRegister(N)) {
+    return Res;
+  }
+
+  // If we get here then try and unpack the bitcast
+  return UnpackBitcastInRegister(N);
 }
 
 /// Replace the node's results with custom code provided by the target and
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 30951112069ed5..dd45b1e2f10896 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -216,7 +216,9 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue BitConvertToInteger(SDValue Op);
   SDValue BitConvertVectorToIntegerVector(SDValue Op);
   SDValue CreateStackStoreLoad(SDValue Op, EVT DestVT);
-  SDValue LowerBitcast(SDNode *N) const;
+  SDValue PackBitcastInRegister(SDNode *N) const;
+  SDValue UnpackBitcastInRegister(SDNode *N) const;
+  SDValue LowerBitcastInRegister(SDNode *N) const;
   bool CustomLowerNode(SDNode *N, EVT VT, bool LegalizeResult);
   bool CustomWidenLowerNode(SDNode *N, EVT VT);
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 2eaeb624004730..7d06139120d712 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -472,13 +472,6 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
   return VectorInfo;
 }
 
-static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
-                            SDValue Value) {
-  if (Value->getValueType(0) == VT)
-    return Value;
-  return DAG.getNode(ISD::BITCAST, DL, VT, Value);
-}
-
 // NVPTXTargetLowering Constructor.
 NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
                                          const NVPTXSubtarget &STI)
@@ -622,9 +615,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
 
-  // Custom conversions to/from v2i8.
-  setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
-
   // Only logical ops can be done on v4i8 directly, others must be done
   // elementwise.
   setOperationAction(
@@ -2594,8 +2584,6 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return Op;
   case ISD::BUILD_VECTOR:
     return LowerBUILD_VECTOR(Op, DAG);
-  case ISD::BITCAST:
-    return SDValue();
   case ISD::EXTRACT_SUBVECTOR:
     return Op;
   case ISD::EXTRACT_VECTOR_ELT:
@@ -5178,28 +5166,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
   return SDValue();
 }
 
-static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
-                           SmallVectorImpl<SDValue> &Results) {
-  // Handle bitcasting to v2i8 without hitting the default promotion
-  // strategy which goes through stack memory.
-  SDValue Op(Node, 0);
-  EVT ToVT = Op->getValueType(0);
-  if (ToVT != MVT::v2i8) {
-    return;
-  }
-
-  // Bitcast to i16 and unpack elements into a vector
-  SDLoc DL(Node);
-  SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
-  SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
-  SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
-  SDValue Vec1 =
-      DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
-                  DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
-  Results.push_back(
-      DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
-}
-
 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
                               SmallVectorImpl<SDValue> &Results) {
@@ -5435,9 +5401,6 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   switch (N->getOpcode()) {
   default:
     report_fatal_error("Unhandled custom legalization");
-  case ISD::BITCAST:
-    ReplaceBITCAST(N, DAG, Results);
-    return;
   case ISD::LOAD:
     ReplaceLoadVector(N, DAG, Results);
     return;

>From a6ca08a5a571dd29c3465f2e2b8ece82853ab07b Mon Sep 17 00:00:00 2001
From: GrumpyPigSkin <oliver61 at live.co.uk>
Date: Mon, 23 Dec 2024 17:25:29 +0000
Subject: [PATCH 6/6] Formatting

---
 .../lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp | 11 +++++------
 llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp       |  5 ++---
 2 files changed, 7 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 0eaf2a5dc44f1c..15aca9e5a9d48e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -475,12 +475,11 @@ SDValue DAGTypeLegalizer::PromoteIntRes_BITCAST(SDNode *N) {
   switch (getTypeAction(InVT)) {
   case TargetLowering::TypeLegal:
     // Try and use in-register bitcast
-     if (SDValue Res = LowerBitcastInRegister(N))
-       return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT,
-                          Res);
-     // Fallback to stack load store
-     break;
-    
+    if (SDValue Res = LowerBitcastInRegister(N))
+      return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT, Res);
+    // Fallback to stack load store
+    break;
+
   case TargetLowering::TypePromoteInteger:
     if (NOutVT.bitsEq(NInVT) && !NOutVT.isVector() && !NInVT.isVector())
       // The input promotes to the same size.  Convert the promoted value.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
index 8220b9a9ffc9fd..27393907fd36a0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.cpp
@@ -948,7 +948,6 @@ SDValue DAGTypeLegalizer::PackBitcastInRegister(SDNode *N) const {
   return DAG.getBitcast(ToVT, Packed);
 }
 
-
 SDValue DAGTypeLegalizer::UnpackBitcastInRegister(SDNode *N) const {
   assert(N->getOpcode() == ISD::BITCAST && "Unexpected opcode!");
   EVT FromVT = N->getOperand(0)->getValueType(0);
@@ -990,9 +989,9 @@ SDValue DAGTypeLegalizer::UnpackBitcastInRegister(SDNode *N) const {
   return DAG.getBuildVector(ToVT, DL, Elements);
 }
 
-
 SDValue DAGTypeLegalizer::LowerBitcastInRegister(SDNode *N) const {
-  // Try the pack, if we aren't going from vector -> scalar it will backout immediately.
+  // Try the pack, if we aren't going from vector -> scalar it will backout
+  // immediately.
   if (SDValue Res = PackBitcastInRegister(N)) {
     return Res;
   }



More information about the llvm-commits mailing list