[llvm] [SanbdoxVec][BottomUpVec] Fix diamond shuffle with multiple vector inputs (PR #126965)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 12 12:33:35 PST 2025
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/126965
When the operand comes from multiple inputs then we need additional packing code. When the operands are scalar then we can use a single InsertElementInst. But when the operands are vectors then we need a chain of ExtractElementInst and InsertElementInst instructions to insert the vector value into the destination vector. This is what this patch implements.
>From 69a5bf571754cbbe7433a1e5f132655632d7b982 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 29 Jan 2025 16:23:21 -0800
Subject: [PATCH] [SanbdoxVec][BottomUpVec] Fix diamond shuffle with multiple
vector inputs
When the operand comes from multiple inputs then we need additional packing
code. When the operands are scalar then we can use a single InsertElementInst.
But when the operands are vectors then we need a chain of ExtractElementInst
and InsertElementInst instructions to insert the vector value into the
destination vector. This is what this patch implements.
---
.../Vectorize/SandboxVectorizer/Legality.cpp | 5 +--
.../SandboxVectorizer/Passes/BottomUpVec.cpp | 35 +++++++++++++++----
.../SandboxVectorizer/bottomup_basic.ll | 33 +++++++++++++++++
3 files changed, 63 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 366243231379f..e8331933594da 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -202,20 +202,17 @@ CollectDescr
LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
Vec.reserve(Bndl.size());
- uint32_t LaneAccum;
for (auto [Elm, V] : enumerate(Bndl)) {
- uint32_t VLanes = VecUtils::getNumLanes(V);
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
// If there is a vector containing `V`, then get the lane it came from.
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
// This could be a vector, like <2 x float> in which case the mask needs
// to enumerate all lanes.
- for (int Ln = 0; Ln != VLanes; ++Ln)
+ for (uint32_t Ln = 0, Lanes = VecUtils::getNumLanes(V); Ln != Lanes; ++Ln)
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
} else {
Vec.emplace_back(V);
}
- LaneAccum += VLanes;
}
return CollectDescr(std::move(Vec));
}
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 4fb029d3344b8..0ccef5aecd28b 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -335,7 +335,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
case LegalityResultID::DiamondReuseMultiInput: {
const auto &Descr =
cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
- Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
+ Type *ResTy = VecUtils::getWideType(Bndl[0]->getType(), Bndl.size());
// TODO: Try to get WhereIt without creating a vector.
SmallVector<Value *, 4> DescrInstrs;
@@ -347,7 +347,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
getInsertPointAfterInstrs(DescrInstrs, UserBB);
Value *LastV = PoisonValue::get(ResTy);
- for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
+ unsigned Lane = 0;
+ for (const auto &ElmDescr : Descr.getDescrs()) {
Value *VecOp = ElmDescr.getValue();
Context &Ctx = VecOp->getContext();
Value *ValueToInsert;
@@ -359,10 +360,32 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
} else {
ValueToInsert = VecOp;
}
- ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
- Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
- WhereIt, Ctx, "VIns");
- LastV = Ins;
+ auto NumLanesToInsert = VecUtils::getNumLanes(ValueToInsert);
+ if (NumLanesToInsert == 1) {
+ // If we are inserting a scalar element then we need a single insert.
+ // %VIns = insert %DstVec, %SrcScalar, Lane
+ ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
+ LastV = InsertElementInst::create(LastV, ValueToInsert, LaneC, WhereIt,
+ Ctx, "VIns");
+ } else {
+ // If we are inserting a vector element then we need to extract and
+ // insert each vector element one by one with a chain of extracts and
+ // inserts, for example:
+ // %VExt0 = extract %SrcVec, 0
+ // %VIns0 = insert %DstVec, %Vect0, Lane + 0
+ // %VExt1 = extract %SrcVec, 1
+ // %VIns1 = insert %VIns0, %Vect0, Lane + 1
+ for (unsigned LnCnt = 0; LnCnt != NumLanesToInsert; ++LnCnt) {
+ auto *ExtrIdxC = ConstantInt::get(Type::getInt32Ty(Ctx), LnCnt);
+ auto *ExtrI = ExtractElementInst::create(ValueToInsert, ExtrIdxC,
+ WhereIt, Ctx, "VExt");
+ unsigned InsLane = Lane + LnCnt;
+ auto *InsLaneC = ConstantInt::get(Type::getInt32Ty(Ctx), InsLane);
+ LastV = InsertElementInst::create(LastV, ExtrI, InsLaneC, WhereIt,
+ Ctx, "VIns");
+ }
+ }
+ Lane += NumLanesToInsert;
}
NewVec = LastV;
break;
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 301d6649669f4..6b18d4069e0ae 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -292,6 +292,39 @@ define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
ret void
}
+; Same but vectorizing <2 x float> vectors instead of scalars.
+define void @diamondMultiInputVector(ptr %ptr, ptr %ptrX) {
+; CHECK-LABEL: define void @diamondMultiInputVector(
+; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[LDX:%.*]] = load <2 x float>, ptr [[PTRX]], align 8
+; CHECK-NEXT: [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
+; CHECK-NEXT: [[VEXT:%.*]] = extractelement <2 x float> [[LDX]], i32 0
+; CHECK-NEXT: [[INSI:%.*]] = insertelement <4 x float> poison, float [[VEXT]], i32 0
+; CHECK-NEXT: [[VEXT1:%.*]] = extractelement <2 x float> [[LDX]], i32 1
+; CHECK-NEXT: [[INSI2:%.*]] = insertelement <4 x float> [[INSI]], float [[VEXT1]], i32 1
+; CHECK-NEXT: [[VEXT3:%.*]] = extractelement <4 x float> [[VECL]], i32 0
+; CHECK-NEXT: [[VINS4:%.*]] = insertelement <4 x float> [[INSI2]], float [[VEXT3]], i32 2
+; CHECK-NEXT: [[VEXT4:%.*]] = extractelement <4 x float> [[VECL]], i32 1
+; CHECK-NEXT: [[VINS5:%.*]] = insertelement <4 x float> [[VINS4]], float [[VEXT4]], i32 3
+; CHECK-NEXT: [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VINS5]]
+; CHECK-NEXT: store <4 x float> [[VEC]], ptr [[PTR0]], align 8
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
+ %ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
+ %ld0 = load <2 x float>, ptr %ptr0
+ %ld1 = load <2 x float>, ptr %ptr1
+
+ %ldX = load <2 x float>, ptr %ptrX
+
+ %sub0 = fsub <2 x float> %ld0, %ldX
+ %sub1 = fsub <2 x float> %ld1, %ld0
+ store <2 x float> %sub0, ptr %ptr0
+ store <2 x float> %sub1, ptr %ptr1
+ ret void
+}
+
define void @diamondWithConstantVector(ptr %ptr) {
; CHECK-LABEL: define void @diamondWithConstantVector(
; CHECK-SAME: ptr [[PTR:%.*]]) {
More information about the llvm-commits
mailing list