[llvm] [Reassociate] Move Disjoint flag handling to OverflowTracking. (PR #140406)
via llvm-commits
llvm-commits at lists.llvm.org
Sat May 17 14:25:06 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Florian Hahn (fhahn)
<details>
<summary>Changes</summary>
Move disjoint flag tracking to OverflowTracking. This enables preserving
disjoint flags in Reassociate.
Depends on https://github.com/llvm/llvm-project/pull/140404
---
Full diff: https://github.com/llvm/llvm-project/pull/140406.diff
7 Files Affected:
- (modified) llvm/include/llvm/Transforms/Scalar/Reassociate.h (+2-12)
- (modified) llvm/include/llvm/Transforms/Utils/Local.h (+22)
- (modified) llvm/lib/Transforms/Scalar/LICM.cpp (+10-14)
- (modified) llvm/lib/Transforms/Scalar/Reassociate.cpp (+3-14)
- (modified) llvm/lib/Transforms/Utils/Local.cpp (+22)
- (modified) llvm/test/Transforms/LICM/hoist-binop.ll (+2-2)
- (modified) llvm/test/Transforms/Reassociate/or-disjoint.ll (+2-2)
``````````diff
diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
index 23b70164d96a4..a5d137661e11e 100644
--- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h
+++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h
@@ -39,6 +39,7 @@ class Function;
class Instruction;
class IRBuilderBase;
class Value;
+struct OverflowTracking;
/// A private "module" namespace for types and utilities used by Reassociate.
/// These are implementation details and should not be used by clients.
@@ -64,17 +65,6 @@ struct Factor {
Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {}
};
-struct OverflowTracking {
- bool HasNUW = true;
- bool HasNSW = true;
- bool AllKnownNonNegative = true;
- bool AllKnownNonZero = true;
- // Note: AllKnownNonNegative can be true in a case where one of the operands
- // is negative, but one the operators is not NSW. AllKnownNonNegative should
- // not be used independently of HasNSW
- OverflowTracking() = default;
-};
-
class XorOpnd;
} // end namespace reassociate
@@ -115,7 +105,7 @@ class ReassociatePass : public PassInfoMixin<ReassociatePass> {
void ReassociateExpression(BinaryOperator *I);
void RewriteExprTree(BinaryOperator *I,
SmallVectorImpl<reassociate::ValueEntry> &Ops,
- reassociate::OverflowTracking Flags);
+ OverflowTracking Flags);
Value *OptimizeExpression(BinaryOperator *I,
SmallVectorImpl<reassociate::ValueEntry> &Ops);
Value *OptimizeAdd(Instruction *I,
diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h
index db064e1f41f02..6de8be3bccaa2 100644
--- a/llvm/include/llvm/Transforms/Utils/Local.h
+++ b/llvm/include/llvm/Transforms/Utils/Local.h
@@ -556,6 +556,28 @@ Value *invertCondition(Value *Condition);
/// function, explicitly materialize the maximal set in the IR.
bool inferAttributesFromOthers(Function &F);
+//===----------------------------------------------------------------------===//
+// Helpers to track and update flags on instructions.
+//
+
+struct OverflowTracking {
+ bool HasNUW = true;
+ bool HasNSW = true;
+ bool IsDisjoint = true;
+ bool AllKnownNonNegative = true;
+ bool AllKnownNonZero = true;
+ // Note: AllKnownNonNegative can be true in a case where one of the operands
+ // is negative, but one the operators is not NSW. AllKnownNonNegative should
+ // not be used independently of HasNSW
+ OverflowTracking() = default;
+
+ /// Merge in the no-wrap flags from \p I.
+ void mergeFlags(Instruction &I);
+
+ /// Apply the no-wrap flags to \p I if applicable.
+ void applyFlags(Instruction &I);
+};
+
} // end namespace llvm
#endif // LLVM_TRANSFORMS_UTILS_LOCAL_H
diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp
index 7d89a13fa3bab..006a09b38bc71 100644
--- a/llvm/lib/Transforms/Scalar/LICM.cpp
+++ b/llvm/lib/Transforms/Scalar/LICM.cpp
@@ -2864,26 +2864,22 @@ static bool hoistBOAssociation(Instruction &I, Loop &L,
auto *NewBO = BinaryOperator::Create(
Opcode, LV, Inv, BO->getName() + ".reass", BO->getIterator());
- // Copy NUW for ADDs if both instructions have it.
- if (Opcode == Instruction::Add && BO->hasNoUnsignedWrap() &&
- BO0->hasNoUnsignedWrap()) {
- // If `Inv` was not constant-folded, a new Instruction has been created.
- if (auto *I = dyn_cast<Instruction>(Inv))
- I->setHasNoUnsignedWrap(true);
- NewBO->setHasNoUnsignedWrap(true);
- } else if (Opcode == Instruction::FAdd || Opcode == Instruction::FMul) {
+ if (Opcode == Instruction::FAdd || Opcode == Instruction::FMul) {
// Intersect FMF flags for FADD and FMUL.
FastMathFlags Intersect = BO->getFastMathFlags() & BO0->getFastMathFlags();
if (auto *I = dyn_cast<Instruction>(Inv))
I->setFastMathFlags(Intersect);
NewBO->setFastMathFlags(Intersect);
- } else if (Opcode == Instruction::Or) {
- bool Disjoint = cast<PossiblyDisjointInst>(BO)->isDisjoint() &&
- cast<PossiblyDisjointInst>(BO0)->isDisjoint();
+ } else {
+ OverflowTracking Flags;
+ Flags.AllKnownNonNegative = false;
+ Flags.AllKnownNonZero = false;
+ Flags.mergeFlags(*BO);
+ Flags.mergeFlags(*BO0);
// If `Inv` was not constant-folded, a new Instruction has been created.
- if (auto *I = dyn_cast<PossiblyDisjointInst>(Inv))
- I->setIsDisjoint(Disjoint);
- cast<PossiblyDisjointInst>(NewBO)->setIsDisjoint(Disjoint);
+ if (auto *I = dyn_cast<Instruction>(Inv))
+ Flags.applyFlags(*I);
+ Flags.applyFlags(*NewBO);
}
BO->replaceAllUsesWith(NewBO);
diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp
index cb7a9ef9b6711..778a6a012556b 100644
--- a/llvm/lib/Transforms/Scalar/Reassociate.cpp
+++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp
@@ -382,7 +382,7 @@ using RepeatedValue = std::pair<Value *, uint64_t>;
static bool LinearizeExprTree(Instruction *I,
SmallVectorImpl<RepeatedValue> &Ops,
ReassociatePass::OrderedSet &ToRedo,
- reassociate::OverflowTracking &Flags) {
+ OverflowTracking &Flags) {
assert((isa<UnaryOperator>(I) || isa<BinaryOperator>(I)) &&
"Expected a UnaryOperator or BinaryOperator!");
LLVM_DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
@@ -431,10 +431,7 @@ static bool LinearizeExprTree(Instruction *I,
// We examine the operands of this binary operator.
auto [I, Weight] = Worklist.pop_back_val();
- if (isa<OverflowingBinaryOperator>(I)) {
- Flags.HasNUW &= I->hasNoUnsignedWrap();
- Flags.HasNSW &= I->hasNoSignedWrap();
- }
+ Flags.mergeFlags(*I);
for (unsigned OpIdx = 0; OpIdx < I->getNumOperands(); ++OpIdx) { // Visit operands.
Value *Op = I->getOperand(OpIdx);
@@ -734,15 +731,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I,
ExpressionChangedStart->clearSubclassOptionalData();
ExpressionChangedStart->setFastMathFlags(Flags);
} else {
- ExpressionChangedStart->clearSubclassOptionalData();
- if (ExpressionChangedStart->getOpcode() == Instruction::Add ||
- (ExpressionChangedStart->getOpcode() == Instruction::Mul &&
- Flags.AllKnownNonZero)) {
- if (Flags.HasNUW)
- ExpressionChangedStart->setHasNoUnsignedWrap();
- if (Flags.HasNSW && (Flags.AllKnownNonNegative || Flags.HasNUW))
- ExpressionChangedStart->setHasNoSignedWrap();
- }
+ Flags.applyFlags(*ExpressionChangedStart);
}
}
diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp
index 3dbd605e19c3a..69dcd30d1af99 100644
--- a/llvm/lib/Transforms/Utils/Local.cpp
+++ b/llvm/lib/Transforms/Utils/Local.cpp
@@ -4362,3 +4362,25 @@ bool llvm::inferAttributesFromOthers(Function &F) {
return Changed;
}
+
+void OverflowTracking::mergeFlags(Instruction &I) {
+ if (isa<OverflowingBinaryOperator>(&I)) {
+ HasNUW &= I.hasNoUnsignedWrap();
+ HasNSW &= I.hasNoSignedWrap();
+ }
+ if (auto *DisjointOp = dyn_cast<PossiblyDisjointInst>(&I))
+ IsDisjoint &= DisjointOp->isDisjoint();
+}
+
+void OverflowTracking::applyFlags(Instruction &I) {
+ I.clearSubclassOptionalData();
+ if (I.getOpcode() == Instruction::Add ||
+ (I.getOpcode() == Instruction::Mul && AllKnownNonZero)) {
+ if (HasNUW)
+ I.setHasNoUnsignedWrap();
+ if (HasNSW && (AllKnownNonNegative || HasNUW))
+ I.setHasNoSignedWrap();
+ }
+ if (auto *DisjointOp = dyn_cast<PossiblyDisjointInst>(&I))
+ DisjointOp->setIsDisjoint(IsDisjoint);
+}
diff --git a/llvm/test/Transforms/LICM/hoist-binop.ll b/llvm/test/Transforms/LICM/hoist-binop.ll
index 33161090a8ccf..1b1347776fb9e 100644
--- a/llvm/test/Transforms/LICM/hoist-binop.ll
+++ b/llvm/test/Transforms/LICM/hoist-binop.ll
@@ -371,13 +371,13 @@ loop:
define void @add_nuw_nsw(i64 %c1, i64 %c2) {
; CHECK-LABEL: @add_nuw_nsw(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[INVARIANT_OP:%.*]] = add nuw i64 [[C1:%.*]], [[C2:%.*]]
+; CHECK-NEXT: [[INVARIANT_OP:%.*]] = add nuw nsw i64 [[C1:%.*]], [[C2:%.*]]
; CHECK-NEXT: br label [[LOOP:%.*]]
; CHECK: loop:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT_REASS:%.*]], [[LOOP]] ]
; CHECK-NEXT: [[STEP_ADD:%.*]] = add nuw nsw i64 [[INDEX]], [[C1]]
; CHECK-NEXT: call void @use(i64 [[STEP_ADD]])
-; CHECK-NEXT: [[INDEX_NEXT_REASS]] = add nuw i64 [[INDEX]], [[INVARIANT_OP]]
+; CHECK-NEXT: [[INDEX_NEXT_REASS]] = add nuw nsw i64 [[INDEX]], [[INVARIANT_OP]]
; CHECK-NEXT: br label [[LOOP]]
;
entry:
diff --git a/llvm/test/Transforms/Reassociate/or-disjoint.ll b/llvm/test/Transforms/Reassociate/or-disjoint.ll
index 777836ed98152..b060b94e01d69 100644
--- a/llvm/test/Transforms/Reassociate/or-disjoint.ll
+++ b/llvm/test/Transforms/Reassociate/or-disjoint.ll
@@ -4,8 +4,8 @@
define i16 @or_disjoint_both(i16 %a, i16 %b) {
; CHECK-LABEL: @or_disjoint_both(
-; CHECK-NEXT: [[OR_1:%.*]] = or i16 [[A:%.*]], 1
-; CHECK-NEXT: [[OR_2:%.*]] = or i16 [[OR_1]], [[B:%.*]]
+; CHECK-NEXT: [[OR_1:%.*]] = or disjoint i16 [[A:%.*]], 1
+; CHECK-NEXT: [[OR_2:%.*]] = or disjoint i16 [[OR_1]], [[B:%.*]]
; CHECK-NEXT: ret i16 [[OR_2]]
;
%or.1 = or disjoint i16 %b, %a
``````````
</details>
https://github.com/llvm/llvm-project/pull/140406
More information about the llvm-commits
mailing list