[llvm-commits] [vector_llvm] CVS: llvm/lib/Transforms/Vector/AltiVec.cpp LowerVectors.cpp RaiseVectors.cpp SSE.cpp

Robert L. Bocchino Jr. bocchino at persephone.cs.uiuc.edu
Tue Nov 15 12:16:03 PST 2005



Changes in directory llvm/lib/Transforms/Vector:

AltiVec.cpp updated: 1.1.2.1 -> 1.1.2.2
LowerVectors.cpp updated: 1.1.2.1 -> 1.1.2.2
RaiseVectors.cpp updated: 1.1.2.1 -> 1.1.2.2
SSE.cpp updated: 1.1.2.1 -> 1.1.2.2
---
Log message:

Improved AltiVec code generation support

Changed the signature for extract

Made lowervectors work with fixed vectors


---
Diffs of the changes:  (+108 -131)

 AltiVec.cpp      |  103 +++++++++++++++++++++++++------------------------------
 LowerVectors.cpp |   62 ++++++++++++++++++++++++++++-----
 RaiseVectors.cpp |    9 +++-
 SSE.cpp          |   65 ----------------------------------
 4 files changed, 108 insertions(+), 131 deletions(-)


Index: llvm/lib/Transforms/Vector/AltiVec.cpp
diff -u llvm/lib/Transforms/Vector/AltiVec.cpp:1.1.2.1 llvm/lib/Transforms/Vector/AltiVec.cpp:1.1.2.2
--- llvm/lib/Transforms/Vector/AltiVec.cpp:1.1.2.1	Tue Oct 18 14:37:03 2005
+++ llvm/lib/Transforms/Vector/AltiVec.cpp	Tue Nov 15 14:15:33 2005
@@ -41,7 +41,6 @@
 
   public:
     bool runOnFunction(Function &F);
-    void visitCastInst(CastInst &);
     void visitVImmInst(VImmInst &);
     void visitExtractInst(ExtractInst &);
     void visitCombineInst(CombineInst &);
@@ -103,7 +102,6 @@
     return "altivec_" + baseName + "_" + VT->getElementType()->getDescription();
   }
 
-
   //===----------------------------------------------------------------------===//
   //                     AltiVec implementation
   //===----------------------------------------------------------------------===//
@@ -114,13 +112,15 @@
     instructionsToDelete.clear();
     changed = false;
     for (Function::iterator FI = F.begin(), FE = F.end(); 
-	 FI != FE; ++FI)
+	 FI != FE; ++FI) {
       for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); 
-	   BI != BE; ++BI)
+	   BI != BE; ++BI) {
 	if (!instructionsToDelete.count(BI)) {
 	  DEBUG(std::cerr << "Visiting instruction " << *BI);
 	  visit(*BI);
 	}
+      }
+    }
     if (changed) deleteInstructions();
     return changed;
   }
@@ -157,35 +157,6 @@
     changed = true;
   }
 
-  void AltiVec::visitCastInst(CastInst &CI) {
-    // We need only worry about a cast of a non-constant scalar to a
-    // vector; the AltiVec C Backend can handle the other cases
-    // directly.
-    //
-    const VectorType *VT = dyn_cast<VectorType>(CI.getType());
-    if (!VT || !isProperType(VT) ||
-	isa<VectorType>(CI.getOperand(0)->getType()) ||
-	isa<Constant>(CI.getOperand(0)))
-      return;
-    // We need to create a new vector on the stack, store the scalar
-    // value into it, and splat the value into a vector register
-    //
-    AllocaInst *vectorPtr = new AllocaInst(VT, 0, "alloca", &CI);
-    Value *element = CI.getOperand(0); 
-    if (element->getType() != VT->getElementType())
-      element = new CastInst(element, VT->getElementType(), "cast", &CI);
-    CastInst *scalarPtr = new CastInst(vectorPtr, PointerType::get(VT->getElementType()),
-				       "cast", &CI);
-    StoreInst *store = new StoreInst(element, scalarPtr, &CI);
-    LoadInst *vector = new LoadInst(vectorPtr, "load", &CI);
-    CallInst *call = VectorUtils::getCallInst(VT, getAltiVecName("splat", VT),
-					      vector, ConstantUInt::get(Type::UByteTy, 0),
-					      "splat", &CI);
-    CI.replaceAllUsesWith(call);
-    instructionsToDelete.insert(&CI);
-    changed = true;
-  }
-  
   // Check whether an extract instruction should be turned into
   // altivec_unpack
   //
@@ -403,7 +374,7 @@
       CastInst *addCast0 = 0, *addCast2 = 0, *shrCast = 0;
       VImmInst *VImm = 0;
       ShiftInst *shr = 0;
-      CallInst *adds;
+      CallInst *adds = 0;
       unsigned offset = 0, shamt = 0;
       if (&BO == add->getOperand(0))
 	addCast0 = dyn_cast<CastInst>(add->getOperand(1));
@@ -427,30 +398,52 @@
       }
       if (shrCast && shrCast->hasOneUse()) {
 	adds = dyn_cast<CallInst>(*shrCast->use_begin());
-	Function *F = adds->getCalledFunction();
-	if (!F || F->getName().substr(0, 10) != "vllvm_adds")
+	if (adds) {
+	  Function *F = adds->getCalledFunction();
+	  if (!F || F->getName().substr(0, 10) != "vllvm_adds")
 	    adds = 0;
+	}
       }
       if (mulCast0 && mulCast1 && addCast0 && VImm &&
-	  offset == 16384 && shrCast && shr && shamt == 15 && 
-	  adds) {
-	VT = cast<FixedVectorType>(adds->getType());
-	CallInst *mradds = VectorUtils::getCallInst(VT, getAltiVecName("mradds", VT),
-						    new CastInst(mulCast0->getOperand(0), VT, "cast", adds), 
-						    new CastInst(mulCast1->getOperand(0), VT, "cast", adds),
-						    (shrCast == adds->getOperand(1)) ? adds->getOperand(2) : adds->getOperand(1),
-						    "mradds", adds);
-	adds->replaceAllUsesWith(mradds);
-	instructionsToDelete.insert(&BO);
-	instructionsToDelete.insert(add);
-	instructionsToDelete.insert(shr);
-	instructionsToDelete.insert(mulCast0);
-	instructionsToDelete.insert(mulCast1);
-	instructionsToDelete.insert(addCast0);
-	instructionsToDelete.insert(VImm);
-	instructionsToDelete.insert(shrCast);
-	instructionsToDelete.insert(adds);
-	changed = true;
+	  offset == 16384 && shrCast && shr && shamt == 15) {
+	if (adds) {
+	  VT = cast<FixedVectorType>(adds->getType());
+	  CallInst *mradds = VectorUtils::getCallInst(VT, getAltiVecName("mradds", VT),
+						      new CastInst(mulCast0->getOperand(0), VT, "cast", adds), 
+						      new CastInst(mulCast1->getOperand(0), VT, "cast", adds),
+						      (shrCast == adds->getOperand(1)) ? adds->getOperand(2) : adds->getOperand(1),
+						      "mradds", adds);
+	  adds->replaceAllUsesWith(mradds);
+	  instructionsToDelete.insert(&BO);
+	  instructionsToDelete.insert(add);
+	  instructionsToDelete.insert(shr);
+	  instructionsToDelete.insert(mulCast0);
+	  instructionsToDelete.insert(mulCast1);
+	  instructionsToDelete.insert(addCast0);
+	  instructionsToDelete.insert(VImm);
+	  instructionsToDelete.insert(shrCast);
+	  instructionsToDelete.insert(adds);
+	  changed = true;
+	} else {
+	  VT = cast<FixedVectorType>(shrCast->getType());
+	  CallInst *mradds = VectorUtils::getCallInst(VT, getAltiVecName("mradds", VT),
+						      new CastInst(mulCast0->getOperand(0), VT, "cast", shrCast), 
+						      new CastInst(mulCast1->getOperand(0), VT, "cast", shrCast),
+						      new VImmInst(Constant::getNullValue(VT->getElementType()),
+								   ConstantUInt::get(Type::UIntTy, VT->getNumElements()),
+								   true, "vimm", shrCast),
+						      "mradds", shrCast);
+	  shrCast->replaceAllUsesWith(mradds);
+	  instructionsToDelete.insert(&BO);
+	  instructionsToDelete.insert(add);
+	  instructionsToDelete.insert(shr);
+	  instructionsToDelete.insert(mulCast0);
+	  instructionsToDelete.insert(mulCast1);
+	  instructionsToDelete.insert(addCast0);
+	  instructionsToDelete.insert(VImm);
+	  instructionsToDelete.insert(shrCast);
+	  changed = true;
+	}
 	return;
       }    
       // Check for mladd pattern


Index: llvm/lib/Transforms/Vector/LowerVectors.cpp
diff -u llvm/lib/Transforms/Vector/LowerVectors.cpp:1.1.2.1 llvm/lib/Transforms/Vector/LowerVectors.cpp:1.1.2.2
--- llvm/lib/Transforms/Vector/LowerVectors.cpp:1.1.2.1	Tue Oct 18 14:37:03 2005
+++ llvm/lib/Transforms/Vector/LowerVectors.cpp	Tue Nov 15 14:15:33 2005
@@ -63,6 +63,7 @@
     void visitFreeInst(FreeInst&);
     void visitStoreInst(StoreInst&);
     void visitLoadInst(LoadInst&);
+    void visitGetElementPtrInst(GetElementPtrInst&);
     void visitInstruction(Instruction& I) {
       std::cerr << "LowerVectors class can't handle instruction " << I << "!\n";
       exit(1);
@@ -110,7 +111,7 @@
   class InstructionLowering : public InstVisitor<InstructionLowering> {
     friend class LowerVectors;
 
-    Instruction *vector;
+    Value *vector;
     Value *result, *vectorIndex;
     std::vector<Value*> idx;
     BasicBlock *body;
@@ -219,6 +220,8 @@
   /// Given a vector type value, look in the lengthMap to get its length.
   ///
   Value *getLength(Value *key) {
+    if (const FixedVectorType *VT = dyn_cast<FixedVectorType>(key->getType()))
+      return ConstantUInt::get(Type::UIntTy, VT->getNumElements());
     Value*& V = lengthMap[key];
     if (!V) {
       V = new Argument(Type::UIntTy);
@@ -248,6 +251,9 @@
   /// type to the same type.
   ///
   const Type* getLoweredMemoryType(const Type* Ty) {
+    if (const FixedVectorType *VT = dyn_cast<FixedVectorType>(Ty)) {
+      return VT->getElementType();
+    }
     if (const VectorType *VectorTy = dyn_cast<VectorType>(Ty)) {
       std::vector<const Type*> Params;
       Params.push_back(PointerType::get(VectorTy->getElementType()));
@@ -284,7 +290,10 @@
   /// lowered.
   ///
   Value *getLoweredValue(Value *key) {
-    const VectorType *VectorTy = dyn_cast<VectorType>(key->getType());
+    if (!VectorUtils::containsVector(key->getType()) &&
+	!isa<VScatterInst>(key)) {
+      return key;
+    }
     Value*& value = loweringMap[key];
     if (!value) {
       value = new Argument(getLoweredRegisterType(key->getType()));
@@ -501,9 +510,29 @@
   }
 
   void LowerVectors::visitCastInst(CastInst &CI) {
-    InstructionLowering lower(&CI, getLength(CI.getOperand(0)));
+    if (isa<PointerType>(CI.getType())) {
+      CastInst *Cast = 
+	new CastInst(getLoweredValue(CI.getOperand(0)),
+		     getLoweredMemoryType(CI.getType()),
+		     "cast", &CI);
+      setLoweredValue(&CI, Cast);
+    } else {
+      InstructionLowering lower(&CI, getLength(CI.getOperand(0)));
+    }
   }
 
+  void LowerVectors::visitGetElementPtrInst(GetElementPtrInst &GEP) {
+    std::vector<Value*> idx;
+    for (User::op_iterator I = GEP.idx_begin(), E = GEP.idx_end();
+	 I != E; ++I) {
+      idx.push_back(*I);
+    }
+    GetElementPtrInst *NewGEP = 
+      new GetElementPtrInst(getLoweredValue(GEP.getPointerOperand()),
+			    idx, "gep", &GEP);
+    setLoweredValue(&GEP, NewGEP);
+  }
+
   void LowerVectors::visitBinaryOperator(BinaryOperator &BO) {
     Value *op0 = BO.getOperand(0);
     Value *op1 = BO.getOperand(1);
@@ -601,7 +630,12 @@
   }
 
   void LowerVectors::visitStoreInst(StoreInst &SI) {
-    if (isa<VectorType>(SI.getOperand(0)->getType())) {
+    const Type *Ty = SI.getOperand(0)->getType();
+    if (const FixedVectorType *VT = dyn_cast<FixedVectorType>(SI.getOperand(0)->getType())) {
+      Value *length = getLength(SI.getOperand(0));//ConstantUInt::get(Type::UIntTy, VT->getNumElements());
+      Value *ptr = getLoweredValue(SI.getOperand(1));
+      InstructionLowering lower(&SI, length, ptr);
+    } else if (isa<VectorType>(SI.getOperand(0)->getType())) {
       std::vector<Value*> Idx;
       Idx.push_back(ConstantUInt::get(Type::UIntTy, 0));
       Idx.push_back(ConstantUInt::get(Type::UIntTy, 0));
@@ -623,7 +657,10 @@
   }
 
   void LowerVectors::visitLoadInst(LoadInst &LI) {
-    if (isa<VectorType>(cast<PointerType>(LI.getOperand(0)->getType())->getElementType())) {
+    const Type *Ty = cast<PointerType>(LI.getOperand(0)->getType())->getElementType();
+    if (isa<FixedVectorType>(Ty)) {
+      setLoweredValue(&LI, getLoweredValue(LI.getOperand(0)));
+    } else if (isa<VectorType>(Ty)) {
       std::vector<Value*> Idx;
       Idx.push_back(ConstantUInt::get(Type::UIntTy, 0));
       Idx.push_back(ConstantUInt::get(Type::UIntTy, 1));
@@ -654,7 +691,16 @@
     const VectorType *VectorTy = dyn_cast<VectorType>(Ty);
     assert(VectorTy && "Instruction must be of vector type!");
     const Type *elementType = VectorTy->getElementType();
-    vector = allocateVector(elementType, length, "vector", I, ptr);
+    if (isa<FixedVectorType>(Ty)) {
+      if (isa<StoreInst>(I)) {
+	vector = ptr;
+      } else {
+	vector = new AllocaInst(VectorTy->getElementType(), length, "vector", 
+				&(I->getParent()->getParent()->getEntryBlock().front()));
+      }
+    } else {
+      vector = allocateVector(elementType, length, "vector", I, ptr);
+    }
     setLoweredValue(I, vector, length);
 
     // Set up the loop index
@@ -751,7 +797,7 @@
   void InstructionLowering::visitCombineInst(CombineInst &CI) {
     // Here we are generating code for
     //
-    //   %tmp = extract v1, v2, start, stride
+    //   %tmp = combine v1, v2, start, stride
     //
     // First we compute secondIndex, the index into v2.  If start <=
     // vectorIndex < start + stride * getLength(v2), then
@@ -1001,7 +1047,7 @@
   //                     VScatterLowering implementation
   //===----------------------------------------------------------------------===//
 
-  /// Lower a vstore instruction.  Find the array that holds the
+  /// Lower a vscatter instruction.  Find the array that holds the
   /// vector contents, then copy that array to the indexed memory
   /// locations.
   ///


Index: llvm/lib/Transforms/Vector/RaiseVectors.cpp
diff -u llvm/lib/Transforms/Vector/RaiseVectors.cpp:1.1.2.1 llvm/lib/Transforms/Vector/RaiseVectors.cpp:1.1.2.2
--- llvm/lib/Transforms/Vector/RaiseVectors.cpp:1.1.2.1	Tue Oct 18 14:37:03 2005
+++ llvm/lib/Transforms/Vector/RaiseVectors.cpp	Tue Nov 15 14:15:33 2005
@@ -492,10 +492,13 @@
   /// Raise a phi node
   ///
   void RaiseVectors::visitPHINode(PHINode &PN) {
-    PHINode *raisedValue =
-      new PHINode(VectorType::get(PN.getType()), "phi",
-		  &PN);
     unsigned length = getVectorLength(&PN);
+    const VectorType *VT = 0;
+    if (length)
+      VT = FixedVectorType::get(PN.getType(), length);
+    else
+      VT = VectorType::get(PN.getType());
+    PHINode *raisedValue = new PHINode(VT, "phi", &PN);
     for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
       raisedValue->addIncoming(getRaisedValue(PN.getIncomingValue(i), length),
 			       PN.getIncomingBlock(i));


Index: llvm/lib/Transforms/Vector/SSE.cpp
diff -u llvm/lib/Transforms/Vector/SSE.cpp:1.1.2.1 llvm/lib/Transforms/Vector/SSE.cpp:1.1.2.2
--- llvm/lib/Transforms/Vector/SSE.cpp:1.1.2.1	Tue Oct 18 14:37:03 2005
+++ llvm/lib/Transforms/Vector/SSE.cpp	Tue Nov 15 14:15:33 2005
@@ -43,7 +43,6 @@
     bool runOnFunction(Function &F);
     void visitCastInst(CastInst &);
     void visitVImmInst(VImmInst &);
-    void visitExtractInst(ExtractInst &);
     void visitCombineInst(CombineInst &);
     void visitVSelectInst(VSelectInst &);
     void visitAdd(BinaryOperator &);
@@ -165,7 +164,6 @@
 	  DEBUG(std::cerr << "Visiting instruction " << *BI);
 	  visit(*BI);
 	}
-    //visit(F);
     if (changed) deleteInstructions();
     return changed;
   }
@@ -204,72 +202,9 @@
 	  changed = true;
 	}
       }
-    } else {
-      // We need to use a _mm_set instruction
-      //
-      const Type *Ty = CI.getOperand(0)->getType();
-      unsigned primitiveSize = Ty->getPrimitiveSize();
-      unsigned vectorSize = getVectorSize(Ty);
-      const FixedVectorType *RetTy = FixedVectorType::get(Ty, vectorSize);
-      CallInst *call = VectorUtils::getCallInst(RetTy, getSSEName("splat", RetTy),
-						 CI.getOperand(0), "splat", &CI);
-      if (RetTy != CI.getType())
-	CI.replaceAllUsesWith(new CastInst(call, CI.getType(), "cast", &CI));
-      else 
-	CI.replaceAllUsesWith(call);
-      instructionsToDelete.insert(&CI);
-      changed = true;
     }
   }
   
-  // Check whether an extract instruction should be turned into
-  // SSE_unpack
-  //
-  // FIXME:  This code doesn't work
-  //
-  void SSE::visitExtractInst(ExtractInst &EI) {
-    Value *v = EI.getOperand(0);
-    ConstantUInt *start = dyn_cast<ConstantUInt>(EI.getOperand(1));
-    ConstantUInt *stride = dyn_cast<ConstantUInt>(EI.getOperand(2));
-    ConstantUInt *len = dyn_cast<ConstantUInt>(EI.getOperand(3));
-    if (!start || !stride || !len) return;
-    if (stride->getValue() != 1 || len->getValue() != 8) return;
-    std::string funcName;
-    if (start->getValue() == 0)
-      funcName = "SSE_unpackh_short";
-    else if (start->getValue() == 8)
-      funcName = "SSE_unpackl_short";
-    else return;
-    const FixedVectorType *VT = dyn_cast<FixedVectorType>(v->getType());
-    if (VT == FixedVectorType::get(Type::UByteTy, 16)) {
-      if (!EI.hasOneUse()) return;
-      CastInst *cast = dyn_cast<CastInst>(*EI.use_begin());
-      if (!cast) return;
-      const FixedVectorType *argTy = FixedVectorType::get(Type::SByteTy, 16);
-      const FixedVectorType *retTy = FixedVectorType::get(Type::ShortTy, 8);
-      if (cast->getType() != FixedVectorType::get(Type::ShortTy, 8))
-	return;
-      CastInst *arg = new CastInst(v, argTy, "cast", &EI);
-      std::vector<const Type*> formalArgs;
-      formalArgs.push_back(argTy);
-      std::vector<Value*> args;
-      args.push_back(arg);
-      FunctionType *FType = FunctionType::get(retTy, formalArgs, false);
-      Module *M = EI.getParent()->getParent()->getParent();
-      Function *unpack = 
-	M->getOrInsertFunction(funcName, FType);
-      CallInst *call = new CallInst(unpack, args, "unpack", &EI);
-      BinaryOperator *andInst = 
-	BinaryOperator::create(Instruction::And, call,
-			       ConstantExpr::getCast(ConstantSInt::get(Type::ShortTy, 0xFF), retTy), 
-			       "and", &EI);
-      cast->replaceAllUsesWith(andInst);
-      instructionsToDelete.insert(cast);
-      instructionsToDelete.insert(&EI);
-      changed = true;
-    }
-  }
-
   void SSE::visitCombineInst(CombineInst &CI) {
     Instruction *combine1 = cast<CombineInst>(&CI);
     Value *v1 = CI.getOperand(0);






More information about the llvm-commits mailing list