[llvm] bc17d32 - [LoopIdiom] Let LIR fold memset pointer / stride SCEV regarding loop guards

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Wed Dec 15 09:08:34 PST 2021


On 12/14/21 10:34 PM, eop Chen wrote:
> Hi Philip,
>
> Sounds good. Let me try and apply your advice.
>
> On the other hand, I saw the thread of “profit driven LICM” today.
> Will you like to share your thoughts in the LoopOpt working group meeting in the near future?
I have shared my thoughts in the llvm-dev thread.  I'm happy to verbally 
discuss either in the LWG, or private discussion.  We didn't get to that 
today, so if you want to chat before Jan we should probably arrange a 
separate call.
>
> Best Regards,
>
> Yueh-Ting (Eop) Chen 陳約廷
>
>> Philip Reames <listmail at philipreames.com> 於 2021年12月15日 上午12:47 寫道:
>>
>> Any reason you need the custom rewriter instead of using applyLoopGuards?
>>
>> Philip
>>
>> On 12/13/21 9:37 AM, via llvm-commits wrote:
>>> Author: eopXD
>>> Date: 2021-12-13T09:36:58-08:00
>>> New Revision: bc17d32a5f71b161186423c200554bddb6fb7e43
>>>
>>> URL: https://github.com/llvm/llvm-project/commit/bc17d32a5f71b161186423c200554bddb6fb7e43
>>> DIFF: https://github.com/llvm/llvm-project/commit/bc17d32a5f71b161186423c200554bddb6fb7e43.diff
>>>
>>> LOG: [LoopIdiom] Let LIR fold memset pointer / stride SCEV regarding loop guards
>>>
>>> Expression guraded in loop entry can be folded prior to comparison. This patch
>>> proceeds D107353 and makes LIR able to deal with nested for-loop.
>>>
>>> Reviewed By: qianzhen, bmahjour
>>>
>>> Differential Revision: https://reviews.llvm.org/D108112
>>>
>>> Added:
>>>      
>>> Modified:
>>>      llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
>>>      llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
>>>      llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
>>>      llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
>>>
>>> Removed:
>>>      
>>>
>>> ################################################################################
>>> diff  --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
>>> index 42da86a9ecf50..e6efb422e7831 100644
>>> --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
>>> +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
>>> @@ -307,6 +307,25 @@ class LoopIdiomRecognizeLegacyPass : public LoopPass {
>>>     }
>>>   };
>>>   +// The Folder will fold expressions that are guarded by the loop entry.
>>> +class SCEVSignToZeroExtentionRewriter
>>> +    : public SCEVRewriteVisitor<SCEVSignToZeroExtentionRewriter> {
>>> +public:
>>> +  ScalarEvolution &SE;
>>> +  const Loop *CurLoop;
>>> +  SCEVSignToZeroExtentionRewriter(ScalarEvolution &SE, const Loop *CurLoop)
>>> +      : SCEVRewriteVisitor(SE), SE(SE), CurLoop(CurLoop) {}
>>> +
>>> +  const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
>>> +    // If expression is guarded by CurLoop to be greater or equal to zero
>>> +    // then convert sext to zext. Otherwise return the original expression.
>>> +    if (SE.isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Expr,
>>> +                                    SE.getZero(Expr->getType())))
>>> +      return SE.getZeroExtendExpr(visit(Expr->getOperand()), Expr->getType());
>>> +    return Expr;
>>> +  }
>>> +};
>>> +
>>>   } // end anonymous namespace
>>>     char LoopIdiomRecognizeLegacyPass::ID = 0;
>>> @@ -967,12 +986,22 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
>>>                         << "\n");
>>>         if (PositiveStrideSCEV != MemsetSizeSCEV) {
>>> -      // TODO: folding can be done to the SCEVs
>>> -      // The folding is to fold expressions that is covered by the loop guard
>>> -      // at loop entry. After the folding, compare again and proceed
>>> -      // optimization if equal.
>>> -      LLVM_DEBUG(dbgs() << "  SCEV don't match, abort\n");
>>> -      return false;
>>> +      // The folding is to fold an expression that is covered by the loop guard
>>> +      // at loop entry. After the folding, compare again and proceed with
>>> +      // optimization, if equal.
>>> +      SCEVSignToZeroExtentionRewriter Folder(*SE, CurLoop);
>>> +      const SCEV *FoldedPositiveStride = Folder.visit(PositiveStrideSCEV);
>>> +      const SCEV *FoldedMemsetSize = Folder.visit(MemsetSizeSCEV);
>>> +
>>> +      LLVM_DEBUG(dbgs() << "  Try to fold SCEV based on loop guard\n"
>>> +                        << "    FoldedMemsetSize: " << *FoldedMemsetSize << "\n"
>>> +                        << "    FoldedPositiveStride: " << *FoldedPositiveStride
>>> +                        << "\n");
>>> +
>>> +      if (FoldedPositiveStride != FoldedMemsetSize) {
>>> +        LLVM_DEBUG(dbgs() << "  SCEV don't match, abort\n");
>>> +        return false;
>>> +      }
>>>       }
>>>     }
>>>   
>>> diff  --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
>>> index fe6ad07a1feba..69e96d53f2202 100644
>>> --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
>>> +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
>>> @@ -369,4 +369,52 @@ for.end17:                                        ; preds = %for.end17.loopexit,
>>>     ret void
>>>   }
>>>   +; void NegStart(int n, int m, int *ar) {
>>> +;   for (int i = -100; i < n; i++) {
>>> +;     int *arr = ar + (i + 100) * m;
>>> +;     memset(arr, 0, m * sizeof(int));
>>> +;   }
>>> +; }
>>> +define void @NegStart(i32 %n, i32 %m, i32* %ar) {
>>> +; CHECK-LABEL: @NegStart(
>>> +; CHECK-NEXT:  entry:
>>> +; CHECK-NEXT:    [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8*
>>> +; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 -100, [[N:%.*]]
>>> +; CHECK-NEXT:    br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]]
>>> +; CHECK:       for.body.lr.ph:
>>> +; CHECK-NEXT:    [[MUL1:%.*]] = mul i32 [[M:%.*]], 4
>>> +; CHECK-NEXT:    [[TMP0:%.*]] = add i32 [[N]], 100
>>> +; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[M]], [[TMP0]]
>>> +; CHECK-NEXT:    [[TMP2:%.*]] = shl i32 [[TMP1]], 2
>>> +; CHECK-NEXT:    call void @llvm.memset.p0i8.i32(i8* align 4 [[AR1]], i8 0, i32 [[TMP2]], i1 false)
>>> +; CHECK-NEXT:    br label [[FOR_END]]
>>> +; CHECK:       for.end:
>>> +; CHECK-NEXT:    ret void
>>> +;
>>> +entry:
>>> +  %cmp1 = icmp slt i32 -100, %n
>>> +  br i1 %cmp1, label %for.body.lr.ph, label %for.end
>>> +
>>> +for.body.lr.ph:                                   ; preds = %entry
>>> +  %mul1 = mul i32 %m, 4
>>> +  br label %for.body
>>> +
>>> +for.body:                                         ; preds = %for.body.lr.ph, %for.body
>>> +  %i.02 = phi i32 [ -100, %for.body.lr.ph ], [ %inc, %for.body ]
>>> +  %add = add nsw i32 %i.02, 100
>>> +  %mul = mul nsw i32 %add, %m
>>> +  %add.ptr = getelementptr inbounds i32, i32* %ar, i32 %mul
>>> +  %0 = bitcast i32* %add.ptr to i8*
>>> +  call void @llvm.memset.p0i8.i32(i8* align 4 %0, i8 0, i32 %mul1, i1 false)
>>> +  %inc = add nsw i32 %i.02, 1
>>> +  %exitcond = icmp ne i32 %inc, %n
>>> +  br i1 %exitcond, label %for.body, label %for.end.loopexit
>>> +
>>> +for.end.loopexit:                                 ; preds = %for.body
>>> +  br label %for.end
>>> +
>>> +for.end:                                          ; preds = %for.end.loopexit, %entry
>>> +  ret void
>>> +}
>>> +
>>>   declare void @llvm.memset.p0i8.i32(i8* nocapture writeonly, i8, i32, i1 immarg)
>>>
>>> diff  --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
>>> index f6d2441a85cc1..5d485b82ada79 100644
>>> --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
>>> +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
>>> @@ -268,6 +268,12 @@ for.body:                                         ; preds = %for.body.lr.ph, %fo
>>>   for.end:                                          ; preds = %for.body, %entry
>>>     ret void
>>>   }
>>> +; This case requires SCEVFolder in LoopIdiomRecognize.cpp to fold SCEV prior to comparison.
>>> +; For the inner-loop, SCEVFolder is not needed, however the promoted memset size would be based
>>> +; on the trip count of inner-loop (which is an unsigned integer).
>>> +; Then in the outer loop, the pointer stride SCEV for memset needs to be converted based on the
>>> +; loop guard for it to equal to the memset size SCEV. The loop guard guaranteeds that m >= 0
>>> +; inside the loop, so m can be converted from sext to zext, making the two SCEV-s equal.
>>>   ; void NestedFor32(int *ar, int n, int m, int o)
>>>   ; {
>>>   ;   int i, j;
>>> @@ -281,6 +287,7 @@ for.end:                                          ; preds = %for.body, %entry
>>>   define void @NestedFor32(i32* %ar, i32 %n, i32 %m, i32 %o) {
>>>   ; CHECK-LABEL: @NestedFor32(
>>>   ; CHECK-NEXT:  entry:
>>> +; CHECK-NEXT:    [[AR2:%.*]] = bitcast i32* [[AR:%.*]] to i8*
>>>   ; CHECK-NEXT:    [[CMP3:%.*]] = icmp slt i32 0, [[N:%.*]]
>>>   ; CHECK-NEXT:    br i1 [[CMP3]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END11:%.*]]
>>>   ; CHECK:       for.body.lr.ph:
>>> @@ -296,17 +303,10 @@ define void @NestedFor32(i32* %ar, i32 %n, i32 %m, i32 %o) {
>>>   ; CHECK-NEXT:    [[TMP3:%.*]] = zext i32 [[M]] to i64
>>>   ; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP0]], [[TMP3]]
>>>   ; CHECK-NEXT:    [[TMP5:%.*]] = shl i64 [[TMP4]], 2
>>> -; CHECK-NEXT:    br label [[FOR_BODY_US:%.*]]
>>> -; CHECK:       for.body.us:
>>> -; CHECK-NEXT:    [[INDVARS_IV6:%.*]] = phi i64 [ 0, [[FOR_BODY_US_PREHEADER]] ], [ [[INDVARS_IV_NEXT7:%.*]], [[FOR_BODY_US]] ]
>>> -; CHECK-NEXT:    [[TMP6:%.*]] = mul i64 [[TMP2]], [[INDVARS_IV6]]
>>> -; CHECK-NEXT:    [[SCEVGEP:%.*]] = getelementptr i32, i32* [[AR:%.*]], i64 [[TMP6]]
>>> -; CHECK-NEXT:    [[SCEVGEP1:%.*]] = bitcast i32* [[SCEVGEP]] to i8*
>>> -; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[M]] to i64
>>> -; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 4 [[SCEVGEP1]], i8 0, i64 [[TMP5]], i1 false)
>>> -; CHECK-NEXT:    [[INDVARS_IV_NEXT7]] = add nuw nsw i64 [[INDVARS_IV6]], 1
>>> -; CHECK-NEXT:    [[EXITCOND11:%.*]] = icmp ne i64 [[INDVARS_IV_NEXT7]], [[WIDE_TRIP_COUNT10]]
>>> -; CHECK-NEXT:    br i1 [[EXITCOND11]], label [[FOR_BODY_US]], label [[FOR_END11]]
>>> +; CHECK-NEXT:    [[TMP6:%.*]] = mul i64 [[TMP4]], [[WIDE_TRIP_COUNT10]]
>>> +; CHECK-NEXT:    [[TMP7:%.*]] = shl i64 [[TMP6]], 2
>>> +; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 4 [[AR2]], i8 0, i64 [[TMP7]], i1 false)
>>> +; CHECK-NEXT:    br label [[FOR_END11]]
>>>   ; CHECK:       for.end11:
>>>   ; CHECK-NEXT:    ret void
>>>   ;
>>> @@ -357,4 +357,58 @@ for.end11:                                        ; preds = %for.end11.loopexit,
>>>     ret void
>>>   }
>>>   +; void NegStart(int n, int m, int *ar) {
>>> +;   for (int i = -100; i < n; i++) {
>>> +;     int *arr = ar + (i + 100) * m;
>>> +;     memset(arr, 0, m * sizeof(int));
>>> +;   }
>>> +; }
>>> +define void @NegStart(i32 %n, i32 %m, i32* %ar) {
>>> +; CHECK-LABEL: @NegStart(
>>> +; CHECK-NEXT:  entry:
>>> +; CHECK-NEXT:    [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8*
>>> +; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 -100, [[N:%.*]]
>>> +; CHECK-NEXT:    br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]]
>>> +; CHECK:       for.body.lr.ph:
>>> +; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[M:%.*]] to i64
>>> +; CHECK-NEXT:    [[MUL1:%.*]] = mul i64 [[CONV]], 4
>>> +; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[M]] to i64
>>> +; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = sext i32 [[N]] to i64
>>> +; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[WIDE_TRIP_COUNT]], 100
>>> +; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], [[TMP0]]
>>> +; CHECK-NEXT:    [[TMP3:%.*]] = shl i64 [[TMP2]], 2
>>> +; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 4 [[AR1]], i8 0, i64 [[TMP3]], i1 false)
>>> +; CHECK-NEXT:    br label [[FOR_END]]
>>> +; CHECK:       for.end:
>>> +; CHECK-NEXT:    ret void
>>> +;
>>> +entry:
>>> +  %cmp1 = icmp slt i32 -100, %n
>>> +  br i1 %cmp1, label %for.body.lr.ph, label %for.end
>>> +
>>> +for.body.lr.ph:                                   ; preds = %entry
>>> +  %conv = sext i32 %m to i64
>>> +  %mul1 = mul i64 %conv, 4
>>> +  %0 = sext i32 %m to i64
>>> +  %wide.trip.count = sext i32 %n to i64
>>> +  br label %for.body
>>> +
>>> +for.body:                                         ; preds = %for.body.lr.ph, %for.body
>>> +  %indvars.iv = phi i64 [ -100, %for.body.lr.ph ], [ %indvars.iv.next, %for.body ]
>>> +  %1 = add nsw i64 %indvars.iv, 100
>>> +  %2 = mul nsw i64 %1, %0
>>> +  %add.ptr = getelementptr inbounds i32, i32* %ar, i64 %2
>>> +  %3 = bitcast i32* %add.ptr to i8*
>>> +  call void @llvm.memset.p0i8.i64(i8* align 4 %3, i8 0, i64 %mul1, i1 false)
>>> +  %indvars.iv.next = add nsw i64 %indvars.iv, 1
>>> +  %exitcond = icmp ne i64 %indvars.iv.next, %wide.trip.count
>>> +  br i1 %exitcond, label %for.body, label %for.end.loopexit
>>> +
>>> +for.end.loopexit:                                 ; preds = %for.body
>>> +  br label %for.end
>>> +
>>> +for.end:                                          ; preds = %for.end.loopexit, %entry
>>> +  ret void
>>> +}
>>> +
>>>   declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg)
>>>
>>> diff  --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
>>> index 8ee554eb6d25e..95f9c969087ef 100644
>>> --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
>>> +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
>>> @@ -19,6 +19,9 @@
>>>   ; CHECK-NEXT: memset size is non-constant
>>>   ; CHECK-NEXT: MemsetSizeSCEV: (4 * (sext i32 %m to i64))<nsw>
>>>   ; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))<nsw>)<nsw>
>>> +; CHECK-NEXT: Try to fold SCEV based on loop guard
>>> +; CHECK-NEXT: FoldedMemsetSize: (4 * (sext i32 %m to i64))<nsw>
>>> +; CHECK-NEXT: FoldedPositiveStride: (4 + (4 * (sext i32 %m to i64))<nsw>)<nsw>
>>>   ; CHECK-NEXT: SCEV don't match, abort
>>>   ; CHECK: loop-idiom Scanning: F[NonZeroAddressSpace] Countable Loop %for.cond1.preheader
>>>   ; CHECK-NEXT: memset size is non-constant
>>>
>>>
>>>          _______________________________________________
>>> llvm-commits mailing list
>>> llvm-commits at lists.llvm.org
>>> https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits


More information about the llvm-commits mailing list