[Mlir-commits] [mlir] [mlir][Transforms][NFC] Decouple `ConversionPatternRewriterImpl` from `ConversionPatternRewriter` (PR #82333)

Matthias Springer llvmlistbot at llvm.org
Fri Feb 23 02:41:42 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/82333

>From 70920e854fdb126c16e66841adb3fe128669f238 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Tue, 20 Feb 2024 09:48:57 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Decouple
 `ConversionPatternRewriterImpl` from `ConversionPatternRewriter`

`ConversionPatternRewriterImpl` no longer maintains a reference to the respective `ConversionPatternRewriter`. An `MLIRContext` is sufficient. This commit simplifies the internal state of `ConversionPatternRewriterImpl`.
---
 .../Transforms/Utils/DialectConversion.cpp    | 44 +++++++++----------
 1 file changed, 21 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 508ee7416d55d8..d015bd52901233 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -756,10 +756,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
 namespace mlir {
 namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
-  explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+  explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
-        config(config) {}
+      : eraseRewriter(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -854,8 +853,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        Type origOutputType,
                                        const TypeConverter *converter);
 
-  Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
-                                               Location loc, ValueRange inputs,
+  Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
+                                               ValueRange inputs,
                                                Type origOutputType,
                                                Type outputType,
                                                const TypeConverter *converter);
@@ -934,8 +933,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // State
   //===--------------------------------------------------------------------===//
 
-  PatternRewriter &rewriter;
-
   /// This rewriter must be used for erasing ops/blocks.
   SingleEraseRewriter eraseRewriter;
 
@@ -1037,8 +1034,12 @@ void BlockTypeConversionRewrite::rollback() {
 
 LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
     function_ref<Operation *(Value)> findLiveUser) {
+  auto builder = OpBuilder::atBlockBegin(block, /*listener=*/&rewriterImpl);
+
   // Process the remapping for each of the original arguments.
   for (auto it : llvm::enumerate(origBlock->getArguments())) {
+    OpBuilder::InsertionGuard g(builder);
+
     // If the type of this argument changed and the argument is still live, we
     // need to materialize a conversion.
     BlockArgument origArg = it.value();
@@ -1050,14 +1051,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
 
     Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
     bool isDroppedArg = replacementValue == origArg;
-    if (isDroppedArg)
-      rewriterImpl.rewriter.setInsertionPointToStart(getBlock());
-    else
-      rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue);
+    if (!isDroppedArg)
+      builder.setInsertionPointAfterValue(replacementValue);
     Value newArg;
     if (converter) {
       newArg = converter->materializeSourceConversion(
-          rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(),
+          builder, origArg.getLoc(), origArg.getType(),
           isDroppedArg ? ValueRange() : ValueRange(replacementValue));
       assert((!newArg || newArg.getType() == origArg.getType()) &&
              "materialization hook did not provide a value of the expected "
@@ -1322,6 +1321,8 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
 Block *ConversionPatternRewriterImpl::applySignatureConversion(
     Block *block, const TypeConverter *converter,
     TypeConverter::SignatureConversion &signatureConversion) {
+  MLIRContext *ctx = block->getParentOp()->getContext();
+
   // If no arguments are being changed or added, there is nothing to do.
   unsigned origArgCount = block->getNumArguments();
   auto convertedTypes = signatureConversion.getConvertedTypes();
@@ -1338,7 +1339,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
   // Map all new arguments to the location of the argument they originate from.
   SmallVector<Location> newLocs(convertedTypes.size(),
-                                rewriter.getUnknownLoc());
+                                Builder(ctx).getUnknownLoc());
   for (unsigned i = 0; i < origArgCount; ++i) {
     auto inputMap = signatureConversion.getInputMapping(i);
     if (!inputMap || inputMap->replacementValue)
@@ -1357,8 +1358,6 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
   argInfo.resize(origArgCount);
 
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(newBlock);
   for (unsigned i = 0; i != origArgCount; ++i) {
     auto inputMap = signatureConversion.getInputMapping(i);
     if (!inputMap)
@@ -1401,7 +1400,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
         outputType = legalOutputType;
 
       newArg = buildUnresolvedArgumentMaterialization(
-          rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
+          newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
           converter);
     }
 
@@ -1439,12 +1438,11 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   return convertOp.getResult(0);
 }
 Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
-    PatternRewriter &rewriter, Location loc, ValueRange inputs,
-    Type origOutputType, Type outputType, const TypeConverter *converter) {
-  return buildUnresolvedMaterialization(
-      MaterializationKind::Argument, rewriter.getInsertionBlock(),
-      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
-      converter);
+    Block *block, Location loc, ValueRange inputs, Type origOutputType,
+    Type outputType, const TypeConverter *converter) {
+  return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
+                                        block->begin(), loc, inputs, outputType,
+                                        origOutputType, converter);
 }
 Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
     Location loc, Value input, Type outputType,
@@ -1556,7 +1554,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 ConversionPatternRewriter::ConversionPatternRewriter(
     MLIRContext *ctx, const ConversionConfig &config)
     : PatternRewriter(ctx),
-      impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
+      impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
   setListener(impl.get());
 }
 



More information about the Mlir-commits mailing list