[llvm] c319c74 - [RISCV] Improve performCONCAT_VECTORCombine stride matching
Michael Maitland via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 16 16:49:26 PDT 2023
Author: Michael Maitland
Date: 2023-10-16T16:45:26-07:00
New Revision: c319c741463a039c2323825b149df70cbe535c67
URL: https://github.com/llvm/llvm-project/commit/c319c741463a039c2323825b149df70cbe535c67
DIFF: https://github.com/llvm/llvm-project/commit/c319c741463a039c2323825b149df70cbe535c67.diff
LOG: [RISCV] Improve performCONCAT_VECTORCombine stride matching
If the load ptrs can be decomposed into a common (Base + Index) with a
common constant stride, then return the constant stride.
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 6eb253cc5146635..4dc3f6137e3061a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -27,6 +27,7 @@
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/SelectionDAGAddressAnalysis.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
#include "llvm/CodeGen/ValueTypes.h"
#include "llvm/IR/DiagnosticInfo.h"
@@ -13803,9 +13804,17 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
Align = std::min(Align, Ld->getAlign());
}
- using PtrDiff = std::pair<SDValue, bool>;
- auto GetPtrDiff = [](LoadSDNode *Ld1,
- LoadSDNode *Ld2) -> std::optional<PtrDiff> {
+ using PtrDiff = std::pair<std::variant<int64_t, SDValue>, bool>;
+ auto GetPtrDiff = [&DAG](LoadSDNode *Ld1,
+ LoadSDNode *Ld2) -> std::optional<PtrDiff> {
+ // If the load ptrs can be decomposed into a common (Base + Index) with a
+ // common constant stride, then return the constant stride.
+ BaseIndexOffset BIO1 = BaseIndexOffset::match(Ld1, DAG);
+ BaseIndexOffset BIO2 = BaseIndexOffset::match(Ld2, DAG);
+ if (BIO1.equalBaseIndex(BIO2, DAG))
+ return {{BIO2.getOffset() - BIO1.getOffset(), false}};
+
+ // Otherwise try to match (add LastPtr, Stride) or (add NextPtr, Stride)
SDValue P1 = Ld1->getBasePtr();
SDValue P2 = Ld2->getBasePtr();
if (P2.getOpcode() == ISD::ADD && P2.getOperand(0) == P1)
@@ -13844,7 +13853,11 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
return SDValue();
- auto [Stride, MustNegateStride] = *BaseDiff;
+ auto [StrideVariant, MustNegateStride] = *BaseDiff;
+ SDValue Stride = std::holds_alternative<SDValue>(StrideVariant)
+ ? std::get<SDValue>(StrideVariant)
+ : DAG.getConstant(std::get<int64_t>(StrideVariant), DL,
+ Lds[0]->getOffset().getValueType());
if (MustNegateStride)
Stride = DAG.getNegative(Stride, DL, Stride.getValueType());
diff --git a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
index 611270ab98ebdaf..ff35043dbd7e75e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/concat-vectors-constant-stride.ll
@@ -7,21 +7,10 @@
define void @constant_forward_stride(ptr %s, ptr %d) {
; CHECK-LABEL: constant_forward_stride:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi a2, a0, 16
-; CHECK-NEXT: addi a3, a0, 32
-; CHECK-NEXT: addi a4, a0, 48
-; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT: vle8.v v8, (a0)
-; CHECK-NEXT: vle8.v v9, (a2)
-; CHECK-NEXT: vle8.v v10, (a3)
-; CHECK-NEXT: vle8.v v11, (a4)
-; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v9, 2
-; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v10, 4
-; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT: vslideup.vi v8, v11, 6
-; CHECK-NEXT: vse8.v v8, (a1)
+; CHECK-NEXT: li a2, 16
+; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT: vlse16.v v8, (a0), a2
+; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: ret
%1 = getelementptr inbounds i8, ptr %s, i64 16
%2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -40,21 +29,11 @@ define void @constant_forward_stride(ptr %s, ptr %d) {
define void @constant_forward_stride2(ptr %s, ptr %d) {
; CHECK-LABEL: constant_forward_stride2:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi a2, a0, -16
-; CHECK-NEXT: addi a3, a0, -32
-; CHECK-NEXT: addi a4, a0, -48
-; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT: vle8.v v8, (a4)
-; CHECK-NEXT: vle8.v v9, (a3)
-; CHECK-NEXT: vle8.v v10, (a2)
-; CHECK-NEXT: vle8.v v11, (a0)
-; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v9, 2
-; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v10, 4
-; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT: vslideup.vi v8, v11, 6
-; CHECK-NEXT: vse8.v v8, (a1)
+; CHECK-NEXT: addi a0, a0, -48
+; CHECK-NEXT: li a2, 16
+; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT: vlse16.v v8, (a0), a2
+; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: ret
%1 = getelementptr inbounds i8, ptr %s, i64 -16
%2 = getelementptr inbounds i8, ptr %s, i64 -32
@@ -73,21 +52,10 @@ define void @constant_forward_stride2(ptr %s, ptr %d) {
define void @constant_forward_stride3(ptr %s, ptr %d) {
; CHECK-LABEL: constant_forward_stride3:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi a2, a0, 16
-; CHECK-NEXT: addi a3, a0, 32
-; CHECK-NEXT: addi a4, a0, 48
-; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT: vle8.v v8, (a0)
-; CHECK-NEXT: vle8.v v9, (a2)
-; CHECK-NEXT: vle8.v v10, (a3)
-; CHECK-NEXT: vle8.v v11, (a4)
-; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v9, 2
-; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v10, 4
-; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT: vslideup.vi v8, v11, 6
-; CHECK-NEXT: vse8.v v8, (a1)
+; CHECK-NEXT: li a2, 16
+; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT: vlse16.v v8, (a0), a2
+; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: ret
%1 = getelementptr inbounds i8, ptr %s, i64 16
%2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -109,21 +77,10 @@ define void @constant_forward_stride3(ptr %s, ptr %d) {
define void @constant_back_stride(ptr %s, ptr %d) {
; CHECK-LABEL: constant_back_stride:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi a2, a0, -16
-; CHECK-NEXT: addi a3, a0, -32
-; CHECK-NEXT: addi a4, a0, -48
-; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT: vle8.v v8, (a0)
-; CHECK-NEXT: vle8.v v9, (a2)
-; CHECK-NEXT: vle8.v v10, (a3)
-; CHECK-NEXT: vle8.v v11, (a4)
-; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v9, 2
-; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v10, 4
-; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT: vslideup.vi v8, v11, 6
-; CHECK-NEXT: vse8.v v8, (a1)
+; CHECK-NEXT: li a2, -16
+; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT: vlse16.v v8, (a0), a2
+; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: ret
%1 = getelementptr inbounds i8, ptr %s, i64 -16
%2 = getelementptr inbounds i8, ptr %s, i64 -32
@@ -142,21 +99,11 @@ define void @constant_back_stride(ptr %s, ptr %d) {
define void @constant_back_stride2(ptr %s, ptr %d) {
; CHECK-LABEL: constant_back_stride2:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi a2, a0, 16
-; CHECK-NEXT: addi a3, a0, 32
-; CHECK-NEXT: addi a4, a0, 48
-; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT: vle8.v v8, (a4)
-; CHECK-NEXT: vle8.v v9, (a3)
-; CHECK-NEXT: vle8.v v10, (a2)
-; CHECK-NEXT: vle8.v v11, (a0)
-; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v9, 2
-; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v10, 4
-; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT: vslideup.vi v8, v11, 6
-; CHECK-NEXT: vse8.v v8, (a1)
+; CHECK-NEXT: addi a0, a0, 48
+; CHECK-NEXT: li a2, -16
+; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT: vlse16.v v8, (a0), a2
+; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: ret
%1 = getelementptr inbounds i8, ptr %s, i64 16
%2 = getelementptr inbounds i8, ptr %s, i64 32
@@ -175,21 +122,10 @@ define void @constant_back_stride2(ptr %s, ptr %d) {
define void @constant_back_stride3(ptr %s, ptr %d) {
; CHECK-LABEL: constant_back_stride3:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi a2, a0, -16
-; CHECK-NEXT: addi a3, a0, -32
-; CHECK-NEXT: addi a4, a0, -48
-; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
-; CHECK-NEXT: vle8.v v8, (a0)
-; CHECK-NEXT: vle8.v v9, (a2)
-; CHECK-NEXT: vle8.v v10, (a3)
-; CHECK-NEXT: vle8.v v11, (a4)
-; CHECK-NEXT: vsetivli zero, 4, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v9, 2
-; CHECK-NEXT: vsetivli zero, 6, e8, mf2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v10, 4
-; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
-; CHECK-NEXT: vslideup.vi v8, v11, 6
-; CHECK-NEXT: vse8.v v8, (a1)
+; CHECK-NEXT: li a2, -16
+; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
+; CHECK-NEXT: vlse16.v v8, (a0), a2
+; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: ret
%1 = getelementptr inbounds i8, ptr %s, i64 -16
%2 = getelementptr inbounds i8, ptr %s, i64 -32
More information about the llvm-commits
mailing list