[llvm] Perf/fold icmp rem constant (PR #79383)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jan 24 14:47:45 PST 2024


https://github.com/Baxi-codes created https://github.com/llvm/llvm-project/pull/79383

issue: https://github.com/llvm/llvm-project/issues/76585
alive2: https://alive2.llvm.org/ce/z/fj2ACE

I can't figure out how to check for both nsw and nuw flags, as there is only m_NUWmul and m_NSWmul, but the urem works fine. Also having some trouble updating the check lines, so they aren't updated with new output yet.

>From d4a973790dffd4c904ddad57bb4b44af46e2ba3e Mon Sep 17 00:00:00 2001
From: Dhairya <baxidhairya2312 at gmail.com>
Date: Tue, 23 Jan 2024 23:34:19 +0530
Subject: [PATCH 1/3] [InstCombine] Add pre-commit tests. NFC

---
 .../InstCombine/icmp-div-constant.ll          | 56 +++++++++++++++++--
 1 file changed, 52 insertions(+), 4 deletions(-)

diff --git a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
index 8dcb96284685ff7..86ab7afa22b3810 100644
--- a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
@@ -107,6 +107,54 @@ define i1 @is_rem16_something_i8(i8 %x) {
   ret i1 %r
 }
 
+; Tests below contain foldable remainder comparison with constant instructions
+; (a * c1) % (b * c2) ==/!= 0 => (a * k) % b ==/!= 0 if c1 % c2 is 0 and k = c1 / c2
+
+define  i1 @icmp_urem_constant(i8 noundef %0, i8 noundef %1) {
+; CHECK-LABEL: @icmp_urem_constant(
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nuw i8 [[TMP0:%.*]], 9
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw i8 [[TMP1:%.*]], 3
+; CHECK-NEXT:    [[TMP5:%.*]] = urem i8 [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i8 [[TMP5]], 0
+; CHECK-NEXT:    ret i1 [[TMP6]]
+;
+  %3 = mul nuw i8 9, %0
+  %4 = mul nuw i8 3, %1
+  %5 = urem i8 %3, %4
+  %6 = icmp eq i8 %5, 0
+  ret i1 %6
+}
+
+define  i1 @icmp_srem_constant(i8 noundef %0, i8 noundef %1) {
+; CHECK-LABEL: @icmp_srem_constant(
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nuw nsw i8 [[TMP0:%.*]], -64
+; CHECK-NEXT:    [[TMP4:%.*]] = shl nuw nsw i8 [[TMP1:%.*]], 3
+; CHECK-NEXT:    [[TMP5:%.*]] = srem i8 [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i8 [[TMP5]], 0
+; CHECK-NEXT:    ret i1 [[TMP6]]
+;
+  %3 = mul nsw nuw i8 -64, %0
+  %4 = mul nsw nuw i8 8, %1
+  %5 = srem i8 %3, %4
+  %6 = icmp eq i8 %5, 0
+  ret i1 %6
+}
+
+define  i1 @icmp_srem_constant2(i8 noundef %0, i8 noundef %1) {
+; CHECK-LABEL: @icmp_srem_constant2(
+; CHECK-NEXT:    [[TMP3:%.*]] = mul nuw nsw i8 [[TMP0:%.*]], 9
+; CHECK-NEXT:    [[TMP4:%.*]] = mul nuw nsw i8 [[TMP1:%.*]], 9
+; CHECK-NEXT:    [[TMP5:%.*]] = srem i8 [[TMP3]], [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = icmp eq i8 [[TMP5]], 0
+; CHECK-NEXT:    ret i1 [[TMP6]]
+;
+  %3 = mul nsw nuw i8 9, %0
+  %4 = mul nsw nuw i8 9, %1
+  %5 = srem i8 %3, %4
+  %6 = icmp eq i8 %5, 0
+  ret i1 %6
+}
+
 ; PR30281 - https://llvm.org/bugs/show_bug.cgi?id=30281
 
 ; All of these tests contain foldable division-by-constant instructions, but we
@@ -118,8 +166,8 @@ define i32 @icmp_div(i16 %a, i16 %c) {
 ; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp eq i16 [[A:%.*]], 0
 ; CHECK-NEXT:    br i1 [[TOBOOL]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[C:%.*]], 0
-; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i16 [[C:%.*]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP_NOT]] to i32
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ -1, [[ENTRY:%.*]] ], [ [[TMP0]], [[THEN]] ]
@@ -173,8 +221,8 @@ define i32 @icmp_div3(i16 %a, i16 %c) {
 ; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp eq i16 [[A:%.*]], 0
 ; CHECK-NEXT:    br i1 [[TOBOOL]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[C:%.*]], 0
-; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i16 [[C:%.*]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP_NOT]] to i32
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ -1, [[ENTRY:%.*]] ], [ [[TMP0]], [[THEN]] ]

>From fd4a471d2afadb2409d2c11ad97dbb255c552b34 Mon Sep 17 00:00:00 2001
From: Dhairya <baxidhairya2312 at gmail.com>
Date: Thu, 25 Jan 2024 04:04:22 +0530
Subject: [PATCH 2/3] [InstCombine] Implement the transform for urem

---
 .../InstCombine/InstCombineCompares.cpp       | 65 +++++++++++++++++++
 .../InstCombine/InstCombineInternal.h         |  4 +-
 2 files changed, 68 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8c0fd6622551306..92df48f1ecf8519 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2575,6 +2575,66 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
   return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask));
 }
 
+Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
+                                                   BinaryOperator *Rem,
+                                                   const APInt &C) {
+  assert((Rem->getOpcode() == Instruction::SRem || Rem->getOpcode() == Instruction::URem) && "Only for srem/urem!");
+  const ICmpInst::Predicate Pred = Cmp.getPredicate();
+  // Check for ==/!= 0
+  if (!ICmpInst::isEquality(Pred) || !C.isZero())
+    return nullptr;
+
+  Value *X = Rem->getOperand(0);
+  Value *Y = Rem->getOperand(1);
+
+  Value *A, *B;
+  const APInt *C1, *C2;
+  Value *NewRem;
+  APInt K;
+  Type *Ty;
+
+  if (Rem->getOpcode() == Instruction::SRem) {
+    // Check if both NSW and NUW flags are on
+    if (!match(X, m_NSWMul(m_Value(A), m_APInt(C1))))
+      return nullptr;
+    if (!match(Y, m_NSWMul(m_Value(B), m_APInt(C2))))
+      return nullptr;
+    if (!match(X, m_NUWMul(m_Value(A), m_APInt(C1))))
+      return nullptr;
+    if (!match(Y, m_NUWMul(m_Value(B), m_APInt(C2))))
+      return nullptr;
+    if (C2->isZero())
+      return nullptr;
+    if (!C1->srem(*C2).isZero())
+      return nullptr;
+    // Compute the new constant k = c1 / c2.
+    K = C1->sdiv(*C2);
+    Ty = Rem->getType();
+    if (K == 1)
+      NewRem = Builder.CreateSRem(A, B);
+    else
+      NewRem = Builder.CreateSRem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true, true), B);
+  }
+  else {
+    if (!match(X, m_NUWMul(m_Value(A), m_APInt(C1))))
+      return nullptr;
+    if (!match(Y, m_NUWMul(m_Value(B), m_APInt(C2))))
+      return nullptr;
+    if (C2->isZero())
+      return nullptr;
+    if (!C1->urem(*C2).isZero())
+      return nullptr;
+    // Compute the new constant k = c1 / c2.
+    K = C1->udiv(*C2);
+    Ty = Rem->getType();
+    if (K == 1)
+      NewRem = Builder.CreateSRem(A, B);
+    else
+      NewRem = Builder.CreateSRem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true, false), B);
+  }
+  return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
+}
+
 /// Fold icmp (udiv X, Y), C.
 Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
                                                     BinaryOperator *UDiv,
@@ -3712,6 +3772,11 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
     if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
       return I;
     break;
+   [[fallthrough]]; 
+  case Instruction::URem:
+    if (Instruction *I = foldICmpRemConstant(Cmp, BO, C)) 
+      return I; 
+    break;
   case Instruction::UDiv:
     if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
       return I;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index c24b6e3a5b33c0b..17bcc29c9730058 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -674,7 +674,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                                    const APInt &C);
   Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr,
                                    const APInt &C);
-  Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
+  Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *URem,
+                                    const APInt &C);
+  Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
                                     const APInt &C);
   Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
                                     const APInt &C);

>From e794cc13226e73b6e5f41951213a2516e8edd0aa Mon Sep 17 00:00:00 2001
From: Dhairya <baxidhairya2312 at gmail.com>
Date: Thu, 25 Jan 2024 04:13:44 +0530
Subject: [PATCH 3/3] Fixed clang-format

---
 .../InstCombine/InstCombineCompares.cpp       | 23 ++++++++++++-------
 .../InstCombine/InstCombineInternal.h         |  2 +-
 2 files changed, 16 insertions(+), 9 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 92df48f1ecf8519..9c86962ac3c16a6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2578,7 +2578,9 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
 Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
                                                    BinaryOperator *Rem,
                                                    const APInt &C) {
-  assert((Rem->getOpcode() == Instruction::SRem || Rem->getOpcode() == Instruction::URem) && "Only for srem/urem!");
+  assert((Rem->getOpcode() == Instruction::SRem ||
+          Rem->getOpcode() == Instruction::URem) &&
+         "Only for srem/urem!");
   const ICmpInst::Predicate Pred = Cmp.getPredicate();
   // Check for ==/!= 0
   if (!ICmpInst::isEquality(Pred) || !C.isZero())
@@ -2613,9 +2615,11 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
     if (K == 1)
       NewRem = Builder.CreateSRem(A, B);
     else
-      NewRem = Builder.CreateSRem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true, true), B);
-  }
-  else {
+      NewRem = Builder.CreateSRem(
+          Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true,
+                            true),
+          B);
+  } else {
     if (!match(X, m_NUWMul(m_Value(A), m_APInt(C1))))
       return nullptr;
     if (!match(Y, m_NUWMul(m_Value(B), m_APInt(C2))))
@@ -2630,7 +2634,10 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
     if (K == 1)
       NewRem = Builder.CreateSRem(A, B);
     else
-      NewRem = Builder.CreateSRem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true, false), B);
+      NewRem = Builder.CreateSRem(
+          Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true,
+                            false),
+          B);
   }
   return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
 }
@@ -3772,10 +3779,10 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
     if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
       return I;
     break;
-   [[fallthrough]]; 
+    [[fallthrough]];
   case Instruction::URem:
-    if (Instruction *I = foldICmpRemConstant(Cmp, BO, C)) 
-      return I; 
+    if (Instruction *I = foldICmpRemConstant(Cmp, BO, C))
+      return I;
     break;
   case Instruction::UDiv:
     if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 17bcc29c9730058..a0453f5b650a0d3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -677,7 +677,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *URem,
                                     const APInt &C);
   Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
-                                    const APInt &C);
+                                   const APInt &C);
   Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
                                     const APInt &C);
   Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div,



More information about the llvm-commits mailing list