[llvm] [AMDGPU] Accept arbitrary sized sources in CalculateByteProvider (PR #70240)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 25 11:39:19 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-amdgpu

Author: Jeffrey Byrnes (jrbyrnes)

<details>
<summary>Changes</summary>

Reland the original patch with additional commit containing fix for two issues:

1. Attempting to bitcast using MVTs with no corresponding LLVM type. getDWordFromOffset now works directly with the original vector to get the corresponding elements given the DWordOffset.
2. Improper bit tracking in CalculateByteProvider for vector types using certain ops. Previously, bit tracking for certain ops (e.g. ISD::TRUNCATE) assumed operands were scalar types, which is not correct since these ops have different semantics depending on vector / scalar. CalculateByteProvider / CalculateSrcByte now exit on vector types, handling which is a TODO.

The patch containing the fixes has not been reviewed yet -- so I've separated out the commits to make reviewing easier. Will land them atomically.

---

Patch is 95.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/70240.diff


6 Files Affected:

- (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+168-87) 
- (modified) llvm/test/CodeGen/AMDGPU/idot4u.ll (+1087) 
- (modified) llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2i16.ll (+6-9) 
- (modified) llvm/test/CodeGen/AMDGPU/load-hi16.ll (+18-18) 
- (modified) llvm/test/CodeGen/AMDGPU/permute.ll (+3-1) 
- (modified) llvm/test/CodeGen/AMDGPU/permute_i8.ll (+497) 


``````````diff
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index ff5d0e27277267b..fd0ffe47b440b0a 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -10834,10 +10834,12 @@ calculateSrcByte(const SDValue Op, uint64_t DestByte, uint64_t SrcIndex = 0,
   if (Depth >= 6)
     return std::nullopt;
 
-  auto ValueSize = Op.getValueSizeInBits();
-  if (ValueSize != 8 && ValueSize != 16 && ValueSize != 32)
+  if (Op.getValueSizeInBits() < 8)
     return std::nullopt;
 
+  if (Op.getValueType().isVector())
+    return ByteProvider<SDValue>::getSrc(Op, DestByte, SrcIndex, IsSigned);
+
   switch (Op->getOpcode()) {
   case ISD::TRUNCATE: {
     return calculateSrcByte(Op->getOperand(0), DestByte, SrcIndex, IsSigned,
@@ -10923,8 +10925,10 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   if (Index > BitWidth / 8 - 1)
     return std::nullopt;
 
+  bool IsVec = Op.getValueType().isVector();
   switch (Op.getOpcode()) {
   case ISD::OR: {
+    if (IsVec) return std::nullopt;
     auto RHS = calculateByteProvider(Op.getOperand(1), Index, Depth + 1,
                                      StartingIndex, IsSigned);
     if (!RHS)
@@ -10945,6 +10949,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   }
 
   case ISD::AND: {
+    if (IsVec) return std::nullopt;
     auto BitMaskOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
     if (!BitMaskOp)
       return std::nullopt;
@@ -10965,6 +10970,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   }
 
   case ISD::FSHR: {
+    if (IsVec) return std::nullopt;
     // fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW))
     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(2));
     if (!ShiftOp || Op.getValueType().isVector())
@@ -10990,6 +10996,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
 
   case ISD::SRA:
   case ISD::SRL: {
+    if (IsVec) return std::nullopt;
     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
     if (!ShiftOp)
       return std::nullopt;
@@ -11015,6 +11022,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   }
 
   case ISD::SHL: {
+    if (IsVec) return std::nullopt;
     auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
     if (!ShiftOp)
       return std::nullopt;
@@ -11039,6 +11047,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   case ISD::SIGN_EXTEND_INREG:
   case ISD::AssertZext:
   case ISD::AssertSext: {
+    if (IsVec) return std::nullopt;
     SDValue NarrowOp = Op->getOperand(0);
     unsigned NarrowBitWidth = NarrowOp.getValueSizeInBits();
     if (Op->getOpcode() == ISD::SIGN_EXTEND_INREG ||
@@ -11069,6 +11078,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   }
 
   case ISD::TRUNCATE: {
+    if (IsVec) return std::nullopt;
     uint64_t NarrowByteWidth = BitWidth / 8;
 
     if (NarrowByteWidth >= Index) {
@@ -11115,9 +11125,11 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     return std::nullopt;
   }
 
-  case ISD::BSWAP:
+  case ISD::BSWAP: {
+    if (IsVec) return std::nullopt;
     return calculateByteProvider(Op->getOperand(0), BitWidth / 8 - Index - 1,
                                  Depth + 1, StartingIndex, IsSigned);
+  }
 
   case ISD::EXTRACT_VECTOR_ELT: {
     auto IdxOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
@@ -11126,8 +11138,6 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
     auto VecIdx = IdxOp->getZExtValue();
     auto ScalarSize = Op.getScalarValueSizeInBits();
     if (ScalarSize != 32) {
-      if ((VecIdx + 1) * ScalarSize > 32)
-        return std::nullopt;
       Index = ScalarSize == 8 ? VecIdx : VecIdx * 2 + Index;
     }
 
@@ -11136,6 +11146,7 @@ calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth,
   }
 
   case AMDGPUISD::PERM: {
+    if (IsVec) return std::nullopt;
     auto PermMask = dyn_cast<ConstantSDNode>(Op->getOperand(2));
     if (!PermMask)
       return std::nullopt;
@@ -11213,9 +11224,6 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
   int Low16 = PermMask & 0xffff;
   int Hi16 = (PermMask & 0xffff0000) >> 16;
 
-  assert(Op.getValueType().isByteSized());
-  assert(OtherOp.getValueType().isByteSized());
-
   auto TempOp = peekThroughBitcasts(Op);
   auto TempOtherOp = peekThroughBitcasts(OtherOp);
 
@@ -11233,15 +11241,66 @@ static bool hasNon16BitAccesses(uint64_t PermMask, SDValue &Op,
   return !addresses16Bits(Low16) || !addresses16Bits(Hi16);
 }
 
+static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,
+                                  unsigned DWordOffset) {
+  SDValue Ret;
+
+  auto ValueSize = Src.getValueSizeInBits().getFixedValue();
+  // ByteProvider must be at least 8 bits
+  assert(!(ValueSize % 8));
+
+  if (ValueSize <= 32)
+    return DAG.getBitcastedAnyExtOrTrunc(Src, SL, MVT::i32);
+
+  if (Src.getValueType().isVector()) {
+    auto BaseSize = Src.getScalarValueSizeInBits();
+    if (BaseSize == 32) {
+      return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Src,
+                         DAG.getConstant(DWordOffset, SL, MVT::i32));
+    }
+    if (BaseSize > 32) {
+      Ret = DAG.getNode(
+          ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, Src,
+          DAG.getConstant(DWordOffset / (BaseSize / 32), SL, MVT::i32));
+      auto ShiftVal = 32 * DWordOffset % (BaseSize / 32);
+      if (ShiftVal)
+        Ret = DAG.getNode(ISD::SRL, SL, Ret.getValueType(), Ret,
+                          DAG.getConstant(ShiftVal, SL, MVT::i32));
+      return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
+    }
+
+    auto NumElements = ValueSize / BaseSize;
+    auto Trunc32Elements = (BaseSize * NumElements) / 32;
+    auto NormalizedTrunc = Trunc32Elements * 32 / BaseSize;
+    auto NumElementsIn32 = 32 / BaseSize;
+    auto NumAvailElements = DWordOffset < Trunc32Elements
+                                ? NumElementsIn32
+                                : NumElements - NormalizedTrunc;
+
+    SmallVector<SDValue, 4> VecSrcs;
+    DAG.ExtractVectorElements(Src, VecSrcs, DWordOffset * NumElementsIn32,
+                              NumAvailElements);
+
+    Ret = DAG.getBuildVector(
+        MVT::getVectorVT(MVT::getIntegerVT(BaseSize), NumAvailElements), SL,
+        VecSrcs);
+    return Ret = DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
+  }
+
+  /// Scalar Type
+  auto ShiftVal = 32 * DWordOffset;
+  Ret = DAG.getNode(ISD::SRL, SL, Src.getValueType(), Src,
+                    DAG.getConstant(ShiftVal, SL, MVT::i32));
+  return DAG.getBitcastedAnyExtOrTrunc(Ret, SL, MVT::i32);
+}
+
 static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   SelectionDAG &DAG = DCI.DAG;
   EVT VT = N->getValueType(0);
-
-  if (VT != MVT::i32)
-    return SDValue();
+  SmallVector<ByteProvider<SDValue>, 8> PermNodes;
 
   // VT is known to be MVT::i32, so we need to provide 4 bytes.
-  SmallVector<ByteProvider<SDValue>, 8> PermNodes;
+  assert(VT == MVT::i32);
   for (int i = 0; i < 4; i++) {
     // Find the ByteProvider that provides the ith byte of the result of OR
     std::optional<ByteProvider<SDValue>> P =
@@ -11249,14 +11308,14 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
     // TODO support constantZero
     if (!P || P->isConstantZero())
       return SDValue();
-
+    
     PermNodes.push_back(*P);
   }
   if (PermNodes.size() != 4)
     return SDValue();
 
-  int FirstSrc = 0;
-  std::optional<int> SecondSrc;
+  std::pair<unsigned, unsigned> FirstSrc(0, PermNodes[0].SrcOffset / 4);
+  std::optional<std::pair<unsigned, unsigned>> SecondSrc;
   uint64_t PermMask = 0x00000000;
   for (size_t i = 0; i < PermNodes.size(); i++) {
     auto PermOp = PermNodes[i];
@@ -11264,33 +11323,31 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
     // by sizeof(Src2) = 4
     int SrcByteAdjust = 4;
 
-    if (!PermOp.hasSameSrc(PermNodes[FirstSrc])) {
-      if (SecondSrc.has_value())
-        if (!PermOp.hasSameSrc(PermNodes[*SecondSrc]))
+    // If the Src uses a byte from a different DWORD, then it corresponds
+    // with a difference source
+    if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) ||
+        ((PermOp.SrcOffset / 4) != FirstSrc.second)) {
+      if (SecondSrc)
+        if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) ||
+            ((PermOp.SrcOffset / 4) != SecondSrc->second))
           return SDValue();
 
       // Set the index of the second distinct Src node
-      SecondSrc = i;
-      assert(!(PermNodes[*SecondSrc].Src->getValueSizeInBits() % 8));
+      SecondSrc = {i, PermNodes[i].SrcOffset / 4};
+      assert(!(PermNodes[SecondSrc->first].Src->getValueSizeInBits() % 8));
       SrcByteAdjust = 0;
     }
-    assert(PermOp.SrcOffset + SrcByteAdjust < 8);
+    assert((PermOp.SrcOffset % 4) + SrcByteAdjust < 8);
     assert(!DAG.getDataLayout().isBigEndian());
-    PermMask |= (PermOp.SrcOffset + SrcByteAdjust) << (i * 8);
+    PermMask |= ((PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);
   }
-
-  SDValue Op = *PermNodes[FirstSrc].Src;
-  SDValue OtherOp = SecondSrc.has_value() ? *PermNodes[*SecondSrc].Src
-                                          : *PermNodes[FirstSrc].Src;
-
-  // Check that we haven't just recreated the same FSHR node.
-  if (N->getOpcode() == ISD::FSHR &&
-      (N->getOperand(0) == Op || N->getOperand(0) == OtherOp) &&
-      (N->getOperand(1) == Op || N->getOperand(1) == OtherOp))
-    return SDValue();
+  SDLoc DL(N);
+  SDValue Op = *PermNodes[FirstSrc.first].Src;
+  Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
+  assert(Op.getValueSizeInBits() == 32);
 
   // Check that we are not just extracting the bytes in order from an op
-  if (Op == OtherOp && Op.getValueSizeInBits() == 32) {
+  if (!SecondSrc) {
     int Low16 = PermMask & 0xffff;
     int Hi16 = (PermMask & 0xffff0000) >> 16;
 
@@ -11302,8 +11359,16 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
       return DAG.getBitcast(MVT::getIntegerVT(32), Op);
   }
 
+  SDValue OtherOp =
+      SecondSrc.has_value() ? *PermNodes[SecondSrc->first].Src : Op;
+
+  if (SecondSrc) {
+    OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
+    assert(OtherOp.getValueSizeInBits() == 32);
+  }
+
   if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {
-    SDLoc DL(N);
+
     assert(Op.getValueType().isByteSized() &&
            OtherOp.getValueType().isByteSized());
 
@@ -11318,7 +11383,6 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
     return DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, Op, OtherOp,
                        DAG.getConstant(PermMask, DL, MVT::i32));
   }
-
   return SDValue();
 }
 
@@ -12794,17 +12858,24 @@ static unsigned addPermMasks(unsigned First, unsigned Second) {
   return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
 }
 
+struct DotSrc {
+  SDValue SrcOp;
+  int64_t PermMask;
+  int64_t DWordOffset;
+};
+
 static void placeSources(ByteProvider<SDValue> &Src0,
                          ByteProvider<SDValue> &Src1,
-                         SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
-                         SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
-                         int Step) {
+                         SmallVectorImpl<DotSrc> &Src0s,
+                         SmallVectorImpl<DotSrc> &Src1s, int Step) {
 
   assert(Src0.Src.has_value() && Src1.Src.has_value());
   // Src0s and Src1s are empty, just place arbitrarily.
   if (Step == 0) {
-    Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
-    Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
+    Src0s.push_back({*Src0.Src, ((Src0.SrcOffset % 4) << 24) + 0x0c0c0c,
+                     Src0.SrcOffset / 4});
+    Src1s.push_back({*Src1.Src, ((Src1.SrcOffset % 4) << 24) + 0x0c0c0c,
+                     Src1.SrcOffset / 4});
     return;
   }
 
@@ -12817,38 +12888,38 @@ static void placeSources(ByteProvider<SDValue> &Src0,
     unsigned FMask = 0xFF << (8 * (3 - Step));
 
     unsigned FirstMask =
-        BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
+        (BPP.first.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
     unsigned SecondMask =
-        BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
+        (BPP.second.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask);
     // Attempt to find Src vector which contains our SDValue, if so, add our
     // perm mask to the existing one. If we are unable to find a match for the
     // first SDValue, attempt to find match for the second.
     int FirstGroup = -1;
     for (int I = 0; I < 2; I++) {
-      SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
-          I == 0 ? Src0s : Src1s;
-      auto MatchesFirst = [&BPP](std::pair<SDValue, unsigned> IterElt) {
-        return IterElt.first == *BPP.first.Src;
+      SmallVectorImpl<DotSrc> &Srcs = I == 0 ? Src0s : Src1s;
+      auto MatchesFirst = [&BPP](DotSrc &IterElt) {
+        return IterElt.SrcOp == *BPP.first.Src &&
+               (IterElt.DWordOffset == (BPP.first.SrcOffset / 4));
       };
 
       auto Match = llvm::find_if(Srcs, MatchesFirst);
       if (Match != Srcs.end()) {
-        Match->second = addPermMasks(FirstMask, Match->second);
+        Match->PermMask = addPermMasks(FirstMask, Match->PermMask);
         FirstGroup = I;
         break;
       }
     }
     if (FirstGroup != -1) {
-      SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
-          FirstGroup == 1 ? Src0s : Src1s;
-      auto MatchesSecond = [&BPP](std::pair<SDValue, unsigned> IterElt) {
-        return IterElt.first == *BPP.second.Src;
+      SmallVectorImpl<DotSrc> &Srcs = FirstGroup == 1 ? Src0s : Src1s;
+      auto MatchesSecond = [&BPP](DotSrc &IterElt) {
+        return IterElt.SrcOp == *BPP.second.Src &&
+               (IterElt.DWordOffset == (BPP.second.SrcOffset / 4));
       };
       auto Match = llvm::find_if(Srcs, MatchesSecond);
       if (Match != Srcs.end()) {
-        Match->second = addPermMasks(SecondMask, Match->second);
+        Match->PermMask = addPermMasks(SecondMask, Match->PermMask);
       } else
-        Srcs.push_back({*BPP.second.Src, SecondMask});
+        Srcs.push_back({*BPP.second.Src, SecondMask, BPP.second.SrcOffset / 4});
       return;
     }
   }
@@ -12860,29 +12931,32 @@ static void placeSources(ByteProvider<SDValue> &Src0,
   unsigned FMask = 0xFF << (8 * (3 - Step));
 
   Src0s.push_back(
-      {*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
+      {*Src0.Src,
+       ((Src0.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
+       Src1.SrcOffset / 4});
   Src1s.push_back(
-      {*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
+      {*Src1.Src,
+       ((Src1.SrcOffset % 4) << (8 * (3 - Step)) | (ZeroMask & ~FMask)),
+       Src1.SrcOffset / 4});
 
   return;
 }
 
-static SDValue
-resolveSources(SelectionDAG &DAG, SDLoc SL,
-               SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
-               bool IsSigned, bool IsAny) {
+static SDValue resolveSources(SelectionDAG &DAG, SDLoc SL,
+                              SmallVectorImpl<DotSrc> &Srcs, bool IsSigned,
+                              bool IsAny) {
 
   // If we just have one source, just permute it accordingly.
   if (Srcs.size() == 1) {
     auto Elt = Srcs.begin();
-    auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first, SL, MVT::i32);
+    auto EltOp = getDWordFromOffset(DAG, SL, Elt->SrcOp, Elt->DWordOffset);
 
-    // v_perm will produce the original value.
-    if (Elt->second == 0x3020100)
-      return EltVal;
+    // v_perm will produce the original value
+    if (Elt->PermMask == 0x3020100)
+      return EltOp;
 
-    return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
-                       DAG.getConstant(Elt->second, SL, MVT::i32));
+    return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
+                       DAG.getConstant(Elt->PermMask, SL, MVT::i32));
   }
 
   auto FirstElt = Srcs.begin();
@@ -12893,8 +12967,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
   // If we have multiple sources in the chain, combine them via perms (using
   // calculated perm mask) and Ors.
   while (true) {
-    auto FirstMask = FirstElt->second;
-    auto SecondMask = SecondElt->second;
+    auto FirstMask = FirstElt->PermMask;
+    auto SecondMask = SecondElt->PermMask;
 
     unsigned FirstCs = FirstMask & 0x0c0c0c0c;
     unsigned FirstPlusFour = FirstMask | 0x04040404;
@@ -12904,9 +12978,9 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
 
     auto PermMask = addPermMasks(FirstMask, SecondMask);
     auto FirstVal =
-        DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
+        getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
     auto SecondVal =
-        DAG.getBitcastedAnyExtOrTrunc(SecondElt->first, SL, MVT::i32);
+        getDWordFromOffset(DAG, SL, SecondElt->SrcOp, SecondElt->DWordOffset);
 
     Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
                                 SecondVal,
@@ -12920,12 +12994,12 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
     // If we only have a FirstElt, then just combine that into the cumulative
     // source node.
     if (SecondElt == Srcs.end()) {
-      auto EltVal =
-          DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
+      auto EltOp =
+          getDWordFromOffset(DAG, SL, FirstElt->SrcOp, FirstElt->DWordOffset);
 
       Perms.push_back(
-          DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
-                      DAG.getConstant(FirstElt->second, SL, MVT::i32)));
+          DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltOp, EltOp,
+                      DAG.getConstant(FirstElt->PermMask, SL, MVT::i32)));
       break;
     }
   }
@@ -12936,9 +13010,8 @@ resolveSources(SelectionDAG &DAG, SDLoc SL,
              : Perms[0];
 }
 
-static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
-                     unsigned ChainLength) {
-  for (auto &[EntryVal, EntryMask] : Srcs) {
+static void fixMasks(SmallVectorImpl<DotSrc> &Srcs, unsigned ChainLength) {
+  for (auto &[EntryVal, EntryMask, EntryOffset] : Srcs) {
     EntryMask = EntryMask >> ((4 - ChainLength) * 8);
     auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
     EntryMask += ZeroMask;
@@ -13003,8 +13076,8 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
       (Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
     SDValue TempNode(N, 0);
     std::optional<bool> IsSigned;
-    SmallVector<std::pair<SDValue, unsigned>, 4> Src0s;
-    SmallVector<std::pair<SDValue, unsigned>, 4> Src1s;
+    SmallVector<DotSrc, 4> Src0s;
+    SmallVector<DotSrc, 4> Src1s;
     SmallVector<SDValue, 4> Src2s;
 
     // Match the v_dot4 tree, while collecting src nodes.
@@ -13082,11 +13155,11 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
     // (commutation).
     bool UseOriginalSrc = false;
     if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
-        Src0s.begin()->second == Src1s.begin()->second &&
-        Src0s.begin()->first.getValueSizeInBits() == 32 &&
-        Src1s.begin()->first.getValueSizeInBits() == 32) {
+        Src0s.begin()->PermMask == Src1s.begin()->PermMask &&
+        Src0s.begin()->SrcOp.getValueSizeInBits() >= 32 &&
+        Src1s.begin()->SrcOp.getValueSizeInBits() >= 32) {
       SmallVector<unsigned, 4> SrcBytes;
-      auto Src0Mask = Src0s.begin()->second;
+      auto Src0Mask = Src0s.begin()->PermMask;
       SrcBytes.push_back(Src0Mask & 0xFF000000);
       bool UniqueEntries = true;
       for (auto I = 1; I < 4; I++) {
@@ -13101,11 +13174,19 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
 
       if (UniqueEntries) {
         UseOriginalSrc = true;
-        // Must be 32 bits to enter above conditional.
-        assert(Src0s.begin()->first.getValueSizeInBits() == 32);
-        assert(Src1s.begin()->first.getValueSizeInBits() == 32);
-        Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
-        Src1 = DAG.getBitcast(MVT::getIntegerVT(32)...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/70240


More information about the llvm-commits mailing list