[Mlir-commits] [mlir] 4d46b46 - Revert "[mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (#96207)"
Benjamin Kramer
llvmlistbot at llvm.org
Thu Jun 27 00:27:24 PDT 2024
Author: Benjamin Kramer
Date: 2024-06-27T09:16:40+02:00
New Revision: 4d46b460f9fe00c33545d9b0b320194d5e4b49b5
URL: https://github.com/llvm/llvm-project/commit/4d46b460f9fe00c33545d9b0b320194d5e4b49b5
DIFF: https://github.com/llvm/llvm-project/commit/4d46b460f9fe00c33545d9b0b320194d5e4b49b5.diff
LOG: Revert "[mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments (#96207)"
This reverts commit f1e0657d144f5a3cfef4b625d0f875f4dacd21d1.
It breaks SCF conversion, see test case on the PR.
Added:
Modified:
mlir/docs/DialectConversion.md
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/Transforms/test-legalize-type-conversion.mlir
Removed:
################################################################################
diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md
index 3002ac004490e..69781bb868bbf 100644
--- a/mlir/docs/DialectConversion.md
+++ b/mlir/docs/DialectConversion.md
@@ -246,13 +246,6 @@ depending on the situation.
- An argument materialization is used when converting the type of a block
argument during a [signature conversion](#region-signature-conversion).
- The new block argument types are specified in a `SignatureConversion`
- object. An original block argument can be converted into multiple
- block arguments, which is not supported everywhere in the dialect
- conversion. (E.g., adaptors support only a single replacement value for
- each original value.) Therefore, an argument materialization is used to
- convert potentially multiple new block arguments back into a single SSA
- value.
* Source Materialization
@@ -266,9 +259,6 @@ depending on the situation.
* When a block argument has been converted to a
diff erent type, but
the original argument still has users that will remain live after
the conversion process has finished.
- * When a block argument has been dropped, but the argument still has
- users that will remain live after the conversion process has
- finished.
* When the result type of an operation has been converted to a
diff erent type, but the original result still has users that will
remain live after the conversion process is finished.
@@ -338,22 +328,19 @@ class TypeConverter {
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
}
- /// All of the following materializations require function objects that are
- /// convertible to the following form:
- /// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
- /// where `T` is any subclass of `Type`. This function is responsible for
- /// creating an operation, using the OpBuilder and Location provided, that
- /// "casts" a range of values into a single value of the given type `T`. It
- /// must return a Value of the converted type on success, an `std::nullopt` if
- /// it failed but other materialization can be attempted, and `nullptr` on
- /// unrecoverable failure. It will only be called for (sub)types of `T`.
- /// Materialization functions must be provided when a type conversion may
- /// persist after the conversion has finished.
-
+ /// Register a materialization function, which must be convertible to the
+ /// following form:
+ /// `Optional<Value> (OpBuilder &, T, ValueRange, Location)`,
+ /// where `T` is any subclass of `Type`.
+ /// This function is responsible for creating an operation, using the
+ /// OpBuilder and Location provided, that "converts" a range of values into a
+ /// single value of the given type `T`. It must return a Value of the
+ /// converted type on success, an `std::nullopt` if it failed but other
+ /// materialization can be attempted, and `nullptr` on unrecoverable failure.
+ /// It will only be called for (sub)types of `T`.
+ ///
/// This method registers a materialization that will be called when
- /// converting (potentially multiple) block arguments that were the result of
- /// a signature conversion of a single block argument, to a single SSA value
- /// of a legal type.
+ /// converting an illegal block argument type, to a legal type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
@@ -361,9 +348,8 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
- /// converting a legal replacement value back to an illegal source type.
- /// This is used when some uses of the original, illegal value must persist
- /// beyond the main conversion.
+ /// converting a legal type to an illegal source type. This is used when
+ /// conversions to an illegal type must persist beyond the main conversion.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) {
@@ -371,7 +357,7 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
- /// converting an illegal (source) value to a legal (target) type.
+ /// converting type from an illegal, or source, type to a legal type.
template <typename FnT,
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index e17f6f682c794..f83f3a3fdf992 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -168,8 +168,8 @@ class TypeConverter {
registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
}
- /// All of the following materializations require function objects that are
- /// convertible to the following form:
+ /// Register a materialization function, which must be convertible to the
+ /// following form:
/// `std::optional<Value>(OpBuilder &, T, ValueRange, Location)`,
/// where `T` is any subclass of `Type`. This function is responsible for
/// creating an operation, using the OpBuilder and Location provided, that
@@ -179,11 +179,9 @@ class TypeConverter {
/// unrecoverable failure. It will only be called for (sub)types of `T`.
/// Materialization functions must be provided when a type conversion may
/// persist after the conversion has finished.
-
+ ///
/// This method registers a materialization that will be called when
- /// converting (potentially multiple) block arguments that were the result of
- /// a signature conversion of a single block argument, to a single SSA value
- /// of a legal type.
+ /// converting an illegal block argument type, to a legal type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addArgumentMaterialization(FnT &&callback) {
@@ -191,9 +189,8 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
- /// converting a legal replacement value back to an illegal source type.
- /// This is used when some uses of the original, illegal value must persist
- /// beyond the main conversion.
+ /// converting a legal type to an illegal source type. This is used when
+ /// conversions to an illegal type must persist beyond the main conversion.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addSourceMaterialization(FnT &&callback) {
@@ -201,7 +198,7 @@ class TypeConverter {
wrapMaterialization<T>(std::forward<FnT>(callback)));
}
/// This method registers a materialization that will be called when
- /// converting an illegal (source) value to a legal (target) type.
+ /// converting type from an illegal, or source, type to a legal type.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 07ebd687ee2b3..e6c0ee2ab2949 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -432,14 +432,34 @@ class MoveBlockRewrite : public BlockRewrite {
Block *insertBeforeBlock;
};
+/// This structure contains the information pertaining to an argument that has
+/// been converted.
+struct ConvertedArgInfo {
+ ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
+ Value castValue = nullptr)
+ : newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}
+
+ /// The start index of in the new argument list that contains arguments that
+ /// replace the original.
+ unsigned newArgIdx;
+
+ /// The number of arguments that replaced the original argument.
+ unsigned newArgSize;
+
+ /// The cast value that was created to cast from the new arguments to the
+ /// old. This only used if 'newArgSize' > 1.
+ Value castValue;
+};
+
/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
- BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
- Block *block, Block *origBlock,
- const TypeConverter *converter)
+ BlockTypeConversionRewrite(
+ ConversionPatternRewriterImpl &rewriterImpl, Block *block,
+ Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
+ const TypeConverter *converter)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
- origBlock(origBlock), converter(converter) {}
+ origBlock(origBlock), argInfo(argInfo), converter(converter) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
@@ -459,6 +479,10 @@ class BlockTypeConversionRewrite : public BlockRewrite {
/// The original block that was requested to have its signature converted.
Block *origBlock;
+ /// The conversion information for each of the arguments. The information is
+ /// std::nullopt if the argument was dropped during conversion.
+ SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+
/// The type converter used to convert the arguments.
const TypeConverter *converter;
};
@@ -672,11 +696,7 @@ enum MaterializationKind {
/// This materialization materializes a conversion from an illegal type to a
/// legal one.
- Target,
-
- /// This materialization materializes a conversion from a legal type back to
- /// an illegal one.
- Source
+ Target
};
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
@@ -688,13 +708,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
MaterializationKind kind = MaterializationKind::Target,
- Type origArgType = nullptr)
+ Type origOutputType = nullptr)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind), origArgType(origArgType) {
- assert(kind == MaterializationKind::Argument ||
- !origArgType && "orginal argument type make sense only for argument "
- "materializations");
- }
+ converterAndKind(converter, kind), origOutputType(origOutputType) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -718,17 +734,17 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
return converterAndKind.getInt();
}
- /// Return the original type of the block argument.
- Type getOrigArgType() const { return origArgType; }
+ /// Return the original illegal output type of the input values.
+ Type getOrigOutputType() const { return origOutputType; }
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
- llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
+ llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
converterAndKind;
/// The original output type. This is only used for argument conversions.
- Type origArgType;
+ Type origOutputType;
};
} // namespace
@@ -846,6 +862,13 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange inputs, Type outputType,
Type origOutputType,
const TypeConverter *converter);
+
+ Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
+ ValueRange inputs,
+ Type origOutputType,
+ Type outputType,
+ const TypeConverter *converter);
+
Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);
@@ -975,6 +998,28 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
for (Operation *op : block->getUsers())
listener->notifyOperationModified(op);
+
+ // Process the remapping for each of the original arguments.
+ for (auto [origArg, info] :
+ llvm::zip_equal(origBlock->getArguments(), argInfo)) {
+ // Handle the case of a 1->0 value mapping.
+ if (!info) {
+ if (Value newArg =
+ rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+ rewriter.replaceAllUsesWith(origArg, newArg);
+ continue;
+ }
+
+ // Otherwise this is a 1->1+ value mapping.
+ Value castValue = info->castValue;
+ assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");
+
+ // If the argument is still used, replace it with the generated cast.
+ if (!origArg.use_empty()) {
+ rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
+ castValue, origArg.getType()));
+ }
+ }
}
void BlockTypeConversionRewrite::rollback() {
@@ -998,13 +1043,15 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
if (!liveUser)
continue;
- Value replacementValue = rewriterImpl.mapping.lookupOrNull(origArg);
- assert(replacementValue && "replacement value not found");
+ Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
+ bool isDroppedArg = replacementValue == origArg;
+ if (!isDroppedArg)
+ builder.setInsertionPointAfterValue(replacementValue);
Value newArg;
if (converter) {
- builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
- builder, origArg.getLoc(), origArg.getType(), replacementValue);
+ builder, origArg.getLoc(), origArg.getType(),
+ isDroppedArg ? ValueRange() : ValueRange(replacementValue));
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
@@ -1015,6 +1062,8 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
+ if (!isDroppedArg)
+ diag << ", with target type " << replacementValue.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
@@ -1300,65 +1349,65 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Replace all uses of the old block with the new block.
block->replaceAllUsesWith(newBlock);
+ // Remap each of the original arguments as determined by the signature
+ // conversion.
+ SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
+ argInfo.resize(origArgCount);
+
for (unsigned i = 0; i != origArgCount; ++i) {
+ auto inputMap = signatureConversion.getInputMapping(i);
+ if (!inputMap)
+ continue;
BlockArgument origArg = block->getArgument(i);
- Type origArgType = origArg.getType();
- // Helper function that tries to legalize the given type. Returns the given
- // type if it could not be legalized.
+ // If inputMap->replacementValue is not nullptr, then the argument is
+ // dropped and a replacement value is provided to be the remappedValue.
+ if (inputMap->replacementValue) {
+ assert(inputMap->size == 0 &&
+ "invalid to provide a replacement value when the argument isn't "
+ "dropped");
+ mapping.map(origArg, inputMap->replacementValue);
+ appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+ continue;
+ }
+
+ // Otherwise, this is a 1->1+ mapping.
+ auto replArgs =
+ newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
+ Value newArg;
+
+ // If this is a 1->1 mapping and the types of new and replacement arguments
+ // match (i.e. it's an identity map), then the argument is mapped to its
+ // original type.
// FIXME: We simply pass through the replacement argument if there wasn't a
// converter, which isn't great as it allows implicit type conversions to
// appear. We should properly restructure this code to handle cases where a
// converter isn't provided and also to properly handle the case where an
// argument materialization is actually a temporary source materialization
// (e.g. in the case of 1->N).
- auto tryLegalizeType = [&](Type type) {
- if (converter)
- if (Type t = converter->convertType(type))
- return t;
- return type;
- };
+ if (replArgs.size() == 1 &&
+ (!converter || replArgs[0].getType() == origArg.getType())) {
+ newArg = replArgs.front();
+ } else {
+ Type origOutputType = origArg.getType();
- std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
- signatureConversion.getInputMapping(i);
- if (!inputMap) {
- // This block argument was dropped and no replacement value was provided.
- // Materialize a replacement value "out of thin air".
- Value repl = buildUnresolvedMaterialization(
- MaterializationKind::Source, newBlock, newBlock->begin(),
- origArg.getLoc(), /*inputs=*/ValueRange(),
- /*outputType=*/origArgType, /*origArgType=*/{}, converter);
- mapping.map(origArg, repl);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
- continue;
- }
+ // Legalize the argument output type.
+ Type outputType = origOutputType;
+ if (Type legalOutputType = converter->convertType(outputType))
+ outputType = legalOutputType;
- if (Value repl = inputMap->replacementValue) {
- // This block argument was dropped and a replacement value was provided.
- assert(inputMap->size == 0 &&
- "invalid to provide a replacement value when the argument isn't "
- "dropped");
- mapping.map(origArg, repl);
- appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
- continue;
+ newArg = buildUnresolvedArgumentMaterialization(
+ newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
+ converter);
}
- // This is a 1->1+ mapping. 1->N mappings are not fully supported in the
- // dialect conversion. Therefore, we need an argument materialization to
- // turn the replacement block arguments into a single SSA value that can be
- // used as a replacement. The type of this SSA value is the legalized
- // version of the original block argument type.
- auto replArgs =
- newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
- Value repl = buildUnresolvedMaterialization(
- MaterializationKind::Argument, newBlock, newBlock->begin(),
- origArg.getLoc(), /*inputs=*/replArgs,
- /*outputType=*/tryLegalizeType(origArgType), origArgType, converter);
- mapping.map(origArg, repl);
+ mapping.map(origArg, newArg);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
+ argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}
- appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);
+ appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
+ converter);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1375,7 +1424,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
- Location loc, ValueRange inputs, Type outputType, Type origArgType,
+ Location loc, ValueRange inputs, Type outputType, Type origOutputType,
const TypeConverter *converter) {
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1387,9 +1436,16 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
- origArgType);
+ origOutputType);
return convertOp.getResult(0);
}
+Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
+ Block *block, Location loc, ValueRange inputs, Type origOutputType,
+ Type outputType, const TypeConverter *converter) {
+ return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
+ block->begin(), loc, inputs, outputType,
+ origOutputType, converter);
+}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
@@ -1398,9 +1454,9 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
if (OpResult inputRes = dyn_cast<OpResult>(input))
insertPt = ++inputRes.getOwner()->getIterator();
- return buildUnresolvedMaterialization(
- MaterializationKind::Target, insertBlock, insertPt, loc, input,
- outputType, /*origArgType=*/{}, converter);
+ return buildUnresolvedMaterialization(MaterializationKind::Target,
+ insertBlock, insertPt, loc, input,
+ outputType, outputType, converter);
}
//===----------------------------------------------------------------------===//
@@ -2796,7 +2852,7 @@ static LogicalResult legalizeUnresolvedMaterialization(
// easily misunderstood. We should clean up the argument hooks to better
// represent the desired invariants we actually care about.
newMaterialization = converter->materializeArgumentConversion(
- rewriter, op->getLoc(), mat.getOrigArgType(), inputOperands);
+ rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
if (newMaterialization)
break;
@@ -2807,10 +2863,6 @@ static LogicalResult legalizeUnresolvedMaterialization(
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
- case MaterializationKind::Source:
- newMaterialization = converter->materializeSourceConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
- break;
}
if (newMaterialization) {
replaceMaterialization(rewriterImpl, opResult, newMaterialization,
@@ -2821,8 +2873,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
InFlightDiagnostic diag = op->emitError()
<< "failed to legalize unresolved materialization "
- "from ("
- << inputOperands.getTypes() << ") to " << outputType
+ "from "
+ << inputOperands.getTypes() << " to " << outputType
<< " that remained live after conversion";
if (Operation *liveUser = findLiveUser(op->getUsers())) {
diag.attachNote(liveUser->getLoc())
diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir
index 8254be68912c8..b35cda8e724f6 100644
--- a/mlir/test/Transforms/test-legalize-type-conversion.mlir
+++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir
@@ -2,8 +2,9 @@
func.func @test_invalid_arg_materialization(
- // expected-error at below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
+ // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
%arg0: i16) {
+ // expected-note at below {{see existing live user here}}
"foo.return"(%arg0) : (i16) -> ()
}
@@ -103,8 +104,9 @@ func.func @test_block_argument_not_converted() {
// Make sure argument type changes aren't implicitly forwarded.
func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
- // expected-error at below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
+ // expected-error at below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
^bb0(%arg0: f32):
+ // expected-note at below {{see existing live user here}}
"test.type_consumer"(%arg0) : (f32) -> ()
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
More information about the Mlir-commits
mailing list