[llvm] [AMDGPU] Convert 64-bit sra to 32-bit if shift amt >= 32 (PR #144421)

Shilei Tian via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 20 06:01:15 PDT 2025


================
@@ -4151,32 +4151,97 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
 
 SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
                                                 DAGCombinerInfo &DCI) const {
-  if (N->getValueType(0) != MVT::i64)
+  SDValue RHS = N->getOperand(1);
+  ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
+  EVT VT = N->getValueType(0);
+  SDValue LHS = N->getOperand(0);
+  SelectionDAG &DAG = DCI.DAG;
+  SDLoc SL(N);
+
+  if (VT.getScalarType() != MVT::i64)
     return SDValue();
 
-  const ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
-  if (!RHS)
+  // for C >= 32
+  // i64 (sra x, C) -> (build_pair (sra hi_32(x), C - 32), sra hi_32(x), 31))
+
+  // On some subtargets, 64-bit shift is a quarter rate instruction. In the
+  // common case, splitting this into a move and a 32-bit shift is faster and
+  // the same code size.
+  KnownBits Known = DAG.computeKnownBits(RHS);
+
+  EVT ElementType = VT.getScalarType();
+  EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
+  EVT TargetType = VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
+                                 : TargetScalarType;
+
+  if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
     return SDValue();
 
-  SelectionDAG &DAG = DCI.DAG;
-  SDLoc SL(N);
-  unsigned RHSVal = RHS->getZExtValue();
+  SDValue ShiftFullAmt =
+      DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
+  SDValue ShiftAmt;
+  if (CRHS) {
+    unsigned RHSVal = CRHS->getZExtValue();
+
+    ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
+                               TargetType);
----------------
shiltian wrote:

nit:
```suggestion
    unsigned RHSVal = CRHS->getZExtValue();
    ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
                               TargetType);
```

https://github.com/llvm/llvm-project/pull/144421


More information about the llvm-commits mailing list