[llvm] 312cb34 - [Reassociate] Preserve NUW flags after expr tree rewriting (#72360)

via llvm-commits llvm-commits at lists.llvm.org
Sat Dec 9 00:45:52 PST 2023


Author: Yingwei Zheng
Date: 2023-12-09T16:45:48+08:00
New Revision: 312cb34da6a5529fbfaa1be62f1aa9bbb26ce506

URL: https://github.com/llvm/llvm-project/commit/312cb34da6a5529fbfaa1be62f1aa9bbb26ce506
DIFF: https://github.com/llvm/llvm-project/commit/312cb34da6a5529fbfaa1be62f1aa9bbb26ce506.diff

LOG: [Reassociate] Preserve NUW flags after expr tree rewriting (#72360)

Alive2: https://alive2.llvm.org/ce/z/38KiC_

Added: 
    llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll

Modified: 
    llvm/include/llvm/Transforms/Scalar/Reassociate.h
    llvm/lib/Transforms/Scalar/Reassociate.cpp
    llvm/test/Transforms/Reassociate/local-cse.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
index 28794d27325ade..7e47f8ae5d81e9 100644
--- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h
+++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
@@ -102,7 +102,8 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
   void canonicalizeOperands(Instruction *I);
   void ReassociateExpression(BinaryOperator *I);
   void RewriteExprTree(BinaryOperator *I,
-                       SmallVectorImpl<reassociate::ValueEntry> &Ops);
+                       SmallVectorImpl<reassociate::ValueEntry> &Ops,
+                       bool HasNUW);
   Value *OptimizeExpression(BinaryOperator *I,
                             SmallVectorImpl<reassociate::ValueEntry> &Ops);
   Value *OptimizeAdd(Instruction *I,

diff  --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 42e979db24d29b..818c7b40d489ef 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -466,7 +466,8 @@ using RepeatedValue = std::pair<Value*, APInt>;
 /// type and thus make the expression bigger.
 static bool LinearizeExprTree(Instruction *I,
                               SmallVectorImpl<RepeatedValue> &Ops,
-                              ReassociatePass::OrderedSet &ToRedo) {
+                              ReassociatePass::OrderedSet &ToRedo,
+                              bool &HasNUW) {
   assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
          "Expected a UnaryOperator or BinaryOperator!");
   LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
@@ -515,6 +516,9 @@ static bool LinearizeExprTree(Instruction *I,
     std::pair<Instruction*, APInt> P = Worklist.pop_back_val();
     I = P.first; // We examine the operands of this binary operator.
 
+    if (isa<OverflowingBinaryOperator>(I))
+      HasNUW &= I->hasNoUnsignedWrap();
+
     for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
       Value *Op = I->getOperand(OpIdx);
       APInt Weight = P.second; // Number of paths to this operand.
@@ -657,7 +661,8 @@ static bool LinearizeExprTree(Instruction *I,
 /// Now that the operands for this expression tree are
 /// linearized and optimized, emit them in-order.
 void ReassociatePass::RewriteExprTree(BinaryOperator *I,
-                                      SmallVectorImpl<ValueEntry> &Ops) {
+                                      SmallVectorImpl<ValueEntry> &Ops,
+                                      bool HasNUW) {
   assert(Ops.size() > 1 && "Single values should be used directly!");
 
   // Since our optimizations should never increase the number of operations, the
@@ -814,14 +819,20 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
   if (ExpressionChangedStart) {
     bool ClearFlags = true;
     do {
-      // Preserve FastMathFlags.
+      // Preserve flags.
       if (ClearFlags) {
         if (isa<FPMathOperator>(I)) {
           FastMathFlags Flags = I->getFastMathFlags();
           ExpressionChangedStart->clearSubclassOptionalData();
           ExpressionChangedStart->setFastMathFlags(Flags);
-        } else
+        } else {
           ExpressionChangedStart->clearSubclassOptionalData();
+          // Note that it doesn't hold for mul if one of the operands is zero.
+          // TODO: We can preserve NUW flag if we prove that all mul operands
+          // are non-zero.
+          if (HasNUW && ExpressionChangedStart->getOpcode() == Instruction::Add)
+            ExpressionChangedStart->setHasNoUnsignedWrap();
+        }
       }
 
       if (ExpressionChangedStart == ExpressionChangedEnd)
@@ -1175,7 +1186,8 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
     return nullptr;
 
   SmallVector<RepeatedValue, 8> Tree;
-  MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts);
+  bool HasNUW = true;
+  MadeChange |= LinearizeExprTree(BO, Tree, RedoInsts, HasNUW);
   SmallVector<ValueEntry, 8> Factors;
   Factors.reserve(Tree.size());
   for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
@@ -1217,7 +1229,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
 
   if (!FoundFactor) {
     // Make sure to restore the operands to the expression tree.
-    RewriteExprTree(BO, Factors);
+    RewriteExprTree(BO, Factors, HasNUW);
     return nullptr;
   }
 
@@ -1229,7 +1241,7 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
     RedoInsts.insert(BO);
     V = Factors[0].Op;
   } else {
-    RewriteExprTree(BO, Factors);
+    RewriteExprTree(BO, Factors, HasNUW);
     V = BO;
   }
 
@@ -2354,7 +2366,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
   // First, walk the expression tree, linearizing the tree, collecting the
   // operand information.
   SmallVector<RepeatedValue, 8> Tree;
-  MadeChange |= LinearizeExprTree(I, Tree, RedoInsts);
+  bool HasNUW = true;
+  MadeChange |= LinearizeExprTree(I, Tree, RedoInsts, HasNUW);
   SmallVector<ValueEntry, 8> Ops;
   Ops.reserve(Tree.size());
   for (const RepeatedValue &E : Tree)
@@ -2547,7 +2560,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
              dbgs() << '\n');
   // Now that we ordered and optimized the expressions, splat them back into
   // the expression tree, removing any unneeded nodes.
-  RewriteExprTree(I, Ops);
+  RewriteExprTree(I, Ops, HasNUW);
 }
 
 void

diff  --git a/llvm/test/Transforms/Reassociate/local-cse.ll b/llvm/test/Transforms/Reassociate/local-cse.ll
index 1609cb1b36fd93..4d0467e263f553 100644
--- a/llvm/test/Transforms/Reassociate/local-cse.ll
+++ b/llvm/test/Transforms/Reassociate/local-cse.ll
@@ -26,16 +26,16 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
 ; LOCAL_CSE-LABEL: define void @chain_spanning_several_blocks
 ; LOCAL_CSE-SAME: (i64 [[INV1:%.*]], i64 [[INV2:%.*]], i64 [[INV3:%.*]], i64 [[INV4:%.*]], i64 [[INV5:%.*]]) {
 ; LOCAL_CSE-NEXT:  bb1:
-; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add i64 [[INV2]], [[INV1]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[INV2]], [[INV1]]
 ; LOCAL_CSE-NEXT:    br label [[BB2:%.*]]
 ; LOCAL_CSE:       bb2:
 ; LOCAL_CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4]]
-; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add i64 [[INV3]], [[INV1]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add nuw i64 [[INV3]], [[INV1]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[VAL_BB2]]
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
@@ -47,11 +47,11 @@ define void @chain_spanning_several_blocks(i64 %inv1, i64 %inv2, i64 %inv3, i64
 ; CSE-NEXT:    br label [[BB2:%.*]]
 ; CSE:       bb2:
 ; CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; CSE-NEXT:    [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1]]
-; CSE-NEXT:    [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2]]
+; CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1]]
+; CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2]]
 ; CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4]]
 ; CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5]]
-; CSE-NEXT:    [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3]]
+; CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3]]
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
@@ -90,19 +90,19 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
 ; LOCAL_CSE-NEXT:    br label [[BB1:%.*]]
 ; LOCAL_CSE:       bb1:
 ; LOCAL_CSE-NEXT:    [[INV1_BB1:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add i64 [[INV1_BB1]], [[INV2_BB0]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[INV1_BB1]], [[INV2_BB0]]
 ; LOCAL_CSE-NEXT:    br label [[BB2:%.*]]
 ; LOCAL_CSE:       bb2:
 ; LOCAL_CSE-NEXT:    [[INV3_BB2:%.*]] = call i64 @get_val()
 ; LOCAL_CSE-NEXT:    [[INV4_BB2:%.*]] = call i64 @get_val()
 ; LOCAL_CSE-NEXT:    [[INV5_BB2:%.*]] = call i64 @get_val()
 ; LOCAL_CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV4_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add i64 [[CHAIN_A1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add i64 [[CHAIN_A0]], [[INV5_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add i64 [[CHAIN_B1]], [[VAL_BB2]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]]
-; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add i64 [[CHAIN_C0]], [[INV3_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV4_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw i64 [[CHAIN_A1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV5_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw i64 [[CHAIN_B1]], [[VAL_BB2]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
+; LOCAL_CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_C0]], [[INV3_BB2]]
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; LOCAL_CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])
@@ -120,11 +120,11 @@ define void @chain_spanning_several_blocks_no_entry_anchor() {
 ; CSE-NEXT:    [[INV4_BB2:%.*]] = call i64 @get_val()
 ; CSE-NEXT:    [[INV5_BB2:%.*]] = call i64 @get_val()
 ; CSE-NEXT:    [[VAL_BB2:%.*]] = call i64 @get_val()
-; CSE-NEXT:    [[CHAIN_A0:%.*]] = add i64 [[VAL_BB2]], [[INV1_BB1]]
-; CSE-NEXT:    [[CHAIN_A1:%.*]] = add i64 [[CHAIN_A0]], [[INV2_BB0]]
+; CSE-NEXT:    [[CHAIN_A0:%.*]] = add nuw i64 [[VAL_BB2]], [[INV1_BB1]]
+; CSE-NEXT:    [[CHAIN_A1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV2_BB0]]
 ; CSE-NEXT:    [[CHAIN_A2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV4_BB2]]
 ; CSE-NEXT:    [[CHAIN_B2:%.*]] = add nuw nsw i64 [[CHAIN_A1]], [[INV5_BB2]]
-; CSE-NEXT:    [[CHAIN_C1:%.*]] = add i64 [[CHAIN_A0]], [[INV3_BB2]]
+; CSE-NEXT:    [[CHAIN_C1:%.*]] = add nuw i64 [[CHAIN_A0]], [[INV3_BB2]]
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_A2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_B2]])
 ; CSE-NEXT:    call void @keep_alive(i64 [[CHAIN_C1]])

diff  --git a/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll b/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll
new file mode 100644
index 00000000000000..682fad8d222b7c
--- /dev/null
+++ b/llvm/test/Transforms/Reassociate/reassoc-mul-nuw.ll
@@ -0,0 +1,34 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=reassociate -S | FileCheck %s
+
+; We cannot preserve nuw flags for mul
+define i4 @nuw_preserve_negative(i4 %a, i4 %b, i4 %c) {
+; CHECK-LABEL: define i4 @nuw_preserve_negative(
+; CHECK-SAME: i4 [[A:%.*]], i4 [[B:%.*]], i4 [[C:%.*]]) {
+; CHECK-NEXT:    [[V0:%.*]] = mul i4 [[B]], [[A]]
+; CHECK-NEXT:    [[V1:%.*]] = mul i4 [[V0]], [[C]]
+; CHECK-NEXT:    ret i4 [[V1]]
+;
+  %v0 = mul nuw i4 %a, %c
+  %v1 = mul nuw i4 %v0, %b
+  ret i4 %v1
+}
+
+; TODO: we can add nuw flags if we know all operands are non-zero.
+define i4 @nuw_preserve_non_zero(i4 %a, i4 %b, i4 %c) {
+; CHECK-LABEL: define i4 @nuw_preserve_non_zero(
+; CHECK-SAME: i4 [[A:%.*]], i4 [[B:%.*]], i4 [[C:%.*]]) {
+; CHECK-NEXT:    [[A0:%.*]] = add nuw i4 [[A]], 1
+; CHECK-NEXT:    [[B0:%.*]] = add nuw i4 [[B]], 1
+; CHECK-NEXT:    [[C0:%.*]] = add nuw i4 [[C]], 1
+; CHECK-NEXT:    [[V0:%.*]] = mul i4 [[B0]], [[A0]]
+; CHECK-NEXT:    [[V1:%.*]] = mul i4 [[V0]], [[C0]]
+; CHECK-NEXT:    ret i4 [[V1]]
+;
+  %a0 = add nuw i4 %a, 1
+  %b0 = add nuw i4 %b, 1
+  %c0 = add nuw i4 %c, 1
+  %v0 = mul nuw i4 %a0, %c0
+  %v1 = mul nuw i4 %v0, %b0
+  ret i4 %v1
+}


        


More information about the llvm-commits mailing list