[llvm-branch-commits] [mlir] fe0ac00 - Revert "[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase (…"

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Nov 20 17:40:36 PST 2024


Author: Matthias Springer
Date: 2024-11-21T10:40:33+09:00
New Revision: fe0ac007ca9e253e79d2dc0e95ce166efd585a5b

URL: https://github.com/llvm/llvm-project/commit/fe0ac007ca9e253e79d2dc0e95ce166efd585a5b
DIFF: https://github.com/llvm/llvm-project/commit/fe0ac007ca9e253e79d2dc0e95ce166efd585a5b.diff

LOG: Revert "[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase (…"

This reverts commit aa65473c9ddcf3cbb80e63c38af842d05346374b.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 03d483f73f255e..42fe5b925654a1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -75,10 +75,6 @@ namespace {
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
 struct ConversionValueMapping {
-  /// Return "true" if an SSA value is mapped to the given value. May return
-  /// false positives.
-  bool isMappedTo(Value value) const { return mappedTo.contains(value); }
-
   /// Lookup the most recently mapped value with the desired type in the
   /// mapping.
   ///
@@ -103,18 +99,22 @@ struct ConversionValueMapping {
         assert(it != oldVal && "inserting cyclic mapping");
     });
     mapping.map(oldVal, newVal);
-    mappedTo.insert(newVal);
   }
 
   /// Drop the last mapping for the given value.
   void erase(Value value) { mapping.erase(value); }
 
+  /// Returns the inverse raw value mapping (without recursive query support).
+  DenseMap<Value, SmallVector<Value>> getInverse() const {
+    DenseMap<Value, SmallVector<Value>> inverse;
+    for (auto &it : mapping.getValueMap())
+      inverse[it.second].push_back(it.first);
+    return inverse;
+  }
+
 private:
   /// Current value mappings.
   IRMapping mapping;
-
-  /// All SSA values that are mapped to. May contain false positives.
-  DenseSet<Value> mappedTo;
 };
 } // namespace
 
@@ -434,9 +434,10 @@ class MoveBlockRewrite : public BlockRewrite {
 class BlockTypeConversionRewrite : public BlockRewrite {
 public:
   BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                             Block *block, Block *origBlock)
+                             Block *block, Block *origBlock,
+                             const TypeConverter *converter)
       : BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
-        origBlock(origBlock) {}
+        origBlock(origBlock), converter(converter) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -444,6 +445,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
 
   Block *getOrigBlock() const { return origBlock; }
 
+  const TypeConverter *getConverter() const { return converter; }
+
   void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
@@ -451,6 +454,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
 private:
   /// The original block that was requested to have its signature converted.
   Block *origBlock;
+
+  /// The type converter used to convert the arguments.
+  const TypeConverter *converter;
 };
 
 /// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -459,10 +465,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
 class ReplaceBlockArgRewrite : public BlockRewrite {
 public:
   ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
-                         Block *block, BlockArgument arg,
-                         const TypeConverter *converter)
-      : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
-        converter(converter) {}
+                         Block *block, BlockArgument arg)
+      : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
 
   static bool classof(const IRRewrite *rewrite) {
     return rewrite->getKind() == Kind::ReplaceBlockArg;
@@ -474,9 +478,6 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
 
 private:
   BlockArgument arg;
-
-  /// The current type converter when the block argument was replaced.
-  const TypeConverter *converter;
 };
 
 /// An operation rewrite.
@@ -626,6 +627,8 @@ class ReplaceOperationRewrite : public OperationRewrite {
 
   void cleanup(RewriterBase &rewriter) override;
 
+  const TypeConverter *getConverter() const { return converter; }
+
 private:
   /// An optional type converter that can be used to materialize conversions
   /// between the new and old values if necessary.
@@ -822,14 +825,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                  ValueRange replacements, Value originalValue,
                                  const TypeConverter *converter);
 
-  /// Find a replacement value for the given SSA value in the conversion value
-  /// mapping. The replacement value must have the same type as the given SSA
-  /// value. If there is no replacement value with the correct type, find the
-  /// latest replacement value (regardless of the type) and build a source
-  /// materialization.
-  Value findOrBuildReplacementValue(Value value,
-                                    const TypeConverter *converter);
-
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -975,7 +970,7 @@ void BlockTypeConversionRewrite::rollback() {
 }
 
 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+  Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
   if (!repl)
     return;
 
@@ -1004,7 +999,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   // Compute replacement values.
   SmallVector<Value> replacements =
       llvm::map_to_vector(op->getResults(), [&](OpResult result) {
-        return rewriterImpl.findOrBuildReplacementValue(result, converter);
+        return rewriterImpl.mapping.lookupOrNull(result, result.getType());
       });
 
   // Notify the listener that the operation is about to be replaced.
@@ -1074,10 +1069,8 @@ void UnresolvedMaterializationRewrite::rollback() {
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
   IRRewriter rewriter(context, config.listener);
-  // Note: New rewrites may be added during the "commit" phase and the
-  // `rewrites` vector may reallocate.
-  for (size_t i = 0; i < rewrites.size(); ++i)
-    rewrites[i]->commit(rewriter);
+  for (auto &rewrite : rewrites)
+    rewrite->commit(rewriter);
 
   // Clean up all rewrites.
   for (auto &rewrite : rewrites)
@@ -1282,7 +1275,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
           /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*originalType=*/Type(), converter);
       mapping.map(origArg, repl);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
     }
 
@@ -1292,7 +1285,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
              "invalid to provide a replacement value when the argument isn't "
              "dropped");
       mapping.map(origArg, repl);
-      appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+      appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
       continue;
     }
 
@@ -1305,10 +1298,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     insertNTo1Materialization(
         OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
         /*replacements=*/replArgs, /*outputValue=*/origArg, converter);
-    appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+    appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
+  appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
 
   // Erase the old block. (It is just unlinked for now and will be erased during
   // cleanup.)
@@ -1378,41 +1371,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
   }
 }
 
-Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
-    Value value, const TypeConverter *converter) {
-  // Find a replacement value with the same type.
-  Value repl = mapping.lookupOrNull(value, value.getType());
-  if (repl)
-    return repl;
-
-  // Check if the value is dead. No replacement value is needed in that case.
-  // This is an approximate check that may have false negatives but does not
-  // require computing and traversing an inverse mapping. (We may end up
-  // building source materializations that are never used and that fold away.)
-  if (llvm::all_of(value.getUsers(),
-                   [&](Operation *op) { return replacedOps.contains(op); }) &&
-      !mapping.isMappedTo(value))
-    return Value();
-
-  // No replacement value was found. Get the latest replacement value
-  // (regardless of the type) and build a source materialization to the
-  // original type.
-  repl = mapping.lookupOrNull(value);
-  if (!repl) {
-    // No replacement value is registered in the mapping. This means that the
-    // value is dropped and no longer needed. (If the value were still needed,
-    // a source materialization producing a replacement value "out of thin air"
-    // would have already been created during `replaceOp` or
-    // `applySignatureConversion`.)
-    return Value();
-  }
-  Value castValue = buildUnresolvedMaterialization(
-      MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
-      /*inputs=*/repl, /*outputType=*/value.getType(),
-      /*originalType=*/Type(), converter);
-  return castValue;
-}
-
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
 
@@ -1639,8 +1597,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
                              << "'(in region of '" << parentOp->getName()
                              << "'(" << from.getOwner()->getParentOp() << ")\n";
   });
-  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
-                                              impl->currentTypeConverter);
+  impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
   impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
 }
 
@@ -2460,6 +2417,10 @@ struct OperationConverter {
   /// Converts an operation with the given rewriter.
   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
 
+  /// This method is called after the conversion process to legalize any
+  /// remaining artifacts and complete the conversion.
+  void finalize(ConversionPatternRewriter &rewriter);
+
   /// Dialect conversion configuration.
   ConversionConfig config;
 
@@ -2580,6 +2541,11 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
     if (failed(convert(rewriter, op)))
       return rewriterImpl.undoRewrites(), failure();
 
+  // Now that all of the operations have been converted, finalize the conversion
+  // process to ensure any lingering conversion artifacts are cleaned up and
+  // legalized.
+  finalize(rewriter);
+
   // After a successful conversion, apply rewrites.
   rewriterImpl.applyRewrites();
 
@@ -2613,6 +2579,80 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   return success();
 }
 
+/// Finds a user of the given value, or of any other value that the given value
+/// replaced, that was not replaced in the conversion process.
+static Operation *findLiveUserOfReplaced(
+    Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
+    const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+  SmallVector<Value> worklist = {initialValue};
+  while (!worklist.empty()) {
+    Value value = worklist.pop_back_val();
+
+    // Walk the users of this value to see if there are any live users that
+    // weren't replaced during conversion.
+    auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
+      return rewriterImpl.isOpIgnored(user);
+    });
+    if (liveUserIt != value.user_end())
+      return *liveUserIt;
+    auto mapIt = inverseMapping.find(value);
+    if (mapIt != inverseMapping.end())
+      worklist.append(mapIt->second);
+  }
+  return nullptr;
+}
+
+/// Helper function that returns the replaced values and the type converter if
+/// the given rewrite object is an "operation replacement" or a "block type
+/// conversion" (which corresponds to a "block replacement"). Otherwise, return
+/// an empty ValueRange and a null type converter pointer.
+static std::pair<ValueRange, const TypeConverter *>
+getReplacedValues(IRRewrite *rewrite) {
+  if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
+    return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
+  if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
+    return {blockRewrite->getOrigBlock()->getArguments(),
+            blockRewrite->getConverter()};
+  return {};
+}
+
+void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
+  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+  DenseMap<Value, SmallVector<Value>> inverseMapping =
+      rewriterImpl.mapping.getInverse();
+
+  // Process requested value replacements.
+  for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
+    ValueRange replacedValues;
+    const TypeConverter *converter;
+    std::tie(replacedValues, converter) =
+        getReplacedValues(rewriterImpl.rewrites[i].get());
+    for (Value originalValue : replacedValues) {
+      // If the type of this value changed and the value is still live, we need
+      // to materialize a conversion.
+      if (rewriterImpl.mapping.lookupOrNull(originalValue,
+                                            originalValue.getType()))
+        continue;
+      Operation *liveUser =
+          findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
+      if (!liveUser)
+        continue;
+
+      // Legalize this value replacement.
+      Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
+      assert(newValue && "replacement value not found");
+      Value castValue = rewriterImpl.buildUnresolvedMaterialization(
+          MaterializationKind::Source, computeInsertPoint(newValue),
+          originalValue.getLoc(),
+          /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
+          /*originalType=*/Type(), converter);
+      rewriterImpl.mapping.map(originalValue, castValue);
+      inverseMapping[castValue].push_back(originalValue);
+      llvm::erase(inverseMapping[newValue], originalValue);
+    }
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // Reconcile Unrealized Casts
 //===----------------------------------------------------------------------===//


        


More information about the llvm-branch-commits mailing list