[llvm] c3b48ec - [RISCV] Match strided loads with reversed indexing sequences

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 22 07:59:57 PDT 2023


Author: Philip Reames
Date: 2023-08-22T07:59:49-07:00
New Revision: c3b48ec6ff92c47cd6136ee25360d31952a3adf2

URL: https://github.com/llvm/llvm-project/commit/c3b48ec6ff92c47cd6136ee25360d31952a3adf2
DIFF: https://github.com/llvm/llvm-project/commit/c3b48ec6ff92c47cd6136ee25360d31952a3adf2.diff

LOG: [RISCV] Match strided loads with reversed indexing sequences

This extends the concat_vector of loads to strided_load transform to handle reversed index pattern. The previous code expected indexing of the form (a0, a1+S, a2+S,...). However, we can also see indexing of the form (a1+S, a2+S, a3+S, .., aS). This form is a strided load starting at address aN + S*(n-1) with stride -S.

Note that this is also fixing what looks to be a bug in the memory location reasoning for forward strided case. A strided load with negative stride access eltsize bytes past base ptr, and then bytes *before* base ptr. (That is, the range should extend from before base ptr to after base ptr.)

Differential Revision: https://reviews.llvm.org/D157886

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 536cbcda6da28d..8eaaa4d06136b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12894,10 +12894,9 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDValue BasePtr = BaseLd->getBasePtr();
 
   // Go through the loads and check that they're strided
-  SDValue CurPtr = BasePtr;
-  SDValue Stride;
+  SmallVector<SDValue> Ptrs;
+  Ptrs.push_back(BasePtr);
   Align Align = BaseLd->getAlign();
-
   for (SDValue Op : N->ops().drop_front()) {
     auto *Ld = dyn_cast<LoadSDNode>(Op);
     if (!Ld || !Ld->isSimple() || !Op.hasOneUse() ||
@@ -12905,27 +12904,66 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
         Ld->getValueType(0) != BaseLdVT)
       return SDValue();
 
-    SDValue Ptr = Ld->getBasePtr();
-    // Check that each load's pointer is (add CurPtr, Stride)
-    if (Ptr.getOpcode() != ISD::ADD || Ptr.getOperand(0) != CurPtr)
-      return SDValue();
-    SDValue Offset = Ptr.getOperand(1);
-    if (!Stride)
-      Stride = Offset;
-    else if (Offset != Stride)
-      return SDValue();
+    Ptrs.push_back(Ld->getBasePtr());
 
     // The common alignment is the most restrictive (smallest) of all the loads
     Align = std::min(Align, Ld->getAlign());
+  }
 
-    CurPtr = Ptr;
+  auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
+    SDValue Stride;
+    for (auto Idx : enumerate(Ptrs)) {
+      if (Idx.index() == 0)
+        continue;
+      SDValue Ptr = Idx.value();
+      // Check that each load's pointer is (add LastPtr, Stride)
+      if (Ptr.getOpcode() != ISD::ADD ||
+          Ptr.getOperand(0) != Ptrs[Idx.index()-1])
+        return SDValue();
+      SDValue Offset = Ptr.getOperand(1);
+      if (!Stride)
+        Stride = Offset;
+      else if (Offset != Stride)
+        return SDValue();
+    }
+    return Stride;
+  };
+  auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
+    SDValue Stride;
+    for (auto Idx : enumerate(Ptrs)) {
+      if (Idx.index() == Ptrs.size() - 1)
+        continue;
+      SDValue Ptr = Idx.value();
+      // Check that each load's pointer is (add NextPtr, Stride)
+      if (Ptr.getOpcode() != ISD::ADD ||
+          Ptr.getOperand(0) != Ptrs[Idx.index()+1])
+        return SDValue();
+      SDValue Offset = Ptr.getOperand(1);
+      if (!Stride)
+        Stride = Offset;
+      else if (Offset != Stride)
+        return SDValue();
+    }
+    return Stride;
+  };
+
+  bool Reversed = false;
+  SDValue Stride = matchForwardStrided(Ptrs);
+  if (!Stride) {
+    Stride = matchReverseStrided(Ptrs);
+    Reversed = true;
+    // TODO: At this point, we've successfully matched a generalized gather
+    // load.  Maybe we should emit that, and then move the specialized
+    // matchers above and below into a DAG combine?
+    if (!Stride)
+      return SDValue();
   }
 
   // A special case is if the stride is exactly the width of one of the loads,
   // in which case it's contiguous and can be combined into a regular vle
   // without changing the element size
   if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
-      ConstStride &&
+      ConstStride && !Reversed &&
       ConstStride->getZExtValue() == BaseLdVT.getFixedSizeInBits() / 8) {
     MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
         BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(),
@@ -12962,6 +13000,8 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
   SDValue IntID =
       DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, Subtarget.getXLenVT());
+  if (Reversed)
+    Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
   SDValue Ops[] = {BaseLd->getChain(),
                    IntID,
                    DAG.getUNDEF(ContainerVT),
@@ -12970,7 +13010,8 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
                    VL};
 
   uint64_t MemSize;
-  if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride))
+  if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
+      ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
     // total size = (elsize * n) + (stride - elsize) * (n-1)
     //            = elsize + stride * (n-1)
     MemSize = WideScalarVT.getSizeInBits() +

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll
index decc992533b553..f52ba6f51d5c89 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll
@@ -494,25 +494,15 @@ define void @strided_constant_neg_4xv2f32(ptr %x, ptr %z, i64 %s) {
   ret void
 }
 
-; TODO: This is a strided load with a negative stride
+; This is a strided load with a negative stride
 define void @reverse_strided_constant_pos_4xv2f32(ptr %x, ptr %z, i64 %s) {
 ; CHECK-LABEL: reverse_strided_constant_pos_4xv2f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 64
-; CHECK-NEXT:    addi a3, a0, 128
-; CHECK-NEXT:    addi a4, a0, 192
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
-; CHECK-NEXT:    vle32.v v8, (a4)
-; CHECK-NEXT:    vle32.v v10, (a3)
-; CHECK-NEXT:    vle32.v v12, (a2)
-; CHECK-NEXT:    vle32.v v14, (a0)
-; CHECK-NEXT:    vsetivli zero, 4, e32, m2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 2
-; CHECK-NEXT:    vsetivli zero, 6, e32, m2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v12, 4
-; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v14, 6
-; CHECK-NEXT:    vse32.v v8, (a1)
+; CHECK-NEXT:    addi a0, a0, 192
+; CHECK-NEXT:    li a2, -64
+; CHECK-NEXT:    vsetivli zero, 4, e64, m2, ta, ma
+; CHECK-NEXT:    vlse64.v v8, (a0), a2
+; CHECK-NEXT:    vse64.v v8, (a1)
 ; CHECK-NEXT:    ret
   %x.1 = getelementptr i8, ptr %x, i64 64
   %x.2 = getelementptr i8, ptr %x.1, i64 64
@@ -531,21 +521,11 @@ define void @reverse_strided_constant_pos_4xv2f32(ptr %x, ptr %z, i64 %s) {
 define void @reverse_strided_constant_neg_4xv2f32(ptr %x, ptr %z, i64 %s) {
 ; CHECK-LABEL: reverse_strided_constant_neg_4xv2f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -64
-; CHECK-NEXT:    addi a3, a0, -128
-; CHECK-NEXT:    addi a4, a0, -192
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
-; CHECK-NEXT:    vle32.v v8, (a4)
-; CHECK-NEXT:    vle32.v v10, (a3)
-; CHECK-NEXT:    vle32.v v12, (a2)
-; CHECK-NEXT:    vle32.v v14, (a0)
-; CHECK-NEXT:    vsetivli zero, 4, e32, m2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 2
-; CHECK-NEXT:    vsetivli zero, 6, e32, m2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v12, 4
-; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v14, 6
-; CHECK-NEXT:    vse32.v v8, (a1)
+; CHECK-NEXT:    addi a0, a0, -192
+; CHECK-NEXT:    li a2, 64
+; CHECK-NEXT:    vsetivli zero, 4, e64, m2, ta, ma
+; CHECK-NEXT:    vlse64.v v8, (a0), a2
+; CHECK-NEXT:    vse64.v v8, (a1)
 ; CHECK-NEXT:    ret
   %x.1 = getelementptr i8, ptr %x, i64 -64
   %x.2 = getelementptr i8, ptr %x.1, i64 -64
@@ -561,25 +541,17 @@ define void @reverse_strided_constant_neg_4xv2f32(ptr %x, ptr %z, i64 %s) {
   ret void
 }
 
-; TODO: This is a strided load with a negative stride
+; This is a strided load with a negative stride
 define void @reverse_strided_runtime_4xv2f32(ptr %x, ptr %z, i64 %s) {
 ; CHECK-LABEL: reverse_strided_runtime_4xv2f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    add a3, a0, a2
-; CHECK-NEXT:    add a4, a3, a2
-; CHECK-NEXT:    add a2, a4, a2
-; CHECK-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
-; CHECK-NEXT:    vle32.v v8, (a2)
-; CHECK-NEXT:    vle32.v v10, (a4)
-; CHECK-NEXT:    vle32.v v12, (a3)
-; CHECK-NEXT:    vle32.v v14, (a0)
-; CHECK-NEXT:    vsetivli zero, 4, e32, m2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 2
-; CHECK-NEXT:    vsetivli zero, 6, e32, m2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v12, 4
-; CHECK-NEXT:    vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v14, 6
-; CHECK-NEXT:    vse32.v v8, (a1)
+; CHECK-NEXT:    add a0, a0, a2
+; CHECK-NEXT:    add a3, a2, a2
+; CHECK-NEXT:    add a0, a0, a3
+; CHECK-NEXT:    neg a2, a2
+; CHECK-NEXT:    vsetivli zero, 4, e64, m2, ta, ma
+; CHECK-NEXT:    vlse64.v v8, (a0), a2
+; CHECK-NEXT:    vse64.v v8, (a1)
 ; CHECK-NEXT:    ret
   %x.1 = getelementptr i8, ptr %x, i64 %s
   %x.2 = getelementptr i8, ptr %x.1, i64 %s


        


More information about the llvm-commits mailing list