[llvm] [SandboxVec][InstrMaps] EraseInstr callback (PR #123256)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 16 15:33:37 PST 2025
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/123256
This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets updated when instructions get erased.
>From 5dcd170ec3ef9afadf254f62d593fb249b4871ee Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 15 Nov 2024 12:53:37 -0800
Subject: [PATCH] [SandboxVec][InstrMaps] EraseInstr callback
This patch hooks up InstrMaps to the Sandbox IR callbacks such that it gets
updated when instructions get erased.
---
.../Vectorize/SandboxVectorizer/InstrMaps.h | 32 +++++++++++++++++++
.../SandboxVectorizer/Passes/BottomUpVec.h | 2 +-
.../SandboxVectorizer/Passes/BottomUpVec.cpp | 6 ++--
.../SandboxVectorizer/InstrMapsTest.cpp | 11 ++++++-
.../SandboxVectorizer/LegalityTest.cpp | 6 ++--
5 files changed, 49 insertions(+), 8 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
index 2c4ba30f6fd052..999fbb0aad9405 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h
@@ -13,9 +13,12 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/SandboxIR/Context.h"
+#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
+#include <algorithm>
namespace llvm::sandboxir {
@@ -30,8 +33,37 @@ class InstrMaps {
/// with the same lane, as they may be coming from vectorizing different
/// original values.
DenseMap<Value *, DenseMap<Value *, unsigned>> VectorToOrigLaneMap;
+ Context &Ctx;
+ std::optional<Context::CallbackID> EraseInstrCB;
+
+private:
+ void notifyEraseInstr(Value *V) {
+ // We don't know if V is an original or a vector value.
+ auto It = OrigToVectorMap.find(V);
+ if (It != OrigToVectorMap.end()) {
+ // V is an original value.
+ // Remove it from VectorToOrigLaneMap.
+ Value *Vec = It->second;
+ VectorToOrigLaneMap[Vec].erase(V);
+ // Now erase V from OrigToVectorMap.
+ OrigToVectorMap.erase(It);
+ } else {
+ // V is a vector value.
+ // Go over the original values it came from and remove them from
+ // OrigToVectorMap.
+ for (auto [Orig, Lane] : VectorToOrigLaneMap[V])
+ OrigToVectorMap.erase(Orig);
+ // Now erase V from VectorToOrigLaneMap.
+ VectorToOrigLaneMap.erase(V);
+ }
+ }
public:
+ InstrMaps(Context &Ctx) : Ctx(Ctx) {
+ EraseInstrCB = Ctx.registerEraseInstrCallback(
+ [this](Instruction *I) { notifyEraseInstr(I); });
+ }
+ ~InstrMaps() { Ctx.unregisterEraseInstrCallback(*EraseInstrCB); }
/// \Returns the vector value that we got from vectorizing \p Orig, or
/// nullptr if not found.
Value *getVectorForOrig(Value *Orig) const {
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
index 69cea3c4c7b53b..dd3012f7c9b556 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h
@@ -28,7 +28,7 @@ class BottomUpVec final : public FunctionPass {
std::unique_ptr<LegalityAnalysis> Legality;
DenseSet<Instruction *> DeadInstrCandidates;
/// Maps scalars to vectors.
- InstrMaps IMaps;
+ std::unique_ptr<InstrMaps> IMaps;
/// Creates and returns a vector instruction that replaces the instructions in
/// \p Bndl. \p Operands are the already vectorized operands.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
index 6b2032be535603..b8e2697839a3c2 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp
@@ -161,7 +161,7 @@ Value *BottomUpVec::createVectorInstr(ArrayRef<Value *> Bndl,
auto *VecI = CreateVectorInstr(Bndl, Operands);
if (VecI != nullptr) {
Change = true;
- IMaps.registerVector(Bndl, VecI);
+ IMaps->registerVector(Bndl, VecI);
}
return VecI;
}
@@ -315,10 +315,10 @@ bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
}
bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
- IMaps.clear();
+ IMaps = std::make_unique<InstrMaps>(F.getContext());
Legality = std::make_unique<LegalityAnalysis>(
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
- F.getContext(), IMaps);
+ F.getContext(), *IMaps);
Change = false;
const auto &DL = F.getParent()->getDataLayout();
unsigned VecRegBits =
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
index bcfb8db7f86741..11831b881ca7a8 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/InstrMapsTest.cpp
@@ -53,7 +53,7 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
auto *VAdd0 = cast<sandboxir::BinaryOperator>(&*It++);
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
- sandboxir::InstrMaps IMaps;
+ sandboxir::InstrMaps IMaps(Ctx);
// Check with empty IMaps.
EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
@@ -75,4 +75,13 @@ define void @foo(i8 %v0, i8 %v1, i8 %v2, i8 %v3, <2 x i8> %vec) {
#ifndef NDEBUG
EXPECT_DEATH(IMaps.registerVector({Add1, Add0}, VAdd0), ".*exists.*");
#endif // NDEBUG
+ // Check callbacks: erase original instr.
+ Add0->eraseFromParent();
+ EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add0));
+ EXPECT_EQ(*IMaps.getOrigLane(VAdd0, Add1), 1);
+ EXPECT_EQ(IMaps.getVectorForOrig(Add0), nullptr);
+ // Check callbacks: erase vector instr.
+ VAdd0->eraseFromParent();
+ EXPECT_FALSE(IMaps.getOrigLane(VAdd0, Add1));
+ EXPECT_EQ(IMaps.getVectorForOrig(Add1), nullptr);
}
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
index 2e90462a633c17..069bfdba0a7cdb 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/LegalityTest.cpp
@@ -111,7 +111,7 @@ define void @foo(ptr %ptr, <2 x float> %vec2, <3 x float> %vec3, i8 %arg, float
auto *CmpSLT = cast<sandboxir::CmpInst>(&*It++);
auto *CmpSGT = cast<sandboxir::CmpInst>(&*It++);
- llvm::sandboxir::InstrMaps IMaps;
+ llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
const auto &Result =
Legality.canVectorize({St0, St1}, /*SkipScheduling=*/true);
@@ -230,7 +230,7 @@ define void @foo(ptr %ptr) {
auto *St0 = cast<sandboxir::StoreInst>(&*It++);
auto *St1 = cast<sandboxir::StoreInst>(&*It++);
- llvm::sandboxir::InstrMaps IMaps;
+ llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
{
// Can vectorize St0,St1.
@@ -266,7 +266,7 @@ define void @foo() {
};
sandboxir::Context Ctx(C);
- llvm::sandboxir::InstrMaps IMaps;
+ llvm::sandboxir::InstrMaps IMaps(Ctx);
sandboxir::LegalityAnalysis Legality(*AA, *SE, DL, Ctx, IMaps);
EXPECT_TRUE(
Matches(Legality.createLegalityResult<sandboxir::Widen>(), "Widen"));
More information about the llvm-commits
mailing list