[llvm] [SandboxVec][BottomUpVec] Implement pack of scalars (PR #115549)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 8 14:01:19 PST 2024
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/115549
This patch implements packing of scalar operands when the vectorizer decides to stop vectorizing. Packing is implemented with a sequence of InsertElement instructions.
Packing vectors requires different instructions so it's implemented in a follow-up patch.
>From 56e533bbad9de2b14ee63ad3c1698d02f6fb1b29 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 5 Nov 2024 10:17:59 -0800
Subject: [PATCH] [SandboxVec][BottomUpVec] Implement pack of scalars
This patch implements packing of scalar operands when the vectorizer
decides to stop vectorizing. Packing is implemented with a sequence of
InsertElement instructions.
Packing vectors requires different instructions so it's implemented
in a follow-up patch.
---
.../SandboxVectorizer/Passes/BottomUpVec.h | 3 +-
.../Vectorize/SandboxVectorizer/VecUtils.h | 25 ++++++++++
.../SandboxVectorizer/Passes/BottomUpVec.cpp | 47 ++++++++++++++++---
.../SandboxVectorizer/bottomup_basic.ll | 44 +++++++++++++++++
.../SandboxVectorizer/VecUtilsTest.cpp | 33 +++++++++++++
5 files changed, 145 insertions(+), 7 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 02cd7650ad8a5a..6109db71611018 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -31,7 +31,8 @@ class BottomUpVec final : public FunctionPass {
/// \p Bndl. \p Operands are the already vectorized operands.
Value *createVectorInstr(ArrayRef<Value *> Bndl, ArrayRef<Value *> Operands);
void tryEraseDeadInstrs();
- Value *vectorizeRec(ArrayRef<Value *> Bndl);
+ Value *createPack(ArrayRef<Value *> ToPack);
+ Value *vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth);
bool tryVectorize(ArrayRef<Value *> Seeds);
// The PM containing the pipeline of region passes.
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index d44c845bfbf4e9..fc9d67fcfcdec4 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -108,6 +108,31 @@ class VecUtils {
}
return LowestI;
}
+ /// If all values in \p Bndl are of the same scalar type then return it,
+ /// otherwise return nullptr.
+ static Type *tryGetCommonScalarType(ArrayRef<Value *> Bndl) {
+ Value *V0 = Bndl[0];
+ Type *Ty0 = Utils::getExpectedType(V0);
+ Type *ScalarTy = VecUtils::getElementType(Ty0);
+ for (auto *V : drop_begin(Bndl)) {
+ Type *NTy = Utils::getExpectedType(V);
+ Type *NScalarTy = VecUtils::getElementType(NTy);
+ if (NScalarTy != ScalarTy)
+ return nullptr;
+ }
+ return ScalarTy;
+ }
+
+ /// Similar to tryGetCommonScalarType() but will assert that there is a common
+ /// type. So this is faster in release builds as it won't iterate through the
+ /// values.
+ static Type *getCommonScalarType(ArrayRef<Value *> Bndl) {
+ Value *V0 = Bndl[0];
+ Type *Ty0 = Utils::getExpectedType(V0);
+ Type *ScalarTy = VecUtils::getElementType(Ty0);
+ assert(tryGetCommonScalarType(Bndl) && "Expected common scalar type!");
+ return ScalarTy;
+ }
};
} // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 3617d369776418..4b82fea2aa2b22 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -164,7 +164,39 @@ void BottomUpVec::tryEraseDeadInstrs() {
DeadInstrCandidates.clear();
}
-Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
+Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
+ BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
+
+ Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
+ unsigned Lanes = VecUtils::getNumLanes(ToPack);
+ Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
+
+ // Create a series of pack instructions.
+ Value *LastInsert = PoisonValue::get(VecTy);
+
+ Context &Ctx = ToPack[0]->getContext();
+
+ unsigned InsertIdx = 0;
+ for (Value *Elm : ToPack) {
+ // An element can be either scalar or vector. We need to generate different
+ // IR for each case.
+ if (Elm->getType()->isVectorTy()) {
+ llvm_unreachable("Unimplemented");
+ } else {
+ Constant *InsertLaneC =
+ ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
+ // This may be folded into a Constant if LastInsert is a Constant. In that
+ // case we only collect the last constant.
+ LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
+ WhereIt, Ctx, "Pack");
+ if (auto *NewI = dyn_cast<Instruction>(LastInsert))
+ WhereIt = std::next(NewI->getIterator());
+ }
+ }
+ return LastInsert;
+}
+
+Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
Value *NewVec = nullptr;
const auto &LegalityRes = Legality->canVectorize(Bndl);
switch (LegalityRes.getSubclassID()) {
@@ -178,7 +210,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
break;
case Instruction::Opcode::Store: {
// Don't recurse towards the pointer operand.
- auto *VecOp = vectorizeRec(getOperand(Bndl, 0));
+ auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1);
VecOperands.push_back(VecOp);
VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
break;
@@ -186,7 +218,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
default:
// Visit all operands.
for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
- auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx));
+ auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1);
VecOperands.push_back(VecOp);
}
break;
@@ -201,8 +233,11 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
break;
}
case LegalityResultID::Pack: {
- // TODO: Unimplemented
- llvm_unreachable("Unimplemented");
+ // If we can't vectorize the seeds then just return.
+ if (Depth == 0)
+ return nullptr;
+ NewVec = createPack(Bndl);
+ break;
}
}
return NewVec;
@@ -210,7 +245,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
DeadInstrCandidates.clear();
- vectorizeRec(Bndl);
+ vectorizeRec(Bndl, /*Depth=*/0);
tryEraseDeadInstrs();
return Change;
}
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 49aeea9f8a8491..dff27d06d8ed22 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -143,3 +143,47 @@ define float @scalars_with_external_uses_not_dead(ptr %ptr) {
ret float %ld0
}
+define void @pack_scalars(ptr %ptr, ptr %ptr2) {
+; CHECK-LABEL: define void @pack_scalars(
+; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTR2:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
+; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR2]], align 4
+; CHECK-NEXT: [[PACK:%.*]] = insertelement <2 x float> poison, float [[LD0]], i32 0
+; CHECK-NEXT: [[PACK1:%.*]] = insertelement <2 x float> [[PACK]], float [[LD1]], i32 1
+; CHECK-NEXT: store <2 x float> [[PACK1]], ptr [[PTR0]], align 4
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr float, ptr %ptr, i32 0
+ %ptr1 = getelementptr float, ptr %ptr, i32 1
+ %ld0 = load float, ptr %ptr0
+ %ld1 = load float, ptr %ptr2
+ store float %ld0, ptr %ptr0
+ store float %ld1, ptr %ptr1
+ ret void
+}
+
+declare void @foo()
+define void @cant_vectorize_seeds(ptr %ptr) {
+; CHECK-LABEL: define void @cant_vectorize_seeds(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
+; CHECK-NEXT: [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT: [[LD1:%.*]] = load float, ptr [[PTR1]], align 4
+; CHECK-NEXT: store float [[LD1]], ptr [[PTR1]], align 4
+; CHECK-NEXT: call void @foo()
+; CHECK-NEXT: store float [[LD1]], ptr [[PTR1]], align 4
+; CHECK-NEXT: ret void
+;
+ %ptr0 = getelementptr float, ptr %ptr, i32 0
+ %ptr1 = getelementptr float, ptr %ptr, i32 1
+ %ld0 = load float, ptr %ptr0
+ %ld1 = load float, ptr %ptr1
+ store float %ld1, ptr %ptr1
+ call void @foo() ; This call blocks scheduling of the store seeds.
+ store float %ld1, ptr %ptr1
+ ret void
+}
+
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
index 835b9285c9d9ff..cf7b6cbc7e55cb 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -439,3 +439,36 @@ define void @foo(i8 %v) {
SmallVector<sandboxir::Instruction *> CBA({IC, IB, IA});
EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC);
}
+
+TEST_F(VecUtilsTest, GetCommonScalarType) {
+ parseIR(R"IR(
+define void @foo(i8 %v, ptr %ptr) {
+bb0:
+ %add0 = add i8 %v, %v
+ store i8 %v, ptr %ptr
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+
+ sandboxir::Context Ctx(C);
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto &BB = *F.begin();
+ auto It = BB.begin();
+ auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+ auto *Store = cast<sandboxir::StoreInst>(&*It++);
+ auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+ {
+ SmallVector<sandboxir::Value *> Vec = {Add0, Store};
+ EXPECT_EQ(sandboxir::VecUtils::tryGetCommonScalarType(Vec),
+ Add0->getType());
+ EXPECT_EQ(sandboxir::VecUtils::getCommonScalarType(Vec), Add0->getType());
+ }
+ {
+ SmallVector<sandboxir::Value *> Vec = {Add0, Ret};
+ EXPECT_EQ(sandboxir::VecUtils::tryGetCommonScalarType(Vec), nullptr);
+#ifndef NDEBUG
+ EXPECT_DEATH(sandboxir::VecUtils::getCommonScalarType(Vec), ".*common.*");
+#endif // NDEBUG
+ }
+}
More information about the llvm-commits
mailing list