[Mlir-commits] [mlir] 60fbd60 - Revert "[mlir][Transforms] Encapsulate dialect conversion options in `ConversionConfig` (#83662)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 2 14:41:44 PST 2024
Author: Mehdi Amini
Date: 2024-03-02T14:41:40-08:00
New Revision: 60fbd6050107875956960c3ce35cf94b202d8675
URL: https://github.com/llvm/llvm-project/commit/60fbd6050107875956960c3ce35cf94b202d8675
DIFF: https://github.com/llvm/llvm-project/commit/60fbd6050107875956960c3ce35cf94b202d8675.diff
LOG: Revert "[mlir][Transforms] Encapsulate dialect conversion options in `ConversionConfig` (#83662)
This reverts commit 5f1319bb385342c7ef4124b05b83b89ef8588ee8.
A FIR test is broken on Windows
Added:
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 84396529eb7c2e..88eefa69a8003f 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -24,7 +24,6 @@ namespace mlir {
// Forward declarations.
class Attribute;
class Block;
-struct ConversionConfig;
class ConversionPatternRewriter;
class MLIRContext;
class Operation;
@@ -768,8 +767,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Conversion pattern rewriters must not be used outside of dialect
/// conversions. They apply some IR rewrites in a delayed fashion and could
/// bring the IR into an inconsistent state when used standalone.
- explicit ConversionPatternRewriter(MLIRContext *ctx,
- const ConversionConfig &config);
+ explicit ConversionPatternRewriter(MLIRContext *ctx);
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
@@ -1069,30 +1067,6 @@ class PDLConversionConfig final {
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
-//===----------------------------------------------------------------------===//
-// ConversionConfig
-//===----------------------------------------------------------------------===//
-
-/// Dialect conversion configuration.
-struct ConversionConfig {
- /// An optional callback used to notify about match failure diagnostics during
- /// the conversion. Diagnostics reported to this callback may only be
- /// available in debug mode.
- function_ref<void(Diagnostic &)> notifyCallback = nullptr;
-
- /// Partial conversion only. All operations that are found not to be
- /// legalizable are placed in this set. (Note that if there is an op
- /// explicitly marked as illegal, the conversion terminates and the set will
- /// not necessarily be complete.)
- DenseSet<Operation *> *unlegalizedOps = nullptr;
-
- /// Analysis conversion only. All operations that are found to be legalizable
- /// are placed in this set. Note that no actual rewrites are applied to the
- /// IR during an analysis conversion and only pre-existing operations are
- /// added to the set.
- DenseSet<Operation *> *legalizableOps = nullptr;
-};
-
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
@@ -1106,16 +1080,20 @@ struct ConversionConfig {
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal.
+/// returns failure if there ops explicitly marked as illegal. If an
+/// `unconvertedOps` set is provided, all operations that are found not to be
+/// legalizable to the given `target` are placed within that set. (Note that if
+/// there is an op explicitly marked as illegal, the conversion terminates and
+/// the `unconvertedOps` set will not necessarily be complete.)
LogicalResult
applyPartialConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- ConversionConfig config = ConversionConfig());
+ DenseSet<Operation *> *unconvertedOps = nullptr);
LogicalResult
applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- ConversionConfig config = ConversionConfig());
+ DenseSet<Operation *> *unconvertedOps = nullptr);
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation
@@ -1123,27 +1101,31 @@ applyPartialConversion(Operation *op, const ConversionTarget &target,
/// within 'ops'.
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- ConversionConfig config = ConversionConfig());
+ const FrozenRewritePatternSet &patterns);
LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- ConversionConfig config = ConversionConfig());
+ const FrozenRewritePatternSet &patterns);
/// Apply an analysis conversion on the given operations, and all nested
/// operations. This method analyzes which operations would be successfully
/// converted to the target if a conversion was applied. All operations that
/// were found to be legalizable to the given 'target' are placed within the
-/// provided 'config.legalizableOps' set; note that no actual rewrites are
-/// applied to the operations on success. This method only returns failure if
-/// there are unreachable blocks in any of the regions nested within 'ops'.
-LogicalResult
-applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- ConversionConfig config = ConversionConfig());
-LogicalResult
-applyAnalysisConversion(Operation *op, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- ConversionConfig config = ConversionConfig());
+/// provided 'convertedOps' set; note that no actual rewrites are applied to the
+/// operations on success and only pre-existing operations are added to the set.
+/// This method only returns failure if there are unreachable blocks in any of
+/// the regions nested within 'ops'. There's an additional argument
+/// `notifyCallback` which is used for collecting match failure diagnostics
+/// generated during the conversion. Diagnostics are only reported to this
+/// callback may only be available in debug mode.
+LogicalResult applyAnalysisConversion(
+ ArrayRef<Operation *> ops, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ DenseSet<Operation *> &convertedOps,
+ function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+LogicalResult applyAnalysisConversion(
+ Operation *op, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ DenseSet<Operation *> &convertedOps,
+ function_ref<void(Diagnostic &)> notifyCallback = nullptr);
} // namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 26899301eb742e..ffdb442033d323 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -230,8 +230,6 @@ class IRRewrite {
/// Erase the given block (unless it was already erased).
void eraseBlock(Block *block);
- const ConversionConfig &getConfig() const;
-
const Kind kind;
ConversionPatternRewriterImpl &rewriterImpl;
};
@@ -734,9 +732,8 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
- const ConversionConfig &config)
- : eraseRewriter(ctx), config(config) {}
+ explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+ : eraseRewriter(rewriter.getContext()) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -936,8 +933,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// converting the arguments of blocks within that region.
DenseMap<Region *, const TypeConverter *> regionToConverter;
- /// Dialect conversion configuration.
- const ConversionConfig &config;
+ /// This allows the user to collect the match failure message.
+ function_ref<void(Diagnostic &)> notifyCallback;
+
+ /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
+ /// this is populated with ops found to be legalizable to the target.
+ /// When mode == OpConversionMode::Partial, this is populated with ops found
+ /// *not* to be legalizable to the target.
+ DenseSet<Operation *> *trackedOps = nullptr;
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
@@ -960,10 +963,6 @@ void IRRewrite::eraseBlock(Block *block) {
rewriterImpl.eraseRewriter.eraseBlock(block);
}
-const ConversionConfig &IRRewrite::getConfig() const {
- return rewriterImpl.config;
-}
-
void BlockTypeConversionRewrite::commit() {
// Process the remapping for each of the original arguments.
for (auto [origArg, info] :
@@ -1081,8 +1080,8 @@ void ReplaceOperationRewrite::commit() {
if (Value newValue =
rewriterImpl.mapping.lookupOrNull(result, result.getType()))
result.replaceAllUsesWith(newValue);
- if (getConfig().unlegalizedOps)
- getConfig().unlegalizedOps->erase(op);
+ if (rewriterImpl.trackedOps)
+ rewriterImpl.trackedOps->erase(op);
// Do not erase the operation yet. It may still be referenced in `mapping`.
op->getBlock()->getOperations().remove(op);
}
@@ -1505,8 +1504,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
logger.startLine() << "** Failure : " << diag.str() << "\n";
- if (config.notifyCallback)
- config.notifyCallback(diag);
+ if (notifyCallback)
+ notifyCallback(diag);
});
}
@@ -1514,10 +1513,9 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
// ConversionPatternRewriter
//===----------------------------------------------------------------------===//
-ConversionPatternRewriter::ConversionPatternRewriter(
- MLIRContext *ctx, const ConversionConfig &config)
+ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this)) {
setListener(impl.get());
}
@@ -1986,12 +1984,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
- if (rewriterImpl.config.notifyCallback) {
+ if (rewriterImpl.notifyCallback) {
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
diag << "Failed to apply pattern \"" << pattern.getDebugName()
<< "\" on op:\n"
<< *op;
- rewriterImpl.config.notifyCallback(diag);
+ rewriterImpl.notifyCallback(diag);
}
});
rewriterImpl.resetState(curState);
@@ -2379,12 +2377,14 @@ namespace mlir {
struct OperationConverter {
explicit OperationConverter(const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- const ConversionConfig &config,
- OpConversionMode mode)
- : opLegalizer(target, patterns), config(config), mode(mode) {}
+ OpConversionMode mode,
+ DenseSet<Operation *> *trackedOps = nullptr)
+ : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
/// Converts the given operations to the conversion target.
- LogicalResult convertOperations(ArrayRef<Operation *> ops);
+ LogicalResult
+ convertOperations(ArrayRef<Operation *> ops,
+ function_ref<void(Diagnostic &)> notifyCallback = nullptr);
private:
/// Converts an operation with the given rewriter.
@@ -2421,11 +2421,14 @@ struct OperationConverter {
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
- /// Dialect conversion configuration.
- ConversionConfig config;
-
/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
+
+ /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
+ /// this is populated with ops found to be legalizable to the target.
+ /// When mode == OpConversionMode::Partial, this is populated with ops found
+ /// *not* to be legalizable to the target.
+ DenseSet<Operation *> *trackedOps;
};
} // namespace mlir
@@ -2439,27 +2442,28 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
// Partial conversions allow conversions to fail iff the operation was not
- // explicitly marked as illegal. If the user provided a `unlegalizedOps`
- // set, non-legalizable ops are added to that set.
+ // explicitly marked as illegal. If the user provided a nonlegalizableOps
+ // set, non-legalizable ops are included.
if (mode == OpConversionMode::Partial) {
if (opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
- if (config.unlegalizedOps)
- config.unlegalizedOps->insert(op);
+ if (trackedOps)
+ trackedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
- if (config.legalizableOps)
- config.legalizableOps->insert(op);
+ trackedOps->insert(op);
}
return success();
}
-LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
+LogicalResult OperationConverter::convertOperations(
+ ArrayRef<Operation *> ops,
+ function_ref<void(Diagnostic &)> notifyCallback) {
if (ops.empty())
return success();
const ConversionTarget &target = opLegalizer.getTarget();
@@ -2480,8 +2484,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
}
// Convert each operation and discard rewrites on failure.
- ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
+ ConversionPatternRewriter rewriter(ops.front()->getContext());
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+ rewriterImpl.notifyCallback = notifyCallback;
+ rewriterImpl.trackedOps = trackedOps;
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
@@ -3468,51 +3474,57 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
//===----------------------------------------------------------------------===//
// Partial Conversion
-LogicalResult mlir::applyPartialConversion(
- ArrayRef<Operation *> ops, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns, ConversionConfig config) {
- OperationConverter opConverter(target, patterns, config,
- OpConversionMode::Partial);
+LogicalResult
+mlir::applyPartialConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ DenseSet<Operation *> *unconvertedOps) {
+ OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
+ unconvertedOps);
return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- ConversionConfig config) {
- return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
+ DenseSet<Operation *> *unconvertedOps) {
+ return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
+ unconvertedOps);
}
//===----------------------------------------------------------------------===//
// Full Conversion
-LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
- const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- ConversionConfig config) {
- OperationConverter opConverter(target, patterns, config,
- OpConversionMode::Full);
+LogicalResult
+mlir::applyFullConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns) {
+ OperationConverter opConverter(target, patterns, OpConversionMode::Full);
return opConverter.convertOperations(ops);
}
-LogicalResult mlir::applyFullConversion(Operation *op,
- const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- ConversionConfig config) {
- return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
+LogicalResult
+mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns) {
+ return applyFullConversion(llvm::ArrayRef(op), target, patterns);
}
//===----------------------------------------------------------------------===//
// Analysis Conversion
-LogicalResult mlir::applyAnalysisConversion(
- ArrayRef<Operation *> ops, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns, ConversionConfig config) {
- OperationConverter opConverter(target, patterns, config,
- OpConversionMode::Analysis);
- return opConverter.convertOperations(ops);
+LogicalResult
+mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
+ ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ DenseSet<Operation *> &convertedOps,
+ function_ref<void(Diagnostic &)> notifyCallback) {
+ OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
+ &convertedOps);
+ return opConverter.convertOperations(ops, notifyCallback);
}
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- ConversionConfig config) {
- return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
+ DenseSet<Operation *> &convertedOps,
+ function_ref<void(Diagnostic &)> notifyCallback) {
+ return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
+ convertedOps, notifyCallback);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index abc0e43c7b7f2d..157bfcc1eb23be 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1152,10 +1152,8 @@ struct TestLegalizePatternDriver
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
- ConversionConfig config;
- config.unlegalizedOps = &unlegalizedOps;
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns), config))) {
+ if (failed(applyPartialConversion(
+ getOperation(), target, std::move(patterns), &unlegalizedOps))) {
getOperation()->emitRemark() << "applyPartialConversion failed";
}
// Emit remarks for each legalizable operation.
@@ -1183,10 +1181,8 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
- ConversionConfig config;
- config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
- std::move(patterns), config)))
+ std::move(patterns), legalizedOps)))
return signalPassFailure();
// Emit remarks for each legalizable operation.
@@ -1809,10 +1805,8 @@ struct TestMergeBlocksPatternDriver
});
DenseSet<Operation *> unlegalizedOps;
- ConversionConfig config;
- config.unlegalizedOps = &unlegalizedOps;
(void)applyPartialConversion(getOperation(), target, std::move(patterns),
- config);
+ &unlegalizedOps);
for (auto *op : unlegalizedOps)
op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
}
More information about the Mlir-commits
mailing list