[llvm] [RISCV] Fold vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK). (PR #123115)
Pengcheng Wang via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 15 19:51:19 PST 2025
================
@@ -16229,6 +16229,69 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
}
+static SDValue performVP_REVERSECombine(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ // Fold:
+ // vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK)
+
+ // Check if its first operand is a vp.load.
+ auto *VPLoad = dyn_cast<VPLoadSDNode>(N->getOperand(0));
+ if (!VPLoad)
+ return SDValue();
+
+ EVT LoadVT = VPLoad->getValueType(0);
+ // We do not have a strided_load version for masks, and the evl of vp.reverse
+ // and vp.load should always be the same.
+ if (!LoadVT.getVectorElementType().isByteSized() ||
+ N->getOperand(2) != VPLoad->getVectorLength() ||
+ !N->getOperand(0).hasOneUse())
+ return SDValue();
+
+ // Check if the mask of outer vp.reverse are all 1's.
+ if (!isOneOrOneSplat(N->getOperand(1)))
+ return SDValue();
+
+ SDValue LoadMask = VPLoad->getMask();
+ // If Mask is not all 1's, try to replace the mask if its opcode
+ // is EXPERIMENTAL_VP_REVERSE and its operand can be directly extracted.
+ if (!isOneOrOneSplat(LoadMask)) {
+ // Check if the mask of vp.reverse in vp.load are all 1's and
+ // the length of mask is same as evl.
+ if (LoadMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
+ !isOneOrOneSplat(LoadMask.getOperand(1)) ||
+ LoadMask.getOperand(2) != VPLoad->getVectorLength())
+ return SDValue();
+ LoadMask = LoadMask.getOperand(0);
+ }
+
+ // Base = LoadAddr + (NumElem - 1) * ElemWidthByte
+ SDLoc DL(N);
+ MVT XLenVT = Subtarget.getXLenVT();
+ SDValue NumElem = VPLoad->getVectorLength();
+ uint64_t ElemWidthByte = VPLoad->getValueType(0).getScalarSizeInBits() / 8;
+
+ SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem,
+ DAG.getConstant(1, DL, XLenVT));
+ SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1,
+ DAG.getConstant(ElemWidthByte, DL, XLenVT));
+ SDValue Base = DAG.getNode(ISD::ADD, DL, XLenVT, VPLoad->getBasePtr(), Temp2);
+ SDValue Stride = DAG.getConstant(0 - ElemWidthByte, DL, XLenVT);
----------------
wangpc-pp wrote:
No need for 0.
https://github.com/llvm/llvm-project/pull/123115
More information about the llvm-commits
mailing list