[llvm-commits] [llvm] r51889 - in /llvm/trunk: lib/VMCore/ConstantFold.cpp lib/VMCore/ConstantFold.h lib/VMCore/Constants.cpp test/Assembler/insertextractvalue.ll

Dan Gohman gohman at apple.com
Mon Jun 2 17:15:21 PDT 2008


Author: djg
Date: Mon Jun  2 19:15:20 2008
New Revision: 51889

URL: http://llvm.org/viewvc/llvm-project?rev=51889&view=rev
Log:
Constant folding for insertvalue and extractvalue.

Modified:
    llvm/trunk/lib/VMCore/ConstantFold.cpp
    llvm/trunk/lib/VMCore/ConstantFold.h
    llvm/trunk/lib/VMCore/Constants.cpp
    llvm/trunk/test/Assembler/insertextractvalue.ll

Modified: llvm/trunk/lib/VMCore/ConstantFold.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/VMCore/ConstantFold.cpp?rev=51889&r1=51888&r2=51889&view=diff

==============================================================================
--- llvm/trunk/lib/VMCore/ConstantFold.cpp (original)
+++ llvm/trunk/lib/VMCore/ConstantFold.cpp Mon Jun  2 19:15:20 2008
@@ -394,6 +394,7 @@
     }
     return ConstantVector::get(Ops);
   }
+
   return 0;
 }
 
@@ -447,18 +448,112 @@
   return ConstantVector::get(&Result[0], Result.size());
 }
 
-Constant *llvm::ConstantFoldExtractValue(const Constant *Agg,
-                                         Constant* const *Idxs,
-                                         unsigned NumIdx) {
-  // FIXME: implement some constant folds
-  return 0;
+Constant *llvm::ConstantFoldExtractValueInstruction(const Constant *Agg,
+                                                    const unsigned *Idxs,
+                                                    unsigned NumIdx) {
+  // Base case: no indices, so return the entire value.
+  if (NumIdx == 0)
+    return const_cast<Constant *>(Agg);
+
+  if (isa<UndefValue>(Agg))  // ev(undef, x) -> undef
+    return UndefValue::get(ExtractValueInst::getIndexedType(Agg->getType(),
+                                                            Idxs,
+                                                            Idxs + NumIdx));
+
+  if (isa<ConstantAggregateZero>(Agg))  // ev(0, x) -> 0
+    return
+      Constant::getNullValue(ExtractValueInst::getIndexedType(Agg->getType(),
+                                                              Idxs,
+                                                              Idxs + NumIdx));
+
+  // Otherwise recurse.
+  return ConstantFoldExtractValueInstruction(Agg->getOperand(*Idxs),
+                                             Idxs+1, NumIdx-1);
 }
 
-Constant *llvm::ConstantFoldInsertValue(const Constant *Agg,
-                                        const Constant *Val,
-                                        Constant* const *Idxs,
-                                        unsigned NumIdx) {
-  // FIXME: implement some constant folds
+Constant *llvm::ConstantFoldInsertValueInstruction(const Constant *Agg,
+                                                   const Constant *Val,
+                                                   const unsigned *Idxs,
+                                                   unsigned NumIdx) {
+  // Base case: no indices, so replace the entire value.
+  if (NumIdx == 0)
+    return const_cast<Constant *>(Val);
+
+  if (isa<UndefValue>(Agg)) {
+    // Insertion of constant into aggregate undef
+    // Optimize away insertion of undef
+    if (isa<UndefValue>(Val))
+      return const_cast<Constant*>(Agg);
+    // Otherwise break the aggregate undef into multiple undefs and do
+    // the insertion
+    const CompositeType *AggTy = cast<CompositeType>(Agg->getType());
+    unsigned numOps;
+    if (const ArrayType *AR = dyn_cast<ArrayType>(AggTy))
+      numOps = AR->getNumElements();
+    else
+      numOps = cast<StructType>(AggTy)->getNumElements();
+    std::vector<Constant*> Ops(numOps); 
+    for (unsigned i = 0; i < numOps; ++i) {
+      const Type *MemberTy = AggTy->getTypeAtIndex(i);
+      const Constant *Op =
+        (*Idxs == i) ?
+        ConstantFoldInsertValueInstruction(UndefValue::get(MemberTy),
+                                           Val, Idxs+1, NumIdx-1) :
+        UndefValue::get(MemberTy);
+      Ops[i] = const_cast<Constant*>(Op);
+    }
+    if (isa<StructType>(AggTy))
+      return ConstantStruct::get(Ops);
+    else
+      return ConstantArray::get(cast<ArrayType>(AggTy), Ops);
+  }
+  if (isa<ConstantAggregateZero>(Agg)) {
+    // Insertion of constant into aggregate zero
+    // Optimize away insertion of zero
+    if (Val->isNullValue())
+      return const_cast<Constant*>(Agg);
+    // Otherwise break the aggregate zero into multiple zeros and do
+    // the insertion
+    const CompositeType *AggTy = cast<CompositeType>(Agg->getType());
+    unsigned numOps;
+    if (const ArrayType *AR = dyn_cast<ArrayType>(AggTy))
+      numOps = AR->getNumElements();
+    else
+      numOps = cast<StructType>(AggTy)->getNumElements();
+    std::vector<Constant*> Ops(numOps);
+    for (unsigned i = 0; i < numOps; ++i) {
+      const Type *MemberTy = AggTy->getTypeAtIndex(i);
+      const Constant *Op =
+        (*Idxs == i) ?
+        ConstantFoldInsertValueInstruction(Constant::getNullValue(MemberTy),
+                                           Val, Idxs+1, NumIdx-1) :
+        Constant::getNullValue(MemberTy);
+      Ops[i] = const_cast<Constant*>(Op);
+    }
+    if (isa<StructType>(AggTy))
+      return ConstantStruct::get(Ops);
+    else
+      return ConstantArray::get(cast<ArrayType>(AggTy), Ops);
+  }
+  if (isa<ConstantStruct>(Agg) || isa<ConstantArray>(Agg)) {
+    // Insertion of constant into aggregate constant
+    std::vector<Constant*> Ops(Agg->getNumOperands());
+    for (unsigned i = 0; i < Agg->getNumOperands(); ++i) {
+      const Constant *Op =
+        (*Idxs == i) ?
+        ConstantFoldInsertValueInstruction(Agg->getOperand(i),
+                                           Val, Idxs+1, NumIdx-1) :
+        Agg->getOperand(i);
+      Ops[i] = const_cast<Constant*>(Op);
+    }
+    Constant *C;
+    if (isa<StructType>(Agg->getType()))
+      C = ConstantStruct::get(Ops);
+    else
+      C = ConstantArray::get(cast<ArrayType>(Agg->getType()), Ops);
+    return C;
+  }
+
   return 0;
 }
 

Modified: llvm/trunk/lib/VMCore/ConstantFold.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/VMCore/ConstantFold.h?rev=51889&r1=51888&r2=51889&view=diff

==============================================================================
--- llvm/trunk/lib/VMCore/ConstantFold.h (original)
+++ llvm/trunk/lib/VMCore/ConstantFold.h Mon Jun  2 19:15:20 2008
@@ -41,10 +41,13 @@
   Constant *ConstantFoldShuffleVectorInstruction(const Constant *V1,
                                                  const Constant *V2,
                                                  const Constant *Mask);
-  Constant *ConstantFoldExtractValue(const Constant *Agg,
-                                     Constant* const *Idxs, unsigned NumIdx);
-  Constant *ConstantFoldInsertValue(const Constant *Agg, const Constant *Val,
-                                    Constant* const *Idxs, unsigned NumIdx);
+  Constant *ConstantFoldExtractValueInstruction(const Constant *Agg,
+                                                const unsigned *Idxs,
+                                                unsigned NumIdx);
+  Constant *ConstantFoldInsertValueInstruction(const Constant *Agg,
+                                               const Constant *Val,
+                                               const unsigned* Idxs,
+                                               unsigned NumIdx);
   Constant *ConstantFoldBinaryInstruction(unsigned Opcode, const Constant *V1,
                                           const Constant *V2);
   Constant *ConstantFoldCompareInstruction(unsigned short predicate, 

Modified: llvm/trunk/lib/VMCore/Constants.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/VMCore/Constants.cpp?rev=51889&r1=51888&r2=51889&view=diff

==============================================================================
--- llvm/trunk/lib/VMCore/Constants.cpp (original)
+++ llvm/trunk/lib/VMCore/Constants.cpp Mon Jun  2 19:15:20 2008
@@ -2305,9 +2305,10 @@
          "insertvalue indices invalid!");
   assert(Agg->getType() == ReqTy &&
          "insertvalue type invalid!");
-
   assert(Agg->getType()->isFirstClassType() &&
          "Non-first-class type for constant InsertValue expression");
+  if (Constant *FC = ConstantFoldInsertValueInstruction(Agg, Val, Idxs, NumIdx))
+    return FC;          // Fold a few common cases...
   // Look up the constant in the table first to ensure uniqueness
   std::vector<Constant*> ArgVec;
   ArgVec.push_back(Agg);
@@ -2336,6 +2337,8 @@
          "extractvalue indices invalid!");
   assert(Agg->getType()->isFirstClassType() &&
          "Non-first-class type for constant extractvalue expression");
+  if (Constant *FC = ConstantFoldExtractValueInstruction(Agg, Idxs, NumIdx))
+    return FC;          // Fold a few common cases...
   // Look up the constant in the table first to ensure uniqueness
   std::vector<Constant*> ArgVec;
   ArgVec.push_back(Agg);

Modified: llvm/trunk/test/Assembler/insertextractvalue.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Assembler/insertextractvalue.ll?rev=51889&r1=51888&r2=51889&view=diff

==============================================================================
--- llvm/trunk/test/Assembler/insertextractvalue.ll (original)
+++ llvm/trunk/test/Assembler/insertextractvalue.ll Mon Jun  2 19:15:20 2008
@@ -1,4 +1,6 @@
-; RUN: llvm-as < %s
+; RUN: llvm-as < %s | llvm-dis > %t
+; RUN: grep insertvalue %t | count 1
+; RUN: grep extractvalue %t | count 1
 
 define float @foo({{i32},{float, double}}* %p) {
   %t = load {{i32},{float, double}}* %p
@@ -11,3 +13,11 @@
   store {{i32},{float, double}} insertvalue ({{i32},{float, double}}{{i32}{i32 4},{float, double}{float 4.0, double 5.0}}, double 20.0, 1, 1), {{i32},{float, double}}* %p
   ret float extractvalue ({{i32},{float, double}}{{i32}{i32 3},{float, double}{float 7.0, double 9.0}}, 1, 0)
 }
+define float @car({{i32},{float, double}}* %p) {
+  store {{i32},{float, double}} insertvalue ({{i32},{float, double}} undef, double 20.0, 1, 1), {{i32},{float, double}}* %p
+  ret float extractvalue ({{i32},{float, double}} undef, 1, 0)
+}
+define float @dar({{i32},{float, double}}* %p) {
+  store {{i32},{float, double}} insertvalue ({{i32},{float, double}} zeroinitializer, double 20.0, 1, 1), {{i32},{float, double}}* %p
+  ret float extractvalue ({{i32},{float, double}} zeroinitializer, 1, 0)
+}





More information about the llvm-commits mailing list