[llvm] [SelectionDAG][X86] Fold `sub(x, mul(divrem(x,y)[0], y))` to `divrem(x, y)[1]` (PR #136565)

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 22 03:51:02 PDT 2025


================
@@ -3867,6 +3867,61 @@ static SDValue foldSubCtlzNot(SDNode *N, SelectionDAG &DAG) {
   return Matcher.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, Not);
 }
 
+// Fold sub(x, mul(divrem(x,y)[0], y)) to divrem(x, y)[1]
+static SDValue foldRemainderIdiom(SDNode *N, SelectionDAG &DAG, SDLoc &DL) {
+  assert(N->getOpcode() == ISD::SUB && "Node must be a SUB");
+  SDValue Sub0 = N->getOperand(0);
+  SDValue Sub1 = N->getOperand(1);
+
+  auto CheckAndFoldMulCase = [&](SDValue DivRem, SDValue MaybeY) -> SDValue {
+    if ((DivRem.getOpcode() == ISD::SDIVREM ||
+         DivRem.getOpcode() == ISD::UDIVREM) &&
+        DivRem.getResNo() == 0 && DivRem.getOperand(0) == Sub0 &&
+        DivRem.getOperand(1) == MaybeY) {
+      return SDValue(DivRem.getNode(), 1);
+    }
+    return SDValue();
+  };
+
+  if (Sub1.getOpcode() == ISD::MUL) {
+    // (sub x, (mul divrem(x,y)[0], y))
+    SDValue Mul0 = Sub1.getOperand(0);
+    SDValue Mul1 = Sub1.getOperand(1);
+
+    SDValue Res = CheckAndFoldMulCase(Mul0, Mul1);
+    if (Res)
+      return Res;
+
+    Res = CheckAndFoldMulCase(Mul1, Mul0);
+    if (Res)
+      return Res;
+
+  } else if (Sub1.getOpcode() == ISD::SHL) {
+    // Handle (sub x, (shl divrem(x,y)[0], C)) where y = 1 << C
+    SDValue Shl0 = Sub1.getOperand(0);
+    SDValue Shl1 = Sub1.getOperand(1);
+    // Check if Shl0 is divrem(x, Y)[0]
+    if ((Shl0.getOpcode() == ISD::SDIVREM ||
+         Shl0.getOpcode() == ISD::UDIVREM) &&
+        Shl0.getResNo() == 0 && Shl0.getOperand(0) == Sub0) {
+
+      SDValue Divisor = Shl0.getOperand(1);
+
+      ConstantSDNode *DivC = isConstOrConstSplat(Divisor);
+      ConstantSDNode *ShC = isConstOrConstSplat(Shl1);
+      if (!DivC || !ShC) {
+        return SDValue();
+      }
+
+      if (DivC->getAPIntValue().isPowerOf2() &&
+          DivC->getAPIntValue().logBase2() == ShC->getAPIntValue()) {
+        return SDValue(Shl0.getNode(), 1);
+      }
----------------
RKSimon wrote:

(style) drop braces

https://github.com/llvm/llvm-project/pull/136565


More information about the llvm-commits mailing list