[Mlir-commits] [mlir] e6a343e - [mlir][DialectConversion][NFC] Add comment blocks and organize a bit of the code

River Riddle llvmlistbot at llvm.org
Wed Jun 24 17:42:17 PDT 2020


Author: River Riddle
Date: 2020-06-24T17:42:10-07:00
New Revision: e6a343e491d4ee52b4085bf2b2c24669f1f9a6ce

URL: https://github.com/llvm/llvm-project/commit/e6a343e491d4ee52b4085bf2b2c24669f1f9a6ce
DIFF: https://github.com/llvm/llvm-project/commit/e6a343e491d4ee52b4085bf2b2c24669f1f9a6ce.diff

LOG: [mlir][DialectConversion][NFC] Add comment blocks and organize a bit of the code

This helps improve the readability when scrolling through the many functions of ConversionPatternRewriterImpl.

Added: 
    

Modified: 
    mlir/lib/Transforms/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index ecebe61d025f..60c9e78b7a69 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -450,7 +450,7 @@ void ArgConverter::insertConversion(Block *newBlock,
 }
 
 //===----------------------------------------------------------------------===//
-// ConversionPatternRewriterImpl
+// Rewriter and Transation State
 //===----------------------------------------------------------------------===//
 namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
@@ -515,74 +515,89 @@ class OperationTransactionState {
   SmallVector<Value, 8> operands;
   SmallVector<Block *, 2> successors;
 };
-} // end anonymous namespace
 
-namespace mlir {
-namespace detail {
-struct ConversionPatternRewriterImpl {
-  /// This class represents one requested operation replacement via 'replaceOp'.
-  struct OpReplacement {
-    OpReplacement() = default;
-    OpReplacement(Operation *op, ValueRange newValues)
-        : op(op), newValues(newValues.begin(), newValues.end()) {}
-
-    Operation *op;
-    SmallVector<Value, 2> newValues;
-  };
+/// This class represents one requested operation replacement via 'replaceOp'.
+struct OpReplacement {
+  OpReplacement() = default;
+  OpReplacement(Operation *op, ValueRange newValues)
+      : op(op), newValues(newValues.begin(), newValues.end()) {}
 
-  /// The kind of the block action performed during the rewrite.  Actions can be
-  /// undone if the conversion fails.
-  enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion };
+  Operation *op;
+  SmallVector<Value, 2> newValues;
+};
 
-  /// Original position of the given block in its parent region.  We cannot use
-  /// a region iterator because it could have been invalidated by other region
-  /// operations since the position was stored.
-  struct BlockPosition {
-    Region *region;
-    Region::iterator::
diff erence_type position;
-  };
+/// The kind of the block action performed during the rewrite.  Actions can be
+/// undone if the conversion fails.
+enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion };
 
-  /// The storage class for an undoable block action (one of BlockActionKind),
-  /// contains the information necessary to undo this action.
-  struct BlockAction {
-    static BlockAction getCreate(Block *block) {
-      return {BlockActionKind::Create, block, {}};
-    }
-    static BlockAction getErase(Block *block, BlockPosition originalPos) {
-      return {BlockActionKind::Erase, block, {originalPos}};
-    }
-    static BlockAction getMove(Block *block, BlockPosition originalPos) {
-      return {BlockActionKind::Move, block, {originalPos}};
-    }
-    static BlockAction getSplit(Block *block, Block *originalBlock) {
-      BlockAction action{BlockActionKind::Split, block, {}};
-      action.originalBlock = originalBlock;
-      return action;
-    }
-    static BlockAction getTypeConversion(Block *block) {
-      return BlockAction{BlockActionKind::TypeConversion, block, {}};
-    }
+/// Original position of the given block in its parent region.  We cannot use
+/// a region iterator because it could have been invalidated by other region
+/// operations since the position was stored.
+struct BlockPosition {
+  Region *region;
+  Region::iterator::
diff erence_type position;
+};
 
-    // The action kind.
-    BlockActionKind kind;
-
-    // A pointer to the block that was created by the action.
-    Block *block;
-
-    union {
-      // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
-      // contains a pointer to the region that originally contained the block as
-      // well as the position of the block in that region.
-      BlockPosition originalPosition;
-      // In use if kind == BlockActionKind::Split and contains a pointer to the
-      // block that was split into two parts.
-      Block *originalBlock;
-    };
+/// The storage class for an undoable block action (one of BlockActionKind),
+/// contains the information necessary to undo this action.
+struct BlockAction {
+  static BlockAction getCreate(Block *block) {
+    return {BlockActionKind::Create, block, {}};
+  }
+  static BlockAction getErase(Block *block, BlockPosition originalPos) {
+    return {BlockActionKind::Erase, block, {originalPos}};
+  }
+  static BlockAction getMove(Block *block, BlockPosition originalPos) {
+    return {BlockActionKind::Move, block, {originalPos}};
+  }
+  static BlockAction getSplit(Block *block, Block *originalBlock) {
+    BlockAction action{BlockActionKind::Split, block, {}};
+    action.originalBlock = originalBlock;
+    return action;
+  }
+  static BlockAction getTypeConversion(Block *block) {
+    return BlockAction{BlockActionKind::TypeConversion, block, {}};
+  }
+
+  // The action kind.
+  BlockActionKind kind;
+
+  // A pointer to the block that was created by the action.
+  Block *block;
+
+  union {
+    // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
+    // contains a pointer to the region that originally contained the block as
+    // well as the position of the block in that region.
+    BlockPosition originalPosition;
+    // In use if kind == BlockActionKind::Split and contains a pointer to the
+    // block that was split into two parts.
+    Block *originalBlock;
   };
+};
+} // end anonymous namespace
 
+//===----------------------------------------------------------------------===//
+// ConversionPatternRewriterImpl
+//===----------------------------------------------------------------------===//
+namespace mlir {
+namespace detail {
+struct ConversionPatternRewriterImpl {
   ConversionPatternRewriterImpl(PatternRewriter &rewriter)
       : argConverter(rewriter) {}
 
+  /// Cleanup and destroy any generated rewrite operations. This method is
+  /// invoked when the conversion process fails.
+  void discardRewrites();
+
+  /// Apply all requested operation rewrites. This method is invoked when the
+  /// conversion process succeeds.
+  void applyRewrites();
+
+  //===--------------------------------------------------------------------===//
+  // State Management
+  //===--------------------------------------------------------------------===//
+
   /// Return the current state of the rewriter.
   RewriterState getCurrentState();
 
@@ -597,13 +612,21 @@ struct ConversionPatternRewriterImpl {
   /// "numActionsToKeep" actions remains.
   void undoBlockActions(unsigned numActionsToKeep = 0);
 
-  /// Cleanup and destroy any generated rewrite operations. This method is
-  /// invoked when the conversion process fails.
-  void discardRewrites();
+  /// Remap the given operands to those with potentially 
diff erent types.
+  void remapValues(Operation::operand_range operands,
+                   SmallVectorImpl<Value> &remapped);
 
-  /// Apply all requested operation rewrites. This method is invoked when the
-  /// conversion process succeeds.
-  void applyRewrites();
+  /// Returns true if the given operation is ignored, and does not need to be
+  /// converted.
+  bool isOpIgnored(Operation *op) const;
+
+  /// Recursively marks the nested operations under 'op' as ignored. This
+  /// removes them from being considered for legalization.
+  void markNestedOpsIgnored(Operation *op);
+
+  //===--------------------------------------------------------------------===//
+  // Type Conversion
+  //===--------------------------------------------------------------------===//
 
   /// Convert the signature of the given block.
   FailureOr<Block *> convertBlockSignature(
@@ -620,8 +643,12 @@ struct ConversionPatternRewriterImpl {
   convertRegionTypes(Region *region, TypeConverter &converter,
                      TypeConverter::SignatureConversion *entryConversion);
 
+  //===--------------------------------------------------------------------===//
+  // Rewriter Notification Hooks
+  //===--------------------------------------------------------------------===//
+
   /// PatternRewriter hook for replacing the results of an operation.
-  void replaceOp(Operation *op, ValueRange newValues);
+  void notifyOpReplaced(Operation *op, ValueRange newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
@@ -640,17 +667,9 @@ struct ConversionPatternRewriterImpl {
   void notifyRegionWasClonedBefore(iterator_range<Region::iterator> &blocks,
                                    Location origRegionLoc);
 
-  /// Remap the given operands to those with potentially 
diff erent types.
-  void remapValues(Operation::operand_range operands,
-                   SmallVectorImpl<Value> &remapped);
-
-  /// Returns true if the given operation is ignored, and does not need to be
-  /// converted.
-  bool isOpIgnored(Operation *op) const;
-
-  /// Recursively marks the nested operations under 'op' as ignored. This
-  /// removes them from being considered for legalization.
-  void markNestedOpsIgnored(Operation *op);
+  //===--------------------------------------------------------------------===//
+  // State
+  //===--------------------------------------------------------------------===//
 
   // Mapping between replaced values that 
diff er in type. This happens when
   // replacing a value with one of a 
diff erent type.
@@ -700,12 +719,6 @@ struct ConversionPatternRewriterImpl {
 } // end namespace detail
 } // end namespace mlir
 
-RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(createdOps.size(), replacements.size(),
-                       argReplacements.size(), blockActions.size(),
-                       ignoredOps.size(), rootUpdates.size());
-}
-
 /// Detach any operations nested in the given operation from their parent
 /// blocks, and erase the given operation. This can be used when the nested
 /// operations are scheduled for erasure themselves, so deleting the regions of
@@ -722,6 +735,73 @@ static void detachNestedAndErase(Operation *op) {
   op->erase();
 }
 
+void ConversionPatternRewriterImpl::discardRewrites() {
+  // Reset any operations that were updated in place.
+  for (auto &state : rootUpdates)
+    state.resetOperation();
+
+  undoBlockActions();
+
+  // Remove any newly created ops.
+  for (auto *op : llvm::reverse(createdOps))
+    detachNestedAndErase(op);
+}
+
+void ConversionPatternRewriterImpl::applyRewrites() {
+  // Apply all of the rewrites replacements requested during conversion.
+  for (auto &repl : replacements) {
+    for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) {
+      if (auto newValue = repl.newValues[i])
+        repl.op->getResult(i).replaceAllUsesWith(
+            mapping.lookupOrDefault(newValue));
+    }
+
+    // If this operation defines any regions, drop any pending argument
+    // rewrites.
+    if (repl.op->getNumRegions())
+      argConverter.notifyOpRemoved(repl.op);
+  }
+
+  // Apply all of the requested argument replacements.
+  for (BlockArgument arg : argReplacements) {
+    Value repl = mapping.lookupOrDefault(arg);
+    if (repl.isa<BlockArgument>()) {
+      arg.replaceAllUsesWith(repl);
+      continue;
+    }
+
+    // If the replacement value is an operation, we check to make sure that we
+    // don't replace uses that are within the parent operation of the
+    // replacement value.
+    Operation *replOp = repl.cast<OpResult>().getOwner();
+    Block *replBlock = replOp->getBlock();
+    arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+      Operation *user = operand.getOwner();
+      return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
+    });
+  }
+
+  // In a second pass, erase all of the replaced operations in reverse. This
+  // allows processing nested operations before their parent region is
+  // destroyed.
+  for (auto &repl : llvm::reverse(replacements))
+    repl.op->erase();
+
+  argConverter.applyRewrites(mapping);
+
+  // Now that the ops have been erased, also erase dangling blocks.
+  eraseDanglingBlocks();
+}
+
+//===----------------------------------------------------------------------===//
+// State Management
+
+RewriterState ConversionPatternRewriterImpl::getCurrentState() {
+  return RewriterState(createdOps.size(), replacements.size(),
+                       argReplacements.size(), blockActions.size(),
+                       ignoredOps.size(), rootUpdates.size());
+}
+
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   // Reset any operations that were updated in place.
   for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
@@ -810,64 +890,34 @@ void ConversionPatternRewriterImpl::undoBlockActions(
   blockActions.resize(numActionsToKeep);
 }
 
-void ConversionPatternRewriterImpl::discardRewrites() {
-  // Reset any operations that were updated in place.
-  for (auto &state : rootUpdates)
-    state.resetOperation();
-
-  undoBlockActions();
-
-  // Remove any newly created ops.
-  for (auto *op : llvm::reverse(createdOps))
-    detachNestedAndErase(op);
+void ConversionPatternRewriterImpl::remapValues(
+    Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
+  remapped.reserve(llvm::size(operands));
+  for (Value operand : operands)
+    remapped.push_back(mapping.lookupOrDefault(operand));
 }
 
-void ConversionPatternRewriterImpl::applyRewrites() {
-  // Apply all of the rewrites replacements requested during conversion.
-  for (auto &repl : replacements) {
-    for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) {
-      if (auto newValue = repl.newValues[i])
-        repl.op->getResult(i).replaceAllUsesWith(
-            mapping.lookupOrDefault(newValue));
-    }
-
-    // If this operation defines any regions, drop any pending argument
-    // rewrites.
-    if (repl.op->getNumRegions())
-      argConverter.notifyOpRemoved(repl.op);
-  }
-
-  // Apply all of the requested argument replacements.
-  for (BlockArgument arg : argReplacements) {
-    Value repl = mapping.lookupOrDefault(arg);
-    if (repl.isa<BlockArgument>()) {
-      arg.replaceAllUsesWith(repl);
-      continue;
-    }
-
-    // If the replacement value is an operation, we check to make sure that we
-    // don't replace uses that are within the parent operation of the
-    // replacement value.
-    Operation *replOp = repl.cast<OpResult>().getOwner();
-    Block *replBlock = replOp->getBlock();
-    arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
-      Operation *user = operand.getOwner();
-      return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
-    });
-  }
-
-  // In a second pass, erase all of the replaced operations in reverse. This
-  // allows processing nested operations before their parent region is
-  // destroyed.
-  for (auto &repl : llvm::reverse(replacements))
-    repl.op->erase();
-
-  argConverter.applyRewrites(mapping);
+bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
+  // Check to see if this operation or its parent were ignored.
+  return ignoredOps.count(op) || ignoredOps.count(op->getParentOp());
+}
 
-  // Now that the ops have been erased, also erase dangling blocks.
-  eraseDanglingBlocks();
+void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
+  // Walk this operation and collect nested operations that define non-empty
+  // regions. We mark such operations as 'ignored' so that we know we don't have
+  // to convert them, or their nested ops.
+  if (op->getNumRegions() == 0)
+    return;
+  op->walk([&](Operation *op) {
+    if (llvm::any_of(op->getRegions(),
+                     [](Region &region) { return !region.empty(); }))
+      ignoredOps.insert(op);
+  });
 }
 
+//===----------------------------------------------------------------------===//
+// Type Conversion
+
 FailureOr<Block *> ConversionPatternRewriterImpl::convertBlockSignature(
     Block *block, TypeConverter &converter,
     TypeConverter::SignatureConversion *conversion) {
@@ -907,8 +957,11 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
   return newEntry;
 }
 
-void ConversionPatternRewriterImpl::replaceOp(Operation *op,
-                                              ValueRange newValues) {
+//===----------------------------------------------------------------------===//
+// Rewriter Notification Hooks
+
+void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
+                                                     ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
 
   // Create mappings for each of the new result values.
@@ -962,31 +1015,6 @@ void ConversionPatternRewriterImpl::notifyRegionWasClonedBefore(
   assert(succeeded(result) && "expected region to have no unreachable blocks");
 }
 
-void ConversionPatternRewriterImpl::remapValues(
-    Operation::operand_range operands, SmallVectorImpl<Value> &remapped) {
-  remapped.reserve(llvm::size(operands));
-  for (Value operand : operands)
-    remapped.push_back(mapping.lookupOrDefault(operand));
-}
-
-bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
-  // Check to see if this operation or its parent were ignored.
-  return ignoredOps.count(op) || ignoredOps.count(op->getParentOp());
-}
-
-void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
-  // Walk this operation and collect nested operations that define non-empty
-  // regions. We mark such operations as 'ignored' so that we know we don't have
-  // to convert them, or their nested ops.
-  if (op->getNumRegions() == 0)
-    return;
-  op->walk([&](Operation *op) {
-    if (llvm::any_of(op->getRegions(),
-                     [](Region &region) { return !region.empty(); }))
-      ignoredOps.insert(op);
-  });
-}
-
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriter
 //===----------------------------------------------------------------------===//
@@ -1002,7 +1030,7 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
     impl->logger.startLine()
         << "** Replace : '" << op->getName() << "'(" << op << ")\n";
   });
-  impl->replaceOp(op, newValues);
+  impl->notifyOpReplaced(op, newValues);
 }
 
 /// PatternRewriter hook for erasing a dead operation. The uses of this
@@ -1014,7 +1042,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
         << "** Erase   : '" << op->getName() << "'(" << op << ")\n";
   });
   SmallVector<Value, 1> nullRepls(op->getNumResults(), nullptr);
-  impl->replaceOp(op, nullRepls);
+  impl->notifyOpReplaced(op, nullRepls);
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
@@ -1160,7 +1188,7 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
 }
 
 //===----------------------------------------------------------------------===//
-// Conversion Patterns
+// ConversionPattern
 //===----------------------------------------------------------------------===//
 
 /// Attempt to match and rewrite the IR root at the specified operation.
@@ -1234,6 +1262,10 @@ class OperationLegalizer {
                                            RewriterState &state,
                                            RewriterState &newState);
 
+  //===--------------------------------------------------------------------===//
+  // Cost Model
+  //===--------------------------------------------------------------------===//
+
   /// Build an optimistic legalization graph given the provided patterns. This
   /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with
   /// patterns for operations that are not directly legal, but may be
@@ -1528,9 +1560,8 @@ LogicalResult OperationLegalizer::legalizePatternBlockActions(
   for (int i = state.numBlockActions, e = newState.numBlockActions; i != e;
        ++i) {
     auto &action = impl.blockActions[i];
-    if (action.kind ==
-            ConversionPatternRewriterImpl::BlockActionKind::TypeConversion ||
-        action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase)
+    if (action.kind == BlockActionKind::TypeConversion ||
+        action.kind == BlockActionKind::Erase)
       continue;
     // Only check blocks outside of the current operation.
     Operation *parentOp = action.block->getParentOp();
@@ -1599,6 +1630,9 @@ LogicalResult OperationLegalizer::legalizePatternRootUpdates(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Cost Model
+
 void OperationLegalizer::buildLegalizationGraph(
     LegalizationPatterns &anyOpLegalizerPatterns,
     DenseMap<OperationName, LegalizationPatterns> &legalizerPatterns) {


        


More information about the Mlir-commits mailing list