[Mlir-commits] [mlir] [mlir][Transforms][NFC] Turn in-place op modification into `IRRewrite` (PR #81245)

Matthias Springer llvmlistbot at llvm.org
Wed Feb 21 07:19:19 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81245

>From 8a8a79da31d7bcef8209e34ddbbef9a6c0343852 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 16 Feb 2024 15:10:21 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Turn in-place op modifications into
 `RewriteAction`s

This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../mlir/Transforms/DialectConversion.h       |   4 +-
 .../Transforms/Utils/DialectConversion.cpp    | 146 +++++++++---------
 2 files changed, 74 insertions(+), 76 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 15fa39bde104b9..0d7722aa07ee38 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
 
   /// PatternRewriter hook for updating the given operation in-place.
   /// Note: These methods only track updates to the given operation itself,
-  /// and not nested regions. Updates to regions will still require
-  /// notification through other more specific hooks above.
+  /// and not nested regions. Updates to regions will still require notification
+  /// through other more specific hooks above.
   void startOpModification(Operation *op) override;
 
   /// PatternRewriter hook for updating the given operation in-place.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 84597fb7986b07..9e0ca3fc9b3491 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,14 +154,12 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numRootUpdates)
+                unsigned numRewrites, unsigned numIgnoredOperations)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
         numArgReplacements(numArgReplacements), numRewrites(numRewrites),
-        numIgnoredOperations(numIgnoredOperations),
-        numRootUpdates(numRootUpdates) {}
+        numIgnoredOperations(numIgnoredOperations) {}
 
   /// The current number of created operations.
   unsigned numCreatedOps;
@@ -180,44 +178,6 @@ struct RewriterState {
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
-
-  /// The current number of operations that were updated in place.
-  unsigned numRootUpdates;
-};
-
-//===----------------------------------------------------------------------===//
-// OperationTransactionState
-
-/// The state of an operation that was updated by a pattern in-place. This
-/// contains all of the necessary information to reconstruct an operation that
-/// was updated in place.
-class OperationTransactionState {
-public:
-  OperationTransactionState() = default;
-  OperationTransactionState(Operation *op)
-      : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
-        operands(op->operand_begin(), op->operand_end()),
-        successors(op->successor_begin(), op->successor_end()) {}
-
-  /// Discard the transaction state and reset the state of the original
-  /// operation.
-  void resetOperation() const {
-    op->setLoc(loc);
-    op->setAttrs(attrs);
-    op->setOperands(operands);
-    for (const auto &it : llvm::enumerate(successors))
-      op->setSuccessor(it.value(), it.index());
-  }
-
-  /// Return the original operation of this state.
-  Operation *getOperation() const { return op; }
-
-private:
-  Operation *op;
-  LocationAttr loc;
-  DictionaryAttr attrs;
-  SmallVector<Value, 8> operands;
-  SmallVector<Block *, 2> successors;
 };
 
 //===----------------------------------------------------------------------===//
@@ -754,14 +714,19 @@ namespace {
 class IRRewrite {
 public:
   /// The kind of the rewrite. Rewrites can be undone if the conversion fails.
+  /// Enum values are ordered, so that they can be used in `classof`: first all
+  /// block rewrites, then all operation rewrites.
   enum class Kind {
+    // Block rewrites
     CreateBlock,
     EraseBlock,
     InlineBlock,
     MoveBlock,
     SplitBlock,
     BlockTypeConversion,
-    MoveOperation
+    // Operation rewrites
+    MoveOperation,
+    ModifyOperation
   };
 
   virtual ~IRRewrite() = default;
@@ -992,7 +957,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::MoveOperation;
+           rewrite->getKind() <= Kind::ModifyOperation;
   }
 
 protected:
@@ -1031,8 +996,48 @@ class MoveOperationRewrite : public OperationRewrite {
   // this operation was the only operation in the region.
   Operation *insertBeforeOp;
 };
+
+/// In-place modification of an op. This rewrite is immediately reflected in
+/// the IR. The previous state of the operation is stored in this object.
+class ModifyOperationRewrite : public OperationRewrite {
+public:
+  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                         Operation *op)
+      : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
+        loc(op->getLoc()), attrs(op->getAttrDictionary()),
+        operands(op->operand_begin(), op->operand_end()),
+        successors(op->successor_begin(), op->successor_end()) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::ModifyOperation;
+  }
+
+  void rollback() override {
+    op->setLoc(loc);
+    op->setAttrs(attrs);
+    op->setOperands(operands);
+    for (const auto &it : llvm::enumerate(successors))
+      op->setSuccessor(it.value(), it.index());
+  }
+
+private:
+  LocationAttr loc;
+  DictionaryAttr attrs;
+  SmallVector<Value, 8> operands;
+  SmallVector<Block *, 2> successors;
+};
 } // namespace
 
+/// Return "true" if there is an operation rewrite that matches the specified
+/// rewrite type and operation among the given rewrites.
+template <typename RewriteTy, typename R>
+static bool hasRewrite(R &&rewrites, Operation *op) {
+  return any_of(std::move(rewrites), [&](auto &rewrite) {
+    auto *rewriteTy = dyn_cast<RewriteTy>(rewrite.get());
+    return rewriteTy && rewriteTy->getOperation() == op;
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionPatternRewriterImpl
 //===----------------------------------------------------------------------===//
@@ -1184,9 +1189,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// operation was ignored.
   SetVector<Operation *> ignoredOps;
 
-  /// A transaction state for each of operations that were updated in-place.
-  SmallVector<OperationTransactionState, 4> rootUpdates;
-
   /// A vector of indices into `replacements` of operations that were replaced
   /// with values with different result types than the original operation, e.g.
   /// 1->N conversion of some kind.
@@ -1238,10 +1240,6 @@ static void detachNestedAndErase(Operation *op) {
 }
 
 void ConversionPatternRewriterImpl::discardRewrites() {
-  // Reset any operations that were updated in place.
-  for (auto &state : rootUpdates)
-    state.resetOperation();
-
   undoRewrites();
 
   // Remove any newly created ops.
@@ -1316,15 +1314,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
                        replacements.size(), argReplacements.size(),
-                       rewrites.size(), ignoredOps.size(), rootUpdates.size());
+                       rewrites.size(), ignoredOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
-  // Reset any operations that were updated in place.
-  for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
-    rootUpdates[i].resetOperation();
-  rootUpdates.resize(state.numRootUpdates);
-
   // Reset any replaced arguments.
   for (BlockArgument replacedArg :
        llvm::drop_begin(argReplacements, state.numArgReplacements))
@@ -1750,7 +1743,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
 #ifndef NDEBUG
   impl->pendingRootUpdates.insert(op);
 #endif
-  impl->rootUpdates.emplace_back(op);
+  impl->appendRewrite<ModifyOperationRewrite>(op);
 }
 
 void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
@@ -1769,13 +1762,15 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
          "operation did not have a pending in-place update");
 #endif
   // Erase the last update for this operation.
-  auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
-  auto &rootUpdates = impl->rootUpdates;
-  auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
-  assert(it != rootUpdates.rend() && "no root update started on op");
-  (*it).resetOperation();
-  int updateIdx = std::prev(rootUpdates.rend()) - it;
-  rootUpdates.erase(rootUpdates.begin() + updateIdx);
+  auto it = llvm::find_if(
+      llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
+        auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
+        return modifyRewrite && modifyRewrite->getOperation() == op;
+      });
+  assert(it != impl->rewrites.rend() && "no root update started on op");
+  (*it)->rollback();
+  int updateIdx = std::prev(impl->rewrites.rend()) - it;
+  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
 }
 
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
@@ -2059,6 +2054,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   // Functor that cleans up the rewriter state after a pattern failed to match.
   RewriterState curState = rewriterImpl.getCurrentState();
   auto onFailure = [&](const Pattern &pattern) {
+    assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
     LLVM_DEBUG({
       logFailure(rewriterImpl.logger, "pattern failed to match");
       if (rewriterImpl.notifyCallback) {
@@ -2076,6 +2072,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   // Functor that performs additional legalization when a pattern is
   // successfully applied.
   auto onSuccess = [&](const Pattern &pattern) {
+    assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
     auto result = legalizePatternResult(op, pattern, rewriter, curState);
     appliedPatterns.erase(&pattern);
     if (failed(result))
@@ -2118,7 +2115,6 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
 
 #ifndef NDEBUG
   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
-#endif
 
   // Check that the root was either replaced or updated in place.
   auto replacedRoot = [&] {
@@ -2127,14 +2123,12 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
         [op](auto &it) { return it.first == op; });
   };
   auto updatedRootInPlace = [&] {
-    return llvm::any_of(
-        llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
-        [op](auto &state) { return state.getOperation() == op; });
+    return hasRewrite<ModifyOperationRewrite>(
+        llvm::drop_begin(impl.rewrites, curState.numRewrites), op);
   };
-  (void)replacedRoot;
-  (void)updatedRootInPlace;
   assert((replacedRoot() || updatedRootInPlace()) &&
          "expected pattern to replace the root operation");
+#endif // NDEBUG
 
   // Legalize each of the actions registered during application.
   RewriterState newState = impl.getCurrentState();
@@ -2221,8 +2215,11 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     RewriterState &state, RewriterState &newState) {
-  for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
-    Operation *op = impl.rootUpdates[i].getOperation();
+  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+    auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
+    if (!rewrite)
+      continue;
+    Operation *op = rewrite->getOperation();
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(
           impl.logger, "failed to legalize operation updated in-place '{0}'",
@@ -3562,7 +3559,8 @@ mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
 // Full Conversion
 
 LogicalResult
-mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
+mlir::applyFullConversion(ArrayRef<Operation *> ops,
+                          const ConversionTarget &target,
                           const FrozenRewritePatternSet &patterns) {
   OperationConverter opConverter(target, patterns, OpConversionMode::Full);
   return opConverter.convertOperations(ops);



More information about the Mlir-commits mailing list