[llvm] [SandboxVec][Legality] Implement ShuffleMask (PR #123404)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 17 13:46:33 PST 2025


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/123404

This patch implements a helper ShuffleMask data structure that helps describe shuffles of elements across lanes.

>From d3a34d7fd8d5d96ba329f345bb14f1794c8297d9 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 19 Nov 2024 09:29:16 -0800
Subject: [PATCH] [SandboxVec][Legality] Implement ShuffleMask

This patch implements a helper ShuffleMask data structure that helps
describe shuffles of elements across lanes.
---
 .../Vectorize/SandboxVectorizer/Legality.h    | 89 +++++++++++++++--
 .../SandboxVectorizer/Passes/BottomUpVec.h    |  2 +
 .../Vectorize/SandboxVectorizer/Legality.cpp  | 14 ++-
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  | 13 +++
 .../SandboxVectorizer/bottomup_basic.ll       | 21 ++++
 .../SandboxVectorizer/LegalityTest.cpp        | 97 ++++++++++++++++++-
 6 files changed, 219 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index c03e7a10397ad2..4858ebaf0770aa 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -25,10 +25,62 @@ class LegalityAnalysis;
 class Value;
 class InstrMaps;
 
+class ShuffleMask {
+public:
+  using IndicesVecT = SmallVector<int, 8>;
+
+private:
+  IndicesVecT Indices;
+
+public:
+  ShuffleMask(SmallVectorImpl<int> &&Indices) : Indices(std::move(Indices)) {}
+  ShuffleMask(std::initializer_list<int> Indices) : Indices(Indices) {}
+  explicit ShuffleMask(ArrayRef<int> Indices) : Indices(Indices) {}
+  operator ArrayRef<int>() const { return Indices; }
+  /// Creates and returns an identity shuffle mask of size \p Sz.
+  /// For example if Sz == 4 the returned mask is {0, 1, 2, 3}.
+  static ShuffleMask getIdentity(unsigned Sz) {
+    IndicesVecT Indices;
+    Indices.reserve(Sz);
+    for (auto Idx : seq<int>(0, (int)Sz))
+      Indices.push_back(Idx);
+    return ShuffleMask(std::move(Indices));
+  }
+  /// \Returns true if the mask is a perfect identity mask with consecutive
+  /// indices, i.e., performs no lane shuffling, like 0,1,2,3...
+  bool isIdentity() const {
+    for (auto [Idx, Elm] : enumerate(Indices)) {
+      if ((int)Idx != Elm)
+        return false;
+    }
+    return true;
+  }
+  bool operator==(const ShuffleMask &Other) const {
+    return Indices == Other.Indices;
+  }
+  bool operator!=(const ShuffleMask &Other) const { return !(*this == Other); }
+  size_t size() const { return Indices.size(); }
+  int operator[](int Idx) const { return Indices[Idx]; }
+  using const_iterator = IndicesVecT::const_iterator;
+  const_iterator begin() const { return Indices.begin(); }
+  const_iterator end() const { return Indices.end(); }
+#ifndef NDEBUG
+  friend raw_ostream &operator<<(raw_ostream &OS, const ShuffleMask &Mask) {
+    Mask.print(OS);
+    return OS;
+  }
+  void print(raw_ostream &OS) const {
+    interleave(Indices, OS, [&OS](auto Elm) { OS << Elm; }, ",");
+  }
+  LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
 enum class LegalityResultID {
-  Pack,         ///> Collect scalar values.
-  Widen,        ///> Vectorize by combining scalars to a vector.
-  DiamondReuse, ///> Don't generate new code, reuse existing vector.
+  Pack,                    ///> Collect scalar values.
+  Widen,                   ///> Vectorize by combining scalars to a vector.
+  DiamondReuse,            ///> Don't generate new code, reuse existing vector.
+  DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle.
 };
 
 /// The reason for vectorizing or not vectorizing.
@@ -54,6 +106,8 @@ struct ToStr {
       return "Widen";
     case LegalityResultID::DiamondReuse:
       return "DiamondReuse";
+    case LegalityResultID::DiamondReuseWithShuffle:
+      return "DiamondReuseWithShuffle";
     }
     llvm_unreachable("Unknown LegalityResultID enum");
   }
@@ -154,6 +208,22 @@ class DiamondReuse final : public LegalityResult {
   Value *getVector() const { return Vec; }
 };
 
+class DiamondReuseWithShuffle final : public LegalityResult {
+  friend class LegalityAnalysis;
+  Value *Vec;
+  ShuffleMask Mask;
+  DiamondReuseWithShuffle(Value *Vec, const ShuffleMask &Mask)
+      : LegalityResult(LegalityResultID::DiamondReuseWithShuffle), Vec(Vec),
+        Mask(Mask) {}
+
+public:
+  static bool classof(const LegalityResult *From) {
+    return From->getSubclassID() == LegalityResultID::DiamondReuseWithShuffle;
+  }
+  Value *getVector() const { return Vec; }
+  const ShuffleMask &getMask() const { return Mask; }
+};
+
 class Pack final : public LegalityResultWithReason {
   Pack(ResultReason Reason)
       : LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
@@ -192,23 +262,22 @@ class CollectDescr {
   CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
       : Descrs(std::move(Descrs)) {}
   /// If all elements come from a single vector input, then return that vector
-  /// and whether we need a shuffle to get them in order.
-  std::optional<std::pair<Value *, bool>> getSingleInput() const {
+  /// and also the shuffle mask required to get them in order.
+  std::optional<std::pair<Value *, ShuffleMask>> getSingleInput() const {
     const auto &Descr0 = *Descrs.begin();
     Value *V0 = Descr0.getValue();
     if (!Descr0.needsExtract())
       return std::nullopt;
-    bool NeedsShuffle = Descr0.getExtractIdx() != 0;
-    int Lane = 1;
+    ShuffleMask::IndicesVecT MaskIndices;
+    MaskIndices.push_back(Descr0.getExtractIdx());
     for (const auto &Descr : drop_begin(Descrs)) {
       if (!Descr.needsExtract())
         return std::nullopt;
       if (Descr.getValue() != V0)
         return std::nullopt;
-      if (Descr.getExtractIdx() != Lane++)
-        NeedsShuffle = true;
+      MaskIndices.push_back(Descr.getExtractIdx());
     }
-    return std::make_pair(V0, NeedsShuffle);
+    return std::make_pair(V0, ShuffleMask(std::move(MaskIndices)));
   }
   bool hasVectorInputs() const {
     return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index dd3012f7c9b556..ac051c3b6570ff 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -36,6 +36,8 @@ class BottomUpVec final : public FunctionPass {
   /// Erases all dead instructions from the dead instruction candidates
   /// collected during vectorization.
   void tryEraseDeadInstrs();
+  /// Creates a shuffle instruction that shuffles \p VecOp according to \p Mask.
+  Value *createShuffle(Value *VecOp, const ShuffleMask &Mask);
   /// Packs all elements of \p ToPack into a vector and returns that vector.
   Value *createPack(ArrayRef<Value *> ToPack);
   void collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index f8149c5bc66363..ad3e38e2f1d923 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -20,6 +20,11 @@ namespace llvm::sandboxir {
 #define DEBUG_TYPE "SBVec:Legality"
 
 #ifndef NDEBUG
+void ShuffleMask::dump() const {
+  print(dbgs());
+  dbgs() << "\n";
+}
+
 void LegalityResult::dump() const {
   print(dbgs());
   dbgs() << "\n";
@@ -213,13 +218,12 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
   auto CollectDescrs = getHowToCollectValues(Bndl);
   if (CollectDescrs.hasVectorInputs()) {
     if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
-      auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
-      if (!NeedsShuffle)
+      auto [Vec, Mask] = *ValueShuffleOpt;
+      if (Mask.isIdentity())
         return createLegalityResult<DiamondReuse>(Vec);
-      llvm_unreachable("TODO: Unimplemented");
-    } else {
-      llvm_unreachable("TODO: Unimplemented");
+      return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
     }
+    llvm_unreachable("TODO: Unimplemented");
   }
 
   if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index b8e2697839a3c2..d62023ea018846 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -179,6 +179,12 @@ void BottomUpVec::tryEraseDeadInstrs() {
   DeadInstrCandidates.clear();
 }
 
+Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask) {
+  BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp});
+  return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt,
+                                   VecOp->getContext(), "VShuf");
+}
+
 Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
   BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
 
@@ -295,6 +301,13 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
     NewVec = cast<DiamondReuse>(LegalityRes).getVector();
     break;
   }
+  case LegalityResultID::DiamondReuseWithShuffle: {
+    auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector();
+    const ShuffleMask &Mask =
+        cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
+    NewVec = createShuffle(VecOp, Mask);
+    break;
+  }
   case LegalityResultID::Pack: {
     // If we can't vectorize the seeds then just return.
     if (Depth == 0)
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 7bc6e5ac3d7605..a3798af8399087 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -221,3 +221,24 @@ define void @diamond(ptr %ptr) {
   store float %sub1, ptr %ptr1
   ret void
 }
+
+define void @diamondWithShuffle(ptr %ptr) {
+; CHECK-LABEL: define void @diamondWithShuffle(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[VSHUF:%.*]] = shufflevector <2 x float> [[VECL]], <2 x float> [[VECL]], <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VSHUF]]
+; CHECK-NEXT:    store <2 x float> [[VEC]], ptr [[PTR0]], align 4
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %ptr0
+  %ld1 = load float, ptr %ptr1
+  %sub0 = fsub float %ld0, %ld1
+  %sub1 = fsub float %ld1, %ld0
+  store float %sub0, ptr %ptr0
+  store float %sub1, ptr %ptr1
+  ret void
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 069bfdba0a7cdb..b421d08bc6b020 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -19,6 +19,7 @@
 #include "llvm/SandboxIR/Instruction.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -321,7 +322,7 @@ define void @foo(ptr %ptr) {
     sandboxir::CollectDescr CD(std::move(Descrs));
     EXPECT_TRUE(CD.getSingleInput());
     EXPECT_EQ(CD.getSingleInput()->first, VLd);
-    EXPECT_EQ(CD.getSingleInput()->second, false);
+    EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(0, 1));
     EXPECT_TRUE(CD.hasVectorInputs());
   }
   {
@@ -331,7 +332,7 @@ define void @foo(ptr %ptr) {
     sandboxir::CollectDescr CD(std::move(Descrs));
     EXPECT_TRUE(CD.getSingleInput());
     EXPECT_EQ(CD.getSingleInput()->first, VLd);
-    EXPECT_EQ(CD.getSingleInput()->second, true);
+    EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(1, 0));
     EXPECT_TRUE(CD.hasVectorInputs());
   }
   {
@@ -352,3 +353,95 @@ define void @foo(ptr %ptr) {
     EXPECT_FALSE(CD.hasVectorInputs());
   }
 }
+
+TEST_F(LegalityTest, ShuffleMask) {
+  {
+    // Check SmallVector constructor.
+    SmallVector<int> Indices({0, 1, 2, 3});
+    sandboxir::ShuffleMask Mask(std::move(Indices));
+    EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
+  }
+  {
+    // Check initializer_list constructor.
+    sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+    EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
+  }
+  {
+    // Check ArrayRef constructor.
+    sandboxir::ShuffleMask Mask(ArrayRef<int>({0, 1, 2, 3}));
+    EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
+  }
+  {
+    // Check operator ArrayRef<int>().
+    sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+    ArrayRef<int> Array = Mask;
+    EXPECT_THAT(Array, testing::ElementsAre(0, 1, 2, 3));
+  }
+  {
+    // Check getIdentity().
+    auto IdentityMask = sandboxir::ShuffleMask::getIdentity(4);
+    EXPECT_THAT(IdentityMask, testing::ElementsAre(0, 1, 2, 3));
+    EXPECT_TRUE(IdentityMask.isIdentity());
+  }
+  {
+    // Check isIdentity().
+    sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
+    EXPECT_TRUE(Mask1.isIdentity());
+    sandboxir::ShuffleMask Mask2({1, 2, 3, 4});
+    EXPECT_FALSE(Mask2.isIdentity());
+  }
+  {
+    // Check operator==().
+    sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
+    sandboxir::ShuffleMask Mask2({0, 1, 2, 3});
+    EXPECT_TRUE(Mask1 == Mask2);
+    EXPECT_FALSE(Mask1 != Mask2);
+  }
+  {
+    // Check operator!=().
+    sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
+    sandboxir::ShuffleMask Mask2({0, 1, 2, 4});
+    EXPECT_TRUE(Mask1 != Mask2);
+    EXPECT_FALSE(Mask1 == Mask2);
+  }
+  {
+    // Check size().
+    sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+    EXPECT_EQ(Mask.size(), 4u);
+  }
+  {
+    // Check operator[].
+    sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+    for (auto [Idx, Elm] : enumerate(Mask)) {
+      EXPECT_EQ(Elm, Mask[Idx]);
+    }
+  }
+  {
+    // Check begin(), end().
+    sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+    sandboxir::ShuffleMask::const_iterator Begin = Mask.begin();
+    sandboxir::ShuffleMask::const_iterator End = Mask.begin();
+    int Idx = 0;
+    for (auto It = Begin; It != End; ++It) {
+      EXPECT_EQ(*It, Mask[Idx++]);
+    }
+  }
+#ifndef NDEBUG
+  {
+    // Check print(OS).
+    sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+    std::string Str;
+    raw_string_ostream OS(Str);
+    Mask.print(OS);
+    EXPECT_EQ(Str, "0,1,2,3");
+  }
+  {
+    // Check operator<<().
+    sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+    std::string Str;
+    raw_string_ostream OS(Str);
+    OS << Mask;
+    EXPECT_EQ(Str, "0,1,2,3");
+  }
+#endif // NDEBUG
+}



More information about the llvm-commits mailing list