[llvm] [SandboxVec][Legality] Query the scheduler for legality (PR #114616)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 5 14:48:02 PST 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/114616
>From ac4d978d4d7d65cb8e8c541e322aef6ba3311cb7 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 30 Oct 2024 08:58:51 -0700
Subject: [PATCH] [SandboxVec][Legality] Query the scheduler for legality
This patch adds the legality check of whether the candidate instructions
can be scheduled together. This uses a Scheduler object.
---
llvm/include/llvm/SandboxIR/Pass.h | 5 +-
.../Vectorize/SandboxVectorizer/Legality.h | 14 ++-
.../SandboxVectorizer/SandboxVectorizer.h | 2 +
.../Vectorize/SandboxVectorizer/Legality.cpp | 13 ++-
.../SandboxVectorizer/Passes/BottomUpVec.cpp | 4 +-
.../SandboxVectorizer/SandboxVectorizer.cpp | 3 +-
.../SandboxVectorizer/LegalityTest.cpp | 102 ++++++++++++++----
7 files changed, 114 insertions(+), 29 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Pass.h b/llvm/include/llvm/SandboxIR/Pass.h
index fee6bd9e779fda..4f4eae87cd3ff7 100644
--- a/llvm/include/llvm/SandboxIR/Pass.h
+++ b/llvm/include/llvm/SandboxIR/Pass.h
@@ -14,6 +14,7 @@
namespace llvm {
+class AAResults;
class ScalarEvolution;
namespace sandboxir {
@@ -22,14 +23,16 @@ class Function;
class Region;
class Analyses {
+ AAResults *AA = nullptr;
ScalarEvolution *SE = nullptr;
Analyses() = default;
public:
- Analyses(ScalarEvolution &SE) : SE(&SE) {}
+ Analyses(AAResults &AA, ScalarEvolution &SE) : AA(&AA), SE(&SE) {}
public:
+ AAResults &getAA() const { return *AA; }
ScalarEvolution &getScalarEvolution() const { return *SE; }
/// For use by unit tests.
static Analyses emptyForTesting() { return Analyses(); }
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index f43e033e3cc7e3..58dcb2eeadbc2d 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -17,6 +17,7 @@
#include "llvm/IR/DataLayout.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
namespace llvm::sandboxir {
@@ -36,6 +37,7 @@ enum class ResultReason {
DiffMathFlags,
DiffWrapFlags,
NotConsecutive,
+ CantSchedule,
Unimplemented,
Infeasible,
};
@@ -66,6 +68,8 @@ struct ToStr {
return "DiffWrapFlags";
case ResultReason::NotConsecutive:
return "NotConsecutive";
+ case ResultReason::CantSchedule:
+ return "CantSchedule";
case ResultReason::Unimplemented:
return "Unimplemented";
case ResultReason::Infeasible:
@@ -146,6 +150,7 @@ class Pack final : public LegalityResultWithReason {
/// Performs the legality analysis and returns a LegalityResult object.
class LegalityAnalysis {
+ Scheduler Sched;
/// Owns the legality result objects created by createLegalityResult().
SmallVector<std::unique_ptr<LegalityResult>> ResultPool;
/// Checks opcodes, types and other IR-specifics and returns a ResultReason
@@ -157,8 +162,8 @@ class LegalityAnalysis {
const DataLayout &DL;
public:
- LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL)
- : SE(SE), DL(DL) {}
+ LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL)
+ : Sched(AA), SE(SE), DL(DL) {}
/// A LegalityResult factory.
template <typename ResultT, typename... ArgsT>
ResultT &createLegalityResult(ArgsT... Args) {
@@ -167,7 +172,10 @@ class LegalityAnalysis {
}
/// Checks if it's legal to vectorize the instructions in \p Bndl.
/// \Returns a LegalityResult object owned by LegalityAnalysis.
- const LegalityResult &canVectorize(ArrayRef<Value *> Bndl);
+ /// \p SkipScheduling skips the scheduler check and is only meant for testing.
+ // TODO: Try to remove the SkipScheduling argument by refactoring the tests.
+ const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
+ bool SkipScheduling = false);
};
} // namespace llvm::sandboxir
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
index 03867df3d98084..46b953ff9b7f49 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
@@ -10,6 +10,7 @@
#include <memory>
+#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/IR/PassManager.h"
#include "llvm/SandboxIR/PassManager.h"
@@ -20,6 +21,7 @@ class TargetTransformInfo;
class SandboxVectorizerPass : public PassInfoMixin<SandboxVectorizerPass> {
TargetTransformInfo *TTI = nullptr;
+ AAResults *AA = nullptr;
ScalarEvolution *SE = nullptr;
// A pipeline of SandboxIR function passes run by the vectorizer.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 1efd178778b9f6..8c6deeb7df249d 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -184,7 +184,8 @@ static void dumpBndl(ArrayRef<Value *> Bndl) {
}
#endif // NDEBUG
-const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl) {
+const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
+ bool SkipScheduling) {
// If Bndl contains values other than instructions, we need to Pack.
if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) {
LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n";
@@ -197,7 +198,15 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl) {
// TODO: Check for existing vectors containing values in Bndl.
- // TODO: Check with scheduler.
+ if (!SkipScheduling) {
+ // TODO: Try to remove the IBndl vector.
+ SmallVector<Instruction *, 8> IBndl;
+ IBndl.reserve(Bndl.size());
+ for (auto *V : Bndl)
+ IBndl.push_back(cast<Instruction>(V));
+ if (!Sched.trySchedule(IBndl))
+ return createLegalityResult<Pack>(ResultReason::CantSchedule);
+ }
return createLegalityResult<Widen>();
}
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 339330c64f0caa..005d2241430ff1 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -61,8 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
- Legality = std::make_unique<LegalityAnalysis>(A.getScalarEvolution(),
- F.getParent()->getDataLayout());
+ Legality = std::make_unique<LegalityAnalysis>(
+ A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout());
Change = false;
// TODO: Start from innermost BBs first
for (auto &BB : F) {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
index 96d825ed852fb2..790bee4a4d7f39 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
@@ -51,6 +51,7 @@ SandboxVectorizerPass::~SandboxVectorizerPass() = default;
PreservedAnalyses SandboxVectorizerPass::run(Function &F,
FunctionAnalysisManager &AM) {
TTI = &AM.getResult<TargetIRAnalysis>(F);
+ AA = &AM.getResult<AAManager>(F);
SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
bool Changed = runImpl(F);
@@ -83,6 +84,6 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
// Create SandboxIR for LLVMF and run BottomUpVec on it.
sandboxir::Context Ctx(LLVMF.getContext());
sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
- sandboxir::Analyses A(*SE);
+ sandboxir::Analyses A(*AA, *SE);
return FPM.runOnFunction(F, A);
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 68557cb8b129f2..51e7a14013299b 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -8,6 +8,7 @@
#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
@@ -30,15 +31,20 @@ struct LegalityTest : public testing::Test {
std::unique_ptr<AssumptionCache> AC;
std::unique_ptr<LoopInfo> LI;
std::unique_ptr<ScalarEvolution> SE;
+ std::unique_ptr<BasicAAResult> BAA;
+ std::unique_ptr<AAResults> AA;
- ScalarEvolution &getSE(llvm::Function &LLVMF) {
+ void getAnalyses(llvm::Function &LLVMF) {
DT = std::make_unique<DominatorTree>(LLVMF);
TLII = std::make_unique<TargetLibraryInfoImpl>();
TLI = std::make_unique<TargetLibraryInfo>(*TLII);
AC = std::make_unique<AssumptionCache>(LLVMF);
LI = std::make_unique<LoopInfo>(*DT);
SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
- return *SE;
+ BAA = std::make_unique<BasicAAResult>(LLVMF.getParent()->getDataLayout(),
+ LLVMF, *TLI, *AC, DT.get());
+ AA = std::make_unique<AAResults>(*TLI);
+ AA->addAAResult(*BAA);
}
void parseIR(LLVMContext &C, const char *IR) {
@@ -49,7 +55,7 @@ struct LegalityTest : public testing::Test {
}
};
-TEST_F(LegalityTest, Legality) {
+TEST_F(LegalityTest, LegalitySkipSchedule) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
%gep0 = getelementptr float, ptr %ptr, i32 0
@@ -76,7 +82,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
- auto &SE = getSE(*LLVMF);
+ getAnalyses(*LLVMF);
const auto &DL = M->getDataLayout();
sandboxir::Context Ctx(C);
@@ -104,83 +110,139 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
- sandboxir::LegalityAnalysis Legality(SE, DL);
- const auto &Result = Legality.canVectorize({St0, St1});
+ sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
+ const auto &Result =
+ Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
{
// Check NotInstructions
- auto &Result = Legality.canVectorize({F, St0});
+ auto &Result = Legality.canVectorize({F, St0}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::NotInstructions);
}
{
// Check DiffOpcodes
- const auto &Result = Legality.canVectorize({St0, Ld0});
+ const auto &Result =
+ Legality.canVectorize({St0, Ld0}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffOpcodes);
}
{
// Check DiffTypes
- EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({St0, StVec2})));
- EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({StVec2, StVec3})));
+ EXPECT_TRUE(isa<sandboxir::Widen>(
+ Legality.canVectorize({St0, StVec2}, /*SkipScheduling=*/true)));
+ EXPECT_TRUE(isa<sandboxir::Widen>(
+ Legality.canVectorize({StVec2, StVec3}, /*SkipScheduling=*/true)));
- const auto &Result = Legality.canVectorize({St0, StI8});
+ const auto &Result =
+ Legality.canVectorize({St0, StI8}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffTypes);
}
{
// Check DiffMathFlags
- const auto &Result = Legality.canVectorize({FAdd0, FAdd1});
+ const auto &Result =
+ Legality.canVectorize({FAdd0, FAdd1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffMathFlags);
}
{
// Check DiffWrapFlags
- const auto &Result = Legality.canVectorize({Trunc0, Trunc1});
+ const auto &Result =
+ Legality.canVectorize({Trunc0, Trunc1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffWrapFlags);
}
{
// Check DiffTypes for unary operands that have a different type.
- const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8});
+ const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8},
+ /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffTypes);
}
{
// Check DiffOpcodes for CMPs with different predicates.
- const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT});
+ const auto &Result =
+ Legality.canVectorize({CmpSLT, CmpSGT}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::DiffOpcodes);
}
{
// Check NotConsecutive Ld0,Ld0b
- const auto &Result = Legality.canVectorize({Ld0, Ld0b});
+ const auto &Result =
+ Legality.canVectorize({Ld0, Ld0b}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::NotConsecutive);
}
{
// Check NotConsecutive Ld0,Ld3
- const auto &Result = Legality.canVectorize({Ld0, Ld3});
+ const auto &Result =
+ Legality.canVectorize({Ld0, Ld3}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Pack>(Result));
EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
sandboxir::ResultReason::NotConsecutive);
}
{
// Check Widen Ld0,Ld1
- const auto &Result = Legality.canVectorize({Ld0, Ld1});
+ const auto &Result =
+ Legality.canVectorize({Ld0, Ld1}, /*SkipScheduling=*/true);
EXPECT_TRUE(isa<sandboxir::Widen>(Result));
}
}
+TEST_F(LegalityTest, LegalitySchedule) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+ %gep0 = getelementptr float, ptr %ptr, i32 0
+ %gep1 = getelementptr float, ptr %ptr, i32 1
+ %ld0 = load float, ptr %gep0
+ store float %ld0, ptr %gep1
+ %ld1 = load float, ptr %gep1
+ store float %ld0, ptr %gep0
+ store float %ld1, ptr %gep1
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ getAnalyses(*LLVMF);
+ const auto &DL = M->getDataLayout();
+
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
+ [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
+ auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
+ [[maybe_unused]] auto *ConflictingSt = cast<sandboxir::StoreInst>(&*It++);
+ auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+ auto *St0 = cast<sandboxir::StoreInst>(&*It++);
+ auto *St1 = cast<sandboxir::StoreInst>(&*It++);
+
+ sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
+ {
+ // Can vectorize St0,St1.
+ const auto &Result = Legality.canVectorize({St0, St1});
+ EXPECT_TRUE(isa<sandboxir::Widen>(Result));
+ }
+ {
+ // Can't vectorize Ld0,Ld1 because of conflicting store.
+ auto &Result = Legality.canVectorize({Ld0, Ld1});
+ EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+ EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+ sandboxir::ResultReason::CantSchedule);
+ }
+}
+
#ifndef NDEBUG
TEST_F(LegalityTest, LegalityResultDump) {
parseIR(C, R"IR(
@@ -189,7 +251,7 @@ define void @foo() {
}
)IR");
llvm::Function *LLVMF = &*M->getFunction("foo");
- auto &SE = getSE(*LLVMF);
+ getAnalyses(*LLVMF);
const auto &DL = M->getDataLayout();
auto Matches = [](const sandboxir::LegalityResult &Result,
@@ -200,7 +262,7 @@ define void @foo() {
return Buff == ExpectedStr;
};
- sandboxir::LegalityAnalysis Legality(SE, DL);
+ sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(
More information about the llvm-commits
mailing list