[llvm] Add a pass "SinkGEPConstOffset" (PR #140657)
via llvm-commits
llvm-commits at lists.llvm.org
Mon May 26 20:40:00 PDT 2025
================
@@ -0,0 +1,219 @@
+//===- SinkGEPConstOffset.cpp -------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/SinkGEPConstOffset.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DepthFirstIterator.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/MemoryBuiltins.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GetElementPtrTypeIterator.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include <cassert>
+#include <cstdint>
+#include <string>
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+
+namespace {
+
+/// A pass that tries to sink const offset in GEP chain to tail.
+/// It is a FunctionPass because searching for the constant offset may inspect
+/// other basic blocks.
+class SinkGEPConstOffsetLegacyPass : public FunctionPass {
+public:
+ static char ID;
+
+ SinkGEPConstOffsetLegacyPass() : FunctionPass(ID) {
+ initializeSinkGEPConstOffsetLegacyPassPass(
+ *PassRegistry::getPassRegistry());
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.setPreservesCFG();
+ }
+
+ bool runOnFunction(Function &F) override;
+};
+
+} // end anonymous namespace
+
+char SinkGEPConstOffsetLegacyPass::ID = 0;
+
+INITIALIZE_PASS_BEGIN(
+ SinkGEPConstOffsetLegacyPass, "sink-gep-const-offset",
+ "Sink const offsets down the GEP chain to the tail for reduction of "
+ "register usage", false, false)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
+INITIALIZE_PASS_END(
+ SinkGEPConstOffsetLegacyPass, "sink-gep-const-offset",
+ "Sink const offsets down the GEP chain to the tail for reduction of "
+ "register usage", false, false)
+
+FunctionPass *llvm::createSinkGEPConstOffsetPass() {
+ return new SinkGEPConstOffsetLegacyPass();
+}
+
+/// The purpose of this function is to sink the constant offsets in the base
+/// GEP to current GEP.
+///
+/// A simple example is given:
+///
+/// %gep0 = getelementptr half, ptr addrspace(3) %ptr, i32 512
+/// %gep1 = getelementptr half, ptr addrspace(3) %gep0, i32 %ofst0
+/// %data = load half, ptr addrspace(3) %gep1, align 2
+/// ==>
+/// %gep0 = getelementptr half, ptr addrspace(3) %ptr, i32 %ofst0
+/// %gep1 = getelementptr half, ptr addrspace(3) %gep0, i32 512
+/// %data = load half, ptr addrspace(3) %gep1, align 2
+static bool sinkGEPConstantOffset(Value *Ptr, const DataLayout *DL) {
+ GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+ if (!GEP)
+ return false;
+
+ if (GEP->getNumIndices() != 1)
+ return false;
+
+ GetElementPtrInst *BaseGEP =
+ dyn_cast<GetElementPtrInst>(GEP->getPointerOperand());
+ if (!BaseGEP)
+ return false;
+
+ if (BaseGEP->getNumIndices() != 1)
+ return false;
+
+ Value *Idx = GEP->getOperand(1);
+ Value *BaseIdx = BaseGEP->getOperand(1);
+
+ ConstantInt *C = nullptr;
+ if (!match(BaseIdx, m_ConstantInt(C)))
+ return false;
+
+ Type *ResTy = GEP->getResultElementType();
+ Type *BaseResTy = BaseGEP->getResultElementType();
+
+ if (match(Idx, m_ConstantInt(C))) {
+ // %gep0 = getelementptr half, ptr addrspace(3) %ptr, i32 8
+ // %gep1 = getelementptr half, ptr addrspace(3) %gep0, i32 4
+ // as:
+ // %gep1 = getelementptr half, ptr addrspace(3) %ptr, i32 12
+ Type *NewResTy = nullptr;
+ int64_t NewIdxValue = 0;
+ if (ResTy == BaseResTy) {
+ NewResTy = ResTy;
+ NewIdxValue = cast<ConstantInt>(BaseIdx)->getSExtValue() +
+ cast<ConstantInt>(Idx)->getSExtValue();
+ } else {
+ NewResTy = Type::getInt8Ty(GEP->getContext());
+ NewIdxValue = (cast<ConstantInt>(BaseIdx)->getSExtValue() *
+ DL->getTypeAllocSize(BaseResTy)) +
+ (cast<ConstantInt>(Idx)->getSExtValue() *
+ DL->getTypeAllocSize(ResTy));
+ }
+ assert(NewResTy);
+ Type *NewIdxType = (Idx->getType()->getPrimitiveSizeInBits() >
+ BaseIdx->getType()->getPrimitiveSizeInBits())
+ ? Idx->getType() : BaseIdx->getType();
+ Constant *NewIdx = ConstantInt::get(NewIdxType, NewIdxValue);
+ auto *NewGEP = GetElementPtrInst::Create(
+ NewResTy, BaseGEP->getPointerOperand(), NewIdx);
+ NewGEP->setIsInBounds(GEP->isInBounds());
+ NewGEP->insertBefore(GEP->getIterator());
+ NewGEP->takeName(GEP);
+
+ GEP->replaceAllUsesWith(NewGEP);
+ RecursivelyDeleteTriviallyDeadInstructions(GEP);
+
+ return true;
+ }
+
+ // %gep0 = getelementptr half, ptr addrspace(3) %ptr, i32 8
+ // %gep1 = getelementptr half, ptr addrspace(3) %gep0, i32 %idx
+ // as:
+ // %gepx0 = getelementptr half, ptr addrspace(3) %ptr, i32 %idx
+ // %gepx1 = getelementptr half, ptr addrspace(3) %gepx0, i32 8
+ auto *GEPX0 =
+ GetElementPtrInst::Create(ResTy, BaseGEP->getPointerOperand(), Idx);
+ GEPX0->setIsInBounds(BaseGEP->isInBounds());
----------------
StevenYangCC wrote:
I have made changes based on the suggestions you made, please verify the results.
https://github.com/llvm/llvm-project/pull/140657
More information about the llvm-commits
mailing list