[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: add `originalType` param to materializations (PR #112128)
Matthias Springer
llvmlistbot at llvm.org
Sun Oct 13 05:04:53 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/112128
>From d245912fc11f70e81d1fb98f71841062c5e95b6c Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 13 Oct 2024 13:58:42 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: add originalType param
to materializations v2
---
.../mlir/Transforms/DialectConversion.h | 108 ++++++++++++------
.../Transforms/Utils/DialectConversion.cpp | 83 +++++++++++---
2 files changed, 136 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 65e279e046e886..45ad6f8586daae 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -138,7 +138,8 @@ class TypeConverter {
};
/// Register a conversion function. A conversion function must be convertible
- /// to any of the following forms(where `T` is a class derived from `Type`:
+ /// to any of the following forms (where `T` is a class derived from `Type`):
+ ///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
@@ -151,15 +152,7 @@ class TypeConverter {
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
- /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
- /// ArrayRef<Type>)
- /// - This form represents a 1-N type conversion supporting recursive
- /// types. The first two arguments and the return value are the same as
- /// for the regular 1-N form. The third argument is contains is the
- /// "call stack" of the recursive conversion: it contains the list of
- /// types currently being converted, with the current type being the
- /// last one. If it is present more than once in the list, the
- /// conversion concerns a recursive type.
+ ///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
@@ -178,6 +171,9 @@ class TypeConverter {
/// it failed but other materialization can be attempted, and `nullptr` on
/// unrecoverable failure. Materialization functions must be provided when a
/// type conversion may persist after the conversion has finished.
+ ///
+ /// Note: Target materializations may optionally accept an additional Type
+ /// parameter, which is the original type of the SSA value.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
@@ -203,11 +199,22 @@ class TypeConverter {
/// This method registers a materialization that will be called when
/// converting an illegal (source) value to a legal (target) type.
+ ///
+ /// Note: For target materializations, users can optionally take the original
+ /// type. This type may be different from the type of the input. For example,
+ /// let's assume that a conversion pattern "P1" replaced an SSA value "v1"
+ /// (type "t1") with "v2" (type "t2"). Then a different conversion pattern
+ /// "P2" matches an op that has "v1" as an operand. Let's furthermore assume
+ /// that "P2" determines that the legalized type of "t1" is "t3", which may
+ /// be different from "t2". In this example, the target materialization
+ /// will be invoked with: outputType = "t3", inputs = "v2",
+ // originalType = "t1". Note that the original type "t1" cannot be recovered
+ /// from just "t3" and "v2"; that's why the originalType parameter exists.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
targetMaterializations.emplace_back(
- wrapMaterialization<T>(std::forward<FnT>(callback)));
+ wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
}
/// Register a conversion function for attributes within types. Type
@@ -303,21 +310,12 @@ class TypeConverter {
/// `add*Materialization` for more information on the context for these
/// methods.
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
- Type resultType,
- ValueRange inputs) const {
- return materializeConversion(argumentMaterializations, builder, loc,
- resultType, inputs);
- }
+ Type resultType, ValueRange inputs) const;
Value materializeSourceConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
- return materializeConversion(sourceMaterializations, builder, loc,
- resultType, inputs);
- }
+ Type resultType, ValueRange inputs) const;
Value materializeTargetConversion(OpBuilder &builder, Location loc,
- Type resultType, ValueRange inputs) const {
- return materializeConversion(targetMaterializations, builder, loc,
- resultType, inputs);
- }
+ Type resultType, ValueRange inputs,
+ Type originalType = {}) const;
/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
@@ -333,21 +331,23 @@ class TypeConverter {
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
Type, SmallVectorImpl<Type> &)>;
- /// The signature of the callback used to materialize a conversion.
+ /// The signature of the callback used to materialize a source/argument
+ /// conversion.
+ ///
+ /// Arguments: builder, result type, inputs, location
using MaterializationCallbackFn = std::function<std::optional<Value>(
OpBuilder &, Type, ValueRange, Location)>;
+ /// The signature of the callback used to materialize a target conversion.
+ ///
+ /// Arguments: builder, result type, inputs, location, original type
+ using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
+ OpBuilder &, Type, ValueRange, Location, Type)>;
+
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
std::function<AttributeConversionResult(Type, Attribute)>;
- /// Attempt to materialize a conversion using one of the provided
- /// materialization functions.
- Value
- materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
- OpBuilder &builder, Location loc, Type resultType,
- ValueRange inputs) const;
-
/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
/// With callback of form: `std::optional<Type>(T)`
@@ -388,9 +388,10 @@ class TypeConverter {
cachedMultiConversions.clear();
}
- /// Generate a wrapper for the given materialization callback. The callback
- /// may take any subclass of `Type` and the wrapper will check for the target
- /// type to be of the expected class before calling the callback.
+ /// Generate a wrapper for the given argument/source materialization
+ /// callback. The callback may take any subclass of `Type` and the
+ /// wrapper will check for the target type to be of the expected class
+ /// before calling the callback.
template <typename T, typename FnT>
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
@@ -402,6 +403,41 @@ class TypeConverter {
};
}
+ /// Generate a wrapper for the given target materialization callback.
+ /// The callback may take any subclass of `Type` and the wrapper will check
+ /// for the target type to be of the expected class before calling the
+ /// callback.
+ ///
+ /// With callback of form:
+ /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+ template <typename T, typename FnT>
+ std::enable_if_t<
+ std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
+ TargetMaterializationCallbackFn>
+ wrapTargetMaterialization(FnT &&callback) const {
+ return [callback = std::forward<FnT>(callback)](
+ OpBuilder &builder, Type resultType, ValueRange inputs,
+ Location loc, Type originalType) -> std::optional<Value> {
+ if (T derivedType = dyn_cast<T>(resultType))
+ return callback(builder, derivedType, inputs, loc, originalType);
+ return std::nullopt;
+ };
+ }
+ /// With callback of form:
+ /// `Value(OpBuilder &, T, ValueRange, Location)`
+ template <typename T, typename FnT>
+ std::enable_if_t<
+ std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
+ TargetMaterializationCallbackFn>
+ wrapTargetMaterialization(FnT &&callback) const {
+ return wrapTargetMaterialization<T>(
+ [callback = std::forward<FnT>(callback)](
+ OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
+ Type originalType) -> std::optional<Value> {
+ return callback(builder, resultType, inputs, loc);
+ });
+ }
+
/// Generate a wrapper for the given memory space conversion callback. The
/// callback may take any subclass of `Attribute` and the wrapper will check
/// for the target attribute to be of the expected class before calling the
@@ -434,7 +470,7 @@ class TypeConverter {
/// The list of registered materialization functions.
SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
- SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
+ SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
/// The list of registered type attribute conversion functions.
SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 97dd3ab1f48293..df99b5dbc2dd5d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -683,10 +683,10 @@ enum MaterializationKind {
/// conversion.
class UnresolvedMaterializationRewrite : public OperationRewrite {
public:
- UnresolvedMaterializationRewrite(
- ConversionPatternRewriterImpl &rewriterImpl,
- UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
- MaterializationKind kind = MaterializationKind::Target);
+ UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+ UnrealizedConversionCastOp op,
+ const TypeConverter *converter,
+ MaterializationKind kind, Type originalType);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,11 +708,18 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
return converterAndKind.getInt();
}
+ /// Return the original type of the SSA value.
+ Type getOriginalType() const { return originalType; }
+
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
converterAndKind;
+
+ /// The original type of the SSA value. Only used for target
+ /// materializations.
+ Type originalType;
};
} // namespace
@@ -808,6 +815,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Value buildUnresolvedMaterialization(MaterializationKind kind,
OpBuilder::InsertPoint ip, Location loc,
ValueRange inputs, Type outputType,
+ Type originalType,
const TypeConverter *converter);
//===--------------------------------------------------------------------===//
@@ -1034,9 +1042,12 @@ void CreateOperationRewrite::rollback() {
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
- const TypeConverter *converter, MaterializationKind kind)
+ const TypeConverter *converter, MaterializationKind kind, Type originalType)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
- converterAndKind(converter, kind) {
+ converterAndKind(converter, kind), originalType(originalType) {
+ assert(!originalType ||
+ kind == MaterializationKind::Target &&
+ "original type is valid only for target materializations");
rewriterImpl.unresolvedMaterializations[op] = this;
}
@@ -1139,7 +1150,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
Value castValue = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(newOperand),
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
- currentTypeConverter);
+ /*originalType=*/origType, currentTypeConverter);
mapping.map(newOperand, castValue);
newOperand = castValue;
}
@@ -1255,7 +1266,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
MaterializationKind::Source,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
/*inputs=*/ValueRange(),
- /*outputType=*/origArgType, converter);
+ /*outputType=*/origArgType, /*originalType=*/Type(), converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
@@ -1280,7 +1291,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument,
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
- /*inputs=*/replArgs, origArgType, converter);
+ /*inputs=*/replArgs, /*outputType=*/origArgType,
+ /*originalType=*/Type(), converter);
mapping.map(origArg, argMat);
Type legalOutputType;
@@ -1299,7 +1311,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (legalOutputType && legalOutputType != origArgType) {
Value targetMat = buildUnresolvedMaterialization(
MaterializationKind::Target, computeInsertPoint(argMat),
- origArg.getLoc(), argMat, legalOutputType, converter);
+ origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
+ /*originalType=*/origArgType, converter);
mapping.map(argMat, targetMat);
}
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1322,7 +1335,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/// of input operands.
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
- ValueRange inputs, Type outputType, const TypeConverter *converter) {
+ ValueRange inputs, Type outputType, Type originalType,
+ const TypeConverter *converter) {
+ assert(!originalType ||
+ kind == MaterializationKind::Target &&
+ "original type is valid only for target materializations");
+
// Avoid materializing an unnecessary cast.
if (inputs.size() == 1 && inputs.front().getType() == outputType)
return inputs.front();
@@ -1333,7 +1351,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+ originalType);
return convertOp.getResult(0);
}
@@ -1381,7 +1400,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
newValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*inputs=*/ValueRange(),
- /*outputType=*/result.getType(), currentTypeConverter);
+ /*outputType=*/result.getType(), /*originalType=*/result.getType(),
+ currentTypeConverter);
}
// Remap, and check for any result type changes.
@@ -2408,7 +2428,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
[[fallthrough]];
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
- rewriter, op->getLoc(), outputType, inputOperands);
+ rewriter, op->getLoc(), outputType, inputOperands,
+ rewrite->getOriginalType());
break;
case MaterializationKind::Source:
newMaterialization = converter->materializeSourceConversion(
@@ -2565,7 +2586,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
MaterializationKind::Source, computeInsertPoint(newValue),
originalValue.getLoc(),
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
- converter);
+ /*originalType=*/originalValue.getType(), converter);
rewriterImpl.mapping.map(originalValue, castValue);
inverseMapping[castValue].push_back(originalValue);
llvm::erase(inverseMapping[newValue], originalValue);
@@ -2787,15 +2808,39 @@ TypeConverter::convertSignatureArgs(TypeRange types,
return success();
}
-Value TypeConverter::materializeConversion(
- ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
- Location loc, Type resultType, ValueRange inputs) const {
- for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
+Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
+ Location loc,
+ Type resultType,
+ ValueRange inputs) const {
+ for (const MaterializationCallbackFn &fn :
+ llvm::reverse(argumentMaterializations))
+ if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
+ return *result;
+ return nullptr;
+}
+
+Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
+ Location loc, Type resultType,
+ ValueRange inputs) const {
+ for (const MaterializationCallbackFn &fn :
+ llvm::reverse(sourceMaterializations))
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
return *result;
return nullptr;
}
+Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
+ Location loc, Type resultType,
+ ValueRange inputs,
+ Type originalType) const {
+ for (const TargetMaterializationCallbackFn &fn :
+ llvm::reverse(targetMaterializations))
+ if (std::optional<Value> result =
+ fn(builder, resultType, inputs, loc, originalType))
+ return *result;
+ return nullptr;
+}
+
std::optional<TypeConverter::SignatureConversion>
TypeConverter::convertBlockSignature(Block *block) const {
SignatureConversion conversion(block->getNumArguments());
More information about the Mlir-commits
mailing list