[llvm] [AMDGPU] Convert more 64-bit lshr to 32-bit if shift amt>=32 (PR #138204)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 6 08:14:55 PDT 2025


================
@@ -4176,50 +4176,106 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
 
 SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
                                                 DAGCombinerInfo &DCI) const {
-  auto *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
-  if (!RHS)
-    return SDValue();
-
+  SDValue RHS = N->getOperand(1);
+  ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
   EVT VT = N->getValueType(0);
   SDValue LHS = N->getOperand(0);
-  unsigned ShiftAmt = RHS->getZExtValue();
   SelectionDAG &DAG = DCI.DAG;
   SDLoc SL(N);
+  unsigned RHSVal;
 
-  // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
-  // this improves the ability to match BFE patterns in isel.
-  if (LHS.getOpcode() == ISD::AND) {
-    if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
-      unsigned MaskIdx, MaskLen;
-      if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
-          MaskIdx == ShiftAmt) {
-        return DAG.getNode(
-            ISD::AND, SL, VT,
-            DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0), N->getOperand(1)),
-            DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1), N->getOperand(1)));
+  if (CRHS) {
+    RHSVal = CRHS->getZExtValue();
+
+    // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
+    // this improves the ability to match BFE patterns in isel.
+    if (LHS.getOpcode() == ISD::AND) {
+      if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
+        unsigned MaskIdx, MaskLen;
+        if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
+            MaskIdx == RHSVal) {
+          return DAG.getNode(ISD::AND, SL, VT,
+                             DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0),
+                                         N->getOperand(1)),
+                             DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1),
+                                         N->getOperand(1)));
+        }
       }
     }
   }
 
-  if (VT != MVT::i64)
+  if (VT.getScalarType() != MVT::i64)
     return SDValue();
 
-  if (ShiftAmt < 32)
+  // for C >= 32
+  // i64 (srl x, C) -> (build_pair (srl hi_32(x), C -32), 0)
+
+  // On some subtargets, 64-bit shift is a quarter rate instruction. In the
+  // common case, splitting this into a move and a 32-bit shift is faster and
+  // the same code size.
+  KnownBits Known = DAG.computeKnownBits(RHS);
+
+  EVT ElementType = VT.getScalarType();
+  EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
+  EVT TargetType = VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
+                                 : TargetScalarType;
+
+  if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
     return SDValue();
 
-  // srl i64:x, C for C >= 32
-  // =>
-  //   build_pair (srl hi_32(x), C - 32), 0
-  SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
+  SDValue ShiftAmt;
+  if (CRHS) {
+    ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
+                               TargetType);
+  } else {
+    SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
+    const SDValue ShiftMask =
+        DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
+    // This AND instruction will clamp out of bounds shift values.
+    // It will also be removed during later instruction selection.
+    ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
+  }
 
-  SDValue Hi = getHiHalf64(LHS, DAG);
+  const SDValue Zero = DAG.getConstant(0, SL, TargetScalarType);
+  EVT ConcatType;
+  SDValue Hi;
+  SDLoc LHSSL(LHS);
+  // Bitcast LHS into ConcatType so hi-half of source can be extracted into Hi
+  if (VT.isVector()) {
+    unsigned NElts = TargetType.getVectorNumElements();
+    ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
+    SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
+    SmallVector<SDValue, 8> HiOps(NElts);
+    SmallVector<SDValue, 16> HiAndLoOps;
+
+    DAG.ExtractVectorElements(SplitLHS, HiAndLoOps, 0, NElts * 2);
+    for (unsigned I = 0; I != NElts; ++I) {
+      HiOps[I] = HiAndLoOps[2 * I + 1];
+    }
+    Hi = DAG.getNode(ISD::BUILD_VECTOR, LHSSL, TargetType, HiOps);
+  } else {
+    const SDValue One = DAG.getConstant(1, LHSSL, TargetScalarType);
+    ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
+    SDValue SplitLHS = DAG.getNode(ISD::BITCAST, LHSSL, ConcatType, LHS);
+    Hi = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, LHSSL, TargetType, SplitLHS, One);
+  }
----------------
LU-JOHN wrote:

I could not come up with a simpler way to code this.  In the SHL case, extracting the low-half can be simply done with a TRUNCATE instruction.  EXTRACT_ELEMENT works for the hi-half, but it does not work for vectors.

https://github.com/llvm/llvm-project/pull/138204


More information about the llvm-commits mailing list