[Mlir-commits] [mlir] ed4749f - [mlir] Add `populateFunctionOpInterfaceTypeConversionPattern` version which operates on any `FunctionOpInterface`
Ivan Butygin
llvmlistbot at llvm.org
Sat Nov 5 04:22:12 PDT 2022
Author: Ivan Butygin
Date: 2022-11-05T12:10:36+01:00
New Revision: ed4749f9373d0079a69e947486aa29042d606458
URL: https://github.com/llvm/llvm-project/commit/ed4749f9373d0079a69e947486aa29042d606458
DIFF: https://github.com/llvm/llvm-project/commit/ed4749f9373d0079a69e947486aa29042d606458.diff
LOG: [mlir] Add `populateFunctionOpInterfaceTypeConversionPattern` version which operates on any `FunctionOpInterface`
Exisitng version is always limited to some specific op.
Differential Revision: https://reviews.llvm.org/D137469
Added:
Modified:
mlir/include/mlir/Transforms/DialectConversion.h
mlir/lib/Transforms/Utils/DialectConversion.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 061edb196f0fc..6045b2237976e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -507,6 +507,9 @@ void populateFunctionOpInterfaceTypeConversionPattern(
patterns, converter);
}
+void populateAnyFunctionOpInterfaceTypeConversionPattern(
+ RewritePatternSet &patterns, TypeConverter &converter);
+
//===----------------------------------------------------------------------===//
// Conversion PatternRewriter
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 505127c459656..61bc4ffbe6f28 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3056,6 +3056,29 @@ auto TypeConverter::convertBlockSignature(Block *block)
// FunctionOpInterfaceSignatureConversion
//===----------------------------------------------------------------------===//
+static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
+ TypeConverter &typeConverter,
+ ConversionPatternRewriter &rewriter) {
+ FunctionType type = funcOp.getFunctionType().cast<FunctionType>();
+
+ // Convert the original function types.
+ TypeConverter::SignatureConversion result(type.getNumInputs());
+ SmallVector<Type, 1> newResults;
+ if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
+ failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
+ failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
+ typeConverter, &result)))
+ return failure();
+
+ // Update the function signature in-place.
+ auto newType = FunctionType::get(rewriter.getContext(),
+ result.getConvertedTypes(), newResults);
+
+ rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
+
+ return success();
+}
+
/// Create a default conversion pattern that rewrites the type signature of a
/// FunctionOpInterface op. This only supports ops which use FunctionType to
/// represent their type.
@@ -3067,27 +3090,21 @@ struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
: ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
LogicalResult
- matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
ConversionPatternRewriter &rewriter) const override {
FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
- FunctionType type = funcOp.getFunctionType().cast<FunctionType>();
-
- // Convert the original function types.
- TypeConverter::SignatureConversion result(type.getNumInputs());
- SmallVector<Type, 1> newResults;
- if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
- failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
- failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
- *typeConverter, &result)))
- return failure();
-
- // Update the function signature in-place.
- auto newType = FunctionType::get(rewriter.getContext(),
- result.getConvertedTypes(), newResults);
+ return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
+ }
+};
- rewriter.updateRootInPlace(op, [&] { funcOp.setType(newType); });
+struct AnyFunctionOpInterfaceSignatureConversion
+ : public OpInterfaceConversionPattern<FunctionOpInterface> {
+ using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
- return success();
+ LogicalResult
+ matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
+ ConversionPatternRewriter &rewriter) const override {
+ return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
}
};
} // namespace
@@ -3099,6 +3116,12 @@ void mlir::populateFunctionOpInterfaceTypeConversionPattern(
functionLikeOpName, patterns.getContext(), converter);
}
+void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
+ RewritePatternSet &patterns, TypeConverter &converter) {
+ patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
+ converter, patterns.getContext());
+}
+
//===----------------------------------------------------------------------===//
// ConversionTarget
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 17c8c1f84d35d..12f374777936c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -786,8 +786,8 @@ struct TestLegalizePatternDriver
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
TestCreateUnregisteredOp>(&getContext());
patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
- mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
- patterns, converter);
+ mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
+ converter);
mlir::populateCallOpTypeConversionPattern(patterns, converter);
// Define the conversion target used for the test.
@@ -1313,8 +1313,8 @@ struct TestTypeConversionDriver
TestTestSignatureConversionNoConverter>(converter,
&getContext());
patterns.add<TestTypeConversionAnotherProducer>(&getContext());
- mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
- patterns, converter);
+ mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
+ converter);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
More information about the Mlir-commits
mailing list