[Mlir-commits] [mlir] bbd4af5 - [mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (#97213)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jul 20 01:12:17 PDT 2024
Author: Matthias Springer
Date: 2024-07-20T10:12:13+02:00
New Revision: bbd4af5da2b741672a8e6f625eb12ea5c2d6220f
URL: https://github.com/llvm/llvm-project/commit/bbd4af5da2b741672a8e6f625eb12ea5c2d6220f
DIFF: https://github.com/llvm/llvm-project/commit/bbd4af5da2b741672a8e6f625eb12ea5c2d6220f.diff
LOG: [mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (#97213)
This commit simplifies the handling of dropped arguments and updates
some dialect conversion documentation that is outdated.
When converting a block signature, a `BlockTypeConversionRewrite` object
and potentially multiple `ReplaceBlockArgRewrite` are created. During
the "commit" phase, uses of the old block arguments are replaced with
the new block arguments, but the old implementation was written in an
inconsistent way: some block arguments were replaced in
`BlockTypeConversionRewrite::commit` and some were replaced in
`ReplaceBlockArgRewrite::commit`. The new
`BlockTypeConversionRewrite::commit` implementation is much simpler and
no longer modifies any IR; that is done only in `ReplaceBlockArgRewrite`
now. The `ConvertedArgInfo` data structure is no longer needed.
To that end, materializations of dropped arguments are now built in
`applySignatureConversion` instead of `materializeLiveConversions`; the
latter function no longer has to deal with dropped arguments.
Other minor improvements:
- Add more comments to `applySignatureConversion`.
Note: Error messages around failed materializations for dropped basic
block arguments changed slightly. That is because those materializations
are now built in `legalizeUnresolvedMaterialization` instead of
`legalizeConvertedArgumentTypes`.
This commit is in preparation of decoupling argument/source/target
materializations from the dialect conversion.
This is a re-upload of #96207.
Added:
Modified:
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Transforms/test-legalize-type-conversion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 1e0afee2373a9..0b552a7e1ca3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
Block *insertBeforeBlock;
};
-/// This structure contains the information pertaining to an argument that has
-/// been converted.
-struct ConvertedArgInfo {
- ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
- Value castValue = nullptr)
- : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
- /// The start index of in the new argument list that contains arguments that
- /// replace the original.
- unsigned newArgIdx;
-
- /// The number of arguments that replaced the original argument.
- unsigned newArgSize;
-
- /// The cast value that was created to cast from the new arguments to the
- /// old. This only used if 'newArgSize' > 1.
- Value castValue;
-};
-
/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
- BlockTypeConversionRewrite(
- ConversionPatternRewriterImpl &rewriterImpl, Block *block,
- Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
- const TypeConverter *converter)
+ BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Block *block, Block *origBlock,
+ const TypeConverter *converter)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
- origBlock(origBlock), argInfo(argInfo), converter(converter) {}
+ origBlock(origBlock), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
/// The original block that was requested to have its signature converted.
Block *origBlock;
- /// The conversion information for each of the arguments. The information is
- /// std::nullopt if the argument was dropped during conversion.
- SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
/// The type converter used to convert the arguments.
const TypeConverter *converter;
};
@@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
/// The type of materialization.
enum MaterializationKind {
/// This materialization materializes a conversion for an illegal block
- /// argument type, to a legal one.
+ /// argument type, to the original one.
Argument,
/// This materialization materializes a conversion from an illegal type to a
/// legal one.
- Target
+ Target,
+
+ /// This materialization materializes a conversion from a legal type back to
+ /// an illegal one.
+ Source
};
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
- llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+ llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
converterAndKind;
};
} // namespace
@@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange inputs, Type outputType,
const TypeConverter *converter);
- Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
- ValueRange inputs,
- Type outputType,
- const TypeConverter *converter);
-
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);
@@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
for (Operation *op : block->getUsers())
listener->notifyOperationModified(op);
-
- // Process the remapping for each of the original arguments.
- for (auto [origArg, info] :
- llvm::zip_equal(origBlock->getArguments(), argInfo)) {
- // Handle the case of a 1->0 value mapping.
- if (!info) {
- if (Value newArg =
- rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
- rewriter.replaceAllUsesWith(origArg, newArg);
- continue;
- }
-
- // Otherwise this is a 1->1+ value mapping.
- Value castValue = info->castValue;
- assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
- // If the argument is still used, replace it with the generated cast.
- if (!origArg.use_empty()) {
- rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
- castValue, origArg.getType()));
- }
- }
}
void BlockTypeConversionRewrite::rollback() {
@@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
continue;
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
- bool isDroppedArg = replacementValue == origArg;
- if (!isDroppedArg)
- builder.setInsertionPointAfterValue(replacementValue);
+ assert(replacementValue && "replacement value not found");
Value newArg;
if (converter) {
+ builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
- builder, origArg.getLoc(), origArg.getType(),
- isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+ builder, origArg.getLoc(), origArg.getType(), replacementValue);
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
@@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
- if (!isDroppedArg)
- diag << ", with target type " << replacementValue.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
@@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Replace all uses of the old block with the new block.
block->replaceAllUsesWith(newBlock);
- // Remap each of the original arguments as determined by the signature
- // conversion.
- SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
- argInfo.resize(origArgCount);
-
for (unsigned i = 0; i != origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap)
- continue;
BlockArgument origArg = block->getArgument(i);
+ Type origArgType = origArg.getType();
+
+ std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
+ signatureConversion.getInputMapping(i);
+ if (!inputMap) {
+ // This block argument was dropped and no replacement value was provided.
+ // Materialize a replacement value "out of thin air".
+ Value repl = buildUnresolvedMaterialization(
+ MaterializationKind::Source, newBlock, newBlock->begin(),
+ origArg.getLoc(), /*inputs=*/ValueRange(),
+ /*outputType=*/origArgType, converter);
+ mapping.map(origArg, repl);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+ continue;
+ }
- // If inputMap->replacementValue is not nullptr, then the argument is
- // dropped and a replacement value is provided to be the remappedValue.
- if (inputMap->replacementValue) {
+ if (Value repl = inputMap->replacementValue) {
+ // This block argument was dropped and a replacement value was provided.
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
- mapping.map(origArg, inputMap->replacementValue);
+ mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}
- // Otherwise, this is a 1->1+ mapping.
+ // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
+ // dialect conversion. Therefore, we need an argument materialization to
+ // turn the replacement block arguments into a single SSA value that can be
+ // used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- Value newArg;
+ Value argMat = buildUnresolvedMaterialization(
+ MaterializationKind::Argument, newBlock, newBlock->begin(),
+ origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
+ mapping.map(origArg, argMat);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
- // If this is a 1->1 mapping and the types of new and replacement arguments
- // match (i.e. it's an identity map), then the argument is mapped to its
- // original type.
// FIXME: We simply pass through the replacement argument if there wasn't a
// converter, which isn't great as it allows implicit type conversions to
// appear. We should properly restructure this code to handle cases where a
// converter isn't provided and also to properly handle the case where an
// argument materialization is actually a temporary source materialization
// (e.g. in the case of 1->N).
- if (replArgs.size() == 1 &&
- (!converter || replArgs[0].getType() == origArg.getType())) {
- newArg = replArgs.front();
- mapping.map(origArg, newArg);
- } else {
- // Build argument materialization: new block arguments -> old block
- // argument type.
- Value argMat = buildUnresolvedArgumentMaterialization(
- newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
- mapping.map(origArg, argMat);
-
- // Build target materialization: old block argument type -> legal type.
- // Note: This function returns an "empty" type if no valid conversion to
- // a legal type exists. In that case, we continue the conversion with the
- // original block argument type.
- Type legalOutputType = converter->convertType(origArg.getType());
- if (legalOutputType && legalOutputType != origArg.getType()) {
- newArg = buildUnresolvedTargetMaterialization(
- origArg.getLoc(), argMat, legalOutputType, converter);
- mapping.map(argMat, newArg);
- } else {
- newArg = argMat;
- }
+ Type legalOutputType;
+ if (converter)
+ legalOutputType = converter->convertType(origArgType);
+ if (legalOutputType && legalOutputType != origArgType) {
+ Value targetMat = buildUnresolvedTargetMaterialization(
+ origArg.getLoc(), argMat, legalOutputType, converter);
+ mapping.map(argMat, targetMat);
}
-
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
- argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
- converter);
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
-Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
- Block *block, Location loc, ValueRange inputs, Type outputType,
- const TypeConverter *converter) {
- return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
- block->begin(), loc, inputs, outputType,
- converter);
-}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
@@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
+ case MaterializationKind::Source:
+ newMaterialization = converter->materializeSourceConversion(
+ rewriter, op->getLoc(), outputType, inputOperands);
+ break;
}
if (newMaterialization) {
assert(newMaterialization.getType() == outputType &&
@@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
InFlightDiagnostic diag = op->emitError()
<< "failed to legalize unresolved materialization "
- "from "
- << inputOperands.getTypes() << " to " << outputType
+ "from ("
+ << inputOperands.getTypes() << ") to " << outputType
<< " that remained live after conversion";
if (Operation *liveUser = findLiveUser(op->getUsers())) {
diag.attachNote(liveUser->getLoc())
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index b35cda8e724f6..8254be68912c8 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -2,9 +2,8 @@
func.func @test_invalid_arg_materialization(
- // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
+ // expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
%arg0: i16) {
- // expected-note at below {{see existing live user here}}
"foo.return"(%arg0) : (i16) -> ()
}
@@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
// Make sure argument type changes aren't implicitly forwarded.
func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
- // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
+ // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
^bb0(%arg0: f32):
- // expected-note at below {{see existing live user here}}
"test.type_consumer"(%arg0) : (f32) -> ()
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
More information about the Mlir-commits
mailing list