[llvm] Implement foldICmpRemConstant in InstCombineCompares (PR #77410)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jan 8 21:03:15 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: None (Baxi-codes)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/77410.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+46-1) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+3-1) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 7c1aff445524de..0add51b8175555 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2572,6 +2572,46 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
   return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask));
 }
 
+Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
+                                                   BinaryOperator *Rem,
+                                                   const APInt &C) {
+  const ICmpInst::Predicate Pred = Cmp.getPredicate();
+  Value *X = Rem->getOperand(0);
+  Value *Y = Rem->getOperand(1);
+
+  // Check if the remainder operation is in the required form.
+  if (!isa<BinaryOperator>(X) || !isa<BinaryOperator>(Y))
+    return nullptr;
+
+  BinaryOperator *MulX = cast<BinaryOperator>(X);
+  BinaryOperator *MulY = cast<BinaryOperator>(Y);
+
+  // Check if the operands are multiplication operations.
+  if (MulX->getOpcode() != Instruction::Mul || MulY->getOpcode() != Instruction::Mul)
+    return nullptr;
+
+  // Get the multiplication operands and constants.
+  Value *A = MulX->getOperand(0);
+  Value *C1 = MulX->getOperand(1);
+  Value *B = MulY->getOperand(0);
+  Value *C2 = MulY->getOperand(1);
+
+  const APInt *C1Value, *C2Value;
+
+  // Check if the constants satisfy the condition c1 % c2 == 0.
+  if (!match(C1, m_APInt(C1Value)) || !match(C2, m_APInt(C2Value)) || C1Value->urem(*C2Value) != 0)
+    return nullptr;
+
+  // Compute the new constant k = c1 / c2.
+  APInt K = C1Value->udiv(*C2Value);
+  Type *Ty = Rem->getType();
+
+  // Create a new remainder instruction (a * k) % b.
+  Value *NewRem = Builder.CreateURem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K)), B);
+  return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
+}
+
+
 /// Fold icmp (udiv X, Y), C.
 Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
                                                     BinaryOperator *UDiv,
@@ -2963,7 +3003,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
 
   // Fold icmp pred (add X, C2), C.
   Type *Ty = Add->getType();
-
+  
   // If the add does not wrap, we can always adjust the compare by subtracting
   // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
   // are canonicalized to SGT/SLT/UGT/ULT.
@@ -3708,7 +3748,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
   case Instruction::SRem:
     if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
       return I;
+   [[fallthrough]]; 
+  case Instruction::URem:
+    if (Instruction *I = foldICmpRemConstant(Cmp, BO, C)) 
+      return I; 
     break;
+
   case Instruction::UDiv:
     if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
       return I;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 21c61bd990184d..748fe04c470e46 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -670,7 +670,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                    const APInt &C);
   Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr,
                                    const APInt &C);
-  Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
+  Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *SRem,
+                                    const APInt &C);
+  Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
                                     const APInt &C);
   Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
                                     const APInt &C);

``````````

</details>


https://github.com/llvm/llvm-project/pull/77410


More information about the llvm-commits mailing list