[llvm] [SandboxVec][InstrMaps] EraseInstr callback (PR #123256)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 16 15:33:37 PST 2025


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/123256

This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased.

>From 5dcd170ec3ef9afadf254f62d593fb249b4871ee Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 15 Nov 2024 12:53:37 -0800
Subject: [PATCH] [SandboxVec][InstrMaps] EraseInstr callback

This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets
updated when instructions get erased.
---
 .../Vectorize/SandboxVectorizer/InstrMaps.h   | 32 +++++++++++++++++++
 .../SandboxVectorizer/Passes/BottomUpVec.h    |  2 +-
 .../SandboxVectorizer/Passes/BottomUpVec.cpp  |  6 ++--
 .../SandboxVectorizer/InstrMapsTest.cpp       | 11 ++++++-
 .../SandboxVectorizer/LegalityTest.cpp        |  6 ++--
 5 files changed, 49 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index 2c4ba30f6fd052..999fbb0aad9405 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -13,9 +13,12 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Instruction.h"
 #include "llvm/SandboxIR/Value.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
+#include <algorithm>
 
 namespace llvm::sandboxir {
 
@@ -30,8 +33,37 @@ class InstrMaps {
   /// with the same lane, as they may be coming from vectorizing different
   /// original values.
   DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
+  Context &Ctx;
+  std::optional<Context::CallbackID> EraseInstrCB;
+
+private:
+  void notifyEraseInstr(Value *V) {
+    // We don't know if V is an original or a vector value.
+    auto It = OrigToVectorMap.find(V);
+    if (It != OrigToVectorMap.end()) {
+      // V is an original value.
+      // Remove it from VectorToOrigLaneMap.
+      Value *Vec = It->second;
+      VectorToOrigLaneMap[Vec].erase(V);
+      // Now erase V from OrigToVectorMap.
+      OrigToVectorMap.erase(It);
+    } else {
+      // V is a vector value.
+      // Go over the original values it came from and remove them from
+      // OrigToVectorMap.
+      for (auto [Orig, Lane] : VectorToOrigLaneMap[V])
+        OrigToVectorMap.erase(Orig);
+      // Now erase V from VectorToOrigLaneMap.
+      VectorToOrigLaneMap.erase(V);
+    }
+  }
 
 public:
+  InstrMaps(Context &Ctx) : Ctx(Ctx) {
+    EraseInstrCB = Ctx.registerEraseInstrCallback(
+        [this](Instruction *I) { notifyEraseInstr(I); });
+  }
+  ~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); }
   /// \Returns the vector value that we got from vectorizing \p Orig, or
   /// nullptr if not found.
   Value *getVectorForOrig(Value *Orig) const {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 69cea3c4c7b53b..dd3012f7c9b556 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass {
   std::unique_ptr<LegalityAnalysis> Legality;
   DenseSet<Instruction *> DeadInstrCandidates;
   /// Maps scalars to vectors.
-  InstrMaps IMaps;
+  std::unique_ptr<InstrMaps> IMaps;
 
   /// Creates and returns a vector instruction that replaces the instructions in
   /// \p Bndl. \p Operands are the already vectorized operands.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 6b2032be535603..b8e2697839a3c2 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
   auto *VecI = CreateVectorInstr(Bndl, Operands);
   if (VecI != nullptr) {
     Change = true;
-    IMaps.registerVector(Bndl, VecI);
+    IMaps->registerVector(Bndl, VecI);
   }
   return VecI;
 }
@@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
 }
 
 bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
-  IMaps.clear();
+  IMaps = std::make_unique<InstrMaps>(F.getContext());
   Legality = std::make_unique<LegalityAnalysis>(
       A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
-      F.getContext(), IMaps);
+      F.getContext(), *IMaps);
   Change = false;
   const auto &DL = F.getParent()->getDataLayout();
   unsigned VecRegBits =
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index bcfb8db7f86741..11831b881ca7a8 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
   auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
   [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
 
-  sandboxir::InstrMaps IMaps;
+  sandboxir::InstrMaps IMaps(Ctx);
   // Check with empty IMaps.
   EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
   EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
@@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
 #ifndef NDEBUG
   EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
 #endif // NDEBUG
+  // Check callbacks: erase original instr.
+  Add0->eraseFromParent();
+  EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
+  EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
+  EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
+  // Check callbacks: erase vector instr.
+  VAdd0->eraseFromParent();
+  EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
+  EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
 }
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 2e90462a633c17..069bfdba0a7cdb 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
   auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
   auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
 
-  llvm::sandboxir::InstrMaps IMaps;
+  llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   const auto &Result =
       Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
@@ -230,7 +230,7 @@ define void @foo(ptr %ptr) {
   auto *St0 = cast<sandboxir::StoreInst>(&*It++);
   auto *St1 = cast<sandboxir::StoreInst>(&*It++);
 
-  llvm::sandboxir::InstrMaps IMaps;
+  llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   {
     // Can vectorize St0,St1.
@@ -266,7 +266,7 @@ define void @foo() {
   };
 
   sandboxir::Context Ctx(C);
-  llvm::sandboxir::InstrMaps IMaps;
+  llvm::sandboxir::InstrMaps IMaps(Ctx);
   sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
   EXPECT_TRUE(
       Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));



More information about the llvm-commits mailing list