[llvm] Comcat vectors combine (PR #68726)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 10 10:01:36 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Michael Maitland (michaelmaitland)

<details>
<summary>Changes</summary>

I intend to commit the two commits in this patch independently:

    Pre-commit concat-vectors-constant-stride tests.ll

    This patch commits tests that can be optimized by improving
    performCONCAT_VECTORCombine to do a better job at decomposing the base
    pointer and recognizing a constant offset.

    Improve performCONCAT_VECTORCombine stride matching

    If the load ptrs can be decomposed into a common (Base + Index) with a
    common constant stride, then return the constant stride. This matcher
    enables some additional optimization since BaseIndexOffset is capable of
    decomposing the load ptrs to (add (add Base, Index), Stride) instead of
    (add LastPtr, Stride) or (add NextPtr, Stride) that matchForwardStrided and
    matchReverseStrided use, respectively.

---
Full diff: https://github.com/llvm/llvm-project/pull/68726.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+67-8) 
- (added) llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll (+231) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6be3fa71479be5c..955674bd13e1c79 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -27,6 +27,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -13821,6 +13822,58 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     Align = std::min(Align, Ld->getAlign());
   }
 
+   // If the load ptrs can be decomposed into a common (Base + Index) with a
+  // common constant stride, then return the constant stride. This matcher
+  // enables some additional optimization since BaseIndexOffset is capable of
+  // decomposing the load ptrs to (add (add Base, Index), Stride) instead of
+  // (add LastPtr, Stride) or (add NextPtr, Stride) that matchForwardStrided and
+  // matchReverseStrided use respectively.
+  auto matchConstantStride = [&DAG, &N](ArrayRef<SDUse> Loads) {
+    // Initialize match constraints based on the first load. Initialize
+    // ConstStride by taking the difference between the offset of the first two
+    // loads.
+    if (Loads.size() < 2)
+      return SDValue();
+    BaseIndexOffset BaseLdBIO =
+        BaseIndexOffset::match(cast<LoadSDNode>(Loads[0]), DAG);
+    BaseIndexOffset LastLdBIO =
+        BaseIndexOffset::match(cast<LoadSDNode>(Loads[1]), DAG);
+    bool AllValidOffset =
+        BaseLdBIO.hasValidOffset() && LastLdBIO.hasValidOffset();
+    if (!AllValidOffset)
+      return SDValue();
+    bool BaseIndexMatch = BaseLdBIO.equalBaseIndex(LastLdBIO, DAG);
+    if (!BaseIndexMatch)
+      return SDValue();
+    int64_t ConstStride = LastLdBIO.getOffset() - BaseLdBIO.getOffset();
+
+    // Check that constraints hold for all subsequent loads and the ConstStride
+    // is the same.
+    for (auto Idx : enumerate(Loads.drop_front(2))) {
+      auto *Ld = cast<LoadSDNode>(Idx.value());
+      BaseIndexOffset BIO = BaseIndexOffset::match(Ld, DAG);
+      AllValidOffset &= BIO.hasValidOffset();
+      if (!AllValidOffset)
+        return SDValue();
+      BaseIndexMatch |= BaseLdBIO.equalBaseIndex(BIO, DAG);
+      // Add 3 to index because the first two loads have been processed before
+      // the loop.
+      bool StrideMatches =
+          ConstStride == BIO.getOffset() - LastLdBIO.getOffset();
+      if (!BaseIndexMatch || !StrideMatches)
+        return SDValue();
+      LastLdBIO = BIO;
+    }
+
+    // The match is a success if all the constraints hold.
+    if (BaseIndexMatch && AllValidOffset)
+      return DAG.getConstant(
+          ConstStride, SDLoc(N),
+          cast<LoadSDNode>(N->getOperand(0))->getOffset().getValueType());
+
+    // The match failed.
+    return SDValue();
+  };
   auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
     SDValue Stride;
     for (auto Idx : enumerate(Ptrs)) {
@@ -13862,13 +13915,21 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDValue Stride = matchForwardStrided(Ptrs);
   if (!Stride) {
     Stride = matchReverseStrided(Ptrs);
-    Reversed = true;
-    // TODO: At this point, we've successfully matched a generalized gather
-    // load.  Maybe we should emit that, and then move the specialized
-    // matchers above and below into a DAG combine?
-    if (!Stride)
-      return SDValue();
+      if (Stride) {
+      Reversed = true;
+      Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+    } else {
+      Stride = matchConstantStride(N->ops());
+      if (Stride) {
+        Reversed = cast<ConstantSDNode>(Stride)->getSExtValue() < 0;
+      } else {
+          return SDValue();
+      }
+    }
   }
+  // TODO: At this point, we've successfully matched a generalized gather
+  // load.  Maybe we should emit that, and then move the specialized
+  // matchers above and below into a DAG combine?
 
   // Get the widened scalar type, e.g. v4i8 -> i64
   unsigned WideScalarBitWidth =
@@ -13888,8 +13949,6 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   SDValue IntID =
     DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
                           Subtarget.getXLenVT());
-  if (Reversed)
-    Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
   SDValue AllOneMask =
     DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
                  DAG.getConstant(1, DL, MVT::i1));
diff --git a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
new file mode 100644
index 000000000000000..611270ab98ebdaf
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
@@ -0,0 +1,231 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+v,+unaligned-vector-mem -target-abi=ilp32 \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV32
+; RUN: llc -mtriple=riscv64 -mattr=+v,+unaligned-vector-mem -target-abi=lp64 \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK,RV64
+
+define void @constant_forward_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = load <2 x i8>, ptr %s, align 1
+  %5 = load <2 x i8>, ptr %1, align 1
+  %6 = load <2 x i8>, ptr %2, align 1
+  %7 = load <2 x i8>, ptr %3, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_forward_stride2(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a4)
+; CHECK-NEXT:    vle8.v v9, (a3)
+; CHECK-NEXT:    vle8.v v10, (a2)
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = load <2 x i8>, ptr %3, align 1
+  %5 = load <2 x i8>, ptr %2, align 1
+  %6 = load <2 x i8>, ptr %1, align 1
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_forward_stride3(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_forward_stride3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = getelementptr inbounds i8, ptr %1, i64 0
+  %5 = getelementptr inbounds i8, ptr %2, i64 0
+  %6 = getelementptr inbounds i8, ptr %3, i64 0
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = load <2 x i8>, ptr %4, align 1
+  %9 = load <2 x i8>, ptr %5, align 1
+  %10 = load <2 x i8>, ptr %6, align 1
+  %11 = shufflevector <2 x i8> %7, <2 x i8> %8, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %12 = shufflevector <2 x i8> %9, <2 x i8> %10, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %13 = shufflevector <4 x i8> %11, <4 x i8> %12, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %13, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = load <2 x i8>, ptr %s, align 1
+  %5 = load <2 x i8>, ptr %1, align 1
+  %6 = load <2 x i8>, ptr %2, align 1
+  %7 = load <2 x i8>, ptr %3, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride2(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, 16
+; CHECK-NEXT:    addi a3, a0, 32
+; CHECK-NEXT:    addi a4, a0, 48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a4)
+; CHECK-NEXT:    vle8.v v9, (a3)
+; CHECK-NEXT:    vle8.v v10, (a2)
+; CHECK-NEXT:    vle8.v v11, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 16
+  %2 = getelementptr inbounds i8, ptr %s, i64 32
+  %3 = getelementptr inbounds i8, ptr %s, i64 48
+  %4 = load <2 x i8>, ptr %3, align 1
+  %5 = load <2 x i8>, ptr %2, align 1
+  %6 = load <2 x i8>, ptr %1, align 1
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = shufflevector <2 x i8> %4, <2 x i8> %5, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %9 = shufflevector <2 x i8> %6, <2 x i8> %7, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %10 = shufflevector <4 x i8> %8, <4 x i8> %9, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %10, ptr %d, align 1
+  ret void
+}
+
+define void @constant_back_stride3(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_back_stride3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a2, a0, -16
+; CHECK-NEXT:    addi a3, a0, -32
+; CHECK-NEXT:    addi a4, a0, -48
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vle8.v v9, (a2)
+; CHECK-NEXT:    vle8.v v10, (a3)
+; CHECK-NEXT:    vle8.v v11, (a4)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v9, 2
+; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
+; CHECK-NEXT:    vslideup.vi v8, v10, 4
+; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT:    vslideup.vi v8, v11, 6
+; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 -16
+  %2 = getelementptr inbounds i8, ptr %s, i64 -32
+  %3 = getelementptr inbounds i8, ptr %s, i64 -48
+  %4 = getelementptr inbounds i8, ptr %1, i64 0
+  %5 = getelementptr inbounds i8, ptr %2, i64 0
+  %6 = getelementptr inbounds i8, ptr %3, i64 0
+  %7 = load <2 x i8>, ptr %s, align 1
+  %8 = load <2 x i8>, ptr %4, align 1
+  %9 = load <2 x i8>, ptr %5, align 1
+  %10 = load <2 x i8>, ptr %6, align 1
+  %11 = shufflevector <2 x i8> %7, <2 x i8> %8, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %12 = shufflevector <2 x i8> %9, <2 x i8> %10, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %13 = shufflevector <4 x i8> %11, <4 x i8> %12, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  store <8 x i8> %13, ptr %d, align 1
+  ret void
+}
+
+define void @constant_zero_stride(ptr %s, ptr %d) {
+; CHECK-LABEL: constant_zero_stride:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
+; CHECK-NEXT:    vle8.v v8, (a0)
+; CHECK-NEXT:    vsetivli zero, 4, e8, mf4, ta, ma
+; CHECK-NEXT:    vmv1r.v v9, v8
+; CHECK-NEXT:    vslideup.vi v9, v8, 2
+; CHECK-NEXT:    vse8.v v9, (a1)
+; CHECK-NEXT:    ret
+  %1 = getelementptr inbounds i8, ptr %s, i64 0
+  %2 = load <2 x i8>, ptr %s, align 1
+  %3 = load <2 x i8>, ptr %1, align 1
+  %4 = shufflevector <2 x i8> %2, <2 x i8> %3, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %4, ptr %d, align 1
+  ret void
+}
+
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; RV32: {{.*}}
+; RV64: {{.*}}

``````````

</details>


https://github.com/llvm/llvm-project/pull/68726


More information about the llvm-commits mailing list