[llvm] [Reassociate] Preserve NSW flags after expr tree rewriting (PR #93105)

via llvm-commits llvm-commits at lists.llvm.org
Wed May 22 15:49:43 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Akshay Deodhar (akshayrdeodhar)

<details>
<summary>Changes</summary>

We can guarantee NSW on all operands in a reassociated add expression tree when:

- All adds in an add operator tree are NSW, AND either 
  - All add operands are guaranteed to be nonnegative, OR 
  - All adds are also NUW

- Alive2: - Nonnegative Operands
	- 3 operands: https://alive2.llvm.org/ce/z/G4XW6Q
	- 4 operands: https://alive2.llvm.org/ce/z/FWcZ6D

    - NUW NSW adds: https://alive2.llvm.org/ce/z/vRUxeC

---
Full diff: https://github.com/llvm/llvm-project/pull/93105.diff


4 Files Affected:

- (modified) llvm/include/llvm/Transforms/Scalar/Reassociate.h (+8-1) 
- (modified) llvm/lib/Transforms/Scalar/Reassociate.cpp (+23-13) 
- (modified) llvm/test/Transforms/Reassociate/local-cse.ll (+20-20) 
- (added) llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll (+61) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
index f3a2e0f4380eb..48a8f4083eed2 100644
--- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h
+++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
@@ -63,6 +63,13 @@ struct Factor {
   Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {}
 };
 
+struct OverflowTracking {
+  bool HasNUW;
+  bool HasNSW;
+  bool AllKnownNonNegative;
+  OverflowTracking(void) : HasNUW(true), HasNSW(true), AllKnownNonNegative(true) {}
+};
+
 class XorOpnd;
 
 } // end namespace reassociate
@@ -103,7 +110,7 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
   void ReassociateExpression(BinaryOperator *I);
   void RewriteExprTree(BinaryOperator *I,
                        SmallVectorImpl<reassociate::ValueEntry> &Ops,
-                       bool HasNUW);
+                       reassociate::OverflowTracking Flags);
   Value *OptimizeExpression(BinaryOperator *I,
                             SmallVectorImpl<reassociate::ValueEntry> &Ops);
   Value *OptimizeAdd(Instruction *I,
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index d91320863e241..f7a476eaaa266 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -471,7 +471,7 @@ using RepeatedValue = std::pair<Value*, APInt>;
 static bool LinearizeExprTree(Instruction *I,
                               SmallVectorImpl<RepeatedValue> &Ops,
                               ReassociatePass::OrderedSet &ToRedo,
-                              bool &HasNUW) {
+                              reassociate::OverflowTracking &Flags) {
   assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
          "Expected a UnaryOperator or BinaryOperator!");
   LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
@@ -512,6 +512,7 @@ static bool LinearizeExprTree(Instruction *I,
   using LeafMap = DenseMap<Value *, APInt>;
   LeafMap Leaves; // Leaf -> Total weight so far.
   SmallVector<Value *, 8> LeafOrder; // Ensure deterministic leaf output order.
+  const DataLayout DL = I->getModule()->getDataLayout();
 
 #ifndef NDEBUG
   SmallPtrSet<Value *, 8> Visited; // For checking the iteration scheme.
@@ -520,8 +521,10 @@ static bool LinearizeExprTree(Instruction *I,
     std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
     I = P.first; // We examine the operands of this binary operator.
 
-    if (isa<OverflowingBinaryOperator>(I))
-      HasNUW &= I->hasNoUnsignedWrap();
+    if (isa<OverflowingBinaryOperator>(I)) {
+      Flags.HasNUW &= I->hasNoUnsignedWrap();
+      Flags.HasNSW &= I->hasNoSignedWrap();
+    }
 
     for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
       Value *Op = I->getOperand(OpIdx);
@@ -648,6 +651,9 @@ static bool LinearizeExprTree(Instruction *I,
     // Ensure the leaf is only output once.
     It->second = 0;
     Ops.push_back(std::make_pair(V, Weight));
+    if (Opcode == Instruction::Add && Flags.AllKnownNonNegative) {
+      Flags.AllKnownNonNegative &= isKnownNonNegative(V, SimplifyQuery(DL));
+    }
   }
 
   // For nilpotent operations or addition there may be no operands, for example
@@ -666,7 +672,7 @@ static bool LinearizeExprTree(Instruction *I,
 /// linearized and optimized, emit them in-order.
 void ReassociatePass::RewriteExprTree(BinaryOperator *I,
                                       SmallVectorImpl<ValueEntry> &Ops,
-                                      bool HasNUW) {
+                                      OverflowTracking Flags) {
   assert(Ops.size() > 1 && "Single values should be used directly!");
 
   // Since our optimizations should never increase the number of operations, the
@@ -834,8 +840,12 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
           // Note that it doesn't hold for mul if one of the operands is zero.
           // TODO: We can preserve NUW flag if we prove that all mul operands
           // are non-zero.
-          if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add)
-            ExpressionChangedStart->setHasNoUnsignedWrap();
+          if (ExpressionChangedStart->getOpcode() == Instruction::Add) {
+            if (Flags.HasNUW)
+              ExpressionChangedStart->setHasNoUnsignedWrap();
+            if (Flags.HasNSW && (Flags.AllKnownNonNegative || Flags.HasNUW))
+              ExpressionChangedStart->setHasNoSignedWrap();
+          }
         }
       }
 
@@ -1192,8 +1202,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
     return nullptr;
 
   SmallVector<RepeatedValue, 8> Tree;
-  bool HasNUW = true;
-  MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW);
+  OverflowTracking Flags;
+  MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, Flags);
   SmallVector<ValueEntry, 8> Factors;
   Factors.reserve(Tree.size());
   for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
@@ -1235,7 +1245,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
 
   if (!FoundFactor) {
     // Make sure to restore the operands to the expression tree.
-    RewriteExprTree(BO, Factors, HasNUW);
+    RewriteExprTree(BO, Factors, Flags);
     return nullptr;
   }
 
@@ -1247,7 +1257,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
     RedoInsts.insert(BO);
     V = Factors[0].Op;
   } else {
-    RewriteExprTree(BO, Factors, HasNUW);
+    RewriteExprTree(BO, Factors, Flags);
     V = BO;
   }
 
@@ -2373,8 +2383,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
   // First, walk the expression tree, linearizing the tree, collecting the
   // operand information.
   SmallVector<RepeatedValue, 8> Tree;
-  bool HasNUW = true;
-  MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW);
+  OverflowTracking Flags;
+  MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, Flags);
   SmallVector<ValueEntry, 8> Ops;
   Ops.reserve(Tree.size());
   for (const RepeatedValue &E : Tree)
@@ -2567,7 +2577,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
              dbgs() << '\n');
   // Now that we ordered and optimized the expressions, splat them back into
   // the expression tree, removing any unneeded nodes.
-  RewriteExprTree(I, Ops, HasNUW);
+  RewriteExprTree(I, Ops, Flags);
 }
 
 void
diff --git a/llvm/test/Transforms/Reassociate/local-cse.ll b/llvm/test/Transforms/Reassociate/local-cse.ll
index 4d0467e263f55..d0d609f022b46 100644
--- a/llvm/test/Transforms/Reassociate/local-cse.ll
+++ b/llvm/test/Transforms/Reassociate/local-cse.ll
@@ -26,16 +26,16 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
 ; LOCAL_CSE-LABEL: define void @chain_spanning_several_blocks
 ; LOCAL_CSE-SAME: (i64 [[INV1:%.*]], i64 [[INV2:%.*]], i64 [[INV3:%.*]], i64 [[INV4:%.*]], i64 [[INV5:%.*]]) {
 ; LOCAL_CSE-NEXT:  bb1:
-; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[INV2]], [[INV1]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw nsw i64 [[INV2]], [[INV1]]
 ; LOCAL_CSE-NEXT:    br label [[BB2:%.*]]
 ; LOCAL_CSE:       bb2:
 ; LOCAL_CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4]]
-; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add nuw i64 [[INV3]], [[INV1]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV4]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV5]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_B1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add nuw nsw i64 [[INV3]], [[INV1]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_C0]], [[VAL_BB2]]
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
@@ -47,11 +47,11 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
 ; CSE-NEXT:    br label [[BB2:%.*]]
 ; CSE:       bb2:
 ; CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1]]
-; CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2]]
+; CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1]]
+; CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV2]]
 ; CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4]]
 ; CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5]]
-; CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3]]
+; CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV3]]
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
@@ -90,19 +90,19 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
 ; LOCAL_CSE-NEXT:    br label [[BB1:%.*]]
 ; LOCAL_CSE:       bb1:
 ; LOCAL_CSE-NEXT:    [[INV1_BB1:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[INV1_BB1]], [[INV2_BB0]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw nsw i64 [[INV1_BB1]], [[INV2_BB0]]
 ; LOCAL_CSE-NEXT:    br label [[BB2:%.*]]
 ; LOCAL_CSE:       bb2:
 ; LOCAL_CSE-NEXT:    [[INV3_BB2:%.*]] = call i64 @get_val()
 ; LOCAL_CSE-NEXT:    [[INV4_BB2:%.*]] = call i64 @get_val()
 ; LOCAL_CSE-NEXT:    [[INV5_BB2:%.*]] = call i64 @get_val()
 ; LOCAL_CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[INV3_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV4_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV5_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_B1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1_BB1]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_C0]], [[INV3_BB2]]
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
@@ -120,11 +120,11 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
 ; CSE-NEXT:    [[INV4_BB2:%.*]] = call i64 @get_val()
 ; CSE-NEXT:    [[INV5_BB2:%.*]] = call i64 @get_val()
 ; CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
-; CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2_BB0]]
+; CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw nsw i64 [[VAL_BB2]], [[INV1_BB1]]
+; CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV2_BB0]]
 ; CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4_BB2]]
 ; CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5_BB2]]
-; CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3_BB2]]
+; CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw nsw i64 [[CHAIN_A0]], [[INV3_BB2]]
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
diff --git a/llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll b/llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll
new file mode 100644
index 0000000000000..61f6537632f41
--- /dev/null
+++ b/llvm/test/Transforms/Reassociate/reassoc-add-nsw.ll
@@ -0,0 +1,61 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=reassociate -S | FileCheck %s
+define i32 @nsw_preserve_nonnegative(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
+; CHECK-LABEL: define i32 @nsw_preserve_nonnegative(
+; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
+; CHECK-NEXT:    [[V0:%.*]] = load i32, ptr [[PTR0]], align 4, !range [[RNG0:![0-9]+]]
+; CHECK-NEXT:    [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]]
+; CHECK-NEXT:    [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]]
+; CHECK-NEXT:    [[ADD0:%.*]] = add nsw i32 [[V1]], [[V0]]
+; CHECK-NEXT:    [[ADD1:%.*]] = add nsw i32 [[ADD0]], [[V2]]
+; CHECK-NEXT:    ret i32 [[ADD1]]
+;
+  %v0 = load i32, ptr %ptr0, !range !1
+  %v1 = load i32, ptr %ptr1, !range !1
+  %v2 = load i32, ptr %ptr2, !range !1
+  %add0 = add nsw i32 %v1, %v2
+  %add1 = add nsw i32 %add0, %v0
+  ret i32 %add1
+}
+
+define i32 @nsw_preserve_nuw_nsw(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
+; CHECK-LABEL: define i32 @nsw_preserve_nuw_nsw(
+; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
+; CHECK-NEXT:    [[V0:%.*]] = load i32, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[V1:%.*]] = load i32, ptr [[PTR1]], align 4
+; CHECK-NEXT:    [[V2:%.*]] = load i32, ptr [[PTR2]], align 4
+; CHECK-NEXT:    [[ADD0:%.*]] = add nuw nsw i32 [[V1]], [[V0]]
+; CHECK-NEXT:    [[ADD1:%.*]] = add nuw nsw i32 [[ADD0]], [[V2]]
+; CHECK-NEXT:    ret i32 [[ADD1]]
+;
+  %v0 = load i32, ptr %ptr0
+  %v1 = load i32, ptr %ptr1
+  %v2 = load i32, ptr %ptr2
+  %add0 = add nuw nsw i32 %v1, %v2
+  %add1 = add nuw nsw i32 %add0, %v0
+  ret i32 %add1
+}
+
+define i32 @nsw_dont_preserve_negative(ptr %ptr0, ptr %ptr1, ptr %ptr2) {
+; CHECK-LABEL: define i32 @nsw_dont_preserve_negative(
+; CHECK-SAME: ptr [[PTR0:%.*]], ptr [[PTR1:%.*]], ptr [[PTR2:%.*]]) {
+; CHECK-NEXT:    [[V0:%.*]] = load i32, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[V1:%.*]] = load i32, ptr [[PTR1]], align 4, !range [[RNG0]]
+; CHECK-NEXT:    [[V2:%.*]] = load i32, ptr [[PTR2]], align 4, !range [[RNG0]]
+; CHECK-NEXT:    [[ADD0:%.*]] = add i32 [[V1]], [[V0]]
+; CHECK-NEXT:    [[ADD1:%.*]] = add i32 [[ADD0]], [[V2]]
+; CHECK-NEXT:    ret i32 [[ADD1]]
+;
+  %v0 = load i32, ptr %ptr0
+  %v1 = load i32, ptr %ptr1, !range !1
+  %v2 = load i32, ptr %ptr2, !range !1
+  %add0 = add nsw i32 %v1, %v2
+  %add1 = add nsw i32 %add0, %v0
+  ret i32 %add1
+}
+
+; Positive 32 bit integers
+!1 = !{i32 0, i32 2147483648}
+;.
+; CHECK: [[RNG0]] = !{i32 0, i32 -2147483648}
+;.

``````````

</details>


https://github.com/llvm/llvm-project/pull/93105


More information about the llvm-commits mailing list