[llvm] [RISCV] Fold vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK). (PR #123115)

Pengcheng Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 15 19:51:19 PST 2025


================
@@ -16229,6 +16229,69 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
 }
 
+static SDValue performVP_REVERSECombine(SDNode *N, SelectionDAG &DAG,
+                                        const RISCVSubtarget &Subtarget) {
+  // Fold:
+  //    vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK)
+
+  // Check if its first operand is a vp.load.
+  auto *VPLoad = dyn_cast<VPLoadSDNode>(N->getOperand(0));
+  if (!VPLoad)
+    return SDValue();
+
+  EVT LoadVT = VPLoad->getValueType(0);
+  // We do not have a strided_load version for masks, and the evl of vp.reverse
+  // and vp.load should always be the same.
+  if (!LoadVT.getVectorElementType().isByteSized() ||
+      N->getOperand(2) != VPLoad->getVectorLength() ||
+      !N->getOperand(0).hasOneUse())
+    return SDValue();
+
+  // Check if the mask of outer vp.reverse are all 1's.
+  if (!isOneOrOneSplat(N->getOperand(1)))
+    return SDValue();
+
+  SDValue LoadMask = VPLoad->getMask();
+  // If Mask is not all 1's, try to replace the mask if its opcode
+  // is EXPERIMENTAL_VP_REVERSE and its operand can be directly extracted.
+  if (!isOneOrOneSplat(LoadMask)) {
+    // Check if the mask of vp.reverse in vp.load are all 1's and
+    // the length of mask is same as evl.
+    if (LoadMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
+        !isOneOrOneSplat(LoadMask.getOperand(1)) ||
+        LoadMask.getOperand(2) != VPLoad->getVectorLength())
+      return SDValue();
+    LoadMask = LoadMask.getOperand(0);
+  }
+
+  // Base = LoadAddr + (NumElem - 1) * ElemWidthByte
+  SDLoc DL(N);
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDValue NumElem = VPLoad->getVectorLength();
+  uint64_t ElemWidthByte = VPLoad->getValueType(0).getScalarSizeInBits() / 8;
+
+  SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem,
+                              DAG.getConstant(1, DL, XLenVT));
+  SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1,
+                              DAG.getConstant(ElemWidthByte, DL, XLenVT));
+  SDValue Base = DAG.getNode(ISD::ADD, DL, XLenVT, VPLoad->getBasePtr(), Temp2);
+  SDValue Stride = DAG.getConstant(0 - ElemWidthByte, DL, XLenVT);
----------------
wangpc-pp wrote:

No need for 0.

https://github.com/llvm/llvm-project/pull/123115


More information about the llvm-commits mailing list