[llvm] [SLP]Check if masked gather can be emitted as a serie of loads/insert subvector. (PR #83481)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 1 10:49:30 PST 2024

https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/83481

>From 44757efeb7372c5520a1412165feb70eea6940cc Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 29 Feb 2024 21:12:43 +0000
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 89 +++++++++++++++++--
 .../X86/scatter-vectorize-reused-pointer.ll   | 22 +++--
 2 files changed, 94 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 94b7c4952f055e..f817d1a304b5b8 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -3997,12 +3997,14 @@ static bool isReverseOrder(ArrayRef<unsigned> Order) {
 /// Checks if the given array of loads can be represented as a vectorized,
 /// scatter or just simple gather.
-static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
+static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
+                                    const Value *VL0,
                                     const TargetTransformInfo &TTI,
                                     const DataLayout &DL, ScalarEvolution &SE,
                                     LoopInfo &LI, const TargetLibraryInfo &TLI,
                                     SmallVectorImpl<unsigned> &Order,
-                                    SmallVectorImpl<Value *> &PointerOps) {
+                                    SmallVectorImpl<Value *> &PointerOps,
+                                    bool TryRecursiveCheck = true) {
   // Check that a vectorized load would load the same memory as a scalar
   // load. For example, we don't want to vectorize loads that are smaller
   // than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM
@@ -4095,6 +4097,68 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
+    auto CheckForShuffledLoads = [&](Align CommonAlignment) {
+      unsigned Sz = DL.getTypeSizeInBits(ScalarTy);
+      unsigned MinVF = R.getMinVF(Sz);
+      unsigned MaxVF = std::max<unsigned>(bit_floor(VL.size() / 2), MinVF);
+      MaxVF = std::min(R.getMaximumVF(Sz, Instruction::Load), MaxVF);
+      for (unsigned VF = MaxVF; VF >= MinVF; VF /= 2) {
+        unsigned VectorizedCnt = 0;
+        SmallVector<LoadsState> States;
+        for (unsigned Cnt = 0, End = VL.size(); Cnt + VF <= End;
+             Cnt += VF, ++VectorizedCnt) {
+          ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
+          SmallVector<unsigned> Order;
+          SmallVector<Value *> PointerOps;
+          LoadsState LS =
+              canVectorizeLoads(R, Slice, Slice.front(), TTI, DL, SE, LI, TLI,
+                                Order, PointerOps, /*TryRecursiveCheck=*/false);
+          // Check that the sorted loads are consecutive.
+          if (LS != LoadsState::Vectorize && LS != LoadsState::StridedVectorize)
+            break;
+          States.push_back(LS);
+        }
+        // Can be vectorized later as a serie of loads/insertelements.
+        if (VectorizedCnt == VL.size() / VF) {
+          // Compare masked gather cost and loads + insersubvector costs.
+          TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+          InstructionCost MaskedGatherCost = TTI.getGatherScatterOpCost(
+              Instruction::Load, VecTy,
+              cast<LoadInst>(VL0)->getPointerOperand(),
+              /*VariableMask=*/false, CommonAlignment, CostKind);
+          InstructionCost VecLdCost = 0;
+          auto *SubVecTy = FixedVectorType::get(ScalarTy, VF);
+          for (auto [I, LS] : enumerate(States)) {
+            auto *LI0 = cast<LoadInst>(VL[I * VF]);
+            switch (LS) {
+            case LoadsState::Vectorize:
+              VecLdCost += TTI.getMemoryOpCost(
+                  Instruction::Load, SubVecTy, LI0->getAlign(),
+                  LI0->getPointerAddressSpace(), CostKind,
+                  TTI::OperandValueInfo());
+              break;
+            case LoadsState::StridedVectorize:
+              VecLdCost += TTI.getStridedMemoryOpCost(
+                  Instruction::Load, SubVecTy, LI0->getPointerOperand(),
+                  /*VariableMask=*/false, CommonAlignment, CostKind);
+              break;
+            case LoadsState::ScatterVectorize:
+            case LoadsState::Gather:
+              llvm_unreachable("Expected only consecutive or strided loads.");
+            }
+            VecLdCost +=
+                TTI.getShuffleCost(TTI ::SK_InsertSubvector, VecTy,
+                                   std::nullopt, CostKind, I * VF, SubVecTy);
+          }
+          // If masked gather cost is higher - better to vectorize, so
+          // consider it as a gather node. It will be better estimated
+          // later.
+          if (MaskedGatherCost > VecLdCost)
+            return true;
+        }
+      }
+      return false;
+    };
     // TODO: need to improve analysis of the pointers, if not all of them are
     // GEPs or have > 2 operands, we end up with a gather node, which just
     // increases the cost.
@@ -4111,8 +4175,17 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
         })) {
       Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
       if (TTI.isLegalMaskedGather(VecTy, CommonAlignment) &&
-          !TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment))
+          !TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment)) {
+        // Check if potential masked gather can be represented as series
+        // of loads + insertsubvectors.
+        if (TryRecursiveCheck && CheckForShuffledLoads(CommonAlignment)) {
+          // If masked gather cost is higher - better to vectorize, so
+          // consider it as a gather node. It will be better estimated
+          // later.
+          return LoadsState::Gather;
+        }
         return LoadsState::ScatterVectorize;
+      }
@@ -5560,8 +5633,8 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
     // treats loading/storing it as an i8 struct. If we vectorize loads/stores
     // from such a struct, we read/write packed bits disagreeing with the
     // unvectorized version.
-    switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, *TLI, CurrentOrder,
-                              PointerOps)) {
+    switch (canVectorizeLoads(*this, VL, VL0, *TTI, *DL, *SE, *LI, *TLI,
+                              CurrentOrder, PointerOps)) {
     case LoadsState::Vectorize:
       return TreeEntry::Vectorize;
     case LoadsState::ScatterVectorize:
@@ -7341,9 +7414,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
               !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) {
             SmallVector<Value *> PointerOps;
             OrdersType CurrentOrder;
-            LoadsState LS =
-                canVectorizeLoads(Slice, Slice.front(), TTI, *R.DL, *R.SE,
-                                  *R.LI, *R.TLI, CurrentOrder, PointerOps);
+            LoadsState LS = canVectorizeLoads(R, Slice, Slice.front(), TTI,
+                                              *R.DL, *R.SE, *R.LI, *R.TLI,
+                                              CurrentOrder, PointerOps);
             switch (LS) {
             case LoadsState::Vectorize:
             case LoadsState::ScatterVectorize:
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/scatter-vectorize-reused-pointer.ll b/llvm/test/Transforms/SLPVectorizer/X86/scatter-vectorize-reused-pointer.ll
index bb16b52f44ecf7..dadf5992ba288d 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/scatter-vectorize-reused-pointer.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/scatter-vectorize-reused-pointer.ll
@@ -5,19 +5,23 @@ define void @test(i1 %c, ptr %arg) {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <4 x ptr> poison, ptr [[ARG:%.*]], i32 0
-; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x ptr> [[TMP1]], <4 x ptr> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, <4 x ptr> [[TMP2]], <4 x i64> <i64 32, i64 24, i64 8, i64 0>
-; CHECK-NEXT:    [[TMP4:%.*]] = call <4 x i64> @llvm.masked.gather.v4i64.v4p0(<4 x ptr> [[TMP3]], i32 8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i64> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = load <2 x i64>, ptr [[ARG:%.*]], align 8
+; CHECK-NEXT:    [[ARG2_2:%.*]] = getelementptr inbounds i8, ptr [[ARG]], i64 24
+; CHECK-NEXT:    [[TMP2:%.*]] = load <2 x i64>, ptr [[ARG2_2]], align 8
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <2 x i64> [[TMP2]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> [[TMP4]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
 ; CHECK-NEXT:    br label [[JOIN:%.*]]
 ; CHECK:       else:
-; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <4 x ptr> poison, ptr [[ARG]], i32 0
-; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <4 x ptr> [[TMP5]], <4 x ptr> poison, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, <4 x ptr> [[TMP6]], <4 x i64> <i64 32, i64 24, i64 8, i64 0>
-; CHECK-NEXT:    [[TMP8:%.*]] = call <4 x i64> @llvm.masked.gather.v4i64.v4p0(<4 x ptr> [[TMP7]], i32 8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i64> poison)
+; CHECK-NEXT:    [[TMP6:%.*]] = load <2 x i64>, ptr [[ARG]], align 8
+; CHECK-NEXT:    [[ARG_2:%.*]] = getelementptr inbounds i8, ptr [[ARG]], i64 24
+; CHECK-NEXT:    [[TMP7:%.*]] = load <2 x i64>, ptr [[ARG_2]], align 8
+; CHECK-NEXT:    [[TMP8:%.*]] = shufflevector <2 x i64> [[TMP7]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP9:%.*]] = shufflevector <2 x i64> [[TMP6]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
+; CHECK-NEXT:    [[TMP10:%.*]] = shufflevector <4 x i64> [[TMP8]], <4 x i64> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
 ; CHECK-NEXT:    br label [[JOIN]]
 ; CHECK:       join:
-; CHECK-NEXT:    [[TMP9:%.*]] = phi <4 x i64> [ [[TMP4]], [[IF]] ], [ [[TMP8]], [[ELSE]] ]
+; CHECK-NEXT:    [[TMP11:%.*]] = phi <4 x i64> [ [[TMP5]], [[IF]] ], [ [[TMP10]], [[ELSE]] ]
 ; CHECK-NEXT:    ret void
   br i1 %c, label %if, label %else

>From 520654af88b993da013e2473da2566e28254ddfb Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Thu, 29 Feb 2024 21:16:11 +0000
Subject: [PATCH 2/2] Fix formatting

Created using spr 1.3.5
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f817d1a304b5b8..b417b476f6c815 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7414,9 +7414,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
               !VectorizedLoads.count(Slice.back()) && allSameBlock(Slice)) {
             SmallVector<Value *> PointerOps;
             OrdersType CurrentOrder;
-            LoadsState LS = canVectorizeLoads(R, Slice, Slice.front(), TTI,
-                                              *R.DL, *R.SE, *R.LI, *R.TLI,
-                                              CurrentOrder, PointerOps);
+            LoadsState LS =
+                canVectorizeLoads(R, Slice, Slice.front(), TTI, *R.DL, *R.SE,
+                                  *R.LI, *R.TLI, CurrentOrder, PointerOps);
             switch (LS) {
             case LoadsState::Vectorize:
             case LoadsState::ScatterVectorize:

More information about the llvm-commits mailing list