[llvm] 5942a99 - [SandboxVec] Notify scheduler about new instructions (#115102)

via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 6 13:26:17 PST 2024


Author: vporpo
Date: 2024-11-06T13:26:14-08:00
New Revision: 5942a99f8b7dd361c35eb1c9c32b2475dce2c0b2

URL: https://github.com/llvm/llvm-project/commit/5942a99f8b7dd361c35eb1c9c32b2475dce2c0b2
DIFF: https://github.com/llvm/llvm-project/commit/5942a99f8b7dd361c35eb1c9c32b2475dce2c0b2.diff

LOG: [SandboxVec] Notify scheduler about new instructions (#115102)

This patch registers the "createInstr" callback that notifies the
scheduler about newly created instructions. This guarantees that all
newly created instructions have a corresponding DAG node associated with
them. Without this the pass crashes when the scheduler encounters the
newly created vector instructions.

This patch also changes the lifetime of the sandboxir Ctx variable in
the SandboxVectorizer pass. It needs to be destroyed after the passes
get destroyed. Without this change when components like the Scheduler
get destroyed Ctx will have already been freed, which is not legal.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
    llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
    llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
    llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
    llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
    llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
    llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 5be05bc80c4925..b498e0f189465c 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -345,6 +345,11 @@ class DependencyGraph {
   Interval<Instruction> extend(ArrayRef<Instruction *> Instrs);
   /// \Returns the range of instructions included in the DAG.
   Interval<Instruction> getInterval() const { return DAGInterval; }
+  /// Called by the scheduler when a new instruction \p I has been created.
+  void notifyCreateInstr(Instruction *I) {
+    getOrCreateNode(I);
+    // TODO: Update the dependencies for the new node.
+  }
 #ifndef NDEBUG
   void print(raw_ostream &OS) const;
   LLVM_DUMP_METHOD void dump() const;

diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
index 58dcb2eeadbc2d..63d6ef31c86453 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h
@@ -162,8 +162,9 @@ class LegalityAnalysis {
   const DataLayout &DL;
 
 public:
-  LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL)
-      : Sched(AA), SE(SE), DL(DL) {}
+  LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
+                   Context &Ctx)
+      : Sched(AA, Ctx), SE(SE), DL(DL) {}
   /// A LegalityResult factory.
   template <typename ResultT, typename... ArgsT>
   ResultT &createLegalityResult(ArgsT... Args) {

diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
index 46b953ff9b7f49..09369dbb496fce 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h
@@ -13,6 +13,7 @@
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/SandboxIR/Context.h"
 #include "llvm/SandboxIR/PassManager.h"
 
 namespace llvm {
@@ -24,6 +25,8 @@ class SandboxVectorizerPass : public PassInfoMixin<SandboxVectorizerPass> {
   AAResults *AA = nullptr;
   ScalarEvolution *SE = nullptr;
 
+  std::unique_ptr<sandboxir::Context> Ctx;
+
   // A pipeline of SandboxIR function passes run by the vectorizer.
   sandboxir::FunctionPassManager FPM;
 

diff  --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
index 08972d460b406e..0e4eea3880efbd 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h
@@ -95,6 +95,8 @@ class Scheduler {
   DependencyGraph DAG;
   std::optional<BasicBlock::iterator> ScheduleTopItOpt;
   SmallVector<std::unique_ptr<SchedBundle>> Bndls;
+  Context &Ctx;
+  Context::CallbackID CreateInstrCB;
 
   /// \Returns a scheduling bundle containing \p Instrs.
   SchedBundle *createBundle(ArrayRef<Instruction *> Instrs);
@@ -110,8 +112,11 @@ class Scheduler {
   Scheduler &operator=(const Scheduler &) = delete;
 
 public:
-  Scheduler(AAResults &AA) : DAG(AA) {}
-  ~Scheduler() {}
+  Scheduler(AAResults &AA, Context &Ctx) : DAG(AA), Ctx(Ctx) {
+    CreateInstrCB = Ctx.registerCreateInstrCallback(
+        [this](Instruction *I) { DAG.notifyCreateInstr(I); });
+  }
+  ~Scheduler() { Ctx.unregisterCreateInstrCallback(CreateInstrCB); }
 
   bool trySchedule(ArrayRef<Instruction *> Instrs);
 

diff  --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 37713e7da6432d..0a930d30aeab58 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -182,8 +182,6 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl) {
     }
     NewVec = createVectorInstr(Bndl, VecOperands);
 
-    // TODO: Notify DAG/Scheduler about new instruction
-
     // TODO: Collect potentially dead instructions.
     break;
   }
@@ -202,7 +200,8 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
 
 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
   Legality = std::make_unique<LegalityAnalysis>(
-      A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout());
+      A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
+      F.getContext());
   Change = false;
   // TODO: Start from innermost BBs first
   for (auto &BB : F) {

diff  --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
index 790bee4a4d7f39..c22eb01d74a1cb 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.cpp
@@ -64,6 +64,9 @@ PreservedAnalyses SandboxVectorizerPass::run(Function &F,
 }
 
 bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
+  if (Ctx == nullptr)
+    Ctx = std::make_unique<sandboxir::Context>(LLVMF.getContext());
+
   if (PrintPassPipeline) {
     FPM.printPipeline(outs());
     return false;
@@ -82,8 +85,7 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {
   }
 
   // Create SandboxIR for LLVMF and run BottomUpVec on it.
-  sandboxir::Context Ctx(LLVMF.getContext());
-  sandboxir::Function &F = *Ctx.createFunction(&LLVMF);
+  sandboxir::Function &F = *Ctx->createFunction(&LLVMF);
   sandboxir::Analyses A(*AA, *SE);
   return FPM.runOnFunction(F, A);
 }

diff  --git a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
index 2b9aac93b74851..45c701a18fd9bf 100644
--- a/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
+++ b/llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
@@ -55,7 +55,46 @@ define void @store_fpext_load(ptr %ptr) {
   ret void
 }
 
-; TODO: Test store_zext_fcmp_load once we implement scheduler callbacks and legality diamond check
+define void @store_fcmp_zext_load(ptr %ptr) {
+; CHECK-LABEL: define void @store_fcmp_zext_load(
+; CHECK-SAME: ptr [[PTR:%.*]]) {
+; CHECK-NEXT:    [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
+; CHECK-NEXT:    [[PTRB0:%.*]] = getelementptr i32, ptr [[PTR]], i32 0
+; CHECK-NEXT:    [[PTRB1:%.*]] = getelementptr i32, ptr [[PTR]], i32 1
+; CHECK-NEXT:    [[LDB0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LDB1:%.*]] = load float, ptr [[PTR1]], align 4
+; CHECK-NEXT:    [[VECL1:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LDA0:%.*]] = load float, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[LDA1:%.*]] = load float, ptr [[PTR1]], align 4
+; CHECK-NEXT:    [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
+; CHECK-NEXT:    [[FCMP0:%.*]] = fcmp ogt float [[LDA0]], [[LDB0]]
+; CHECK-NEXT:    [[FCMP1:%.*]] = fcmp ogt float [[LDA1]], [[LDB1]]
+; CHECK-NEXT:    [[VCMP:%.*]] = fcmp ogt <2 x float> [[VECL]], [[VECL1]]
+; CHECK-NEXT:    [[ZEXT0:%.*]] = zext i1 [[FCMP0]] to i32
+; CHECK-NEXT:    [[ZEXT1:%.*]] = zext i1 [[FCMP1]] to i32
+; CHECK-NEXT:    [[VCAST:%.*]] = zext <2 x i1> [[VCMP]] to <2 x i32>
+; CHECK-NEXT:    store i32 [[ZEXT0]], ptr [[PTRB0]], align 4
+; CHECK-NEXT:    store i32 [[ZEXT1]], ptr [[PTRB1]], align 4
+; CHECK-NEXT:    store <2 x i32> [[VCAST]], ptr [[PTRB0]], align 4
+; CHECK-NEXT:    ret void
+;
+  %ptr0 = getelementptr float, ptr %ptr, i32 0
+  %ptr1 = getelementptr float, ptr %ptr, i32 1
+  %ptrb0 = getelementptr i32, ptr %ptr, i32 0
+  %ptrb1 = getelementptr i32, ptr %ptr, i32 1
+  %ldB0 = load float, ptr %ptr0
+  %ldB1 = load float, ptr %ptr1
+  %ldA0 = load float, ptr %ptr0
+  %ldA1 = load float, ptr %ptr1
+  %fcmp0 = fcmp ogt float %ldA0, %ldB0
+  %fcmp1 = fcmp ogt float %ldA1, %ldB1
+  %zext0 = zext i1 %fcmp0 to i32
+  %zext1 = zext i1 %fcmp1 to i32
+  store i32 %zext0, ptr %ptrb0
+  store i32 %zext1, ptr %ptrb1
+  ret void
+}
 
 ; TODO: Test store_fadd_load once we implement scheduler callbacks and legality diamond check
 

diff  --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 51e7a14013299b..b5e2c302f5901e 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -110,7 +110,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
   auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
 
-  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
+  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
   const auto &Result =
       Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
   EXPECT_TRUE(isa<sandboxir::Widen>(Result));
@@ -228,7 +228,7 @@ define void @foo(ptr %ptr) {
   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
 
-  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
+  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
   {
     // Can vectorize St0,St1.
     const auto &Result = Legality.canVectorize({St0, St1});
@@ -262,7 +262,8 @@ define void @foo() {
     return Buff == ExpectedStr;
   };
 
-  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL);
+  sandboxir::Context Ctx(C);
+  sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx);
   EXPECT_TRUE(
       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
   EXPECT_TRUE(Matches(Legality.createLegalityResult<sandboxir::Pack>(

diff  --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
index 92e767e55fbddb..4a8b0ba1d7c12b 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/SchedulerTest.cpp
@@ -156,20 +156,20 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
 
   {
     // Schedule all instructions in sequence.
-    sandboxir::Scheduler Sched(getAA(*LLVMF));
+    sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
     EXPECT_TRUE(Sched.trySchedule({Ret}));
     EXPECT_TRUE(Sched.trySchedule({S1}));
     EXPECT_TRUE(Sched.trySchedule({S0}));
   }
   {
     // Skip instructions.
-    sandboxir::Scheduler Sched(getAA(*LLVMF));
+    sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
     EXPECT_TRUE(Sched.trySchedule({Ret}));
     EXPECT_TRUE(Sched.trySchedule({S0}));
   }
   {
     // Try invalid scheduling
-    sandboxir::Scheduler Sched(getAA(*LLVMF));
+    sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
     EXPECT_TRUE(Sched.trySchedule({Ret}));
     EXPECT_TRUE(Sched.trySchedule({S0}));
     EXPECT_FALSE(Sched.trySchedule({S1}));
@@ -197,7 +197,7 @@ define void @foo(ptr noalias %ptr0, ptr noalias %ptr1) {
   auto *S1 = cast<sandboxir::StoreInst>(&*It++);
   auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
-  sandboxir::Scheduler Sched(getAA(*LLVMF));
+  sandboxir::Scheduler Sched(getAA(*LLVMF), Ctx);
   EXPECT_TRUE(Sched.trySchedule({Ret}));
   EXPECT_TRUE(Sched.trySchedule({S0, S1}));
   EXPECT_TRUE(Sched.trySchedule({L0, L1}));


        


More information about the llvm-commits mailing list