[llvm] Feat/sink gep constant offset (PR #140027)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Thu May 15 13:19:21 PDT 2025


================
@@ -1344,6 +1362,133 @@ bool SeparateConstOffsetFromGEP::reuniteExts(Function &F) {
   return Changed;
 }
 
+bool SeparateConstOffsetFromGEP::sinkGEPConstantOffset(Value *Ptr,
+                                                       bool &Changed) {
+  // The purpose of this function is to sink the constant offsets in the GEP
+  // chain to the tail of the chain.
+  // This algorithm is implemented recursively, the algorithm starts from the
+  // tail of the chain through the DFS method and shifts the constant offset
+  // of the GEP step by step upwards by bottom-up DFS method, i.e. step by step
+  // down to the tail.
+  // A simple example is given:
+  /// %gep0 = getelementptr half, ptr addrspace(3) %ptr, i32 512
+  /// %gep1 = getelementptr half, ptr addrspace(3) %gep0, i32 %ofst0
+  /// %gep2 = getelementptr half, ptr addrspace(3) %gep1, i32 %ofst1
+  /// %data = load half, ptr addrspace(3) %gep2, align 2
+  /// ==>
+  /// %gep0 = getelementptr half, ptr addrspace(3) %ptr, i32 %ofst0
+  /// %gep1 = getelementptr half, ptr addrspace(3) %gep0, i32 %ofst1
+  /// %gep2 = getelementptr half, ptr addrspace(3) %gep1, i32 512
+  /// %data = load half, ptr addrspace(3) %gep2, align 2
+  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+  if (!GEP)
+    return false;
+
+  bool BaseResult = sinkGEPConstantOffset(GEP->getPointerOperand(), Changed);
+
+  if (GEP->getNumIndices() != 1)
+    return false;
+
+  ConstantInt *C = nullptr;
+  Value *Idx = GEP->getOperand(1);
+  bool MatchConstant = match(Idx, m_ConstantInt(C));
+
+  if (!BaseResult)
+    return MatchConstant;
+
+  Type *ResTy = GEP->getResultElementType();
+  GetElementPtrInst *BaseGEP =
+      dyn_cast<GetElementPtrInst>(GEP->getPointerOperand());
+  assert(BaseGEP);
+  Value *BaseIdx = BaseGEP->getOperand(1);
+  Type *BaseResTy = BaseGEP->getResultElementType();
+
+  if (MatchConstant) {
+    // %gep0 = getelementptr half, ptr addrspace(3) %ptr, i32 8
+    // %gep1 = getelementptr half, ptr addrspace(3) %gep0, i32 4
+    // as:
+    // %gep1 = getelementptr half, ptr addrspace(3) %ptr, i32 12
+    Type *NewResTy = nullptr;
+    Constant *NewIdx = nullptr;
+    if (ResTy == BaseResTy) {
+      NewResTy = ResTy;
+      int64_t NewIdxValue = cast<ConstantInt>(BaseIdx)->getSExtValue() +
+                            cast<ConstantInt>(Idx)->getSExtValue();
+      Type *NewIdxType = (NewIdxValue < std::numeric_limits<int32_t>::min() ||
+                          NewIdxValue > std::numeric_limits<int32_t>::max())
+                             ? Type::getInt64Ty(GEP->getContext())
+                             : Type::getInt32Ty(GEP->getContext());
+      NewIdx = ConstantInt::get(NewIdxType, NewIdxValue);
+    } else {
+      NewResTy = Type::getInt8Ty(GEP->getContext());
+      int64_t NewIdxValue = (cast<ConstantInt>(BaseIdx)->getSExtValue() *
+                             DL->getTypeAllocSize(BaseResTy)) +
+                            (cast<ConstantInt>(Idx)->getSExtValue() *
+                             DL->getTypeAllocSize(ResTy));
+      Type *NewIdxType = (NewIdxValue < std::numeric_limits<int32_t>::min() ||
+                          NewIdxValue > std::numeric_limits<int32_t>::max())
+                             ? Type::getInt64Ty(GEP->getContext())
+                             : Type::getInt32Ty(GEP->getContext());
+      NewIdx = ConstantInt::get(NewIdxType, NewIdxValue);
+    }
+    assert(NewResTy);
+    assert(NewIdx);
+    auto *NewGEP = GetElementPtrInst::Create(
+        NewResTy, BaseGEP->getPointerOperand(), NewIdx);
+    NewGEP->setIsInBounds(GEP->isInBounds());
+    NewGEP->insertBefore(GEP->getIterator());
+    NewGEP->takeName(GEP);
+
+    GEP->replaceAllUsesWith(NewGEP);
+    GEP->eraseFromParent();
+
+    Changed = true;
+    return true;
+  }
+
+  // %gep0 = getelementptr half, ptr addrspace(3) %ptr, i32 8
+  // %gep1 = getelementptr half, ptr addrspace(3) %gep0, i32 %idx
+  // as:
+  // %gepx0 = getelementptr half, ptr addrspace(3) %ptr, i32 %idx
+  // %gepx1 = getelementptr half, ptr addrspace(3) %gepx0, i32 8
+  auto *GEPX0 =
+      GetElementPtrInst::Create(ResTy, BaseGEP->getPointerOperand(), Idx);
+  GEPX0->setIsInBounds(BaseGEP->isInBounds());
+  GEPX0->insertBefore(GEP->getIterator());
+  auto *GEPX1 = GetElementPtrInst::Create(BaseResTy, GEPX0, BaseIdx);
+  GEPX1->setIsInBounds(GEP->isInBounds());
+  GEPX1->insertBefore(GEP->getIterator());
+  GEPX1->takeName(GEP);
+
+  GEP->replaceAllUsesWith(GEPX1);
+  GEP->eraseFromParent();
+
+  Changed = true;
+  return true;
+}
+
+bool SeparateConstOffsetFromGEP::sinkGEPConstantOffset(Function &F) {
+  bool Changed = false;
+  SmallVector<Value *, 4> Candidates;
+  for (BasicBlock &B : F) {
+    for (Instruction &I : B) {
+      Value *Ptr = nullptr;
+      if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
+        Ptr = LI->getPointerOperand();
+      } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
+        Ptr = SI->getPointerOperand();
+      }
+      if (Ptr)
----------------
arsenm wrote:

Missing atomics and intrinsics, and all the usual pointer users. Do you really need to consider the use types, or can you keep this to purely operating with pointer expressions 

https://github.com/llvm/llvm-project/pull/140027


More information about the llvm-commits mailing list