[Mlir-commits] [mlir] e6a343e - [mlir][DialectConversion][NFC] Add comment blocks and organize a bit of the code
River Riddle
llvmlistbot at llvm.org
Wed Jun 24 17:42:17 PDT 2020
Author: River Riddle
Date: 2020-06-24T17:42:10-07:00
New Revision: e6a343e491d4ee52b4085bf2b2c24669f1f9a6ce
URL: https://github.com/llvm/llvm-project/commit/e6a343e491d4ee52b4085bf2b2c24669f1f9a6ce
DIFF: https://github.com/llvm/llvm-project/commit/e6a343e491d4ee52b4085bf2b2c24669f1f9a6ce.diff
LOG: [mlir][DialectConversion][NFC] Add comment blocks and organize a bit of the code
This helps improve the readability when scrolling through the many functions of ConversionPatternRewriterImpl.
Added:
Modified:
mlir/lib/Transforms/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index ecebe61d025f..60c9e78b7a69 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -450,7 +450,7 @@ void ArgConverter::insertConversion(Block *newBlock,
}
//===----------------------------------------------------------------------===//
-// ConversionPatternRewriterImpl
+// Rewriter and Transation State
//===----------------------------------------------------------------------===//
namespace {
/// This class contains a snapshot of the current conversion rewriter state.
@@ -515,74 +515,89 @@ class OperationTransactionState {
SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;
};
-} // end anonymous namespace
-namespace mlir {
-namespace detail {
-struct ConversionPatternRewriterImpl {
- /// This class represents one requested operation replacement via 'replaceOp'.
- struct OpReplacement {
- OpReplacement() = default;
- OpReplacement(Operation *op, ValueRange newValues)
- : op(op), newValues(newValues.begin(), newValues.end()) {}
-
- Operation *op;
- SmallVector<Value, 2> newValues;
- };
+/// This class represents one requested operation replacement via 'replaceOp'.
+struct OpReplacement {
+ OpReplacement() = default;
+ OpReplacement(Operation *op, ValueRange newValues)
+ : op(op), newValues(newValues.begin(), newValues.end()) {}
- /// The kind of the block action performed during the rewrite. Actions can be
- /// undone if the conversion fails.
- enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion };
+ Operation *op;
+ SmallVector<Value, 2> newValues;
+};
- /// Original position of the given block in its parent region. We cannot use
- /// a region iterator because it could have been invalidated by other region
- /// operations since the position was stored.
- struct BlockPosition {
- Region *region;
- Region::iterator::
diff erence_type position;
- };
+/// The kind of the block action performed during the rewrite. Actions can be
+/// undone if the conversion fails.
+enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion };
- /// The storage class for an undoable block action (one of BlockActionKind),
- /// contains the information necessary to undo this action.
- struct BlockAction {
- static BlockAction getCreate(Block *block) {
- return {BlockActionKind::Create, block, {}};
- }
- static BlockAction getErase(Block *block, BlockPosition originalPos) {
- return {BlockActionKind::Erase, block, {originalPos}};
- }
- static BlockAction getMove(Block *block, BlockPosition originalPos) {
- return {BlockActionKind::Move, block, {originalPos}};
- }
- static BlockAction getSplit(Block *block, Block *originalBlock) {
- BlockAction action{BlockActionKind::Split, block, {}};
- action.originalBlock = originalBlock;
- return action;
- }
- static BlockAction getTypeConversion(Block *block) {
- return BlockAction{BlockActionKind::TypeConversion, block, {}};
- }
+/// Original position of the given block in its parent region. We cannot use
+/// a region iterator because it could have been invalidated by other region
+/// operations since the position was stored.
+struct BlockPosition {
+ Region *region;
+ Region::iterator::
diff erence_type position;
+};
- // The action kind.
- BlockActionKind kind;
-
- // A pointer to the block that was created by the action.
- Block *block;
-
- union {
- // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
- // contains a pointer to the region that originally contained the block as
- // well as the position of the block in that region.
- BlockPosition originalPosition;
- // In use if kind == BlockActionKind::Split and contains a pointer to the
- // block that was split into two parts.
- Block *originalBlock;
- };
+/// The storage class for an undoable block action (one of BlockActionKind),
+/// contains the information necessary to undo this action.
+struct BlockAction {
+ static BlockAction getCreate(Block *block) {
+ return {BlockActionKind::Create, block, {}};
+ }
+ static BlockAction getErase(Block *block, BlockPosition originalPos) {
+ return {BlockActionKind::Erase, block, {originalPos}};
+ }
+ static BlockAction getMove(Block *block, BlockPosition originalPos) {
+ return {BlockActionKind::Move, block, {originalPos}};
+ }
+ static BlockAction getSplit(Block *block, Block *originalBlock) {
+ BlockAction action{BlockActionKind::Split, block, {}};
+ action.originalBlock = originalBlock;
+ return action;
+ }
+ static BlockAction getTypeConversion(Block *block) {
+ return BlockAction{BlockActionKind::TypeConversion, block, {}};
+ }
+
+ // The action kind.
+ BlockActionKind kind;
+
+ // A pointer to the block that was created by the action.
+ Block *block;
+
+ union {
+ // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
+ // contains a pointer to the region that originally contained the block as
+ // well as the position of the block in that region.
+ BlockPosition originalPosition;
+ // In use if kind == BlockActionKind::Split and contains a pointer to the
+ // block that was split into two parts.
+ Block *originalBlock;
};
+};
+} // end anonymous namespace
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriterImpl
+//===----------------------------------------------------------------------===//
+namespace mlir {
+namespace detail {
+struct ConversionPatternRewriterImpl {
ConversionPatternRewriterImpl(PatternRewriter &rewriter)
: argConverter(rewriter) {}
+ /// Cleanup and destroy any generated rewrite operations. This method is
+ /// invoked when the conversion process fails.
+ void discardRewrites();
+
+ /// Apply all requested operation rewrites. This method is invoked when the
+ /// conversion process succeeds.
+ void applyRewrites();
+
+ //===--------------------------------------------------------------------===//
+ // State Management
+ //===--------------------------------------------------------------------===//
+
/// Return the current state of the rewriter.
RewriterState getCurrentState();
@@ -597,13 +612,21 @@ struct ConversionPatternRewriterImpl {
/// "numActionsToKeep" actions remains.
void undoBlockActions(unsigned numActionsToKeep = 0);
- /// Cleanup and destroy any generated rewrite operations. This method is
- /// invoked when the conversion process fails.
- void discardRewrites();
+ /// Remap the given operands to those with potentially
diff erent types.
+ void remapValues(Operation::operand_range operands,
+ SmallVectorImpl<Value> &remapped);
- /// Apply all requested operation rewrites. This method is invoked when the
- /// conversion process succeeds.
- void applyRewrites();
+ /// Returns true if the given operation is ignored, and does not need to be
+ /// converted.
+ bool isOpIgnored(Operation *op) const;
+
+ /// Recursively marks the nested operations under 'op' as ignored. This
+ /// removes them from being considered for legalization.
+ void markNestedOpsIgnored(Operation *op);
+
+ //===--------------------------------------------------------------------===//
+ // Type Conversion
+ //===--------------------------------------------------------------------===//
/// Convert the signature of the given block.
FailureOr<Block *> convertBlockSignature(
@@ -620,8 +643,12 @@ struct ConversionPatternRewriterImpl {
convertRegionTypes(Region *region, TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
+ //===--------------------------------------------------------------------===//
+ // Rewriter Notification Hooks
+ //===--------------------------------------------------------------------===//
+
/// PatternRewriter hook for replacing the results of an operation.
- void replaceOp(Operation *op, ValueRange newValues);
+ void notifyOpReplaced(Operation *op, ValueRange newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
@@ -640,17 +667,9 @@ struct ConversionPatternRewriterImpl {
void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
Location origRegionLoc);
- /// Remap the given operands to those with potentially
diff erent types.
- void remapValues(Operation::operand_range operands,
- SmallVectorImpl<Value> &remapped);
-
- /// Returns true if the given operation is ignored, and does not need to be
- /// converted.
- bool isOpIgnored(Operation *op) const;
-
- /// Recursively marks the nested operations under 'op' as ignored. This
- /// removes them from being considered for legalization.
- void markNestedOpsIgnored(Operation *op);
+ //===--------------------------------------------------------------------===//
+ // State
+ //===--------------------------------------------------------------------===//
// Mapping between replaced values that
diff er in type. This happens when
// replacing a value with one of a
diff erent type.
@@ -700,12 +719,6 @@ struct ConversionPatternRewriterImpl {
} // end namespace detail
} // end namespace mlir
-RewriterState ConversionPatternRewriterImpl::getCurrentState() {
- return RewriterState(createdOps.size(), replacements.size(),
- argReplacements.size(), blockActions.size(),
- ignoredOps.size(), rootUpdates.size());
-}
-
/// Detach any operations nested in the given operation from their parent
/// blocks, and erase the given operation. This can be used when the nested
/// operations are scheduled for erasure themselves, so deleting the regions of
@@ -722,6 +735,73 @@ static void detachNestedAndErase(Operation *op) {
op->erase();
}
+void ConversionPatternRewriterImpl::discardRewrites() {
+ // Reset any operations that were updated in place.
+ for (auto &state : rootUpdates)
+ state.resetOperation();
+
+ undoBlockActions();
+
+ // Remove any newly created ops.
+ for (auto *op : llvm::reverse(createdOps))
+ detachNestedAndErase(op);
+}
+
+void ConversionPatternRewriterImpl::applyRewrites() {
+ // Apply all of the rewrites replacements requested during conversion.
+ for (auto &repl : replacements) {
+ for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) {
+ if (auto newValue = repl.newValues[i])
+ repl.op->getResult(i).replaceAllUsesWith(
+ mapping.lookupOrDefault(newValue));
+ }
+
+ // If this operation defines any regions, drop any pending argument
+ // rewrites.
+ if (repl.op->getNumRegions())
+ argConverter.notifyOpRemoved(repl.op);
+ }
+
+ // Apply all of the requested argument replacements.
+ for (BlockArgument arg : argReplacements) {
+ Value repl = mapping.lookupOrDefault(arg);
+ if (repl.isa<BlockArgument>()) {
+ arg.replaceAllUsesWith(repl);
+ continue;
+ }
+
+ // If the replacement value is an operation, we check to make sure that we
+ // don't replace uses that are within the parent operation of the
+ // replacement value.
+ Operation *replOp = repl.cast<OpResult>().getOwner();
+ Block *replBlock = replOp->getBlock();
+ arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+ Operation *user = operand.getOwner();
+ return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+ });
+ }
+
+ // In a second pass, erase all of the replaced operations in reverse. This
+ // allows processing nested operations before their parent region is
+ // destroyed.
+ for (auto &repl : llvm::reverse(replacements))
+ repl.op->erase();
+
+ argConverter.applyRewrites(mapping);
+
+ // Now that the ops have been erased, also erase dangling blocks.
+ eraseDanglingBlocks();
+}
+
+//===----------------------------------------------------------------------===//
+// State Management
+
+RewriterState ConversionPatternRewriterImpl::getCurrentState() {
+ return RewriterState(createdOps.size(), replacements.size(),
+ argReplacements.size(), blockActions.size(),
+ ignoredOps.size(), rootUpdates.size());
+}
+
void ConversionPatternRewriterImpl::resetState(RewriterState state) {
// Reset any operations that were updated in place.
for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
@@ -810,64 +890,34 @@ void ConversionPatternRewriterImpl::undoBlockActions(
blockActions.resize(numActionsToKeep);
}
-void ConversionPatternRewriterImpl::discardRewrites() {
- // Reset any operations that were updated in place.
- for (auto &state : rootUpdates)
- state.resetOperation();
-
- undoBlockActions();
-
- // Remove any newly created ops.
- for (auto *op : llvm::reverse(createdOps))
- detachNestedAndErase(op);
+void ConversionPatternRewriterImpl::remapValues(
+ Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
+ remapped.reserve(llvm::size(operands));
+ for (Value operand : operands)
+ remapped.push_back(mapping.lookupOrDefault(operand));
}
-void ConversionPatternRewriterImpl::applyRewrites() {
- // Apply all of the rewrites replacements requested during conversion.
- for (auto &repl : replacements) {
- for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) {
- if (auto newValue = repl.newValues[i])
- repl.op->getResult(i).replaceAllUsesWith(
- mapping.lookupOrDefault(newValue));
- }
-
- // If this operation defines any regions, drop any pending argument
- // rewrites.
- if (repl.op->getNumRegions())
- argConverter.notifyOpRemoved(repl.op);
- }
-
- // Apply all of the requested argument replacements.
- for (BlockArgument arg : argReplacements) {
- Value repl = mapping.lookupOrDefault(arg);
- if (repl.isa<BlockArgument>()) {
- arg.replaceAllUsesWith(repl);
- continue;
- }
-
- // If the replacement value is an operation, we check to make sure that we
- // don't replace uses that are within the parent operation of the
- // replacement value.
- Operation *replOp = repl.cast<OpResult>().getOwner();
- Block *replBlock = replOp->getBlock();
- arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
- Operation *user = operand.getOwner();
- return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
- });
- }
-
- // In a second pass, erase all of the replaced operations in reverse. This
- // allows processing nested operations before their parent region is
- // destroyed.
- for (auto &repl : llvm::reverse(replacements))
- repl.op->erase();
-
- argConverter.applyRewrites(mapping);
+bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
+ // Check to see if this operation or its parent were ignored.
+ return ignoredOps.count(op) || ignoredOps.count(op->getParentOp());
+}
- // Now that the ops have been erased, also erase dangling blocks.
- eraseDanglingBlocks();
+void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
+ // Walk this operation and collect nested operations that define non-empty
+ // regions. We mark such operations as 'ignored' so that we know we don't have
+ // to convert them, or their nested ops.
+ if (op->getNumRegions() == 0)
+ return;
+ op->walk([&](Operation *op) {
+ if (llvm::any_of(op->getRegions(),
+ [](Region ®ion) { return !region.empty(); }))
+ ignoredOps.insert(op);
+ });
}
+//===----------------------------------------------------------------------===//
+// Type Conversion
+
FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
Block *block, TypeConverter &converter,
TypeConverter::SignatureConversion *conversion) {
@@ -907,8 +957,11 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
return newEntry;
}
-void ConversionPatternRewriterImpl::replaceOp(Operation *op,
- ValueRange newValues) {
+//===----------------------------------------------------------------------===//
+// Rewriter Notification Hooks
+
+void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
+ ValueRange newValues) {
assert(newValues.size() == op->getNumResults());
// Create mappings for each of the new result values.
@@ -962,31 +1015,6 @@ void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
assert(succeeded(result) && "expected region to have no unreachable blocks");
}
-void ConversionPatternRewriterImpl::remapValues(
- Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
- remapped.reserve(llvm::size(operands));
- for (Value operand : operands)
- remapped.push_back(mapping.lookupOrDefault(operand));
-}
-
-bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
- // Check to see if this operation or its parent were ignored.
- return ignoredOps.count(op) || ignoredOps.count(op->getParentOp());
-}
-
-void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
- // Walk this operation and collect nested operations that define non-empty
- // regions. We mark such operations as 'ignored' so that we know we don't have
- // to convert them, or their nested ops.
- if (op->getNumRegions() == 0)
- return;
- op->walk([&](Operation *op) {
- if (llvm::any_of(op->getRegions(),
- [](Region ®ion) { return !region.empty(); }))
- ignoredOps.insert(op);
- });
-}
-
//===----------------------------------------------------------------------===//
// ConversionPatternRewriter
//===----------------------------------------------------------------------===//
@@ -1002,7 +1030,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
- impl->replaceOp(op, newValues);
+ impl->notifyOpReplaced(op, newValues);
}
/// PatternRewriter hook for erasing a dead operation. The uses of this
@@ -1014,7 +1042,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
- impl->replaceOp(op, nullRepls);
+ impl->notifyOpReplaced(op, nullRepls);
}
void ConversionPatternRewriter::eraseBlock(Block *block) {
@@ -1160,7 +1188,7 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
}
//===----------------------------------------------------------------------===//
-// Conversion Patterns
+// ConversionPattern
//===----------------------------------------------------------------------===//
/// Attempt to match and rewrite the IR root at the specified operation.
@@ -1234,6 +1262,10 @@ class OperationLegalizer {
RewriterState &state,
RewriterState &newState);
+ //===--------------------------------------------------------------------===//
+ // Cost Model
+ //===--------------------------------------------------------------------===//
+
/// Build an optimistic legalization graph given the provided patterns. This
/// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
/// patterns for operations that are not directly legal, but may be
@@ -1528,9 +1560,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
++i) {
auto &action = impl.blockActions[i];
- if (action.kind ==
- ConversionPatternRewriterImpl::BlockActionKind::TypeConversion ||
- action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase)
+ if (action.kind == BlockActionKind::TypeConversion ||
+ action.kind == BlockActionKind::Erase)
continue;
// Only check blocks outside of the current operation.
Operation *parentOp = action.block->getParentOp();
@@ -1599,6 +1630,9 @@ LogicalResult OperationLegalizer::legalizePatternRootUpdates(
return success();
}
+//===----------------------------------------------------------------------===//
+// Cost Model
+
void OperationLegalizer::buildLegalizationGraph(
LegalizationPatterns &anyOpLegalizerPatterns,
DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {
More information about the Mlir-commits
mailing list