[llvm] [RISCV] Improve performCONCAT_VECTORCombine stride matching (PR #68726)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 16 08:13:03 PDT 2023


https://github.com/michaelmaitland updated https://github.com/llvm/llvm-project/pull/68726

>From 4e24ef483e739a314f046f8a6797091cfd2d11c6 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Sat, 14 Oct 2023 13:14:48 -0400
Subject: [PATCH 1/4] [RISCV] Refactor performCONCAT_VECTORSCombine. NFC

Instead of doing a forward pass for positive strides and a reverse pass for
negative strides, we can just do one pass by negating the offset if the
pointers do happen to be in reverse order.

We can extend getPtrDiff later in #68726 to handle more constant offset
sequences.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 81 +++++++--------------
 1 file changed, 25 insertions(+), 56 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d7552317fd8bc69..9912f19c9a50191 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13785,11 +13785,10 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   EVT BaseLdVT = BaseLd->getValueType(0);
-  SDValue BasePtr = BaseLd->getBasePtr();
 
   // Go through the loads and check that they're strided
-  SmallVector<SDValue> Ptrs;
-  Ptrs.push_back(BasePtr);
+  SmallVector<LoadSDNode *> Lds;
+  Lds.push_back(BaseLd);
   Align Align = BaseLd->getAlign();
   for (SDValue Op : N->ops().drop_front()) {
     auto *Ld = dyn_cast<LoadSDNode>(Op);
@@ -13798,58 +13797,33 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
         Ld->getValueType(0) != BaseLdVT)
       return SDValue();
 
-    Ptrs.push_back(Ld->getBasePtr());
+    Lds.push_back(Ld);
 
     // The common alignment is the most restrictive (smallest) of all the loads
     Align = std::min(Align, Ld->getAlign());
   }
 
-  auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == 0)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add LastPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()-1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
-  };
-  auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
-    SDValue Stride;
-    for (auto Idx : enumerate(Ptrs)) {
-      if (Idx.index() == Ptrs.size() - 1)
-        continue;
-      SDValue Ptr = Idx.value();
-      // Check that each load's pointer is (add NextPtr, Stride)
-      if (Ptr.getOpcode() != ISD::ADD ||
-          Ptr.getOperand(0) != Ptrs[Idx.index()+1])
-        return SDValue();
-      SDValue Offset = Ptr.getOperand(1);
-      if (!Stride)
-        Stride = Offset;
-      else if (Offset != Stride)
-        return SDValue();
-    }
-    return Stride;
+  auto getPtrDiff = [&DAG, &DL](LoadSDNode *Ld1, LoadSDNode *Ld2) {
+    SDValue P1 = Ld1->getBasePtr();
+    SDValue P2 = Ld2->getBasePtr();
+    if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
+      return P2.getOperand(1);
+    if (P1.getOpcode() == ISD::ADD && P1.getOperand(0) == P2)
+      return DAG.getNegative(P1.getOperand(1), DL,
+                             P1.getOperand(1).getValueType());
+    return SDValue();
   };
 
-  bool Reversed = false;
-  SDValue Stride = matchForwardStrided(Ptrs);
-  if (!Stride) {
-    Stride = matchReverseStrided(Ptrs);
-    Reversed = true;
-    // TODO: At this point, we've successfully matched a generalized gather
-    // load.  Maybe we should emit that, and then move the specialized
-    // matchers above and below into a DAG combine?
+  SDValue Stride;
+  for (auto [Idx, Ld] : enumerate(Lds)) {
+    if (Idx == 0)
+      continue;
+    SDValue Offset = getPtrDiff(Lds[Idx - 1], Ld);
+    if (!Offset)
+      return SDValue();
     if (!Stride)
+      Stride = Offset;
+    else if (Offset != Stride)
       return SDValue();
   }
 
@@ -13871,22 +13845,17 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDValue IntID =
     DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
                           Subtarget.getXLenVT());
-  if (Reversed)
-    Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+
   SDValue AllOneMask =
     DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
                  DAG.getConstant(1, DL, MVT::i1));
 
-  SDValue Ops[] = {BaseLd->getChain(),
-                   IntID,
-                   DAG.getUNDEF(WideVecVT),
-                   BasePtr,
-                   Stride,
-                   AllOneMask};
+  SDValue Ops[] = {BaseLd->getChain(),   IntID,  DAG.getUNDEF(WideVecVT),
+                   BaseLd->getBasePtr(), Stride, AllOneMask};
 
   uint64_t MemSize;
   if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
-      ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
+      ConstStride && ConstStride->getSExtValue() >= 0)
     // total size = (elsize * n) + (stride - elsize) * (n-1)
     //            = elsize + stride * (n-1)
     MemSize = WideScalarVT.getSizeInBits() +

>From 7fb86f69e3ebe3fe1d38accc8bee88dc54a7421e Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Sun, 15 Oct 2023 08:49:46 -0700
Subject: [PATCH 2/4] [RISCV] Pre-commit concat-vectors-constant-stride.ll

This patch commits tests that can be optimized by improving
performCONCAT_VECTORCombine to do a better job at decomposing the base
pointer and recognizing a constant offset.
---
 .../rvv/concat-vectors-constant-stride.ll     | 231 ++++++++++++++++++
 1 file changed, 231 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll

diff --git a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
new file mode 100644
index 000000000000000..611270ab98ebdaf
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
@@ -0,0 +1,231 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v,+unaligned-vector-mem -target-abi=ilp32 \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+unaligned-vector-mem -target-abi=lp64 \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
+
+define void @constant_forward_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = load <2 x i8>, ptr %s, align 1
+  %5 = load <2 x i8>, ptr %1, align 1
+  %6 = load <2 x i8>, ptr %2, align 1
+  %7 = load <2 x i8>, ptr %3, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_forward_stride2(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a4)
+; CHECK-NEXT:    vle8.v v9, (a3)
+; CHECK-NEXT:    vle8.v v10, (a2)
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = load <2 x i8>, ptr %3, align 1
+  %5 = load <2 x i8>, ptr %2, align 1
+  %6 = load <2 x i8>, ptr %1, align 1
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_forward_stride3(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = getelementptr inbounds i8, ptr %1, i64 0
+  %5 = getelementptr inbounds i8, ptr %2, i64 0
+  %6 = getelementptr inbounds i8, ptr %3, i64 0
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = load <2 x i8>, ptr %4, align 1
+  %9 = load <2 x i8>, ptr %5, align 1
+  %10 = load <2 x i8>, ptr %6, align 1
+  %11 = shufflevector <2 x i8> %7, <2 x i8> %8, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %12 = shufflevector <2 x i8> %9, <2 x i8> %10, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %13 = shufflevector <4 x i8> %11, <4 x i8> %12, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %13, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = load <2 x i8>, ptr %s, align 1
+  %5 = load <2 x i8>, ptr %1, align 1
+  %6 = load <2 x i8>, ptr %2, align 1
+  %7 = load <2 x i8>, ptr %3, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride2(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a4)
+; CHECK-NEXT:    vle8.v v9, (a3)
+; CHECK-NEXT:    vle8.v v10, (a2)
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = load <2 x i8>, ptr %3, align 1
+  %5 = load <2 x i8>, ptr %2, align 1
+  %6 = load <2 x i8>, ptr %1, align 1
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride3(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = getelementptr inbounds i8, ptr %1, i64 0
+  %5 = getelementptr inbounds i8, ptr %2, i64 0
+  %6 = getelementptr inbounds i8, ptr %3, i64 0
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = load <2 x i8>, ptr %4, align 1
+  %9 = load <2 x i8>, ptr %5, align 1
+  %10 = load <2 x i8>, ptr %6, align 1
+  %11 = shufflevector <2 x i8> %7, <2 x i8> %8, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %12 = shufflevector <2 x i8> %9, <2 x i8> %10, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %13 = shufflevector <4 x i8> %11, <4 x i8> %12, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %13, ptr %d, align 1
+  ret void
+}
+
+define void @constant_zero_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_zero_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-NEXT:    vmv1r.v v9, v8
+; CHECK-NEXT:    vslideup.vi v9, v8, 2
+; CHECK-NEXT:    vse8.v v9, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 0
+  %2 = load <2 x i8>, ptr %s, align 1
+  %3 = load <2 x i8>, ptr %1, align 1
+  %4 = shufflevector <2 x i8> %2, <2 x i8> %3, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %4, ptr %d, align 1
+  ret void
+}
+
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}

>From 100702c5861328019ae412026e40f00649a4e441 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Sun, 15 Oct 2023 09:00:04 -0700
Subject: [PATCH 3/4] [RISCV] Improve performCONCAT_VECTORCombine stride
 matching

If the load ptrs can be decomposed into a common (Base + Index) with a
common constant stride, then return the constant stride.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  11 ++
 .../rvv/concat-vectors-constant-stride.ll     | 116 ++++--------------
 2 files changed, 37 insertions(+), 90 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 9912f19c9a50191..408dcb91578b905 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -27,6 +27,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -13804,6 +13805,16 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   }
 
   auto getPtrDiff = [&DAG, &DL](LoadSDNode *Ld1, LoadSDNode *Ld2) {
+    // If the load ptrs can be decomposed into a common (Base + Index) with a
+    // common constant stride, then return the constant stride.
+    BaseIndexOffset BIO1 = BaseIndexOffset::match(Ld1, DAG);
+    BaseIndexOffset BIO2 = BaseIndexOffset::match(Ld2, DAG);
+    if (BIO1.hasValidOffset() && BIO2.hasValidOffset() &&
+        BIO1.equalBaseIndex(BIO2, DAG))
+      return DAG.getConstant(BIO2.getOffset() - BIO1.getOffset(), DL,
+                             Ld1->getOffset().getValueType());
+
+    // Otherwise try to match (add LastPtr, Stride) or (add NextPtr, Stride)
     SDValue P1 = Ld1->getBasePtr();
     SDValue P2 = Ld2->getBasePtr();
     if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
diff --git a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
index 611270ab98ebdaf..ff35043dbd7e75e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
@@ -7,21 +7,10 @@
 define void @constant_forward_stride(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_forward_stride:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 16
-; CHECK-NEXT:    addi a3, a0, 32
-; CHECK-NEXT:    addi a4, a0, 48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, 16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 16
   %2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -40,21 +29,11 @@ define void @constant_forward_stride(ptr %s, ptr %d) {
 define void @constant_forward_stride2(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_forward_stride2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -16
-; CHECK-NEXT:    addi a3, a0, -32
-; CHECK-NEXT:    addi a4, a0, -48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a4)
-; CHECK-NEXT:    vle8.v v9, (a3)
-; CHECK-NEXT:    vle8.v v10, (a2)
-; CHECK-NEXT:    vle8.v v11, (a0)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    addi a0, a0, -48
+; CHECK-NEXT:    li a2, 16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 -16
   %2 = getelementptr inbounds i8, ptr %s, i64 -32
@@ -73,21 +52,10 @@ define void @constant_forward_stride2(ptr %s, ptr %d) {
 define void @constant_forward_stride3(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_forward_stride3:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 16
-; CHECK-NEXT:    addi a3, a0, 32
-; CHECK-NEXT:    addi a4, a0, 48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, 16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 16
   %2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -109,21 +77,10 @@ define void @constant_forward_stride3(ptr %s, ptr %d) {
 define void @constant_back_stride(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_back_stride:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -16
-; CHECK-NEXT:    addi a3, a0, -32
-; CHECK-NEXT:    addi a4, a0, -48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, -16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 -16
   %2 = getelementptr inbounds i8, ptr %s, i64 -32
@@ -142,21 +99,11 @@ define void @constant_back_stride(ptr %s, ptr %d) {
 define void @constant_back_stride2(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_back_stride2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 16
-; CHECK-NEXT:    addi a3, a0, 32
-; CHECK-NEXT:    addi a4, a0, 48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a4)
-; CHECK-NEXT:    vle8.v v9, (a3)
-; CHECK-NEXT:    vle8.v v10, (a2)
-; CHECK-NEXT:    vle8.v v11, (a0)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    addi a0, a0, 48
+; CHECK-NEXT:    li a2, -16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 16
   %2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -175,21 +122,10 @@ define void @constant_back_stride2(ptr %s, ptr %d) {
 define void @constant_back_stride3(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_back_stride3:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -16
-; CHECK-NEXT:    addi a3, a0, -32
-; CHECK-NEXT:    addi a4, a0, -48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, -16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 -16
   %2 = getelementptr inbounds i8, ptr %s, i64 -32

>From 7eb3b59014c5e7de4eb4af9ebe0e15278151abc3 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Mon, 16 Oct 2023 08:12:45 -0700
Subject: [PATCH 4/4] remove hasValidOffsetCheck

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 408dcb91578b905..8d78d6d381468df 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13809,8 +13809,7 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     // common constant stride, then return the constant stride.
     BaseIndexOffset BIO1 = BaseIndexOffset::match(Ld1, DAG);
     BaseIndexOffset BIO2 = BaseIndexOffset::match(Ld2, DAG);
-    if (BIO1.hasValidOffset() && BIO2.hasValidOffset() &&
-        BIO1.equalBaseIndex(BIO2, DAG))
+    if (BIO1.equalBaseIndex(BIO2, DAG))
       return DAG.getConstant(BIO2.getOffset() - BIO1.getOffset(), DL,
                              Ld1->getOffset().getValueType());
 



More information about the llvm-commits mailing list