[Mlir-commits] [mlir] [mlir][func] Replace `ValueDecomposer` with target materialization (PR #114192)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 30 00:56:16 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-func
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #<!-- -->113032, the dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)
---
Full diff: https://github.com/llvm/llvm-project/pull/114192.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h (+1-61)
- (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+53-56)
- (modified) mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp (+37-23)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
index 1d311b37b37a4f..1be406bf3adf92 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
@@ -23,70 +23,10 @@
namespace mlir {
-/// This class provides a hook that expands one Value into multiple Value's,
-/// with a TypeConverter-inspired callback registration mechanism.
-///
-/// For folks that are familiar with the dialect conversion framework /
-/// TypeConverter, this is effectively the inverse of a source/argument
-/// materialization. A target materialization is not what we want here because
-/// it always produces a single Value, but in this case the whole point is to
-/// decompose a Value into multiple Value's.
-///
-/// The reason we need this inverse is easily understood by looking at what we
-/// need to do for decomposing types for a return op. When converting a return
-/// op, the dialect conversion framework will give the list of converted
-/// operands, and will ensure that each converted operand, even if it expanded
-/// into multiple types, is materialized as a single result. We then need to
-/// undo that materialization to a single result, which we do with the
-/// decomposeValue hooks registered on this object.
-///
-/// TODO: Eventually, the type conversion infra should have this hook built-in.
-/// See
-/// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2
-class ValueDecomposer {
-public:
- /// This method tries to decompose a value of a certain type using provided
- /// decompose callback functions. If it is unable to do so, the original value
- /// is returned.
- void decomposeValue(OpBuilder &, Location, Type, Value,
- SmallVectorImpl<Value> &);
-
- /// This method registers a callback function that will be called to decompose
- /// a value of a certain type into 0, 1, or multiple values.
- template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<2>>
- void addDecomposeValueConversion(FnT &&callback) {
- decomposeValueConversions.emplace_back(
- wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
- }
-
-private:
- using DecomposeValueConversionCallFn =
- std::function<std::optional<LogicalResult>(
- OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
-
- /// Generate a wrapper for the given decompose value conversion callback.
- template <typename T, typename FnT>
- DecomposeValueConversionCallFn
- wrapDecomposeValueConversionCallback(FnT &&callback) {
- return
- [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Location loc, Type type, Value value,
- SmallVectorImpl<Value> &newValues) -> std::optional<LogicalResult> {
- if (T derivedType = dyn_cast<T>(type))
- return callback(builder, loc, derivedType, value, newValues);
- return std::nullopt;
- };
- }
-
- SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
-};
-
/// Populates the patterns needed to drive the conversion process for
-/// decomposing call graph types with the given `ValueDecomposer`.
+/// decomposing call graph types with the given `TypeConverter`.
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
const TypeConverter &typeConverter,
- ValueDecomposer &decomposer,
RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..8800ffd0be96dc 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -14,52 +14,46 @@ using namespace mlir;
using namespace mlir::func;
//===----------------------------------------------------------------------===//
-// ValueDecomposer
+// Helper functions
//===----------------------------------------------------------------------===//
-void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
- Type type, Value value,
- SmallVectorImpl<Value> &results) {
- for (auto &conversion : decomposeValueConversions)
- if (conversion(builder, loc, type, value, results))
- return;
- results.push_back(value);
+/// If the given value can be decomposed with the type converter, decompose it.
+/// Otherwise, return the given value.
+static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
+ Value value,
+ const TypeConverter *converter) {
+ // Try to convert the given value's type. If that fails, just return the
+ // given value.
+ SmallVector<Type> convertedTypes;
+ if (failed(converter->convertType(value.getType(), convertedTypes)))
+ return {value};
+ if (convertedTypes.empty())
+ return {};
+
+ // If the given value's type is already legal, just return the given value.
+ TypeRange convertedTypeRange(convertedTypes);
+ if (convertedTypeRange == TypeRange(value.getType()))
+ return {value};
+
+ // Try to materialize a target conversion. If the materialization did not
+ // produce values of the requested type, the materialization failed. Just
+ // return the given value in that case.
+ SmallVector<Value> result = converter->materializeTargetConversion(
+ builder, loc, convertedTypeRange, value);
+ if (result.empty())
+ return {value};
+ return result;
}
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesOpConversionPattern
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Base OpConversionPattern class to make a ValueDecomposer available to
-/// inherited patterns.
-template <typename SourceOp>
-class DecomposeCallGraphTypesOpConversionPattern
- : public OpConversionPattern<SourceOp> {
-public:
- DecomposeCallGraphTypesOpConversionPattern(const TypeConverter &typeConverter,
- MLIRContext *context,
- ValueDecomposer &decomposer,
- PatternBenefit benefit = 1)
- : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
- decomposer(decomposer) {}
-
-protected:
- ValueDecomposer &decomposer;
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//
namespace {
-/// Expand function arguments according to the provided TypeConverter and
-/// ValueDecomposer.
+/// Expand function arguments according to the provided TypeConverter.
struct DecomposeCallGraphTypesForFuncArgs
- : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
- using DecomposeCallGraphTypesOpConversionPattern::
- DecomposeCallGraphTypesOpConversionPattern;
+ : public OpConversionPattern<func::FuncOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +94,22 @@ struct DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//
namespace {
-/// Expand return operands according to the provided TypeConverter and
-/// ValueDecomposer.
+/// Expand return operands according to the provided TypeConverter.
struct DecomposeCallGraphTypesForReturnOp
- : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
- using DecomposeCallGraphTypesOpConversionPattern::
- DecomposeCallGraphTypesOpConversionPattern;
+ : public OpConversionPattern<ReturnOp> {
+ using OpConversionPattern::OpConversionPattern;
+
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Value, 2> newOperands;
- for (Value operand : adaptor.getOperands())
- decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
- operand, newOperands);
+ for (Value operand : adaptor.getOperands()) {
+ // TODO: We can directly take the values from the adaptor once this is a
+ // 1:N conversion pattern.
+ llvm::append_range(newOperands,
+ decomposeValue(rewriter, operand.getLoc(), operand,
+ getTypeConverter()));
+ }
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
@@ -124,12 +121,9 @@ struct DecomposeCallGraphTypesForReturnOp
//===----------------------------------------------------------------------===//
namespace {
-/// Expand call op operands and results according to the provided TypeConverter
-/// and ValueDecomposer.
-struct DecomposeCallGraphTypesForCallOp
- : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
- using DecomposeCallGraphTypesOpConversionPattern::
- DecomposeCallGraphTypesOpConversionPattern;
+/// Expand call op operands and results according to the provided TypeConverter.
+struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(CallOp op, OpAdaptor adaptor,
@@ -137,9 +131,13 @@ struct DecomposeCallGraphTypesForCallOp
// Create the operands list of the new `CallOp`.
SmallVector<Value, 2> newOperands;
- for (Value operand : adaptor.getOperands())
- decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
- operand, newOperands);
+ for (Value operand : adaptor.getOperands()) {
+ // TODO: We can directly take the values from the adaptor once this is a
+ // 1:N conversion pattern.
+ llvm::append_range(newOperands,
+ decomposeValue(rewriter, operand.getLoc(), operand,
+ getTypeConverter()));
+ }
// Create the new result types for the new `CallOp` and track the indices in
// the new call op's results that correspond to the old call op's results.
@@ -189,9 +187,8 @@ struct DecomposeCallGraphTypesForCallOp
void mlir::populateDecomposeCallGraphTypesPatterns(
MLIRContext *context, const TypeConverter &typeConverter,
- ValueDecomposer &decomposer, RewritePatternSet &patterns) {
+ RewritePatternSet &patterns) {
patterns
.add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
- DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
- decomposer);
+ DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
}
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 92216da9f201e6..de511c58ae6ee0 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -21,23 +21,40 @@ namespace {
/// given tuple value. If some tuple elements are, in turn, tuples, the elements
/// of those are extracted recursively such that the returned values have the
/// same types as `resultTypes.getFlattenedTypes()`.
-static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
- TupleType resultType, Value value,
- SmallVectorImpl<Value> &values) {
- for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
- Type elementType = resultType.getType(i);
- Value element = builder.create<test::GetTupleElementOp>(
- loc, elementType, value, builder.getI32IntegerAttr(i));
- if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
- // Recurse if the current element is also a tuple.
- if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element,
- values)))
- return failure();
- } else {
- values.push_back(element);
+static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder,
+ TypeRange resultTypes,
+ ValueRange inputs, Location loc) {
+ // Skip materialization if the single input value is not a tuple.
+ if (inputs.size() != 1)
+ return {};
+ Value tuple = inputs.front();
+ auto tupleType = dyn_cast<TupleType>(tuple.getType());
+ if (!tupleType)
+ return {};
+ // Skip materialization if the flattened types do not match the requested
+ // result types.
+ SmallVector<Type> flattenedTypes;
+ tupleType.getFlattenedTypes(flattenedTypes);
+ if (TypeRange(resultTypes) != TypeRange(flattenedTypes))
+ return {};
+ // Recursively decompose the tuple.
+ SmallVector<Value> result;
+ std::function<void(Value)> decompose = [&](Value tuple) {
+ auto tupleType = dyn_cast<TupleType>(tuple.getType());
+ if (!tupleType) {
+ // This is not a tuple.
+ result.push_back(tuple);
+ return;
}
- }
- return success();
+ for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
+ Type elementType = tupleType.getType(i);
+ Value element = builder.create<test::GetTupleElementOp>(
+ loc, elementType, tuple, builder.getI32IntegerAttr(i));
+ decompose(element);
+ }
+ };
+ decompose(tuple);
+ return result;
}
/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
@@ -82,8 +99,8 @@ static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
/// A pass for testing call graph type decomposition.
///
-/// This instantiates the patterns with a TypeConverter and ValueDecomposer
-/// that splits tuple types into their respective element types.
+/// This instantiates the patterns with a TypeConverter that splits tuple types
+/// into their respective element types.
/// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
struct TestDecomposeCallGraphTypes
: public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
@@ -123,12 +140,9 @@ struct TestDecomposeCallGraphTypes
return success();
});
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+ typeConverter.addTargetMaterialization(buildDecomposeTuple);
- ValueDecomposer decomposer;
- decomposer.addDecomposeValueConversion(buildDecomposeTuple);
-
- populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
- patterns);
+ populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);
if (failed(applyPartialConversion(module, target, std::move(patterns))))
return signalPassFailure();
``````````
</details>
https://github.com/llvm/llvm-project/pull/114192
More information about the Mlir-commits
mailing list