[llvm] f12a556 - [InstCombine] Fold binop of `select` and cast of `select` condition

Antonio Frighetto via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 20 12:43:48 PDT 2023


Author: Antonio Frighetto
Date: 2023-07-20T19:42:58Z
New Revision: f12a5561b2cbfae384c9a31293938ee2acea79fd

URL: https://github.com/llvm/llvm-project/commit/f12a5561b2cbfae384c9a31293938ee2acea79fd
DIFF: https://github.com/llvm/llvm-project/commit/f12a5561b2cbfae384c9a31293938ee2acea79fd.diff

LOG: [InstCombine] Fold binop of `select` and cast of `select` condition

Simplify binary operations, whose operands involve a `select`
instruction and a cast of the `select` condition. Specifically,
the binop is canonicalized into a `select` with folded arguments
as follows:

(Binop (zext C), (select C, T, F))
  -> (select C, (binop 1, T), (binop 0, F))

(Binop (sext C), (select C, T, F))
  -> (select C, (binop -1, T), (binop 0, F))

Proofs: https://alive2.llvm.org/ce/z/c_JwwM

Differential Revision: https://reviews.llvm.org/D153963

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
    llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index cd47a8b5330300..91ca44e0f11e87 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1618,6 +1618,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   if (Instruction *Res = foldBinOpOfDisplacedShifts(I))
     return Res;
 
+  if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
+    return Res;
+
   return Changed ? &I : nullptr;
 }
 
@@ -2466,6 +2469,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
     }
   }
 
+  if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
+    return Res;
+
   return TryToNarrowDeduceFlags();
 }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index faa1b06bde4f9e..701579e1de4830 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -454,6 +454,12 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   //    -> (BinOp (logic_shift (BinOp X, Y)), Mask)
   Instruction *foldBinOpShiftWithShift(BinaryOperator &I);
 
+  /// Tries to simplify binops of select and cast of the select condition.
+  ///
+  /// (Binop (cast C), (select C, T, F))
+  ///    -> (select C, C0, C1)
+  Instruction *foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I);
+
   /// This tries to simplify binary operations by factorizing out common terms
   /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)").
   Value *tryFactorizationFolds(BinaryOperator &I);

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index a66c2071cce5bd..50458e2773e63d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -474,6 +474,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
   if (Instruction *Ext = narrowMathIfNoOverflow(I))
     return Ext;
 
+  if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
+    return Res;
+
   // min(X, Y) * max(X, Y) => X * Y.
   if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)),
                                     m_c_SMin(m_Deferred(X), m_Deferred(Y))),

diff  --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 3fdd9195ed86cb..fbf7d7be81c058 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -870,6 +870,71 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) {
   return MatchBinOp(1);
 }
 
+// (Binop (zext C), (select C, T, F))
+//    -> (select C, (binop 1, T), (binop 0, F))
+//
+// (Binop (sext C), (select C, T, F))
+//    -> (select C, (binop -1, T), (binop 0, F))
+//
+// Attempt to simplify binary operations into a select with folded args, when
+// one operand of the binop is a select instruction and the other operand is a
+// zext/sext extension, whose value is the select condition.
+Instruction *
+InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) {
+  // TODO: this simplification may be extended to any speculatable instruction,
+  // not just binops, and would possibly be handled better in FoldOpIntoSelect.
+  Instruction::BinaryOps Opc = I.getOpcode();
+  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+  Value *A, *CondVal, *TrueVal, *FalseVal;
+  Value *CastOp;
+
+  auto MatchSelectAndCast = [&](Value *CastOp, Value *SelectOp) {
+    return match(CastOp, m_ZExtOrSExt(m_Value(A))) &&
+           A->getType()->getScalarSizeInBits() == 1 &&
+           match(SelectOp, m_Select(m_Value(CondVal), m_Value(TrueVal),
+                                    m_Value(FalseVal)));
+  };
+
+  // Make sure one side of the binop is a select instruction, and the other is a
+  // zero/sign extension operating on a i1.
+  if (MatchSelectAndCast(LHS, RHS))
+    CastOp = LHS;
+  else if (MatchSelectAndCast(RHS, LHS))
+    CastOp = RHS;
+  else
+    return nullptr;
+
+  auto NewFoldedConst = [&](bool IsTrueArm, Value *V) {
+    bool IsCastOpRHS = (CastOp == RHS);
+    bool IsZExt = isa<ZExtInst>(CastOp);
+    Constant *C;
+
+    if (IsTrueArm) {
+      C = Constant::getNullValue(V->getType());
+    } else if (IsZExt) {
+      C = Constant::getIntegerValue(
+          V->getType(), APInt(V->getType()->getIntegerBitWidth(), 1));
+    } else {
+      C = Constant::getAllOnesValue(V->getType());
+    }
+
+    return IsCastOpRHS ? Builder.CreateBinOp(Opc, V, C)
+                       : Builder.CreateBinOp(Opc, C, V);
+  };
+
+  // If the value used in the zext/sext is the select condition, or the negated
+  // of the select condition, the binop can be simplified.
+  if (CondVal == A)
+    return SelectInst::Create(CondVal, NewFoldedConst(false, TrueVal),
+                              NewFoldedConst(true, FalseVal));
+
+  if (match(A, m_Not(m_Specific(CondVal))))
+    return SelectInst::Create(CondVal, NewFoldedConst(true, TrueVal),
+                              NewFoldedConst(false, FalseVal));
+
+  return nullptr;
+}
+
 Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) {
   Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
   BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);

diff  --git a/llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll b/llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll
index 000a89849625af..d1cec11878616d 100644
--- a/llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll
+++ b/llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll
@@ -4,9 +4,7 @@
 define i64 @add_select_zext(i1 %c) {
 ; CHECK-LABEL: define i64 @add_select_zext
 ; CHECK-SAME: (i1 [[C:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1
-; CHECK-NEXT:    [[EXT:%.*]] = zext i1 [[C]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[SEL]], [[EXT]]
+; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[C]], i64 65, i64 1
 ; CHECK-NEXT:    ret i64 [[ADD]]
 ;
   %sel = select i1 %c, i64 64, i64 1
@@ -18,9 +16,7 @@ define i64 @add_select_zext(i1 %c) {
 define i64 @add_select_sext(i1 %c) {
 ; CHECK-LABEL: define i64 @add_select_sext
 ; CHECK-SAME: (i1 [[C:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1
-; CHECK-NEXT:    [[EXT:%.*]] = sext i1 [[C]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i64 [[SEL]], [[EXT]]
+; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[C]], i64 63, i64 1
 ; CHECK-NEXT:    ret i64 [[ADD]]
 ;
   %sel = select i1 %c, i64 64, i64 1
@@ -32,10 +28,7 @@ define i64 @add_select_sext(i1 %c) {
 define i64 @add_select_not_zext(i1 %c) {
 ; CHECK-LABEL: define i64 @add_select_not_zext
 ; CHECK-SAME: (i1 [[C:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1
-; CHECK-NEXT:    [[NOT_C:%.*]] = xor i1 [[C]], true
-; CHECK-NEXT:    [[EXT:%.*]] = zext i1 [[NOT_C]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[SEL]], [[EXT]]
+; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[C]], i64 64, i64 2
 ; CHECK-NEXT:    ret i64 [[ADD]]
 ;
   %sel = select i1 %c, i64 64, i64 1
@@ -48,10 +41,7 @@ define i64 @add_select_not_zext(i1 %c) {
 define i64 @add_select_not_sext(i1 %c) {
 ; CHECK-LABEL: define i64 @add_select_not_sext
 ; CHECK-SAME: (i1 [[C:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1
-; CHECK-NEXT:    [[NOT_C:%.*]] = xor i1 [[C]], true
-; CHECK-NEXT:    [[EXT:%.*]] = sext i1 [[NOT_C]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i64 [[SEL]], [[EXT]]
+; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[C]], i64 64, i64 0
 ; CHECK-NEXT:    ret i64 [[ADD]]
 ;
   %sel = select i1 %c, i64 64, i64 1
@@ -64,9 +54,7 @@ define i64 @add_select_not_sext(i1 %c) {
 define i64 @sub_select_sext(i1 %c, i64 %arg) {
 ; CHECK-LABEL: define i64 @sub_select_sext
 ; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 64, i64 [[ARG]]
-; CHECK-NEXT:    [[EXT_NEG:%.*]] = zext i1 [[C]] to i64
-; CHECK-NEXT:    [[SUB:%.*]] = add i64 [[SEL]], [[EXT_NEG]]
+; CHECK-NEXT:    [[SUB:%.*]] = select i1 [[C]], i64 65, i64 [[ARG]]
 ; CHECK-NEXT:    ret i64 [[SUB]]
 ;
   %sel = select i1 %c, i64 64, i64 %arg
@@ -78,10 +66,7 @@ define i64 @sub_select_sext(i1 %c, i64 %arg) {
 define i64 @sub_select_not_zext(i1 %c, i64 %arg) {
 ; CHECK-LABEL: define i64 @sub_select_not_zext
 ; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 [[ARG]], i64 64
-; CHECK-NEXT:    [[NOT_C:%.*]] = xor i1 [[C]], true
-; CHECK-NEXT:    [[EXT_NEG:%.*]] = sext i1 [[NOT_C]] to i64
-; CHECK-NEXT:    [[SUB:%.*]] = add i64 [[SEL]], [[EXT_NEG]]
+; CHECK-NEXT:    [[SUB:%.*]] = select i1 [[C]], i64 [[ARG]], i64 63
 ; CHECK-NEXT:    ret i64 [[SUB]]
 ;
   %sel = select i1 %c, i64 %arg, i64 64
@@ -94,10 +79,7 @@ define i64 @sub_select_not_zext(i1 %c, i64 %arg) {
 define i64 @sub_select_not_sext(i1 %c, i64 %arg) {
 ; CHECK-LABEL: define i64 @sub_select_not_sext
 ; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 [[ARG]], i64 64
-; CHECK-NEXT:    [[NOT_C:%.*]] = xor i1 [[C]], true
-; CHECK-NEXT:    [[EXT_NEG:%.*]] = zext i1 [[NOT_C]] to i64
-; CHECK-NEXT:    [[SUB:%.*]] = add i64 [[SEL]], [[EXT_NEG]]
+; CHECK-NEXT:    [[SUB:%.*]] = select i1 [[C]], i64 [[ARG]], i64 65
 ; CHECK-NEXT:    ret i64 [[SUB]]
 ;
   %sel = select i1 %c, i64 %arg, i64 64
@@ -122,9 +104,7 @@ define i64 @mul_select_zext(i1 %c, i64 %arg) {
 define i64 @mul_select_sext(i1 %c) {
 ; CHECK-LABEL: define i64 @mul_select_sext
 ; CHECK-SAME: (i1 [[C:%.*]]) {
-; CHECK-NEXT:    [[EXT:%.*]] = sext i1 [[C]] to i64
-; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[C]], i64 6, i64 0
-; CHECK-NEXT:    [[MUL:%.*]] = shl i64 [[EXT]], [[TMP1]]
+; CHECK-NEXT:    [[MUL:%.*]] = select i1 [[C]], i64 -64, i64 0
 ; CHECK-NEXT:    ret i64 [[MUL]]
 ;
   %sel = select i1 %c, i64 64, i64 1
@@ -168,10 +148,7 @@ define <2 x i64> @vector_test(i1 %c) {
 define i64 @multiuse_add(i1 %c) {
 ; CHECK-LABEL: define i64 @multiuse_add
 ; CHECK-SAME: (i1 [[C:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1
-; CHECK-NEXT:    [[EXT:%.*]] = zext i1 [[C]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nuw nsw i64 [[SEL]], [[EXT]]
-; CHECK-NEXT:    [[ADD2:%.*]] = add nuw nsw i64 [[ADD]], 1
+; CHECK-NEXT:    [[ADD2:%.*]] = select i1 [[C]], i64 66, i64 2
 ; CHECK-NEXT:    ret i64 [[ADD2]]
 ;
   %sel = select i1 %c, i64 64, i64 1
@@ -184,10 +161,7 @@ define i64 @multiuse_add(i1 %c) {
 define i64 @multiuse_select(i1 %c) {
 ; CHECK-LABEL: define i64 @multiuse_select
 ; CHECK-SAME: (i1 [[C:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 64, i64 0
-; CHECK-NEXT:    [[EXT_NEG:%.*]] = sext i1 [[C]] to i64
-; CHECK-NEXT:    [[ADD:%.*]] = add nsw i64 [[SEL]], [[EXT_NEG]]
-; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i64 [[SEL]], [[ADD]]
+; CHECK-NEXT:    [[MUL:%.*]] = select i1 [[C]], i64 4032, i64 0
 ; CHECK-NEXT:    ret i64 [[MUL]]
 ;
   %sel = select i1 %c, i64 64, i64 0
@@ -200,9 +174,8 @@ define i64 @multiuse_select(i1 %c) {
 define i64 @select_non_const_sides(i1 %c, i64 %arg1, i64 %arg2) {
 ; CHECK-LABEL: define i64 @select_non_const_sides
 ; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG1:%.*]], i64 [[ARG2:%.*]]) {
-; CHECK-NEXT:    [[EXT_NEG:%.*]] = sext i1 [[C]] to i64
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i64 [[ARG1]], i64 [[ARG2]]
-; CHECK-NEXT:    [[SUB:%.*]] = add i64 [[SEL]], [[EXT_NEG]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[ARG1]], -1
+; CHECK-NEXT:    [[SUB:%.*]] = select i1 [[C]], i64 [[TMP1]], i64 [[ARG2]]
 ; CHECK-NEXT:    ret i64 [[SUB]]
 ;
   %ext = zext i1 %c to i64
@@ -214,9 +187,9 @@ define i64 @select_non_const_sides(i1 %c, i64 %arg1, i64 %arg2) {
 define i6 @sub_select_sext_op_swapped_non_const_args(i1 %c, i6 %argT, i6 %argF) {
 ; CHECK-LABEL: define i6 @sub_select_sext_op_swapped_non_const_args
 ; CHECK-SAME: (i1 [[C:%.*]], i6 [[ARGT:%.*]], i6 [[ARGF:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i6 [[ARGT]], i6 [[ARGF]]
-; CHECK-NEXT:    [[EXT:%.*]] = sext i1 [[C]] to i6
-; CHECK-NEXT:    [[SUB:%.*]] = sub i6 [[EXT]], [[SEL]]
+; CHECK-DAG:     [[TMP1:%.*]] = xor i6 [[ARGT]], -1
+; CHECK-DAG:     [[TMP2:%.*]] = sub i6 0, [[ARGF]]
+; CHECK-NEXT:    [[SUB:%.*]] = select i1 [[C]], i6 [[TMP1]], i6 [[TMP2]]
 ; CHECK-NEXT:    ret i6 [[SUB]]
 ;
   %sel = select i1 %c, i6 %argT, i6 %argF
@@ -228,9 +201,9 @@ define i6 @sub_select_sext_op_swapped_non_const_args(i1 %c, i6 %argT, i6 %argF)
 define i6 @sub_select_zext_op_swapped_non_const_args(i1 %c, i6 %argT, i6 %argF) {
 ; CHECK-LABEL: define i6 @sub_select_zext_op_swapped_non_const_args
 ; CHECK-SAME: (i1 [[C:%.*]], i6 [[ARGT:%.*]], i6 [[ARGF:%.*]]) {
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[C]], i6 [[ARGT]], i6 [[ARGF]]
-; CHECK-NEXT:    [[EXT:%.*]] = zext i1 [[C]] to i6
-; CHECK-NEXT:    [[SUB:%.*]] = sub i6 [[EXT]], [[SEL]]
+; CHECK-DAG:     [[TMP1:%.*]] = sub i6 1, [[ARGT]]
+; CHECK-DAG:     [[TMP2:%.*]] = sub i6 0, [[ARGF]]
+; CHECK-NEXT:    [[SUB:%.*]] = select i1 [[C]], i6 [[TMP1]], i6 [[TMP2]]
 ; CHECK-NEXT:    ret i6 [[SUB]]
 ;
   %sel = select i1 %c, i6 %argT, i6 %argF


        


More information about the llvm-commits mailing list