[llvm] [AMDGPU] Convert more 64-bit lshr to 32-bit if shift amt>=32 (PR #138204)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu May 1 23:20:35 PDT 2025


================
@@ -4176,50 +4176,110 @@ SDValue AMDGPUTargetLowering::performSraCombine(SDNode *N,
 
 SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
                                                 DAGCombinerInfo &DCI) const {
-  auto *RHS = dyn_cast<ConstantSDNode>(N->getOperand(1));
-  if (!RHS)
-    return SDValue();
-
+  SDValue RHS = N->getOperand(1);
+  ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(RHS);
   EVT VT = N->getValueType(0);
   SDValue LHS = N->getOperand(0);
-  unsigned ShiftAmt = RHS->getZExtValue();
   SelectionDAG &DAG = DCI.DAG;
   SDLoc SL(N);
+  unsigned RHSVal;
+
+  if (CRHS) {
+    RHSVal = CRHS->getZExtValue();
 
-  // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
-  // this improves the ability to match BFE patterns in isel.
-  if (LHS.getOpcode() == ISD::AND) {
-    if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
-      unsigned MaskIdx, MaskLen;
-      if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
-          MaskIdx == ShiftAmt) {
-        return DAG.getNode(
-            ISD::AND, SL, VT,
-            DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0), N->getOperand(1)),
-            DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1), N->getOperand(1)));
+    // fold (srl (and x, c1 << c2), c2) -> (and (srl(x, c2), c1)
+    // this improves the ability to match BFE patterns in isel.
+    if (LHS.getOpcode() == ISD::AND) {
+      if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
+        unsigned MaskIdx, MaskLen;
+        if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
+            MaskIdx == RHSVal) {
+          return DAG.getNode(ISD::AND, SL, VT,
+                             DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0),
+                                         N->getOperand(1)),
+                             DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(1),
+                                         N->getOperand(1)));
+        }
       }
     }
   }
 
-  if (VT != MVT::i64)
+  // If the shift is exact, the shifted out bits matter.
+  if (N->getFlags().hasExact())
     return SDValue();
----------------
arsenm wrote:

This looks backwards? The shifted out bits are poison, exact allows you to be more aggressive 

https://github.com/llvm/llvm-project/pull/138204


More information about the llvm-commits mailing list