[Mlir-commits] [mlir] 56a6985 - [mlir][Pass][NFC] Merge OpToOpPassAdaptor and OpToOpPassAdaptorParallel
River Riddle
llvmlistbot at llvm.org
Wed Apr 29 15:23:22 PDT 2020
Author: River Riddle
Date: 2020-04-29T15:23:10-07:00
New Revision: 56a698510faef5bf3ef224c229a049bb1e376a56
URL: https://github.com/llvm/llvm-project/commit/56a698510faef5bf3ef224c229a049bb1e376a56
DIFF: https://github.com/llvm/llvm-project/commit/56a698510faef5bf3ef224c229a049bb1e376a56.diff
LOG: [mlir][Pass][NFC] Merge OpToOpPassAdaptor and OpToOpPassAdaptorParallel
This moves the threading check to runOnOperation. This produces a much cleaner interface for the adaptor pass, and will allow for the ability to enable/disable threading in a much cleaner way in the future.
Differential Revision: https://reviews.llvm.org/D78313
Added:
Modified:
mlir/lib/Pass/IRPrinting.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/lib/Pass/PassStatistics.cpp
mlir/lib/Pass/PassTiming.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 679a9ec27ead..ba9ff989ea79 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -99,7 +99,7 @@ class IRPrinterInstrumentation : public PassInstrumentation {
/// Returns true if the given pass is hidden from IR printing.
static bool isHiddenPass(Pass *pass) {
- return isAdaptorPass(pass) || isa<VerifierPass>(pass);
+ return isa<OpToOpPassAdaptor>(pass) || isa<VerifierPass>(pass);
}
static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
@@ -173,7 +173,7 @@ void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
}
void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
- if (isAdaptorPass(pass))
+ if (isa<OpToOpPassAdaptor>(pass))
return;
if (config->shouldPrintAfterOnlyOnChange())
beforePassFingerPrints.erase(pass);
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 53ccd4f005a4..b6bef48cb3ec 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -51,7 +51,7 @@ void Pass::copyOptionValuesFrom(const Pass *other) {
/// an adaptor pass, print with the op_name(sub_pass,...) format.
void Pass::printAsTextualPipeline(raw_ostream &os) {
// Special case for adaptors to use the 'op_name(sub_passes)' format.
- if (auto *adaptor = getAdaptorPassBase(this)) {
+ if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(this)) {
llvm::interleaveComma(adaptor->getPassManagers(), os,
[&](OpPassManager &pm) {
os << pm.getOpName() << "(";
@@ -152,15 +152,15 @@ struct OpPassManagerImpl {
void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() {
// Bail out early if there are no adaptor passes.
if (llvm::none_of(passes, [](std::unique_ptr<Pass> &pass) {
- return isAdaptorPass(pass.get());
+ return isa<OpToOpPassAdaptor>(pass.get());
}))
return;
// Walk the pass list and merge adjacent adaptors.
- OpToOpPassAdaptorBase *lastAdaptor = nullptr;
+ OpToOpPassAdaptor *lastAdaptor = nullptr;
for (auto it = passes.begin(), e = passes.end(); it != e; ++it) {
// Check to see if this pass is an adaptor.
- if (auto *currentAdaptor = getAdaptorPassBase(it->get())) {
+ if (auto *currentAdaptor = dyn_cast<OpToOpPassAdaptor>(it->get())) {
// If it is the first adaptor in a possible chain, remember it and
// continue.
if (!lastAdaptor) {
@@ -243,16 +243,7 @@ LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) {
/// pass manager.
OpPassManager &OpPassManager::nest(const OperationName &nestedName) {
OpPassManager nested(nestedName, impl->disableThreads, impl->verifyPasses);
-
- /// Create an adaptor for this pass. If multi-threading is disabled, then
- /// create a synchronous adaptor.
- if (impl->disableThreads || !llvm::llvm_is_multithreaded()) {
- auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
- addPass(std::unique_ptr<Pass>(adaptor));
- return adaptor->getPassManagers().front();
- }
-
- auto *adaptor = new OpToOpPassAdaptorParallel(std::move(nested));
+ auto *adaptor = new OpToOpPassAdaptor(std::move(nested));
addPass(std::unique_ptr<Pass>(adaptor));
return adaptor->getPassManagers().front();
}
@@ -330,12 +321,12 @@ static OpPassManager *findPassManagerFor(MutableArrayRef<OpPassManager> mgrs,
return it == mgrs.end() ? nullptr : &*it;
}
-OpToOpPassAdaptorBase::OpToOpPassAdaptorBase(OpPassManager &&mgr) {
+OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
mgrs.emplace_back(std::move(mgr));
}
/// Merge the current pass adaptor into given 'rhs'.
-void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) {
+void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
for (auto &pm : mgrs) {
// If an existing pass manager exists, then merge the given pass manager
// into it.
@@ -357,7 +348,7 @@ void OpToOpPassAdaptorBase::mergeInto(OpToOpPassAdaptorBase &rhs) {
}
/// Returns the adaptor pass name.
-std::string OpToOpPassAdaptorBase::getName() {
+std::string OpToOpPassAdaptor::getAdaptorName() {
std::string name = "Pipeline Collection : [";
llvm::raw_string_ostream os(name);
llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) {
@@ -367,11 +358,16 @@ std::string OpToOpPassAdaptorBase::getName() {
return os.str();
}
-OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr)
- : OpToOpPassAdaptorBase(std::move(mgr)) {}
-
/// Run the held pipeline over all nested operations.
void OpToOpPassAdaptor::runOnOperation() {
+ if (mgrs.front().getImpl().disableThreads || !llvm::llvm_is_multithreaded())
+ runOnOperationImpl();
+ else
+ runOnOperationAsyncImpl();
+}
+
+/// Run this pass adaptor synchronously.
+void OpToOpPassAdaptor::runOnOperationImpl() {
auto am = getAnalysisManager();
PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
this};
@@ -397,9 +393,6 @@ void OpToOpPassAdaptor::runOnOperation() {
}
}
-OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(OpPassManager &&mgr)
- : OpToOpPassAdaptorBase(std::move(mgr)) {}
-
/// Utility functor that checks if the two ranges of pass managers have a size
/// mismatch.
static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
@@ -409,8 +402,8 @@ static bool hasSizeMismatch(ArrayRef<OpPassManager> lhs,
[&](size_t i) { return lhs[i].size() != rhs[i].size(); });
}
-// Run the held pipeline asynchronously across the functions within the module.
-void OpToOpPassAdaptorParallel::runOnOperation() {
+/// Run this pass adaptor synchronously.
+void OpToOpPassAdaptor::runOnOperationAsyncImpl() {
AnalysisManager am = getAnalysisManager();
// Create the async executors if they haven't been created, or if the main
@@ -491,16 +484,6 @@ void OpToOpPassAdaptorParallel::runOnOperation() {
signalPassFailure();
}
-/// Utility function to convert the given class to the base adaptor it is an
-/// adaptor pass, returns nullptr otherwise.
-OpToOpPassAdaptorBase *mlir::detail::getAdaptorPassBase(Pass *pass) {
- if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
- return adaptor;
- if (auto *adaptor = dyn_cast<OpToOpPassAdaptorParallel>(pass))
- return adaptor;
- return nullptr;
-}
-
//===----------------------------------------------------------------------===//
// PassCrashReproducer
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 59d9a7a0576f..2342a1a7af97 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -27,69 +27,44 @@ class VerifierPass : public PassWrapper<VerifierPass, OperationPass<>> {
// OpToOpPassAdaptor
//===----------------------------------------------------------------------===//
-/// A base class for Op-to-Op adaptor passes.
-class OpToOpPassAdaptorBase {
-public:
- OpToOpPassAdaptorBase(OpPassManager &&mgr);
- OpToOpPassAdaptorBase(const OpToOpPassAdaptorBase &rhs) = default;
-
- /// Merge the current pass adaptor into given 'rhs'.
- void mergeInto(OpToOpPassAdaptorBase &rhs);
-
- /// Returns the pass managers held by this adaptor.
- MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
-
- /// Returns the adaptor pass name.
- std::string getName();
-
-protected:
- // A set of adaptors to run.
- SmallVector<OpPassManager, 1> mgrs;
-};
-
-/// An adaptor pass used to run operation passes over nested operations
-/// synchronously on a single thread.
+/// An adaptor pass used to run operation passes over nested operations.
class OpToOpPassAdaptor
- : public PassWrapper<OpToOpPassAdaptor, OperationPass<>>,
- public OpToOpPassAdaptorBase {
+ : public PassWrapper<OpToOpPassAdaptor, OperationPass<>> {
public:
OpToOpPassAdaptor(OpPassManager &&mgr);
+ OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs) = default;
/// Run the held pipeline over all operations.
void runOnOperation() override;
-};
-/// An adaptor pass used to run operation passes over nested operations
-/// asynchronously across multiple threads.
-class OpToOpPassAdaptorParallel
- : public PassWrapper<OpToOpPassAdaptorParallel, OperationPass<>>,
- public OpToOpPassAdaptorBase {
-public:
- OpToOpPassAdaptorParallel(OpPassManager &&mgr);
+ /// Merge the current pass adaptor into given 'rhs'.
+ void mergeInto(OpToOpPassAdaptor &rhs);
- /// Run the held pipeline over all operations.
- void runOnOperation() override;
+ /// Returns the pass managers held by this adaptor.
+ MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
/// Return the async pass managers held by this parallel adaptor.
MutableArrayRef<SmallVector<OpPassManager, 1>> getParallelPassManagers() {
return asyncExecutors;
}
+ /// Returns the adaptor pass name.
+ std::string getAdaptorName();
+
private:
- // A set of executors, cloned from the main executor, that run asynchronously
- // on
diff erent threads.
- SmallVector<SmallVector<OpPassManager, 1>, 8> asyncExecutors;
-};
+ /// Run this pass adaptor synchronously.
+ void runOnOperationImpl();
+
+ /// Run this pass adaptor asynchronously.
+ void runOnOperationAsyncImpl();
-/// Utility function to convert the given class to the base adaptor it is an
-/// adaptor pass, returns nullptr otherwise.
-OpToOpPassAdaptorBase *getAdaptorPassBase(Pass *pass);
+ /// A set of adaptors to run.
+ SmallVector<OpPassManager, 1> mgrs;
-/// Utility function to return if a pass refers to an adaptor pass. Adaptor
-/// passes are those that internally execute a pipeline.
-inline bool isAdaptorPass(Pass *pass) {
- return isa<OpToOpPassAdaptorParallel>(pass) || isa<OpToOpPassAdaptor>(pass);
-}
+ /// A set of executors, cloned from the main executor, that run asynchronously
+ /// on
diff erent threads. This is used when threading is enabled.
+ SmallVector<SmallVector<OpPassManager, 1>, 8> asyncExecutors;
+};
} // end namespace detail
} // end namespace mlir
diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp
index 7ac54f7cf1af..6ef0d3bbea6a 100644
--- a/mlir/lib/Pass/PassStatistics.cpp
+++ b/mlir/lib/Pass/PassStatistics.cpp
@@ -60,7 +60,7 @@ static void printPassEntry(raw_ostream &os, unsigned indent, StringRef pass,
static void printResultsAsList(raw_ostream &os, OpPassManager &pm) {
llvm::StringMap<std::vector<Statistic>> mergedStats;
std::function<void(Pass *)> addStats = [&](Pass *pass) {
- auto *adaptor = getAdaptorPassBase(pass);
+ auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass);
// If this is not an adaptor, add the stats to the list if there are any.
if (!adaptor) {
@@ -105,13 +105,12 @@ static void printResultsAsList(raw_ostream &os, OpPassManager &pm) {
static void printResultsAsPipeline(raw_ostream &os, OpPassManager &pm) {
std::function<void(unsigned, Pass *)> printPass = [&](unsigned indent,
Pass *pass) {
- // Handle the case of an adaptor pass.
- if (auto *adaptor = getAdaptorPassBase(pass)) {
+ if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass)) {
// If this adaptor has more than one internal pipeline, print an entry for
// it.
auto mgrs = adaptor->getPassManagers();
if (mgrs.size() > 1) {
- printPassEntry(os, indent, adaptor->getName());
+ printPassEntry(os, indent, adaptor->getAdaptorName());
indent += 2;
}
@@ -195,8 +194,8 @@ void OpPassManager::mergeStatisticsInto(OpPassManager &other) {
Pass &pass = std::get<0>(passPair), &otherPass = std::get<1>(passPair);
// If this is an adaptor, then recursively merge the pass managers.
- if (auto *adaptorPass = getAdaptorPassBase(&pass)) {
- auto *otherAdaptorPass = getAdaptorPassBase(&otherPass);
+ if (auto *adaptorPass = dyn_cast<OpToOpPassAdaptor>(&pass)) {
+ auto *otherAdaptorPass = cast<OpToOpPassAdaptor>(&otherPass);
for (auto mgrs : llvm::zip(adaptorPass->getPassManagers(),
otherAdaptorPass->getPassManagers()))
std::get<0>(mgrs).mergeStatisticsInto(std::get<1>(mgrs));
@@ -217,18 +216,16 @@ void OpPassManager::mergeStatisticsInto(OpPassManager &other) {
/// consumption(e.g. dumping).
static void prepareStatistics(OpPassManager &pm) {
for (Pass &pass : pm.getPasses()) {
- OpToOpPassAdaptorBase *adaptor = getAdaptorPassBase(&pass);
+ OpToOpPassAdaptor *adaptor = dyn_cast<OpToOpPassAdaptor>(&pass);
if (!adaptor)
continue;
MutableArrayRef<OpPassManager> nestedPms = adaptor->getPassManagers();
- // If this is a parallel adaptor, merge the statistics from the async
- // pass managers into the main nested pass managers.
- if (auto *parallelAdaptor = dyn_cast<OpToOpPassAdaptorParallel>(&pass)) {
- for (auto &asyncPM : parallelAdaptor->getParallelPassManagers()) {
- for (unsigned i = 0, e = asyncPM.size(); i != e; ++i)
- asyncPM[i].mergeStatisticsInto(nestedPms[i]);
- }
+ // Merge the statistics from the async pass managers into the main nested
+ // pass managers.
+ for (auto &asyncPM : adaptor->getParallelPassManagers()) {
+ for (unsigned i = 0, e = asyncPM.size(); i != e; ++i)
+ asyncPM[i].mergeStatisticsInto(nestedPms[i]);
}
// Prepare the statistics of each of the nested passes.
diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp
index c8f0ad8afa50..71bf822a864b 100644
--- a/mlir/lib/Pass/PassTiming.cpp
+++ b/mlir/lib/Pass/PassTiming.cpp
@@ -277,17 +277,17 @@ void PassTiming::runAfterPipeline(const OperationName &name,
/// Start a new timer for the given pass.
void PassTiming::startPassTimer(Pass *pass) {
- auto kind = isAdaptorPass(pass) ? TimerKind::PipelineCollection
- : TimerKind::PassOrAnalysis;
+ auto kind = isa<OpToOpPassAdaptor>(pass) ? TimerKind::PipelineCollection
+ : TimerKind::PassOrAnalysis;
Timer *timer = getTimer(pass, kind, [pass]() -> std::string {
- if (auto *adaptor = getAdaptorPassBase(pass))
- return adaptor->getName();
+ if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
+ return adaptor->getAdaptorName();
return std::string(pass->getName());
});
// We don't actually want to time the adaptor passes, they gather their total
// from their held passes.
- if (!isAdaptorPass(pass))
+ if (!isa<OpToOpPassAdaptor>(pass))
timer->start();
}
@@ -302,9 +302,9 @@ void PassTiming::startAnalysisTimer(StringRef name, TypeID id) {
void PassTiming::runAfterPass(Pass *pass, Operation *) {
Timer *timer = popLastActiveTimer();
- // If this is an OpToOpPassAdaptorParallel, then we need to merge in the
- // timing data for the pipelines running on other threads.
- if (isa<OpToOpPassAdaptorParallel>(pass)) {
+ // If this is a pass adaptor, then we need to merge in the timing data for the
+ // pipelines running on other threads.
+ if (isa<OpToOpPassAdaptor>(pass)) {
auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass});
if (toMerge != pipelinesToMerge.end()) {
for (auto &it : toMerge->second)
@@ -314,10 +314,7 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) {
return;
}
- // Adaptor passes aren't timed directly, so we don't need to stop their
- // timers.
- if (!isAdaptorPass(pass))
- timer->stop();
+ timer->stop();
}
/// Stop a timer.
More information about the Mlir-commits
mailing list