[llvm] 3297e9b - Clean up usages of asserting vector getters in Type

Christopher Tetreault via llvm-commits llvm-commits at lists.llvm.org
Mon Apr 13 12:30:02 PDT 2020


Author: Christopher Tetreault
Date: 2020-04-13T12:29:43-07:00
New Revision: 3297e9b7c3dee210279754aa31bbd8b747656426

URL: https://github.com/llvm/llvm-project/commit/3297e9b7c3dee210279754aa31bbd8b747656426
DIFF: https://github.com/llvm/llvm-project/commit/3297e9b7c3dee210279754aa31bbd8b747656426.diff

LOG: Clean up usages of asserting vector getters in Type

Summary:
Remove usages of asserting vector getters in Type in preparation for the
VectorType refactor. The existence of these functions complicates the
refactor while adding little value.

Reviewers: rriddle, sdesmalen, efriedma

Reviewed By: efriedma

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D77259

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 09b42c25a0da..8596666db872 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -1913,8 +1913,8 @@ void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) {
 Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step,
                                           Instruction::BinaryOps BinOp) {
   // Create and check the types.
-  assert(Val->getType()->isVectorTy() && "Must be a vector");
-  int VLen = Val->getType()->getVectorNumElements();
+  auto *ValVTy = cast<VectorType>(Val->getType());
+  int VLen = ValVTy->getNumElements();
 
   Type *STy = Val->getType()->getScalarType();
   assert((STy->isIntegerTy() || STy->isFloatingPointTy()) &&
@@ -3318,13 +3318,14 @@ unsigned LoopVectorizationCostModel::getVectorIntrinsicCost(CallInst *CI,
 }
 
 static Type *smallestIntegerVectorType(Type *T1, Type *T2) {
-  auto *I1 = cast<IntegerType>(T1->getVectorElementType());
-  auto *I2 = cast<IntegerType>(T2->getVectorElementType());
+  auto *I1 = cast<IntegerType>(cast<VectorType>(T1)->getElementType());
+  auto *I2 = cast<IntegerType>(cast<VectorType>(T2)->getElementType());
   return I1->getBitWidth() < I2->getBitWidth() ? T1 : T2;
 }
+
 static Type *largestIntegerVectorType(Type *T1, Type *T2) {
-  auto *I1 = cast<IntegerType>(T1->getVectorElementType());
-  auto *I2 = cast<IntegerType>(T2->getVectorElementType());
+  auto *I1 = cast<IntegerType>(cast<VectorType>(T1)->getElementType());
+  auto *I2 = cast<IntegerType>(cast<VectorType>(T2)->getElementType());
   return I1->getBitWidth() > I2->getBitWidth() ? T1 : T2;
 }
 
@@ -3347,8 +3348,8 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() {
       Type *OriginalTy = I->getType();
       Type *ScalarTruncatedTy =
           IntegerType::get(OriginalTy->getContext(), KV.second);
-      Type *TruncatedTy = VectorType::get(ScalarTruncatedTy,
-                                          OriginalTy->getVectorNumElements());
+      Type *TruncatedTy = VectorType::get(
+          ScalarTruncatedTy, cast<VectorType>(OriginalTy)->getNumElements());
       if (TruncatedTy == OriginalTy)
         continue;
 
@@ -3398,10 +3399,12 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() {
           break;
         }
       } else if (auto *SI = dyn_cast<ShuffleVectorInst>(I)) {
-        auto Elements0 = SI->getOperand(0)->getType()->getVectorNumElements();
+        auto Elements0 =
+            cast<VectorType>(SI->getOperand(0)->getType())->getNumElements();
         auto *O0 = B.CreateZExtOrTrunc(
             SI->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements0));
-        auto Elements1 = SI->getOperand(1)->getType()->getVectorNumElements();
+        auto Elements1 =
+            cast<VectorType>(SI->getOperand(1)->getType())->getNumElements();
         auto *O1 = B.CreateZExtOrTrunc(
             SI->getOperand(1), VectorType::get(ScalarTruncatedTy, Elements1));
 
@@ -3410,13 +3413,15 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() {
         // Don't do anything with the operands, just extend the result.
         continue;
       } else if (auto *IE = dyn_cast<InsertElementInst>(I)) {
-        auto Elements = IE->getOperand(0)->getType()->getVectorNumElements();
+        auto Elements =
+            cast<VectorType>(IE->getOperand(0)->getType())->getNumElements();
         auto *O0 = B.CreateZExtOrTrunc(
             IE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements));
         auto *O1 = B.CreateZExtOrTrunc(IE->getOperand(1), ScalarTruncatedTy);
         NewI = B.CreateInsertElement(O0, O1, IE->getOperand(2));
       } else if (auto *EE = dyn_cast<ExtractElementInst>(I)) {
-        auto Elements = EE->getOperand(0)->getType()->getVectorNumElements();
+        auto Elements =
+            cast<VectorType>(EE->getOperand(0)->getType())->getNumElements();
         auto *O0 = B.CreateZExtOrTrunc(
             EE->getOperand(0), VectorType::get(ScalarTruncatedTy, Elements));
         NewI = B.CreateExtractElement(O0, EE->getOperand(2));

diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index b80e56c5b4d2..e14f8ea10e77 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -285,7 +285,7 @@ static bool isCommutative(Instruction *I) {
 static Optional<TargetTransformInfo::ShuffleKind>
 isShuffle(ArrayRef<Value *> VL) {
   auto *EI0 = cast<ExtractElementInst>(VL[0]);
-  unsigned Size = EI0->getVectorOperandType()->getVectorNumElements();
+  unsigned Size = EI0->getVectorOperandType()->getNumElements();
   Value *Vec1 = nullptr;
   Value *Vec2 = nullptr;
   enum ShuffleMode { Unknown, Select, Permute };
@@ -294,7 +294,7 @@ isShuffle(ArrayRef<Value *> VL) {
     auto *EI = cast<ExtractElementInst>(VL[I]);
     auto *Vec = EI->getVectorOperand();
     // All vector operands must have the same number of vector elements.
-    if (Vec->getType()->getVectorNumElements() != Size)
+    if (cast<VectorType>(Vec->getType())->getNumElements() != Size)
       return None;
     auto *Idx = dyn_cast<ConstantInt>(EI->getIndexOperand());
     if (!Idx)
@@ -3182,7 +3182,7 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue,
     if (!LI || !LI->isSimple() || !LI->hasNUses(VL.size()))
       return false;
   } else {
-    NElts = Vec->getType()->getVectorNumElements();
+    NElts = cast<VectorType>(Vec->getType())->getNumElements();
   }
 
   if (NElts != VL.size())

diff  --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 5c73e414ab8b..9e7cbe8c345e 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -236,10 +236,10 @@ static bool foldExtractExtract(Instruction &I, const TargetTransformInfo &TTI) {
     // ShufMask = { 2, undef, undef, undef }
     uint64_t SplatIndex = ConvertToShuffle == Ext0 ? C0 : C1;
     uint64_t CheapExtIndex = ConvertToShuffle == Ext0 ? C1 : C0;
-    Type *VecTy = V0->getType();
+    auto *VecTy = cast<VectorType>(V0->getType());
     Type *I32Ty = IntegerType::getInt32Ty(I.getContext());
     UndefValue *Undef = UndefValue::get(I32Ty);
-    SmallVector<Constant *, 32> ShufMask(VecTy->getVectorNumElements(), Undef);
+    SmallVector<Constant *, 32> ShufMask(VecTy->getNumElements(), Undef);
     ShufMask[CheapExtIndex] = ConstantInt::get(I32Ty, SplatIndex);
     IRBuilder<> Builder(ConvertToShuffle);
 
@@ -272,15 +272,14 @@ static bool foldBitcastShuf(Instruction &I, const TargetTransformInfo &TTI) {
                                                     m_Mask(Mask))))))
     return false;
 
-  Type *DestTy = I.getType();
-  Type *SrcTy = V->getType();
-  if (!DestTy->isVectorTy() || I.getOperand(0)->getType() != SrcTy)
+  auto *DestTy = dyn_cast<VectorType>(I.getType());
+  auto *SrcTy = cast<VectorType>(V->getType());
+  if (!DestTy || I.getOperand(0)->getType() != SrcTy)
     return false;
 
   // TODO: Handle bitcast from narrow element type to wide element type.
-  assert(SrcTy->isVectorTy() && "Shuffle of non-vector type?");
-  unsigned DestNumElts = DestTy->getVectorNumElements();
-  unsigned SrcNumElts = SrcTy->getVectorNumElements();
+  unsigned DestNumElts = DestTy->getNumElements();
+  unsigned SrcNumElts = SrcTy->getNumElements();
   if (SrcNumElts > DestNumElts)
     return false;
 


        


More information about the llvm-commits mailing list