[llvm] [RISCV] Fold vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK). (PR #123115)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 15 11:52:42 PST 2025
https://github.com/topperc created https://github.com/llvm/llvm-project/pull/123115
This was extracted from our downstream with only a quick re-review. It was originally written 1.5 years ago so there might be existing helper functions added since then that could simplify it.
>From c0bbaf5171506f4d32cb914d90d216a0571522e2 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 15 Jan 2025 11:44:37 -0800
Subject: [PATCH] [RISCV] Fold vp.reverse(vp.load(ADDR, MASK)) ->
vp.strided.load(ADDR, -1, MASK).
This was extracted from our downstream with only a quick review.
It was originally written 1.5 years ago so there might be existing
helper functions added since then that could simplify it.
Co-authored-by: Brandon Wu <brandon.wu at sifive.com>
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 65 ++++++++++
.../RISCV/rvv/vp-combine-reverse-load.ll | 114 ++++++++++++++++++
2 files changed, 179 insertions(+)
create mode 100644 llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b25cb128bce9fb..8b879893b5960b 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16229,6 +16229,69 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
}
+static SDValue performVP_REVERSECombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ // Fold:
+ // vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK)
+
+ // Check if its first operand is a vp.load.
+ auto *VPLoad = dyn_cast<VPLoadSDNode>(N->getOperand(0));
+ if (!VPLoad)
+ return SDValue();
+
+ EVT LoadVT = VPLoad->getValueType(0);
+ // We do not have a strided_load version for masks, and the evl of vp.reverse
+ // and vp.load should always be the same.
+ if (!LoadVT.getVectorElementType().isByteSized() ||
+ N->getOperand(2) != VPLoad->getVectorLength() ||
+ !N->getOperand(0).hasOneUse())
+ return SDValue();
+
+ // Check if the mask of outer vp.reverse are all 1's.
+ if (!isOneOrOneSplat(N->getOperand(1)))
+ return SDValue();
+
+ SDValue LoadMask = VPLoad->getMask();
+ // If Mask is not all 1's, try to replace the mask if it's opcode
+ // is EXPERIMENTAL_VP_REVERSE and it's operand can be directly extracted.
+ if (!isOneOrOneSplat(LoadMask)) {
+ // Check if the mask of vp.reverse in vp.load are all 1's and
+ // the length of mask is same as evl.
+ if (LoadMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
+ !isOneOrOneSplat(LoadMask.getOperand(1)) ||
+ LoadMask.getOperand(2) != VPLoad->getVectorLength())
+ return SDValue();
+ LoadMask = LoadMask.getOperand(0);
+ }
+
+ // Base = LoadAddr + (NumElem - 1) * ElemWidthByte
+ SDLoc DL(N);
+ MVT XLenVT = Subtarget.getXLenVT();
+ SDValue NumElem = VPLoad->getVectorLength();
+ uint64_t ElemWidthByte = VPLoad->getValueType(0).getScalarSizeInBits() / 8;
+
+ SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem,
+ DAG.getConstant(1, DL, XLenVT));
+ SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1,
+ DAG.getConstant(ElemWidthByte, DL, XLenVT));
+ SDValue Base = DAG.getNode(ISD::ADD, DL, XLenVT, VPLoad->getBasePtr(), Temp2);
+ SDValue Stride = DAG.getConstant(0 - ElemWidthByte, DL, XLenVT);
+
+ MachineFunction &MF = DAG.getMachineFunction();
+ MachinePointerInfo PtrInfo(VPLoad->getAddressSpace());
+ MachineMemOperand *MMO = MF.getMachineMemOperand(
+ PtrInfo, VPLoad->getMemOperand()->getFlags(),
+ LocationSize::beforeOrAfterPointer(), VPLoad->getAlign());
+
+ SDValue Ret = DAG.getStridedLoadVP(
+ LoadVT, DL, VPLoad->getChain(), Base, Stride, LoadMask,
+ VPLoad->getVectorLength(), MMO, VPLoad->isExpandingLoad());
+
+ DAG.ReplaceAllUsesOfValueWith(SDValue(VPLoad, 1), Ret.getValue(1));
+
+ return Ret;
+}
+
// Convert from one FMA opcode to another based on whether we are negating the
// multiply result and/or the accumulator.
// NOTE: Only supports RVV operations with VL.
@@ -18372,6 +18435,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
}
}
+ case ISD::EXPERIMENTAL_VP_REVERSE:
+ return performVP_REVERSECombine(N, DAG, Subtarget);
case ISD::BITCAST: {
assert(Subtarget.useRVVForFixedLengthVectors());
SDValue N0 = N->getOperand(0);
diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll b/llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll
new file mode 100644
index 00000000000000..b0604d51f93b31
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll
@@ -0,0 +1,114 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+f,+v -verify-machineinstrs < %s | FileCheck %s
+
+define <vscale x 2 x float> @test_reverse_load_combiner(<vscale x 2 x float>* %ptr, i32 zeroext %evl) {
+; CHECK-LABEL: test_reverse_load_combiner:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a2, a1, 2
+; CHECK-NEXT: add a0, a2, a0
+; CHECK-NEXT: addi a0, a0, -4
+; CHECK-NEXT: li a2, -4
+; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT: vlse32.v v8, (a0), a2
+; CHECK-NEXT: ret
+ %head = insertelement <vscale x 2 x i1> undef, i1 1, i32 0
+ %allones = shufflevector <vscale x 2 x i1> %head, <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer
+
+ %load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %allones, i32 %evl)
+ %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> %allones, i32 %evl)
+ ret <vscale x 2 x float> %rev
+}
+
+define <vscale x 2 x float> @test_load_mask_is_vp_reverse(<vscale x 2 x float>* %ptr, i32 zeroext %evl) {
+; CHECK-LABEL: test_load_mask_is_vp_reverse:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a2, zero, e8, mf4, ta, ma
+; CHECK-NEXT: vmv.v.i v8, 0
+; CHECK-NEXT: li a2, 1
+; CHECK-NEXT: vsetvli zero, zero, e8, mf4, tu, ma
+; CHECK-NEXT: vmv.s.x v8, a2
+; CHECK-NEXT: slli a2, a1, 2
+; CHECK-NEXT: add a0, a2, a0
+; CHECK-NEXT: li a2, -4
+; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
+; CHECK-NEXT: vand.vi v8, v8, 1
+; CHECK-NEXT: vmsne.vi v0, v8, 0
+; CHECK-NEXT: addi a0, a0, -4
+; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT: vlse32.v v8, (a0), a2, v0.t
+; CHECK-NEXT: ret
+ %head = insertelement <vscale x 2 x i1> undef, i1 1, i32 0
+ %allones = shufflevector <vscale x 2 x i1> %head, <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer
+
+ %loadmask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %head, <vscale x 2 x i1> %allones, i32 %evl)
+ %load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %loadmask, i32 %evl)
+ %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> %allones, i32 %evl)
+ ret <vscale x 2 x float> %rev
+}
+
+define <vscale x 2 x float> @test_load_mask_not_all_one(<vscale x 2 x float>* %ptr, i32 zeroext %evl) {
+; CHECK-LABEL: test_load_mask_not_all_one:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a2, zero, e8, mf4, ta, ma
+; CHECK-NEXT: vmclr.m v0
+; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT: vle32.v v9, (a0), v0.t
+; CHECK-NEXT: vid.v v8, v0.t
+; CHECK-NEXT: addi a1, a1, -1
+; CHECK-NEXT: vrsub.vx v10, v8, a1, v0.t
+; CHECK-NEXT: vrgather.vv v8, v9, v10, v0.t
+; CHECK-NEXT: ret
+ %head = insertelement <vscale x 2 x i1> undef, i1 1, i32 1
+ %notallones = shufflevector <vscale x 2 x i1> %head, <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer
+
+ %load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 %evl)
+ %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> %notallones, i32 %evl)
+ ret <vscale x 2 x float> %rev
+}
+
+define <vscale x 2 x float> @test_differnet_evl(<vscale x 2 x float>* %ptr, i32 zeroext %evl1, i32 zeroext %evl2) {
+; CHECK-LABEL: test_differnet_evl:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 2, e8, mf4, ta, ma
+; CHECK-NEXT: vmv.v.i v9, 1
+; CHECK-NEXT: vsetvli a3, zero, e8, mf4, ta, ma
+; CHECK-NEXT: vmv.v.i v10, 0
+; CHECK-NEXT: vmclr.m v8
+; CHECK-NEXT: addi a3, a1, -1
+; CHECK-NEXT: vsetivli zero, 2, e8, mf4, tu, ma
+; CHECK-NEXT: vslideup.vi v10, v9, 1
+; CHECK-NEXT: vmv1r.v v0, v8
+; CHECK-NEXT: vsetvli zero, a1, e16, mf2, ta, ma
+; CHECK-NEXT: vid.v v9, v0.t
+; CHECK-NEXT: vsetvli a4, zero, e8, mf4, ta, ma
+; CHECK-NEXT: vand.vi v10, v10, 1
+; CHECK-NEXT: vsetvli zero, a1, e16, mf2, ta, ma
+; CHECK-NEXT: vrsub.vx v9, v9, a3, v0.t
+; CHECK-NEXT: vsetvli a3, zero, e8, mf4, ta, ma
+; CHECK-NEXT: vmsne.vi v0, v10, 0
+; CHECK-NEXT: vsetvli zero, a1, e8, mf4, ta, ma
+; CHECK-NEXT: vmv.v.i v10, 0
+; CHECK-NEXT: vmerge.vim v10, v10, 1, v0
+; CHECK-NEXT: vmv1r.v v0, v8
+; CHECK-NEXT: vrgatherei16.vv v11, v10, v9, v0.t
+; CHECK-NEXT: vmsne.vi v0, v11, 0, v0.t
+; CHECK-NEXT: vsetvli zero, a2, e32, m1, ta, ma
+; CHECK-NEXT: vle32.v v9, (a0), v0.t
+; CHECK-NEXT: vmv1r.v v0, v8
+; CHECK-NEXT: vid.v v10, v0.t
+; CHECK-NEXT: addi a2, a2, -1
+; CHECK-NEXT: vrsub.vx v10, v10, a2, v0.t
+; CHECK-NEXT: vrgather.vv v8, v9, v10, v0.t
+; CHECK-NEXT: ret
+ %head = insertelement <vscale x 2 x i1> undef, i1 1, i32 1
+ %allones = shufflevector <vscale x 2 x i1> %head, <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer
+
+ %loadmask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %head, <vscale x 2 x i1> %allones, i32 %evl1)
+ %load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %loadmask, i32 %evl2)
+ %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> %allones, i32 %evl2)
+ ret <vscale x 2 x float> %rev
+}
+
+declare <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* nocapture, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1>, <vscale x 2 x i1>, i32)
More information about the llvm-commits
mailing list