[llvm-commits] CVS: llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp

Chris Lattner sabre at nondot.org
Sun Dec 10 15:57:05 PST 2006



Changes in directory llvm/lib/Transforms/Scalar:

ScalarReplAggregates.cpp updated: 1.52 -> 1.53
---
Log message:

 * Eliminate calls to CastInst::createInferredCast.
 * Add support for promoting unions with fp values in them.  This produces
   our new int<->fp bitcast instructions, implementing 
   Transforms/ScalarRepl/union-fp-int.ll

As an example, this allows us to compile this:

union intfloat { int i; float f; };
float invsqrt(const float arg_x) {
    union intfloat x = { .f = arg_x };
    const float xhalf = arg_x * 0.5f;
    x.i = 0x5f3759df - (x.i >> 1);
    return x.f * (1.5f - xhalf * x.f * x.f);
}

into:

_invsqrt:
        movss 4(%esp), %xmm0
        movd %xmm0, %eax
        sarl %eax
        movl $1597463007, %ecx
        subl %eax, %ecx
        movd %ecx, %xmm1
        mulss LCPI1_0, %xmm0
        mulss %xmm1, %xmm0
        movss LCPI1_1, %xmm2
        mulss %xmm1, %xmm0
        subss %xmm0, %xmm2
        movl 8(%esp), %eax
        mulss %xmm2, %xmm1
        movss %xmm1, (%eax)
        ret

instead of:

_invsqrt:
        subl $4, %esp
        movss 8(%esp), %xmm0
        movss %xmm0, (%esp)
        movl (%esp), %eax
        movl $1597463007, %ecx
        sarl %eax
        subl %eax, %ecx
        movl %ecx, (%esp)
        mulss LCPI1_0, %xmm0
        movss (%esp), %xmm1
        mulss %xmm1, %xmm0
        mulss %xmm1, %xmm0
        movss LCPI1_1, %xmm2
        subss %xmm0, %xmm2
        mulss %xmm2, %xmm1
        movl 12(%esp), %eax
        movss %xmm1, (%eax)
        addl $4, %esp
        ret



---
Diffs of the changes:  (+93 -40)

 ScalarReplAggregates.cpp |  133 ++++++++++++++++++++++++++++++++---------------
 1 files changed, 93 insertions(+), 40 deletions(-)


Index: llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp
diff -u llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp:1.52 llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp:1.53
--- llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp:1.52	Wed Dec  6 11:46:33 2006
+++ llvm/lib/Transforms/Scalar/ScalarReplAggregates.cpp	Sun Dec 10 17:56:50 2006
@@ -419,11 +419,14 @@
 /// types are incompatible, return true, otherwise update Accum and return
 /// false.
 ///
-/// There are two cases we handle here:
+/// There are three cases we handle here:
 ///   1) An effectively integer union, where the pieces are stored into as
 ///      smaller integers (common with byte swap and other idioms).
 ///   2) A union of a vector and its elements.  Here we turn element accesses
 ///      into insert/extract element operations.
+///   3) A union of scalar types, such as int/float or int/pointer.  Here we
+///      merge together into integers, allowing the xform to work with #1 as
+///      well.
 static bool MergeInType(const Type *In, const Type *&Accum,
                         const TargetData &TD) {
   // If this is our first type, just use it.
@@ -436,22 +439,38 @@
       Accum = In;
   } else if (isa<PointerType>(In) && isa<PointerType>(Accum)) {
     // Pointer unions just stay as one of the pointers.
-  } else if ((PTy = dyn_cast<PackedType>(Accum)) && 
-             PTy->getElementType() == In) {
-    // Accum is a vector, and we are accessing an element: ok.
-  } else if ((PTy = dyn_cast<PackedType>(In)) && 
-             PTy->getElementType() == Accum) {
-    // In is a vector, and accum is an element: ok, remember In.
-    Accum = In;
-  } else if (isa<PointerType>(In) && Accum->isIntegral()) {
-    // Pointer/Integer unions merge together as integers.
-    return MergeInType(TD.getIntPtrType(), Accum, TD);
-  } else if (isa<PointerType>(Accum) && In->isIntegral()) {
-    // Pointer/Integer unions merge together as integers.
-    Accum = TD.getIntPtrType();
-    return MergeInType(In, Accum, TD);
+  } else if (isa<PackedType>(In) || isa<PackedType>(Accum)) {
+    if ((PTy = dyn_cast<PackedType>(Accum)) && 
+        PTy->getElementType() == In) {
+      // Accum is a vector, and we are accessing an element: ok.
+    } else if ((PTy = dyn_cast<PackedType>(In)) && 
+               PTy->getElementType() == Accum) {
+      // In is a vector, and accum is an element: ok, remember In.
+      Accum = In;
+    } else {
+      // FIXME: Handle packed->packed.
+      return true;
+    }
   } else {
-    return true;
+    // Pointer/FP/Integer unions merge together as integers.
+    switch (Accum->getTypeID()) {
+    case Type::PointerTyID: Accum = TD.getIntPtrType(); break;
+    case Type::FloatTyID:   Accum = Type::UIntTy; break;
+    case Type::DoubleTyID:  Accum = Type::ULongTy; break;
+    default:
+      assert(Accum->isIntegral() && "Unknown FP type!");
+      break;
+    }
+    
+    switch (In->getTypeID()) {
+    case Type::PointerTyID: In = TD.getIntPtrType(); break;
+    case Type::FloatTyID:   In = Type::UIntTy; break;
+    case Type::DoubleTyID:  In = Type::ULongTy; break;
+    default:
+      assert(In->isIntegral() && "Unknown FP type!");
+      break;
+    }
+    return MergeInType(In, Accum, TD);
   }
   return false;
 }
@@ -612,20 +631,35 @@
           unsigned Elt = Offset/(TD.getTypeSize(PTy->getElementType())*8);
           NV = new ExtractElementInst(NV, ConstantInt::get(Type::UIntTy, Elt),
                                       "tmp", LI);
+        } else if (isa<PointerType>(NV->getType())) {
+          assert(isa<PointerType>(LI->getType()));
+          // Must be ptr->ptr cast.  Anything else would result in NV being
+          // an integer.
+          NV = new BitCastInst(NV, LI->getType(), LI->getName(), LI);
         } else {
-          if (Offset) {
-            assert(NV->getType()->isInteger() && "Unknown promotion!");
-            if (Offset < TD.getTypeSize(NV->getType())*8) {
-              NV = new ShiftInst(Instruction::LShr, NV, 
-                                 ConstantInt::get(Type::UByteTy, Offset), 
-                                 LI->getName(), LI);
-            }
+          assert(NV->getType()->isInteger() && "Unknown promotion!");
+          if (Offset && Offset < TD.getTypeSize(NV->getType())*8) {
+            NV = new ShiftInst(Instruction::LShr, NV, 
+                               ConstantInt::get(Type::UByteTy, Offset), 
+                               LI->getName(), LI);
+          }
+          
+          // If the result is an integer, this is a trunc or bitcast.
+          if (LI->getType()->isIntegral()) {
+            NV = CastInst::createTruncOrBitCast(NV, LI->getType(),
+                                                LI->getName(), LI);
+          } else if (LI->getType()->isFloatingPoint()) {
+            // If needed, truncate the integer to the appropriate size.
+            if (NV->getType()->getPrimitiveSize() > 
+                  LI->getType()->getPrimitiveSize())
+              NV = new TruncInst(NV, LI->getType(), LI->getName(), LI);
+            
+            // Then do a bitcast.
+            NV = new BitCastInst(NV, LI->getType(), LI->getName(), LI);
           } else {
-            assert((NV->getType()->isInteger() ||
-                    isa<PointerType>(NV->getType())) && "Unknown promotion!");
+            // Otherwise must be a pointer.
+            NV = new IntToPtrInst(NV, LI->getType(), LI->getName(), LI);
           }
-          NV = CastInst::createInferredCast(NV, LI->getType(), LI->getName(), 
-                                            LI);
         }
       }
       LI->replaceAllUsesWith(NV);
@@ -647,24 +681,43 @@
                                      ConstantInt::get(Type::UIntTy, Elt),
                                      "tmp", SI);
         } else {
-          // Always zero extend the value.
-          if (SV->getType()->isSigned())
-            SV = CastInst::createInferredCast(SV, 
-                SV->getType()->getUnsignedVersion(), SV->getName(), SI);
-          SV = CastInst::createInferredCast(SV, Old->getType(), SV->getName(), 
-                                            SI);
-          if (Offset && Offset < TD.getTypeSize(SV->getType())*8)
+          // If SV is a float, convert it to the appropriate integer type.
+          // If it is a pointer, do the same, and also handle ptr->ptr casts
+          // here.
+          switch (SV->getType()->getTypeID()) {
+          default:
+            assert(!SV->getType()->isFloatingPoint() && "Unknown FP type!");
+            break;
+          case Type::FloatTyID:
+            SV = new BitCastInst(SV, Type::UIntTy, SV->getName(), SI);
+            break;
+          case Type::DoubleTyID:
+            SV = new BitCastInst(SV, Type::ULongTy, SV->getName(), SI);
+            break;
+          case Type::PointerTyID:
+            if (isa<PointerType>(AllocaType))
+              SV = new BitCastInst(SV, AllocaType, SV->getName(), SI);
+            else
+              SV = new PtrToIntInst(SV, TD.getIntPtrType(), SV->getName(), SI);
+            break;
+          }
+
+          unsigned SrcSize = TD.getTypeSize(SV->getType())*8;
+
+          // Always zero extend the value if needed.
+          if (SV->getType() != AllocaType)
+            SV = CastInst::createZExtOrBitCast(SV, AllocaType,
+                                               SV->getName(), SI);
+          if (Offset && Offset < AllocaType->getPrimitiveSizeInBits())
             SV = new ShiftInst(Instruction::Shl, SV,
                                ConstantInt::get(Type::UByteTy, Offset),
                                SV->getName()+".adj", SI);
           // Mask out the bits we are about to insert from the old value.
           unsigned TotalBits = TD.getTypeSize(SV->getType())*8;
-          unsigned InsertBits = TD.getTypeSize(SI->getOperand(0)->getType())*8;
-          if (TotalBits != InsertBits) {
-            assert(TotalBits > InsertBits);
-            uint64_t Mask = ~(((1ULL << InsertBits)-1) << Offset);
-            if (TotalBits != 64)
-              Mask = Mask & ((1ULL << TotalBits)-1);
+          if (TotalBits != SrcSize) {
+            assert(TotalBits > SrcSize);
+            uint64_t Mask = ~(((1ULL << SrcSize)-1) << Offset);
+            Mask = Mask & SV->getType()->getIntegralTypeMask();
             Old = BinaryOperator::createAnd(Old,
                                         ConstantInt::get(Old->getType(), Mask),
                                             Old->getName()+".mask", SI);






More information about the llvm-commits mailing list