[llvm] [SLP] no need to generate extract for in-tree uses for original scala… (PR #76077)

via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 25 07:56:54 PST 2023


https://github.com/Enna1 updated https://github.com/llvm/llvm-project/pull/76077

>From dd0dd255e65df5713847de74eee7c5c15a6f8a69 Mon Sep 17 00:00:00 2001
From: "xumingjie.enna1" <xumingjie.enna1 at bytedance.com>
Date: Thu, 21 Dec 2023 15:23:00 +0800
Subject: [PATCH] [SLP] no need to generate extract for in-tree uses for
 original scalar instruction.

Before https://github.com/llvm/llvm-project/commit/77a609b55636dc540090ef9105c60a99cfdbd1dd,
we always skip in-tree uses of the vectorized scalars in `buildExternalUses()`,
that commit handles the case that if the in-tree use is scalar operand in vectorized instruction,
we need to generate extract for these in-tree uses.

in-tree uses remain as scalar in vectorized instructions can be 3 cases:
- The pointer operand of vectorized LoadInst uses an in-tree scalar
- The pointer operand of vectorized StoreInst uses an in-tree scalar
- The scalar argument of vector form intrinsic uses an in-tree scalar

Generating extract for in-tree uses for vectorized instructions are implemented in `BoUpSLP::vectorizeTree()`:
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11497-L11506
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11542-L11551
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11657-L11667

However, https://github.com/llvm/llvm-project/commit/77a609b55636dc540090ef9105c60a99cfdbd1dd
not only generates extract for vectorized instructions,
but also generates extract for original scalar instructions.
There is no need to generate extract for origin scalar instrutions,
as these scalar instructions will be replaced by vector instructions and get erased later.

This patch replaces extracts for original scalar instructions with corresponding vectorized instructions,
and remove
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11497-L11506
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11542-L11551
- https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp#L11657-L11667
extracts.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 86 +++++++------------
 .../SLPVectorizer/X86/extract_in_tree_user.ll | 34 ++++----
 .../X86/reorder-reused-masked-gather2.ll      | 18 ++--
 3 files changed, 58 insertions(+), 80 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 32913b3f55697e..bf397e41b2cc28 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -4925,36 +4925,38 @@ void BoUpSLP::buildExternalUses(
         LLVM_DEBUG(dbgs() << "SLP: Checking user:" << *U << ".\n");
 
         Instruction *UserInst = dyn_cast<Instruction>(U);
-        if (!UserInst)
+        if (!UserInst || isDeleted(UserInst))
           continue;
 
-        if (isDeleted(UserInst))
+        // Ignore users in the user ignore list.
+        if (UserIgnoreList && UserIgnoreList->contains(UserInst))
           continue;
 
         // Skip in-tree scalars that become vectors
         if (TreeEntry *UseEntry = getTreeEntry(U)) {
-          Value *UseScalar = UseEntry->Scalars[0];
           // Some in-tree scalars will remain as scalar in vectorized
-          // instructions. If that is the case, the one in Lane 0 will
+          // instructions. If that is the case, the one in FoundLane will
           // be used.
-          if (UseScalar != U ||
-              UseEntry->State == TreeEntry::ScatterVectorize ||
+          if (UseEntry->State == TreeEntry::ScatterVectorize ||
               UseEntry->State == TreeEntry::PossibleStridedVectorize ||
-              !doesInTreeUserNeedToExtract(Scalar, UserInst, TLI)) {
+              !doesInTreeUserNeedToExtract(
+                  Scalar, cast<Instruction>(UseEntry->Scalars.front()), TLI)) {
             LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
                               << ".\n");
             assert(UseEntry->State != TreeEntry::NeedToGather && "Bad state");
             continue;
           }
+          U = nullptr;
         }
 
-        // Ignore users in the user ignore list.
-        if (UserIgnoreList && UserIgnoreList->contains(UserInst))
-          continue;
-
-        LLVM_DEBUG(dbgs() << "SLP: Need to extract:" << *U << " from lane "
-                          << Lane << " from " << *Scalar << ".\n");
-        ExternalUses.push_back(ExternalUser(Scalar, U, FoundLane));
+        LLVM_DEBUG(dbgs() << "SLP: Need to extract:" << *UserInst
+                          << " from lane " << Lane << " from " << *Scalar
+                          << ".\n");
+        ExternalUses.emplace_back(Scalar, U, FoundLane);
+        // If U == nullptr, all uses of this Scalar will be automatically
+        // replaced by extractelement.
+        if (!U)
+          break;
       }
     }
   }
@@ -11493,17 +11495,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       Value *PO = LI->getPointerOperand();
       if (E->State == TreeEntry::Vectorize) {
         NewLI = Builder.CreateAlignedLoad(VecTy, PO, LI->getAlign());
-
-        // The pointer operand uses an in-tree scalar so we add the new
-        // LoadInst to ExternalUses list to make sure that an extract will
-        // be generated in the future.
-        if (isa<Instruction>(PO)) {
-          if (TreeEntry *Entry = getTreeEntry(PO)) {
-            // Find which lane we need to extract.
-            unsigned FoundLane = Entry->findLaneForValue(PO);
-            ExternalUses.emplace_back(PO, NewLI, FoundLane);
-          }
-        }
       } else {
         assert((E->State == TreeEntry::ScatterVectorize ||
                 E->State == TreeEntry::PossibleStridedVectorize) &&
@@ -11539,17 +11530,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       StoreInst *ST =
           Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());
 
-      // The pointer operand uses an in-tree scalar, so add the new StoreInst to
-      // ExternalUses to make sure that an extract will be generated in the
-      // future.
-      if (isa<Instruction>(Ptr)) {
-        if (TreeEntry *Entry = getTreeEntry(Ptr)) {
-          // Find which lane we need to extract.
-          unsigned FoundLane = Entry->findLaneForValue(Ptr);
-          ExternalUses.push_back(ExternalUser(Ptr, ST, FoundLane));
-        }
-      }
-
       Value *V = propagateMetadata(ST, E->Scalars);
 
       E->VectorizedValue = V;
@@ -11654,18 +11634,6 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       CI->getOperandBundlesAsDefs(OpBundles);
       Value *V = Builder.CreateCall(CF, OpVecs, OpBundles);
 
-      // The scalar argument uses an in-tree scalar so we add the new vectorized
-      // call to ExternalUses list to make sure that an extract will be
-      // generated in the future.
-      if (isa_and_present<Instruction>(ScalarArg)) {
-        if (TreeEntry *Entry = getTreeEntry(ScalarArg)) {
-          // Find which lane we need to extract.
-          unsigned FoundLane = Entry->findLaneForValue(ScalarArg);
-          ExternalUses.push_back(
-              ExternalUser(ScalarArg, cast<User>(V), FoundLane));
-        }
-      }
-
       propagateIRFlags(V, E->Scalars, VL0);
       V = FinalShuffle(V, E, VecTy, IsSigned);
 
@@ -11947,13 +11915,23 @@ Value *BoUpSLP::vectorizeTree(
       VectorToInsertElement.try_emplace(Vec, IE);
       return Vec;
     };
-    // If User == nullptr, the Scalar is used as extra arg. Generate
-    // ExtractElement instruction and update the record for this scalar in
-    // ExternallyUsedValues.
+    // If User == nullptr, the Scalar remains as scalar in vectorized
+    // instructions or is used as extra arg. Generate ExtractElement instruction
+    // and update the record for this scalar in ExternallyUsedValues.
     if (!User) {
-      assert(ExternallyUsedValues.count(Scalar) &&
-             "Scalar with nullptr as an external user must be registered in "
-             "ExternallyUsedValues map");
+      assert((ExternallyUsedValues.count(Scalar) ||
+              any_of(Scalar->users(),
+                     [this, Scalar](llvm::User *U) {
+                       TreeEntry *UseEntry = getTreeEntry(U);
+                       return UseEntry &&
+                              doesInTreeUserNeedToExtract(
+                                  Scalar,
+                                  cast<Instruction>(UseEntry->Scalars.front()),
+                                  TLI);
+                     })) &&
+             "Scalar with nullptr must be registered in "
+             "ExternallyUsedValues map or remain as scalar in vectorized "
+             "instructions");
       if (auto *VecI = dyn_cast<Instruction>(Vec)) {
         if (auto *PHI = dyn_cast<PHINode>(VecI))
           Builder.SetInsertPoint(PHI->getParent(),
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/extract_in_tree_user.ll b/llvm/test/Transforms/SLPVectorizer/X86/extract_in_tree_user.ll
index be6b0bc47c0253..096f57d100a50f 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/extract_in_tree_user.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/extract_in_tree_user.ll
@@ -11,11 +11,11 @@ define i32 @fn1() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = load ptr, ptr @a, align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x ptr> poison, ptr [[TMP0]], i32 0
-; CHECK-NEXT:    [[SHUFFLE:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i64, <2 x ptr> [[SHUFFLE]], <2 x i64> <i64 11, i64 56>
-; CHECK-NEXT:    [[TMP3:%.*]] = ptrtoint <2 x ptr> [[TMP2]] to <2 x i64>
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP2]], i32 0
-; CHECK-NEXT:    store <2 x i64> [[TMP3]], ptr [[TMP4]], align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i64, <2 x ptr> [[TMP2]], <2 x i64> <i64 11, i64 56>
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP3]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = ptrtoint <2 x ptr> [[TMP3]] to <2 x i64>
+; CHECK-NEXT:    store <2 x i64> [[TMP5]], ptr [[TMP4]], align 8
 ; CHECK-NEXT:    ret i32 undef
 ;
 entry:
@@ -34,13 +34,13 @@ declare float @llvm.powi.f32.i32(float, i32)
 define void @fn2(ptr %a, ptr %b, ptr %c) {
 ; CHECK-LABEL: @fn2(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[A:%.*]], align 4
-; CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i32>, ptr [[B:%.*]], align 4
-; CHECK-NEXT:    [[TMP4:%.*]] = add <4 x i32> [[TMP1]], [[TMP3]]
-; CHECK-NEXT:    [[TMP5:%.*]] = sitofp <4 x i32> [[TMP4]] to <4 x float>
-; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x i32> [[TMP4]], i32 0
-; CHECK-NEXT:    [[TMP7:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> [[TMP5]], i32 [[TMP6]])
-; CHECK-NEXT:    store <4 x float> [[TMP7]], ptr [[C:%.*]], align 4
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[A:%.*]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[B:%.*]], align 4
+; CHECK-NEXT:    [[TMP2:%.*]] = add <4 x i32> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = extractelement <4 x i32> [[TMP2]], i32 0
+; CHECK-NEXT:    [[TMP4:%.*]] = sitofp <4 x i32> [[TMP2]] to <4 x float>
+; CHECK-NEXT:    [[TMP5:%.*]] = call <4 x float> @llvm.powi.v4f32.i32(<4 x float> [[TMP4]], i32 [[TMP3]])
+; CHECK-NEXT:    store <4 x float> [[TMP5]], ptr [[C:%.*]], align 4
 ; CHECK-NEXT:    ret void
 ;
 entry:
@@ -90,12 +90,12 @@ define void @externally_used_ptrs() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = load ptr, ptr @a, align 8
 ; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x ptr> poison, ptr [[TMP0]], i32 0
-; CHECK-NEXT:    [[SHUFFLE:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i64, <2 x ptr> [[SHUFFLE]], <2 x i64> <i64 56, i64 11>
-; CHECK-NEXT:    [[TMP3:%.*]] = ptrtoint <2 x ptr> [[TMP2]] to <2 x i64>
-; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP2]], i32 1
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i64, <2 x ptr> [[TMP2]], <2 x i64> <i64 56, i64 11>
+; CHECK-NEXT:    [[TMP4:%.*]] = extractelement <2 x ptr> [[TMP3]], i32 1
+; CHECK-NEXT:    [[TMP5:%.*]] = ptrtoint <2 x ptr> [[TMP3]] to <2 x i64>
 ; CHECK-NEXT:    [[TMP6:%.*]] = load <2 x i64>, ptr [[TMP4]], align 8
-; CHECK-NEXT:    [[TMP7:%.*]] = add <2 x i64> [[TMP3]], [[TMP6]]
+; CHECK-NEXT:    [[TMP7:%.*]] = add <2 x i64> [[TMP5]], [[TMP6]]
 ; CHECK-NEXT:    store <2 x i64> [[TMP7]], ptr [[TMP4]], align 8
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/reorder-reused-masked-gather2.ll b/llvm/test/Transforms/SLPVectorizer/X86/reorder-reused-masked-gather2.ll
index 75431c13a7703a..ddc2a1b819041f 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/reorder-reused-masked-gather2.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/reorder-reused-masked-gather2.ll
@@ -9,15 +9,15 @@ define void @"foo"(ptr addrspace(1) %0, ptr addrspace(1) %1) #0 {
 ; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x ptr addrspace(1)> poison, ptr addrspace(1) [[TMP0:%.*]], i32 0
 ; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x ptr addrspace(1)> [[TMP3]], <4 x ptr addrspace(1)> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr i8, <4 x ptr addrspace(1)> [[TMP4]], <4 x i64> <i64 8, i64 12, i64 28, i64 24>
-; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds i8, ptr addrspace(1) [[TMP1:%.*]], i64 8
-; CHECK-NEXT:    [[TMP7:%.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)> [[TMP5]], i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x float> poison)
-; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <4 x float> [[TMP7]], <4 x float> poison, <8 x i32> <i32 0, i32 3, i32 0, i32 3, i32 2, i32 1, i32 2, i32 1>
-; CHECK-NEXT:    [[TMP9:%.*]] = load <8 x float>, ptr addrspace(1) [[TMP6]], align 4
-; CHECK-NEXT:    [[TMP10:%.*]] = fmul <8 x float> [[TMP8]], [[TMP9]]
-; CHECK-NEXT:    [[TMP11:%.*]] = fadd <8 x float> [[TMP10]], zeroinitializer
-; CHECK-NEXT:    [[TMP12:%.*]] = shufflevector <8 x float> [[TMP11]], <8 x float> poison, <8 x i32> <i32 0, i32 5, i32 2, i32 7, i32 4, i32 1, i32 6, i32 3>
-; CHECK-NEXT:    [[TMP13:%.*]] = extractelement <4 x ptr addrspace(1)> [[TMP5]], i32 0
-; CHECK-NEXT:    store <8 x float> [[TMP12]], ptr addrspace(1) [[TMP13]], align 4
+; CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x ptr addrspace(1)> [[TMP5]], i32 0
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr inbounds i8, ptr addrspace(1) [[TMP1:%.*]], i64 8
+; CHECK-NEXT:    [[TMP8:%.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)> [[TMP5]], i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x float> poison)
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <4 x float> [[TMP8]], <4 x float> poison, <8 x i32> <i32 0, i32 3, i32 0, i32 3, i32 2, i32 1, i32 2, i32 1>
+; CHECK-NEXT:    [[TMP10:%.*]] = load <8 x float>, ptr addrspace(1) [[TMP7]], align 4
+; CHECK-NEXT:    [[TMP11:%.*]] = fmul <8 x float> [[TMP9]], [[TMP10]]
+; CHECK-NEXT:    [[TMP12:%.*]] = fadd <8 x float> [[TMP11]], zeroinitializer
+; CHECK-NEXT:    [[TMP13:%.*]] = shufflevector <8 x float> [[TMP12]], <8 x float> poison, <8 x i32> <i32 0, i32 5, i32 2, i32 7, i32 4, i32 1, i32 6, i32 3>
+; CHECK-NEXT:    store <8 x float> [[TMP13]], ptr addrspace(1) [[TMP6]], align 4
 ; CHECK-NEXT:    ret void
 ;
   %3 = getelementptr inbounds i8, ptr addrspace(1) %0, i64 8



More information about the llvm-commits mailing list