[llvm] [Reassociate] Preserve NUW flags after expr tree rewriting (PR #72360)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 14 23:43:46 PST 2023
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Yingwei Zheng (dtcxzyw)
<details>
<summary>Changes</summary>
Alive2: https://alive2.llvm.org/ce/z/38KiC_
This missed optimization is discovered with the help of https://github.com/AliveToolkit/alive2/pull/962.
---
Full diff: https://github.com/llvm/llvm-project/pull/72360.diff
3 Files Affected:
- (modified) llvm/include/llvm/Transforms/Scalar/Reassociate.h (+2-1)
- (modified) llvm/lib/Transforms/Scalar/Reassociate.cpp (+19-9)
- (modified) llvm/test/Transforms/Reassociate/local-cse.ll (+20-20)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
index 28794d27325adec..7e47f8ae5d81e96 100644
--- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h
+++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
@@ -102,7 +102,8 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
void canonicalizeOperands(Instruction *I);
void ReassociateExpression(BinaryOperator *I);
void RewriteExprTree(BinaryOperator *I,
- SmallVectorImpl<reassociate::ValueEntry> &Ops);
+ SmallVectorImpl<reassociate::ValueEntry> &Ops,
+ bool HasNUW);
Value *OptimizeExpression(BinaryOperator *I,
SmallVectorImpl<reassociate::ValueEntry> &Ops);
Value *OptimizeAdd(Instruction *I,
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 9c4a344d4295f8a..07e8f1b24d8c759 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -466,7 +466,8 @@ using RepeatedValue = std::pair<Value*, APInt>;
/// type and thus make the expression bigger.
static bool LinearizeExprTree(Instruction *I,
SmallVectorImpl<RepeatedValue> &Ops,
- ReassociatePass::OrderedSet &ToRedo) {
+ ReassociatePass::OrderedSet &ToRedo,
+ bool &HasNUW) {
assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
"Expected a UnaryOperator or BinaryOperator!");
LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
@@ -515,6 +516,9 @@ static bool LinearizeExprTree(Instruction *I,
std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
I = P.first; // We examine the operands of this binary operator.
+ if (isa<OverflowingBinaryOperator>(I))
+ HasNUW &= I->hasNoUnsignedWrap();
+
for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
Value *Op = I->getOperand(OpIdx);
APInt Weight = P.second; // Number of paths to this operand.
@@ -657,7 +661,8 @@ static bool LinearizeExprTree(Instruction *I,
/// Now that the operands for this expression tree are
/// linearized and optimized, emit them in-order.
void ReassociatePass::RewriteExprTree(BinaryOperator *I,
- SmallVectorImpl<ValueEntry> &Ops) {
+ SmallVectorImpl<ValueEntry> &Ops,
+ bool HasNUW) {
assert(Ops.size() > 1 && "Single values should be used directly!");
// Since our optimizations should never increase the number of operations, the
@@ -814,14 +819,17 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
if (ExpressionChangedStart) {
bool ClearFlags = true;
do {
- // Preserve FastMathFlags.
+ // Preserve flags.
if (ClearFlags) {
if (isa<FPMathOperator>(I)) {
FastMathFlags Flags = I->getFastMathFlags();
ExpressionChangedStart->clearSubclassOptionalData();
ExpressionChangedStart->setFastMathFlags(Flags);
- } else
+ } else {
ExpressionChangedStart->clearSubclassOptionalData();
+ if (HasNUW && isa<OverflowingBinaryOperator>(ExpressionChangedStart))
+ ExpressionChangedStart->setHasNoUnsignedWrap();
+ }
}
if (ExpressionChangedStart == ExpressionChangedEnd)
@@ -1171,7 +1179,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
return nullptr;
SmallVector<RepeatedValue, 8> Tree;
- MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts);
+ bool HasNUW = true;
+ MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW);
SmallVector<ValueEntry, 8> Factors;
Factors.reserve(Tree.size());
for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
@@ -1213,7 +1222,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
if (!FoundFactor) {
// Make sure to restore the operands to the expression tree.
- RewriteExprTree(BO, Factors);
+ RewriteExprTree(BO, Factors, HasNUW);
return nullptr;
}
@@ -1225,7 +1234,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
RedoInsts.insert(BO);
V = Factors[0].Op;
} else {
- RewriteExprTree(BO, Factors);
+ RewriteExprTree(BO, Factors, HasNUW);
V = BO;
}
@@ -2349,7 +2358,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
// First, walk the expression tree, linearizing the tree, collecting the
// operand information.
SmallVector<RepeatedValue, 8> Tree;
- MadeChange |= LinearizeExprTree(I, Tree, RedoInsts);
+ bool HasNUW = true;
+ MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW);
SmallVector<ValueEntry, 8> Ops;
Ops.reserve(Tree.size());
for (const RepeatedValue &E : Tree)
@@ -2542,7 +2552,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
dbgs() << '\n');
// Now that we ordered and optimized the expressions, splat them back into
// the expression tree, removing any unneeded nodes.
- RewriteExprTree(I, Ops);
+ RewriteExprTree(I, Ops, HasNUW);
}
void
diff --git a/llvm/test/Transforms/Reassociate/local-cse.ll b/llvm/test/Transforms/Reassociate/local-cse.ll
index 1609cb1b36fd93e..4d0467e263f5538 100644
--- a/llvm/test/Transforms/Reassociate/local-cse.ll
+++ b/llvm/test/Transforms/Reassociate/local-cse.ll
@@ -26,16 +26,16 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
; LOCAL_CSE-LABEL: define void @chain_spanning_several_blocks
; LOCAL_CSE-SAME: (i64 [[INV1:%.*]], i64 [[INV2:%.*]], i64 [[INV3:%.*]], i64 [[INV4:%.*]], i64 [[INV5:%.*]]) {
; LOCAL_CSE-NEXT: bb1:
-; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[INV2]], [[INV1]]
+; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV2]], [[INV1]]
; LOCAL_CSE-NEXT: br label [[BB2:%.*]]
; LOCAL_CSE: bb2:
; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4]]
-; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5]]
-; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add i64 [[INV3]], [[INV1]]
-; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4]]
+; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5]]
+; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[INV3]], [[INV1]]
+; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[VAL_BB2]]
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
@@ -47,11 +47,11 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
; CSE-NEXT: br label [[BB2:%.*]]
; CSE: bb2:
; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
-; CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1]]
-; CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2]]
+; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1]]
+; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2]]
; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4]]
; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5]]
-; CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3]]
+; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3]]
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
@@ -90,19 +90,19 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
; LOCAL_CSE-NEXT: br label [[BB1:%.*]]
; LOCAL_CSE: bb1:
; LOCAL_CSE-NEXT: [[INV1_BB1:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[INV1_BB1]], [[INV2_BB0]]
+; LOCAL_CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[INV1_BB1]], [[INV2_BB0]]
; LOCAL_CSE-NEXT: br label [[BB2:%.*]]
; LOCAL_CSE: bb2:
; LOCAL_CSE-NEXT: [[INV3_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val()
; LOCAL_CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4_BB2]]
-; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5_BB2]]
-; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]]
-; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[INV3_BB2]]
+; LOCAL_CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4_BB2]]
+; LOCAL_CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT: [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5_BB2]]
+; LOCAL_CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT: [[CHAIN_C0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
+; LOCAL_CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[INV3_BB2]]
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; LOCAL_CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
@@ -120,11 +120,11 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
; CSE-NEXT: [[INV4_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[INV5_BB2:%.*]] = call i64 @get_val()
; CSE-NEXT: [[VAL_BB2:%.*]] = call i64 @get_val()
-; CSE-NEXT: [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]]
-; CSE-NEXT: [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2_BB0]]
+; CSE-NEXT: [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
+; CSE-NEXT: [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2_BB0]]
; CSE-NEXT: [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4_BB2]]
; CSE-NEXT: [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5_BB2]]
-; CSE-NEXT: [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3_BB2]]
+; CSE-NEXT: [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3_BB2]]
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_A2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_B2]])
; CSE-NEXT: call void @keep_alive(i64 [[CHAIN_C1]])
``````````
</details>
https://github.com/llvm/llvm-project/pull/72360
More information about the llvm-commits
mailing list