[Mlir-commits] [mlir] d7eba20 - [mlir][Inliner] Refactor the inliner to use nested pass pipelines instead of just canonicalization
River Riddle
llvmlistbot at llvm.org
Mon Dec 14 18:28:29 PST 2020
Author: River Riddle
Date: 2020-12-14T18:09:47-08:00
New Revision: d7eba2005267aa4a8f46f73f208c7cc23e6c6a1a
URL: https://github.com/llvm/llvm-project/commit/d7eba2005267aa4a8f46f73f208c7cc23e6c6a1a
DIFF: https://github.com/llvm/llvm-project/commit/d7eba2005267aa4a8f46f73f208c7cc23e6c6a1a.diff
LOG: [mlir][Inliner] Refactor the inliner to use nested pass pipelines instead of just canonicalization
Now that passes have support for running nested pipelines, the inliner can now allow for users to provide proper nested pipelines to use for optimization during inlining. This revision also changes the behavior of optimization during inlining to optimize before attempting to inline, which should lead to a more accurate cost model and prevents the need for users to schedule additional duplicate cleanup passes before/after the inliner that would already be run during inlining.
Differential Revision: https://reviews.llvm.org/D91211
Added:
Modified:
llvm/include/llvm/ADT/Sequence.h
mlir/include/mlir/Pass/AnalysisManager.h
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/lib/Pass/PassRegistry.cpp
mlir/lib/Pass/PassTiming.cpp
mlir/lib/Transforms/Inliner.cpp
mlir/test/Dialect/Affine/inlining.mlir
mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
mlir/test/Pass/dynamic-pipeline-nested.mlir
mlir/test/Transforms/inlining.mlir
mlir/test/lib/Transforms/TestDynamicPipeline.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/ADT/Sequence.h b/llvm/include/llvm/ADT/Sequence.h
index 8c505f2010dd..8a695d75f77a 100644
--- a/llvm/include/llvm/ADT/Sequence.h
+++ b/llvm/include/llvm/ADT/Sequence.h
@@ -42,6 +42,10 @@ class value_sequence_iterator
value_sequence_iterator(const value_sequence_iterator &) = default;
value_sequence_iterator(value_sequence_iterator &&Arg)
: Value(std::move(Arg.Value)) {}
+ value_sequence_iterator &operator=(const value_sequence_iterator &Arg) {
+ Value = Arg.Value;
+ return *this;
+ }
template <typename U, typename Enabler = decltype(ValueT(std::declval<U>()))>
value_sequence_iterator(U &&Value) : Value(std::forward<U>(Value)) {}
diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index ec6b7696ce60..5da0c95d78dc 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -98,7 +98,7 @@ struct AnalysisConcept {
/// A derived analysis model used to hold a specific analysis object.
template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
template <typename... Args>
- explicit AnalysisModel(Args &&... args)
+ explicit AnalysisModel(Args &&...args)
: analysis(std::forward<Args>(args)...) {}
/// A hook used to query analyses for invalidation.
@@ -198,7 +198,10 @@ class AnalysisMap {
/// An analysis map that contains a map for the current operation, and a set of
/// maps for any child operations.
struct NestedAnalysisMap {
- NestedAnalysisMap(Operation *op) : analyses(op) {}
+ NestedAnalysisMap(Operation *op, PassInstrumentor *instrumentor)
+ : analyses(op), parentOrInstrumentor(instrumentor) {}
+ NestedAnalysisMap(Operation *op, NestedAnalysisMap *parent)
+ : analyses(op), parentOrInstrumentor(parent) {}
/// Get the operation for this analysis map.
Operation *getOperation() const { return analyses.getOperation(); }
@@ -206,11 +209,34 @@ struct NestedAnalysisMap {
/// Invalidate any non preserved analyses.
void invalidate(const PreservedAnalyses &pa);
+ /// Returns the parent analysis map for this analysis map, or null if this is
+ /// the top-level map.
+ const NestedAnalysisMap *getParent() const {
+ return parentOrInstrumentor.dyn_cast<NestedAnalysisMap *>();
+ }
+
+ /// Returns a pass instrumentation object for the current operation. This
+ /// value may be null.
+ PassInstrumentor *getPassInstrumentor() const {
+ if (auto *parent = getParent())
+ return parent->getPassInstrumentor();
+ return parentOrInstrumentor.get<PassInstrumentor *>();
+ }
+
/// The cached analyses for nested operations.
DenseMap<Operation *, std::unique_ptr<NestedAnalysisMap>> childAnalyses;
- /// The analyses for the owning module.
+ /// The analyses for the owning operation.
detail::AnalysisMap analyses;
+
+ /// This value has three possible states:
+ /// NestedAnalysisMap*: A pointer to the parent analysis map.
+ /// PassInstrumentor*: This analysis map is the top-level map, and this
+ /// pointer is the optional pass instrumentor for the
+ /// current compilation.
+ /// nullptr: This analysis map is the top-level map, and there is nop pass
+ /// instrumentor.
+ PointerUnion<NestedAnalysisMap *, PassInstrumentor *> parentOrInstrumentor;
};
} // namespace detail
@@ -236,11 +262,11 @@ class AnalysisManager {
template <typename AnalysisT>
Optional<std::reference_wrapper<AnalysisT>>
getCachedParentAnalysis(Operation *parentOp) const {
- ParentPointerT curParent = parent;
- while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>()) {
- if (parentAM->impl->getOperation() == parentOp)
- return parentAM->getCachedAnalysis<AnalysisT>();
- curParent = parentAM->parent;
+ const detail::NestedAnalysisMap *curParent = impl;
+ while (auto *parentAM = curParent->getParent()) {
+ if (parentAM->getOperation() == parentOp)
+ return parentAM->analyses.getCachedAnalysis<AnalysisT>();
+ curParent = parentAM;
}
return None;
}
@@ -286,7 +312,8 @@ class AnalysisManager {
return it->second->analyses.getCachedAnalysis<AnalysisT>();
}
- /// Get an analysis manager for the given child operation.
+ /// Get an analysis manager for the given operation, which must be a proper
+ /// descendant of the current operation represented by this analysis manager.
AnalysisManager nest(Operation *op);
/// Invalidate any non preserved analyses,
@@ -300,19 +327,15 @@ class AnalysisManager {
/// Returns a pass instrumentation object for the current operation. This
/// value may be null.
- PassInstrumentor *getPassInstrumentor() const;
+ PassInstrumentor *getPassInstrumentor() const {
+ return impl->getPassInstrumentor();
+ }
private:
- AnalysisManager(const AnalysisManager *parent,
- detail::NestedAnalysisMap *impl)
- : parent(parent), impl(impl) {}
- AnalysisManager(const ModuleAnalysisManager *parent,
- detail::NestedAnalysisMap *impl)
- : parent(parent), impl(impl) {}
+ AnalysisManager(detail::NestedAnalysisMap *impl) : impl(impl) {}
- /// A reference to the parent analysis manager, or the top-level module
- /// analysis manager.
- ParentPointerT parent;
+ /// Get an analysis manager for the given immediately nested child operation.
+ AnalysisManager nestImmediate(Operation *op);
/// A reference to the impl analysis map within the parent analysis manager.
detail::NestedAnalysisMap *impl;
@@ -328,23 +351,16 @@ class AnalysisManager {
class ModuleAnalysisManager {
public:
ModuleAnalysisManager(Operation *op, PassInstrumentor *passInstrumentor)
- : analyses(op), passInstrumentor(passInstrumentor) {}
+ : analyses(op, passInstrumentor) {}
ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
- /// Returns a pass instrumentation object for the current module. This value
- /// may be null.
- PassInstrumentor *getPassInstrumentor() const { return passInstrumentor; }
-
/// Returns an analysis manager for the current top-level module.
- operator AnalysisManager() { return AnalysisManager(this, &analyses); }
+ operator AnalysisManager() { return AnalysisManager(&analyses); }
private:
/// The analyses for the owning module.
detail::NestedAnalysisMap analyses;
-
- /// An optional instrumentation object.
- PassInstrumentor *passInstrumentor;
};
} // end namespace mlir
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 87f4f6be5ab0..7a9523714293 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -95,7 +95,7 @@ class Pass {
typename OptionParser = detail::PassOptions::OptionParser<DataType>>
struct Option : public detail::PassOptions::Option<DataType, OptionParser> {
template <typename... Args>
- Option(Pass &parent, StringRef arg, Args &&... args)
+ Option(Pass &parent, StringRef arg, Args &&...args)
: detail::PassOptions::Option<DataType, OptionParser>(
parent.passOptions, arg, std::forward<Args>(args)...) {}
using detail::PassOptions::Option<DataType, OptionParser>::operator=;
@@ -107,14 +107,17 @@ class Pass {
struct ListOption
: public detail::PassOptions::ListOption<DataType, OptionParser> {
template <typename... Args>
- ListOption(Pass &parent, StringRef arg, Args &&... args)
+ ListOption(Pass &parent, StringRef arg, Args &&...args)
: detail::PassOptions::ListOption<DataType, OptionParser>(
parent.passOptions, arg, std::forward<Args>(args)...) {}
using detail::PassOptions::ListOption<DataType, OptionParser>::operator=;
};
/// Attempt to initialize the options of this pass from the given string.
- LogicalResult initializeOptions(StringRef options);
+ /// Derived classes may override this method to hook into the point at which
+ /// options are initialized, but should generally always invoke this base
+ /// class variant.
+ virtual LogicalResult initializeOptions(StringRef options);
/// Prints out the pass in the textual representation of pipelines. If this is
/// an adaptor pass, print with the op_name(sub_pass,...) format.
@@ -265,7 +268,6 @@ class Pass {
void copyOptionValuesFrom(const Pass *other);
private:
-
/// Out of line virtual method to ensure vtables and metadata are emitted to a
/// single .o file.
virtual void anchor();
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 5e9c9a790d29..2715ebd05cac 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -48,8 +48,8 @@ struct PassExecutionState;
class OpPassManager {
public:
enum class Nesting { Implicit, Explicit };
- OpPassManager(Identifier name, Nesting nesting);
- OpPassManager(StringRef name, Nesting nesting);
+ OpPassManager(Identifier name, Nesting nesting = Nesting::Explicit);
+ OpPassManager(StringRef name, Nesting nesting = Nesting::Explicit);
OpPassManager(OpPassManager &&rhs);
OpPassManager(const OpPassManager &rhs);
~OpPassManager();
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index c092d0120b60..6fe66601b21b 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -107,6 +107,19 @@ std::unique_ptr<Pass> createPrintOpStatsPass();
/// Creates a pass which inlines calls and callable operations as defined by
/// the CallGraph.
std::unique_ptr<Pass> createInlinerPass();
+/// Creates an instance of the inliner pass, and use the provided pass managers
+/// when optimizing callable operations with names matching the key type.
+/// Callable operations with a name not within the provided map will use the
+/// default inliner pipeline during optimization.
+std::unique_ptr<Pass>
+createInlinerPass(llvm::StringMap<OpPassManager> opPipelines);
+/// Creates an instance of the inliner pass, and use the provided pass managers
+/// when optimizing callable operations with names matching the key type.
+/// Callable operations with a name not within the provided map will use the
+/// provided default pipeline builder.
+std::unique_ptr<Pass>
+createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
+ std::function<void(OpPassManager &)> defaultPipelineBuilder);
/// Creates a pass which performs sparse conditional constant propagation over
/// nested operations.
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index afad7cd5852f..438a468673b5 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -285,9 +285,12 @@ def Inliner : Pass<"inline"> {
let summary = "Inline function calls";
let constructor = "mlir::createInlinerPass()";
let options = [
- Option<"disableCanonicalization", "disable-simplify", "bool",
- /*default=*/"false",
- "Disable running simplifications during inlining">,
+ Option<"defaultPipelineStr", "default-pipeline", "std::string",
+ /*default=*/"", "The default optimizer pipeline used for callables">,
+ ListOption<"opPipelineStrs", "op-pipelines", "std::string",
+ "Callable operation specific optimizer pipelines (in the form "
+ "of `dialect.op(pipeline)`)",
+ "llvm::cl::MiscFlags::CommaSeparated">,
Option<"maxInliningIterations", "max-iterations", "unsigned",
/*default=*/"4",
"Maximum number of iterations when inlining within an SCC">,
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index f53a087fac47..d9046bef1469 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -340,22 +340,25 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
// Initialize the pass state with a callback for the pass to dynamically
// execute a pipeline on the currently visited operation.
- auto dynamic_pipeline_callback =
- [op, &am, verifyPasses](OpPassManager &pipeline,
- Operation *root) -> LogicalResult {
+ PassInstrumentor *pi = am.getPassInstrumentor();
+ PassInstrumentation::PipelineParentInfo parentInfo = {llvm::get_threadid(),
+ pass};
+ auto dynamic_pipeline_callback = [&](OpPassManager &pipeline,
+ Operation *root) -> LogicalResult {
if (!op->isAncestor(root))
return root->emitOpError()
<< "Trying to schedule a dynamic pipeline on an "
"operation that isn't "
"nested under the current operation the pass is processing";
+ assert(pipeline.getOpName() == root->getName().getStringRef());
- AnalysisManager nestedAm = am.nest(root);
+ AnalysisManager nestedAm = root == op ? am : am.nest(root);
return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm,
- verifyPasses);
+ verifyPasses, pi, &parentInfo);
};
pass->passState.emplace(op, am, dynamic_pipeline_callback);
+
// Instrument before the pass has run.
- PassInstrumentor *pi = am.getPassInstrumentor();
if (pi)
pi->runBeforePass(pass, op);
@@ -388,7 +391,10 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
/// Run the given operation and analysis manager on a provided op pass manager.
LogicalResult OpToOpPassAdaptor::runPipeline(
iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
- AnalysisManager am, bool verifyPasses) {
+ AnalysisManager am, bool verifyPasses, PassInstrumentor *instrumentor,
+ const PassInstrumentation::PipelineParentInfo *parentInfo) {
+ assert((!instrumentor || parentInfo) &&
+ "expected parent info if instrumentor is provided");
auto scope_exit = llvm::make_scope_exit([&] {
// Clear out any computed operation analyses. These analyses won't be used
// any more in this pipeline, and this helps reduce the current working set
@@ -398,10 +404,13 @@ LogicalResult OpToOpPassAdaptor::runPipeline(
});
// Run the pipeline over the provided operation.
+ if (instrumentor)
+ instrumentor->runBeforePipeline(op->getName().getIdentifier(), *parentInfo);
for (Pass &pass : passes)
if (failed(run(&pass, op, am, verifyPasses)))
return failure();
-
+ if (instrumentor)
+ instrumentor->runAfterPipeline(op->getName().getIdentifier(), *parentInfo);
return success();
}
@@ -491,17 +500,10 @@ void OpToOpPassAdaptor::runOnOperationImpl(bool verifyPasses) {
*op.getContext());
if (!mgr)
continue;
- Identifier opName = mgr->getOpName(*getOperation()->getContext());
// Run the held pipeline over the current operation.
- if (instrumentor)
- instrumentor->runBeforePipeline(opName, parentInfo);
- LogicalResult result =
- runPipeline(mgr->getPasses(), &op, am.nest(&op), verifyPasses);
- if (instrumentor)
- instrumentor->runAfterPipeline(opName, parentInfo);
-
- if (failed(result))
+ if (failed(runPipeline(mgr->getPasses(), &op, am.nest(&op),
+ verifyPasses, instrumentor, &parentInfo)))
return signalPassFailure();
}
}
@@ -576,13 +578,9 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
pms, it.first->getName().getIdentifier(), getContext());
assert(pm && "expected valid pass manager for operation");
- Identifier opName = pm->getOpName(*getOperation()->getContext());
- if (instrumentor)
- instrumentor->runBeforePipeline(opName, parentInfo);
- auto pipelineResult =
- runPipeline(pm->getPasses(), it.first, it.second, verifyPasses);
- if (instrumentor)
- instrumentor->runAfterPipeline(opName, parentInfo);
+ LogicalResult pipelineResult =
+ runPipeline(pm->getPasses(), it.first, it.second, verifyPasses,
+ instrumentor, &parentInfo);
// Drop this thread from being tracked by the diagnostic handler.
// After this task has finished, the thread may be used outside of
@@ -848,22 +846,41 @@ void PassManager::addInstrumentation(std::unique_ptr<PassInstrumentation> pi) {
// AnalysisManager
//===----------------------------------------------------------------------===//
-/// Returns a pass instrumentation object for the current operation.
-PassInstrumentor *AnalysisManager::getPassInstrumentor() const {
- ParentPointerT curParent = parent;
- while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>())
- curParent = parentAM->parent;
- return curParent.get<const ModuleAnalysisManager *>()->getPassInstrumentor();
+/// Get an analysis manager for the given operation, which must be a proper
+/// descendant of the current operation represented by this analysis manager.
+AnalysisManager AnalysisManager::nest(Operation *op) {
+ Operation *currentOp = impl->getOperation();
+ assert(currentOp->isProperAncestor(op) &&
+ "expected valid descendant operation");
+
+ // Check for the base case where the provided operation is immediately nested.
+ if (currentOp == op->getParentOp())
+ return nestImmediate(op);
+
+ // Otherwise, we need to collect all ancestors up to the current operation.
+ SmallVector<Operation *, 4> opAncestors;
+ do {
+ opAncestors.push_back(op);
+ op = op->getParentOp();
+ } while (op != currentOp);
+
+ AnalysisManager result = *this;
+ for (Operation *op : llvm::reverse(opAncestors))
+ result = result.nestImmediate(op);
+ return result;
}
-/// Get an analysis manager for the given child operation.
-AnalysisManager AnalysisManager::nest(Operation *op) {
+/// Get an analysis manager for the given immediately nested child operation.
+AnalysisManager AnalysisManager::nestImmediate(Operation *op) {
+ assert(impl->getOperation() == op->getParentOp() &&
+ "expected immediate child operation");
+
auto it = impl->childAnalyses.find(op);
if (it == impl->childAnalyses.end())
it = impl->childAnalyses
- .try_emplace(op, std::make_unique<NestedAnalysisMap>(op))
+ .try_emplace(op, std::make_unique<NestedAnalysisMap>(op, impl))
.first;
- return {this, it->second.get()};
+ return {it->second.get()};
}
/// Invalidate any non preserved analyses.
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index d888d570854e..2533d877fc00 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -60,9 +60,11 @@ class OpToOpPassAdaptor
/// Run the given operation and analysis manager on a provided op pass
/// manager.
- static LogicalResult
- runPipeline(iterator_range<OpPassManager::pass_iterator> passes,
- Operation *op, AnalysisManager am, bool verifyPasses);
+ static LogicalResult runPipeline(
+ iterator_range<OpPassManager::pass_iterator> passes, Operation *op,
+ AnalysisManager am, bool verifyPasses,
+ PassInstrumentor *instrumentor = nullptr,
+ const PassInstrumentation::PipelineParentInfo *parentInfo = nullptr);
/// A set of adaptors to run.
SmallVector<OpPassManager, 1> mgrs;
diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp
index 78e40d5b0aa7..50cbee8dc12d 100644
--- a/mlir/lib/Pass/PassRegistry.cpp
+++ b/mlir/lib/Pass/PassRegistry.cpp
@@ -291,11 +291,15 @@ class TextualPipeline {
/// given to enable accurate error reporting.
LogicalResult TextualPipeline::initialize(StringRef text,
raw_ostream &errorStream) {
+ if (text.empty())
+ return success();
+
// Build a source manager to use for error reporting.
llvm::SourceMgr pipelineMgr;
- pipelineMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(
- text, "MLIR Textual PassPipeline Parser"),
- llvm::SMLoc());
+ pipelineMgr.AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(text, "MLIR Textual PassPipeline Parser",
+ /*RequiresNullTerminator=*/false),
+ llvm::SMLoc());
auto errorHandler = [&](const char *rawLoc, Twine msg) {
pipelineMgr.PrintMessage(errorStream, llvm::SMLoc::getFromPointer(rawLoc),
llvm::SourceMgr::DK_Error, msg);
@@ -327,7 +331,7 @@ LogicalResult TextualPipeline::parsePipelineText(StringRef text,
pipeline.emplace_back(/*name=*/text.substr(0, pos).trim());
// If we have a single terminating name, we're done.
- if (pos == text.npos)
+ if (pos == StringRef::npos)
break;
text = text.substr(pos);
@@ -338,9 +342,19 @@ LogicalResult TextualPipeline::parsePipelineText(StringRef text,
text = text.substr(1);
// Skip over everything until the closing '}' and store as options.
- size_t close = text.find('}');
+ size_t close = StringRef::npos;
+ for (unsigned i = 0, e = text.size(), braceCount = 1; i < e; ++i) {
+ if (text[i] == '{') {
+ ++braceCount;
+ continue;
+ }
+ if (text[i] == '}' && --braceCount == 0) {
+ close = i;
+ break;
+ }
+ }
- // TODO: Handle skipping over quoted sub-strings.
+ // Check to see if a closing options brace was found.
if (close == StringRef::npos) {
return errorHandler(
/*rawLoc=*/text.data() - 1,
diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp
index e3978751c11c..499887559595 100644
--- a/mlir/lib/Pass/PassTiming.cpp
+++ b/mlir/lib/Pass/PassTiming.cpp
@@ -302,16 +302,13 @@ void PassTiming::startAnalysisTimer(StringRef name, TypeID id) {
void PassTiming::runAfterPass(Pass *pass, Operation *) {
Timer *timer = popLastActiveTimer();
- // If this is a pass adaptor, then we need to merge in the timing data for the
- // pipelines running on other threads.
- if (isa<OpToOpPassAdaptor>(pass)) {
- auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass});
- if (toMerge != pipelinesToMerge.end()) {
- for (auto &it : toMerge->second)
- timer->mergeChild(std::move(it));
- pipelinesToMerge.erase(toMerge);
- }
- return;
+ // Check to see if we need to merge in the timing data for the pipelines
+ // running on other threads.
+ auto toMerge = pipelinesToMerge.find({llvm::get_threadid(), pass});
+ if (toMerge != pipelinesToMerge.end()) {
+ for (auto &it : toMerge->second)
+ timer->mergeChild(std::move(it));
+ pipelinesToMerge.erase(toMerge);
}
timer->stop();
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 64c7ca86dc1e..364af20c0695 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -15,9 +15,8 @@
#include "PassDetail.h"
#include "mlir/Analysis/CallGraph.h"
-#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SCCIterator.h"
@@ -28,6 +27,11 @@
using namespace mlir;
+/// This function implements the default inliner optimization pipeline.
+static void defaultInlinerOptPipeline(OpPassManager &pm) {
+ pm.addPass(createCanonicalizerPass());
+}
+
//===----------------------------------------------------------------------===//
// Symbol Use Tracking
//===----------------------------------------------------------------------===//
@@ -279,9 +283,9 @@ class CallGraphSCC {
/// Run a given transformation over the SCCs of the callgraph in a bottom up
/// traversal.
-static void
-runTransformOnCGSCCs(const CallGraph &cg,
- function_ref<void(CallGraphSCC &)> sccTransformer) {
+static LogicalResult runTransformOnCGSCCs(
+ const CallGraph &cg,
+ function_ref<LogicalResult(CallGraphSCC &)> sccTransformer) {
llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
CallGraphSCC currentSCC(cgi);
while (!cgi.isAtEnd()) {
@@ -289,8 +293,10 @@ runTransformOnCGSCCs(const CallGraph &cg,
// SCC without invalidating our iterator.
currentSCC.reset(*cgi);
++cgi;
- sccTransformer(currentSCC);
+ if (failed(sccTransformer(currentSCC)))
+ return failure();
}
+ return success();
}
namespace {
@@ -499,85 +505,94 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
return success(inlinedAnyCalls);
}
-/// Canonicalize the nodes within the given SCC with the given set of
-/// canonicalization patterns.
-static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
- CallGraphSCC ¤tSCC, MLIRContext *context,
- const FrozenRewritePatternList &canonPatterns) {
- // Collect the sets of nodes to canonicalize.
- SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
- for (auto *node : currentSCC) {
- // Don't canonicalize the external node, it has no valid callable region.
- if (node->isExternal())
- continue;
-
- // Don't canonicalize nodes with children. Nodes with children
- // require special handling as we may remove the node during
- // canonicalization. In the future, we should be able to handle this
- // case with proper node deletion tracking.
- if (node->hasChildren())
- continue;
-
- // We also won't apply canonicalizations for nodes that are not
- // isolated. This avoids potentially mutating the regions of nodes defined
- // above, this is also a stipulation of the 'applyPatternsAndFoldGreedily'
- // driver.
- auto *region = node->getCallableRegion();
- if (!region->getParentOp()->isKnownIsolatedFromAbove())
- continue;
- nodesToCanonicalize.push_back(node);
- }
- if (nodesToCanonicalize.empty())
- return;
-
- // Canonicalize each of the nodes within the SCC in parallel.
- // NOTE: This is simple now, because we don't enable canonicalizing nodes
- // within children. When we remove this restriction, this logic will need to
- // be reworked.
- if (context->isMultithreadingEnabled()) {
- ParallelDiagnosticHandler canonicalizationHandler(context);
- llvm::parallelForEachN(
- /*Begin=*/0, /*End=*/nodesToCanonicalize.size(), [&](size_t index) {
- // Set the order for this thread so that diagnostics will be properly
- // ordered.
- canonicalizationHandler.setOrderIDForThread(index);
-
- // Apply the canonicalization patterns to this region.
- auto *node = nodesToCanonicalize[index];
- applyPatternsAndFoldGreedily(*node->getCallableRegion(),
- canonPatterns);
-
- // Make sure to reset the order ID for the diagnostic handler, as this
- // thread may be used in a
diff erent context.
- canonicalizationHandler.eraseOrderIDForThread();
- });
- } else {
- for (CallGraphNode *node : nodesToCanonicalize)
- applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns);
- }
-
- // Recompute the uses held by each of the nodes.
- for (CallGraphNode *node : nodesToCanonicalize)
- useList.recomputeUses(node, cg);
-}
-
//===----------------------------------------------------------------------===//
// InlinerPass
//===----------------------------------------------------------------------===//
namespace {
-struct InlinerPass : public InlinerBase<InlinerPass> {
+class InlinerPass : public InlinerBase<InlinerPass> {
+public:
+ InlinerPass();
+ InlinerPass(const InlinerPass &) = default;
+ InlinerPass(std::function<void(OpPassManager &)> defaultPipeline);
+ InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
+ llvm::StringMap<OpPassManager> opPipelines);
void runOnOperation() override;
- /// Attempt to inline calls within the given scc, and run canonicalizations
- /// with the given patterns, until a fixed point is reached. This allows for
- /// the inlining of newly devirtualized calls.
- void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC,
- MLIRContext *context,
- const FrozenRewritePatternList &canonPatterns);
+private:
+ /// Attempt to inline calls within the given scc, and run simplifications,
+ /// until a fixed point is reached. This allows for the inlining of newly
+ /// devirtualized calls. Returns failure if there was a fatal error during
+ /// inlining.
+ LogicalResult inlineSCC(Inliner &inliner, CGUseList &useList,
+ CallGraphSCC ¤tSCC, MLIRContext *context);
+
+ /// Optimize the nodes within the given SCC with one of the held optimization
+ /// pass pipelines. Returns failure if an error occurred during the
+ /// optimization of the SCC, success otherwise.
+ LogicalResult optimizeSCC(CallGraph &cg, CGUseList &useList,
+ CallGraphSCC ¤tSCC, MLIRContext *context);
+
+ /// Optimize the nodes within the given SCC in parallel. Returns failure if an
+ /// error occurred during the optimization of the SCC, success otherwise.
+ LogicalResult optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
+ MLIRContext *context);
+
+ /// Optimize the given callable node with one of the pass managers provided
+ /// with `pipelines`, or the default pipeline. Returns failure if an error
+ /// occurred during the optimization of the callable, success otherwise.
+ LogicalResult optimizeCallable(CallGraphNode *node,
+ llvm::StringMap<OpPassManager> &pipelines);
+
+ /// Attempt to initialize the options of this pass from the given string.
+ /// Derived classes may override this method to hook into the point at which
+ /// options are initialized, but should generally always invoke this base
+ /// class variant.
+ LogicalResult initializeOptions(StringRef options) override;
+
+ /// An optional function that constructs a default optimization pipeline for
+ /// a given operation.
+ std::function<void(OpPassManager &)> defaultPipeline;
+ /// A map of operation names to pass pipelines to use when optimizing
+ /// callable operations of these types. This provides a specialized pipeline
+ /// instead of the default. The vector size is the number of threads used
+ /// during optimization.
+ SmallVector<llvm::StringMap<OpPassManager>, 8> opPipelines;
};
} // end anonymous namespace
+InlinerPass::InlinerPass() : InlinerPass(defaultInlinerOptPipeline) {}
+InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline)
+ : defaultPipeline(defaultPipeline) {
+ opPipelines.push_back({});
+
+ // Initialize the pass options with the provided arguments.
+ if (defaultPipeline) {
+ OpPassManager fakePM("__mlir_fake_pm_op");
+ defaultPipeline(fakePM);
+ llvm::raw_string_ostream strStream(defaultPipelineStr);
+ fakePM.printAsTextualPipeline(strStream);
+ }
+}
+
+InlinerPass::InlinerPass(std::function<void(OpPassManager &)> defaultPipeline,
+ llvm::StringMap<OpPassManager> opPipelines)
+ : InlinerPass(std::move(defaultPipeline)) {
+ if (opPipelines.empty())
+ return;
+
+ // Update the option for the op specific optimization pipelines.
+ for (auto &it : opPipelines) {
+ std::string pipeline;
+ llvm::raw_string_ostream pipelineOS(pipeline);
+ pipelineOS << it.getKey() << "(";
+ it.second.printAsTextualPipeline(pipelineOS);
+ pipelineOS << ")";
+ opPipelineStrs.addValue(pipeline);
+ }
+ this->opPipelines.emplace_back(std::move(opPipelines));
+}
+
void InlinerPass::runOnOperation() {
CallGraph &cg = getAnalysis<CallGraph>();
auto *context = &getContext();
@@ -591,42 +606,190 @@ void InlinerPass::runOnOperation() {
return signalPassFailure();
}
- // Collect a set of canonicalization patterns to use when simplifying
- // callable regions within an SCC.
- OwningRewritePatternList canonPatterns;
- for (auto *op : context->getRegisteredOperations())
- op->getCanonicalizationPatterns(canonPatterns, context);
- FrozenRewritePatternList frozenCanonPatterns(std::move(canonPatterns));
-
// Run the inline transform in post-order over the SCCs in the callgraph.
SymbolTableCollection symbolTable;
Inliner inliner(context, cg, symbolTable);
CGUseList useList(getOperation(), cg, symbolTable);
- runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
- inlineSCC(inliner, useList, scc, context, frozenCanonPatterns);
+ LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
+ return inlineSCC(inliner, useList, scc, context);
});
+ if (failed(result))
+ return signalPassFailure();
// After inlining, make sure to erase any callables proven to be dead.
inliner.eraseDeadCallables();
}
-void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
- CallGraphSCC ¤tSCC, MLIRContext *context,
- const FrozenRewritePatternList &canonPatterns) {
- // If we successfully inlined any calls, run some simplifications on the
- // nodes of the scc. Continue attempting to inline until we reach a fixed
- // point, or a maximum iteration count. We canonicalize here as it may
- // devirtualize new calls, as well as give us a better cost model.
+LogicalResult InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
+ CallGraphSCC ¤tSCC,
+ MLIRContext *context) {
+ // Continuously simplify and inline until we either reach a fixed point, or
+ // hit the maximum iteration count. Simplifying early helps to refine the cost
+ // model, and in future iterations may devirtualize new calls.
unsigned iterationCount = 0;
- while (succeeded(inlineCallsInSCC(inliner, useList, currentSCC))) {
- // If we aren't allowing simplifications or the max iteration count was
- // reached, then bail out early.
- if (disableCanonicalization || ++iterationCount >= maxInliningIterations)
+ do {
+ if (failed(optimizeSCC(inliner.cg, useList, currentSCC, context)))
+ return failure();
+ if (failed(inlineCallsInSCC(inliner, useList, currentSCC)))
break;
- canonicalizeSCC(inliner.cg, useList, currentSCC, context, canonPatterns);
+ } while (++iterationCount < maxInliningIterations);
+ return success();
+}
+
+LogicalResult InlinerPass::optimizeSCC(CallGraph &cg, CGUseList &useList,
+ CallGraphSCC ¤tSCC,
+ MLIRContext *context) {
+ // Collect the sets of nodes to simplify.
+ SmallVector<CallGraphNode *, 4> nodesToVisit;
+ for (auto *node : currentSCC) {
+ if (node->isExternal())
+ continue;
+
+ // Don't simplify nodes with children. Nodes with children require special
+ // handling as we may remove the node during simplification. In the future,
+ // we should be able to handle this case with proper node deletion tracking.
+ if (node->hasChildren())
+ continue;
+
+ // We also won't apply simplifications to nodes that can't have passes
+ // scheduled on them.
+ auto *region = node->getCallableRegion();
+ if (!region->getParentOp()->isKnownIsolatedFromAbove())
+ continue;
+ nodesToVisit.push_back(node);
+ }
+ if (nodesToVisit.empty())
+ return success();
+
+ // Optimize each of the nodes within the SCC in parallel.
+ // NOTE: This is simple now, because we don't enable optimizing nodes within
+ // children. When we remove this restriction, this logic will need to be
+ // reworked.
+ if (context->isMultithreadingEnabled()) {
+ if (failed(optimizeSCCAsync(nodesToVisit, context)))
+ return failure();
+
+ // Otherwise, we are optimizing within a single thread.
+ } else {
+ for (CallGraphNode *node : nodesToVisit) {
+ if (failed(optimizeCallable(node, opPipelines[0])))
+ return failure();
+ }
+ }
+
+ // Recompute the uses held by each of the nodes.
+ for (CallGraphNode *node : nodesToVisit)
+ useList.recomputeUses(node, cg);
+ return success();
+}
+
+LogicalResult
+InlinerPass::optimizeSCCAsync(MutableArrayRef<CallGraphNode *> nodesToVisit,
+ MLIRContext *context) {
+ // Ensure that there are enough pipeline maps for the optimizer to run in
+ // parallel.
+ size_t numThreads = llvm::hardware_concurrency().compute_thread_count();
+ if (opPipelines.size() != numThreads) {
+ // Reserve before resizing so that we can use a reference to the first
+ // element.
+ opPipelines.reserve(numThreads);
+ opPipelines.resize(numThreads, opPipelines.front());
+ }
+
+ // Ensure an analysis manager has been constructed for each of the nodes.
+ // This prevents thread races when running the nested pipelines.
+ for (CallGraphNode *node : nodesToVisit)
+ getAnalysisManager().nest(node->getCallableRegion()->getParentOp());
+
+ // An index for the current node to optimize.
+ std::atomic<unsigned> nodeIt(0);
+
+ // Optimize the nodes of the SCC in parallel.
+ ParallelDiagnosticHandler optimizerHandler(context);
+ return llvm::parallelTransformReduce(
+ llvm::seq<size_t>(0, numThreads), success(),
+ [](LogicalResult lhs, LogicalResult rhs) {
+ return success(succeeded(lhs) && succeeded(rhs));
+ },
+ [&](size_t index) {
+ LogicalResult result = success();
+ for (auto e = nodesToVisit.size(); nodeIt < e && succeeded(result);) {
+ // Get the next available operation index.
+ unsigned nextID = nodeIt++;
+ if (nextID >= e)
+ break;
+
+ // Set the order for this thread so that diagnostics will be
+ // properly ordered, and reset after optimization has finished.
+ optimizerHandler.setOrderIDForThread(nextID);
+ result = optimizeCallable(nodesToVisit[nextID], opPipelines[index]);
+ optimizerHandler.eraseOrderIDForThread();
+ }
+ return result;
+ });
+}
+
+LogicalResult
+InlinerPass::optimizeCallable(CallGraphNode *node,
+ llvm::StringMap<OpPassManager> &pipelines) {
+ Operation *callable = node->getCallableRegion()->getParentOp();
+ StringRef opName = callable->getName().getStringRef();
+ auto pipelineIt = pipelines.find(opName);
+ if (pipelineIt == pipelines.end()) {
+ // If a pipeline didn't exist, use the default if possible.
+ if (!defaultPipeline)
+ return success();
+
+ OpPassManager defaultPM(opName);
+ defaultPipeline(defaultPM);
+ pipelineIt = pipelines.try_emplace(opName, std::move(defaultPM)).first;
}
+ return runPipeline(pipelineIt->second, callable);
+}
+
+LogicalResult InlinerPass::initializeOptions(StringRef options) {
+ if (failed(Pass::initializeOptions(options)))
+ return failure();
+
+ // Initialize the default pipeline builder to use the option string.
+ if (!defaultPipelineStr.empty()) {
+ std::string defaultPipelineCopy = defaultPipelineStr;
+ defaultPipeline = [=](OpPassManager &pm) {
+ parsePassPipeline(defaultPipelineCopy, pm);
+ };
+ } else if (defaultPipelineStr.getNumOccurrences()) {
+ defaultPipeline = nullptr;
+ }
+
+ // Initialize the op specific pass pipelines.
+ llvm::StringMap<OpPassManager> pipelines;
+ for (StringRef pipeline : opPipelineStrs) {
+ // Pipelines are expected to be of the form `<op-name>(<pipeline>)`.
+ size_t pipelineStart = pipeline.find_first_of('(');
+ if (pipelineStart == StringRef::npos || !pipeline.consume_back(")"))
+ return failure();
+ StringRef opName = pipeline.take_front(pipelineStart);
+ OpPassManager pm(opName);
+ if (failed(parsePassPipeline(pipeline.drop_front(1 + pipelineStart), pm)))
+ return failure();
+ pipelines.try_emplace(opName, std::move(pm));
+ }
+ opPipelines.assign({std::move(pipelines)});
+
+ return success();
}
std::unique_ptr<Pass> mlir::createInlinerPass() {
return std::make_unique<InlinerPass>();
}
+std::unique_ptr<Pass>
+mlir::createInlinerPass(llvm::StringMap<OpPassManager> opPipelines) {
+ return std::make_unique<InlinerPass>(defaultInlinerOptPipeline,
+ std::move(opPipelines));
+}
+std::unique_ptr<Pass>
+createInlinerPass(llvm::StringMap<OpPassManager> opPipelines,
+ std::function<void(OpPassManager &)> defaultPipelineBuilder) {
+ return std::make_unique<InlinerPass>(std::move(defaultPipelineBuilder),
+ std::move(opPipelines));
+}
diff --git a/mlir/test/Dialect/Affine/inlining.mlir b/mlir/test/Dialect/Affine/inlining.mlir
index 5879acdeaedb..173e48cc19e5 100644
--- a/mlir/test/Dialect/Affine/inlining.mlir
+++ b/mlir/test/Dialect/Affine/inlining.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -inline="disable-simplify" | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -inline="default-pipeline=''" | FileCheck %s
// Basic test that functions within affine operations are inlined.
func @func_with_affine_ops(%N: index) {
diff --git a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
index 36b1f8fd8a31..983fc2611223 100644
--- a/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
+++ b/mlir/test/Dialect/SPIRV/Transforms/inlining.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline{disable-simplify})' | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -pass-pipeline='spv.module(inline{default-pipeline=''})' | FileCheck %s
spv.module Logical GLSL450 {
spv.func @callee() "None" {
diff --git a/mlir/test/Pass/dynamic-pipeline-nested.mlir b/mlir/test/Pass/dynamic-pipeline-nested.mlir
index 9e0945b28e06..a1ba9ccaac47 100644
--- a/mlir/test/Pass/dynamic-pipeline-nested.mlir
+++ b/mlir/test/Pass/dynamic-pipeline-nested.mlir
@@ -20,9 +20,9 @@ module @inner_mod1 {
// CHECK: Dump Before CSE
// NOTNESTED-NEXT: @inner_mod1
// NESTED-NEXT: @foo
- func private @foo()
+ module @foo {}
// Only in the nested case we have a second run of the pass here.
// NESTED: Dump Before CSE
// NESTED-NEXT: @baz
- func private @baz()
+ module @baz {}
}
diff --git a/mlir/test/Transforms/inlining.mlir b/mlir/test/Transforms/inlining.mlir
index be9aa9cfc55b..d568be0429a9 100644
--- a/mlir/test/Transforms/inlining.mlir
+++ b/mlir/test/Transforms/inlining.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s -inline="disable-simplify" | FileCheck %s
-// RUN: mlir-opt %s -inline="disable-simplify" -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC
+// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s
+// RUN: mlir-opt %s -inline='default-pipeline=''' -mlir-print-debuginfo -mlir-print-local-scope | FileCheck %s --check-prefix INLINE-LOC
// RUN: mlir-opt %s -inline | FileCheck %s --check-prefix INLINE_SIMPLIFY
// Inline a function that takes an argument.
diff --git a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp
index 57c5a598dbe4..a6a83dd9b369 100644
--- a/mlir/test/lib/Transforms/TestDynamicPipeline.cpp
+++ b/mlir/test/lib/Transforms/TestDynamicPipeline.cpp
@@ -35,15 +35,17 @@ class TestDynamicPipelinePass
TestDynamicPipelinePass(const TestDynamicPipelinePass &) {}
void runOnOperation() override {
+ Operation *currentOp = getOperation();
+
llvm::errs() << "Dynamic execute '" << pipeline << "' on "
- << getOperation()->getName() << "\n";
+ << currentOp->getName() << "\n";
if (pipeline.empty()) {
llvm::errs() << "Empty pipeline\n";
return;
}
- auto symbolOp = dyn_cast<SymbolOpInterface>(getOperation());
+ auto symbolOp = dyn_cast<SymbolOpInterface>(currentOp);
if (!symbolOp) {
- getOperation()->emitWarning()
+ currentOp->emitWarning()
<< "Ignoring because not implementing SymbolOpInterface\n";
return;
}
@@ -54,24 +56,24 @@ class TestDynamicPipelinePass
return;
}
if (!pm) {
- pm = std::make_unique<OpPassManager>(
- getOperation()->getName().getIdentifier(),
- OpPassManager::Nesting::Implicit);
+ pm = std::make_unique<OpPassManager>(currentOp->getName().getIdentifier(),
+ OpPassManager::Nesting::Implicit);
parsePassPipeline(pipeline, *pm, llvm::errs());
}
// Check that running on the parent operation always immediately fails.
if (runOnParent) {
- if (getOperation()->getParentOp())
- if (!failed(runPipeline(*pm, getOperation()->getParentOp())))
+ if (currentOp->getParentOp())
+ if (!failed(runPipeline(*pm, currentOp->getParentOp())))
signalPassFailure();
return;
}
if (runOnNestedOp) {
llvm::errs() << "Run on nested op\n";
- getOperation()->walk([&](Operation *op) {
- if (op == getOperation() || !op->isKnownIsolatedFromAbove())
+ currentOp->walk([&](Operation *op) {
+ if (op == currentOp || !op->isKnownIsolatedFromAbove() ||
+ op->getName() != currentOp->getName())
return;
llvm::errs() << "Run on " << *op << "\n";
// Run on the current operation
@@ -80,7 +82,7 @@ class TestDynamicPipelinePass
});
} else {
// Run on the current operation
- if (failed(runPipeline(*pm, getOperation())))
+ if (failed(runPipeline(*pm, currentOp)))
signalPassFailure();
}
}
More information about the Mlir-commits
mailing list