[Mlir-commits] [mlir] 01b55f1 - [NFC] Tidy up DialectConversion.cpp
River Riddle
llvmlistbot at llvm.org
Tue Oct 26 19:12:30 PDT 2021
Author: River Riddle
Date: 2021-10-27T02:00:50Z
New Revision: 01b55f163a40073914bf92f5b710319e7cae2e02
URL: https://github.com/llvm/llvm-project/commit/01b55f163a40073914bf92f5b710319e7cae2e02
DIFF: https://github.com/llvm/llvm-project/commit/01b55f163a40073914bf92f5b710319e7cae2e02.diff
LOG: [NFC] Tidy up DialectConversion.cpp
This file has gotten a bit crusty over the years, and has outdated stylistic decisions.
Added:
Modified:
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7e937825845e9..da4f125a18a56 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -163,6 +163,166 @@ Value ConversionValueMapping::lookupOrNull(Value from) const {
return result == from ? nullptr : result;
}
+//===----------------------------------------------------------------------===//
+// Rewriter and Translation State
+//===----------------------------------------------------------------------===//
+namespace {
+/// This class contains a snapshot of the current conversion rewriter state.
+/// This is useful when saving and undoing a set of rewrites.
+struct RewriterState {
+ RewriterState(unsigned numCreatedOps, unsigned numReplacements,
+ unsigned numArgReplacements, unsigned numBlockActions,
+ unsigned numIgnoredOperations, unsigned numRootUpdates)
+ : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
+ numArgReplacements(numArgReplacements),
+ numBlockActions(numBlockActions),
+ numIgnoredOperations(numIgnoredOperations),
+ numRootUpdates(numRootUpdates) {}
+
+ /// The current number of created operations.
+ unsigned numCreatedOps;
+
+ /// The current number of replacements queued.
+ unsigned numReplacements;
+
+ /// The current number of argument replacements queued.
+ unsigned numArgReplacements;
+
+ /// The current number of block actions performed.
+ unsigned numBlockActions;
+
+ /// The current number of ignored operations.
+ unsigned numIgnoredOperations;
+
+ /// The current number of operations that were updated in place.
+ unsigned numRootUpdates;
+};
+
+//===----------------------------------------------------------------------===//
+// OperationTransactionState
+
+/// The state of an operation that was updated by a pattern in-place. This
+/// contains all of the necessary information to reconstruct an operation that
+/// was updated in place.
+class OperationTransactionState {
+public:
+ OperationTransactionState() = default;
+ OperationTransactionState(Operation *op)
+ : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
+ operands(op->operand_begin(), op->operand_end()),
+ successors(op->successor_begin(), op->successor_end()) {}
+
+ /// Discard the transaction state and reset the state of the original
+ /// operation.
+ void resetOperation() const {
+ op->setLoc(loc);
+ op->setAttrs(attrs);
+ op->setOperands(operands);
+ for (auto it : llvm::enumerate(successors))
+ op->setSuccessor(it.value(), it.index());
+ }
+
+ /// Return the original operation of this state.
+ Operation *getOperation() const { return op; }
+
+private:
+ Operation *op;
+ LocationAttr loc;
+ DictionaryAttr attrs;
+ SmallVector<Value, 8> operands;
+ SmallVector<Block *, 2> successors;
+};
+
+//===----------------------------------------------------------------------===//
+// OpReplacement
+
+/// This class represents one requested operation replacement via 'replaceOp' or
+/// 'eraseOp`.
+struct OpReplacement {
+ OpReplacement(TypeConverter *converter = nullptr) : converter(converter) {}
+
+ /// An optional type converter that can be used to materialize conversions
+ /// between the new and old values if necessary.
+ TypeConverter *converter;
+};
+
+//===----------------------------------------------------------------------===//
+// BlockAction
+
+/// The kind of the block action performed during the rewrite. Actions can be
+/// undone if the conversion fails.
+enum class BlockActionKind {
+ Create,
+ Erase,
+ Merge,
+ Move,
+ Split,
+ TypeConversion
+};
+
+/// Original position of the given block in its parent region. During undo
+/// actions, the block needs to be placed after `insertAfterBlock`.
+struct BlockPosition {
+ Region *region;
+ Block *insertAfterBlock;
+};
+
+/// Information needed to undo the merge actions.
+/// - the source block, and
+/// - the Operation that was the last operation in the dest block before the
+/// merge (could be null if the dest block was empty).
+struct MergeInfo {
+ Block *sourceBlock;
+ Operation *destBlockLastInst;
+};
+
+/// The storage class for an undoable block action (one of BlockActionKind),
+/// contains the information necessary to undo this action.
+struct BlockAction {
+ static BlockAction getCreate(Block *block) {
+ return {BlockActionKind::Create, block, {}};
+ }
+ static BlockAction getErase(Block *block, BlockPosition originalPosition) {
+ return {BlockActionKind::Erase, block, {originalPosition}};
+ }
+ static BlockAction getMerge(Block *block, Block *sourceBlock) {
+ BlockAction action{BlockActionKind::Merge, block, {}};
+ action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
+ return action;
+ }
+ static BlockAction getMove(Block *block, BlockPosition originalPosition) {
+ return {BlockActionKind::Move, block, {originalPosition}};
+ }
+ static BlockAction getSplit(Block *block, Block *originalBlock) {
+ BlockAction action{BlockActionKind::Split, block, {}};
+ action.originalBlock = originalBlock;
+ return action;
+ }
+ static BlockAction getTypeConversion(Block *block) {
+ return BlockAction{BlockActionKind::TypeConversion, block, {}};
+ }
+
+ // The action kind.
+ BlockActionKind kind;
+
+ // A pointer to the block that was created by the action.
+ Block *block;
+
+ union {
+ // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
+ // contains a pointer to the region that originally contained the block as
+ // well as the position of the block in that region.
+ BlockPosition originalPosition;
+ // In use if kind == BlockActionKind::Split and contains a pointer to the
+ // block that was split into two parts.
+ Block *originalBlock;
+ // In use if kind == BlockActionKind::Merge, and contains the information
+ // needed to undo the merge.
+ MergeInfo mergeInfo;
+ };
+};
+} // end anonymous namespace
+
//===----------------------------------------------------------------------===//
// ArgConverter
//===----------------------------------------------------------------------===//
@@ -499,7 +659,7 @@ Block *ArgConverter::applySignatureConversion(
// match (i.e. it's an identity map), then the argument is mapped to its
// original type.
if (replArgs.size() == 1 && replArgs[0].getType() == origArg.getType())
- newArg = replArgs[0];
+ newArg = replArgs.front();
else
newArg = converter.materializeArgumentConversion(
rewriter, origArg.getLoc(), origArg.getType(), replArgs);
@@ -535,158 +695,6 @@ void ArgConverter::insertConversion(Block *newBlock,
conversionInfo.insert({newBlock, std::move(info)});
}
-//===----------------------------------------------------------------------===//
-// Rewriter and Translation State
-//===----------------------------------------------------------------------===//
-namespace {
-/// This class contains a snapshot of the current conversion rewriter state.
-/// This is useful when saving and undoing a set of rewrites.
-struct RewriterState {
- RewriterState(unsigned numCreatedOps, unsigned numReplacements,
- unsigned numArgReplacements, unsigned numBlockActions,
- unsigned numIgnoredOperations, unsigned numRootUpdates)
- : numCreatedOps(numCreatedOps), numReplacements(numReplacements),
- numArgReplacements(numArgReplacements),
- numBlockActions(numBlockActions),
- numIgnoredOperations(numIgnoredOperations),
- numRootUpdates(numRootUpdates) {}
-
- /// The current number of created operations.
- unsigned numCreatedOps;
-
- /// The current number of replacements queued.
- unsigned numReplacements;
-
- /// The current number of argument replacements queued.
- unsigned numArgReplacements;
-
- /// The current number of block actions performed.
- unsigned numBlockActions;
-
- /// The current number of ignored operations.
- unsigned numIgnoredOperations;
-
- /// The current number of operations that were updated in place.
- unsigned numRootUpdates;
-};
-
-/// The state of an operation that was updated by a pattern in-place. This
-/// contains all of the necessary information to reconstruct an operation that
-/// was updated in place.
-class OperationTransactionState {
-public:
- OperationTransactionState() = default;
- OperationTransactionState(Operation *op)
- : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
- operands(op->operand_begin(), op->operand_end()),
- successors(op->successor_begin(), op->successor_end()) {}
-
- /// Discard the transaction state and reset the state of the original
- /// operation.
- void resetOperation() const {
- op->setLoc(loc);
- op->setAttrs(attrs);
- op->setOperands(operands);
- for (auto it : llvm::enumerate(successors))
- op->setSuccessor(it.value(), it.index());
- }
-
- /// Return the original operation of this state.
- Operation *getOperation() const { return op; }
-
-private:
- Operation *op;
- LocationAttr loc;
- DictionaryAttr attrs;
- SmallVector<Value, 8> operands;
- SmallVector<Block *, 2> successors;
-};
-
-/// This class represents one requested operation replacement via 'replaceOp' or
-/// 'eraseOp`.
-struct OpReplacement {
- OpReplacement() = default;
- OpReplacement(TypeConverter *converter) : converter(converter) {}
-
- /// An optional type converter that can be used to materialize conversions
- /// between the new and old values if necessary.
- TypeConverter *converter = nullptr;
-};
-
-/// The kind of the block action performed during the rewrite. Actions can be
-/// undone if the conversion fails.
-enum class BlockActionKind {
- Create,
- Erase,
- Merge,
- Move,
- Split,
- TypeConversion
-};
-
-/// Original position of the given block in its parent region. During undo
-/// actions, the block needs to be placed after `insertAfterBlock`.
-struct BlockPosition {
- Region *region;
- Block *insertAfterBlock;
-};
-
-/// Information needed to undo the merge actions.
-/// - the source block, and
-/// - the Operation that was the last operation in the dest block before the
-/// merge (could be null if the dest block was empty).
-struct MergeInfo {
- Block *sourceBlock;
- Operation *destBlockLastInst;
-};
-
-/// The storage class for an undoable block action (one of BlockActionKind),
-/// contains the information necessary to undo this action.
-struct BlockAction {
- static BlockAction getCreate(Block *block) {
- return {BlockActionKind::Create, block, {}};
- }
- static BlockAction getErase(Block *block, BlockPosition originalPosition) {
- return {BlockActionKind::Erase, block, {originalPosition}};
- }
- static BlockAction getMerge(Block *block, Block *sourceBlock) {
- BlockAction action{BlockActionKind::Merge, block, {}};
- action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
- return action;
- }
- static BlockAction getMove(Block *block, BlockPosition originalPosition) {
- return {BlockActionKind::Move, block, {originalPosition}};
- }
- static BlockAction getSplit(Block *block, Block *originalBlock) {
- BlockAction action{BlockActionKind::Split, block, {}};
- action.originalBlock = originalBlock;
- return action;
- }
- static BlockAction getTypeConversion(Block *block) {
- return BlockAction{BlockActionKind::TypeConversion, block, {}};
- }
-
- // The action kind.
- BlockActionKind kind;
-
- // A pointer to the block that was created by the action.
- Block *block;
-
- union {
- // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
- // contains a pointer to the region that originally contained the block as
- // well as the position of the block in that region.
- BlockPosition originalPosition;
- // In use if kind == BlockActionKind::Split and contains a pointer to the
- // block that was split into two parts.
- Block *originalBlock;
- // In use if kind == BlockActionKind::Merge, and contains the information
- // needed to undo the merge.
- MergeInfo mergeInfo;
- };
-};
-} // end anonymous namespace
-
//===----------------------------------------------------------------------===//
// ConversionPatternRewriterImpl
//===----------------------------------------------------------------------===//
@@ -843,9 +851,9 @@ struct ConversionPatternRewriterImpl {
/// explicitly provided.
TypeConverter defaultTypeConverter;
- /// The current conversion pattern that is being rewritten, or nullptr if
- /// called from outside of a conversion pattern rewrite.
- const ConversionPattern *currentConversionPattern = nullptr;
+ /// The current type converter, or nullptr if no type converter is currently
+ /// active.
+ TypeConverter *currentTypeConverter = nullptr;
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
@@ -1237,10 +1245,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
operationsWithChangedResults.push_back(replacements.size());
// Record the requested operation replacement.
- TypeConverter *converter = nullptr;
- if (currentConversionPattern)
- converter = currentConversionPattern->getTypeConverter();
- replacements.insert(std::make_pair(op, OpReplacement(converter)));
+ replacements.insert(std::make_pair(op, OpReplacement(currentTypeConverter)));
// Mark this operation as recursively ignored so that we don't need to
// convert any nested operations.
@@ -1313,8 +1318,6 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
impl(new detail::ConversionPatternRewriterImpl(*this)) {}
ConversionPatternRewriter::~ConversionPatternRewriter() {}
-/// PatternRewriter hook for replacing the results of an operation when the
-/// given functor returns true.
void ConversionPatternRewriter::replaceOpWithIf(
Operation *op, ValueRange newValues, bool *allUsesReplaced,
llvm::unique_function<bool(OpOperand &) const> functor) {
@@ -1328,7 +1331,6 @@ void ConversionPatternRewriter::replaceOpWithIf(
"replaceOpWithIf is currently not supported by DialectConversion");
}
-/// PatternRewriter hook for replacing the results of an operation.
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
LLVM_DEBUG({
impl->logger.startLine()
@@ -1337,9 +1339,6 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->notifyOpReplaced(op, newValues);
}
-/// PatternRewriter hook for erasing a dead operation. The uses of this
-/// operation *must* be made dead by the end of the conversion process,
-/// otherwise an assert will be issued.
void ConversionPatternRewriter::eraseOp(Operation *op) {
LLVM_DEBUG({
impl->logger.startLine()
@@ -1393,18 +1392,14 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
-/// Return the converted value that replaces 'key'. Return 'key' if there is
-/// no such a converted value.
Value ConversionPatternRewriter::getRemappedValue(Value key) {
return impl->mapping.lookupOrDefault(key);
}
-/// PatternRewriter hook for creating a new block with the given arguments.
void ConversionPatternRewriter::notifyBlockCreated(Block *block) {
impl->notifyCreatedBlock(block);
}
-/// PatternRewriter hook for splitting a block into two parts.
Block *ConversionPatternRewriter::splitBlock(Block *block,
Block::iterator before) {
auto *continuation = PatternRewriter::splitBlock(block, before);
@@ -1412,7 +1407,6 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
return continuation;
}
-/// PatternRewriter hook for merging a block into another.
void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
ValueRange argValues) {
impl->notifyBlocksBeingMerged(dest, source);
@@ -1427,7 +1421,6 @@ void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
eraseBlock(source);
}
-/// PatternRewriter hook for moving blocks out of a region.
void ConversionPatternRewriter::inlineRegionBefore(Region ®ion,
Region &parent,
Region::iterator before) {
@@ -1435,7 +1428,6 @@ void ConversionPatternRewriter::inlineRegionBefore(Region ®ion,
PatternRewriter::inlineRegionBefore(region, parent, before);
}
-/// PatternRewriter hook for cloning blocks of one region into another.
void ConversionPatternRewriter::cloneRegionBefore(
Region ®ion, Region &parent, Region::iterator before,
BlockAndValueMapping &mapping) {
@@ -1449,7 +1441,6 @@ void ConversionPatternRewriter::cloneRegionBefore(
impl->notifyRegionWasClonedBefore(clonedBlocks, region.getLoc());
}
-/// PatternRewriter hook for creating a new operation.
void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
LLVM_DEBUG({
impl->logger.startLine()
@@ -1458,7 +1449,6 @@ void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
impl->createdOps.push_back(op);
}
-/// PatternRewriter hook for updating the root operation in-place.
void ConversionPatternRewriter::startRootUpdate(Operation *op) {
#ifndef NDEBUG
impl->pendingRootUpdates.insert(op);
@@ -1466,7 +1456,6 @@ void ConversionPatternRewriter::startRootUpdate(Operation *op) {
impl->rootUpdates.emplace_back(op);
}
-/// PatternRewriter hook for updating the root operation in-place.
void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
// There is nothing to do here, we only need to track the operation at the
// start of the update.
@@ -1476,7 +1465,6 @@ void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) {
#endif
}
-/// PatternRewriter hook for updating the root operation in-place.
void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
@@ -1492,13 +1480,11 @@ void ConversionPatternRewriter::cancelRootUpdate(Operation *op) {
rootUpdates.erase(rootUpdates.begin() + updateIdx);
}
-/// PatternRewriter hook for notifying match failure reasons.
LogicalResult ConversionPatternRewriter::notifyMatchFailure(
Operation *op, function_ref<void(Diagnostic &)> reasonCallback) {
return impl->notifyMatchFailure(op->getLoc(), reasonCallback);
}
-/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
return *impl;
}
@@ -1507,18 +1493,15 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
-/// Attempt to match and rewrite the IR root at the specified operation.
LogicalResult
ConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
auto &rewriterImpl = dialectRewriter.getImpl();
- // Track the current conversion pattern in the rewriter.
- assert(!rewriterImpl.currentConversionPattern &&
- "already inside of a pattern rewrite");
- llvm::SaveAndRestore<const ConversionPattern *> currentPatternGuard(
- rewriterImpl.currentConversionPattern, this);
+ // Track the current conversion pattern type converter in the rewriter.
+ llvm::SaveAndRestore<TypeConverter *> currentConverterGuard(
+ rewriterImpl.currentTypeConverter, getTypeConverter());
// Remap the operands of the operation.
SmallVector<Value, 4> operands;
@@ -1669,20 +1652,19 @@ OperationLegalizer::legalize(Operation *op,
const char *logLineComment =
"//===-------------------------------------------===//\n";
- auto &rewriterImpl = rewriter.getImpl();
+ auto &logger = rewriter.getImpl().logger;
#endif
LLVM_DEBUG({
- auto &os = rewriterImpl.logger;
- os.getOStream() << "\n";
- os.startLine() << logLineComment;
- os.startLine() << "Legalizing operation : '" << op->getName() << "'(" << op
- << ") {\n";
- os.indent();
+ logger.getOStream() << "\n";
+ logger.startLine() << logLineComment;
+ logger.startLine() << "Legalizing operation : '" << op->getName() << "'("
+ << op << ") {\n";
+ logger.indent();
// If the operation has no regions, just print it here.
if (op->getNumRegions() == 0) {
- op->print(os.startLine(), OpPrintingFlags().printGenericOpForm());
- os.getOStream() << "\n\n";
+ op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm());
+ logger.getOStream() << "\n\n";
}
});
@@ -1690,11 +1672,11 @@ OperationLegalizer::legalize(Operation *op,
if (auto legalityInfo = target.isLegal(op)) {
LLVM_DEBUG({
logSuccess(
- rewriterImpl.logger, "operation marked legal by the target{0}",
+ logger, "operation marked legal by the target{0}",
legalityInfo->isRecursivelyLegal
? "; NOTE: operation is recursively legal; skipping internals"
: "");
- rewriterImpl.logger.startLine() << logLineComment;
+ logger.startLine() << logLineComment;
});
// If this operation is recursively legal, mark its children as ignored so
@@ -1707,9 +1689,8 @@ OperationLegalizer::legalize(Operation *op,
// Check to see if the operation is ignored and doesn't need to be converted.
if (rewriter.getImpl().isOpIgnored(op)) {
LLVM_DEBUG({
- logSuccess(rewriterImpl.logger,
- "operation marked 'ignored' during conversion");
- rewriterImpl.logger.startLine() << logLineComment;
+ logSuccess(logger, "operation marked 'ignored' during conversion");
+ logger.startLine() << logLineComment;
});
return success();
}
@@ -1719,8 +1700,8 @@ OperationLegalizer::legalize(Operation *op,
// already legal?
if (succeeded(legalizeWithFold(op, rewriter))) {
LLVM_DEBUG({
- logSuccess(rewriterImpl.logger, "operation was folded");
- rewriterImpl.logger.startLine() << logLineComment;
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
});
return success();
}
@@ -1728,15 +1709,15 @@ OperationLegalizer::legalize(Operation *op,
// Otherwise, we need to apply a legalization pattern to this operation.
if (succeeded(legalizeWithPattern(op, rewriter))) {
LLVM_DEBUG({
- logSuccess(rewriterImpl.logger, "");
- rewriterImpl.logger.startLine() << logLineComment;
+ logSuccess(logger, "");
+ logger.startLine() << logLineComment;
});
return success();
}
LLVM_DEBUG({
- logFailure(rewriterImpl.logger, "no matched legalization pattern");
- rewriterImpl.logger.startLine() << logLineComment;
+ logFailure(logger, "no matched legalization pattern");
+ logger.startLine() << logLineComment;
});
return failure();
}
@@ -1927,6 +1908,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
}
return success();
}
+
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
@@ -1941,6 +1923,7 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
}
return success();
}
+
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
RewriterState &state, RewriterState &newState) {
@@ -2151,16 +2134,16 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
//===----------------------------------------------------------------------===//
namespace {
enum OpConversionMode {
- // In this mode, the conversion will ignore failed conversions to allow
- // illegal operations to co-exist in the IR.
+ /// In this mode, the conversion will ignore failed conversions to allow
+ /// illegal operations to co-exist in the IR.
Partial,
- // In this mode, all operations must be legal for the given target for the
- // conversion to succeed.
+ /// In this mode, all operations must be legal for the given target for the
+ /// conversion to succeed.
Full,
- // In this mode, operations are analyzed for legality. No actual rewrites are
- // applied to the operations on success.
+ /// In this mode, operations are analyzed for legality. No actual rewrites are
+ /// applied to the operations on success.
Analysis,
};
@@ -2274,11 +2257,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// legalized.
if (failed(finalize(rewriter)))
return rewriterImpl.discardRewrites(), failure();
+
// After a successful conversion, apply rewrites if this is not an analysis
// conversion.
- if (mode == OpConversionMode::Analysis)
+ if (mode == OpConversionMode::Analysis) {
rewriterImpl.discardRewrites();
- else {
+ } else {
rewriterImpl.applyRewrites();
// It is possible for a later pattern to erase an op that was originally
@@ -2465,8 +2449,6 @@ LogicalResult OperationConverter::legalizeChangedResultType(
// Type Conversion
//===----------------------------------------------------------------------===//
-/// Remap an input of the original signature with a new set of types. The
-/// new types are appended to the new signature conversion.
void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
ArrayRef<Type> types) {
assert(!types.empty() && "expected valid types");
@@ -2474,16 +2456,12 @@ void TypeConverter::SignatureConversion::addInputs(unsigned origInputNo,
addInputs(types);
}
-/// Append new input types to the signature conversion, this should only be
-/// used if the new types are not intended to remap an existing input.
void TypeConverter::SignatureConversion::addInputs(ArrayRef<Type> types) {
assert(!types.empty() &&
"1->0 type remappings don't need to be added explicitly");
argTypes.append(types.begin(), types.end());
}
-/// Remap an input of the original signature with a range of types in the
-/// new signature.
void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
unsigned newInputNo,
unsigned newInputCount) {
@@ -2493,8 +2471,6 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
InputMapping{newInputNo, newInputCount, /*replacementValue=*/nullptr};
}
-/// Remap an input of the original signature to another `replacementValue`
-/// value. This would make the signature converter drop this argument.
void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
Value replacementValue) {
assert(!remappedInputs[origInputNo] && "input has already been remapped");
@@ -2502,7 +2478,6 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
InputMapping{origInputNo, /*size=*/0, replacementValue};
}
-/// This hooks allows for converting a type.
LogicalResult TypeConverter::convertType(Type t,
SmallVectorImpl<Type> &results) {
auto existingIt = cachedDirectConversions.find(t);
@@ -2537,8 +2512,6 @@ LogicalResult TypeConverter::convertType(Type t,
return failure();
}
-/// This hook simplifies defining 1-1 type conversions. This function returns
-/// the type to convert to on success, and a null type on failure.
Type TypeConverter::convertType(Type t) {
// Use the multi-type result version to convert the type.
SmallVector<Type, 1> results;
@@ -2549,9 +2522,6 @@ Type TypeConverter::convertType(Type t) {
return results.size() == 1 ? results.front() : nullptr;
}
-/// Convert the given set of types, filling 'results' as necessary. This
-/// returns failure if the conversion of any of the types fails, success
-/// otherwise.
LogicalResult TypeConverter::convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) {
for (Type type : types)
@@ -2560,28 +2530,21 @@ LogicalResult TypeConverter::convertTypes(TypeRange types,
return success();
}
-/// Return true if the given type is legal for this type converter, i.e. the
-/// type converts to itself.
bool TypeConverter::isLegal(Type type) { return convertType(type) == type; }
-/// Return true if the given operation has legal operand and result types.
bool TypeConverter::isLegal(Operation *op) {
return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
}
-/// Return true if the types of block arguments within the region are legal.
bool TypeConverter::isLegal(Region *region) {
return llvm::all_of(*region, [this](Block &block) {
return isLegal(block.getArgumentTypes());
});
}
-/// Return true if the inputs and outputs of the given function type are
-/// legal.
bool TypeConverter::isSignatureLegal(FunctionType ty) {
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
}
-/// This hook allows for converting a specific argument of a signature.
LogicalResult TypeConverter::convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) {
// Try to convert the given input type.
@@ -2615,9 +2578,6 @@ Value TypeConverter::materializeConversion(
return nullptr;
}
-/// This function converts the type signature of the given block, by invoking
-/// 'convertSignatureArg' for each argument. This function should return a valid
-/// conversion for the signature on success, None otherwise.
auto TypeConverter::convertBlockSignature(Block *block)
-> Optional<SignatureConversion> {
SignatureConversion conversion(block->getNumArguments());
@@ -2626,6 +2586,10 @@ auto TypeConverter::convertBlockSignature(Block *block)
return conversion;
}
+//===----------------------------------------------------------------------===//
+// FunctionLikeSignatureConversion
+//===----------------------------------------------------------------------===//
+
/// Create a default conversion pattern that rewrites the type signature of a
/// FunctionLike op. This only supports FunctionLike ops which use FunctionType
/// to represent their type.
@@ -2678,29 +2642,23 @@ void mlir::populateFuncOpTypeConversionPattern(RewritePatternSet &patterns,
// ConversionTarget
//===----------------------------------------------------------------------===//
-/// Register a legality action for the given operation.
void ConversionTarget::setOpAction(OperationName op,
LegalizationAction action) {
legalOperations[op].action = action;
}
-/// Register a legality action for the given dialects.
void ConversionTarget::setDialectAction(ArrayRef<StringRef> dialectNames,
LegalizationAction action) {
for (StringRef dialect : dialectNames)
legalDialects[dialect] = action;
}
-/// Get the legality action for the given operation.
auto ConversionTarget::getOpAction(OperationName op) const
-> Optional<LegalizationAction> {
Optional<LegalizationInfo> info = getOpInfo(op);
return info ? info->action : Optional<LegalizationAction>();
}
-/// If the given operation instance is legal on this target, a structure
-/// containing legality information is returned. If the operation is not legal,
-/// None is returned.
auto ConversionTarget::isLegal(Operation *op) const
-> Optional<LegalOpDetails> {
Optional<LegalizationInfo> info = getOpInfo(op->getName());
@@ -2752,7 +2710,6 @@ static ConversionTarget::DynamicLegalityCallbackFn composeLegalityCallbacks(
return chain;
}
-/// Set the dynamic legality callback for the given operation.
void ConversionTarget::setLegalityCallback(
OperationName name, const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
@@ -2764,8 +2721,6 @@ void ConversionTarget::setLegalityCallback(
composeLegalityCallbacks(std::move(infoIt->second.legalityFn), callback);
}
-/// Set the recursive legality callback for the given operation and mark the
-/// operation as recursively legal.
void ConversionTarget::markOpRecursivelyLegal(
OperationName name, const DynamicLegalityCallbackFn &callback) {
auto infoIt = legalOperations.find(name);
@@ -2780,7 +2735,6 @@ void ConversionTarget::markOpRecursivelyLegal(
opRecursiveLegalityFns.erase(name);
}
-/// Set the dynamic legality callback for the given dialects.
void ConversionTarget::setLegalityCallback(
ArrayRef<StringRef> dialects, const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
@@ -2789,14 +2743,12 @@ void ConversionTarget::setLegalityCallback(
std::move(dialectLegalityFns[dialect]), callback);
}
-/// Set the dynamic legality callback for the unknown ops.
void ConversionTarget::setLegalityCallback(
const DynamicLegalityCallbackFn &callback) {
assert(callback && "expected valid legality callback");
unknownLegalityFn = composeLegalityCallbacks(unknownLegalityFn, callback);
}
-/// Get the legalization information for the given operation.
auto ConversionTarget::getOpInfo(OperationName op) const
-> Optional<LegalizationInfo> {
// Check for info for this specific operation.
@@ -2824,14 +2776,9 @@ auto ConversionTarget::getOpInfo(OperationName op) const
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
-/// Apply a partial conversion on the given operations and all nested
-/// operations. This method converts as many operations to the target as
-/// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal.
-/// If an `unconvertedOps` set is provided, all operations that are found not
-/// to be legalizable to the given `target` are placed within that set. (Note
-/// that if there is an op explicitly marked as illegal, the conversion
-/// terminates and the `unconvertedOps` set will not necessarily be complete.)
+//===----------------------------------------------------------------------===//
+// Partial Conversion
+
LogicalResult
mlir::applyPartialConversion(ArrayRef<Operation *> ops,
ConversionTarget &target,
@@ -2849,9 +2796,9 @@ mlir::applyPartialConversion(Operation *op, ConversionTarget &target,
unconvertedOps);
}
-/// Apply a complete conversion on the given operations, and all nested
-/// operations. This method will return failure if the conversion of any
-/// operation fails.
+//===----------------------------------------------------------------------===//
+// Full Conversion
+
LogicalResult
mlir::applyFullConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
const FrozenRewritePatternSet &patterns) {
@@ -2864,12 +2811,9 @@ mlir::applyFullConversion(Operation *op, ConversionTarget &target,
return applyFullConversion(llvm::makeArrayRef(op), target, patterns);
}
-/// Apply an analysis conversion on the given operations, and all nested
-/// operations. This method analyzes which operations would be successfully
-/// converted to the target if a conversion was applied. All operations that
-/// were found to be legalizable to the given 'target' are placed within the
-/// provided 'convertedOps' set; note that no actual rewrites are applied to the
-/// operations on success and only pre-existing operations are added to the set.
+//===----------------------------------------------------------------------===//
+// Analysis Conversion
+
LogicalResult
mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
ConversionTarget &target,
More information about the Mlir-commits
mailing list