[llvm] [InstCombine][WIP] Fold `(binop VarTwoPossibleVals, C)` to `select` (PR #101731)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 2 11:32:40 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: None (goldsteinn)
<details>
<summary>Changes</summary>
I.e:
`(binop X, C)` where `X` equals `C0` or `C1` to:
`(select (icmp eq X, C0), (binop C0, C), (binop C1, C))`
We currently handle the opposite case for `add`, `or`, `xor`, and
`sub` in `foldSelectICmpAnd`. For the ops we don't handle there, this
patch makes the decision intentional.
---
Full diff: https://github.com/llvm/llvm-project/pull/101731.diff
10 Files Affected:
- (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+3)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+2)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+16)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+9)
- (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+70)
- (modified) llvm/test/Transforms/InstCombine/binop-select.ll (+1-1)
- (modified) llvm/test/Transforms/InstCombine/pr72433.ll (+1-2)
- (modified) llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll (+1-3)
- (modified) llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll (+17-12)
- (modified) llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll (+2-3)
``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 2db05c669145b..12eae825c85cf 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2836,6 +2836,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
/*SimplifyOnly*/ false, *this))
return BinaryOperator::CreateAnd(Op0, V);
+ if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+ return R;
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 64fbcc80e0edf..19034c8bcda82 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -674,6 +674,8 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
Instruction *foldSignBitTest(ICmpInst &I);
Instruction *foldICmpWithZero(ICmpInst &Cmp);
+ Instruction *foldOpWithTwoPossibleValuesToSelect(BinaryOperator &I);
+
Value *foldMultiplicationOverflowCheck(ICmpInst &Cmp);
Instruction *foldICmpBinOpWithConstant(ICmpInst &Cmp, BinaryOperator *BO,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index f4f3644acfe5e..4914c8a7daea7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -598,6 +598,9 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) {
return replaceInstUsesWith(I, Fabs);
}
+ if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+ return R;
+
return nullptr;
}
@@ -1577,6 +1580,9 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
}
+ if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+ return R;
+
return nullptr;
}
@@ -1716,6 +1722,10 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
return SelectInst::Create(Cond, ConstantInt::get(Ty, 1),
ConstantInt::getAllOnesValue(Ty));
}
+
+ if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+ return R;
+
return nullptr;
}
@@ -2230,6 +2240,9 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) {
}
}
+ if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+ return R;
+
return nullptr;
}
@@ -2302,6 +2315,9 @@ Instruction *InstCombinerImpl::visitSRem(BinaryOperator &I) {
}
}
+ if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+ return R;
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 38f8a41214b68..c17fa66e7deba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1250,6 +1250,9 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
}
}
+ if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+ return R;
+
return nullptr;
}
@@ -1592,6 +1595,9 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
if (Instruction *Overflow = foldLShrOverflowBit(I))
return Overflow;
+ if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+ return R;
+
return nullptr;
}
@@ -1795,5 +1801,8 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
return BinaryOperator::CreateNot(NewAShr);
}
+ if (Instruction *R = foldOpWithTwoPossibleValuesToSelect(I))
+ return R;
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 0fb8b639c97b9..6b32e35527801 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -4872,6 +4872,76 @@ void InstCombinerImpl::tryToSinkInstructionDbgValues(
}
}
+// Fold:
+// (binop X, C)
+// Where X can only be one of two possible values to:
+// (select (icmp eq X, C0), (binop C0, C), (binop C1, C))
+//
+// We don't do this for all binops, only those not handled by
+// `foldSelectICmpAnd` which does the inverse.
+Instruction *
+InstCombinerImpl::foldOpWithTwoPossibleValuesToSelect(BinaryOperator &I) {
+
+ switch (I.getOpcode()) {
+ // Handled in `foldSelectICmpAnd` where we go the other way.
+ case Instruction::Add:
+ case Instruction::Or:
+ case Instruction::Sub:
+ case Instruction::Xor:
+ return nullptr;
+ default:
+ break;
+ }
+
+ for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) {
+ switch (I.getOpcode()) {
+ case Instruction::Shl:
+ case Instruction::AShr:
+ case Instruction::LShr:
+ if (OpIdx == 1)
+ return nullptr;
+ break;
+ default:
+ break;
+ }
+ ConstantInt *C;
+ if (!match(I.getOperand(OpIdx), m_ConstantInt(C)))
+ continue;
+
+ Value *Other = I.getOperand(1 - OpIdx);
+ KnownBits OtherKnown = computeKnownBits(Other, /*Depth=*/0, &I);
+
+ // See if the other op has only two possible values.
+ if ((OtherKnown.One | OtherKnown.Zero).popcount() !=
+ (OtherKnown.getBitWidth() - 1))
+ continue;
+
+ // Get the two possible values.
+ Constant *OtherC0 =
+ ConstantInt::get(C->getType(), OtherKnown.getMaxValue());
+ Constant *OtherC1 =
+ ConstantInt::get(C->getType(), OtherKnown.getMinValue());
+
+ assert(OtherC0 != OtherC1 && "This should have been constant folded!");
+ Constant *SelTC, *SelFC;
+ // See if we can create a select with two constants.
+ if (OpIdx == 0) {
+ SelTC = ConstantFoldBinaryOpOperands(I.getOpcode(), C, OtherC0, DL);
+ SelFC = ConstantFoldBinaryOpOperands(I.getOpcode(), C, OtherC1, DL);
+ } else {
+ SelTC = ConstantFoldBinaryOpOperands(I.getOpcode(), OtherC0, C, DL);
+ SelFC = ConstantFoldBinaryOpOperands(I.getOpcode(), OtherC1, C, DL);
+ }
+ if (!SelTC || !SelFC)
+ continue;
+
+ Value *SelCmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Other, OtherC0);
+ return SelectInst::Create(SelCmp, SelTC, SelFC);
+ }
+
+ return nullptr;
+}
+
void InstCombinerImpl::tryToSinkInstructionDbgVariableRecords(
Instruction *I, BasicBlock::iterator InsertPos, BasicBlock *SrcBlock,
BasicBlock *DestBlock,
diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll
index 6cd4132eadd77..e10efe41048ea 100644
--- a/llvm/test/Transforms/InstCombine/binop-select.ll
+++ b/llvm/test/Transforms/InstCombine/binop-select.ll
@@ -395,7 +395,7 @@ define i32 @ashr_sel_op1_use(i1 %b) {
; CHECK-LABEL: @ashr_sel_op1_use(
; CHECK-NEXT: [[S:%.*]] = select i1 [[B:%.*]], i32 2, i32 0
; CHECK-NEXT: call void @use(i32 [[S]])
-; CHECK-NEXT: [[R:%.*]] = ashr i32 -2, [[S]]
+; CHECK-NEXT: [[R:%.*]] = select i1 [[B]], i32 -1, i32 -2
; CHECK-NEXT: ret i32 [[R]]
;
%s = select i1 %b, i32 2, i32 0
diff --git a/llvm/test/Transforms/InstCombine/pr72433.ll b/llvm/test/Transforms/InstCombine/pr72433.ll
index c6e74582a13d3..1633885075e87 100644
--- a/llvm/test/Transforms/InstCombine/pr72433.ll
+++ b/llvm/test/Transforms/InstCombine/pr72433.ll
@@ -6,8 +6,7 @@ define i32 @widget(i32 %arg, i32 %arg1) {
; CHECK-SAME: i32 [[ARG:%.*]], i32 [[ARG1:%.*]]) {
; CHECK-NEXT: bb:
; CHECK-NEXT: [[ICMP:%.*]] = icmp ne i32 [[ARG]], 0
-; CHECK-NEXT: [[TMP0:%.*]] = zext i1 [[ICMP]] to i32
-; CHECK-NEXT: [[MUL:%.*]] = shl nuw nsw i32 20, [[TMP0]]
+; CHECK-NEXT: [[MUL:%.*]] = select i1 [[ICMP]], i32 40, i32 20
; CHECK-NEXT: [[XOR:%.*]] = zext i1 [[ICMP]] to i32
; CHECK-NEXT: [[ADD9:%.*]] = or disjoint i32 [[MUL]], [[XOR]]
; CHECK-NEXT: [[TMP1:%.*]] = zext i1 [[ICMP]] to i32
diff --git a/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll b/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll
index b06a90e2cd99b..8bffdf5a76fdb 100644
--- a/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll
+++ b/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-pr49778.ll
@@ -5,9 +5,7 @@
define i32 @src(i1 %x2) {
; CHECK-LABEL: @src(
; CHECK-NEXT: [[X13:%.*]] = zext i1 [[X2:%.*]] to i32
-; CHECK-NEXT: [[_7:%.*]] = shl nsw i32 -1, [[X13]]
-; CHECK-NEXT: [[MASK:%.*]] = xor i32 [[_7]], -1
-; CHECK-NEXT: [[_8:%.*]] = and i32 [[MASK]], [[X13]]
+; CHECK-NEXT: [[_8:%.*]] = zext i1 [[X2]] to i32
; CHECK-NEXT: [[_9:%.*]] = shl nuw nsw i32 [[_8]], [[X13]]
; CHECK-NEXT: ret i32 [[_9]]
;
diff --git a/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll b/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll
index b992460d0be69..b2c4a5505f004 100644
--- a/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll
+++ b/llvm/test/Transforms/InstCombine/sext-of-trunc-nsw.ll
@@ -117,9 +117,10 @@ define i64 @narrow_source_matching_signbits(i32 %x) {
define i64 @narrow_source_not_matching_signbits(i32 %x) {
; CHECK-LABEL: @narrow_source_not_matching_signbits(
-; CHECK-NEXT: [[M:%.*]] = and i32 [[X:%.*]], 8
-; CHECK-NEXT: [[A:%.*]] = shl nsw i32 -1, [[M]]
-; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i8
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 3
+; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT: [[B:%.*]] = add nsw i8 [[TMP3]], -1
; CHECK-NEXT: [[C:%.*]] = sext i8 [[B]] to i64
; CHECK-NEXT: ret i64 [[C]]
;
@@ -148,9 +149,10 @@ define i24 @wide_source_matching_signbits(i32 %x) {
define i24 @wide_source_not_matching_signbits(i32 %x) {
; CHECK-LABEL: @wide_source_not_matching_signbits(
-; CHECK-NEXT: [[M2:%.*]] = and i32 [[X:%.*]], 8
-; CHECK-NEXT: [[A:%.*]] = shl nsw i32 -1, [[M2]]
-; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i8
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 3
+; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT: [[B:%.*]] = add nsw i8 [[TMP3]], -1
; CHECK-NEXT: [[C:%.*]] = sext i8 [[B]] to i24
; CHECK-NEXT: ret i24 [[C]]
;
@@ -178,9 +180,11 @@ define i32 @same_source_matching_signbits(i32 %x) {
define i32 @same_source_not_matching_signbits(i32 %x) {
; CHECK-LABEL: @same_source_not_matching_signbits(
-; CHECK-NEXT: [[M2:%.*]] = and i32 [[X:%.*]], 8
-; CHECK-NEXT: [[TMP1:%.*]] = shl i32 -16777216, [[M2]]
-; CHECK-NEXT: [[C:%.*]] = ashr exact i32 [[TMP1]], 24
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 3
+; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT: [[B:%.*]] = add nsw i8 [[TMP3]], -1
+; CHECK-NEXT: [[C:%.*]] = sext i8 [[B]] to i32
; CHECK-NEXT: ret i32 [[C]]
;
%m2 = and i32 %x, 8
@@ -208,9 +212,10 @@ define i32 @same_source_matching_signbits_extra_use(i32 %x) {
define i32 @same_source_not_matching_signbits_extra_use(i32 %x) {
; CHECK-LABEL: @same_source_not_matching_signbits_extra_use(
-; CHECK-NEXT: [[M2:%.*]] = and i32 [[X:%.*]], 8
-; CHECK-NEXT: [[A:%.*]] = shl nsw i32 -1, [[M2]]
-; CHECK-NEXT: [[B:%.*]] = trunc i32 [[A]] to i8
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[X:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = lshr i8 [[TMP1]], 3
+; CHECK-NEXT: [[TMP3:%.*]] = and i8 [[TMP2]], 1
+; CHECK-NEXT: [[B:%.*]] = add nsw i8 [[TMP3]], -1
; CHECK-NEXT: call void @use8(i8 [[B]])
; CHECK-NEXT: [[C:%.*]] = sext i8 [[B]] to i32
; CHECK-NEXT: ret i32 [[C]]
diff --git a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
index 00a19e4962e6c..954e3e4ae47e2 100644
--- a/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
+++ b/llvm/test/Transforms/InstCombine/shift-amount-reassociation-in-bittest.ll
@@ -675,9 +675,8 @@ define i1 @constantexpr() {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = load i16, ptr @f.a, align 2
; CHECK-NEXT: [[SHR:%.*]] = lshr i16 [[TMP0]], 1
-; CHECK-NEXT: [[CMP:%.*]] = icmp ne i16 ptrtoint (ptr @f.a to i16), 1
-; CHECK-NEXT: [[ZEXT:%.*]] = zext i1 [[CMP]] to i16
-; CHECK-NEXT: [[TMP1:%.*]] = shl nuw nsw i16 1, [[ZEXT]]
+; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp eq i16 ptrtoint (ptr @f.a to i16), 1
+; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[CMP_NOT]], i16 1, i16 2
; CHECK-NEXT: [[TMP2:%.*]] = and i16 [[SHR]], [[TMP1]]
; CHECK-NEXT: [[TOBOOL:%.*]] = icmp ne i16 [[TMP2]], 0
; CHECK-NEXT: ret i1 [[TOBOOL]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/101731
More information about the llvm-commits
mailing list