[llvm] [RISCV] Improve performCONCAT_VECTORCombine stride matching (PR #68726)

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 10 17:39:41 PDT 2023


https://github.com/michaelmaitland updated https://github.com/llvm/llvm-project/pull/68726

>From 35da9430bd4eec0488e3a8c13734ea059f4beb29 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Thu, 5 Oct 2023 19:16:08 -0700
Subject: [PATCH 1/4] [RISCV] Pre-commit concat-vectors-constant-stride
 tests.ll

This patch commits tests that can be optimized by improving
performCONCAT_VECTORCombine to do a better job at decomposing the base
pointer and recognizing a constant offset.
---
 .../rvv/concat-vectors-constant-stride.ll     | 231 ++++++++++++++++++
 1 file changed, 231 insertions(+)
 create mode 100644 llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll

diff --git a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
new file mode 100644
index 000000000000000..611270ab98ebdaf
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
@@ -0,0 +1,231 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v,+unaligned-vector-mem -target-abi=ilp32 \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+unaligned-vector-mem -target-abi=lp64 \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
+
+define void @constant_forward_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = load <2 x i8>, ptr %s, align 1
+  %5 = load <2 x i8>, ptr %1, align 1
+  %6 = load <2 x i8>, ptr %2, align 1
+  %7 = load <2 x i8>, ptr %3, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_forward_stride2(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a4)
+; CHECK-NEXT:    vle8.v v9, (a3)
+; CHECK-NEXT:    vle8.v v10, (a2)
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = load <2 x i8>, ptr %3, align 1
+  %5 = load <2 x i8>, ptr %2, align 1
+  %6 = load <2 x i8>, ptr %1, align 1
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_forward_stride3(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = getelementptr inbounds i8, ptr %1, i64 0
+  %5 = getelementptr inbounds i8, ptr %2, i64 0
+  %6 = getelementptr inbounds i8, ptr %3, i64 0
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = load <2 x i8>, ptr %4, align 1
+  %9 = load <2 x i8>, ptr %5, align 1
+  %10 = load <2 x i8>, ptr %6, align 1
+  %11 = shufflevector <2 x i8> %7, <2 x i8> %8, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %12 = shufflevector <2 x i8> %9, <2 x i8> %10, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %13 = shufflevector <4 x i8> %11, <4 x i8> %12, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %13, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = load <2 x i8>, ptr %s, align 1
+  %5 = load <2 x i8>, ptr %1, align 1
+  %6 = load <2 x i8>, ptr %2, align 1
+  %7 = load <2 x i8>, ptr %3, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride2(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a4)
+; CHECK-NEXT:    vle8.v v9, (a3)
+; CHECK-NEXT:    vle8.v v10, (a2)
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = load <2 x i8>, ptr %3, align 1
+  %5 = load <2 x i8>, ptr %2, align 1
+  %6 = load <2 x i8>, ptr %1, align 1
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride3(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = getelementptr inbounds i8, ptr %1, i64 0
+  %5 = getelementptr inbounds i8, ptr %2, i64 0
+  %6 = getelementptr inbounds i8, ptr %3, i64 0
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = load <2 x i8>, ptr %4, align 1
+  %9 = load <2 x i8>, ptr %5, align 1
+  %10 = load <2 x i8>, ptr %6, align 1
+  %11 = shufflevector <2 x i8> %7, <2 x i8> %8, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %12 = shufflevector <2 x i8> %9, <2 x i8> %10, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %13 = shufflevector <4 x i8> %11, <4 x i8> %12, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %13, ptr %d, align 1
+  ret void
+}
+
+define void @constant_zero_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_zero_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-NEXT:    vmv1r.v v9, v8
+; CHECK-NEXT:    vslideup.vi v9, v8, 2
+; CHECK-NEXT:    vse8.v v9, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 0
+  %2 = load <2 x i8>, ptr %s, align 1
+  %3 = load <2 x i8>, ptr %1, align 1
+  %4 = shufflevector <2 x i8> %2, <2 x i8> %3, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %4, ptr %d, align 1
+  ret void
+}
+
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}

>From 9e87a89723a83f589485d3e99939e949903bb5ce Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Mon, 9 Oct 2023 13:24:25 -0700
Subject: [PATCH 2/4] [RISCV] Improve performCONCAT_VECTORCombine stride
 matching

If the load ptrs can be decomposed into a common (Base + Index) with a
common constant stride, then return the constant stride. This matcher
enables some additional optimization since BaseIndexOffset is capable of
decomposing the load ptrs to (add (add Base, Index), Stride) instead of
(add LastPtr, Stride) or (add NextPtr, Stride) that matchForwardStrided and
matchReverseStrided use, respectively.
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   |  80 ++++++++++--
 .../rvv/concat-vectors-constant-stride.ll     | 116 ++++--------------
 2 files changed, 95 insertions(+), 101 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6be3fa71479be5c..62719e4946a9ff6 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -27,6 +27,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -13821,6 +13822,58 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     Align = std::min(Align, Ld->getAlign());
   }
 
+  // If the load ptrs can be decomposed into a common (Base + Index) with a
+  // common constant stride, then return the constant stride. This matcher
+  // enables some additional optimization since BaseIndexOffset is capable of
+  // decomposing the load ptrs to (add (add Base, Index), Stride) instead of
+  // (add LastPtr, Stride) or (add NextPtr, Stride) that matchForwardStrided and
+  // matchReverseStrided use respectively.
+  auto matchConstantStride = [&DAG, &N](ArrayRef<SDUse> Loads) {
+    // Initialize match constraints based on the first load. Initialize
+    // ConstStride by taking the difference between the offset of the first two
+    // loads.
+    if (Loads.size() < 2)
+      return SDValue();
+    BaseIndexOffset BaseLdBIO =
+        BaseIndexOffset::match(cast<LoadSDNode>(Loads[0]), DAG);
+    BaseIndexOffset LastLdBIO =
+        BaseIndexOffset::match(cast<LoadSDNode>(Loads[1]), DAG);
+    bool AllValidOffset =
+        BaseLdBIO.hasValidOffset() && LastLdBIO.hasValidOffset();
+    if (!AllValidOffset)
+      return SDValue();
+    bool BaseIndexMatch = BaseLdBIO.equalBaseIndex(LastLdBIO, DAG);
+    if (!BaseIndexMatch)
+      return SDValue();
+    int64_t ConstStride = LastLdBIO.getOffset() - BaseLdBIO.getOffset();
+
+    // Check that constraints hold for all subsequent loads and the ConstStride
+    // is the same.
+    for (auto Idx : enumerate(Loads.drop_front(2))) {
+      auto *Ld = cast<LoadSDNode>(Idx.value());
+      BaseIndexOffset BIO = BaseIndexOffset::match(Ld, DAG);
+      AllValidOffset &= BIO.hasValidOffset();
+      if (!AllValidOffset)
+        return SDValue();
+      BaseIndexMatch |= BaseLdBIO.equalBaseIndex(BIO, DAG);
+      // Add 3 to index because the first two loads have been processed before
+      // the loop.
+      bool StrideMatches =
+          ConstStride == BIO.getOffset() - LastLdBIO.getOffset();
+      if (!BaseIndexMatch || !StrideMatches)
+        return SDValue();
+      LastLdBIO = BIO;
+    }
+
+    // The match is a success if all the constraints hold.
+    if (BaseIndexMatch && AllValidOffset)
+      return DAG.getConstant(
+          ConstStride, SDLoc(N),
+          cast<LoadSDNode>(N->getOperand(0))->getOffset().getValueType());
+
+    // The match failed.
+    return SDValue();
+  };
   auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
     SDValue Stride;
     for (auto Idx : enumerate(Ptrs)) {
@@ -13862,13 +13915,21 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDValue Stride = matchForwardStrided(Ptrs);
   if (!Stride) {
     Stride = matchReverseStrided(Ptrs);
-    Reversed = true;
-    // TODO: At this point, we've successfully matched a generalized gather
-    // load.  Maybe we should emit that, and then move the specialized
-    // matchers above and below into a DAG combine?
-    if (!Stride)
-      return SDValue();
+    if (Stride) {
+      Reversed = true;
+      Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+    } else {
+      Stride = matchConstantStride(N->ops());
+      if (Stride) {
+        Reversed = cast<ConstantSDNode>(Stride)->getSExtValue() < 0;
+      } else {
+        return SDValue();
+      }
+    }
   }
+  // TODO: At this point, we've successfully matched a generalized gather
+  // load.  Maybe we should emit that, and then move the specialized
+  // matchers above and below into a DAG combine?
 
   // Get the widened scalar type, e.g. v4i8 -> i64
   unsigned WideScalarBitWidth =
@@ -13885,11 +13946,8 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
-  SDValue IntID =
-    DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
-                          Subtarget.getXLenVT());
-  if (Reversed)
-    Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+  SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load,
+                                        DL, Subtarget.getXLenVT());
   SDValue AllOneMask =
     DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
                  DAG.getConstant(1, DL, MVT::i1));
diff --git a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
index 611270ab98ebdaf..ff35043dbd7e75e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
@@ -7,21 +7,10 @@
 define void @constant_forward_stride(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_forward_stride:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 16
-; CHECK-NEXT:    addi a3, a0, 32
-; CHECK-NEXT:    addi a4, a0, 48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, 16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 16
   %2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -40,21 +29,11 @@ define void @constant_forward_stride(ptr %s, ptr %d) {
 define void @constant_forward_stride2(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_forward_stride2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -16
-; CHECK-NEXT:    addi a3, a0, -32
-; CHECK-NEXT:    addi a4, a0, -48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a4)
-; CHECK-NEXT:    vle8.v v9, (a3)
-; CHECK-NEXT:    vle8.v v10, (a2)
-; CHECK-NEXT:    vle8.v v11, (a0)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    addi a0, a0, -48
+; CHECK-NEXT:    li a2, 16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 -16
   %2 = getelementptr inbounds i8, ptr %s, i64 -32
@@ -73,21 +52,10 @@ define void @constant_forward_stride2(ptr %s, ptr %d) {
 define void @constant_forward_stride3(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_forward_stride3:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 16
-; CHECK-NEXT:    addi a3, a0, 32
-; CHECK-NEXT:    addi a4, a0, 48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, 16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 16
   %2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -109,21 +77,10 @@ define void @constant_forward_stride3(ptr %s, ptr %d) {
 define void @constant_back_stride(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_back_stride:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -16
-; CHECK-NEXT:    addi a3, a0, -32
-; CHECK-NEXT:    addi a4, a0, -48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, -16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 -16
   %2 = getelementptr inbounds i8, ptr %s, i64 -32
@@ -142,21 +99,11 @@ define void @constant_back_stride(ptr %s, ptr %d) {
 define void @constant_back_stride2(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_back_stride2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 16
-; CHECK-NEXT:    addi a3, a0, 32
-; CHECK-NEXT:    addi a4, a0, 48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a4)
-; CHECK-NEXT:    vle8.v v9, (a3)
-; CHECK-NEXT:    vle8.v v10, (a2)
-; CHECK-NEXT:    vle8.v v11, (a0)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    addi a0, a0, 48
+; CHECK-NEXT:    li a2, -16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 16
   %2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -175,21 +122,10 @@ define void @constant_back_stride2(ptr %s, ptr %d) {
 define void @constant_back_stride3(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_back_stride3:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -16
-; CHECK-NEXT:    addi a3, a0, -32
-; CHECK-NEXT:    addi a4, a0, -48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, -16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 -16
   %2 = getelementptr inbounds i8, ptr %s, i64 -32

>From 7af5d8ceca3358215ab06109526f5b76af003a81 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 10 Oct 2023 13:29:09 -0700
Subject: [PATCH 3/4] Remove stale comment; Remove enumerate

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 62719e4946a9ff6..e0bff8ef2907d33 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13849,15 +13849,13 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
 
     // Check that constraints hold for all subsequent loads and the ConstStride
     // is the same.
-    for (auto Idx : enumerate(Loads.drop_front(2))) {
-      auto *Ld = cast<LoadSDNode>(Idx.value());
+    for (auto &Use : Loads.drop_front(2)) {
+      auto *Ld = cast<LoadSDNode>(Use);
       BaseIndexOffset BIO = BaseIndexOffset::match(Ld, DAG);
       AllValidOffset &= BIO.hasValidOffset();
       if (!AllValidOffset)
         return SDValue();
       BaseIndexMatch |= BaseLdBIO.equalBaseIndex(BIO, DAG);
-      // Add 3 to index because the first two loads have been processed before
-      // the loop.
       bool StrideMatches =
           ConstStride == BIO.getOffset() - LastLdBIO.getOffset();
       if (!BaseIndexMatch || !StrideMatches)

>From 7b1867ed8fe0735f7d1d21f0c33fce323c34ead1 Mon Sep 17 00:00:00 2001
From: Michael Maitland <michaeltmaitland at gmail.com>
Date: Tue, 10 Oct 2023 17:39:10 -0700
Subject: [PATCH 4/4] Keep getNegative until we know we're going to do the
 optimization; Remove non-necessary constraint variables

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 40 +++++++--------------
 1 file changed, 13 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index e0bff8ef2907d33..234df430bdda480 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13838,12 +13838,9 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
         BaseIndexOffset::match(cast<LoadSDNode>(Loads[0]), DAG);
     BaseIndexOffset LastLdBIO =
         BaseIndexOffset::match(cast<LoadSDNode>(Loads[1]), DAG);
-    bool AllValidOffset =
-        BaseLdBIO.hasValidOffset() && LastLdBIO.hasValidOffset();
-    if (!AllValidOffset)
+    if (!BaseLdBIO.hasValidOffset() || !LastLdBIO.hasValidOffset())
       return SDValue();
-    bool BaseIndexMatch = BaseLdBIO.equalBaseIndex(LastLdBIO, DAG);
-    if (!BaseIndexMatch)
+    if (!BaseLdBIO.equalBaseIndex(LastLdBIO, DAG))
       return SDValue();
     int64_t ConstStride = LastLdBIO.getOffset() - BaseLdBIO.getOffset();
 
@@ -13852,25 +13849,16 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     for (auto &Use : Loads.drop_front(2)) {
       auto *Ld = cast<LoadSDNode>(Use);
       BaseIndexOffset BIO = BaseIndexOffset::match(Ld, DAG);
-      AllValidOffset &= BIO.hasValidOffset();
-      if (!AllValidOffset)
-        return SDValue();
-      BaseIndexMatch |= BaseLdBIO.equalBaseIndex(BIO, DAG);
-      bool StrideMatches =
-          ConstStride == BIO.getOffset() - LastLdBIO.getOffset();
-      if (!BaseIndexMatch || !StrideMatches)
+      if (!BIO.hasValidOffset() || !BaseLdBIO.equalBaseIndex(BIO, DAG)
+          || ConstStride != BIO.getOffset() - LastLdBIO.getOffset())
         return SDValue();
       LastLdBIO = BIO;
     }
 
     // The match is a success if all the constraints hold.
-    if (BaseIndexMatch && AllValidOffset)
-      return DAG.getConstant(
-          ConstStride, SDLoc(N),
-          cast<LoadSDNode>(N->getOperand(0))->getOffset().getValueType());
-
-    // The match failed.
-    return SDValue();
+    return DAG.getConstant(
+        ConstStride, SDLoc(N),
+        cast<LoadSDNode>(N->getOperand(0))->getOffset().getValueType());
   };
   auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
     SDValue Stride;
@@ -13909,20 +13897,16 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     return Stride;
   };
 
-  bool Reversed = false;
+  bool NeedToNegateStride = false;
   SDValue Stride = matchForwardStrided(Ptrs);
   if (!Stride) {
     Stride = matchReverseStrided(Ptrs);
     if (Stride) {
-      Reversed = true;
-      Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+      NeedToNegateStride = true;
     } else {
       Stride = matchConstantStride(N->ops());
-      if (Stride) {
-        Reversed = cast<ConstantSDNode>(Stride)->getSExtValue() < 0;
-      } else {
+      if(!Stride)
         return SDValue();
-      }
     }
   }
   // TODO: At this point, we've successfully matched a generalized gather
@@ -13946,6 +13930,8 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
   SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load,
                                         DL, Subtarget.getXLenVT());
+  if (NeedToNegateStride)
+    Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
   SDValue AllOneMask =
     DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
                  DAG.getConstant(1, DL, MVT::i1));
@@ -13959,7 +13945,7 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
 
   uint64_t MemSize;
   if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
-      ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
+      ConstStride && !NeedToNegateStride && ConstStride->getSExtValue() >= 0)
     // total size = (elsize * n) + (stride - elsize) * (n-1)
     //            = elsize + stride * (n-1)
     MemSize = WideScalarVT.getSizeInBits() +



More information about the llvm-commits mailing list