[Mlir-commits] [mlir] 6592b01 - [mlir][func] Add support for nested tuples to TestDecomposeCallGraphTypes.
Ingo Müller
llvmlistbot at llvm.org
Wed Feb 8 21:22:06 PST 2023
Author: Ingo Müller
Date: 2023-02-09T05:22:01Z
New Revision: 6592b010c1016f2305ffe4a83c244325c850fda9
URL: https://github.com/llvm/llvm-project/commit/6592b010c1016f2305ffe4a83c244325c850fda9
DIFF: https://github.com/llvm/llvm-project/commit/6592b010c1016f2305ffe4a83c244325c850fda9.diff
LOG: [mlir][func] Add support for nested tuples to TestDecomposeCallGraphTypes.
Nested tuples were only supported in some narrow edge cases (and
potentially only because the test ops like `test.make_tuple` aren't
properly verified). This patch adds a couple of test cases with tested
tuple types and makes them work in the test pass by extending the
argument materialization and decomposition functions.
Reviewed By: silvas
Differential Revision: https://reviews.llvm.org/D143579
Added:
Modified:
mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
mlir/test/Transforms/decompose-call-graph-types.mlir
mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
index 6f27cbb3a59fe..29bab1dec8638 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
@@ -53,8 +53,8 @@ class ValueDecomposer {
/// This method registers a callback function that will be called to decompose
/// a value of a certain type into 0, 1, or multiple values.
- template <typename FnT,
- typename T = typename llvm::function_traits<FnT>::template arg_t<2>>
+ template <typename FnT, typename T = typename llvm::function_traits<
+ std::decay_t<FnT>>::template arg_t<2>>
void addDecomposeValueConversion(FnT &&callback) {
decomposeValueConversions.emplace_back(
wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir
index 5ecbad131504d..604e948afaf6e 100644
--- a/mlir/test/Transforms/decompose-call-graph-types.mlir
+++ b/mlir/test/Transforms/decompose-call-graph-types.mlir
@@ -37,6 +37,29 @@ func.func @recursive_decomposition(%arg0: tuple<tuple<tuple<i1>>>) -> tuple<tupl
// -----
+// Test case: Type that needs to be recursively decomposed at
diff erent recursion depths.
+
+// CHECK-LABEL: func @mixed_recursive_decomposition(
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) {
+// CHECK: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<>
+// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple<i1>
+// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple<i2>
+// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple<i2>) -> tuple<tuple<i2>>
+// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple<i1>, tuple<tuple<i2>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
+// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<>
+// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 1 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<i1>
+// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple<i1>) -> i1
+// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 2 : i32} : (tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<i2>>
+// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple<tuple<i2>>) -> tuple<i2>
+// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple<i2>) -> i2
+// CHECK: return %[[V7]], %[[V10]] : i1, i2
+func.func @mixed_recursive_decomposition(%arg0: tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>) -> tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>> {
+ return %arg0 : tuple<tuple<>, tuple<i1>, tuple<tuple<i2>>>
+}
+
+// -----
+
// Test case: Check decomposition of calls.
// CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32)
@@ -89,6 +112,26 @@ func.func @unconverted_op_result() -> tuple<i1, i32> {
// -----
+// Test case: Ensure decompositions are inserted properly around results of
+// unconverted ops in the case of
diff erent nesting levels.
+
+// CHECK-LABEL: func @nested_unconverted_op_result(
+// CHECK-SAME: %[[ARG0:.*]]: i1,
+// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) {
+// CHECK: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple<i32>
+// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple<i32>) -> tuple<i1, tuple<i32>>
+// CHECK: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>>
+// CHECK: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple<i1, tuple<i32>>) -> i1
+// CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple<i1, tuple<i32>>) -> tuple<i32>
+// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple<i32>) -> i32
+// CHECK: return %[[V3]], %[[V5]] : i1, i32
+func.func @nested_unconverted_op_result(%arg: tuple<i1, tuple<i32>>) -> tuple<i1, tuple<i32>> {
+ %0 = "test.op"(%arg) : (tuple<i1, tuple<i32>>) -> (tuple<i1, tuple<i32>>)
+ return %0 : tuple<i1, tuple<i32>>
+}
+
+// -----
+
// Test case: Check mixed decomposed and non-decomposed args.
// This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions.
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 9492d23b1778f..41e166600c433 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -16,6 +16,70 @@
using namespace mlir;
namespace {
+/// Creates a sequence of `test.get_tuple_element` ops for all elements of a
+/// given tuple value. If some tuple elements are, in turn, tuples, the elements
+/// of those are extracted recursively such that the returned values have the
+/// same types as `resultTypes.getFlattenedTypes()`.
+static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
+ TupleType resultType, Value value,
+ SmallVectorImpl<Value> &values) {
+ for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
+ Type elementType = resultType.getType(i);
+ Value element = builder.create<test::GetTupleElementOp>(
+ loc, elementType, value, builder.getI32IntegerAttr(i));
+ if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+ // Recurse if the current element is also a tuple.
+ if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element,
+ values)))
+ return failure();
+ } else {
+ values.push_back(element);
+ }
+ }
+ return success();
+}
+
+/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
+/// type `resultType`. If that type is nested, each nested tuple is built
+/// recursively with another `test.make_tuple` op.
+static std::optional<Value> buildMakeTupleOp(OpBuilder &builder,
+ TupleType resultType,
+ ValueRange inputs, Location loc) {
+ // Build one value for each element at this nesting level.
+ SmallVector<Value> elements;
+ elements.reserve(resultType.getTypes().size());
+ ValueRange::iterator inputIt = inputs.begin();
+ for (Type elementType : resultType.getTypes()) {
+ if (auto nestedTupleType = elementType.dyn_cast<TupleType>()) {
+ // Determine how many input values are needed for the nested elements of
+ // the nested TupleType and advance inputIt by that number.
+ // TODO: We only need the *number* of nested types, not the types itself.
+ // Maybe it's worth adding a more efficient overload?
+ SmallVector<Type> nestedFlattenedTypes;
+ nestedTupleType.getFlattenedTypes(nestedFlattenedTypes);
+ size_t numNestedFlattenedTypes = nestedFlattenedTypes.size();
+ ValueRange nestedFlattenedelements(inputIt,
+ inputIt + numNestedFlattenedTypes);
+ inputIt += numNestedFlattenedTypes;
+
+ // Recurse on the values for the nested TupleType.
+ std::optional<Value> res = buildMakeTupleOp(builder, nestedTupleType,
+ nestedFlattenedelements, loc);
+ if (!res.has_value())
+ return {};
+
+ // The tuple constructed by the conversion is the element value.
+ elements.push_back(res.value());
+ } else {
+ // Base case: take one input as is.
+ elements.push_back(*inputIt++);
+ }
+ }
+
+ // Assemble the tuple from the elements.
+ return builder.create<test::MakeTupleOp>(loc, resultType, elements);
+}
+
/// A pass for testing call graph type decomposition.
///
/// This instantiates the patterns with a TypeConverter and ValueDecomposer
@@ -39,7 +103,6 @@ struct TestDecomposeCallGraphTypes
auto *context = &getContext();
TypeConverter typeConverter;
ConversionTarget target(*context);
- ValueDecomposer decomposer;
RewritePatternSet patterns(context);
target.addLegalDialect<test::TestDialect>();
@@ -59,27 +122,10 @@ struct TestDecomposeCallGraphTypes
tupleType.getFlattenedTypes(types);
return success();
});
+ typeConverter.addArgumentMaterialization(buildMakeTupleOp);
- decomposer.addDecomposeValueConversion([](OpBuilder &builder, Location loc,
- TupleType resultType, Value value,
- SmallVectorImpl<Value> &values) {
- for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
- Value res = builder.create<test::GetTupleElementOp>(
- loc, resultType.getType(i), value, builder.getI32IntegerAttr(i));
- values.push_back(res);
- }
- return success();
- });
-
- typeConverter.addArgumentMaterialization(
- [](OpBuilder &builder, TupleType resultType, ValueRange inputs,
- Location loc) -> std::optional<Value> {
- if (inputs.size() == 1)
- return std::nullopt;
- TupleType tuple = builder.getTupleType(inputs.getTypes());
- Value value = builder.create<test::MakeTupleOp>(loc, tuple, inputs);
- return value;
- });
+ ValueDecomposer decomposer;
+ decomposer.addDecomposeValueConversion(buildDecomposeTuple);
populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
patterns);
More information about the Mlir-commits
mailing list