[llvm-branch-commits] [mlir] [mlir][Transforms][NFC] Turn op creation into `IRRewrite` (PR #81759)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Feb 14 08:36:43 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).

Until now, the dialect conversion kept track of "op creation" in separate internal data structures. This commit turns "op creation" into an `IRRewrite` that can be committed and rolled back just like any other rewrite. This commit simplifies the internal state of the dialect conversion.

---
Full diff: https://github.com/llvm/llvm-project/pull/81759.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+66-38) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a07c8a56822de5..2bb56d5df1ebd3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,17 +152,12 @@ namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
-  RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
-                unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased)
-      : numCreatedOps(numCreatedOps),
-        numUnresolvedMaterializations(numUnresolvedMaterializations),
+  RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
+                unsigned numIgnoredOperations, unsigned numErased)
+      : numUnresolvedMaterializations(numUnresolvedMaterializations),
         numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
         numErased(numErased) {}
 
-  /// The current number of created operations.
-  unsigned numCreatedOps;
-
   /// The current number of unresolved materializations.
   unsigned numUnresolvedMaterializations;
 
@@ -299,7 +294,8 @@ class IRRewrite {
     ReplaceBlockArg,
     MoveOperation,
     ModifyOperation,
-    ReplaceOperation
+    ReplaceOperation,
+    CreateOperation
   };
 
   virtual ~IRRewrite() = default;
@@ -372,7 +368,11 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
-    eraseBlock(block);
+    if (block->getParent()) {
+      eraseBlock(block);
+    } else {
+      delete block;
+    }
   }
 };
 
@@ -602,7 +602,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::ReplaceOperation;
+           rewrite->getKind() <= Kind::CreateOperation;
   }
 
 protected:
@@ -708,6 +708,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
   /// 1->N conversion of some kind.
   bool changedResults;
 };
+
+class CreateOperationRewrite : public OperationRewrite {
+public:
+  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                         Operation *op)
+      : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::CreateOperation;
+  }
+
+  void rollback() override;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -925,9 +938,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
 
-  /// Ordered vector of all of the newly created operations during conversion.
-  SmallVector<Operation *> createdOps;
-
   /// Ordered vector of all unresolved type conversion materializations during
   /// conversion.
   SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1110,7 +1120,18 @@ void ReplaceOperationRewrite::rollback() {
 
 void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
 
+void CreateOperationRewrite::rollback() {
+  for (Region &region : op->getRegions()) {
+    while (!region.getBlocks().empty())
+      region.getBlocks().remove(region.getBlocks().begin());
+  }
+  op->dropAllUses();
+  eraseOp(op);
+}
+
 void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
+  // if (erasedIR.erasedOps.contains(op)) return;
+
   for (Region &region : op->getRegions()) {
     for (Block &block : region.getBlocks()) {
       while (!block.getOperations().empty())
@@ -1127,8 +1148,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
   // Remove any newly created ops.
   for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
     detachNestedAndErase(materialization.getOp());
-  for (auto *op : llvm::reverse(createdOps))
-    detachNestedAndErase(op);
 }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1148,9 +1167,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
-                       rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size());
+  return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
+                       ignoredOps.size(), eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1171,12 +1189,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     detachNestedAndErase(op);
   }
 
-  // Pop all of the newly created operations.
-  while (createdOps.size() != state.numCreatedOps) {
-    detachNestedAndErase(createdOps.back());
-    createdOps.pop_back();
-  }
-
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
@@ -1460,7 +1472,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
   });
   if (!previous.isSet()) {
     // This is a newly created op.
-    createdOps.push_back(op);
+    appendRewrite<CreateOperationRewrite>(op);
     return;
   }
   Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1961,13 +1973,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
   rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
-  for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
+  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
        i != e; ++i) {
-    Operation *cstOp = rewriterImpl.createdOps[i];
-    if (failed(legalize(cstOp, rewriter))) {
+    auto *createOp =
+        dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
+    if (!createOp)
+      continue;
+    if (failed(legalize(createOp->getOperation(), rewriter))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
-                            cstOp->getName()));
+                            createOp->getOperation()->getName()));
       rewriterImpl.resetState(curState);
       return failure();
     }
@@ -2112,9 +2127,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     // blocks in regions created by this pattern will already be legalized later
     // on. If we haven't built the set yet, build it now.
     if (operationsToIgnore.empty()) {
-      auto createdOps = ArrayRef<Operation *>(impl.createdOps)
-                            .drop_front(state.numCreatedOps);
-      operationsToIgnore.insert(createdOps.begin(), createdOps.end());
+      for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
+           ++i) {
+        auto *createOp =
+            dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+        if (!createOp)
+          continue;
+        operationsToIgnore.insert(createOp->getOperation());
+      }
     }
 
     // If this operation should be considered for re-legalization, try it.
@@ -2132,8 +2152,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     RewriterState &state, RewriterState &newState) {
-  for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
-    Operation *op = impl.createdOps[i];
+  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+    auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+    if (!createOp)
+      continue;
+    Operation *op = createOp->getOperation();
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(impl.logger,
                             "failed to legalize generated operation '{0}'({1})",
@@ -2563,10 +2586,15 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
     });
     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
   };
-  for (auto &r : rewriterImpl.rewrites)
-    if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
-      if (failed(rewrite->materializeLiveConversions(findLiveUser)))
+  // Note: `rewrites` may be reallocated as the loop is running.
+  for (int64_t i = 0; i < rewriterImpl.rewrites.size(); ++i) {
+    auto &rewrite = rewriterImpl.rewrites[i];
+    if (auto *blockTypeConversionRewrite =
+            dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
+      if (failed(blockTypeConversionRewrite->materializeLiveConversions(
+              findLiveUser)))
         return failure();
+  }
   return success();
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/81759


More information about the llvm-branch-commits mailing list