[Mlir-commits] [mlir] 605098d - Revert "[mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (#96329)"

Benjamin Kramer llvmlistbot at llvm.org
Thu Jun 27 00:27:22 PDT 2024


Author: Benjamin Kramer
Date: 2024-06-27T09:16:17+02:00
New Revision: 605098dcd4e79b27c86784b1a3d7fc6e3010ce00

URL: https://github.com/llvm/llvm-project/commit/605098dcd4e79b27c86784b1a3d7fc6e3010ce00
DIFF: https://github.com/llvm/llvm-project/commit/605098dcd4e79b27c86784b1a3d7fc6e3010ce00.diff

LOG: Revert "[mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (#96329)"

This reverts commit c01ce797619359ee282773dfc4b1e91ff0a30435. It depends
on f1e0657d144f5a3cfef4b625d0f875f4dacd21d1 which breaks SCF lowering.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 47e03383304af..07ebd687ee2b3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -53,16 +53,6 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   });
 }
 
-/// Helper function that computes an insertion point where the given value is
-/// defined and can be used without a dominance violation.
-static OpBuilder::InsertPoint computeInsertPoint(Value value) {
-  Block *insertBlock = value.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(value))
-    insertPt = ++inputRes.getOwner()->getIterator();
-  return OpBuilder::InsertPoint(insertBlock, insertPt);
-}
-
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
@@ -455,9 +445,11 @@ class BlockTypeConversionRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::BlockTypeConversion;
   }
 
-  Block *getOrigBlock() const { return origBlock; }
-
-  const TypeConverter *getConverter() const { return converter; }
+  /// Materialize any necessary conversions for converted arguments that have
+  /// live users, using the provided `findLiveUser` to search for a user that
+  /// survives the conversion process.
+  LogicalResult
+  materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
 
   void commit(RewriterBase &rewriter) override;
 
@@ -849,10 +841,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
-                                       OpBuilder::InsertPoint ip, Location loc,
+                                       Block *insertBlock,
+                                       Block::iterator insertPt, Location loc,
                                        ValueRange inputs, Type outputType,
                                        Type origOutputType,
                                        const TypeConverter *converter);
+  Value buildUnresolvedTargetMaterialization(Location loc, Value input,
+                                             Type outputType,
+                                             const TypeConverter *converter);
 
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
@@ -985,6 +981,49 @@ void BlockTypeConversionRewrite::rollback() {
   block->replaceAllUsesWith(origBlock);
 }
 
+LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
+    function_ref<Operation *(Value)> findLiveUser) {
+  // Process the remapping for each of the original arguments.
+  for (auto it : llvm::enumerate(origBlock->getArguments())) {
+    BlockArgument origArg = it.value();
+    // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
+    OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
+    builder.setInsertionPointToStart(block);
+
+    // If the type of this argument changed and the argument is still live, we
+    // need to materialize a conversion.
+    if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+      continue;
+    Operation *liveUser = findLiveUser(origArg);
+    if (!liveUser)
+      continue;
+
+    Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
+    assert(replacementValue && "replacement value not found");
+    Value newArg;
+    if (converter) {
+      builder.setInsertionPointAfterValue(replacementValue);
+      newArg = converter->materializeSourceConversion(
+          builder, origArg.getLoc(), origArg.getType(), replacementValue);
+      assert((!newArg || newArg.getType() == origArg.getType()) &&
+             "materialization hook did not provide a value of the expected "
+             "type");
+    }
+    if (!newArg) {
+      InFlightDiagnostic diag =
+          emitError(origArg.getLoc())
+          << "failed to materialize conversion for block argument #"
+          << it.index() << " that remained live after conversion, type was "
+          << origArg.getType();
+      diag.attachNote(liveUser->getLoc())
+          << "see existing live user here: " << *liveUser;
+      return failure();
+    }
+    rewriterImpl.mapping.map(origArg, newArg);
+  }
+  return success();
+}
+
 void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
   if (!repl)
@@ -1157,10 +1196,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     Type newOperandType = newOperand.getType();
     if (currentTypeConverter && desiredType && newOperandType != desiredType) {
       Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
-      Value castValue = buildUnresolvedMaterialization(
-          MaterializationKind::Target, computeInsertPoint(newOperand),
-          operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
-          /*origArgType=*/{}, currentTypeConverter);
+      Value castValue = buildUnresolvedTargetMaterialization(
+          operandLoc, newOperand, desiredType, currentTypeConverter);
       mapping.map(mapping.lookupOrDefault(newOperand), castValue);
       newOperand = castValue;
     }
@@ -1288,9 +1325,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
       // This block argument was dropped and no replacement value was provided.
       // Materialize a replacement value "out of thin air".
       Value repl = buildUnresolvedMaterialization(
-          MaterializationKind::Source,
-          OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-          /*inputs=*/ValueRange(),
+          MaterializationKind::Source, newBlock, newBlock->begin(),
+          origArg.getLoc(), /*inputs=*/ValueRange(),
           /*outputType=*/origArgType, /*origArgType=*/{}, converter);
       mapping.map(origArg, repl);
       appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1315,9 +1351,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
     auto replArgs =
         newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
     Value repl = buildUnresolvedMaterialization(
-        MaterializationKind::Argument,
-        OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
-        /*inputs=*/replArgs,
+        MaterializationKind::Argument, newBlock, newBlock->begin(),
+        origArg.getLoc(), /*inputs=*/replArgs,
         /*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
     mapping.map(origArg, repl);
     appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1339,8 +1374,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// Build an unresolved materialization operation given an output type and set
 /// of input operands.
 Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
-    MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueRange inputs, Type outputType, Type origArgType,
+    MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
+    Location loc, ValueRange inputs, Type outputType, Type origArgType,
     const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1348,13 +1383,25 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
 
   // Create an unresolved materialization. We use a new OpBuilder to avoid
   // tracking the materialization like we do for other operations.
-  OpBuilder builder(ip.getBlock(), ip.getPoint());
+  OpBuilder builder(insertBlock, insertPt);
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
   appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
                                                   origArgType);
   return convertOp.getResult(0);
 }
+Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
+    Location loc, Value input, Type outputType,
+    const TypeConverter *converter) {
+  Block *insertBlock = input.getParentBlock();
+  Block::iterator insertPt = insertBlock->begin();
+  if (OpResult inputRes = dyn_cast<OpResult>(input))
+    insertPt = ++inputRes.getOwner()->getIterator();
+
+  return buildUnresolvedMaterialization(
+      MaterializationKind::Target, insertBlock, insertPt, loc, input,
+      outputType, /*origArgType=*/{}, converter);
+}
 
 //===----------------------------------------------------------------------===//
 // Rewriter Notification Hooks
@@ -2468,9 +2515,9 @@ LogicalResult
 OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
-  if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
-      failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
-                                                inverseMapping)))
+  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
+                                                inverseMapping)) ||
+      failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
     return failure();
 
   // Process requested operation replacements.
@@ -2526,28 +2573,10 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
        ++i) {
     auto &rewrite = rewriterImpl.rewrites[i];
     if (auto *blockTypeConversionRewrite =
-            dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
-      // Process the remapping for each of the original arguments.
-      for (Value origArg :
-           blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
-        // If the type of this argument changed and the argument is still live,
-        // we need to materialize a conversion.
-        if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-          continue;
-        Operation *liveUser = findLiveUser(origArg);
-        if (!liveUser)
-          continue;
-
-        Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
-        assert(replacementValue && "replacement value not found");
-        Value repl = rewriterImpl.buildUnresolvedMaterialization(
-            MaterializationKind::Source, computeInsertPoint(replacementValue),
-            origArg.getLoc(), /*inputs=*/replacementValue,
-            /*outputType=*/origArg.getType(), /*origArgType=*/{},
-            blockTypeConversionRewrite->getConverter());
-        rewriterImpl.mapping.map(origArg, repl);
-      }
-    }
+            dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
+      if (failed(blockTypeConversionRewrite->materializeLiveConversions(
+              findLiveUser)))
+        return failure();
   }
   return success();
 }


        


More information about the Mlir-commits mailing list