[llvm] [SelectionDAG][X86] Fold `sub(x, mul(divrem(x,y)[0], y))` to `divrem(x, y)[1]` (PR #136565)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 21 06:24:32 PDT 2025


================
@@ -3867,6 +3867,60 @@ static SDValue foldSubCtlzNot(SDNode *N, SelectionDAG &DAG) {
   return Matcher.getNode(ISD::CTLZ_ZERO_UNDEF, DL, VT, Not);
 }
 
+// Fold sub(x, mul(divrem(x,y)[0], y)) to divrem(x, y)[1]
+static SDValue foldSubOfQuotientToRem(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::SUB && "Node must be a SUB");
+  SDValue Sub0 = N->getOperand(0);
+  SDValue Sub1 = N->getOperand(1);
+  SDLoc DL(N);
+
+  auto CheckAndFoldMulCase = [&](SDValue DivRem, SDValue MaybeY) -> SDValue {
+    if ((DivRem.getOpcode() == ISD::SDIVREM ||
+         DivRem.getOpcode() == ISD::UDIVREM) &&
+        DivRem.getResNo() == 0 && DivRem.getOperand(0) == Sub0 &&
+        DivRem.getOperand(1) == MaybeY) {
+      return SDValue(DivRem.getNode(), 1);
+    }
+    return SDValue();
+  };
+
+  if (Sub1.getOpcode() == ISD::MUL) {
+    // (sub x, (mul divrem(x,y)[0], y))
+    SDValue Mul0 = Sub1.getOperand(0);
+    SDValue Mul1 = Sub1.getOperand(1);
+
+    SDValue Res = CheckAndFoldMulCase(Mul0, Mul1);
+    if (Res.getNode())
+      return Res;
+
+    Res = CheckAndFoldMulCase(Mul1, Mul0);
+    if (Res.getNode())
+      return Res;
+
+  } else if (Sub1.getOpcode() == ISD::SHL) {
+    // Handle (sub x, (shl divrem(x,y)[0], C)) where y = 1 << C
+    SDValue Shl0 = Sub1.getOperand(0);
+    SDValue Shl1 = Sub1.getOperand(1);
+    // Check if Shl0 is divrem(x, Y)[0]
+    if ((Shl0.getOpcode() == ISD::SDIVREM ||
+         Shl0.getOpcode() == ISD::UDIVREM) &&
+        Shl0.getResNo() == 0 && Shl0.getOperand(0) == Sub0) {
+
+      SDValue Divisor = Shl0.getOperand(1);
+
+      // Check if DivRemDivisor is a constant power of 2
+      auto *C = dyn_cast<ConstantSDNode>(Divisor);
+      if (C && C->getAPIntValue().isPowerOf2() && !C->isZero()) {
+        auto *Shamt = dyn_cast<ConstantSDNode>(Shl1);
+        if (Shamt && Shamt->getAPIntValue() == C->getAPIntValue().logBase2()) {
----------------
arsenm wrote:

Doesn't handle the vector case

https://github.com/llvm/llvm-project/pull/136565


More information about the llvm-commits mailing list