[Mlir-commits] [mlir] 49f39b3 - [mlir][Transforms][NFC] Simplify function signatures (#155997)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 29 06:16:22 PDT 2025


Author: Matthias Springer
Date: 2025-08-29T15:16:18+02:00
New Revision: 49f39b349db181ca516eb0253462105ff0e2c634

URL: https://github.com/llvm/llvm-project/commit/49f39b349db181ca516eb0253462105ff0e2c634
DIFF: https://github.com/llvm/llvm-project/commit/49f39b349db181ca516eb0253462105ff0e2c634.diff

LOG: [mlir][Transforms][NFC] Simplify function signatures (#155997)

Many internal functions take a `ConversionPatternRewriter &` or
`ConversionPatternRewriterImpl &` as a parameter. There's only a single
instance of these classes, so it's better to store the reference in a
field. This commit is in preparation of another PR that will require
access to `ConversionPatternRewriter` in additional helper functions.

Note: Public API does not change.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b6a216adfdd25..c0685f54731d5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
-  explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+  explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
                                          const ConversionConfig &config)
-      : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
+      : rewriter(rewriter), config(config),
+        notifyingRewriter(rewriter.getContext(), config.listener) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -887,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// is the tag used when describing a value within a diagnostic, e.g.
   /// "operand".
   LogicalResult remapValues(StringRef valueDiagTag,
-                            std::optional<Location> inputLoc,
-                            PatternRewriter &rewriter, ValueRange values,
+                            std::optional<Location> inputLoc, ValueRange values,
                             SmallVector<ValueVector> &remapped);
 
   /// Return "true" if the given operation is ignored, and does not need to be
@@ -918,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
   /// Convert the types of block arguments within the given region.
   FailureOr<Block *>
-  convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
-                     const TypeConverter &converter,
+  convertRegionTypes(Region *region, const TypeConverter &converter,
                      TypeConverter::SignatureConversion *entryConversion);
 
   /// Apply the given signature conversion on the given block. The new block
@@ -929,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// translate between the origin argument types and those specified in the
   /// signature conversion.
   Block *applySignatureConversion(
-      ConversionPatternRewriter &rewriter, Block *block,
-      const TypeConverter *converter,
+      Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
   /// Replace the results of the given operation with the given values and
@@ -1060,8 +1058,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // State
   //===--------------------------------------------------------------------===//
 
-  /// MLIR context.
-  MLIRContext *context;
+  /// The rewriter that is used to perform the conversion.
+  ConversionPatternRewriter &rewriter;
 
   // Mapping between replaced values that 
diff er in type. This happens when
   // replacing a value with one of a 
diff erent type.
@@ -1258,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
 }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
-  // Commit all rewrites.
-  IRRewriter rewriter(context, config.listener);
+  // Commit all rewrites. Use a new rewriter, so the modifications are not
+  // tracked for rollback purposes etc.
+  IRRewriter irRewriter(rewriter.getContext(), config.listener);
   // Note: New rewrites may be added during the "commit" phase and the
   // `rewrites` vector may reallocate.
   for (size_t i = 0; i < rewrites.size(); ++i)
-    rewrites[i]->commit(rewriter);
+    rewrites[i]->commit(irRewriter);
 
   // Clean up all rewrites.
   SingleEraseRewriter eraseRewriter(
-      context, /*opErasedCallback=*/[&](Operation *op) {
+      rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
         if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
           unresolvedMaterializations.erase(castOp);
       });
@@ -1412,8 +1411,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
 }
 
 LogicalResult ConversionPatternRewriterImpl::remapValues(
-    StringRef valueDiagTag, std::optional<Location> inputLoc,
-    PatternRewriter &rewriter, ValueRange values,
+    StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
     SmallVector<ValueVector> &remapped) {
   remapped.reserve(llvm::size(values));
 
@@ -1484,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
 //===----------------------------------------------------------------------===//
 
 FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
-    ConversionPatternRewriter &rewriter, Region *region,
-    const TypeConverter &converter,
+    Region *region, const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
   regionToConverter[region] = &converter;
   if (region->empty())
@@ -1500,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
     if (!conversion)
       return failure();
     // Convert the block with the computed signature.
-    applySignatureConversion(rewriter, &block, &converter, *conversion);
+    applySignatureConversion(&block, &converter, *conversion);
   }
 
   // Convert the entry block. If an entry signature conversion was provided,
   // use that one. Otherwise, compute the signature with the type converter.
   if (entryConversion)
-    return applySignatureConversion(rewriter, &region->front(), &converter,
+    return applySignatureConversion(&region->front(), &converter,
                                     *entryConversion);
   std::optional<TypeConverter::SignatureConversion> conversion =
       converter.convertBlockSignature(&region->front());
   if (!conversion)
     return failure();
-  return applySignatureConversion(rewriter, &region->front(), &converter,
-                                  *conversion);
+  return applySignatureConversion(&region->front(), &converter, *conversion);
 }
 
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
-    ConversionPatternRewriter &rewriter, Block *block,
-    const TypeConverter *converter,
+    Block *block, const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
   // A block cannot be converted multiple times.
@@ -2023,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 ConversionPatternRewriter::ConversionPatternRewriter(
     MLIRContext *ctx, const ConversionConfig &config)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+      impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
   setListener(impl.get());
 }
 
@@ -2100,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
   assert(!impl->wasOpReplaced(block->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->applySignatureConversion(*this, block, converter, conversion);
+  return impl->applySignatureConversion(block, converter, conversion);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -2109,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
   assert(!impl->wasOpReplaced(region->getParentOp()) &&
          "attempting to apply a signature conversion to a block within a "
          "replaced/erased op");
-  return impl->convertRegionTypes(*this, region, converter, entryConversion);
+  return impl->convertRegionTypes(region, converter, entryConversion);
 }
 
 void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2128,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
 
 Value ConversionPatternRewriter::getRemappedValue(Value key) {
   SmallVector<ValueVector> remappedValues;
-  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
                                remappedValues)))
     return nullptr;
   assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -2141,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
   if (keys.empty())
     return success();
   SmallVector<ValueVector> remapped;
-  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+  if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
                                remapped)))
     return failure();
   for (const auto &values : remapped) {
@@ -2288,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
 
   // Remap the operands of the operation.
   SmallVector<ValueVector> remapped;
-  if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+  if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
                                       op->getOperands(), remapped))) {
     return failure();
   }
@@ -2310,7 +2305,8 @@ class OperationLegalizer {
 public:
   using LegalizationAction = ConversionTarget::LegalizationAction;
 
-  OperationLegalizer(const ConversionTarget &targetInfo,
+  OperationLegalizer(ConversionPatternRewriter &rewriter,
+                     const ConversionTarget &targetInfo,
                      const FrozenRewritePatternSet &patterns);
 
   /// Returns true if the given operation is known to be illegal on the target.
@@ -2318,29 +2314,25 @@ class OperationLegalizer {
 
   /// Attempt to legalize the given operation. Returns success if the operation
   /// was legalized, failure otherwise.
-  LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+  LogicalResult legalize(Operation *op);
 
   /// Returns the conversion target in use by the legalizer.
   const ConversionTarget &getTarget() { return target; }
 
 private:
   /// Attempt to legalize the given operation by folding it.
-  LogicalResult legalizeWithFold(Operation *op,
-                                 ConversionPatternRewriter &rewriter);
+  LogicalResult legalizeWithFold(Operation *op);
 
   /// Attempt to legalize the given operation by applying a pattern. Returns
   /// success if the operation was legalized, failure otherwise.
-  LogicalResult legalizeWithPattern(Operation *op,
-                                    ConversionPatternRewriter &rewriter);
+  LogicalResult legalizeWithPattern(Operation *op);
 
   /// Return true if the given pattern may be applied to the given operation,
   /// false otherwise.
-  bool canApplyPattern(Operation *op, const Pattern &pattern,
-                       ConversionPatternRewriter &rewriter);
+  bool canApplyPattern(Operation *op, const Pattern &pattern);
 
   /// Legalize the resultant IR after successfully applying the given pattern.
   LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
-                                      ConversionPatternRewriter &rewriter,
                                       const RewriterState &curState,
                                       const SetVector<Operation *> &newOps,
                                       const SetVector<Operation *> &modifiedOps,
@@ -2349,18 +2341,12 @@ class OperationLegalizer {
   /// Legalizes the actions registered during the execution of a pattern.
   LogicalResult
   legalizePatternBlockRewrites(Operation *op,
-                               ConversionPatternRewriter &rewriter,
-                               ConversionPatternRewriterImpl &impl,
                                const SetVector<Block *> &insertedBlocks,
                                const SetVector<Operation *> &newOps);
   LogicalResult
-  legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
-                                   ConversionPatternRewriterImpl &impl,
-                                   const SetVector<Operation *> &newOps);
+  legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
   LogicalResult
-  legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
-                             ConversionPatternRewriterImpl &impl,
-                             const SetVector<Operation *> &modifiedOps);
+  legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
 
   //===--------------------------------------------------------------------===//
   // Cost Model
@@ -2403,6 +2389,9 @@ class OperationLegalizer {
   /// The current set of patterns that have been applied.
   SmallPtrSet<const Pattern *, 8> appliedPatterns;
 
+  /// The rewriter to use when converting operations.
+  ConversionPatternRewriter &rewriter;
+
   /// The legalization information provided by the target.
   const ConversionTarget ⌖
 
@@ -2411,9 +2400,10 @@ class OperationLegalizer {
 };
 } // namespace
 
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+                                       const ConversionTarget &targetInfo,
                                        const FrozenRewritePatternSet &patterns)
-    : target(targetInfo), applicator(patterns) {
+    : rewriter(rewriter), target(targetInfo), applicator(patterns) {
   // The set of patterns that can be applied to illegal operations to transform
   // them into legal ones.
   DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2427,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
   return target.isIllegal(op);
 }
 
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
-                             ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
 #ifndef NDEBUG
   const char *logLineComment =
       "//===-------------------------------------------===//\n";
@@ -2495,7 +2483,7 @@ OperationLegalizer::legalize(Operation *op,
   // is 'BeforePatterns'. 'Never' will skip this.
   const ConversionConfig &config = rewriter.getConfig();
   if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
-    if (succeeded(legalizeWithFold(op, rewriter))) {
+    if (succeeded(legalizeWithFold(op))) {
       LLVM_DEBUG({
         logSuccess(logger, "operation was folded");
         logger.startLine() << logLineComment;
@@ -2505,7 +2493,7 @@ OperationLegalizer::legalize(Operation *op,
   }
 
   // Otherwise, we need to apply a legalization pattern to this operation.
-  if (succeeded(legalizeWithPattern(op, rewriter))) {
+  if (succeeded(legalizeWithPattern(op))) {
     LLVM_DEBUG({
       logSuccess(logger, "");
       logger.startLine() << logLineComment;
@@ -2516,7 +2504,7 @@ OperationLegalizer::legalize(Operation *op,
   // If the operation can't be legalized via patterns, try to fold it in-place
   // if the folding mode is 'AfterPatterns'.
   if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
-    if (succeeded(legalizeWithFold(op, rewriter))) {
+    if (succeeded(legalizeWithFold(op))) {
       LLVM_DEBUG({
         logSuccess(logger, "operation was folded");
         logger.startLine() << logLineComment;
@@ -2541,9 +2529,7 @@ static T moveAndReset(T &obj) {
   return result;
 }
 
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
-                                     ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
   auto &rewriterImpl = rewriter.getImpl();
   LLVM_DEBUG({
     rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2577,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op,
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
   if (replacementValues.empty())
-    return legalize(op, rewriter);
+    return legalize(op);
 
   // Insert a replacement for 'op' with the folded replacement values.
   rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
   for (Operation *newOp : newOps) {
-    if (failed(legalize(newOp, rewriter))) {
+    if (failed(legalize(newOp))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
                             newOp->getName()));
@@ -2629,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern,
       llvm::join(insertedBlockNames, ", ") + "}");
 }
 
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
-                                        ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
   auto &rewriterImpl = rewriter.getImpl();
   const ConversionConfig &config = rewriter.getConfig();
 
@@ -2663,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
 
   // Functor that returns if the given pattern may be applied.
   auto canApply = [&](const Pattern &pattern) {
-    bool canApply = canApplyPattern(op, pattern, rewriter);
+    bool canApply = canApplyPattern(op, pattern);
     if (canApply && config.listener)
       config.listener->notifyPatternBegin(pattern, op);
     return canApply;
@@ -2728,7 +2712,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
         moveAndReset(rewriterImpl.patternModifiedOps);
     SetVector<Block *> insertedBlocks =
         moveAndReset(rewriterImpl.patternInsertedBlocks);
-    auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+    auto result = legalizePatternResult(op, pattern, curState, newOps,
                                         modifiedOps, insertedBlocks);
     appliedPatterns.erase(&pattern);
     if (failed(result)) {
@@ -2747,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
                                     onSuccess);
 }
 
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
-                                         ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+                                         const Pattern &pattern) {
   LLVM_DEBUG({
     auto &os = rewriter.getImpl().logger;
     os.getOStream() << "\n";
@@ -2770,8 +2754,8 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
 }
 
 LogicalResult OperationLegalizer::legalizePatternResult(
-    Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
-    const RewriterState &curState, const SetVector<Operation *> &newOps,
+    Operation *op, const Pattern &pattern, const RewriterState &curState,
+    const SetVector<Operation *> &newOps,
     const SetVector<Operation *> &modifiedOps,
     const SetVector<Block *> &insertedBlocks) {
   auto &impl = rewriter.getImpl();
@@ -2792,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
 
   // Legalize each of the actions registered during application.
-  if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
-                                          newOps)) ||
-      failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
-      failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+  if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+      failed(legalizePatternRootUpdates(modifiedOps)) ||
+      failed(legalizePatternCreatedOperations(newOps))) {
     return failure();
   }
 
@@ -2804,10 +2787,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
 }
 
 LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
-    Operation *op, ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &impl,
-    const SetVector<Block *> &insertedBlocks,
+    Operation *op, const SetVector<Block *> &insertedBlocks,
     const SetVector<Operation *> &newOps) {
+  ConversionPatternRewriterImpl &impl = rewriter.getImpl();
   SmallPtrSet<Operation *, 16> alreadyLegalized;
 
   // If the pattern moved or created any blocks, make sure the types of block
@@ -2831,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
                                            "block"));
         return failure();
       }
-      impl.applySignatureConversion(rewriter, block, converter, *conversion);
+      impl.applySignatureConversion(block, converter, *conversion);
       continue;
     }
 
@@ -2840,7 +2822,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     // operation, and blocks in regions created by this pattern will already be
     // legalized later on.
     if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
-      if (failed(legalize(parentOp, rewriter))) {
+      if (failed(legalize(parentOp))) {
         LLVM_DEBUG(logFailure(
             impl.logger, "operation '{0}'({1}) became illegal after rewrite",
             parentOp->getName(), parentOp));
@@ -2852,11 +2834,10 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
 }
 
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
-    ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     const SetVector<Operation *> &newOps) {
   for (Operation *op : newOps) {
-    if (failed(legalize(op, rewriter))) {
-      LLVM_DEBUG(logFailure(impl.logger,
+    if (failed(legalize(op))) {
+      LLVM_DEBUG(logFailure(rewriter.getImpl().logger,
                             "failed to legalize generated operation '{0}'({1})",
                             op->getName(), op));
       return failure();
@@ -2866,13 +2847,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
 }
 
 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
-    ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     const SetVector<Operation *> &modifiedOps) {
   for (Operation *op : modifiedOps) {
-    if (failed(legalize(op, rewriter))) {
-      LLVM_DEBUG(logFailure(
-          impl.logger, "failed to legalize operation updated in-place '{0}'",
-          op->getName()));
+    if (failed(legalize(op))) {
+      LLVM_DEBUG(
+          logFailure(rewriter.getImpl().logger,
+                     "failed to legalize operation updated in-place '{0}'",
+                     op->getName()));
       return failure();
     }
   }
@@ -3092,21 +3073,22 @@ namespace mlir {
 // rewrite patterns. The conversion behaves 
diff erently depending on the
 // conversion mode.
 struct OperationConverter {
-  explicit OperationConverter(const ConversionTarget &target,
+  explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target,
                               const FrozenRewritePatternSet &patterns,
                               const ConversionConfig &config,
                               OpConversionMode mode)
-      : config(config), opLegalizer(target, patterns), mode(mode) {}
+      : rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
+        mode(mode) {}
 
   /// Converts the given operations to the conversion target.
   LogicalResult convertOperations(ArrayRef<Operation *> ops);
 
 private:
   /// Converts an operation with the given rewriter.
-  LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
+  LogicalResult convert(Operation *op);
 
-  /// Dialect conversion configuration.
-  ConversionConfig config;
+  /// The rewriter to use when converting operations.
+  ConversionPatternRewriter rewriter;
 
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
@@ -3116,10 +3098,11 @@ struct OperationConverter {
 };
 } // namespace mlir
 
-LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
-                                          Operation *op) {
+LogicalResult OperationConverter::convert(Operation *op) {
+  const ConversionConfig &config = rewriter.getConfig();
+
   // Legalize the given operation.
-  if (failed(opLegalizer.legalize(op, rewriter))) {
+  if (failed(opLegalizer.legalize(op))) {
     // Handle the case of a failed conversion for each of the 
diff erent modes.
     // Full conversions expect all operations to be converted.
     if (mode == OpConversionMode::Full)
@@ -3195,7 +3178,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
 }
 
 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
-  assert(!ops.empty() && "expected at least one operation");
   const ConversionTarget &target = opLegalizer.getTarget();
 
   // Compute the set of operations and blocks to convert.
@@ -3214,11 +3196,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   }
 
   // Convert each operation and discard rewrites on failure.
-  ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
 
   for (auto *op : toConvert) {
-    if (failed(convert(rewriter, op))) {
+    if (failed(convert(op))) {
       // Dialect conversion failed.
       if (rewriterImpl.config.allowPatternRollback) {
         // Rollback is allowed: restore the original IR.
@@ -3253,13 +3234,16 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
     castOp->removeAttr(kPureTypeConversionMarker);
 
   // Try to legalize all unresolved materializations.
-  if (config.buildMaterializations) {
-    IRRewriter rewriter(rewriterImpl.context, config.listener);
+  if (rewriter.getConfig().buildMaterializations) {
+    // Use a new rewriter, so the modifications are not tracked for rollback
+    // purposes etc.
+    IRRewriter irRewriter(rewriterImpl.rewriter.getContext(),
+                          rewriter.getConfig().listener);
     for (UnrealizedConversionCastOp castOp : remainingCastOps) {
       auto it = materializations.find(castOp);
       assert(it != materializations.end() && "inconsistent state");
-      if (failed(
-              legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
+      if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp,
+                                                   it->second)))
         return failure();
     }
   }
@@ -4001,7 +3985,8 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops,
   SmallVector<IRUnit> irUnits(ops.begin(), ops.end());
   ctx->executeAction<ApplyConversionAction>(
       [&] {
-        OperationConverter opConverter(target, patterns, config, mode);
+        OperationConverter opConverter(ops.front()->getContext(), target,
+                                       patterns, config, mode);
         status = opConverter.convertOperations(ops);
       },
       irUnits);


        


More information about the Mlir-commits mailing list