[llvm] [SandboxVec][BottomUpVec] Implement pack of scalars (PR #115549)

via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 8 14:01:19 PST 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/115549

This patch implements packing of scalar operands when the vectorizer decides to stop vectorizing. Packing is implemented with a sequence of InsertElement instructions.

Packing vectors requires different instructions so it's implemented in a follow-up patch.

>From 56e533bbad9de2b14ee63ad3c1698d02f6fb1b29 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 5 Nov 2024 10:17:59 -0800
Subject: [PATCH] [SandboxVec][BottomUpVec] Implement pack of scalars

This patch implements packing of scalar operands when the vectorizer
decides to stop vectorizing. Packing is implemented with a sequence of
InsertElement instructions.

Packing vectors requires different instructions so it's implemented
in a follow-up patch.
---
 .../SandboxVectorizer/Passes/BottomUpVec.h    |  3 +-
 .../Vectorize/SandboxVectorizer/VecUtils.h    | 25 ++++++++++
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  | 47 ++++++++++++++++---
 .../SandboxVectorizer/bottomup_basic.ll       | 44 +++++++++++++++++
 .../SandboxVectorizer/VecUtilsTest.cpp        | 33 +++++++++++++
 5 files changed, 145 insertions(+), 7 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 02cd7650ad8a5a..6109db71611018 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -31,7 +31,8 @@ class BottomUpVec final : public FunctionPass {
   /// \p Bndl. \p Operands are the already vectorized operands.
   Value *createVectorInstr(ArrayRef<Value *> Bndl, ArrayRef<Value *> Operands);
   void tryEraseDeadInstrs();
-  Value *vectorizeRec(ArrayRef<Value *> Bndl);
+  Value *createPack(ArrayRef<Value *> ToPack);
+  Value *vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth);
   bool tryVectorize(ArrayRef<Value *> Seeds);
 
   // The PM containing the pipeline of region passes.
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index d44c845bfbf4e9..fc9d67fcfcdec4 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -108,6 +108,31 @@ class VecUtils {
     }
     return LowestI;
   }
+  /// If all values in \p Bndl are of the same scalar type then return it,
+  /// otherwise return nullptr.
+  static Type *tryGetCommonScalarType(ArrayRef<Value *> Bndl) {
+    Value *V0 = Bndl[0];
+    Type *Ty0 = Utils::getExpectedType(V0);
+    Type *ScalarTy = VecUtils::getElementType(Ty0);
+    for (auto *V : drop_begin(Bndl)) {
+      Type *NTy = Utils::getExpectedType(V);
+      Type *NScalarTy = VecUtils::getElementType(NTy);
+      if (NScalarTy != ScalarTy)
+        return nullptr;
+    }
+    return ScalarTy;
+  }
+
+  /// Similar to tryGetCommonScalarType() but will assert that there is a common
+  /// type. So this is faster in release builds as it won't iterate through the
+  /// values.
+  static Type *getCommonScalarType(ArrayRef<Value *> Bndl) {
+    Value *V0 = Bndl[0];
+    Type *Ty0 = Utils::getExpectedType(V0);
+    Type *ScalarTy = VecUtils::getElementType(Ty0);
+    assert(tryGetCommonScalarType(Bndl) && "Expected common scalar type!");
+    return ScalarTy;
+  }
 };
 
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 3617d369776418..4b82fea2aa2b22 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -164,7 +164,39 @@ void BottomUpVec::tryEraseDeadInstrs() {
   DeadInstrCandidates.clear();
 }
 
-Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
+Value *BottomUpVec::createPack(ArrayRef<Value *> ToPack) {
+  BasicBlock::iterator WhereIt = getInsertPointAfterInstrs(ToPack);
+
+  Type *ScalarTy = VecUtils::getCommonScalarType(ToPack);
+  unsigned Lanes = VecUtils::getNumLanes(ToPack);
+  Type *VecTy = VecUtils::getWideType(ScalarTy, Lanes);
+
+  // Create a series of pack instructions.
+  Value *LastInsert = PoisonValue::get(VecTy);
+
+  Context &Ctx = ToPack[0]->getContext();
+
+  unsigned InsertIdx = 0;
+  for (Value *Elm : ToPack) {
+    // An element can be either scalar or vector. We need to generate different
+    // IR for each case.
+    if (Elm->getType()->isVectorTy()) {
+      llvm_unreachable("Unimplemented");
+    } else {
+      Constant *InsertLaneC =
+          ConstantInt::getSigned(Type::getInt32Ty(Ctx), InsertIdx++);
+      // This may be folded into a Constant if LastInsert is a Constant. In that
+      // case we only collect the last constant.
+      LastInsert = InsertElementInst::create(LastInsert, Elm, InsertLaneC,
+                                             WhereIt, Ctx, "Pack");
+      if (auto *NewI = dyn_cast<Instruction>(LastInsert))
+        WhereIt = std::next(NewI->getIterator());
+    }
+  }
+  return LastInsert;
+}
+
+Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {
   Value *NewVec = nullptr;
   const auto &LegalityRes = Legality->canVectorize(Bndl);
   switch (LegalityRes.getSubclassID()) {
@@ -178,7 +210,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
       break;
     case Instruction::Opcode::Store: {
       // Don't recurse towards the pointer operand.
-      auto *VecOp = vectorizeRec(getOperand(Bndl, 0));
+      auto *VecOp = vectorizeRec(getOperand(Bndl, 0), Depth + 1);
       VecOperands.push_back(VecOp);
       VecOperands.push_back(cast<StoreInst>(I)->getPointerOperand());
       break;
@@ -186,7 +218,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
     default:
       // Visit all operands.
       for (auto OpIdx : seq<unsigned>(I->getNumOperands())) {
-        auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx));
+        auto *VecOp = vectorizeRec(getOperand(Bndl, OpIdx), Depth + 1);
         VecOperands.push_back(VecOp);
       }
       break;
@@ -201,8 +233,11 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
     break;
   }
   case LegalityResultID::Pack: {
-    // TODO: Unimplemented
-    llvm_unreachable("Unimplemented");
+    // If we can't vectorize the seeds then just return.
+    if (Depth == 0)
+      return nullptr;
+    NewVec = createPack(Bndl);
+    break;
   }
   }
   return NewVec;
@@ -210,7 +245,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
 
 bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
   DeadInstrCandidates.clear();
-  vectorizeRec(Bndl);
+  vectorizeRec(Bndl, /*Depth=*/0);
   tryEraseDeadInstrs();
   return Change;
 }
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 49aeea9f8a8491..dff27d06d8ed22 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -143,3 +143,47 @@ define float @scalars_with_external_uses_not_dead(ptr %ptr) {
   ret float %ld0
 }
 
+define void @pack_scalars(ptr %ptr, ptr %ptr2) {
+; CHECK-LABEL: define void @pack_scalars(
+; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTR2:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
+; CHECK-NEXT:    [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LD1:%.*]] = load float, ptr [[PTR2]], align 4
+; CHECK-NEXT:    [[PACK:%.*]] = insertelement <2 x float> poison, float [[LD0]], i32 0
+; CHECK-NEXT:    [[PACK1:%.*]] = insertelement <2 x float> [[PACK]], float [[LD1]], i32 1
+; CHECK-NEXT:    store <2 x float> [[PACK1]], ptr [[PTR0]], align 4
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %ptr0
+  %ld1 = load float, ptr %ptr2
+  store float %ld0, ptr %ptr0
+  store float %ld1, ptr %ptr1
+  ret void
+}
+
+declare void @foo()
+define void @cant_vectorize_seeds(ptr %ptr) {
+; CHECK-LABEL: define void @cant_vectorize_seeds(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
+; CHECK-NEXT:    [[LD0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LD1:%.*]] = load float, ptr [[PTR1]], align 4
+; CHECK-NEXT:    store float [[LD1]], ptr [[PTR1]], align 4
+; CHECK-NEXT:    call void @foo()
+; CHECK-NEXT:    store float [[LD1]], ptr [[PTR1]], align 4
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %ptr0
+  %ld1 = load float, ptr %ptr1
+  store float %ld1, ptr %ptr1
+  call void @foo() ; This call blocks scheduling of the store seeds.
+  store float %ld1, ptr %ptr1
+  ret void
+}
+
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
index 835b9285c9d9ff..cf7b6cbc7e55cb 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -439,3 +439,36 @@ define void @foo(i8 %v) {
   SmallVector<sandboxir::Instruction *> CBA({IC, IB, IA});
   EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC);
 }
+
+TEST_F(VecUtilsTest, GetCommonScalarType) {
+  parseIR(R"IR(
+define void @foo(i8 %v, ptr %ptr) {
+bb0:
+  %add0 = add i8 %v, %v
+  store i8 %v, ptr %ptr
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto It = BB.begin();
+  auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *Store = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+  {
+    SmallVector<sandboxir::Value *> Vec = {Add0, Store};
+    EXPECT_EQ(sandboxir::VecUtils::tryGetCommonScalarType(Vec),
+              Add0->getType());
+    EXPECT_EQ(sandboxir::VecUtils::getCommonScalarType(Vec), Add0->getType());
+  }
+  {
+    SmallVector<sandboxir::Value *> Vec = {Add0, Ret};
+    EXPECT_EQ(sandboxir::VecUtils::tryGetCommonScalarType(Vec), nullptr);
+#ifndef NDEBUG
+    EXPECT_DEATH(sandboxir::VecUtils::getCommonScalarType(Vec), ".*common.*");
+#endif // NDEBUG
+  }
+}



More information about the llvm-commits mailing list