[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (PR #96207)
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 21 05:00:56 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/96207
>From 8c0d89137f66f920fc542d7513b66da8d73d171d Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 20 Jun 2024 11:18:44 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: Simplify handling of
dropped arguments
This commit simplifies the handling of dropped arguments and updates some dialect conversion documentation that is outdated.
When converting a block signature, a `BlockTypeConversionRewrite` object and potentially multiple `ReplaceBlockArgRewrite` are created. During the "commit" phase, uses of the old block arguments are replaced with the new block arguments, but the old implementation was written in an inconsistent way: some block arguments were replaced in `BlockTypeConversionRewrite::commit` and some were replaced in `ReplaceBlockArgRewrite::commit`. The new `BlockTypeConversionRewrite::commit` implementation is much simpler and no longer modifies any IR; that is done only in `ReplaceBlockArgRewrite` now. The `ConvertedArgInfo` data structure is no longer needed.
To that end, materializations of dropped arguments are now built in `applySignatureConversion` instead of `materializeLiveConversions`; the latter function no longer has to deal with dropped arguments.
Other minor improvements:
- Improve variable name: `origOutputType` -> `origArgType`. Add an assertion to check that this field is only used for argument materializations.
- Add more comments to `applySignatureConversion`.
Note: Error messages around failed materializations for dropped basic block arguments changed slightly. That is because those materializations are now built in `legalizeUnresolvedMaterialization` instead of `legalizeConvertedArgumentTypes`.
---
mlir/docs/DialectConversion.md | 37 +++-
.../mlir/Transforms/DialectConversion.h | 10 +-
.../Transforms/Utils/DialectConversion.cpp | 208 +++++++-----------
.../test-legalize-type-conversion.mlir | 6 +-
4 files changed, 111 insertions(+), 150 deletions(-)
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 69781bb868bbf..f722974a9a1e5 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -246,6 +246,13 @@ depending on the situation.
- An argument materialization is used when converting the type of a block
argument during a [signature conversion](#region-signature-conversion).
+ The new block argument types are specified in a `SignatureConversion`
+ object. An original block argument can be converted into multiple
+ block arguments, which is not supported everywhere in the dialect
+ conversion. (E.g., adaptors support only a single replacement value for
+ each original value.) Therefore, an argument materialization is used to
+ convert potentially multiple new block arguments back into a single SSA
+ value.
* Source Materialization
@@ -259,6 +266,9 @@ depending on the situation.
* When a block argument has been converted to a different type, but
the original argument still has users that will remain live after
the conversion process has finished.
+ * When a block argument has been dropped, but the argument still has
+ users that will remain live after the conversion process has
+ finished.
* When the result type of an operation has been converted to a
different type, but the original result still has users that will
remain live after the conversion process is finished.
@@ -330,17 +340,19 @@ class TypeConverter {
/// Register a materialization function, which must be convertible to the
/// following form:
- /// `Optional<Value> (OpBuilder &, T, ValueRange, Location)`,
- /// where `T` is any subclass of `Type`.
- /// This function is responsible for creating an operation, using the
- /// OpBuilder and Location provided, that "converts" a range of values into a
- /// single value of the given type `T`. It must return a Value of the
- /// converted type on success, an `std::nullopt` if it failed but other
- /// materialization can be attempted, and `nullptr` on unrecoverable failure.
- /// It will only be called for (sub)types of `T`.
+ /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
+ /// where `T` is any subclass of `Type`. This function is responsible for
+ /// creating an operation, using the OpBuilder and Location provided, that
+ /// "casts" a range of values into a single value of the given type `T`. It
+ /// must return a Value of the converted type on success, an `std::nullopt` if
+ /// it failed but other materialization can be attempted, and `nullptr` on
+ /// unrecoverable failure. It will only be called for (sub)types of `T`.
+ /// Materialization functions must be provided when a type conversion may
+ /// persist after the conversion has finished.
///
/// This method registers a materialization that will be called when
- /// converting an illegal block argument type, to a legal type.
+ /// converting potentially multiple replacement block arguments (of a single
+ /// original block argument), to a single SSA value with a legal type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
@@ -348,8 +360,9 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
- /// converting a legal type to an illegal source type. This is used when
- /// conversions to an illegal type must persist beyond the main conversion.
+ /// converting a legal replacement value back to an illegal source type.
+ /// This is used when some uses of the original, illegal value must persist
+ /// beyond the main conversion.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) {
@@ -357,7 +370,7 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
- /// converting type from an illegal, or source, type to a legal type.
+ /// converting an illegal (source) value to a legal (target) type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f83f3a3fdf992..87b5dd9a6f340 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -181,7 +181,8 @@ class TypeConverter {
/// persist after the conversion has finished.
///
/// This method registers a materialization that will be called when
- /// converting an illegal block argument type, to a legal type.
+ /// converting potentially multiple replacement block arguments (of a single
+ /// original block argument), to a single SSA value with a legal type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
@@ -189,8 +190,9 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
- /// converting a legal type to an illegal source type. This is used when
- /// conversions to an illegal type must persist beyond the main conversion.
+ /// converting a legal replacement value back to an illegal source type.
+ /// This is used when some uses of the original, illegal value must persist
+ /// beyond the main conversion.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) {
@@ -198,7 +200,7 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
- /// converting type from an illegal, or source, type to a legal type.
+ /// converting an illegal (source) value to a legal (target) type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e6c0ee2ab2949..07ebd687ee2b3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
Block *insertBeforeBlock;
};
-/// This structure contains the information pertaining to an argument that has
-/// been converted.
-struct ConvertedArgInfo {
- ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
- Value castValue = nullptr)
- : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
-
- /// The start index of in the new argument list that contains arguments that
- /// replace the original.
- unsigned newArgIdx;
-
- /// The number of arguments that replaced the original argument.
- unsigned newArgSize;
-
- /// The cast value that was created to cast from the new arguments to the
- /// old. This only used if 'newArgSize' > 1.
- Value castValue;
-};
-
/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
- BlockTypeConversionRewrite(
- ConversionPatternRewriterImpl &rewriterImpl, Block *block,
- Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
- const TypeConverter *converter)
+ BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ Block *block, Block *origBlock,
+ const TypeConverter *converter)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
- origBlock(origBlock), argInfo(argInfo), converter(converter) {}
+ origBlock(origBlock), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
/// The original block that was requested to have its signature converted.
Block *origBlock;
- /// The conversion information for each of the arguments. The information is
- /// std::nullopt if the argument was dropped during conversion.
- SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
-
/// The type converter used to convert the arguments.
const TypeConverter *converter;
};
@@ -696,7 +672,11 @@ enum MaterializationKind {
/// This materialization materializes a conversion from an illegal type to a
/// legal one.
- Target
+ Target,
+
+ /// This materialization materializes a conversion from a legal type back to
+ /// an illegal one.
+ Source
};
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -708,9 +688,13 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
MaterializationKind kind = MaterializationKind::Target,
- Type origOutputType = nullptr)
+ Type origArgType = nullptr)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind), origOutputType(origOutputType) {}
+ converterAndKind(converter, kind), origArgType(origArgType) {
+ assert(kind == MaterializationKind::Argument ||
+ !origArgType && "orginal argument type make sense only for argument "
+ "materializations");
+ }
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -734,17 +718,17 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
return converterAndKind.getInt();
}
- /// Return the original illegal output type of the input values.
- Type getOrigOutputType() const { return origOutputType; }
+ /// Return the original type of the block argument.
+ Type getOrigArgType() const { return origArgType; }
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
- llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
+ llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
converterAndKind;
/// The original output type. This is only used for argument conversions.
- Type origOutputType;
+ Type origArgType;
};
} // namespace
@@ -862,13 +846,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange inputs, Type outputType,
Type origOutputType,
const TypeConverter *converter);
-
- Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
- ValueRange inputs,
- Type origOutputType,
- Type outputType,
- const TypeConverter *converter);
-
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);
@@ -998,28 +975,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
for (Operation *op : block->getUsers())
listener->notifyOperationModified(op);
-
- // Process the remapping for each of the original arguments.
- for (auto [origArg, info] :
- llvm::zip_equal(origBlock->getArguments(), argInfo)) {
- // Handle the case of a 1->0 value mapping.
- if (!info) {
- if (Value newArg =
- rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
- rewriter.replaceAllUsesWith(origArg, newArg);
- continue;
- }
-
- // Otherwise this is a 1->1+ value mapping.
- Value castValue = info->castValue;
- assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
-
- // If the argument is still used, replace it with the generated cast.
- if (!origArg.use_empty()) {
- rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
- castValue, origArg.getType()));
- }
- }
}
void BlockTypeConversionRewrite::rollback() {
@@ -1043,15 +998,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
if (!liveUser)
continue;
- Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
- bool isDroppedArg = replacementValue == origArg;
- if (!isDroppedArg)
- builder.setInsertionPointAfterValue(replacementValue);
+ Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
+ assert(replacementValue && "replacement value not found");
Value newArg;
if (converter) {
+ builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
- builder, origArg.getLoc(), origArg.getType(),
- isDroppedArg ? ValueRange() : ValueRange(replacementValue));
+ builder, origArg.getLoc(), origArg.getType(), replacementValue);
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
@@ -1062,8 +1015,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
- if (!isDroppedArg)
- diag << ", with target type " << replacementValue.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
@@ -1349,65 +1300,65 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Replace all uses of the old block with the new block.
block->replaceAllUsesWith(newBlock);
- // Remap each of the original arguments as determined by the signature
- // conversion.
- SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
- argInfo.resize(origArgCount);
-
for (unsigned i = 0; i != origArgCount; ++i) {
- auto inputMap = signatureConversion.getInputMapping(i);
- if (!inputMap)
- continue;
BlockArgument origArg = block->getArgument(i);
+ Type origArgType = origArg.getType();
- // If inputMap->replacementValue is not nullptr, then the argument is
- // dropped and a replacement value is provided to be the remappedValue.
- if (inputMap->replacementValue) {
- assert(inputMap->size == 0 &&
- "invalid to provide a replacement value when the argument isn't "
- "dropped");
- mapping.map(origArg, inputMap->replacementValue);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
- continue;
- }
-
- // Otherwise, this is a 1->1+ mapping.
- auto replArgs =
- newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- Value newArg;
-
- // If this is a 1->1 mapping and the types of new and replacement arguments
- // match (i.e. it's an identity map), then the argument is mapped to its
- // original type.
+ // Helper function that tries to legalize the given type. Returns the given
+ // type if it could not be legalized.
// FIXME: We simply pass through the replacement argument if there wasn't a
// converter, which isn't great as it allows implicit type conversions to
// appear. We should properly restructure this code to handle cases where a
// converter isn't provided and also to properly handle the case where an
// argument materialization is actually a temporary source materialization
// (e.g. in the case of 1->N).
- if (replArgs.size() == 1 &&
- (!converter || replArgs[0].getType() == origArg.getType())) {
- newArg = replArgs.front();
- } else {
- Type origOutputType = origArg.getType();
+ auto tryLegalizeType = [&](Type type) {
+ if (converter)
+ if (Type t = converter->convertType(type))
+ return t;
+ return type;
+ };
- // Legalize the argument output type.
- Type outputType = origOutputType;
- if (Type legalOutputType = converter->convertType(outputType))
- outputType = legalOutputType;
+ std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
+ signatureConversion.getInputMapping(i);
+ if (!inputMap) {
+ // This block argument was dropped and no replacement value was provided.
+ // Materialize a replacement value "out of thin air".
+ Value repl = buildUnresolvedMaterialization(
+ MaterializationKind::Source, newBlock, newBlock->begin(),
+ origArg.getLoc(), /*inputs=*/ValueRange(),
+ /*outputType=*/origArgType, /*origArgType=*/{}, converter);
+ mapping.map(origArg, repl);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+ continue;
+ }
- newArg = buildUnresolvedArgumentMaterialization(
- newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
- converter);
+ if (Value repl = inputMap->replacementValue) {
+ // This block argument was dropped and a replacement value was provided.
+ assert(inputMap->size == 0 &&
+ "invalid to provide a replacement value when the argument isn't "
+ "dropped");
+ mapping.map(origArg, repl);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+ continue;
}
- mapping.map(origArg, newArg);
+ // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
+ // dialect conversion. Therefore, we need an argument materialization to
+ // turn the replacement block arguments into a single SSA value that can be
+ // used as a replacement. The type of this SSA value is the legalized
+ // version of the original block argument type.
+ auto replArgs =
+ newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
+ Value repl = buildUnresolvedMaterialization(
+ MaterializationKind::Argument, newBlock, newBlock->begin(),
+ origArg.getLoc(), /*inputs=*/replArgs,
+ /*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
+ mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
- argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
- converter);
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1424,7 +1375,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
- Location loc, ValueRange inputs, Type outputType, Type origOutputType,
+ Location loc, ValueRange inputs, Type outputType, Type origArgType,
const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1436,16 +1387,9 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
- origOutputType);
+ origArgType);
return convertOp.getResult(0);
}
-Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
- Block *block, Location loc, ValueRange inputs, Type origOutputType,
- Type outputType, const TypeConverter *converter) {
- return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
- block->begin(), loc, inputs, outputType,
- origOutputType, converter);
-}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
@@ -1454,9 +1398,9 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
if (OpResult inputRes = dyn_cast<OpResult>(input))
insertPt = ++inputRes.getOwner()->getIterator();
- return buildUnresolvedMaterialization(MaterializationKind::Target,
- insertBlock, insertPt, loc, input,
- outputType, outputType, converter);
+ return buildUnresolvedMaterialization(
+ MaterializationKind::Target, insertBlock, insertPt, loc, input,
+ outputType, /*origArgType=*/{}, converter);
}
//===----------------------------------------------------------------------===//
@@ -2852,7 +2796,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
// easily misunderstood. We should clean up the argument hooks to better
// represent the desired invariants we actually care about.
newMaterialization = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
+ rewriter, op->getLoc(), mat.getOrigArgType(), inputOperands);
if (newMaterialization)
break;
@@ -2863,6 +2807,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
+ case MaterializationKind::Source:
+ newMaterialization = converter->materializeSourceConversion(
+ rewriter, op->getLoc(), outputType, inputOperands);
+ break;
}
if (newMaterialization) {
replaceMaterialization(rewriterImpl, opResult, newMaterialization,
@@ -2873,8 +2821,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
InFlightDiagnostic diag = op->emitError()
<< "failed to legalize unresolved materialization "
- "from "
- << inputOperands.getTypes() << " to " << outputType
+ "from ("
+ << inputOperands.getTypes() << ") to " << outputType
<< " that remained live after conversion";
if (Operation *liveUser = findLiveUser(op->getUsers())) {
diag.attachNote(liveUser->getLoc())
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index b35cda8e724f6..8254be68912c8 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -2,9 +2,8 @@
func.func @test_invalid_arg_materialization(
- // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
+ // expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
%arg0: i16) {
- // expected-note at below {{see existing live user here}}
"foo.return"(%arg0) : (i16) -> ()
}
@@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
// Make sure argument type changes aren't implicitly forwarded.
func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
- // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
+ // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
^bb0(%arg0: f32):
- // expected-note at below {{see existing live user here}}
"test.type_consumer"(%arg0) : (f32) -> ()
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
More information about the Mlir-commits
mailing list