[llvm-branch-commits] [mlir] [mlir][Transforms][NFC][WIP] Turn unresolved materializations into `IRRewrite`s (PR #81761)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 19 02:33:20 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit is a refactoring of the dialect conversion. The dialect conversion maintains a list of "IR rewrites" that can be committed (upon success) or rolled back (upon failure).

This commit turns the creation of unresolved materializations (`unrealized_conversion_cast`) into `IRRewrite` objects. After this commit, all steps in `applyRewrites` and `discardRewrites` are calls to `IRRewrite::commit` and `IRRewrite::rollback`.


---

Patch is 25.35 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81761.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+176-195) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5b7ad4e7b8e281..4ef26a739e4ea1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -152,15 +152,11 @@ namespace {
 /// This class contains a snapshot of the current conversion rewriter state.
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
-  RewriterState(unsigned numUnresolvedMaterializations, unsigned numRewrites,
-                unsigned numIgnoredOperations, unsigned numErased)
-      : numUnresolvedMaterializations(numUnresolvedMaterializations),
-        numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
+  RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
+                unsigned numErased)
+      : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
         numErased(numErased) {}
 
-  /// The current number of unresolved materializations.
-  unsigned numUnresolvedMaterializations;
-
   /// The current number of rewrites performed.
   unsigned numRewrites;
 
@@ -171,109 +167,10 @@ struct RewriterState {
   unsigned numErased;
 };
 
-//===----------------------------------------------------------------------===//
-// UnresolvedMaterialization
-
-/// This class represents an unresolved materialization, i.e. a materialization
-/// that was inserted during conversion that needs to be legalized at the end of
-/// the conversion process.
-class UnresolvedMaterialization {
-public:
-  /// The type of materialization.
-  enum Kind {
-    /// This materialization materializes a conversion for an illegal block
-    /// argument type, to a legal one.
-    Argument,
-
-    /// This materialization materializes a conversion from an illegal type to a
-    /// legal one.
-    Target
-  };
-
-  UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr,
-                            const TypeConverter *converter = nullptr,
-                            Kind kind = Target, Type origOutputType = nullptr)
-      : op(op), converterAndKind(converter, kind),
-        origOutputType(origOutputType) {}
-
-  /// Return the temporary conversion operation inserted for this
-  /// materialization.
-  UnrealizedConversionCastOp getOp() const { return op; }
-
-  /// Return the type converter of this materialization (which may be null).
-  const TypeConverter *getConverter() const {
-    return converterAndKind.getPointer();
-  }
-
-  /// Return the kind of this materialization.
-  Kind getKind() const { return converterAndKind.getInt(); }
-
-  /// Set the kind of this materialization.
-  void setKind(Kind kind) { converterAndKind.setInt(kind); }
-
-  /// Return the original illegal output type of the input values.
-  Type getOrigOutputType() const { return origOutputType; }
-
-private:
-  /// The unresolved materialization operation created during conversion.
-  UnrealizedConversionCastOp op;
-
-  /// The corresponding type converter to use when resolving this
-  /// materialization, and the kind of this materialization.
-  llvm::PointerIntPair<const TypeConverter *, 1, Kind> converterAndKind;
-
-  /// The original output type. This is only used for argument conversions.
-  Type origOutputType;
-};
-} // namespace
-
-/// Build an unresolved materialization operation given an output type and set
-/// of input operands.
-static Value buildUnresolvedMaterialization(
-    UnresolvedMaterialization::Kind kind, Block *insertBlock,
-    Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType,
-    Type origOutputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  // Avoid materializing an unnecessary cast.
-  if (inputs.size() == 1 && inputs.front().getType() == outputType)
-    return inputs.front();
-
-  // Create an unresolved materialization. We use a new OpBuilder to avoid
-  // tracking the materialization like we do for other operations.
-  OpBuilder builder(insertBlock, insertPt);
-  auto convertOp =
-      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
-  unresolvedMaterializations.emplace_back(convertOp, converter, kind,
-                                          origOutputType);
-  return convertOp.getResult(0);
-}
-static Value buildUnresolvedArgumentMaterialization(
-    PatternRewriter &rewriter, Location loc, ValueRange inputs,
-    Type origOutputType, Type outputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  return buildUnresolvedMaterialization(
-      UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(),
-      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
-      converter, unresolvedMaterializations);
-}
-static Value buildUnresolvedTargetMaterialization(
-    Location loc, Value input, Type outputType, const TypeConverter *converter,
-    SmallVectorImpl<UnresolvedMaterialization> &unresolvedMaterializations) {
-  Block *insertBlock = input.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(input))
-    insertPt = ++inputRes.getOwner()->getIterator();
-
-  return buildUnresolvedMaterialization(
-      UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input,
-      outputType, outputType, converter, unresolvedMaterializations);
-}
-
 //===----------------------------------------------------------------------===//
 // IR rewrites
 //===----------------------------------------------------------------------===//
 
-namespace {
 /// An IR rewrite that can be committed (upon success) or rolled back (upon
 /// failure).
 ///
@@ -295,7 +192,8 @@ class IRRewrite {
     MoveOperation,
     ModifyOperation,
     ReplaceOperation,
-    CreateOperation
+    CreateOperation,
+    UnresolvedMaterialization
   };
 
   virtual ~IRRewrite() = default;
@@ -602,7 +500,7 @@ class OperationRewrite : public IRRewrite {
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() >= Kind::MoveOperation &&
-           rewrite->getKind() <= Kind::CreateOperation;
+           rewrite->getKind() <= Kind::UnresolvedMaterialization;
   }
 
 protected:
@@ -721,6 +619,70 @@ class CreateOperationRewrite : public OperationRewrite {
 
   void rollback() override;
 };
+
+/// The type of materialization.
+enum MaterializationKind {
+  /// This materialization materializes a conversion for an illegal block
+  /// argument type, to a legal one.
+  Argument,
+
+  /// This materialization materializes a conversion from an illegal type to a
+  /// legal one.
+  Target
+};
+
+/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
+/// op. Unresolved materializations are erased at the end of the dialect
+/// conversion.
+class UnresolvedMaterializationRewrite : public OperationRewrite {
+public:
+  UnresolvedMaterializationRewrite(
+      ConversionPatternRewriterImpl &rewriterImpl,
+      UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
+      MaterializationKind kind = MaterializationKind::Target,
+      Type origOutputType = nullptr)
+      : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
+        converterAndKind(converter, kind), origOutputType(origOutputType) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::UnresolvedMaterialization;
+  }
+
+  UnrealizedConversionCastOp getOperation() const {
+    return cast<UnrealizedConversionCastOp>(op);
+  }
+
+  void rollback() override;
+
+  void cleanup() override;
+
+  /// Return the type converter of this materialization (which may be null).
+  const TypeConverter *getConverter() const {
+    return converterAndKind.getPointer();
+  }
+
+  /// Return the kind of this materialization.
+  MaterializationKind getMaterializationKind() const {
+    return converterAndKind.getInt();
+  }
+
+  /// Set the kind of this materialization.
+  void setMaterializationKind(MaterializationKind kind) {
+    converterAndKind.setInt(kind);
+  }
+
+  /// Return the original illegal output type of the input values.
+  Type getOrigOutputType() const { return origOutputType; }
+
+private:
+  /// The corresponding type converter to use when resolving this
+  /// materialization, and the kind of this materialization.
+  llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+      converterAndKind;
+
+  /// The original output type. This is only used for argument conversions.
+  Type origOutputType;
+};
 } // namespace
 
 /// Return "true" if there is an operation rewrite that matches the specified
@@ -763,14 +725,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
         notifyCallback(nullptr) {}
 
-  /// Cleanup and destroy any generated rewrite operations. This method is
-  /// invoked when the conversion process fails.
-  void discardRewrites();
-
-  /// Apply all requested operation rewrites. This method is invoked when the
-  /// conversion process succeeds.
-  void applyRewrites();
-
   //===--------------------------------------------------------------------===//
   // State Management
   //===--------------------------------------------------------------------===//
@@ -778,6 +732,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Return the current state of the rewriter.
   RewriterState getCurrentState();
 
+  /// Apply all requested operation rewrites. This method is invoked when the
+  /// conversion process succeeds.
+  void applyRewrites();
+
   /// Reset the state of the rewriter to a previously saved point.
   void resetState(RewriterState state);
 
@@ -810,17 +768,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// removes them from being considered for legalization.
   void markNestedOpsIgnored(Operation *op);
 
-  /// Detach any operations nested in the given operation from their parent
-  /// blocks, and erase the given operation. This can be used when the nested
-  /// operations are scheduled for erasure themselves, so deleting the regions
-  /// of the given operation together with their content would result in
-  /// double-free. This happens, for example, when rolling back op creation in
-  /// the reverse order and if the nested ops were created before the parent op.
-  /// This function does not need to collect nested ops recursively because it
-  /// is expected to also be called for each nested op when it is about to be
-  /// deleted.
-  void detachNestedAndErase(Operation *op);
-
   //===--------------------------------------------------------------------===//
   // Type Conversion
   //===--------------------------------------------------------------------===//
@@ -859,6 +806,28 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
       Block *block, const TypeConverter *converter,
       TypeConverter::SignatureConversion &signatureConversion);
 
+  //===--------------------------------------------------------------------===//
+  // Materializations
+  //===--------------------------------------------------------------------===//
+  /// Build an unresolved materialization operation given an output type and set
+  /// of input operands.
+  Value buildUnresolvedMaterialization(MaterializationKind kind,
+                                       Block *insertBlock,
+                                       Block::iterator insertPt, Location loc,
+                                       ValueRange inputs, Type outputType,
+                                       Type origOutputType,
+                                       const TypeConverter *converter);
+
+  Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
+                                               Location loc, ValueRange inputs,
+                                               Type origOutputType,
+                                               Type outputType,
+                                               const TypeConverter *converter);
+
+  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+                                             Type outputType,
+                                             const TypeConverter *converter);
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -938,10 +907,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
 
-  /// Ordered vector of all unresolved type conversion materializations during
-  /// conversion.
-  SmallVector<UnresolvedMaterialization> unresolvedMaterializations;
-
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
@@ -1129,26 +1094,15 @@ void CreateOperationRewrite::rollback() {
   eraseOp(op);
 }
 
-void ConversionPatternRewriterImpl::detachNestedAndErase(Operation *op) {
-  // if (erasedIR.erasedOps.contains(op)) return;
-
-  for (Region &region : op->getRegions()) {
-    for (Block &block : region.getBlocks()) {
-      while (!block.getOperations().empty())
-        block.getOperations().remove(block.getOperations().begin());
-      block.dropAllDefinedValueUses();
-    }
+void UnresolvedMaterializationRewrite::rollback() {
+  if (getMaterializationKind() == MaterializationKind::Target) {
+    for (Value input : op->getOperands())
+      rewriterImpl.mapping.erase(input);
   }
-  eraseRewriter.eraseOp(op);
+  eraseOp(op);
 }
 
-void ConversionPatternRewriterImpl::discardRewrites() {
-  undoRewrites();
-
-  // Remove any newly created ops.
-  for (UnresolvedMaterialization &materialization : unresolvedMaterializations)
-    detachNestedAndErase(materialization.getOp());
-}
+void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
 
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
@@ -1156,39 +1110,20 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrite->commit();
   for (auto &rewrite : rewrites)
     rewrite->cleanup();
-
-  // Drop all of the unresolved materialization operations created during
-  // conversion.
-  for (auto &mat : unresolvedMaterializations)
-    eraseRewriter.eraseOp(mat.getOp());
 }
 
 //===----------------------------------------------------------------------===//
 // State Management
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
-  return RewriterState(unresolvedMaterializations.size(), rewrites.size(),
-                       ignoredOps.size(), eraseRewriter.erased.size());
+  return RewriterState(rewrites.size(), ignoredOps.size(),
+                       eraseRewriter.erased.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   // Undo any rewrites.
   undoRewrites(state.numRewrites);
 
-  // Pop all of the newly inserted materializations.
-  while (unresolvedMaterializations.size() !=
-         state.numUnresolvedMaterializations) {
-    UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val();
-    UnrealizedConversionCastOp op = mat.getOp();
-
-    // If this was a target materialization, drop the mapping that was inserted.
-    if (mat.getKind() == UnresolvedMaterialization::Target) {
-      for (Value input : op->getOperands())
-        mapping.erase(input);
-    }
-    detachNestedAndErase(op);
-  }
-
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
@@ -1249,8 +1184,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     if (currentTypeConverter && desiredType && newOperandType != desiredType) {
       Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
       Value castValue = buildUnresolvedTargetMaterialization(
-          operandLoc, newOperand, desiredType, currentTypeConverter,
-          unresolvedMaterializations);
+          operandLoc, newOperand, desiredType, currentTypeConverter);
       mapping.map(mapping.lookupOrDefault(newOperand), castValue);
       newOperand = castValue;
     }
@@ -1432,7 +1366,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 
       newArg = buildUnresolvedArgumentMaterialization(
           rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
-          converter, unresolvedMaterializations);
+          converter);
     }
 
     mapping.map(origArg, newArg);
@@ -1445,6 +1379,50 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   return newBlock;
 }
 
+//===----------------------------------------------------------------------===//
+// Materializations
+//===----------------------------------------------------------------------===//
+
+/// Build an unresolved materialization operation given an output type and set
+/// of input operands.
+Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
+    MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+    Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+    const TypeConverter *converter) {
+  // Avoid materializing an unnecessary cast.
+  if (inputs.size() == 1 && inputs.front().getType() == outputType)
+    return inputs.front();
+
+  // Create an unresolved materialization. We use a new OpBuilder to avoid
+  // tracking the materialization like we do for other operations.
+  OpBuilder builder(insertBlock, insertPt);
+  auto convertOp =
+      builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+                                                  origOutputType);
+  return convertOp.getResult(0);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
+    PatternRewriter &rewriter, Location loc, ValueRange inputs,
+    Type origOutputType, Type outputType, const TypeConverter *converter) {
+  return buildUnresolvedMaterialization(
+      MaterializationKind::Argument, rewriter.getInsertionBlock(),
+      rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
+      converter);
+}
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+    Location loc, Value input, Type outputType,
+    const TypeConverter *converter) {
+  Block *insertBlock = input.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(input))
+    insertPt = ++inputRes.getOwner()->getIterator();
+
+  return buildUnresolvedMaterialization(MaterializationKind::Target,
+                                        insertBlock, insertPt, loc, input,
+                                        outputType, outputType, converter);
+}
+
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -2497,18 +2475,18 @@ LogicalResult OperationConverter::convertOperations(
 
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
-      return rewriterImpl.discardRewrites(), failure();
+      return rewriterImpl.undoRewrites(), failure();
 
   // Now that all of the operations have been converted, finalize the conversion
   // process to ensure any lingering conversion artifacts are cleaned up and
   // legalized.
   if (failed(finalize(rewriter)))
-    return rewriterImpl.discardRewrites(), failure();
+    return rewriterImpl.undoRewrites(), failure();
 
   // After a successful conversion, apply rewrites if this is not an analysis
   // conversion.
   if (mode == OpConversionMode::Analysis) {
-    rewriterImpl.discardRewrites();
+    rewriterImpl.undoRewrites();
   } else {
     rewriterImpl.applyRewrites();
   }
@@ -2613,11 +2591,12 @@ replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
 /// Compute all of the unresolved materializations that will persist beyond the
 /// conversion process, and require inserting a proper user materialization for.
 static void computeNecessaryMaterializations(
-    DenseMap<Operation *, UnresolvedMaterialization *> &materializati...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81761


More information about the llvm-branch-commits mailing list