[llvm-branch-commits] [mlir] fe0ac00 - Revert "[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase (…"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Nov 20 17:40:36 PST 2024
Author: Matthias Springer
Date: 2024-11-21T10:40:33+09:00
New Revision: fe0ac007ca9e253e79d2dc0e95ce166efd585a5b
URL: https://github.com/llvm/llvm-project/commit/fe0ac007ca9e253e79d2dc0e95ce166efd585a5b
DIFF: https://github.com/llvm/llvm-project/commit/fe0ac007ca9e253e79d2dc0e95ce166efd585a5b.diff
LOG: Revert "[mlir][Transforms][NFC] Dialect conversion: Remove "finalize" phase (…"
This reverts commit aa65473c9ddcf3cbb80e63c38af842d05346374b.
Added:
Modified:
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 03d483f73f255e..42fe5b925654a1 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -75,10 +75,6 @@ namespace {
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
struct ConversionValueMapping {
- /// Return "true" if an SSA value is mapped to the given value. May return
- /// false positives.
- bool isMappedTo(Value value) const { return mappedTo.contains(value); }
-
/// Lookup the most recently mapped value with the desired type in the
/// mapping.
///
@@ -103,18 +99,22 @@ struct ConversionValueMapping {
assert(it != oldVal && "inserting cyclic mapping");
});
mapping.map(oldVal, newVal);
- mappedTo.insert(newVal);
}
/// Drop the last mapping for the given value.
void erase(Value value) { mapping.erase(value); }
+ /// Returns the inverse raw value mapping (without recursive query support).
+ DenseMap<Value, SmallVector<Value>> getInverse() const {
+ DenseMap<Value, SmallVector<Value>> inverse;
+ for (auto &it : mapping.getValueMap())
+ inverse[it.second].push_back(it.first);
+ return inverse;
+ }
+
private:
/// Current value mappings.
IRMapping mapping;
-
- /// All SSA values that are mapped to. May contain false positives.
- DenseSet<Value> mappedTo;
};
} // namespace
@@ -434,9 +434,10 @@ class MoveBlockRewrite : public BlockRewrite {
class BlockTypeConversionRewrite : public BlockRewrite {
public:
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, Block *origBlock)
+ Block *block, Block *origBlock,
+ const TypeConverter *converter)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
- origBlock(origBlock) {}
+ origBlock(origBlock), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -444,6 +445,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
Block *getOrigBlock() const { return origBlock; }
+ const TypeConverter *getConverter() const { return converter; }
+
void commit(RewriterBase &rewriter) override;
void rollback() override;
@@ -451,6 +454,9 @@ class BlockTypeConversionRewrite : public BlockRewrite {
private:
/// The original block that was requested to have its signature converted.
Block *origBlock;
+
+ /// The type converter used to convert the arguments.
+ const TypeConverter *converter;
};
/// Replacing a block argument. This rewrite is not immediately reflected in the
@@ -459,10 +465,8 @@ class BlockTypeConversionRewrite : public BlockRewrite {
class ReplaceBlockArgRewrite : public BlockRewrite {
public:
ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, BlockArgument arg,
- const TypeConverter *converter)
- : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg),
- converter(converter) {}
+ Block *block, BlockArgument arg)
+ : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::ReplaceBlockArg;
@@ -474,9 +478,6 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
private:
BlockArgument arg;
-
- /// The current type converter when the block argument was replaced.
- const TypeConverter *converter;
};
/// An operation rewrite.
@@ -626,6 +627,8 @@ class ReplaceOperationRewrite : public OperationRewrite {
void cleanup(RewriterBase &rewriter) override;
+ const TypeConverter *getConverter() const { return converter; }
+
private:
/// An optional type converter that can be used to materialize conversions
/// between the new and old values if necessary.
@@ -822,14 +825,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange replacements, Value originalValue,
const TypeConverter *converter);
- /// Find a replacement value for the given SSA value in the conversion value
- /// mapping. The replacement value must have the same type as the given SSA
- /// value. If there is no replacement value with the correct type, find the
- /// latest replacement value (regardless of the type) and build a source
- /// materialization.
- Value findOrBuildReplacementValue(Value value,
- const TypeConverter *converter);
-
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -975,7 +970,7 @@ void BlockTypeConversionRewrite::rollback() {
}
void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+ Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
if (!repl)
return;
@@ -1004,7 +999,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
// Compute replacement values.
SmallVector<Value> replacements =
llvm::map_to_vector(op->getResults(), [&](OpResult result) {
- return rewriterImpl.findOrBuildReplacementValue(result, converter);
+ return rewriterImpl.mapping.lookupOrNull(result, result.getType());
});
// Notify the listener that the operation is about to be replaced.
@@ -1074,10 +1069,8 @@ void UnresolvedMaterializationRewrite::rollback() {
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
IRRewriter rewriter(context, config.listener);
- // Note: New rewrites may be added during the "commit" phase and the
- // `rewrites` vector may reallocate.
- for (size_t i = 0; i < rewrites.size(); ++i)
- rewrites[i]->commit(rewriter);
+ for (auto &rewrite : rewrites)
+ rewrite->commit(rewriter);
// Clean up all rewrites.
for (auto &rewrite : rewrites)
@@ -1282,7 +1275,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
mapping.map(origArg, repl);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
@@ -1292,7 +1285,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, repl);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
@@ -1305,10 +1298,10 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
insertNTo1Materialization(
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
}
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block);
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1378,41 +1371,6 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
}
}
-Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
- Value value, const TypeConverter *converter) {
- // Find a replacement value with the same type.
- Value repl = mapping.lookupOrNull(value, value.getType());
- if (repl)
- return repl;
-
- // Check if the value is dead. No replacement value is needed in that case.
- // This is an approximate check that may have false negatives but does not
- // require computing and traversing an inverse mapping. (We may end up
- // building source materializations that are never used and that fold away.)
- if (llvm::all_of(value.getUsers(),
- [&](Operation *op) { return replacedOps.contains(op); }) &&
- !mapping.isMappedTo(value))
- return Value();
-
- // No replacement value was found. Get the latest replacement value
- // (regardless of the type) and build a source materialization to the
- // original type.
- repl = mapping.lookupOrNull(value);
- if (!repl) {
- // No replacement value is registered in the mapping. This means that the
- // value is dropped and no longer needed. (If the value were still needed,
- // a source materialization producing a replacement value "out of thin air"
- // would have already been created during `replaceOp` or
- // `applySignatureConversion`.)
- return Value();
- }
- Value castValue = buildUnresolvedMaterialization(
- MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
- /*inputs=*/repl, /*outputType=*/value.getType(),
- /*originalType=*/Type(), converter);
- return castValue;
-}
-
//===----------------------------------------------------------------------===//
// Rewriter Notification Hooks
@@ -1639,8 +1597,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
<< "'(in region of '" << parentOp->getName()
<< "'(" << from.getOwner()->getParentOp() << ")\n";
});
- impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from,
- impl->currentTypeConverter);
+ impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
impl->mapping.map(impl->mapping.lookupOrDefault(from), to);
}
@@ -2460,6 +2417,10 @@ struct OperationConverter {
/// Converts an operation with the given rewriter.
LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
+ /// This method is called after the conversion process to legalize any
+ /// remaining artifacts and complete the conversion.
+ void finalize(ConversionPatternRewriter &rewriter);
+
/// Dialect conversion configuration.
ConversionConfig config;
@@ -2580,6 +2541,11 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (failed(convert(rewriter, op)))
return rewriterImpl.undoRewrites(), failure();
+ // Now that all of the operations have been converted, finalize the conversion
+ // process to ensure any lingering conversion artifacts are cleaned up and
+ // legalized.
+ finalize(rewriter);
+
// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();
@@ -2613,6 +2579,80 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return success();
}
+/// Finds a user of the given value, or of any other value that the given value
+/// replaced, that was not replaced in the conversion process.
+static Operation *findLiveUserOfReplaced(
+ Value initialValue, ConversionPatternRewriterImpl &rewriterImpl,
+ const DenseMap<Value, SmallVector<Value>> &inverseMapping) {
+ SmallVector<Value> worklist = {initialValue};
+ while (!worklist.empty()) {
+ Value value = worklist.pop_back_val();
+
+ // Walk the users of this value to see if there are any live users that
+ // weren't replaced during conversion.
+ auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) {
+ return rewriterImpl.isOpIgnored(user);
+ });
+ if (liveUserIt != value.user_end())
+ return *liveUserIt;
+ auto mapIt = inverseMapping.find(value);
+ if (mapIt != inverseMapping.end())
+ worklist.append(mapIt->second);
+ }
+ return nullptr;
+}
+
+/// Helper function that returns the replaced values and the type converter if
+/// the given rewrite object is an "operation replacement" or a "block type
+/// conversion" (which corresponds to a "block replacement"). Otherwise, return
+/// an empty ValueRange and a null type converter pointer.
+static std::pair<ValueRange, const TypeConverter *>
+getReplacedValues(IRRewrite *rewrite) {
+ if (auto *opRewrite = dyn_cast<ReplaceOperationRewrite>(rewrite))
+ return {opRewrite->getOperation()->getResults(), opRewrite->getConverter()};
+ if (auto *blockRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite))
+ return {blockRewrite->getOrigBlock()->getArguments(),
+ blockRewrite->getConverter()};
+ return {};
+}
+
+void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
+ ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+ DenseMap<Value, SmallVector<Value>> inverseMapping =
+ rewriterImpl.mapping.getInverse();
+
+ // Process requested value replacements.
+ for (unsigned i = 0, e = rewriterImpl.rewrites.size(); i < e; ++i) {
+ ValueRange replacedValues;
+ const TypeConverter *converter;
+ std::tie(replacedValues, converter) =
+ getReplacedValues(rewriterImpl.rewrites[i].get());
+ for (Value originalValue : replacedValues) {
+ // If the type of this value changed and the value is still live, we need
+ // to materialize a conversion.
+ if (rewriterImpl.mapping.lookupOrNull(originalValue,
+ originalValue.getType()))
+ continue;
+ Operation *liveUser =
+ findLiveUserOfReplaced(originalValue, rewriterImpl, inverseMapping);
+ if (!liveUser)
+ continue;
+
+ // Legalize this value replacement.
+ Value newValue = rewriterImpl.mapping.lookupOrNull(originalValue);
+ assert(newValue && "replacement value not found");
+ Value castValue = rewriterImpl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(newValue),
+ originalValue.getLoc(),
+ /*inputs=*/newValue, /*outputType=*/originalValue.getType(),
+ /*originalType=*/Type(), converter);
+ rewriterImpl.mapping.map(originalValue, castValue);
+ inverseMapping[castValue].push_back(originalValue);
+ llvm::erase(inverseMapping[newValue], originalValue);
+ }
+ }
+}
+
//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
More information about the llvm-branch-commits
mailing list