[llvm] [SandboxVec][Legality] Fix mask on diamond reuse with shuffle (PR #126963)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 12 11:55:12 PST 2025
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/126963
This patch fixes a bug in the creation of shuffle masks when vectorizing vectors in case of a diamond reuse with shuffle. The mask needs to enumerate all elements of a vector, not treat the original vector value as a single element. That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask needs to have 4 indices, not just 2.
>From 3408853f7fa716bc97f837f9b7a2855aa3b78000 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 29 Jan 2025 14:37:38 -0800
Subject: [PATCH] [SandboxVec][Legality] Fix mask on diamond reuse with shuffle
This patch fixes a bug in the creation of shuffle masks when vectorizing vectors
in case of a diamond reuse with shuffle. The mask needs to enumerate all
elements of a vector, not treat the original vector value as a single element.
That is: if vectorizing two <2 x float> vectors into a <4 x float> the mask
needs to have 4 indices, not just 2.
---
.../Vectorize/SandboxVectorizer/InstrMaps.h | 5 +++-
.../Vectorize/SandboxVectorizer/Legality.cpp | 10 +++++--
.../SandboxVectorizer/Passes/BottomUpVec.cpp | 2 ++
.../SandboxVectorizer/bottomup_basic.ll | 22 +++++++++++++++
.../SandboxVectorizer/InstrMapsTest.cpp | 27 +++++++++++++++++++
5 files changed, 63 insertions(+), 3 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index c931319d3b002..9bdf940fc77b7 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -18,6 +18,7 @@
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
#include <algorithm>
namespace llvm::sandboxir {
@@ -85,11 +86,13 @@ class InstrMaps {
/// Update the map to reflect that \p Origs got vectorized into \p Vec.
void registerVector(ArrayRef<Value *> Origs, Value *Vec) {
auto &OrigToLaneMap = VectorToOrigLaneMap[Vec];
- for (auto [Lane, Orig] : enumerate(Origs)) {
+ unsigned Lane = 0;
+ for (Value *Orig : Origs) {
auto Pair = OrigToVectorMap.try_emplace(Orig, Vec);
assert(Pair.second && "Orig already exists in the map!");
(void)Pair;
OrigToLaneMap[Orig] = Lane;
+ Lane += VecUtils::getNumLanes(Orig);
}
}
void clear() {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index c9329c24e1f4c..366243231379f 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -202,14 +202,20 @@ CollectDescr
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
Vec.reserve(Bndl.size());
- for (auto [Lane, V] : enumerate(Bndl)) {
+ uint32_t LaneAccum;
+ for (auto [Elm, V] : enumerate(Bndl)) {
+ uint32_t VLanes = VecUtils::getNumLanes(V);
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
// If there is a vector containing `V`, then get the lane it came from.
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
- Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt : -1);
+ // This could be a vector, like <2 x float> in which case the mask needs
+ // to enumerate all lanes.
+ for (int Ln = 0; Ln != VLanes; ++Ln)
+ Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
} else {
Vec.emplace_back(V);
}
+ LaneAccum += VLanes;
}
return CollectDescr(std::move(Vec));
}
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 507d163240127..4fb029d3344b8 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -328,6 +328,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
const ShuffleMask &Mask =
cast<DiamondReuseWithShuffle>(LegalityRes).getMask();
NewVec = createShuffle(VecOp, Mask, UserBB);
+ assert(NewVec->getType() == VecOp->getType() &&
+ "Expected same type! Bad mask ?");
break;
}
case LegalityResultID::DiamondReuseMultiInput: {
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 45b937dc1b1b6..301d6649669f4 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -243,6 +243,28 @@ define void @diamondWithShuffle(ptr %ptr) {
ret void
}
+; Same but with <2 x float> elements instead of scalars.
+define void @diamondWithShuffleFromVec(ptr %ptr) {
+; CHECK-LABEL: define void @diamondWithShuffleFromVec(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
+; CHECK-NEXT: [[VSHUF:%.*]] = shufflevector <4 x float> [[VECL]], <4 x float> [[VECL]], <4 x i32> <i32 2, i32 3, i32 0, i32 1>
+; CHECK-NEXT: [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VSHUF]]
+; CHECK-NEXT: store <4 x float> [[VEC]], ptr [[PTR0]], align 8
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
+ %ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
+ %ld0 = load <2 x float>, ptr %ptr0
+ %ld1 = load <2 x float>, ptr %ptr1
+ %sub0 = fsub <2 x float> %ld0, %ld1
+ %sub1 = fsub <2 x float> %ld1, %ld0
+ store <2 x float> %sub0, ptr %ptr0
+ store <2 x float> %sub1, ptr %ptr1
+ ret void
+}
+
define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
; CHECK-LABEL: define void @diamondMultiInput(
; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index 1d7c8f9cdde04..5b033f0edcb02 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -85,3 +85,30 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
}
+
+TEST_F(InstrMapsTest, VectorLanes) {
+ parseIR(C, R"IR(
+define void @foo(<2 x i8> %v0, <2 x i8> %v1, <4 x i8> %v2, <4 x i8> %v3) {
+ %vadd0 = add <2 x i8> %v0, %v1
+ %vadd1 = add <2 x i8> %v0, %v1
+ %vadd2 = add <4 x i8> %v2, %v3
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+
+ auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *VAdd1 = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *VAdd2 = cast<sandboxir::BinaryOperator>(&*It++);
+
+ sandboxir::InstrMaps IMaps(Ctx);
+
+ // Check that the vector lanes are calculated correctly.
+ IMaps.registerVector({VAdd0, VAdd1}, VAdd2);
+ EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd0), 0U);
+ EXPECT_EQ(*IMaps.getOrigLane(VAdd2, VAdd1), 2U);
+}
More information about the llvm-commits
mailing list