[llvm] [SLP] no need to generate extract for in-tree uses for original scala… (PR #76077)

via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 26 18:12:58 PST 2023


https://github.com/Enna1 updated https://github.com/llvm/llvm-project/pull/76077

>From 335b5b8fa274164cb84e35ecabbd5e1c38b278c7 Mon Sep 17 00:00:00 2001
From: "xumingjie.enna1" <xumingjie.enna1 at bytedance.com>
Date: Thu, 21 Dec 2023 15:23:00 +0800
Subject: [PATCH 1/3] [SLP] no need to generate extract for in-tree uses for
 original scalar instruction.

Before https://github.com/llvm/llvm-project/commit/77a609b55636dc540090ef9105c60a99cfdbd1dd,
we always skip in-tree uses of the vectorized scalars in `buildExternalUses()`,
that commit handles the case that if the in-tree use is scalar operand in vectorized instruction,
we need to generate extract for these in-tree uses.

in-tree uses remain as scalar in vectorized instructions can be 3 cases:
- The pointer operand of vectorized LoadInst uses an in-tree scalar
- The pointer operand of vectorized StoreInst uses an in-tree scalar
- The scalar argument of vector form intrinsic uses an in-tree scalar

Generating extract for in-tree uses for vectorized instructions are implemented in `BoUpSLP::vectorizeTree()`:
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11497-L11506
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11542-L11551
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11657-L11667

However, https://github.com/llvm/llvm-project/commit/77a609b55636dc540090ef9105c60a99cfdbd1dd
not only generates extract for vectorized instructions,
but also generates extract for original scalar instructions.
There is no need to generate extract for origin scalar instrutions,
as these scalar instructions will be replaced by vector instructions and get erased later.

This patch replaces extracts for original scalar instructions with corresponding vectorized instructions,
and remove
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11497-L11506
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11542-L11551
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11657-L11667
extracts.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 85 +++++++------------
 .../SLPVectorizer/X86/extract_in_tree_user.ll | 34 ++++----
 .../X86/reorder-reused-masked-gather2.ll      | 18 ++--
 3 files changed, 57 insertions(+), 80 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 32913b3f55697e..01eb624ab3bff9 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -4925,36 +4925,34 @@ void BoUpSLP::buildExternalUses(
         LLVM_DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n");
 
         Instruction *UserInst = dyn_cast<Instruction>(U);
-        if (!UserInst)
+        if (!UserInst || isDeleted(UserInst))
           continue;
 
-        if (isDeleted(UserInst))
+        // Ignore users in the user ignore list.
+        if (UserIgnoreList && UserIgnoreList->contains(UserInst))
           continue;
 
         // Skip in-tree scalars that become vectors
         if (TreeEntry *UseEntry = getTreeEntry(U)) {
-          Value *UseScalar = UseEntry->Scalars[0];
           // Some in-tree scalars will remain as scalar in vectorized
-          // instructions. If that is the case, the one in Lane 0 will
+          // instructions. If that is the case, the one in FoundLane will
           // be used.
-          if (UseScalar != U ||
-              UseEntry->State == TreeEntry::ScatterVectorize ||
+          if (UseEntry->State == TreeEntry::ScatterVectorize ||
               UseEntry->State == TreeEntry::PossibleStridedVectorize ||
-              !doesInTreeUserNeedToExtract(Scalar, UserInst, TLI)) {
+              !doesInTreeUserNeedToExtract(
+                  Scalar, cast<Instruction>(UseEntry->Scalars.front()), TLI)) {
             LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
                               << ".\n");
             assert(UseEntry->State != TreeEntry::NeedToGather && "Bad state");
             continue;
           }
+          U = nullptr;
         }
 
-        // Ignore users in the user ignore list.
-        if (UserIgnoreList && UserIgnoreList->contains(UserInst))
-          continue;
-
-        LLVM_DEBUG(dbgs() << "SLP: Need to extract:" << *U << " from lane "
-                          << Lane << " from " << *Scalar << ".\n");
-        ExternalUses.push_back(ExternalUser(Scalar, U, FoundLane));
+        LLVM_DEBUG(dbgs() << "SLP: Need to extract:" << *UserInst
+                          << " from lane " << Lane << " from " << *Scalar
+                          << ".\n");
+        ExternalUses.emplace_back(Scalar, U, FoundLane);
       }
     }
   }
@@ -11493,17 +11491,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       Value *PO = LI->getPointerOperand();
       if (E->State == TreeEntry::Vectorize) {
         NewLI = Builder.CreateAlignedLoad(VecTy, PO, LI->getAlign());
-
-        // The pointer operand uses an in-tree scalar so we add the new
-        // LoadInst to ExternalUses list to make sure that an extract will
-        // be generated in the future.
-        if (isa<Instruction>(PO)) {
-          if (TreeEntry *Entry = getTreeEntry(PO)) {
-            // Find which lane we need to extract.
-            unsigned FoundLane = Entry->findLaneForValue(PO);
-            ExternalUses.emplace_back(PO, NewLI, FoundLane);
-          }
-        }
       } else {
         assert((E->State == TreeEntry::ScatterVectorize ||
                 E->State == TreeEntry::PossibleStridedVectorize) &&
@@ -11539,17 +11526,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       StoreInst *ST =
           Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());
 
-      // The pointer operand uses an in-tree scalar, so add the new StoreInst to
-      // ExternalUses to make sure that an extract will be generated in the
-      // future.
-      if (isa<Instruction>(Ptr)) {
-        if (TreeEntry *Entry = getTreeEntry(Ptr)) {
-          // Find which lane we need to extract.
-          unsigned FoundLane = Entry->findLaneForValue(Ptr);
-          ExternalUses.push_back(ExternalUser(Ptr, ST, FoundLane));
-        }
-      }
-
       Value *V = propagateMetadata(ST, E->Scalars);
 
       E->VectorizedValue = V;
@@ -11654,18 +11630,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       CI->getOperandBundlesAsDefs(OpBundles);
       Value *V = Builder.CreateCall(CF, OpVecs, OpBundles);
 
-      // The scalar argument uses an in-tree scalar so we add the new vectorized
-      // call to ExternalUses list to make sure that an extract will be
-      // generated in the future.
-      if (isa_and_present<Instruction>(ScalarArg)) {
-        if (TreeEntry *Entry = getTreeEntry(ScalarArg)) {
-          // Find which lane we need to extract.
-          unsigned FoundLane = Entry->findLaneForValue(ScalarArg);
-          ExternalUses.push_back(
-              ExternalUser(ScalarArg, cast<User>(V), FoundLane));
-        }
-      }
-
       propagateIRFlags(V, E->Scalars, VL0);
       V = FinalShuffle(V, E, VecTy, IsSigned);
 
@@ -11877,6 +11841,7 @@ Value *BoUpSLP::vectorizeTree(
   DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs;
   SmallDenseSet<Value *, 4> UsedInserts;
   DenseMap<Value *, Value *> VectorCasts;
+  SmallDenseSet<Value *, 4> ScalarsWithNullptrUser;
   // Extract all of the elements with the external uses.
   for (const auto &ExternalUse : ExternalUses) {
     Value *Scalar = ExternalUse.Scalar;
@@ -11947,13 +11912,25 @@ Value *BoUpSLP::vectorizeTree(
       VectorToInsertElement.try_emplace(Vec, IE);
       return Vec;
     };
-    // If User == nullptr, the Scalar is used as extra arg. Generate
-    // ExtractElement instruction and update the record for this scalar in
-    // ExternallyUsedValues.
+    // If User == nullptr, the Scalar remains as scalar in vectorized
+    // instructions or is used as extra arg. Generate ExtractElement instruction
+    // and update the record for this scalar in ExternallyUsedValues.
     if (!User) {
-      assert(ExternallyUsedValues.count(Scalar) &&
-             "Scalar with nullptr as an external user must be registered in "
-             "ExternallyUsedValues map");
+      if (!ScalarsWithNullptrUser.insert(Scalar).second)
+        continue;
+      assert((ExternallyUsedValues.count(Scalar) ||
+              any_of(Scalar->users(),
+                     [this, Scalar](llvm::User *U) {
+                       TreeEntry *UseEntry = getTreeEntry(U);
+                       return UseEntry &&
+                              doesInTreeUserNeedToExtract(
+                                  Scalar,
+                                  cast<Instruction>(UseEntry->Scalars.front()),
+                                  TLI);
+                     })) &&
+             "Scalar with nullptr User must be registered in "
+             "ExternallyUsedValues map or remain as scalar in vectorized "
+             "instructions");
       if (auto *VecI = dyn_cast<Instruction>(Vec)) {
         if (auto *PHI = dyn_cast<PHINode>(VecI))
           Builder.SetInsertPoint(PHI->getParent(),
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/extract_in_tree_user.ll b/llvm/test/Transforms/SLPVectorizer/X86/extract_in_tree_user.ll
index be6b0bc47c0253..096f57d100a50f 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/extract_in_tree_user.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/extract_in_tree_user.ll
@@ -11,11 +11,11 @@ define i32 @fn1() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = load ptr, ptr @a, align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x ptr> poison, ptr [[TMP0]], i32 0
-; CHECK-NEXT:    [[SHUFFLE:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i64, <2 x ptr> [[SHUFFLE]], <2 x i64> <i64 11, i64 56>
-; CHECK-NEXT:    [[TMP3:%.*]] = ptrtoint <2 x ptr> [[TMP2]] to <2 x i64>
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP2]], i32 0
-; CHECK-NEXT:    store <2 x i64> [[TMP3]], ptr [[TMP4]], align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i64, <2 x ptr> [[TMP2]], <2 x i64> <i64 11, i64 56>
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP3]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = ptrtoint <2 x ptr> [[TMP3]] to <2 x i64>
+; CHECK-NEXT:    store <2 x i64> [[TMP5]], ptr [[TMP4]], align 8
 ; CHECK-NEXT:    ret i32 undef
 ;
 entry:
@@ -34,13 +34,13 @@ declare float @llvm.powi.f32.i32(float, i32)
 define void @fn2(ptr %a, ptr %b, ptr %c) {
 ; CHECK-LABEL: @fn2(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[A:%.*]], align 4
-; CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i32>, ptr [[B:%.*]], align 4
-; CHECK-NEXT:    [[TMP4:%.*]] = add <4 x i32> [[TMP1]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = sitofp <4 x i32> [[TMP4]] to <4 x float>
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x i32> [[TMP4]], i32 0
-; CHECK-NEXT:    [[TMP7:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> [[TMP5]], i32 [[TMP6]])
-; CHECK-NEXT:    store <4 x float> [[TMP7]], ptr [[C:%.*]], align 4
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[A:%.*]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[B:%.*]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = add <4 x i32> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x i32> [[TMP2]], i32 0
+; CHECK-NEXT:    [[TMP4:%.*]] = sitofp <4 x i32> [[TMP2]] to <4 x float>
+; CHECK-NEXT:    [[TMP5:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> [[TMP4]], i32 [[TMP3]])
+; CHECK-NEXT:    store <4 x float> [[TMP5]], ptr [[C:%.*]], align 4
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -90,12 +90,12 @@ define void @externally_used_ptrs() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = load ptr, ptr @a, align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x ptr> poison, ptr [[TMP0]], i32 0
-; CHECK-NEXT:    [[SHUFFLE:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i64, <2 x ptr> [[SHUFFLE]], <2 x i64> <i64 56, i64 11>
-; CHECK-NEXT:    [[TMP3:%.*]] = ptrtoint <2 x ptr> [[TMP2]] to <2 x i64>
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP2]], i32 1
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i64, <2 x ptr> [[TMP2]], <2 x i64> <i64 56, i64 11>
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP3]], i32 1
+; CHECK-NEXT:    [[TMP5:%.*]] = ptrtoint <2 x ptr> [[TMP3]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = load <2 x i64>, ptr [[TMP4]], align 8
-; CHECK-NEXT:    [[TMP7:%.*]] = add <2 x i64> [[TMP3]], [[TMP6]]
+; CHECK-NEXT:    [[TMP7:%.*]] = add <2 x i64> [[TMP5]], [[TMP6]]
 ; CHECK-NEXT:    store <2 x i64> [[TMP7]], ptr [[TMP4]], align 8
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder-reused-masked-gather2.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder-reused-masked-gather2.ll
index 75431c13a7703a..ddc2a1b819041f 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder-reused-masked-gather2.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder-reused-masked-gather2.ll
@@ -9,15 +9,15 @@ define void @"foo"(ptr addrspace(1) %0, ptr addrspace(1) %1) #0 {
 ; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x ptr addrspace(1)> poison, ptr addrspace(1) [[TMP0:%.*]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x ptr addrspace(1)> [[TMP3]], <4 x ptr addrspace(1)> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr i8, <4 x ptr addrspace(1)> [[TMP4]], <4 x i64> <i64 8, i64 12, i64 28, i64 24>
-; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i8, ptr addrspace(1) [[TMP1:%.*]], i64 8
-; CHECK-NEXT:    [[TMP7:%.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)> [[TMP5]], i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x float> poison)
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x float> [[TMP7]], <4 x float> poison, <8 x i32> <i32 0, i32 3, i32 0, i32 3, i32 2, i32 1, i32 2, i32 1>
-; CHECK-NEXT:    [[TMP9:%.*]] = load <8 x float>, ptr addrspace(1) [[TMP6]], align 4
-; CHECK-NEXT:    [[TMP10:%.*]] = fmul <8 x float> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP11:%.*]] = fadd <8 x float> [[TMP10]], zeroinitializer
-; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <8 x float> [[TMP11]], <8 x float> poison, <8 x i32> <i32 0, i32 5, i32 2, i32 7, i32 4, i32 1, i32 6, i32 3>
-; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <4 x ptr addrspace(1)> [[TMP5]], i32 0
-; CHECK-NEXT:    store <8 x float> [[TMP12]], ptr addrspace(1) [[TMP13]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x ptr addrspace(1)> [[TMP5]], i32 0
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i8, ptr addrspace(1) [[TMP1:%.*]], i64 8
+; CHECK-NEXT:    [[TMP8:%.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)> [[TMP5]], i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x float> poison)
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <4 x float> [[TMP8]], <4 x float> poison, <8 x i32> <i32 0, i32 3, i32 0, i32 3, i32 2, i32 1, i32 2, i32 1>
+; CHECK-NEXT:    [[TMP10:%.*]] = load <8 x float>, ptr addrspace(1) [[TMP7]], align 4
+; CHECK-NEXT:    [[TMP11:%.*]] = fmul <8 x float> [[TMP9]], [[TMP10]]
+; CHECK-NEXT:    [[TMP12:%.*]] = fadd <8 x float> [[TMP11]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <8 x float> [[TMP12]], <8 x float> poison, <8 x i32> <i32 0, i32 5, i32 2, i32 7, i32 4, i32 1, i32 6, i32 3>
+; CHECK-NEXT:    store <8 x float> [[TMP13]], ptr addrspace(1) [[TMP6]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %3 = getelementptr inbounds i8, ptr addrspace(1) %0, i64 8

>From 721914da49ffa2111ca4e7f9199b8038c6907260 Mon Sep 17 00:00:00 2001
From: "xumingjie.enna1" <xumingjie.enna1 at bytedance.com>
Date: Wed, 27 Dec 2023 10:07:48 +0800
Subject: [PATCH 2/3] update from comment

---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 01eb624ab3bff9..eecaf5d9710656 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11920,9 +11920,9 @@ Value *BoUpSLP::vectorizeTree(
         continue;
       assert((ExternallyUsedValues.count(Scalar) ||
               any_of(Scalar->users(),
-                     [this, Scalar](llvm::User *U) {
+                     [&](llvm::User *U) {
                        TreeEntry *UseEntry = getTreeEntry(U);
-                       return UseEntry &&
+                       return UseEntry && UseEntry->State == TreeEntry::Vectorize &&
                               doesInTreeUserNeedToExtract(
                                   Scalar,
                                   cast<Instruction>(UseEntry->Scalars.front()),

>From 6533e84296b9338a54cddb48bbc10b67f00a5c47 Mon Sep 17 00:00:00 2001
From: "xumingjie.enna1" <xumingjie.enna1 at bytedance.com>
Date: Wed, 27 Dec 2023 10:12:39 +0800
Subject: [PATCH 3/3] clang-format

---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index eecaf5d9710656..01085314bcbc5c 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11922,7 +11922,8 @@ Value *BoUpSLP::vectorizeTree(
               any_of(Scalar->users(),
                      [&](llvm::User *U) {
                        TreeEntry *UseEntry = getTreeEntry(U);
-                       return UseEntry && UseEntry->State == TreeEntry::Vectorize &&
+                       return UseEntry &&
+                              UseEntry->State == TreeEntry::Vectorize &&
                               doesInTreeUserNeedToExtract(
                                   Scalar,
                                   cast<Instruction>(UseEntry->Scalars.front()),



More information about the llvm-commits mailing list