[llvm] 2dc1c95 - [SandboxVec][VecUtils] Implement VecUtils::getLowest() (#124024)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 22 16:08:19 PST 2025
Author: vporpo
Date: 2025-01-22T16:08:15-08:00
New Revision: 2dc1c95595e409c74a8a3d743afb7898e1af3255
URL: https://github.com/llvm/llvm-project/commit/2dc1c95595e409c74a8a3d743afb7898e1af3255
DIFF: https://github.com/llvm/llvm-project/commit/2dc1c95595e409c74a8a3d743afb7898e1af3255.diff
LOG: [SandboxVec][VecUtils] Implement VecUtils::getLowest() (#124024)
VecUtils::getLowest(Valse) returns the lowest instruction in the BB among Vals.
If the instructions are not in the same BB, or if none of them is an
instruction it returns nullptr.
Added:
Modified:
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index 6cbbb396ea823f..4e3ca2bccfe6fd 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -100,6 +100,8 @@ class VecUtils {
}
return FixedVectorType::get(ElemTy, NumElts);
}
+ /// \Returns the instruction in \p Instrs that is lowest in the BB. Expects
+ /// that all instructions are in the same BB.
static Instruction *getLowest(ArrayRef<Instruction *> Instrs) {
Instruction *LowestI = Instrs.front();
for (auto *I : drop_begin(Instrs)) {
@@ -108,6 +110,33 @@ class VecUtils {
}
return LowestI;
}
+ /// \Returns the lowest instruction in \p Vals, or nullptr if no instructions
+ /// are found or if not in the same BB.
+ static Instruction *getLowest(ArrayRef<Value *> Vals) {
+ // Find the first Instruction in Vals.
+ auto It = find_if(Vals, [](Value *V) { return isa<Instruction>(V); });
+ // If we couldn't find an instruction return nullptr.
+ if (It == Vals.end())
+ return nullptr;
+ Instruction *FirstI = cast<Instruction>(*It);
+ // Now look for the lowest instruction in Vals starting from one position
+ // after FirstI.
+ Instruction *LowestI = FirstI;
+ auto *LowestBB = LowestI->getParent();
+ for (auto *V : make_range(std::next(It), Vals.end())) {
+ auto *I = dyn_cast<Instruction>(V);
+ // Skip non-instructions.
+ if (I == nullptr)
+ continue;
+ // If the instructions are in
diff erent BBs return nullptr.
+ if (I->getParent() != LowestBB)
+ return nullptr;
+ // If `LowestI` comes before `I` then `I` is the new lowest.
+ if (LowestI->comesBefore(I))
+ LowestI = I;
+ }
+ return LowestI;
+ }
/// If all values in \p Bndl are of the same scalar type then return it,
/// otherwise return nullptr.
static Type *tryGetCommonScalarType(ArrayRef<Value *> Bndl) {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index c6ab3c1942c330..8432b4c6c469ae 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -45,11 +45,7 @@ static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
static BasicBlock::iterator
getInsertPointAfterInstrs(ArrayRef<Value *> Instrs) {
- // TODO: Use the VecUtils function for getting the bottom instr once it lands.
- auto *BotI = cast<Instruction>(
- *std::max_element(Instrs.begin(), Instrs.end(), [](auto *V1, auto *V2) {
- return cast<Instruction>(V1)->comesBefore(cast<Instruction>(V2));
- }));
+ auto *BotI = VecUtils::getLowest(Instrs);
// If Bndl contains Arguments or Constants, use the beginning of the BB.
return std::next(BotI->getIterator());
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
index 8661dcd5067c0a..b69172738d36a5 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -50,6 +50,14 @@ struct VecUtilsTest : public testing::Test {
}
};
+sandboxir::BasicBlock &getBasicBlockByName(sandboxir::Function &F,
+ StringRef Name) {
+ for (sandboxir::BasicBlock &BB : F)
+ if (BB.getName() == Name)
+ return BB;
+ llvm_unreachable("Expected to find basic block!");
+}
+
TEST_F(VecUtilsTest, GetNumElements) {
sandboxir::Context Ctx(C);
auto *ElemTy = sandboxir::Type::getInt32Ty(Ctx);
@@ -415,9 +423,11 @@ TEST_F(VecUtilsTest, GetLowest) {
parseIR(R"IR(
define void @foo(i8 %v) {
bb0:
- %A = add i8 %v, %v
- %B = add i8 %v, %v
- %C = add i8 %v, %v
+ br label %bb1
+bb1:
+ %A = add i8 %v, 1
+ %B = add i8 %v, 2
+ %C = add i8 %v, 3
ret void
}
)IR");
@@ -425,11 +435,21 @@ define void @foo(i8 %v) {
sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
- auto &BB = *F.begin();
- auto It = BB.begin();
- auto *IA = &*It++;
- auto *IB = &*It++;
- auto *IC = &*It++;
+ auto &BB0 = getBasicBlockByName(F, "bb0");
+ auto It = BB0.begin();
+ auto *BB0I = cast<sandboxir::BranchInst>(&*It++);
+
+ auto &BB = getBasicBlockByName(F, "bb1");
+ It = BB.begin();
+ auto *IA = cast<sandboxir::Instruction>(&*It++);
+ auto *C1 = cast<sandboxir::Constant>(IA->getOperand(1));
+ auto *IB = cast<sandboxir::Instruction>(&*It++);
+ auto *C2 = cast<sandboxir::Constant>(IB->getOperand(1));
+ auto *IC = cast<sandboxir::Instruction>(&*It++);
+ auto *C3 = cast<sandboxir::Constant>(IC->getOperand(1));
+ // Check getLowest(ArrayRef<Instruction *>)
+ SmallVector<sandboxir::Instruction *> A({IA});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(A), IA);
SmallVector<sandboxir::Instruction *> ABC({IA, IB, IC});
EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC), IC);
SmallVector<sandboxir::Instruction *> ACB({IA, IC, IB});
@@ -438,6 +458,27 @@ define void @foo(i8 %v) {
EXPECT_EQ(sandboxir::VecUtils::getLowest(CAB), IC);
SmallVector<sandboxir::Instruction *> CBA({IC, IB, IA});
EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC);
+
+ // Check getLowest(ArrayRef<Value *>)
+ SmallVector<sandboxir::Value *> C1Only({C1});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(C1Only), nullptr);
+ SmallVector<sandboxir::Value *> AOnly({IA});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(AOnly), IA);
+ SmallVector<sandboxir::Value *> AC1({IA, C1});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1), IA);
+ SmallVector<sandboxir::Value *> C1A({C1, IA});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(C1A), IA);
+ SmallVector<sandboxir::Value *> AC1B({IA, C1, IB});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1B), IB);
+ SmallVector<sandboxir::Value *> ABC1({IA, IB, C1});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC1), IB);
+ SmallVector<sandboxir::Value *> AC1C2({IA, C1, C2});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(AC1C2), IA);
+ SmallVector<sandboxir::Value *> C1C2C3({C1, C2, C3});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(C1C2C3), nullptr);
+
+ SmallVector<sandboxir::Value *> DiffBBs({BB0I, IA});
+ EXPECT_EQ(sandboxir::VecUtils::getLowest(DiffBBs), nullptr);
}
TEST_F(VecUtilsTest, GetCommonScalarType) {
More information about the llvm-commits
mailing list