[llvm] bc17d32 - [LoopIdiom] Let LIR fold memset pointer / stride SCEV regarding loop guards

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 14 08:47:59 PST 2021


Any reason you need the custom rewriter instead of using applyLoopGuards?

Philip

On 12/13/21 9:37 AM, via llvm-commits wrote:
> Author: eopXD
> Date: 2021-12-13T09:36:58-08:00
> New Revision: bc17d32a5f71b161186423c200554bddb6fb7e43
>
> URL: https://github.com/llvm/llvm-project/commit/bc17d32a5f71b161186423c200554bddb6fb7e43
> DIFF: https://github.com/llvm/llvm-project/commit/bc17d32a5f71b161186423c200554bddb6fb7e43.diff
>
> LOG: [LoopIdiom] Let LIR fold memset pointer / stride SCEV regarding loop guards
>
> Expression guraded in loop entry can be folded prior to comparison. This patch
> proceeds D107353 and makes LIR able to deal with nested for-loop.
>
> Reviewed By: qianzhen, bmahjour
>
> Differential Revision: https://reviews.llvm.org/D108112
>
> Added:
>      
>
> Modified:
>      llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
>      llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
>      llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
>      llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
>
> Removed:
>      
>
>
> ################################################################################
> diff  --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
> index 42da86a9ecf50..e6efb422e7831 100644
> --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
> +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
> @@ -307,6 +307,25 @@ class LoopIdiomRecognizeLegacyPass : public LoopPass {
>     }
>   };
>   
> +// The Folder will fold expressions that are guarded by the loop entry.
> +class SCEVSignToZeroExtentionRewriter
> +    : public SCEVRewriteVisitor<SCEVSignToZeroExtentionRewriter> {
> +public:
> +  ScalarEvolution &SE;
> +  const Loop *CurLoop;
> +  SCEVSignToZeroExtentionRewriter(ScalarEvolution &SE, const Loop *CurLoop)
> +      : SCEVRewriteVisitor(SE), SE(SE), CurLoop(CurLoop) {}
> +
> +  const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
> +    // If expression is guarded by CurLoop to be greater or equal to zero
> +    // then convert sext to zext. Otherwise return the original expression.
> +    if (SE.isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Expr,
> +                                    SE.getZero(Expr->getType())))
> +      return SE.getZeroExtendExpr(visit(Expr->getOperand()), Expr->getType());
> +    return Expr;
> +  }
> +};
> +
>   } // end anonymous namespace
>   
>   char LoopIdiomRecognizeLegacyPass::ID = 0;
> @@ -967,12 +986,22 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
>                         << "\n");
>   
>       if (PositiveStrideSCEV != MemsetSizeSCEV) {
> -      // TODO: folding can be done to the SCEVs
> -      // The folding is to fold expressions that is covered by the loop guard
> -      // at loop entry. After the folding, compare again and proceed
> -      // optimization if equal.
> -      LLVM_DEBUG(dbgs() << "  SCEV don't match, abort\n");
> -      return false;
> +      // The folding is to fold an expression that is covered by the loop guard
> +      // at loop entry. After the folding, compare again and proceed with
> +      // optimization, if equal.
> +      SCEVSignToZeroExtentionRewriter Folder(*SE, CurLoop);
> +      const SCEV *FoldedPositiveStride = Folder.visit(PositiveStrideSCEV);
> +      const SCEV *FoldedMemsetSize = Folder.visit(MemsetSizeSCEV);
> +
> +      LLVM_DEBUG(dbgs() << "  Try to fold SCEV based on loop guard\n"
> +                        << "    FoldedMemsetSize: " << *FoldedMemsetSize << "\n"
> +                        << "    FoldedPositiveStride: " << *FoldedPositiveStride
> +                        << "\n");
> +
> +      if (FoldedPositiveStride != FoldedMemsetSize) {
> +        LLVM_DEBUG(dbgs() << "  SCEV don't match, abort\n");
> +        return false;
> +      }
>       }
>     }
>   
>
> diff  --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
> index fe6ad07a1feba..69e96d53f2202 100644
> --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
> +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
> @@ -369,4 +369,52 @@ for.end17:                                        ; preds = %for.end17.loopexit,
>     ret void
>   }
>   
> +; void NegStart(int n, int m, int *ar) {
> +;   for (int i = -100; i < n; i++) {
> +;     int *arr = ar + (i + 100) * m;
> +;     memset(arr, 0, m * sizeof(int));
> +;   }
> +; }
> +define void @NegStart(i32 %n, i32 %m, i32* %ar) {
> +; CHECK-LABEL: @NegStart(
> +; CHECK-NEXT:  entry:
> +; CHECK-NEXT:    [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8*
> +; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 -100, [[N:%.*]]
> +; CHECK-NEXT:    br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]]
> +; CHECK:       for.body.lr.ph:
> +; CHECK-NEXT:    [[MUL1:%.*]] = mul i32 [[M:%.*]], 4
> +; CHECK-NEXT:    [[TMP0:%.*]] = add i32 [[N]], 100
> +; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[M]], [[TMP0]]
> +; CHECK-NEXT:    [[TMP2:%.*]] = shl i32 [[TMP1]], 2
> +; CHECK-NEXT:    call void @llvm.memset.p0i8.i32(i8* align 4 [[AR1]], i8 0, i32 [[TMP2]], i1 false)
> +; CHECK-NEXT:    br label [[FOR_END]]
> +; CHECK:       for.end:
> +; CHECK-NEXT:    ret void
> +;
> +entry:
> +  %cmp1 = icmp slt i32 -100, %n
> +  br i1 %cmp1, label %for.body.lr.ph, label %for.end
> +
> +for.body.lr.ph:                                   ; preds = %entry
> +  %mul1 = mul i32 %m, 4
> +  br label %for.body
> +
> +for.body:                                         ; preds = %for.body.lr.ph, %for.body
> +  %i.02 = phi i32 [ -100, %for.body.lr.ph ], [ %inc, %for.body ]
> +  %add = add nsw i32 %i.02, 100
> +  %mul = mul nsw i32 %add, %m
> +  %add.ptr = getelementptr inbounds i32, i32* %ar, i32 %mul
> +  %0 = bitcast i32* %add.ptr to i8*
> +  call void @llvm.memset.p0i8.i32(i8* align 4 %0, i8 0, i32 %mul1, i1 false)
> +  %inc = add nsw i32 %i.02, 1
> +  %exitcond = icmp ne i32 %inc, %n
> +  br i1 %exitcond, label %for.body, label %for.end.loopexit
> +
> +for.end.loopexit:                                 ; preds = %for.body
> +  br label %for.end
> +
> +for.end:                                          ; preds = %for.end.loopexit, %entry
> +  ret void
> +}
> +
>   declare void @llvm.memset.p0i8.i32(i8* nocapture writeonly, i8, i32, i1 immarg)
>
> diff  --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
> index f6d2441a85cc1..5d485b82ada79 100644
> --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
> +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
> @@ -268,6 +268,12 @@ for.body:                                         ; preds = %for.body.lr.ph, %fo
>   for.end:                                          ; preds = %for.body, %entry
>     ret void
>   }
> +; This case requires SCEVFolder in LoopIdiomRecognize.cpp to fold SCEV prior to comparison.
> +; For the inner-loop, SCEVFolder is not needed, however the promoted memset size would be based
> +; on the trip count of inner-loop (which is an unsigned integer).
> +; Then in the outer loop, the pointer stride SCEV for memset needs to be converted based on the
> +; loop guard for it to equal to the memset size SCEV. The loop guard guaranteeds that m >= 0
> +; inside the loop, so m can be converted from sext to zext, making the two SCEV-s equal.
>   ; void NestedFor32(int *ar, int n, int m, int o)
>   ; {
>   ;   int i, j;
> @@ -281,6 +287,7 @@ for.end:                                          ; preds = %for.body, %entry
>   define void @NestedFor32(i32* %ar, i32 %n, i32 %m, i32 %o) {
>   ; CHECK-LABEL: @NestedFor32(
>   ; CHECK-NEXT:  entry:
> +; CHECK-NEXT:    [[AR2:%.*]] = bitcast i32* [[AR:%.*]] to i8*
>   ; CHECK-NEXT:    [[CMP3:%.*]] = icmp slt i32 0, [[N:%.*]]
>   ; CHECK-NEXT:    br i1 [[CMP3]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END11:%.*]]
>   ; CHECK:       for.body.lr.ph:
> @@ -296,17 +303,10 @@ define void @NestedFor32(i32* %ar, i32 %n, i32 %m, i32 %o) {
>   ; CHECK-NEXT:    [[TMP3:%.*]] = zext i32 [[M]] to i64
>   ; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP0]], [[TMP3]]
>   ; CHECK-NEXT:    [[TMP5:%.*]] = shl i64 [[TMP4]], 2
> -; CHECK-NEXT:    br label [[FOR_BODY_US:%.*]]
> -; CHECK:       for.body.us:
> -; CHECK-NEXT:    [[INDVARS_IV6:%.*]] = phi i64 [ 0, [[FOR_BODY_US_PREHEADER]] ], [ [[INDVARS_IV_NEXT7:%.*]], [[FOR_BODY_US]] ]
> -; CHECK-NEXT:    [[TMP6:%.*]] = mul i64 [[TMP2]], [[INDVARS_IV6]]
> -; CHECK-NEXT:    [[SCEVGEP:%.*]] = getelementptr i32, i32* [[AR:%.*]], i64 [[TMP6]]
> -; CHECK-NEXT:    [[SCEVGEP1:%.*]] = bitcast i32* [[SCEVGEP]] to i8*
> -; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[M]] to i64
> -; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 4 [[SCEVGEP1]], i8 0, i64 [[TMP5]], i1 false)
> -; CHECK-NEXT:    [[INDVARS_IV_NEXT7]] = add nuw nsw i64 [[INDVARS_IV6]], 1
> -; CHECK-NEXT:    [[EXITCOND11:%.*]] = icmp ne i64 [[INDVARS_IV_NEXT7]], [[WIDE_TRIP_COUNT10]]
> -; CHECK-NEXT:    br i1 [[EXITCOND11]], label [[FOR_BODY_US]], label [[FOR_END11]]
> +; CHECK-NEXT:    [[TMP6:%.*]] = mul i64 [[TMP4]], [[WIDE_TRIP_COUNT10]]
> +; CHECK-NEXT:    [[TMP7:%.*]] = shl i64 [[TMP6]], 2
> +; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 4 [[AR2]], i8 0, i64 [[TMP7]], i1 false)
> +; CHECK-NEXT:    br label [[FOR_END11]]
>   ; CHECK:       for.end11:
>   ; CHECK-NEXT:    ret void
>   ;
> @@ -357,4 +357,58 @@ for.end11:                                        ; preds = %for.end11.loopexit,
>     ret void
>   }
>   
> +; void NegStart(int n, int m, int *ar) {
> +;   for (int i = -100; i < n; i++) {
> +;     int *arr = ar + (i + 100) * m;
> +;     memset(arr, 0, m * sizeof(int));
> +;   }
> +; }
> +define void @NegStart(i32 %n, i32 %m, i32* %ar) {
> +; CHECK-LABEL: @NegStart(
> +; CHECK-NEXT:  entry:
> +; CHECK-NEXT:    [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8*
> +; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 -100, [[N:%.*]]
> +; CHECK-NEXT:    br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]]
> +; CHECK:       for.body.lr.ph:
> +; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[M:%.*]] to i64
> +; CHECK-NEXT:    [[MUL1:%.*]] = mul i64 [[CONV]], 4
> +; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[M]] to i64
> +; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = sext i32 [[N]] to i64
> +; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[WIDE_TRIP_COUNT]], 100
> +; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], [[TMP0]]
> +; CHECK-NEXT:    [[TMP3:%.*]] = shl i64 [[TMP2]], 2
> +; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 4 [[AR1]], i8 0, i64 [[TMP3]], i1 false)
> +; CHECK-NEXT:    br label [[FOR_END]]
> +; CHECK:       for.end:
> +; CHECK-NEXT:    ret void
> +;
> +entry:
> +  %cmp1 = icmp slt i32 -100, %n
> +  br i1 %cmp1, label %for.body.lr.ph, label %for.end
> +
> +for.body.lr.ph:                                   ; preds = %entry
> +  %conv = sext i32 %m to i64
> +  %mul1 = mul i64 %conv, 4
> +  %0 = sext i32 %m to i64
> +  %wide.trip.count = sext i32 %n to i64
> +  br label %for.body
> +
> +for.body:                                         ; preds = %for.body.lr.ph, %for.body
> +  %indvars.iv = phi i64 [ -100, %for.body.lr.ph ], [ %indvars.iv.next, %for.body ]
> +  %1 = add nsw i64 %indvars.iv, 100
> +  %2 = mul nsw i64 %1, %0
> +  %add.ptr = getelementptr inbounds i32, i32* %ar, i64 %2
> +  %3 = bitcast i32* %add.ptr to i8*
> +  call void @llvm.memset.p0i8.i64(i8* align 4 %3, i8 0, i64 %mul1, i1 false)
> +  %indvars.iv.next = add nsw i64 %indvars.iv, 1
> +  %exitcond = icmp ne i64 %indvars.iv.next, %wide.trip.count
> +  br i1 %exitcond, label %for.body, label %for.end.loopexit
> +
> +for.end.loopexit:                                 ; preds = %for.body
> +  br label %for.end
> +
> +for.end:                                          ; preds = %for.end.loopexit, %entry
> +  ret void
> +}
> +
>   declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg)
>
> diff  --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
> index 8ee554eb6d25e..95f9c969087ef 100644
> --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
> +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
> @@ -19,6 +19,9 @@
>   ; CHECK-NEXT: memset size is non-constant
>   ; CHECK-NEXT: MemsetSizeSCEV: (4 * (sext i32 %m to i64))<nsw>
>   ; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))<nsw>)<nsw>
> +; CHECK-NEXT: Try to fold SCEV based on loop guard
> +; CHECK-NEXT: FoldedMemsetSize: (4 * (sext i32 %m to i64))<nsw>
> +; CHECK-NEXT: FoldedPositiveStride: (4 + (4 * (sext i32 %m to i64))<nsw>)<nsw>
>   ; CHECK-NEXT: SCEV don't match, abort
>   ; CHECK: loop-idiom Scanning: F[NonZeroAddressSpace] Countable Loop %for.cond1.preheader
>   ; CHECK-NEXT: memset size is non-constant
>
>
>          
> _______________________________________________
> llvm-commits mailing list
> llvm-commits at lists.llvm.org
> https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-commits


More information about the llvm-commits mailing list