[llvm] c3b48ec - [RISCV] Match strided loads with reversed indexing sequences
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 22 07:59:57 PDT 2023
Author: Philip Reames
Date: 2023-08-22T07:59:49-07:00
New Revision: c3b48ec6ff92c47cd6136ee25360d31952a3adf2
URL: https://github.com/llvm/llvm-project/commit/c3b48ec6ff92c47cd6136ee25360d31952a3adf2
DIFF: https://github.com/llvm/llvm-project/commit/c3b48ec6ff92c47cd6136ee25360d31952a3adf2.diff
LOG: [RISCV] Match strided loads with reversed indexing sequences
This extends the concat_vector of loads to strided_load transform to handle reversed index pattern. The previous code expected indexing of the form (a0, a1+S, a2+S,...). However, we can also see indexing of the form (a1+S, a2+S, a3+S, .., aS). This form is a strided load starting at address aN + S*(n-1) with stride -S.
Note that this is also fixing what looks to be a bug in the memory location reasoning for forward strided case. A strided load with negative stride access eltsize bytes past base ptr, and then bytes *before* base ptr. (That is, the range should extend from before base ptr to after base ptr.)
Differential Revision: https://reviews.llvm.org/D157886
Added:
Modified:
llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 536cbcda6da28d..8eaaa4d06136b2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -12894,10 +12894,9 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
SDValue BasePtr = BaseLd->getBasePtr();
// Go through the loads and check that they're strided
- SDValue CurPtr = BasePtr;
- SDValue Stride;
+ SmallVector<SDValue> Ptrs;
+ Ptrs.push_back(BasePtr);
Align Align = BaseLd->getAlign();
-
for (SDValue Op : N->ops().drop_front()) {
auto *Ld = dyn_cast<LoadSDNode>(Op);
if (!Ld || !Ld->isSimple() || !Op.hasOneUse() ||
@@ -12905,27 +12904,66 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
Ld->getValueType(0) != BaseLdVT)
return SDValue();
- SDValue Ptr = Ld->getBasePtr();
- // Check that each load's pointer is (add CurPtr, Stride)
- if (Ptr.getOpcode() != ISD::ADD || Ptr.getOperand(0) != CurPtr)
- return SDValue();
- SDValue Offset = Ptr.getOperand(1);
- if (!Stride)
- Stride = Offset;
- else if (Offset != Stride)
- return SDValue();
+ Ptrs.push_back(Ld->getBasePtr());
// The common alignment is the most restrictive (smallest) of all the loads
Align = std::min(Align, Ld->getAlign());
+ }
- CurPtr = Ptr;
+ auto matchForwardStrided = [](ArrayRef<SDValue> Ptrs) {
+ SDValue Stride;
+ for (auto Idx : enumerate(Ptrs)) {
+ if (Idx.index() == 0)
+ continue;
+ SDValue Ptr = Idx.value();
+ // Check that each load's pointer is (add LastPtr, Stride)
+ if (Ptr.getOpcode() != ISD::ADD ||
+ Ptr.getOperand(0) != Ptrs[Idx.index()-1])
+ return SDValue();
+ SDValue Offset = Ptr.getOperand(1);
+ if (!Stride)
+ Stride = Offset;
+ else if (Offset != Stride)
+ return SDValue();
+ }
+ return Stride;
+ };
+ auto matchReverseStrided = [](ArrayRef<SDValue> Ptrs) {
+ SDValue Stride;
+ for (auto Idx : enumerate(Ptrs)) {
+ if (Idx.index() == Ptrs.size() - 1)
+ continue;
+ SDValue Ptr = Idx.value();
+ // Check that each load's pointer is (add NextPtr, Stride)
+ if (Ptr.getOpcode() != ISD::ADD ||
+ Ptr.getOperand(0) != Ptrs[Idx.index()+1])
+ return SDValue();
+ SDValue Offset = Ptr.getOperand(1);
+ if (!Stride)
+ Stride = Offset;
+ else if (Offset != Stride)
+ return SDValue();
+ }
+ return Stride;
+ };
+
+ bool Reversed = false;
+ SDValue Stride = matchForwardStrided(Ptrs);
+ if (!Stride) {
+ Stride = matchReverseStrided(Ptrs);
+ Reversed = true;
+ // TODO: At this point, we've successfully matched a generalized gather
+ // load. Maybe we should emit that, and then move the specialized
+ // matchers above and below into a DAG combine?
+ if (!Stride)
+ return SDValue();
}
// A special case is if the stride is exactly the width of one of the loads,
// in which case it's contiguous and can be combined into a regular vle
// without changing the element size
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
- ConstStride &&
+ ConstStride && !Reversed &&
ConstStride->getZExtValue() == BaseLdVT.getFixedSizeInBits() / 8) {
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(),
@@ -12962,6 +13000,8 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
SDValue IntID =
DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, Subtarget.getXLenVT());
+ if (Reversed)
+ Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
SDValue Ops[] = {BaseLd->getChain(),
IntID,
DAG.getUNDEF(ContainerVT),
@@ -12970,7 +13010,8 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
VL};
uint64_t MemSize;
- if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride))
+ if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
+ ConstStride && !Reversed && ConstStride->getSExtValue() >= 0)
// total size = (elsize * n) + (stride - elsize) * (n-1)
// = elsize + stride * (n-1)
MemSize = WideScalarVT.getSizeInBits() +
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll
index decc992533b553..f52ba6f51d5c89 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-strided-load-combine.ll
@@ -494,25 +494,15 @@ define void @strided_constant_neg_4xv2f32(ptr %x, ptr %z, i64 %s) {
ret void
}
-; TODO: This is a strided load with a negative stride
+; This is a strided load with a negative stride
define void @reverse_strided_constant_pos_4xv2f32(ptr %x, ptr %z, i64 %s) {
; CHECK-LABEL: reverse_strided_constant_pos_4xv2f32:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi a2, a0, 64
-; CHECK-NEXT: addi a3, a0, 128
-; CHECK-NEXT: addi a4, a0, 192
-; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
-; CHECK-NEXT: vle32.v v8, (a4)
-; CHECK-NEXT: vle32.v v10, (a3)
-; CHECK-NEXT: vle32.v v12, (a2)
-; CHECK-NEXT: vle32.v v14, (a0)
-; CHECK-NEXT: vsetivli zero, 4, e32, m2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v10, 2
-; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v12, 4
-; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT: vslideup.vi v8, v14, 6
-; CHECK-NEXT: vse32.v v8, (a1)
+; CHECK-NEXT: addi a0, a0, 192
+; CHECK-NEXT: li a2, -64
+; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma
+; CHECK-NEXT: vlse64.v v8, (a0), a2
+; CHECK-NEXT: vse64.v v8, (a1)
; CHECK-NEXT: ret
%x.1 = getelementptr i8, ptr %x, i64 64
%x.2 = getelementptr i8, ptr %x.1, i64 64
@@ -531,21 +521,11 @@ define void @reverse_strided_constant_pos_4xv2f32(ptr %x, ptr %z, i64 %s) {
define void @reverse_strided_constant_neg_4xv2f32(ptr %x, ptr %z, i64 %s) {
; CHECK-LABEL: reverse_strided_constant_neg_4xv2f32:
; CHECK: # %bb.0:
-; CHECK-NEXT: addi a2, a0, -64
-; CHECK-NEXT: addi a3, a0, -128
-; CHECK-NEXT: addi a4, a0, -192
-; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
-; CHECK-NEXT: vle32.v v8, (a4)
-; CHECK-NEXT: vle32.v v10, (a3)
-; CHECK-NEXT: vle32.v v12, (a2)
-; CHECK-NEXT: vle32.v v14, (a0)
-; CHECK-NEXT: vsetivli zero, 4, e32, m2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v10, 2
-; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v12, 4
-; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT: vslideup.vi v8, v14, 6
-; CHECK-NEXT: vse32.v v8, (a1)
+; CHECK-NEXT: addi a0, a0, -192
+; CHECK-NEXT: li a2, 64
+; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma
+; CHECK-NEXT: vlse64.v v8, (a0), a2
+; CHECK-NEXT: vse64.v v8, (a1)
; CHECK-NEXT: ret
%x.1 = getelementptr i8, ptr %x, i64 -64
%x.2 = getelementptr i8, ptr %x.1, i64 -64
@@ -561,25 +541,17 @@ define void @reverse_strided_constant_neg_4xv2f32(ptr %x, ptr %z, i64 %s) {
ret void
}
-; TODO: This is a strided load with a negative stride
+; This is a strided load with a negative stride
define void @reverse_strided_runtime_4xv2f32(ptr %x, ptr %z, i64 %s) {
; CHECK-LABEL: reverse_strided_runtime_4xv2f32:
; CHECK: # %bb.0:
-; CHECK-NEXT: add a3, a0, a2
-; CHECK-NEXT: add a4, a3, a2
-; CHECK-NEXT: add a2, a4, a2
-; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
-; CHECK-NEXT: vle32.v v8, (a2)
-; CHECK-NEXT: vle32.v v10, (a4)
-; CHECK-NEXT: vle32.v v12, (a3)
-; CHECK-NEXT: vle32.v v14, (a0)
-; CHECK-NEXT: vsetivli zero, 4, e32, m2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v10, 2
-; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma
-; CHECK-NEXT: vslideup.vi v8, v12, 4
-; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
-; CHECK-NEXT: vslideup.vi v8, v14, 6
-; CHECK-NEXT: vse32.v v8, (a1)
+; CHECK-NEXT: add a0, a0, a2
+; CHECK-NEXT: add a3, a2, a2
+; CHECK-NEXT: add a0, a0, a3
+; CHECK-NEXT: neg a2, a2
+; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma
+; CHECK-NEXT: vlse64.v v8, (a0), a2
+; CHECK-NEXT: vse64.v v8, (a1)
; CHECK-NEXT: ret
%x.1 = getelementptr i8, ptr %x, i64 %s
%x.2 = getelementptr i8, ptr %x.1, i64 %s
More information about the llvm-commits
mailing list