[llvm] [Reassociate] Conserve nsw/nuw flags when factoring out (PR #125773)

via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 4 14:32:38 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: None (joe-rivos)

<details>
<summary>Changes</summary>

When transforming, e.g, A*A+A*B*C+D into A*(A+B*C)+D, we can set the
factored out mul to have nuw/nsw iff all of the adds and muls had this
flag set.

---
Full diff: https://github.com/llvm/llvm-project/pull/125773.diff


3 Files Affected:

- (modified) llvm/include/llvm/Transforms/Scalar/Reassociate.h (+4-2) 
- (modified) llvm/lib/Transforms/Scalar/Reassociate.cpp (+42-11) 
- (modified) llvm/test/Transforms/Reassociate/basictest.ll (+37) 


``````````diff
diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
index 3b2d2b83ced623..8023385c2ad903 100644
--- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h
+++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
@@ -119,9 +119,11 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
                        SmallVectorImpl<reassociate::ValueEntry> &Ops,
                        reassociate::OverflowTracking Flags);
   Value *OptimizeExpression(BinaryOperator *I,
-                            SmallVectorImpl<reassociate::ValueEntry> &Ops);
+                            SmallVectorImpl<reassociate::ValueEntry> &Ops,
+                            reassociate::OverflowTracking Flags);
   Value *OptimizeAdd(Instruction *I,
-                     SmallVectorImpl<reassociate::ValueEntry> &Ops);
+                     SmallVectorImpl<reassociate::ValueEntry> &Ops,
+                     reassociate::OverflowTracking Flags);
   Value *OptimizeXor(Instruction *I,
                      SmallVectorImpl<reassociate::ValueEntry> &Ops);
   bool CombineXorOpnd(BasicBlock::iterator It, reassociate::XorOpnd *Opnd1,
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 7cb9bace47bf44..ceae4e1957cf27 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -1174,16 +1174,21 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
 ///
 /// Ops is the top-level list of add operands we're trying to factor.
 static void FindSingleUseMultiplyFactors(Value *V,
-                                         SmallVectorImpl<Value*> &Factors) {
+                                         SmallVectorImpl<Value *> &Factors,
+                                         bool &AllNUW, bool &AllNSW) {
   BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul);
   if (!BO) {
     Factors.push_back(V);
     return;
   }
 
+  if (isa<OverflowingBinaryOperator>(BO)) {
+    AllNUW &= BO->hasNoUnsignedWrap();
+    AllNSW &= BO->hasNoSignedWrap();
+  }
   // Otherwise, add the LHS and RHS to the list of factors.
-  FindSingleUseMultiplyFactors(BO->getOperand(1), Factors);
-  FindSingleUseMultiplyFactors(BO->getOperand(0), Factors);
+  FindSingleUseMultiplyFactors(BO->getOperand(1), Factors, AllNUW, AllNSW);
+  FindSingleUseMultiplyFactors(BO->getOperand(0), Factors, AllNUW, AllNSW);
 }
 
 /// Optimize a series of operands to an 'and', 'or', or 'xor' instruction.
@@ -1492,7 +1497,9 @@ Value *ReassociatePass::OptimizeXor(Instruction *I,
 /// optimizes based on identities.  If it can be reduced to a single Value, it
 /// is returned, otherwise the Ops list is mutated as necessary.
 Value *ReassociatePass::OptimizeAdd(Instruction *I,
-                                    SmallVectorImpl<ValueEntry> &Ops) {
+                                    SmallVectorImpl<ValueEntry> &Ops,
+                                    OverflowTracking Flags) {
+
   // Scan the operand lists looking for X and -X pairs.  If we find any, we
   // can simplify expressions like X+-X == 0 and X+~X ==-1.  While we're at it,
   // scan for any
@@ -1586,8 +1593,11 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
 
   // Keep track of each multiply we see, to avoid triggering on (X*4)+(X*4)
   // where they are actually the same multiply.
+  // Also track every use of this factor shares nuw/nsw. This will allow the
+  // use of these flags in the factored value.
   unsigned MaxOcc = 0;
   Value *MaxOccVal = nullptr;
+  bool MaxOccAllNUW, MaxOccAllNSW = false;
   for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
     BinaryOperator *BOp =
         isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul);
@@ -1596,7 +1606,9 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
 
     // Compute all of the factors of this added value.
     SmallVector<Value*, 8> Factors;
-    FindSingleUseMultiplyFactors(BOp, Factors);
+    bool AllNUW = Flags.HasNUW;
+    bool AllNSW = Flags.HasNSW;
+    FindSingleUseMultiplyFactors(BOp, Factors, AllNUW, AllNSW);
     assert(Factors.size() > 1 && "Bad linearize!");
 
     // Add one to FactorOccurrences for each unique factor in this op.
@@ -1608,7 +1620,16 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
       unsigned Occ = ++FactorOccurrences[Factor];
       if (Occ > MaxOcc) {
         MaxOcc = Occ;
-        MaxOccVal = Factor;
+        if (MaxOccVal != Factor) {
+          MaxOccVal = Factor;
+          if (Occ == 1) {
+            MaxOccAllNUW = AllNUW;
+            MaxOccAllNSW = AllNSW;
+          } else {
+            MaxOccAllNUW &= AllNUW;
+            MaxOccAllNSW &= AllNSW;
+          }
+        }
       }
 
       // If Factor is a negative constant, add the negated value as a factor
@@ -1690,11 +1711,20 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
     // A*A*B + A*A*C   -->   A*(A*B+A*C)   -->   A*(A*(B+C))
     assert(NumAddedValues > 1 && "Each occurrence should contribute a value");
     (void)NumAddedValues;
-    if (Instruction *VI = dyn_cast<Instruction>(V))
+    if (Instruction *VI = dyn_cast<Instruction>(V)) {
+      if (isa<OverflowingBinaryOperator>(VI)) {
+        VI->setHasNoUnsignedWrap(MaxOccAllNUW);
+        VI->setHasNoSignedWrap(MaxOccAllNSW);
+      }
       RedoInsts.insert(VI);
+    }
 
     // Create the multiply.
     Instruction *V2 = CreateMul(V, MaxOccVal, "reass.mul", I->getIterator(), I);
+    if (isa<OverflowingBinaryOperator>(V2)) {
+      V2->setHasNoUnsignedWrap(MaxOccAllNUW);
+      V2->setHasNoSignedWrap(MaxOccAllNSW);
+    }
 
     // Rerun associate on the multiply in case the inner expression turned into
     // a multiply.  We want to make sure that we keep things in canonical form.
@@ -1890,7 +1920,8 @@ Value *ReassociatePass::OptimizeMul(BinaryOperator *I,
 }
 
 Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
-                                           SmallVectorImpl<ValueEntry> &Ops) {
+                                           SmallVectorImpl<ValueEntry> &Ops,
+                                           OverflowTracking Flags) {
   // Now that we have the linearized expression tree, try to optimize it.
   // Start by folding any constants that we found.
   const DataLayout &DL = I->getDataLayout();
@@ -1944,7 +1975,7 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
 
   case Instruction::Add:
   case Instruction::FAdd:
-    if (Value *Result = OptimizeAdd(I, Ops))
+    if (Value *Result = OptimizeAdd(I, Ops, Flags))
       return Result;
     break;
 
@@ -1956,7 +1987,7 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
   }
 
   if (Ops.size() != NumOps)
-    return OptimizeExpression(I, Ops);
+    return OptimizeExpression(I, Ops, Flags);
   return nullptr;
 }
 
@@ -2305,7 +2336,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
 
   // Now that we have the expression tree in a convenient
   // sorted form, optimize it globally if possible.
-  if (Value *V = OptimizeExpression(I, Ops)) {
+  if (Value *V = OptimizeExpression(I, Ops, Flags)) {
     if (V == I)
       // Self-referential expression in unreachable code.
       return;
diff --git a/llvm/test/Transforms/Reassociate/basictest.ll b/llvm/test/Transforms/Reassociate/basictest.ll
index 3f4057dd14e7e1..b5f87a8202185c 100644
--- a/llvm/test/Transforms/Reassociate/basictest.ll
+++ b/llvm/test/Transforms/Reassociate/basictest.ll
@@ -293,3 +293,40 @@ define i32 @test17(i32 %X1, i32 %X2, i32 %X3, i32 %X4) {
   ret i32 %E
 }
 
+define i32 @test18(i32 %X1, i32 %X2) {
+; CHECK-LABEL: @test18(
+; CHECK-NEXT:    [[REASS_ADD:%.*]] = add nsw i32 [[X2:%.*]], [[X1:%.*]]
+; CHECK-NEXT:    [[REASS_MUL:%.*]] = mul nsw i32 [[REASS_ADD]], 47
+; CHECK-NEXT:    ret i32 [[REASS_MUL]]
+;
+  %B = mul nsw i32 %X1, 47
+  %C = mul nsw i32 %X2, 47
+  %D = add nsw i32 %B, %C
+  ret i32 %D
+}
+
+define i32 @test19(i32 %X1, i32 %X2) {
+; CHECK-LABEL: @test19(
+; CHECK-NEXT:    [[REASS_ADD:%.*]] = add i32 [[X1:%.*]], 67
+; CHECK-NEXT:    [[REASS_MUL:%.*]] = mul nsw i32 [[REASS_ADD]], [[X2:%.*]]
+; CHECK-NEXT:    ret i32 [[REASS_MUL]]
+;
+  %A = add i32 %X1, 20
+  %B = mul nsw i32 %X2, 47
+  %C = mul nsw i32 %X2, %A
+  %D = add nsw i32 %B, %C
+  ret i32 %D
+}
+
+define i32 @test20(i32 %X1, i32 %X2) {
+; CHECK-LABEL: @test20(
+; CHECK-NEXT:    [[REASS_ADD:%.*]] = add i32 [[X1:%.*]], 67
+; CHECK-NEXT:    [[REASS_MUL:%.*]] = mul i32 [[REASS_ADD]], [[X2:%.*]]
+; CHECK-NEXT:    ret i32 [[REASS_MUL]]
+;
+  %A = add i32 %X1, 20
+  %B = mul nuw i32 %X2, 47
+  %C = mul nsw i32 %X2, %A
+  %D = add nsw i32 %B, %C
+  ret i32 %D
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/125773


More information about the llvm-commits mailing list