[llvm] 87e4b68 - [SandboxVec][Legality] Implement ShuffleMask (#123404)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 17 15:48:28 PST 2025
Author: vporpo
Date: 2025-01-17T15:48:24-08:00
New Revision: 87e4b68195adc81fae40a4fa27e33458a9586fe5
URL: https://github.com/llvm/llvm-project/commit/87e4b68195adc81fae40a4fa27e33458a9586fe5
DIFF: https://github.com/llvm/llvm-project/commit/87e4b68195adc81fae40a4fa27e33458a9586fe5.diff
LOG: [SandboxVec][Legality] Implement ShuffleMask (#123404)
This patch implements a helper ShuffleMask data structure that helps
describe shuffles of elements across lanes.
Added:
Modified:
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index c03e7a10397ad2..4858ebaf0770aa 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -25,10 +25,62 @@ class LegalityAnalysis;
class Value;
class InstrMaps;
+class ShuffleMask {
+public:
+ using IndicesVecT = SmallVector<int, 8>;
+
+private:
+ IndicesVecT Indices;
+
+public:
+ ShuffleMask(SmallVectorImpl<int> &&Indices) : Indices(std::move(Indices)) {}
+ ShuffleMask(std::initializer_list<int> Indices) : Indices(Indices) {}
+ explicit ShuffleMask(ArrayRef<int> Indices) : Indices(Indices) {}
+ operator ArrayRef<int>() const { return Indices; }
+ /// Creates and returns an identity shuffle mask of size \p Sz.
+ /// For example if Sz == 4 the returned mask is {0, 1, 2, 3}.
+ static ShuffleMask getIdentity(unsigned Sz) {
+ IndicesVecT Indices;
+ Indices.reserve(Sz);
+ for (auto Idx : seq<int>(0, (int)Sz))
+ Indices.push_back(Idx);
+ return ShuffleMask(std::move(Indices));
+ }
+ /// \Returns true if the mask is a perfect identity mask with consecutive
+ /// indices, i.e., performs no lane shuffling, like 0,1,2,3...
+ bool isIdentity() const {
+ for (auto [Idx, Elm] : enumerate(Indices)) {
+ if ((int)Idx != Elm)
+ return false;
+ }
+ return true;
+ }
+ bool operator==(const ShuffleMask &Other) const {
+ return Indices == Other.Indices;
+ }
+ bool operator!=(const ShuffleMask &Other) const { return !(*this == Other); }
+ size_t size() const { return Indices.size(); }
+ int operator[](int Idx) const { return Indices[Idx]; }
+ using const_iterator = IndicesVecT::const_iterator;
+ const_iterator begin() const { return Indices.begin(); }
+ const_iterator end() const { return Indices.end(); }
+#ifndef NDEBUG
+ friend raw_ostream &operator<<(raw_ostream &OS, const ShuffleMask &Mask) {
+ Mask.print(OS);
+ return OS;
+ }
+ void print(raw_ostream &OS) const {
+ interleave(Indices, OS, [&OS](auto Elm) { OS << Elm; }, ",");
+ }
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+};
+
enum class LegalityResultID {
- Pack, ///> Collect scalar values.
- Widen, ///> Vectorize by combining scalars to a vector.
- DiamondReuse, ///> Don't generate new code, reuse existing vector.
+ Pack, ///> Collect scalar values.
+ Widen, ///> Vectorize by combining scalars to a vector.
+ DiamondReuse, ///> Don't generate new code, reuse existing vector.
+ DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle.
};
/// The reason for vectorizing or not vectorizing.
@@ -54,6 +106,8 @@ struct ToStr {
return "Widen";
case LegalityResultID::DiamondReuse:
return "DiamondReuse";
+ case LegalityResultID::DiamondReuseWithShuffle:
+ return "DiamondReuseWithShuffle";
}
llvm_unreachable("Unknown LegalityResultID enum");
}
@@ -154,6 +208,22 @@ class DiamondReuse final : public LegalityResult {
Value *getVector() const { return Vec; }
};
+class DiamondReuseWithShuffle final : public LegalityResult {
+ friend class LegalityAnalysis;
+ Value *Vec;
+ ShuffleMask Mask;
+ DiamondReuseWithShuffle(Value *Vec, const ShuffleMask &Mask)
+ : LegalityResult(LegalityResultID::DiamondReuseWithShuffle), Vec(Vec),
+ Mask(Mask) {}
+
+public:
+ static bool classof(const LegalityResult *From) {
+ return From->getSubclassID() == LegalityResultID::DiamondReuseWithShuffle;
+ }
+ Value *getVector() const { return Vec; }
+ const ShuffleMask &getMask() const { return Mask; }
+};
+
class Pack final : public LegalityResultWithReason {
Pack(ResultReason Reason)
: LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
@@ -192,23 +262,22 @@ class CollectDescr {
CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
: Descrs(std::move(Descrs)) {}
/// If all elements come from a single vector input, then return that vector
- /// and whether we need a shuffle to get them in order.
- std::optional<std::pair<Value *, bool>> getSingleInput() const {
+ /// and also the shuffle mask required to get them in order.
+ std::optional<std::pair<Value *, ShuffleMask>> getSingleInput() const {
const auto &Descr0 = *Descrs.begin();
Value *V0 = Descr0.getValue();
if (!Descr0.needsExtract())
return std::nullopt;
- bool NeedsShuffle = Descr0.getExtractIdx() != 0;
- int Lane = 1;
+ ShuffleMask::IndicesVecT MaskIndices;
+ MaskIndices.push_back(Descr0.getExtractIdx());
for (const auto &Descr : drop_begin(Descrs)) {
if (!Descr.needsExtract())
return std::nullopt;
if (Descr.getValue() != V0)
return std::nullopt;
- if (Descr.getExtractIdx() != Lane++)
- NeedsShuffle = true;
+ MaskIndices.push_back(Descr.getExtractIdx());
}
- return std::make_pair(V0, NeedsShuffle);
+ return std::make_pair(V0, ShuffleMask(std::move(MaskIndices)));
}
bool hasVectorInputs() const {
return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index dd3012f7c9b556..ac051c3b6570ff 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -36,6 +36,8 @@ class BottomUpVec final : public FunctionPass {
/// Erases all dead instructions from the dead instruction candidates
/// collected during vectorization.
void tryEraseDeadInstrs();
+ /// Creates a shuffle instruction that shuffles \p VecOp according to \p Mask.
+ Value *createShuffle(Value *VecOp, const ShuffleMask &Mask);
/// Packs all elements of \p ToPack into a vector and returns that vector.
Value *createPack(ArrayRef<Value *> ToPack);
void collectPotentiallyDeadInstrs(ArrayRef<Value *> Bndl);
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index f8149c5bc66363..ad3e38e2f1d923 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -20,6 +20,11 @@ namespace llvm::sandboxir {
#define DEBUG_TYPE "SBVec:Legality"
#ifndef NDEBUG
+void ShuffleMask::dump() const {
+ print(dbgs());
+ dbgs() << "\n";
+}
+
void LegalityResult::dump() const {
print(dbgs());
dbgs() << "\n";
@@ -213,13 +218,12 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
auto CollectDescrs = getHowToCollectValues(Bndl);
if (CollectDescrs.hasVectorInputs()) {
if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
- auto [Vec, NeedsShuffle] = *ValueShuffleOpt;
- if (!NeedsShuffle)
+ auto [Vec, Mask] = *ValueShuffleOpt;
+ if (Mask.isIdentity())
return createLegalityResult<DiamondReuse>(Vec);
- llvm_unreachable("TODO: Unimplemented");
- } else {
- llvm_unreachable("TODO: Unimplemented");
+ return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
}
+ llvm_unreachable("TODO: Unimplemented");
}
if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index b8e2697839a3c2..d62023ea018846 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -179,6 +179,12 @@ void BottomUpVec::tryEraseDeadInstrs() {
DeadInstrCandidates.clear();
}
+Value *BottomUpVec::createShuffle(Value *VecOp, const ShuffleMask &Mask) {
+ BasicBlock::iterator WhereIt = getInsertPointAfterInstrs({VecOp});
+ return ShuffleVectorInst::create(VecOp, VecOp, Mask, WhereIt,
+ VecOp->getContext(), "VShuf");
+}
+
Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
@@ -295,6 +301,13 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
NewVec = cast<DiamondReuse>(LegalityRes).getVector();
break;
}
+ case LegalityResultID::DiamondReuseWithShuffle: {
+ auto *VecOp = cast<DiamondReuseWithShuffle>(LegalityRes).getVector();
+ const ShuffleMask &Mask =
+ cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
+ NewVec = createShuffle(VecOp, Mask);
+ break;
+ }
case LegalityResultID::Pack: {
// If we can't vectorize the seeds then just return.
if (Depth == 0)
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 7bc6e5ac3d7605..a3798af8399087 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -221,3 +221,24 @@ define void @diamond(ptr %ptr) {
store float %sub1, ptr %ptr1
ret void
}
+
+define void @diamondWithShuffle(ptr %ptr) {
+; CHECK-LABEL: define void @diamondWithShuffle(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[VSHUF:%.*]] = shufflevector <2 x float> [[VECL]], <2 x float> [[VECL]], <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT: [[VEC:%.*]] = fsub <2 x float> [[VECL]], [[VSHUF]]
+; CHECK-NEXT: store <2 x float> [[VEC]], ptr [[PTR0]], align 4
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr float, ptr %ptr, i32 0
+ %ptr1 = getelementptr float, ptr %ptr, i32 1
+ %ld0 = load float, ptr %ptr0
+ %ld1 = load float, ptr %ptr1
+ %sub0 = fsub float %ld0, %ld1
+ %sub1 = fsub float %ld1, %ld0
+ store float %sub0, ptr %ptr0
+ store float %sub1, ptr %ptr1
+ ret void
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 069bfdba0a7cdb..b421d08bc6b020 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -19,6 +19,7 @@
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
+#include "gmock/gmock.h"
#include "gtest/gtest.h"
using namespace llvm;
@@ -321,7 +322,7 @@ define void @foo(ptr %ptr) {
sandboxir::CollectDescr CD(std::move(Descrs));
EXPECT_TRUE(CD.getSingleInput());
EXPECT_EQ(CD.getSingleInput()->first, VLd);
- EXPECT_EQ(CD.getSingleInput()->second, false);
+ EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(0, 1));
EXPECT_TRUE(CD.hasVectorInputs());
}
{
@@ -331,7 +332,7 @@ define void @foo(ptr %ptr) {
sandboxir::CollectDescr CD(std::move(Descrs));
EXPECT_TRUE(CD.getSingleInput());
EXPECT_EQ(CD.getSingleInput()->first, VLd);
- EXPECT_EQ(CD.getSingleInput()->second, true);
+ EXPECT_THAT(CD.getSingleInput()->second, testing::ElementsAre(1, 0));
EXPECT_TRUE(CD.hasVectorInputs());
}
{
@@ -352,3 +353,95 @@ define void @foo(ptr %ptr) {
EXPECT_FALSE(CD.hasVectorInputs());
}
}
+
+TEST_F(LegalityTest, ShuffleMask) {
+ {
+ // Check SmallVector constructor.
+ SmallVector<int> Indices({0, 1, 2, 3});
+ sandboxir::ShuffleMask Mask(std::move(Indices));
+ EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
+ }
+ {
+ // Check initializer_list constructor.
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
+ }
+ {
+ // Check ArrayRef constructor.
+ sandboxir::ShuffleMask Mask(ArrayRef<int>({0, 1, 2, 3}));
+ EXPECT_THAT(Mask, testing::ElementsAre(0, 1, 2, 3));
+ }
+ {
+ // Check operator ArrayRef<int>().
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ ArrayRef<int> Array = Mask;
+ EXPECT_THAT(Array, testing::ElementsAre(0, 1, 2, 3));
+ }
+ {
+ // Check getIdentity().
+ auto IdentityMask = sandboxir::ShuffleMask::getIdentity(4);
+ EXPECT_THAT(IdentityMask, testing::ElementsAre(0, 1, 2, 3));
+ EXPECT_TRUE(IdentityMask.isIdentity());
+ }
+ {
+ // Check isIdentity().
+ sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
+ EXPECT_TRUE(Mask1.isIdentity());
+ sandboxir::ShuffleMask Mask2({1, 2, 3, 4});
+ EXPECT_FALSE(Mask2.isIdentity());
+ }
+ {
+ // Check operator==().
+ sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
+ sandboxir::ShuffleMask Mask2({0, 1, 2, 3});
+ EXPECT_TRUE(Mask1 == Mask2);
+ EXPECT_FALSE(Mask1 != Mask2);
+ }
+ {
+ // Check operator!=().
+ sandboxir::ShuffleMask Mask1({0, 1, 2, 3});
+ sandboxir::ShuffleMask Mask2({0, 1, 2, 4});
+ EXPECT_TRUE(Mask1 != Mask2);
+ EXPECT_FALSE(Mask1 == Mask2);
+ }
+ {
+ // Check size().
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ EXPECT_EQ(Mask.size(), 4u);
+ }
+ {
+ // Check operator[].
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ for (auto [Idx, Elm] : enumerate(Mask)) {
+ EXPECT_EQ(Elm, Mask[Idx]);
+ }
+ }
+ {
+ // Check begin(), end().
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ sandboxir::ShuffleMask::const_iterator Begin = Mask.begin();
+ sandboxir::ShuffleMask::const_iterator End = Mask.begin();
+ int Idx = 0;
+ for (auto It = Begin; It != End; ++It) {
+ EXPECT_EQ(*It, Mask[Idx++]);
+ }
+ }
+#ifndef NDEBUG
+ {
+ // Check print(OS).
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ std::string Str;
+ raw_string_ostream OS(Str);
+ Mask.print(OS);
+ EXPECT_EQ(Str, "0,1,2,3");
+ }
+ {
+ // Check operator<<().
+ sandboxir::ShuffleMask Mask({0, 1, 2, 3});
+ std::string Str;
+ raw_string_ostream OS(Str);
+ OS << Mask;
+ EXPECT_EQ(Str, "0,1,2,3");
+ }
+#endif // NDEBUG
+}
More information about the llvm-commits
mailing list