[Mlir-commits] [mlir] [mlir][emitc] Update the `WrapFuncInClassPass` pass (PR #179184)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 2 01:02:29 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-emitc

Author: Andrzej WarzyƄski (banach-space)

<details>
<summary>Changes</summary>

Update the `WrapFuncInClassPass` pass so that, by default, the generated
method is named `operator()` rather than `execute()`. This makes the
pass more generic, instead of catering to specific users expecting an
`execute()` method.

To preserve the original behaviour, add a new pass option to override
the method name: `func-name`. For example:

  mlir-opt file.mlir -wrap-emitc-func-in-class=func-name=execute

Additionally, make a couple of small editorial changes:
  * Rename `populateFuncPatterns` to `populateWrapFuncInClass` to make it
    clear that the corresponding pattern is specific to the
    `WrapFuncInClass` pass.
  * Remove `// CHECK: module {` to reduce test noise.

For context, this change was proposed on Discourse:
  * https://discourse.llvm.org/t/rfc-emitc-support-for-mlgo


---
Full diff: https://github.com/llvm/llvm-project/pull/179184.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td (+9-2) 
- (modified) mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h (+5-2) 
- (modified) mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp (+12-6) 
- (modified) mlir/test/Dialect/EmitC/wrap-func-in-class.mlir (+9-6) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 1893c101e735b..0ca9a075bc44e 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -25,7 +25,8 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
   let description = [{
     This pass transforms `emitc.func` operations into `emitc.class` operations.
     Function arguments become fields of the class, and the function body is moved
-    to a new `execute` method within the class.
+    to a new member method within the class. By default, this is `operator()`.
+
     If the corresponding function argument has attributes (accessed via `argAttrs`), 
     these attributes are attached to the field operation. 
     Otherwise, the field is created without additional attributes.
@@ -41,7 +42,7 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
     // becomes 
     emitc.class @modelClass {
       emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}
-      emitc.func @execute() {
+      emitc.func @operator() {
         %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
         %1 = get_field @input_tensor : !emitc.array<1xf32>
         %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
@@ -51,6 +52,12 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
     ```
   }];
   let dependentDialects = ["emitc::EmitCDialect"];
+  let options = [
+      Option<"funcName", "func-name", "std::string",
+      /*default=*/[{"operator"}],
+      "The name of the newly generated member function with body "
+      "matching the input function.">
+  ];
 }
 
 #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index bdf6d0985e6db..fd7c846a8ed39 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -28,8 +28,11 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
 /// Populates `patterns` with expression-related patterns.
 void populateExpressionPatterns(RewritePatternSet &patterns);
 
-/// Populates 'patterns' with func-related patterns.
-void populateFuncPatterns(RewritePatternSet &patterns);
+//===----------------------------------------------------------------------===//
+// The WrapFuncInClass pass.
+//===----------------------------------------------------------------------===//
+
+void populateWrapFuncInClass(RewritePatternSet &patterns, std::string &fName);
 
 } // namespace emitc
 } // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index 06d7e07005f8a..d58dd06b9e1e0 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -31,7 +31,7 @@ struct WrapFuncInClassPass
     Operation *rootOp = getOperation();
 
     RewritePatternSet patterns(&getContext());
-    populateFuncPatterns(patterns);
+    populateWrapFuncInClass(patterns, funcName);
 
     walkAndApplyPatterns(rootOp, std::move(patterns));
   }
@@ -43,8 +43,8 @@ struct WrapFuncInClassPass
 
 class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
 public:
-  WrapFuncInClass(MLIRContext *context)
-      : OpRewritePattern<emitc::FuncOp>(context) {}
+  WrapFuncInClass(MLIRContext *context, std::string &fName)
+      : OpRewritePattern<emitc::FuncOp>(context), funcName(fName) {}
 
   LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
@@ -76,7 +76,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     FunctionType funcType = funcOp.getFunctionType();
     Location loc = funcOp.getLoc();
     FuncOp newFuncOp =
-        emitc::FuncOp::create(rewriter, loc, ("execute"), funcType);
+        emitc::FuncOp::create(rewriter, loc, (funcName), funcType);
 
     rewriter.createBlock(&newFuncOp.getBody());
     newFuncOp.getBody().takeBody(funcOp.getBody());
@@ -102,8 +102,14 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     rewriter.replaceOp(funcOp, newClassOp);
     return success();
   }
+
+private:
+  /// Name of the newly generated member function with body matching the input
+  /// function.
+  std::string funcName;
 };
 
-void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
-  patterns.add<WrapFuncInClass>(patterns.getContext());
+void mlir::emitc::populateWrapFuncInClass(RewritePatternSet &patterns,
+                                          std::string &fName) {
+  patterns.add<WrapFuncInClass>(patterns.getContext(), fName);
 }
diff --git a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
index 809febd0267b1..b3a3f6a5ce7b7 100644
--- a/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap-func-in-class.mlir
@@ -1,20 +1,22 @@
 // RUN: mlir-opt %s -wrap-emitc-func-in-class -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -wrap-emitc-func-in-class=func-name=execute -split-input-file | FileCheck %s --check-prefixes=EXECUTE
 
 emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
   emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
   emitc.return
 }
 
-// CHECK: module {
 // CHECK:   emitc.class @fooClass {
 // CHECK:     emitc.field @fieldName0 : !emitc.array<1xf32>
-// CHECK:     emitc.func @execute() {
+// CHECK:     emitc.func @operator() {
 // CHECK:       %0 = get_field @fieldName0 : !emitc.array<1xf32>
 // CHECK:       call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
 // CHECK:       return
 // CHECK:     }
 // CHECK:   }
-// CHECK: }
+
+// EXECUTE-NOT: operator()
+// EXECUTE: execute()
 
 // -----
 
@@ -34,12 +36,11 @@ module attributes { } {
   }
 }
 
-// CHECK: module {
 // CHECK:   emitc.class @modelClass {
 // CHECK:     emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
 // CHECK:     emitc.field @fieldName1 : !emitc.array<1xf32>  {emitc.name_hint = "some_feature"}
 // CHECK:     emitc.field @fieldName2 : !emitc.array<1xf32>  {emitc.name_hint = "output_0"}
-// CHECK:     emitc.func @execute() {
+// CHECK:     emitc.func @operator() {
 // CHECK:       get_field @fieldName0 : !emitc.array<1xf32>
 // CHECK:       get_field @fieldName1 : !emitc.array<1xf32>
 // CHECK:       get_field @fieldName2 : !emitc.array<1xf32>
@@ -54,4 +55,6 @@ module attributes { } {
 // CHECK:       return
 // CHECK:     }
 // CHECK:   }
-// CHECK: }
+
+// EXECUTE-NOT: operator()
+// EXECUTE: execute()

``````````

</details>


https://github.com/llvm/llvm-project/pull/179184


More information about the Mlir-commits mailing list