[llvm] [RISCV] Fold vector shift of sext/zext to widening multiply (PR #121563)

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 3 05:41:14 PST 2025


================
@@ -17341,6 +17341,78 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
   return DAG.getZExtOrTrunc(Pop, DL, VT);
 }
 
+static SDValue combineSHL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                          const RISCVSubtarget &Subtarget) {
+  if (DCI.isBeforeLegalize())
+    return SDValue();
+
+  // (shl (zext x), y) -> (vwsll   x, y)
+  if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
+    return V;
+
+  // (shl (sext x), C) -> (vwmulsu x, 1u << C)
+  // (shl (zext x), C) -> (vwmulu  x, 1u << C)
+
+  SDValue LHS = N->getOperand(0);
+  if (!LHS.hasOneUse())
+    return SDValue();
+  unsigned Opcode;
+  switch (LHS.getOpcode()) {
+  case ISD::SIGN_EXTEND:
+    Opcode = RISCVISD::VWMULSU_VL;
+    break;
+  case ISD::ZERO_EXTEND:
+    Opcode = RISCVISD::VWMULU_VL;
+    break;
+  default:
+    return SDValue();
+  }
+
+  SDValue RHS = N->getOperand(1);
+  APInt ShAmt;
+  if (!ISD::isConstantSplatVector(RHS.getNode(), ShAmt))
+    return SDValue();
+
+  // Better foldings:
+  // (shl (sext x), 1) -> (vwadd  x, x)
+  // (shl (zext x), 1) -> (vwaddu x, x)
+  uint64_t ShAmtInt = ShAmt.getZExtValue();
+  if (ShAmtInt <= 1)
+    return SDValue();
+
+  SDValue NarrowOp = LHS.getOperand(0);
+  EVT NarrowVT = NarrowOp.getValueType();
+  uint64_t NarrowBits = NarrowVT.getScalarSizeInBits();
+  if (ShAmtInt >= NarrowBits)
+    return SDValue();
+  EVT VT = N->getValueType(0);
+  if (NarrowBits * 2 != VT.getScalarSizeInBits())
+    return SDValue();
+
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc DL(N);
+  SDValue Passthru, Mask, VL;
+  switch (N->getOpcode()) {
+  case ISD::SHL:
+    if (!VT.isScalableVector())
+      return SDValue();
----------------
lukel97 wrote:

It might be worthwhile to leave a TODO to handle fixed length vectors later. You would need to use the `ContainerVT = getContainerForFixedLengthVector(...)` pattern that we use elsewhere. 

https://github.com/llvm/llvm-project/pull/121563


More information about the llvm-commits mailing list