[Mlir-commits] [mlir] [mlir][Transforms] Make 1:N function conversion pattern interface-based (PR #92395)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 16 05:43:43 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit turns the 1:N dialect conversion pattern for function signatures into a pattern for `FunctionOpInterface`. This is similar to the interface-based pattern that is provided with the 1:1 dialect conversion (`populateFunctionOpInterfaceTypeConversionPattern`). No change in functionality apart from supporting all `FunctionOpInterface` ops and not just `func::FuncOp`.

---
Full diff: https://github.com/llvm/llvm-project/pull/92395.diff


3 Files Affected:

- (modified) mlir/include/mlir/Transforms/OneToNTypeConversion.h (+14) 
- (modified) mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp (+2-45) 
- (modified) mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp (+59) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
index 933961814cbe4..4c689ba219e88 100644
--- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h
+++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h
@@ -297,6 +297,20 @@ LogicalResult
 applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
                              const FrozenRewritePatternSet &patterns);
 
+/// Add a pattern to the given pattern list to convert the signature of a
+/// FunctionOpInterface op with the given type converter. This only supports
+/// ops which use FunctionType to represent their type. This is intended to be
+/// used with the 1:N dialect conversion.
+void populateOneToNFunctionOpInterfaceTypeConversionPattern(
+    StringRef functionLikeOpName, TypeConverter &converter,
+    RewritePatternSet &patterns);
+template <typename FuncOpT>
+void populateOneToNFunctionOpInterfaceTypeConversionPattern(
+    TypeConverter &converter, RewritePatternSet &patterns) {
+  populateOneToNFunctionOpInterfaceTypeConversionPattern(
+      FuncOpT::getOperationName(), converter, patterns);
+}
+
 } // namespace mlir
 
 #endif // MLIR_TRANSFORMS_ONETONTYPECONVERSION_H
diff --git a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
index a5b88338e6381..8489396da7a2c 100644
--- a/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/OneToNFuncConversions.cpp
@@ -49,50 +49,6 @@ class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern<CallOp> {
   }
 };
 
-class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern<FuncOp> {
-public:
-  using OneToNOpConversionPattern<FuncOp>::OneToNOpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(FuncOp op, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
-    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
-
-    // Construct mapping for function arguments.
-    OneToNTypeMapping argumentMapping(op.getArgumentTypes());
-    if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(),
-                                                 argumentMapping)))
-      return failure();
-
-    // Construct mapping for function results.
-    OneToNTypeMapping funcResultMapping(op.getResultTypes());
-    if (failed(typeConverter->computeTypeMapping(op.getResultTypes(),
-                                                 funcResultMapping)))
-      return failure();
-
-    // Nothing to do if the op doesn't have any non-identity conversions for its
-    // operands or results.
-    if (!argumentMapping.hasNonIdentityConversion() &&
-        !funcResultMapping.hasNonIdentityConversion())
-      return failure();
-
-    // Update the function signature in-place.
-    auto newType = FunctionType::get(rewriter.getContext(),
-                                     argumentMapping.getConvertedTypes(),
-                                     funcResultMapping.getConvertedTypes());
-    rewriter.modifyOpInPlace(op, [&] { op.setType(newType); });
-
-    // Update block signatures.
-    if (!op.isExternal()) {
-      Region *region = &op.getBody();
-      Block *block = &region->front();
-      rewriter.applySignatureConversion(block, argumentMapping);
-    }
-
-    return success();
-  }
-};
-
 class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern<ReturnOp> {
 public:
   using OneToNOpConversionPattern<ReturnOp>::OneToNOpConversionPattern;
@@ -121,10 +77,11 @@ void populateFuncTypeConversionPatterns(TypeConverter &typeConverter,
   patterns.add<
       // clang-format off
       ConvertTypesInFuncCallOp,
-      ConvertTypesInFuncFuncOp,
       ConvertTypesInFuncReturnOp
       // clang-format on
       >(typeConverter, patterns.getContext());
+  populateOneToNFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
+      typeConverter, patterns);
 }
 
 } // namespace mlir
diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
index fef9d8eb0fef7..f6e8e9e7ad339 100644
--- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
+++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Transforms/OneToNTypeConversion.h"
 
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/SmallSet.h"
 
@@ -412,4 +413,62 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
   return success();
 }
 
+namespace {
+class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern {
+public:
+  FunctionOpInterfaceSignatureConversion(StringRef functionLikeOpName,
+                                         MLIRContext *ctx,
+                                         TypeConverter &converter)
+      : OneToNConversionPattern(converter, functionLikeOpName, /*benefit=*/1,
+                                ctx) {}
+
+  LogicalResult matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter,
+                                const OneToNTypeMapping &operandMapping,
+                                const OneToNTypeMapping &resultMapping,
+                                ValueRange convertedOperands) const override {
+    auto funcOp = cast<FunctionOpInterface>(op);
+    auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
+
+    // Construct mapping for function arguments.
+    OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes());
+    if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(),
+                                                 argumentMapping)))
+      return failure();
+
+    // Construct mapping for function results.
+    OneToNTypeMapping funcResultMapping(funcOp.getResultTypes());
+    if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(),
+                                                 funcResultMapping)))
+      return failure();
+
+    // Nothing to do if the op doesn't have any non-identity conversions for its
+    // operands or results.
+    if (!argumentMapping.hasNonIdentityConversion() &&
+        !funcResultMapping.hasNonIdentityConversion())
+      return failure();
+
+    // Update the function signature in-place.
+    auto newType = FunctionType::get(rewriter.getContext(),
+                                     argumentMapping.getConvertedTypes(),
+                                     funcResultMapping.getConvertedTypes());
+    rewriter.modifyOpInPlace(op, [&] { funcOp.setType(newType); });
+
+    // Update block signatures.
+    if (!funcOp.isExternal()) {
+      Region *region = &funcOp.getFunctionBody();
+      Block *block = &region->front();
+      rewriter.applySignatureConversion(block, argumentMapping);
+    }
+
+    return success();
+  }
+};
+} // namespace
+
+void populateOneToNFunctionOpInterfaceTypeConversionPattern(
+    StringRef functionLikeOpName, TypeConverter &converter,
+    RewritePatternSet &patterns) {
+  patterns.add<FunctionOpInterfaceSignatureConversion>(
+      functionLikeOpName, patterns.getContext(), converter);
+}
 } // namespace mlir

``````````

</details>


https://github.com/llvm/llvm-project/pull/92395


More information about the Mlir-commits mailing list