[llvm] [Reassociate] Conserve nsw/nuw flags when factoring out (PR #125773)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 4 14:32:38 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: None (joe-rivos)
<details>
<summary>Changes</summary>
When transforming, e.g, A*A+A*B*C+D into A*(A+B*C)+D, we can set the
factored out mul to have nuw/nsw iff all of the adds and muls had this
flag set.
---
Full diff: https://github.com/llvm/llvm-project/pull/125773.diff
3 Files Affected:
- (modified) llvm/include/llvm/Transforms/Scalar/Reassociate.h (+4-2)
- (modified) llvm/lib/Transforms/Scalar/Reassociate.cpp (+42-11)
- (modified) llvm/test/Transforms/Reassociate/basictest.ll (+37)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
index 3b2d2b83ced623..8023385c2ad903 100644
--- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h
+++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
@@ -119,9 +119,11 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
SmallVectorImpl<reassociate::ValueEntry> &Ops,
reassociate::OverflowTracking Flags);
Value *OptimizeExpression(BinaryOperator *I,
- SmallVectorImpl<reassociate::ValueEntry> &Ops);
+ SmallVectorImpl<reassociate::ValueEntry> &Ops,
+ reassociate::OverflowTracking Flags);
Value *OptimizeAdd(Instruction *I,
- SmallVectorImpl<reassociate::ValueEntry> &Ops);
+ SmallVectorImpl<reassociate::ValueEntry> &Ops,
+ reassociate::OverflowTracking Flags);
Value *OptimizeXor(Instruction *I,
SmallVectorImpl<reassociate::ValueEntry> &Ops);
bool CombineXorOpnd(BasicBlock::iterator It, reassociate::XorOpnd *Opnd1,
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index 7cb9bace47bf44..ceae4e1957cf27 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -1174,16 +1174,21 @@ Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) {
///
/// Ops is the top-level list of add operands we're trying to factor.
static void FindSingleUseMultiplyFactors(Value *V,
- SmallVectorImpl<Value*> &Factors) {
+ SmallVectorImpl<Value *> &Factors,
+ bool &AllNUW, bool &AllNSW) {
BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul);
if (!BO) {
Factors.push_back(V);
return;
}
+ if (isa<OverflowingBinaryOperator>(BO)) {
+ AllNUW &= BO->hasNoUnsignedWrap();
+ AllNSW &= BO->hasNoSignedWrap();
+ }
// Otherwise, add the LHS and RHS to the list of factors.
- FindSingleUseMultiplyFactors(BO->getOperand(1), Factors);
- FindSingleUseMultiplyFactors(BO->getOperand(0), Factors);
+ FindSingleUseMultiplyFactors(BO->getOperand(1), Factors, AllNUW, AllNSW);
+ FindSingleUseMultiplyFactors(BO->getOperand(0), Factors, AllNUW, AllNSW);
}
/// Optimize a series of operands to an 'and', 'or', or 'xor' instruction.
@@ -1492,7 +1497,9 @@ Value *ReassociatePass::OptimizeXor(Instruction *I,
/// optimizes based on identities. If it can be reduced to a single Value, it
/// is returned, otherwise the Ops list is mutated as necessary.
Value *ReassociatePass::OptimizeAdd(Instruction *I,
- SmallVectorImpl<ValueEntry> &Ops) {
+ SmallVectorImpl<ValueEntry> &Ops,
+ OverflowTracking Flags) {
+
// Scan the operand lists looking for X and -X pairs. If we find any, we
// can simplify expressions like X+-X == 0 and X+~X ==-1. While we're at it,
// scan for any
@@ -1586,8 +1593,11 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
// Keep track of each multiply we see, to avoid triggering on (X*4)+(X*4)
// where they are actually the same multiply.
+ // Also track every use of this factor shares nuw/nsw. This will allow the
+ // use of these flags in the factored value.
unsigned MaxOcc = 0;
Value *MaxOccVal = nullptr;
+ bool MaxOccAllNUW, MaxOccAllNSW = false;
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
BinaryOperator *BOp =
isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul);
@@ -1596,7 +1606,9 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
// Compute all of the factors of this added value.
SmallVector<Value*, 8> Factors;
- FindSingleUseMultiplyFactors(BOp, Factors);
+ bool AllNUW = Flags.HasNUW;
+ bool AllNSW = Flags.HasNSW;
+ FindSingleUseMultiplyFactors(BOp, Factors, AllNUW, AllNSW);
assert(Factors.size() > 1 && "Bad linearize!");
// Add one to FactorOccurrences for each unique factor in this op.
@@ -1608,7 +1620,16 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
unsigned Occ = ++FactorOccurrences[Factor];
if (Occ > MaxOcc) {
MaxOcc = Occ;
- MaxOccVal = Factor;
+ if (MaxOccVal != Factor) {
+ MaxOccVal = Factor;
+ if (Occ == 1) {
+ MaxOccAllNUW = AllNUW;
+ MaxOccAllNSW = AllNSW;
+ } else {
+ MaxOccAllNUW &= AllNUW;
+ MaxOccAllNSW &= AllNSW;
+ }
+ }
}
// If Factor is a negative constant, add the negated value as a factor
@@ -1690,11 +1711,20 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I,
// A*A*B + A*A*C --> A*(A*B+A*C) --> A*(A*(B+C))
assert(NumAddedValues > 1 && "Each occurrence should contribute a value");
(void)NumAddedValues;
- if (Instruction *VI = dyn_cast<Instruction>(V))
+ if (Instruction *VI = dyn_cast<Instruction>(V)) {
+ if (isa<OverflowingBinaryOperator>(VI)) {
+ VI->setHasNoUnsignedWrap(MaxOccAllNUW);
+ VI->setHasNoSignedWrap(MaxOccAllNSW);
+ }
RedoInsts.insert(VI);
+ }
// Create the multiply.
Instruction *V2 = CreateMul(V, MaxOccVal, "reass.mul", I->getIterator(), I);
+ if (isa<OverflowingBinaryOperator>(V2)) {
+ V2->setHasNoUnsignedWrap(MaxOccAllNUW);
+ V2->setHasNoSignedWrap(MaxOccAllNSW);
+ }
// Rerun associate on the multiply in case the inner expression turned into
// a multiply. We want to make sure that we keep things in canonical form.
@@ -1890,7 +1920,8 @@ Value *ReassociatePass::OptimizeMul(BinaryOperator *I,
}
Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
- SmallVectorImpl<ValueEntry> &Ops) {
+ SmallVectorImpl<ValueEntry> &Ops,
+ OverflowTracking Flags) {
// Now that we have the linearized expression tree, try to optimize it.
// Start by folding any constants that we found.
const DataLayout &DL = I->getDataLayout();
@@ -1944,7 +1975,7 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
case Instruction::Add:
case Instruction::FAdd:
- if (Value *Result = OptimizeAdd(I, Ops))
+ if (Value *Result = OptimizeAdd(I, Ops, Flags))
return Result;
break;
@@ -1956,7 +1987,7 @@ Value *ReassociatePass::OptimizeExpression(BinaryOperator *I,
}
if (Ops.size() != NumOps)
- return OptimizeExpression(I, Ops);
+ return OptimizeExpression(I, Ops, Flags);
return nullptr;
}
@@ -2305,7 +2336,7 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) {
// Now that we have the expression tree in a convenient
// sorted form, optimize it globally if possible.
- if (Value *V = OptimizeExpression(I, Ops)) {
+ if (Value *V = OptimizeExpression(I, Ops, Flags)) {
if (V == I)
// Self-referential expression in unreachable code.
return;
diff --git a/llvm/test/Transforms/Reassociate/basictest.ll b/llvm/test/Transforms/Reassociate/basictest.ll
index 3f4057dd14e7e1..b5f87a8202185c 100644
--- a/llvm/test/Transforms/Reassociate/basictest.ll
+++ b/llvm/test/Transforms/Reassociate/basictest.ll
@@ -293,3 +293,40 @@ define i32 @test17(i32 %X1, i32 %X2, i32 %X3, i32 %X4) {
ret i32 %E
}
+define i32 @test18(i32 %X1, i32 %X2) {
+; CHECK-LABEL: @test18(
+; CHECK-NEXT: [[REASS_ADD:%.*]] = add nsw i32 [[X2:%.*]], [[X1:%.*]]
+; CHECK-NEXT: [[REASS_MUL:%.*]] = mul nsw i32 [[REASS_ADD]], 47
+; CHECK-NEXT: ret i32 [[REASS_MUL]]
+;
+ %B = mul nsw i32 %X1, 47
+ %C = mul nsw i32 %X2, 47
+ %D = add nsw i32 %B, %C
+ ret i32 %D
+}
+
+define i32 @test19(i32 %X1, i32 %X2) {
+; CHECK-LABEL: @test19(
+; CHECK-NEXT: [[REASS_ADD:%.*]] = add i32 [[X1:%.*]], 67
+; CHECK-NEXT: [[REASS_MUL:%.*]] = mul nsw i32 [[REASS_ADD]], [[X2:%.*]]
+; CHECK-NEXT: ret i32 [[REASS_MUL]]
+;
+ %A = add i32 %X1, 20
+ %B = mul nsw i32 %X2, 47
+ %C = mul nsw i32 %X2, %A
+ %D = add nsw i32 %B, %C
+ ret i32 %D
+}
+
+define i32 @test20(i32 %X1, i32 %X2) {
+; CHECK-LABEL: @test20(
+; CHECK-NEXT: [[REASS_ADD:%.*]] = add i32 [[X1:%.*]], 67
+; CHECK-NEXT: [[REASS_MUL:%.*]] = mul i32 [[REASS_ADD]], [[X2:%.*]]
+; CHECK-NEXT: ret i32 [[REASS_MUL]]
+;
+ %A = add i32 %X1, 20
+ %B = mul nuw i32 %X2, 47
+ %C = mul nsw i32 %X2, %A
+ %D = add nsw i32 %B, %C
+ ret i32 %D
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/125773
More information about the llvm-commits
mailing list