[Mlir-commits] [mlir] ed4749f - [mlir] Add `populateFunctionOpInterfaceTypeConversionPattern` version which operates on any `FunctionOpInterface`

Ivan Butygin llvmlistbot at llvm.org
Sat Nov 5 04:22:12 PDT 2022


Author: Ivan Butygin
Date: 2022-11-05T12:10:36+01:00
New Revision: ed4749f9373d0079a69e947486aa29042d606458

URL: https://github.com/llvm/llvm-project/commit/ed4749f9373d0079a69e947486aa29042d606458
DIFF: https://github.com/llvm/llvm-project/commit/ed4749f9373d0079a69e947486aa29042d606458.diff

LOG: [mlir] Add `populateFunctionOpInterfaceTypeConversionPattern` version which operates on any `FunctionOpInterface`

Exisitng version is always limited to some specific op.

Differential Revision: https://reviews.llvm.org/D137469

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 061edb196f0fc..6045b2237976e 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -507,6 +507,9 @@ void populateFunctionOpInterfaceTypeConversionPattern(
                                                    patterns, converter);
 }
 
+void populateAnyFunctionOpInterfaceTypeConversionPattern(
+    RewritePatternSet &patterns, TypeConverter &converter);
+
 //===----------------------------------------------------------------------===//
 // Conversion PatternRewriter
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 505127c459656..61bc4ffbe6f28 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3056,6 +3056,29 @@ auto TypeConverter::convertBlockSignature(Block *block)
 // FunctionOpInterfaceSignatureConversion
 //===----------------------------------------------------------------------===//
 
+static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
+                                        TypeConverter &typeConverter,
+                                        ConversionPatternRewriter &rewriter) {
+  FunctionType type = funcOp.getFunctionType().cast<FunctionType>();
+
+  // Convert the original function types.
+  TypeConverter::SignatureConversion result(type.getNumInputs());
+  SmallVector<Type, 1> newResults;
+  if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
+      failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
+      failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
+                                         typeConverter, &result)))
+    return failure();
+
+  // Update the function signature in-place.
+  auto newType = FunctionType::get(rewriter.getContext(),
+                                   result.getConvertedTypes(), newResults);
+
+  rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); });
+
+  return success();
+}
+
 /// Create a default conversion pattern that rewrites the type signature of a
 /// FunctionOpInterface op. This only supports ops which use FunctionType to
 /// represent their type.
@@ -3067,27 +3090,21 @@ struct FunctionOpInterfaceSignatureConversion : public ConversionPattern {
       : ConversionPattern(converter, functionLikeOpName, /*benefit=*/1, ctx) {}
 
   LogicalResult
-  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+  matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
                   ConversionPatternRewriter &rewriter) const override {
     FunctionOpInterface funcOp = cast<FunctionOpInterface>(op);
-    FunctionType type = funcOp.getFunctionType().cast<FunctionType>();
-
-    // Convert the original function types.
-    TypeConverter::SignatureConversion result(type.getNumInputs());
-    SmallVector<Type, 1> newResults;
-    if (failed(typeConverter->convertSignatureArgs(type.getInputs(), result)) ||
-        failed(typeConverter->convertTypes(type.getResults(), newResults)) ||
-        failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
-                                           *typeConverter, &result)))
-      return failure();
-
-    // Update the function signature in-place.
-    auto newType = FunctionType::get(rewriter.getContext(),
-                                     result.getConvertedTypes(), newResults);
+    return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
+  }
+};
 
-    rewriter.updateRootInPlace(op, [&] { funcOp.setType(newType); });
+struct AnyFunctionOpInterfaceSignatureConversion
+    : public OpInterfaceConversionPattern<FunctionOpInterface> {
+  using OpInterfaceConversionPattern::OpInterfaceConversionPattern;
 
-    return success();
+  LogicalResult
+  matchAndRewrite(FunctionOpInterface funcOp, ArrayRef<Value> /*operands*/,
+                  ConversionPatternRewriter &rewriter) const override {
+    return convertFuncOpTypes(funcOp, *typeConverter, rewriter);
   }
 };
 } // namespace
@@ -3099,6 +3116,12 @@ void mlir::populateFunctionOpInterfaceTypeConversionPattern(
       functionLikeOpName, patterns.getContext(), converter);
 }
 
+void mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(
+    RewritePatternSet &patterns, TypeConverter &converter) {
+  patterns.add<AnyFunctionOpInterfaceSignatureConversion>(
+      converter, patterns.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // ConversionTarget
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 17c8c1f84d35d..12f374777936c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -786,8 +786,8 @@ struct TestLegalizePatternDriver
              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
              TestCreateUnregisteredOp>(&getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
-    mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
-        patterns, converter);
+    mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
+                                                              converter);
     mlir::populateCallOpTypeConversionPattern(patterns, converter);
 
     // Define the conversion target used for the test.
@@ -1313,8 +1313,8 @@ struct TestTypeConversionDriver
                  TestTestSignatureConversionNoConverter>(converter,
                                                          &getContext());
     patterns.add<TestTypeConversionAnotherProducer>(&getContext());
-    mlir::populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
-        patterns, converter);
+    mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
+                                                              converter);
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))


        


More information about the Mlir-commits mailing list