[llvm] [InstCombine] Handle `ICMP_EQ` when flooring by constant two (PR #73706)
Antonio Frighetto via llvm-commits
llvm-commits at lists.llvm.org
Wed Nov 29 03:38:13 PST 2023
https://github.com/antoniofrighetto updated https://github.com/llvm/llvm-project/pull/73706
>From ad492c0fdd3687aed9a150662f6c84c67ba2b77e Mon Sep 17 00:00:00 2001
From: Antonio Frighetto <me at antoniofrighetto.com>
Date: Tue, 28 Nov 2023 21:04:56 +0100
Subject: [PATCH] [InstCombine] Handle equality comparison when flooring by
constant 2
Support `icmp eq` when reducing signed divisions by power of 2 to
arithmetic shift right, as `icmp ugt` may have been canonicalized
into `icmp eq` by the time additions are folded into `ashr`.
Fixes: https://github.com/llvm/llvm-project/issues/73622.
Proof: https://alive2.llvm.org/ce/z/8-eUdb.
---
.../InstCombine/InstCombineAddSub.cpp | 22 ++++++++++----
llvm/test/Transforms/InstCombine/add.ll | 30 +++++++++++++++++++
2 files changed, 46 insertions(+), 6 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 24a166906f1f46d..3604abb8e5277b5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1234,18 +1234,28 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) {
return nullptr;
// Rounding is done by adding -1 if the dividend (X) is negative and has any
- // low bits set. The canonical pattern for that is an "ugt" compare with SMIN:
- // sext (icmp ugt (X & (DivC - 1)), SMIN)
- const APInt *MaskC;
+ // low bits set. It recognizes two canonical patterns:
+ // 1. For an 'ugt' cmp with the signed minimum value (SMIN), the
+ // pattern is: sext (icmp ugt (X & (DivC - 1)), SMIN).
+ // 2. For an 'eq' cmp, the pattern's: sext (icmp eq X & (SMIN + 1), SMIN + 1).
+ // Note that, by the time we end up here, if possible, ugt has been
+ // canonicalized into eq.
+ const APInt *MaskC, *MaskCCmp;
ICmpInst::Predicate Pred;
if (!match(Add.getOperand(1),
m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)),
- m_SignMask()))) ||
- Pred != ICmpInst::ICMP_UGT)
+ m_APInt(MaskCCmp)))))
+ return nullptr;
+
+ if ((Pred != ICmpInst::ICMP_UGT || !MaskCCmp->isSignMask()) &&
+ (Pred != ICmpInst::ICMP_EQ || *MaskCCmp != *MaskC))
return nullptr;
APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits());
- if (*MaskC != (SMin | (*DivC - 1)))
+ bool IsMaskValid = Pred == ICmpInst::ICMP_UGT
+ ? (*MaskC == (SMin | (*DivC - 1)))
+ : (*DivC == 2 && *MaskC == SMin + 1);
+ if (!IsMaskValid)
return nullptr;
// (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC)
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index c35d2af42a4beae..84ca5b2e243912c 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -2700,6 +2700,36 @@ define i32 @floor_sdiv(i32 %x) {
ret i32 %r
}
+define i8 @floor_sdiv_by_2(i8 %x) {
+; CHECK-LABEL: @floor_sdiv_by_2(
+; CHECK-NEXT: [[RV:%.*]] = ashr i8 [[X:%.*]], 1
+; CHECK-NEXT: ret i8 [[RV]]
+;
+ %div = sdiv i8 %x, 2
+ %and = and i8 %x, -127
+ %icmp = icmp eq i8 %and, -127
+ %sext = sext i1 %icmp to i8
+ %rv = add nsw i8 %div, %sext
+ ret i8 %rv
+}
+
+define i8 @floor_sdiv_by_2_wrong_mask(i8 %x) {
+; CHECK-LABEL: @floor_sdiv_by_2_wrong_mask(
+; CHECK-NEXT: [[DIV:%.*]] = sdiv i8 [[X:%.*]], 2
+; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 127
+; CHECK-NEXT: [[ICMP:%.*]] = icmp eq i8 [[AND]], 127
+; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[ICMP]] to i8
+; CHECK-NEXT: [[RV:%.*]] = add nsw i8 [[DIV]], [[SEXT]]
+; CHECK-NEXT: ret i8 [[RV]]
+;
+ %div = sdiv i8 %x, 2
+ %and = and i8 %x, 127
+ %icmp = icmp eq i8 %and, 127
+ %sext = sext i1 %icmp to i8
+ %rv = add nsw i8 %div, %sext
+ ret i8 %rv
+}
+
; vectors work too and commute is handled by complexity-based canonicalization
define <2 x i32> @floor_sdiv_vec_commute(<2 x i32> %x) {
More information about the llvm-commits
mailing list