[llvm] goldsteinn/simplify multiple ops (PR #121708)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Jan 5 13:55:47 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: None (goldsteinn)
<details>
<summary>Changes</summary>
- **[InstSimplify] Refactor `simplifyWithOpsReplaced` to allow multiple replacements; NFC**
- **[InstSimplify] Use multi-op replacement when simplify `select`**
---
Full diff: https://github.com/llvm/llvm-project/pull/121708.diff
2 Files Affected:
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+75-48)
- (modified) llvm/test/Transforms/InstCombine/select.ll (+10-25)
``````````diff
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 8567a0504f54e1..aba71c5355ad46 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4275,25 +4275,23 @@ Value *llvm::simplifyFCmpInst(CmpPredicate Predicate, Value *LHS, Value *RHS,
return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
}
-static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
- const SimplifyQuery &Q,
- bool AllowRefinement,
- SmallVectorImpl<Instruction *> *DropFlags,
- unsigned MaxRecurse) {
+static Value *simplifyWithOpsReplaced(Value *V,
+ ArrayRef<std::pair<Value *, Value *>> Ops,
+ const SimplifyQuery &Q,
+ bool AllowRefinement,
+ SmallVectorImpl<Instruction *> *DropFlags,
+ unsigned MaxRecurse) {
assert((AllowRefinement || !Q.CanUseUndef) &&
"If AllowRefinement=false then CanUseUndef=false");
-
- // Trivial replacement.
- if (V == Op)
- return RepOp;
+ for (const auto &OpAndRepOp : Ops) {
+ // Trivial replacement.
+ if (V == OpAndRepOp.first)
+ return OpAndRepOp.second;
+ }
if (!MaxRecurse--)
return nullptr;
- // We cannot replace a constant, and shouldn't even try.
- if (isa<Constant>(Op))
- return nullptr;
-
auto *I = dyn_cast<Instruction>(V);
if (!I)
return nullptr;
@@ -4303,11 +4301,6 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (isa<PHINode>(I))
return nullptr;
- // For vector types, the simplification must hold per-lane, so forbid
- // potentially cross-lane operations like shufflevector.
- if (Op->getType()->isVectorTy() && !isNotCrossLaneOperation(I))
- return nullptr;
-
// Don't fold away llvm.is.constant checks based on assumptions.
if (match(I, m_Intrinsic<Intrinsic::is_constant>()))
return nullptr;
@@ -4316,12 +4309,30 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (isa<FreezeInst>(I))
return nullptr;
+ SmallVector<std::pair<Value *, Value *>> ValidReplacements{};
+ for (const auto &OpAndRepOp : Ops) {
+ // We cannot replace a constant, and shouldn't even try.
+ if (isa<Constant>(OpAndRepOp.first))
+ return nullptr;
+
+ // For vector types, the simplification must hold per-lane, so forbid
+ // potentially cross-lane operations like shufflevector.
+ if (OpAndRepOp.first->getType()->isVectorTy() &&
+ !isNotCrossLaneOperation(I))
+ continue;
+ ValidReplacements.emplace_back(OpAndRepOp);
+ }
+
+ if (ValidReplacements.empty())
+ return nullptr;
+
// Replace Op with RepOp in instruction operands.
SmallVector<Value *, 8> NewOps;
bool AnyReplaced = false;
for (Value *InstOp : I->operands()) {
- if (Value *NewInstOp = simplifyWithOpReplaced(
- InstOp, Op, RepOp, Q, AllowRefinement, DropFlags, MaxRecurse)) {
+ if (Value *NewInstOp =
+ simplifyWithOpsReplaced(InstOp, ValidReplacements, Q,
+ AllowRefinement, DropFlags, MaxRecurse)) {
NewOps.push_back(NewInstOp);
AnyReplaced = InstOp != NewInstOp;
} else {
@@ -4372,7 +4383,9 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
// by assumption and this case never wraps, so nowrap flags can be
// ignored.
if ((Opcode == Instruction::Sub || Opcode == Instruction::Xor) &&
- NewOps[0] == RepOp && NewOps[1] == RepOp)
+ any_of(ValidReplacements, [=](const auto &Rep) {
+ return NewOps[0] == NewOps[1] && NewOps[0] == Rep.second;
+ }))
return Constant::getNullValue(I->getType());
// If we are substituting an absorber constant into a binop and extra
@@ -4382,10 +4395,10 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
// (Op == 0) ? 0 : (Op & -Op) --> Op & -Op
// (Op == 0) ? 0 : (Op * (binop Op, C)) --> Op * (binop Op, C)
// (Op == -1) ? -1 : (Op | (binop C, Op) --> Op | (binop C, Op)
- Constant *Absorber =
- ConstantExpr::getBinOpAbsorber(Opcode, I->getType());
+ Constant *Absorber = ConstantExpr::getBinOpAbsorber(Opcode, I->getType());
if ((NewOps[0] == Absorber || NewOps[1] == Absorber) &&
- impliesPoison(BO, Op))
+ any_of(ValidReplacements,
+ [=](const auto &Rep) { return impliesPoison(BO, Rep.first); }))
return Absorber;
}
@@ -4453,6 +4466,15 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
/*AllowNonDeterministic=*/false);
}
+static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+ const SimplifyQuery &Q,
+ bool AllowRefinement,
+ SmallVectorImpl<Instruction *> *DropFlags,
+ unsigned MaxRecurse) {
+ return simplifyWithOpsReplaced(V, {{Op, RepOp}}, Q, AllowRefinement,
+ DropFlags, MaxRecurse);
+}
+
Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement,
@@ -4595,17 +4617,16 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
/// Try to simplify a select instruction when its condition operand is an
/// integer equality or floating-point equivalence comparison.
-static Value *simplifySelectWithEquivalence(Value *CmpLHS, Value *CmpRHS,
- Value *TrueVal, Value *FalseVal,
- const SimplifyQuery &Q,
- unsigned MaxRecurse) {
- if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q.getWithoutUndef(),
- /* AllowRefinement */ false,
- /* DropFlags */ nullptr, MaxRecurse) == TrueVal)
+static Value *simplifySelectWithEquivalence(
+ ArrayRef<std::pair<Value *, Value *>> Replacements, Value *TrueVal,
+ Value *FalseVal, const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (simplifyWithOpsReplaced(FalseVal, Replacements, Q.getWithoutUndef(),
+ /* AllowRefinement */ false,
+ /* DropFlags */ nullptr, MaxRecurse) == TrueVal)
return FalseVal;
- if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
- /* AllowRefinement */ true,
- /* DropFlags */ nullptr, MaxRecurse) == FalseVal)
+ if (simplifyWithOpsReplaced(TrueVal, Replacements, Q,
+ /* AllowRefinement */ true,
+ /* DropFlags */ nullptr, MaxRecurse) == FalseVal)
return FalseVal;
return nullptr;
@@ -4699,10 +4720,10 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
// the arms of the select. See if substituting this value into the arm and
// simplifying the result yields the same value as the other arm.
if (Pred == ICmpInst::ICMP_EQ) {
- if (Value *V = simplifySelectWithEquivalence(CmpLHS, CmpRHS, TrueVal,
+ if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, TrueVal,
FalseVal, Q, MaxRecurse))
return V;
- if (Value *V = simplifySelectWithEquivalence(CmpRHS, CmpLHS, TrueVal,
+ if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, TrueVal,
FalseVal, Q, MaxRecurse))
return V;
@@ -4712,11 +4733,14 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
if (match(CmpLHS, m_Or(m_Value(X), m_Value(Y))) &&
match(CmpRHS, m_Zero())) {
// (X | Y) == 0 implies X == 0 and Y == 0.
- if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
- Q, MaxRecurse))
+ if (Value *V = simplifySelectWithEquivalence(
+ {{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse))
+ return V;
+ if (Value *V = simplifySelectWithEquivalence({{X, CmpRHS}}, TrueVal,
+ FalseVal, Q, MaxRecurse))
return V;
- if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
- Q, MaxRecurse))
+ if (Value *V = simplifySelectWithEquivalence({{Y, CmpRHS}}, TrueVal,
+ FalseVal, Q, MaxRecurse))
return V;
}
@@ -4724,11 +4748,14 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
if (match(CmpLHS, m_And(m_Value(X), m_Value(Y))) &&
match(CmpRHS, m_AllOnes())) {
// (X & Y) == -1 implies X == -1 and Y == -1.
- if (Value *V = simplifySelectWithEquivalence(X, CmpRHS, TrueVal, FalseVal,
- Q, MaxRecurse))
+ if (Value *V = simplifySelectWithEquivalence(
+ {{X, CmpRHS}, {Y, CmpRHS}}, TrueVal, FalseVal, Q, MaxRecurse))
return V;
- if (Value *V = simplifySelectWithEquivalence(Y, CmpRHS, TrueVal, FalseVal,
- Q, MaxRecurse))
+ if (Value *V = simplifySelectWithEquivalence({{X, CmpRHS}}, TrueVal,
+ FalseVal, Q, MaxRecurse))
+ return V;
+ if (Value *V = simplifySelectWithEquivalence({{Y, CmpRHS}}, TrueVal,
+ FalseVal, Q, MaxRecurse))
return V;
}
}
@@ -4757,11 +4784,11 @@ static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F,
// This transforms is safe if at least one operand is known to not be zero.
// Otherwise, the select can change the sign of a zero operand.
if (IsEquiv) {
- if (Value *V =
- simplifySelectWithEquivalence(CmpLHS, CmpRHS, T, F, Q, MaxRecurse))
+ if (Value *V = simplifySelectWithEquivalence({{CmpLHS, CmpRHS}}, T, F, Q,
+ MaxRecurse))
return V;
- if (Value *V =
- simplifySelectWithEquivalence(CmpRHS, CmpLHS, T, F, Q, MaxRecurse))
+ if (Value *V = simplifySelectWithEquivalence({{CmpRHS, CmpLHS}}, T, F, Q,
+ MaxRecurse))
return V;
}
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 0168a804239a89..85637b27d8adcf 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -3937,11 +3937,8 @@ entry:
define i32 @src_or_eq_0_and_xor(i32 %x, i32 %y) {
; CHECK-LABEL: @src_or_eq_0_and_xor(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y]], [[X]]
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[XOR]]
-; CHECK-NEXT: ret i32 [[COND]]
+; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: ret i32 [[XOR]]
;
entry:
%or = or i32 %y, %x
@@ -3956,11 +3953,8 @@ entry:
define i32 @src_or_eq_0_xor_and(i32 %x, i32 %y) {
; CHECK-LABEL: @src_or_eq_0_xor_and(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[OR:%.*]] = or i32 [[Y:%.*]], [[X:%.*]]
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y]], [[X]]
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP]], i32 0, i32 [[AND]]
-; CHECK-NEXT: ret i32 [[COND]]
+; CHECK-NEXT: [[AND:%.*]] = and i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: ret i32 [[AND]]
;
entry:
%or = or i32 %y, %x
@@ -4438,11 +4432,8 @@ define i32 @src_no_trans_select_and_eq0_xor_and(i32 %x, i32 %y) {
define i32 @src_no_trans_select_or_eq0_or_and(i32 %x, i32 %y) {
; CHECK-LABEL: @src_no_trans_select_or_eq0_or_and(
-; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT: [[OR0:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT: [[AND:%.*]] = and i32 [[X]], [[Y]]
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[OR0]], i32 0, i32 [[AND]]
-; CHECK-NEXT: ret i32 [[COND]]
+; CHECK-NEXT: [[AND:%.*]] = and i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i32 [[AND]]
;
%or = or i32 %x, %y
%or0 = icmp eq i32 %or, 0
@@ -4453,11 +4444,8 @@ define i32 @src_no_trans_select_or_eq0_or_and(i32 %x, i32 %y) {
define i32 @src_no_trans_select_or_eq0_or_xor(i32 %x, i32 %y) {
; CHECK-LABEL: @src_no_trans_select_or_eq0_or_xor(
-; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT: [[OR0:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X]], [[Y]]
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[OR0]], i32 0, i32 [[XOR]]
-; CHECK-NEXT: ret i32 [[COND]]
+; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i32 [[XOR]]
;
%or = or i32 %x, %y
%or0 = icmp eq i32 %or, 0
@@ -4492,11 +4480,8 @@ define i32 @src_no_trans_select_or_eq0_xor_or(i32 %x, i32 %y) {
define i32 @src_no_trans_select_and_ne0_xor_or(i32 %x, i32 %y) {
; CHECK-LABEL: @src_no_trans_select_and_ne0_xor_or(
-; CHECK-NEXT: [[OR:%.*]] = or i32 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT: [[OR0_NOT:%.*]] = icmp eq i32 [[OR]], 0
-; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X]], [[Y]]
-; CHECK-NEXT: [[COND:%.*]] = select i1 [[OR0_NOT]], i32 0, i32 [[XOR]]
-; CHECK-NEXT: ret i32 [[COND]]
+; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: ret i32 [[XOR]]
;
%or = or i32 %x, %y
%or0 = icmp ne i32 %or, 0
``````````
</details>
https://github.com/llvm/llvm-project/pull/121708
More information about the llvm-commits
mailing list