[llvm] bc17d32 - [LoopIdiom] Let LIR fold memset pointer / stride SCEV regarding loop guards

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 13 09:37:03 PST 2021


Author: eopXD
Date: 2021-12-13T09:36:58-08:00
New Revision: bc17d32a5f71b161186423c200554bddb6fb7e43

URL: https://github.com/llvm/llvm-project/commit/bc17d32a5f71b161186423c200554bddb6fb7e43
DIFF: https://github.com/llvm/llvm-project/commit/bc17d32a5f71b161186423c200554bddb6fb7e43.diff

LOG: [LoopIdiom] Let LIR fold memset pointer / stride SCEV regarding loop guards

Expression guraded in loop entry can be folded prior to comparison. This patch
proceeds D107353 and makes LIR able to deal with nested for-loop.

Reviewed By: qianzhen, bmahjour

Differential Revision: https://reviews.llvm.org/D108112

Added: 
    

Modified: 
    llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
    llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
    llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
    llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 42da86a9ecf50..e6efb422e7831 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -307,6 +307,25 @@ class LoopIdiomRecognizeLegacyPass : public LoopPass {
   }
 };
 
+// The Folder will fold expressions that are guarded by the loop entry.
+class SCEVSignToZeroExtentionRewriter
+    : public SCEVRewriteVisitor<SCEVSignToZeroExtentionRewriter> {
+public:
+  ScalarEvolution &SE;
+  const Loop *CurLoop;
+  SCEVSignToZeroExtentionRewriter(ScalarEvolution &SE, const Loop *CurLoop)
+      : SCEVRewriteVisitor(SE), SE(SE), CurLoop(CurLoop) {}
+
+  const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
+    // If expression is guarded by CurLoop to be greater or equal to zero
+    // then convert sext to zext. Otherwise return the original expression.
+    if (SE.isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Expr,
+                                    SE.getZero(Expr->getType())))
+      return SE.getZeroExtendExpr(visit(Expr->getOperand()), Expr->getType());
+    return Expr;
+  }
+};
+
 } // end anonymous namespace
 
 char LoopIdiomRecognizeLegacyPass::ID = 0;
@@ -967,12 +986,22 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
                       << "\n");
 
     if (PositiveStrideSCEV != MemsetSizeSCEV) {
-      // TODO: folding can be done to the SCEVs
-      // The folding is to fold expressions that is covered by the loop guard
-      // at loop entry. After the folding, compare again and proceed
-      // optimization if equal.
-      LLVM_DEBUG(dbgs() << "  SCEV don't match, abort\n");
-      return false;
+      // The folding is to fold an expression that is covered by the loop guard
+      // at loop entry. After the folding, compare again and proceed with
+      // optimization, if equal.
+      SCEVSignToZeroExtentionRewriter Folder(*SE, CurLoop);
+      const SCEV *FoldedPositiveStride = Folder.visit(PositiveStrideSCEV);
+      const SCEV *FoldedMemsetSize = Folder.visit(MemsetSizeSCEV);
+
+      LLVM_DEBUG(dbgs() << "  Try to fold SCEV based on loop guard\n"
+                        << "    FoldedMemsetSize: " << *FoldedMemsetSize << "\n"
+                        << "    FoldedPositiveStride: " << *FoldedPositiveStride
+                        << "\n");
+
+      if (FoldedPositiveStride != FoldedMemsetSize) {
+        LLVM_DEBUG(dbgs() << "  SCEV don't match, abort\n");
+        return false;
+      }
     }
   }
 

diff  --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
index fe6ad07a1feba..69e96d53f2202 100644
--- a/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
+++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll
@@ -369,4 +369,52 @@ for.end17:                                        ; preds = %for.end17.loopexit,
   ret void
 }
 
+; void NegStart(int n, int m, int *ar) {
+;   for (int i = -100; i < n; i++) {
+;     int *arr = ar + (i + 100) * m;
+;     memset(arr, 0, m * sizeof(int));
+;   }
+; }
+define void @NegStart(i32 %n, i32 %m, i32* %ar) {
+; CHECK-LABEL: @NegStart(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8*
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 -100, [[N:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]]
+; CHECK:       for.body.lr.ph:
+; CHECK-NEXT:    [[MUL1:%.*]] = mul i32 [[M:%.*]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = add i32 [[N]], 100
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[M]], [[TMP0]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i32 [[TMP1]], 2
+; CHECK-NEXT:    call void @llvm.memset.p0i8.i32(i8* align 4 [[AR1]], i8 0, i32 [[TMP2]], i1 false)
+; CHECK-NEXT:    br label [[FOR_END]]
+; CHECK:       for.end:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %cmp1 = icmp slt i32 -100, %n
+  br i1 %cmp1, label %for.body.lr.ph, label %for.end
+
+for.body.lr.ph:                                   ; preds = %entry
+  %mul1 = mul i32 %m, 4
+  br label %for.body
+
+for.body:                                         ; preds = %for.body.lr.ph, %for.body
+  %i.02 = phi i32 [ -100, %for.body.lr.ph ], [ %inc, %for.body ]
+  %add = add nsw i32 %i.02, 100
+  %mul = mul nsw i32 %add, %m
+  %add.ptr = getelementptr inbounds i32, i32* %ar, i32 %mul
+  %0 = bitcast i32* %add.ptr to i8*
+  call void @llvm.memset.p0i8.i32(i8* align 4 %0, i8 0, i32 %mul1, i1 false)
+  %inc = add nsw i32 %i.02, 1
+  %exitcond = icmp ne i32 %inc, %n
+  br i1 %exitcond, label %for.body, label %for.end.loopexit
+
+for.end.loopexit:                                 ; preds = %for.body
+  br label %for.end
+
+for.end:                                          ; preds = %for.end.loopexit, %entry
+  ret void
+}
+
 declare void @llvm.memset.p0i8.i32(i8* nocapture writeonly, i8, i32, i1 immarg)

diff  --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
index f6d2441a85cc1..5d485b82ada79 100644
--- a/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
+++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll
@@ -268,6 +268,12 @@ for.body:                                         ; preds = %for.body.lr.ph, %fo
 for.end:                                          ; preds = %for.body, %entry
   ret void
 }
+; This case requires SCEVFolder in LoopIdiomRecognize.cpp to fold SCEV prior to comparison.
+; For the inner-loop, SCEVFolder is not needed, however the promoted memset size would be based
+; on the trip count of inner-loop (which is an unsigned integer).
+; Then in the outer loop, the pointer stride SCEV for memset needs to be converted based on the
+; loop guard for it to equal to the memset size SCEV. The loop guard guaranteeds that m >= 0
+; inside the loop, so m can be converted from sext to zext, making the two SCEV-s equal.
 ; void NestedFor32(int *ar, int n, int m, int o)
 ; {
 ;   int i, j;
@@ -281,6 +287,7 @@ for.end:                                          ; preds = %for.body, %entry
 define void @NestedFor32(i32* %ar, i32 %n, i32 %m, i32 %o) {
 ; CHECK-LABEL: @NestedFor32(
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AR2:%.*]] = bitcast i32* [[AR:%.*]] to i8*
 ; CHECK-NEXT:    [[CMP3:%.*]] = icmp slt i32 0, [[N:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP3]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END11:%.*]]
 ; CHECK:       for.body.lr.ph:
@@ -296,17 +303,10 @@ define void @NestedFor32(i32* %ar, i32 %n, i32 %m, i32 %o) {
 ; CHECK-NEXT:    [[TMP3:%.*]] = zext i32 [[M]] to i64
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP0]], [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = shl i64 [[TMP4]], 2
-; CHECK-NEXT:    br label [[FOR_BODY_US:%.*]]
-; CHECK:       for.body.us:
-; CHECK-NEXT:    [[INDVARS_IV6:%.*]] = phi i64 [ 0, [[FOR_BODY_US_PREHEADER]] ], [ [[INDVARS_IV_NEXT7:%.*]], [[FOR_BODY_US]] ]
-; CHECK-NEXT:    [[TMP6:%.*]] = mul i64 [[TMP2]], [[INDVARS_IV6]]
-; CHECK-NEXT:    [[SCEVGEP:%.*]] = getelementptr i32, i32* [[AR:%.*]], i64 [[TMP6]]
-; CHECK-NEXT:    [[SCEVGEP1:%.*]] = bitcast i32* [[SCEVGEP]] to i8*
-; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[M]] to i64
-; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 4 [[SCEVGEP1]], i8 0, i64 [[TMP5]], i1 false)
-; CHECK-NEXT:    [[INDVARS_IV_NEXT7]] = add nuw nsw i64 [[INDVARS_IV6]], 1
-; CHECK-NEXT:    [[EXITCOND11:%.*]] = icmp ne i64 [[INDVARS_IV_NEXT7]], [[WIDE_TRIP_COUNT10]]
-; CHECK-NEXT:    br i1 [[EXITCOND11]], label [[FOR_BODY_US]], label [[FOR_END11]]
+; CHECK-NEXT:    [[TMP6:%.*]] = mul i64 [[TMP4]], [[WIDE_TRIP_COUNT10]]
+; CHECK-NEXT:    [[TMP7:%.*]] = shl i64 [[TMP6]], 2
+; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 4 [[AR2]], i8 0, i64 [[TMP7]], i1 false)
+; CHECK-NEXT:    br label [[FOR_END11]]
 ; CHECK:       for.end11:
 ; CHECK-NEXT:    ret void
 ;
@@ -357,4 +357,58 @@ for.end11:                                        ; preds = %for.end11.loopexit,
   ret void
 }
 
+; void NegStart(int n, int m, int *ar) {
+;   for (int i = -100; i < n; i++) {
+;     int *arr = ar + (i + 100) * m;
+;     memset(arr, 0, m * sizeof(int));
+;   }
+; }
+define void @NegStart(i32 %n, i32 %m, i32* %ar) {
+; CHECK-LABEL: @NegStart(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8*
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp slt i32 -100, [[N:%.*]]
+; CHECK-NEXT:    br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]]
+; CHECK:       for.body.lr.ph:
+; CHECK-NEXT:    [[CONV:%.*]] = sext i32 [[M:%.*]] to i64
+; CHECK-NEXT:    [[MUL1:%.*]] = mul i64 [[CONV]], 4
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i32 [[M]] to i64
+; CHECK-NEXT:    [[WIDE_TRIP_COUNT:%.*]] = sext i32 [[N]] to i64
+; CHECK-NEXT:    [[TMP1:%.*]] = add nsw i64 [[WIDE_TRIP_COUNT]], 100
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 [[TMP1]], [[TMP0]]
+; CHECK-NEXT:    [[TMP3:%.*]] = shl i64 [[TMP2]], 2
+; CHECK-NEXT:    call void @llvm.memset.p0i8.i64(i8* align 4 [[AR1]], i8 0, i64 [[TMP3]], i1 false)
+; CHECK-NEXT:    br label [[FOR_END]]
+; CHECK:       for.end:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %cmp1 = icmp slt i32 -100, %n
+  br i1 %cmp1, label %for.body.lr.ph, label %for.end
+
+for.body.lr.ph:                                   ; preds = %entry
+  %conv = sext i32 %m to i64
+  %mul1 = mul i64 %conv, 4
+  %0 = sext i32 %m to i64
+  %wide.trip.count = sext i32 %n to i64
+  br label %for.body
+
+for.body:                                         ; preds = %for.body.lr.ph, %for.body
+  %indvars.iv = phi i64 [ -100, %for.body.lr.ph ], [ %indvars.iv.next, %for.body ]
+  %1 = add nsw i64 %indvars.iv, 100
+  %2 = mul nsw i64 %1, %0
+  %add.ptr = getelementptr inbounds i32, i32* %ar, i64 %2
+  %3 = bitcast i32* %add.ptr to i8*
+  call void @llvm.memset.p0i8.i64(i8* align 4 %3, i8 0, i64 %mul1, i1 false)
+  %indvars.iv.next = add nsw i64 %indvars.iv, 1
+  %exitcond = icmp ne i64 %indvars.iv.next, %wide.trip.count
+  br i1 %exitcond, label %for.body, label %for.end.loopexit
+
+for.end.loopexit:                                 ; preds = %for.body
+  br label %for.end
+
+for.end:                                          ; preds = %for.end.loopexit, %entry
+  ret void
+}
+
 declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg)

diff  --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
index 8ee554eb6d25e..95f9c969087ef 100644
--- a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
+++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll
@@ -19,6 +19,9 @@
 ; CHECK-NEXT: memset size is non-constant
 ; CHECK-NEXT: MemsetSizeSCEV: (4 * (sext i32 %m to i64))<nsw>
 ; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))<nsw>)<nsw>
+; CHECK-NEXT: Try to fold SCEV based on loop guard
+; CHECK-NEXT: FoldedMemsetSize: (4 * (sext i32 %m to i64))<nsw>
+; CHECK-NEXT: FoldedPositiveStride: (4 + (4 * (sext i32 %m to i64))<nsw>)<nsw>
 ; CHECK-NEXT: SCEV don't match, abort
 ; CHECK: loop-idiom Scanning: F[NonZeroAddressSpace] Countable Loop %for.cond1.preheader
 ; CHECK-NEXT: memset size is non-constant


        


More information about the llvm-commits mailing list