[vmkit-commits] [vmkit] r200162 - New inline pass able to resolve symbols in any compilation unit.

Gael Thomas gael.thomas at lip6.fr
Sun Jan 26 15:07:34 PST 2014


Author: gthomas
Date: Sun Jan 26 17:07:33 2014
New Revision: 200162

URL: http://llvm.org/viewvc/llvm-project?rev=200162&view=rev
Log:
New inline pass able to resolve symbols in any compilation unit.


Modified:
    vmkit/branches/mcjit/include/j3/j3method.h
    vmkit/branches/mcjit/include/vmkit/compiler.h
    vmkit/branches/mcjit/lib/j3/vm/j3method.cc
    vmkit/branches/mcjit/lib/vmkit/inliner.cc

Modified: vmkit/branches/mcjit/include/j3/j3method.h
URL: http://llvm.org/viewvc/llvm-project/vmkit/branches/mcjit/include/j3/j3method.h?rev=200162&r1=200161&r2=200162&view=diff
==============================================================================
--- vmkit/branches/mcjit/include/j3/j3method.h (original)
+++ vmkit/branches/mcjit/include/j3/j3method.h Sun Jan 26 17:07:33 2014
@@ -51,57 +51,59 @@ namespace j3 {
 	public:
 		J3Method(uint16_t access, J3Class* cl, const vmkit::Name* name, J3Signature* signature);
 
-		uint32_t            slot() { return _slot; }
+		vmkit::CompilationUnit* unit();
 
-		static J3Method*    nativeMethod(J3ObjectHandle* handle);
-		J3ObjectHandle*     javaMethod();
+		uint32_t                slot() { return _slot; }
 
-		void*               nativeFnPtr() { return _nativeFnPtr; }
+		static J3Method*        nativeMethod(J3ObjectHandle* handle);
+		J3ObjectHandle*         javaMethod();
 
-		void                markCompiled(llvm::Function* llvmFunction, void* fnPtr);
+		void*                   nativeFnPtr() { return _nativeFnPtr; }
 
-		uint32_t            interfaceIndex();
+		void                    markCompiled(llvm::Function* llvmFunction, void* fnPtr);
 
-		void*               getSymbolAddress();
+		uint32_t                interfaceIndex();
 
-		char*               llvmFunctionName(J3Class* from=0);
-		char*               llvmStubName(J3Class* from=0);
+		void*                   getSymbolAddress();
 
-		void                postInitialise(uint32_t access, J3Attributes* attributes);
-		void                setIndex(uint32_t index); 
+		char*                   llvmFunctionName(J3Class* from=0);
+		char*                   llvmStubName(J3Class* from=0);
 
-		J3Method*           resolve(J3ObjectHandle* obj);
+		void                    postInitialise(uint32_t access, J3Attributes* attributes);
+		void                    setIndex(uint32_t index); 
 
-		uint32_t            index();
+		J3Method*               resolve(J3ObjectHandle* obj);
 
-		J3Attributes*       attributes() const { return _attributes; }
-		uint16_t            access() const { return _access; }
-		J3Class*            cl()     const { return _cl; }
-		const vmkit::Name*  name()   const { return _name; }
-		J3Signature*        signature() const { return _signature; }
+		uint32_t                index();
 
-		void                registerNative(void* ptr);
+		J3Attributes*           attributes() const { return _attributes; }
+		uint16_t                access() const { return _access; }
+		J3Class*                cl()     const { return _cl; }
+		const vmkit::Name*      name()   const { return _name; }
+		J3Signature*            signature() const { return _signature; }
 
-		J3Value             invokeStatic(...);
-		J3Value             invokeStatic(J3Value* args);
-		J3Value             invokeStatic(va_list va);
-		J3Value             invokeSpecial(J3ObjectHandle* obj, ...);
-		J3Value             invokeSpecial(J3ObjectHandle* obj, J3Value* args);
-		J3Value             invokeSpecial(J3ObjectHandle* obj, va_list va);
-		J3Value             invokeVirtual(J3ObjectHandle* obj, ...);
-		J3Value             invokeVirtual(J3ObjectHandle* obj, J3Value* args);
-		J3Value             invokeVirtual(J3ObjectHandle* obj, va_list va);
+		void                    registerNative(void* ptr);
 
-		void                aotSnapshot(llvm::Linker* linker);
-		void                ensureCompiled(uint32_t mode);
+		J3Value                 invokeStatic(...);
+		J3Value                 invokeStatic(J3Value* args);
+		J3Value                 invokeStatic(va_list va);
+		J3Value                 invokeSpecial(J3ObjectHandle* obj, ...);
+		J3Value                 invokeSpecial(J3ObjectHandle* obj, J3Value* args);
+		J3Value                 invokeSpecial(J3ObjectHandle* obj, va_list va);
+		J3Value                 invokeVirtual(J3ObjectHandle* obj, ...);
+		J3Value                 invokeVirtual(J3ObjectHandle* obj, J3Value* args);
+		J3Value                 invokeVirtual(J3ObjectHandle* obj, va_list va);
+
+		void                    aotSnapshot(llvm::Linker* linker);
+		void                    ensureCompiled(uint32_t mode);
 		J3Signature::function_t cxxCaller();
-		void*               fnPtr();
-		llvm::Function*     llvmFunction() { return _llvmFunction; } /* overwrite vmkit::Symbol */
-		uint64_t            inlineWeight();
-		void*               functionPointerOrStaticTrampoline();
-		void*               functionPointerOrVirtualTrampoline();
+		void*                   fnPtr();
+		llvm::Function*         llvmFunction() { return _llvmFunction; } /* overwrite vmkit::Symbol */
+		uint64_t                inlineWeight();
+		void*                   functionPointerOrStaticTrampoline();
+		void*                   functionPointerOrVirtualTrampoline();
 
-		void                dump();
+		void                    dump();
 	};
 }
 

Modified: vmkit/branches/mcjit/include/vmkit/compiler.h
URL: http://llvm.org/viewvc/llvm-project/vmkit/branches/mcjit/include/vmkit/compiler.h?rev=200162&r1=200161&r2=200162&view=diff
==============================================================================
--- vmkit/branches/mcjit/include/vmkit/compiler.h (original)
+++ vmkit/branches/mcjit/include/vmkit/compiler.h Sun Jan 26 17:07:33 2014
@@ -21,23 +21,28 @@ namespace llvm {
 
 namespace vmkit {
 	class VMKit;
+	class CompilationUnit;
 
 	class Symbol : public PermanentObject {
 		uint64_t                cachedWeight;
 	public:
-		virtual void*           getSymbolAddress();
-		virtual llvm::Function* llvmFunction() { return 0; }
-		virtual uint64_t        inlineWeight();
+		virtual void*            getSymbolAddress();
+		virtual llvm::Function*  llvmFunction() { return 0; }
+		virtual uint64_t         inlineWeight();
+		virtual CompilationUnit* unit() { return 0; }
 	};
 
 	class NativeSymbol : public Symbol {
-		llvm::Function* original;
-		void*           addr;
+		llvm::Function*  original;
+		void*            addr;
 	public:
-		NativeSymbol(llvm::Function* _original, void* _addr) { original = _original; addr = _addr; }
+		NativeSymbol(llvm::Function* _original, void* _addr) { 
+			original = _original; 
+			addr = _addr; 
+		}
 
-		llvm::Function* llvmFunction() { return original; }
-		void*           getSymbolAddress() { return addr; }
+		llvm::Function*  llvmFunction() { return original; }
+		void*            getSymbolAddress() { return addr; }
 	};
 
 	class CompilationUnit  : public llvm::SectionMemoryManager {

Modified: vmkit/branches/mcjit/lib/j3/vm/j3method.cc
URL: http://llvm.org/viewvc/llvm-project/vmkit/branches/mcjit/lib/j3/vm/j3method.cc?rev=200162&r1=200161&r2=200162&view=diff
==============================================================================
--- vmkit/branches/mcjit/lib/j3/vm/j3method.cc (original)
+++ vmkit/branches/mcjit/lib/j3/vm/j3method.cc Sun Jan 26 17:07:33 2014
@@ -31,6 +31,10 @@ J3Method::J3Method(uint16_t access, J3Cl
 	_index = -1;
 }
 
+vmkit::CompilationUnit* J3Method::unit() {
+	return cl()->loader();
+}
+
 uint64_t J3Method::inlineWeight() {
 	if(J3Thread::get()->vm()->options()->enableInlining)
 		return vmkit::Symbol::inlineWeight();

Modified: vmkit/branches/mcjit/lib/vmkit/inliner.cc
URL: http://llvm.org/viewvc/llvm-project/vmkit/branches/mcjit/lib/vmkit/inliner.cc?rev=200162&r1=200161&r2=200162&view=diff
==============================================================================
--- vmkit/branches/mcjit/lib/vmkit/inliner.cc (original)
+++ vmkit/branches/mcjit/lib/vmkit/inliner.cc Sun Jan 26 17:07:33 2014
@@ -14,156 +14,325 @@
 #include <dlfcn.h>
 
 namespace vmkit {
-  class FunctionInliner : public llvm::FunctionPass {
-  public:
-    static char ID;
-
-		CompilationUnit*         compiler;
-		llvm::InlineCostAnalysis costAnalysis;
-		unsigned int             inlineThreshold; 		// 225 in llvm
-		bool                     onlyAlwaysInline;
-
-		//FunctionInliner() : FunctionPass(ID) {}
-    FunctionInliner(CompilationUnit* _compiler, unsigned int _inlineThreshold, bool _onlyAlwaysInline) : 
-			FunctionPass(ID) { 
-			compiler = _compiler;
-			inlineThreshold = _inlineThreshold; 
+	class FunctionInliner {
+	public:
+		llvm::Function*                          function;
+		llvm::SmallPtrSet<llvm::BasicBlock*, 32> visited;
+		llvm::SmallVector<std::pair<Symbol*, llvm::BasicBlock*>, 8>  visitStack;
+		CompilationUnit*                         originalUnit;
+		Symbol*                                  curSymbol;
+		bool                                     onlyAlwaysInline;
+		uint64_t                                 inlineThreshold;
+
+		FunctionInliner(CompilationUnit* unit, llvm::Function* _function, uint64_t inlineThreshold, bool _onlyAlwaysInline) {
+			function = _function;
+			originalUnit = unit;
 			onlyAlwaysInline = _onlyAlwaysInline;
+			push(0, &function->getEntryBlock());
 		}
 
-    virtual const char* getPassName() const {
-      return "Simple inliner";
-    }
-
-		bool ensureLocal(llvm::Function* function, llvm::Function* callee) {
-			/* prevent exernal references because some llvm passes don't like that */
-			if(callee->getParent() != function->getParent()) {
-				//fprintf(stderr, "       rewrite local\n");
-				llvm::Function* local = (llvm::Function*)function->getParent()->getOrInsertFunction(callee->getName().data(), 
-																																														callee->getFunctionType());
-				callee->replaceAllUsesWith(local);
-				callee = local;
-				return 1;
-			} else
-				return 0;
+		void push(Symbol* symbol, llvm::BasicBlock* bb) {
+			if(visited.insert(bb))
+				visitStack.push_back(std::make_pair(symbol, bb));
 		}
-		
-		Symbol* tryInline(llvm::Function* function, llvm::Function* callee) {
+
+		llvm::BasicBlock* pop() {
+			std::pair<Symbol*, llvm::BasicBlock*> top = visitStack.pop_back_val();
+			curSymbol = top.first;
+			return top.second;
+		}
+
+		Symbol* tryInline(llvm::Function* callee) {
 			if(callee->isIntrinsic())
 				return 0;
 
 			const char*     id = callee->getName().data();
-			Symbol*         symbol = compiler->getSymbol(id, 0);
+			CompilationUnit* unit = curSymbol ? curSymbol->unit() : originalUnit;
+			if(!unit)
+				unit = originalUnit;
+			Symbol*         symbol = unit->getSymbol(id, 0);
 			llvm::Function* bc;
 
 			//fprintf(stderr, "   processing: %s => %p\n", id, symbol);
-
+			
 			if(symbol) {
 				bc = symbol->llvmFunction();
 				if(!bc)
 					return 0;
 			} else {
 				bc = callee;
-
+				
 				if(callee->isDeclaration() && callee->isMaterializable())
 					callee->Materialize();
-
+				
 				if(callee->isDeclaration())
 					return 0;
-					
+				
 				uint8_t* addr = (uint8_t*)dlsym(SELF_HANDLE, id);
-				symbol = new(compiler->allocator()) NativeSymbol(callee, addr);
-				compiler->addSymbol(id, symbol);
+				symbol = new(unit->allocator()) NativeSymbol(callee, addr);
+				unit->addSymbol(id, symbol);
 			}
-
-			//fprintf(stderr, "       weight: %lld\n", symbol->inlineWeight());
+		//fprintf(stderr, "       weight: %lld\n", symbol->inlineWeight());
 
 			return (!bc->hasFnAttribute(llvm::Attribute::NoInline)
 							&& (bc->hasFnAttribute(llvm::Attribute::AlwaysInline) || 
-									(!onlyAlwaysInline && (uint64_t)(symbol->inlineWeight()-1) < inlineThreshold))) ? symbol : 0;
+									(0 && !onlyAlwaysInline && (uint64_t)(symbol->inlineWeight()-1) < inlineThreshold))) ? symbol : 0;
 		}
 
-		//llvm::SmallPtrSet<const Function*, 16> NeverInline;
+		bool visitBB(llvm::BasicBlock* bb) {
+			bool changed = 0;
+			bool takeNext = 0;
+
+			//fprintf(stderr, "    visit basic block: %s\n", bb->getName().data());
+
+			for(llvm::BasicBlock::iterator it=bb->begin(), prev=0; it!=bb->end(); takeNext && (prev=it++)) {
+				llvm::Instruction *insn = it;
+				takeNext = 1;
+
+				//fprintf(stderr, "        visit insn: ");
+				//insn->dump();
+
+				//fprintf(stderr, "             %d operands\n", insn->getNumOperands());
+				for(unsigned i=0; i<insn->getNumOperands(); i++) {
+					llvm::Value* op = insn->getOperand(i);
+						
+					//fprintf(stderr, " ----> ");
+					//op->dump();
+					//fprintf(stderr, "     => %s\n", llvm::isa<llvm::GlobalValue>(op) ? "global" : "not global");
+
+					if(llvm::isa<llvm::GlobalValue>(op)) {
+						llvm::GlobalValue* gv = llvm::cast<llvm::GlobalValue>(op);
+						if(gv->getParent() != function->getParent()) {
+							llvm::Value* copy =
+								llvm::isa<llvm::Function>(gv) ?
+								function->getParent()->getOrInsertFunction(gv->getName().data(), 
+																													 llvm::cast<llvm::Function>(gv)->getFunctionType()) :
+								function->getParent()->getOrInsertGlobal(gv->getName().data(), gv->getType()->getContainedType(0));
 
-		bool runOnFunction(llvm::Function& function) {
-			bool     changed = false;
-			
-			//fprintf(stderr, "Analyzing: %s\n", function.getName().data());
-			
-		restart:
-			for (llvm::Function::iterator bit=function.begin(); bit!=function.end(); bit++) { 
-				llvm::BasicBlock* bb = bit; 
-				uint32_t prof = 0;
+							//fprintf(stderr, "<<<reimporting>>>: %s\n", gv->getName().data());
+							gv->replaceAllUsesWith(copy);
+						}
+					}
+				}
 
-				for(llvm::BasicBlock::iterator it=bb->begin(), prev=0; it!=bb->end() && prof<42; prev=it++) {
-					llvm::Instruction *insn = it;
+				if(insn->getOpcode() != llvm::Instruction::Call &&
+					 insn->getOpcode() != llvm::Instruction::Invoke) {
+					continue;
+				}
 
-					//fprintf(stderr, "  process: ");
-					//insn->dump();
+				llvm::CallSite  call(insn);
+				llvm::Function* callee = call.getCalledFunction();
+				
+				if(!callee)
+					continue;
+				
+				Symbol* symbol = tryInline(callee);
+				
+				if(symbol) {
+					llvm::Function* bc = symbol->llvmFunction();
 
-#if 0
-					if(insn->getOpcode() == llvm::Instruction::LandingPad) {
-						llvm::LandingPadInst* lp = (llvm::LandingPadInst*)insn;
-						ensureLocal(&function, (llvm::Function*)lp->getPersonalityFn());
-						continue;
-					}
-#endif
+					if(bc != callee)
+						callee->replaceAllUsesWith(bc);
+					
+					fprintf(stderr, "            inlining %s in %s\n", bc->getName().data(), function->getName().data());
 
-					if (insn->getOpcode() != llvm::Instruction::Call &&
-							insn->getOpcode() != llvm::Instruction::Invoke) {
-						continue;
+					if(llvm::isa<llvm::TerminatorInst>(insn)) {
+						llvm::TerminatorInst* terminator = llvm::cast<llvm::TerminatorInst>(insn);
+						for(unsigned i=0; i<terminator->getNumSuccessors(); i++)
+							push(symbol, terminator->getSuccessor(i));
+					} else {
+						size_t len = strlen(bc->getName().data());
+						char buf[len + 16];
+						memcpy(buf, bc->getName().data(), len);
+						memcpy(buf+len, ".after-inline", 14);
+						push(symbol, bb->splitBasicBlock(insn->getNextNode(), buf));
 					}
+					
+					llvm::InlineFunctionInfo ifi(0);
+					bool isInlined = llvm::InlineFunction(call, ifi, false);
+					//fprintf(stderr, " ok?: %d\n", isInlined);
+					changed |= isInlined;
+
+					if(isInlined) {
+						curSymbol = symbol;
+						if(prev)
+							it = prev;
+						else {
+							takeNext = 0;
+							it = bb->begin();
+						}
+					} else if(bc != callee)
+						bc->replaceAllUsesWith(callee);
+				}
+			}
+
+			return changed;
+		}
 
-					llvm::CallSite  call(insn);
-					llvm::Function* callee = call.getCalledFunction();
+		bool proceed() {
+			bool changed = 0;
 
-					if(!callee)
-						continue;
+			//fprintf(stderr, "visit function: %s\n", function->getName().data());
 
-					Symbol* symbol = tryInline(&function, callee);
-					llvm::Function* bc;
+			while(visitStack.size()) {
+				llvm::BasicBlock* bb = pop();
 
-					if(symbol && (bc = symbol->llvmFunction()) != &function) {
-						if(bc != callee)
-							callee->replaceAllUsesWith(bc);
-
-						//fprintf(stderr, "   inlining: %s\n", bc->getName().data());
-						//bc->dump();
-						llvm::InlineFunctionInfo ifi(0);
-						bool isInlined = llvm::InlineFunction(call, ifi, false);
-						changed |= isInlined;
-
-						if(isInlined) {
-							prof++;
-							//it = prev ? prev : bb->begin();
-							//continue;
-							//fprintf(stderr, "... restart ....\n");
-							goto restart;
-						}
-					} else
-						changed |= ensureLocal(&function, callee);
+				changed |= visitBB(bb);
+
+				llvm::TerminatorInst* terminator = bb->getTerminator();
+				if(terminator) {
+					for(unsigned i=0; i<terminator->getNumSuccessors(); i++)
+						push(curSymbol, terminator->getSuccessor(i));
 				}
 			}
 
-			//#if 0
 			if(changed) {
-				//function.dump();
-				//abort();
+				//function->dump();
 			}
-			//#endif
 
 			return changed;
 		}
 	};
 
-  char FunctionInliner::ID = 0;
+  class FunctionInlinerPass : public llvm::FunctionPass {
+  public:
+    static char ID;
+
+		CompilationUnit*         unit;
+		llvm::InlineCostAnalysis costAnalysis;
+		unsigned int             inlineThreshold; 		// 225 in llvm
+		bool                     onlyAlwaysInline;
+
+		FunctionInlinerPass(CompilationUnit* _unit, unsigned int _inlineThreshold, bool _onlyAlwaysInline) : 
+			FunctionPass(ID) { 
+			unit = _unit;
+			inlineThreshold = _inlineThreshold; 
+			onlyAlwaysInline = _onlyAlwaysInline;
+		}
+
+    virtual const char* getPassName() const { return "VMKit inliner"; }
+		bool                ensureLocal(llvm::Function* function, llvm::Function* callee);
+		Symbol*             tryInline(llvm::Function* function, llvm::Function* callee);
+		bool                runOnBB(llvm::BasicBlock* bb);
+		bool                runOnFunction0(llvm::Function& function);
+		bool                runOnFunction(llvm::Function& function) {
+#if 0
+			return runOnFunction0(function);
+#else
+			FunctionInliner inliner(unit, &function, inlineThreshold, onlyAlwaysInline);
+			return inliner.proceed();
+#endif
+		}
+	};
+
+  char FunctionInlinerPass::ID = 0;
+
+#if 0
+	llvm::RegisterPass<FunctionInlinerPass> X("FunctionInlinerPass",
+																				"Inlining Pass that inlines evaluator's functions.");
+#endif
+
+	//FunctionInlinerPass() : FunctionPass(ID) {}
+
+	bool FunctionInlinerPass::ensureLocal(llvm::Function* function, llvm::Function* callee) {
+		/* prevent exernal references because some llvm passes don't like that */
+		if(callee->getParent() != function->getParent()) {
+			//fprintf(stderr, "       rewrite local\n");
+			llvm::Function* local = (llvm::Function*)function->getParent()->getOrInsertFunction(callee->getName().data(), 
+																																													callee->getFunctionType());
+			callee->replaceAllUsesWith(local);
+			callee = local;
+			return 1;
+		} else
+			return 0;
+	}
+		
+	//llvm::SmallPtrSet<const Function*, 16> NeverInline;
+
+	bool FunctionInlinerPass::runOnBB(llvm::BasicBlock* bb) {
+		fprintf(stderr, " process basic block %s\n", bb->getName().data());
+
+			//SmallPtrSet<const BasicBlock*, 8> Visited;
+
+		return 0;
+	}
+
+	bool FunctionInlinerPass::runOnFunction0(llvm::Function& function) {
+		bool     changed = false;
+			
+		//fprintf(stderr, "Analyzing: %s\n", function.getName().data());
+			
+	restart:
+		for (llvm::Function::iterator bit=function.begin(); bit!=function.end(); bit++) { 
+			llvm::BasicBlock* bb = bit; 
+			uint32_t prof = 0;
+
+			for(llvm::BasicBlock::iterator it=bb->begin(), prev=0; it!=bb->end() && prof<42; prev=it++) {
+				llvm::Instruction *insn = it;
+
+				//fprintf(stderr, "  process: ");
+				//insn->dump();
+
+#if 0
+				if(insn->getOpcode() == llvm::Instruction::LandingPad) {
+					llvm::LandingPadInst* lp = (llvm::LandingPadInst*)insn;
+					ensureLocal(&function, (llvm::Function*)lp->getPersonalityFn());
+					continue;
+				}
+#endif
+
+				if (insn->getOpcode() != llvm::Instruction::Call &&
+						insn->getOpcode() != llvm::Instruction::Invoke) {
+					continue;
+				}
+				
+				llvm::CallSite  call(insn);
+				llvm::Function* callee = call.getCalledFunction();
+				
+				if(!callee)
+					continue;
+				
+				Symbol* symbol = tryInline(&function, callee);
+				llvm::Function* bc;
+				
+				if(symbol && (bc = symbol->llvmFunction()) != &function) {
+					if(bc != callee)
+						callee->replaceAllUsesWith(bc);
+					
+					//fprintf(stderr, "   inlining: %s\n", bc->getName().data());
+					//bc->dump();
+					llvm::InlineFunctionInfo ifi(0);
+					bool isInlined = llvm::InlineFunction(call, ifi, false);
+					changed |= isInlined;
+					
+					if(isInlined) {
+						prof++;
+						//it = prev ? prev : bb->begin();
+						//continue;
+						//fprintf(stderr, "... restart ....\n");
+						goto restart;
+					}
+				} else
+					changed |= ensureLocal(&function, callee);
+			}
+		}
+
+		//#if 0
+		if(changed) {
+			//function.dump();
+			//abort();
+		}
+		//#endif
+
+		return changed;
+	}
 
 #if 0
-	llvm::RegisterPass<FunctionInliner> X("FunctionInliner",
+	llvm::RegisterPass<FunctionInlinerPass> X("FunctionInlinerPass",
 																				"Inlining Pass that inlines evaluator's functions.");
 #endif
 
 	llvm::FunctionPass* createFunctionInlinerPass(CompilationUnit* compiler, bool onlyAlwaysInline) {
-		return new FunctionInliner(compiler, 2000, onlyAlwaysInline);
+		return new FunctionInlinerPass(compiler, 2000, onlyAlwaysInline);
 	}
 }





More information about the vmkit-commits mailing list