[llvm] [InstCombine] Support nested GEPs in OptimizePointerDifference (PR #142958)
Nikita Popov via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 6 01:53:12 PDT 2025
================
@@ -2068,71 +2068,118 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
return nullptr;
}
+struct CommonBase {
+ /// Common base pointer.
+ Value *Ptr = nullptr;
+ /// LHS GEPs until common base.
+ SmallVector<GEPOperator *> LHSGEPs;
+ /// RHS GEPs until common base.
+ SmallVector<GEPOperator *> RHSGEPs;
+ /// LHS GEP NoWrapFlags until common base.
+ GEPNoWrapFlags LHSNW = GEPNoWrapFlags::all();
+ /// RHS GEP NoWrapFlags until common base.
+ GEPNoWrapFlags RHSNW = GEPNoWrapFlags::all();
+};
+
+static CommonBase computeCommonBase(Value *LHS, Value *RHS) {
+ CommonBase Base;
+
+ if (LHS->getType() != RHS->getType())
+ return Base;
+
+ // Collect all base pointers of LHS.
+ SmallPtrSet<Value *, 16> Ptrs;
+ Value *Ptr = LHS;
+ while (true) {
+ Ptrs.insert(Ptr);
+ if (auto *GEP = dyn_cast<GEPOperator>(Ptr))
+ Ptr = GEP->getPointerOperand();
+ else
+ break;
+ }
+
+ // Find common base and collect RHS GEPs.
+ while (true) {
+ if (Ptrs.contains(RHS)) {
+ if (LHS->getType() != RHS->getType())
+ return Base;
+ Base.Ptr = RHS;
+ break;
+ }
+
+ if (auto *GEP = dyn_cast<GEPOperator>(RHS)) {
+ Base.RHSGEPs.push_back(GEP);
+ Base.RHSNW &= GEP->getNoWrapFlags();
+ RHS = GEP->getPointerOperand();
+ } else {
+ // No common base.
+ return Base;
+ }
+ }
+
+ // Collect LHS GEPs.
+ while (true) {
+ if (LHS == Base.Ptr)
+ break;
+
+ auto *GEP = cast<GEPOperator>(LHS);
+ Base.LHSGEPs.push_back(GEP);
+ Base.LHSNW &= GEP->getNoWrapFlags();
+ LHS = GEP->getPointerOperand();
+ }
+
+ return Base;
+}
+
/// Optimize pointer differences into the same array into a size. Consider:
/// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer
/// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
Type *Ty, bool IsNUW) {
- // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
- // this.
- bool Swapped = false;
- GEPOperator *GEP1 = nullptr, *GEP2 = nullptr;
- if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) {
- std::swap(LHS, RHS);
- Swapped = true;
- }
-
- // Require at least one GEP with a common base pointer on both sides.
- if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
- // (gep X, ...) - X
- if (LHSGEP->getOperand(0)->stripPointerCasts() ==
- RHS->stripPointerCasts()) {
- GEP1 = LHSGEP;
- } else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
- // (gep X, ...) - (gep X, ...)
- if (LHSGEP->getOperand(0)->stripPointerCasts() ==
- RHSGEP->getOperand(0)->stripPointerCasts()) {
- GEP1 = LHSGEP;
- GEP2 = RHSGEP;
- }
- }
- }
-
- if (!GEP1)
+ CommonBase Base = computeCommonBase(LHS, RHS);
+ if (!Base.Ptr)
return nullptr;
// To avoid duplicating the offset arithmetic, rewrite the GEP to use the
- // computed offset. This may erase the original GEP, so be sure to cache the
- // nowrap flags before emitting the offset.
+ // computed offset.
// TODO: We should probably do this even if there is only one GEP.
- bool RewriteGEPs = GEP2 != nullptr;
+ bool RewriteGEPs = !Base.LHSGEPs.empty() && !Base.RHSGEPs.empty();
+
+ Type *IdxTy = DL.getIndexType(Base.Ptr->getType());
+ auto EmitOffsetFromBase = [&](ArrayRef<GEPOperator *> GEPs) -> Value * {
+ Value *Sum = nullptr;
+ for (GEPOperator *GEP : reverse(GEPs)) {
+ Value *Offset = EmitGEPOffset(GEP, RewriteGEPs);
+ if (Sum)
+ Sum = Builder.CreateAdd(Sum, Offset);
+ else
+ Sum = Offset;
+ }
+ if (!Sum)
+ return Constant::getNullValue(IdxTy);
+ return Sum;
+ };
- // Emit the offset of the GEP and an intptr_t.
- GEPNoWrapFlags GEP1NW = GEP1->getNoWrapFlags();
- Value *Result = EmitGEPOffset(GEP1, RewriteGEPs);
+ Value *Result = EmitOffsetFromBase(Base.LHSGEPs);
+ Value *Offset2 = EmitOffsetFromBase(Base.RHSGEPs);
// If this is a single inbounds GEP and the original sub was nuw,
// then the final multiplication is also nuw.
if (auto *I = dyn_cast<Instruction>(Result))
- if (IsNUW && !GEP2 && !Swapped && GEP1NW.isInBounds() &&
+ if (IsNUW && match(Offset2, m_Zero()) && Base.LHSNW.isInBounds() &&
----------------
nikic wrote:
(Based on the proof, I assume this is commenting on the isInBounds use below, not this one.)
I think it's easier to understand if you consider the `(ptradd(p, a) - ptradd(p, b))` case. With nusw, if p is sitting in the middle of the address space, you could have a as a large positive value and b as a large negative one, with overflow if you subtract them. With inbounds, it is guaranteed that the distance between a and b cannot exceed half the address space.
https://github.com/llvm/llvm-project/pull/142958
More information about the llvm-commits
mailing list