[llvm] Perf/fold icmp rem constant (PR #79383)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 24 14:48:34 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: None (Baxi-codes)

<details>
<summary>Changes</summary>

issue: https://github.com/llvm/llvm-project/issues/76585
alive2: https://alive2.llvm.org/ce/z/fj2ACE

I can't figure out how to check for both nsw and nuw flags, as there is only m_NUWmul and m_NSWmul, but the urem works fine. Also having some trouble updating the check lines, so they aren't updated with new output yet.

---
Full diff: https://github.com/llvm/llvm-project/pull/79383.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+72) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+3-1) 
- (modified) llvm/test/Transforms/InstCombine/icmp-div-constant.ll (+52-4) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8c0fd662255130..9c86962ac3c16a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2575,6 +2575,73 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
   return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask));
 }
 
+Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
+                                                   BinaryOperator *Rem,
+                                                   const APInt &C) {
+  assert((Rem->getOpcode() == Instruction::SRem ||
+          Rem->getOpcode() == Instruction::URem) &&
+         "Only for srem/urem!");
+  const ICmpInst::Predicate Pred = Cmp.getPredicate();
+  // Check for ==/!= 0
+  if (!ICmpInst::isEquality(Pred) || !C.isZero())
+    return nullptr;
+
+  Value *X = Rem->getOperand(0);
+  Value *Y = Rem->getOperand(1);
+
+  Value *A, *B;
+  const APInt *C1, *C2;
+  Value *NewRem;
+  APInt K;
+  Type *Ty;
+
+  if (Rem->getOpcode() == Instruction::SRem) {
+    // Check if both NSW and NUW flags are on
+    if (!match(X, m_NSWMul(m_Value(A), m_APInt(C1))))
+      return nullptr;
+    if (!match(Y, m_NSWMul(m_Value(B), m_APInt(C2))))
+      return nullptr;
+    if (!match(X, m_NUWMul(m_Value(A), m_APInt(C1))))
+      return nullptr;
+    if (!match(Y, m_NUWMul(m_Value(B), m_APInt(C2))))
+      return nullptr;
+    if (C2->isZero())
+      return nullptr;
+    if (!C1->srem(*C2).isZero())
+      return nullptr;
+    // Compute the new constant k = c1 / c2.
+    K = C1->sdiv(*C2);
+    Ty = Rem->getType();
+    if (K == 1)
+      NewRem = Builder.CreateSRem(A, B);
+    else
+      NewRem = Builder.CreateSRem(
+          Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true,
+                            true),
+          B);
+  } else {
+    if (!match(X, m_NUWMul(m_Value(A), m_APInt(C1))))
+      return nullptr;
+    if (!match(Y, m_NUWMul(m_Value(B), m_APInt(C2))))
+      return nullptr;
+    if (C2->isZero())
+      return nullptr;
+    if (!C1->urem(*C2).isZero())
+      return nullptr;
+    // Compute the new constant k = c1 / c2.
+    K = C1->udiv(*C2);
+    Ty = Rem->getType();
+    if (K == 1)
+      NewRem = Builder.CreateSRem(A, B);
+    else
+      NewRem = Builder.CreateSRem(
+          Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true,
+                            false),
+          B);
+  }
+  return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
+}
+
 /// Fold icmp (udiv X, Y), C.
 Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
                                                     BinaryOperator *UDiv,
@@ -3712,6 +3779,11 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
     if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
       return I;
     break;
+    [[fallthrough]];
+  case Instruction::URem:
+    if (Instruction *I = foldICmpRemConstant(Cmp, BO, C))
+      return I;
+    break;
   case Instruction::UDiv:
     if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
       return I;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index c24b6e3a5b33c0..a0453f5b650a0d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -674,8 +674,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                    const APInt &C);
   Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr,
                                    const APInt &C);
-  Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
+  Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *URem,
                                     const APInt &C);
+  Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
+                                   const APInt &C);
   Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
                                     const APInt &C);
   Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div,
diff --git a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
index 8dcb96284685ff..86ab7afa22b381 100644
--- a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
@@ -107,6 +107,54 @@ define i1 @is_rem16_something_i8(i8 %x) {
   ret i1 %r
 }
 
+; Tests below contain foldable remainder comparison with constant instructions
+; (a * c1) % (b * c2) ==/!= 0 => (a * k) % b ==/!= 0 if c1 % c2 is 0 and k = c1 / c2
+
+define  i1 @icmp_urem_constant(i8 noundef %0, i8 noundef %1) {
+; CHECK-LABEL: @icmp_urem_constant(
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nuw i8 [[TMP0:%.*]], 9
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw i8 [[TMP1:%.*]], 3
+; CHECK-NEXT:    [[TMP5:%.*]] = urem i8 [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i8 [[TMP5]], 0
+; CHECK-NEXT:    ret i1 [[TMP6]]
+;
+  %3 = mul nuw i8 9, %0
+  %4 = mul nuw i8 3, %1
+  %5 = urem i8 %3, %4
+  %6 = icmp eq i8 %5, 0
+  ret i1 %6
+}
+
+define  i1 @icmp_srem_constant(i8 noundef %0, i8 noundef %1) {
+; CHECK-LABEL: @icmp_srem_constant(
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nuw nsw i8 [[TMP0:%.*]], -64
+; CHECK-NEXT:    [[TMP4:%.*]] = shl nuw nsw i8 [[TMP1:%.*]], 3
+; CHECK-NEXT:    [[TMP5:%.*]] = srem i8 [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i8 [[TMP5]], 0
+; CHECK-NEXT:    ret i1 [[TMP6]]
+;
+  %3 = mul nsw nuw i8 -64, %0
+  %4 = mul nsw nuw i8 8, %1
+  %5 = srem i8 %3, %4
+  %6 = icmp eq i8 %5, 0
+  ret i1 %6
+}
+
+define  i1 @icmp_srem_constant2(i8 noundef %0, i8 noundef %1) {
+; CHECK-LABEL: @icmp_srem_constant2(
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nuw nsw i8 [[TMP0:%.*]], 9
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw i8 [[TMP1:%.*]], 9
+; CHECK-NEXT:    [[TMP5:%.*]] = srem i8 [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i8 [[TMP5]], 0
+; CHECK-NEXT:    ret i1 [[TMP6]]
+;
+  %3 = mul nsw nuw i8 9, %0
+  %4 = mul nsw nuw i8 9, %1
+  %5 = srem i8 %3, %4
+  %6 = icmp eq i8 %5, 0
+  ret i1 %6
+}
+
 ; PR30281 - https://llvm.org/bugs/show_bug.cgi?id=30281
 
 ; All of these tests contain foldable division-by-constant instructions, but we
@@ -118,8 +166,8 @@ define i32 @icmp_div(i16 %a, i16 %c) {
 ; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp eq i16 [[A:%.*]], 0
 ; CHECK-NEXT:    br i1 [[TOBOOL]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[C:%.*]], 0
-; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i16 [[C:%.*]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP_NOT]] to i32
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ -1, [[ENTRY:%.*]] ], [ [[TMP0]], [[THEN]] ]
@@ -173,8 +221,8 @@ define i32 @icmp_div3(i16 %a, i16 %c) {
 ; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp eq i16 [[A:%.*]], 0
 ; CHECK-NEXT:    br i1 [[TOBOOL]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[C:%.*]], 0
-; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i16 [[C:%.*]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP_NOT]] to i32
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ -1, [[ENTRY:%.*]] ], [ [[TMP0]], [[THEN]] ]

``````````

</details>


https://github.com/llvm/llvm-project/pull/79383


More information about the llvm-commits mailing list