[llvm] [InstCombine] handle trunc to i1 in foldSelectICmpAndBinOp (PR #127390)
Andreas Jonson via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 19 09:16:33 PST 2025
https://github.com/andjo403 updated https://github.com/llvm/llvm-project/pull/127390
>From 8d61bc5669e094208a486a285a42daa5084681dd Mon Sep 17 00:00:00 2001
From: Andreas Jonson <andjo403 at hotmail.com>
Date: Sun, 16 Feb 2025 12:09:20 +0100
Subject: [PATCH] [InstCombine] handle trunc to i1 in foldSelectICmpAndBinOp
---
.../InstCombine/InstCombineSelect.cpp | 62 +++++++++++--------
.../InstCombine/select-with-bitwise-ops.ll | 29 +++++----
2 files changed, 49 insertions(+), 42 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index cf38fc5f058f2..0dfdd9209b40e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -742,39 +742,47 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal,
/// 1. The icmp predicate is inverted
/// 2. The select operands are reversed
/// 3. The magnitude of C2 and C1 are flipped
-static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
- Value *FalseVal,
- InstCombiner::BuilderTy &Builder) {
+static Value *foldSelectICmpAndBinOp(Value *CondVal, Value *TrueVal,
+ Value *FalseVal,
+ InstCombiner::BuilderTy &Builder) {
// Only handle integer compares. Also, if this is a vector select, we need a
// vector compare.
if (!TrueVal->getType()->isIntOrIntVectorTy() ||
- TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
+ TrueVal->getType()->isVectorTy() != CondVal->getType()->isVectorTy())
return nullptr;
- Value *CmpLHS = IC->getOperand(0);
- Value *CmpRHS = IC->getOperand(1);
-
unsigned C1Log;
bool NeedAnd = false;
- CmpInst::Predicate Pred = IC->getPredicate();
- if (IC->isEquality()) {
- if (!match(CmpRHS, m_Zero()))
- return nullptr;
+ CmpPredicate Pred;
+ Value *CmpLHS, *CmpRHS;
- const APInt *C1;
- if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
- return nullptr;
+ if (match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) {
+ if (ICmpInst::isEquality(Pred)) {
+ if (!match(CmpRHS, m_Zero()))
+ return nullptr;
- C1Log = C1->logBase2();
- } else {
- auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
- if (!Res || !Res->Mask.isPowerOf2())
- return nullptr;
+ const APInt *C1;
+ if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1))))
+ return nullptr;
- CmpLHS = Res->X;
- Pred = Res->Pred;
- C1Log = Res->Mask.logBase2();
- NeedAnd = true;
+ C1Log = C1->logBase2();
+ } else {
+ auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+ if (!Res || !Res->Mask.isPowerOf2())
+ return nullptr;
+
+ CmpLHS = Res->X;
+ Pred = Res->Pred;
+ C1Log = Res->Mask.logBase2();
+ NeedAnd = true;
+ }
+ } else if (auto *Trunc = dyn_cast<TruncInst>(CondVal)) {
+ CmpLHS = Trunc->getOperand(0);
+ C1Log = 0;
+ Pred = ICmpInst::ICMP_NE;
+ NeedAnd = !Trunc->hasNoUnsignedWrap();
+ } else {
+ return nullptr;
}
Value *Y, *V = CmpLHS;
@@ -808,7 +816,7 @@ static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
// Make sure we don't create more instructions than we save.
if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) >
- (IC->hasOneUse() + BinOp->hasOneUse()))
+ (CondVal->hasOneUse() + BinOp->hasOneUse()))
return nullptr;
if (NeedAnd) {
@@ -1986,9 +1994,6 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder))
return V;
- if (Value *V = foldSelectICmpAndBinOp(ICI, TrueVal, FalseVal, Builder))
- return replaceInstUsesWith(SI, V);
-
if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);
@@ -3946,6 +3951,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *Result = foldSelectInstWithICmp(SI, ICI))
return Result;
+ if (Value *V = foldSelectICmpAndBinOp(CondVal, TrueVal, FalseVal, Builder))
+ return replaceInstUsesWith(SI, V);
+
if (Instruction *Add = foldAddSubSelect(SI, Builder))
return Add;
if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder))
diff --git a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
index 67dec9178eeca..ca2e23c1d082e 100644
--- a/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
+++ b/llvm/test/Transforms/InstCombine/select-with-bitwise-ops.ll
@@ -1754,9 +1754,9 @@ define i8 @select_icmp_eq_and_1_0_lshr_tv(i8 %x, i8 %y) {
define i8 @select_trunc_or_2(i8 %x, i8 %y) {
; CHECK-LABEL: @select_trunc_or_2(
-; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
; CHECK-NEXT: ret i8 [[SELECT]]
;
%trunc = trunc i8 %x to i1
@@ -1767,9 +1767,9 @@ define i8 @select_trunc_or_2(i8 %x, i8 %y) {
define i8 @select_not_trunc_or_2(i8 %x, i8 %y) {
; CHECK-LABEL: @select_not_trunc_or_2(
-; CHECK-NEXT: [[TRUNC:%.*]] = trunc i8 [[X:%.*]] to i1
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
; CHECK-NEXT: ret i8 [[SELECT]]
;
%trunc = trunc i8 %x to i1
@@ -1781,9 +1781,8 @@ define i8 @select_not_trunc_or_2(i8 %x, i8 %y) {
define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) {
; CHECK-LABEL: @select_trunc_nuw_or_2(
-; CHECK-NEXT: [[TRUNC:%.*]] = trunc nuw i8 [[X:%.*]] to i1
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP1]]
; CHECK-NEXT: ret i8 [[SELECT]]
;
%trunc = trunc nuw i8 %x to i1
@@ -1794,9 +1793,9 @@ define i8 @select_trunc_nuw_or_2(i8 %x, i8 %y) {
define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) {
; CHECK-LABEL: @select_trunc_nsw_or_2(
-; CHECK-NEXT: [[TRUNC:%.*]] = trunc nsw i8 [[X:%.*]] to i1
-; CHECK-NEXT: [[OR:%.*]] = or i8 [[Y:%.*]], 2
-; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[TRUNC]], i8 [[OR]], i8 [[Y]]
+; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 1
+; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], 2
+; CHECK-NEXT: [[SELECT:%.*]] = or i8 [[Y:%.*]], [[TMP2]]
; CHECK-NEXT: ret i8 [[SELECT]]
;
%trunc = trunc nsw i8 %x to i1
@@ -1807,9 +1806,9 @@ define i8 @select_trunc_nsw_or_2(i8 %x, i8 %y) {
define <2 x i8> @select_trunc_or_2_vec(<2 x i8> %x, <2 x i8> %y) {
; CHECK-LABEL: @select_trunc_or_2_vec(
-; CHECK-NEXT: [[TRUNC:%.*]] = trunc <2 x i8> [[X:%.*]] to <2 x i1>
-; CHECK-NEXT: [[OR:%.*]] = or <2 x i8> [[Y:%.*]], splat (i8 2)
-; CHECK-NEXT: [[SELECT:%.*]] = select <2 x i1> [[TRUNC]], <2 x i8> [[OR]], <2 x i8> [[Y]]
+; CHECK-NEXT: [[TMP1:%.*]] = shl <2 x i8> [[X:%.*]], splat (i8 1)
+; CHECK-NEXT: [[TMP2:%.*]] = and <2 x i8> [[TMP1]], splat (i8 2)
+; CHECK-NEXT: [[SELECT:%.*]] = or <2 x i8> [[Y:%.*]], [[TMP2]]
; CHECK-NEXT: ret <2 x i8> [[SELECT]]
;
%trunc = trunc <2 x i8> %x to <2 x i1>
More information about the llvm-commits
mailing list