[llvm] [InstCombine][WIP] Fold `(binop VarTwoPossibleVals, C)` to `select` (PR #101731)

via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 2 11:32:40 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

<details>
<summary>Changes</summary>

I.e:
    `(binop X, C)` where `X` equals `C0` or `C1` to:
    `(select (icmp eq X, C0), (binop C0, C), (binop C1, C))`

We currently handle the opposite case for `add`, `or`, `xor`, and
`sub` in `foldSelectICmpAnd`. For the ops we don't handle there, this
patch makes the decision intentional.


---
Full diff: https://github.com/llvm/llvm-project/pull/101731.diff


10 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+3) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+2) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+16) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+9) 
- (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+70) 
- (modified) llvm/test/Transforms/InstCombine/binop-select.ll (+1-1) 
- (modified) llvm/test/Transforms/InstCombine/pr72433.ll (+1-2) 
- (modified) llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll (+1-3) 
- (modified) llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll (+17-12) 
- (modified) llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll (+2-3) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 2db05c669145b..12eae825c85cf 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2836,6 +2836,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
                                       /*SimplifyOnly*/ false, *this))
     return BinaryOperator::CreateAnd(Op0, V);
 
+  if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+    return R;
+
   return nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 64fbcc80e0edf..19034c8bcda82 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -674,6 +674,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *foldSignBitTest(ICmpInst &I);
   Instruction *foldICmpWithZero(ICmpInst &Cmp);
 
+  Instruction *foldOpWithTwoPossibleValuesToSelect(BinaryOperator &I);
+
   Value *foldMultiplicationOverflowCheck(ICmpInst &Cmp);
 
   Instruction *foldICmpBinOpWithConstant(ICmpInst &Cmp, BinaryOperator *BO,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f4f3644acfe5e..4914c8a7daea7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -598,6 +598,9 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) {
     return replaceInstUsesWith(I, Fabs);
   }
 
+  if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+    return R;
+
   return nullptr;
 }
 
@@ -1577,6 +1580,9 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
         I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
   }
 
+  if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+    return R;
+
   return nullptr;
 }
 
@@ -1716,6 +1722,10 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
     return SelectInst::Create(Cond, ConstantInt::get(Ty, 1),
                               ConstantInt::getAllOnesValue(Ty));
   }
+
+  if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+    return R;
+
   return nullptr;
 }
 
@@ -2230,6 +2240,9 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) {
     }
   }
 
+  if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+    return R;
+
   return nullptr;
 }
 
@@ -2302,6 +2315,9 @@ Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) {
     }
   }
 
+  if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+    return R;
+
   return nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 38f8a41214b68..c17fa66e7deba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1250,6 +1250,9 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
     }
   }
 
+  if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+    return R;
+
   return nullptr;
 }
 
@@ -1592,6 +1595,9 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
   if (Instruction *Overflow = foldLShrOverflowBit(I))
     return Overflow;
 
+  if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+    return R;
+
   return nullptr;
 }
 
@@ -1795,5 +1801,8 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
     return BinaryOperator::CreateNot(NewAShr);
   }
 
+  if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+    return R;
+
   return nullptr;
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 0fb8b639c97b9..6b32e35527801 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -4872,6 +4872,76 @@ void InstCombinerImpl::tryToSinkInstructionDbgValues(
   }
 }
 
+// Fold:
+//  (binop X, C)
+// Where X can only be one of two possible values to:
+//  (select (icmp eq X, C0), (binop C0, C), (binop C1, C))
+//
+// We don't do this for all binops, only those not handled by
+// `foldSelectICmpAnd` which does the inverse.
+Instruction *
+InstCombinerImpl::foldOpWithTwoPossibleValuesToSelect(BinaryOperator &I) {
+
+  switch (I.getOpcode()) {
+    // Handled in `foldSelectICmpAnd` where we go the other way.
+  case Instruction::Add:
+  case Instruction::Or:
+  case Instruction::Sub:
+  case Instruction::Xor:
+    return nullptr;
+  default:
+    break;
+  }
+
+  for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) {
+    switch (I.getOpcode()) {
+    case Instruction::Shl:
+    case Instruction::AShr:
+    case Instruction::LShr:
+      if (OpIdx == 1)
+        return nullptr;
+      break;
+    default:
+      break;
+    }
+    ConstantInt *C;
+    if (!match(I.getOperand(OpIdx), m_ConstantInt(C)))
+      continue;
+
+    Value *Other = I.getOperand(1 - OpIdx);
+    KnownBits OtherKnown = computeKnownBits(Other, /*Depth=*/0, &I);
+
+    // See if the other op has only two possible values.
+    if ((OtherKnown.One | OtherKnown.Zero).popcount() !=
+        (OtherKnown.getBitWidth() - 1))
+      continue;
+
+    // Get the two possible values.
+    Constant *OtherC0 =
+        ConstantInt::get(C->getType(), OtherKnown.getMaxValue());
+    Constant *OtherC1 =
+        ConstantInt::get(C->getType(), OtherKnown.getMinValue());
+
+    assert(OtherC0 != OtherC1 && "This should have been constant folded!");
+    Constant *SelTC, *SelFC;
+    // See if we can create a select with two constants.
+    if (OpIdx == 0) {
+      SelTC = ConstantFoldBinaryOpOperands(I.getOpcode(), C, OtherC0, DL);
+      SelFC = ConstantFoldBinaryOpOperands(I.getOpcode(), C, OtherC1, DL);
+    } else {
+      SelTC = ConstantFoldBinaryOpOperands(I.getOpcode(), OtherC0, C, DL);
+      SelFC = ConstantFoldBinaryOpOperands(I.getOpcode(), OtherC1, C, DL);
+    }
+    if (!SelTC || !SelFC)
+      continue;
+
+    Value *SelCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Other, OtherC0);
+    return SelectInst::Create(SelCmp, SelTC, SelFC);
+  }
+
+  return nullptr;
+}
+
 void InstCombinerImpl::tryToSinkInstructionDbgVariableRecords(
     Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock,
     BasicBlock *DestBlock,
diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll
index 6cd4132eadd77..e10efe41048ea 100644
--- a/llvm/test/Transforms/InstCombine/binop-select.ll
+++ b/llvm/test/Transforms/InstCombine/binop-select.ll
@@ -395,7 +395,7 @@ define i32 @ashr_sel_op1_use(i1 %b) {
 ; CHECK-LABEL: @ashr_sel_op1_use(
 ; CHECK-NEXT:    [[S:%.*]] = select i1 [[B:%.*]], i32 2, i32 0
 ; CHECK-NEXT:    call void @use(i32 [[S]])
-; CHECK-NEXT:    [[R:%.*]] = ashr i32 -2, [[S]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[B]], i32 -1, i32 -2
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %s = select i1 %b, i32 2, i32 0
diff --git a/llvm/test/Transforms/InstCombine/pr72433.ll b/llvm/test/Transforms/InstCombine/pr72433.ll
index c6e74582a13d3..1633885075e87 100644
--- a/llvm/test/Transforms/InstCombine/pr72433.ll
+++ b/llvm/test/Transforms/InstCombine/pr72433.ll
@@ -6,8 +6,7 @@ define i32 @widget(i32 %arg, i32 %arg1) {
 ; CHECK-SAME: i32 [[ARG:%.*]], i32 [[ARG1:%.*]]) {
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    [[ICMP:%.*]] = icmp ne i32 [[ARG]], 0
-; CHECK-NEXT:    [[TMP0:%.*]] = zext i1 [[ICMP]] to i32
-; CHECK-NEXT:    [[MUL:%.*]] = shl nuw nsw i32 20, [[TMP0]]
+; CHECK-NEXT:    [[MUL:%.*]] = select i1 [[ICMP]], i32 40, i32 20
 ; CHECK-NEXT:    [[XOR:%.*]] = zext i1 [[ICMP]] to i32
 ; CHECK-NEXT:    [[ADD9:%.*]] = or disjoint i32 [[MUL]], [[XOR]]
 ; CHECK-NEXT:    [[TMP1:%.*]] = zext i1 [[ICMP]] to i32
diff --git a/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll b/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll
index b06a90e2cd99b..8bffdf5a76fdb 100644
--- a/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll
+++ b/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll
@@ -5,9 +5,7 @@
 define i32 @src(i1 %x2) {
 ; CHECK-LABEL: @src(
 ; CHECK-NEXT:    [[X13:%.*]] = zext i1 [[X2:%.*]] to i32
-; CHECK-NEXT:    [[_7:%.*]] = shl nsw i32 -1, [[X13]]
-; CHECK-NEXT:    [[MASK:%.*]] = xor i32 [[_7]], -1
-; CHECK-NEXT:    [[_8:%.*]] = and i32 [[MASK]], [[X13]]
+; CHECK-NEXT:    [[_8:%.*]] = zext i1 [[X2]] to i32
 ; CHECK-NEXT:    [[_9:%.*]] = shl nuw nsw i32 [[_8]], [[X13]]
 ; CHECK-NEXT:    ret i32 [[_9]]
 ;
diff --git a/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll b/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll
index b992460d0be69..b2c4a5505f004 100644
--- a/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll
+++ b/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll
@@ -117,9 +117,10 @@ define i64 @narrow_source_matching_signbits(i32 %x) {
 
 define i64 @narrow_source_not_matching_signbits(i32 %x) {
 ; CHECK-LABEL: @narrow_source_not_matching_signbits(
-; CHECK-NEXT:    [[M:%.*]] = and i32 [[X:%.*]], 8
-; CHECK-NEXT:    [[A:%.*]] = shl nsw i32 -1, [[M]]
-; CHECK-NEXT:    [[B:%.*]] = trunc i32 [[A]] to i8
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i8 [[TMP1]], 3
+; CHECK-NEXT:    [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT:    [[B:%.*]] = add nsw i8 [[TMP3]], -1
 ; CHECK-NEXT:    [[C:%.*]] = sext i8 [[B]] to i64
 ; CHECK-NEXT:    ret i64 [[C]]
 ;
@@ -148,9 +149,10 @@ define i24 @wide_source_matching_signbits(i32 %x) {
 
 define i24 @wide_source_not_matching_signbits(i32 %x) {
 ; CHECK-LABEL: @wide_source_not_matching_signbits(
-; CHECK-NEXT:    [[M2:%.*]] = and i32 [[X:%.*]], 8
-; CHECK-NEXT:    [[A:%.*]] = shl nsw i32 -1, [[M2]]
-; CHECK-NEXT:    [[B:%.*]] = trunc i32 [[A]] to i8
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i8 [[TMP1]], 3
+; CHECK-NEXT:    [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT:    [[B:%.*]] = add nsw i8 [[TMP3]], -1
 ; CHECK-NEXT:    [[C:%.*]] = sext i8 [[B]] to i24
 ; CHECK-NEXT:    ret i24 [[C]]
 ;
@@ -178,9 +180,11 @@ define i32 @same_source_matching_signbits(i32 %x) {
 
 define i32 @same_source_not_matching_signbits(i32 %x) {
 ; CHECK-LABEL: @same_source_not_matching_signbits(
-; CHECK-NEXT:    [[M2:%.*]] = and i32 [[X:%.*]], 8
-; CHECK-NEXT:    [[TMP1:%.*]] = shl i32 -16777216, [[M2]]
-; CHECK-NEXT:    [[C:%.*]] = ashr exact i32 [[TMP1]], 24
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i8 [[TMP1]], 3
+; CHECK-NEXT:    [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT:    [[B:%.*]] = add nsw i8 [[TMP3]], -1
+; CHECK-NEXT:    [[C:%.*]] = sext i8 [[B]] to i32
 ; CHECK-NEXT:    ret i32 [[C]]
 ;
   %m2 = and i32 %x, 8
@@ -208,9 +212,10 @@ define i32 @same_source_matching_signbits_extra_use(i32 %x) {
 
 define i32 @same_source_not_matching_signbits_extra_use(i32 %x) {
 ; CHECK-LABEL: @same_source_not_matching_signbits_extra_use(
-; CHECK-NEXT:    [[M2:%.*]] = and i32 [[X:%.*]], 8
-; CHECK-NEXT:    [[A:%.*]] = shl nsw i32 -1, [[M2]]
-; CHECK-NEXT:    [[B:%.*]] = trunc i32 [[A]] to i8
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i8 [[TMP1]], 3
+; CHECK-NEXT:    [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT:    [[B:%.*]] = add nsw i8 [[TMP3]], -1
 ; CHECK-NEXT:    call void @use8(i8 [[B]])
 ; CHECK-NEXT:    [[C:%.*]] = sext i8 [[B]] to i32
 ; CHECK-NEXT:    ret i32 [[C]]
diff --git a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
index 00a19e4962e6c..954e3e4ae47e2 100644
--- a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
+++ b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
@@ -675,9 +675,8 @@ define i1 @constantexpr() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = load i16, ptr @f.a, align 2
 ; CHECK-NEXT:    [[SHR:%.*]] = lshr i16 [[TMP0]], 1
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne i16 ptrtoint (ptr @f.a to i16), 1
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i1 [[CMP]] to i16
-; CHECK-NEXT:    [[TMP1:%.*]] = shl nuw nsw i16 1, [[ZEXT]]
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i16 ptrtoint (ptr @f.a to i16), 1
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[CMP_NOT]], i16 1, i16 2
 ; CHECK-NEXT:    [[TMP2:%.*]] = and i16 [[SHR]], [[TMP1]]
 ; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp ne i16 [[TMP2]], 0
 ; CHECK-NEXT:    ret i1 [[TOBOOL]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/101731


More information about the llvm-commits mailing list