[llvm] r308968 - [LIR] Teach LIR to avoid extending the BE count prior to adding one to

Chandler Carruth via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 25 03:48:33 PDT 2017


Author: chandlerc
Date: Tue Jul 25 03:48:32 2017
New Revision: 308968

URL: http://llvm.org/viewvc/llvm-project?rev=308968&view=rev
Log:
[LIR] Teach LIR to avoid extending the BE count prior to adding one to
it when safe.

Very often the BE count is the trip count minus one, and the plus one
here should fold with that minus one. But because the BE count might in
theory be UINT_MAX or some such, adding one before we extend could in
some cases wrap to zero and break when we scale things.

This patch checks to see if it would be safe to add one because the
specific case that would cause this is guarded for prior to entering the
preheader. This should handle essentially all of the common loop idioms
coming out of C/C++ code once canonicalized by LLVM.

Before this patch, both forms of loop in the added test cases ended up
subtracting one from the size, extending it, scaling it up by 8 and then
adding 8 back onto it. This is really silly, and it turns out made it
all the way into generated code very often, so this is a surprisingly
important cleanup to do.

Many thanks to Sanjoy for showing me how to do this with SCEV.

Differential Revision: https://reviews.llvm.org/D35758

Modified:
    llvm/trunk/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
    llvm/trunk/test/Transforms/LoopIdiom/basic.ll

Modified: llvm/trunk/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/LoopIdiomRecognize.cpp?rev=308968&r1=308967&r2=308968&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/LoopIdiomRecognize.cpp Tue Jul 25 03:48:32 2017
@@ -780,6 +780,41 @@ static const SCEV *getStartForNegStride(
   return SE->getMinusSCEV(Start, Index);
 }
 
+/// Compute the number of bytes as a SCEV from the backedge taken count.
+///
+/// This also maps the SCEV into the provided type and tries to handle the
+/// computation in a way that will fold cleanly.
+static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr,
+                               unsigned StoreSize, Loop *CurLoop,
+                               const DataLayout *DL, ScalarEvolution *SE) {
+  const SCEV *NumBytesS;
+  // The # stored bytes is (BECount+1)*Size.  Expand the trip count out to
+  // pointer size if it isn't already.
+  //
+  // If we're going to need to zero extend the BE count, check if we can add
+  // one to it prior to zero extending without overflow. Provided this is safe,
+  // it allows better simplification of the +1.
+  if (DL->getTypeSizeInBits(BECount->getType()) <
+          DL->getTypeSizeInBits(IntPtr) &&
+      SE->isLoopEntryGuardedByCond(
+          CurLoop, ICmpInst::ICMP_NE, BECount,
+          SE->getNegativeSCEV(SE->getOne(BECount->getType())))) {
+    NumBytesS = SE->getZeroExtendExpr(
+        SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW),
+        IntPtr);
+  } else {
+    NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr),
+                               SE->getOne(IntPtr), SCEV::FlagNUW);
+  }
+
+  // And scale it based on the store size.
+  if (StoreSize != 1) {
+    NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
+                               SCEV::FlagNUW);
+  }
+  return NumBytesS;
+}
+
 /// processLoopStridedStore - We see a strided store of some value.  If we can
 /// transform this into a memset or memset_pattern in the loop preheader, do so.
 bool LoopIdiomRecognize::processLoopStridedStore(
@@ -837,16 +872,8 @@ bool LoopIdiomRecognize::processLoopStri
 
   // Okay, everything looks good, insert the memset.
 
-  // The # stored bytes is (BECount+1)*Size.  Expand the trip count out to
-  // pointer size if it isn't already.
-  BECount = SE->getTruncateOrZeroExtend(BECount, IntPtr);
-
   const SCEV *NumBytesS =
-      SE->getAddExpr(BECount, SE->getOne(IntPtr), SCEV::FlagNUW);
-  if (StoreSize != 1) {
-    NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize),
-                               SCEV::FlagNUW);
-  }
+      getNumBytes(BECount, IntPtr, StoreSize, CurLoop, DL, SE);
 
   // TODO: ideally we should still be able to generate memset if SCEV expander
   // is taught to generate the dependencies at the latest point.
@@ -976,16 +1003,8 @@ bool LoopIdiomRecognize::processLoopStor
 
   // Okay, everything is safe, we can transform this!
 
-  // The # stored bytes is (BECount+1)*Size.  Expand the trip count out to
-  // pointer size if it isn't already.
-  BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy);
-
   const SCEV *NumBytesS =
-      SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW);
-
-  if (StoreSize != 1)
-    NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize),
-                               SCEV::FlagNUW);
+      getNumBytes(BECount, IntPtrTy, StoreSize, CurLoop, DL, SE);
 
   Value *NumBytes =
       Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator());

Modified: llvm/trunk/test/Transforms/LoopIdiom/basic.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/LoopIdiom/basic.ll?rev=308968&r1=308967&r2=308968&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/LoopIdiom/basic.ll (original)
+++ llvm/trunk/test/Transforms/LoopIdiom/basic.ll Tue Jul 25 03:48:32 2017
@@ -563,6 +563,75 @@ for.end6:
 ; CHECK: ret void
 }
 
+; Handle loops where the trip count is a narrow integer that needs to be
+; extended.
+define void @form_memset_narrow_size(i64* %ptr, i32 %size) {
+; CHECK-LABEL: @form_memset_narrow_size(
+entry:
+  %cmp1 = icmp sgt i32 %size, 0
+  br i1 %cmp1, label %loop.ph, label %exit
+; CHECK:       entry:
+; CHECK:         %[[C1:.*]] = icmp sgt i32 %size, 0
+; CHECK-NEXT:    br i1 %[[C1]], label %loop.ph, label %exit
+
+loop.ph:
+  br label %loop.body
+; CHECK:       loop.ph:
+; CHECK-NEXT:    %[[ZEXT_SIZE:.*]] = zext i32 %size to i64
+; CHECK-NEXT:    %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3
+; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* %{{.*}}, i8 0, i64 %[[SCALED_SIZE]], i32 8, i1 false)
+
+loop.body:
+  %storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ]
+  %idxprom = sext i32 %storemerge4 to i64
+  %arrayidx = getelementptr inbounds i64, i64* %ptr, i64 %idxprom
+  store i64 0, i64* %arrayidx, align 8
+  %inc = add nsw i32 %storemerge4, 1
+  %cmp2 = icmp slt i32 %inc, %size
+  br i1 %cmp2, label %loop.body, label %loop.exit
+
+loop.exit:
+  br label %exit
+
+exit:
+  ret void
+}
+
+define void @form_memcpy_narrow_size(i64* noalias %dst, i64* noalias %src, i32 %size) {
+; CHECK-LABEL: @form_memcpy_narrow_size(
+entry:
+  %cmp1 = icmp sgt i32 %size, 0
+  br i1 %cmp1, label %loop.ph, label %exit
+; CHECK:       entry:
+; CHECK:         %[[C1:.*]] = icmp sgt i32 %size, 0
+; CHECK-NEXT:    br i1 %[[C1]], label %loop.ph, label %exit
+
+loop.ph:
+  br label %loop.body
+; CHECK:       loop.ph:
+; CHECK-NEXT:    %[[ZEXT_SIZE:.*]] = zext i32 %size to i64
+; CHECK-NEXT:    %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3
+; CHECK-NEXT:    call void @llvm.memcpy.p0i8.p0i8.i64(i8* %{{.*}}, i8* %{{.*}}, i64 %[[SCALED_SIZE]], i32 8, i1 false)
+
+loop.body:
+  %storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ]
+  %idxprom1 = sext i32 %storemerge4 to i64
+  %arrayidx1 = getelementptr inbounds i64, i64* %src, i64 %idxprom1
+  %v = load i64, i64* %arrayidx1, align 8
+  %idxprom2 = sext i32 %storemerge4 to i64
+  %arrayidx2 = getelementptr inbounds i64, i64* %dst, i64 %idxprom2
+  store i64 %v, i64* %arrayidx2, align 8
+  %inc = add nsw i32 %storemerge4, 1
+  %cmp2 = icmp slt i32 %inc, %size
+  br i1 %cmp2, label %loop.body, label %loop.exit
+
+loop.exit:
+  br label %exit
+
+exit:
+  ret void
+}
+
 ; Validate that "memset_pattern" has the proper attributes.
 ; CHECK: declare void @memset_pattern16(i8* nocapture, i8* nocapture readonly, i64) [[ATTRS:#[0-9]+]]
 ; CHECK: [[ATTRS]] = { argmemonly }




More information about the llvm-commits mailing list