[llvm] [SLP][REVEC] Make ShuffleCostEstimator and ShuffleInstructionBuilder can vectorize vector instructions. (PR #99499)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jul 18 07:29:31 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Han-Kuan Chen (HanKuanChen)

<details>
<summary>Changes</summary>

This PR will try to make ShuffleCostEstimator and ShuffleInstructionBuilder can vectorize vector instructions.
Since the mask indices expect the source is scalar type, we need to transform the mask indices into a form which can be used when REVEC is enabled. The transform is only called before the CreateShuffleVector.

In addition, when REVEC is enabled, CreateInsertVector and CreateExtractVector are used because the scalar type may be a FixedVectorType.

---
Full diff: https://github.com/llvm/llvm-project/pull/99499.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+79-11) 
- (modified) llvm/test/Transforms/SLPVectorizer/revec.ll (+36) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index d8c3bae06e932..aefec86d332fe 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -253,6 +253,21 @@ static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) {
                               VF * getNumElements(ScalarTy));
 }
 
+static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements,
+                                                   SmallVectorImpl<int> &Mask) {
+  // The ShuffleBuilder implementation use shufflevector to splat an "element".
+  // But the element have different meaning for SLP (scalar) and REVEC
+  // (vector). We need to expand Mask into masks which shufflevector can use
+  // directly.
+  SmallVector<int> NewMask(Mask.size() * VecTyNumElements);
+  for (size_t I = 0, E = Mask.size(); I != E; ++I)
+    for (unsigned J = 0; J != VecTyNumElements; ++J)
+      NewMask[I * VecTyNumElements + J] = Mask[I] == PoisonMaskElem
+                                              ? PoisonMaskElem
+                                              : Mask[I] * VecTyNumElements + J;
+  Mask.swap(NewMask);
+}
+
 /// \returns True if the value is a constant (but not globals/constant
 /// expressions).
 static bool isConstant(Value *V) {
@@ -8800,6 +8815,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
       // Shuffle single vector.
       ExtraCost += GetValueMinBWAffectedCost(V1);
       CommonVF = cast<FixedVectorType>(V1->getType())->getNumElements();
+      if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+        CommonVF /= VecTy->getNumElements();
       assert(
           all_of(Mask,
                  [=](int Idx) { return Idx < static_cast<int>(CommonVF); }) &&
@@ -8807,6 +8824,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     } else if (V1 && !V2) {
       // Shuffle vector and tree node.
       unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+      if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+        VF /= VecTy->getNumElements();
       const TreeEntry *E2 = P2.get<const TreeEntry *>();
       CommonVF = std::max(VF, E2->getVectorFactor());
       assert(all_of(Mask,
@@ -8833,6 +8852,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     } else if (!V1 && V2) {
       // Shuffle vector and tree node.
       unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements();
+      if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+        VF /= VecTy->getNumElements();
       const TreeEntry *E1 = P1.get<const TreeEntry *>();
       CommonVF = std::max(VF, E1->getVectorFactor());
       assert(all_of(Mask,
@@ -8863,6 +8884,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
       unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
       CommonVF =
           std::max(VF, cast<FixedVectorType>(V2->getType())->getNumElements());
+      if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+        CommonVF /= VecTy->getNumElements();
       assert(all_of(Mask,
                     [=](int Idx) {
                       return Idx < 2 * static_cast<int>(CommonVF);
@@ -8880,6 +8903,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
           V2 = getAllOnesValue(*R.DL, getWidenedType(ScalarTy, CommonVF));
       }
     }
+    if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+      transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
+                                             CommonMask);
     InVectors.front() =
         Constant::getNullValue(getWidenedType(ScalarTy, CommonMask.size()));
     if (InVectors.size() == 2)
@@ -9098,6 +9124,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
     assert(!InVectors.empty() && !CommonMask.empty() &&
            "Expected only tree entries from extracts/reused buildvectors.");
     unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
+    if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+      VF /= VecTy->getNumElements();
     if (InVectors.size() == 2) {
       Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask);
       transformMaskAfterShuffle(CommonMask, CommonMask);
@@ -9125,18 +9153,27 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
       if (MaskVF != 0)
         VF = std::min(VF, MaskVF);
       for (Value *V : VL.take_front(VF)) {
-        if (isa<UndefValue>(V)) {
-          Vals.push_back(cast<Constant>(V));
-          continue;
+        Type *Ty = V->getType();
+        Type *ScalarTy = Ty->getScalarType();
+        unsigned VNumElements = getNumElements(Ty);
+        for (unsigned I = 0; I != VNumElements; ++I) {
+          if (isa<PoisonValue>(V)) {
+            Vals.push_back(PoisonValue::get(ScalarTy));
+            continue;
+          }
+          if (isa<UndefValue>(V)) {
+            Vals.push_back(UndefValue::get(ScalarTy));
+            continue;
+          }
+          Vals.push_back(Constant::getNullValue(ScalarTy));
         }
-        Vals.push_back(Constant::getNullValue(V->getType()));
       }
       return ConstantVector::get(Vals);
     }
     return ConstantVector::getSplat(
         ElementCount::getFixed(
             cast<FixedVectorType>(Root->getType())->getNumElements()),
-        getAllOnesValue(*R.DL, ScalarTy));
+        getAllOnesValue(*R.DL, ScalarTy->getScalarType()));
   }
   InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
   /// Finalize emission of the shuffles.
@@ -11618,7 +11655,8 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
                                       Type *Ty) {
     Value *Scalar = V;
     if (Scalar->getType() != Ty) {
-      assert(Scalar->getType()->isIntegerTy() && Ty->isIntegerTy() &&
+      assert(Scalar->getType()->getScalarType()->isIntegerTy() &&
+             Ty->getScalarType()->isIntegerTy() &&
              "Expected integer types only.");
       Value *V = Scalar;
       if (auto *CI = dyn_cast<CastInst>(Scalar);
@@ -11632,10 +11670,20 @@ Value *BoUpSLP::gather(ArrayRef<Value *> VL, Value *Root, Type *ScalarTy) {
           V, Ty, !isKnownNonNegative(Scalar, SimplifyQuery(*DL)));
     }
 
-    Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
-    auto *InsElt = dyn_cast<InsertElementInst>(Vec);
-    if (!InsElt)
-      return Vec;
+    Instruction *InsElt;
+    if (auto *VecTy = dyn_cast<FixedVectorType>(Scalar->getType())) {
+      Vec = InsElt = Builder.CreateInsertVector(
+          Vec->getType(), Vec, V,
+          Builder.getInt64(Pos * VecTy->getNumElements()));
+      auto *II = dyn_cast<IntrinsicInst>(InsElt);
+      if (!(II && II->getIntrinsicID() == Intrinsic::vector_insert))
+        return Vec;
+    } else {
+      Vec = Builder.CreateInsertElement(Vec, Scalar, Builder.getInt32(Pos));
+      InsElt = dyn_cast<InsertElementInst>(Vec);
+      if (!InsElt)
+        return Vec;
+    }
     GatherShuffleExtractSeq.insert(InsElt);
     CSEBlocks.insert(InsElt->getParent());
     // Add to our 'need-to-extract' list.
@@ -12123,6 +12171,8 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
     int VF = CommonMask.size();
     if (auto *FTy = dyn_cast<FixedVectorType>(V1->getType()))
       VF = FTy->getNumElements();
+    if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy))
+      VF /= VecTy->getNumElements();
     for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
       if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
         CommonMask[Idx] = Mask[Idx] + (It == InVectors.begin() ? 0 : VF);
@@ -12145,6 +12195,14 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
   finalize(ArrayRef<int> ExtMask, unsigned VF = 0,
            function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
     IsFinalized = true;
+    SmallVector<int> NewExtMask(ExtMask);
+    if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
+      transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
+                                             CommonMask);
+      transformScalarShuffleIndiciesToVector(VecTy->getNumElements(),
+                                             NewExtMask);
+      ExtMask = NewExtMask;
+    }
     if (Action) {
       Value *Vec = InVectors.front();
       if (InVectors.size() == 2) {
@@ -13906,7 +13964,17 @@ Value *BoUpSLP::vectorizeTree(
               CloneGEP->takeName(GEP);
             Ex = CloneGEP;
           } else {
-            Ex = Builder.CreateExtractElement(Vec, Lane);
+            if (auto *VecTy = dyn_cast<FixedVectorType>(Scalar->getType())) {
+              unsigned VecTyNumElements = VecTy->getNumElements();
+              // When REVEC is enabled, we need to extract a vector.
+              // Note: The element size of Scalar may be different from the
+              // element size of Vec.
+              Ex = Builder.CreateExtractVector(
+                  FixedVectorType::get(Vec->getType()->getScalarType(),
+                                       VecTyNumElements),
+                  Vec, Builder.getInt64(ExternalUse.Lane * VecTyNumElements));
+            } else
+              Ex = Builder.CreateExtractElement(Vec, Lane);
           }
           // If necessary, sign-extend or zero-extend ScalarRoot
           // to the larger type.
diff --git a/llvm/test/Transforms/SLPVectorizer/revec.ll b/llvm/test/Transforms/SLPVectorizer/revec.ll
index c2dc6d0ab73b7..84426ce6e96bf 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec.ll
@@ -58,3 +58,39 @@ entry:
   store <8 x i16> %4, ptr %5, align 2
   ret void
 }
+
+define void @test3(ptr %in, ptr %out) {
+; CHECK-LABEL: @test3(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = load <8 x float>, ptr [[IN:%.*]], align 4
+; CHECK-NEXT:    [[TMP1:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> poison, i64 8)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP1]], <8 x float> [[TMP0]], i64 0)
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <16 x float> [[TMP2]], <16 x float> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP4:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> zeroinitializer, i64 0)
+; CHECK-NEXT:    [[TMP5:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP4]], <8 x float> zeroinitializer, i64 8)
+; CHECK-NEXT:    [[TMP6:%.*]] = fmul <16 x float> [[TMP3]], [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> poison, <8 x float> poison, i64 0)
+; CHECK-NEXT:    [[TMP8:%.*]] = call <16 x float> @llvm.vector.insert.v16f32.v8f32(<16 x float> [[TMP7]], <8 x float> zeroinitializer, i64 8)
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <16 x float> [[TMP2]], <16 x float> [[TMP8]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+; CHECK-NEXT:    [[TMP10:%.*]] = fadd <16 x float> [[TMP9]], [[TMP6]]
+; CHECK-NEXT:    [[TMP11:%.*]] = fcmp ogt <16 x float> [[TMP10]], [[TMP5]]
+; CHECK-NEXT:    [[TMP12:%.*]] = getelementptr i1, ptr [[OUT:%.*]], i64 8
+; CHECK-NEXT:    [[TMP13:%.*]] = call <8 x i1> @llvm.vector.extract.v8i1.v16i1(<16 x i1> [[TMP11]], i64 8)
+; CHECK-NEXT:    store <8 x i1> [[TMP13]], ptr [[OUT]], align 1
+; CHECK-NEXT:    [[TMP14:%.*]] = call <8 x i1> @llvm.vector.extract.v8i1.v16i1(<16 x i1> [[TMP11]], i64 0)
+; CHECK-NEXT:    store <8 x i1> [[TMP14]], ptr [[TMP12]], align 1
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = load <8 x float>, ptr %in, align 4
+  %1 = fmul <8 x float> %0, zeroinitializer
+  %2 = fmul <8 x float> %0, zeroinitializer
+  %3 = fadd <8 x float> zeroinitializer, %1
+  %4 = fadd <8 x float> %0, %2
+  %5 = fcmp ogt <8 x float> %3, zeroinitializer
+  %6 = fcmp ogt <8 x float> %4, zeroinitializer
+  %7 = getelementptr i1, ptr %out, i64 8
+  store <8 x i1> %5, ptr %out, align 1
+  store <8 x i1> %6, ptr %7, align 1
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/99499


More information about the llvm-commits mailing list