[Mlir-commits] [mlir] a2590e0 - [mlir][Transforms] Make 1:N function conversion pattern interface-based (#92395)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 17 04:44:04 PDT 2024
Author: Matthias Springer
Date: 2024-05-17T13:44:00+02:00
New Revision: a2590e0c145c56928a8870d9a6ea76ccbf4fcfeb
URL: https://github.com/llvm/llvm-project/commit/a2590e0c145c56928a8870d9a6ea76ccbf4fcfeb
DIFF: https://github.com/llvm/llvm-project/commit/a2590e0c145c56928a8870d9a6ea76ccbf4fcfeb.diff
LOG: [mlir][Transforms] Make 1:N function conversion pattern interface-based (#92395)
This commit turns the 1:N dialect conversion pattern for function
signatures into a pattern for `FunctionOpInterface`. This is similar to
the interface-based pattern that is provided with the 1:1 dialect
conversion (`populateFunctionOpInterfaceTypeConversionPattern`). No
change in functionality apart from supporting all `FunctionOpInterface`
ops and not just `func::FuncOp`.
Added:
Modified:
mlir/include/mlir/Transforms/OneToNTypeConversion.h
mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index 933961814cbe4..4c689ba219e88 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -297,6 +297,20 @@ LogicalResult
applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
const FrozenRewritePatternSet &patterns);
+/// Add a pattern to the given pattern list to convert the signature of a
+/// FunctionOpInterface op with the given type converter. This only supports
+/// ops which use FunctionType to represent their type. This is intended to be
+/// used with the 1:N dialect conversion.
+void populateOneToNFunctionOpInterfaceTypeConversionPattern(
+ StringRef functionLikeOpName, TypeConverter &converter,
+ RewritePatternSet &patterns);
+template <typename FuncOpT>
+void populateOneToNFunctionOpInterfaceTypeConversionPattern(
+ TypeConverter &converter, RewritePatternSet &patterns) {
+ populateOneToNFunctionOpInterfaceTypeConversionPattern(
+ FuncOpT::getOperationName(), converter, patterns);
+}
+
} // namespace mlir
#endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
index a5b88338e6381..8489396da7a2c 100644
--- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
@@ -49,50 +49,6 @@ class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
}
};
-class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
-public:
- using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
-
- LogicalResult
- matchAndRewrite(FuncOp op, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
- auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
-
- // Construct mapping for function arguments.
- OneToNTypeMapping argumentMapping(op.getArgumentTypes());
- if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
- argumentMapping)))
- return failure();
-
- // Construct mapping for function results.
- OneToNTypeMapping funcResultMapping(op.getResultTypes());
- if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
- funcResultMapping)))
- return failure();
-
- // Nothing to do if the op doesn't have any non-identity conversions for its
- // operands or results.
- if (!argumentMapping.hasNonIdentityConversion() &&
- !funcResultMapping.hasNonIdentityConversion())
- return failure();
-
- // Update the function signature in-place.
- auto newType = FunctionType::get(rewriter.getContext(),
- argumentMapping.getConvertedTypes(),
- funcResultMapping.getConvertedTypes());
- rewriter.modifyOpInPlace(op, [&] { op.setType(newType); });
-
- // Update block signatures.
- if (!op.isExternal()) {
- Region *region = &op.getBody();
- Block *block = ®ion->front();
- rewriter.applySignatureConversion(block, argumentMapping);
- }
-
- return success();
- }
-};
-
class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
public:
using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
@@ -121,10 +77,11 @@ void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
patterns.add<
// clang-format off
ConvertTypesInFuncCallOp,
- ConvertTypesInFuncFuncOp,
ConvertTypesInFuncReturnOp
// clang-format on
>(typeConverter, patterns.getContext());
+ populateOneToNFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+ typeConverter, patterns);
}
} // namespace mlir
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index fef9d8eb0fef7..f6e8e9e7ad339 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -8,6 +8,7 @@
#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h"
@@ -412,4 +413,62 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
return success();
}
+namespace {
+class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
+public:
+ FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
+ MLIRContext *ctx,
+ TypeConverter &converter)
+ : OneToNConversionPattern(converter, functionLikeOpName, /*benefit=*/1,
+ ctx) {}
+
+ LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
+ const OneToNTypeMapping &operandMapping,
+ const OneToNTypeMapping &resultMapping,
+ ValueRange convertedOperands) const override {
+ auto funcOp = cast<FunctionOpInterface>(op);
+ auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+
+ // Construct mapping for function arguments.
+ OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
+ if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
+ argumentMapping)))
+ return failure();
+
+ // Construct mapping for function results.
+ OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
+ if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
+ funcResultMapping)))
+ return failure();
+
+ // Nothing to do if the op doesn't have any non-identity conversions for its
+ // operands or results.
+ if (!argumentMapping.hasNonIdentityConversion() &&
+ !funcResultMapping.hasNonIdentityConversion())
+ return failure();
+
+ // Update the function signature in-place.
+ auto newType = FunctionType::get(rewriter.getContext(),
+ argumentMapping.getConvertedTypes(),
+ funcResultMapping.getConvertedTypes());
+ rewriter.modifyOpInPlace(op, [&] { funcOp.setType(newType); });
+
+ // Update block signatures.
+ if (!funcOp.isExternal()) {
+ Region *region = &funcOp.getFunctionBody();
+ Block *block = ®ion->front();
+ rewriter.applySignatureConversion(block, argumentMapping);
+ }
+
+ return success();
+ }
+};
+} // namespace
+
+void populateOneToNFunctionOpInterfaceTypeConversionPattern(
+ StringRef functionLikeOpName, TypeConverter &converter,
+ RewritePatternSet &patterns) {
+ patterns.add<FunctionOpInterfaceSignatureConversion>(
+ functionLikeOpName, patterns.getContext(), converter);
+}
} // namespace mlir
More information about the Mlir-commits
mailing list