[llvm] 84aa6cf - [Transforms][SROA] Promote allocas with mem2reg for scalable types

Cullen Rhodes via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 1 03:34:49 PDT 2020


Author: Cullen Rhodes
Date: 2020-04-01T10:34:11Z
New Revision: 84aa6cf1a9fe7c2d1c35b27ba6fbf1ee36a09a71

URL: https://github.com/llvm/llvm-project/commit/84aa6cf1a9fe7c2d1c35b27ba6fbf1ee36a09a71
DIFF: https://github.com/llvm/llvm-project/commit/84aa6cf1a9fe7c2d1c35b27ba6fbf1ee36a09a71.diff

LOG: [Transforms][SROA] Promote allocas with mem2reg for scalable types

Summary:
Aggregate types containing scalable vectors aren't supported and as far
as I can tell this pass is mostly concerned with optimisations on
aggregate types, so the majority of this pass isn't very useful for
scalable vectors.

This patch modifies SROA such that mem2reg is run on allocas with
scalable types that are promotable, but nothing else such as slicing is
done.

The use of TypeSize in this pass has also been updated to be explicitly
fixed size. When invoking the following methods in DataLayout:

    * getTypeSizeInBits
    * getTypeStoreSize
    * getTypeStoreSizeInBits
    * getTypeAllocSize

we now called getFixedSize on the resultant TypeSize. This is quite an
extensive change with around 50 calls to these functions, and also the
first change of this kind (being explicit about fixed vs scalable
size) as far as I'm aware, so feedback welcome.

A test is included containing IR with scalable vectors that this pass is
able to optimise.

Reviewed By: efriedma

Differential Revision: https://reviews.llvm.org/D76720

Added: 
    llvm/test/Transforms/SROA/scalable-vectors.ll

Modified: 
    llvm/lib/Transforms/Scalar/SROA.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 0c51fd5ff423..05025747db8e 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -662,7 +662,8 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
 public:
   SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS)
       : PtrUseVisitor<SliceBuilder>(DL),
-        AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {}
+        AllocSize(DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize()),
+        AS(AS) {}
 
 private:
   void markAsDead(Instruction &I) {
@@ -751,8 +752,10 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
           // For array or vector indices, scale the index by the size of the
           // type.
           APInt Index = OpC->getValue().sextOrTrunc(Offset.getBitWidth());
-          GEPOffset += Index * APInt(Offset.getBitWidth(),
-                                     DL.getTypeAllocSize(GTI.getIndexedType()));
+          GEPOffset +=
+              Index *
+              APInt(Offset.getBitWidth(),
+                    DL.getTypeAllocSize(GTI.getIndexedType()).getFixedSize());
         }
 
         // If this index has computed an intermediate pointer which is not
@@ -787,7 +790,7 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
         LI.getPointerAddressSpace() != DL.getAllocaAddrSpace())
       return PI.setAborted(&LI);
 
-    uint64_t Size = DL.getTypeStoreSize(LI.getType());
+    uint64_t Size = DL.getTypeStoreSize(LI.getType()).getFixedSize();
     return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile());
   }
 
@@ -802,7 +805,7 @@ class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> {
         SI.getPointerAddressSpace() != DL.getAllocaAddrSpace())
       return PI.setAborted(&SI);
 
-    uint64_t Size = DL.getTypeStoreSize(ValOp->getType());
+    uint64_t Size = DL.getTypeStoreSize(ValOp->getType()).getFixedSize();
 
     // If this memory access can be shown to *statically* extend outside the
     // bounds of the allocation, it's behavior is undefined, so simply
@@ -1220,7 +1223,7 @@ static bool isSafePHIToSpeculate(PHINode &PN) {
       if (BBI->mayWriteToMemory())
         return false;
 
-    uint64_t Size = DL.getTypeStoreSize(LI->getType());
+    uint64_t Size = DL.getTypeStoreSize(LI->getType()).getFixedSize();
     MaxAlign = std::max(MaxAlign, MaybeAlign(LI->getAlignment()));
     MaxSize = MaxSize.ult(Size) ? APInt(APWidth, Size) : MaxSize;
     HaveLoad = true;
@@ -1478,7 +1481,8 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL,
   // extremely poorly defined currently. The long-term goal is to remove GEPing
   // over a vector from the IR completely.
   if (VectorType *VecTy = dyn_cast<VectorType>(Ty)) {
-    unsigned ElementSizeInBits = DL.getTypeSizeInBits(VecTy->getScalarType());
+    unsigned ElementSizeInBits =
+        DL.getTypeSizeInBits(VecTy->getScalarType()).getFixedSize();
     if (ElementSizeInBits % 8 != 0) {
       // GEPs over non-multiple of 8 size vector elements are invalid.
       return nullptr;
@@ -1495,7 +1499,8 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL,
 
   if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) {
     Type *ElementTy = ArrTy->getElementType();
-    APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy));
+    APInt ElementSize(Offset.getBitWidth(),
+                      DL.getTypeAllocSize(ElementTy).getFixedSize());
     APInt NumSkippedElements = Offset.sdiv(ElementSize);
     if (NumSkippedElements.ugt(ArrTy->getNumElements()))
       return nullptr;
@@ -1517,7 +1522,7 @@ static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL,
   unsigned Index = SL->getElementContainingOffset(StructOffset);
   Offset -= APInt(Offset.getBitWidth(), SL->getElementOffset(Index));
   Type *ElementTy = STy->getElementType(Index);
-  if (Offset.uge(DL.getTypeAllocSize(ElementTy)))
+  if (Offset.uge(DL.getTypeAllocSize(ElementTy).getFixedSize()))
     return nullptr; // The offset points into alignment padding.
 
   Indices.push_back(IRB.getInt32(Index));
@@ -1549,7 +1554,8 @@ static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL,
   Type *ElementTy = Ty->getElementType();
   if (!ElementTy->isSized())
     return nullptr; // We can't GEP through an unsized element.
-  APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy));
+  APInt ElementSize(Offset.getBitWidth(),
+                    DL.getTypeAllocSize(ElementTy).getFixedSize());
   if (ElementSize == 0)
     return nullptr; // Zero-length arrays can't help us build a natural GEP.
   APInt NumSkippedElements = Offset.sdiv(ElementSize);
@@ -1716,7 +1722,8 @@ static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) {
     return false;
   }
 
-  if (DL.getTypeSizeInBits(NewTy) != DL.getTypeSizeInBits(OldTy))
+  if (DL.getTypeSizeInBits(NewTy).getFixedSize() !=
+      DL.getTypeSizeInBits(OldTy).getFixedSize())
     return false;
   if (!NewTy->isSingleValueType() || !OldTy->isSingleValueType())
     return false;
@@ -1889,7 +1896,8 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
       // Return if bitcast to vectors is 
diff erent for total size in bits.
       if (!CandidateTys.empty()) {
         VectorType *V = CandidateTys[0];
-        if (DL.getTypeSizeInBits(VTy) != DL.getTypeSizeInBits(V)) {
+        if (DL.getTypeSizeInBits(VTy).getFixedSize() !=
+            DL.getTypeSizeInBits(V).getFixedSize()) {
           CandidateTys.clear();
           return;
         }
@@ -1935,7 +1943,8 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
     // they're all integer vectors. We sort by ascending number of elements.
     auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) {
       (void)DL;
-      assert(DL.getTypeSizeInBits(RHSTy) == DL.getTypeSizeInBits(LHSTy) &&
+      assert(DL.getTypeSizeInBits(RHSTy).getFixedSize() ==
+                 DL.getTypeSizeInBits(LHSTy).getFixedSize() &&
              "Cannot have vector types of 
diff erent sizes!");
       assert(RHSTy->getElementType()->isIntegerTy() &&
              "All non-integer types eliminated!");
@@ -1963,13 +1972,14 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) {
 
   // Try each vector type, and return the one which works.
   auto CheckVectorTypeForPromotion = [&](VectorType *VTy) {
-    uint64_t ElementSize = DL.getTypeSizeInBits(VTy->getElementType());
+    uint64_t ElementSize =
+        DL.getTypeSizeInBits(VTy->getElementType()).getFixedSize();
 
     // While the definition of LLVM vectors is bitpacked, we don't support sizes
     // that aren't byte sized.
     if (ElementSize % 8)
       return false;
-    assert((DL.getTypeSizeInBits(VTy) % 8) == 0 &&
+    assert((DL.getTypeSizeInBits(VTy).getFixedSize() % 8) == 0 &&
            "vector size not a multiple of element size?");
     ElementSize /= 8;
 
@@ -1999,7 +2009,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
                                             Type *AllocaTy,
                                             const DataLayout &DL,
                                             bool &WholeAllocaOp) {
-  uint64_t Size = DL.getTypeStoreSize(AllocaTy);
+  uint64_t Size = DL.getTypeStoreSize(AllocaTy).getFixedSize();
 
   uint64_t RelBegin = S.beginOffset() - AllocBeginOffset;
   uint64_t RelEnd = S.endOffset() - AllocBeginOffset;
@@ -2015,7 +2025,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
     if (LI->isVolatile())
       return false;
     // We can't handle loads that extend past the allocated memory.
-    if (DL.getTypeStoreSize(LI->getType()) > Size)
+    if (DL.getTypeStoreSize(LI->getType()).getFixedSize() > Size)
       return false;
     // So far, AllocaSliceRewriter does not support widening split slice tails
     // in rewriteIntegerLoad.
@@ -2027,7 +2037,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
     if (!isa<VectorType>(LI->getType()) && RelBegin == 0 && RelEnd == Size)
       WholeAllocaOp = true;
     if (IntegerType *ITy = dyn_cast<IntegerType>(LI->getType())) {
-      if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy))
+      if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize())
         return false;
     } else if (RelBegin != 0 || RelEnd != Size ||
                !canConvertValue(DL, AllocaTy, LI->getType())) {
@@ -2040,7 +2050,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
     if (SI->isVolatile())
       return false;
     // We can't handle stores that extend past the allocated memory.
-    if (DL.getTypeStoreSize(ValueTy) > Size)
+    if (DL.getTypeStoreSize(ValueTy).getFixedSize() > Size)
       return false;
     // So far, AllocaSliceRewriter does not support widening split slice tails
     // in rewriteIntegerStore.
@@ -2052,7 +2062,7 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
     if (!isa<VectorType>(ValueTy) && RelBegin == 0 && RelEnd == Size)
       WholeAllocaOp = true;
     if (IntegerType *ITy = dyn_cast<IntegerType>(ValueTy)) {
-      if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy))
+      if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy).getFixedSize())
         return false;
     } else if (RelBegin != 0 || RelEnd != Size ||
                !canConvertValue(DL, ValueTy, AllocaTy)) {
@@ -2083,13 +2093,13 @@ static bool isIntegerWideningViableForSlice(const Slice &S,
 /// promote the resulting alloca.
 static bool isIntegerWideningViable(Partition &P, Type *AllocaTy,
                                     const DataLayout &DL) {
-  uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy);
+  uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy).getFixedSize();
   // Don't create integer types larger than the maximum bitwidth.
   if (SizeInBits > IntegerType::MAX_INT_BITS)
     return false;
 
   // Don't try to handle allocas with bit-padding.
-  if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy))
+  if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy).getFixedSize())
     return false;
 
   // We need to ensure that an integer type with the appropriate bitwidth can
@@ -2128,11 +2138,13 @@ static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V,
                              const Twine &Name) {
   LLVM_DEBUG(dbgs() << "       start: " << *V << "\n");
   IntegerType *IntTy = cast<IntegerType>(V->getType());
-  assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) &&
+  assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <=
+             DL.getTypeStoreSize(IntTy).getFixedSize() &&
          "Element extends past full value");
   uint64_t ShAmt = 8 * Offset;
   if (DL.isBigEndian())
-    ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset);
+    ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() -
+                 DL.getTypeStoreSize(Ty).getFixedSize() - Offset);
   if (ShAmt) {
     V = IRB.CreateLShr(V, ShAmt, Name + ".shift");
     LLVM_DEBUG(dbgs() << "     shifted: " << *V << "\n");
@@ -2157,11 +2169,13 @@ static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old,
     V = IRB.CreateZExt(V, IntTy, Name + ".ext");
     LLVM_DEBUG(dbgs() << "    extended: " << *V << "\n");
   }
-  assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) &&
+  assert(DL.getTypeStoreSize(Ty).getFixedSize() + Offset <=
+             DL.getTypeStoreSize(IntTy).getFixedSize() &&
          "Element store outside of alloca store");
   uint64_t ShAmt = 8 * Offset;
   if (DL.isBigEndian())
-    ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset);
+    ShAmt = 8 * (DL.getTypeStoreSize(IntTy).getFixedSize() -
+                 DL.getTypeStoreSize(Ty).getFixedSize() - Offset);
   if (ShAmt) {
     V = IRB.CreateShl(V, ShAmt, Name + ".shift");
     LLVM_DEBUG(dbgs() << "     shifted: " << *V << "\n");
@@ -2324,18 +2338,20 @@ class llvm::sroa::AllocaSliceRewriter
         NewAllocaBeginOffset(NewAllocaBeginOffset),
         NewAllocaEndOffset(NewAllocaEndOffset),
         NewAllocaTy(NewAI.getAllocatedType()),
-        IntTy(IsIntegerPromotable
-                  ? Type::getIntNTy(
-                        NewAI.getContext(),
-                        DL.getTypeSizeInBits(NewAI.getAllocatedType()))
-                  : nullptr),
+        IntTy(
+            IsIntegerPromotable
+                ? Type::getIntNTy(NewAI.getContext(),
+                                  DL.getTypeSizeInBits(NewAI.getAllocatedType())
+                                      .getFixedSize())
+                : nullptr),
         VecTy(PromotableVecTy),
         ElementTy(VecTy ? VecTy->getElementType() : nullptr),
-        ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy) / 8 : 0),
+        ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8
+                          : 0),
         PHIUsers(PHIUsers), SelectUsers(SelectUsers),
         IRB(NewAI.getContext(), ConstantFolder()) {
     if (VecTy) {
-      assert((DL.getTypeSizeInBits(ElementTy) % 8) == 0 &&
+      assert((DL.getTypeSizeInBits(ElementTy).getFixedSize() % 8) == 0 &&
              "Only multiple-of-8 sized vector elements are viable");
       ++NumVectorized;
     }
@@ -2500,7 +2516,8 @@ class llvm::sroa::AllocaSliceRewriter
 
     Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8)
                              : LI.getType();
-    const bool IsLoadPastEnd = DL.getTypeStoreSize(TargetTy) > SliceSize;
+    const bool IsLoadPastEnd =
+        DL.getTypeStoreSize(TargetTy).getFixedSize() > SliceSize;
     bool IsPtrAdjusted = false;
     Value *V;
     if (VecTy) {
@@ -2568,7 +2585,7 @@ class llvm::sroa::AllocaSliceRewriter
       assert(!LI.isVolatile());
       assert(LI.getType()->isIntegerTy() &&
              "Only integer type loads and stores are split");
-      assert(SliceSize < DL.getTypeStoreSize(LI.getType()) &&
+      assert(SliceSize < DL.getTypeStoreSize(LI.getType()).getFixedSize() &&
              "Split load isn't smaller than original load");
       assert(DL.typeSizeEqualsStoreSize(LI.getType()) &&
              "Non-byte-multiple bit width");
@@ -2626,7 +2643,8 @@ class llvm::sroa::AllocaSliceRewriter
   bool rewriteIntegerStore(Value *V, StoreInst &SI, AAMDNodes AATags) {
     assert(IntTy && "We cannot extract an integer from the alloca");
     assert(!SI.isVolatile());
-    if (DL.getTypeSizeInBits(V->getType()) != IntTy->getBitWidth()) {
+    if (DL.getTypeSizeInBits(V->getType()).getFixedSize() !=
+        IntTy->getBitWidth()) {
       Value *Old = IRB.CreateAlignedLoad(NewAI.getAllocatedType(), &NewAI,
                                          NewAI.getAlign(), "oldload");
       Old = convertValue(DL, IRB, Old, IntTy);
@@ -2661,7 +2679,7 @@ class llvm::sroa::AllocaSliceRewriter
       if (AllocaInst *AI = dyn_cast<AllocaInst>(V->stripInBoundsOffsets()))
         Pass.PostPromotionWorklist.insert(AI);
 
-    if (SliceSize < DL.getTypeStoreSize(V->getType())) {
+    if (SliceSize < DL.getTypeStoreSize(V->getType()).getFixedSize()) {
       assert(!SI.isVolatile());
       assert(V->getType()->isIntegerTy() &&
              "Only integer type loads and stores are split");
@@ -2677,7 +2695,8 @@ class llvm::sroa::AllocaSliceRewriter
     if (IntTy && V->getType()->isIntegerTy())
       return rewriteIntegerStore(V, SI, AATags);
 
-    const bool IsStorePastEnd = DL.getTypeStoreSize(V->getType()) > SliceSize;
+    const bool IsStorePastEnd =
+        DL.getTypeStoreSize(V->getType()).getFixedSize() > SliceSize;
     StoreInst *NewSI;
     if (NewBeginOffset == NewAllocaBeginOffset &&
         NewEndOffset == NewAllocaEndOffset &&
@@ -2792,7 +2811,7 @@ class llvm::sroa::AllocaSliceRewriter
       auto *Int8Ty = IntegerType::getInt8Ty(NewAI.getContext());
       auto *SrcTy = VectorType::get(Int8Ty, Len);
       return canConvertValue(DL, SrcTy, AllocaTy) &&
-        DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy));
+             DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy).getFixedSize());
     }();
 
     // If this doesn't map cleanly onto the alloca type, and that type isn't
@@ -2826,8 +2845,8 @@ class llvm::sroa::AllocaSliceRewriter
       unsigned NumElements = EndIndex - BeginIndex;
       assert(NumElements <= VecTy->getNumElements() && "Too many elements!");
 
-      Value *Splat =
-          getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ElementTy) / 8);
+      Value *Splat = getIntegerSplat(
+          II.getValue(), DL.getTypeSizeInBits(ElementTy).getFixedSize() / 8);
       Splat = convertValue(DL, IRB, Splat, ElementTy);
       if (NumElements > 1)
         Splat = getVectorSplat(Splat, NumElements);
@@ -2860,7 +2879,8 @@ class llvm::sroa::AllocaSliceRewriter
       assert(NewBeginOffset == NewAllocaBeginOffset);
       assert(NewEndOffset == NewAllocaEndOffset);
 
-      V = getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ScalarTy) / 8);
+      V = getIntegerSplat(II.getValue(),
+                          DL.getTypeSizeInBits(ScalarTy).getFixedSize() / 8);
       if (VectorType *AllocaVecTy = dyn_cast<VectorType>(AllocaTy))
         V = getVectorSplat(V, AllocaVecTy->getNumElements());
 
@@ -2923,7 +2943,8 @@ class llvm::sroa::AllocaSliceRewriter
     bool EmitMemCpy =
         !VecTy && !IntTy &&
         (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset ||
-         SliceSize != DL.getTypeStoreSize(NewAI.getAllocatedType()) ||
+         SliceSize !=
+             DL.getTypeStoreSize(NewAI.getAllocatedType()).getFixedSize() ||
          !NewAI.getAllocatedType()->isSingleValueType());
 
     // If we're just going to emit a memcpy, the alloca hasn't changed, and the
@@ -3469,8 +3490,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {
   if (Ty->isSingleValueType())
     return Ty;
 
-  uint64_t AllocSize = DL.getTypeAllocSize(Ty);
-  uint64_t TypeSize = DL.getTypeSizeInBits(Ty);
+  uint64_t AllocSize = DL.getTypeAllocSize(Ty).getFixedSize();
+  uint64_t TypeSize = DL.getTypeSizeInBits(Ty).getFixedSize();
 
   Type *InnerTy;
   if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) {
@@ -3483,8 +3504,8 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {
     return Ty;
   }
 
-  if (AllocSize > DL.getTypeAllocSize(InnerTy) ||
-      TypeSize > DL.getTypeSizeInBits(InnerTy))
+  if (AllocSize > DL.getTypeAllocSize(InnerTy).getFixedSize() ||
+      TypeSize > DL.getTypeSizeInBits(InnerTy).getFixedSize())
     return Ty;
 
   return stripAggregateTypeWrapping(DL, InnerTy);
@@ -3505,15 +3526,15 @@ static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) {
 /// return a type if necessary.
 static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset,
                               uint64_t Size) {
-  if (Offset == 0 && DL.getTypeAllocSize(Ty) == Size)
+  if (Offset == 0 && DL.getTypeAllocSize(Ty).getFixedSize() == Size)
     return stripAggregateTypeWrapping(DL, Ty);
-  if (Offset > DL.getTypeAllocSize(Ty) ||
-      (DL.getTypeAllocSize(Ty) - Offset) < Size)
+  if (Offset > DL.getTypeAllocSize(Ty).getFixedSize() ||
+      (DL.getTypeAllocSize(Ty).getFixedSize() - Offset) < Size)
     return nullptr;
 
   if (SequentialType *SeqTy = dyn_cast<SequentialType>(Ty)) {
     Type *ElementTy = SeqTy->getElementType();
-    uint64_t ElementSize = DL.getTypeAllocSize(ElementTy);
+    uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize();
     uint64_t NumSkippedElements = Offset / ElementSize;
     if (NumSkippedElements >= SeqTy->getNumElements())
       return nullptr;
@@ -3553,7 +3574,7 @@ static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset,
   Offset -= SL->getElementOffset(Index);
 
   Type *ElementTy = STy->getElementType(Index);
-  uint64_t ElementSize = DL.getTypeAllocSize(ElementTy);
+  uint64_t ElementSize = DL.getTypeAllocSize(ElementTy).getFixedSize();
   if (Offset >= ElementSize)
     return nullptr; // The offset points into alignment padding.
 
@@ -4121,7 +4142,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
   Type *SliceTy = nullptr;
   const DataLayout &DL = AI.getModule()->getDataLayout();
   if (Type *CommonUseTy = findCommonType(P.begin(), P.end(), P.endOffset()))
-    if (DL.getTypeAllocSize(CommonUseTy) >= P.size())
+    if (DL.getTypeAllocSize(CommonUseTy).getFixedSize() >= P.size())
       SliceTy = CommonUseTy;
   if (!SliceTy)
     if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(),
@@ -4133,7 +4154,7 @@ AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS,
     SliceTy = Type::getIntNTy(*C, P.size() * 8);
   if (!SliceTy)
     SliceTy = ArrayType::get(Type::getInt8Ty(*C), P.size());
-  assert(DL.getTypeAllocSize(SliceTy) >= P.size());
+  assert(DL.getTypeAllocSize(SliceTy).getFixedSize() >= P.size());
 
   bool IsIntegerPromotable = isIntegerWideningViable(P, SliceTy, DL);
 
@@ -4274,7 +4295,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
   // to be rewritten into a partition.
   bool IsSorted = true;
 
-  uint64_t AllocaSize = DL.getTypeAllocSize(AI.getAllocatedType());
+  uint64_t AllocaSize =
+      DL.getTypeAllocSize(AI.getAllocatedType()).getFixedSize();
   const uint64_t MaxBitVectorSize = 1024;
   if (AllocaSize <= MaxBitVectorSize) {
     // If a byte boundary is included in any load or store, a slice starting or
@@ -4338,7 +4360,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
       Changed = true;
       if (NewAI != &AI) {
         uint64_t SizeOfByte = 8;
-        uint64_t AllocaSize = DL.getTypeSizeInBits(NewAI->getAllocatedType());
+        uint64_t AllocaSize =
+            DL.getTypeSizeInBits(NewAI->getAllocatedType()).getFixedSize();
         // Don't include any padding.
         uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte);
         Fragments.push_back(Fragment(NewAI, P.beginOffset() * SizeOfByte, Size));
@@ -4358,7 +4381,8 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) {
     auto *Expr = DbgDeclares.front()->getExpression();
     auto VarSize = Var->getSizeInBits();
     DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false);
-    uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType());
+    uint64_t AllocaSize =
+        DL.getTypeSizeInBits(AI.getAllocatedType()).getFixedSize();
     for (auto Fragment : Fragments) {
       // Create a fragment expression describing the new partition or reuse AI's
       // expression if there is only one partition.
@@ -4446,8 +4470,10 @@ bool SROA::runOnAlloca(AllocaInst &AI) {
   const DataLayout &DL = AI.getModule()->getDataLayout();
 
   // Skip alloca forms that this analysis can't handle.
-  if (AI.isArrayAllocation() || !AI.getAllocatedType()->isSized() ||
-      DL.getTypeAllocSize(AI.getAllocatedType()) == 0)
+  auto *AT = AI.getAllocatedType();
+  if (AI.isArrayAllocation() || !AT->isSized() ||
+      (isa<VectorType>(AT) && cast<VectorType>(AT)->isScalable()) ||
+      DL.getTypeAllocSize(AT).getFixedSize() == 0)
     return false;
 
   bool Changed = false;
@@ -4567,8 +4593,15 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT,
   BasicBlock &EntryBB = F.getEntryBlock();
   for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end());
        I != E; ++I) {
-    if (AllocaInst *AI = dyn_cast<AllocaInst>(I))
-      Worklist.insert(AI);
+    if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) {
+      if (isa<VectorType>(AI->getAllocatedType()) &&
+          cast<VectorType>(AI->getAllocatedType())->isScalable()) {
+        if (isAllocaPromotable(AI))
+          PromotableAllocas.push_back(AI);
+      } else {
+        Worklist.insert(AI);
+      }
+    }
   }
 
   bool Changed = false;

diff  --git a/llvm/test/Transforms/SROA/scalable-vectors.ll b/llvm/test/Transforms/SROA/scalable-vectors.ll
new file mode 100644
index 000000000000..bda54e25b945
--- /dev/null
+++ b/llvm/test/Transforms/SROA/scalable-vectors.ll
@@ -0,0 +1,36 @@
+; RUN: opt < %s -sroa -S | FileCheck %s
+; RUN: opt < %s -passes=sroa -S | FileCheck %s
+
+; This test checks that SROA runs mem2reg on scalable vectors.
+
+define <vscale x 16 x i1> @alloca_nxv16i1(<vscale x 16 x i1> %pg) {
+; CHECK-LABEL: alloca_nxv16i1
+; CHECK-NEXT: ret <vscale x 16 x i1> %pg
+  %pg.addr = alloca <vscale x 16 x i1>
+  store <vscale x 16 x i1> %pg, <vscale x 16 x i1>* %pg.addr
+  %1 = load <vscale x 16 x i1>, <vscale x 16 x i1>* %pg.addr
+  ret <vscale x 16 x i1> %1
+}
+
+define <vscale x 16 x i8> @alloca_nxv16i8(<vscale x 16 x i8> %vec) {
+; CHECK-LABEL: alloca_nxv16i8
+; CHECK-NEXT: ret <vscale x 16 x i8> %vec
+  %vec.addr = alloca <vscale x 16 x i8>
+  store <vscale x 16 x i8> %vec, <vscale x 16 x i8>* %vec.addr
+  %1 = load <vscale x 16 x i8>, <vscale x 16 x i8>* %vec.addr
+  ret <vscale x 16 x i8> %1
+}
+
+; Test scalable alloca that can't be promoted. Mem2Reg only considers
+; non-volatile loads and stores for promotion.
+define <vscale x 16 x i8> @unpromotable_alloca(<vscale x 16 x i8> %vec) {
+; CHECK-LABEL: unpromotable_alloca
+; CHECK-NEXT: %vec.addr = alloca <vscale x 16 x i8>
+; CHECK-NEXT: store volatile <vscale x 16 x i8> %vec, <vscale x 16 x i8>* %vec.addr
+; CHECK-NEXT: %1 = load volatile <vscale x 16 x i8>, <vscale x 16 x i8>* %vec.addr
+; CHECK-NEXT: ret <vscale x 16 x i8> %1
+  %vec.addr = alloca <vscale x 16 x i8>
+  store volatile <vscale x 16 x i8> %vec, <vscale x 16 x i8>* %vec.addr
+  %1 = load volatile <vscale x 16 x i8>, <vscale x 16 x i8>* %vec.addr
+  ret <vscale x 16 x i8> %1
+}


        


More information about the llvm-commits mailing list