[llvm] [InstCombine] Generalize select equiv fold for plain condition (PR #85663)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 18 09:56:25 PDT 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/85663

The select equivalence fold takes a select like "X == Y ? A : B" and then tries to simplify A based on the known equality.

This patch also uses it for the case were we have just "C ? A : B" by treating the condition as either "C == 1" or "C != 1".

This is intended as an alternative to #83405
for fixing https://github.com/llvm/llvm-project/issues/83225.

>From b1fab40db94184e9f45cc70d8654074699c103bf Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 20 Sep 2022 16:53:46 +0200
Subject: [PATCH] [InstCombine] Generalize select equiv fold for plain
 condition

The select equivalence fold takes a select like "X == Y ? A : B"
and then tries to simplify A based on the known equality.

This patch also uses it for the case were we have just "C ? A : B"
by treating the condition as either "C == 1" or "C != 1".

This is intended as an alternative to #83405
for fixing https://github.com/llvm/llvm-project/issues/83225.
---
 .../InstCombine/InstCombineInternal.h         |  4 +-
 .../InstCombine/InstCombineSelect.cpp         | 41 ++++++++-----------
 llvm/test/Transforms/InstCombine/select.ll    | 15 +++----
 3 files changed, 25 insertions(+), 35 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index e2b744ba66f2a9..a8353092d72db1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -735,9 +735,11 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                             Value *A, Value *B, Instruction &Outer,
                             SelectPatternFlavor SPF2, Value *C);
   Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI);
-  Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI);
   bool replaceInInstruction(Value *V, Value *Old, Value *New,
                             unsigned Depth = 0);
+  Instruction *foldSelectValueEquivalence(SelectInst &Sel,
+                                          ICmpInst::Predicate Pred,
+                                          Value *CmpLHS, Value *CmpRHS);
 
   Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
                          bool isSigned, bool Inside);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ee76a6294428b3..3d52661d3c20d8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1262,27 +1262,23 @@ bool InstCombinerImpl::replaceInInstruction(Value *V, Value *Old, Value *New,
 ///
 /// We can't replace %sel with %add unless we strip away the flags.
 /// TODO: Wrapping flags could be preserved in some cases with better analysis.
-Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
-                                                          ICmpInst &Cmp) {
-  if (!Cmp.isEquality())
+Instruction *InstCombinerImpl::foldSelectValueEquivalence(
+    SelectInst &Sel, ICmpInst::Predicate Pred, Value *CmpLHS, Value *CmpRHS) {
+  if (!ICmpInst::isEquality(Pred))
     return nullptr;
 
   // Canonicalize the pattern to ICMP_EQ by swapping the select operands.
   Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue();
   bool Swapped = false;
-  if (Cmp.getPredicate() == ICmpInst::ICMP_NE) {
+  if (Pred == ICmpInst::ICMP_NE) {
     std::swap(TrueVal, FalseVal);
     Swapped = true;
   }
 
   // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand.
   // Make sure Y cannot be undef though, as we might pick different values for
-  // undef in the icmp and in f(Y). Additionally, take care to avoid replacing
-  // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite
-  // replacement cycle.
-  Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
-  if (TrueVal != CmpLHS &&
-      isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
+  // undef in the icmp and in f(Y).
+  if (isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
     if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
                                           /* AllowRefinement */ true))
       // Require either the replacement or the simplification result to be a
@@ -1299,7 +1295,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     // profitability is not clear for other cases.
     // FIXME: Support vectors.
     if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) &&
-        !Cmp.getType()->isVectorTy())
+        !CmpLHS->getType()->isVectorTy())
       if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS))
         return &Sel;
   }
@@ -1680,7 +1676,8 @@ static Value *foldSelectInstWithICmpConst(SelectInst &SI, ICmpInst *ICI,
 /// Visit a SelectInst that has an ICmpInst as its first operand.
 Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
                                                       ICmpInst *ICI) {
-  if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI))
+  if (Instruction *NewSel = foldSelectValueEquivalence(
+          SI, ICI->getPredicate(), ICI->getOperand(0), ICI->getOperand(1)))
     return NewSel;
 
   if (Value *V =
@@ -3376,21 +3373,15 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
   if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this))
     return I;
 
-  // If the type of select is not an integer type or if the condition and
-  // the selection type are not both scalar nor both vector types, there is no
-  // point in attempting to match these patterns.
   Type *CondType = CondVal->getType();
-  if (!isa<Constant>(CondVal) && SelType->isIntOrIntVectorTy() &&
-      CondType->isVectorTy() == SelType->isVectorTy()) {
-    if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal,
-                                          ConstantInt::getTrue(CondType), SQ,
-                                          /* AllowRefinement */ true))
-      return replaceOperand(SI, 1, S);
+  if (!isa<Constant>(CondVal)) {
+    if (Instruction *I = foldSelectValueEquivalence(
+            SI, ICmpInst::ICMP_EQ, CondVal, ConstantInt::getTrue(CondType)))
+      return I;
 
-    if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal,
-                                          ConstantInt::getFalse(CondType), SQ,
-                                          /* AllowRefinement */ true))
-      return replaceOperand(SI, 2, S);
+    if (Instruction *I = foldSelectValueEquivalence(
+            SI, ICmpInst::ICMP_NE, CondVal, ConstantInt::getFalse(CondType)))
+      return I;
   }
 
   if (Instruction *R = foldSelectOfBools(SI))
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 278cabdff9ed3e..53392fcd8340d2 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -3709,9 +3709,8 @@ define i32 @src_select_xxory_eq0_xorxy_y(i32 %x, i32 %y) {
 
 define i32 @sequence_select_with_same_cond_false(i1 %c1, i1 %c2){
 ; CHECK-LABEL: @sequence_select_with_same_cond_false(
-; CHECK-NEXT:    [[S1:%.*]] = select i1 [[C1:%.*]], i32 23, i32 45
-; CHECK-NEXT:    [[S2:%.*]] = select i1 [[C2:%.*]], i32 666, i32 [[S1]]
-; CHECK-NEXT:    [[S3:%.*]] = select i1 [[C1]], i32 789, i32 [[S2]]
+; CHECK-NEXT:    [[S2:%.*]] = select i1 [[C2:%.*]], i32 666, i32 45
+; CHECK-NEXT:    [[S3:%.*]] = select i1 [[C1:%.*]], i32 789, i32 [[S2]]
 ; CHECK-NEXT:    ret i32 [[S3]]
 ;
   %s1 = select i1 %c1, i32 23, i32 45
@@ -3722,9 +3721,8 @@ define i32 @sequence_select_with_same_cond_false(i1 %c1, i1 %c2){
 
 define i32 @sequence_select_with_same_cond_true(i1 %c1, i1 %c2){
 ; CHECK-LABEL: @sequence_select_with_same_cond_true(
-; CHECK-NEXT:    [[S1:%.*]] = select i1 [[C1:%.*]], i32 45, i32 23
-; CHECK-NEXT:    [[S2:%.*]] = select i1 [[C2:%.*]], i32 [[S1]], i32 666
-; CHECK-NEXT:    [[S3:%.*]] = select i1 [[C1]], i32 [[S2]], i32 789
+; CHECK-NEXT:    [[S2:%.*]] = select i1 [[C2:%.*]], i32 45, i32 666
+; CHECK-NEXT:    [[S3:%.*]] = select i1 [[C1:%.*]], i32 [[S2]], i32 789
 ; CHECK-NEXT:    ret i32 [[S3]]
 ;
   %s1 = select i1 %c1, i32 45, i32 23
@@ -3735,9 +3733,8 @@ define i32 @sequence_select_with_same_cond_true(i1 %c1, i1 %c2){
 
 define double @sequence_select_with_same_cond_double(double %a, i1 %c1, i1 %c2, double %r1, double %r2){
 ; CHECK-LABEL: @sequence_select_with_same_cond_double(
-; CHECK-NEXT:    [[S1:%.*]] = select i1 [[C1:%.*]], double 1.000000e+00, double 0.000000e+00
-; CHECK-NEXT:    [[S2:%.*]] = select i1 [[C2:%.*]], double [[S1]], double 2.000000e+00
-; CHECK-NEXT:    [[S3:%.*]] = select i1 [[C1]], double [[S2]], double 3.000000e+00
+; CHECK-NEXT:    [[S2:%.*]] = select i1 [[C2:%.*]], double 1.000000e+00, double 2.000000e+00
+; CHECK-NEXT:    [[S3:%.*]] = select i1 [[C1:%.*]], double [[S2]], double 3.000000e+00
 ; CHECK-NEXT:    ret double [[S3]]
 ;
   %s1 = select i1 %c1, double 1.0, double 0.0



More information about the llvm-commits mailing list