[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn in-place op modifications into `RewriteAction`s (PR #81245)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 12 01:12:50 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81245

>From f7010ea9f363c011f4171a6329ad6de30880c3e7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 12 Feb 2024 09:11:47 +0000
Subject: [PATCH] [mlir][Transforms][NFC] Turn in-place op modifications into
 `RewriteAction`s

This commit simplifies the internal state of the dialect conversion. A separate field for the previous state of in-place op modifications is no longer needed.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../mlir/Transforms/DialectConversion.h       |   4 +-
 .../Transforms/Utils/DialectConversion.cpp    | 123 ++++++++----------
 2 files changed, 56 insertions(+), 71 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index b028d2b71b3762..c0c702a7d34821 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -721,8 +721,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
 
   /// PatternRewriter hook for updating the given operation in-place.
   /// Note: These methods only track updates to the given operation itself,
-  /// and not nested regions. Updates to regions will still require
-  /// notification through other more specific hooks above.
+  /// and not nested regions. Updates to regions will still require notification
+  /// through other more specific hooks above.
   void startOpModification(Operation *op) override;
 
   /// PatternRewriter hook for updating the given operation in-place.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 85b67bb834de7c..489ccd0139c7f2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -154,14 +154,12 @@ namespace {
 struct RewriterState {
   RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
                 unsigned numReplacements, unsigned numArgReplacements,
-                unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numRootUpdates)
+                unsigned numRewrites, unsigned numIgnoredOperations)
       : numCreatedOps(numCreatedOps),
         numUnresolvedMaterializations(numUnresolvedMaterializations),
         numReplacements(numReplacements),
         numArgReplacements(numArgReplacements), numRewrites(numRewrites),
-        numIgnoredOperations(numIgnoredOperations),
-        numRootUpdates(numRootUpdates) {}
+        numIgnoredOperations(numIgnoredOperations) {}
 
   /// The current number of created operations.
   unsigned numCreatedOps;
@@ -180,44 +178,6 @@ struct RewriterState {
 
   /// The current number of ignored operations.
   unsigned numIgnoredOperations;
-
-  /// The current number of operations that were updated in place.
-  unsigned numRootUpdates;
-};
-
-//===----------------------------------------------------------------------===//
-// OperationTransactionState
-
-/// The state of an operation that was updated by a pattern in-place. This
-/// contains all of the necessary information to reconstruct an operation that
-/// was updated in place.
-class OperationTransactionState {
-public:
-  OperationTransactionState() = default;
-  OperationTransactionState(Operation *op)
-      : op(op), loc(op->getLoc()), attrs(op->getAttrDictionary()),
-        operands(op->operand_begin(), op->operand_end()),
-        successors(op->successor_begin(), op->successor_end()) {}
-
-  /// Discard the transaction state and reset the state of the original
-  /// operation.
-  void resetOperation() const {
-    op->setLoc(loc);
-    op->setAttrs(attrs);
-    op->setOperands(operands);
-    for (const auto &it : llvm::enumerate(successors))
-      op->setSuccessor(it.value(), it.index());
-  }
-
-  /// Return the original operation of this state.
-  Operation *getOperation() const { return op; }
-
-private:
-  Operation *op;
-  LocationAttr loc;
-  DictionaryAttr attrs;
-  SmallVector<Value, 8> operands;
-  SmallVector<Block *, 2> successors;
 };
 
 //===----------------------------------------------------------------------===//
@@ -761,7 +721,8 @@ class IRRewrite {
     MoveBlock,
     SplitBlock,
     BlockTypeConversion,
-    MoveOperation
+    MoveOperation,
+    ModifyOperation
   };
 
   virtual ~IRRewrite() = default;
@@ -992,7 +953,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::MoveOperation;
+           rewrite->getKind() <= Kind::ModifyOperation;
   }
 
 protected:
@@ -1031,6 +992,34 @@ class MoveOperationRewrite : public OperationRewrite {
   // this operation was the only operation in the region.
   Operation *insertBeforeOp;
 };
+
+/// In-place modification of an op. This rewrite is immediately reflected in
+/// the IR. The previous state of the operation is stored in this object.
+class ModifyOperationRewrite : public OperationRewrite {
+public:
+  ModifyOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                         Operation *op)
+      : OperationRewrite(Kind::ModifyOperation, rewriterImpl, op),
+        loc(op->getLoc()), attrs(op->getAttrDictionary()),
+        operands(op->operand_begin(), op->operand_end()),
+        successors(op->successor_begin(), op->successor_end()) {}
+
+  /// Discard the transaction state and reset the state of the original
+  /// operation.
+  void rollback() override {
+    op->setLoc(loc);
+    op->setAttrs(attrs);
+    op->setOperands(operands);
+    for (const auto &it : llvm::enumerate(successors))
+      op->setSuccessor(it.value(), it.index());
+  }
+
+private:
+  LocationAttr loc;
+  DictionaryAttr attrs;
+  SmallVector<Value, 8> operands;
+  SmallVector<Block *, 2> successors;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1184,9 +1173,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// operation was ignored.
   SetVector<Operation *> ignoredOps;
 
-  /// A transaction state for each of operations that were updated in-place.
-  SmallVector<OperationTransactionState, 4> rootUpdates;
-
   /// A vector of indices into `replacements` of operations that were replaced
   /// with values with different result types than the original operation, e.g.
   /// 1->N conversion of some kind.
@@ -1238,10 +1224,6 @@ static void detachNestedAndErase(Operation *op) {
 }
 
 void ConversionPatternRewriterImpl::discardRewrites() {
-  // Reset any operations that were updated in place.
-  for (auto &state : rootUpdates)
-    state.resetOperation();
-
   undoRewrites();
 
   // Remove any newly created ops.
@@ -1316,15 +1298,10 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
                        replacements.size(), argReplacements.size(),
-                       rewrites.size(), ignoredOps.size(), rootUpdates.size());
+                       rewrites.size(), ignoredOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
-  // Reset any operations that were updated in place.
-  for (unsigned i = state.numRootUpdates, e = rootUpdates.size(); i != e; ++i)
-    rootUpdates[i].resetOperation();
-  rootUpdates.resize(state.numRootUpdates);
-
   // Reset any replaced arguments.
   for (BlockArgument replacedArg :
        llvm::drop_begin(argReplacements, state.numArgReplacements))
@@ -1750,7 +1727,7 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
 #ifndef NDEBUG
   impl->pendingRootUpdates.insert(op);
 #endif
-  impl->rootUpdates.emplace_back(op);
+  impl->appendRewrite<ModifyOperationRewrite>(op);
 }
 
 void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
@@ -1769,13 +1746,15 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
          "operation did not have a pending in-place update");
 #endif
   // Erase the last update for this operation.
-  auto stateHasOp = [op](const auto &it) { return it.getOperation() == op; };
-  auto &rootUpdates = impl->rootUpdates;
-  auto it = llvm::find_if(llvm::reverse(rootUpdates), stateHasOp);
-  assert(it != rootUpdates.rend() && "no root update started on op");
-  (*it).resetOperation();
-  int updateIdx = std::prev(rootUpdates.rend()) - it;
-  rootUpdates.erase(rootUpdates.begin() + updateIdx);
+  auto it = llvm::find_if(
+      llvm::reverse(impl->rewrites), [&](std::unique_ptr<IRRewrite> &rewrite) {
+        auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
+        return modifyRewrite && modifyRewrite->getOperation() == op;
+      });
+  assert(it != impl->rewrites.rend() && "no root update started on op");
+  (*it)->rollback();
+  int updateIdx = std::prev(impl->rewrites.rend()) - it;
+  impl->rewrites.erase(impl->rewrites.begin() + updateIdx);
 }
 
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
@@ -2128,8 +2107,11 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
   };
   auto updatedRootInPlace = [&] {
     return llvm::any_of(
-        llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
-        [op](auto &state) { return state.getOperation() == op; });
+        llvm::drop_begin(impl.rewrites, curState.numRewrites),
+        [op](auto &rewrite) {
+          auto *modifyRewrite = dyn_cast<ModifyOperationRewrite>(rewrite.get());
+          return modifyRewrite && modifyRewrite->getOperation() == op;
+        });
   };
   (void)replacedRoot;
   (void)updatedRootInPlace;
@@ -2221,8 +2203,11 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
 LogicalResult OperationLegalizer::legalizePatternRootUpdates(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     RewriterState &state, RewriterState &newState) {
-  for (int i = state.numRootUpdates, e = newState.numRootUpdates; i != e; ++i) {
-    Operation *op = impl.rootUpdates[i].getOperation();
+  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+    auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
+    if (!rewrite)
+      continue;
+    Operation *op = rewrite->getOperation();
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(
           impl.logger, "failed to legalize operation updated in-place '{0}'",



More information about the llvm-branch-commits mailing list