[llvm] 747d8f3 - [SandboxVec][DAG] Implement PredIterator (#111604)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 10 12:02:00 PDT 2024
Author: vporpo
Date: 2024-10-10T12:01:56-07:00
New Revision: 747d8f3fc93d912183059142631a343fb20bd07f
URL: https://github.com/llvm/llvm-project/commit/747d8f3fc93d912183059142631a343fb20bd07f
DIFF: https://github.com/llvm/llvm-project/commit/747d8f3fc93d912183059142631a343fb20bd07f.diff
LOG: [SandboxVec][DAG] Implement PredIterator (#111604)
This patch implements an iterator for iterating over both use-def and
mem dependencies of MemDGNodes.
Added:
Modified:
llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
index 134adc4b21ab12..eba6d7562e41de 100644
--- a/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
+++ b/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h
@@ -40,6 +40,54 @@ enum class DGNodeID {
MemDGNode,
};
+class DGNode;
+class MemDGNode;
+class DependencyGraph;
+
+/// While OpIt points to a Value that is not an Instruction keep incrementing
+/// it. \Returns the first iterator that points to an Instruction, or end.
+[[nodiscard]] static User::op_iterator skipNonInstr(User::op_iterator OpIt,
+ User::op_iterator OpItE) {
+ while (OpIt != OpItE && !isa<Instruction>((*OpIt).get()))
+ ++OpIt;
+ return OpIt;
+}
+
+/// Iterate over both def-use and mem dependencies.
+class PredIterator {
+ User::op_iterator OpIt;
+ User::op_iterator OpItE;
+ DenseSet<MemDGNode *>::iterator MemIt;
+ DGNode *N = nullptr;
+ DependencyGraph *DAG = nullptr;
+
+ PredIterator(const User::op_iterator &OpIt, const User::op_iterator &OpItE,
+ const DenseSet<MemDGNode *>::iterator &MemIt, DGNode *N,
+ DependencyGraph &DAG)
+ : OpIt(OpIt), OpItE(OpItE), MemIt(MemIt), N(N), DAG(&DAG) {}
+ PredIterator(const User::op_iterator &OpIt, const User::op_iterator &OpItE,
+ DGNode *N, DependencyGraph &DAG)
+ : OpIt(OpIt), OpItE(OpItE), N(N), DAG(&DAG) {}
+ friend class DGNode; // For constructor
+ friend class MemDGNode; // For constructor
+
+public:
+ using
diff erence_type = std::ptr
diff _t;
+ using value_type = DGNode *;
+ using pointer = value_type *;
+ using reference = value_type &;
+ using iterator_category = std::input_iterator_tag;
+ value_type operator*();
+ PredIterator &operator++();
+ PredIterator operator++(int) {
+ auto Copy = *this;
+ ++(*this);
+ return Copy;
+ }
+ bool operator==(const PredIterator &Other) const;
+ bool operator!=(const PredIterator &Other) const { return !(*this == Other); }
+};
+
/// A DependencyGraph Node that points to an Instruction and contains memory
/// dependency edges.
class DGNode {
@@ -63,6 +111,23 @@ class DGNode {
virtual ~DGNode() = default;
/// \Returns true if this is before \p Other in program order.
bool comesBefore(const DGNode *Other) { return I->comesBefore(Other->I); }
+ using iterator = PredIterator;
+ virtual iterator preds_begin(DependencyGraph &DAG) {
+ return PredIterator(skipNonInstr(I->op_begin(), I->op_end()), I->op_end(),
+ this, DAG);
+ }
+ virtual iterator preds_end(DependencyGraph &DAG) {
+ return PredIterator(I->op_end(), I->op_end(), this, DAG);
+ }
+ iterator preds_begin(DependencyGraph &DAG) const {
+ return const_cast<DGNode *>(this)->preds_begin(DAG);
+ }
+ iterator preds_end(DependencyGraph &DAG) const {
+ return const_cast<DGNode *>(this)->preds_end(DAG);
+ }
+ iterator_range<iterator> preds(DependencyGraph &DAG) const {
+ return make_range(preds_begin(DAG), preds_end(DAG));
+ }
static bool isStackSaveOrRestoreIntrinsic(Instruction *I) {
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
@@ -145,6 +210,14 @@ class MemDGNode final : public DGNode {
static bool classof(const DGNode *Other) {
return Other->SubclassID == DGNodeID::MemDGNode;
}
+ iterator preds_begin(DependencyGraph &DAG) override {
+ auto OpEndIt = I->op_end();
+ return PredIterator(skipNonInstr(I->op_begin(), OpEndIt), OpEndIt,
+ MemPreds.begin(), this, DAG);
+ }
+ iterator preds_end(DependencyGraph &DAG) override {
+ return PredIterator(I->op_end(), I->op_end(), MemPreds.end(), this, DAG);
+ }
/// \Returns the previous Mem DGNode in instruction order.
MemDGNode *getPrevNode() const { return PrevMemN; }
/// \Returns the next Mem DGNode in instruction order.
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 82f253d4c63231..7aea466ed6d8db 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -8,10 +8,54 @@
#include "llvm/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Utils.h"
namespace llvm::sandboxir {
+PredIterator::value_type PredIterator::operator*() {
+ // If it's a DGNode then we dereference the operand iterator.
+ if (!isa<MemDGNode>(N)) {
+ assert(OpIt != OpItE && "Can't dereference end iterator!");
+ return DAG->getNode(cast<Instruction>((Value *)*OpIt));
+ }
+ // It's a MemDGNode, so we check if we return either the use-def operand,
+ // or a mem predecessor.
+ if (OpIt != OpItE)
+ return DAG->getNode(cast<Instruction>((Value *)*OpIt));
+ assert(MemIt != cast<MemDGNode>(N)->memPreds().end() &&
+ "Cant' dereference end iterator!");
+ return *MemIt;
+}
+
+PredIterator &PredIterator::operator++() {
+ // If it's a DGNode then we increment the use-def iterator.
+ if (!isa<MemDGNode>(N)) {
+ assert(OpIt != OpItE && "Already at end!");
+ ++OpIt;
+ // Skip operands that are not instructions.
+ OpIt = skipNonInstr(OpIt, OpItE);
+ return *this;
+ }
+ // It's a MemDGNode, so if we are not at the end of the use-def iterator we
+ // need to first increment that.
+ if (OpIt != OpItE) {
+ ++OpIt;
+ // Skip operands that are not instructions.
+ OpIt = skipNonInstr(OpIt, OpItE);
+ return *this;
+ }
+ assert(MemIt != cast<MemDGNode>(N)->memPreds().end() && "Already at end!");
+ ++MemIt;
+ return *this;
+}
+
+bool PredIterator::operator==(const PredIterator &Other) const {
+ assert(DAG == Other.DAG && "Iterators of
diff erent DAGs!");
+ assert(N == Other.N && "Iterators of
diff erent nodes!");
+ return OpIt == Other.OpIt && MemIt == Other.MemIt;
+}
+
#ifndef NDEBUG
void DGNode::print(raw_ostream &OS, bool PrintDeps) const {
I->dumpOS(OS);
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index e2f16919a5cddd..6b3d9cc77c9955 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -240,12 +240,53 @@ define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
EXPECT_TRUE(N1->hasMemPred(N0));
EXPECT_FALSE(N0->hasMemPred(N1));
+ // Check preds().
+ EXPECT_TRUE(N0->preds(DAG).empty());
+ EXPECT_THAT(N1->preds(DAG), testing::ElementsAre(N0));
+
// Check memPreds().
EXPECT_TRUE(N0->memPreds().empty());
EXPECT_THAT(N1->memPreds(), testing::ElementsAre(N0));
EXPECT_TRUE(N2->memPreds().empty());
}
+TEST_F(DependencyGraphTest, Preds) {
+ parseIR(C, R"IR(
+declare ptr @bar(i8)
+define i8 @foo(i8 %v0, i8 %v1) {
+ %add0 = add i8 %v0, %v0
+ %add1 = add i8 %v1, %v1
+ %add2 = add i8 %add0, %add1
+ %ptr = call ptr @bar(i8 %add1)
+ store i8 %add2, ptr %ptr
+ ret i8 %add2
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *BB = &*F->begin();
+ auto It = BB->begin();
+ sandboxir::DependencyGraph DAG(getAA(*LLVMF));
+ DAG.extend({&*BB->begin(), BB->getTerminator()});
+
+ auto *AddN0 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
+ auto *AddN1 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
+ auto *AddN2 = DAG.getNode(cast<sandboxir::BinaryOperator>(&*It++));
+ auto *CallN = DAG.getNode(cast<sandboxir::CallInst>(&*It++));
+ auto *StN = DAG.getNode(cast<sandboxir::StoreInst>(&*It++));
+ auto *RetN = DAG.getNode(cast<sandboxir::ReturnInst>(&*It++));
+
+ // Check preds().
+ EXPECT_THAT(AddN0->preds(DAG), testing::ElementsAre());
+ EXPECT_THAT(AddN1->preds(DAG), testing::ElementsAre());
+ EXPECT_THAT(AddN2->preds(DAG), testing::ElementsAre(AddN0, AddN1));
+ EXPECT_THAT(CallN->preds(DAG), testing::ElementsAre(AddN1));
+ EXPECT_THAT(StN->preds(DAG),
+ testing::UnorderedElementsAre(CallN, CallN, AddN2));
+ EXPECT_THAT(RetN->preds(DAG), testing::ElementsAre(AddN2));
+}
+
TEST_F(DependencyGraphTest, MemDGNode_getPrevNode_getNextNode) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %v0, i8 %v1) {
More information about the llvm-commits
mailing list