[Mlir-commits] [mlir] [emitC]Pass in `mlir-opt` to wrap a func in class (PR #141158)

Jaden Angella llvmlistbot at llvm.org
Mon Jun 16 11:17:51 PDT 2025


Valentin Clement =?utf-8?b?KOODkOODrOODsw=?=,Jaddyen <ajaden at google.com>,Jaddyen
 <ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaddyen
 <ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaden Angella
 <141196890+Jaddyen at users.noreply.github.com>,Jaden Angella
 <141196890+Jaddyen at users.noreply.github.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/141158 at github.com>


================
@@ -0,0 +1,139 @@
+//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Rewrite.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/LogicalResult.h"
+
+namespace mlir {
+namespace emitc {
+
+#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
+#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
+
+namespace {
+
+struct WrapFuncInClassPass
+    : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
+  using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
+  void runOnOperation() override {
+    Operation *rootOp = getOperation();
+    MLIRContext *context = rootOp->getContext();
+
+    RewritePatternSet patterns(context);
+    populateFuncPatterns(patterns, namedAttribute);
+
+    if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
+      return signalPassFailure();
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<emitc::EmitCDialect>();
+  }
+};
+
+} // namespace
+
+} // namespace emitc
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::emitc;
+
+class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
+private:
+  std::string attributeName;
+
+public:
+  WrapFuncInClass(MLIRContext *context, const std::string &attrName)
+      : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
+
+  LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override {
+    if (funcOp->getParentOfType<emitc::ClassOp>()) {
+      return failure();
+    }
+    auto className = "My" + funcOp.getSymNameAttr().str() + "Class";
+    mlir::emitc::ClassOp newClassOp =
+        rewriter.create<emitc::ClassOp>(funcOp.getLoc(), className);
+
+    SmallVector<std::pair<StringAttr, TypeAttr>> fields;
+    rewriter.createBlock(&newClassOp.getBody());
+    rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
+
+    auto argAttrs = funcOp.getArgAttrs();
+    if (argAttrs) {
+      for (const auto &[arg, val] :
+           llvm::zip(*argAttrs, funcOp.getArguments())) {
+        if (auto namedAttr =
+                dyn_cast<mlir::DictionaryAttr>(arg).getNamed(attributeName)) {
+          Attribute nv = namedAttr->getValue();
+          StringAttr fieldName =
+              cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
+          TypeAttr typeAttr = TypeAttr::get(val.getType());
+          fields.push_back({fieldName, typeAttr});
+
+          rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+                                          /* attributes*/ arg);
+        }
+      }
+    } else {
+      funcOp->emitOpError("arguments should have attributes so we can "
+                          "initialize class fields.");
+      return failure();
+    }
+
+    rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
+    MLIRContext *funcContext = funcOp.getContext();
+    ArrayRef<Type> inputTypes = funcOp.getFunctionType().getInputs();
+    ArrayRef<Type> results = funcOp.getFunctionType().getResults();
+    FunctionType funcType = FunctionType::get(funcContext, inputTypes, results);
+    Location loc = funcOp.getLoc();
+    FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
+        loc, rewriter.getStringAttr("execute"), funcType);
+
+    rewriter.createBlock(&newFuncOp.getBody());
+    newFuncOp.getBody().takeBody(funcOp.getBody());
+
+    rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
+    std::vector<Value> newArguments;
+    for (auto [fieldName, attr] : fields) {
+      auto arg =
+          rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
+      newArguments.push_back(arg);
+    }
+
+    for (auto [oldArg, newArg] :
+         llvm::zip(newFuncOp.getArguments(), newArguments)) {
+      rewriter.replaceAllUsesWith(oldArg, newArg);
+    }
+
+    while (!newFuncOp.getArguments().empty()) {
+      if (failed(newFuncOp.eraseArgument(0))) {
+        break;
+      }
+    }
+
+    rewriter.replaceOp(funcOp, newClassOp);
+    return funcOp->use_empty() ? success() : failure();
----------------
Jaddyen wrote:

not quite. it was just a sanity check that once we call `replaceOp`, all uses of the old `funcOp` are dropped.
i can just rewrite this to return success and assume correctness of `replaceop`.

https://github.com/llvm/llvm-project/pull/141158


More information about the Mlir-commits mailing list