[llvm] 26aa1bb - [NFCI] [LoopIdiom] Let processLoopStridedStore take StoreSize as SCEV instead of unsigned

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 4 22:21:52 PDT 2021


Author: eopXD
Date: 2021-08-05T13:21:48+08:00
New Revision: 26aa1bbe97a3a1566633abbcf754046a6bffb155

URL: https://github.com/llvm/llvm-project/commit/26aa1bbe97a3a1566633abbcf754046a6bffb155
DIFF: https://github.com/llvm/llvm-project/commit/26aa1bbe97a3a1566633abbcf754046a6bffb155.diff

LOG: [NFCI] [LoopIdiom] Let processLoopStridedStore take StoreSize as SCEV instead of unsigned

Letting it take SCEV allows further modification on the function to optimize
if the StoreSize / Stride is runtime determined.

This is a preceeding of D107353.
The big picture is to let LoopIdiom deal with runtime-determined sizes.

Reviewed By: Whitney, lebedev.ri

Differential Revision: https://reviews.llvm.org/D104595

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 3d60e205b002..13b2703d5109 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -217,7 +217,7 @@ class LoopIdiomRecognize {
   bool processLoopMemCpy(MemCpyInst *MCI, const SCEV *BECount);
   bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount);
 
-  bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize,
+  bool processLoopStridedStore(Value *DestPtr, const SCEV *StoreSizeSCEV,
                                MaybeAlign StoreAlignment, Value *StoredVal,
                                Instruction *TheStore,
                                SmallPtrSetImpl<Instruction *> &Stores,
@@ -786,7 +786,8 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL,
 
     bool NegStride = StoreSize == -Stride;
 
-    if (processLoopStridedStore(StorePtr, StoreSize,
+    const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize);
+    if (processLoopStridedStore(StorePtr, StoreSizeSCEV,
                                 MaybeAlign(HeadStore->getAlignment()),
                                 StoredVal, HeadStore, AdjacentStores, StoreEv,
                                 BECount, NegStride)) {
@@ -936,9 +937,10 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
   SmallPtrSet<Instruction *, 1> MSIs;
   MSIs.insert(MSI);
   bool NegStride = SizeInBytes == -Stride;
-  return processLoopStridedStore(
-      Pointer, (unsigned)SizeInBytes, MaybeAlign(MSI->getDestAlignment()),
-      SplatValue, MSI, MSIs, Ev, BECount, NegStride, /*IsLoopMemset=*/true);
+  return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()),
+                                 MaybeAlign(MSI->getDestAlignment()),
+                                 SplatValue, MSI, MSIs, Ev, BECount, NegStride,
+                                 /*IsLoopMemset=*/true);
 }
 
 /// mayLoopAccessLocation - Return true if the specified loop might access the
@@ -946,7 +948,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
 /// argument specifies what the verboten forms of access are (read or write).
 static bool
 mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
-                      const SCEV *BECount, unsigned StoreSize,
+                      const SCEV *BECount, const SCEV *StoreSizeSCEV,
                       AliasAnalysis &AA,
                       SmallPtrSetImpl<Instruction *> &IgnoredStores) {
   // Get the location that may be stored across the loop.  Since the access is
@@ -956,9 +958,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
 
   // If the loop iterates a fixed number of times, we can refine the access size
   // to be exactly the size of the memset, which is (BECount+1)*StoreSize
-  if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
+  const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
+  const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
+  if (BECst && ConstSize)
     AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) *
-                                       StoreSize);
+                                       ConstSize->getValue()->getZExtValue());
 
   // TODO: For this to be really effective, we have to dive into the pointer
   // operand in the store.  Store to &A[i] of 100 will always return may alias
@@ -973,7 +977,6 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
           isModOrRefSet(
               intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access)))
         return true;
-
   return false;
 }
 
@@ -981,54 +984,78 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
 // we're trying to memset.  Therefore, we need to recompute the base pointer,
 // which is just Start - BECount*Size.
 static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount,
-                                        Type *IntPtr, unsigned StoreSize,
+                                        Type *IntPtr, const SCEV *StoreSizeSCEV,
                                         ScalarEvolution *SE) {
   const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr);
-  if (StoreSize != 1)
-    Index = SE->getMulExpr(Index, SE->getConstant(IntPtr, StoreSize),
+  if (!StoreSizeSCEV->isOne()) {
+    // index = back edge count * store size
+    Index = SE->getMulExpr(Index,
+                           SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
                            SCEV::FlagNUW);
+  }
+  // base pointer = start - index * store size
   return SE->getMinusSCEV(Start, Index);
 }
 
-/// Compute the number of bytes as a SCEV from the backedge taken count.
-///
-/// This also maps the SCEV into the provided type and tries to handle the
-/// computation in a way that will fold cleanly.
-static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
-                               unsigned StoreSize, Loop *CurLoop,
-                               const DataLayout *DL, ScalarEvolution *SE) {
-  const SCEV *NumBytesS;
-  // The # stored bytes is (BECount+1)*Size.  Expand the trip count out to
+/// Compute trip count from the backedge taken count.
+static const SCEV *getTripCount(const SCEV *BECount, Type *IntPtr,
+                                Loop *CurLoop, const DataLayout *DL,
+                                ScalarEvolution *SE) {
+  const SCEV *TripCountS = nullptr;
+  // The # stored bytes is (BECount+1).  Expand the trip count out to
   // pointer size if it isn't already.
   //
   // If we're going to need to zero extend the BE count, check if we can add
   // one to it prior to zero extending without overflow. Provided this is safe,
   // it allows better simplification of the +1.
-  if (DL->getTypeSizeInBits(BECount->getType()).getFixedSize() <
-          DL->getTypeSizeInBits(IntPtr).getFixedSize() &&
+  if (DL->getTypeSizeInBits(BECount->getType()) <
+          DL->getTypeSizeInBits(IntPtr) &&
       SE->isLoopEntryGuardedByCond(
           CurLoop, ICmpInst::ICMP_NE, BECount,
           SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
-    NumBytesS = SE->getZeroExtendExpr(
+    TripCountS = SE->getZeroExtendExpr(
         SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
         IntPtr);
   } else {
-    NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
-                               SE->getOne(IntPtr), SCEV::FlagNUW);
+    TripCountS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
+                                SE->getOne(IntPtr), SCEV::FlagNUW);
   }
 
+  return TripCountS;
+}
+
+/// Compute the number of bytes as a SCEV from the backedge taken count.
+///
+/// This also maps the SCEV into the provided type and tries to handle the
+/// computation in a way that will fold cleanly.
+static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
+                               unsigned StoreSize, Loop *CurLoop,
+                               const DataLayout *DL, ScalarEvolution *SE) {
+  const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
+
   // And scale it based on the store size.
   if (StoreSize != 1) {
-    NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
-                               SCEV::FlagNUW);
+    return SE->getMulExpr(TripCountSCEV, SE->getConstant(IntPtr, StoreSize),
+                          SCEV::FlagNUW);
   }
-  return NumBytesS;
+  return TripCountSCEV;
+}
+
+/// getNumBytes that takes StoreSize as a SCEV
+static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
+                               const SCEV *StoreSizeSCEV, Loop *CurLoop,
+                               const DataLayout *DL, ScalarEvolution *SE) {
+  const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE);
+
+  return SE->getMulExpr(TripCountSCEV,
+                        SE->getTruncateOrZeroExtend(StoreSizeSCEV, IntPtr),
+                        SCEV::FlagNUW);
 }
 
 /// processLoopStridedStore - We see a strided store of some value.  If we can
 /// transform this into a memset or memset_pattern in the loop preheader, do so.
 bool LoopIdiomRecognize::processLoopStridedStore(
-    Value *DestPtr, unsigned StoreSize, MaybeAlign StoreAlignment,
+    Value *DestPtr, const SCEV *StoreSizeSCEV, MaybeAlign StoreAlignment,
     Value *StoredVal, Instruction *TheStore,
     SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev,
     const SCEV *BECount, bool NegStride, bool IsLoopMemset) {
@@ -1057,7 +1084,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
   const SCEV *Start = Ev->getStart();
   // Handle negative strided loops.
   if (NegStride)
-    Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSize, SE);
+    Start = getStartForNegStride(Start, BECount, IntIdxTy, StoreSizeSCEV, SE);
 
   // TODO: ideally we should still be able to generate memset if SCEV expander
   // is taught to generate the dependencies at the latest point.
@@ -1082,7 +1109,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
   Changed = true;
 
   if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount,
-                            StoreSize, *AA, Stores))
+                            StoreSizeSCEV, *AA, Stores))
     return Changed;
 
   if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset))
@@ -1091,7 +1118,7 @@ bool LoopIdiomRecognize::processLoopStridedStore(
   // Okay, everything looks good, insert the memset.
 
   const SCEV *NumBytesS =
-      getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE);
+      getNumBytes(BECount, IntIdxTy, StoreSizeSCEV, CurLoop, DL, SE);
 
   // TODO: ideally we should still be able to generate memset if SCEV expander
   // is taught to generate the dependencies at the latest point.
@@ -1215,9 +1242,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
   APInt Stride = getStoreStride(StoreEv);
   bool NegStride = StoreSize == -Stride;
 
+  const SCEV *StoreSizeSCEV = SE->getConstant(BECount->getType(), StoreSize);
   // Handle negative strided loops.
   if (NegStride)
-    StrStart = getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSize, SE);
+    StrStart =
+        getStartForNegStride(StrStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
 
   // Okay, we have a strided store "p[i]" of a loaded value.  We can turn
   // this into a memcpy in the loop preheader now if we want.  However, this
@@ -1245,11 +1274,11 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
 
   bool UseMemMove =
       mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
-                            StoreSize, *AA, Stores);
+                            StoreSizeSCEV, *AA, Stores);
   if (UseMemMove) {
     Stores.insert(TheLoad);
     if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop,
-                              BECount, StoreSize, *AA, Stores)) {
+                              BECount, StoreSizeSCEV, *AA, Stores)) {
       ORE.emit([&]() {
         return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessStore",
                                         TheStore)
@@ -1268,7 +1297,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
 
   // Handle negative strided loops.
   if (NegStride)
-    LdStart = getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSize, SE);
+    LdStart =
+        getStartForNegStride(LdStart, BECount, IntIdxTy, StoreSizeSCEV, SE);
 
   // For a memcpy, we have to make sure that the input array is not being
   // mutated by the loop.
@@ -1280,7 +1310,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(
   if (IsMemCpy)
     Stores.erase(TheStore);
   if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
-                            StoreSize, *AA, Stores)) {
+                            StoreSizeSCEV, *AA, Stores)) {
     ORE.emit([&]() {
       return OptimizationRemarkMissed(DEBUG_TYPE, "LoopMayAccessLoad", TheLoad)
              << ore::NV("Inst", InstRemark) << " in "


        


More information about the llvm-commits mailing list