[llvm-commits] [llvm] r47776 - in /llvm/trunk: include/llvm/Transforms/IPO.h lib/Transforms/IPO/StructRetPromotion.cpp

Devang Patel dpatel at apple.com
Fri Feb 29 15:34:12 PST 2008


Author: dpatel
Date: Fri Feb 29 17:34:08 2008
New Revision: 47776

URL: http://llvm.org/viewvc/llvm-project?rev=47776&view=rev
Log:
Add pass to promote sret.
This pass transforms 

  %struct._Point = type { i32, i32, i32, i32, i32, i32 }
  define internal void @foo(%struct._Point* sret  %agg.result)

into

  %struct._Point = type { i32, i32, i32, i32, i32, i32 }
  define internal %struct._Point @foo()

This pass updates foo() clients appropriately to use
getresult instruction to extract return values.

This pass is not yet ready for prime time.

Added:
    llvm/trunk/lib/Transforms/IPO/StructRetPromotion.cpp
Modified:
    llvm/trunk/include/llvm/Transforms/IPO.h

Modified: llvm/trunk/include/llvm/Transforms/IPO.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Transforms/IPO.h?rev=47776&r1=47775&r2=47776&view=diff

==============================================================================
--- llvm/trunk/include/llvm/Transforms/IPO.h (original)
+++ llvm/trunk/include/llvm/Transforms/IPO.h Fri Feb 29 17:34:08 2008
@@ -125,6 +125,7 @@
 /// be passed by value.
 ///
 Pass *createArgumentPromotionPass();
+Pass *createStructRetPromotionPass();
 
 //===----------------------------------------------------------------------===//
 /// createIPConstantPropagationPass - This pass propagates constants from call

Added: llvm/trunk/lib/Transforms/IPO/StructRetPromotion.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/IPO/StructRetPromotion.cpp?rev=47776&view=auto

==============================================================================
--- llvm/trunk/lib/Transforms/IPO/StructRetPromotion.cpp (added)
+++ llvm/trunk/lib/Transforms/IPO/StructRetPromotion.cpp Fri Feb 29 17:34:08 2008
@@ -0,0 +1,292 @@
+//===-- StructRetPromotion.cpp - Promote sret arguments -000000------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass promotes "by reference" arguments to be "by value" arguments.  In
+// practice, this means looking for internal functions that have pointer
+// arguments.  If it can prove, through the use of alias analysis, that an
+// argument is *only* loaded, then it can pass the value into the function
+// instead of the address of the value.  This can cause recursive simplification
+// of code and lead to the elimination of allocas (especially in C++ template
+// code like the STL).
+//
+// This pass also handles aggregate arguments that are passed into a function,
+// scalarizing them if the elements of the aggregate are only loaded.  Note that
+// it refuses to scalarize aggregates which would require passing in more than
+// three operands to the function, because passing thousands of operands for a
+// large array or structure is unprofitable!
+//
+// Note that this transformation could also be done for arguments that are only
+// stored to (returning the value instead), but does not currently.  This case
+// would be best handled when and if LLVM begins supporting multiple return
+// values from functions.
+//
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "sretpromotion"
+#include "llvm/Transforms/IPO.h"
+#include "llvm/Constants.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/Module.h"
+#include "llvm/CallGraphSCCPass.h"
+#include "llvm/Instructions.h"
+#include "llvm/Analysis/CallGraph.h"
+#include "llvm/Support/CallSite.h"
+#include "llvm/Support/CFG.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Compiler.h"
+using namespace llvm;
+
+namespace {
+  /// SRETPromotion - This pass removes sret parameter and updates
+  /// function to use multiple return value.
+  ///
+  struct VISIBILITY_HIDDEN SRETPromotion : public CallGraphSCCPass {
+    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
+      CallGraphSCCPass::getAnalysisUsage(AU);
+    }
+
+    virtual bool runOnSCC(const std::vector<CallGraphNode *> &SCC);
+    static char ID; // Pass identification, replacement for typeid
+    SRETPromotion() : CallGraphSCCPass((intptr_t)&ID) {}
+
+  private:
+    bool PromoteReturn(CallGraphNode *CGN);
+    bool isSafeToUpdateAllCallers(Function *F);
+    Function *cloneFunctionBody(Function *F, const StructType *STy);
+    void updateCallSites(Function *F, Function *NF);
+  };
+
+  char SRETPromotion::ID = 0;
+  RegisterPass<SRETPromotion> X("sretpromotion",
+                               "Promote sret arguments to multiple ret values");
+}
+
+Pass *llvm::createStructRetPromotionPass() {
+  return new SRETPromotion();
+}
+
+bool SRETPromotion::runOnSCC(const std::vector<CallGraphNode *> &SCC) {
+  bool Changed = false;
+
+  for (unsigned i = 0, e = SCC.size(); i != e; ++i)
+    Changed |= PromoteReturn(SCC[i]);
+
+  return Changed;
+}
+
+/// PromoteReturn - This method promotes function that uses StructRet paramater 
+/// into a function that uses mulitple return value.
+bool SRETPromotion::PromoteReturn(CallGraphNode *CGN) {
+  Function *F = CGN->getFunction();
+
+  // Make sure that it is local to this module.
+  if (!F || !F->hasInternalLinkage())
+    return false;
+
+  // Make sure that function returns struct.
+  if (F->arg_size() == 0 || !F->isStructReturn() || F->doesNotReturn())
+    return false;
+
+  assert (F->getReturnType() == Type::VoidTy && "Invalid function return type");
+  Function::arg_iterator AI = F->arg_begin();
+  const llvm::PointerType *FArgType = dyn_cast<PointerType>(AI->getType());
+  assert (FArgType && "Invalid sret paramater type");
+  const llvm::StructType *STy = 
+    dyn_cast<StructType>(FArgType->getElementType());
+  assert (STy && "Invalid sret parameter element type");
+
+  // Check if it is ok to perform this promotion.
+  if (isSafeToUpdateAllCallers(F) == false)
+    return false;
+
+  // [1] Replace use of sret parameter 
+  AllocaInst *TheAlloca = new AllocaInst (STy, NULL, "mrv", F->getEntryBlock().begin());
+  Value *NFirstArg = F->arg_begin();
+  NFirstArg->replaceAllUsesWith(TheAlloca);
+
+  // Find and replace ret instructions
+  SmallVector<Value *,4> RetVals;
+  for (Function::iterator FI = F->begin(), FE = F->end();  FI != FE; ++FI) 
+    for(BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ) {
+      Instruction *I = BI;
+      ++BI;
+      if (isa<ReturnInst>(I)) {
+        RetVals.clear();
+        for (unsigned idx = 0; idx < STy->getNumElements(); ++idx) {
+          SmallVector<Value*, 2> GEPIdx;
+          GEPIdx.push_back(ConstantInt::get(Type::Int32Ty, 0));
+          GEPIdx.push_back(ConstantInt::get(Type::Int32Ty, idx));
+          Value *NGEPI = new GetElementPtrInst(TheAlloca, GEPIdx.begin(), GEPIdx.end(),
+                                               "mrv.gep", I);
+          Value *NV = new LoadInst(NGEPI, "mrv.ld", I);
+          RetVals.push_back(NV);
+        }
+    
+        ReturnInst *NR = new ReturnInst(&RetVals[0], RetVals.size(), I);
+        I->replaceAllUsesWith(NR);
+        I->eraseFromParent();
+      }
+    }
+
+  // Create the new function body and insert it into the module.
+  Function *NF = cloneFunctionBody(F, STy);
+
+  // Update all call sites to use new function
+  updateCallSites(F, NF);
+
+  F->eraseFromParent();
+  getAnalysis<CallGraph>().changeFunction(F, NF);
+  return true;
+}
+
+  // Check if it is ok to perform this promotion.
+bool SRETPromotion::isSafeToUpdateAllCallers(Function *F) {
+
+  if (F->use_empty())
+    // No users. OK to modify signature.
+    return true;
+
+  for (Value::use_iterator FnUseI = F->use_begin(), FnUseE = F->use_end();
+       FnUseI != FnUseE; ++FnUseI) {
+
+    CallSite CS = CallSite::get(*FnUseI);
+    Instruction *Call = CS.getInstruction();
+    CallSite::arg_iterator AI = CS.arg_begin();
+    Value *FirstArg = *AI;
+
+    if (!isa<AllocaInst>(FirstArg))
+      return false;
+
+    // Check FirstArg's users.
+    for (Value::use_iterator ArgI = FirstArg->use_begin(), 
+           ArgE = FirstArg->use_end(); ArgI != ArgE; ++ArgI) {
+
+      // If FirstArg user is a CallInst that does not correspond to current
+      // call site then this function F is not suitable for sret promotion.
+      if (CallInst *CI = dyn_cast<CallInst>(ArgI)) {
+        if (CI != Call)
+          return false;
+      }
+      // If FirstArg user is a GEP whose all users are not LoadInst then
+      // this function F is not suitable for sret promotion.
+      else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(ArgI)) {
+        for (Value::use_iterator GEPI = GEP->use_begin(), GEPE = GEP->use_end();
+             GEPI != GEPE; ++GEPI) 
+          if (!isa<LoadInst>(GEPI))
+            return false;
+      } 
+      // Any other FirstArg users make this function unsuitable for sret 
+      // promotion.
+      else
+        return false;
+    }
+  }
+
+  return true;
+}
+
+/// cloneFunctionBody - Create a new function based on F and
+/// insert it into module. Remove first argument. Use STy as
+/// the return type for new function.
+Function *SRETPromotion::cloneFunctionBody(Function *F, 
+                                           const StructType *STy) {
+
+  // FIXME : Do not drop param attributes on the floor.
+  const FunctionType *FTy = F->getFunctionType();
+  std::vector<const Type*> Params;
+
+  // Skip first argument.
+  Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
+  ++I;
+  while (I != E) {
+    Params.push_back(I->getType());
+    ++I;
+  }
+
+  FunctionType *NFTy = FunctionType::get(STy, Params, FTy->isVarArg());
+  Function *NF = new Function(NFTy, F->getLinkage(), F->getName());
+  NF->setCallingConv(F->getCallingConv());
+  F->getParent()->getFunctionList().insert(F, NF);
+  NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
+
+  // Replace arguments
+  I = F->arg_begin();
+  E = F->arg_end();
+  Function::arg_iterator NI = NF->arg_begin();
+  ++I;
+  while (I != E) {
+      I->replaceAllUsesWith(NI);
+      NI->takeName(I);
+      ++I;
+      ++NI;
+  }
+
+  return NF;
+}
+
+/// updateCallSites - Update all sites that call F to use NF.
+void SRETPromotion::updateCallSites(Function *F, Function *NF) {
+
+  // FIXME : Handle parameter attributes
+  SmallVector<Value*, 16> Args;
+
+  for (Value::use_iterator FUI = F->use_begin(), FUE = F->use_end(); FUI != FUE;) {
+    CallSite CS = CallSite::get(*FUI);
+    ++FUI;
+    Instruction *Call = CS.getInstruction();
+
+    // Copy arguments, however skip first one.
+    CallSite::arg_iterator AI = CS.arg_begin(), AE = CS.arg_end();
+    Value *FirstCArg = *AI;
+    ++AI;
+    while (AI != AE) {
+      Args.push_back(*AI); 
+      ++AI;
+    }
+
+    // Build new call instruction.
+    Instruction *New;
+    if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
+      New = new InvokeInst(NF, II->getNormalDest(), II->getUnwindDest(),
+                           Args.begin(), Args.end(), "", Call);
+      cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv());
+    } else {
+      New = new CallInst(NF, Args.begin(), Args.end(), "", Call);
+      cast<CallInst>(New)->setCallingConv(CS.getCallingConv());
+      if (cast<CallInst>(Call)->isTailCall())
+        cast<CallInst>(New)->setTailCall();
+    }
+    Args.clear();
+    New->takeName(Call);
+
+    // Update all users of sret parameter to extract value using getresult.
+    for (Value::use_iterator UI = FirstCArg->use_begin(), 
+           UE = FirstCArg->use_end(); UI != UE; ) {
+      User *U2 = *UI++;
+      CallInst *C2 = dyn_cast<CallInst>(U2);
+      if (C2 && (C2 == Call))
+        continue;
+      else if (GetElementPtrInst *UGEP = dyn_cast<GetElementPtrInst>(U2)) {
+        Value *GR = new GetResultInst(New, 5, "xxx", UGEP);
+        for (Value::use_iterator GI = UGEP->use_begin(),
+               GE = UGEP->use_end(); GI != GE; ++GI) {
+          if (LoadInst *L = dyn_cast<LoadInst>(*GI)) {
+            L->replaceAllUsesWith(GR);
+            L->eraseFromParent();
+          }
+        }
+        UGEP->eraseFromParent();
+      }
+      else assert( 0 && "Unexpected sret parameter use");
+    }
+    Call->eraseFromParent();
+  }
+}





More information about the llvm-commits mailing list