[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Move argument materialization logic (PR #96329)
Matthias Springer
llvmlistbot at llvm.org
Mon Jun 24 23:44:48 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/96329
>From 50eec6bea675fb47aa36bdb49276632faff4aae7 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 20 Jun 2024 17:40:07 +0200
Subject: [PATCH] remove materializeLiveConversions
---
.../Transforms/Utils/DialectConversion.cpp | 133 +++++++-----------
1 file changed, 52 insertions(+), 81 deletions(-)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 07ebd687ee2b3..47e03383304af 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -53,6 +53,16 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
});
}
+/// Helper function that computes an insertion point where the given value is
+/// defined and can be used without a dominance violation.
+static OpBuilder::InsertPoint computeInsertPoint(Value value) {
+ Block *insertBlock = value.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = dyn_cast<OpResult>(value))
+ insertPt = ++inputRes.getOwner()->getIterator();
+ return OpBuilder::InsertPoint(insertBlock, insertPt);
+}
+
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
@@ -445,11 +455,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
return rewrite->getKind() == Kind::BlockTypeConversion;
}
- /// Materialize any necessary conversions for converted arguments that have
- /// live users, using the provided `findLiveUser` to search for a user that
- /// survives the conversion process.
- LogicalResult
- materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+ Block *getOrigBlock() const { return origBlock; }
+
+ const TypeConverter *getConverter() const { return converter; }
void commit(RewriterBase &rewriter) override;
@@ -841,14 +849,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value buildUnresolvedMaterialization(MaterializationKind kind,
- Block *insertBlock,
- Block::iterator insertPt, Location loc,
+ OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType,
Type origOutputType,
const TypeConverter *converter);
- Value buildUnresolvedTargetMaterialization(Location loc, Value input,
- Type outputType,
- const TypeConverter *converter);
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -981,49 +985,6 @@ void BlockTypeConversionRewrite::rollback() {
block->replaceAllUsesWith(origBlock);
}
-LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
- function_ref<Operation *(Value)> findLiveUser) {
- // Process the remapping for each of the original arguments.
- for (auto it : llvm::enumerate(origBlock->getArguments())) {
- BlockArgument origArg = it.value();
- // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
- OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
- builder.setInsertionPointToStart(block);
-
- // If the type of this argument changed and the argument is still live, we
- // need to materialize a conversion.
- if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
- continue;
- Operation *liveUser = findLiveUser(origArg);
- if (!liveUser)
- continue;
-
- Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
- assert(replacementValue && "replacement value not found");
- Value newArg;
- if (converter) {
- builder.setInsertionPointAfterValue(replacementValue);
- newArg = converter->materializeSourceConversion(
- builder, origArg.getLoc(), origArg.getType(), replacementValue);
- assert((!newArg || newArg.getType() == origArg.getType()) &&
- "materialization hook did not provide a value of the expected "
- "type");
- }
- if (!newArg) {
- InFlightDiagnostic diag =
- emitError(origArg.getLoc())
- << "failed to materialize conversion for block argument #"
- << it.index() << " that remained live after conversion, type was "
- << origArg.getType();
- diag.attachNote(liveUser->getLoc())
- << "see existing live user here: " << *liveUser;
- return failure();
- }
- rewriterImpl.mapping.map(origArg, newArg);
- }
- return success();
-}
-
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
if (!repl)
@@ -1196,8 +1157,10 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Type newOperandType = newOperand.getType();
if (currentTypeConverter && desiredType && newOperandType != desiredType) {
Location operandLoc = inputLoc ? *inputLoc : operand.getLoc();
- Value castValue = buildUnresolvedTargetMaterialization(
- operandLoc, newOperand, desiredType, currentTypeConverter);
+ Value castValue = buildUnresolvedMaterialization(
+ MaterializationKind::Target, computeInsertPoint(newOperand),
+ operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
+ /*origArgType=*/{}, currentTypeConverter);
mapping.map(mapping.lookupOrDefault(newOperand), castValue);
newOperand = castValue;
}
@@ -1325,8 +1288,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
- MaterializationKind::Source, newBlock, newBlock->begin(),
- origArg.getLoc(), /*inputs=*/ValueRange(),
+ MaterializationKind::Source,
+ OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+ /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*origArgType=*/{}, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1351,8 +1315,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value repl = buildUnresolvedMaterialization(
- MaterializationKind::Argument, newBlock, newBlock->begin(),
- origArg.getLoc(), /*inputs=*/replArgs,
+ MaterializationKind::Argument,
+ OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
+ /*inputs=*/replArgs,
/*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1374,8 +1339,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
- MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
- Location loc, ValueRange inputs, Type outputType, Type origArgType,
+ MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
+ ValueRange inputs, Type outputType, Type origArgType,
const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1383,25 +1348,13 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
- OpBuilder builder(insertBlock, insertPt);
+ OpBuilder builder(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
origArgType);
return convertOp.getResult(0);
}
-Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
- Location loc, Value input, Type outputType,
- const TypeConverter *converter) {
- Block *insertBlock = input.getParentBlock();
- Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = dyn_cast<OpResult>(input))
- insertPt = ++inputRes.getOwner()->getIterator();
-
- return buildUnresolvedMaterialization(
- MaterializationKind::Target, insertBlock, insertPt, loc, input,
- outputType, /*origArgType=*/{}, converter);
-}
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -2515,9 +2468,9 @@ LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
- if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
- inverseMapping)) ||
- failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
+ if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
+ failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
+ inverseMapping)))
return failure();
// Process requested operation replacements.
@@ -2573,10 +2526,28 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
++i) {
auto &rewrite = rewriterImpl.rewrites[i];
if (auto *blockTypeConversionRewrite =
- dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
- if (failed(blockTypeConversionRewrite->materializeLiveConversions(
- findLiveUser)))
- return failure();
+ dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) {
+ // Process the remapping for each of the original arguments.
+ for (Value origArg :
+ blockTypeConversionRewrite->getOrigBlock()->getArguments()) {
+ // If the type of this argument changed and the argument is still live,
+ // we need to materialize a conversion.
+ if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+ continue;
+ Operation *liveUser = findLiveUser(origArg);
+ if (!liveUser)
+ continue;
+
+ Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
+ assert(replacementValue && "replacement value not found");
+ Value repl = rewriterImpl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(replacementValue),
+ origArg.getLoc(), /*inputs=*/replacementValue,
+ /*outputType=*/origArg.getType(), /*origArgType=*/{},
+ blockTypeConversionRewrite->getConverter());
+ rewriterImpl.mapping.map(origArg, repl);
+ }
+ }
}
return success();
}
More information about the Mlir-commits
mailing list