[llvm] [InstCombine] Generalize select equiv fold for plain condition (PR #85663)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 18 09:56:25 PDT 2024
https://github.com/nikic created https://github.com/llvm/llvm-project/pull/85663
The select equivalence fold takes a select like "X == Y ? A : B" and then tries to simplify A based on the known equality.
This patch also uses it for the case were we have just "C ? A : B" by treating the condition as either "C == 1" or "C != 1".
This is intended as an alternative to #83405
for fixing https://github.com/llvm/llvm-project/issues/83225.
>From b1fab40db94184e9f45cc70d8654074699c103bf Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 20 Sep 2022 16:53:46 +0200
Subject: [PATCH] [InstCombine] Generalize select equiv fold for plain
condition
The select equivalence fold takes a select like "X == Y ? A : B"
and then tries to simplify A based on the known equality.
This patch also uses it for the case were we have just "C ? A : B"
by treating the condition as either "C == 1" or "C != 1".
This is intended as an alternative to #83405
for fixing https://github.com/llvm/llvm-project/issues/83225.
---
.../InstCombine/InstCombineInternal.h | 4 +-
.../InstCombine/InstCombineSelect.cpp | 41 ++++++++-----------
llvm/test/Transforms/InstCombine/select.ll | 15 +++----
3 files changed, 25 insertions(+), 35 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index e2b744ba66f2a9..a8353092d72db1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -735,9 +735,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Value *A, Value *B, Instruction &Outer,
SelectPatternFlavor SPF2, Value *C);
Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI);
- Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI);
bool replaceInInstruction(Value *V, Value *Old, Value *New,
unsigned Depth = 0);
+ Instruction *foldSelectValueEquivalence(SelectInst &Sel,
+ ICmpInst::Predicate Pred,
+ Value *CmpLHS, Value *CmpRHS);
Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
bool isSigned, bool Inside);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ee76a6294428b3..3d52661d3c20d8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1262,27 +1262,23 @@ bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New,
///
/// We can't replace %sel with %add unless we strip away the flags.
/// TODO: Wrapping flags could be preserved in some cases with better analysis.
-Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
- ICmpInst &Cmp) {
- if (!Cmp.isEquality())
+Instruction *InstCombinerImpl::foldSelectValueEquivalence(
+ SelectInst &Sel, ICmpInst::Predicate Pred, Value *CmpLHS, Value *CmpRHS) {
+ if (!ICmpInst::isEquality(Pred))
return nullptr;
// Canonicalize the pattern to ICMP_EQ by swapping the select operands.
Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
bool Swapped = false;
- if (Cmp.getPredicate() == ICmpInst::ICMP_NE) {
+ if (Pred == ICmpInst::ICMP_NE) {
std::swap(TrueVal, FalseVal);
Swapped = true;
}
// In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand.
// Make sure Y cannot be undef though, as we might pick different values for
- // undef in the icmp and in f(Y). Additionally, take care to avoid replacing
- // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite
- // replacement cycle.
- Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
- if (TrueVal != CmpLHS &&
- isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
+ // undef in the icmp and in f(Y).
+ if (isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
/* AllowRefinement */ true))
// Require either the replacement or the simplification result to be a
@@ -1299,7 +1295,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
// profitability is not clear for other cases.
// FIXME: Support vectors.
if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) &&
- !Cmp.getType()->isVectorTy())
+ !CmpLHS->getType()->isVectorTy())
if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS))
return &Sel;
}
@@ -1680,7 +1676,8 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
/// Visit a SelectInst that has an ICmpInst as its first operand.
Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
ICmpInst *ICI) {
- if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI))
+ if (Instruction *NewSel = foldSelectValueEquivalence(
+ SI, ICI->getPredicate(), ICI->getOperand(0), ICI->getOperand(1)))
return NewSel;
if (Value *V =
@@ -3376,21 +3373,15 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this))
return I;
- // If the type of select is not an integer type or if the condition and
- // the selection type are not both scalar nor both vector types, there is no
- // point in attempting to match these patterns.
Type *CondType = CondVal->getType();
- if (!isa<Constant>(CondVal) && SelType->isIntOrIntVectorTy() &&
- CondType->isVectorTy() == SelType->isVectorTy()) {
- if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal,
- ConstantInt::getTrue(CondType), SQ,
- /* AllowRefinement */ true))
- return replaceOperand(SI, 1, S);
+ if (!isa<Constant>(CondVal)) {
+ if (Instruction *I = foldSelectValueEquivalence(
+ SI, ICmpInst::ICMP_EQ, CondVal, ConstantInt::getTrue(CondType)))
+ return I;
- if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal,
- ConstantInt::getFalse(CondType), SQ,
- /* AllowRefinement */ true))
- return replaceOperand(SI, 2, S);
+ if (Instruction *I = foldSelectValueEquivalence(
+ SI, ICmpInst::ICMP_NE, CondVal, ConstantInt::getFalse(CondType)))
+ return I;
}
if (Instruction *R = foldSelectOfBools(SI))
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 278cabdff9ed3e..53392fcd8340d2 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -3709,9 +3709,8 @@ define i32 @src_select_xxory_eq0_xorxy_y(i32 %x, i32 %y) {
define i32 @sequence_select_with_same_cond_false(i1 %c1, i1 %c2){
; CHECK-LABEL: @sequence_select_with_same_cond_false(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[C1:%.*]], i32 23, i32 45
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], i32 666, i32 [[S1]]
-; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1]], i32 789, i32 [[S2]]
+; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], i32 666, i32 45
+; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1:%.*]], i32 789, i32 [[S2]]
; CHECK-NEXT: ret i32 [[S3]]
;
%s1 = select i1 %c1, i32 23, i32 45
@@ -3722,9 +3721,8 @@ define i32 @sequence_select_with_same_cond_false(i1 %c1, i1 %c2){
define i32 @sequence_select_with_same_cond_true(i1 %c1, i1 %c2){
; CHECK-LABEL: @sequence_select_with_same_cond_true(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[C1:%.*]], i32 45, i32 23
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], i32 [[S1]], i32 666
-; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1]], i32 [[S2]], i32 789
+; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], i32 45, i32 666
+; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1:%.*]], i32 [[S2]], i32 789
; CHECK-NEXT: ret i32 [[S3]]
;
%s1 = select i1 %c1, i32 45, i32 23
@@ -3735,9 +3733,8 @@ define i32 @sequence_select_with_same_cond_true(i1 %c1, i1 %c2){
define double @sequence_select_with_same_cond_double(double %a, i1 %c1, i1 %c2, double %r1, double %r2){
; CHECK-LABEL: @sequence_select_with_same_cond_double(
-; CHECK-NEXT: [[S1:%.*]] = select i1 [[C1:%.*]], double 1.000000e+00, double 0.000000e+00
-; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], double [[S1]], double 2.000000e+00
-; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1]], double [[S2]], double 3.000000e+00
+; CHECK-NEXT: [[S2:%.*]] = select i1 [[C2:%.*]], double 1.000000e+00, double 2.000000e+00
+; CHECK-NEXT: [[S3:%.*]] = select i1 [[C1:%.*]], double [[S2]], double 3.000000e+00
; CHECK-NEXT: ret double [[S3]]
;
%s1 = select i1 %c1, double 1.0, double 0.0
More information about the llvm-commits
mailing list