[llvm] [SandboxVec][Scheduler] Implement rescheduling (PR #115220)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 15:11:30 PST 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/115220

>From 4b6d875b6fd1da39ec77bfe7893cd5e5300bd55c Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 4 Nov 2024 13:30:11 -0800
Subject: [PATCH] [SandboxVec][Scheduler] Implement rescheduling

This patch adds support for re-scheduling already scheduled instructions.
For now this will clear and rebuild the DAG, and will reschedule the code
using the new DAG.
---
 .../SandboxVectorizer/DependencyGraph.h       | 13 +++
 .../Vectorize/SandboxVectorizer/Scheduler.h   | 27 +++++-
 .../Vectorize/SandboxVectorizer/VecUtils.h    |  8 ++
 .../Vectorize/SandboxVectorizer/Scheduler.cpp | 96 +++++++++++++++----
 .../SandboxVectorizer/bottomup_basic.ll       | 32 ++++++-
 .../SandboxVectorizer/SchedulerTest.cpp       | 41 +++++++-
 .../SandboxVectorizer/VecUtilsTest.cpp        | 29 ++++++
 7 files changed, 222 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index b498e0f189465c..5211c7922ea2fd 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -33,6 +33,7 @@ namespace llvm::sandboxir {
 
 class DependencyGraph;
 class MemDGNode;
+class SchedBundle;
 
 /// SubclassIDs for isa/dyn_cast etc.
 enum class DGNodeID {
@@ -100,6 +101,12 @@ class DGNode {
   unsigned UnscheduledSuccs = 0;
   /// This is true if this node has been scheduled.
   bool Scheduled = false;
+  /// The scheduler bundle that this node belongs to.
+  SchedBundle *SB = nullptr;
+
+  void setSchedBundle(SchedBundle &SB) { this->SB = &SB; }
+  void clearSchedBundle() { this->SB = nullptr; }
+  friend class SchedBundle; // For setSchedBundle(), clearSchedBundle().
 
   DGNode(Instruction *I, DGNodeID ID) : I(I), SubclassID(ID) {}
   friend class MemDGNode;       // For constructor.
@@ -122,6 +129,8 @@ class DGNode {
   /// \Returns true if this node has been scheduled.
   bool scheduled() const { return Scheduled; }
   void setScheduled(bool NewVal) { Scheduled = NewVal; }
+  /// \Returns the scheduling bundle that this node belongs to, or nullptr.
+  SchedBundle *getSchedBundle() const { return SB; }
   /// \Returns true if this is before \p Other in program order.
   bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
   using iterator = PredIterator;
@@ -350,6 +359,10 @@ class DependencyGraph {
     getOrCreateNode(I);
     // TODO: Update the dependencies for the new node.
   }
+  void clear() {
+    InstrToNodeMap.clear();
+    DAGInterval = {};
+  }
 #ifndef NDEBUG
   void print(raw_ostream &OS) const;
   LLVM_DUMP_METHOD void dump() const;
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
index 0e4eea3880efbd..2d6b4035b67408 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -53,6 +53,7 @@ class ReadyListContainer {
     return Back;
   }
   bool empty() const { return List.empty(); }
+  void clear() { List = {}; }
 #ifndef NDEBUG
   void dump(raw_ostream &OS) const;
   LLVM_DUMP_METHOD void dump() const;
@@ -70,7 +71,16 @@ class SchedBundle {
 
 public:
   SchedBundle() = default;
-  SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {}
+  SchedBundle(ContainerTy &&Nodes) : Nodes(std::move(Nodes)) {
+    for (auto *N : this->Nodes)
+      N->setSchedBundle(*this);
+  }
+  ~SchedBundle() {
+    for (auto *N : this->Nodes)
+      N->clearSchedBundle();
+  }
+  bool empty() const { return Nodes.empty(); }
+  DGNode *back() const { return Nodes.back(); }
   using iterator = ContainerTy::iterator;
   using const_iterator = ContainerTy::const_iterator;
   iterator begin() { return Nodes.begin(); }
@@ -94,19 +104,30 @@ class Scheduler {
   ReadyListContainer ReadyList;
   DependencyGraph DAG;
   std::optional<BasicBlock::iterator> ScheduleTopItOpt;
-  SmallVector<std::unique_ptr<SchedBundle>> Bndls;
+  // TODO: This is wasting memory in exchange for fast removal using a raw ptr.
+  DenseMap<SchedBundle *, std::unique_ptr<SchedBundle>> Bndls;
   Context &Ctx;
   Context::CallbackID CreateInstrCB;
 
   /// \Returns a scheduling bundle containing \p Instrs.
   SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
+  void eraseBundle(SchedBundle *SB);
   /// Schedule nodes until we can schedule \p Instrs back-to-back.
   bool tryScheduleUntil(ArrayRef<Instruction *> Instrs);
   /// Schedules all nodes in \p Bndl, marks them as scheduled, updates the
   /// UnscheduledSuccs counter of all dependency predecessors, and adds any of
   /// them that become ready to the ready list.
   void scheduleAndUpdateReadyList(SchedBundle &Bndl);
-
+  /// The scheduling state of the instructions in the bundle.
+  enum class BndlSchedState {
+    NoneScheduled,
+    PartiallyOrDifferentlyScheduled,
+    FullyScheduled,
+  };
+  /// \Returns whether none/some/all of \p Instrs have been scheduled.
+  BndlSchedState getBndlSchedState(ArrayRef<Instruction *> Instrs) const;
+  /// Destroy the top-most part of the schedule that includes \p Instrs.
+  void trimSchedule(ArrayRef<Instruction *> Instrs);
   /// Disable copies.
   Scheduler(const Scheduler &) = delete;
   Scheduler &operator=(const Scheduler &) = delete;
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
index 85229150de2b6c..d44c845bfbf4e9 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h
@@ -100,6 +100,14 @@ class VecUtils {
     }
     return FixedVectorType::get(ElemTy, NumElts);
   }
+  static Instruction *getLowest(ArrayRef<Instruction *> Instrs) {
+    Instruction *LowestI = Instrs.front();
+    for (auto *I : drop_begin(Instrs)) {
+      if (LowestI->comesBefore(I))
+        LowestI = I;
+    }
+    return LowestI;
+  }
 };
 
 } // namespace llvm::sandboxir
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
index 6140c2a8dcec82..1fbf201f74a66b 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Scheduler.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
+#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
 
 namespace llvm::sandboxir {
 
@@ -95,10 +96,12 @@ SchedBundle *Scheduler::createBundle(ArrayRef<Instruction *> Instrs) {
     Nodes.push_back(DAG.getNode(I));
   auto BndlPtr = std::make_unique<SchedBundle>(std::move(Nodes));
   auto *Bndl = BndlPtr.get();
-  Bndls.push_back(std::move(BndlPtr));
+  Bndls[Bndl] = std::move(BndlPtr);
   return Bndl;
 }
 
+void Scheduler::eraseBundle(SchedBundle *SB) { Bndls.erase(SB); }
+
 bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
   // Use a set of instructions, instead of `Instrs` for fast lookups.
   DenseSet<Instruction *> InstrsToDefer(Instrs.begin(), Instrs.end());
@@ -133,29 +136,88 @@ bool Scheduler::tryScheduleUntil(ArrayRef<Instruction *> Instrs) {
   return false;
 }
 
+Scheduler::BndlSchedState
+Scheduler::getBndlSchedState(ArrayRef<Instruction *> Instrs) const {
+  assert(!Instrs.empty() && "Expected non-empty bundle");
+  bool PartiallyScheduled = false;
+  bool FullyScheduled = true;
+  for (auto *I : Instrs) {
+    auto *N = DAG.getNode(I);
+    if (N != nullptr && N->scheduled())
+      PartiallyScheduled = true;
+    else
+      FullyScheduled = false;
+  }
+  if (FullyScheduled) {
+    // If not all instrs in the bundle are in the same SchedBundle then this
+    // should be considered as partially-scheduled, because we will need to
+    // re-schedule.
+    SchedBundle *SB = DAG.getNode(Instrs[0])->getSchedBundle();
+    assert(SB != nullptr && "FullyScheduled assumes that there is an SB!");
+    if (any_of(drop_begin(Instrs), [this, SB](sandboxir::Value *SBV) {
+          return DAG.getNode(cast<sandboxir::Instruction>(SBV))
+                     ->getSchedBundle() != SB;
+        }))
+      FullyScheduled = false;
+  }
+  return FullyScheduled       ? BndlSchedState::FullyScheduled
+         : PartiallyScheduled ? BndlSchedState::PartiallyOrDifferentlyScheduled
+                              : BndlSchedState::NoneScheduled;
+}
+
+void Scheduler::trimSchedule(ArrayRef<Instruction *> Instrs) {
+  Instruction *TopI = &*ScheduleTopItOpt.value();
+  Instruction *LowestI = VecUtils::getLowest(Instrs);
+  // Destroy the schedule bundles from LowestI all the way to the top.
+  for (auto *I = LowestI, *E = TopI->getPrevNode(); I != E;
+       I = I->getPrevNode()) {
+    auto *N = DAG.getNode(I);
+    if (auto *SB = N->getSchedBundle())
+      eraseBundle(SB);
+  }
+  // TODO: For now we clear the DAG. Trim view once it gets implemented.
+  Bndls.clear();
+  DAG.clear();
+
+  // Since we are scheduling NewRegion from scratch, we clear the ready lists.
+  // The nodes currently in the list may not be ready after clearing the View.
+  ReadyList.clear();
+}
+
 bool Scheduler::trySchedule(ArrayRef<Instruction *> Instrs) {
   assert(all_of(drop_begin(Instrs),
                 [Instrs](Instruction *I) {
                   return I->getParent() == (*Instrs.begin())->getParent();
                 }) &&
          "Instrs not in the same BB!");
-  // Extend the DAG to include Instrs.
-  Interval<Instruction> Extension = DAG.extend(Instrs);
-  // TODO: Set the window of the DAG that we are interested in.
-  // We start scheduling at the bottom instr of Instrs.
-  auto getBottomI = [](ArrayRef<Instruction *> Instrs) -> Instruction * {
-    return *min_element(Instrs,
-                        [](auto *I1, auto *I2) { return I1->comesBefore(I2); });
-  };
-  ScheduleTopItOpt = std::next(getBottomI(Instrs)->getIterator());
-  // Add nodes to ready list.
-  for (auto &I : Extension) {
-    auto *N = DAG.getNode(&I);
-    if (N->ready())
-      ReadyList.insert(N);
+  auto SchedState = getBndlSchedState(Instrs);
+  switch (SchedState) {
+  case BndlSchedState::FullyScheduled:
+    // Nothing to do.
+    return true;
+  case BndlSchedState::PartiallyOrDifferentlyScheduled:
+    // If one or more instrs are already scheduled we need to destroy the
+    // top-most part of the schedule that includes the instrs in the bundle and
+    // re-schedule.
+    trimSchedule(Instrs);
+    [[fallthrough]];
+  case BndlSchedState::NoneScheduled: {
+    // TODO: Set the window of the DAG that we are interested in.
+    // We start scheduling at the bottom instr of Instrs.
+    ScheduleTopItOpt = std::next(VecUtils::getLowest(Instrs)->getIterator());
+
+    // Extend the DAG to include Instrs.
+    Interval<Instruction> Extension = DAG.extend(Instrs);
+    // Add nodes to ready list.
+    for (auto &I : Extension) {
+      auto *N = DAG.getNode(&I);
+      if (N->ready())
+        ReadyList.insert(N);
+    }
+    // Try schedule all nodes until we can schedule Instrs back-to-back.
+    return tryScheduleUntil(Instrs);
+  }
   }
-  // Try schedule all nodes until we can schedule Instrs back-to-back.
-  return tryScheduleUntil(Instrs);
 }
 
 #ifndef NDEBUG
diff --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 45c701a18fd9bf..e56dbd75963f7a 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -96,7 +96,37 @@ define void @store_fcmp_zext_load(ptr %ptr) {
   ret void
 }
 
-; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check
+define void @store_fadd_load(ptr %ptr) {
+; CHECK-LABEL: define void @store_fadd_load(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
+; CHECK-NEXT:    [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4
+; CHECK-NEXT:    [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4
+; CHECK-NEXT:    [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[FADD0:%.*]] = fadd float [[LDA0]], [[LDB0]]
+; CHECK-NEXT:    [[FADD1:%.*]] = fadd float [[LDA1]], [[LDB1]]
+; CHECK-NEXT:    [[VEC:%.*]] = fadd <2 x float> [[VECL]], [[VECL1]]
+; CHECK-NEXT:    store float [[FADD0]], ptr [[PTR0]], align 4
+; CHECK-NEXT:    store float [[FADD1]], ptr [[PTR1]], align 4
+; CHECK-NEXT:    store <2 x float> [[VEC]], ptr [[PTR0]], align 4
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ldA0 = load float, ptr %ptr0
+  %ldA1 = load float, ptr %ptr1
+  %ldB0 = load float, ptr %ptr0
+  %ldB1 = load float, ptr %ptr1
+  %fadd0 = fadd float %ldA0, %ldB0
+  %fadd1 = fadd float %ldA1, %ldB1
+  store float %fadd0, ptr %ptr0
+  store float %fadd1, ptr %ptr1
+  ret void
+}
 
 define void @store_fneg_load(ptr %ptr) {
 ; CHECK-LABEL: define void @store_fneg_load(
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
index 4a8b0ba1d7c12b..94a57914429748 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
@@ -168,11 +168,10 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
     EXPECT_TRUE(Sched.trySchedule({S0}));
   }
   {
-    // Try invalid scheduling
+    // Try invalid scheduling. Dependency S0->S1.
     sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
     EXPECT_TRUE(Sched.trySchedule({Ret}));
-    EXPECT_TRUE(Sched.trySchedule({S0}));
-    EXPECT_FALSE(Sched.trySchedule({S1}));
+    EXPECT_FALSE(Sched.trySchedule({S0, S1}));
   }
 }
 
@@ -202,3 +201,39 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
   EXPECT_TRUE(Sched.trySchedule({S0, S1}));
   EXPECT_TRUE(Sched.trySchedule({L0, L1}));
 }
+
+TEST_F(SchedulerTest, RescheduleAlreadyScheduled) {
+  parseIR(C, R"IR(
+define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
+  %ld0 = load i8, ptr %ptr0
+  %ld1 = load i8, ptr %ptr1
+  %add0 = add i8 %ld0, %ld0
+  %add1 = add i8 %ld1, %ld1
+  store i8 %add0, ptr %ptr0
+  store i8 %add1, ptr %ptr1
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *L0 = cast<sandboxir::LoadInst>(&*It++);
+  auto *L1 = cast<sandboxir::LoadInst>(&*It++);
+  auto *Add0 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *Add1 = cast<sandboxir::BinaryOperator>(&*It++);
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+
+  sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
+  EXPECT_TRUE(Sched.trySchedule({Ret}));
+  EXPECT_TRUE(Sched.trySchedule({S0, S1}));
+  EXPECT_TRUE(Sched.trySchedule({L0, L1}));
+  // At this point Add0 and Add1 should have been individually scheduled
+  // as single bundles.
+  // Check if rescheduling works.
+  EXPECT_TRUE(Sched.trySchedule({Add0, Add1}));
+  EXPECT_TRUE(Sched.trySchedule({L0, L1}));
+}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
index 6d1ab95ce31440..835b9285c9d9ff 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/VecUtilsTest.cpp
@@ -410,3 +410,32 @@ TEST_F(VecUtilsTest, GetWideType) {
   auto *Int32X8Ty = sandboxir::FixedVectorType::get(Int32Ty, 8);
   EXPECT_EQ(sandboxir::VecUtils::getWideType(Int32X4Ty, 2), Int32X8Ty);
 }
+
+TEST_F(VecUtilsTest, GetLowest) {
+  parseIR(R"IR(
+define void @foo(i8 %v) {
+bb0:
+  %A = add i8 %v, %v
+  %B = add i8 %v, %v
+  %C = add i8 %v, %v
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+
+  sandboxir::Context Ctx(C);
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto It = BB.begin();
+  auto *IA = &*It++;
+  auto *IB = &*It++;
+  auto *IC = &*It++;
+  SmallVector<sandboxir::Instruction *> ABC({IA, IB, IC});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(ABC), IC);
+  SmallVector<sandboxir::Instruction *> ACB({IA, IC, IB});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(ACB), IC);
+  SmallVector<sandboxir::Instruction *> CAB({IC, IA, IB});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(CAB), IC);
+  SmallVector<sandboxir::Instruction *> CBA({IC, IB, IA});
+  EXPECT_EQ(sandboxir::VecUtils::getLowest(CBA), IC);
+}



More information about the llvm-commits mailing list