[llvm] [RISCV] Refactor performCONCAT_VECTORSCombine. NFC (PR #69068)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 14 10:18:59 PDT 2023


https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/69068

Instead of doing a forward pass for positive strides and a reverse pass for
negative strides, we can just do one pass by negating the offset if the
pointers do happen to be in reverse order.

We can extend getPtrDiff later in #68726 to handle more constant offset
sequences.


>From 4e24ef483e739a314f046f8a6797091cfd2d11c6 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Sat, 14 Oct 2023 13:14:48 -0400
Subject: [PATCH] [RISCV] Refactor performCONCAT_VECTORSCombine. NFC

Instead of doing a forward pass for positive strides and a reverse pass for
negative strides, we can just do one pass by negating the offset if the
pointers do happen to be in reverse order.

We can extend getPtrDiff later in #68726 to handle more constant offset
sequences.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 81 +++++++--------------
 1 file changed, 25 insertions(+), 56 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d7552317fd8bc69..9912f19c9a50191 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13785,11 +13785,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   EVT BaseLdVT = BaseLd->getValueType(0);
-  SDValue BasePtr = BaseLd->getBasePtr();
 
   // Go through the loads and check that they're strided
-  SmallVector<SDValue> Ptrs;
-  Ptrs.push_back(BasePtr);
+  SmallVector<LoadSDNode *> Lds;
+  Lds.push_back(BaseLd);
   Align Align = BaseLd->getAlign();
   for (SDValue Op : N->ops().drop_front()) {
     auto *Ld = dyn_cast<LoadSDNode>(Op);
@@ -13798,58 +13797,33 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
         Ld->getValueType(0) != BaseLdVT)
       return SDValue();
 
-    Ptrs.push_back(Ld->getBasePtr());
+    Lds.push_back(Ld);
 
     // The common alignment is the most restrictive (smallest) of all the loads
     Align = std::min(Align, Ld->getAlign());
   }
 
-  auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == 0)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add LastPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()-1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
-  };
-  auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == Ptrs.size() - 1)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add NextPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()+1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
+  auto getPtrDiff = [&DAG, &DL](LoadSDNode *Ld1, LoadSDNode *Ld2) {
+    SDValue P1 = Ld1->getBasePtr();
+    SDValue P2 = Ld2->getBasePtr();
+    if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
+      return P2.getOperand(1);
+    if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2)
+      return DAG.getNegative(P1.getOperand(1), DL,
+                             P1.getOperand(1).getValueType());
+    return SDValue();
   };
 
-  bool Reversed = false;
-  SDValue Stride = matchForwardStrided(Ptrs);
-  if (!Stride) {
-    Stride = matchReverseStrided(Ptrs);
-    Reversed = true;
-    // TODO: At this point, we've successfully matched a generalized gather
-    // load.  Maybe we should emit that, and then move the specialized
-    // matchers above and below into a DAG combine?
+  SDValue Stride;
+  for (auto [Idx, Ld] : enumerate(Lds)) {
+    if (Idx == 0)
+      continue;
+    SDValue Offset = getPtrDiff(Lds[Idx - 1], Ld);
+    if (!Offset)
+      return SDValue();
     if (!Stride)
+      Stride = Offset;
+    else if (Offset != Stride)
       return SDValue();
   }
 
@@ -13871,22 +13845,17 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDValue IntID =
     DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
                           Subtarget.getXLenVT());
-  if (Reversed)
-    Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+
   SDValue AllOneMask =
     DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
                  DAG.getConstant(1, DL, MVT::i1));
 
-  SDValue Ops[] = {BaseLd->getChain(),
-                   IntID,
-                   DAG.getUNDEF(WideVecVT),
-                   BasePtr,
-                   Stride,
-                   AllOneMask};
+  SDValue Ops[] = {BaseLd->getChain(),   IntID,  DAG.getUNDEF(WideVecVT),
+                   BaseLd->getBasePtr(), Stride, AllOneMask};
 
   uint64_t MemSize;
   if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
-      ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
+      ConstStride && ConstStride->getSExtValue() >= 0)
     // total size = (elsize * n) + (stride - elsize) * (n-1)
     //            = elsize + stride * (n-1)
     MemSize = WideScalarVT.getSizeInBits() +



More information about the llvm-commits mailing list