[llvm-commits] [vector_llvm] CVS: llvm/lib/Transforms/Vector/Alloca2Realloc.cpp AltiVec.cpp LowerVectors.cpp RaiseVectors.cpp SSE.cpp
Robert Bocchino
bocchino at cs.uiuc.edu
Tue Oct 18 12:37:14 PDT 2005
Changes in directory llvm/lib/Transforms/Vector:
Alloca2Realloc.cpp added (r1.1.2.1)
AltiVec.cpp added (r1.1.2.1)
LowerVectors.cpp added (r1.1.2.1)
RaiseVectors.cpp added (r1.1.2.1)
SSE.cpp added (r1.1.2.1)
---
Log message:
Initial commit of Vector LLVM.
---
Diffs of the changes: (+2978 -0)
Alloca2Realloc.cpp | 212 ++++++++++
AltiVec.cpp | 530 ++++++++++++++++++++++++++
LowerVectors.cpp | 1059 +++++++++++++++++++++++++++++++++++++++++++++++++++++
RaiseVectors.cpp | 594 +++++++++++++++++++++++++++++
SSE.cpp | 583 +++++++++++++++++++++++++++++
5 files changed, 2978 insertions(+)
Index: llvm/lib/Transforms/Vector/Alloca2Realloc.cpp
diff -c /dev/null llvm/lib/Transforms/Vector/Alloca2Realloc.cpp:1.1.2.1
*** /dev/null Tue Oct 18 14:37:13 2005
--- llvm/lib/Transforms/Vector/Alloca2Realloc.cpp Tue Oct 18 14:37:03 2005
***************
*** 0 ****
--- 1,212 ----
+ //===- Alloca2Realloc.cpp - Replace allocas with realloc inside loops -----===//
+ //
+ // The LLVM Compiler Infrastructure
+ //
+ // This file was developed by the LLVM research group and is distributed under
+ // the University of Illinois Open Source License. See LICENSE.TXT for details.
+ //
+ //===----------------------------------------------------------------------===//
+ //
+ // This file replaces alloca instructions with calls to the C library
+ // function realloc inside loops, in order to conserve memory and
+ // prevent stack overflow. The replacement is done only if the
+ // pointer resulting from the alloca is never stored to memory or
+ // passed as an argument to a function. In this case, the replacement
+ // is safe: because LLVM is in SSA form, any subsequent execution of
+ // the same alloca must overwrite the old pointer.
+ //
+ //===----------------------------------------------------------------------===//
+
+ #define DEBUG_TYPE "alloca2realloc"
+
+ #include "llvm/Analysis/LoopInfo.h"
+ #include "llvm/BasicBlock.h"
+ #include "llvm/Constants.h"
+ #include "llvm/Function.h"
+ #include "llvm/Instructions.h"
+ #include "llvm/Module.h"
+ #include "llvm/Pass.h"
+ #include "llvm/DerivedTypes.h"
+ #include "llvm/Support/Debug.h"
+
+ using namespace llvm;
+
+ namespace {
+
+
+ //===----------------------------------------------------------------------===//
+ // Class definitions
+ //===----------------------------------------------------------------------===//
+
+ class Alloca2Realloc : public FunctionPass {
+
+ Function *ReallocFunc;
+ bool changed;
+ Function *function;
+ LoopInfo *LI;
+
+ public:
+ /// This transformation requires natural loop information
+ ///
+ virtual void getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.setPreservesCFG();
+ AU.addRequired<LoopInfo>();
+ }
+ virtual bool doInitialization(Module &M);
+ virtual bool runOnFunction(Function &F);
+
+ private:
+ void visitLoop(Loop*);
+ bool processAlloca(AllocaInst*);
+ bool isSafe(Value*);
+ };
+
+ RegisterOpt<Alloca2Realloc> X("alloca2realloc",
+ "Replace alloca with realloc inside loops");
+
+
+ //===----------------------------------------------------------------------===//
+ // Alloca2Realloc implementation
+ //===----------------------------------------------------------------------===//
+
+ bool Alloca2Realloc::doInitialization(Module &M) {
+ const Type *SBPTy = PointerType::get(Type::SByteTy);
+ ReallocFunc = M.getNamedFunction("realloc");
+
+ if (ReallocFunc == 0)
+ ReallocFunc = M.getOrInsertFunction("realloc", SBPTy, SBPTy, Type::UIntTy, 0);
+
+ return true;
+ }
+
+ bool Alloca2Realloc::runOnFunction(Function &F) {
+ changed = false;
+ function = &F;
+ assert(ReallocFunc && "Pass not initialized!");
+
+ // Get loop information
+ //
+ LI = &getAnalysis<LoopInfo>();
+
+ // Process each top-level loop
+ //
+ for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I) {
+ visitLoop(*I);
+ }
+ return changed;
+ }
+
+ void Alloca2Realloc::visitLoop(Loop *L) {
+ // Recurse through all subloops before we process this loop...
+ //
+ for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) {
+ visitLoop(*I);
+ }
+
+ // Now do this loop
+ //
+ for (std::vector<BasicBlock*>::const_iterator BI = L->getBlocks().begin(),
+ BE = L->getBlocks().end(); BI != BE; ++BI) {
+ BasicBlock *BB = *BI;
+ if (LI->getLoopFor(BB) == L) { // Ignore blocks in subloops...
+ BasicBlock::InstListType &BBIL = BB->getInstList();
+ for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE; ++II) {
+ if (AllocaInst *AI = dyn_cast<AllocaInst>(II)) {
+ if (processAlloca(AI)) {
+ II = --BBIL.erase(II); // remove and delete the alloca instruction
+ changed = true;
+ }
+ }
+ }
+ }
+ }
+
+ }
+
+ // Process an alloca instruction. Return true if the alloca was
+ // replaced with a realloc, false otherwise.
+ //
+ bool Alloca2Realloc::processAlloca(AllocaInst *AI) {
+ DEBUG(std::cerr << "Processing " << *AI);
+ if (!isSafe(AI)) {
+ DEBUG(std::cerr << "Instruction is not safe to replace!\n");
+ return false;
+ }
+ DEBUG(std::cerr << "Replacing instruction\n");
+ Instruction *before = &(function->getEntryBlock().front());
+ Value *loadStorePtr = new AllocaInst(AI->getType(), 0, "loadstore_ptr", before);
+ new StoreInst(Constant::getNullValue(AI->getType()), loadStorePtr, before);
+
+ const FunctionType *ReallocFTy = ReallocFunc->getFunctionType();
+
+ // Create the vector of arguments to realloc
+ //
+ Value *reallocPtr = new LoadInst(loadStorePtr, "realloc_ptr", AI);
+ if (reallocPtr->getType() != ReallocFTy->getParamType(0)) {
+ reallocPtr = new CastInst(reallocPtr, ReallocFTy->getParamType(0), "cast", AI);
+ }
+
+ Value *reallocSize =
+ BinaryOperator::create(Instruction::Mul, AI->getArraySize(),
+ ConstantUInt::get(Type::UIntTy,
+ AI->getAllocatedType()->getPrimitiveSize()),
+ "size", AI);
+ if (reallocSize->getType() != ReallocFTy->getParamType(1)) {
+ reallocSize = new CastInst(reallocSize, ReallocFTy->getParamType(1), "cast", AI);
+ }
+
+ std::vector<Value*> ReallocArgs;
+ ReallocArgs.push_back(reallocPtr);
+ ReallocArgs.push_back(reallocSize);
+
+ // Create the call to realloc
+ //
+ CallInst *call = new CallInst(ReallocFunc, ReallocArgs, AI->getName(), AI);
+
+ // Create a cast instruction to convert to the right type...
+ //
+ Value *newVal = call;
+ if (call->getType() == Type::VoidTy)
+ newVal = Constant::getNullValue(AI->getType());
+ else if (call->getType() != AI->getType())
+ newVal = new CastInst(call, AI->getType(), "cast", AI);
+
+ // Replace all uses of the old malloc inst with the cast inst
+ //
+ AI->replaceAllUsesWith(newVal);
+
+ // Insert a free instruction before every return
+ //
+ for (Function::iterator FI = function->begin(), FE = function->end();
+ FI != FE; ++FI) {
+ if (ReturnInst *RI = dyn_cast<ReturnInst>(FI->getTerminator())) {
+ LoadInst *freePtr = new LoadInst(loadStorePtr, "free_ptr", RI);
+ new FreeInst(freePtr, RI);
+ }
+ }
+
+ return true;
+ }
+
+ // Return true if the given value is never used in a store or call
+ // instruction, false otherwise
+ //
+ bool Alloca2Realloc::isSafe(Value *V) {
+ for (Value::use_iterator I = V->use_begin(), E = V->use_end();
+ I != E; ++I) {
+ if (isa<CallInst>(*I)) {
+ DEBUG(std::cerr << "Followed chain of uses to " << **I);
+ return false;
+ }
+ StoreInst *SI = dyn_cast<StoreInst>(*I);
+ if (SI && V == SI->getOperand(0)) {
+ DEBUG(std::cerr << "Followed chain of uses to " << *SI);
+ return false;
+ }
+ if (!isa<LoadInst>(*I) && !isSafe(*I))
+ return false;
+ }
+ return true;
+ }
+
+ }
Index: llvm/lib/Transforms/Vector/AltiVec.cpp
diff -c /dev/null llvm/lib/Transforms/Vector/AltiVec.cpp:1.1.2.1
*** /dev/null Tue Oct 18 14:37:14 2005
--- llvm/lib/Transforms/Vector/AltiVec.cpp Tue Oct 18 14:37:03 2005
***************
*** 0 ****
--- 1,530 ----
+ //===- AltiVec.cpp - Raise significant functions to Vector-LLVM ------===//
+ //
+ // The LLVM Compiler Infrastructure
+ //
+ // This file was developed by the LLVM research group and is distributed under
+ // the University of Illinois Open Source License. See LICENSE.TXT for details.
+ //
+ //===----------------------------------------------------------------------===//
+ //
+ // This file takes blocked Vector-LLVM code and puts it in a form that
+ // can be passed to the AltiVec C Backend.
+ //
+ //===----------------------------------------------------------------------===//
+
+ #define DEBUG_TYPE "altivec"
+
+ #include "VectorLLVM/Utils.h"
+ #include "llvm/Constants.h"
+ #include "llvm/DerivedTypes.h"
+ #include "llvm/Function.h"
+ #include "llvm/Instructions.h"
+ #include "llvm/Pass.h"
+ #include "llvm/Type.h"
+ #include "llvm/Support/Debug.h"
+ #include "llvm/ADT/hash_map"
+ #include "llvm/ADT/hash_set"
+ #include "llvm/ADT/STLExtras.h"
+ #include "llvm/Support/InstVisitor.h"
+ #include "VectorLLVM/Utils.h"
+
+ using namespace llvm;
+
+ namespace {
+
+
+ //===----------------------------------------------------------------------===//
+ // Class definitions
+ //===----------------------------------------------------------------------===//
+
+ class AltiVec : public FunctionPass, public InstVisitor<AltiVec> {
+
+ public:
+ bool runOnFunction(Function &F);
+ void visitCastInst(CastInst &);
+ void visitVImmInst(VImmInst &);
+ void visitExtractInst(ExtractInst &);
+ void visitCombineInst(CombineInst &);
+ void visitVSelectInst(VSelectInst &);
+ void visitShiftInst(ShiftInst &);
+ void visitMul(BinaryOperator &BO);
+ void visitSub(BinaryOperator &BO);
+ void visitSetCondInst(SetCondInst &BO);
+ void visitCallInst(CallInst &);
+ void visitInstruction(Instruction& I) {}
+
+ private:
+ bool changed;
+ hash_set<Instruction*> instructionsToDelete;
+
+ void deleteInstructions() {
+ for (hash_set<Instruction*>::iterator I = instructionsToDelete.begin(),
+ E = instructionsToDelete.end(); I != E; ++I) {
+ (*I)->dropAllReferences();
+ }
+ for (hash_set<Instruction*>::iterator I = instructionsToDelete.begin(),
+ E = instructionsToDelete.end(); I != E; ++I) {
+ (*I)->getParent()->getInstList().erase(*I);
+ }
+ }
+ };
+
+ RegisterOpt<AltiVec> X("altivec",
+ "AltiVec code generation pre-pass");
+
+ //===----------------------------------------------------------------------===//
+ // Helper functions
+ //===----------------------------------------------------------------------===//
+
+ /// Check whether the type is one that AltiVec can handle; if not,
+ /// it must be lowered later.
+ ///
+ bool isProperType(const VectorType *VT) {
+ // Only fixed vector types are allowed
+ //
+ const FixedVectorType *FVT = dyn_cast<FixedVectorType>(VT);
+ if (!FVT) return false;
+ switch(VT->getElementType()->getTypeID()) {
+ case (Type::IntTyID):
+ case (Type::UIntTyID):
+ return FVT->getNumElements() == 4;
+ case (Type::ShortTyID):
+ case (Type::UShortTyID):
+ return FVT->getNumElements() == 8;
+ case (Type::SByteTyID):
+ case (Type::UByteTyID):
+ return FVT->getNumElements() == 16;
+ default:
+ return false;
+ }
+ }
+
+ std::string getAltiVecName(const std::string& baseName, const VectorType *VT) {
+ return "altivec_" + baseName + "_" + VT->getElementType()->getDescription();
+ }
+
+
+ //===----------------------------------------------------------------------===//
+ // AltiVec implementation
+ //===----------------------------------------------------------------------===//
+
+ /// Main function called by PassManager
+ ///
+ bool AltiVec::runOnFunction(Function &F) {
+ instructionsToDelete.clear();
+ changed = false;
+ for (Function::iterator FI = F.begin(), FE = F.end();
+ FI != FE; ++FI)
+ for (BasicBlock::iterator BI = FI->begin(), BE = FI->end();
+ BI != BE; ++BI)
+ if (!instructionsToDelete.count(BI)) {
+ DEBUG(std::cerr << "Visiting instruction " << *BI);
+ visit(*BI);
+ }
+ if (changed) deleteInstructions();
+ return changed;
+ }
+
+ void AltiVec::visitVImmInst(VImmInst &VL) {
+ if (!VL.hasOneUse()) return;
+ CastInst *cast = dyn_cast<CastInst>(*VL.use_begin());
+ if (!cast) return;
+ const VectorType *VT = dyn_cast<VectorType>(cast->getType());
+ // We need only worry about a cast of a non-constant scalar to a
+ // vector; the AltiVec C Backend can handle the other cases
+ // directly.
+ //
+ if (!VT || !isProperType(VT) ||
+ isa<Constant>(VL.getOperand(0)))
+ return;
+ // We need to create a new vector on the stack, store the scalar
+ // value into it, and splat the value into a vector register
+ //
+ AllocaInst *vectorPtr = new AllocaInst(VT, 0, "alloca", &VL);
+ Value *element = VL.getOperand(0);
+ if (element->getType() != VT->getElementType())
+ element = new CastInst(element, VT->getElementType(), "cast", &VL);
+ CastInst *scalarPtr = new CastInst(vectorPtr, PointerType::get(VT->getElementType()),
+ "cast", &VL);
+ StoreInst *store = new StoreInst(element, scalarPtr, &VL);
+ LoadInst *vector = new LoadInst(vectorPtr, "load", &VL);
+ CallInst *call = VectorUtils::getCallInst(VT, getAltiVecName("splat", VT),
+ vector, ConstantUInt::get(Type::UByteTy, 0),
+ "splat", &VL);
+ cast->replaceAllUsesWith(call);
+ instructionsToDelete.insert(cast);
+ instructionsToDelete.insert(&VL);
+ changed = true;
+ }
+
+ void AltiVec::visitCastInst(CastInst &CI) {
+ // We need only worry about a cast of a non-constant scalar to a
+ // vector; the AltiVec C Backend can handle the other cases
+ // directly.
+ //
+ const VectorType *VT = dyn_cast<VectorType>(CI.getType());
+ if (!VT || !isProperType(VT) ||
+ isa<VectorType>(CI.getOperand(0)->getType()) ||
+ isa<Constant>(CI.getOperand(0)))
+ return;
+ // We need to create a new vector on the stack, store the scalar
+ // value into it, and splat the value into a vector register
+ //
+ AllocaInst *vectorPtr = new AllocaInst(VT, 0, "alloca", &CI);
+ Value *element = CI.getOperand(0);
+ if (element->getType() != VT->getElementType())
+ element = new CastInst(element, VT->getElementType(), "cast", &CI);
+ CastInst *scalarPtr = new CastInst(vectorPtr, PointerType::get(VT->getElementType()),
+ "cast", &CI);
+ StoreInst *store = new StoreInst(element, scalarPtr, &CI);
+ LoadInst *vector = new LoadInst(vectorPtr, "load", &CI);
+ CallInst *call = VectorUtils::getCallInst(VT, getAltiVecName("splat", VT),
+ vector, ConstantUInt::get(Type::UByteTy, 0),
+ "splat", &CI);
+ CI.replaceAllUsesWith(call);
+ instructionsToDelete.insert(&CI);
+ changed = true;
+ }
+
+ // Check whether an extract instruction should be turned into
+ // altivec_unpack
+ //
+ void AltiVec::visitExtractInst(ExtractInst &EI) {
+ Value *v = EI.getOperand(0);
+ ConstantUInt *start = dyn_cast<ConstantUInt>(EI.getOperand(1));
+ ConstantUInt *stride = dyn_cast<ConstantUInt>(EI.getOperand(2));
+ ConstantUInt *len = dyn_cast<ConstantUInt>(EI.getOperand(3));
+ if (!start || !stride || !len) return;
+ if (stride->getValue() != 1 || len->getValue() != 8) return;
+ std::string baseName;
+ if (start->getValue() == 0)
+ baseName = "unpackh";
+ else if (start->getValue() == 8)
+ baseName = "unpackl";
+ else return;
+ const FixedVectorType *VT = dyn_cast<FixedVectorType>(v->getType());
+ if (VT == FixedVectorType::get(Type::UByteTy, 16)) {
+ if (!EI.hasOneUse()) return;
+ CastInst *cast = dyn_cast<CastInst>(*EI.use_begin());
+ if (!cast) return;
+ const FixedVectorType *argTy = FixedVectorType::get(Type::SByteTy, 16);
+ const FixedVectorType *retTy = FixedVectorType::get(Type::ShortTy, 8);
+ if (cast->getType() != FixedVectorType::get(Type::ShortTy, 8)) {
+ CastInst *cast2 = new CastInst(&EI, FixedVectorType::get(Type::ShortTy, 8), "cast", cast);
+ cast->setOperand(0, cast2);
+ cast = cast2;
+ }
+ CastInst *arg = new CastInst(v, argTy, "cast", &EI);
+ CallInst *call = VectorUtils::getCallInst(retTy, getAltiVecName(baseName, argTy),
+ arg, "unpack", &EI);
+
+ BinaryOperator *andInst =
+ BinaryOperator::create(Instruction::And, call,
+ ConstantExpr::getCast(ConstantSInt::get(Type::ShortTy, 0xFF), retTy),
+ "and", &EI);
+ cast->replaceAllUsesWith(andInst);
+ instructionsToDelete.insert(cast);
+ instructionsToDelete.insert(&EI);
+ changed = true;
+ }
+ }
+
+ void AltiVec::visitCombineInst(CombineInst &CI) {
+ CombineInst *combine1 = cast<CombineInst>(&CI);
+ Value *v1 = CI.getOperand(0);
+ Value *v2 = CI.getOperand(1);
+ // If the destination is a combine instruction, do nothing; if
+ // necessary, we'll handle the first combine instruction in the
+ // series.
+ //
+ if (isa<CombineInst>(v1))
+ return;
+ // We must have two fixed-vector operands
+ //
+ const FixedVectorType *VT1 = dyn_cast<FixedVectorType>(v1->getType());
+ if (!VT1) return;
+ const FixedVectorType *VT2 = dyn_cast<FixedVectorType>(v2->getType());
+ if (!VT2) return;
+ if (VT1->getNumElements() != 2*VT2->getNumElements())
+ return;
+ // This combine must have exactly one use, and it must be a
+ // combine whose operand 0 is this combine. The types must work
+ // out properly.
+ //
+ if (!CI.hasOneUse()) return;
+ CombineInst* combine2 = dyn_cast<CombineInst>(*CI.use_begin());
+ if (!combine2) return;
+ if (&CI != combine2->getOperand(0)) return;
+ if (combine2->getOperand(1)->getType() != VT2) return;
+ if (combine2->hasOneUse()) {
+ // Check for a vec_pack or vec_perm pattern
+ //
+ Instruction *use = dyn_cast<Instruction>(*combine2->use_begin());
+ if (!use) return;
+ const FixedVectorType *VT = dyn_cast<FixedVectorType>(use->getType());
+ if (!VT || !isProperType(VT)) return;
+ std::string baseName;
+ CallInst *call = 0;
+ if (isa<CastInst>(use)) {
+ baseName = "pack";
+ call = VectorUtils::getCallInst(VT, getAltiVecName(baseName, VT1),
+ v2, combine2->getOperand(1),
+ "pack", &CI);
+ }
+ else if (VectorUtils::isFunctionContaining(use, "saturate")) {
+ baseName = "packsu";
+ call = VectorUtils::getCallInst(VT, getAltiVecName(baseName, VT1),
+ v2, combine2->getOperand(1),
+ "pack", &CI);
+ }
+ else if (VectorUtils::isFunctionContaining(use, "permute")) {
+ call = VectorUtils::getCallInst(VT, getAltiVecName("perm", VT),
+ CI.getOperand(1), combine2->getOperand(1),
+ cast<CallInst>(use)->getOperand(2),
+ "perm", &CI);
+ }
+ else {
+ return;
+ }
+
+ use->replaceAllUsesWith(call);
+ instructionsToDelete.insert(use);
+ instructionsToDelete.insert(combine2);
+ instructionsToDelete.insert(&CI);
+ Instruction *op0 = dyn_cast<Instruction>(CI.getOperand(0));
+ if (op0)
+ instructionsToDelete.insert(op0);
+ changed = true;
+ } else if (combine2->hasNUses(2)) {
+ Value::use_iterator I = combine2->use_begin();
+ ExtractInst *extract0 = dyn_cast<ExtractInst>(*I++);
+ ExtractInst *extract1 = dyn_cast<ExtractInst>(*I);
+ assert(extract0 && extract1);
+ CallInst *mergeh = VectorUtils::getCallInst(VT2, getAltiVecName("mergeh", VT2),
+ combine1->getOperand(1), combine2->getOperand(1),
+ "mergeh", extract0);
+ CallInst *mergel = VectorUtils::getCallInst(VT2, getAltiVecName("mergel", VT2),
+ combine1->getOperand(1), combine2->getOperand(1),
+ "mergel", extract0);
+ if (cast<ConstantUInt>(extract0->getOperand(1))->getValue() == 0) {
+ extract0->replaceAllUsesWith(mergeh);
+ extract1->replaceAllUsesWith(mergel);
+ } else {
+ extract0->replaceAllUsesWith(mergel);
+ extract1->replaceAllUsesWith(mergeh);
+ }
+ instructionsToDelete.insert(combine1);
+ instructionsToDelete.insert(combine2);
+ instructionsToDelete.insert(extract0);
+ instructionsToDelete.insert(extract1);
+ changed = true;
+ }
+ }
+
+ void AltiVec::visitVSelectInst(VSelectInst &VI) {
+ const FixedVectorType *VT = dyn_cast<FixedVectorType>(VI.getType());
+ if (!VT || !isProperType(VT)) return;
+ CallInst *call = VectorUtils::getCallInst(VT, getAltiVecName("sel", VT),
+ VI.getOperand(2), VI.getOperand(1),
+ VI.getOperand(0), "sel", &VI);
+ VI.replaceAllUsesWith(call);
+ instructionsToDelete.insert(&VI);
+ changed = true;
+ }
+
+ void AltiVec::visitSetCondInst(SetCondInst &BO) {
+ const FixedVectorType *VT =
+ dyn_cast<FixedVectorType>(BO.getOperand(0)->getType());
+ if (!VT || !isProperType(VT)) return;
+ std::string name;
+ switch(BO.getOpcode()) {
+ case Instruction::VSetGT:
+ name = "cmpgt";
+ break;
+ default:
+ assert(0 && "Unknown VSetCC opcode!");
+ }
+ CallInst *call = VectorUtils::getCallInst(BO.getType(), getAltiVecName(name, VT),
+ BO.getOperand(0), BO.getOperand(1),
+ "cmp", &BO);
+ BO.replaceAllUsesWith(call);
+ instructionsToDelete.insert(&BO);
+ changed = true;
+ }
+
+ void AltiVec::visitSub(BinaryOperator &BO) {
+ const FixedVectorType *VT =
+ dyn_cast<FixedVectorType>(BO.getOperand(0)->getType());
+ if (!VT) return;
+ CallInst *sub = VectorUtils::getCallInst(VT, getAltiVecName("sub", VT),
+ BO.getOperand(0), BO.getOperand(1),
+ "sub", &BO);
+ BO.replaceAllUsesWith(sub);
+ instructionsToDelete.insert(&BO);
+ changed = true;
+ }
+
+ void AltiVec::visitShiftInst(ShiftInst &SI) {
+ const FixedVectorType *VT =
+ dyn_cast<FixedVectorType>(SI.getOperand(0)->getType());
+ if (!VT) return;
+ std::string shortName;
+ if (SI.getOpcode() == Instruction::Shr) {
+ if (VT->getElementType()->isSigned())
+ shortName = "sra";
+ else
+ shortName = "srl";
+ } else {
+ shortName = "sll";
+ }
+ CastInst *cast = new CastInst(SI.getOperand(1), FixedVectorType::get(Type::UShortTy, 8), "cast", &SI);
+ CallInst *shift = VectorUtils::getCallInst(VT, getAltiVecName(shortName, VT),
+ SI.getOperand(0), cast,
+ "shift", &SI);
+ SI.replaceAllUsesWith(shift);
+ instructionsToDelete.insert(&SI);
+ changed = true;
+ }
+
+ void AltiVec::visitMul(BinaryOperator &BO) {
+ const FixedVectorType *VT =
+ dyn_cast<FixedVectorType>(BO.getOperand(0)->getType());
+ if (!VT || !BO.hasOneUse())
+ return;
+ Instruction *use = dyn_cast<Instruction>(*BO.use_begin());
+ if (!use) return;
+ switch (use->getOpcode()) {
+ case Instruction::Add: {
+ // Check for mradds pattern
+ //
+ BinaryOperator *add = cast<BinaryOperator>(use);
+ CastInst *mulCast0 = dyn_cast<CastInst>(BO.getOperand(0));
+ CastInst *mulCast1 = dyn_cast<CastInst>(BO.getOperand(1));
+ CastInst *addCast0 = 0, *addCast2 = 0, *shrCast = 0;
+ VImmInst *VImm = 0;
+ ShiftInst *shr = 0;
+ CallInst *adds;
+ unsigned offset = 0, shamt = 0;
+ if (&BO == add->getOperand(0))
+ addCast0 = dyn_cast<CastInst>(add->getOperand(1));
+ else
+ addCast0 = dyn_cast<CastInst>(add->getOperand(0));
+ if (addCast0)
+ VImm = dyn_cast<VImmInst>(addCast0->getOperand(0));
+ if (VImm) {
+ if (ConstantSInt *C = dyn_cast<ConstantSInt>(VImm->getOperand(0)))
+ offset = C->getValue();
+ else if (ConstantUInt *C = dyn_cast<ConstantUInt>(VImm->getOperand(0)))
+ offset = C->getValue();
+ }
+ if (add->hasOneUse()) {
+ shr = dyn_cast<ShiftInst>(*add->use_begin());
+ }
+ if (shr && shr->hasOneUse()) {
+ if (ConstantUInt *C = dyn_cast<ConstantUInt>(shr->getOperand(1)))
+ shamt = C->getValue();
+ shrCast = dyn_cast<CastInst>(*shr->use_begin());
+ }
+ if (shrCast && shrCast->hasOneUse()) {
+ adds = dyn_cast<CallInst>(*shrCast->use_begin());
+ Function *F = adds->getCalledFunction();
+ if (!F || F->getName().substr(0, 10) != "vllvm_adds")
+ adds = 0;
+ }
+ if (mulCast0 && mulCast1 && addCast0 && VImm &&
+ offset == 16384 && shrCast && shr && shamt == 15 &&
+ adds) {
+ VT = cast<FixedVectorType>(adds->getType());
+ CallInst *mradds = VectorUtils::getCallInst(VT, getAltiVecName("mradds", VT),
+ new CastInst(mulCast0->getOperand(0), VT, "cast", adds),
+ new CastInst(mulCast1->getOperand(0), VT, "cast", adds),
+ (shrCast == adds->getOperand(1)) ? adds->getOperand(2) : adds->getOperand(1),
+ "mradds", adds);
+ adds->replaceAllUsesWith(mradds);
+ instructionsToDelete.insert(&BO);
+ instructionsToDelete.insert(add);
+ instructionsToDelete.insert(shr);
+ instructionsToDelete.insert(mulCast0);
+ instructionsToDelete.insert(mulCast1);
+ instructionsToDelete.insert(addCast0);
+ instructionsToDelete.insert(VImm);
+ instructionsToDelete.insert(shrCast);
+ instructionsToDelete.insert(adds);
+ changed = true;
+ return;
+ }
+ // Check for mladd pattern
+ //
+ CallInst *call = VectorUtils::getCallInst(VT, getAltiVecName("mladd", VT),
+ BO.getOperand(0), BO.getOperand(1),
+ (use->getOperand(0) == &BO) ? use->getOperand(1) : use->getOperand(0),
+ "mladd", use);
+ use->replaceAllUsesWith(call);
+ instructionsToDelete.insert(&BO);
+ instructionsToDelete.insert(use);
+ changed = true;
+ break;
+ }
+ case Instruction::Shr: {
+ // Check for madds pattern
+ //
+ if (!use->hasOneUse())
+ return;
+ ConstantUInt *shamt = dyn_cast<ConstantUInt>(use->getOperand(1));
+ if (!shamt || (shamt->getValue() != 15))
+ return;
+ CastInst *op0 = dyn_cast<CastInst>(BO.getOperand(0));
+ CastInst *op1 = dyn_cast<CastInst>(BO.getOperand(1));
+ CastInst *use2 = dyn_cast<CastInst>(*use->use_begin());
+ VT = dyn_cast<FixedVectorType>(op0->getOperand(0)->getType());
+ if (op0 && op1 && use2 && VT) {
+ CallInst *madds =
+ VectorUtils::getCallInst(VT, getAltiVecName("madds", VT),
+ op0->getOperand(0), op1->getOperand(0),
+ ConstantExpr::getCast(ConstantSInt::get(Type::IntTy, 0), VT),
+ "madds", &BO);
+ use2->replaceAllUsesWith(madds);
+ instructionsToDelete.insert(&BO);
+ instructionsToDelete.insert(use);
+ instructionsToDelete.insert(op0);
+ instructionsToDelete.insert(op1);
+ instructionsToDelete.insert(use2);
+ changed = true;
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ }
+
+ void AltiVec::visitCallInst(CallInst &CI) {
+ const FixedVectorType *VT =
+ dyn_cast<FixedVectorType>(CI.getType());
+ if (!VT || !isProperType(VT)) return;
+ Function *callee = CI.getCalledFunction();
+ if (!callee) return;
+ std::string calleeName = callee->getName();
+ std::string prefix = calleeName.substr(0, 6);
+ if (prefix != "vllvm_") return;
+ unsigned pos = calleeName.find("_", 6);
+ if (pos == std::string::npos) {
+ std::cerr << "Bad syntax for Vector-LLVM intrinsic " << calleeName << "\n";
+ exit(1);
+ }
+ std::string shortName = calleeName.substr(6, pos-6);
+ if (shortName == "saturate") {
+ return;
+ } else {
+ std::vector<Value*> args;
+ for (unsigned i = 1; i < CI.getNumOperands(); ++i)
+ args.push_back(CI.getOperand(i));
+ CallInst *call = VectorUtils::getCallInst(VT, getAltiVecName(shortName, VT),
+ args, "vec_func", &CI);
+ CI.replaceAllUsesWith(call);
+ instructionsToDelete.insert(&CI);
+ }
+ changed = true;
+ }
+
+ }
Index: llvm/lib/Transforms/Vector/LowerVectors.cpp
diff -c /dev/null llvm/lib/Transforms/Vector/LowerVectors.cpp:1.1.2.1
*** /dev/null Tue Oct 18 14:37:14 2005
--- llvm/lib/Transforms/Vector/LowerVectors.cpp Tue Oct 18 14:37:03 2005
***************
*** 0 ****
--- 1,1059 ----
+ //===- LowerVectors.cpp - Lower vector operations -------------------------===//
+ //
+ // The LLVM Compiler Infrastructure
+ //
+ // This file was developed by the LLVM research group and is distributed under
+ // the University of Illinois Open Source License. See LICENSE.TXT for details.
+ //
+ //===----------------------------------------------------------------------===//
+ //
+ // This file lowers vector operations (such as vgather, vscatter, and
+ // vector arithmetic) to iterated scalar operations. This pass does
+ // NOT generate efficient code; it is intended for testing and
+ // debugging of Vector-LLVM.
+ //
+ //===----------------------------------------------------------------------===//
+
+ #define DEBUG_TYPE "lowervectors"
+
+ #include <sstream>
+
+ #include "llvm/Constants.h"
+ #include "llvm/DerivedTypes.h"
+ #include "llvm/Function.h"
+ #include "llvm/Instructions.h"
+ #include "llvm/Module.h"
+ #include "llvm/Pass.h"
+ #include "llvm/Type.h"
+ #include "llvm/Support/CFG.h"
+ #include "llvm/Support/Debug.h"
+ #include "llvm/Support/InstVisitor.h"
+ #include "llvm/Transforms/Utils/BasicBlockUtils.h"
+ #include "llvm/ADT/hash_map"
+ #include "llvm/ADT/hash_set"
+ #include "llvm/Transforms/Scalar.h"
+ #include "VectorLLVM/Utils.h"
+
+ using namespace llvm;
+
+ namespace {
+
+ //===----------------------------------------------------------------------===//
+ // Class definitions
+ //===----------------------------------------------------------------------===//
+
+ class LowerVectors : public FunctionPass, public InstVisitor<LowerVectors> {
+
+ public:
+ bool doInitialization(Module &M);
+ bool runOnFunction(Function &F);
+ void visitVGatherInst(VGatherInst&);
+ void visitVScatterInst(VScatterInst&);
+ void visitCastInst(CastInst&);
+ void visitBinaryOperator(BinaryOperator&);
+ void visitVImmInst(VImmInst&);
+ void visitShiftInst(ShiftInst&);
+ void visitVSelectInst(VSelectInst&);
+ void visitExtractInst(ExtractInst&);
+ void visitCombineInst(CombineInst&);
+ void visitExtractElementInst(ExtractElementInst&);
+ void visitCombineElementInst(CombineElementInst&);
+ void visitPHINode(PHINode&);
+ void visitMallocInst(MallocInst&);
+ void visitFreeInst(FreeInst&);
+ void visitStoreInst(StoreInst&);
+ void visitLoadInst(LoadInst&);
+ void visitInstruction(Instruction& I) {
+ std::cerr << "LowerVectors class can't handle instruction " << I << "!\n";
+ exit(1);
+ }
+
+ static Function *ReallocFunc;
+
+ private:
+ void lowerInstruction(Instruction&);
+ //void lowerInstructionToLoop(Instruction*,Value*);
+ };
+
+ RegisterOpt<LowerVectors> X("lowervectors",
+ "Lower vector operations to iterated scalar operations");
+
+ class VMemoryInstLowering {
+ protected:
+ BasicBlock *constructLoop(VMemoryInst*,BasicBlock*,std::vector<Value*>,Value*);
+ void constructInnerLoop(VMemoryInst*,BasicBlock*,BasicBlock*,BasicBlock*,
+ BasicBlock*,std::vector<Value*>,Value*);
+ virtual void constructLoopBody(VMemoryInst*,BasicBlock*,
+ std::vector<Value*>,Value*) = 0;
+ };
+
+ class VGatherLowering : public VMemoryInstLowering {
+ public:
+ VGatherLowering(VGatherInst *VL) { lowerVGather(VL); }
+ void lowerVGather(VGatherInst*);
+
+ private:
+ void constructLoopBody(VMemoryInst*,BasicBlock*,
+ std::vector<Value*>,Value*);
+ };
+
+ class VScatterLowering : public VMemoryInstLowering {
+ public:
+ VScatterLowering(VScatterInst *VS) { lowerVScatter(VS); }
+ void lowerVScatter(VScatterInst*);
+
+ private:
+ void constructLoopBody(VMemoryInst*,BasicBlock*,
+ std::vector<Value*>,Value*);
+ };
+
+ class InstructionLowering : public InstVisitor<InstructionLowering> {
+ friend class LowerVectors;
+
+ Instruction *vector;
+ Value *result, *vectorIndex;
+ std::vector<Value*> idx;
+ BasicBlock *body;
+
+ public:
+ InstructionLowering(Instruction *I, Value *length, Value *ptr=0)
+ { lowerInstruction (I, length, ptr); }
+ void lowerInstruction(Instruction*,Value*,Value*);
+ void visitBinaryOperator(BinaryOperator&);
+ void visitVSelectInst(VSelectInst&);
+ void visitCastInst(CastInst&);
+ void visitShiftInst(ShiftInst&);
+ void visitExtractInst(ExtractInst&);
+ void visitCombineInst(CombineInst&);
+ void visitVImmInst(VImmInst&);
+ void visitLoadInst(LoadInst&);
+ void visitStoreInst(StoreInst&);
+ void visitInstruction(Instruction& I) {
+ std::cerr << "InstructionLowering class can't handle instruction " << I << "!\n";
+ exit(1);
+ }
+ };
+
+
+ //===----------------------------------------------------------------------===//
+ // Global data for this module
+ //===----------------------------------------------------------------------===//
+
+ /// Map from each vector value to the in-memory array that has been
+ /// allocated to hold the vector.
+ ///
+ hash_map<Value*,Value*> loweringMap;
+
+ /// Map from each vector value to its length. If a vector value
+ /// does not appear in this map, then it has not yet been lowered.
+ ///
+ hash_map<Value*,Value*> lengthMap;
+
+
+ //===----------------------------------------------------------------------===//
+ // Helper functions for managing vectors
+ //===----------------------------------------------------------------------===//
+
+ /// Initialize all {T*,uint} instances in a lowered type
+ /// corresponding to a [vector of T] that would be created by a
+ /// malloc or alloca with nulls
+ ///
+ void initializeVectors(Value *ptr, const Type* originalAllocatedTy,
+ Instruction *before) {
+ if (isa<VectorType>(originalAllocatedTy)) {
+ for (unsigned i = 0; i < 2; ++i) {
+ std::vector<Value*> Idx;
+ Idx.push_back(ConstantUInt::get(Type::UIntTy, 0));
+ Idx.push_back(ConstantUInt::get(Type::UIntTy, i));
+ GetElementPtrInst *GEP =
+ new GetElementPtrInst(ptr, Idx, "gep", before);
+ const PointerType *pointerTy = cast<PointerType>(GEP->getType());
+ StoreInst *store =
+ new StoreInst(Constant::getNullValue(pointerTy->getElementType()),
+ GEP, before);
+ }
+ } else if (isa<PointerType>(originalAllocatedTy)) {
+ return;
+ } else {
+ std::cerr << "Can't yet handle this type!\n";
+ exit(1);
+ }
+ }
+
+ /// Initialize all {T*,uint} instances in a lowered type
+ /// corresponding to a [vector of T] that would be created by a
+ /// malloc or alloca with nulls
+ ///
+ void freeVectors(Value *ptr, const Type* originalAllocatedTy,
+ Instruction *before) {
+ if (isa<VectorType>(originalAllocatedTy)) {
+ std::vector<Value*> Idx;
+ Idx.push_back(ConstantUInt::get(Type::UIntTy, 0));
+ Idx.push_back(ConstantUInt::get(Type::UIntTy, 0));
+ GetElementPtrInst *GEP =
+ new GetElementPtrInst(ptr, Idx, "gep", before);
+ LoadInst *load =
+ new LoadInst(GEP, "load", before);
+ FreeInst *free =
+ new FreeInst(load, before);
+ } else if (isa<PointerType>(originalAllocatedTy)) {
+ return;
+ } else {
+ std::cerr << "Can't yet handle this type!\n";
+ exit(1);
+ }
+ }
+
+ /// Add the specified <vector, length> pair to the lengthMap,
+ /// replacing dummy uses if necessary
+ ///
+ void setLength(Value *key, Value *newValue) {
+ Value*& oldValue = lengthMap[key];
+ if (oldValue) {
+ oldValue->replaceAllUsesWith(newValue);
+ delete oldValue;
+ }
+ oldValue = newValue;
+ }
+
+ /// Given a vector type value, look in the lengthMap to get its length.
+ ///
+ Value *getLength(Value *key) {
+ Value*& V = lengthMap[key];
+ if (!V) {
+ V = new Argument(Type::UIntTy);
+ DEBUG(std::cerr << "Creating dummy length " << *V << "\n");
+ }
+ return V;
+ }
+
+ /// Add the specified pair to the loweringMap, replacing dummy uses
+ /// if necessary
+ ///
+ void setLoweredValue(Value *key, Value *newValue, Value *length = 0) {
+ Value*& oldValue = loweringMap[key];
+ if (oldValue) {
+ oldValue->replaceAllUsesWith(newValue);
+ delete oldValue;
+ }
+ oldValue = newValue;
+ if (length)
+ setLength(key, length);
+ }
+
+ /// Lower an arbitary type of a pointer or an object stored in
+ /// memory. Lower a vector-derived type by replacing all instances
+ /// of [vector of T] with {T*,uint}, with some special rules for
+ /// vector-derived function types. "Lower" a non-vector-derived
+ /// type to the same type.
+ ///
+ const Type* getLoweredMemoryType(const Type* Ty) {
+ if (const VectorType *VectorTy = dyn_cast<VectorType>(Ty)) {
+ std::vector<const Type*> Params;
+ Params.push_back(PointerType::get(VectorTy->getElementType()));
+ Params.push_back(Type::UIntTy);
+ return StructType::get(Params);
+ } else if (Ty->isPrimitiveType()) {
+ return Ty;
+ } else if (const PointerType *PT = dyn_cast<PointerType>(Ty)) {
+ return PointerType::get(getLoweredMemoryType(PT->getElementType()));
+ }
+ return Ty; // For now
+ }
+
+ /// Lower the type of a first-class object stored in a virtual
+ /// register. Lower [vector of T] to T*, but lower [vector of T]*
+ /// to {T*,uint}*. "Lower" a non-vector-derived type to the same
+ /// type.
+ ///
+ const Type* getLoweredRegisterType(const Type* Ty) {
+ assert (Ty->isFirstClassType() &&
+ "getLoweredRegisterType() should be called only on first-class types!");
+ if (const VectorType *VectorTy = dyn_cast<VectorType>(Ty)) {
+ return PointerType::get(VectorTy->getElementType());
+ } else if (isa<PointerType>(Ty)) {
+ return getLoweredMemoryType(Ty);
+ }
+ return Ty;
+ }
+
+ /// Given a value, return the corresponding lowered value.
+ /// Otherwise, look in the loweringMap to get the corresponding
+ /// scalar value. If there is no corresponding value, we create a
+ /// dummy value that will be filled in later when the operand is
+ /// lowered.
+ ///
+ Value *getLoweredValue(Value *key) {
+ const VectorType *VectorTy = dyn_cast<VectorType>(key->getType());
+ Value*& value = loweringMap[key];
+ if (!value) {
+ value = new Argument(getLoweredRegisterType(key->getType()));
+ DEBUG(std::cerr << "Creating dummy lowered value " << *value << "\n");
+ }
+ return value;
+ }
+
+ /// Get a single lowered operand of an instruction. Each operand is
+ /// a value loaded from the specified position of the array that
+ /// stores the vector elements.
+ ///
+ Value* getOp(Instruction* I, unsigned n, std::vector<Value*>& idx,
+ Instruction *insertBefore) {
+ Value *op = I->getOperand(n);
+ Value *loweredOp = getLoweredValue(op);
+ Value *ptr = new GetElementPtrInst(loweredOp, idx, "ptr",
+ insertBefore);
+ return new LoadInst(ptr, "load", insertBefore);
+ }
+
+ /// Get the first n lowered operands of an instruction.
+ ///
+ void getOps(Instruction* I, Value* ops[], unsigned n,
+ std::vector<Value*>& idx, Instruction *insertBefore) {
+ for (unsigned i = 0; i < n; ++i)
+ ops[i] = getOp(I, i, idx, insertBefore);
+ }
+
+ /// Allocate a vector. We use realloc in case the allocation
+ /// happens inside a loop; because LLVM is in SSA form, this is
+ /// guaranteed to be correct. For each possible allocation, we
+ /// store a null pointer on the stack. When the allocation is
+ /// actually done, we store the pointer there. At each exit point
+ /// of the function, for each allocation, we call free on the
+ /// (possibly null) pointer.
+ ///
+ /// In the case of a store instruction, we use the
+ /// previously-allocated pointer, so we don't need to create a new
+ /// pointer, and we don't need to add the free instruction.
+ ///
+ Instruction *allocateVector(const Type* Ty, Value* ArraySize,
+ const std::string& Name, Instruction *InsertBefore,
+ Value *ptr=0) {
+ Function *function = InsertBefore->getParent()->getParent();
+ Instruction *front = &(function->getEntryBlock().front());
+ const Type *PointerTy = PointerType::get(Ty);
+
+ Value *loadStorePtr = ptr;
+ if (!ptr) {
+ loadStorePtr = new AllocaInst(PointerType::get(Ty), 0, "loadstore_ptr", front);
+ new StoreInst(Constant::getNullValue(PointerType::get(Ty)), loadStorePtr, front);
+ }
+
+ const FunctionType *ReallocFTy = LowerVectors::ReallocFunc->getFunctionType();
+
+ // Create the vector of arguments to realloc
+ //
+ Value *reallocPtr = new LoadInst(loadStorePtr, "realloc_ptr", InsertBefore);
+ if (reallocPtr->getType() != ReallocFTy->getParamType(0)) {
+ reallocPtr = new CastInst(reallocPtr, ReallocFTy->getParamType(0), "cast", InsertBefore);
+ }
+
+ Value *reallocSize =
+ BinaryOperator::create(Instruction::Mul, ArraySize,
+ ConstantUInt::get(Type::UIntTy,
+ Ty->getPrimitiveSize()),
+ "size", InsertBefore);
+ if (reallocSize->getType() != ReallocFTy->getParamType(1)) {
+ reallocSize = new CastInst(reallocSize, ReallocFTy->getParamType(1),
+ "cast", InsertBefore);
+ }
+
+ std::vector<Value*> ReallocArgs;
+ ReallocArgs.push_back(reallocPtr);
+ ReallocArgs.push_back(reallocSize);
+
+ // Create the call to realloc
+ //
+ CallInst *call = new CallInst(LowerVectors::ReallocFunc,
+ ReallocArgs, Name, InsertBefore);
+
+ if (!ptr) {
+ // Insert a free instruction before every return
+ //
+ for (Function::iterator FI = function->begin(), FE = function->end();
+ FI != FE; ++FI) {
+ if (ReturnInst *RI = dyn_cast<ReturnInst>(FI->getTerminator())) {
+ LoadInst *freePtr = new LoadInst(loadStorePtr, "free_ptr", RI);
+ new FreeInst(freePtr, RI);
+ }
+ }
+ }
+
+ // Create a cast instruction if necessary to convert to the right
+ // type
+ //
+ Instruction *result = call;
+ if (call->getType() != PointerTy)
+ result = new CastInst(call, PointerTy, "cast", InsertBefore);
+
+ // Store the pointer to the allocated memory in the load-store pointer
+ //
+ new StoreInst(result, loadStorePtr, InsertBefore);
+
+ return result;
+ }
+
+ /// Test whether a given instruction is a vector instruction that we
+ /// need to lower
+ ///
+ bool isVectorInstruction(Instruction *I) {
+ if (isa<VMemoryInst>(I)) return true;
+ if (isa<FreeInst>(I) || isa<StoreInst>(I))
+ return (getLoweredMemoryType(I->getOperand(0)->getType())
+ != I->getOperand(0)->getType());
+ return getLoweredMemoryType(I->getType()) != I->getType();
+ }
+
+
+ //===----------------------------------------------------------------------===//
+ // LowerVectors implementation
+ //===----------------------------------------------------------------------===//
+
+ Function *LowerVectors::ReallocFunc = 0;
+
+ bool LowerVectors::doInitialization(Module &M) {
+ const Type *SBPTy = PointerType::get(Type::SByteTy);
+ ReallocFunc = M.getNamedFunction("realloc");
+
+ if (ReallocFunc == 0)
+ ReallocFunc = M.getOrInsertFunction("realloc", SBPTy, SBPTy, Type::UIntTy, 0);
+
+ return false;
+ }
+
+ bool LowerVectors::runOnFunction(Function &F) {
+ DEBUG(std::cerr << "\nrunOnFunction(" << F.getName() << ")\n");
+
+ std::vector<Instruction*> instList;
+ loweringMap.clear();
+ lengthMap.clear();
+
+ // Find all vector instructions and push them on a list
+ //
+ for (Function::iterator FI = F.begin(), FE = F.end();
+ FI != FE; ++FI)
+ for (BasicBlock::iterator BI = FI->begin(), BE = FI->end();
+ BI != BE; ++BI)
+ //if (isVectorInstruction(BI)) {
+ if (VectorUtils::containsVector(BI)) {
+ instList.push_back(BI);
+ }
+
+ // Lower each one
+ //
+ for (std::vector<Instruction*>::const_iterator I = instList.begin();
+ I != instList.end(); ++I) {
+ lowerInstruction(**I);
+ }
+
+ // Remove each one
+ //
+ for (std::vector<Instruction*>::const_iterator I = instList.begin();
+ I != instList.end(); ++I) {
+ (*I)->dropAllReferences();
+ }
+
+ for (std::vector<Instruction*>::const_iterator I = instList.begin();
+ I != instList.end(); ++I) {
+ DEBUG(std::cerr << "Removing instruction " << **I);
+ (*I)->getParent()->getInstList().erase(*I);
+ }
+
+ // Program was changed iff we processed something on the list
+ //
+ return instList.size() > 0;
+ }
+
+ /// Lower a single instruction
+ ///
+ void LowerVectors::lowerInstruction(Instruction &I) {
+
+ // Check the operands
+ //
+ for (User::const_op_iterator OI = I.op_begin(),
+ OE = I.op_end(); OI != OE; ++OI) {
+ if (isa<VectorType>((*OI)->getType())) {
+ if (!isa<Instruction>(*OI))
+ // For a correct Vector-LLVM program, this should never
+ // occur
+ //
+ assert(0 && "All operands of vector instructions must be instructions for this pass to work!");
+ }
+ }
+
+ // Now process the instruction itself
+ //
+ DEBUG(std::cerr << "Lowering instruction " << I);
+ visit(I);
+
+ }
+
+ void LowerVectors::visitVGatherInst(VGatherInst &VL) {
+ VGatherLowering lower(&VL);
+ }
+
+ void LowerVectors::visitVScatterInst(VScatterInst &VS) {
+ VScatterLowering lower(&VS);
+ }
+
+ void LowerVectors::visitVImmInst(VImmInst &VL) {
+ InstructionLowering lower(&VL, VL.getOperand(1));
+ }
+
+ void LowerVectors::visitCastInst(CastInst &CI) {
+ InstructionLowering lower(&CI, getLength(CI.getOperand(0)));
+ }
+
+ void LowerVectors::visitBinaryOperator(BinaryOperator &BO) {
+ Value *op0 = BO.getOperand(0);
+ Value *op1 = BO.getOperand(1);
+ Value *length0 = getLength(op0);
+ Value *length1 = getLength(op1);
+ VectorUtils::ensureEquality(&BO, length0, length1);
+ InstructionLowering lower(&BO, length0);
+ }
+
+ void LowerVectors::visitShiftInst(ShiftInst &SI) {
+ Value *op0 = SI.getOperand(0);
+ Value *length = getLength(op0);
+ InstructionLowering lower(&SI, length);
+ }
+
+ void LowerVectors::visitVSelectInst(VSelectInst &SI) {
+ Value *op0 = SI.getOperand(0);
+ Value *op1 = SI.getOperand(1);
+ Value *op2 = SI.getOperand(2);
+ Value *loweredOp0 = getLoweredValue(op0);
+ Value *loweredOp1 = getLoweredValue(op1);
+ Value *loweredOp2 = getLoweredValue(op2);
+ Value *length0 = getLength(op0);
+ Value *length1 = getLength(op1);
+ Value *length2 = getLength(op2);
+ VectorUtils::ensureEquality(&SI, length0, length1);
+ VectorUtils::ensureEquality(&SI, length1, length2);
+ InstructionLowering lower(&SI, length0);
+ }
+
+ void LowerVectors::visitExtractInst(ExtractInst &EI) {
+ Value *length = EI.getOperand(3);
+ InstructionLowering lower(&EI, length);
+ }
+
+ void LowerVectors::visitCombineInst(CombineInst &CI) {
+ Value *length = getLength(CI.getOperand(0));
+ InstructionLowering lower(&CI, length);
+ }
+
+ void LowerVectors::visitExtractElementInst(ExtractElementInst &EI) {
+ std::vector<Value*> idx;
+ idx.push_back(EI.getOperand(1));
+ Value *element = getOp(&EI, 0, idx, &EI);
+ EI.replaceAllUsesWith(element);
+ }
+
+ void LowerVectors::visitCombineElementInst(CombineElementInst &CI) {
+ Value *base = getLoweredValue(CI.getOperand(0));
+ Value *element = CI.getOperand(1);
+ std::vector<Value*> idx;
+ idx.push_back(CI.getOperand(2));
+ Value *ptr = new GetElementPtrInst(base, idx, "ptr", &CI);
+ // According to the relaxed semantics, this is correct
+ //
+ new StoreInst(element, ptr, &CI);
+ setLoweredValue(&CI, base, getLength(CI.getOperand(0)));
+ }
+
+ void LowerVectors::visitPHINode(PHINode &PN) {
+ const VectorType *Ty = dyn_cast<VectorType>(PN.getType());
+ assert(Ty && "Instruction must be of vector type!");
+ PHINode *vectorPHI =
+ new PHINode(getLoweredRegisterType(PN.getIncomingValue(0)->getType()),
+ "phi", &PN);
+ PHINode *lengthPHI = new PHINode(Type::UIntTy, "phi", vectorPHI);
+ for(unsigned i = 0, e = PN.getNumIncomingValues();
+ i < e; ++i) {
+ Value *V = PN.getIncomingValue(i);
+ BasicBlock *BB = PN.getIncomingBlock(i);
+ vectorPHI->addIncoming(getLoweredValue(V), BB);
+ lengthPHI->addIncoming(getLength(V), BB);
+ }
+ setLoweredValue(&PN, vectorPHI, lengthPHI);
+ }
+
+ void LowerVectors::visitMallocInst(MallocInst &MI) {
+ const Type* Ty = MI.getAllocatedType();
+ const Type* loweredTy = getLoweredMemoryType(Ty);
+ MallocInst *malloc =
+ new MallocInst(loweredTy, MI.getArraySize(), "malloc", &MI);
+ initializeVectors(malloc, Ty, &MI);
+ setLoweredValue(&MI, malloc);
+ }
+
+ void LowerVectors::visitFreeInst(FreeInst &FI) {
+ Value *ptr = FI.getOperand(0);
+ Value *loweredPtr = getLoweredValue(ptr);
+ const Type* originalAllocatedTy =
+ cast<PointerType>(ptr->getType())->getElementType();
+ freeVectors(loweredPtr, originalAllocatedTy, &FI);
+ FreeInst *free =
+ new FreeInst(loweredPtr, &FI);
+ setLoweredValue(&FI, free);
+ }
+
+ void LowerVectors::visitStoreInst(StoreInst &SI) {
+ if (isa<VectorType>(SI.getOperand(0)->getType())) {
+ std::vector<Value*> Idx;
+ Idx.push_back(ConstantUInt::get(Type::UIntTy, 0));
+ Idx.push_back(ConstantUInt::get(Type::UIntTy, 0));
+ Value *ptr = getLoweredValue(SI.getOperand(1));
+ GetElementPtrInst *GEP =
+ new GetElementPtrInst(ptr, Idx, "gep", &SI);
+ InstructionLowering lower(&SI, getLength(SI.getOperand(0)), GEP);
+ Idx.pop_back();
+ Idx.push_back(ConstantUInt::get(Type::UIntTy, 1));
+ GEP = new GetElementPtrInst(ptr, Idx, "gep", &SI);
+ new StoreInst(getLength(SI.getOperand(0)), GEP, &SI);
+ } else {
+ Value *val = getLoweredValue(SI.getOperand(0));
+ Value *ptr = getLoweredValue(SI.getOperand(1));
+ StoreInst *store =
+ new StoreInst(val, ptr, &SI);
+ setLoweredValue(&SI, store);
+ }
+ }
+
+ void LowerVectors::visitLoadInst(LoadInst &LI) {
+ if (isa<VectorType>(cast<PointerType>(LI.getOperand(0)->getType())->getElementType())) {
+ std::vector<Value*> Idx;
+ Idx.push_back(ConstantUInt::get(Type::UIntTy, 0));
+ Idx.push_back(ConstantUInt::get(Type::UIntTy, 1));
+ Value *ptr = getLoweredValue(LI.getOperand(0));
+ GetElementPtrInst *GEP =
+ new GetElementPtrInst(ptr, Idx, "gep", &LI);
+ LoadInst *load =
+ new LoadInst(GEP, "load", &LI);
+ InstructionLowering lower(&LI, load);
+ } else {
+ LoadInst *load =
+ new LoadInst(getLoweredValue(LI.getOperand(0)), "load", &LI);
+ setLoweredValue(&LI, load);
+ }
+ }
+
+
+ //===----------------------------------------------------------------------===//
+ // InstructionLowering implementation
+ //===----------------------------------------------------------------------===//
+
+ /// Lower an instruction to a loop
+ ///
+ void InstructionLowering::lowerInstruction(Instruction *I, Value *length,
+ Value *ptr = 0) {
+ const Type *Ty =
+ isa<StoreInst>(I) ? I->getOperand(0)->getType() : I->getType();
+ const VectorType *VectorTy = dyn_cast<VectorType>(Ty);
+ assert(VectorTy && "Instruction must be of vector type!");
+ const Type *elementType = VectorTy->getElementType();
+ vector = allocateVector(elementType, length, "vector", I, ptr);
+ setLoweredValue(I, vector, length);
+
+ // Set up the loop index
+ //
+ Instruction *vectorIndexPtr =
+ new AllocaInst(Type::UIntTy, ConstantUInt::get(Type::UIntTy, 1),
+ vector->getName() + ".index", I);
+ new StoreInst(Constant::getNullValue(Type::UIntTy), vectorIndexPtr, I);
+
+ // Create the basic blocks of the loop
+ //
+ BasicBlock *predecessor = I->getParent();
+ BasicBlock *header =
+ I->getParent()->splitBasicBlock(I, "loop_header");
+ body =
+ header->splitBasicBlock(I, "loop_body");
+ BasicBlock *exit =
+ body->splitBasicBlock(I, "loop_exit");
+
+ // Add the correct branch instructions
+ //
+ body->getTerminator()->setSuccessor(0, header);
+ vectorIndex = new LoadInst(vectorIndexPtr, "index",
+ header->getTerminator());
+ Instruction *cond =
+ BinaryOperator::create(Instruction::SetLT, vectorIndex, length,
+ "setlt", header->getTerminator());
+ BasicBlock::iterator oldInst(header->getTerminator());
+ Instruction *newInst = new BranchInst(body, exit, cond);
+ ReplaceInstWithInst(header->getInstList(), oldInst, newInst);
+
+ // Perform the operation in the body of the loop
+ //
+ idx.clear();
+ idx.push_back(vectorIndex);
+ visit(I);
+ Instruction *GEP = new GetElementPtrInst(vector, idx, "GEP",
+ body->getTerminator());
+ new StoreInst(result, GEP, body->getTerminator());
+
+ // Increment the loop counter
+ //
+ Value *incr = BinaryOperator::create(Instruction::Add, vectorIndex,
+ ConstantUInt::get(Type::UIntTy, 1),
+ "incr", body->getTerminator());
+ new StoreInst(incr, vectorIndexPtr, body->getTerminator());
+ }
+
+ void InstructionLowering::visitBinaryOperator(BinaryOperator &BO) {
+ Value *ops[2];
+ getOps(&BO, ops, 2, idx, body->getTerminator());
+ Instruction::BinaryOps loweredOpcode = BO.getOpcode();
+ if (SetCondInst *SC = dyn_cast<SetCondInst>(&BO))
+ loweredOpcode = SC->getScalarOpcode();
+ result = BinaryOperator::create(loweredOpcode,
+ ops[0], ops[1], "binop",
+ body->getTerminator());
+ }
+
+ void InstructionLowering::visitVSelectInst(VSelectInst &SI) {
+ Value *ops[3];
+ getOps(&SI, ops, 3, idx, body->getTerminator());
+ result = new SelectInst(ops[0], ops[1], ops[2], "select",
+ body->getTerminator());
+ }
+
+ void InstructionLowering::visitCastInst(CastInst &CI) {
+ const VectorType *VT = dyn_cast<VectorType>(CI.getType());
+ assert(VT && "Cast instruction must have vector type!");
+ const Type* destTy = VT->getElementType();
+ Value *loweredOp = getOp(&CI, 0, idx, body->getTerminator());
+ result = new CastInst(loweredOp, destTy, "cast", body->getTerminator());
+ }
+
+ void InstructionLowering::visitShiftInst(ShiftInst &SI) {
+ Value *loweredOp = getOp(&SI, 0, idx, body->getTerminator());
+ result = new ShiftInst(cast<ShiftInst>(SI).getOpcode(), loweredOp,
+ SI.getOperand(1), "shift",
+ body->getTerminator());
+ }
+
+ void InstructionLowering::visitExtractInst(ExtractInst &EI) {
+ Value *mul = BinaryOperator::create(Instruction::Mul, vectorIndex,
+ EI.getOperand(2),
+ "mul", body->getTerminator());
+ Value *add = BinaryOperator::create(Instruction::Add,
+ EI.getOperand(1), mul,
+ "add", body->getTerminator());
+ std::vector<Value*> idx2;
+ idx2.push_back(add);
+ result = getOp(&EI, 0, idx2, body->getTerminator());
+ }
+
+ void InstructionLowering::visitCombineInst(CombineInst &CI) {
+ // Here we are generating code for
+ //
+ // %tmp = extract v1, v2, start, stride
+ //
+ // First we compute secondIndex, the index into v2. If start <=
+ // vectorIndex < start + stride * getLength(v2), then
+ // secondIndex = (vectorIndex - start) / stride. Otherwise, we
+ // won't use the value from the second vector, so we set
+ // secondIndex = 0 (a safe value that won't cause an illegal
+ // load).
+ //
+ Value *secondLength =
+ getLength(CI.getOperand(1));
+ Instruction *mul =
+ BinaryOperator::create(Instruction::Mul, CI.getOperand(3),
+ secondLength,
+ "mul", body->getTerminator());
+ Instruction *add =
+ BinaryOperator::create(Instruction::Add, CI.getOperand(2),
+ mul,
+ "add", body->getTerminator());
+ Instruction *compare1 =
+ BinaryOperator::create(Instruction::SetGE, vectorIndex,
+ CI.getOperand(2),
+ "compare", body->getTerminator());
+ Instruction *compare2 =
+ BinaryOperator::create(Instruction::SetLT, vectorIndex,
+ add,
+ "compare", body->getTerminator());
+ Instruction *inRange =
+ BinaryOperator::create(Instruction::And, compare1,
+ compare2,
+ "inRange", body->getTerminator());
+ Instruction *sub =
+ BinaryOperator::create(Instruction::Sub, vectorIndex,
+ CI.getOperand(2),
+ "secondIndex", body->getTerminator());
+ Instruction *div =
+ BinaryOperator::create(Instruction::Div, sub,
+ CI.getOperand(3),
+ "secondIndex", body->getTerminator());
+ Value *secondIndex =
+ new SelectInst(inRange, div,
+ ConstantUInt::get(Type::UIntTy, 0),
+ "select", body->getTerminator());
+
+ // Get the value out of v1 at vectorIndex and the value out of
+ // v2 at secondIndex
+ //
+ Value *firstValue = getOp(&CI, 0, idx, body->getTerminator());
+ std::vector<Value*> idx2;
+ idx2.push_back(secondIndex);
+ Value *secondValue = getOp(&CI, 1, idx2, body->getTerminator());
+
+ // Use the second value if vectorIndex is in range and
+ // (vectorIndex - start) is 0 mod the stride; otherwise use the
+ // first value
+ //
+ Value *select =
+ new SelectInst(inRange, secondValue, firstValue,
+ "select", body->getTerminator());
+ Value *rem =
+ BinaryOperator::create(Instruction::Rem, sub,
+ CI.getOperand(3), "rem",
+ body->getTerminator());
+ Value *compare =
+ BinaryOperator::create(Instruction::SetEQ, rem,
+ ConstantUInt::get(Type::UIntTy, 0),
+ "compare", body->getTerminator());
+ result =
+ new SelectInst(compare, select, firstValue,
+ "select", body->getTerminator());
+ }
+
+ void InstructionLowering::visitVImmInst(VImmInst &VL) {
+ result = VL.getOperand(0);
+ }
+
+ void InstructionLowering::visitLoadInst(LoadInst &LI) {
+ Value *ptr = getLoweredValue(LI.getOperand(0));
+ std::vector<Value*> idx2;
+ idx2.push_back(ConstantUInt::get(Type::UIntTy, 0));
+ idx2.push_back(ConstantUInt::get(Type::UIntTy, 0));
+ GetElementPtrInst *GEP =
+ new GetElementPtrInst(ptr, idx2, "gep", body->getTerminator());
+ LoadInst *load =
+ new LoadInst(GEP, "load", body->getTerminator());
+ GEP =
+ new GetElementPtrInst(load, idx, "gep", body->getTerminator());
+ result = new LoadInst(GEP, "load", body->getTerminator());
+ }
+
+ void InstructionLowering::visitStoreInst(StoreInst &SI) {
+ result = getOp(&SI, 0, idx, body->getTerminator());
+ }
+
+
+ //===----------------------------------------------------------------------===//
+ // VMemoryInstLowering implementation
+ //===----------------------------------------------------------------------===//
+
+ /// Construct a loop over the array elements defined by the indices
+ /// of a vgather or vscatter instruction. The resulting loop has
+ /// one nest for each set of four indices. This function calls the
+ /// template function constructLoopBody(), which is defined
+ /// differently for vgather and vscatter.
+ ///
+ BasicBlock* VMemoryInstLowering::constructLoop(VMemoryInst *VI, BasicBlock *oldHeader,
+ std::vector<Value*> arrayIndices,
+ Value* vectorIndex) {
+ BasicBlock *lastBB = 0;
+ unsigned loopLevel = arrayIndices.size();
+ if (VI->getNumIndices() == 4 * loopLevel) {
+ // This is the innermost loop nest; construct the loop body
+ //
+ BasicBlock *body =
+ oldHeader->getTerminator()->getSuccessor(0);
+ body->setName("loop_body");
+ constructLoopBody(VI, body, arrayIndices, vectorIndex);
+ // Increment the vector index if there is one
+ //
+ if (vectorIndex) {
+ Value *load = new LoadInst(vectorIndex, "load", body->getTerminator());
+ Value *add = BinaryOperator::create(Instruction::Add, load,
+ ConstantSInt::get(Type::LongTy, 1),
+ "add", body->getTerminator());
+ new StoreInst(add, vectorIndex, body->getTerminator());
+ }
+ lastBB = body;
+ } else {
+ // We need another loop nest, so go ahead and split this loop
+ // into header, body, and exit basic blocks and recursively
+ // construct the nest
+ //
+ BasicBlock *header =
+ oldHeader->getTerminator()->getSuccessor(0);
+ header->setName("loop_header");
+ BasicBlock *bodyBegin =
+ header->splitBasicBlock(header->begin());
+ BasicBlock *exit =
+ bodyBegin->splitBasicBlock(bodyBegin->begin(), "loop_exit");
+ bodyBegin->getTerminator()->setSuccessor(0, header);
+ constructInnerLoop(VI, oldHeader, header, bodyBegin, exit,
+ arrayIndices, vectorIndex);
+ lastBB = exit;
+ }
+ return lastBB;
+ }
+
+ /// Construct an inner loop, recursively calling constructLoop
+ ///
+ void VMemoryInstLowering::constructInnerLoop(VMemoryInst *VI, BasicBlock *oldHeader,
+ BasicBlock *header, BasicBlock *bodyBegin,
+ BasicBlock *exit, std::vector<Value*> arrayIndices,
+ Value *vectorIndex) {
+ // Fill in the loop header
+ //
+ unsigned loopLevel = arrayIndices.size();
+ PHINode *arrayPHI = new PHINode(Type::LongTy, "phi",
+ header->getTerminator());
+ arrayPHI->addIncoming(VI->getLowerBound(loopLevel), oldHeader);
+
+ // Fill in the body or split it to make a new loop
+ //
+ arrayIndices.push_back(arrayPHI);
+ BasicBlock *bodyEnd = constructLoop(VI, header, arrayIndices, vectorIndex);
+
+ // Fill in the increment and branch instructions at the end of the
+ // loop
+ //
+ pred_iterator pred = pred_begin(bodyEnd);
+ BasicBlock::iterator I = (*pred)->begin();
+ Instruction *arrayIncr =
+ BinaryOperator::create(Instruction::Add, arrayPHI, VI->getStride(loopLevel),
+ "add", bodyEnd->getTerminator());
+ arrayPHI->addIncoming(arrayIncr, bodyEnd);
+ Instruction *cond1 =
+ BinaryOperator::create(Instruction::SetLE, arrayPHI, VI->getUpperBound(loopLevel),
+ "setle", header->getTerminator());
+ Instruction *cond2 =
+ BinaryOperator::create(Instruction::SetGE, arrayPHI, VI->getUpperBound(loopLevel),
+ "setge", header->getTerminator());
+ Instruction *positiveStride =
+ BinaryOperator::create(Instruction::SetGE, VI->getStride(loopLevel),
+ ConstantSInt::get(Type::LongTy, 0),
+ "setge", header->getTerminator());
+ Instruction *select =
+ new SelectInst(positiveStride, cond1, cond2, "select", header->getTerminator());
+ BasicBlock::iterator oldInst(header->getTerminator());
+ Instruction *newInst = new BranchInst(bodyBegin, exit, select);
+ ReplaceInstWithInst(header->getInstList(), oldInst, newInst);
+ }
+
+
+ //===----------------------------------------------------------------------===//
+ // VGatherLowering implementation
+ //===----------------------------------------------------------------------===//
+
+ /// Lower a vgather instruction: Create an array to hold the vector
+ /// contents, then copy the indexed memory locations to that array
+ ///
+ void VGatherLowering::lowerVGather(VGatherInst *VL) {
+
+ // Allocate the array.
+ //
+ Instruction *firstInst = VL;
+ std::string name = VL->getName();
+ Value *length =
+ VectorUtils::computeIndexedLength(VL, VL, name + std::string(".length"));
+ VL->setName(name + std::string(".vector"));
+ Instruction *vector = allocateVector(VL->getElementType(), length,
+ name, firstInst);
+ setLoweredValue(VL, vector, length);
+
+ // Generate a loop to copy the contents
+ //
+ BasicBlock *header = VL->getParent();
+ Value *vectorIndex = new AllocaInst(Type::LongTy, ConstantUInt::get(Type::UIntTy, 1),
+ name + ".index", VL);
+ new StoreInst(Constant::getNullValue(Type::LongTy), vectorIndex, VL);
+ header->splitBasicBlock(VL);
+ std::vector<Value*> arrayIndices;
+ constructLoop(VL, header, arrayIndices, vectorIndex);
+
+ }
+
+ /// Generate code to take elements from the indexed memory locations
+ /// and put them into the array that holds the vector
+ ///
+ void VGatherLowering::constructLoopBody(VMemoryInst *VI, BasicBlock *body,
+ std::vector<Value*> arrayIndices,
+ Value* vectorIndexPtr) {
+ Instruction *arrayPtr =
+ VectorUtils::computeFlattenedPointer(VI, arrayIndices,
+ body->getTerminator());
+ Instruction *load =
+ new LoadInst(arrayPtr, "load", body->getTerminator());
+
+ Value *vectorIndex = new LoadInst(vectorIndexPtr, "index", body->getTerminator());
+ std::vector<Value*> vectorIndices;
+ vectorIndices.push_back(vectorIndex);
+ Instruction *vectorPtr =
+ new GetElementPtrInst(getLoweredValue(VI), vectorIndices, "ptr",
+ body->getTerminator());
+ new StoreInst(load, vectorPtr, body->getTerminator());
+ }
+
+
+ //===----------------------------------------------------------------------===//
+ // VScatterLowering implementation
+ //===----------------------------------------------------------------------===//
+
+ /// Lower a vstore instruction. Find the array that holds the
+ /// vector contents, then copy that array to the indexed memory
+ /// locations.
+ ///
+ void VScatterLowering::lowerVScatter(VScatterInst *VS) {
+ Value *V = getLoweredValue(VS->getOperand(0));
+ setLoweredValue(VS, V);
+
+ // Make sure the vector length agrees with the indexed array slice
+ //
+ Value *indexedLength =
+ VectorUtils::computeIndexedLength(VS, VS,
+ VS->getPointerOperand()->getName()
+ + std::string(".length"));
+ Value *operandLength = getLength(VS->getOperand(0));
+ VectorUtils::ensureEquality(VS, indexedLength, operandLength);
+
+ std::vector<Value*> arrayIndices;
+ BasicBlock *header = VS->getParent();
+ Value *vectorIndex = 0;
+
+ // Create a vector index to track the indexed position in the
+ // vector.
+ //
+ vectorIndex = new AllocaInst(Type::LongTy, ConstantUInt::get(Type::UIntTy, 1),
+ V->getName() + ".index", VS);
+ new StoreInst(Constant::getNullValue(Type::LongTy), vectorIndex, VS);
+
+ header->splitBasicBlock(VS);
+ constructLoop(VS, header, arrayIndices, vectorIndex);
+ }
+
+ /// Generate code to take elements from the array allocated for the
+ /// vector and put them into the indexed memory locations
+ ///
+ void VScatterLowering::constructLoopBody(VMemoryInst *VI, BasicBlock *body,
+ std::vector<Value*> arrayIndices,
+ Value* vectorIndexPtr) {
+ Value *elt;
+ // Copy vector array at indexed position to destination array.
+ //
+ Value *vectorIndex = new LoadInst(vectorIndexPtr, "index",
+ body->getTerminator());
+ std::vector<Value*> vectorIndices;
+ vectorIndices.push_back(vectorIndex);
+ Instruction *vectorPtr =
+ new GetElementPtrInst(getLoweredValue(VI), vectorIndices, "ptr",
+ body->getTerminator());
+ elt = new LoadInst(vectorPtr, "load", body->getTerminator());
+ Instruction *arrayPtr =
+ VectorUtils::computeFlattenedPointer(VI, arrayIndices,
+ body->getTerminator());
+ new StoreInst(elt, arrayPtr, body->getTerminator());
+ }
+
+ }
Index: llvm/lib/Transforms/Vector/RaiseVectors.cpp
diff -c /dev/null llvm/lib/Transforms/Vector/RaiseVectors.cpp:1.1.2.1
*** /dev/null Tue Oct 18 14:37:14 2005
--- llvm/lib/Transforms/Vector/RaiseVectors.cpp Tue Oct 18 14:37:03 2005
***************
*** 0 ****
--- 1,594 ----
+ //===- RaiseVectors.cpp - Raise significant functions to Vector-LLVM ------===//
+ //
+ // The LLVM Compiler Infrastructure
+ //
+ // This file was developed by the LLVM research group and is distributed under
+ // the University of Illinois Open Source License. See LICENSE.TXT for details.
+ //
+ //===----------------------------------------------------------------------===//
+ //
+ // This file raises the Vector-C significant functions to Vector-LLVM.
+ //
+ //===----------------------------------------------------------------------===//
+
+ #define DEBUG_TYPE "raisevectors"
+
+ #include "llvm/Constants.h"
+ #include "llvm/DerivedTypes.h"
+ #include "llvm/Function.h"
+ #include "llvm/Instructions.h"
+ #include "llvm/Pass.h"
+ #include "llvm/Type.h"
+ #include "llvm/Support/Debug.h"
+ #include "llvm/ADT/hash_map"
+ #include "llvm/ADT/hash_set"
+ #include "llvm/ADT/STLExtras.h"
+ #include "llvm/Support/InstVisitor.h"
+ #include "VectorLLVM/VectorSignificantFunctions.h"
+ #include "VectorLLVM/Utils.h"
+
+ using namespace llvm;
+
+ namespace {
+
+
+ //===----------------------------------------------------------------------===//
+ // Class definitions
+ //===----------------------------------------------------------------------===//
+
+ class RaiseVectors : public FunctionPass, public InstVisitor<RaiseVectors> {
+
+ public:
+ virtual bool doInitialization(Module &M);
+ virtual bool runOnFunction(Function &F);
+ void visitCallInst(CallInst&);
+ void visitBinaryOperator(BinaryOperator&);
+ void visitCastInst(CastInst&);
+ void visitShiftInst(ShiftInst&);
+ void visitSelectInst(SelectInst&);
+ void visitPHINode(PHINode&);
+ void visitInstruction(Instruction& I) {
+ std::cerr << "RaiseVectors: Unhandled instruction " << I;
+ exit(1);
+ }
+
+ private:
+ /// Map from original value to vector instruction
+ ///
+ hash_map<Value*,Value*> raisingMap;
+
+ /// Worklist of instructions to raise
+ ///
+ std::vector<Instruction*> workList;
+
+ /// Set of instructions that we have raised
+ ///
+ hash_set<Instruction*> raisedInstructions;
+
+ unsigned getVectorLength(Instruction*);
+ const Type *getRaisedType(const Type*,unsigned);
+ void setRaisedValue(Instruction*,Value*);
+ Value *getRaisedValue(Value*,unsigned);
+ Value *getRaisedOperand(Instruction*,unsigned);
+ bool addDefsToWorklist(Function&);
+ void addUsesToWorklist(Instruction*);
+ void raiseInstructions();
+ void deleteRaisedInstructions();
+
+ };
+
+ RegisterOpt<RaiseVectors> X("raisevectors",
+ "Raise Vector-C significant functions to Vector-LLVM");
+
+
+ //===----------------------------------------------------------------------===//
+ // RaiseVectors implementation
+ //===----------------------------------------------------------------------===//
+
+ /// llvm-gcc is very permissive about function declarations --
+ /// undeclared functions are treated as int(...). Here we require
+ /// and check that the user has properly declared all Vector-C
+ /// significant functions.
+ ///
+ bool RaiseVectors::doInitialization(Module &M) {
+ for (Module::iterator I = M.begin(), E = M.end();
+ I != E; ++I) {
+ if (!VectorSignificantFunctions::isProperlyDeclared(I)) {
+ std::cerr << "Significant function " << I->getName()
+ << " was not declared or was improperly declared.\n";
+ exit(1);
+ }
+ }
+ return false;
+ }
+
+ /// Main function called by PassManager
+ ///
+ bool RaiseVectors::runOnFunction(Function &F) {
+ DEBUG(std::cerr << "\nrunOnFunction(" << F.getName() << ")\n");
+
+ raisingMap.clear();
+ workList.clear();
+ raisedInstructions.clear();
+ bool changed = addDefsToWorklist(F);
+
+ if (changed) {
+ raiseInstructions();
+ deleteRaisedInstructions();
+ }
+
+ return changed;
+
+ }
+
+ /// Add all vector definitions to the work list
+ ///
+ bool RaiseVectors::addDefsToWorklist(Function& F) {
+ bool defFound = false;
+ for (Function::iterator FI = F.begin(), FE = F.end();
+ FI != FE; ++FI) {
+ for (BasicBlock::iterator BI = FI->begin(), BE = FI->end();
+ BI != BE; ++BI) {
+ if (CallInst *CI = dyn_cast<CallInst>(BI)) {
+ if (Function* F = CI->getCalledFunction()) {
+ VectorSignificantFunctions::ID id =
+ VectorSignificantFunctions::getID(F->getName());
+ if (id == VectorSignificantFunctions::vload ||
+ id == VectorSignificantFunctions::vgather ||
+ id == VectorSignificantFunctions::vloadi ||
+ id == VectorSignificantFunctions::vimm ||
+ id == VectorSignificantFunctions::fixed_vimm ||
+ id == VectorSignificantFunctions::load ||
+ id == VectorSignificantFunctions::constant) {
+ workList.push_back(CI);
+ defFound = true;
+ }
+ }
+ }
+ }
+ }
+ return defFound;
+
+ }
+
+ /// Raise all instructions on the worklist.
+ ///
+ void RaiseVectors::raiseInstructions() {
+ // Visit all the instructions to be raised. As uses are
+ // encountered, they are added to the worklist.
+ //
+ while (workList.size() > 0) {
+ Instruction *I = workList.back();
+ workList.pop_back();
+ if(raisedInstructions.insert(I).second) {
+ DEBUG(std::cerr << "Raising " << *I);
+ visit(*I);
+ DEBUG(if (raisingMap[I]) {std::cerr << "Raised value is " << *raisingMap[I];});
+ }
+ }
+ // Check for leftover dummy values indicating the program
+ // attempted to combine a scalar with a vector
+ //
+ for (hash_map<Value*,Value*>::iterator I = raisingMap.begin(),
+ E = raisingMap.end(); I != E; ++I) {
+ if (I->second && isa<Argument>(I->second)) {
+ std::cerr << "Value was never raised!\n";
+ std::cerr << *(I->first) << "\n";
+ std::cerr << "This is because you used a scalar value in a vector operation.\n";
+ std::cerr << "Use vimm to promote scalars to vectors "
+ << "before combining them with vectors.\n";
+ exit(1);
+ }
+ }
+ }
+
+ /// Delete all instructions that we have raised.
+ ///
+ void RaiseVectors::deleteRaisedInstructions() {
+
+ for (hash_set<Instruction*>::iterator I = raisedInstructions.begin(),
+ E = raisedInstructions.end(); I != E; ++I) {
+ DEBUG(std::cerr << "Dropping all references from " << **I);
+ (*I)->dropAllReferences();
+ }
+
+ for (hash_set<Instruction*>::iterator I = raisedInstructions.begin(),
+ E = raisedInstructions.end(); I != E; ++I) {
+ (*I)->getParent()->getInstList().erase(*I);
+ }
+ }
+
+ /// Raise a significant function call
+ ///
+ void RaiseVectors::visitCallInst(CallInst &CI) {
+ Function *F = CI.getCalledFunction();
+ if (!F) {
+ std::cerr << "Can't handle indirect function call " << CI;
+ exit(1);
+ }
+ std::string name = F->getName();
+ Value *raisedValue;
+ switch(VectorSignificantFunctions::getID(name)) {
+ case VectorSignificantFunctions::vload:
+ case VectorSignificantFunctions::vgather: {
+ std::vector<Value*> idx;
+ for (unsigned i = 2; i < CI.getNumOperands(); ++i) {
+ CastInst *castInst =
+ new CastInst(CI.getOperand(i), Type::LongTy, "cast", &CI);
+ idx.push_back(castInst);
+ }
+ raisedValue = new VGatherInst(CI.getOperand(1),
+ idx, "vgather", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ case VectorSignificantFunctions::load: {
+ ConstantUInt *UIntVal = dyn_cast<ConstantUInt>(CI.getOperand(2));
+ assert(UIntVal && "Vector length must be a constant UInt!");
+ const PointerType *PointerTy =
+ dyn_cast<PointerType>(CI.getOperand(1)->getType());
+ assert(PointerTy && "Pointer operand must be pointer type!");
+ CastInst *cast =
+ new CastInst(CI.getOperand(1),
+ PointerType::get(FixedVectorType::get(PointerTy->getElementType(),
+ UIntVal->getValue())),
+ "cast", &CI);
+ std::vector<Value*> Idx;
+ Idx.push_back(CI.getOperand(3));
+ GetElementPtrInst *GEP =
+ new GetElementPtrInst(cast, Idx, "gep", &CI);
+ raisedValue = new LoadInst(GEP, "load", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ case VectorSignificantFunctions::vimm:
+ case VectorSignificantFunctions::vloadi: {
+ raisedValue = new VImmInst(CI.getOperand(1), CI.getOperand(2), false,
+ "vimm", &CI);
+ const VectorType *VT = VectorType::get(CI.getOperand(1)->getType());
+ if (raisedValue->getType() != VT)
+ raisedValue = new CastInst(raisedValue, VT, "cast", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ case VectorSignificantFunctions::constant: {
+ std::vector<Constant*> elements;
+ for (unsigned i = 1; i < CI.getNumOperands(); ++i) {
+ Constant *C = dyn_cast<Constant>(CI.getOperand(i));
+ assert(C && "Operands of constant must be constants!");
+ elements.push_back(ConstantExpr::getCast(C, CI.getType()));
+ }
+ raisedValue = ConstantVector::get(elements);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ case VectorSignificantFunctions::fixed_vimm: {
+ ConstantUInt *UIntVal = dyn_cast<ConstantUInt>(CI.getOperand(2));
+ assert(UIntVal && "Vector length must be a constant UInt!");
+ raisedValue = new VImmInst(CI.getOperand(1), CI.getOperand(2),
+ true, "vimm", &CI);
+ const FixedVectorType *VT = FixedVectorType::get(CI.getType(), UIntVal->getValue());
+ if (raisedValue->getType() != VT)
+ raisedValue = new CastInst(raisedValue, VT, "cast", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ case VectorSignificantFunctions::vstore:
+ case VectorSignificantFunctions::vscatter: {
+ std::vector<Value*> idx;
+ for (unsigned i = 3; i < CI.getNumOperands(); ++i) {
+ CastInst *castInst =
+ new CastInst(CI.getOperand(i), Type::LongTy, "cast", &CI);
+ idx.push_back(castInst);
+ }
+ Value *raisedOp = getRaisedValue(CI.getOperand(1), (unsigned) 0);
+ Value *ptr = CI.getOperand(2);
+ raisedValue = new VScatterInst(raisedOp, ptr, idx, &CI);
+ break;
+ }
+ case VectorSignificantFunctions::store: {
+ unsigned length = getVectorLength(&CI);
+ Value *op1 = getRaisedValue(CI.getOperand(1), length);
+ const PointerType *PointerTy =
+ dyn_cast<PointerType>(CI.getOperand(2)->getType());
+ assert(PointerTy && "Pointer operand must be pointer type!");
+ CastInst *cast =
+ new CastInst(CI.getOperand(2),
+ PointerType::get(FixedVectorType::get(PointerTy->getElementType(),
+ length)),
+ "cast", &CI);
+ std::vector<Value*> Idx;
+ Idx.push_back(CI.getOperand(3));
+ GetElementPtrInst *GEP =
+ new GetElementPtrInst(cast, Idx, "gep", &CI);
+ raisedValue = new StoreInst(op1, GEP, &CI);
+ break;
+ }
+ case VectorSignificantFunctions::vselect: {
+ unsigned numArgs = 3;
+ unsigned length = getVectorLength(&CI);
+ assert((CI.getNumOperands() == numArgs+1) &&
+ "Wrong number of arguments to _select!");
+ CastInst *cast = dyn_cast<CastInst>(CI.getOperand(1));
+ assert(cast && "First operand of vselect must be cast!");
+ assert(cast->getOperand(0)->getType() == Type::BoolTy &&
+ "First operand of vselect must be cast of bool to int!");
+ Value *raisedArgs[numArgs];
+ raisedArgs[0] = getRaisedValue(cast->getOperand(0), length);
+ for (unsigned i = 1; i < numArgs; ++i) {
+ Value *arg = CI.getOperand(i+1);
+ raisedArgs[i] = getRaisedValue(arg, length);
+ }
+ raisedValue =
+ new VSelectInst(raisedArgs[0], raisedArgs[1], raisedArgs[2], "vselect", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ case VectorSignificantFunctions::extract: {
+ Value *raisedOp = getRaisedOperand(&CI, 1);
+ raisedValue = new ExtractInst(raisedOp, CI.getOperand(2),
+ CI.getOperand(3), CI.getOperand(4),
+ "extract", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ case VectorSignificantFunctions::combine: {
+ unsigned length = getVectorLength(&CI);
+ Value *raisedOp1 = getRaisedValue(CI.getOperand(1), length);
+ Value *raisedOp2 = getRaisedValue(CI.getOperand(2), length);
+ raisedValue = new CombineInst(raisedOp1, raisedOp2,
+ CI.getOperand(3), CI.getOperand(4),
+ "combine", &CI);
+ addUsesToWorklist(&CI);
+
+ break;
+ }
+ case VectorSignificantFunctions::fixed_combine: {
+ ConstantUInt *op2 = dyn_cast<ConstantUInt>(CI.getOperand(2));
+ ConstantUInt *op4 = dyn_cast<ConstantUInt>(CI.getOperand(4));
+ assert((op2 && op4) && "Vector length operands to fixed_combine must be constant uints!");
+ unsigned length1 = op2->getValue();
+ unsigned length2 = op4->getValue();
+ Value *raisedOp1 = getRaisedValue(CI.getOperand(1), length1);
+ Value *raisedOp2 = getRaisedValue(CI.getOperand(3), length2);
+ raisedValue = new CombineInst(raisedOp1, raisedOp2,
+ CI.getOperand(5), CI.getOperand(6),
+ "combine", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ case VectorSignificantFunctions::fixed_permute: {
+ ConstantUInt *op2 = dyn_cast<ConstantUInt>(CI.getOperand(2));
+ ConstantUInt *op4 = dyn_cast<ConstantUInt>(CI.getOperand(4));
+ assert((op2 && op4) && "Vector length operands to fixed_combine must be constant uints!");
+ unsigned length1 = op2->getValue();
+ unsigned length2 = op4->getValue();
+ Value *raisedOp1 = getRaisedValue(CI.getOperand(1), length1);
+ Value *raisedOp2 = getRaisedValue(CI.getOperand(3), length2);
+ raisedValue = VectorUtils::getCallInst(raisedOp2->getType(), "vllvm_permute_" +
+ cast<FixedVectorType>(raisedOp2->getType())->getElementType()->getDescription(),
+ raisedOp1, raisedOp2, "permute", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ case VectorSignificantFunctions::extractelement: {
+ Value *raisedOp = getRaisedOperand(&CI, 1);
+ raisedValue =
+ new ExtractElementInst(raisedOp, CI.getOperand(2),
+ "extractelement", &CI);
+ CI.replaceAllUsesWith(raisedValue);
+ break;
+ }
+ case VectorSignificantFunctions::combineelement: {
+ Value *raisedOp1 = getRaisedOperand(&CI, 1);
+ raisedValue = new CombineElementInst(raisedOp1, CI.getOperand(2),
+ CI.getOperand(3), "combineelement", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ default:
+ if (name.substr(0, 7) == "vectorc") {
+ name.erase(0, 7);
+ name = "vllvm" + name;
+ }
+ else if (name.substr(0, 5) == "vllvm") {
+ name += "_vector";
+ } else {
+ std::cerr << "Can't handle instruction " << CI;
+ exit(1);
+ }
+ unsigned length = getVectorLength(&CI);
+ std::vector<const Type*> formalArgs;
+ std::vector<Value*> args;
+ for (unsigned i = 1; i < CI.getNumOperands(); ++i) {
+ Value *op = CI.getOperand(i);
+ if (isa<Constant>(op)) {
+ formalArgs.push_back(op->getType());
+ args.push_back(op);
+ } else {
+ formalArgs.push_back(getRaisedType(CI.getOperand(i)->getType(), length));
+ args.push_back(getRaisedOperand(&CI, i));
+ }
+ }
+ FunctionType *FType =
+ FunctionType::get(getRaisedType(F->getReturnType(), length), formalArgs, false);
+ Module *M = CI.getParent()->getParent()->getParent();
+ Function *func = M->getOrInsertFunction(name, FType);
+ raisedValue = new CallInst(func, args, "func", &CI);
+ addUsesToWorklist(&CI);
+ break;
+ }
+ setRaisedValue(&CI, raisedValue);
+ }
+
+ /// Raise a binary operator
+ ///
+ void RaiseVectors::visitBinaryOperator(BinaryOperator &BO) {
+ unsigned length = getVectorLength(&BO);
+ Value *newOp[2];
+ for (unsigned i = 0; i < 2; ++i) {
+ newOp[i] = getRaisedValue(BO.getOperand(i), length);
+ }
+ Instruction::BinaryOps raisedOp;
+ if (SetCondInst *SI = dyn_cast<SetCondInst>(&BO))
+ raisedOp = SI->getVectorOpcode();
+ else raisedOp = BO.getOpcode();
+ Instruction *raisedValue =
+ BinaryOperator::create(raisedOp, newOp[0], newOp[1], "binop", &BO);
+ setRaisedValue(&BO, raisedValue);
+ addUsesToWorklist(&BO);
+ }
+
+ /// Raise a cast instruction
+ ///
+ void RaiseVectors::visitCastInst(CastInst &CI) {
+ // Don't raise the cast if it's a cast of a bool to an int to get
+ // it into a vselect significant function
+ //
+ if (CI.hasOneUse()) {
+ User *use = *CI.use_begin();
+ if (CallInst *I = dyn_cast<CallInst>(use)) {
+ if (Function *F = I->getCalledFunction()) {
+ if (VectorSignificantFunctions::getID(F->getName()) ==
+ VectorSignificantFunctions::vselect)
+ return;
+ }
+ }
+ }
+ unsigned length = getVectorLength(&CI);
+ Value *raisedOp = getRaisedValue(CI.getOperand(0), length);
+ Instruction *raisedValue =
+ new CastInst(raisedOp, getRaisedType(CI.getType(), length), "cast", &CI);
+ setRaisedValue(&CI, raisedValue);
+ addUsesToWorklist(&CI);
+ }
+
+ /// Raise a shift instruction
+ ///
+ void RaiseVectors::visitShiftInst(ShiftInst &SI) {
+ Value *raisedOp = getRaisedOperand(&SI, 0);
+ Instruction *raisedValue =
+ new ShiftInst(SI.getOpcode(), raisedOp,
+ SI.getOperand(1), "shift", &SI);
+ setRaisedValue(&SI, raisedValue);
+ addUsesToWorklist(&SI);
+ }
+
+ /// Raise a select instruction
+ ///
+ void RaiseVectors::visitSelectInst(SelectInst &SI) {
+ Value *newOp[2];
+ unsigned length = getVectorLength(&SI);
+ for (unsigned i = 0; i < 2; ++i) {
+ newOp[i] = getRaisedValue(SI.getOperand(i+1), length);
+ }
+ Instruction *raisedValue =
+ new SelectInst(SI.getOperand(0), newOp[0], newOp[1],
+ "select", &SI);
+ setRaisedValue(&SI, raisedValue);
+ addUsesToWorklist(&SI);
+ }
+
+ /// Raise a phi node
+ ///
+ void RaiseVectors::visitPHINode(PHINode &PN) {
+ PHINode *raisedValue =
+ new PHINode(VectorType::get(PN.getType()), "phi",
+ &PN);
+ unsigned length = getVectorLength(&PN);
+ for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
+ raisedValue->addIncoming(getRaisedValue(PN.getIncomingValue(i), length),
+ PN.getIncomingBlock(i));
+ setRaisedValue(&PN, raisedValue);
+ addUsesToWorklist(&PN);
+ }
+
+ /// Get the raised type for a given scalar type or pointer to scalar
+ /// type and vector length. Vector length of 0 means a non-fixed
+ /// length vector.
+ ///
+ const Type *RaiseVectors::getRaisedType(const Type* Ty, unsigned length) {
+ unsigned i = 0;
+ while (isa<PointerType>(Ty)) {
+ Ty = cast<PointerType>(Ty)->getElementType();
+ ++i;
+ }
+ const Type *result =
+ (length == 0) ? VectorType::get(Ty) : FixedVectorType::get(Ty, length);
+ while (i-- > 0) {
+ result = PointerType::get(result);
+ }
+ return result;
+ }
+
+ /// Get the (fixed) vector length of a raised instruction. Return 0
+ /// if the vector is not a fixed-length vector.
+ ///
+ unsigned RaiseVectors::getVectorLength(Instruction *I) {
+ for (User::op_iterator OI = I->op_begin(), OE = I->op_end();
+ OI != OE; ++OI) {
+ if (Value *val = raisingMap[*OI]) {
+ const Type *Ty = val->getType();
+ while (isa<PointerType>(Ty))
+ Ty = cast<PointerType>(Ty)->getElementType();
+ if (const FixedVectorType *VT = dyn_cast<FixedVectorType>(Ty))
+ return VT->getNumElements();
+ return 0;
+ }
+ }
+ assert(0 && "Instruction has no raised operands!");
+ }
+
+ /// Add the specified pair to the raising map, replacing dummy uses
+ /// if necessary
+ ///
+ void RaiseVectors::setRaisedValue(Instruction *key, Value *newValue) {
+ if (!newValue)
+ return;
+ Value*& oldValue = raisingMap[key];
+ if (oldValue) {
+ if (oldValue->getType() != newValue->getType())
+ newValue = new CastInst(oldValue, newValue->getType(), "cast", key);
+ oldValue->replaceAllUsesWith(newValue);
+ delete oldValue;
+ }
+ oldValue = newValue;
+ }
+
+ /// Get the value for the specified key from the raising map. If no
+ /// value is there, we haven't raised the value yet, so create a
+ /// dummy value with the appropriate vector length and replace it
+ /// when the value is raised.
+ ///
+ Value *RaiseVectors::getRaisedValue(Value *key, unsigned length) {
+ const Type *Ty = getRaisedType(key->getType(), length);
+ Value*& Val = raisingMap[key];
+ if (!Val) {
+ Val = Ty ? new Argument(Ty) : new Argument(VectorType::get(key->getType()));
+ DEBUG(std::cerr << "Created dummy value " << *Val << "\n");
+ }
+ return Val;
+ }
+
+ /// Get the raised operand of an instruction
+ ///
+ Value *RaiseVectors::getRaisedOperand(Instruction *I, unsigned i) {
+ Value *op = I->getOperand(i);
+ unsigned length = getVectorLength(I);
+ return getRaisedValue(op, length);
+ }
+
+ /// Add all uses of an instruction to the worklist
+ ///
+ void RaiseVectors::addUsesToWorklist(Instruction *I) {
+ DEBUG(std::cerr << "Adding uses of " << *I);
+ for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
+ UI != UE; ++UI) {
+ if (Instruction *II = dyn_cast<Instruction>(*UI)) {
+ DEBUG(std::cerr << "Adding " << *II);
+ workList.push_back(II);
+ }
+ }
+ }
+
+ }
Index: llvm/lib/Transforms/Vector/SSE.cpp
diff -c /dev/null llvm/lib/Transforms/Vector/SSE.cpp:1.1.2.1
*** /dev/null Tue Oct 18 14:37:14 2005
--- llvm/lib/Transforms/Vector/SSE.cpp Tue Oct 18 14:37:03 2005
***************
*** 0 ****
--- 1,583 ----
+ //===- SSE.cpp - Raise significant functions to Vector-LLVM ------===//
+ //
+ // The LLVM Compiler Infrastructure
+ //
+ // This file was developed by the LLVM research group and is distributed under
+ // the University of Illinois Open Source License. See LICENSE.TXT for details.
+ //
+ //===----------------------------------------------------------------------===//
+ //
+ // This file takes blocked Vector-LLVM code and puts it in a form that
+ // can be passed to the SSE C Backend.
+ //
+ //===----------------------------------------------------------------------===//
+
+ #define DEBUG_TYPE "SSE"
+
+ #include <sstream>
+ #include "VectorLLVM/Utils.h"
+ #include "llvm/Constants.h"
+ #include "llvm/DerivedTypes.h"
+ #include "llvm/Function.h"
+ #include "llvm/Instructions.h"
+ #include "llvm/Pass.h"
+ #include "llvm/Type.h"
+ #include "llvm/Support/Debug.h"
+ #include "llvm/ADT/hash_map"
+ #include "llvm/ADT/hash_set"
+ #include "llvm/ADT/STLExtras.h"
+ #include "llvm/Support/InstVisitor.h"
+
+ using namespace llvm;
+
+ namespace {
+
+
+ //===----------------------------------------------------------------------===//
+ // Class definitions
+ //===----------------------------------------------------------------------===//
+
+ class SSE : public FunctionPass, public InstVisitor<SSE> {
+
+ public:
+ bool runOnFunction(Function &F);
+ void visitCastInst(CastInst &);
+ void visitVImmInst(VImmInst &);
+ void visitExtractInst(ExtractInst &);
+ void visitCombineInst(CombineInst &);
+ void visitVSelectInst(VSelectInst &);
+ void visitAdd(BinaryOperator &);
+ void visitMul(BinaryOperator &);
+ void visitSetCondInst(SetCondInst &);
+ void visitSub(BinaryOperator &);
+ void visitCallInst(CallInst &);
+ void visitShiftInst(ShiftInst &);
+ void visitInstruction(Instruction& I) {}
+
+ private:
+ bool changed;
+ hash_set<Instruction*> instructionsToDelete;
+
+ void deleteInstructions() {
+ for (hash_set<Instruction*>::iterator I = instructionsToDelete.begin(),
+ E = instructionsToDelete.end(); I != E; ++I) {
+ (*I)->dropAllReferences();
+ }
+ for (hash_set<Instruction*>::iterator I = instructionsToDelete.begin(),
+ E = instructionsToDelete.end(); I != E; ++I) {
+ (*I)->getParent()->getInstList().erase(*I);
+ }
+ }
+ void addComposeConstant(BinaryOperator&,Value*,Value*);
+ };
+
+ RegisterOpt<SSE> X("sse",
+ "SSE code generation pre-pass");
+
+
+ //===----------------------------------------------------------------------===//
+ // Helper functions
+ //===----------------------------------------------------------------------===//
+
+ static unsigned getVectorSize(const Type* Ty) {
+ return 128 / (8 * Ty->getPrimitiveSize());
+ }
+
+ /// Check whether the type is one that SSE can handle; if not,
+ /// it must be lowered later.
+ ///
+ bool isProperType(const VectorType *VT) {
+ // Only fixed vector types are allowed
+ //
+ const FixedVectorType *FVT = dyn_cast<FixedVectorType>(VT);
+ if (!FVT) return false;
+ // Vector size must be appropriate
+ //
+ return (FVT->getNumElements() == getVectorSize(FVT->getElementType()));
+ }
+
+ static std::string getSSESuffix(const FixedVectorType *VecTy) {
+ const Type *ElTy = VecTy->getElementType();
+ std::ostringstream os;
+ if (ElTy->isIntegral()) {
+ if (ElTy->isSigned())
+ os << "epi" << 8*ElTy->getPrimitiveSize();
+ else
+ os << "epu" << 8*ElTy->getPrimitiveSize();
+ } else {
+ std::cerr << "Can't yet handle this type!\n"
+ << VecTy->getDescription() << "\n";
+ exit(1);
+ }
+ return os.str();
+ }
+
+ static std::string getSSEName(std::string baseName, const FixedVectorType *VecTy) {
+ return "_mm_" + baseName + "_" + getSSESuffix(VecTy);
+ }
+
+ static bool isCall(Value *V, const std::string &name) {
+ if (!V) return false;
+ if (CallInst *CI = dyn_cast<CallInst>(V))
+ if (Function *F = CI->getCalledFunction())
+ if (F->getName().substr(0, name.length()) == name)
+ return true;
+ return false;
+ }
+
+ static bool isMMCall(Value *V, const std::string &name) {
+ return isCall(V, "_mm_" + name);
+ }
+
+ static bool isComposeIntrinsic(Value *V) {
+ if (CallInst *CI = dyn_cast<CallInst>(V)) {
+ if (Function *F = CI->getCalledFunction()) {
+ if (F->getName().substr(0, 7) == "compose")
+ return true;
+ }
+ }
+ return false;
+ }
+
+ static bool isFullCompose(CallInst *CI) {
+ if (Function *F = CI->getCalledFunction()) {
+ if (F && F->getName().substr(0, 11) == "fullCompose")
+ return true;
+ }
+ return false;
+ }
+
+
+ //===----------------------------------------------------------------------===//
+ // SSE implementation
+ //===----------------------------------------------------------------------===//
+
+ /// Main function called by PassManager
+ ///
+ bool SSE::runOnFunction(Function &F) {
+ instructionsToDelete.clear();
+ changed = false;
+ for (Function::iterator FI = F.begin(), FE = F.end();
+ FI != FE; ++FI)
+ for (BasicBlock::iterator BI = FI->begin(), BE = FI->end();
+ BI != BE; ++BI)
+ if (!instructionsToDelete.count(BI)) {
+ DEBUG(std::cerr << "Visiting instruction " << *BI);
+ visit(*BI);
+ }
+ //visit(F);
+ if (changed) deleteInstructions();
+ return changed;
+ }
+
+ void SSE::visitVImmInst(VImmInst &VL) {
+ const FixedVectorType *VT = dyn_cast<FixedVectorType>(VL.getType());
+ assert(VT && "Vimm must be fixed vector type!\n");
+ CallInst *call = VectorUtils::getCallInst(VT, getSSEName("splat", VT),
+ VL.getOperand(0), "splat", &VL);
+ VL.replaceAllUsesWith(call);
+ instructionsToDelete.insert(&VL);
+ changed = true;
+ }
+
+ void SSE::visitCastInst(CastInst &CI) {
+ const VectorType *VT = dyn_cast<VectorType>(CI.getType());
+ if (!VT || !isProperType(VT))
+ return;
+ if (isa<VectorType>(CI.getOperand(0)->getType())) {
+ CallInst *op0 = dyn_cast<CallInst>(CI.getOperand(0));
+ if (op0 && isComposeIntrinsic(op0) && (CI.getType() == op0->getOperand(1)->getType())) {
+ CallInst *Or = VectorUtils::getCallInst(VT, "_mm_or_si128", op0->getOperand(1),
+ op0->getOperand(2), "or", &CI);
+ CI.replaceAllUsesWith(Or);
+ instructionsToDelete.insert(&CI);
+ instructionsToDelete.insert(op0);
+ changed = true;
+ } else if (op0 && isFullCompose(op0)) {
+ if (const FixedVectorType *LongVT = dyn_cast<FixedVectorType>(op0->getType())) {
+ CallInst *Pack = VectorUtils::getCallInst(VT, getSSEName("pack", LongVT),
+ op0->getOperand(1), op0->getOperand(2),
+ "pack", &CI);
+ CI.replaceAllUsesWith(Pack);
+ instructionsToDelete.insert(&CI);
+ instructionsToDelete.insert(op0);
+ changed = true;
+ }
+ }
+ } else {
+ // We need to use a _mm_set instruction
+ //
+ const Type *Ty = CI.getOperand(0)->getType();
+ unsigned primitiveSize = Ty->getPrimitiveSize();
+ unsigned vectorSize = getVectorSize(Ty);
+ const FixedVectorType *RetTy = FixedVectorType::get(Ty, vectorSize);
+ CallInst *call = VectorUtils::getCallInst(RetTy, getSSEName("splat", RetTy),
+ CI.getOperand(0), "splat", &CI);
+ if (RetTy != CI.getType())
+ CI.replaceAllUsesWith(new CastInst(call, CI.getType(), "cast", &CI));
+ else
+ CI.replaceAllUsesWith(call);
+ instructionsToDelete.insert(&CI);
+ changed = true;
+ }
+ }
+
+ // Check whether an extract instruction should be turned into
+ // SSE_unpack
+ //
+ // FIXME: This code doesn't work
+ //
+ void SSE::visitExtractInst(ExtractInst &EI) {
+ Value *v = EI.getOperand(0);
+ ConstantUInt *start = dyn_cast<ConstantUInt>(EI.getOperand(1));
+ ConstantUInt *stride = dyn_cast<ConstantUInt>(EI.getOperand(2));
+ ConstantUInt *len = dyn_cast<ConstantUInt>(EI.getOperand(3));
+ if (!start || !stride || !len) return;
+ if (stride->getValue() != 1 || len->getValue() != 8) return;
+ std::string funcName;
+ if (start->getValue() == 0)
+ funcName = "SSE_unpackh_short";
+ else if (start->getValue() == 8)
+ funcName = "SSE_unpackl_short";
+ else return;
+ const FixedVectorType *VT = dyn_cast<FixedVectorType>(v->getType());
+ if (VT == FixedVectorType::get(Type::UByteTy, 16)) {
+ if (!EI.hasOneUse()) return;
+ CastInst *cast = dyn_cast<CastInst>(*EI.use_begin());
+ if (!cast) return;
+ const FixedVectorType *argTy = FixedVectorType::get(Type::SByteTy, 16);
+ const FixedVectorType *retTy = FixedVectorType::get(Type::ShortTy, 8);
+ if (cast->getType() != FixedVectorType::get(Type::ShortTy, 8))
+ return;
+ CastInst *arg = new CastInst(v, argTy, "cast", &EI);
+ std::vector<const Type*> formalArgs;
+ formalArgs.push_back(argTy);
+ std::vector<Value*> args;
+ args.push_back(arg);
+ FunctionType *FType = FunctionType::get(retTy, formalArgs, false);
+ Module *M = EI.getParent()->getParent()->getParent();
+ Function *unpack =
+ M->getOrInsertFunction(funcName, FType);
+ CallInst *call = new CallInst(unpack, args, "unpack", &EI);
+ BinaryOperator *andInst =
+ BinaryOperator::create(Instruction::And, call,
+ ConstantExpr::getCast(ConstantSInt::get(Type::ShortTy, 0xFF), retTy),
+ "and", &EI);
+ cast->replaceAllUsesWith(andInst);
+ instructionsToDelete.insert(cast);
+ instructionsToDelete.insert(&EI);
+ changed = true;
+ }
+ }
+
+ void SSE::visitCombineInst(CombineInst &CI) {
+ Instruction *combine1 = cast<CombineInst>(&CI);
+ Value *v1 = CI.getOperand(0);
+ Value *v2 = CI.getOperand(1);
+ // If the destination is a combine instruction, do nothing; if
+ // necessary, we'll handle the first combine instruction in the
+ // series.
+ //
+ if (isa<CombineInst>(v1))
+ return;
+ // We must have two fixed-vector operands, and the first must have
+ // twice as many elements as the second. Also, the second must
+ // have proper type (but the first need not).
+ //
+ const FixedVectorType *VT1 = dyn_cast<FixedVectorType>(v1->getType());
+ if (!VT1) return;
+ const FixedVectorType *VT2 = dyn_cast<FixedVectorType>(v2->getType());
+ if (!VT2) return;
+ if (VT1->getNumElements() != 2*VT2->getNumElements()) return;
+ // This combine must have exactly one use, and it must be a
+ // combine whose operand 0 is this combine. The types must work
+ // out properly.
+ //
+ if (!CI.hasOneUse()) return;
+ CombineInst* combine2 = dyn_cast<CombineInst>(*CI.use_begin());
+ if (!combine2) return;
+ if (&CI != combine2->getOperand(0)) return;
+ if (combine2->getOperand(1)->getType() != VT2) return;
+ if (combine2->hasOneUse()) {
+ // Check for _mm_packs pattern. Second combine must have
+ // exactly one use, and it must be a cast to an appropriate
+ // type.
+ //
+ Instruction *use = dyn_cast<Instruction>(*combine2->use_begin());
+ if (!use) return;
+ const FixedVectorType *VT = dyn_cast<FixedVectorType>(use->getType());
+ if (!VT) return;
+ if (VT->getNumElements() != VT1->getNumElements()) return;
+ if (!isa<CastInst>(use))
+ return;
+ Value *op1, *op2;
+ op1 = v2;
+ op2 = combine2->getOperand(1);
+ // Right now this only works for signed values
+ //
+ Value *eight = ConstantUInt::get(Type::UByteTy, 8);
+ op1 = VectorUtils::getCallInst(VT2, getSSEName("slli", VT2),
+ op1, eight, "slli", &CI);
+ op1 = VectorUtils::getCallInst(VT2, getSSEName("srai", VT2),
+ op1, eight, "srai", &CI);
+ op2 = VectorUtils::getCallInst(VT2, getSSEName("slli", VT2),
+ op2, eight, "slli", &CI);
+ op2 = VectorUtils::getCallInst(VT2, getSSEName("srai", VT2),
+ op2, eight, "srai", &CI);
+ use->replaceAllUsesWith(VectorUtils::getCallInst(VT, getSSEName("packs", VT1),
+ op1, op2, "packs", &CI));
+ instructionsToDelete.insert(use);
+ instructionsToDelete.insert(combine2);
+ instructionsToDelete.insert(&CI);
+ Instruction *op0 = dyn_cast<Instruction>(CI.getOperand(0));
+ if (op0)
+ instructionsToDelete.insert(op0);
+ changed = true;
+ } else if (combine2->hasNUses(2)) {
+ Value::use_iterator I = combine2->use_begin();
+ ExtractInst *extract0 = dyn_cast<ExtractInst>(*I++);
+ ExtractInst *extract1 = dyn_cast<ExtractInst>(*I);
+ assert(extract0 && extract1);
+ CallInst *unpackhi = VectorUtils::getCallInst(VT2, getSSEName("unpackhi", VT2),
+ combine1->getOperand(1), combine2->getOperand(1),
+ "unpackhi", extract0);
+ CallInst *unpacklo = VectorUtils::getCallInst(VT2, getSSEName("unpacklo", VT2),
+ combine1->getOperand(1), combine2->getOperand(1),
+ "unpacklo", extract0);
+ if (cast<ConstantUInt>(extract0->getOperand(1))->getValue() == 1) {
+ extract0->replaceAllUsesWith(unpackhi);
+ extract1->replaceAllUsesWith(unpacklo);
+ } else {
+ extract0->replaceAllUsesWith(unpacklo);
+ extract1->replaceAllUsesWith(unpackhi);
+ }
+ instructionsToDelete.insert(combine1);
+ instructionsToDelete.insert(combine2);
+ instructionsToDelete.insert(extract0);
+ instructionsToDelete.insert(extract1);
+ Instruction *op0 = dyn_cast<Instruction>(CI.getOperand(0));
+ if (op0)
+ instructionsToDelete.insert(op0);
+ changed = true;
+ }
+ }
+
+ void SSE::visitVSelectInst(VSelectInst &VI) {
+ const FixedVectorType *VT = dyn_cast<FixedVectorType>(VI.getType());
+ if (!VT || !isProperType(VT)) return;
+ Value *mask = VI.getOperand(0);
+ CallInst *And = VectorUtils::getCallInst(VT, "_mm_and_si128", VI.getOperand(1),
+ mask, "and", &VI);
+ CallInst *AndNot = VectorUtils::getCallInst(VT, "_mm_andnot_si128", mask,
+ VI.getOperand(2), "andnot", &VI);
+ CallInst *Or = VectorUtils::getCallInst(VT, "_mm_or_si128", And, AndNot,
+ "or", &VI);
+ VI.replaceAllUsesWith(Or);
+ instructionsToDelete.insert(&VI);
+ changed = true;
+ }
+
+ void SSE::visitShiftInst(ShiftInst &SI) {
+ const FixedVectorType *VT = dyn_cast<FixedVectorType>(SI.getType());
+ if (!VT) return;
+ CallInst *CI = dyn_cast<CallInst>(SI.getOperand(0));
+ if (CI && isFullCompose(CI)) {
+ Value *op1 = CI->getOperand(1);
+ Value *op2 = CI->getOperand(2);
+ VT = cast<FixedVectorType>(op1->getType());
+ std::string shortName;
+ if (SI.getOpcode() == Instruction::Shl)
+ shortName = "slli";
+ else if (VT->getElementType()->isSigned())
+ shortName = "srai";
+ else
+ shortName = "srli";
+ CallInst *shiftLo = VectorUtils::getCallInst(VT, getSSEName(shortName, VT),
+ op1, SI.getOperand(1), "shiftLo", &SI);
+ CallInst *shiftHi = VectorUtils::getCallInst(VT, getSSEName(shortName, VT),
+ op2, SI.getOperand(1), "shiftHi", &SI);
+ CallInst *fullCompose = new CallInst(CI->getCalledFunction(), shiftLo,
+ shiftHi, "fullCompose", &SI);
+ SI.replaceAllUsesWith(fullCompose);
+ instructionsToDelete.insert(&SI);
+ instructionsToDelete.insert(CI);
+ changed = true;
+ } else if (CI && isComposeIntrinsic(CI)) {
+ const FixedVectorType* Ty = cast<FixedVectorType>(CI->getOperand(1)->getType());
+ Value *shamtLo = SI.getOperand(1);
+ Instruction *shamtHi = BinaryOperator::create(Instruction::Sub, ConstantUInt::get(Type::UByteTy, 16),
+ shamtLo, "sub", &SI);
+ Instruction *lo = VectorUtils::getCallInst(Ty, "_mm_srli_" + getSSESuffix(Ty),
+ CI->getOperand(1), shamtLo, "lo", &SI);
+ Instruction *hi = VectorUtils::getCallInst(Ty, "_mm_slli_" + getSSESuffix(Ty),
+ CI->getOperand(2), shamtHi, "hi", &SI);
+ Instruction *compose = new CallInst(CI->getCalledFunction(), lo, hi, "compose", &SI);
+ SI.replaceAllUsesWith(compose);
+ instructionsToDelete.insert(&SI);
+ instructionsToDelete.insert(CI);
+ changed = true;
+ } else {
+ std::string shortName;
+ if (SI.getOpcode() == Instruction::Shr) {
+ if (VT->getElementType()->isSigned())
+ shortName = "srai";
+ else
+ shortName = "srli";
+ } else {
+ shortName = "slli";
+ }
+ CallInst *shift = VectorUtils::getCallInst(VT, getSSEName(shortName, VT),
+ SI.getOperand(0), SI.getOperand(1),
+ "shift", &SI);
+ SI.replaceAllUsesWith(shift);
+ instructionsToDelete.insert(&SI);
+ changed = true;
+ }
+ }
+
+ static const Type *getSignedType(const Type *Ty) {
+ if (Ty->isSigned())
+ return Ty;
+ switch(Ty->getTypeID()) {
+ case Type::UIntTyID:
+ return Type::IntTy;
+ default:
+ std::cerr << "Can't handle type " << Ty->getDescription() << "\n";
+ }
+ return 0;
+ }
+
+ void SSE::addComposeConstant(BinaryOperator &Add,
+ Value *arg1, Value *arg2) {
+ CallInst *compose = dyn_cast<CallInst>(arg1);//Add.getOperand(0));
+ if (!compose || !isComposeIntrinsic(compose) || !compose->hasOneUse()) return;
+ Value *op1 = compose->getOperand(1);
+ Value *op2 = compose->getOperand(2);
+ CastInst *addCast0 = dyn_cast<CastInst>(arg2);//Add.getOperand(1));
+ if (!addCast0) return;
+ CallInst *splat = dyn_cast<CallInst>(addCast0->getOperand(0));
+ if (!isMMCall(splat, "splat")) return;
+ Constant *C = dyn_cast<Constant>(splat->getOperand(1));
+ if (!C) return;
+ const FixedVectorType *LongVT = dyn_cast<FixedVectorType>(Add.getType());
+ const FixedVectorType *ShortVT = dyn_cast<FixedVectorType>(op1->getType());
+ if (!LongVT || !ShortVT) return;
+ const FixedVectorType *HalfVT = FixedVectorType::get(getSignedType(LongVT->getElementType()),
+ LongVT->getNumElements() / 2);
+ CallInst *splat2 = VectorUtils::getCallInst(HalfVT, getSSEName("splat", HalfVT),
+ C, "splat", &Add);
+ CallInst *unpackLo = VectorUtils::getCallInst(HalfVT, getSSEName("unpacklo", ShortVT),
+ op1, op2, "unpackLo", &Add);
+ CallInst *unpackHi = VectorUtils::getCallInst(HalfVT, getSSEName("unpackhi", ShortVT),
+ op1, op2, "unpackHi", &Add);
+ CallInst *addLo = VectorUtils::getCallInst(HalfVT, getSSEName("add", HalfVT),
+ unpackLo, splat2, "addLo", &Add);
+ CallInst *addHi = VectorUtils::getCallInst(HalfVT, getSSEName("add", HalfVT),
+ unpackHi, splat2, "addHi", &Add);
+ CallInst *fullCompose = VectorUtils::getCallInst(LongVT, "fullCompose_" + HalfVT->getElementType()->getDescription(),
+ addLo, addHi, "fullCompose", &Add);
+ Add.replaceAllUsesWith(fullCompose);
+ instructionsToDelete.insert(addCast0);
+ instructionsToDelete.insert(splat);
+ instructionsToDelete.insert(compose);
+ instructionsToDelete.insert(&Add);
+ changed = true;
+ }
+
+ /// FIXME: This is very specialized to the form add(compose,
+ /// cast(cast(constant))). Generalize this!!!
+ ///
+ void SSE::visitAdd(BinaryOperator &Add) {
+ addComposeConstant(Add, Add.getOperand(0), Add.getOperand(1));
+ addComposeConstant(Add, Add.getOperand(1), Add.getOperand(0));
+ }
+
+ void SSE::visitMul(BinaryOperator &BO) {
+ CastInst *op0 = dyn_cast<CastInst>(BO.getOperand(0));
+ CastInst *op1 = dyn_cast<CastInst>(BO.getOperand(1));
+ if (!op0 || !op1) return;
+ const FixedVectorType *Ty = dyn_cast<FixedVectorType>(op0->getType());
+ const FixedVectorType* RetTy = dyn_cast<FixedVectorType>(op0->getOperand(0)->getType());
+ if (Ty && RetTy) {
+ Instruction *hi = VectorUtils::getCallInst(RetTy, "_mm_mulhi_" + getSSESuffix(RetTy),
+ op0->getOperand(0), op1->getOperand(0),
+ "mul", &BO);
+ Instruction *lo = VectorUtils::getCallInst(RetTy, "_mm_mullo_" + getSSESuffix(RetTy),
+ op0->getOperand(0), op1->getOperand(0),
+ "mul", &BO);
+ Instruction *result = VectorUtils::getCallInst(Ty, "compose_" + RetTy->getElementType()->getDescription(),
+ lo, hi, "compose", &BO);
+ BO.replaceAllUsesWith(result);
+ instructionsToDelete.insert(&BO);
+ instructionsToDelete.insert(op0);
+ instructionsToDelete.insert(op1);
+ changed = true;
+ }
+ }
+
+ void SSE::visitSetCondInst(SetCondInst &BO) {
+ const FixedVectorType *VT =
+ dyn_cast<FixedVectorType>(BO.getOperand(0)->getType());
+ if (!VT) return;
+ std::string name;
+ switch(BO.getOpcode()) {
+ case Instruction::VSetGT:
+ name = "cmpgt";
+ break;
+ default:
+ std::cerr << "Can't handle instruction " << BO;
+ exit(1);
+ }
+ std::string fullName = "_mm_" + name + "_" + getSSESuffix(VT);
+ CallInst *call =
+ VectorUtils::getCallInst(BO.getType(), fullName,
+ BO.getOperand(0), BO.getOperand(1),
+ "cmp", &BO);
+ BO.replaceAllUsesWith(call);
+ instructionsToDelete.insert(&BO);
+ changed = true;
+ }
+
+ void SSE::visitSub(BinaryOperator &BO) {
+ const FixedVectorType *VT =
+ dyn_cast<FixedVectorType>(BO.getOperand(0)->getType());
+ if (!VT) return;
+ CallInst *sub = VectorUtils::getCallInst(VT, getSSEName("sub", VT),
+ BO.getOperand(0), BO.getOperand(1),
+ "sub", &BO);
+ BO.replaceAllUsesWith(sub);
+ instructionsToDelete.insert(&BO);
+ changed = true;
+ }
+
+ void SSE::visitCallInst(CallInst &CI) {
+ const FixedVectorType *VT =
+ dyn_cast<FixedVectorType>(CI.getType());
+ if (!VT || !isProperType(VT)) return;
+ Function *callee = CI.getCalledFunction();
+ if (!callee) return;
+ std::string calleeName = callee->getName();
+ std::string prefix = calleeName.substr(0, 6);
+ if (prefix != "vllvm_") return;
+ unsigned pos = calleeName.find("_", 6);
+ if (pos == std::string::npos) {
+ std::cerr << "Bad syntax for Vector-LLVM intrinsic " << calleeName << "\n";
+ exit(1);
+ }
+ std::string shortName = calleeName.substr(6, pos-6);
+ if (shortName == "saturate") {
+ return;
+ } else {
+ std::string fullName = "_mm_" + shortName + "_" +
+ getSSESuffix(VT);
+ std::vector<Value*> args;
+ for (unsigned i = 1; i < CI.getNumOperands(); ++i)
+ args.push_back(CI.getOperand(i));
+ CallInst *call = VectorUtils::getCallInst(VT, fullName, args, "mm_call", &CI);
+ CI.replaceAllUsesWith(call);
+ instructionsToDelete.insert(&CI);
+ }
+ changed = true;
+ }
+
+ }
More information about the llvm-commits
mailing list