[llvm] r242008 - [InstSimplify] Teach InstSimplify how to simplify extractelement

David Majnemer david.majnemer at gmail.com
Sun Jul 12 18:15:54 PDT 2015


Author: majnemer
Date: Sun Jul 12 20:15:53 2015
New Revision: 242008

URL: http://llvm.org/viewvc/llvm-project?rev=242008&view=rev
Log:
[InstSimplify] Teach InstSimplify how to simplify extractelement

Modified:
    llvm/trunk/include/llvm/Analysis/ConstantFolding.h
    llvm/trunk/include/llvm/Analysis/InstructionSimplify.h
    llvm/trunk/include/llvm/Analysis/VectorUtils.h
    llvm/trunk/lib/Analysis/InstructionSimplify.cpp
    llvm/trunk/lib/Analysis/VectorUtils.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
    llvm/trunk/test/Transforms/InstSimplify/undef.ll

Modified: llvm/trunk/include/llvm/Analysis/ConstantFolding.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ConstantFolding.h?rev=242008&r1=242007&r2=242008&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/ConstantFolding.h (original)
+++ llvm/trunk/include/llvm/Analysis/ConstantFolding.h Sun Jul 12 20:15:53 2015
@@ -78,6 +78,11 @@ Constant *ConstantFoldInsertValueInstruc
 Constant *ConstantFoldExtractValueInstruction(Constant *Agg,
                                               ArrayRef<unsigned> Idxs);
 
+/// \brief Attempt to constant fold an extractelement instruction with the
+/// specified operands and indices.  The constant result is returned if
+/// successful; if not, null is returned.
+Constant *ConstantFoldExtractElementInstruction(Constant *Val, Constant *Idx);
+
 /// ConstantFoldLoadFromConstPtr - Return the value that a load from C would
 /// produce if it is constant and determinable.  If this is not determinable,
 /// return null.

Modified: llvm/trunk/include/llvm/Analysis/InstructionSimplify.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/InstructionSimplify.h?rev=242008&r1=242007&r2=242008&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/InstructionSimplify.h (original)
+++ llvm/trunk/include/llvm/Analysis/InstructionSimplify.h Sun Jul 12 20:15:53 2015
@@ -253,6 +253,15 @@ namespace llvm {
                                   AssumptionCache *AC = nullptr,
                                   const Instruction *CxtI = nullptr);
 
+  /// \brief Given operands for an ExtractElementInst, see if we can fold the
+  /// result.  If not, this returns null.
+  Value *SimplifyExtractElementInst(Value *Vec, Value *Idx,
+                                    const DataLayout &DL,
+                                    const TargetLibraryInfo *TLI = nullptr,
+                                    const DominatorTree *DT = nullptr,
+                                    AssumptionCache *AC = nullptr,
+                                    const Instruction *CxtI = nullptr);
+
   /// SimplifyTruncInst - Given operands for an TruncInst, see if we can fold
   /// the result.  If not, this returns null.
   Value *SimplifyTruncInst(Value *Op, Type *Ty, const DataLayout &DL,

Modified: llvm/trunk/include/llvm/Analysis/VectorUtils.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/VectorUtils.h?rev=242008&r1=242007&r2=242008&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/VectorUtils.h (original)
+++ llvm/trunk/include/llvm/Analysis/VectorUtils.h Sun Jul 12 20:15:53 2015
@@ -74,6 +74,11 @@ Value *getUniqueCastUse(Value *Ptr, Loop
 /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
 Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
 
+/// \brief Given a vector and an element number, see if the scalar value is
+/// already around as a register, for example if it were inserted then extracted
+/// from the vector.
+Value *findScalarElement(Value *V, unsigned EltNo);
+
 } // llvm namespace
 
 #endif

Modified: llvm/trunk/lib/Analysis/InstructionSimplify.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/InstructionSimplify.cpp?rev=242008&r1=242007&r2=242008&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/InstructionSimplify.cpp (original)
+++ llvm/trunk/lib/Analysis/InstructionSimplify.cpp Sun Jul 12 20:15:53 2015
@@ -24,6 +24,7 @@
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
@@ -3555,6 +3556,47 @@ Value *llvm::SimplifyExtractValueInst(Va
                                     RecursionLimit);
 }
 
+/// SimplifyExtractElementInst - Given operands for an ExtractElementInst, see if we
+/// can fold the result.  If not, this returns null.
+static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const Query &,
+                                         unsigned) {
+  if (auto *CVec = dyn_cast<Constant>(Vec)) {
+    if (auto *CIdx = dyn_cast<Constant>(Idx))
+      return ConstantFoldExtractElementInstruction(CVec, CIdx);
+
+    // The index is not relevant if our vector is a splat.
+    if (auto *Splat = CVec->getSplatValue())
+      return Splat;
+
+    if (isa<UndefValue>(Vec))
+      return UndefValue::get(Vec->getType()->getVectorElementType());
+  }
+
+  // If extracting a specified index from the vector, see if we can recursively
+  // find a previously computed scalar that was inserted into the vector.
+  if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) {
+    unsigned IndexVal = IdxC->getZExtValue();
+    unsigned VectorWidth = Vec->getType()->getVectorNumElements();
+
+    // If this is extracting an invalid index, turn this into undef, to avoid
+    // crashing the code below.
+    if (IndexVal >= VectorWidth)
+      return UndefValue::get(Vec->getType()->getVectorElementType());
+
+    if (Value *Elt = findScalarElement(Vec, IndexVal))
+      return Elt;
+  }
+
+  return nullptr;
+}
+
+Value *llvm::SimplifyExtractElementInst(
+    Value *Vec, Value *Idx, const DataLayout &DL, const TargetLibraryInfo *TLI,
+    const DominatorTree *DT, AssumptionCache *AC, const Instruction *CxtI) {
+  return ::SimplifyExtractElementInst(Vec, Idx, Query(DL, TLI, DT, AC, CxtI),
+                                      RecursionLimit);
+}
+
 /// SimplifyPHINode - See if we can fold the given phi.  If not, returns null.
 static Value *SimplifyPHINode(PHINode *PN, const Query &Q) {
   // If all of the PHI's incoming values are the same then replace the PHI node
@@ -3970,6 +4012,12 @@ Value *llvm::SimplifyInstruction(Instruc
                                       EVI->getIndices(), DL, TLI, DT, AC, I);
     break;
   }
+  case Instruction::ExtractElement: {
+    auto *EEI = cast<ExtractElementInst>(I);
+    Result = SimplifyExtractElementInst(
+        EEI->getVectorOperand(), EEI->getIndexOperand(), DL, TLI, DT, AC, I);
+    break;
+  }
   case Instruction::PHI:
     Result = SimplifyPHINode(cast<PHINode>(I), Query(DL, TLI, DT, AC, I));
     break;

Modified: llvm/trunk/lib/Analysis/VectorUtils.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/VectorUtils.cpp?rev=242008&r1=242007&r2=242008&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/VectorUtils.cpp (original)
+++ llvm/trunk/lib/Analysis/VectorUtils.cpp Sun Jul 12 20:15:53 2015
@@ -357,3 +357,55 @@ llvm::Value *llvm::getStrideFromPointer(
 
   return Stride;
 }
+
+/// \brief Given a vector and an element number, see if the scalar value is
+/// already around as a register, for example if it were inserted then extracted
+/// from the vector.
+llvm::Value *llvm::findScalarElement(llvm::Value *V, unsigned EltNo) {
+  assert(V->getType()->isVectorTy() && "Not looking at a vector?");
+  VectorType *VTy = cast<VectorType>(V->getType());
+  unsigned Width = VTy->getNumElements();
+  if (EltNo >= Width)  // Out of range access.
+    return UndefValue::get(VTy->getElementType());
+
+  if (Constant *C = dyn_cast<Constant>(V))
+    return C->getAggregateElement(EltNo);
+
+  if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) {
+    // If this is an insert to a variable element, we don't know what it is.
+    if (!isa<ConstantInt>(III->getOperand(2)))
+      return nullptr;
+    unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue();
+
+    // If this is an insert to the element we are looking for, return the
+    // inserted value.
+    if (EltNo == IIElt)
+      return III->getOperand(1);
+
+    // Otherwise, the insertelement doesn't modify the value, recurse on its
+    // vector input.
+    return findScalarElement(III->getOperand(0), EltNo);
+  }
+
+  if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) {
+    unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements();
+    int InEl = SVI->getMaskValue(EltNo);
+    if (InEl < 0)
+      return UndefValue::get(VTy->getElementType());
+    if (InEl < (int)LHSWidth)
+      return findScalarElement(SVI->getOperand(0), InEl);
+    return findScalarElement(SVI->getOperand(1), InEl - LHSWidth);
+  }
+
+  // Extract a value from a vector add operation with a constant zero.
+  Value *Val = nullptr; Constant *Con = nullptr;
+  if (match(V,
+            llvm::PatternMatch::m_Add(llvm::PatternMatch::m_Value(Val),
+                                      llvm::PatternMatch::m_Constant(Con)))) {
+    if (Con->getAggregateElement(EltNo)->isNullValue())
+      return findScalarElement(Val, EltNo);
+  }
+
+  // Otherwise, we don't know.
+  return nullptr;
+}

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineVectorOps.cpp?rev=242008&r1=242007&r2=242008&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineVectorOps.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineVectorOps.cpp Sun Jul 12 20:15:53 2015
@@ -14,6 +14,8 @@
 
 #include "InstCombineInternal.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
@@ -60,56 +62,6 @@ static bool CheapToScalarize(Value *V, b
   return false;
 }
 
-/// FindScalarElement - Given a vector and an element number, see if the scalar
-/// value is already around as a register, for example if it were inserted then
-/// extracted from the vector.
-static Value *FindScalarElement(Value *V, unsigned EltNo) {
-  assert(V->getType()->isVectorTy() && "Not looking at a vector?");
-  VectorType *VTy = cast<VectorType>(V->getType());
-  unsigned Width = VTy->getNumElements();
-  if (EltNo >= Width)  // Out of range access.
-    return UndefValue::get(VTy->getElementType());
-
-  if (Constant *C = dyn_cast<Constant>(V))
-    return C->getAggregateElement(EltNo);
-
-  if (InsertElementInst *III = dyn_cast<InsertElementInst>(V)) {
-    // If this is an insert to a variable element, we don't know what it is.
-    if (!isa<ConstantInt>(III->getOperand(2)))
-      return nullptr;
-    unsigned IIElt = cast<ConstantInt>(III->getOperand(2))->getZExtValue();
-
-    // If this is an insert to the element we are looking for, return the
-    // inserted value.
-    if (EltNo == IIElt)
-      return III->getOperand(1);
-
-    // Otherwise, the insertelement doesn't modify the value, recurse on its
-    // vector input.
-    return FindScalarElement(III->getOperand(0), EltNo);
-  }
-
-  if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(V)) {
-    unsigned LHSWidth = SVI->getOperand(0)->getType()->getVectorNumElements();
-    int InEl = SVI->getMaskValue(EltNo);
-    if (InEl < 0)
-      return UndefValue::get(VTy->getElementType());
-    if (InEl < (int)LHSWidth)
-      return FindScalarElement(SVI->getOperand(0), InEl);
-    return FindScalarElement(SVI->getOperand(1), InEl - LHSWidth);
-  }
-
-  // Extract a value from a vector add operation with a constant zero.
-  Value *Val = nullptr; Constant *Con = nullptr;
-  if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) {
-    if (Con->getAggregateElement(EltNo)->isNullValue())
-      return FindScalarElement(Val, EltNo);
-  }
-
-  // Otherwise, we don't know.
-  return nullptr;
-}
-
 // If we have a PHI node with a vector type that has only 2 uses: feed
 // itself and be an operand of extractelement at a constant location,
 // try to replace the PHI of the vector type with a PHI of a scalar type.
@@ -178,6 +130,10 @@ Instruction *InstCombiner::scalarizePHI(
 }
 
 Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
+  if (Value *V = SimplifyExtractElementInst(
+          EI.getVectorOperand(), EI.getIndexOperand(), DL, TLI, DT, AC))
+    return ReplaceInstUsesWith(EI, V);
+
   // If vector val is constant with all elements the same, replace EI with
   // that element.  We handle a known element # below.
   if (Constant *C = dyn_cast<Constant>(EI.getOperand(0)))
@@ -190,10 +146,8 @@ Instruction *InstCombiner::visitExtractE
     unsigned IndexVal = IdxC->getZExtValue();
     unsigned VectorWidth = EI.getVectorOperandType()->getNumElements();
 
-    // If this is extracting an invalid index, turn this into undef, to avoid
-    // crashing the code below.
-    if (IndexVal >= VectorWidth)
-      return ReplaceInstUsesWith(EI, UndefValue::get(EI.getType()));
+    // InstSimplify handles cases where the index is invalid.
+    assert(IndexVal < VectorWidth);
 
     // This instruction only demands the single element from the input vector.
     // If the input vector has a single use, simplify it based on this use
@@ -209,16 +163,13 @@ Instruction *InstCombiner::visitExtractE
       }
     }
 
-    if (Value *Elt = FindScalarElement(EI.getOperand(0), IndexVal))
-      return ReplaceInstUsesWith(EI, Elt);
-
     // If the this extractelement is directly using a bitcast from a vector of
     // the same number of elements, see if we can find the source element from
     // it.  In this case, we will end up needing to bitcast the scalars.
     if (BitCastInst *BCI = dyn_cast<BitCastInst>(EI.getOperand(0))) {
       if (VectorType *VT = dyn_cast<VectorType>(BCI->getOperand(0)->getType()))
         if (VT->getNumElements() == VectorWidth)
-          if (Value *Elt = FindScalarElement(BCI->getOperand(0), IndexVal))
+          if (Value *Elt = findScalarElement(BCI->getOperand(0), IndexVal))
             return new BitCastInst(Elt, EI.getType());
     }
 

Modified: llvm/trunk/test/Transforms/InstSimplify/undef.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstSimplify/undef.ll?rev=242008&r1=242007&r2=242008&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstSimplify/undef.ll (original)
+++ llvm/trunk/test/Transforms/InstSimplify/undef.ll Sun Jul 12 20:15:53 2015
@@ -265,3 +265,17 @@ define i32 @test34(i32 %a) {
   %b = lshr i32 undef, 0
   ret i32 %b
 }
+
+; CHECK-LABEL: @test35
+; CHECK: ret i32 undef
+define i32 @test35(<4 x i32> %V) {
+  %b = extractelement <4 x i32> %V, i32 4
+  ret i32 %b
+}
+
+; CHECK-LABEL: @test36
+; CHECK: ret i32 undef
+define i32 @test36(i32 %V) {
+  %b = extractelement <4 x i32> undef, i32 %V
+  ret i32 %b
+}





More information about the llvm-commits mailing list