[Mlir-commits] [mlir] f406adf - Add capture of "IRUnits" as context for an MLIR Action
Mehdi Amini
llvmlistbot at llvm.org
Mon Mar 20 05:47:31 PDT 2023
Author: Mehdi Amini
Date: 2023-03-20T13:40:55+01:00
New Revision: f406adf134c2f81747bbc653b1399656268fe17a
URL: https://github.com/llvm/llvm-project/commit/f406adf134c2f81747bbc653b1399656268fe17a
DIFF: https://github.com/llvm/llvm-project/commit/f406adf134c2f81747bbc653b1399656268fe17a.diff
LOG: Add capture of "IRUnits" as context for an MLIR Action
IRUnit is defined as:
using IRUnit = PointerUnion<Operation *, Region *, Block *, Value>;
The tracing::Action is extended to take an ArrayRef<IRUnit> as context to
describe an Action. It is demonstrated in the "ActionLogging" observer.
Reviewed By: rriddle, Mogball
Differential Revision: https://reviews.llvm.org/D144814
Added:
mlir/include/mlir/IR/Unit.h
mlir/lib/IR/Unit.cpp
Modified:
mlir/include/mlir/Debug/Observers/ActionLogging.h
mlir/include/mlir/IR/Action.h
mlir/include/mlir/IR/MLIRContext.h
mlir/lib/Debug/Observers/ActionLogging.cpp
mlir/lib/IR/CMakeLists.txt
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/test/Pass/action-logging.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Debug/Observers/ActionLogging.h b/mlir/include/mlir/Debug/Observers/ActionLogging.h
index ff280c59da9ce..bd1d56538906a 100644
--- a/mlir/include/mlir/Debug/Observers/ActionLogging.h
+++ b/mlir/include/mlir/Debug/Observers/ActionLogging.h
@@ -22,9 +22,9 @@ namespace tracing {
/// on the provided stream.
struct ActionLogger : public ExecutionContext::Observer {
ActionLogger(raw_ostream &os, bool printActions = true,
- bool printBreakpoints = true)
- : os(os), printActions(printActions), printBreakpoints(printBreakpoints) {
- }
+ bool printBreakpoints = true, bool printIRUnits = true)
+ : os(os), printActions(printActions), printBreakpoints(printBreakpoints),
+ printIRUnits(printIRUnits) {}
void beforeExecute(const ActionActiveStack *action, Breakpoint *breakpoint,
bool willExecute) override;
@@ -34,6 +34,7 @@ struct ActionLogger : public ExecutionContext::Observer {
raw_ostream &os;
bool printActions;
bool printBreakpoints;
+ bool printIRUnits;
};
} // namespace tracing
diff --git a/mlir/include/mlir/IR/Action.h b/mlir/include/mlir/IR/Action.h
index 569d4288f2086..9359324dd6090 100644
--- a/mlir/include/mlir/IR/Action.h
+++ b/mlir/include/mlir/IR/Action.h
@@ -15,6 +15,7 @@
#ifndef MLIR_IR_ACTION_H
#define MLIR_IR_ACTION_H
+#include "mlir/IR/Unit.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/TypeID.h"
#include "llvm/ADT/ArrayRef.h"
@@ -51,11 +52,19 @@ class Action {
os << "Action \"" << getTag() << "\"";
}
+ /// Return the set of IR units that are associated with this action.
+ virtual ArrayRef<IRUnit> getContextIRUnits() const { return irUnits; }
+
protected:
- Action(TypeID actionID) : actionID(actionID) {}
+ Action(TypeID actionID, ArrayRef<IRUnit> irUnits)
+ : actionID(actionID), irUnits(irUnits) {}
/// The type of the derived action class, used for `isa`/`dyn_cast`.
TypeID actionID;
+
+ /// Set of IR units (operations, regions, blocks, values) that are associated
+ /// with this action.
+ ArrayRef<IRUnit> irUnits;
};
/// CRTP Implementation of an action. This class provides a base class for
@@ -67,7 +76,8 @@ class Action {
template <typename Derived>
class ActionImpl : public Action {
public:
- ActionImpl() : Action(TypeID::get<Derived>()) {}
+ ActionImpl(ArrayRef<IRUnit> irUnits = {})
+ : Action(TypeID::get<Derived>(), irUnits) {}
/// Provide classof to allow casting between action types.
static bool classof(const Action *action) {
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index cc13447d9d584..d9e140bd75f72 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -11,6 +11,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
+#include "llvm/ADT/ArrayRef.h"
#include <functional>
#include <memory>
#include <vector>
@@ -265,9 +266,10 @@ class MLIRContext {
/// Dispatch the provided action to the handler if any, or just execute it.
template <typename ActionTy, typename... Args>
- void executeAction(function_ref<void()> actionFn, Args &&...args) {
+ void executeAction(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
+ Args &&...args) {
if (LLVM_UNLIKELY(hasActionHandler()))
- executeActionInternal<ActionTy, Args...>(actionFn,
+ executeActionInternal<ActionTy, Args...>(actionFn, irUnits,
std::forward<Args>(args)...);
else
actionFn();
@@ -286,8 +288,10 @@ class MLIRContext {
/// avoid calling the ctor for the Action unnecessarily.
template <typename ActionTy, typename... Args>
LLVM_ATTRIBUTE_NOINLINE void
- executeActionInternal(function_ref<void()> actionFn, Args &&...args) {
- executeActionInternal(actionFn, ActionTy(std::forward<Args>(args)...));
+ executeActionInternal(function_ref<void()> actionFn, ArrayRef<IRUnit> irUnits,
+ Args &&...args) {
+ executeActionInternal(actionFn,
+ ActionTy(irUnits, std::forward<Args>(args)...));
}
const std::unique_ptr<MLIRContextImpl> impl;
diff --git a/mlir/include/mlir/IR/Unit.h b/mlir/include/mlir/IR/Unit.h
new file mode 100644
index 0000000000000..033dab5974516
--- /dev/null
+++ b/mlir/include/mlir/IR/Unit.h
@@ -0,0 +1,42 @@
+//===- Unit.h - IR Unit definition--------------------*- C++ -*-=============//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_UNIT_H
+#define MLIR_IR_UNIT_H
+
+#include "mlir/IR/OperationSupport.h"
+#include "llvm/ADT/PointerUnion.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace llvm {
+class raw_ostream;
+} // namespace llvm
+namespace mlir {
+class Operation;
+class Region;
+class Block;
+class Value;
+
+/// IRUnit is a union of the
diff erent types of IR objects that consistute the
+/// IR structure (other than Type and Attribute), that is Operation, Region, and
+/// Block.
+class IRUnit : public PointerUnion<Operation *, Region *, Block *, Value> {
+public:
+ using PointerUnion::PointerUnion;
+
+ /// Print the IRUnit to the given stream.
+ void print(raw_ostream &os,
+ OpPrintingFlags flags =
+ OpPrintingFlags().skipRegions().useLocalScope()) const;
+};
+
+raw_ostream &operator<<(raw_ostream &os, const IRUnit &unit);
+
+} // end namespace mlir
+
+#endif // MLIR_IR_UNIT_H
diff --git a/mlir/lib/Debug/Observers/ActionLogging.cpp b/mlir/lib/Debug/Observers/ActionLogging.cpp
index 9826adf33ee16..7e7c5acaaee1f 100644
--- a/mlir/lib/Debug/Observers/ActionLogging.cpp
+++ b/mlir/lib/Debug/Observers/ActionLogging.cpp
@@ -7,9 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Debug/Observers/ActionLogging.h"
+#include "mlir/IR/Action.h"
#include "llvm/Support/Threading.h"
-#include <sstream>
-#include <thread>
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::tracing;
@@ -22,6 +22,10 @@ void ActionLogger::beforeExecute(const ActionActiveStack *action,
Breakpoint *breakpoint, bool willExecute) {
SmallVector<char> name;
llvm::get_thread_name(name);
+ if (name.empty()) {
+ llvm::raw_svector_ostream os(name);
+ os << llvm::get_threadid();
+ }
os << "[thread " << name << "] ";
if (willExecute)
os << "begins ";
@@ -29,21 +33,30 @@ void ActionLogger::beforeExecute(const ActionActiveStack *action,
os << "skipping ";
if (printBreakpoints) {
if (breakpoint)
- os << " (on breakpoint: " << *breakpoint << ") ";
+ os << "(on breakpoint: " << *breakpoint << ") ";
else
- os << " (no breakpoint) ";
+ os << "(no breakpoint) ";
}
os << "Action ";
if (printActions)
action->getAction().print(os);
else
os << action->getAction().getTag();
+ if (printIRUnits) {
+ os << " (";
+ interleaveComma(action->getAction().getContextIRUnits(), os);
+ os << ")";
+ }
os << "`\n";
}
void ActionLogger::afterExecute(const ActionActiveStack *action) {
SmallVector<char> name;
llvm::get_thread_name(name);
+ if (name.empty()) {
+ llvm::raw_svector_ostream os(name);
+ os << llvm::get_threadid();
+ }
os << "[thread " << name << "] completed `" << action->getAction().getTag()
<< "`\n";
}
diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt
index 8b4fb42e03eab..4377ebe160554 100644
--- a/mlir/lib/IR/CMakeLists.txt
+++ b/mlir/lib/IR/CMakeLists.txt
@@ -32,6 +32,7 @@ add_mlir_library(MLIRIR
Types.cpp
TypeRange.cpp
TypeUtilities.cpp
+ Unit.cpp
Value.cpp
ValueRange.cpp
Verifier.cpp
diff --git a/mlir/lib/IR/Unit.cpp b/mlir/lib/IR/Unit.cpp
new file mode 100644
index 0000000000000..7da714fe7d539
--- /dev/null
+++ b/mlir/lib/IR/Unit.cpp
@@ -0,0 +1,63 @@
+//===- Unit.cpp - Support for manipulating IR Unit ------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Unit.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
+#include "mlir/IR/Region.h"
+#include "llvm/Support/raw_ostream.h"
+#include <iterator>
+#include <sstream>
+
+using namespace mlir;
+
+static void printOp(llvm::raw_ostream &os, Operation *op,
+ OpPrintingFlags &flags) {
+ if (!op) {
+ os << "<Operation:nullptr>";
+ return;
+ }
+ op->print(os, flags);
+}
+
+static void printRegion(llvm::raw_ostream &os, Region *region,
+ OpPrintingFlags &flags) {
+ if (!region) {
+ os << "<Region:nullptr>";
+ return;
+ }
+ os << "Region #" << region->getRegionNumber() << " for op ";
+ printOp(os, region->getParentOp(), flags);
+}
+
+static void printBlock(llvm::raw_ostream &os, Block *block,
+ OpPrintingFlags &flags) {
+ Region *region = block->getParent();
+ Block *entry = ®ion->front();
+ int blockId = std::distance(entry->getIterator(), block->getIterator());
+ os << "Block #" << blockId << " for ";
+ bool shouldSkipRegions = flags.shouldSkipRegions();
+ printRegion(os, region, flags.skipRegions());
+ if (!shouldSkipRegions)
+ block->print(os);
+}
+
+void mlir::IRUnit::print(llvm::raw_ostream &os, OpPrintingFlags flags) const {
+ if (auto *op = this->dyn_cast<Operation *>())
+ return printOp(os, op, flags);
+ if (auto *region = this->dyn_cast<Region *>())
+ return printRegion(os, region, flags);
+ if (auto *block = this->dyn_cast<Block *>())
+ return printBlock(os, block, flags);
+ llvm_unreachable("unknown IRUnit");
+}
+
+llvm::raw_ostream &mlir::operator<<(llvm::raw_ostream &os, const IRUnit &unit) {
+ unit.print(os);
+ return os;
+}
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 2b07898c8200f..e496a29e9fbe5 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -482,7 +482,7 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
pass->runOnOperation();
passFailed = pass->passState->irAndPassFailed.getInt();
},
- *pass, op);
+ {op}, *pass);
// Invalidate any non preserved analyses.
am.invalidate(pass->passState->preservedAnalyses);
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 6fc46aff35835..ca60cf2fa5894 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -11,17 +11,23 @@
#include "mlir/IR/Action.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/FormatVariadic.h"
namespace mlir {
/// Encapsulate the "action" of executing a single pass, used for the MLIR
/// tracing infrastructure.
struct PassExecutionAction : public tracing::ActionImpl<PassExecutionAction> {
- PassExecutionAction(const Pass &pass, Operation *op) : pass(pass), op(op) {}
+ using Base = tracing::ActionImpl<PassExecutionAction>;
+ PassExecutionAction(ArrayRef<IRUnit> irUnits, const Pass &pass)
+ : Base(irUnits), pass(pass) {}
static constexpr StringLiteral tag = "pass-execution-action";
void print(raw_ostream &os) const override;
const Pass &getPass() const { return pass; }
- Operation *getOp() const { return op; }
+ Operation *getOp() const {
+ ArrayRef<IRUnit> irUnits = getContextIRUnits();
+ return irUnits.empty() ? nullptr : irUnits[0].dyn_cast<Operation *>();
+ }
public:
const Pass &pass;
diff --git a/mlir/test/Pass/action-logging.mlir b/mlir/test/Pass/action-logging.mlir
index 943f05a2968fe..d10c64c2af2ed 100644
--- a/mlir/test/Pass/action-logging.mlir
+++ b/mlir/test/Pass/action-logging.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s --log-actions-to=- -canonicalize -test-module-pass | FileCheck %s
-// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `Canonicalizer` on Operation `builtin.module`
-// CHECK: [thread {{.*}}] completed `pass-execution-action`
-// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `(anonymous namespace)::TestModulePass` on Operation `builtin.module`
-// CHECK: [thread {{.*}}] completed `pass-execution-action`
+// CHECK: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `Canonicalizer` on Operation `builtin.module` (module {...}
+// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action`
+// CHECK-NEXT: [thread {{.*}}] begins (no breakpoint) Action `pass-execution-action` running `{{.*}}TestModulePass` on Operation `builtin.module` (module {...}
+// CHECK-NEXT: [thread {{.*}}] completed `pass-execution-action`
+// CHECK-NOT: Action
More information about the Mlir-commits
mailing list