[Mlir-commits] [mlir] f2d500c - [mlir][Transforms] Dialect conversion: Fix bug in `UnresolvedMaterializationRewrite` rollback (#105949)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 28 18:12:30 PST 2024
Author: Matthias Springer
Date: 2024-11-29T11:12:27+09:00
New Revision: f2d500c61701fc50f5c0c2cd9660a93e15ecc9b9
URL: https://github.com/llvm/llvm-project/commit/f2d500c61701fc50f5c0c2cd9660a93e15ecc9b9
DIFF: https://github.com/llvm/llvm-project/commit/f2d500c61701fc50f5c0c2cd9660a93e15ecc9b9.diff
LOG: [mlir][Transforms] Dialect conversion: Fix bug in `UnresolvedMaterializationRewrite` rollback (#105949)
When an unresolved materialization (`unrealized_conversion_cast` op) is
rolled back, the mapping should be rolled back as well, regardless of
whether it is a source, target or argument materialization. Otherwise,
we accumulate pointers to erased IR in the `mapping`. This is harmless
in most cases, but can cause issues when a new operation is allocated at
the same memory location and the pointer is "reused".
It is not possible to write a test case for this because I cannot
trigger the pointer reuse programmatically.
Added:
Modified:
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 710c976281dc3d..1424c4974f2d43 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -676,7 +676,8 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op,
const TypeConverter *converter,
- MaterializationKind kind, Type originalType);
+ MaterializationKind kind, Type originalType,
+ Value mappedValue);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -710,6 +711,10 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
/// The original type of the SSA value. Only used for target
/// materializations.
Type originalType;
+
+ /// The value in the conversion value mapping that is being replaced by the
+ /// results of this unresolved materialization.
+ Value mappedValue;
};
} // namespace
@@ -814,10 +819,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Build an unresolved materialization operation given an output type and set
/// of input operands.
+ ///
+ /// If `valueToMap` is set to a non-null Value, then that value is mapped to
+ /// the result of the unresolved materialization in the conversion value
+ /// mapping.
Value buildUnresolvedMaterialization(MaterializationKind kind,
OpBuilder::InsertPoint ip, Location loc,
- ValueRange inputs, Type outputType,
- Type originalType,
+ Value valueToMap, ValueRange inputs,
+ Type outputType, Type originalType,
const TypeConverter *converter);
/// Build an N:1 materialization for the given original value that was
@@ -1068,19 +1077,19 @@ void CreateOperationRewrite::rollback() {
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
- const TypeConverter *converter, MaterializationKind kind, Type originalType)
+ const TypeConverter *converter, MaterializationKind kind, Type originalType,
+ Value mappedValue)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind), originalType(originalType) {
+ converterAndKind(converter, kind), originalType(originalType),
+ mappedValue(mappedValue) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
rewriterImpl.unresolvedMaterializations[op] = this;
}
void UnresolvedMaterializationRewrite::rollback() {
- if (getMaterializationKind() == MaterializationKind::Target) {
- for (Value input : op->getOperands())
- rewriterImpl.mapping.erase(input);
- }
+ if (mappedValue)
+ rewriterImpl.mapping.erase(mappedValue);
rewriterImpl.unresolvedMaterializations.erase(getOperation());
op->erase();
}
@@ -1176,10 +1185,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// source materialization was created yet.
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
- operandLoc,
- /*inputs=*/newOperand, /*outputType=*/desiredType,
- /*originalType=*/origType, currentTypeConverter);
- mapping.map(newOperand, castValue);
+ operandLoc, /*valueToMap=*/newOperand, /*inputs=*/newOperand,
+ /*outputType=*/desiredType, /*originalType=*/origType,
+ currentTypeConverter);
newOperand = castValue;
}
remapped.push_back(newOperand);
@@ -1293,12 +1301,11 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
- Value repl = buildUnresolvedMaterialization(
+ buildUnresolvedMaterialization(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*inputs=*/ValueRange(),
+ /*valueToMap=*/origArg, /*inputs=*/ValueRange(),
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
- mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
continue;
}
@@ -1342,14 +1349,17 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- ValueRange inputs, Type outputType, Type originalType,
+ Value valueToMap, ValueRange inputs, Type outputType, Type originalType,
const TypeConverter *converter) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
// Avoid materializing an unnecessary cast.
- if (inputs.size() == 1 && inputs.front().getType() == outputType)
+ if (inputs.size() == 1 && inputs.front().getType() == outputType) {
+ if (valueToMap)
+ mapping.map(valueToMap, inputs.front());
return inputs.front();
+ }
// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
@@ -1357,8 +1367,10 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
+ if (valueToMap)
+ mapping.map(valueToMap, convertOp.getResult(0));
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
- originalType);
+ originalType, valueToMap);
return convertOp.getResult(0);
}
@@ -1367,11 +1379,10 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
Value originalValue, const TypeConverter *converter) {
// Insert argument materialization back to the original type.
Type originalType = originalValue.getType();
- Value argMat =
- buildUnresolvedMaterialization(MaterializationKind::Argument, ip, loc,
- /*inputs=*/replacements, originalType,
- /*originalType=*/Type(), converter);
- mapping.map(originalValue, argMat);
+ Value argMat = buildUnresolvedMaterialization(
+ MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
+ /*inputs=*/replacements, originalType, /*originalType=*/Type(),
+ converter);
// Insert target materialization to the legalized type.
Type legalOutputType;
@@ -1387,11 +1398,11 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
legalOutputType = replacements[0].getType();
}
if (legalOutputType && legalOutputType != originalType) {
- Value targetMat = buildUnresolvedMaterialization(
- MaterializationKind::Target, computeInsertPoint(argMat), loc,
- /*inputs=*/argMat, /*outputType=*/legalOutputType,
- /*originalType=*/originalType, converter);
- mapping.map(argMat, targetMat);
+ buildUnresolvedMaterialization(MaterializationKind::Target,
+ computeInsertPoint(argMat), loc,
+ /*valueToMap=*/argMat, /*inputs=*/argMat,
+ /*outputType=*/legalOutputType,
+ /*originalType=*/originalType, converter);
}
}
@@ -1425,9 +1436,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
}
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),
- /*inputs=*/repl, /*outputType=*/value.getType(),
+ /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(),
/*originalType=*/Type(), converter);
- mapping.map(value, castValue);
return castValue;
}
@@ -1480,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(
// Materialize a replacement value "out of thin air".
Value sourceMat = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
- result.getLoc(), /*inputs=*/ValueRange(),
+ result.getLoc(), /*valueToMap=*/Value(), /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), /*originalType=*/Type(),
currentTypeConverter);
repl.push_back(sourceMat);
More information about the Mlir-commits
mailing list