[llvm] c319c74 - [RISCV] Improve performCONCAT_VECTORCombine stride matching

Michael Maitland via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 16 16:49:26 PDT 2023


Author: Michael Maitland
Date: 2023-10-16T16:45:26-07:00
New Revision: c319c741463a039c2323825b149df70cbe535c67

URL: https://github.com/llvm/llvm-project/commit/c319c741463a039c2323825b149df70cbe535c67
DIFF: https://github.com/llvm/llvm-project/commit/c319c741463a039c2323825b149df70cbe535c67.diff

LOG: [RISCV] Improve performCONCAT_VECTORCombine stride matching

If the load ptrs can be decomposed into a common (Base + Index) with a
common constant stride, then return the constant stride.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6eb253cc5146635..4dc3f6137e3061a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -27,6 +27,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -13803,9 +13804,17 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
     Align = std::min(Align, Ld->getAlign());
   }
 
-  using PtrDiff = std::pair<SDValue, bool>;
-  auto GetPtrDiff = [](LoadSDNode *Ld1,
-                       LoadSDNode *Ld2) -> std::optional<PtrDiff> {
+  using PtrDiff = std::pair<std::variant<int64_t, SDValue>, bool>;
+  auto GetPtrDiff = [&DAG](LoadSDNode *Ld1,
+                           LoadSDNode *Ld2) -> std::optional<PtrDiff> {
+    // If the load ptrs can be decomposed into a common (Base + Index) with a
+    // common constant stride, then return the constant stride.
+    BaseIndexOffset BIO1 = BaseIndexOffset::match(Ld1, DAG);
+    BaseIndexOffset BIO2 = BaseIndexOffset::match(Ld2, DAG);
+    if (BIO1.equalBaseIndex(BIO2, DAG))
+      return {{BIO2.getOffset() - BIO1.getOffset(), false}};
+
+    // Otherwise try to match (add LastPtr, Stride) or (add NextPtr, Stride)
     SDValue P1 = Ld1->getBasePtr();
     SDValue P2 = Ld2->getBasePtr();
     if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
@@ -13844,7 +13853,11 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
     return SDValue();
 
-  auto [Stride, MustNegateStride] = *BaseDiff;
+  auto [StrideVariant, MustNegateStride] = *BaseDiff;
+  SDValue Stride = std::holds_alternative<SDValue>(StrideVariant)
+                       ? std::get<SDValue>(StrideVariant)
+                       : DAG.getConstant(std::get<int64_t>(StrideVariant), DL,
+                                         Lds[0]->getOffset().getValueType());
   if (MustNegateStride)
     Stride = DAG.getNegative(Stride, DL, Stride.getValueType());
 

diff  --git a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
index 611270ab98ebdaf..ff35043dbd7e75e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
@@ -7,21 +7,10 @@
 define void @constant_forward_stride(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_forward_stride:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 16
-; CHECK-NEXT:    addi a3, a0, 32
-; CHECK-NEXT:    addi a4, a0, 48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, 16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 16
   %2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -40,21 +29,11 @@ define void @constant_forward_stride(ptr %s, ptr %d) {
 define void @constant_forward_stride2(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_forward_stride2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -16
-; CHECK-NEXT:    addi a3, a0, -32
-; CHECK-NEXT:    addi a4, a0, -48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a4)
-; CHECK-NEXT:    vle8.v v9, (a3)
-; CHECK-NEXT:    vle8.v v10, (a2)
-; CHECK-NEXT:    vle8.v v11, (a0)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    addi a0, a0, -48
+; CHECK-NEXT:    li a2, 16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 -16
   %2 = getelementptr inbounds i8, ptr %s, i64 -32
@@ -73,21 +52,10 @@ define void @constant_forward_stride2(ptr %s, ptr %d) {
 define void @constant_forward_stride3(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_forward_stride3:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 16
-; CHECK-NEXT:    addi a3, a0, 32
-; CHECK-NEXT:    addi a4, a0, 48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, 16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 16
   %2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -109,21 +77,10 @@ define void @constant_forward_stride3(ptr %s, ptr %d) {
 define void @constant_back_stride(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_back_stride:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -16
-; CHECK-NEXT:    addi a3, a0, -32
-; CHECK-NEXT:    addi a4, a0, -48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, -16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 -16
   %2 = getelementptr inbounds i8, ptr %s, i64 -32
@@ -142,21 +99,11 @@ define void @constant_back_stride(ptr %s, ptr %d) {
 define void @constant_back_stride2(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_back_stride2:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, 16
-; CHECK-NEXT:    addi a3, a0, 32
-; CHECK-NEXT:    addi a4, a0, 48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a4)
-; CHECK-NEXT:    vle8.v v9, (a3)
-; CHECK-NEXT:    vle8.v v10, (a2)
-; CHECK-NEXT:    vle8.v v11, (a0)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    addi a0, a0, 48
+; CHECK-NEXT:    li a2, -16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 16
   %2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -175,21 +122,10 @@ define void @constant_back_stride2(ptr %s, ptr %d) {
 define void @constant_back_stride3(ptr %s, ptr %d) {
 ; CHECK-LABEL: constant_back_stride3:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    addi a2, a0, -16
-; CHECK-NEXT:    addi a3, a0, -32
-; CHECK-NEXT:    addi a4, a0, -48
-; CHECK-NEXT:    vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT:    vle8.v v8, (a0)
-; CHECK-NEXT:    vle8.v v9, (a2)
-; CHECK-NEXT:    vle8.v v10, (a3)
-; CHECK-NEXT:    vle8.v v11, (a4)
-; CHECK-NEXT:    vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v9, 2
-; CHECK-NEXT:    vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT:    vslideup.vi v8, v10, 4
-; CHECK-NEXT:    vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT:    vslideup.vi v8, v11, 6
-; CHECK-NEXT:    vse8.v v8, (a1)
+; CHECK-NEXT:    li a2, -16
+; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT:    vlse16.v v8, (a0), a2
+; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = getelementptr inbounds i8, ptr %s, i64 -16
   %2 = getelementptr inbounds i8, ptr %s, i64 -32


        


More information about the llvm-commits mailing list