[llvm] 9ee8e38 - [VPlan] Also propagate versioned strides to users via sext/zext.

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 26 13:30:10 PDT 2024


Author: Florian Hahn
Date: 2024-04-26T21:29:43+01:00
New Revision: 9ee8e38cdcc6925a4127d44a0360dc8de23dfb5f

URL: https://github.com/llvm/llvm-project/commit/9ee8e38cdcc6925a4127d44a0360dc8de23dfb5f
DIFF: https://github.com/llvm/llvm-project/commit/9ee8e38cdcc6925a4127d44a0360dc8de23dfb5f.diff

LOG: [VPlan] Also propagate versioned strides to users via sext/zext.

The versioned value may not be used in the loop directly but through a
sext/zext. Add new live-ins in those cases.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
    llvm/lib/Transforms/Vectorize/VPlan.h
    llvm/test/Transforms/LoopVectorize/runtime-checks-hoist.ll
    llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 33c4decd58a6c2..74ceb9eecf2130 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8817,12 +8817,24 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
     // Only handle constant strides for now.
     if (!ScevStride)
       continue;
-    Constant *CI = ConstantInt::get(Stride->getType(), ScevStride->getAPInt());
 
-    auto *ConstVPV = Plan->getOrAddLiveIn(CI);
-    // The versioned value may not be used in the loop directly, so just add a
-    // new live-in in those cases.
-    Plan->getOrAddLiveIn(StrideV)->replaceAllUsesWith(ConstVPV);
+    auto *CI = Plan->getOrAddLiveIn(
+        ConstantInt::get(Stride->getType(), ScevStride->getAPInt()));
+    if (VPValue *StrideVPV = Plan->getLiveIn(StrideV))
+      StrideVPV->replaceAllUsesWith(CI);
+
+    // The versioned value may not be used in the loop directly but through a
+    // sext/zext. Add new live-ins in those cases.
+    for (Value *U : StrideV->users()) {
+      if (!isa<SExtInst, ZExtInst>(U))
+        continue;
+      VPValue *StrideVPV = Plan->getLiveIn(U);
+      if (!StrideVPV)
+        continue;
+      VPValue *CI = Plan->getOrAddLiveIn(ConstantInt::get(
+          U->getType(), ScevStride->getAPInt().getSExtValue()));
+      StrideVPV->replaceAllUsesWith(CI);
+    }
   }
 
   VPlanTransforms::dropPoisonGeneratingRecipes(*Plan, [this](BasicBlock *BB) {

diff  --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 21b088cd238660..71387bf5b7e929 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -3215,6 +3215,9 @@ class VPlan {
     return Value2VPValue[V];
   }
 
+  /// Return the live-in VPValue for \p V, if there is one or nullptr otherwise.
+  VPValue *getLiveIn(Value *V) const { return Value2VPValue.lookup(V); }
+
 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
   /// Print the live-ins of this VPlan to \p O.
   void printLiveIns(raw_ostream &O) const;

diff  --git a/llvm/test/Transforms/LoopVectorize/runtime-checks-hoist.ll b/llvm/test/Transforms/LoopVectorize/runtime-checks-hoist.ll
index 0b9b592627c657..c4f9c404a9265b 100644
--- a/llvm/test/Transforms/LoopVectorize/runtime-checks-hoist.ll
+++ b/llvm/test/Transforms/LoopVectorize/runtime-checks-hoist.ll
@@ -1328,13 +1328,11 @@ define void @unknown_inner_stride(ptr nocapture noundef %dst, ptr nocapture noun
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 0
-; CHECK-NEXT:    [[TMP15:%.*]] = mul nsw i64 [[TMP14]], [[TMP0]]
-; CHECK-NEXT:    [[TMP16:%.*]] = add nsw i64 [[TMP15]], [[TMP11]]
+; CHECK-NEXT:    [[TMP16:%.*]] = add nsw i64 [[TMP14]], [[TMP11]]
 ; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i64 [[TMP16]]
 ; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i32, ptr [[TMP17]], i32 0
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <4 x i32>, ptr [[TMP18]], align 4, !alias.scope [[META60:![0-9]+]]
-; CHECK-NEXT:    [[TMP19:%.*]] = mul nsw i64 [[TMP14]], [[TMP1]]
-; CHECK-NEXT:    [[TMP20:%.*]] = add nsw i64 [[TMP19]], [[TMP12]]
+; CHECK-NEXT:    [[TMP20:%.*]] = add nsw i64 [[TMP14]], [[TMP12]]
 ; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr inbounds i32, ptr [[DST]], i64 [[TMP20]]
 ; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[TMP21]], i32 0
 ; CHECK-NEXT:    [[WIDE_LOAD3:%.*]] = load <4 x i32>, ptr [[TMP22]], align 4, !alias.scope [[META63:![0-9]+]], !noalias [[META60]]

diff  --git a/llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll b/llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll
index 693a8b2876153c..d09066fa2d7042 100644
--- a/llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll
+++ b/llvm/test/Transforms/LoopVectorize/version-stride-with-integer-casts.ll
@@ -34,7 +34,7 @@ define void @test_versioned_with_sext_use(i32 %offset, ptr %dst) {
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr i32, ptr [[DST]], i64 [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr i32, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    store <4 x i32> zeroinitializer, ptr [[TMP5]], align 8
-; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[TMP3]], [[OFFSET_EXT]]
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[TMP3]], 1
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], 200
 ; CHECK-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
@@ -110,7 +110,7 @@ define void @test_versioned_with_zext_use(i32 %offset, ptr %dst) {
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr i32, ptr [[DST]], i64 [[TMP3]]
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr i32, ptr [[TMP4]], i32 0
 ; CHECK-NEXT:    store <4 x i32> zeroinitializer, ptr [[TMP5]], align 8
-; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[TMP3]], [[OFFSET_EXT]]
+; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[TMP3]], 1
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], 200
 ; CHECK-NEXT:    br i1 [[TMP7]], label [[MIDDLE_BLOCK]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
@@ -165,38 +165,23 @@ define void @versioned_sext_use_in_gep(i32 %scale, ptr %dst, i64 %scale.2) {
 ; CHECK-NEXT:    [[IDENT_CHECK:%.*]] = icmp ne i32 [[SCALE]], 1
 ; CHECK-NEXT:    br i1 [[IDENT_CHECK]], label [[SCALAR_PH]], label [[VECTOR_PH:%.*]]
 ; CHECK:       vector.ph:
-; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x i64> poison, i64 [[SCALE_EXT]], i64 0
-; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i64> [[BROADCAST_SPLATINSERT]], <4 x i64> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <4 x i64> poison, i64 [[SCALE_2]], i64 0
-; CHECK-NEXT:    [[BROADCAST_SPLAT2:%.*]] = shufflevector <4 x i64> [[BROADCAST_SPLATINSERT1]], <4 x i64> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <4 x i64> [ <i64 0, i64 1, i64 2, i64 3>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = mul <4 x i64> [[VEC_IND]], [[BROADCAST_SPLAT]]
-; CHECK-NEXT:    [[TMP1:%.*]] = extractelement <4 x i64> [[TMP0]], i32 0
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x i64> [[TMP0]], i32 1
-; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = extractelement <4 x i64> [[TMP0]], i32 2
-; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP5]]
-; CHECK-NEXT:    [[TMP7:%.*]] = extractelement <4 x i64> [[TMP0]], i32 3
-; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP7]]
-; CHECK-NEXT:    [[TMP9:%.*]] = mul <4 x i64> [[BROADCAST_SPLAT]], [[BROADCAST_SPLAT2]]
-; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <4 x i64> [[TMP9]], i32 0
+; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 0
+; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[INDEX]], 1
+; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 2
+; CHECK-NEXT:    [[TMP16:%.*]] = add i64 [[INDEX]], 3
 ; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP10]]
-; CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i64> [[TMP9]], i32 1
 ; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP12]]
-; CHECK-NEXT:    [[TMP14:%.*]] = extractelement <4 x i64> [[TMP9]], i32 2
 ; CHECK-NEXT:    [[TMP15:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP14]]
-; CHECK-NEXT:    [[TMP16:%.*]] = extractelement <4 x i64> [[TMP9]], i32 3
 ; CHECK-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[DST]], i64 [[TMP16]]
-; CHECK-NEXT:    store ptr [[TMP11]], ptr [[TMP2]], align 8
-; CHECK-NEXT:    store ptr [[TMP13]], ptr [[TMP4]], align 8
-; CHECK-NEXT:    store ptr [[TMP15]], ptr [[TMP6]], align 8
-; CHECK-NEXT:    store ptr [[TMP17]], ptr [[TMP8]], align 8
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[DST]], i64 [[SCALE_2]]
+; CHECK-NEXT:    store ptr [[TMP8]], ptr [[TMP11]], align 8
+; CHECK-NEXT:    store ptr [[TMP8]], ptr [[TMP13]], align 8
+; CHECK-NEXT:    store ptr [[TMP8]], ptr [[TMP15]], align 8
+; CHECK-NEXT:    store ptr [[TMP8]], ptr [[TMP17]], align 8
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
-; CHECK-NEXT:    [[VEC_IND_NEXT]] = add <4 x i64> [[VEC_IND]], <i64 4, i64 4, i64 4, i64 4>
 ; CHECK-NEXT:    [[TMP18:%.*]] = icmp eq i64 [[INDEX_NEXT]], 256
 ; CHECK-NEXT:    br i1 [[TMP18]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP6:![0-9]+]]
 ; CHECK:       middle.block:
@@ -282,7 +267,7 @@ define void @test_versioned_with_
diff erent_uses(i32 %offset, ptr noalias %dst.1,
 ; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr i32, ptr [[DST_2]], i64 [[TMP3]]
 ; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr i32, ptr [[TMP12]], i32 0
 ; CHECK-NEXT:    store <4 x i32> zeroinitializer, ptr [[TMP13]], align 8
-; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[TMP3]], [[OFFSET_EXT]]
+; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[TMP3]], 1
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
 ; CHECK-NEXT:    [[TMP15:%.*]] = icmp eq i64 [[INDEX_NEXT]], 200
 ; CHECK-NEXT:    br i1 [[TMP15]], label [[MIDDLE_BLOCK]], label [[VECTOR_BODY]], !llvm.loop [[LOOP8:![0-9]+]]
@@ -376,8 +361,7 @@ define void @test_versioned_with_non_ex_use(i32 %offset, ptr noalias %dst.1, ptr
 ; CHECK-NEXT:    store i32 0, ptr [[TMP14]], align 8
 ; CHECK-NEXT:    store i32 0, ptr [[TMP16]], align 8
 ; CHECK-NEXT:    store i32 0, ptr [[TMP18]], align 8
-; CHECK-NEXT:    [[TMP19:%.*]] = mul i64 [[TMP9]], [[OFFSET_EXT]]
-; CHECK-NEXT:    [[TMP20:%.*]] = getelementptr i32, ptr [[DST_2]], i64 [[TMP19]]
+; CHECK-NEXT:    [[TMP20:%.*]] = getelementptr i32, ptr [[DST_2]], i64 [[TMP9]]
 ; CHECK-NEXT:    [[TMP21:%.*]] = getelementptr i32, ptr [[TMP20]], i32 0
 ; CHECK-NEXT:    store <4 x i32> zeroinitializer, ptr [[TMP21]], align 8
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4


        


More information about the llvm-commits mailing list