[Mlir-commits] [mlir] 9ca70d7 - [mlir][Transforms][NFC] Turn op creation into `IRRewrite` (#81759)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 23 01:03:29 PST 2024


Author: Matthias Springer
Date: 2024-02-23T10:03:26+01:00
New Revision: 9ca70d72f4f217ff4f6ab337ad4a8e6666860791

URL: https://github.com/llvm/llvm-project/commit/9ca70d72f4f217ff4f6ab337ad4a8e6666860791
DIFF: https://github.com/llvm/llvm-project/commit/9ca70d72f4f217ff4f6ab337ad4a8e6666860791.diff

LOG: [mlir][Transforms][NFC] Turn op creation into `IRRewrite` (#81759)

This commit is a refactoring of the dialect conversion. The dialect
conversion maintains a list of "IR rewrites" that can be committed (upon
success) or rolled back (upon failure).

Until now, the dialect conversion kept track of "op creation" in
separate internal data structures. This commit turns "op creation" into
an `IRRewrite` that can be committed and rolled back just like any other
rewrite. This commit simplifies the internal state of the dialect
conversion.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index dec68048dc1d30..704597148dfac0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,17 +152,12 @@ namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
-  RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations,
-                unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased)
-      : numCreatedOps(numCreatedOps),
-        numUnresolvedMaterializations(numUnresolvedMaterializations),
+  RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
+                unsigned numIgnoredOperations, unsigned numErased)
+      : numUnresolvedMaterializations(numUnresolvedMaterializations),
         numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
         numErased(numErased) {}
 
-  /// The current number of created operations.
-  unsigned numCreatedOps;
-
   /// The current number of unresolved materializations.
   unsigned numUnresolvedMaterializations;
 
@@ -303,7 +298,8 @@ class IRRewrite {
     // Operation rewrites
     MoveOperation,
     ModifyOperation,
-    ReplaceOperation
+    ReplaceOperation,
+    CreateOperation
   };
 
   virtual ~IRRewrite() = default;
@@ -376,7 +372,10 @@ class CreateBlockRewrite : public BlockRewrite {
     auto &blockOps = block->getOperations();
     while (!blockOps.empty())
       blockOps.remove(blockOps.begin());
-    eraseBlock(block);
+    if (block->getParent())
+      eraseBlock(block);
+    else
+      delete block;
   }
 };
 
@@ -606,7 +605,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::ReplaceOperation;
+           rewrite->getKind() <= Kind::CreateOperation;
   }
 
 protected:
@@ -740,6 +739,19 @@ class ReplaceOperationRewrite : public OperationRewrite {
   /// A boolean flag that indicates whether result types have changed or not.
   bool changedResults;
 };
+
+class CreateOperationRewrite : public OperationRewrite {
+public:
+  CreateOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                         Operation *op)
+      : OperationRewrite(Kind::CreateOperation, rewriterImpl, op) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::CreateOperation;
+  }
+
+  void rollback() override;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -957,9 +969,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // replacing a value with one of a 
diff erent type.
   ConversionValueMapping mapping;
 
-  /// Ordered vector of all of the newly created operations during conversion.
-  SmallVector<Operation *> createdOps;
-
   /// Ordered vector of all unresolved type conversion materializations during
   /// conversion.
   SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
@@ -1144,6 +1153,15 @@ void ReplaceOperationRewrite::rollback() {
 
 void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
 
+void CreateOperationRewrite::rollback() {
+  for (Region &region : op->getRegions()) {
+    while (!region.getBlocks().empty())
+      region.getBlocks().remove(region.getBlocks().begin());
+  }
+  op->dropAllUses();
+  eraseOp(op);
+}
+
 void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
   for (Region &region : op->getRegions()) {
     for (Block &block : region.getBlocks()) {
@@ -1161,8 +1179,6 @@ void ConversionPatternRewriterImpl::discardRewrites() {
   // Remove any newly created ops.
   for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
     detachNestedAndErase(materialization.getOp());
-  for (auto *op : llvm::reverse(createdOps))
-    detachNestedAndErase(op);
 }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
@@ -1182,9 +1198,8 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(createdOps.size(), unresolvedMaterializations.size(),
-                       rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size());
+  return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
+                       ignoredOps.size(), eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1205,12 +1220,6 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     detachNestedAndErase(op);
   }
 
-  // Pop all of the newly created operations.
-  while (createdOps.size() != state.numCreatedOps) {
-    detachNestedAndErase(createdOps.back());
-    createdOps.pop_back();
-  }
-
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
@@ -1478,7 +1487,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
   });
   if (!previous.isSet()) {
     // This is a newly created op.
-    createdOps.push_back(op);
+    appendRewrite<CreateOperationRewrite>(op);
     return;
   }
   Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1979,13 +1988,16 @@ OperationLegalizer::legalizeWithFold(Operation *op,
   rewriter.replaceOp(op, replacementValues);
 
   // Recursively legalize any new constant operations.
-  for (unsigned i = curState.numCreatedOps, e = rewriterImpl.createdOps.size();
+  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
        i != e; ++i) {
-    Operation *cstOp = rewriterImpl.createdOps[i];
-    if (failed(legalize(cstOp, rewriter))) {
+    auto *createOp =
+        dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
+    if (!createOp)
+      continue;
+    if (failed(legalize(createOp->getOperation(), rewriter))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
-                            cstOp->getName()));
+                            createOp->getOperation()->getName()));
       rewriterImpl.resetState(curState);
       return failure();
     }
@@ -2132,9 +2144,14 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
     // blocks in regions created by this pattern will already be legalized later
     // on. If we haven't built the set yet, build it now.
     if (operationsToIgnore.empty()) {
-      auto createdOps = ArrayRef<Operation *>(impl.createdOps)
-                            .drop_front(state.numCreatedOps);
-      operationsToIgnore.insert(createdOps.begin(), createdOps.end());
+      for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
+           ++i) {
+        auto *createOp =
+            dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+        if (!createOp)
+          continue;
+        operationsToIgnore.insert(createOp->getOperation());
+      }
     }
 
     // If this operation should be considered for re-legalization, try it.
@@ -2152,8 +2169,11 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
 LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
     ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
     RewriterState &state, RewriterState &newState) {
-  for (int i = state.numCreatedOps, e = newState.numCreatedOps; i != e; ++i) {
-    Operation *op = impl.createdOps[i];
+  for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
+    auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
+    if (!createOp)
+      continue;
+    Operation *op = createOp->getOperation();
     if (failed(legalize(op, rewriter))) {
       LLVM_DEBUG(logFailure(impl.logger,
                             "failed to legalize generated operation '{0}'({1})",
@@ -2583,10 +2603,16 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
     });
     return liveUserIt == val.user_end() ? nullptr : *liveUserIt;
   };
-  for (auto &r : rewriterImpl.rewrites)
-    if (auto *rewrite = dyn_cast<BlockTypeConversionRewrite>(r.get()))
-      if (failed(rewrite->materializeLiveConversions(findLiveUser)))
+  // Note: `rewrites` may be reallocated as the loop is running.
+  for (int64_t i = 0; i < static_cast<int64_t>(rewriterImpl.rewrites.size());
+       ++i) {
+    auto &rewrite = rewriterImpl.rewrites[i];
+    if (auto *blockTypeConversionRewrite =
+            dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
+      if (failed(blockTypeConversionRewrite->materializeLiveConversions(
+              findLiveUser)))
         return failure();
+  }
   return success();
 }
 


        


More information about the Mlir-commits mailing list