[llvm] Perf/fold icmp rem constant (PR #79383)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 24 14:47:45 PST 2024
https://github.com/Baxi-codes created https://github.com/llvm/llvm-project/pull/79383
issue: https://github.com/llvm/llvm-project/issues/76585
alive2: https://alive2.llvm.org/ce/z/fj2ACE
I can't figure out how to check for both nsw and nuw flags, as there is only m_NUWmul and m_NSWmul, but the urem works fine. Also having some trouble updating the check lines, so they aren't updated with new output yet.
>From d4a973790dffd4c904ddad57bb4b44af46e2ba3e Mon Sep 17 00:00:00 2001
From: Dhairya <baxidhairya2312 at gmail.com>
Date: Tue, 23 Jan 2024 23:34:19 +0530
Subject: [PATCH 1/3] [InstCombine] Add pre-commit tests. NFC
---
.../InstCombine/icmp-div-constant.ll | 56 +++++++++++++++++--
1 file changed, 52 insertions(+), 4 deletions(-)
diff --git a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
index 8dcb96284685ff7..86ab7afa22b3810 100644
--- a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
@@ -107,6 +107,54 @@ define i1 @is_rem16_something_i8(i8 %x) {
ret i1 %r
}
+; Tests below contain foldable remainder comparison with constant instructions
+; (a * c1) % (b * c2) ==/!= 0 => (a * k) % b ==/!= 0 if c1 % c2 is 0 and k = c1 / c2
+
+define i1 @icmp_urem_constant(i8 noundef %0, i8 noundef %1) {
+; CHECK-LABEL: @icmp_urem_constant(
+; CHECK-NEXT: [[TMP3:%.*]] = mul nuw i8 [[TMP0:%.*]], 9
+; CHECK-NEXT: [[TMP4:%.*]] = mul nuw i8 [[TMP1:%.*]], 3
+; CHECK-NEXT: [[TMP5:%.*]] = urem i8 [[TMP3]], [[TMP4]]
+; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i8 [[TMP5]], 0
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %3 = mul nuw i8 9, %0
+ %4 = mul nuw i8 3, %1
+ %5 = urem i8 %3, %4
+ %6 = icmp eq i8 %5, 0
+ ret i1 %6
+}
+
+define i1 @icmp_srem_constant(i8 noundef %0, i8 noundef %1) {
+; CHECK-LABEL: @icmp_srem_constant(
+; CHECK-NEXT: [[TMP3:%.*]] = mul nuw nsw i8 [[TMP0:%.*]], -64
+; CHECK-NEXT: [[TMP4:%.*]] = shl nuw nsw i8 [[TMP1:%.*]], 3
+; CHECK-NEXT: [[TMP5:%.*]] = srem i8 [[TMP3]], [[TMP4]]
+; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i8 [[TMP5]], 0
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %3 = mul nsw nuw i8 -64, %0
+ %4 = mul nsw nuw i8 8, %1
+ %5 = srem i8 %3, %4
+ %6 = icmp eq i8 %5, 0
+ ret i1 %6
+}
+
+define i1 @icmp_srem_constant2(i8 noundef %0, i8 noundef %1) {
+; CHECK-LABEL: @icmp_srem_constant2(
+; CHECK-NEXT: [[TMP3:%.*]] = mul nuw nsw i8 [[TMP0:%.*]], 9
+; CHECK-NEXT: [[TMP4:%.*]] = mul nuw nsw i8 [[TMP1:%.*]], 9
+; CHECK-NEXT: [[TMP5:%.*]] = srem i8 [[TMP3]], [[TMP4]]
+; CHECK-NEXT: [[TMP6:%.*]] = icmp eq i8 [[TMP5]], 0
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %3 = mul nsw nuw i8 9, %0
+ %4 = mul nsw nuw i8 9, %1
+ %5 = srem i8 %3, %4
+ %6 = icmp eq i8 %5, 0
+ ret i1 %6
+}
+
; PR30281 - https://llvm.org/bugs/show_bug.cgi?id=30281
; All of these tests contain foldable division-by-constant instructions, but we
@@ -118,8 +166,8 @@ define i32 @icmp_div(i16 %a, i16 %c) {
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i16 [[A:%.*]], 0
; CHECK-NEXT: br i1 [[TOBOOL]], label [[THEN:%.*]], label [[EXIT:%.*]]
; CHECK: then:
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[C:%.*]], 0
-; CHECK-NEXT: [[TMP0:%.*]] = sext i1 [[CMP]] to i32
+; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i16 [[C:%.*]], 0
+; CHECK-NEXT: [[TMP0:%.*]] = sext i1 [[CMP_NOT]] to i32
; CHECK-NEXT: br label [[EXIT]]
; CHECK: exit:
; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ -1, [[ENTRY:%.*]] ], [ [[TMP0]], [[THEN]] ]
@@ -173,8 +221,8 @@ define i32 @icmp_div3(i16 %a, i16 %c) {
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp eq i16 [[A:%.*]], 0
; CHECK-NEXT: br i1 [[TOBOOL]], label [[THEN:%.*]], label [[EXIT:%.*]]
; CHECK: then:
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i16 [[C:%.*]], 0
-; CHECK-NEXT: [[TMP0:%.*]] = sext i1 [[CMP]] to i32
+; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i16 [[C:%.*]], 0
+; CHECK-NEXT: [[TMP0:%.*]] = sext i1 [[CMP_NOT]] to i32
; CHECK-NEXT: br label [[EXIT]]
; CHECK: exit:
; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ -1, [[ENTRY:%.*]] ], [ [[TMP0]], [[THEN]] ]
>From fd4a471d2afadb2409d2c11ad97dbb255c552b34 Mon Sep 17 00:00:00 2001
From: Dhairya <baxidhairya2312 at gmail.com>
Date: Thu, 25 Jan 2024 04:04:22 +0530
Subject: [PATCH 2/3] [InstCombine] Implement the transform for urem
---
.../InstCombine/InstCombineCompares.cpp | 65 +++++++++++++++++++
.../InstCombine/InstCombineInternal.h | 4 +-
2 files changed, 68 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8c0fd6622551306..92df48f1ecf8519 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2575,6 +2575,66 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask));
}
+Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
+ BinaryOperator *Rem,
+ const APInt &C) {
+ assert((Rem->getOpcode() == Instruction::SRem || Rem->getOpcode() == Instruction::URem) && "Only for srem/urem!");
+ const ICmpInst::Predicate Pred = Cmp.getPredicate();
+ // Check for ==/!= 0
+ if (!ICmpInst::isEquality(Pred) || !C.isZero())
+ return nullptr;
+
+ Value *X = Rem->getOperand(0);
+ Value *Y = Rem->getOperand(1);
+
+ Value *A, *B;
+ const APInt *C1, *C2;
+ Value *NewRem;
+ APInt K;
+ Type *Ty;
+
+ if (Rem->getOpcode() == Instruction::SRem) {
+ // Check if both NSW and NUW flags are on
+ if (!match(X, m_NSWMul(m_Value(A), m_APInt(C1))))
+ return nullptr;
+ if (!match(Y, m_NSWMul(m_Value(B), m_APInt(C2))))
+ return nullptr;
+ if (!match(X, m_NUWMul(m_Value(A), m_APInt(C1))))
+ return nullptr;
+ if (!match(Y, m_NUWMul(m_Value(B), m_APInt(C2))))
+ return nullptr;
+ if (C2->isZero())
+ return nullptr;
+ if (!C1->srem(*C2).isZero())
+ return nullptr;
+ // Compute the new constant k = c1 / c2.
+ K = C1->sdiv(*C2);
+ Ty = Rem->getType();
+ if (K == 1)
+ NewRem = Builder.CreateSRem(A, B);
+ else
+ NewRem = Builder.CreateSRem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true, true), B);
+ }
+ else {
+ if (!match(X, m_NUWMul(m_Value(A), m_APInt(C1))))
+ return nullptr;
+ if (!match(Y, m_NUWMul(m_Value(B), m_APInt(C2))))
+ return nullptr;
+ if (C2->isZero())
+ return nullptr;
+ if (!C1->urem(*C2).isZero())
+ return nullptr;
+ // Compute the new constant k = c1 / c2.
+ K = C1->udiv(*C2);
+ Ty = Rem->getType();
+ if (K == 1)
+ NewRem = Builder.CreateSRem(A, B);
+ else
+ NewRem = Builder.CreateSRem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true, false), B);
+ }
+ return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
+}
+
/// Fold icmp (udiv X, Y), C.
Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp,
BinaryOperator *UDiv,
@@ -3712,6 +3772,11 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
return I;
break;
+ [[fallthrough]];
+ case Instruction::URem:
+ if (Instruction *I = foldICmpRemConstant(Cmp, BO, C))
+ return I;
+ break;
case Instruction::UDiv:
if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
return I;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index c24b6e3a5b33c0b..17bcc29c9730058 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -674,7 +674,9 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
const APInt &C);
Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr,
const APInt &C);
- Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
+ Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *URem,
+ const APInt &C);
+ Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
const APInt &C);
Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
const APInt &C);
>From e794cc13226e73b6e5f41951213a2516e8edd0aa Mon Sep 17 00:00:00 2001
From: Dhairya <baxidhairya2312 at gmail.com>
Date: Thu, 25 Jan 2024 04:13:44 +0530
Subject: [PATCH 3/3] Fixed clang-format
---
.../InstCombine/InstCombineCompares.cpp | 23 ++++++++++++-------
.../InstCombine/InstCombineInternal.h | 2 +-
2 files changed, 16 insertions(+), 9 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 92df48f1ecf8519..9c86962ac3c16a6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2578,7 +2578,9 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp,
Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
BinaryOperator *Rem,
const APInt &C) {
- assert((Rem->getOpcode() == Instruction::SRem || Rem->getOpcode() == Instruction::URem) && "Only for srem/urem!");
+ assert((Rem->getOpcode() == Instruction::SRem ||
+ Rem->getOpcode() == Instruction::URem) &&
+ "Only for srem/urem!");
const ICmpInst::Predicate Pred = Cmp.getPredicate();
// Check for ==/!= 0
if (!ICmpInst::isEquality(Pred) || !C.isZero())
@@ -2613,9 +2615,11 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
if (K == 1)
NewRem = Builder.CreateSRem(A, B);
else
- NewRem = Builder.CreateSRem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true, true), B);
- }
- else {
+ NewRem = Builder.CreateSRem(
+ Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true,
+ true),
+ B);
+ } else {
if (!match(X, m_NUWMul(m_Value(A), m_APInt(C1))))
return nullptr;
if (!match(Y, m_NUWMul(m_Value(B), m_APInt(C2))))
@@ -2630,7 +2634,10 @@ Instruction *InstCombinerImpl::foldICmpRemConstant(ICmpInst &Cmp,
if (K == 1)
NewRem = Builder.CreateSRem(A, B);
else
- NewRem = Builder.CreateSRem(Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true, false), B);
+ NewRem = Builder.CreateSRem(
+ Builder.CreateMul(A, ConstantInt::get(A->getType(), K), "", true,
+ false),
+ B);
}
return new ICmpInst(Pred, NewRem, ConstantInt::get(Ty, C));
}
@@ -3772,10 +3779,10 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
if (Instruction *I = foldICmpSRemConstant(Cmp, BO, C))
return I;
break;
- [[fallthrough]];
+ [[fallthrough]];
case Instruction::URem:
- if (Instruction *I = foldICmpRemConstant(Cmp, BO, C))
- return I;
+ if (Instruction *I = foldICmpRemConstant(Cmp, BO, C))
+ return I;
break;
case Instruction::UDiv:
if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 17bcc29c9730058..a0453f5b650a0d3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -677,7 +677,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *URem,
const APInt &C);
Instruction *foldICmpRemConstant(ICmpInst &Cmp, BinaryOperator *Rem,
- const APInt &C);
+ const APInt &C);
Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv,
const APInt &C);
Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div,
More information about the llvm-commits
mailing list