[Mlir-commits] [mlir] [mlir][Transforms] Merge 1:1 and 1:N type converters (PR #113032)
Matthias Springer
llvmlistbot at llvm.org
Wed Oct 23 09:58:16 PDT 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/113032
>From 44db7a9efdbd33f5a40723a96cbe9f66f68e0970 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sat, 19 Oct 2024 12:05:13 +0200
Subject: [PATCH 1/2] [mlir][Transforms] Merge 1:1 and 1:N type converters
---
.../Dialect/SparseTensor/Transforms/Passes.h | 2 +-
.../mlir/Transforms/DialectConversion.h | 56 ++++++++++++++-----
.../mlir/Transforms/OneToNTypeConversion.h | 45 +--------------
.../ArmSME/Transforms/VectorLegalization.cpp | 2 +-
.../Transforms/Utils/DialectConversion.cpp | 24 ++++++--
.../Transforms/Utils/OneToNTypeConversion.cpp | 44 +++++----------
.../TestOneToNTypeConversionPass.cpp | 18 ++++--
7 files changed, 93 insertions(+), 98 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
index 6ccbc40bdd6034..2e9c297f20182a 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
@@ -150,7 +150,7 @@ std::unique_ptr<Pass> createLowerForeachToSCFPass();
//===----------------------------------------------------------------------===//
/// Type converter for iter_space and iterator.
-struct SparseIterationTypeConverter : public OneToNTypeConverter {
+struct SparseIterationTypeConverter : public TypeConverter {
SparseIterationTypeConverter();
};
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5ff36160dd6162..37da03bbe386e9 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -173,7 +173,9 @@ class TypeConverter {
/// conversion has finished.
///
/// Note: Target materializations may optionally accept an additional Type
- /// parameter, which is the original type of the SSA value.
+ /// parameter, which is the original type of the SSA value. Furthermore, `T`
+ /// can be a TypeRange; in that case, the function must return a
+ /// SmallVector<Value>.
/// This method registers a materialization that will be called when
/// converting (potentially multiple) block arguments that were the result of
@@ -210,6 +212,9 @@ class TypeConverter {
/// will be invoked with: outputType = "t3", inputs = "v2",
// originalType = "t1". Note that the original type "t1" cannot be recovered
/// from just "t3" and "v2"; that's why the originalType parameter exists.
+ ///
+ /// Note: During a 1:N conversion, the result types can be a TypeRange. In
+ /// that case the materialization produces a SmallVector<Value>.
template <typename FnT, typename T = typename llvm::function_traits<
std::decay_t<FnT>>::template arg_t<1>>
void addTargetMaterialization(FnT &&callback) {
@@ -316,6 +321,11 @@ class TypeConverter {
Value materializeTargetConversion(OpBuilder &builder, Location loc,
Type resultType, ValueRange inputs,
Type originalType = {}) const;
+ SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
+ Location loc,
+ TypeRange resultType,
+ ValueRange inputs,
+ Type originalType = {}) const;
/// Convert an attribute present `attr` from within the type `type` using
/// the registered conversion functions. If no applicable conversion has been
@@ -340,9 +350,9 @@ class TypeConverter {
/// The signature of the callback used to materialize a target conversion.
///
- /// Arguments: builder, result type, inputs, location, original type
- using TargetMaterializationCallbackFn =
- std::function<Value(OpBuilder &, Type, ValueRange, Location, Type)>;
+ /// Arguments: builder, result types, inputs, location, original type
+ using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
+ OpBuilder &, TypeRange, ValueRange, Location, Type)>;
/// The signature of the callback used to convert a type attribute.
using TypeAttributeConversionCallbackFn =
@@ -409,22 +419,40 @@ class TypeConverter {
/// callback.
///
/// With callback of form:
- /// `Value(OpBuilder &, T, ValueRange, Location, Type)`
+ /// - Value(OpBuilder &, T, ValueRange, Location, Type)
+ /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
TargetMaterializationCallbackFn>
wrapTargetMaterialization(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Type resultType, ValueRange inputs,
- Location loc, Type originalType) -> Value {
- if (T derivedType = dyn_cast<T>(resultType))
- return callback(builder, derivedType, inputs, loc, originalType);
- return Value();
+ OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc, Type originalType) -> SmallVector<Value> {
+ SmallVector<Value> result;
+ if constexpr (std::is_same<T, TypeRange>::value) {
+ // This is a 1:N target materialization. Return the produces values
+ // directly.
+ result = callback(builder, resultTypes, inputs, loc, originalType);
+ } else {
+ // This is a 1:1 target materialization. Invoke it only if the result
+ // type class of the callback matches the requested result type.
+ if (T derivedType = dyn_cast<T>(resultTypes.front())) {
+ // 1:1 materializations produce single values, but we store 1:N
+ // target materialization functions in the type converter. Wrap the
+ // result value in a SmallVector<Value>.
+ std::optional<Value> val =
+ callback(builder, derivedType, inputs, loc, originalType);
+ if (val.has_value() && *val)
+ result.push_back(*val);
+ }
+ }
+ return result;
};
}
/// With callback of form:
- /// `Value(OpBuilder &, T, ValueRange, Location)`
+ /// - Value(OpBuilder &, T, ValueRange, Location)
+ /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
template <typename T, typename FnT>
std::enable_if_t<
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
@@ -432,9 +460,9 @@ class TypeConverter {
wrapTargetMaterialization(FnT &&callback) const {
return wrapTargetMaterialization<T>(
[callback = std::forward<FnT>(callback)](
- OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
- Type originalType) -> Value {
- return callback(builder, resultType, inputs, loc);
+ OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
+ Type originalType) {
+ return callback(builder, resultTypes, inputs, loc);
});
}
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index c59a3a52f028f3..7b4dd65cbff7b2 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -33,49 +33,6 @@
namespace mlir {
-/// Extends `TypeConverter` with 1:N target materializations. Such
-/// materializations have to provide the "reverse" of 1:N type conversions,
-/// i.e., they need to materialize N values with target types into one value
-/// with a source type (which isn't possible in the base class currently).
-class OneToNTypeConverter : public TypeConverter {
-public:
- /// Callback that expresses user-provided materialization logic from the given
- /// value to N values of the given types. This is useful for expressing target
- /// materializations for 1:N type conversions, which materialize one value in
- /// a source type as N values in target types.
- using OneToNMaterializationCallbackFn =
- std::function<std::optional<SmallVector<Value>>(OpBuilder &, TypeRange,
- Value, Location)>;
-
- /// Creates the mapping of the given range of original types to target types
- /// of the conversion and stores that mapping in the given (signature)
- /// conversion. This function simply calls
- /// `TypeConverter::convertSignatureArgs` and exists here with a different
- /// name to reflect the broader semantic.
- LogicalResult computeTypeMapping(TypeRange types,
- SignatureConversion &result) const {
- return convertSignatureArgs(types, result);
- }
-
- /// Applies one of the user-provided 1:N target materializations. If several
- /// exists, they are tried out in the reverse order in which they have been
- /// added until the first one succeeds. If none succeeds, the functions
- /// returns `std::nullopt`.
- std::optional<SmallVector<Value>>
- materializeTargetConversion(OpBuilder &builder, Location loc,
- TypeRange resultTypes, Value input) const;
-
- /// Adds a 1:N target materialization to the converter. Such materializations
- /// build IR that converts N values with target types into 1 value of the
- /// source type.
- void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) {
- oneToNTargetMaterializations.emplace_back(std::move(callback));
- }
-
-private:
- SmallVector<OneToNMaterializationCallbackFn> oneToNTargetMaterializations;
-};
-
/// Stores a 1:N mapping of types and provides several useful accessors. This
/// class extends `SignatureConversion`, which already supports 1:N type
/// mappings but lacks some accessors into the mapping as well as access to the
@@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern {
/// not fail if some ops or types remain unconverted (i.e., the conversion is
/// only "partial").
LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns);
/// Add a pattern to the given pattern list to convert the signature of a
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4968c4fc463d04..e908a536e6fb27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -921,7 +921,7 @@ struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
- OneToNTypeConverter converter;
+ TypeConverter converter;
RewritePatternSet patterns(context);
converter.addConversion([](Type type) { return type; });
converter.addConversion(
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3cfcaa965f3546..bf969e74e8bfe0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2831,11 +2831,27 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
Location loc, Type resultType,
ValueRange inputs,
Type originalType) const {
+ SmallVector<Value> result = materializeTargetConversion(
+ builder, loc, TypeRange(resultType), inputs, originalType);
+ if (result.empty())
+ return nullptr;
+ assert(result.size() == 1 && "requested 1:1 materialization, but callback "
+ "produced 1:N materialization");
+ return result.front();
+}
+
+SmallVector<Value> TypeConverter::materializeTargetConversion(
+ OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs,
+ Type originalType) const {
for (const TargetMaterializationCallbackFn &fn :
- llvm::reverse(targetMaterializations))
- if (Value result = fn(builder, resultType, inputs, loc, originalType))
- return result;
- return nullptr;
+ llvm::reverse(targetMaterializations)) {
+ SmallVector<Value> result =
+ fn(builder, resultTypes, inputs, loc, originalType);
+ if (result.empty())
+ continue;
+ return result;
+ }
+ return {};
}
std::optional<TypeConverter::SignatureConversion>
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index 19e29d48623e04..c208716891ef1f 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -17,20 +17,6 @@
using namespace llvm;
using namespace mlir;
-std::optional<SmallVector<Value>>
-OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
- Location loc,
- TypeRange resultTypes,
- Value input) const {
- for (const OneToNMaterializationCallbackFn &fn :
- llvm::reverse(oneToNTargetMaterializations)) {
- if (std::optional<SmallVector<Value>> result =
- fn(builder, resultTypes, input, loc))
- return *result;
- }
- return std::nullopt;
-}
-
TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
TypeRange convertedTypes = getConvertedTypes();
if (auto mapping = getInputMapping(originalTypeNo))
@@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion(
LogicalResult
OneToNConversionPattern::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+ auto *typeConverter = getTypeConverter();
// Construct conversion mapping for results.
Operation::result_type_range originalResultTypes = op->getResultTypes();
OneToNTypeMapping resultMapping(originalResultTypes);
- if (failed(typeConverter->computeTypeMapping(originalResultTypes,
- resultMapping)))
+ if (failed(typeConverter->convertSignatureArgs(originalResultTypes,
+ resultMapping)))
return failure();
// Construct conversion mapping for operands.
Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
OneToNTypeMapping operandMapping(originalOperandTypes);
- if (failed(typeConverter->computeTypeMapping(originalOperandTypes,
- operandMapping)))
+ if (failed(typeConverter->convertSignatureArgs(originalOperandTypes,
+ operandMapping)))
return failure();
// Cast operands to target types.
@@ -318,7 +304,7 @@ namespace mlir {
// inserted by this pass are annotated with a string attribute that also
// documents which kind of the cast (source, argument, or target).
LogicalResult
-applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
+applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns) {
#ifndef NDEBUG
// Remember existing unrealized casts. This data structure is only used in
@@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
// Target materialization.
assert(!areOperandTypesLegal && areResultsTypesLegal &&
operands.size() == 1 && "found unexpected target cast");
- std::optional<SmallVector<Value>> maybeResults =
- typeConverter.materializeTargetConversion(
- rewriter, castOp->getLoc(), resultTypes, operands.front());
- if (!maybeResults) {
+ materializedResults = typeConverter.materializeTargetConversion(
+ rewriter, castOp->getLoc(), resultTypes, operands.front());
+ if (materializedResults.empty()) {
emitError(castOp->getLoc())
<< "failed to create target materialization";
return failure();
}
- materializedResults = maybeResults.value();
} else {
// Source and argument materializations.
assert(areOperandTypesLegal && !areResultsTypesLegal &&
@@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
const OneToNTypeMapping &resultMapping,
ValueRange convertedOperands) const override {
auto funcOp = cast<FunctionOpInterface>(op);
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+ auto *typeConverter = getTypeConverter();
// Construct mapping for function arguments.
OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
- if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
- argumentMapping)))
+ if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(),
+ argumentMapping)))
return failure();
// Construct mapping for function results.
OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
- if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
- funcResultMapping)))
+ if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(),
+ funcResultMapping)))
return failure();
// Nothing to do if the op doesn't have any non-identity conversions for its
diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
index 5c03ac12d1e58c..b18dfd8bb22cb1 100644
--- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
+++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp
@@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter,
///
/// This function has been copied (with small adaptions) from
/// TestDecomposeCallGraphTypes.cpp.
-static std::optional<SmallVector<Value>>
-buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input,
- Location loc) {
+static SmallVector<Value> buildGetTupleElementOps(OpBuilder &builder,
+ TypeRange resultTypes,
+ ValueRange inputs,
+ Location loc) {
+ if (inputs.size() != 1)
+ return {};
+ Value input = inputs.front();
+
TupleType inputType = dyn_cast<TupleType>(input.getType());
if (!inputType)
return {};
@@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() {
auto *context = &getContext();
// Assemble type converter.
- OneToNTypeConverter typeConverter;
+ TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
@@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() {
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
typeConverter.addSourceMaterialization(buildMakeTupleOp);
typeConverter.addTargetMaterialization(buildGetTupleElementOps);
+ // Test the other target materialization variant that takes the original type
+ // as additional argument. This materialization function always fails.
+ typeConverter.addTargetMaterialization(
+ [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
+ Location loc, Type originalType) -> SmallVector<Value> { return {}; });
// Assemble patterns.
RewritePatternSet patterns(context);
>From df341e92122f9c3d3a67b28583a45107a78268a8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 23 Oct 2024 09:36:45 -0700
Subject: [PATCH 2/2] Update mlir/include/mlir/Transforms/DialectConversion.h
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Co-authored-by: Markus Böck <markus.boeck02 at gmail.com>
---
mlir/include/mlir/Transforms/DialectConversion.h | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 37da03bbe386e9..0638dfedd647b0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -434,18 +434,19 @@ class TypeConverter {
// This is a 1:N target materialization. Return the produces values
// directly.
result = callback(builder, resultTypes, inputs, loc, originalType);
- } else {
+ } else if constexpr (std::is_assignable<Type, T>::value) {
// This is a 1:1 target materialization. Invoke it only if the result
// type class of the callback matches the requested result type.
if (T derivedType = dyn_cast<T>(resultTypes.front())) {
// 1:1 materializations produce single values, but we store 1:N
// target materialization functions in the type converter. Wrap the
// result value in a SmallVector<Value>.
- std::optional<Value> val =
- callback(builder, derivedType, inputs, loc, originalType);
- if (val.has_value() && *val)
- result.push_back(*val);
+ Value val = callback(builder, derivedType, inputs, loc, originalType);
+ if (val)
+ result.push_back(val);
}
+ } else {
+ static_assert(false, "T must be a Type or a TypeRange");
}
return result;
};
More information about the Mlir-commits
mailing list