[llvm] [RISCV] Fold vector shift of sext/zext to widening multiply (PR #121563)

Piotr Fusik via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 13 04:56:11 PST 2025


================
@@ -17341,6 +17341,95 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
   return DAG.getZExtOrTrunc(Pop, DL, VT);
 }
 
+static SDValue combineSHL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                          const RISCVSubtarget &Subtarget) {
+  // (shl (zext x), y) -> (vwsll   x, y)
+  if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
+    return V;
+
+  // (shl (sext x), C) -> (vwmulsu x, 1u << C)
+  // (shl (zext x), C) -> (vwmulu  x, 1u << C)
+
+  if (!DCI.isAfterLegalizeDAG())
+    return SDValue();
+
+  SDValue LHS = N->getOperand(0);
+  if (!LHS.hasOneUse())
+    return SDValue();
+  unsigned Opcode;
+  switch (LHS.getOpcode()) {
+  case ISD::SIGN_EXTEND:
+  case RISCVISD::VSEXT_VL:
+    Opcode = RISCVISD::VWMULSU_VL;
+    break;
+  case ISD::ZERO_EXTEND:
+  case RISCVISD::VZEXT_VL:
+    Opcode = RISCVISD::VWMULU_VL;
+    break;
+  default:
+    return SDValue();
+  }
+
+  SDValue RHS = N->getOperand(1);
+  APInt ShAmt;
+  uint64_t ShAmtInt;
+  if (ISD::isConstantSplatVector(RHS.getNode(), ShAmt))
+    ShAmtInt = ShAmt.getZExtValue();
+  else if (RHS.getOpcode() == RISCVISD::VMV_V_X_VL &&
+           RHS.getOperand(1).getOpcode() == ISD::Constant)
+    ShAmtInt = RHS.getConstantOperandVal(1);
+  else
+    return SDValue();
+
+  // Better foldings:
+  // (shl (sext x), 1) -> (vwadd  x, x)
+  // (shl (zext x), 1) -> (vwaddu x, x)
+  if (ShAmtInt <= 1)
+    return SDValue();
+
+  SDValue NarrowOp = LHS.getOperand(0);
+  MVT NarrowVT = NarrowOp.getSimpleValueType();
+  uint64_t NarrowBits = NarrowVT.getScalarSizeInBits();
+  if (ShAmtInt >= NarrowBits)
+    return SDValue();
+  MVT VT = N->getSimpleValueType(0);
+  if (NarrowBits * 2 != VT.getScalarSizeInBits())
+    return SDValue();
+
+  SelectionDAG &DAG = DCI.DAG;
+  MVT NarrowContainerVT = NarrowVT;
+  MVT ContainerVT = VT;
+  SDLoc DL(N);
+  SDValue Passthru, Mask, VL;
+  switch (N->getOpcode()) {
+  case ISD::SHL:
+    if (VT.isFixedLengthVector()) {
+      NarrowContainerVT =
+          getContainerForFixedLengthVector(DAG, NarrowVT, Subtarget);
+      NarrowOp =
+          convertToScalableVector(NarrowContainerVT, NarrowOp, DAG, Subtarget);
+      ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+    }
+    Passthru = DAG.getUNDEF(VT);
----------------
pfusik wrote:

You are right. This code was dead. Removed.

https://github.com/llvm/llvm-project/pull/121563


More information about the llvm-commits mailing list