[llvm] [RISCV] Fold vp.store(vp.reverse(VAL), ADDR, MASK) -> vp.strided.store(VAL, NEW_ADDR, -1, MASK) (PR #123123)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 15 13:50:02 PST 2025
https://github.com/topperc created https://github.com/llvm/llvm-project/pull/123123
This was extracted from our downstream with only a quick re-review. It was originally written 1.5 years ago so there might be existing helper functions added since then that could simplify it.
Co-authored-by: Brandon Wu <brandon.wu at sifive.com>
>From 8367247e1906a53ad9ca1667cf54423a61b9cb23 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 15 Jan 2025 13:15:55 -0800
Subject: [PATCH] [RISCV] Fold vp.store(vp.reverse(VAL), ADDR, MASK) ->
vp.strided.store(VAL, NEW_ADDR, -1, MASK)
This was extracted from our downstream with only a quick re-review. It was originally written 1.5 years ago so there might be existing helper functions added since then that could simplify it.
Co-authored-by: Brandon Wu <brandon.wu at sifive.com>
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 78 ++++++++++++++++--
.../RISCV/rvv/vp-combine-store-reverse.ll | 81 +++++++++++++++++++
2 files changed, 153 insertions(+), 6 deletions(-)
create mode 100644 llvm/test/CodeGen/RISCV/rvv/vp-combine-store-reverse.ll
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b25cb128bce9fb..fc2f7781eae443 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1524,13 +1524,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
if (Subtarget.hasVInstructions())
- setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
- ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
- ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR,
+ setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER,
+ ISD::MSCATTER, ISD::VP_GATHER,
+ ISD::VP_SCATTER, ISD::SRA,
+ ISD::SRL, ISD::SHL,
+ ISD::STORE, ISD::SPLAT_VECTOR,
ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
- ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
- ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
- ISD::INSERT_VECTOR_ELT, ISD::ABS, ISD::CTPOP});
+ ISD::VP_STORE, ISD::EXPERIMENTAL_VP_REVERSE,
+ ISD::MUL, ISD::SDIV,
+ ISD::UDIV, ISD::SREM,
+ ISD::UREM, ISD::INSERT_VECTOR_ELT,
+ ISD::ABS, ISD::CTPOP});
if (Subtarget.hasVendorXTHeadMemPair())
setTargetDAGCombine({ISD::LOAD, ISD::STORE});
if (Subtarget.useRVVForFixedLengthVectors())
@@ -16229,6 +16233,66 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
}
+static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ // Fold:
+ // vp.store(vp.reverse(VAL), ADDR, MASK) -> vp.strided.store(VAL, NEW_ADDR,
+ // -1, MASK)
+ auto *VPStore = cast<VPStoreSDNode>(N);
+
+ if (VPStore->getValue().getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE)
+ return SDValue();
+
+ SDValue VPReverse = VPStore->getValue();
+ EVT ReverseVT = VPReverse->getValueType(0);
+
+ // We do not have a strided_store version for masks, and the evl of vp.reverse
+ // and vp.store should always be the same.
+ if (!ReverseVT.getVectorElementType().isByteSized() ||
+ VPStore->getVectorLength() != VPReverse.getOperand(2) ||
+ !VPReverse.hasOneUse())
+ return SDValue();
+
+ SDValue StoreMask = VPStore->getMask();
+ // If Mask is not all 1's, try to replace the mask if it's opcode
+ // is EXPERIMENTAL_VP_REVERSE and it's operand can be directly extracted.
+ if (!isOneOrOneSplat(StoreMask)) {
+ // Check if the mask of vp.reverse in vp.store are all 1's and
+ // the length of mask is same as evl.
+ if (StoreMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
+ !isOneOrOneSplat(StoreMask.getOperand(1)) ||
+ StoreMask.getOperand(2) != VPStore->getVectorLength())
+ return SDValue();
+ StoreMask = StoreMask.getOperand(0);
+ }
+
+ // Base = StoreAddr + (NumElem - 1) * ElemWidthByte
+ SDLoc DL(N);
+ MVT XLenVT = Subtarget.getXLenVT();
+ SDValue NumElem = VPStore->getVectorLength();
+ uint64_t ElemWidthByte = VPReverse.getValueType().getScalarSizeInBits() / 8;
+
+ SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem,
+ DAG.getConstant(1, DL, XLenVT));
+ SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1,
+ DAG.getConstant(ElemWidthByte, DL, XLenVT));
+ SDValue Base =
+ DAG.getNode(ISD::ADD, DL, XLenVT, VPStore->getBasePtr(), Temp2);
+ SDValue Stride = DAG.getConstant(0 - ElemWidthByte, DL, XLenVT);
+
+ MachineFunction &MF = DAG.getMachineFunction();
+ MachinePointerInfo PtrInfo(VPStore->getAddressSpace());
+ MachineMemOperand *MMO = MF.getMachineMemOperand(
+ PtrInfo, VPStore->getMemOperand()->getFlags(),
+ LocationSize::beforeOrAfterPointer(), VPStore->getAlign());
+
+ return DAG.getStridedStoreVP(
+ VPStore->getChain(), DL, VPReverse.getOperand(0), Base,
+ VPStore->getOffset(), Stride, StoreMask, VPStore->getVectorLength(),
+ VPStore->getMemoryVT(), MMO, VPStore->getAddressingMode(),
+ VPStore->isTruncatingStore(), VPStore->isCompressingStore());
+}
+
// Convert from one FMA opcode to another based on whether we are negating the
// multiply result and/or the accumulator.
// NOTE: Only supports RVV operations with VL.
@@ -18372,6 +18436,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
}
}
}
+ case ISD::VP_STORE:
+ return performVP_STORECombine(N, DAG, Subtarget);
case ISD::BITCAST: {
assert(Subtarget.useRVVForFixedLengthVectors());
SDValue N0 = N->getOperand(0);
diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-combine-store-reverse.ll b/llvm/test/CodeGen/RISCV/rvv/vp-combine-store-reverse.ll
new file mode 100644
index 00000000000000..4896a1367935ac
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-combine-store-reverse.ll
@@ -0,0 +1,81 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+f,+v -verify-machineinstrs < %s | FileCheck %s
+
+define void @test_store_reverse_combiner(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, i32 zeroext %evl) {
+; CHECK-LABEL: test_store_reverse_combiner:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a2, a1, 2
+; CHECK-NEXT: add a0, a2, a0
+; CHECK-NEXT: addi a0, a0, -4
+; CHECK-NEXT: li a2, -4
+; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT: vsse32.v v8, (a0), a2
+; CHECK-NEXT: ret
+ %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+ call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+ ret void
+}
+
+define void @test_store_mask_is_vp_reverse(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl) {
+; CHECK-LABEL: test_store_mask_is_vp_reverse:
+; CHECK: # %bb.0:
+; CHECK-NEXT: slli a2, a1, 2
+; CHECK-NEXT: add a0, a2, a0
+; CHECK-NEXT: addi a0, a0, -4
+; CHECK-NEXT: li a2, -4
+; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT: vsse32.v v8, (a0), a2, v0.t
+; CHECK-NEXT: ret
+ %storemask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+ %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+ call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %storemask, i32 %evl)
+ ret void
+}
+
+define void @test_store_mask_not_all_one(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 zeroext %evl) {
+; CHECK-LABEL: test_store_mask_not_all_one:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT: vid.v v9, v0.t
+; CHECK-NEXT: addi a1, a1, -1
+; CHECK-NEXT: vrsub.vx v9, v9, a1, v0.t
+; CHECK-NEXT: vrgather.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vse32.v v10, (a0), v0.t
+; CHECK-NEXT: ret
+ %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> %notallones, i32 %evl)
+ call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 %evl)
+ ret void
+}
+
+define void @test_different_evl(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl1, i32 zeroext %evl2) {
+; CHECK-LABEL: test_different_evl:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a1, e16, mf2, ta, ma
+; CHECK-NEXT: vid.v v9
+; CHECK-NEXT: addi a1, a1, -1
+; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
+; CHECK-NEXT: vmv.v.i v10, 0
+; CHECK-NEXT: vmerge.vim v10, v10, 1, v0
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vid.v v11
+; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
+; CHECK-NEXT: vrsub.vx v9, v9, a1
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vrsub.vx v11, v11, a1
+; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
+; CHECK-NEXT: vrgatherei16.vv v12, v10, v9
+; CHECK-NEXT: vmsne.vi v0, v12, 0
+; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT: vrgather.vv v9, v8, v11
+; CHECK-NEXT: vsetvli zero, a2, e32, m1, ta, ma
+; CHECK-NEXT: vse32.v v9, (a0), v0.t
+; CHECK-NEXT: ret
+ %storemask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl1)
+ %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> splat (i1 true), i32 %evl1)
+ call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %storemask, i32 %evl2)
+ ret void
+}
+
+declare <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1>, <vscale x 2 x i1>, i32)
+declare void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float>, <vscale x 2 x float>* nocapture, <vscale x 2 x i1>, i32)
More information about the llvm-commits
mailing list