[llvm] [InstCombine] Support nested GEPs in OptimizePointerDifference (PR #142958)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 6 01:53:12 PDT 2025


================
@@ -2068,71 +2068,118 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
   return nullptr;
 }
 
+struct CommonBase {
+  /// Common base pointer.
+  Value *Ptr = nullptr;
+  /// LHS GEPs until common base.
+  SmallVector<GEPOperator *> LHSGEPs;
+  /// RHS GEPs until common base.
+  SmallVector<GEPOperator *> RHSGEPs;
+  /// LHS GEP NoWrapFlags until common base.
+  GEPNoWrapFlags LHSNW = GEPNoWrapFlags::all();
+  /// RHS GEP NoWrapFlags until common base.
+  GEPNoWrapFlags RHSNW = GEPNoWrapFlags::all();
+};
+
+static CommonBase computeCommonBase(Value *LHS, Value *RHS) {
+  CommonBase Base;
+
+  if (LHS->getType() != RHS->getType())
+    return Base;
+
+  // Collect all base pointers of LHS.
+  SmallPtrSet<Value *, 16> Ptrs;
+  Value *Ptr = LHS;
+  while (true) {
+    Ptrs.insert(Ptr);
+    if (auto *GEP = dyn_cast<GEPOperator>(Ptr))
+      Ptr = GEP->getPointerOperand();
+    else
+      break;
+  }
+
+  // Find common base and collect RHS GEPs.
+  while (true) {
+    if (Ptrs.contains(RHS)) {
+      if (LHS->getType() != RHS->getType())
+        return Base;
+      Base.Ptr = RHS;
+      break;
+    }
+
+    if (auto *GEP = dyn_cast<GEPOperator>(RHS)) {
+      Base.RHSGEPs.push_back(GEP);
+      Base.RHSNW &= GEP->getNoWrapFlags();
+      RHS = GEP->getPointerOperand();
+    } else {
+      // No common base.
+      return Base;
+    }
+  }
+
+  // Collect LHS GEPs.
+  while (true) {
+    if (LHS == Base.Ptr)
+      break;
+
+    auto *GEP = cast<GEPOperator>(LHS);
+    Base.LHSGEPs.push_back(GEP);
+    Base.LHSNW &= GEP->getNoWrapFlags();
+    LHS = GEP->getPointerOperand();
+  }
+
+  return Base;
+}
+
 /// Optimize pointer differences into the same array into a size.  Consider:
 ///  &A[10] - &A[0]: we should compile this to "10".  LHS/RHS are the pointer
 /// operands to the ptrtoint instructions for the LHS/RHS of the subtract.
 Value *InstCombinerImpl::OptimizePointerDifference(Value *LHS, Value *RHS,
                                                    Type *Ty, bool IsNUW) {
-  // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize
-  // this.
-  bool Swapped = false;
-  GEPOperator *GEP1 = nullptr, *GEP2 = nullptr;
-  if (!isa<GEPOperator>(LHS) && isa<GEPOperator>(RHS)) {
-    std::swap(LHS, RHS);
-    Swapped = true;
-  }
-
-  // Require at least one GEP with a common base pointer on both sides.
-  if (auto *LHSGEP = dyn_cast<GEPOperator>(LHS)) {
-    // (gep X, ...) - X
-    if (LHSGEP->getOperand(0)->stripPointerCasts() ==
-        RHS->stripPointerCasts()) {
-      GEP1 = LHSGEP;
-    } else if (auto *RHSGEP = dyn_cast<GEPOperator>(RHS)) {
-      // (gep X, ...) - (gep X, ...)
-      if (LHSGEP->getOperand(0)->stripPointerCasts() ==
-          RHSGEP->getOperand(0)->stripPointerCasts()) {
-        GEP1 = LHSGEP;
-        GEP2 = RHSGEP;
-      }
-    }
-  }
-
-  if (!GEP1)
+  CommonBase Base = computeCommonBase(LHS, RHS);
+  if (!Base.Ptr)
     return nullptr;
 
   // To avoid duplicating the offset arithmetic, rewrite the GEP to use the
-  // computed offset. This may erase the original GEP, so be sure to cache the
-  // nowrap flags before emitting the offset.
+  // computed offset.
   // TODO: We should probably do this even if there is only one GEP.
-  bool RewriteGEPs = GEP2 != nullptr;
+  bool RewriteGEPs = !Base.LHSGEPs.empty() && !Base.RHSGEPs.empty();
+
+  Type *IdxTy = DL.getIndexType(Base.Ptr->getType());
+  auto EmitOffsetFromBase = [&](ArrayRef<GEPOperator *> GEPs) -> Value * {
+    Value *Sum = nullptr;
+    for (GEPOperator *GEP : reverse(GEPs)) {
+      Value *Offset = EmitGEPOffset(GEP, RewriteGEPs);
+      if (Sum)
+        Sum = Builder.CreateAdd(Sum, Offset);
+      else
+        Sum = Offset;
+    }
+    if (!Sum)
+      return Constant::getNullValue(IdxTy);
+    return Sum;
+  };
 
-  // Emit the offset of the GEP and an intptr_t.
-  GEPNoWrapFlags GEP1NW = GEP1->getNoWrapFlags();
-  Value *Result = EmitGEPOffset(GEP1, RewriteGEPs);
+  Value *Result = EmitOffsetFromBase(Base.LHSGEPs);
+  Value *Offset2 = EmitOffsetFromBase(Base.RHSGEPs);
 
   // If this is a single inbounds GEP and the original sub was nuw,
   // then the final multiplication is also nuw.
   if (auto *I = dyn_cast<Instruction>(Result))
-    if (IsNUW && !GEP2 && !Swapped && GEP1NW.isInBounds() &&
+    if (IsNUW && match(Offset2, m_Zero()) && Base.LHSNW.isInBounds() &&
----------------
nikic wrote:

(Based on the proof, I assume this is commenting on the isInBounds use below, not this one.)

I think it's easier to understand if you consider the `(ptradd(p, a) - ptradd(p, b))` case. With nusw, if p is sitting in the middle of the address space, you could have a as a large positive value and b as a large negative one, with overflow if you subtract them. With inbounds, it is guaranteed that the distance between a and b cannot exceed half the address space.

https://github.com/llvm/llvm-project/pull/142958


More information about the llvm-commits mailing list