[llvm] 40105a9 - [SLP]Find reused scalars in buildvector sequences, if any.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 5 09:39:30 PDT 2023


Author: Alexey Bataev
Date: 2023-04-05T09:37:05-07:00
New Revision: 40105a993399699fe351789c7eb2a0e6d36f440a

URL: https://github.com/llvm/llvm-project/commit/40105a993399699fe351789c7eb2a0e6d36f440a
DIFF: https://github.com/llvm/llvm-project/commit/40105a993399699fe351789c7eb2a0e6d36f440a.diff

LOG: [SLP]Find reused scalars in buildvector sequences, if any.

Patch generalizes analysis of scalars. The main part is outlined into
lambda, which can be used to find reused inserted scalars and emit
shuffle for them instead of multiple insertelement instructions, if the
permutation is found alreadyi. I.e. some scalars are transformed by the
permutation of previously vectorized nodes, and some are inserted
directly.

Reworked part of D110978

Differential Revision: https://reviews.llvm.org/D146564

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/X86/jumbled-load-multiuse.ll
    llvm/test/Transforms/SLPVectorizer/X86/matched-shuffled-entries.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 7f3e6abe4c829..2689f5a8f2d07 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -9411,8 +9411,11 @@ Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
         continue;
       auto *EI = cast<ExtractElementInst>(E->Scalars[I]);
       VecBase = EI->getVectorOperand();
-      // TODO: EI can be erased, if all its users are vectorized. But need to
-      // emit shuffles for such extractelement instructions.
+      // If all users are vectorized - can delete the extractelement itself.
+      if (any_of(EI->users(),
+                 [&](User *U) { return !ScalarToTreeEntry.count(U); }))
+        continue;
+      eraseInstruction(EI);
     }
     return VecBase;
   };
@@ -9471,7 +9474,6 @@ Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
   Value *Vec = nullptr;
   SmallVector<int> Mask;
   SmallVector<int> ExtractMask;
-  SmallVector<int> ReuseMask;
   std::optional<TargetTransformInfo::ShuffleKind> ExtractShuffle;
   std::optional<TargetTransformInfo::ShuffleKind> GatherShuffle;
   SmallVector<const TreeEntry *> Entries;
@@ -9522,6 +9524,95 @@ Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
       }
     }
   }
+  auto TryPackScalars = [&](SmallVectorImpl<Value *> &Scalars,
+                            SmallVectorImpl<int> &ReuseMask,
+                            bool IsRootPoison) {
+    // For splats with can emit broadcasts instead of gathers, so try to find
+    // such sequences.
+    bool IsSplat = IsRootPoison && isSplat(Scalars) &&
+                   (Scalars.size() > 2 || Scalars.front() == Scalars.back());
+    Scalars.append(VF - Scalars.size(), PoisonValue::get(ScalarTy));
+    SmallVector<int> UndefPos;
+    DenseMap<Value *, unsigned> UniquePositions;
+    // Gather unique non-const values and all constant values.
+    // For repeated values, just shuffle them.
+    int NumNonConsts = 0;
+    int SinglePos = 0;
+    for (auto [I, V] : enumerate(Scalars)) {
+      if (isa<UndefValue>(V)) {
+        if (!isa<PoisonValue>(V)) {
+          ReuseMask[I] = I;
+          UndefPos.push_back(I);
+        }
+        continue;
+      }
+      if (isConstant(V)) {
+        ReuseMask[I] = I;
+        continue;
+      }
+      ++NumNonConsts;
+      SinglePos = I;
+      Value *OrigV = V;
+      Scalars[I] = PoisonValue::get(ScalarTy);
+      if (IsSplat) {
+        Scalars.front() = OrigV;
+        ReuseMask[I] = 0;
+      } else {
+        const auto Res = UniquePositions.try_emplace(OrigV, I);
+        Scalars[Res.first->second] = OrigV;
+        ReuseMask[I] = Res.first->second;
+      }
+    }
+    if (NumNonConsts == 1) {
+      // Restore single insert element.
+      if (IsSplat) {
+        ReuseMask.assign(VF, UndefMaskElem);
+        std::swap(Scalars.front(), Scalars[SinglePos]);
+        if (!UndefPos.empty() && UndefPos.front() == 0)
+          Scalars.front() = UndefValue::get(ScalarTy);
+      }
+      ReuseMask[SinglePos] = SinglePos;
+    } else if (!UndefPos.empty() && IsSplat) {
+      // For undef values, try to replace them with the simple broadcast.
+      // We can do it if the broadcasted value is guaranteed to be
+      // non-poisonous, or by freezing the incoming scalar value first.
+      auto *It = find_if(Scalars, [this, E](Value *V) {
+        return !isa<UndefValue>(V) &&
+               (getTreeEntry(V) || isGuaranteedNotToBePoison(V) ||
+                (E->UserTreeIndices.size() == 1 &&
+                 any_of(V->uses(), [E](const Use &U) {
+                   // Check if the value already used in the same operation in
+                   // one of the nodes already.
+                   return E->UserTreeIndices.front().EdgeIdx !=
+                              U.getOperandNo() &&
+                          is_contained(
+                              E->UserTreeIndices.front().UserTE->Scalars,
+                              U.getUser());
+                 })));
+      });
+      if (It != Scalars.end()) {
+        // Replace undefs by the non-poisoned scalars and emit broadcast.
+        int Pos = std::distance(Scalars.begin(), It);
+        for_each(UndefPos, [&](int I) {
+          // Set the undef position to the non-poisoned scalar.
+          ReuseMask[I] = Pos;
+          // Replace the undef by the poison, in the mask it is replaced by
+          // non-poisoned scalar already.
+          if (I != Pos)
+            Scalars[I] = PoisonValue::get(ScalarTy);
+        });
+      } else {
+        // Replace undefs by the poisons, emit broadcast and then emit
+        // freeze.
+        for_each(UndefPos, [&](int I) {
+          ReuseMask[I] = UndefMaskElem;
+          if (isa<UndefValue>(Scalars[I]))
+            Scalars[I] = PoisonValue::get(ScalarTy);
+        });
+        NeedFreeze = true;
+      }
+    }
+  };
   if (ExtractShuffle || GatherShuffle) {
     bool IsNonPoisoned = true;
     bool IsUsedInExpr = false;
@@ -9549,6 +9640,8 @@ Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
         }
       }
       if (Vec2) {
+        IsNonPoisoned &=
+            isGuaranteedNotToBePoison(Vec1) && isGuaranteedNotToBePoison(Vec2);
         ShuffleBuilder.add(Vec1, Vec2, ExtractMask);
       } else if (Vec1) {
         ShuffleBuilder.add(Vec1, ExtractMask);
@@ -9569,6 +9662,9 @@ Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
       } else {
         ShuffleBuilder.add(Entries.front()->VectorizedValue,
                            Entries.back()->VectorizedValue, Mask);
+        IsNonPoisoned &=
+            isGuaranteedNotToBePoison(Entries.front()->VectorizedValue) &&
+            isGuaranteedNotToBePoison(Entries.back()->VectorizedValue);
       }
     }
     // Try to figure out best way to combine values: build a shuffle and insert
@@ -9617,11 +9713,8 @@ Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
     // Generate constants for final shuffle and build a mask for them.
     if (!all_of(GatheredScalars, PoisonValue::classof)) {
       SmallVector<int> BVMask(GatheredScalars.size(), UndefMaskElem);
+      TryPackScalars(GatheredScalars, BVMask, /*IsRootPoison=*/true);
       Value *BV = gather(GatheredScalars);
-      for (int I = 0, Sz = GatheredScalars.size(); I < Sz; ++I) {
-        if (!isa<PoisonValue>(GatheredScalars[I]))
-          BVMask[I] = I;
-      }
       ShuffleBuilder.add(BV, BVMask);
     }
     if (all_of(NonConstants, [=](Value *V) {
@@ -9634,111 +9727,25 @@ Value *BoUpSLP::createBuildVector(const TreeEntry *E) {
       Vec = ShuffleBuilder.finalize(
           E->ReuseShuffleIndices, E->Scalars.size(),
           [&](Value *&Vec, SmallVectorImpl<int> &Mask) {
+            TryPackScalars(NonConstants, Mask, /*IsRootPoison=*/false);
             Vec = gather(NonConstants, Vec);
-            for (unsigned I = 0, Sz = Mask.size(); I < Sz; ++I)
-              if (!isa<PoisonValue>(NonConstants[I]))
-                Mask[I] = I;
           });
-  } else if (!allConstant(E->Scalars)) {
-    // TODO: remove this code once able to combine shuffled vectors and build
-    // vector elements.
-    copy(E->Scalars, GatheredScalars.begin());
-    // For splats with can emit broadcasts instead of gathers, so try to find
-    // such sequences.
-    bool IsSplat = isSplat(GatheredScalars) &&
-                   (GatheredScalars.size() > 2 ||
-                    GatheredScalars.front() == GatheredScalars.back());
-    GatheredScalars.append(VF - GatheredScalars.size(),
-                           PoisonValue::get(ScalarTy));
-    ReuseMask.assign(VF, UndefMaskElem);
-    SmallVector<int> UndefPos;
-    DenseMap<Value *, unsigned> UniquePositions;
-    // Gather unique non-const values and all constant values.
-    // For repeated values, just shuffle them.
-    int NumNonConsts = 0;
-    int SinglePos = 0;
-    for (auto [I, V] : enumerate(GatheredScalars)) {
-      if (isa<UndefValue>(V)) {
-        if (!isa<PoisonValue>(V)) {
-          ReuseMask[I] = I;
-          UndefPos.push_back(I);
-        }
-        continue;
-      }
-      if (isConstant(V)) {
-        ReuseMask[I] = I;
-        continue;
-      }
-      ++NumNonConsts;
-      SinglePos = I;
-      Value *OrigV = V;
-      GatheredScalars[I] = PoisonValue::get(ScalarTy);
-      if (IsSplat) {
-        GatheredScalars.front() = OrigV;
-        ReuseMask[I] = 0;
-      } else {
-        const auto Res = UniquePositions.try_emplace(OrigV, I);
-        GatheredScalars[Res.first->second] = OrigV;
-        ReuseMask[I] = Res.first->second;
-      }
-    }
-    if (NumNonConsts == 1) {
-      // Restore single insert element.
-      if (IsSplat) {
-        ReuseMask.assign(VF, UndefMaskElem);
-        std::swap(GatheredScalars.front(), GatheredScalars[SinglePos]);
-        if (!UndefPos.empty() && UndefPos.front() == 0)
-          GatheredScalars.front() = UndefValue::get(ScalarTy);
-      }
-      ReuseMask[SinglePos] = SinglePos;
-    } else if (!UndefPos.empty() && IsSplat) {
-      // For undef values, try to replace them with the simple broadcast.
-      // We can do it if the broadcasted value is guaranteed to be
-      // non-poisonous, or by freezing the incoming scalar value first.
-      auto *It = find_if(GatheredScalars, [this, E](Value *V) {
-        return !isa<UndefValue>(V) &&
-               (getTreeEntry(V) || isGuaranteedNotToBePoison(V) ||
-                (E->UserTreeIndices.size() == 1 &&
-                 any_of(V->uses(), [E](const Use &U) {
-                   // Check if the value already used in the same operation in
-                   // one of the nodes already.
-                   return E->UserTreeIndices.front().EdgeIdx !=
-                              U.getOperandNo() &&
-                          is_contained(
-                              E->UserTreeIndices.front().UserTE->Scalars,
-                              U.getUser());
-                 })));
-      });
-      if (It != GatheredScalars.end()) {
-        // Replace undefs by the non-poisoned scalars and emit broadcast.
-        int Pos = std::distance(GatheredScalars.begin(), It);
-        for_each(UndefPos, [&](int I) {
-          // Set the undef position to the non-poisoned scalar.
-          ReuseMask[I] = Pos;
-          // Replace the undef by the poison, in the mask it is replaced by
-          // non-poisoned scalar already.
-          if (I != Pos)
-            GatheredScalars[I] = PoisonValue::get(ScalarTy);
-        });
-      } else {
-        // Replace undefs by the poisons, emit broadcast and then emit
-        // freeze.
-        for_each(UndefPos, [&](int I) {
-          ReuseMask[I] = UndefMaskElem;
-          if (isa<UndefValue>(GatheredScalars[I]))
-            GatheredScalars[I] = PoisonValue::get(ScalarTy);
-        });
-        NeedFreeze = true;
-      }
-    }
+  } else if (!allConstant(GatheredScalars)) {
     // Gather unique scalars and all constants.
+    SmallVector<int> ReuseMask(GatheredScalars.size(), UndefMaskElem);
+    TryPackScalars(GatheredScalars, ReuseMask, /*IsRootPoison=*/true);
     Vec = gather(GatheredScalars);
     ShuffleBuilder.add(Vec, ReuseMask);
     Vec = ShuffleBuilder.finalize(E->ReuseShuffleIndices);
   } else {
     // Gather all constants.
+    SmallVector<int> Mask(E->Scalars.size(), UndefMaskElem);
+    for (auto [I, V] : enumerate(E->Scalars)) {
+      if (!isa<PoisonValue>(V))
+        Mask[I] = I;
+    }
     Vec = gather(E->Scalars);
-    ShuffleBuilder.add(Vec, ReuseMask);
+    ShuffleBuilder.add(Vec, Mask);
     Vec = ShuffleBuilder.finalize(E->ReuseShuffleIndices);
   }
 

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/jumbled-load-multiuse.ll b/llvm/test/Transforms/SLPVectorizer/X86/jumbled-load-multiuse.ll
index 6adcf21a4bc63..e687f440e728b 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/jumbled-load-multiuse.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/jumbled-load-multiuse.ll
@@ -9,7 +9,7 @@ define i32 @fn1() {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr @b, align 4
 ; CHECK-NEXT:    [[TMP1:%.*]] = icmp sgt <4 x i32> [[TMP0]], zeroinitializer
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP0]], <4 x i32> <i32 8, i32 poison, i32 ptrtoint (ptr @fn1 to i32), i32 ptrtoint (ptr @fn1 to i32)>, <4 x i32> <i32 4, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP0]], <4 x i32> <i32 8, i32 poison, i32 ptrtoint (ptr @fn1 to i32), i32 poison>, <4 x i32> <i32 4, i32 1, i32 6, i32 6>
 ; CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[TMP1]], <4 x i32> [[TMP2]], <4 x i32> <i32 0, i32 6, i32 0, i32 0>
 ; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <4 x i32> [[TMP3]], <4 x i32> poison, <4 x i32> <i32 1, i32 2, i32 3, i32 0>
 ; CHECK-NEXT:    store <4 x i32> [[TMP4]], ptr @a, align 4

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/matched-shuffled-entries.ll b/llvm/test/Transforms/SLPVectorizer/X86/matched-shuffled-entries.ll
index a117fdc75bfdc..584eeb3710907 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/matched-shuffled-entries.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/matched-shuffled-entries.ll
@@ -18,7 +18,7 @@ define i32 @bar() local_unnamed_addr {
 ; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <16 x i32> [[TMP4]], <16 x i32> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 9, i32 11, i32 12, i32 13, i32 14, i32 15>
 ; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <16 x i32> [[TMP4]], <16 x i32> poison, <16 x i32> <i32 undef, i32 undef, i32 undef, i32 undef, i32 7, i32 6, i32 5, i32 4, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef>
 ; CHECK-NEXT:    [[TMP7:%.*]] = insertelement <16 x i32> [[TMP6]], i32 [[SUB102_3]], i32 12
-; CHECK-NEXT:    [[TMP8:%.*]] = insertelement <16 x i32> [[TMP7]], i32 [[SUB102_3]], i32 15
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <16 x i32> [[TMP7]], <16 x i32> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 12>
 ; CHECK-NEXT:    [[TMP9:%.*]] = add nsw <16 x i32> [[TMP5]], [[TMP8]]
 ; CHECK-NEXT:    [[TMP10:%.*]] = sub nsw <16 x i32> [[TMP5]], [[TMP8]]
 ; CHECK-NEXT:    [[TMP11:%.*]] = shufflevector <16 x i32> [[TMP9]], <16 x i32> [[TMP10]], <16 x i32> <i32 0, i32 1, i32 18, i32 19, i32 4, i32 5, i32 22, i32 23, i32 8, i32 9, i32 26, i32 27, i32 12, i32 13, i32 30, i32 31>


        


More information about the llvm-commits mailing list