[llvm] [DAGCombiner, NVPTX] Port 'rem' custom combine from NVPTX to generic combiner (PR #167147)

via llvm-commits llvm-commits at lists.llvm.org
Sat Nov 8 07:30:53 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: None (VindhyaP312)

<details>
<summary>Changes</summary>

Fixes #<!-- -->116695 

This patch ports the custom `rem` (remainder) DAG combine from the NVPTX backend (`NVPTXISelLowering.cpp`) into the generic `DAGCombiner`. The optimization is a CSE pattern that folds `A % B` into `A - (A / B) * B` if the quotient `(A / B)` is already computed.

This move allows all targets to benefit from the optimization and cleans up the NVPTX backend. The generic logic now includes an `isIntDivCheap` guard to prevent conflicts with target-specific division optimizations.

---
Full diff: https://github.com/llvm/llvm-project/pull/167147.diff


2 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+38) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (-34) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index f144f17d5a8f2..867f30985ad4e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -900,6 +900,41 @@ namespace {
                          ISD::NodeType ExtType);
   };
 
+/// Generic remainder optimization : Folds a remainder operation (A % B) by reusing the computed quotient (A / B).
+static SDValue PerformREMCombineGeneric(SDNode *N, DAGCombiner &DC,
+                                        CodeGenOptLevel OptLevel) {
+  assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
+
+  // Don't do anything at less than -O2.
+  if (OptLevel < CodeGenOptLevel::Default)
+    return SDValue();
+
+  SelectionDAG &DAG = DC.getDAG();
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  bool IsSigned = N->getOpcode() == ISD::SREM;
+  unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
+
+  const SDValue &Num = N->getOperand(0);
+  const SDValue &Den = N->getOperand(1);
+  
+  AttributeList Attr = DC.getDAG().getMachineFunction().getFunction().getAttributes();
+  if (DC.getDAG().getTargetLoweringInfo().isIntDivCheap(N->getValueType(0), Attr))
+    return SDValue();
+
+  for (const SDNode *U : Num->users()) {
+    if (U->getOpcode() == DivOpc && U->getOperand(0) == Num &&
+        U->getOperand(1) == Den) {
+      // Num % Den -> Num - (Num / Den) * Den
+      return DAG.getNode(ISD::SUB, DL, VT, Num,
+                         DAG.getNode(ISD::MUL, DL, VT,
+                                     DAG.getNode(DivOpc, DL, VT, Num, Den),
+                                     Den));
+    }
+  }
+  return SDValue();
+}
+
 /// This class is a DAGUpdateListener that removes any deleted
 /// nodes from the worklist.
 class WorklistRemover : public SelectionDAG::DAGUpdateListener {
@@ -5400,6 +5435,9 @@ SDValue DAGCombiner::visitREM(SDNode *N) {
   if (SDValue NewSel = foldBinOpIntoSelect(N))
     return NewSel;
 
+  if (SDValue V = PerformREMCombineGeneric(N, *this, OptLevel))
+    return V;
+  
   if (isSigned) {
     // If we know the sign bits of both operands are zero, strength reduce to a
     // urem instead.  Handles (X & 0x0FFFFFFF) %s 16 -> X&15
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index a3deb36074e68..a3cbb09297f24 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5726,37 +5726,6 @@ static SDValue PerformFMinMaxCombine(SDNode *N,
   return SDValue();
 }
 
-static SDValue PerformREMCombine(SDNode *N,
-                                 TargetLowering::DAGCombinerInfo &DCI,
-                                 CodeGenOptLevel OptLevel) {
-  assert(N->getOpcode() == ISD::SREM || N->getOpcode() == ISD::UREM);
-
-  // Don't do anything at less than -O2.
-  if (OptLevel < CodeGenOptLevel::Default)
-    return SDValue();
-
-  SelectionDAG &DAG = DCI.DAG;
-  SDLoc DL(N);
-  EVT VT = N->getValueType(0);
-  bool IsSigned = N->getOpcode() == ISD::SREM;
-  unsigned DivOpc = IsSigned ? ISD::SDIV : ISD::UDIV;
-
-  const SDValue &Num = N->getOperand(0);
-  const SDValue &Den = N->getOperand(1);
-
-  for (const SDNode *U : Num->users()) {
-    if (U->getOpcode() == DivOpc && U->getOperand(0) == Num &&
-        U->getOperand(1) == Den) {
-      // Num % Den -> Num - (Num / Den) * Den
-      return DAG.getNode(ISD::SUB, DL, VT, Num,
-                         DAG.getNode(ISD::MUL, DL, VT,
-                                     DAG.getNode(DivOpc, DL, VT, Num, Den),
-                                     Den));
-    }
-  }
-  return SDValue();
-}
-
 // (sign_extend|zero_extend (mul|shl) x, y) -> (mul.wide x, y)
 static SDValue combineMulWide(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                               CodeGenOptLevel OptLevel) {
@@ -6428,9 +6397,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     return PerformSETCCCombine(N, DCI, STI.getSmVersion());
   case ISD::SHL:
     return PerformSHLCombine(N, DCI, OptLevel);
-  case ISD::SREM:
-  case ISD::UREM:
-    return PerformREMCombine(N, DCI, OptLevel);
   case ISD::STORE:
   case NVPTXISD::StoreV2:
   case NVPTXISD::StoreV4:

``````````

</details>


https://github.com/llvm/llvm-project/pull/167147


More information about the llvm-commits mailing list