[llvm] 182a65a - [RISCV] Refactor performCONCAT_VECTORSCombine. NFC (#69068)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 16 09:56:37 PDT 2023


Author: Luke Lau
Date: 2023-10-16T12:56:32-04:00
New Revision: 182a65adcf8af922246cac80ea6f3fdb159cd89e

URL: https://github.com/llvm/llvm-project/commit/182a65adcf8af922246cac80ea6f3fdb159cd89e
DIFF: https://github.com/llvm/llvm-project/commit/182a65adcf8af922246cac80ea6f3fdb159cd89e.diff

LOG: [RISCV] Refactor performCONCAT_VECTORSCombine. NFC (#69068)

Instead of doing a forward pass for positive strides and a reverse pass
for
negative strides, we can just do one pass by negating the offset if the
pointers do happen to be in reverse order.

We can extend getPtrDiff later in #68726 to handle more constant offset
sequences.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index ed1f7b6c50a4d12..6eb253cc5146635 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13785,11 +13785,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   EVT BaseLdVT = BaseLd->getValueType(0);
-  SDValue BasePtr = BaseLd->getBasePtr();
 
   // Go through the loads and check that they're strided
-  SmallVector<SDValue> Ptrs;
-  Ptrs.push_back(BasePtr);
+  SmallVector<LoadSDNode *> Lds;
+  Lds.push_back(BaseLd);
   Align Align = BaseLd->getAlign();
   for (SDValue Op : N->ops().drop_front()) {
     auto *Ld = dyn_cast<LoadSDNode>(Op);
@@ -13798,60 +13797,38 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
         Ld->getValueType(0) != BaseLdVT)
       return SDValue();
 
-    Ptrs.push_back(Ld->getBasePtr());
+    Lds.push_back(Ld);
 
     // The common alignment is the most restrictive (smallest) of all the loads
     Align = std::min(Align, Ld->getAlign());
   }
 
-  auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == 0)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add LastPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()-1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
-  };
-  auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == Ptrs.size() - 1)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add NextPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()+1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
+  using PtrDiff = std::pair<SDValue, bool>;
+  auto GetPtrDiff = [](LoadSDNode *Ld1,
+                       LoadSDNode *Ld2) -> std::optional<PtrDiff> {
+    SDValue P1 = Ld1->getBasePtr();
+    SDValue P2 = Ld2->getBasePtr();
+    if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
+      return {{P2.getOperand(1), false}};
+    if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2)
+      return {{P1.getOperand(1), true}};
+
+    return std::nullopt;
   };
 
-  bool Reversed = false;
-  SDValue Stride = matchForwardStrided(Ptrs);
-  if (!Stride) {
-    Stride = matchReverseStrided(Ptrs);
-    Reversed = true;
-    // TODO: At this point, we've successfully matched a generalized gather
-    // load.  Maybe we should emit that, and then move the specialized
-    // matchers above and below into a DAG combine?
-    if (!Stride)
+  // Get the distance between the first and second loads
+  auto BaseDiff = GetPtrDiff(Lds[0], Lds[1]);
+  if (!BaseDiff)
+    return SDValue();
+
+  // Check all the loads are the same distance apart
+  for (auto *It = Lds.begin() + 1; It != Lds.end() - 1; It++)
+    if (GetPtrDiff(*It, *std::next(It)) != BaseDiff)
       return SDValue();
-  }
+
+  // TODO: At this point, we've successfully matched a generalized gather
+  // load.  Maybe we should emit that, and then move the specialized
+  // matchers above and below into a DAG combine?
 
   // Get the widened scalar type, e.g. v4i8 -> i64
   unsigned WideScalarBitWidth =
@@ -13867,26 +13844,25 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
     return SDValue();
 
+  auto [Stride, MustNegateStride] = *BaseDiff;
+  if (MustNegateStride)
+    Stride = DAG.getNegative(Stride, DL, Stride.getValueType());
+
   SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
   SDValue IntID =
     DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
                           Subtarget.getXLenVT());
-  if (Reversed)
-    Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+
   SDValue AllOneMask =
     DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
                  DAG.getConstant(1, DL, MVT::i1));
 
-  SDValue Ops[] = {BaseLd->getChain(),
-                   IntID,
-                   DAG.getUNDEF(WideVecVT),
-                   BasePtr,
-                   Stride,
-                   AllOneMask};
+  SDValue Ops[] = {BaseLd->getChain(),   IntID,  DAG.getUNDEF(WideVecVT),
+                   BaseLd->getBasePtr(), Stride, AllOneMask};
 
   uint64_t MemSize;
   if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
-      ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
+      ConstStride && ConstStride->getSExtValue() >= 0)
     // total size = (elsize * n) + (stride - elsize) * (n-1)
     //            = elsize + stride * (n-1)
     MemSize = WideScalarVT.getSizeInBits() +


        


More information about the llvm-commits mailing list