[Mlir-commits] [mlir] ea2d938 - [mlir][Transforms][NFC] Improve listener layering in dialect conversion (#81236)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Feb 14 07:51:32 PST 2024


Author: Matthias Springer
Date: 2024-02-14T16:51:28+01:00
New Revision: ea2d9383a23ca17b9240ad64c2adc5f2b5a73dc0

URL: https://github.com/llvm/llvm-project/commit/ea2d9383a23ca17b9240ad64c2adc5f2b5a73dc0
DIFF: https://github.com/llvm/llvm-project/commit/ea2d9383a23ca17b9240ad64c2adc5f2b5a73dc0.diff

LOG: [mlir][Transforms][NFC] Improve listener layering in dialect conversion (#81236)

Context: Conversion patterns provide a `ConversionPatternRewriter` to
modify the IR. `ConversionPatternRewriter` provides the public API. Most
function calls are forwarded/handled by `ConversionPatternRewriterImpl`.
The dialect conversion uses the listener infrastructure to get notified
about op/block insertions.

In the current design, `ConversionPatternRewriter` inherits from both
`PatternRewriter` and `Listener`. The conversion rewriter registers
itself as a listener. This is problematic because listener functions
such as `notifyOperationInserted` are now part of the public API and can
be called from conversion patterns; that would bring the dialect
conversion into an inconsistent state.

With this commit, `ConversionPatternRewriter` no longer inherits from
`Listener`. Instead `ConversionPatternRewriterImpl` inherits from
`Listener`. This removes the problematic public API and also simplifies
the code a bit: block/op insertion notifications were previously
forwarded to the `ConversionPatternRewriterImpl`. This is no longer
needed.

Added: 
    

Modified: 
    flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index bd8d3d92d480b6..1c4f82e2de818b 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -739,12 +739,12 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
   void notifyOperationInserted(mlir::Operation *op,
                                mlir::OpBuilder::InsertPoint previous) override {
     builder.notifyOperationInserted(op, previous);
-    rewriter.notifyOperationInserted(op, previous);
+    rewriter.getListener()->notifyOperationInserted(op, previous);
   }
   virtual void notifyBlockInserted(mlir::Block *block, mlir::Region *previous,
                                    mlir::Region::iterator previousIt) override {
     builder.notifyBlockInserted(block, previous, previousIt);
-    rewriter.notifyBlockInserted(block, previous, previousIt);
+    rewriter.getListener()->notifyBlockInserted(block, previous, previousIt);
   }
   fir::FirOpBuilder &builder;
   mlir::ConversionPatternRewriter &rewriter;

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 091131651bbf56..851d639ae68a77 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -655,8 +655,7 @@ struct ConversionPatternRewriterImpl;
 /// This class implements a pattern rewriter for use with ConversionPatterns. It
 /// extends the base PatternRewriter and provides special conversion specific
 /// hooks.
-class ConversionPatternRewriter final : public PatternRewriter,
-                                        public RewriterBase::Listener {
+class ConversionPatternRewriter final : public PatternRewriter {
 public:
   explicit ConversionPatternRewriter(MLIRContext *ctx);
   ~ConversionPatternRewriter() override;
@@ -735,10 +734,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
   /// implemented for dialect conversion.
   void eraseBlock(Block *block) override;
 
-  /// PatternRewriter hook creating a new block.
-  void notifyBlockInserted(Block *block, Region *previous,
-                           Region::iterator previousIt) override;
-
   /// PatternRewriter hook for splitting a block into two parts.
   Block *splitBlock(Block *block, Block::iterator before) override;
 
@@ -747,9 +742,6 @@ class ConversionPatternRewriter final : public PatternRewriter,
                          ValueRange argValues = std::nullopt) override;
   using PatternRewriter::inlineBlockBefore;
 
-  /// PatternRewriter hook for inserting a new operation.
-  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
-
   /// PatternRewriter hook for updating the given operation in-place.
   /// Note: These methods only track updates to the given operation itself,
   /// and not nested regions. Updates to regions will still require notification
@@ -762,18 +754,11 @@ class ConversionPatternRewriter final : public PatternRewriter,
   /// PatternRewriter hook for updating the given operation in-place.
   void cancelOpModification(Operation *op) override;
 
-  /// PatternRewriter hook for notifying match failure reasons.
-  void
-  notifyMatchFailure(Location loc,
-                     function_ref<void(Diagnostic &)> reasonCallback) override;
-  using PatternRewriter::notifyMatchFailure;
-
   /// Return a reference to the internal implementation.
   detail::ConversionPatternRewriterImpl &getImpl();
 
 private:
   // Hide unsupported pattern rewriter API.
-  using OpBuilder::getListener;
   using OpBuilder::setListener;
 
   void moveOpBefore(Operation *op, Block *block,

diff  --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index 828f53c16d8f86..31e81107f655c0 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -582,7 +582,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
     // Inside regular functions we use the blocking wait operation to wait for
     // the async object (token, value or group) to become available.
     if (!isInCoroutine) {
-      ImplicitLocOpBuilder builder(loc, op, &rewriter);
+      ImplicitLocOpBuilder builder(loc, rewriter);
       builder.create<RuntimeAwaitOp>(loc, operand);
 
       // Assert that the awaited operands is not in the error state.
@@ -601,7 +601,7 @@ class AwaitOpLoweringBase : public OpConversionPattern<AwaitType> {
       CoroMachinery &coro = funcCoro->getSecond();
       Block *suspended = op->getBlock();
 
-      ImplicitLocOpBuilder builder(loc, op, &rewriter);
+      ImplicitLocOpBuilder builder(loc, rewriter);
       MLIRContext *ctx = op->getContext();
 
       // Save the coroutine state and resume on a runtime managed thread when

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a5a77e00fbfb5f..dbf5bf50d60e7f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -825,7 +825,7 @@ void ArgConverter::insertConversion(Block *newBlock,
 //===----------------------------------------------------------------------===//
 namespace mlir {
 namespace detail {
-struct ConversionPatternRewriterImpl {
+struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
       : argConverter(rewriter, unresolvedMaterializations),
         notifyCallback(nullptr) {}
@@ -903,15 +903,19 @@ struct ConversionPatternRewriterImpl {
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
 
-  /// PatternRewriter hook for replacing the results of an operation.
+  //// Notifies that an op was inserted.
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override;
+
+  /// Notifies that an op is about to be replaced with the given values.
   void notifyOpReplaced(Operation *op, ValueRange newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
 
-  /// Notifies that a block was created.
-  void notifyInsertedBlock(Block *block, Region *previous,
-                           Region::iterator previousIt);
+  /// Notifies that a block was inserted.
+  void notifyBlockInserted(Block *block, Region *previous,
+                           Region::iterator previousIt) override;
 
   /// Notifies that a block was split.
   void notifySplitBlock(Block *block, Block *continuation);
@@ -921,8 +925,9 @@ struct ConversionPatternRewriterImpl {
                                Block::iterator before);
 
   /// Notifies that a pattern match failed for the given reason.
-  void notifyMatchFailure(Location loc,
-                          function_ref<void(Diagnostic &)> reasonCallback);
+  void
+  notifyMatchFailure(Location loc,
+                     function_ref<void(Diagnostic &)> reasonCallback) override;
 
   //===--------------------------------------------------------------------===//
   // State
@@ -1363,6 +1368,16 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
+void ConversionPatternRewriterImpl::notifyOperationInserted(
+    Operation *op, OpBuilder::InsertPoint previous) {
+  assert(!previous.isSet() && "expected newly created op");
+  LLVM_DEBUG({
+    logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
+                       << ")\n";
+  });
+  createdOps.push_back(op);
+}
+
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
                                                      ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
@@ -1398,7 +1413,7 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
   blockActions.push_back(BlockAction::getErase(block, {region, origNextBlock}));
 }
 
-void ConversionPatternRewriterImpl::notifyInsertedBlock(
+void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
   if (!previous) {
     // This is a newly created block.
@@ -1437,7 +1452,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
 ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
     : PatternRewriter(ctx),
       impl(new detail::ConversionPatternRewriterImpl(*this)) {
-  setListener(this);
+  setListener(impl.get());
 }
 
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
@@ -1540,11 +1555,6 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
                            results);
 }
 
-void ConversionPatternRewriter::notifyBlockInserted(
-    Block *block, Region *previous, Region::iterator previousIt) {
-  impl->notifyInsertedBlock(block, previous, previousIt);
-}
-
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
   auto *continuation = block->splitBlock(before);
@@ -1572,16 +1582,6 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
   eraseBlock(source);
 }
 
-void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
-                                                        InsertPoint previous) {
-  assert(!previous.isSet() && "expected newly created op");
-  LLVM_DEBUG({
-    impl->logger.startLine()
-        << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
-  });
-  impl->createdOps.push_back(op);
-}
-
 void ConversionPatternRewriter::startOpModification(Operation *op) {
 #ifndef NDEBUG
   impl->pendingRootUpdates.insert(op);
@@ -1614,11 +1614,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
   rootUpdates.erase(rootUpdates.begin() + updateIdx);
 }
 
-void ConversionPatternRewriter::notifyMatchFailure(
-    Location loc, function_ref<void(Diagnostic &)> reasonCallback) {
-  impl->notifyMatchFailure(loc, reasonCallback);
-}
-
 void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
                                              Block::iterator iterator) {
   llvm_unreachable(


        


More information about the Mlir-commits mailing list