[Mlir-commits] [mlir] [mlir][emitc] Update the `WrapFuncInClassPass` pass (PR #179184)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 2 01:02:29 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-emitc
Author: Andrzej WarzyĆski (banach-space)
<details>
<summary>Changes</summary>
Update the `WrapFuncInClassPass` pass so that, by default, the generated
method is named `operator()` rather than `execute()`. This makes the
pass more generic, instead of catering to specific users expecting an
`execute()` method.
To preserve the original behaviour, add a new pass option to override
the method name: `func-name`. For example:
mlir-opt file.mlir -wrap-emitc-func-in-class=func-name=execute
Additionally, make a couple of small editorial changes:
* Rename `populateFuncPatterns` to `populateWrapFuncInClass` to make it
clear that the corresponding pattern is specific to the
`WrapFuncInClass` pass.
* Remove `// CHECK: module {` to reduce test noise.
For context, this change was proposed on Discourse:
* https://discourse.llvm.org/t/rfc-emitc-support-for-mlgo
---
Full diff: https://github.com/llvm/llvm-project/pull/179184.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td (+9-2)
- (modified) mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h (+5-2)
- (modified) mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp (+12-6)
- (modified) mlir/test/Dialect/EmitC/wrap-func-in-class.mlir (+9-6)
``````````diff
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 1893c101e735b..0ca9a075bc44e 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -25,7 +25,8 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
let description = [{
This pass transforms `emitc.func` operations into `emitc.class` operations.
Function arguments become fields of the class, and the function body is moved
- to a new `execute` method within the class.
+ to a new member method within the class. By default, this is `operator()`.
+
If the corresponding function argument has attributes (accessed via `argAttrs`),
these attributes are attached to the field operation.
Otherwise, the field is created without additional attributes.
@@ -41,7 +42,7 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
// becomes
emitc.class @modelClass {
emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}
- emitc.func @execute() {
+ emitc.func @operator() {
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
%1 = get_field @input_tensor : !emitc.array<1xf32>
%2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
@@ -51,6 +52,12 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
```
}];
let dependentDialects = ["emitc::EmitCDialect"];
+ let options = [
+ Option<"funcName", "func-name", "std::string",
+ /*default=*/[{"operator"}],
+ "The name of the newly generated member function with body "
+ "matching the input function.">
+ ];
}
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index bdf6d0985e6db..fd7c846a8ed39 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -28,8 +28,11 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
/// Populates `patterns` with expression-related patterns.
void populateExpressionPatterns(RewritePatternSet &patterns);
-/// Populates 'patterns' with func-related patterns.
-void populateFuncPatterns(RewritePatternSet &patterns);
+//===----------------------------------------------------------------------===//
+// The WrapFuncInClass pass.
+//===----------------------------------------------------------------------===//
+
+void populateWrapFuncInClass(RewritePatternSet &patterns, std::string &fName);
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index 06d7e07005f8a..d58dd06b9e1e0 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -31,7 +31,7 @@ struct WrapFuncInClassPass
Operation *rootOp = getOperation();
RewritePatternSet patterns(&getContext());
- populateFuncPatterns(patterns);
+ populateWrapFuncInClass(patterns, funcName);
walkAndApplyPatterns(rootOp, std::move(patterns));
}
@@ -43,8 +43,8 @@ struct WrapFuncInClassPass
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
public:
- WrapFuncInClass(MLIRContext *context)
- : OpRewritePattern<emitc::FuncOp>(context) {}
+ WrapFuncInClass(MLIRContext *context, std::string &fName)
+ : OpRewritePattern<emitc::FuncOp>(context), funcName(fName) {}
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
PatternRewriter &rewriter) const override {
@@ -76,7 +76,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
FunctionType funcType = funcOp.getFunctionType();
Location loc = funcOp.getLoc();
FuncOp newFuncOp =
- emitc::FuncOp::create(rewriter, loc, ("execute"), funcType);
+ emitc::FuncOp::create(rewriter, loc, (funcName), funcType);
rewriter.createBlock(&newFuncOp.getBody());
newFuncOp.getBody().takeBody(funcOp.getBody());
@@ -102,8 +102,14 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
rewriter.replaceOp(funcOp, newClassOp);
return success();
}
+
+private:
+ /// Name of the newly generated member function with body matching the input
+ /// function.
+ std::string funcName;
};
-void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
- patterns.add<WrapFuncInClass>(patterns.getContext());
+void mlir::emitc::populateWrapFuncInClass(RewritePatternSet &patterns,
+ std::string &fName) {
+ patterns.add<WrapFuncInClass>(patterns.getContext(), fName);
}
diff --git a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
index 809febd0267b1..b3a3f6a5ce7b7 100644
--- a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
@@ -1,20 +1,22 @@
// RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -wrap-emitc-func-in-class=func-name=execute -split-input-file | FileCheck %s --check-prefixes=EXECUTE
emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
emitc.return
}
-// CHECK: module {
// CHECK: emitc.class @fooClass {
// CHECK: emitc.field @fieldName0 : !emitc.array<1xf32>
-// CHECK: emitc.func @execute() {
+// CHECK: emitc.func @operator() {
// CHECK: %0 = get_field @fieldName0 : !emitc.array<1xf32>
// CHECK: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
// CHECK: return
// CHECK: }
// CHECK: }
-// CHECK: }
+
+// EXECUTE-NOT: operator()
+// EXECUTE: execute()
// -----
@@ -34,12 +36,11 @@ module attributes { } {
}
}
-// CHECK: module {
// CHECK: emitc.class @modelClass {
// CHECK: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
// CHECK: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"}
// CHECK: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"}
-// CHECK: emitc.func @execute() {
+// CHECK: emitc.func @operator() {
// CHECK: get_field @fieldName0 : !emitc.array<1xf32>
// CHECK: get_field @fieldName1 : !emitc.array<1xf32>
// CHECK: get_field @fieldName2 : !emitc.array<1xf32>
@@ -54,4 +55,6 @@ module attributes { } {
// CHECK: return
// CHECK: }
// CHECK: }
-// CHECK: }
+
+// EXECUTE-NOT: operator()
+// EXECUTE: execute()
``````````
</details>
https://github.com/llvm/llvm-project/pull/179184
More information about the Mlir-commits
mailing list