[llvm-branch-commits] [mlir] [mlir][Transforms] Encapsulate dialect conversion options in `ConversionConfig` (PR #82250)
Matthias Springer via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Feb 19 06:03:40 PST 2024
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/82250
This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver.
A few existing options are moved to this objects, simplifying the dialect conversion API.
>From 819e5f95ed0857e88972501cb10b9931b1e91a1c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 19 Feb 2024 14:02:14 +0000
Subject: [PATCH] [mlir][Transforms] Encapsulate dialect conversion options in
`ConversionConfig`
This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver.
A few existing options are moved to this objects, simplifying the dialect conversion API.
---
.../mlir/Transforms/DialectConversion.h | 75 ++++++----
.../Transforms/Utils/DialectConversion.cpp | 129 +++++++++---------
mlir/test/lib/Dialect/Test/TestPatterns.cpp | 14 +-
3 files changed, 118 insertions(+), 100 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5c91a9498b35d4..8da5dcb0be3fd0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -24,6 +24,7 @@ namespace mlir {
// Forward declarations.
class Attribute;
class Block;
+struct ConversionConfig;
class ConversionPatternRewriter;
class MLIRContext;
class Operation;
@@ -770,7 +771,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Conversion pattern rewriters must not be used outside of dialect
/// conversions. They apply some IR rewrites in a delayed fashion and could
/// bring the IR into an inconsistent state when used standalone.
- explicit ConversionPatternRewriter(MLIRContext *ctx);
+ explicit ConversionPatternRewriter(MLIRContext *ctx,
+ const ConversionConfig &config);
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
@@ -1070,6 +1072,30 @@ class PDLConversionConfig final {
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+//===----------------------------------------------------------------------===//
+// ConversionConfig
+//===----------------------------------------------------------------------===//
+
+/// Dialect conversion configuration.
+struct ConversionConfig {
+ /// An optional callback used to notify about match failure diagnostics during
+ /// the conversion. Diagnostics are only reported to this callback may only be
+ /// available in debug mode.
+ function_ref<void(Diagnostic &)> notifyCallback = nullptr;
+
+ /// Partial conversion only. All operations that are found not to be
+ /// legalizable are placed in this set. (Note that if there is an op
+ /// explicitly marked as illegal, the conversion terminates and the set will
+ /// not necessarily be complete.)
+ DenseSet<Operation *> *unlegalizedOps = nullptr;
+
+ /// Analysis conversion only. All operations that are found to be legalizable
+ /// are placed in this set. Note that no actual rewrites are applied to the
+ /// IR during an analysis conversion and only pre-existing operations are
+ /// added to the set.
+ DenseSet<Operation *> *legalizableOps = nullptr;
+};
+
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
@@ -1083,19 +1109,16 @@ class PDLConversionConfig final {
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal. If an
-/// `unconvertedOps` set is provided, all operations that are found not to be
-/// legalizable to the given `target` are placed within that set. (Note that if
-/// there is an op explicitly marked as illegal, the conversion terminates and
-/// the `unconvertedOps` set will not necessarily be complete.)
+/// returns failure if there ops explicitly marked as illegal.
LogicalResult
-applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
+applyPartialConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps = nullptr);
+ ConversionConfig config = ConversionConfig());
LogicalResult
applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps = nullptr);
+ ConversionConfig config = ConversionConfig());
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation
@@ -1103,31 +1126,27 @@ applyPartialConversion(Operation *op, const ConversionTarget &target,
/// within 'ops'.
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns);
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns);
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
/// Apply an analysis conversion on the given operations, and all nested
/// operations. This method analyzes which operations would be successfully
/// converted to the target if a conversion was applied. All operations that
/// were found to be legalizable to the given 'target' are placed within the
-/// provided 'convertedOps' set; note that no actual rewrites are applied to the
-/// operations on success and only pre-existing operations are added to the set.
-/// This method only returns failure if there are unreachable blocks in any of
-/// the regions nested within 'ops'. There's an additional argument
-/// `notifyCallback` which is used for collecting match failure diagnostics
-/// generated during the conversion. Diagnostics are only reported to this
-/// callback may only be available in debug mode.
-LogicalResult applyAnalysisConversion(
- ArrayRef<Operation *> ops, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback = nullptr);
-LogicalResult applyAnalysisConversion(
- Operation *op, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+/// provided 'config.legalizableOps' set; note that no actual rewrites are
+/// applied to the operations on success. This method only returns failure if
+/// there are unreachable blocks in any of the regions nested within 'ops'.
+LogicalResult
+applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
+LogicalResult
+applyAnalysisConversion(Operation *op, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
} // namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 6cf178e149be7f..30fc2298b3deb3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -224,6 +224,8 @@ class IRRewrite {
/// Erase the given block (unless it was already erased).
void eraseBlock(Block *block);
+ const ConversionConfig &getConfig() const;
+
const Kind kind;
ConversionPatternRewriterImpl &rewriterImpl;
};
@@ -723,9 +725,10 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+ explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+ const ConversionConfig &config)
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
- notifyCallback(nullptr) {}
+ config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -931,10 +934,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// converting the arguments of blocks within that region.
DenseMap<Region *, const TypeConverter *> regionToConverter;
- /// This allows the user to collect the match failure message.
- function_ref<void(Diagnostic &)> notifyCallback;
-
- DenseSet<Operation *> *trackedOps = nullptr;
+ /// Dialect conversion configuration.
+ const ConversionConfig &config;
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
@@ -957,6 +958,10 @@ void IRRewrite::eraseBlock(Block *block) {
rewriterImpl.eraseRewriter.eraseBlock(block);
}
+const ConversionConfig &IRRewrite::getConfig() const {
+ return rewriterImpl.config;
+}
+
void BlockTypeConversionRewrite::commit() {
// Process the remapping for each of the original arguments.
for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
@@ -1074,8 +1079,8 @@ void ReplaceOperationRewrite::commit() {
if (Value newValue =
rewriterImpl.mapping.lookupOrNull(result, result.getType()))
result.replaceAllUsesWith(newValue);
- if (rewriterImpl.trackedOps)
- rewriterImpl.trackedOps->erase(op);
+ if (getConfig().unlegalizedOps)
+ getConfig().unlegalizedOps->erase(op);
// Do not erase the operation yet. It may still be referenced in `mapping`.
op->getBlock()->getOperations().remove(op);
}
@@ -1510,8 +1515,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
logger.startLine() << "** Failure : " << diag.str() << "\n";
- if (notifyCallback)
- notifyCallback(diag);
+ if (config.notifyCallback)
+ config.notifyCallback(diag);
});
}
@@ -1519,9 +1524,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
// ConversionPatternRewriter
//===----------------------------------------------------------------------===//
-ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
+ConversionPatternRewriter::ConversionPatternRewriter(
+ MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(*this)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
setListener(impl.get());
}
@@ -1972,12 +1978,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
- if (rewriterImpl.notifyCallback) {
+ if (rewriterImpl.config.notifyCallback) {
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
diag << "Failed to apply pattern \"" << pattern.getDebugName()
<< "\" on op:\n"
<< *op;
- rewriterImpl.notifyCallback(diag);
+ rewriterImpl.config.notifyCallback(diag);
}
});
rewriterImpl.resetState(curState);
@@ -2365,14 +2371,12 @@ namespace mlir {
struct OperationConverter {
explicit OperationConverter(const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- OpConversionMode mode,
- DenseSet<Operation *> *trackedOps = nullptr)
- : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
+ const ConversionConfig &config,
+ OpConversionMode mode)
+ : opLegalizer(target, patterns), config(config), mode(mode) {}
/// Converts the given operations to the conversion target.
- LogicalResult
- convertOperations(ArrayRef<Operation *> ops,
- function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+ LogicalResult convertOperations(ArrayRef<Operation *> ops);
private:
/// Converts an operation with the given rewriter.
@@ -2409,14 +2413,11 @@ struct OperationConverter {
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
+ /// Dialect conversion configuration.
+ ConversionConfig config;
+
/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
-
- /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
- /// this is populated with ops found to be legalizable to the target.
- /// When mode == OpConversionMode::Partial, this is populated with ops found
- /// *not* to be legalizable to the target.
- DenseSet<Operation *> *trackedOps;
};
} // namespace mlir
@@ -2430,28 +2431,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
// Partial conversions allow conversions to fail iff the operation was not
- // explicitly marked as illegal. If the user provided a nonlegalizableOps
- // set, non-legalizable ops are included.
+ // explicitly marked as illegal. If the user provided a `unlegalizedOps`
+ // set, non-legalizable ops are added to that set.
if (mode == OpConversionMode::Partial) {
if (opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
- if (trackedOps)
- trackedOps->insert(op);
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
- trackedOps->insert(op);
+ if (config.legalizableOps)
+ config.legalizableOps->insert(op);
}
return success();
}
-LogicalResult OperationConverter::convertOperations(
- ArrayRef<Operation *> ops,
- function_ref<void(Diagnostic &)> notifyCallback) {
+LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
const ConversionTarget &target = opLegalizer.getTarget();
@@ -2472,10 +2472,8 @@ LogicalResult OperationConverter::convertOperations(
}
// Convert each operation and discard rewrites on failure.
- ConversionPatternRewriter rewriter(ops.front()->getContext());
+ ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
- rewriterImpl.notifyCallback = notifyCallback;
- rewriterImpl.trackedOps = trackedOps;
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
@@ -3461,56 +3459,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
//===----------------------------------------------------------------------===//
// Partial Conversion
-LogicalResult
-mlir::applyPartialConversion(ArrayRef<Operation *> ops,
- const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps) {
- OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
- unconvertedOps);
+LogicalResult mlir::applyPartialConversion(
+ ArrayRef<Operation *> ops, const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+ OperationConverter opConverter(target, patterns, config,
+ OpConversionMode::Partial);
return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps) {
- return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
- unconvertedOps);
+ ConversionConfig config) {
+ return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
}
//===----------------------------------------------------------------------===//
// Full Conversion
-LogicalResult
-mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns) {
- OperationConverter opConverter(target, patterns, OpConversionMode::Full);
+LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config) {
+ OperationConverter opConverter(target, patterns, config,
+ OpConversionMode::Full);
return opConverter.convertOperations(ops);
}
-LogicalResult
-mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns) {
- return applyFullConversion(llvm::ArrayRef(op), target, patterns);
+LogicalResult mlir::applyFullConversion(Operation *op,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config) {
+ return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
}
//===----------------------------------------------------------------------===//
// Analysis Conversion
-LogicalResult
-mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
- ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback) {
- OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
- &convertedOps);
- return opConverter.convertOperations(ops, notifyCallback);
+LogicalResult mlir::applyAnalysisConversion(
+ ArrayRef<Operation *> ops, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+ OperationConverter opConverter(target, patterns, config,
+ OpConversionMode::Analysis);
+ return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback) {
- return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
- convertedOps, notifyCallback);
+ ConversionConfig config) {
+ return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 1c02232b8adbb1..c04d5cc80446f4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1135,8 +1135,10 @@ struct TestLegalizePatternDriver
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
- if (failed(applyPartialConversion(
- getOperation(), target, std::move(patterns), &unlegalizedOps))) {
+ ConversionConfig config;
+ config.unlegalizedOps = &unlegalizedOps;
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns), config))) {
getOperation()->emitRemark() << "applyPartialConversion failed";
}
// Emit remarks for each legalizable operation.
@@ -1164,8 +1166,10 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
+ ConversionConfig config;
+ config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
- std::move(patterns), legalizedOps)))
+ std::move(patterns), config)))
return signalPassFailure();
// Emit remarks for each legalizable operation.
@@ -1789,8 +1793,10 @@ struct TestMergeBlocksPatternDriver
});
DenseSet<Operation *> unlegalizedOps;
+ ConversionConfig config;
+ config.unlegalizedOps = &unlegalizedOps;
(void)applyPartialConversion(getOperation(), target, std::move(patterns),
- &unlegalizedOps);
+ config);
for (auto *op : unlegalizedOps)
op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
}
More information about the llvm-branch-commits
mailing list