[llvm] [RISCV] Fold vp.store(vp.reverse(VAL), ADDR, MASK) -> vp.strided.store(VAL, NEW_ADDR, -1, MASK) (PR #123123)

Pengcheng Wang via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 15 19:45:48 PST 2025


================
@@ -16229,6 +16233,66 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
 }
 
+static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
+                                      const RISCVSubtarget &Subtarget) {
+  // Fold:
+  //    vp.store(vp.reverse(VAL), ADDR, MASK) -> vp.strided.store(VAL, NEW_ADDR,
+  //    -1, MASK)
+  auto *VPStore = cast<VPStoreSDNode>(N);
+
+  if (VPStore->getValue().getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE)
+    return SDValue();
+
+  SDValue VPReverse = VPStore->getValue();
+  EVT ReverseVT = VPReverse->getValueType(0);
+
+  // We do not have a strided_store version for masks, and the evl of vp.reverse
+  // and vp.store should always be the same.
+  if (!ReverseVT.getVectorElementType().isByteSized() ||
+      VPStore->getVectorLength() != VPReverse.getOperand(2) ||
+      !VPReverse.hasOneUse())
+    return SDValue();
+
+  SDValue StoreMask = VPStore->getMask();
+  // If Mask is not all 1's, try to replace the mask if its opcode
+  // is EXPERIMENTAL_VP_REVERSE and its operand can be directly extracted.
+  if (!isOneOrOneSplat(StoreMask)) {
+    // Check if the mask of vp.reverse in vp.store are all 1's and
+    // the length of mask is same as evl.
+    if (StoreMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
+        !isOneOrOneSplat(StoreMask.getOperand(1)) ||
+        StoreMask.getOperand(2) != VPStore->getVectorLength())
+      return SDValue();
+    StoreMask = StoreMask.getOperand(0);
+  }
+
+  // Base = StoreAddr + (NumElem - 1) * ElemWidthByte
+  SDLoc DL(N);
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDValue NumElem = VPStore->getVectorLength();
+  uint64_t ElemWidthByte = VPReverse.getValueType().getScalarSizeInBits() / 8;
+
+  SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem,
+                              DAG.getConstant(1, DL, XLenVT));
+  SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1,
+                              DAG.getConstant(ElemWidthByte, DL, XLenVT));
+  SDValue Base =
+      DAG.getNode(ISD::ADD, DL, XLenVT, VPStore->getBasePtr(), Temp2);
+  SDValue Stride = DAG.getConstant(0 - ElemWidthByte, DL, XLenVT);
----------------
wangpc-pp wrote:

No need for 0 I think.

https://github.com/llvm/llvm-project/pull/123123


More information about the llvm-commits mailing list