[llvm] [RISCV] Fold vector shift of sext/zext to widening multiply (PR #121563)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Mon Jan 13 04:22:03 PST 2025
================
@@ -17341,6 +17341,95 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
return DAG.getZExtOrTrunc(Pop, DL, VT);
}
+static SDValue combineSHL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ const RISCVSubtarget &Subtarget) {
+ // (shl (zext x), y) -> (vwsll x, y)
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
+ return V;
+
+ // (shl (sext x), C) -> (vwmulsu x, 1u << C)
+ // (shl (zext x), C) -> (vwmulu x, 1u << C)
+
+ if (!DCI.isAfterLegalizeDAG())
+ return SDValue();
+
+ SDValue LHS = N->getOperand(0);
+ if (!LHS.hasOneUse())
+ return SDValue();
+ unsigned Opcode;
+ switch (LHS.getOpcode()) {
+ case ISD::SIGN_EXTEND:
+ case RISCVISD::VSEXT_VL:
+ Opcode = RISCVISD::VWMULSU_VL;
+ break;
+ case ISD::ZERO_EXTEND:
+ case RISCVISD::VZEXT_VL:
+ Opcode = RISCVISD::VWMULU_VL;
+ break;
+ default:
+ return SDValue();
+ }
+
+ SDValue RHS = N->getOperand(1);
+ APInt ShAmt;
+ uint64_t ShAmtInt;
+ if (ISD::isConstantSplatVector(RHS.getNode(), ShAmt))
+ ShAmtInt = ShAmt.getZExtValue();
+ else if (RHS.getOpcode() == RISCVISD::VMV_V_X_VL &&
+ RHS.getOperand(1).getOpcode() == ISD::Constant)
+ ShAmtInt = RHS.getConstantOperandVal(1);
+ else
+ return SDValue();
+
+ // Better foldings:
+ // (shl (sext x), 1) -> (vwadd x, x)
+ // (shl (zext x), 1) -> (vwaddu x, x)
+ if (ShAmtInt <= 1)
+ return SDValue();
+
+ SDValue NarrowOp = LHS.getOperand(0);
+ MVT NarrowVT = NarrowOp.getSimpleValueType();
+ uint64_t NarrowBits = NarrowVT.getScalarSizeInBits();
+ if (ShAmtInt >= NarrowBits)
+ return SDValue();
+ MVT VT = N->getSimpleValueType(0);
+ if (NarrowBits * 2 != VT.getScalarSizeInBits())
+ return SDValue();
+
+ SelectionDAG &DAG = DCI.DAG;
+ MVT NarrowContainerVT = NarrowVT;
+ MVT ContainerVT = VT;
+ SDLoc DL(N);
+ SDValue Passthru, Mask, VL;
+ switch (N->getOpcode()) {
+ case ISD::SHL:
+ if (VT.isFixedLengthVector()) {
+ NarrowContainerVT =
+ getContainerForFixedLengthVector(DAG, NarrowVT, Subtarget);
+ NarrowOp =
+ convertToScalableVector(NarrowContainerVT, NarrowOp, DAG, Subtarget);
+ ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
+ }
+ Passthru = DAG.getUNDEF(VT);
----------------
lukel97 wrote:
Doesn't the passthru need to be of ContainerVT type?
Either way if you moved this to after dag legalization then there shouldn't be any ISD::SHL nodes of fixed-length vector type, I think it should all be scalable. Does it still work if you remove the fixed-length handling here?
https://github.com/llvm/llvm-project/pull/121563
More information about the llvm-commits
mailing list