[llvm] [SandboxVec][Legality] Query the scheduler for legality (PR #114616)

via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 5 14:48:02 PST 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/114616

>From ac4d978d4d7d65cb8e8c541e322aef6ba3311cb7 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Wed, 30 Oct 2024 08:58:51 -0700
Subject: [PATCH] [SandboxVec][Legality] Query the scheduler for legality

This patch adds the legality check of whether the candidate instructions
can be scheduled together. This uses a Scheduler object.
---
 llvm/include/llvm/SandboxIR/Pass.h            |   5 +-
 .../Vectorize/SandboxVectorizer/Legality.h    |  14 ++-
 .../SandboxVectorizer/SandboxVectorizer.h     |   2 +
 .../Vectorize/SandboxVectorizer/Legality.cpp  |  13 ++-
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  |   4 +-
 .../SandboxVectorizer/SandboxVectorizer.cpp   |   3 +-
 .../SandboxVectorizer/LegalityTest.cpp        | 102 ++++++++++++++----
 7 files changed, 114 insertions(+), 29 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Pass.h b/llvm/include/llvm/SandboxIR/Pass.h
index fee6bd9e779fda..4f4eae87cd3ff7 100644
--- a/llvm/include/llvm/SandboxIR/Pass.h
+++ b/llvm/include/llvm/SandboxIR/Pass.h
@@ -14,6 +14,7 @@
 
 namespace llvm {
 
+class AAResults;
 class ScalarEvolution;
 
 namespace sandboxir {
@@ -22,14 +23,16 @@ class Function;
 class Region;
 
 class Analyses {
+  AAResults *AA = nullptr;
   ScalarEvolution *SE = nullptr;
 
   Analyses() = default;
 
 public:
-  Analyses(ScalarEvolution &SE) : SE(&SE) {}
+  Analyses(AAResults &AA, ScalarEvolution &SE) : AA(&AA), SE(&SE) {}
 
 public:
+  AAResults &getAA() const { return *AA; }
   ScalarEvolution &getScalarEvolution() const { return *SE; }
   /// For use by unit tests.
   static Analyses emptyForTesting() { return Analyses(); }
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index f43e033e3cc7e3..58dcb2eeadbc2d 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -17,6 +17,7 @@
 #include "llvm/IR/DataLayout.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
 
 namespace llvm::sandboxir {
 
@@ -36,6 +37,7 @@ enum class ResultReason {
   DiffMathFlags,
   DiffWrapFlags,
   NotConsecutive,
+  CantSchedule,
   Unimplemented,
   Infeasible,
 };
@@ -66,6 +68,8 @@ struct ToStr {
       return "DiffWrapFlags";
     case ResultReason::NotConsecutive:
       return "NotConsecutive";
+    case ResultReason::CantSchedule:
+      return "CantSchedule";
     case ResultReason::Unimplemented:
       return "Unimplemented";
     case ResultReason::Infeasible:
@@ -146,6 +150,7 @@ class Pack final : public LegalityResultWithReason {
 
 /// Performs the legality analysis and returns a LegalityResult object.
 class LegalityAnalysis {
+  Scheduler Sched;
   /// Owns the legality result objects created by createLegalityResult().
   SmallVector<std::unique_ptr<LegalityResult>> ResultPool;
   /// Checks opcodes, types and other IR-specifics and returns a ResultReason
@@ -157,8 +162,8 @@ class LegalityAnalysis {
   const DataLayout &DL;
 
 public:
-  LegalityAnalysis(ScalarEvolution &SE, const DataLayout &DL)
-      : SE(SE), DL(DL) {}
+  LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL)
+      : Sched(AA), SE(SE), DL(DL) {}
   /// A LegalityResult factory.
   template <typename ResultT, typename... ArgsT>
   ResultT &createLegalityResult(ArgsT... Args) {
@@ -167,7 +172,10 @@ class LegalityAnalysis {
   }
   /// Checks if it's legal to vectorize the instructions in \p Bndl.
   /// \Returns a LegalityResult object owned by LegalityAnalysis.
-  const LegalityResult &canVectorize(ArrayRef<Value *> Bndl);
+  /// \p SkipScheduling skips the scheduler check and is only meant for testing.
+  // TODO: Try to remove the SkipScheduling argument by refactoring the tests.
+  const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
+                                     bool SkipScheduling = false);
 };
 
 } // namespace llvm::sandboxir
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
index 03867df3d98084..46b953ff9b7f49 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
@@ -10,6 +10,7 @@
 
 #include <memory>
 
+#include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/SandboxIR/PassManager.h"
@@ -20,6 +21,7 @@ class TargetTransformInfo;
 
 class SandboxVectorizerPass : public PassInfoMixin<SandboxVectorizerPass> {
   TargetTransformInfo *TTI = nullptr;
+  AAResults *AA = nullptr;
   ScalarEvolution *SE = nullptr;
 
   // A pipeline of SandboxIR function passes run by the vectorizer.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
index 1efd178778b9f6..8c6deeb7df249d 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp
@@ -184,7 +184,8 @@ static void dumpBndl(ArrayRef<Value *> Bndl) {
 }
 #endif // NDEBUG
 
-const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl) {
+const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
+                                                     bool SkipScheduling) {
   // If Bndl contains values other than instructions, we need to Pack.
   if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); })) {
     LLVM_DEBUG(dbgs() << "Not vectorizing: Not Instructions:\n";
@@ -197,7 +198,15 @@ const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl) {
 
   // TODO: Check for existing vectors containing values in Bndl.
 
-  // TODO: Check with scheduler.
+  if (!SkipScheduling) {
+    // TODO: Try to remove the IBndl vector.
+    SmallVector<Instruction *, 8> IBndl;
+    IBndl.reserve(Bndl.size());
+    for (auto *V : Bndl)
+      IBndl.push_back(cast<Instruction>(V));
+    if (!Sched.trySchedule(IBndl))
+      return createLegalityResult<Pack>(ResultReason::CantSchedule);
+  }
 
   return createLegalityResult<Widen>();
 }
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 339330c64f0caa..005d2241430ff1 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -61,8 +61,8 @@ void BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
 void BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) { vectorizeRec(Bndl); }
 
 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
-  Legality = std::make_unique<LegalityAnalysis>(A.getScalarEvolution(),
-                                                F.getParent()->getDataLayout());
+  Legality = std::make_unique<LegalityAnalysis>(
+      A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout());
   Change = false;
   // TODO: Start from innermost BBs first
   for (auto &BB : F) {
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
index 96d825ed852fb2..790bee4a4d7f39 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
@@ -51,6 +51,7 @@ SandboxVectorizerPass::~SandboxVectorizerPass() = default;
 PreservedAnalyses SandboxVectorizerPass::run(Function &F,
                                              FunctionAnalysisManager &AM) {
   TTI = &AM.getResult<TargetIRAnalysis>(F);
+  AA = &AM.getResult<AAManager>(F);
   SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
 
   bool Changed = runImpl(F);
@@ -83,6 +84,6 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
   // Create SandboxIR for LLVMF and run BottomUpVec on it.
   sandboxir::Context Ctx(LLVMF.getContext());
   sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
-  sandboxir::Analyses A(*SE);
+  sandboxir::Analyses A(*AA, *SE);
   return FPM.runOnFunction(F, A);
 }
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 68557cb8b129f2..51e7a14013299b 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BasicAliasAnalysis.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
@@ -30,15 +31,20 @@ struct LegalityTest : public testing::Test {
   std::unique_ptr<AssumptionCache> AC;
   std::unique_ptr<LoopInfo> LI;
   std::unique_ptr<ScalarEvolution> SE;
+  std::unique_ptr<BasicAAResult> BAA;
+  std::unique_ptr<AAResults> AA;
 
-  ScalarEvolution &getSE(llvm::Function &LLVMF) {
+  void getAnalyses(llvm::Function &LLVMF) {
     DT = std::make_unique<DominatorTree>(LLVMF);
     TLII = std::make_unique<TargetLibraryInfoImpl>();
     TLI = std::make_unique<TargetLibraryInfo>(*TLII);
     AC = std::make_unique<AssumptionCache>(LLVMF);
     LI = std::make_unique<LoopInfo>(*DT);
     SE = std::make_unique<ScalarEvolution>(LLVMF, *TLI, *AC, *DT, *LI);
-    return *SE;
+    BAA = std::make_unique<BasicAAResult>(LLVMF.getParent()->getDataLayout(),
+                                          LLVMF, *TLI, *AC, DT.get());
+    AA = std::make_unique<AAResults>(*TLI);
+    AA->addAAResult(*BAA);
   }
 
   void parseIR(LLVMContext &C, const char *IR) {
@@ -49,7 +55,7 @@ struct LegalityTest : public testing::Test {
   }
 };
 
-TEST_F(LegalityTest, Legality) {
+TEST_F(LegalityTest, LegalitySkipSchedule) {
   parseIR(C, R"IR(
 define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float %farg0, float %farg1, i64 %v0, i64 %v1, i32 %v2) {
   %gep0 = getelementptr float, ptr %ptr, i32 0
@@ -76,7 +82,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
 }
 )IR");
   llvm::Function *LLVMF = &*M->getFunction("foo");
-  auto &SE = getSE(*LLVMF);
+  getAnalyses(*LLVMF);
   const auto &DL = M->getDataLayout();
 
   sandboxir::Context Ctx(C);
@@ -104,83 +110,139 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
   auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
 
-  sandboxir::LegalityAnalysis Legality(SE, DL);
-  const auto &Result = Legality.canVectorize({St0, St1});
+  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
+  const auto &Result =
+      Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
   EXPECT_TRUE(isa<sandboxir::Widen>(Result));
 
   {
     // Check NotInstructions
-    auto &Result = Legality.canVectorize({F, St0});
+    auto &Result = Legality.canVectorize({F, St0}, /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::NotInstructions);
   }
   {
     // Check DiffOpcodes
-    const auto &Result = Legality.canVectorize({St0, Ld0});
+    const auto &Result =
+        Legality.canVectorize({St0, Ld0}, /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::DiffOpcodes);
   }
   {
     // Check DiffTypes
-    EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({St0, StVec2})));
-    EXPECT_TRUE(isa<sandboxir::Widen>(Legality.canVectorize({StVec2, StVec3})));
+    EXPECT_TRUE(isa<sandboxir::Widen>(
+        Legality.canVectorize({St0, StVec2}, /*SkipScheduling=*/true)));
+    EXPECT_TRUE(isa<sandboxir::Widen>(
+        Legality.canVectorize({StVec2, StVec3}, /*SkipScheduling=*/true)));
 
-    const auto &Result = Legality.canVectorize({St0, StI8});
+    const auto &Result =
+        Legality.canVectorize({St0, StI8}, /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::DiffTypes);
   }
   {
     // Check DiffMathFlags
-    const auto &Result = Legality.canVectorize({FAdd0, FAdd1});
+    const auto &Result =
+        Legality.canVectorize({FAdd0, FAdd1}, /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::DiffMathFlags);
   }
   {
     // Check DiffWrapFlags
-    const auto &Result = Legality.canVectorize({Trunc0, Trunc1});
+    const auto &Result =
+        Legality.canVectorize({Trunc0, Trunc1}, /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::DiffWrapFlags);
   }
   {
     // Check DiffTypes for unary operands that have a different type.
-    const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8});
+    const auto &Result = Legality.canVectorize({Trunc64to8, Trunc32to8},
+                                               /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::DiffTypes);
   }
   {
     // Check DiffOpcodes for CMPs with different predicates.
-    const auto &Result = Legality.canVectorize({CmpSLT, CmpSGT});
+    const auto &Result =
+        Legality.canVectorize({CmpSLT, CmpSGT}, /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::DiffOpcodes);
   }
   {
     // Check NotConsecutive Ld0,Ld0b
-    const auto &Result = Legality.canVectorize({Ld0, Ld0b});
+    const auto &Result =
+        Legality.canVectorize({Ld0, Ld0b}, /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::NotConsecutive);
   }
   {
     // Check NotConsecutive Ld0,Ld3
-    const auto &Result = Legality.canVectorize({Ld0, Ld3});
+    const auto &Result =
+        Legality.canVectorize({Ld0, Ld3}, /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Pack>(Result));
     EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
               sandboxir::ResultReason::NotConsecutive);
   }
   {
     // Check Widen Ld0,Ld1
-    const auto &Result = Legality.canVectorize({Ld0, Ld1});
+    const auto &Result =
+        Legality.canVectorize({Ld0, Ld1}, /*SkipScheduling=*/true);
     EXPECT_TRUE(isa<sandboxir::Widen>(Result));
   }
 }
 
+TEST_F(LegalityTest, LegalitySchedule) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr) {
+  %gep0 = getelementptr float, ptr %ptr, i32 0
+  %gep1 = getelementptr float, ptr %ptr, i32 1
+  %ld0 = load float, ptr %gep0
+  store float %ld0, ptr %gep1
+  %ld1 = load float, ptr %gep1
+  store float %ld0, ptr %gep0
+  store float %ld1, ptr %gep1
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  getAnalyses(*LLVMF);
+  const auto &DL = M->getDataLayout();
+
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  [[maybe_unused]] auto *Gep0 = cast<sandboxir::GetElementPtrInst>(&*It++);
+  [[maybe_unused]] auto *Gep1 = cast<sandboxir::GetElementPtrInst>(&*It++);
+  auto *Ld0 = cast<sandboxir::LoadInst>(&*It++);
+  [[maybe_unused]] auto *ConflictingSt = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ld1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *St0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *St1 = cast<sandboxir::StoreInst>(&*It++);
+
+  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
+  {
+    // Can vectorize St0,St1.
+    const auto &Result = Legality.canVectorize({St0, St1});
+    EXPECT_TRUE(isa<sandboxir::Widen>(Result));
+  }
+  {
+    // Can't vectorize Ld0,Ld1 because of conflicting store.
+    auto &Result = Legality.canVectorize({Ld0, Ld1});
+    EXPECT_TRUE(isa<sandboxir::Pack>(Result));
+    EXPECT_EQ(cast<sandboxir::Pack>(Result).getReason(),
+              sandboxir::ResultReason::CantSchedule);
+  }
+}
+
 #ifndef NDEBUG
 TEST_F(LegalityTest, LegalityResultDump) {
   parseIR(C, R"IR(
@@ -189,7 +251,7 @@ define void @foo() {
 }
 )IR");
   llvm::Function *LLVMF = &*M->getFunction("foo");
-  auto &SE = getSE(*LLVMF);
+  getAnalyses(*LLVMF);
   const auto &DL = M->getDataLayout();
 
   auto Matches = [](const sandboxir::LegalityResult &Result,
@@ -200,7 +262,7 @@ define void @foo() {
     return Buff == ExpectedStr;
   };
 
-  sandboxir::LegalityAnalysis Legality(SE, DL);
+  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
   EXPECT_TRUE(
       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(



More information about the llvm-commits mailing list