[llvm-commits] CVS: llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp

Chris Lattner lattner at cs.uiuc.edu
Fri Apr 14 14:42:53 PDT 2006



Changes in directory llvm/lib/Transforms/Scalar:

ScalarReplAggregates.cpp updated: 1.37 -> 1.38
---
Log message:

Teach scalarrepl to promote unions of vectors and floats, producing 
insert/extractelement operations.  This implements 
Transforms/ScalarRepl/vector_promote.ll


---
Diffs of the changes:  (+101 -46)

 ScalarReplAggregates.cpp |  147 ++++++++++++++++++++++++++++++++---------------
 1 files changed, 101 insertions(+), 46 deletions(-)


Index: llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp
diff -u llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp:1.37 llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp:1.38
--- llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp:1.37	Tue Mar  7 19:05:29 2006
+++ llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp	Fri Apr 14 16:42:41 2006
@@ -415,16 +415,30 @@
 /// MergeInType - Add the 'In' type to the accumulated type so far.  If the
 /// types are incompatible, return true, otherwise update Accum and return
 /// false.
+///
+/// There are two cases we handle here:
+///   1) An effectively integer union, where the pieces are stored into as
+///      smaller integers (common with byte swap and other idioms).
+///   2) A union of a vector and its elements.  Here we turn element accesses
+///      into insert/extract element operations.
 static bool MergeInType(const Type *In, const Type *&Accum) {
-  if (!In->isIntegral()) return true;
-  
   // If this is our first type, just use it.
-  if (Accum == Type::VoidTy) {
+  const PackedType *PTy;
+  if (Accum == Type::VoidTy || In == Accum) {
     Accum = In;
-  } else {
+  } else if (In->isIntegral() && Accum->isIntegral()) {   // integer union.
     // Otherwise pick whichever type is larger.
     if (In->getTypeID() > Accum->getTypeID())
       Accum = In;
+  } else if ((PTy = dyn_cast<PackedType>(Accum)) && 
+             PTy->getElementType() == In) {
+    // Accum is a vector, and we are accessing an element: ok.
+  } else if ((PTy = dyn_cast<PackedType>(In)) && 
+             PTy->getElementType() == Accum) {
+    // In is a vector, and accum is an element: ok, remember In.
+    Accum = In;
+  } else {
+    return true;
   }
   return false;
 }
@@ -462,7 +476,7 @@
       // Storing the pointer, not the into the value?
       if (SI->getOperand(0) == V) return 0;
       
-      // NOTE: We could handle storing of FP imms here!
+      // NOTE: We could handle storing of FP imms into integers here!
       
       if (MergeInType(SI->getOperand(0)->getType(), UsedType))
         return 0;
@@ -482,7 +496,7 @@
         IsNotTrivial = true;
         const Type *SubElt = CanConvertToScalar(GEP, IsNotTrivial);
         if (SubElt == 0) return 0;
-        if (SubElt != Type::VoidTy) {
+        if (SubElt != Type::VoidTy && SubElt->isInteger()) {
           const Type *NewTy = 
             getUIntAtLeastAsBitAs(SubElt->getPrimitiveSizeInBits()+BitOffset);
           if (NewTy == 0 || MergeInType(NewTy, UsedType)) return 0;
@@ -499,8 +513,23 @@
         
         if (const ArrayType *ATy = dyn_cast<ArrayType>(AggTy)) {
           if (Idx >= ATy->getNumElements()) return 0;  // Out of range.
-        } else if (const PackedType *PTy = dyn_cast<PackedType>(AggTy)) {
-          if (Idx >= PTy->getNumElements()) return 0;  // Out of range.
+        } else if (const PackedType *PackedTy = dyn_cast<PackedType>(AggTy)) {
+          // Getting an element of the packed vector.
+          if (Idx >= PackedTy->getNumElements()) return 0;  // Out of range.
+
+          // Merge in the packed type.
+          if (MergeInType(PackedTy, UsedType)) return 0;
+          
+          const Type *SubTy = CanConvertToScalar(GEP, IsNotTrivial);
+          if (SubTy == 0) return 0;
+          
+          if (SubTy != Type::VoidTy && MergeInType(SubTy, UsedType))
+            return 0;
+
+          // We'll need to change this to an insert/extract element operation.
+          IsNotTrivial = true;
+          continue;    // Everything looks ok
+          
         } else if (isa<StructType>(AggTy)) {
           // Structs are always ok.
         } else {
@@ -537,31 +566,47 @@
          "Not in the entry block!");
   EntryBlock->getInstList().remove(AI);  // Take the alloca out of the program.
   
+  if (ActualTy->isInteger())
+    ActualTy = ActualTy->getUnsignedVersion();
+  
   // Create and insert the alloca.
-  AllocaInst *NewAI = new AllocaInst(ActualTy->getUnsignedVersion(), 0,
-                                     AI->getName(), EntryBlock->begin());
+  AllocaInst *NewAI = new AllocaInst(ActualTy, 0, AI->getName(),
+                                     EntryBlock->begin());
   ConvertUsesToScalar(AI, NewAI, 0);
   delete AI;
 }
 
 
 /// ConvertUsesToScalar - Convert all of the users of Ptr to use the new alloca
-/// directly.  Offset is an offset from the original alloca, in bits that need
-/// to be shifted to the right.  By the end of this, there should be no uses of
-/// Ptr.
+/// directly.  This happens when we are converting an "integer union" to a
+/// single integer scalar, or when we are converting a "vector union" to a
+/// vector with insert/extractelement instructions.
+///
+/// Offset is an offset from the original alloca, in bits that need to be
+/// shifted to the right.  By the end of this, there should be no uses of Ptr.
 void SROA::ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, unsigned Offset) {
+  bool isVectorInsert = isa<PackedType>(NewAI->getType()->getElementType());
   while (!Ptr->use_empty()) {
     Instruction *User = cast<Instruction>(Ptr->use_back());
     
     if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
       // The load is a bit extract from NewAI shifted right by Offset bits.
       Value *NV = new LoadInst(NewAI, LI->getName(), LI);
-      if (Offset && Offset < NV->getType()->getPrimitiveSizeInBits())
-        NV = new ShiftInst(Instruction::Shr, NV,
-                           ConstantUInt::get(Type::UByteTy, Offset),
-                           LI->getName(), LI);
-      if (NV->getType() != LI->getType())
-        NV = new CastInst(NV, LI->getType(), LI->getName(), LI);
+      if (NV->getType() != LI->getType()) {
+        if (const PackedType *PTy = dyn_cast<PackedType>(NV->getType())) {
+          // Must be an element access.
+          unsigned Elt = Offset/PTy->getElementType()->getPrimitiveSizeInBits();
+          NV = new ExtractElementInst(NV, ConstantUInt::get(Type::UIntTy, Elt),
+                                      "tmp", LI);
+        } else {
+          assert(NV->getType()->isInteger() && "Unknown promotion!");
+          if (Offset && Offset < NV->getType()->getPrimitiveSizeInBits())
+            NV = new ShiftInst(Instruction::Shr, NV,
+                               ConstantUInt::get(Type::UByteTy, Offset),
+                               LI->getName(), LI);
+          NV = new CastInst(NV, LI->getType(), LI->getName(), LI);
+        }
+      }
       LI->replaceAllUsesWith(NV);
       LI->eraseFromParent();
     } else if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
@@ -570,31 +615,41 @@
       // Convert the stored type to the actual type, shift it left to insert
       // then 'or' into place.
       Value *SV = SI->getOperand(0);
-      if (SV->getType() != NewAI->getType()->getElementType() || Offset != 0) {
+      const Type *AllocaType = NewAI->getType()->getElementType();
+      if (SV->getType() != AllocaType) {
         Value *Old = new LoadInst(NewAI, NewAI->getName()+".in", SI);
-        // If SV is signed, convert it to unsigned, so that the next cast zero
-        // extends the value.
-        if (SV->getType()->isSigned())
-          SV = new CastInst(SV, SV->getType()->getUnsignedVersion(),
-                            SV->getName(), SI);
-        SV = new CastInst(SV, Old->getType(), SV->getName(), SI);
-        if (Offset && Offset < SV->getType()->getPrimitiveSizeInBits())
-          SV = new ShiftInst(Instruction::Shl, SV,
-                             ConstantUInt::get(Type::UByteTy, Offset),
-                             SV->getName()+".adj", SI);
-        // Mask out the bits we are about to insert from the old value.
-        unsigned TotalBits = SV->getType()->getPrimitiveSizeInBits();
-        unsigned InsertBits =
-          SI->getOperand(0)->getType()->getPrimitiveSizeInBits();
-        if (TotalBits != InsertBits) {
-          assert(TotalBits > InsertBits);
-          uint64_t Mask = ~(((1ULL << InsertBits)-1) << Offset);
-          if (TotalBits != 64)
-            Mask = Mask & ((1ULL << TotalBits)-1);
-          Old = BinaryOperator::createAnd(Old,
+        
+        if (const PackedType *PTy = dyn_cast<PackedType>(AllocaType)) {
+          // Must be an element insertion.
+          unsigned Elt = Offset/PTy->getElementType()->getPrimitiveSizeInBits();
+          SV = new InsertElementInst(Old, SV,
+                                     ConstantUInt::get(Type::UIntTy, Elt),
+                                     "tmp", SI);
+        } else {
+          // If SV is signed, convert it to unsigned, so that the next cast zero
+          // extends the value.
+          if (SV->getType()->isSigned())
+            SV = new CastInst(SV, SV->getType()->getUnsignedVersion(),
+                              SV->getName(), SI);
+          SV = new CastInst(SV, Old->getType(), SV->getName(), SI);
+          if (Offset && Offset < SV->getType()->getPrimitiveSizeInBits())
+            SV = new ShiftInst(Instruction::Shl, SV,
+                               ConstantUInt::get(Type::UByteTy, Offset),
+                               SV->getName()+".adj", SI);
+          // Mask out the bits we are about to insert from the old value.
+          unsigned TotalBits = SV->getType()->getPrimitiveSizeInBits();
+          unsigned InsertBits =
+            SI->getOperand(0)->getType()->getPrimitiveSizeInBits();
+          if (TotalBits != InsertBits) {
+            assert(TotalBits > InsertBits);
+            uint64_t Mask = ~(((1ULL << InsertBits)-1) << Offset);
+            if (TotalBits != 64)
+              Mask = Mask & ((1ULL << TotalBits)-1);
+            Old = BinaryOperator::createAnd(Old,
                                         ConstantUInt::get(Old->getType(), Mask),
-                                          Old->getName()+".mask", SI);
-          SV = BinaryOperator::createOr(Old, SV, SV->getName()+".ins", SI);
+                                            Old->getName()+".mask", SI);
+            SV = BinaryOperator::createOr(Old, SV, SV->getName()+".ins", SI);
+          }
         }
       }
       new StoreInst(SV, NewAI, SI);
@@ -603,7 +658,7 @@
     } else if (CastInst *CI = dyn_cast<CastInst>(User)) {
       unsigned NewOff = Offset;
       const TargetData &TD = getAnalysis<TargetData>();
-      if (TD.isBigEndian()) {
+      if (TD.isBigEndian() && !isVectorInsert) {
         // Adjust the pointer.  For example, storing 16-bits into a 32-bit
         // alloca with just a cast makes it modify the top 16-bits.
         const Type *SrcTy = cast<PointerType>(Ptr->getType())->getElementType();
@@ -625,7 +680,7 @@
         unsigned Idx = cast<ConstantInt>(GEP->getOperand(1))->getRawValue();
         unsigned BitOffset = Idx*AggSizeInBits;
         
-        if (TD.isLittleEndian())
+        if (TD.isLittleEndian() || isVectorInsert)
           NewOffset += BitOffset;
         else
           NewOffset -= BitOffset;
@@ -637,14 +692,14 @@
         if (const SequentialType *SeqTy = dyn_cast<SequentialType>(AggTy)) {
           unsigned ElSizeBits = TD.getTypeSize(SeqTy->getElementType())*8;
 
-          if (TD.isLittleEndian())
+          if (TD.isLittleEndian() || isVectorInsert)
             NewOffset += ElSizeBits*Idx;
           else
             NewOffset += AggSizeInBits-ElSizeBits*(Idx+1);
         } else if (const StructType *STy = dyn_cast<StructType>(AggTy)) {
           unsigned EltBitOffset = TD.getStructLayout(STy)->MemberOffsets[Idx]*8;
           
-          if (TD.isLittleEndian())
+          if (TD.isLittleEndian() || isVectorInsert)
             NewOffset += EltBitOffset;
           else {
             const PointerType *ElPtrTy = cast<PointerType>(GEP->getType());






More information about the llvm-commits mailing list