[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn in-place op modifications into `RewriteAction`s (PR #81245)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Feb 9 03:38:38 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/81245

This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.

>From cdbd92786887ba8f801cb0c4299f708f0a410465 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 9 Feb 2024 11:36:59 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Turn in-place op modifications into
 `RewriteAction`s

This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.
---
 .../Transforms/Utils/DialectConversion.cpp    | 128 ++++++++----------
 1 file changed, 58 insertions(+), 70 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ffdb069f6e9b81..d0114a148cd374 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,15 +154,13 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numRewriteActions, unsigned numIgnoredOperations,
-                unsigned numRootUpdates)
+                unsigned numRewriteActions, unsigned numIgnoredOperations)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
         numArgReplacements(numArgReplacements),
         numRewriteActions(numRewriteActions),
-        numIgnoredOperations(numIgnoredOperations),
-        numRootUpdates(numRootUpdates) {}
+        numIgnoredOperations(numIgnoredOperations) {}
 
   /// The current number of created operations.
   unsigned numCreatedOps;
@@ -181,44 +179,6 @@ struct RewriterState {
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
-
-  /// The current number of operations that were updated in place.
-  unsigned numRootUpdates;
-};
-
-//===----------------------------------------------------------------------===//
-// OperationTransactionState
-
-/// The state of an operation that was updated by a pattern in-place. This
-/// contains all of the necessary information to reconstruct an operation that
-/// was updated in place.
-class OperationTransactionState {
-public:
-  OperationTransactionState() = default;
-  OperationTransactionState(Operation *op)
-      : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
-        operands(op->operand_begin(), op->operand_end()),
-        successors(op->successor_begin(), op->successor_end()) {}
-
-  /// Discard the transaction state and reset the state of the original
-  /// operation.
-  void resetOperation() const {
-    op->setLoc(loc);
-    op->setAttrs(attrs);
-    op->setOperands(operands);
-    for (const auto &it : llvm::enumerate(successors))
-      op->setSuccessor(it.value(), it.index());
-  }
-
-  /// Return the original operation of this state.
-  Operation *getOperation() const { return op; }
-
-private:
-  Operation *op;
-  LocationAttr loc;
-  DictionaryAttr attrs;
-  SmallVector<Value, 8> operands;
-  SmallVector<Block *, 2> successors;
 };
 
 //===----------------------------------------------------------------------===//
@@ -758,7 +718,8 @@ class RewriteAction {
     MoveBlock,
     SplitBlock,
     BlockTypeConversion,
-    MoveOperation
+    MoveOperation,
+    ModifyOperation
   };
 
   virtual ~RewriteAction() = default;
@@ -980,7 +941,7 @@ class OperationAction : public RewriteAction {
 
   static bool classof(const RewriteAction *action) {
     return action->getKind() >= Kind::MoveOperation &&
-           action->getKind() <= Kind::MoveOperation;
+           action->getKind() <= Kind::ModifyOperation;
   }
 
 protected:
@@ -1019,6 +980,34 @@ class MoveOperationAction : public OperationAction {
   // this operation was the only operation in the region.
   Operation *insertBeforeOp;
 };
+
+/// Rewrite action that represents the in-place modification of an operation.
+/// The previous state of the operation is stored in this action.
+class ModifyOperationAction : public OperationAction {
+public:
+  ModifyOperationAction(ConversionPatternRewriterImpl &rewriterImpl,
+                        Operation *op)
+      : OperationAction(Kind::ModifyOperation, rewriterImpl, op),
+        loc(op->getLoc()), attrs(op->getAttrDictionary()),
+        operands(op->operand_begin(), op->operand_end()),
+        successors(op->successor_begin(), op->successor_end()) {}
+
+  /// Discard the transaction state and reset the state of the original
+  /// operation.
+  void rollback() override {
+    op->setLoc(loc);
+    op->setAttrs(attrs);
+    op->setOperands(operands);
+    for (const auto &it : llvm::enumerate(successors))
+      op->setSuccessor(it.value(), it.index());
+  }
+
+private:
+  LocationAttr loc;
+  DictionaryAttr attrs;
+  SmallVector<Value, 8> operands;
+  SmallVector<Block *, 2> successors;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1172,9 +1161,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// operation was ignored.
   SetVector<Operation *> ignoredOps;
 
-  /// A transaction state for each of operations that were updated in-place.
-  SmallVector<OperationTransactionState, 4> rootUpdates;
-
   /// A vector of indices into `replacements` of operations that were replaced
   /// with values with different result types than the original operation, e.g.
   /// 1->N conversion of some kind.
@@ -1226,10 +1212,6 @@ static void detachNestedAndErase(Operation *op) {
 }
 
 void ConversionPatternRewriterImpl::discardRewrites() {
-  // Reset any operations that were updated in place.
-  for (auto &state : rootUpdates)
-    state.resetOperation();
-
   undoRewriteActions();
 
   // Remove any newly created ops.
@@ -1304,16 +1286,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
                        replacements.size(), argReplacements.size(),
-                       rewriteActions.size(), ignoredOps.size(),
-                       rootUpdates.size());
+                       rewriteActions.size(), ignoredOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
-  // Reset any operations that were updated in place.
-  for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
-    rootUpdates[i].resetOperation();
-  rootUpdates.resize(state.numRootUpdates);
-
   // Reset any replaced arguments.
   for (BlockArgument replacedArg :
        llvm::drop_begin(argReplacements, state.numArgReplacements))
@@ -1740,7 +1716,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
 #ifndef NDEBUG
   impl->pendingRootUpdates.insert(op);
 #endif
-  impl->rootUpdates.emplace_back(op);
+  impl->appendRewriteAction<ModifyOperationAction>(op);
 }
 
 void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
@@ -1759,13 +1735,17 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
          "operation did not have a pending in-place update");
 #endif
   // Erase the last update for this operation.
-  auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
-  auto &rootUpdates = impl->rootUpdates;
-  auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
-  assert(it != rootUpdates.rend() && "no root update started on op");
-  (*it).resetOperation();
-  int updateIdx = std::prev(rootUpdates.rend()) - it;
-  rootUpdates.erase(rootUpdates.begin() + updateIdx);
+  auto it =
+      llvm::find_if(llvm::reverse(impl->rewriteActions),
+                    [&](std::unique_ptr<RewriteAction> &action) {
+                      auto *modifyAction =
+                          dynamic_cast<ModifyOperationAction *>(action.get());
+                      return modifyAction && modifyAction->getOperation() == op;
+                    });
+  assert(it != impl->rewriteActions.rend() && "no root update started on op");
+  (*it)->rollback();
+  int updateIdx = std::prev(impl->rewriteActions.rend()) - it;
+  impl->rewriteActions.erase(impl->rewriteActions.begin() + updateIdx);
 }
 
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
@@ -2118,8 +2098,11 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
   };
   auto updatedRootInPlace = [&] {
     return llvm::any_of(
-        llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
-        [op](auto &state) { return state.getOperation() == op; });
+        llvm::drop_begin(impl.rewriteActions, curState.numRewriteActions),
+        [op](auto &action) {
+          auto *modifyAction = dyn_cast<ModifyOperationAction>(action.get());
+          return modifyAction && modifyAction->getOperation() == op;
+        });
   };
   (void)replacedRoot;
   (void)updatedRootInPlace;
@@ -2213,8 +2196,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     RewriterState &state, RewriterState &newState) {
-  for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
-    Operation *op = impl.rootUpdates[i].getOperation();
+  for (int i = state.numRewriteActions, e = newState.numRewriteActions; i != e;
+       ++i) {
+    auto *action =
+        dyn_cast<ModifyOperationAction>(impl.rewriteActions[i].get());
+    if (!action)
+      continue;
+    Operation *op = action->getOperation();
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(
           impl.logger, "failed to legalize operation updated in-place '{0}'",



More information about the llvm-branch-commits mailing list