[llvm] [AMDGPU] Convert more 64-bit lshr to 32-bit if shift amt>=32 (PR #138204)
via llvm-commits
llvm-commits at lists.llvm.org
Tue May 6 08:14:55 PDT 2025
================
@@ -4176,50 +4176,106 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
- auto *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
- if (!RHS)
- return SDValue();
-
+ SDValue RHS = N->getOperand(1);
+ ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
EVT VT = N->getValueType(0);
SDValue LHS = N->getOperand(0);
- unsigned ShiftAmt = RHS->getZExtValue();
SelectionDAG &DAG = DCI.DAG;
SDLoc SL(N);
+ unsigned RHSVal;
- // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
- // this improves the ability to match BFE patterns in isel.
- if (LHS.getOpcode() == ISD::AND) {
- if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
- unsigned MaskIdx, MaskLen;
- if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
- MaskIdx == ShiftAmt) {
- return DAG.getNode(
- ISD::AND, SL, VT,
- DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0), N->getOperand(1)),
- DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1), N->getOperand(1)));
+ if (CRHS) {
+ RHSVal = CRHS->getZExtValue();
+
+ // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
+ // this improves the ability to match BFE patterns in isel.
+ if (LHS.getOpcode() == ISD::AND) {
+ if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
+ unsigned MaskIdx, MaskLen;
+ if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
+ MaskIdx == RHSVal) {
+ return DAG.getNode(ISD::AND, SL, VT,
+ DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0),
+ N->getOperand(1)),
+ DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1),
+ N->getOperand(1)));
+ }
}
}
}
- if (VT != MVT::i64)
+ if (VT.getScalarType() != MVT::i64)
return SDValue();
- if (ShiftAmt < 32)
+ // for C >= 32
+ // i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
+
+ // On some subtargets, 64-bit shift is a quarter rate instruction. In the
+ // common case, splitting this into a move and a 32-bit shift is faster and
+ // the same code size.
+ KnownBits Known = DAG.computeKnownBits(RHS);
+
+ EVT ElementType = VT.getScalarType();
+ EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
+ EVT TargetType = VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
+ : TargetScalarType;
+
+ if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
return SDValue();
- // srl i64:x, C for C >= 32
- // =>
- // build_pair (srl hi_32(x), C - 32), 0
- SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
+ SDValue ShiftAmt;
+ if (CRHS) {
+ ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
+ TargetType);
+ } else {
+ SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
+ const SDValue ShiftMask =
+ DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
+ // This AND instruction will clamp out of bounds shift values.
+ // It will also be removed during later instruction selection.
+ ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
+ }
- SDValue Hi = getHiHalf64(LHS, DAG);
+ const SDValue Zero = DAG.getConstant(0, SL, TargetScalarType);
+ EVT ConcatType;
+ SDValue Hi;
+ SDLoc LHSSL(LHS);
+ // Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
+ if (VT.isVector()) {
+ unsigned NElts = TargetType.getVectorNumElements();
+ ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
+ SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
+ SmallVector<SDValue, 8> HiOps(NElts);
+ SmallVector<SDValue, 16> HiAndLoOps;
+
+ DAG.ExtractVectorElements(SplitLHS, HiAndLoOps, 0, NElts * 2);
+ for (unsigned I = 0; I != NElts; ++I) {
+ HiOps[I] = HiAndLoOps[2 * I + 1];
+ }
+ Hi = DAG.getNode(ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
+ } else {
+ const SDValue One = DAG.getConstant(1, LHSSL, TargetScalarType);
+ ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
+ SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
+ Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
+ }
----------------
LU-JOHN wrote:
I could not come up with a simpler way to code this. In the SHL case, extracting the low-half can be simply done with a TRUNCATE instruction. EXTRACT_ELEMENT works for the hi-half, but it does not work for vectors.
https://github.com/llvm/llvm-project/pull/138204
More information about the llvm-commits
mailing list