[Mlir-commits] [mlir] [mlir][IR][NFC] Improve listener layering in dialect conversion (PR #80825)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 6 03:09:20 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Context: Conversion patterns provide a `ConversionPatternRewriter` to modify the IR. `ConversionPatternRewriter` provides the public API. Most function calls are forwarded/handled by `ConversionPatternRewriterImpl`. The dialect conversion uses the listener infrastructure to get notified about op/block insertions.
In the current design, `ConversionPatternRewriter` inherits from both `PatternRewriter` and `Listener`. The conversion rewriter registers itself as a listener. This is problematic because listener functions such as `notifyOperationInserted` are now part of the public API and can be called from conversion patterns; that would bring the dialect conversion into an inconsistent state.
With this commit, `ConversionPatternRewriter` no longer inherits from `Listener`. Instead `ConversionPatternRewriterImpl` inherits from `Listener`. This removes the problematic public API and also simplifies the code a bit: block/op insertion notifications were previously forwarded to the `ConversionPatternRewriterImpl`. This is no longer needed.
Depends on #<!-- -->80704. Review only the top commit.
---
Full diff: https://github.com/llvm/llvm-project/pull/80825.diff
7 Files Affected:
- (modified) mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h (+1-1)
- (modified) mlir/include/mlir/IR/PatternMatch.h (+6-13)
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+1-15)
- (modified) mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp (+4-2)
- (modified) mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp (+1-2)
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+26-32)
- (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+3-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index c2e3cde8ebc69..2e096e1f55292 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -992,7 +992,7 @@ class TrackingListener : public RewriterBase::Listener,
/// Notify the listener that the pattern failed to match the given operation,
/// and provide a callback to populate a diagnostic with the reason why the
/// failure occurred.
- LogicalResult
+ void
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 61da27825e870..78dcfe7f6fc3d 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -437,11 +437,9 @@ class RewriterBase : public OpBuilder {
/// reason why the failure occurred. This method allows for derived
/// listeners to optionally hook into the reason why a rewrite failed, and
/// display it to users.
- virtual LogicalResult
+ virtual void
notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback) {
- return failure();
- }
+ function_ref<void(Diagnostic &)> reasonCallback) {}
static bool classof(const OpBuilder::Listener *base);
};
@@ -480,12 +478,11 @@ class RewriterBase : public OpBuilder {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
rewriteListener->notifyOperationRemoved(op);
}
- LogicalResult notifyMatchFailure(
+ void notifyMatchFailure(
Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
- return rewriteListener->notifyMatchFailure(loc, reasonCallback);
- return failure();
+ rewriteListener->notifyMatchFailure(loc, reasonCallback);
}
private:
@@ -688,20 +685,16 @@ class RewriterBase : public OpBuilder {
template <typename CallbackT>
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
-#ifndef NDEBUG
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- return rewriteListener->notifyMatchFailure(
+ rewriteListener->notifyMatchFailure(
loc, function_ref<void(Diagnostic &)>(reasonCallback));
return failure();
-#else
- return failure();
-#endif
}
template <typename CallbackT>
std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
if (auto *rewriteListener = dyn_cast_if_present<Listener>(listener))
- return rewriteListener->notifyMatchFailure(
+ rewriteListener->notifyMatchFailure(
op->getLoc(), function_ref<void(Diagnostic &)>(reasonCallback));
return failure();
}
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 51e3e413b516f..dd1d1a3f707ed 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -632,8 +632,7 @@ struct ConversionPatternRewriterImpl;
/// This class implements a pattern rewriter for use with ConversionPatterns. It
/// extends the base PatternRewriter and provides special conversion specific
/// hooks.
-class ConversionPatternRewriter final : public PatternRewriter,
- public RewriterBase::Listener {
+class ConversionPatternRewriter final : public PatternRewriter {
public:
explicit ConversionPatternRewriter(MLIRContext *ctx);
~ConversionPatternRewriter() override;
@@ -712,10 +711,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// implemented for dialect conversion.
void eraseBlock(Block *block) override;
- /// PatternRewriter hook creating a new block.
- void notifyBlockInserted(Block *block, Region *previous,
- Region::iterator previousIt) override;
-
/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;
@@ -724,9 +719,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
ValueRange argValues = std::nullopt) override;
using PatternRewriter::inlineBlockBefore;
- /// PatternRewriter hook for inserting a new operation.
- void notifyOperationInserted(Operation *op, InsertPoint previous) override;
-
/// PatternRewriter hook for updating the given operation in-place.
/// Note: These methods only track updates to the given operation itself,
/// and not nested regions. Updates to regions will still require notification
@@ -739,12 +731,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
/// PatternRewriter hook for updating the given operation in-place.
void cancelOpModification(Operation *op) override;
- /// PatternRewriter hook for notifying match failure reasons.
- LogicalResult
- notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback) override;
- using PatternRewriter::notifyMatchFailure;
-
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 828f53c16d8f8..6dc6a8bc8ccc7 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -582,7 +582,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
// Inside regular functions we use the blocking wait operation to wait for
// the async object (token, value or group) to become available.
if (!isInCoroutine) {
- ImplicitLocOpBuilder builder(loc, op, &rewriter);
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ builder.setInsertionPoint(op);
builder.create<RuntimeAwaitOp>(loc, operand);
// Assert that the awaited operands is not in the error state.
@@ -601,7 +602,8 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
CoroMachinery &coro = funcCoro->getSecond();
Block *suspended = op->getBlock();
- ImplicitLocOpBuilder builder(loc, op, &rewriter);
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ builder.setInsertionPoint(op);
MLIRContext *ctx = op->getContext();
// Save the coroutine state and resume on a runtime managed thread when
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index 371ad904dcae5..a964c205b62e8 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -1265,14 +1265,13 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
return diag;
}
-LogicalResult transform::TrackingListener::notifyMatchFailure(
+void transform::TrackingListener::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
DBGS() << "Match Failure : " << diag.str() << "\n";
});
- return failure();
}
void transform::TrackingListener::notifyOperationRemoved(Operation *op) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 346135fb44722..e41231d7cbd39 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -825,7 +825,7 @@ void ArgConverter::insertConversion(Block *newBlock,
//===----------------------------------------------------------------------===//
namespace mlir {
namespace detail {
-struct ConversionPatternRewriterImpl {
+struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
: argConverter(rewriter, unresolvedMaterializations),
notifyCallback(nullptr) {}
@@ -903,15 +903,19 @@ struct ConversionPatternRewriterImpl {
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
- /// PatternRewriter hook for replacing the results of an operation.
+ //// Notifies that an op was inserted.
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override;
+
+ /// Notifies that an op is about to be replaced with the given values.
void notifyOpReplaced(Operation *op, ValueRange newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
- /// Notifies that a block was created.
- void notifyInsertedBlock(Block *block, Region *previous,
- Region::iterator previousIt);
+ /// Notifies that a block was inserted.
+ void notifyBlockInserted(Block *block, Region *previous,
+ Region::iterator previousIt) override;
/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
@@ -921,9 +925,9 @@ struct ConversionPatternRewriterImpl {
Block::iterator before);
/// Notifies that a pattern match failed for the given reason.
- LogicalResult
+ void
notifyMatchFailure(Location loc,
- function_ref<void(Diagnostic &)> reasonCallback);
+ function_ref<void(Diagnostic &)> reasonCallback) override;
//===--------------------------------------------------------------------===//
// State
@@ -1236,10 +1240,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
legalTypes.clear();
if (failed(currentTypeConverter->convertType(origType, legalTypes))) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
- return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
+ notifyMatchFailure(operandLoc, [=](Diagnostic &diag) {
diag << "unable to convert type for " << valueDiagTag << " #"
<< it.index() << ", type was " << origType;
});
+ return failure();
}
// TODO: There currently isn't any mechanism to do 1->N type conversion
// via the PatternRewriter replacement API, so for now we just ignore it.
@@ -1363,6 +1368,16 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
+void ConversionPatternRewriterImpl::notifyOperationInserted(
+ Operation *op, OpBuilder::InsertPoint previous) {
+ assert(!previous.isSet() && "expected newly created op");
+ LLVM_DEBUG({
+ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+ createdOps.push_back(op);
+}
+
void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
@@ -1398,7 +1413,7 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
}
-void ConversionPatternRewriterImpl::notifyInsertedBlock(
+void ConversionPatternRewriterImpl::notifyBlockInserted(
Block *block, Region *previous, Region::iterator previousIt) {
if (!previous) {
// This is a newly created block.
@@ -1419,7 +1434,7 @@ void ConversionPatternRewriterImpl::notifyBlockBeingInlined(
blockActions.push_back(BlockAction::getInline(block, srcBlock, before));
}
-LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
+void ConversionPatternRewriterImpl::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
@@ -1428,7 +1443,6 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
if (notifyCallback)
notifyCallback(diag);
});
- return failure();
}
//===----------------------------------------------------------------------===//
@@ -1438,7 +1452,7 @@ LogicalResult ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
: PatternRewriter(ctx),
impl(new detail::ConversionPatternRewriterImpl(*this)) {
- setListener(this);
+ setListener(impl.get());
}
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
@@ -1541,11 +1555,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
results);
}
-void ConversionPatternRewriter::notifyBlockInserted(
- Block *block, Region *previous, Region::iterator previousIt) {
- impl->notifyInsertedBlock(block, previous, previousIt);
-}
-
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
auto *continuation = block->splitBlock(before);
@@ -1573,16 +1582,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
eraseBlock(source);
}
-void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
- InsertPoint previous) {
- assert(!previous.isSet() && "expected newly created op");
- LLVM_DEBUG({
- impl->logger.startLine()
- << "** Insert : '" << op->getName() << "'(" << op << ")\n";
- });
- impl->createdOps.push_back(op);
-}
-
void ConversionPatternRewriter::startOpModification(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
@@ -1615,11 +1614,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + updateIdx);
}
-LogicalResult ConversionPatternRewriter::notifyMatchFailure(
- Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
- return impl->notifyMatchFailure(loc, reasonCallback);
-}
-
void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
Block::iterator iterator) {
llvm_unreachable(
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index d5395045af434..bde8c290e774b 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -387,7 +387,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
void notifyBlockRemoved(Block *block) override;
/// For debugging only: Notify the driver of a pattern match failure.
- LogicalResult
+ void
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
@@ -726,7 +726,7 @@ void GreedyPatternRewriteDriver::notifyOperationReplaced(
config.listener->notifyOperationReplaced(op, replacement);
}
-LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
+void GreedyPatternRewriteDriver::notifyMatchFailure(
Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
@@ -734,8 +734,7 @@ LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(
logger.startLine() << "** Failure : " << diag.str() << "\n";
});
if (config.listener)
- return config.listener->notifyMatchFailure(loc, reasonCallback);
- return failure();
+ config.listener->notifyMatchFailure(loc, reasonCallback);
}
//===----------------------------------------------------------------------===//
``````````
</details>
https://github.com/llvm/llvm-project/pull/80825
More information about the Mlir-commits
mailing list