[llvm] 20e8de9 - [InstCombine] Support nested GEPs in OptimizePointerDifference (#142958)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 10 00:28:18 PDT 2025
Author: Nikita Popov
Date: 2025-06-10T09:28:15+02:00
New Revision: 20e8de9c8f4cf54c6c57535428c987d921861034
URL: https://github.com/llvm/llvm-project/commit/20e8de9c8f4cf54c6c57535428c987d921861034
DIFF: https://github.com/llvm/llvm-project/commit/20e8de9c8f4cf54c6c57535428c987d921861034.diff
LOG: [InstCombine] Support nested GEPs in OptimizePointerDifference (#142958)
Currently OptimizePointerDifference() only handles single GEPs with a
common base, not GEP chains. This patch generalizes the support to
nested GEPs with a common base.
Finding the common base is a bit annoying because we want to stop as
soon as possible and not recurse into common GEP prefixes.
This helps avoids regressions from
https://github.com/llvm/llvm-project/pull/137297.
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
llvm/test/Transforms/InstCombine/icmp.ll
llvm/test/Transforms/InstCombine/sub-gep.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index a9ac5ff9b9c89..4b6958618557f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -2068,71 +2068,118 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
return nullptr;
}
+struct CommonBase {
+ /// Common base pointer.
+ Value *Ptr = nullptr;
+ /// LHS GEPs until common base.
+ SmallVector<GEPOperator *> LHSGEPs;
+ /// RHS GEPs until common base.
+ SmallVector<GEPOperator *> RHSGEPs;
+ /// LHS GEP NoWrapFlags until common base.
+ GEPNoWrapFlags LHSNW = GEPNoWrapFlags::all();
+ /// RHS GEP NoWrapFlags until common base.
+ GEPNoWrapFlags RHSNW = GEPNoWrapFlags::all();
+};
+
+static CommonBase computeCommonBase(Value *LHS, Value *RHS) {
+ CommonBase Base;
+
+ if (LHS->getType() != RHS->getType())
+ return Base;
+
+ // Collect all base pointers of LHS.
+ SmallPtrSet<Value *, 16> Ptrs;
+ Value *Ptr = LHS;
+ while (true) {
+ Ptrs.insert(Ptr);
+ if (auto *GEP = dyn_cast<GEPOperator>(Ptr))
+ Ptr = GEP->getPointerOperand();
+ else
+ break;
+ }
+
+ // Find common base and collect RHS GEPs.
+ while (true) {
+ if (Ptrs.contains(RHS)) {
+ if (LHS->getType() != RHS->getType())
+ return Base;
+ Base.Ptr = RHS;
+ break;
+ }
+
+ if (auto *GEP = dyn_cast<GEPOperator>(RHS)) {
+ Base.RHSGEPs.push_back(GEP);
+ Base.RHSNW &= GEP->getNoWrapFlags();
+ RHS = GEP->getPointerOperand();
+ } else {
+ // No common base.
+ return Base;
+ }
+ }
+
+ // Collect LHS GEPs.
+ while (true) {
+ if (LHS == Base.Ptr)
+ break;
+
+ auto *GEP = cast<GEPOperator>(LHS);
+ Base.LHSGEPs.push_back(GEP);
+ Base.LHSNW &= GEP->getNoWrapFlags();
+ LHS = GEP->getPointerOperand();
+ }
+
+ return Base;
+}
+
/// Optimize pointer
diff erences into the same array into a size. Consider:
/// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer
/// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
Type *Ty, bool IsNUW) {
- // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
- // this.
- bool Swapped = false;
- GEPOperator *GEP1 = nullptr, *GEP2 = nullptr;
- if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) {
- std::swap(LHS, RHS);
- Swapped = true;
- }
-
- // Require at least one GEP with a common base pointer on both sides.
- if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
- // (gep X, ...) - X
- if (LHSGEP->getOperand(0)->stripPointerCasts() ==
- RHS->stripPointerCasts()) {
- GEP1 = LHSGEP;
- } else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
- // (gep X, ...) - (gep X, ...)
- if (LHSGEP->getOperand(0)->stripPointerCasts() ==
- RHSGEP->getOperand(0)->stripPointerCasts()) {
- GEP1 = LHSGEP;
- GEP2 = RHSGEP;
- }
- }
- }
-
- if (!GEP1)
+ CommonBase Base = computeCommonBase(LHS, RHS);
+ if (!Base.Ptr)
return nullptr;
// To avoid duplicating the offset arithmetic, rewrite the GEP to use the
- // computed offset. This may erase the original GEP, so be sure to cache the
- // nowrap flags before emitting the offset.
+ // computed offset.
// TODO: We should probably do this even if there is only one GEP.
- bool RewriteGEPs = GEP2 != nullptr;
+ bool RewriteGEPs = !Base.LHSGEPs.empty() && !Base.RHSGEPs.empty();
+
+ Type *IdxTy = DL.getIndexType(Base.Ptr->getType());
+ auto EmitOffsetFromBase = [&](ArrayRef<GEPOperator *> GEPs) -> Value * {
+ Value *Sum = nullptr;
+ for (GEPOperator *GEP : reverse(GEPs)) {
+ Value *Offset = EmitGEPOffset(GEP, RewriteGEPs);
+ if (Sum)
+ Sum = Builder.CreateAdd(Sum, Offset);
+ else
+ Sum = Offset;
+ }
+ if (!Sum)
+ return Constant::getNullValue(IdxTy);
+ return Sum;
+ };
- // Emit the offset of the GEP and an intptr_t.
- GEPNoWrapFlags GEP1NW = GEP1->getNoWrapFlags();
- Value *Result = EmitGEPOffset(GEP1, RewriteGEPs);
+ Value *Result = EmitOffsetFromBase(Base.LHSGEPs);
+ Value *Offset2 = EmitOffsetFromBase(Base.RHSGEPs);
// If this is a single inbounds GEP and the original sub was nuw,
// then the final multiplication is also nuw.
if (auto *I = dyn_cast<Instruction>(Result))
- if (IsNUW && !GEP2 && !Swapped && GEP1NW.isInBounds() &&
+ if (IsNUW && match(Offset2, m_Zero()) && Base.LHSNW.isInBounds() &&
I->getOpcode() == Instruction::Mul)
I->setHasNoUnsignedWrap();
// If we have a 2nd GEP of the same base pointer, subtract the offsets.
// If both GEPs are inbounds, then the subtract does not have signed overflow.
// If both GEPs are nuw and the original sub is nuw, the new sub is also nuw.
- if (GEP2) {
- GEPNoWrapFlags GEP2NW = GEP2->getNoWrapFlags();
- Value *Offset = EmitGEPOffset(GEP2, RewriteGEPs);
- Result = Builder.CreateSub(Result, Offset, "gep
diff ",
- IsNUW && GEP1NW.hasNoUnsignedWrap() &&
- GEP2NW.hasNoUnsignedWrap(),
- GEP1NW.isInBounds() && GEP2NW.isInBounds());
- }
-
- // If we have p - gep(p, ...) then we have to negate the result.
- if (Swapped)
- Result = Builder.CreateNeg(Result, "
diff .neg");
+ if (!match(Offset2, m_Zero())) {
+ Result =
+ Builder.CreateSub(Result, Offset2, "gep
diff ",
+ IsNUW && Base.LHSNW.hasNoUnsignedWrap() &&
+ Base.RHSNW.hasNoUnsignedWrap(),
+ Base.LHSNW.isInBounds() && Base.RHSNW.isInBounds());
+ }
return Builder.CreateIntCast(Result, Ty, true);
}
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index f5df8573d6304..365c17b35a468 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -506,8 +506,7 @@ define <2 x i1> @test23vec(<2 x i32> %x) {
; unsigned overflow does not happen during offset computation
define i1 @test24_neg_offs(ptr %p, i64 %offs) {
; CHECK-LABEL: @test24_neg_offs(
-; CHECK-NEXT: [[P1_IDX_NEG:%.*]] = mul i64 [[OFFS:%.*]], -4
-; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[P1_IDX_NEG]], 8
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[OFFS:%.*]], -2
; CHECK-NEXT: ret i1 [[CMP]]
;
%p1 = getelementptr inbounds i32, ptr %p, i64 %offs
diff --git a/llvm/test/Transforms/InstCombine/sub-gep.ll b/llvm/test/Transforms/InstCombine/sub-gep.ll
index c86a1a37bd7ad..4a27b04f724d8 100644
--- a/llvm/test/Transforms/InstCombine/sub-gep.ll
+++ b/llvm/test/Transforms/InstCombine/sub-gep.ll
@@ -80,7 +80,7 @@ define i32 @test_inbounds_nuw_trunc(ptr %base, i64 %idx) {
define i64 @test_inbounds_nuw_swapped(ptr %base, i64 %idx) {
; CHECK-LABEL: @test_inbounds_nuw_swapped(
-; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
+; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
;
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
@@ -104,7 +104,7 @@ define i64 @test_inbounds1_nuw_swapped(ptr %base, i64 %idx) {
define i64 @test_inbounds2_nuw_swapped(ptr %base, i64 %idx) {
; CHECK-LABEL: @test_inbounds2_nuw_swapped(
-; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul i64 [[IDX:%.*]], -4
+; CHECK-NEXT: [[P2_IDX_NEG:%.*]] = mul nsw i64 [[IDX:%.*]], -4
; CHECK-NEXT: ret i64 [[P2_IDX_NEG]]
;
%p2 = getelementptr inbounds [0 x i32], ptr %base, i64 0, i64 %idx
@@ -279,8 +279,8 @@ define i16 @test24_as1(ptr addrspace(1) %P, i16 %A) {
define i64 @test24a(ptr %P, i64 %A){
; CHECK-LABEL: @test24a(
-; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i64 0, [[A:%.*]]
-; CHECK-NEXT: ret i64 [[DIFF_NEG]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i64 0, [[A:%.*]]
+; CHECK-NEXT: ret i64 [[GEPDIFF]]
;
%B = getelementptr inbounds i8, ptr %P, i64 %A
%C = ptrtoint ptr %B to i64
@@ -291,8 +291,8 @@ define i64 @test24a(ptr %P, i64 %A){
define i16 @test24a_as1(ptr addrspace(1) %P, i16 %A) {
; CHECK-LABEL: @test24a_as1(
-; CHECK-NEXT: [[DIFF_NEG:%.*]] = sub i16 0, [[A:%.*]]
-; CHECK-NEXT: ret i16 [[DIFF_NEG]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = sub nsw i16 0, [[A:%.*]]
+; CHECK-NEXT: ret i16 [[GEPDIFF]]
;
%B = getelementptr inbounds i8, ptr addrspace(1) %P, i16 %A
%C = ptrtoint ptr addrspace(1) %B to i16
@@ -860,3 +860,85 @@ _Z3fooPKc.exit:
%tobool = icmp eq i64 %2, 0
ret i1 %tobool
}
+
+define i64 @multiple_geps_one_chain(ptr %base, i64 %idx, i64 %idx2) {
+; CHECK-LABEL: @multiple_geps_one_chain(
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[D:%.*]] = shl i64 [[P2_IDX1]], 2
+; CHECK-NEXT: ret i64 [[D]]
+;
+ %p2 = getelementptr inbounds i32, ptr %base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %i1 = ptrtoint ptr %base to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i2, %i1
+ ret i64 %d
+}
+
+define i64 @multiple_geps_one_chain_commuted(ptr %base, i64 %idx, i64 %idx2) {
+; CHECK-LABEL: @multiple_geps_one_chain_commuted(
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[DOTNEG:%.*]] = mul i64 [[P2_IDX1]], -4
+; CHECK-NEXT: ret i64 [[DOTNEG]]
+;
+ %p2 = getelementptr inbounds i32, ptr %base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %i1 = ptrtoint ptr %base to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i1, %i2
+ ret i64 %d
+}
+
+define i64 @multiple_geps_two_chains(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
+; CHECK-LABEL: @multiple_geps_two_chains(
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
+; CHECK-NEXT: ret i64 [[GEPDIFF]]
+;
+ %p2 = getelementptr inbounds i32, ptr %base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %p4 = getelementptr inbounds i32, ptr %base, i64 %idx3
+ %i1 = ptrtoint ptr %p4 to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i2, %i1
+ ret i64 %d
+}
+
+define i64 @multiple_geps_two_chains_commuted(ptr %base, i64 %idx, i64 %idx2, i64 %idx3) {
+; CHECK-LABEL: @multiple_geps_two_chains_commuted(
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[IDX3:%.*]], [[P2_IDX1]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
+; CHECK-NEXT: ret i64 [[GEPDIFF]]
+;
+ %p2 = getelementptr inbounds i32, ptr %base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %p4 = getelementptr inbounds i32, ptr %base, i64 %idx3
+ %i1 = ptrtoint ptr %p4 to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i1, %i2
+ ret i64 %d
+}
+
+declare void @use(ptr)
+
+define i64 @multiple_geps_two_chains_gep_base(ptr %base, i64 %base.idx, i64 %idx, i64 %idx2, i64 %idx3) {
+; CHECK-LABEL: @multiple_geps_two_chains_gep_base(
+; CHECK-NEXT: [[GEP_BASE:%.*]] = getelementptr inbounds i32, ptr [[BASE:%.*]], i64 [[BASE_IDX:%.*]]
+; CHECK-NEXT: call void @use(ptr [[GEP_BASE]])
+; CHECK-NEXT: [[P2_IDX1:%.*]] = add i64 [[IDX:%.*]], [[IDX2:%.*]]
+; CHECK-NEXT: [[TMP1:%.*]] = sub i64 [[P2_IDX1]], [[IDX3:%.*]]
+; CHECK-NEXT: [[GEPDIFF:%.*]] = shl i64 [[TMP1]], 2
+; CHECK-NEXT: ret i64 [[GEPDIFF]]
+;
+ %gep.base = getelementptr inbounds i32, ptr %base, i64 %base.idx
+ call void @use(ptr %gep.base)
+ %p2 = getelementptr inbounds i32, ptr %gep.base, i64 %idx
+ %p3 = getelementptr inbounds i32, ptr %p2, i64 %idx2
+ %p4 = getelementptr inbounds i32, ptr %gep.base, i64 %idx3
+ %i1 = ptrtoint ptr %p4 to i64
+ %i2 = ptrtoint ptr %p3 to i64
+ %d = sub i64 %i2, %i1
+ ret i64 %d
+}
More information about the llvm-commits
mailing list