[Mlir-commits] [mlir] [mlir][Func] Support 1:N result type conversions in `func.call` conversion (PR #117413)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 22 20:48:51 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit adds support for 1:N result type conversions for `func.call` ops. In that case, argument materializations to the original result type should be inserted (via `replaceOpWithMultiple`).
This commit is in preparation of merging the 1:1 and 1:N conversion drivers.
---
Full diff: https://github.com/llvm/llvm-project/pull/117413.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp (+24-11)
- (modified) mlir/test/Transforms/test-legalizer.mlir (+22)
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+12)
``````````diff
diff --git a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
index eb444d665ff260..b1cde6ca5d2fca 100644
--- a/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/FuncConversions.cpp
@@ -23,21 +23,34 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
LogicalResult
matchAndRewrite(CallOp callOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // Convert the original function results.
+ // Convert the original function results. Keep track of how many result
+ // types an original result type is converted into.
+ SmallVector<size_t> numResultsReplacments;
SmallVector<Type, 1> convertedResults;
- if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
- convertedResults)))
- return failure();
-
- // If this isn't a one-to-one type mapping, we don't know how to aggregate
- // the results.
- if (callOp->getNumResults() != convertedResults.size())
- return failure();
+ size_t numFlattenedResults = 0;
+ for (auto [idx, type] : llvm::enumerate(callOp.getResultTypes())) {
+ if (failed(typeConverter->convertTypes(type, convertedResults)))
+ return failure();
+ numResultsReplacments.push_back(convertedResults.size() -
+ numFlattenedResults);
+ numFlattenedResults = convertedResults.size();
+ }
// Substitute with the new result types from the corresponding FuncType
// conversion.
- rewriter.replaceOpWithNewOp<CallOp>(
- callOp, callOp.getCallee(), convertedResults, adaptor.getOperands());
+ auto newCallOp =
+ rewriter.create<CallOp>(callOp.getLoc(), callOp.getCallee(),
+ convertedResults, adaptor.getOperands());
+ SmallVector<ValueRange> replacements;
+ size_t offset = 0;
+ for (int i = 0, e = callOp->getNumResults(); i < e; ++i) {
+ replacements.push_back(
+ newCallOp->getResults().slice(offset, numResultsReplacments[i]));
+ offset += numResultsReplacments[i];
+ }
+ assert(offset == convertedResults.size() &&
+ "expected that all converted results are used");
+ rewriter.replaceOpWithMultiple(callOp, replacements);
return success();
}
};
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index e5503ee8920424..7c6e3c5c3a6c53 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -463,3 +463,25 @@ func.func @circular_mapping() {
%0 = "test.erase_op"() : () -> (i64)
"test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
}
+
+// -----
+
+module {
+// CHECK-LABEL: func.func private @foo() -> (i23, i23)
+func.func private @foo() -> (i22, i24)
+
+// CHECK: func.func @bar()
+func.func @bar() {
+ // i22 is converted to (i23, i23).
+ // i24 is converted to ().
+ // CHECK: %[[call:.*]]:2 = call @foo() : () -> (i23, i23)
+ %0:2 = func.call @foo() : () -> (i22, i24)
+
+ // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24
+ // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (i23, i23) -> i22
+ // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (i22, i24) -> ()
+ // expected-remark @below{{'test.some_user' is not legalizable}}
+ "test.some_user"(%0#0, %0#1) : (i22, i24) -> ()
+ "test.return"() : () -> ()
+}
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 3df6cff3c0a60b..912173f391086e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1215,6 +1215,18 @@ struct TestTypeConverter : public TypeConverter {
return success();
}
+ // Convert I22 to multiple I23.
+ if (t.isInteger(22)) {
+ results.push_back(IntegerType::get(t.getContext(), 23));
+ results.push_back(IntegerType::get(t.getContext(), 23));
+ return success();
+ }
+
+ // Drop I24 types.
+ if (t.isInteger(24)) {
+ return success();
+ }
+
// Otherwise, convert the type directly.
results.push_back(t);
return success();
``````````
</details>
https://github.com/llvm/llvm-project/pull/117413
More information about the Mlir-commits
mailing list