[Mlir-commits] [mlir] 0d906a4 - [mlir][Transforms] Dialect conversion: add `originalType` param to materializations (#112128)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Oct 14 23:52:35 PDT 2024
Author: Matthias Springer
Date: 2024-10-15T08:52:32+02:00
New Revision: 0d906a425444e0205be8d19e585abe7caa808ba0
URL: https://github.com/llvm/llvm-project/commit/0d906a425444e0205be8d19e585abe7caa808ba0
DIFF: https://github.com/llvm/llvm-project/commit/0d906a425444e0205be8d19e585abe7caa808ba0.diff
LOG: [mlir][Transforms] Dialect conversion: add `originalType` param to materializations (#112128)
This commit adds an optional `originalType` parameter to target
materialization functions. Without this parameter, target
materializations are underspecified.
Note: `originalType` is only needed for target materializations.
Source/argument materializations do not have it.
Consider the following example: Let's assume that a conversion pattern
"P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then
a different conversion pattern "P2" matches an op that has "v1" as an
operand. Let's furthermore assume that "P2" determines that the
legalized type of "t1" is "t3", which may be different from "t2". In
this example, the target materialization callback will be invoked with:
outputType = "t3", inputs = "v2", originalType = "t1". Note that the
original type "t1" cannot be recovered from just "t3" and "v2"; that's
why the `originalType` parameter is added.
This change is in preparation of merging the 1:1 and 1:N dialect
conversion drivers. As part of that change, argument materializations
will be removed (as they are no longer needed; they were just a
workaround because of missing 1:N support in the dialect conversion).
The new `originalType` parameter is needed when lowering MemRef to LLVM.
During that lowering, MemRef function block arguments are replaced with
the elements that make up a MemRef descriptor. The type converter is set
up in such a way that the legalized type of a MemRef type is an
`!llvm.struct` that represents the MemRef descriptor. When the bare
pointer calling convention is enabled, the function block arguments
consist of just an LLVM pointer. In such a case, a target
materialization will be invoked to construct a MemRef descriptor (output
type = `!llvm.struct<...>`) from just the bare pointer (inputs =
`!llvm.ptr`). The original MemRef type is required to construct the
MemRef descriptor, as static sizes/strides/offset cannot be inferred
from just the bare pointer.
Added:
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..45ad6f8586daae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -138,7 +138,8 @@ class TypeConverter {
};
/// Register a conversion function. A conversion function must be convertible
- /// to any of the following forms(where `T` is a class derived from `Type`:
+ /// to any of the following forms (where `T` is a class derived from `Type`):
+ ///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
@@ -151,15 +152,7 @@ class TypeConverter {
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
- /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
- /// ArrayRef<Type>)
- /// - This form represents a 1-N type conversion supporting recursive
- /// types. The first two arguments and the return value are the same as
- /// for the regular 1-N form. The third argument is contains is the
- /// "call stack" of the recursive conversion: it contains the list of
- /// types currently being converted, with the current type being the
- /// last one. If it is present more than once in the list, the
- /// conversion concerns a recursive type.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
@@ -178,6 +171,9 @@ class TypeConverter {
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. Materialization functions must be provided when a
/// type conversion may persist after the conversion has finished.
+ ///
+ /// Note: Target materializations may optionally accept an additional Type
+ /// parameter, which is the original type of the SSA value.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
@@ -203,11 +199,22 @@ class TypeConverter {
/// This method registers a materialization that will be called when
/// converting an illegal (source) value to a legal (target) type.
+ ///
+ /// Note: For target materializations, users can optionally take the original
+ /// type. This type may be
diff erent from the type of the input. For example,
+ /// let's assume that a conversion pattern "P1" replaced an SSA value "v1"
+ /// (type "t1") with "v2" (type "t2"). Then a
diff erent conversion pattern
+ /// "P2" matches an op that has "v1" as an operand. Let's furthermore assume
+ /// that "P2" determines that the legalized type of "t1" is "t3", which may
+ /// be
diff erent from "t2". In this example, the target materialization
+ /// will be invoked with: outputType = "t3", inputs = "v2",
+ // originalType = "t1". Note that the original type "t1" cannot be recovered
+ /// from just "t3" and "v2"; that's why the originalType parameter exists.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back(
- wrapMaterialization<T>(std::forward<FnT>(callback)));
+ wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
}
/// Register a conversion function for attributes within types. Type
@@ -303,21 +310,12 @@ class TypeConverter {
/// `add*Materialization` for more information on the context for these
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
- Type resultType,
- ValueRange inputs) const {
- return materializeConversion(argumentMaterializations, builder, loc,
- resultType, inputs);
- }
+ Type resultType, ValueRange inputs) const;
Value materializeSourceConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
- return materializeConversion(sourceMaterializations, builder, loc,
- resultType, inputs);
- }
+ Type resultType, ValueRange inputs) const;
Value materializeTargetConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
- return materializeConversion(targetMaterializations, builder, loc,
- resultType, inputs);
- }
+ Type resultType, ValueRange inputs,
+ Type originalType = {}) const;
/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
@@ -333,21 +331,23 @@ class TypeConverter {
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
Type, SmallVectorImpl<Type> &)>;
- /// The signature of the callback used to materialize a conversion.
+ /// The signature of the callback used to materialize a source/argument
+ /// conversion.
+ ///
+ /// Arguments: builder, result type, inputs, location
using MaterializationCallbackFn = std::function<std::optional<Value>(
OpBuilder &, Type, ValueRange, Location)>;
+ /// The signature of the callback used to materialize a target conversion.
+ ///
+ /// Arguments: builder, result type, inputs, location, original type
+ using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
+ OpBuilder &, Type, ValueRange, Location, Type)>;
+
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
std::function<AttributeConversionResult(Type, Attribute)>;
- /// Attempt to materialize a conversion using one of the provided
- /// materialization functions.
- Value
- materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
- OpBuilder &builder, Location loc, Type resultType,
- ValueRange inputs) const;
-
/// Generate a wrapper for the given callback. This allows for accepting
///
diff erent callback forms, that all compose into a single version.
/// With callback of form: `std::optional<Type>(T)`
@@ -388,9 +388,10 @@ class TypeConverter {
cachedMultiConversions.clear();
}
- /// Generate a wrapper for the given materialization callback. The callback
- /// may take any subclass of `Type` and the wrapper will check for the target
- /// type to be of the expected class before calling the callback.
+ /// Generate a wrapper for the given argument/source materialization
+ /// callback. The callback may take any subclass of `Type` and the
+ /// wrapper will check for the target type to be of the expected class
+ /// before calling the callback.
template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
@@ -402,6 +403,41 @@ class TypeConverter {
};
}
+ /// Generate a wrapper for the given target materialization callback.
+ /// The callback may take any subclass of `Type` and the wrapper will check
+ /// for the target type to be of the expected class before calling the
+ /// callback.
+ ///
+ /// With callback of form:
+ /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+ template <typename T, typename FnT>
+ std::enable_if_t<
+ std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
+ TargetMaterializationCallbackFn>
+ wrapTargetMaterialization(FnT &&callback) const {
+ return [callback = std::forward<FnT>(callback)](
+ OpBuilder &builder, Type resultType, ValueRange inputs,
+ Location loc, Type originalType) -> std::optional<Value> {
+ if (T derivedType = dyn_cast<T>(resultType))
+ return callback(builder, derivedType, inputs, loc, originalType);
+ return std::nullopt;
+ };
+ }
+ /// With callback of form:
+ /// `Value(OpBuilder &, T, ValueRange, Location)`
+ template <typename T, typename FnT>
+ std::enable_if_t<
+ std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
+ TargetMaterializationCallbackFn>
+ wrapTargetMaterialization(FnT &&callback) const {
+ return wrapTargetMaterialization<T>(
+ [callback = std::forward<FnT>(callback)](
+ OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
+ Type originalType) -> std::optional<Value> {
+ return callback(builder, resultType, inputs, loc);
+ });
+ }
+
/// Generate a wrapper for the given memory space conversion callback. The
/// callback may take any subclass of `Attribute` and the wrapper will check
/// for the target attribute to be of the expected class before calling the
@@ -434,7 +470,7 @@ class TypeConverter {
/// The list of registered materialization functions.
SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
- SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
+ SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
/// The list of registered type attribute conversion functions.
SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 97dd3ab1f48293..1baddd881f6aa2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -683,10 +683,10 @@ enum MaterializationKind {
/// conversion.
class UnresolvedMaterializationRewrite : public OperationRewrite {
public:
- UnresolvedMaterializationRewrite(
- ConversionPatternRewriterImpl &rewriterImpl,
- UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
- MaterializationKind kind = MaterializationKind::Target);
+ UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ UnrealizedConversionCastOp op,
+ const TypeConverter *converter,
+ MaterializationKind kind, Type originalType);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,11 +708,18 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
return converterAndKind.getInt();
}
+ /// Return the original type of the SSA value.
+ Type getOriginalType() const { return originalType; }
+
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
converterAndKind;
+
+ /// The original type of the SSA value. Only used for target
+ /// materializations.
+ Type originalType;
};
} // namespace
@@ -808,6 +815,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Value buildUnresolvedMaterialization(MaterializationKind kind,
OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType,
+ Type originalType,
const TypeConverter *converter);
//===--------------------------------------------------------------------===//
@@ -1034,9 +1042,12 @@ void CreateOperationRewrite::rollback() {
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
- const TypeConverter *converter, MaterializationKind kind)
+ const TypeConverter *converter, MaterializationKind kind, Type originalType)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind) {
+ converterAndKind(converter, kind), originalType(originalType) {
+ assert(!originalType ||
+ kind == MaterializationKind::Target &&
+ "original type is valid only for target materializations");
rewriterImpl.unresolvedMaterializations[op] = this;
}
@@ -1139,7 +1150,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
- currentTypeConverter);
+ /*originalType=*/origType, currentTypeConverter);
mapping.map(newOperand, castValue);
newOperand = castValue;
}
@@ -1255,7 +1266,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/ValueRange(),
- /*outputType=*/origArgType, converter);
+ /*outputType=*/origArgType, /*originalType=*/Type(), converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
@@ -1280,7 +1291,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*inputs=*/replArgs, origArgType, converter);
+ /*inputs=*/replArgs, /*outputType=*/origArgType,
+ /*originalType=*/Type(), converter);
mapping.map(origArg, argMat);
Type legalOutputType;
@@ -1299,7 +1311,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (legalOutputType && legalOutputType != origArgType) {
Value targetMat = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(argMat),
- origArg.getLoc(), argMat, legalOutputType, converter);
+ origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
+ /*originalType=*/origArgType, converter);
mapping.map(argMat, targetMat);
}
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1322,7 +1335,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- ValueRange inputs, Type outputType, const TypeConverter *converter) {
+ ValueRange inputs, Type outputType, Type originalType,
+ const TypeConverter *converter) {
+ assert(!originalType ||
+ kind == MaterializationKind::Target &&
+ "original type is valid only for target materializations");
+
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
return inputs.front();
@@ -1333,7 +1351,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+ originalType);
return convertOp.getResult(0);
}
@@ -1381,7 +1400,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
newValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*inputs=*/ValueRange(),
- /*outputType=*/result.getType(), currentTypeConverter);
+ /*outputType=*/result.getType(), /*originalType=*/Type(),
+ currentTypeConverter);
}
// Remap, and check for any result type changes.
@@ -2408,7 +2428,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
[[fallthrough]];
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
+ rewriter, op->getLoc(), outputType, inputOperands,
+ rewrite->getOriginalType());
break;
case MaterializationKind::Source:
newMaterialization = converter->materializeSourceConversion(
@@ -2565,7 +2586,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
MaterializationKind::Source, computeInsertPoint(newValue),
originalValue.getLoc(),
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
- converter);
+ /*originalType=*/Type(), converter);
rewriterImpl.mapping.map(originalValue, castValue);
inverseMapping[castValue].push_back(originalValue);
llvm::erase(inverseMapping[newValue], originalValue);
@@ -2787,15 +2808,39 @@ TypeConverter::convertSignatureArgs(TypeRange types,
return success();
}
-Value TypeConverter::materializeConversion(
- ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
- Location loc, Type resultType, ValueRange inputs) const {
- for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
+Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
+ Location loc,
+ Type resultType,
+ ValueRange inputs) const {
+ for (const MaterializationCallbackFn &fn :
+ llvm::reverse(argumentMaterializations))
+ if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
+ return *result;
+ return nullptr;
+}
+
+Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
+ Location loc, Type resultType,
+ ValueRange inputs) const {
+ for (const MaterializationCallbackFn &fn :
+ llvm::reverse(sourceMaterializations))
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
return *result;
return nullptr;
}
+Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
+ Location loc, Type resultType,
+ ValueRange inputs,
+ Type originalType) const {
+ for (const TargetMaterializationCallbackFn &fn :
+ llvm::reverse(targetMaterializations))
+ if (std::optional<Value> result =
+ fn(builder, resultType, inputs, loc, originalType))
+ return *result;
+ return nullptr;
+}
+
std::optional<TypeConverter::SignatureConversion>
TypeConverter::convertBlockSignature(Block *block) const {
SignatureConversion conversion(block->getNumArguments());
More information about the Mlir-commits
mailing list