[Mlir-commits] [mlir] [emitC]Pass in `mlir-opt` to wrap a func in class (PR #141158)

Jaden Angella llvmlistbot at llvm.org
Wed Jun 18 15:02:16 PDT 2025


Valentin Clement =?utf-8?b?KOODkOODrOODsw=?=,Jaddyen <ajaden at google.com>,Jaddyen
 <ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaddyen
 <ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaddyen <ajaden at google.com>,
Valentin Clement =?utf-8?b?KOODkOODrOODsw=?=,Jaddyen <ajaden at google.com>,Jaddyen
 <ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaden Angella
 <141196890+Jaddyen at users.noreply.github.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/141158 at github.com>


https://github.com/Jaddyen updated https://github.com/llvm/llvm-project/pull/141158

>From c3867879aaf0273e45793b8b1f0a4d8ca5d81221 Mon Sep 17 00:00:00 2001
From: Jaden Jaden <ajaden at google.com>
Date: Wed, 21 May 2025 23:10:51 +0000
Subject: [PATCH 01/20] Able to get class when you set a class-Name

---
 mlir/include/mlir/Target/Cpp/CppEmitter.h     |   2 +-
 mlir/lib/Target/Cpp/TranslateRegistration.cpp |   7 +-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 168 +++++++++++++++---
 3 files changed, 147 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index 7c5747a888261..f87fc67662f55 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,7 +28,7 @@ namespace emitc {
 /// with matching id are emitted.
 LogicalResult translateToCpp(Operation *op, raw_ostream &os,
                              bool declareVariablesAtTop = false,
-                             StringRef fileId = {});
+                             StringRef fileId = {}, StringRef className = {});
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 2108ffd414c56..d556927d7f904 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,13 +33,18 @@ void registerToCppTranslation() {
       "file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
       llvm::cl::init(""));
 
+  static llvm::cl::opt<std::string> className(
+      "class-name", llvm::cl::desc("Specify the class name for the generated C++ code"),
+      llvm::cl::init(""));
+
   TranslateFromMLIRRegistration reg(
       "mlir-to-cpp", "translate from mlir to cpp",
       [](Operation *op, raw_ostream &output) {
         return emitc::translateToCpp(
             op, output,
             /*declareVariablesAtTop=*/declareVariablesAtTop,
-            /*fileId=*/fileId);
+            /*fileId=*/fileId,
+            /*className=*/className);
       },
       [](DialectRegistry &registry) {
         // clang-format off
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 5abc112ab8c7a..1e94eebae50bf 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -68,6 +69,13 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
 }
 
+template <typename Container, typename UnaryFunctor>
+inline LogicalResult interleaveWithNewLineWithError(const Container &c,
+                                              raw_ostream &os,
+                                              UnaryFunctor eachFn) {
+  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << "\n"; });
+}
+
 /// Return the precedence of a operator as an integer, higher values
 /// imply higher precedence.
 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
@@ -116,7 +124,7 @@ namespace {
 /// Emitter that uses dialect specific emitters to emit C++ code.
 struct CppEmitter {
   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                      StringRef fileId);
+                      StringRef fileId, StringRef className);
 
   /// Emits attribute or returns failure.
   LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -233,6 +241,9 @@ struct CppEmitter {
   /// be declared at the beginning of a function.
   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
 
+  // Returns if we should spit out a C++ class
+  std::string shouldSpitClass() { return className; };
+
   /// Returns whether this file op should be emitted
   bool shouldEmitFile(FileOp file) {
     return !fileId.empty() && file.getId() == fileId;
@@ -268,6 +279,9 @@ struct CppEmitter {
   /// Only emit file ops whos id matches this value.
   std::string fileId;
 
+  /// Name of the C++ class we will spit
+  std::string className;
+
   /// Map from value to name of C++ variable that contain the name.
   ValueMapper valueMapper;
 
@@ -1033,6 +1047,32 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
       }));
 }
 
+static LogicalResult printAttributes(CppEmitter &emitter,
+                                       Operation *functionOp,
+                                       ArrayRef<Type> arguments) {
+  raw_indented_ostream &os = emitter.ostream();
+  for (auto arg : arguments) {
+    os << "  ";
+    if (failed(emitter.emitType(functionOp->getLoc(), arg)))
+      return failure();
+    os << " " << arg << ";\n";
+  }
+  return success();
+}
+
+static LogicalResult printAttributes(CppEmitter &emitter,
+                                       Operation *functionOp,
+                                       Region::BlockArgListType arguments) {
+  raw_indented_ostream &os = emitter.ostream();
+  for (auto arg : arguments) {
+    os << "  ";
+    if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
+      return failure();
+    os << " " << emitter.getOrCreateName(arg) << ";\n";
+  }
+  return success();
+}
+
 static LogicalResult printFunctionBody(CppEmitter &emitter,
                                        Operation *functionOp,
                                        Region::BlockListType &blocks) {
@@ -1146,35 +1186,107 @@ static LogicalResult printOperation(CppEmitter &emitter,
         "with multiple blocks needs variables declared at top");
   }
 
-  CppEmitter::Scope scope(emitter);
+  CppEmitter::Scope classScope(emitter);
   raw_indented_ostream &os = emitter.ostream();
-  if (functionOp.getSpecifiers()) {
-    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
-      os << cast<StringAttr>(specifier).str() << " ";
+  Operation *operation = functionOp.getOperation();
+  if (!emitter.shouldSpitClass().empty()) {
+    os << "class " << emitter.shouldSpitClass() << " final {\n";
+    os << "public: \n";
+
+    if (functionOp.isExternal()) { 
+      if (failed(printAttributes(emitter, operation,
+                                  functionOp.getArgumentTypes())))
+        return failure();
+      return success();
     }
-  }
+    if (failed(printAttributes(emitter, operation, functionOp.getArguments())))
+      return failure();
+    
+    os << "\n";
+    
+    auto argAttrs = functionOp.getArgAttrs();
+    std::map<std::string, Value> fields; 
+    if (argAttrs)
+      for (auto [a,v] : zip(*argAttrs, functionOp.getArguments())) { 
+        if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
+          auto nv = da.getNamed("tf_saved_model.index_path")->getValue(); 
+          auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str(); 
+          fields[name] = v;
+        }
+      }
 
-  if (failed(emitter.emitTypes(functionOp.getLoc(),
-                               functionOp.getFunctionType().getResults())))
-    return failure();
-  os << " " << functionOp.getName();
+    for (auto & r : functionOp->getRegions()) 
+      for (auto &b : r.getBlocks())
+        for (auto &opt : b.getOperations())  
+          if (auto alloc = dyn_cast<memref::AllocOp>(opt)) { 
+            auto name = emitter.getOrCreateName(alloc).str(); 
+            fields[name] = alloc;
+            if (failed(emitter.emitType(alloc.getLoc(), alloc.getType().getElementType())))
+              return failure();
+            os << " [" << alloc.getType().getNumElements() <<"] "; 
+            os << " " << name << ";\n";
+          }
+    os << "  std::map<std::string, char*> _buffer_map {"; 
+    for (auto &[n,v]:fields)
+      os << "{ \"" << n << "\"" << ", reinterpret_cast<char*>(" << emitter.getOrCreateName(v) << ") },";
+    os << "  };\n";
+    os << "  char* getBufferForName(const std::string& name) const {\n";
+    os << "    auto it = _buffer_map.find(name);\n";
+    os << "    return (it == _buffer_map.end()) ? nullptr : it->second;\n";
+    os << "  }\n\n";
 
-  os << "(";
-  Operation *operation = functionOp.getOperation();
-  if (functionOp.isExternal()) {
-    if (failed(printFunctionArgs(emitter, operation,
-                                 functionOp.getArgumentTypes())))
+    os.indent();
+
+    // Begin defining the main function where we have the actual execution
+    if (functionOp.getSpecifiers()) {
+      for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+        os << cast<StringAttr>(specifier).str() << " ";
+      }
+    }
+
+    if (failed(emitter.emitTypes(functionOp.getLoc(),
+                                functionOp.getFunctionType().getResults())))
       return failure();
-    os << ");";
-    return success();
-  }
-  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
-    return failure();
-  os << ") {\n";
-  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
-    return failure();
-  os << "}\n";
+    os << " " << functionOp.getName();
 
+    //Defining the functionBody
+    os << "() { \n"; // Begin defining the function header (We need the function name and output type to remain without the rest of the function header)
+    os.indent();
+    if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+      return failure();
+    os.unindent();
+    os << "}\n";
+    os.unindent();
+    os << "};\n";
+  } else {
+    if (functionOp.getSpecifiers()) {
+      for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+        os << cast<StringAttr>(specifier).str() << " ";
+      }
+    }
+
+    if (failed(emitter.emitTypes(functionOp.getLoc(),
+                                functionOp.getFunctionType().getResults())))
+      return failure();
+    os << " " << functionOp.getName();
+
+    os << "(";
+    Operation *operation = functionOp.getOperation();
+    if (functionOp.isExternal()) {
+      if (failed(printFunctionArgs(emitter, operation,
+                                  functionOp.getArgumentTypes())))
+        return failure();
+      os << ");";
+      return success();
+    }
+    if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+      return failure();
+    os << ") {\n";
+    if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+      return failure();
+    os << "}\n";
+  }
+  
   return success();
 }
 
@@ -1210,9 +1322,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
 }
 
 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                       StringRef fileId)
+                       StringRef fileId, StringRef className)
     : os(os), declareVariablesAtTop(declareVariablesAtTop),
-      fileId(fileId.str()) {
+      fileId(fileId.str()), className(className.str()) {
   valueInScopeCount.push(0);
   labelInScopeCount.push(0);
 }
@@ -1795,7 +1907,7 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
 
 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
                                     bool declareVariablesAtTop,
-                                    StringRef fileId) {
-  CppEmitter emitter(os, declareVariablesAtTop, fileId);
+                                    StringRef fileId, StringRef className) {
+  CppEmitter emitter(os, declareVariablesAtTop, fileId, className);
   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
 }

>From 7a97f17c3c1270f5371b262c5cd75b9e79e4b633 Mon Sep 17 00:00:00 2001
From: Jaden Jaden <ajaden at google.com>
Date: Thu, 22 May 2025 23:17:00 +0000
Subject: [PATCH 02/20] fixes

---
 mlir/lib/Target/Cpp/TranslateRegistration.cpp |   4 +-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 233 ++++++++++--------
 2 files changed, 128 insertions(+), 109 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index d556927d7f904..17fcc058eef04 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -34,7 +34,9 @@ void registerToCppTranslation() {
       llvm::cl::init(""));
 
   static llvm::cl::opt<std::string> className(
-      "class-name", llvm::cl::desc("Specify the class name for the generated C++ code"),
+      "class-name",
+      llvm::cl::desc("Optional class name. If specified, the output will be a "
+                     "class where the function(s) in the module are members."),
       llvm::cl::init(""));
 
   TranslateFromMLIRRegistration reg(
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 1e94eebae50bf..2e8dfca7479b2 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -73,7 +72,8 @@ template <typename Container, typename UnaryFunctor>
 inline LogicalResult interleaveWithNewLineWithError(const Container &c,
                                               raw_ostream &os,
                                               UnaryFunctor eachFn) {
-  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << "\n"; });
+  return interleaveWithError(c.begin(), c.end(), eachFn,
+                             [&]() { os << ";\n"; });
 }
 
 /// Return the precedence of a operator as an integer, higher values
@@ -241,8 +241,11 @@ struct CppEmitter {
   /// be declared at the beginning of a function.
   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
 
-  // Returns if we should spit out a C++ class
-  std::string shouldSpitClass() { return className; };
+  // Returns whether we should emit a C++ class
+  bool shouldPrintClass() { return !className.empty(); };
+
+  // Returns the class name to emit
+  std::string getClassName() { return className; };
 
   /// Returns whether this file op should be emitted
   bool shouldEmitFile(FileOp file) {
@@ -1047,30 +1050,25 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
       }));
 }
 
-static LogicalResult printAttributes(CppEmitter &emitter,
-                                       Operation *functionOp,
-                                       ArrayRef<Type> arguments) {
+static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
+                                 ArrayRef<Type> arguments) {
   raw_indented_ostream &os = emitter.ostream();
-  for (auto arg : arguments) {
-    os << "  ";
-    if (failed(emitter.emitType(functionOp->getLoc(), arg)))
-      return failure();
-    os << " " << arg << ";\n";
-  }
-  return success();
+
+  return (interleaveWithNewLineWithError(
+      arguments, os, [&](Type arg) -> LogicalResult {
+        return emitter.emitType(functionOp->getLoc(), arg);
+      }));
 }
 
-static LogicalResult printAttributes(CppEmitter &emitter,
-                                       Operation *functionOp,
-                                       Region::BlockArgListType arguments) {
+static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
+                                 Region::BlockArgListType arguments) {
   raw_indented_ostream &os = emitter.ostream();
-  for (auto arg : arguments) {
-    os << "  ";
-    if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
-      return failure();
-    os << " " << emitter.getOrCreateName(arg) << ";\n";
-  }
-  return success();
+
+  return (interleaveWithNewLineWithError(
+      arguments, os, [&](BlockArgument arg) -> LogicalResult {
+        return emitter.emitVariableDeclaration(
+            functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
+      }));
 }
 
 static LogicalResult printFunctionBody(CppEmitter &emitter,
@@ -1177,116 +1175,135 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter,
-                                    emitc::FuncOp functionOp) {
-  // We need to declare variables at top if the function has multiple blocks.
-  if (!emitter.shouldDeclareVariablesAtTop() &&
-      functionOp.getBlocks().size() > 1) {
-    return functionOp.emitOpError(
-        "with multiple blocks needs variables declared at top");
-  }
-
-  CppEmitter::Scope classScope(emitter);
+static LogicalResult printFunctionHeader(CppEmitter &emitter,
+                                         emitc::FuncOp functionOp) {
   raw_indented_ostream &os = emitter.ostream();
   Operation *operation = functionOp.getOperation();
-  if (!emitter.shouldSpitClass().empty()) {
-    os << "class " << emitter.shouldSpitClass() << " final {\n";
-    os << "public: \n";
+  if (functionOp.getSpecifiers()) {
+    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+      os << cast<StringAttr>(specifier).str() << " ";
+    }
+  }
 
-    if (functionOp.isExternal()) { 
-      if (failed(printAttributes(emitter, operation,
-                                  functionOp.getArgumentTypes())))
+  if (failed(emitter.emitTypes(functionOp.getLoc(),
+                               functionOp.getFunctionType().getResults())))
+    return failure();
+  os << " " << functionOp.getName();
+  if (!emitter.shouldPrintClass()) {
+    os << "(";
+    if (functionOp.isExternal()) {
+      if (failed(printFunctionArgs(emitter, operation,
+                                   functionOp.getArgumentTypes())))
         return failure();
+      os << ");";
       return success();
     }
-    if (failed(printAttributes(emitter, operation, functionOp.getArguments())))
+    if (failed(
+            printFunctionArgs(emitter, operation, functionOp.getArguments())))
       return failure();
-    
-    os << "\n";
-    
-    auto argAttrs = functionOp.getArgAttrs();
-    std::map<std::string, Value> fields; 
-    if (argAttrs)
-      for (auto [a,v] : zip(*argAttrs, functionOp.getArguments())) { 
-        if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
-          auto nv = da.getNamed("tf_saved_model.index_path")->getValue(); 
-          auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str(); 
-          fields[name] = v;
-        }
-      }
+    os << ") {\n";
 
-    for (auto & r : functionOp->getRegions()) 
-      for (auto &b : r.getBlocks())
-        for (auto &opt : b.getOperations())  
-          if (auto alloc = dyn_cast<memref::AllocOp>(opt)) { 
-            auto name = emitter.getOrCreateName(alloc).str(); 
-            fields[name] = alloc;
-            if (failed(emitter.emitType(alloc.getLoc(), alloc.getType().getElementType())))
-              return failure();
-            os << " [" << alloc.getType().getNumElements() <<"] "; 
-            os << " " << name << ";\n";
-          }
-    os << "  std::map<std::string, char*> _buffer_map {"; 
-    for (auto &[n,v]:fields)
-      os << "{ \"" << n << "\"" << ", reinterpret_cast<char*>(" << emitter.getOrCreateName(v) << ") },";
-    os << "  };\n";
-    os << "  char* getBufferForName(const std::string& name) const {\n";
-    os << "    auto it = _buffer_map.find(name);\n";
-    os << "    return (it == _buffer_map.end()) ? nullptr : it->second;\n";
-    os << "  }\n\n";
+  } else {
+    os << "() { \n";
+  }
 
-    os.indent();
+  return success();
+}
 
-    // Begin defining the main function where we have the actual execution
-    if (functionOp.getSpecifiers()) {
-      for (Attribute specifier : functionOp.getSpecifiersAttr()) {
-        os << cast<StringAttr>(specifier).str() << " ";
+static LogicalResult emitClassBody(CppEmitter &emitter,
+                                   emitc::FuncOp functionOp) {
+  raw_indented_ostream &os = emitter.ostream();
+  Operation *operation = functionOp.getOperation();
+  auto argAttrs = functionOp.getArgAttrs();
+  std::map<std::string, Value> fields;
+  os << "\nstd::map<std::string, char*> _buffer_map {";
+  if (argAttrs) // We can have no argattrs in the case that the function has no
+                // inputs nor outputs -> procedure
+    for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
+      if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
+        auto nv =
+            da.getNamed("tf_saved_model.index_path")
+                ->getValue(); // From what I've seen so far, this is the only
+                              // way to have the argAttrs keys. If there is
+                              // another way, I need to run the tests to see and
+                              // see what cases trigger this change in format.
+        auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
+        fields[name] = v; // The only way to not have unique names is in the
+                          // case that you have duplicate arguments in your
+                          // tensorflow/python function. By python syntax rules,
+                          // you're not allowed to have that(Current assumption)
+        os << "{ \"" << name << "\"" << ", reinterpret_cast<char*>("
+           << emitter.getOrCreateName(v) << ") },";
       }
     }
+  else
+    return failure();
+
+  os << " };\n";
+  os << "char* getBufferForName(const std::string& name) const {\n";
+  os.indent();
+  os.indent();
+  os << "auto it = _buffer_map.find(name);\n";
+  os << "return (it == _buffer_map.end()) ? nullptr : it->second;\n";
+  os.unindent();
+  os.unindent();
+  os << "}\n\n";
+
+  if (failed(printFunctionHeader(emitter, functionOp)))
+    return failure();
+
+  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    return failure();
+  os << "}\n";
+  os.unindent();
+  os << "};\n";
+
+  return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::FuncOp functionOp) {
+  // We need to declare variables at top if the function has multiple blocks.
+  if (!emitter.shouldDeclareVariablesAtTop() &&
+      functionOp.getBlocks().size() > 1) {
+    return functionOp.emitOpError(
+        "with multiple blocks needs variables declared at top");
+  }
 
-    if (failed(emitter.emitTypes(functionOp.getLoc(),
-                                functionOp.getFunctionType().getResults())))
+  CppEmitter::Scope classScope(emitter);
+  raw_indented_ostream &os = emitter.ostream();
+  Operation *operation = functionOp.getOperation();
+  if (!emitter.shouldPrintClass()) {
+
+    if (failed(printFunctionHeader(emitter, functionOp)))
       return failure();
-    os << " " << functionOp.getName();
 
-    //Defining the functionBody
-    os << "() { \n"; // Begin defining the function header (We need the function name and output type to remain without the rest of the function header)
-    os.indent();
-    if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    if (failed(printFunctionBody(
+            emitter, operation,
+            functionOp.getBlocks()))) // This is the only similarity between the
+                                      // function and the class
       return failure();
-    os.unindent();
     os << "}\n";
-    os.unindent();
-    os << "};\n";
-  } else {
-    if (functionOp.getSpecifiers()) {
-      for (Attribute specifier : functionOp.getSpecifiersAttr()) {
-        os << cast<StringAttr>(specifier).str() << " ";
-      }
-    }
 
-    if (failed(emitter.emitTypes(functionOp.getLoc(),
-                                functionOp.getFunctionType().getResults())))
-      return failure();
-    os << " " << functionOp.getName();
+  } else {
+    os << "class " << emitter.getClassName() << " final {\n";
+    os << "public: \n";
+    os.indent();
 
-    os << "(";
-    Operation *operation = functionOp.getOperation();
     if (functionOp.isExternal()) {
-      if (failed(printFunctionArgs(emitter, operation,
-                                  functionOp.getArgumentTypes())))
+      if (failed(
+              printFields(emitter, operation, functionOp.getArgumentTypes())))
         return failure();
-      os << ");";
       return success();
     }
-    if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+    if (failed(printFields(emitter, operation, functionOp.getArguments())))
       return failure();
-    os << ") {\n";
-    if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    os << ";\n";
+
+    if (failed(emitClassBody(emitter, functionOp)))
       return failure();
-    os << "}\n";
   }
-  
+
   return success();
 }
 

>From 23e84353689078ec1ae4ae509d952201113eb58b Mon Sep 17 00:00:00 2001
From: Jaden Jaden <ajaden at google.com>
Date: Tue, 27 May 2025 18:39:14 +0000
Subject: [PATCH 03/20] Added flags to increase readability

---
 mlir/include/mlir/Target/Cpp/CppEmitter.h     |   4 +-
 mlir/lib/Target/Cpp/TranslateRegistration.cpp |  38 +++-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 181 ++++++++----------
 mlir/test/mlir-translate/emit-class.mlir      |  49 +++++
 4 files changed, 170 insertions(+), 102 deletions(-)
 create mode 100644 mlir/test/mlir-translate/emit-class.mlir

diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index f87fc67662f55..d1a6c1dc12d4c 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,7 +28,9 @@ namespace emitc {
 /// with matching id are emitted.
 LogicalResult translateToCpp(Operation *op, raw_ostream &os,
                              bool declareVariablesAtTop = false,
-                             StringRef fileId = {}, StringRef className = {});
+                             StringRef fileId = {}, bool emitClass = false,
+                             StringRef className = {},
+                             StringRef fieldNameAttribute = {});
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 17fcc058eef04..69e0ab01bb71d 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,20 +33,50 @@ void registerToCppTranslation() {
       "file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
       llvm::cl::init(""));
 
+  static llvm::cl::opt<bool> emitClass(
+      "emit-class",
+      llvm::cl::desc("If specified, the output will be a class where "
+                     "the function(s) in the module are members. "
+                     "Enables class-related options."),
+      llvm::cl::init(false));
+
   static llvm::cl::opt<std::string> className(
       "class-name",
-      llvm::cl::desc("Optional class name. If specified, the output will be a "
-                     "class where the function(s) in the module are members."),
+      llvm::cl::desc("Mandatory class name if --emit-class is set."),
+      llvm::cl::init(""));
+
+  static llvm::cl::opt<std::string> fieldNameAttribute(
+      "field-name-attribute",
+      llvm::cl::desc("Mandatory name of the attribute to use as field name if "
+                     "--emit-class is set."),
       llvm::cl::init(""));
 
   TranslateFromMLIRRegistration reg(
       "mlir-to-cpp", "translate from mlir to cpp",
       [](Operation *op, raw_ostream &output) {
+        if (emitClass) {
+          if (className.empty()) {
+            llvm::errs() << "Error: --class-name is mandatory when "
+                            "--emit-class is set.\n";
+            return mlir::failure();
+          }
+          if (fieldNameAttribute.empty()) {
+            llvm::errs() << "Error: --field-name-attribute is mandatory when "
+                            "--emit-class is set.\n";
+            return mlir::failure();
+          }
+          return emitc::translateToCpp(
+              op, output,
+              /*declareVariablesAtTop=*/declareVariablesAtTop,
+              /*fileId=*/fileId, /*emitClass=*/emitClass,
+              /*className=*/className,
+              /*fieldNameAttribute=*/fieldNameAttribute);
+        }
         return emitc::translateToCpp(
             op, output,
             /*declareVariablesAtTop=*/declareVariablesAtTop,
-            /*fileId=*/fileId,
-            /*className=*/className);
+            /*fileId=*/fileId, /*emitClass=*/emitClass, /*className=*/className,
+            /*fieldNameAttribute=*/fieldNameAttribute);
       },
       [](DialectRegistry &registry) {
         // clang-format off
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 2e8dfca7479b2..ef6583caf2d46 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <stack>
@@ -124,7 +125,8 @@ namespace {
 /// Emitter that uses dialect specific emitters to emit C++ code.
 struct CppEmitter {
   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                      StringRef fileId, StringRef className);
+                      StringRef fileId, bool emitClass, StringRef className,
+                      StringRef fieldNameAttribute);
 
   /// Emits attribute or returns failure.
   LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -242,11 +244,14 @@ struct CppEmitter {
   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
 
   // Returns whether we should emit a C++ class
-  bool shouldPrintClass() { return !className.empty(); };
+  bool shouldPrintClass() { return emitClass; };
 
   // Returns the class name to emit
   std::string getClassName() { return className; };
 
+  // Returns the field name to use in the map
+  std::string getfieldNameAttribute() { return fieldNameAttribute; };
+
   /// Returns whether this file op should be emitted
   bool shouldEmitFile(FileOp file) {
     return !fileId.empty() && file.getId() == fileId;
@@ -282,9 +287,18 @@ struct CppEmitter {
   /// Only emit file ops whos id matches this value.
   std::string fileId;
 
-  /// Name of the C++ class we will spit
+  /// Controls whether the output should be a C++ class.
+  /// If true, the generated C++ code will be encapsulated within a class,
+  /// and functions from the input module will become its member functions.
+  bool emitClass;
+
+  /// The specified name for the generated C++ class
   std::string className;
 
+  /// Name of the MLIR attribute to use as a field name within the generated
+  /// class
+  std::string fieldNameAttribute;
+
   /// Map from value to name of C++ variable that contain the name.
   ValueMapper valueMapper;
 
@@ -1050,16 +1064,6 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
       }));
 }
 
-static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
-                                 ArrayRef<Type> arguments) {
-  raw_indented_ostream &os = emitter.ostream();
-
-  return (interleaveWithNewLineWithError(
-      arguments, os, [&](Type arg) -> LogicalResult {
-        return emitter.emitType(functionOp->getLoc(), arg);
-      }));
-}
-
 static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
                                  Region::BlockArgListType arguments) {
   raw_indented_ostream &os = emitter.ostream();
@@ -1175,71 +1179,38 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
-static LogicalResult printFunctionHeader(CppEmitter &emitter,
-                                         emitc::FuncOp functionOp) {
+static LogicalResult emitClassFields(CppEmitter &emitter,
+                                     emitc::FuncOp functionOp) {
   raw_indented_ostream &os = emitter.ostream();
+  auto argAttrs = functionOp.getArgAttrs();
   Operation *operation = functionOp.getOperation();
-  if (functionOp.getSpecifiers()) {
-    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
-      os << cast<StringAttr>(specifier).str() << " ";
-    }
-  }
-
-  if (failed(emitter.emitTypes(functionOp.getLoc(),
-                               functionOp.getFunctionType().getResults())))
+  if (failed(printFields(emitter, operation, functionOp.getArguments())))
     return failure();
-  os << " " << functionOp.getName();
-  if (!emitter.shouldPrintClass()) {
-    os << "(";
-    if (functionOp.isExternal()) {
-      if (failed(printFunctionArgs(emitter, operation,
-                                   functionOp.getArgumentTypes())))
-        return failure();
-      os << ");";
-      return success();
-    }
-    if (failed(
-            printFunctionArgs(emitter, operation, functionOp.getArguments())))
-      return failure();
-    os << ") {\n";
+  os << ";\n";
 
-  } else {
-    os << "() { \n";
-  }
-
-  return success();
-}
-
-static LogicalResult emitClassBody(CppEmitter &emitter,
-                                   emitc::FuncOp functionOp) {
-  raw_indented_ostream &os = emitter.ostream();
-  Operation *operation = functionOp.getOperation();
-  auto argAttrs = functionOp.getArgAttrs();
   std::map<std::string, Value> fields;
-  os << "\nstd::map<std::string, char*> _buffer_map {";
-  if (argAttrs) // We can have no argattrs in the case that the function has no
-                // inputs nor outputs -> procedure
+  os << "std::map<std::string, char*> _buffer_map {";
+  if (argAttrs) {
+    bool isFirst = true;
     for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
       if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
-        auto nv =
-            da.getNamed("tf_saved_model.index_path")
-                ->getValue(); // From what I've seen so far, this is the only
-                              // way to have the argAttrs keys. If there is
-                              // another way, I need to run the tests to see and
-                              // see what cases trigger this change in format.
+        auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
         auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
-        fields[name] = v; // The only way to not have unique names is in the
-                          // case that you have duplicate arguments in your
-                          // tensorflow/python function. By python syntax rules,
-                          // you're not allowed to have that(Current assumption)
+        auto Ins = fields.insert({name, v});
+        if (!Ins.second)
+          return failure();
+        if (!isFirst) {
+          os << ",";
+        }
         os << "{ \"" << name << "\"" << ", reinterpret_cast<char*>("
-           << emitter.getOrCreateName(v) << ") },";
+           << emitter.getOrCreateName(v) << ") }";
+        isFirst = false;
       }
     }
-  else
+  } else
     return failure();
 
-  os << " };\n";
+  os << "};";
   os << "char* getBufferForName(const std::string& name) const {\n";
   os.indent();
   os.indent();
@@ -1249,15 +1220,6 @@ static LogicalResult emitClassBody(CppEmitter &emitter,
   os.unindent();
   os << "}\n\n";
 
-  if (failed(printFunctionHeader(emitter, functionOp)))
-    return failure();
-
-  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
-    return failure();
-  os << "}\n";
-  os.unindent();
-  os << "};\n";
-
   return success();
 }
 
@@ -1270,38 +1232,58 @@ static LogicalResult printOperation(CppEmitter &emitter,
         "with multiple blocks needs variables declared at top");
   }
 
-  CppEmitter::Scope classScope(emitter);
+  CppEmitter::Scope scope(emitter);
   raw_indented_ostream &os = emitter.ostream();
   Operation *operation = functionOp.getOperation();
-  if (!emitter.shouldPrintClass()) {
-
-    if (failed(printFunctionHeader(emitter, functionOp)))
-      return failure();
-
-    if (failed(printFunctionBody(
-            emitter, operation,
-            functionOp.getBlocks()))) // This is the only similarity between the
-                                      // function and the class
+  if (emitter.shouldPrintClass()) {
+    if (functionOp.isExternal())
       return failure();
-    os << "}\n";
-
-  } else {
     os << "class " << emitter.getClassName() << " final {\n";
     os << "public: \n";
     os.indent();
 
+    if (failed(emitClassFields(emitter, functionOp)))
+      return failure();
+  }
+
+  if (functionOp.getSpecifiers()) {
+    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+      os << cast<StringAttr>(specifier).str() << " ";
+    }
+  }
+
+  if (failed(emitter.emitTypes(functionOp.getLoc(),
+                               functionOp.getFunctionType().getResults())))
+    return failure();
+  os << " " << functionOp.getName();
+
+  os << "(";
+
+  if (emitter.shouldPrintClass())
+    os << ") { \n";
+  else {
     if (functionOp.isExternal()) {
-      if (failed(
-              printFields(emitter, operation, functionOp.getArgumentTypes())))
+      if (failed(printFunctionArgs(emitter, operation,
+                                   functionOp.getArgumentTypes())))
         return failure();
+      os << ");";
       return success();
     }
-    if (failed(printFields(emitter, operation, functionOp.getArguments())))
+    if (failed(
+            printFunctionArgs(emitter, operation, functionOp.getArguments())))
       return failure();
-    os << ";\n";
+    os << ") {\n";
+  }
 
-    if (failed(emitClassBody(emitter, functionOp)))
-      return failure();
+  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    return failure();
+
+  if (emitter.shouldPrintClass()) {
+    os << "}\n";
+    os.unindent();
+    os << "};\n";
+  } else {
+    os << "}\n";
   }
 
   return success();
@@ -1339,9 +1321,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
 }
 
 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                       StringRef fileId, StringRef className)
+                       StringRef fileId, bool emitClass, StringRef className,
+                       StringRef fieldNameAttribute)
     : os(os), declareVariablesAtTop(declareVariablesAtTop),
-      fileId(fileId.str()), className(className.str()) {
+      fileId(fileId.str()), emitClass(emitClass), className(className.str()),
+      fieldNameAttribute(fieldNameAttribute.str()) {
   valueInScopeCount.push(0);
   labelInScopeCount.push(0);
 }
@@ -1924,7 +1908,10 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
 
 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
                                     bool declareVariablesAtTop,
-                                    StringRef fileId, StringRef className) {
-  CppEmitter emitter(os, declareVariablesAtTop, fileId, className);
+                                    StringRef fileId, bool emitClass,
+                                    StringRef className,
+                                    StringRef fieldNameAttribute) {
+  CppEmitter emitter(os, declareVariablesAtTop, fileId, emitClass, className,
+                     fieldNameAttribute);
   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
 }
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
new file mode 100644
index 0000000000000..eddc2181a0b19
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path /tmp/model_emitc.mlir | FileCheck %s --check-prefix=ADDER_TEST
+
+// ADDER_TEST-LABEL: class MyAdder final {
+// ADDER_TEST-NEXT: public:
+// ADDER_TEST-DAG: float v1[1];
+// ADDER_TEST-DAG: float v2[1];
+// ADDER_TEST-DAG: float v3[1];      
+// ADDER_TEST-NEXT: std::map<std::string, char*> _buffer_map {{ "another_feature", reinterpret_cast<char*>(v1) },{ "some_feature", reinterpret_cast<char*>(v2) },{ "output_0", reinterpret_cast<char*>(v3) }};
+// ADDER_TEST-NEXT: char* getBufferForName(const std::string& name) const {
+// ADDER_TEST-NEXT:     auto it = _buffer_map.find(name);
+// ADDER_TEST-NEXT:     return (it == _buffer_map.end()) ? nullptr : it->second;
+// ADDER_TEST-NEXT: }
+// ADDER_TEST-NEXT: void main() {
+// ADDER_TEST-NEXT:     size_t v4 = 0;
+// ADDER_TEST-NEXT:     float v5 = v2[v4];
+// ADDER_TEST-NEXT:     float v6 = v1[v4];
+// ADDER_TEST-NEXT:     float v7 = v5 + v6;
+// ADDER_TEST-NEXT:     v3[v4] = v7;
+// ADDER_TEST-NEXT:     return;
+// ADDER_TEST-NEXT: }
+// ADDER_TEST-NEXT: };
+
+// ---
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyMultiOutput --field-name-attribute=tf_saved_model.index_path /tmp/model_multi_out_emitc.mlir | FileCheck %s --check-prefix=MULTI_OUT
+
+// MULTI_OUT-LABEL: class MyMultiOutput final {
+// MULTI_OUT-NEXT: public:
+// MULTI_OUT-DAG: float v1[1];
+// MULTI_OUT-DAG: float v2[1];
+// MULTI_OUT-DAG: float v3[1];
+// MULTI_OUT-DAG: float v4[1];
+// MULTI_OUT: std::map<std::string, char*> _buffer_map {{ "b", reinterpret_cast<char*>(v1) },{ "a", reinterpret_cast<char*>(v2) },{ "output_1", reinterpret_cast<char*>(v3) },{ "output_0", reinterpret_cast<char*>(v4) }, };
+// MULTI_OUT-NEXT: char* getBufferForName(const std::string& name) const {
+// MULTI_OUT-NEXT:     auto it = _buffer_map.find(name);
+// MULTI_OUT-NEXT:     return (it == _buffer_map.end()) ? nullptr : it->second;
+// MULTI_OUT-NEXT: }
+// MULTI_OUT-NEXT: void main() {
+// MULTI_OUT-NEXT:     size_t v5 = 0;
+// MULTI_OUT-NEXT:     float v6 = v2[v5];
+// MULTI_OUT-NEXT:     float v7 = v1[v5];
+// MULTI_OUT-NEXT:     float v8 = v6 + v7;
+// MULTI_OUT-NEXT:     v4[v5] = v8;
+// MULTI_OUT-NEXT:     float v9 = v2[v5];
+// MULTI_OUT-NEXT:     float v10 = v1[v5];
+// MULTI_OUT-NEXT:     float v11 = v9 - v10;
+// MULTI_OUT-NEXT:     v3[v5] = v11;
+// MULTI_OUT-NEXT:     return;
+// MULTI_OUT-NEXT: }
+// MULTI_OUT-NEXT: };

>From 0c55dcf1f7cf314b21b6caf67d6d10000d986c1a Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Wed, 28 May 2025 21:33:44 +0000
Subject: [PATCH 04/20] Working tests to emit-class

---
 mlir/lib/Target/Cpp/TranslateRegistration.cpp | 10 +--
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 35 ++++----
 .../emit-class-neg-external.mlir              |  8 ++
 .../emit-class-neg-noArgAttrs.mlir            | 15 ++++
 mlir/test/mlir-translate/emit-class.mlir      | 81 ++++++++-----------
 5 files changed, 80 insertions(+), 69 deletions(-)
 create mode 100644 mlir/test/mlir-translate/emit-class-neg-external.mlir
 create mode 100644 mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir

diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 69e0ab01bb71d..9e1533d34f6ea 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -36,20 +36,20 @@ void registerToCppTranslation() {
   static llvm::cl::opt<bool> emitClass(
       "emit-class",
       llvm::cl::desc("If specified, the output will be a class where "
-                     "the function(s) in the module are members. "
-                     "Enables class-related options."),
+                     "the function(s) in the module are methods "
+                     "Enables class-related options"),
       llvm::cl::init(false));
 
   static llvm::cl::opt<std::string> className(
       "class-name",
-      llvm::cl::desc("Mandatory class name if --emit-class is set."),
+      llvm::cl::desc("Mandatory class name if --emit-class is set"),
       llvm::cl::init(""));
 
   static llvm::cl::opt<std::string> fieldNameAttribute(
       "field-name-attribute",
       llvm::cl::desc("Mandatory name of the attribute to use as field name if "
-                     "--emit-class is set."),
-      llvm::cl::init(""));
+                     "--emit-class is set(default=tf_saved_model.index_path)"),
+      llvm::cl::init("tf_saved_model.index_path"));
 
   TranslateFromMLIRRegistration reg(
       "mlir-to-cpp", "translate from mlir to cpp",
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index ef6583caf2d46..a819e550ad385 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -22,7 +22,6 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <stack>
@@ -290,14 +289,14 @@ struct CppEmitter {
   /// Controls whether the output should be a C++ class.
   /// If true, the generated C++ code will be encapsulated within a class,
   /// and functions from the input module will become its member functions.
-  bool emitClass;
+  const bool emitClass;
 
   /// The specified name for the generated C++ class
-  std::string className;
+  const std::string className;
 
   /// Name of the MLIR attribute to use as a field name within the generated
   /// class
-  std::string fieldNameAttribute;
+  const std::string fieldNameAttribute;
 
   /// Map from value to name of C++ variable that contain the name.
   ValueMapper valueMapper;
@@ -1189,9 +1188,8 @@ static LogicalResult emitClassFields(CppEmitter &emitter,
   os << ";\n";
 
   std::map<std::string, Value> fields;
-  os << "std::map<std::string, char*> _buffer_map {";
+  os << "\nstd::map<std::string, char*> _buffer_map {";
   if (argAttrs) {
-    bool isFirst = true;
     for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
       if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
         auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
@@ -1199,18 +1197,14 @@ static LogicalResult emitClassFields(CppEmitter &emitter,
         auto Ins = fields.insert({name, v});
         if (!Ins.second)
           return failure();
-        if (!isFirst) {
-          os << ",";
-        }
-        os << "{ \"" << name << "\"" << ", reinterpret_cast<char*>("
-           << emitter.getOrCreateName(v) << ") }";
-        isFirst = false;
+        os << " { \"" << name << "\"" << ", reinterpret_cast<char*>("
+           << emitter.getOrCreateName(v) << ") }, ";
       }
     }
   } else
     return failure();
 
-  os << "};";
+  os << "};\n";
   os << "char* getBufferForName(const std::string& name) const {\n";
   os.indent();
   os.indent();
@@ -1236,8 +1230,15 @@ static LogicalResult printOperation(CppEmitter &emitter,
   raw_indented_ostream &os = emitter.ostream();
   Operation *operation = functionOp.getOperation();
   if (emitter.shouldPrintClass()) {
-    if (functionOp.isExternal())
+    if (functionOp.isExternal()) {
+      // TODO: Determine the best long-term strategy for external functions.
+      // Currently, we're stopping here to prevent downstream errors.
+      os << "Warning: Cannot process external function '"
+         << functionOp.getName() << "'. "
+         << "It lacks a body, and attempting to continue would lead to errors "
+            "due to missing argument details.\n";
       return failure();
+    }
     os << "class " << emitter.getClassName() << " final {\n";
     os << "public: \n";
     os.indent();
@@ -1259,9 +1260,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
 
   os << "(";
 
-  if (emitter.shouldPrintClass())
-    os << ") { \n";
-  else {
+  if (!emitter.shouldPrintClass()) {
     if (functionOp.isExternal()) {
       if (failed(printFunctionArgs(emitter, operation,
                                    functionOp.getArgumentTypes())))
@@ -1272,8 +1271,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
     if (failed(
             printFunctionArgs(emitter, operation, functionOp.getArguments())))
       return failure();
-    os << ") {\n";
   }
+  os << ") {\n";
 
   if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
     return failure();
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
new file mode 100644
index 0000000000000..347992c467759
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-external.mlir
@@ -0,0 +1,8 @@
+/// An external function - has no body
+// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s 2>&1 | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+  emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
+}
+
+// CHECK: Warning: Cannot process external function 'extern_func'. It lacks a body, and attempting to continue would lead to errors due to missing argument details.
\ No newline at end of file
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
new file mode 100644
index 0000000000000..77e89da2f0a4f
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
@@ -0,0 +1,15 @@
+/// The function has no argument attributes
+// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+  emitc.func @foo(%arg0 : i32) {
+    emitc.call_opaque "bar" (%arg0) : (i32) -> ()
+    emitc.return
+  }
+}
+
+// CHECK: class ArgAttrs final {
+// CHECK-NEXT: public: 
+// CHECK-NEXT:   int32_t v1;
+// CHECK-EMPTY: 
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map {
\ No newline at end of file
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
index eddc2181a0b19..9fa69f652df6c 100644
--- a/mlir/test/mlir-translate/emit-class.mlir
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -1,49 +1,38 @@
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path /tmp/model_emitc.mlir | FileCheck %s --check-prefix=ADDER_TEST
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s  
 
-// ADDER_TEST-LABEL: class MyAdder final {
-// ADDER_TEST-NEXT: public:
-// ADDER_TEST-DAG: float v1[1];
-// ADDER_TEST-DAG: float v2[1];
-// ADDER_TEST-DAG: float v3[1];      
-// ADDER_TEST-NEXT: std::map<std::string, char*> _buffer_map {{ "another_feature", reinterpret_cast<char*>(v1) },{ "some_feature", reinterpret_cast<char*>(v2) },{ "output_0", reinterpret_cast<char*>(v3) }};
-// ADDER_TEST-NEXT: char* getBufferForName(const std::string& name) const {
-// ADDER_TEST-NEXT:     auto it = _buffer_map.find(name);
-// ADDER_TEST-NEXT:     return (it == _buffer_map.end()) ? nullptr : it->second;
-// ADDER_TEST-NEXT: }
-// ADDER_TEST-NEXT: void main() {
-// ADDER_TEST-NEXT:     size_t v4 = 0;
-// ADDER_TEST-NEXT:     float v5 = v2[v4];
-// ADDER_TEST-NEXT:     float v6 = v1[v4];
-// ADDER_TEST-NEXT:     float v7 = v5 + v6;
-// ADDER_TEST-NEXT:     v3[v4] = v7;
-// ADDER_TEST-NEXT:     return;
-// ADDER_TEST-NEXT: }
-// ADDER_TEST-NEXT: };
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+  emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+    %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+    %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %2 = load %1 : <f32>
+    %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %4 = load %3 : <f32>
+    %5 = add %2, %4 : (f32, f32) -> f32
+    %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    assign %5 : f32 to %6 : <f32>
+    return
+  }
+}
 
-// ---
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyMultiOutput --field-name-attribute=tf_saved_model.index_path /tmp/model_multi_out_emitc.mlir | FileCheck %s --check-prefix=MULTI_OUT
+// CHECK: class MyAdder final {
+// CHECK-NEXT: public:
+// CHECK-NEXT:   float v1[1];
+// CHECK-NEXT:   float v2[1];
+// CHECK-NEXT:   float v3[1];
+// CHECK-EMPTY:
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
+// CHECK-NEXT:   char* getBufferForName(const std::string& name) const {
+// CHECK-NEXT:      auto it = _buffer_map.find(name);
+// CHECK-NEXT:      return (it == _buffer_map.end()) ? nullptr : it->second;
+// CHECK-NEXT:   }
+// CHECK-EMPTY:
+// CHECK-NEXT: void main() {
+// CHECK-NEXT:     size_t v4 = 0;
+// CHECK-NEXT:     float v5 = v2[v4];
+// CHECK-NEXT:     float v6 = v1[v4];
+// CHECK-NEXT:     float v7 = v5 + v6;
+// CHECK-NEXT:     v3[v4] = v7;
+// CHECK-NEXT:     return;
+// CHECK-NEXT:  }
+// CHECK-NEXT: };
 
-// MULTI_OUT-LABEL: class MyMultiOutput final {
-// MULTI_OUT-NEXT: public:
-// MULTI_OUT-DAG: float v1[1];
-// MULTI_OUT-DAG: float v2[1];
-// MULTI_OUT-DAG: float v3[1];
-// MULTI_OUT-DAG: float v4[1];
-// MULTI_OUT: std::map<std::string, char*> _buffer_map {{ "b", reinterpret_cast<char*>(v1) },{ "a", reinterpret_cast<char*>(v2) },{ "output_1", reinterpret_cast<char*>(v3) },{ "output_0", reinterpret_cast<char*>(v4) }, };
-// MULTI_OUT-NEXT: char* getBufferForName(const std::string& name) const {
-// MULTI_OUT-NEXT:     auto it = _buffer_map.find(name);
-// MULTI_OUT-NEXT:     return (it == _buffer_map.end()) ? nullptr : it->second;
-// MULTI_OUT-NEXT: }
-// MULTI_OUT-NEXT: void main() {
-// MULTI_OUT-NEXT:     size_t v5 = 0;
-// MULTI_OUT-NEXT:     float v6 = v2[v5];
-// MULTI_OUT-NEXT:     float v7 = v1[v5];
-// MULTI_OUT-NEXT:     float v8 = v6 + v7;
-// MULTI_OUT-NEXT:     v4[v5] = v8;
-// MULTI_OUT-NEXT:     float v9 = v2[v5];
-// MULTI_OUT-NEXT:     float v10 = v1[v5];
-// MULTI_OUT-NEXT:     float v11 = v9 - v10;
-// MULTI_OUT-NEXT:     v3[v5] = v11;
-// MULTI_OUT-NEXT:     return;
-// MULTI_OUT-NEXT: }
-// MULTI_OUT-NEXT: };

>From ff237e575769e9235b236b3a4489cac07266d82b Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 29 May 2025 16:09:06 +0000
Subject: [PATCH 05/20] A lil' cleaning up

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp                  | 7 +++----
 mlir/test/mlir-translate/emit-class-neg-external.mlir   | 6 +++---
 mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir | 4 ++--
 mlir/test/mlir-translate/emit-class.mlir                | 5 +++--
 4 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index a819e550ad385..e09ed0a142725 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1232,12 +1232,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
   if (emitter.shouldPrintClass()) {
     if (functionOp.isExternal()) {
       // TODO: Determine the best long-term strategy for external functions.
-      // Currently, we're stopping here to prevent downstream errors.
+      // Currently, we're skipping over this functionOp.
       os << "Warning: Cannot process external function '"
          << functionOp.getName() << "'. "
-         << "It lacks a body, and attempting to continue would lead to errors "
-            "due to missing argument details.\n";
-      return failure();
+         << "This functionOp lacks a body so we will skip over it.";
+      return success();
     }
     os << "class " << emitter.getClassName() << " final {\n";
     os << "public: \n";
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
index 347992c467759..c34a1652abd3f 100644
--- a/mlir/test/mlir-translate/emit-class-neg-external.mlir
+++ b/mlir/test/mlir-translate/emit-class-neg-external.mlir
@@ -1,8 +1,8 @@
 /// An external function - has no body
-// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s 2>&1 | FileCheck %s
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
 
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
   emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
 }
 
-// CHECK: Warning: Cannot process external function 'extern_func'. It lacks a body, and attempting to continue would lead to errors due to missing argument details.
\ No newline at end of file
+// CHECK: Warning: Cannot process external function 'extern_func'. This functionOp lacks a body so we will skip over it.
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
index 77e89da2f0a4f..6d43fa953a946 100644
--- a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
+++ b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
@@ -1,7 +1,7 @@
 /// The function has no argument attributes
 // RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
 
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
   emitc.func @foo(%arg0 : i32) {
     emitc.call_opaque "bar" (%arg0) : (i32) -> ()
     emitc.return
@@ -12,4 +12,4 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
 // CHECK-NEXT: public: 
 // CHECK-NEXT:   int32_t v1;
 // CHECK-EMPTY: 
-// CHECK-NEXT:   std::map<std::string, char*> _buffer_map {
\ No newline at end of file
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map {
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
index 9fa69f652df6c..2779cb315ed41 100644
--- a/mlir/test/mlir-translate/emit-class.mlir
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s  
 
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
   emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
     %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
     %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
@@ -20,7 +20,8 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
 // CHECK-NEXT:   float v2[1];
 // CHECK-NEXT:   float v3[1];
 // CHECK-EMPTY:
-// CHECK-NEXT:   std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, 
+// CHECK-SAME: { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
 // CHECK-NEXT:   char* getBufferForName(const std::string& name) const {
 // CHECK-NEXT:      auto it = _buffer_map.find(name);
 // CHECK-NEXT:      return (it == _buffer_map.end()) ? nullptr : it->second;

>From 2dc314891040dfa94182f21fc874ebbbee186d1b Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 29 May 2025 17:51:59 +0000
Subject: [PATCH 06/20] Clarifying TODO messages

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp | 8 ++++++++
 1 file changed, 8 insertions(+)

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index e09ed0a142725..aa85a03b84885 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1233,6 +1233,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
     if (functionOp.isExternal()) {
       // TODO: Determine the best long-term strategy for external functions.
       // Currently, we're skipping over this functionOp.
+      // We have considered using emitWarning() which would return
+      // InFlightDiagnostic which seems can be automatically converted to LogicalResult since
+      // this is done in emitAttributes where emitError is converted to LogicalResult. However, it requires that we pass in a
+      // location which at first glance we don't have in this scope. Open to
+      // further discussion on this.
       os << "Warning: Cannot process external function '"
          << functionOp.getName() << "'. "
          << "This functionOp lacks a body so we will skip over it.";
@@ -1255,6 +1260,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
   if (failed(emitter.emitTypes(functionOp.getLoc(),
                                functionOp.getFunctionType().getResults())))
     return failure();
+  // TODO: We may wanna consider having the name of the function be execute in
+  // the case that we want to emit a class instead of main. Leaving as is for
+  // now to make the change smaller.
   os << " " << functionOp.getName();
 
   os << "(";

>From bf4f1cd010631adc295210bbb98a32178bad8f34 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 29 May 2025 19:01:55 +0000
Subject: [PATCH 07/20] Formatting issues

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp | 9 +++++----
 1 file changed, 5 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index aa85a03b84885..fc550d24113e3 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -70,8 +70,8 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
 
 template <typename Container, typename UnaryFunctor>
 inline LogicalResult interleaveWithNewLineWithError(const Container &c,
-                                              raw_ostream &os,
-                                              UnaryFunctor eachFn) {
+                                                    raw_ostream &os,
+                                                    UnaryFunctor eachFn) {
   return interleaveWithError(c.begin(), c.end(), eachFn,
                              [&]() { os << ";\n"; });
 }
@@ -1234,8 +1234,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
       // TODO: Determine the best long-term strategy for external functions.
       // Currently, we're skipping over this functionOp.
       // We have considered using emitWarning() which would return
-      // InFlightDiagnostic which seems can be automatically converted to LogicalResult since
-      // this is done in emitAttributes where emitError is converted to LogicalResult. However, it requires that we pass in a
+      // InFlightDiagnostic which seems can be automatically converted to
+      // LogicalResult since this is done in emitAttributes where emitError is
+      // converted to LogicalResult. However, it requires that we pass in a
       // location which at first glance we don't have in this scope. Open to
       // further discussion on this.
       os << "Warning: Cannot process external function '"

>From e7f208437bc44f211a50c152809499b81ae40c98 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
 =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
 =?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Thu, 22 May 2025 08:24:18 -0700
Subject: [PATCH 08/20] [flang][rt] Fix the use of kNoAsyncId -> kNoAsyncObject
 (#141079)


>From 20d9f7a4e416f9940430ef1aa99f0c6e3995543a Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 6 Jun 2025 16:12:33 +0000
Subject: [PATCH 09/20] Adding ClassOp, FieldOp, GetFieldOp to allow for a
 transfrom from func to class

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 116 ++++++++++++++++++
 .../mlir/Dialect/EmitC/Transforms/Passes.h    |   1 +
 .../mlir/Dialect/EmitC/Transforms/Passes.td   |  11 ++
 .../Dialect/EmitC/Transforms/Transforms.h     |   6 +
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |  73 +++++++++++
 .../Dialect/EmitC/Transforms/CMakeLists.txt   |   1 +
 .../EmitC/Transforms/ConvertFuncToClass.cpp   |  70 +++++++++++
 .../Dialect/EmitC/Transforms/Transforms.cpp   |  77 ++++++++++++
 .../Dialect/EmitC/convert_func_to_class.mlir  |  15 +++
 9 files changed, 370 insertions(+)
 create mode 100644 mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
 create mode 100644 mlir/test/Dialect/EmitC/convert_func_to_class.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index e53d3e45875d5..ea6af41ee2901 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1572,4 +1572,120 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
   let hasVerifier = 1;
 }
 
+def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
+                                       IsolatedFromAbove, OpAsmOpInterface]> {
+  let summary =
+      "Represents a C++ class definition, encapsulating fields and methods.";
+
+  let description = [{
+    The `emitc.class` operation defines a C++ class, acting as a container
+    for its data fields (`emitc.variable`) and methods (`emitc.func`).
+    It creates a distinct scope, isolating its contents from the surrounding
+    MLIR region, similar to how C++ classes encapsulate their internals.
+
+    Example:
+    ```mlir
+    emitc.class @MyModelClass {
+        emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
+        emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>>
+        emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
+
+        emitc.func @main() attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+          %c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+
+          %some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
+          %another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
+          %output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
+
+          %v1 = subscript %some_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
+          %v1_val = load %v1 : !emitc.lvalue<f32> -> f32
+
+          %v2 = subscript %another_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
+          %v2_val = load %v2 : !emitc.lvalue<f32> -> f32
+
+          %v3_val = add %v1_val, %v2_val : (f32, f32) -> f32
+
+          %output_lvalue = subscript %output_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
+          assign %v3_val, %output_lvalue : (f32, !emitc.lvalue<f32>) -> ()
+
+          return
+        }
+      }
+    }
+
+    ```
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name);
+
+  let regions = (region AnyRegion:$body);
+
+  let builders = [];
+
+  let extraClassDeclaration = [{
+    // Returns the body block containing class members and methods.
+    Block &getBlock();
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+
+  let assemblyFormat = "`class` $sym_name attr-dict-with-keyword $body";
+}
+
+def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
+  let summary = "A field within a class";
+  let description = [{
+    The `emitc.field` operation declares a named field within an `emitc.class`
+    operation. The field's type must be an EmitC type. An optional initial value can be provided.
+
+    Example with initial values:
+
+    ```mlir
+    emitc.class @MyModelClass {
+      emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<0.0> : !emitc.f32
+      emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<1.0> : !emitc.f32
+      emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
+    }
+    ```
+    Example without initial value:
+    ```mlir
+    emitc.class @MyModelClass {
+      emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
+    }
+    ```
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
+      OptionalAttr<AnyAttr>:$initial_value);
+
+  let assemblyFormat = "$sym_name `:` $type (`=` $initial_value^)? attr-dict";
+
+  let hasVerifier = 1;
+}
+
+def EmitC_GetFieldOp
+    : EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods<
+                                       SymbolUserOpInterface>]> {
+  let summary = "Obtain access to a field within a class instance";
+  let description = [{
+     The `emitc.get_field` operation retrieves the lvalue of a
+     named field from a given class instance.
+
+     Example:
+
+     ```mlir
+     %some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
+     %another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
+     %output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
+     ```
+  }];
+
+  let arguments = (ins AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$base,
+      FlatSymbolRefAttr:$class_name, FlatSymbolRefAttr:$field_name);
+
+  let results = (outs AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$result);
+  let assemblyFormat = "$base `:` type($base) $class_name `,` $field_name `->` "
+                       "type($result) attr-dict";
+}
+
 #endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
index 5a103f181c76b..ad516f93808f8 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
@@ -15,6 +15,7 @@ namespace mlir {
 namespace emitc {
 
 #define GEN_PASS_DECL_FORMEXPRESSIONSPASS
+#define GEN_PASS_DECL_CONVERTFUNCTOCLASSPASS
 #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index f46b705ca2dfe..d84c3184da777 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,4 +20,15 @@ def FormExpressionsPass : Pass<"form-expressions"> {
   let dependentDialects = ["emitc::EmitCDialect"];
 }
 
+def ConvertFuncToClassPass
+    : Pass<"convert-emitc-func-to-class", "mlir::emitc::FuncOp"> {
+  let summary = "Convert functions to classes, using arguments as fields.";
+  let description = [{
+    This pass transforms `emitc.func` operations into `emitc.class` operations.
+    Function arguments become fields of the class, and the function body is moved
+    to a new `execute` method within the class.
+  }];
+  let dependentDialects = ["emitc::EmitCDialect"];
+}
+
 #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 2574acd7d48e0..49dec99938a44 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -28,6 +28,12 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
 /// Populates `patterns` with expression-related patterns.
 void populateExpressionPatterns(RewritePatternSet &patterns);
 
+//===----------------------------------------------------------------------===//
+// Convert Func to Class Transform
+//===----------------------------------------------------------------------===//
+
+ClassOp createClass(FuncOp funcOp, OpBuilder &builder);
+
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index f82b20712b8c6..3af0f16be7515 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1400,6 +1400,79 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
       builder.getNamedAttr("id", builder.getStringAttr(id)));
 }
 
+//===----------------------------------------------------------------------===//
+// FieldOp
+//===----------------------------------------------------------------------===//
+LogicalResult FieldOp::verify() {
+  if (!isSupportedEmitCType(getType())) {
+    return emitOpError("expected valid emitc type");
+  }
+
+  if (!getInitialValue().has_value()) {
+    return success();
+  }
+
+  Attribute initValue = getInitialValue().value();
+  // Check that the type of the initial value is compatible with the type of
+  // the global variable.
+  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
+    auto initialValueType = elementsAttr.getType();
+    if (!initialValueType) {
+      return emitOpError("initial value attribute must have a type");
+    }
+    auto fieldType = getType();
+    if (initialValueType != fieldType) {
+      if (auto lvalueType = dyn_cast<LValueType>(fieldType)) {
+        auto innerFieldType = lvalueType.getValueType();
+        if (innerFieldType != initialValueType) {
+          return emitOpError("initial value type ")
+                 << initialValueType << " is not compatible with field type "
+                 << fieldType << " its inner type " << innerFieldType;
+        }
+
+      } else {
+        return emitOpError("initial value type ")
+               << initialValueType << " is not compatible with field type "
+               << fieldType;
+      }
+    }
+  }
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetFieldOp
+//===----------------------------------------------------------------------===//
+LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  auto classNameAttr = getClassNameAttr();
+  auto fieldNameAttr = getFieldNameAttr();
+  if (!classNameAttr || !fieldNameAttr) {
+    return emitError("class and field name attributes are mandatory");
+  }
+  StringRef className = classNameAttr.getValue();
+  StringRef fieldName = fieldNameAttr.getValue();
+
+  auto fieldOp =
+      symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
+
+  if (!fieldOp) {
+    return emitOpError("field '")
+           << fieldName << "' not found in class '" << className << "'";
+  }
+
+  Type getFieldResultType = getResult().getType();
+  Type fieldType = fieldOp.getType();
+
+  if (fieldType != getFieldResultType) {
+    return emitOpError("result type ")
+           << getFieldResultType << " does not match field '" << fieldName
+           << "' type " << fieldType;
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
index 19b80b22bd84b..4c2525c2a5b00 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
   Transforms.cpp
   FormExpressions.cpp
   TypeConversions.cpp
+  ConvertFuncToClass.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
diff --git a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
new file mode 100644
index 0000000000000..5beb62f675414
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
@@ -0,0 +1,70 @@
+//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace emitc {
+
+#define GEN_PASS_DEF_CONVERTFUNCTOCLASSPASS
+#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
+
+namespace {
+
+struct ConvertFuncToClassPass
+    : public impl::ConvertFuncToClassPassBase<ConvertFuncToClassPass> {
+  void runOnOperation() override {
+    emitc::FuncOp funcOp = getOperation();
+    MLIRContext *context = funcOp->getContext();
+
+    // Wrap each C operator op with an expression op.
+    OpBuilder builder(context);
+    createClass(funcOp, builder);
+
+    // // Create the new function inside the class
+    // auto funcType = FunctionType::get(funcOp.getContext(),
+    // funcOp.getFunctionType().getInputs(),
+    // funcOp.getFunctionType().getResults()); auto newFuncOp =
+    // builder.create<emitc::FuncOp>(
+    //     funcOp.getLoc(),builder.getStringAttr("execute"), funcType );
+
+    // builder.createBlock(&newFuncOp.getBody());
+    // builder.setInsertionPointToStart(&newFuncOp.getBody().front());
+
+    // // 7. Remap original arguments to field pointers
+    // IRMapping mapper;
+
+    // // 8. move or clone operations from original function
+    // for (Operation &opToClone :
+    // llvm::make_early_inc_range(funcOp.getBody().front())) {
+    //     if (isa<emitc::ConstantOp>(opToClone) ||
+    //         isa<emitc::SubscriptOp>(opToClone) ||
+    //         isa<emitc::LoadOp>(opToClone) ||
+    //         isa<emitc::AddOp>(opToClone) ||
+    //         isa<emitc::AssignOp>(opToClone) ||
+    //         isa<emitc::ReturnOp>(opToClone )) {
+    //     builder.clone(opToClone, mapper);
+    //     } else  {
+    //     opToClone.emitOpError("Unsupported operation found");
+    //     }
+    // }
+    // if (funcOp->use_empty()) funcOp->erase();
+  }
+};
+
+} // namespace
+
+} // namespace emitc
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 87350ecdceaaa..a9687e4d1d187 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -43,6 +43,83 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
   return expressionOp;
 }
 
+ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
+  builder.setInsertionPoint(funcOp);
+
+  // 2. Create the class
+  auto classOp = builder.create<emitc::ClassOp>(
+      funcOp.getLoc(), builder.getStringAttr("MyModelClass"));
+
+  // Create a block inside the class body and set insertion point
+  builder.createBlock(&classOp.getBody());
+  builder.setInsertionPointToStart(&classOp.getBody().front());
+
+  // 3. Extract input/output names from function arguments
+  SmallVector<std::pair<StringRef, Type>> fields;
+  llvm::SmallDenseMap<Value, Value> argToFieldMap;
+
+  auto argAttrs = funcOp.getArgAttrs();
+  if (argAttrs) {
+    for (const auto [arg, val] : zip(*argAttrs, funcOp.getArguments())) {
+      if (auto da = dyn_cast<mlir::DictionaryAttr>(arg)) {
+        auto nv = da.getNamed("tf_saved_model.index_path")->getValue();
+        auto fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
+        auto fieldType = emitc::LValueType::get(emitc::PointerType::get(
+            dyn_cast_or_null<emitc::ArrayType>(val.getType())
+                .getElementType()));
+        fields.push_back({fieldName.str(), fieldType});
+
+        // 4.Create the class fields
+        auto typeAttr = TypeAttr::get(val.getType());
+        mlir::Attribute emptyAttr = builder.getAttr<mlir::UnitAttr>();
+        auto dictAttr = DictionaryAttr::get(
+            builder.getContext(),
+            {builder.getNamedAttr(fieldName.str(), emptyAttr)});
+        builder.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+                                       /* attributes*/ dictAttr);
+        // 5. Get the pointers to the class fields
+        auto pointer = emitc::PointerType::get(
+            dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
+        auto ptr = builder.create<emitc::GetFieldOp>(
+            funcOp.getLoc(), pointer, val, "MyModelClass", fieldName);
+        argToFieldMap[val] = ptr;
+      }
+    }
+  }
+
+  // Create the new function inside the class
+  auto funcContext = funcOp.getContext();
+  auto inputTypes = funcOp.getFunctionType().getInputs();
+  auto results = funcOp.getFunctionType().getResults();
+  auto funcType = FunctionType::get(funcContext, inputTypes, results);
+  auto loc = funcOp.getLoc();
+  auto newFuncOp = builder.create<emitc::FuncOp>(
+      loc, builder.getStringAttr("execute"), funcType);
+
+  builder.createBlock(&newFuncOp.getBody());
+  builder.setInsertionPointToStart(&newFuncOp.getBody().front());
+
+  // 7. Remap original arguments to field pointers
+  IRMapping mapper;
+
+  // 8. move or clone operations from original function
+  auto body = llvm::make_early_inc_range(funcOp.getBody().front());
+  for (Operation &opToClone : body) {
+    if (isa<emitc::ConstantOp>(opToClone) ||
+        isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
+        isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
+        isa<emitc::ReturnOp>(opToClone)) {
+      builder.clone(opToClone, mapper);
+    } else {
+      opToClone.emitOpError("Unsupported operation found");
+    }
+  }
+
+  // if (funcOp->use_empty()) funcOp->erase();
+
+  return classOp;
+}
+
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/test/Dialect/EmitC/convert_func_to_class.mlir b/mlir/test/Dialect/EmitC/convert_func_to_class.mlir
new file mode 100644
index 0000000000000..f9e276f059130
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/convert_func_to_class.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt  %s --emitc-convert-func-to-class 
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+  emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+    %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+    %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %2 = load %1 : <f32>
+    %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %4 = load %3 : <f32>
+    %5 = add %2, %4 : (f32, f32) -> f32
+    %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    assign %5 : f32 to %6 : <f32>
+    return
+  }
+}

>From 8ba35d7d97c605f4d98842a667502f24df0220f0 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 6 Jun 2025 17:22:04 +0000
Subject: [PATCH 10/20] Removed unnecessary comments

---
 .../EmitC/Transforms/ConvertFuncToClass.cpp   | 32 +------------------
 .../Dialect/EmitC/Transforms/Transforms.cpp   | 25 +++++++--------
 2 files changed, 12 insertions(+), 45 deletions(-)

diff --git a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
index 5beb62f675414..fe8a05d39e1df 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
@@ -29,42 +29,12 @@ struct ConvertFuncToClassPass
     emitc::FuncOp funcOp = getOperation();
     MLIRContext *context = funcOp->getContext();
 
-    // Wrap each C operator op with an expression op.
     OpBuilder builder(context);
     createClass(funcOp, builder);
-
-    // // Create the new function inside the class
-    // auto funcType = FunctionType::get(funcOp.getContext(),
-    // funcOp.getFunctionType().getInputs(),
-    // funcOp.getFunctionType().getResults()); auto newFuncOp =
-    // builder.create<emitc::FuncOp>(
-    //     funcOp.getLoc(),builder.getStringAttr("execute"), funcType );
-
-    // builder.createBlock(&newFuncOp.getBody());
-    // builder.setInsertionPointToStart(&newFuncOp.getBody().front());
-
-    // // 7. Remap original arguments to field pointers
-    // IRMapping mapper;
-
-    // // 8. move or clone operations from original function
-    // for (Operation &opToClone :
-    // llvm::make_early_inc_range(funcOp.getBody().front())) {
-    //     if (isa<emitc::ConstantOp>(opToClone) ||
-    //         isa<emitc::SubscriptOp>(opToClone) ||
-    //         isa<emitc::LoadOp>(opToClone) ||
-    //         isa<emitc::AddOp>(opToClone) ||
-    //         isa<emitc::AssignOp>(opToClone) ||
-    //         isa<emitc::ReturnOp>(opToClone )) {
-    //     builder.clone(opToClone, mapper);
-    //     } else  {
-    //     opToClone.emitOpError("Unsupported operation found");
-    //     }
-    // }
-    // if (funcOp->use_empty()) funcOp->erase();
   }
 };
 
 } // namespace
 
 } // namespace emitc
-} // namespace mlir
\ No newline at end of file
+} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index a9687e4d1d187..8471631c6e60b 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -46,15 +46,12 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
 ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
   builder.setInsertionPoint(funcOp);
 
-  // 2. Create the class
   auto classOp = builder.create<emitc::ClassOp>(
       funcOp.getLoc(), builder.getStringAttr("MyModelClass"));
 
-  // Create a block inside the class body and set insertion point
   builder.createBlock(&classOp.getBody());
   builder.setInsertionPointToStart(&classOp.getBody().front());
 
-  // 3. Extract input/output names from function arguments
   SmallVector<std::pair<StringRef, Type>> fields;
   llvm::SmallDenseMap<Value, Value> argToFieldMap;
 
@@ -69,7 +66,6 @@ ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
                 .getElementType()));
         fields.push_back({fieldName.str(), fieldType});
 
-        // 4.Create the class fields
         auto typeAttr = TypeAttr::get(val.getType());
         mlir::Attribute emptyAttr = builder.getAttr<mlir::UnitAttr>();
         auto dictAttr = DictionaryAttr::get(
@@ -77,17 +73,19 @@ ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
             {builder.getNamedAttr(fieldName.str(), emptyAttr)});
         builder.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
                                        /* attributes*/ dictAttr);
-        // 5. Get the pointers to the class fields
-        auto pointer = emitc::PointerType::get(
-            dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
-        auto ptr = builder.create<emitc::GetFieldOp>(
-            funcOp.getLoc(), pointer, val, "MyModelClass", fieldName);
-        argToFieldMap[val] = ptr;
+
+        // TODO: From my current understanding, we need to instantiate a class
+        // so we can get the pointers from .field but we can't do that in here
+        // so I'm unsure how I can rewrite the following line to ensure
+        // GetFieldOp works correctly. auto pointer =
+        // emitc::PointerType::get(dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
+        // auto ptr = builder.create<emitc::GetFieldOp>(funcOp.getLoc(),
+        // pointer, val, "MyModelClass", fieldName);
+        argToFieldMap[val] = nullptr;
       }
     }
   }
 
-  // Create the new function inside the class
   auto funcContext = funcOp.getContext();
   auto inputTypes = funcOp.getFunctionType().getInputs();
   auto results = funcOp.getFunctionType().getResults();
@@ -99,10 +97,8 @@ ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
   builder.createBlock(&newFuncOp.getBody());
   builder.setInsertionPointToStart(&newFuncOp.getBody().front());
 
-  // 7. Remap original arguments to field pointers
   IRMapping mapper;
 
-  // 8. move or clone operations from original function
   auto body = llvm::make_early_inc_range(funcOp.getBody().front());
   for (Operation &opToClone : body) {
     if (isa<emitc::ConstantOp>(opToClone) ||
@@ -115,7 +111,8 @@ ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
     }
   }
 
-  // if (funcOp->use_empty()) funcOp->erase();
+  // TODO: Need to erase the funcOp after all this. Using funcOp->erase raises
+  // errors:
 
   return classOp;
 }

>From f3acc4faf11d2109b1e9b5f9ae50705b2e1ffb8d Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 12 Jun 2025 21:36:24 +0000
Subject: [PATCH 11/20] rewritten

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   |  87 +++++-----
 .../mlir/Dialect/EmitC/Transforms/Passes.h    |   2 +-
 .../mlir/Dialect/EmitC/Transforms/Passes.td   |   5 +-
 .../Dialect/EmitC/Transforms/Transforms.h     |   7 +-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           |  38 +++--
 .../Dialect/EmitC/Transforms/CMakeLists.txt   |   2 +-
 .../EmitC/Transforms/ConvertFuncToClass.cpp   |  40 -----
 .../Dialect/EmitC/Transforms/Transforms.cpp   |  75 ---------
 .../EmitC/Transforms/WrapFuncInClass.cpp      | 150 ++++++++++++++++++
 ...ass.mlir => wrap_emitc_func_in_class.mlir} |   2 +-
 10 files changed, 214 insertions(+), 194 deletions(-)
 delete mode 100644 mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
 create mode 100644 mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
 rename mlir/test/Dialect/EmitC/{convert_func_to_class.mlir => wrap_emitc_func_in_class.mlir} (95%)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ea6af41ee2901..08235ca701e2a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1572,46 +1572,45 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
   let hasVerifier = 1;
 }
 
-def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
-                                       IsolatedFromAbove, OpAsmOpInterface]> {
+def EmitC_ClassOp
+    : EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove,
+                         OpAsmOpInterface, SymbolTable,
+                         Symbol]#GraphRegionNoTerminator.traits> {
   let summary =
       "Represents a C++ class definition, encapsulating fields and methods.";
 
+  // FIX WORDING
   let description = [{
     The `emitc.class` operation defines a C++ class, acting as a container
     for its data fields (`emitc.variable`) and methods (`emitc.func`).
     It creates a distinct scope, isolating its contents from the surrounding
     MLIR region, similar to how C++ classes encapsulate their internals.
+    All the class memebrs need to be default initalizable. 
 
     Example:
     ```mlir
-    emitc.class @MyModelClass {
-        emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
-        emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>>
-        emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
-
-        emitc.func @main() attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
-          %c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-
-          %some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
-          %another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
-          %output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
-
-          %v1 = subscript %some_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
-          %v1_val = load %v1 : !emitc.lvalue<f32> -> f32
-
-          %v2 = subscript %another_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
-          %v2_val = load %v2 : !emitc.lvalue<f32> -> f32
-
-          %v3_val = add %v1_val, %v2_val : (f32, f32) -> f32
-
-          %output_lvalue = subscript %output_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
-          assign %v3_val, %output_lvalue : (f32, !emitc.lvalue<f32>) -> ()
-
-          return
-        }
+    emitc.class @MymainClass {
+      emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
+      emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
+      emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
+
+      emitc.func @execute() {
+        %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+
+        %1 = get_field @another_feature : !emitc.array<1xf32>
+        %2 = get_field @some_feature : !emitc.array<1xf32>
+        %3 = get_field @output_0 : !emitc.array<1xf32>
+
+        %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+        %5 = load %4 : <f32>
+        %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+        %7 = load %6 : <f32>
+        %8 = add %5, %7 : (f32, f32) -> f32
+        %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+        assign %8 : f32 to %9 : <f32>
+        return
       }
-    }
+  }
 
     ```
   }];
@@ -1629,7 +1628,7 @@ def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
 
   let hasCustomAssemblyFormat = 1;
 
-  let assemblyFormat = "`class` $sym_name attr-dict-with-keyword $body";
+  let assemblyFormat = [{ $sym_name attr-dict-with-keyword $body }];
 }
 
 def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
@@ -1642,15 +1641,9 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
 
     ```mlir
     emitc.class @MyModelClass {
-      emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<0.0> : !emitc.f32
-      emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<1.0> : !emitc.f32
-      emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
-    }
-    ```
-    Example without initial value:
-    ```mlir
-    emitc.class @MyModelClass {
-      emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
+      emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
+      emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
+      emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
     }
     ```
   }];
@@ -1658,7 +1651,8 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
   let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
       OptionalAttr<AnyAttr>:$initial_value);
 
-  let assemblyFormat = "$sym_name `:` $type (`=` $initial_value^)? attr-dict";
+  let assemblyFormat =
+      [{ $sym_name `:` $type (`=` $initial_value^)? attr-dict}];
 
   let hasVerifier = 1;
 }
@@ -1674,18 +1668,15 @@ def EmitC_GetFieldOp
      Example:
 
      ```mlir
-     %some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
-     %another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
-     %output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
+    %some_ptr = emitc.get_field @some_feature : !emitc.array<1xf32> 
+    %another_ptr = emitc.get_field @another_feature : !emitc.array<1xf32> 
+    %output_ptr = emitc.get_field @output_0 : !emitc.array<1xf32> 
      ```
   }];
 
-  let arguments = (ins AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$base,
-      FlatSymbolRefAttr:$class_name, FlatSymbolRefAttr:$field_name);
-
-  let results = (outs AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$result);
-  let assemblyFormat = "$base `:` type($base) $class_name `,` $field_name `->` "
-                       "type($result) attr-dict";
+  let arguments = (ins FlatSymbolRefAttr:$field_name);
+  let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
+  let assemblyFormat = "$field_name `:` type($result) attr-dict";
 }
 
 #endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
index ad516f93808f8..1af4aa06fa811 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
@@ -15,7 +15,7 @@ namespace mlir {
 namespace emitc {
 
 #define GEN_PASS_DECL_FORMEXPRESSIONSPASS
-#define GEN_PASS_DECL_CONVERTFUNCTOCLASSPASS
+#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS
 #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index d84c3184da777..1aa95b32217c1 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,9 +20,8 @@ def FormExpressionsPass : Pass<"form-expressions"> {
   let dependentDialects = ["emitc::EmitCDialect"];
 }
 
-def ConvertFuncToClassPass
-    : Pass<"convert-emitc-func-to-class", "mlir::emitc::FuncOp"> {
-  let summary = "Convert functions to classes, using arguments as fields.";
+def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
+  let summary = "Wrap functions in classes, using arguments as fields.";
   let description = [{
     This pass transforms `emitc.func` operations into `emitc.class` operations.
     Function arguments become fields of the class, and the function body is moved
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 49dec99938a44..bdf6d0985e6db 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -28,11 +28,8 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
 /// Populates `patterns` with expression-related patterns.
 void populateExpressionPatterns(RewritePatternSet &patterns);
 
-//===----------------------------------------------------------------------===//
-// Convert Func to Class Transform
-//===----------------------------------------------------------------------===//
-
-ClassOp createClass(FuncOp funcOp, OpBuilder &builder);
+/// Populates 'patterns' with func-related patterns.
+void populateFuncPatterns(RewritePatternSet &patterns);
 
 } // namespace emitc
 } // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 3af0f16be7515..695ad3ee3bb0a 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1408,32 +1408,32 @@ LogicalResult FieldOp::verify() {
     return emitOpError("expected valid emitc type");
   }
 
-  if (!getInitialValue().has_value()) {
+  if (!getInitialValue()) {
     return success();
   }
 
-  Attribute initValue = getInitialValue().value();
+  Attribute initValue = *getInitialValue();
   // Check that the type of the initial value is compatible with the type of
   // the global variable.
-  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
-    auto initialValueType = elementsAttr.getType();
+  if (ElementsAttr elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
+    Type initialValueType = elementsAttr.getType();
     if (!initialValueType) {
       return emitOpError("initial value attribute must have a type");
     }
-    auto fieldType = getType();
+    Type fieldType = getType();
     if (initialValueType != fieldType) {
-      if (auto lvalueType = dyn_cast<LValueType>(fieldType)) {
-        auto innerFieldType = lvalueType.getValueType();
+      if (LValueType lvalueType = dyn_cast<LValueType>(fieldType)) {
+        Type innerFieldType = lvalueType.getValueType();
         if (innerFieldType != initialValueType) {
           return emitOpError("initial value type ")
-                 << initialValueType << " is not compatible with field type "
-                 << fieldType << " its inner type " << innerFieldType;
+                 << initialValueType << " is not compatible with field type '"
+                 << fieldType << "' its inner type '" << innerFieldType << "'";
         }
 
       } else {
-        return emitOpError("initial value type ")
-               << initialValueType << " is not compatible with field type "
-               << fieldType;
+        return emitOpError("initial value type '")
+               << initialValueType << "' is not compatible with field type '"
+               << fieldType << "'";
       }
     }
   }
@@ -1445,20 +1445,18 @@ LogicalResult FieldOp::verify() {
 // GetFieldOp
 //===----------------------------------------------------------------------===//
 LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
-  auto classNameAttr = getClassNameAttr();
-  auto fieldNameAttr = getFieldNameAttr();
-  if (!classNameAttr || !fieldNameAttr) {
-    return emitError("class and field name attributes are mandatory");
+  mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
+  if (!fieldNameAttr) {
+    return emitError("field name attribute is mandatory");
   }
-  StringRef className = classNameAttr.getValue();
+
   StringRef fieldName = fieldNameAttr.getValue();
 
-  auto fieldOp =
+  FieldOp fieldOp =
       symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
 
   if (!fieldOp) {
-    return emitOpError("field '")
-           << fieldName << "' not found in class '" << className << "'";
+    return emitOpError("field '") << fieldName << "' not found in the class '";
   }
 
   Type getFieldResultType = getResult().getType();
diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
index 4c2525c2a5b00..baf67afc30072 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
   Transforms.cpp
   FormExpressions.cpp
   TypeConversions.cpp
-  ConvertFuncToClass.cpp
+  WrapFuncInClass.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
diff --git a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
deleted file mode 100644
index fe8a05d39e1df..0000000000000
--- a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
+++ /dev/null
@@ -1,40 +0,0 @@
-//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/EmitC/IR/EmitC.h"
-#include "mlir/Dialect/EmitC/Transforms/Passes.h"
-#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace emitc {
-
-#define GEN_PASS_DEF_CONVERTFUNCTOCLASSPASS
-#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
-
-namespace {
-
-struct ConvertFuncToClassPass
-    : public impl::ConvertFuncToClassPassBase<ConvertFuncToClassPass> {
-  void runOnOperation() override {
-    emitc::FuncOp funcOp = getOperation();
-    MLIRContext *context = funcOp->getContext();
-
-    OpBuilder builder(context);
-    createClass(funcOp, builder);
-  }
-};
-
-} // namespace
-
-} // namespace emitc
-} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 8471631c6e60b..a252924caeb62 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -10,7 +10,6 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
-#include "llvm/Support/Debug.h"
 
 namespace mlir {
 namespace emitc {
@@ -43,80 +42,6 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
   return expressionOp;
 }
 
-ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
-  builder.setInsertionPoint(funcOp);
-
-  auto classOp = builder.create<emitc::ClassOp>(
-      funcOp.getLoc(), builder.getStringAttr("MyModelClass"));
-
-  builder.createBlock(&classOp.getBody());
-  builder.setInsertionPointToStart(&classOp.getBody().front());
-
-  SmallVector<std::pair<StringRef, Type>> fields;
-  llvm::SmallDenseMap<Value, Value> argToFieldMap;
-
-  auto argAttrs = funcOp.getArgAttrs();
-  if (argAttrs) {
-    for (const auto [arg, val] : zip(*argAttrs, funcOp.getArguments())) {
-      if (auto da = dyn_cast<mlir::DictionaryAttr>(arg)) {
-        auto nv = da.getNamed("tf_saved_model.index_path")->getValue();
-        auto fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
-        auto fieldType = emitc::LValueType::get(emitc::PointerType::get(
-            dyn_cast_or_null<emitc::ArrayType>(val.getType())
-                .getElementType()));
-        fields.push_back({fieldName.str(), fieldType});
-
-        auto typeAttr = TypeAttr::get(val.getType());
-        mlir::Attribute emptyAttr = builder.getAttr<mlir::UnitAttr>();
-        auto dictAttr = DictionaryAttr::get(
-            builder.getContext(),
-            {builder.getNamedAttr(fieldName.str(), emptyAttr)});
-        builder.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
-                                       /* attributes*/ dictAttr);
-
-        // TODO: From my current understanding, we need to instantiate a class
-        // so we can get the pointers from .field but we can't do that in here
-        // so I'm unsure how I can rewrite the following line to ensure
-        // GetFieldOp works correctly. auto pointer =
-        // emitc::PointerType::get(dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
-        // auto ptr = builder.create<emitc::GetFieldOp>(funcOp.getLoc(),
-        // pointer, val, "MyModelClass", fieldName);
-        argToFieldMap[val] = nullptr;
-      }
-    }
-  }
-
-  auto funcContext = funcOp.getContext();
-  auto inputTypes = funcOp.getFunctionType().getInputs();
-  auto results = funcOp.getFunctionType().getResults();
-  auto funcType = FunctionType::get(funcContext, inputTypes, results);
-  auto loc = funcOp.getLoc();
-  auto newFuncOp = builder.create<emitc::FuncOp>(
-      loc, builder.getStringAttr("execute"), funcType);
-
-  builder.createBlock(&newFuncOp.getBody());
-  builder.setInsertionPointToStart(&newFuncOp.getBody().front());
-
-  IRMapping mapper;
-
-  auto body = llvm::make_early_inc_range(funcOp.getBody().front());
-  for (Operation &opToClone : body) {
-    if (isa<emitc::ConstantOp>(opToClone) ||
-        isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
-        isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
-        isa<emitc::ReturnOp>(opToClone)) {
-      builder.clone(opToClone, mapper);
-    } else {
-      opToClone.emitOpError("Unsupported operation found");
-    }
-  }
-
-  // TODO: Need to erase the funcOp after all this. Using funcOp->erase raises
-  // errors:
-
-  return classOp;
-}
-
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
new file mode 100644
index 0000000000000..a6f333b074a02
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -0,0 +1,150 @@
+//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Rewrite.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/GraphWriter.h"
+
+namespace mlir {
+namespace emitc {
+
+#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
+#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
+
+namespace {
+
+struct WrapFuncInClassPass
+    : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
+  void runOnOperation() override {
+    Operation *rootOp = getOperation();
+    MLIRContext *context = rootOp->getContext();
+
+    RewritePatternSet patterns(context);
+    populateFuncPatterns(patterns);
+
+    if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
+      return signalPassFailure();
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<emitc::EmitCDialect>();
+  }
+};
+
+} // namespace
+
+} // namespace emitc
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::emitc;
+
+static bool validOp(Operation &opToClone) {
+  return isa<emitc::ConstantOp>(opToClone) ||
+         isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
+         isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
+         isa<emitc::ReturnOp>(opToClone);
+}
+
+class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
+public:
+  using OpRewritePattern<emitc::FuncOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
+                                PatternRewriter &rewriter) const override {
+    if (funcOp->getParentOfType<emitc::ClassOp>()) {
+      return failure();
+    }
+    auto className = "My" + funcOp.getSymNameAttr().str() + "Class";
+    mlir::emitc::ClassOp newClassOp =
+        rewriter.create<emitc::ClassOp>(funcOp.getLoc(), className);
+
+    SmallVector<std::pair<StringAttr, TypeAttr>> fields;
+    rewriter.createBlock(&newClassOp.getBody());
+    rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
+
+    auto argAttrs = funcOp.getArgAttrs();
+
+    for (const auto &[arg, val] : (zip(*argAttrs, funcOp.getArguments()))) {
+      // FIXME:How can we avoid hardcoding this name?
+      // Should we loop through the dictionary and check for each named
+      // attribute if attr.getName().getValue().contains("tf_saved_model")
+      if (auto namedAttr = dyn_cast<mlir::DictionaryAttr>(arg).getNamed(
+              "tf_saved_model.index_path")) {
+        Attribute nv = namedAttr->getValue();
+        StringAttr fieldName =
+            cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
+        TypeAttr typeAttr = TypeAttr::get(val.getType());
+        fields.push_back({fieldName, typeAttr});
+
+        rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+                                        /* attributes*/ arg);
+      } else
+        funcOp->emitOpError("Only Covers TF models");
+    }
+
+    rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
+    MLIRContext *funcContext = funcOp.getContext();
+    ArrayRef<Type> inputTypes = funcOp.getFunctionType().getInputs();
+    ArrayRef<Type> results = funcOp.getFunctionType().getResults();
+    FunctionType funcType = FunctionType::get(funcContext, inputTypes, results);
+    Location loc = funcOp.getLoc();
+    FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
+        loc, rewriter.getStringAttr("execute"), funcType);
+
+    rewriter.setInsertionPointToStart(newFuncOp.addEntryBlock());
+
+    std::vector<Value> newArguments;
+    for (auto [fieldName, attr] : fields) {
+      auto arg =
+          rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
+      newArguments.push_back(arg);
+    }
+
+    IRMapping mapper;
+    for (auto [oldArg, newArg] :
+         llvm::zip(funcOp.getArguments(), newArguments)) {
+      mapper.map(oldArg, newArg);
+    }
+
+    while (!newFuncOp.getArguments().empty()) {
+      if (failed(newFuncOp.eraseArgument(0))) {
+        break;
+      }
+    }
+
+    // TODO: The mapper is easier to use but cloning is more expensive than
+    // moving the body. Working on changing this portion to move the body
+    // instead
+    auto body = llvm::make_early_inc_range(funcOp.getBody().front());
+    for (Operation &opToClone : body) {
+      if (validOp(opToClone)) {
+        rewriter.clone(opToClone, mapper);
+      } else {
+        opToClone.emitOpError("Unsupported operation found");
+      }
+    }
+
+    rewriter.replaceOp(funcOp, newClassOp);
+    return funcOp->use_empty() ? success() : failure();
+  }
+};
+
+void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
+  patterns.add<WrapFuncInClass>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/EmitC/convert_func_to_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
similarity index 95%
rename from mlir/test/Dialect/EmitC/convert_func_to_class.mlir
rename to mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index f9e276f059130..3a775f969ab17 100644
--- a/mlir/test/Dialect/EmitC/convert_func_to_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt  %s --emitc-convert-func-to-class 
+// RUN: mlir-opt  %s --wrap-emitc-func-in-class
 
 module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
   emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {

>From f47f2111c669c77c8160abb8243d940b20f12b94 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 13 Jun 2025 23:48:08 +0000
Subject: [PATCH 12/20] Added tests for the wrap-emitc-func-in-class

---
 .../mlir/Dialect/EmitC/Transforms/Passes.td   |  3 +
 .../Dialect/EmitC/Transforms/Transforms.h     |  3 +-
 .../EmitC/Transforms/WrapFuncInClass.cpp      | 81 ++++++++-----------
 .../EmitC/wrap_emitc_func_in_class.mlir       | 26 +++++-
 .../EmitC/wrap_emitc_func_in_class_neg.mlir   |  8 ++
 5 files changed, 72 insertions(+), 49 deletions(-)
 create mode 100644 mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 1aa95b32217c1..d8ebf4a613bfd 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -28,6 +28,9 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
     to a new `execute` method within the class.
   }];
   let dependentDialects = ["emitc::EmitCDialect"];
+  let options = [Option<
+      "namedAttribute", "named-attribute", "std::string", "\"\"",
+      "Name of the attribute to look for field names on function arguments">];
 }
 
 #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index bdf6d0985e6db..11a1ad2ad2ff2 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -29,7 +29,8 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
 void populateExpressionPatterns(RewritePatternSet &patterns);
 
 /// Populates 'patterns' with func-related patterns.
-void populateFuncPatterns(RewritePatternSet &patterns);
+void populateFuncPatterns(RewritePatternSet &patterns,
+                          const std::string &namedAttribute);
 
 } // namespace emitc
 } // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index a6f333b074a02..d87af1379d96a 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -12,14 +12,13 @@
 #include "mlir/Dialect/EmitC/Transforms/Transforms.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeRange.h"
-#include "mlir/IR/ValueRange.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/LogicalResult.h"
 
 namespace mlir {
 namespace emitc {
@@ -31,12 +30,13 @@ namespace {
 
 struct WrapFuncInClassPass
     : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
+  using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
   void runOnOperation() override {
     Operation *rootOp = getOperation();
     MLIRContext *context = rootOp->getContext();
 
     RewritePatternSet patterns(context);
-    populateFuncPatterns(patterns);
+    populateFuncPatterns(patterns, namedAttribute);
 
     if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
       return signalPassFailure();
@@ -54,16 +54,13 @@ struct WrapFuncInClassPass
 using namespace mlir;
 using namespace mlir::emitc;
 
-static bool validOp(Operation &opToClone) {
-  return isa<emitc::ConstantOp>(opToClone) ||
-         isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
-         isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
-         isa<emitc::ReturnOp>(opToClone);
-}
-
 class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
+private:
+  std::string attributeName;
+
 public:
-  using OpRewritePattern<emitc::FuncOp>::OpRewritePattern;
+  WrapFuncInClass(MLIRContext *context, const std::string &attrName)
+      : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
 
   LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
@@ -79,23 +76,25 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
 
     auto argAttrs = funcOp.getArgAttrs();
-
-    for (const auto &[arg, val] : (zip(*argAttrs, funcOp.getArguments()))) {
-      // FIXME:How can we avoid hardcoding this name?
-      // Should we loop through the dictionary and check for each named
-      // attribute if attr.getName().getValue().contains("tf_saved_model")
-      if (auto namedAttr = dyn_cast<mlir::DictionaryAttr>(arg).getNamed(
-              "tf_saved_model.index_path")) {
-        Attribute nv = namedAttr->getValue();
-        StringAttr fieldName =
-            cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
-        TypeAttr typeAttr = TypeAttr::get(val.getType());
-        fields.push_back({fieldName, typeAttr});
-
-        rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
-                                        /* attributes*/ arg);
-      } else
-        funcOp->emitOpError("Only Covers TF models");
+    if (argAttrs) {
+      for (const auto &[arg, val] :
+           llvm::zip(*argAttrs, funcOp.getArguments())) {
+        if (auto namedAttr =
+                dyn_cast<mlir::DictionaryAttr>(arg).getNamed(attributeName)) {
+          Attribute nv = namedAttr->getValue();
+          StringAttr fieldName =
+              cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
+          TypeAttr typeAttr = TypeAttr::get(val.getType());
+          fields.push_back({fieldName, typeAttr});
+
+          rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+                                          /* attributes*/ arg);
+        }
+      }
+    } else {
+      funcOp->emitOpError("arguments should have attributes so we can "
+                          "initialize class fields.");
+      return failure();
     }
 
     rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
@@ -107,8 +106,10 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
         loc, rewriter.getStringAttr("execute"), funcType);
 
-    rewriter.setInsertionPointToStart(newFuncOp.addEntryBlock());
+    rewriter.createBlock(&newFuncOp.getBody());
+    newFuncOp.getBody().takeBody(funcOp.getBody());
 
+    rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
     std::vector<Value> newArguments;
     for (auto [fieldName, attr] : fields) {
       auto arg =
@@ -116,10 +117,9 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
       newArguments.push_back(arg);
     }
 
-    IRMapping mapper;
     for (auto [oldArg, newArg] :
-         llvm::zip(funcOp.getArguments(), newArguments)) {
-      mapper.map(oldArg, newArg);
+         llvm::zip(newFuncOp.getArguments(), newArguments)) {
+      rewriter.replaceAllUsesWith(oldArg, newArg);
     }
 
     while (!newFuncOp.getArguments().empty()) {
@@ -128,23 +128,12 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
       }
     }
 
-    // TODO: The mapper is easier to use but cloning is more expensive than
-    // moving the body. Working on changing this portion to move the body
-    // instead
-    auto body = llvm::make_early_inc_range(funcOp.getBody().front());
-    for (Operation &opToClone : body) {
-      if (validOp(opToClone)) {
-        rewriter.clone(opToClone, mapper);
-      } else {
-        opToClone.emitOpError("Unsupported operation found");
-      }
-    }
-
     rewriter.replaceOp(funcOp, newClassOp);
     return funcOp->use_empty() ? success() : failure();
   }
 };
 
-void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
-  patterns.add<WrapFuncInClass>(patterns.getContext());
+void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
+                                       const std::string &namedAttribute) {
+  patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
 }
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index 3a775f969ab17..e0fa78a3dc459 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt  %s --wrap-emitc-func-in-class
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s | FileCheck %s
 
 module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
-  emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+  emitc.func @Model(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
     %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
     %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
     %2 = load %1 : <f32>
@@ -13,3 +13,25 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
     return
   }
 }
+
+// CHECK: module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+// CHECK:   emitc.class @MyModelClass {
+// CHECK:     emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
+// CHECK:     emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
+// CHECK:     emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
+// CHECK:     emitc.func @execute() {
+// CHECK:       %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+// CHECK:       %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
+// CHECK:       %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
+// CHECK:       %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
+// CHECK:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK:       %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
+// CHECK:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK:       %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
+// CHECK:       %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
+// CHECK:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK:       assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
+// CHECK:       return
+// CHECK:     }
+// CHECK:   }
+// CHECK: }
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
new file mode 100644
index 0000000000000..578e8f6c21ebf
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s 2>&1 | FileCheck %s
+
+emitc.func @foo(%arg0 : i32) {
+  emitc.call_opaque "bar" (%arg0) : (i32) -> ()
+  emitc.return
+}
+
+// CHECK: error: 'emitc.func' op arguments should have attributes so we can initialize class fields.

>From 49f202cd3d91767f8a1d25481f1a48ba8050fcc1 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Sat, 14 Jun 2025 00:20:52 +0000
Subject: [PATCH 13/20] Remove class-name changes

---
 mlir/include/mlir/Target/Cpp/CppEmitter.h     |   4 +-
 mlir/lib/Target/Cpp/TranslateRegistration.cpp |  41 +----
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 153 ++----------------
 .../emit-class-neg-external.mlir              |   8 -
 .../emit-class-neg-noArgAttrs.mlir            |  15 --
 mlir/test/mlir-translate/emit-class.mlir      |  39 -----
 6 files changed, 18 insertions(+), 242 deletions(-)
 delete mode 100644 mlir/test/mlir-translate/emit-class-neg-external.mlir
 delete mode 100644 mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
 delete mode 100644 mlir/test/mlir-translate/emit-class.mlir

diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index d1a6c1dc12d4c..7c5747a888261 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,9 +28,7 @@ namespace emitc {
 /// with matching id are emitted.
 LogicalResult translateToCpp(Operation *op, raw_ostream &os,
                              bool declareVariablesAtTop = false,
-                             StringRef fileId = {}, bool emitClass = false,
-                             StringRef className = {},
-                             StringRef fieldNameAttribute = {});
+                             StringRef fileId = {});
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 9e1533d34f6ea..f1869ac9e8eda 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,50 +33,13 @@ void registerToCppTranslation() {
       "file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
       llvm::cl::init(""));
 
-  static llvm::cl::opt<bool> emitClass(
-      "emit-class",
-      llvm::cl::desc("If specified, the output will be a class where "
-                     "the function(s) in the module are methods "
-                     "Enables class-related options"),
-      llvm::cl::init(false));
-
-  static llvm::cl::opt<std::string> className(
-      "class-name",
-      llvm::cl::desc("Mandatory class name if --emit-class is set"),
-      llvm::cl::init(""));
-
-  static llvm::cl::opt<std::string> fieldNameAttribute(
-      "field-name-attribute",
-      llvm::cl::desc("Mandatory name of the attribute to use as field name if "
-                     "--emit-class is set(default=tf_saved_model.index_path)"),
-      llvm::cl::init("tf_saved_model.index_path"));
-
   TranslateFromMLIRRegistration reg(
       "mlir-to-cpp", "translate from mlir to cpp",
       [](Operation *op, raw_ostream &output) {
-        if (emitClass) {
-          if (className.empty()) {
-            llvm::errs() << "Error: --class-name is mandatory when "
-                            "--emit-class is set.\n";
-            return mlir::failure();
-          }
-          if (fieldNameAttribute.empty()) {
-            llvm::errs() << "Error: --field-name-attribute is mandatory when "
-                            "--emit-class is set.\n";
-            return mlir::failure();
-          }
-          return emitc::translateToCpp(
-              op, output,
-              /*declareVariablesAtTop=*/declareVariablesAtTop,
-              /*fileId=*/fileId, /*emitClass=*/emitClass,
-              /*className=*/className,
-              /*fieldNameAttribute=*/fieldNameAttribute);
-        }
         return emitc::translateToCpp(
             op, output,
             /*declareVariablesAtTop=*/declareVariablesAtTop,
-            /*fileId=*/fileId, /*emitClass=*/emitClass, /*className=*/className,
-            /*fieldNameAttribute=*/fieldNameAttribute);
+            /*fileId=*/fileId);
       },
       [](DialectRegistry &registry) {
         // clang-format off
@@ -87,4 +50,4 @@ void registerToCppTranslation() {
       });
 }
 
-} // namespace mlir
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index fc550d24113e3..52bfd4d774688 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -68,14 +68,6 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
 }
 
-template <typename Container, typename UnaryFunctor>
-inline LogicalResult interleaveWithNewLineWithError(const Container &c,
-                                                    raw_ostream &os,
-                                                    UnaryFunctor eachFn) {
-  return interleaveWithError(c.begin(), c.end(), eachFn,
-                             [&]() { os << ";\n"; });
-}
-
 /// Return the precedence of a operator as an integer, higher values
 /// imply higher precedence.
 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
@@ -124,8 +116,7 @@ namespace {
 /// Emitter that uses dialect specific emitters to emit C++ code.
 struct CppEmitter {
   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                      StringRef fileId, bool emitClass, StringRef className,
-                      StringRef fieldNameAttribute);
+                      StringRef fileId);
 
   /// Emits attribute or returns failure.
   LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -242,15 +233,6 @@ struct CppEmitter {
   /// be declared at the beginning of a function.
   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
 
-  // Returns whether we should emit a C++ class
-  bool shouldPrintClass() { return emitClass; };
-
-  // Returns the class name to emit
-  std::string getClassName() { return className; };
-
-  // Returns the field name to use in the map
-  std::string getfieldNameAttribute() { return fieldNameAttribute; };
-
   /// Returns whether this file op should be emitted
   bool shouldEmitFile(FileOp file) {
     return !fileId.empty() && file.getId() == fileId;
@@ -286,18 +268,6 @@ struct CppEmitter {
   /// Only emit file ops whos id matches this value.
   std::string fileId;
 
-  /// Controls whether the output should be a C++ class.
-  /// If true, the generated C++ code will be encapsulated within a class,
-  /// and functions from the input module will become its member functions.
-  const bool emitClass;
-
-  /// The specified name for the generated C++ class
-  const std::string className;
-
-  /// Name of the MLIR attribute to use as a field name within the generated
-  /// class
-  const std::string fieldNameAttribute;
-
   /// Map from value to name of C++ variable that contain the name.
   ValueMapper valueMapper;
 
@@ -1063,17 +1033,6 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
       }));
 }
 
-static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
-                                 Region::BlockArgListType arguments) {
-  raw_indented_ostream &os = emitter.ostream();
-
-  return (interleaveWithNewLineWithError(
-      arguments, os, [&](BlockArgument arg) -> LogicalResult {
-        return emitter.emitVariableDeclaration(
-            functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
-      }));
-}
-
 static LogicalResult printFunctionBody(CppEmitter &emitter,
                                        Operation *functionOp,
                                        Region::BlockListType &blocks) {
@@ -1178,45 +1137,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
-static LogicalResult emitClassFields(CppEmitter &emitter,
-                                     emitc::FuncOp functionOp) {
-  raw_indented_ostream &os = emitter.ostream();
-  auto argAttrs = functionOp.getArgAttrs();
-  Operation *operation = functionOp.getOperation();
-  if (failed(printFields(emitter, operation, functionOp.getArguments())))
-    return failure();
-  os << ";\n";
-
-  std::map<std::string, Value> fields;
-  os << "\nstd::map<std::string, char*> _buffer_map {";
-  if (argAttrs) {
-    for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
-      if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
-        auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
-        auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
-        auto Ins = fields.insert({name, v});
-        if (!Ins.second)
-          return failure();
-        os << " { \"" << name << "\"" << ", reinterpret_cast<char*>("
-           << emitter.getOrCreateName(v) << ") }, ";
-      }
-    }
-  } else
-    return failure();
-
-  os << "};\n";
-  os << "char* getBufferForName(const std::string& name) const {\n";
-  os.indent();
-  os.indent();
-  os << "auto it = _buffer_map.find(name);\n";
-  os << "return (it == _buffer_map.end()) ? nullptr : it->second;\n";
-  os.unindent();
-  os.unindent();
-  os << "}\n\n";
-
-  return success();
-}
-
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::FuncOp functionOp) {
   // We need to declare variables at top if the function has multiple blocks.
@@ -1228,30 +1148,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
 
   CppEmitter::Scope scope(emitter);
   raw_indented_ostream &os = emitter.ostream();
-  Operation *operation = functionOp.getOperation();
-  if (emitter.shouldPrintClass()) {
-    if (functionOp.isExternal()) {
-      // TODO: Determine the best long-term strategy for external functions.
-      // Currently, we're skipping over this functionOp.
-      // We have considered using emitWarning() which would return
-      // InFlightDiagnostic which seems can be automatically converted to
-      // LogicalResult since this is done in emitAttributes where emitError is
-      // converted to LogicalResult. However, it requires that we pass in a
-      // location which at first glance we don't have in this scope. Open to
-      // further discussion on this.
-      os << "Warning: Cannot process external function '"
-         << functionOp.getName() << "'. "
-         << "This functionOp lacks a body so we will skip over it.";
-      return success();
-    }
-    os << "class " << emitter.getClassName() << " final {\n";
-    os << "public: \n";
-    os.indent();
-
-    if (failed(emitClassFields(emitter, functionOp)))
-      return failure();
-  }
-
   if (functionOp.getSpecifiers()) {
     for (Attribute specifier : functionOp.getSpecifiersAttr()) {
       os << cast<StringAttr>(specifier).str() << " ";
@@ -1261,37 +1157,23 @@ static LogicalResult printOperation(CppEmitter &emitter,
   if (failed(emitter.emitTypes(functionOp.getLoc(),
                                functionOp.getFunctionType().getResults())))
     return failure();
-  // TODO: We may wanna consider having the name of the function be execute in
-  // the case that we want to emit a class instead of main. Leaving as is for
-  // now to make the change smaller.
   os << " " << functionOp.getName();
 
   os << "(";
-
-  if (!emitter.shouldPrintClass()) {
-    if (functionOp.isExternal()) {
-      if (failed(printFunctionArgs(emitter, operation,
-                                   functionOp.getArgumentTypes())))
-        return failure();
-      os << ");";
-      return success();
-    }
-    if (failed(
-            printFunctionArgs(emitter, operation, functionOp.getArguments())))
+  Operation *operation = functionOp.getOperation();
+  if (functionOp.isExternal()) {
+    if (failed(printFunctionArgs(emitter, operation,
+                                 functionOp.getArgumentTypes())))
       return failure();
+    os << ");";
+    return success();
   }
+  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+    return failure();
   os << ") {\n";
-
   if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
     return failure();
-
-  if (emitter.shouldPrintClass()) {
-    os << "}\n";
-    os.unindent();
-    os << "};\n";
-  } else {
-    os << "}\n";
-  }
+  os << "}\n";
 
   return success();
 }
@@ -1328,11 +1210,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
 }
 
 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                       StringRef fileId, bool emitClass, StringRef className,
-                       StringRef fieldNameAttribute)
+                       StringRef fileId)
     : os(os), declareVariablesAtTop(declareVariablesAtTop),
-      fileId(fileId.str()), emitClass(emitClass), className(className.str()),
-      fieldNameAttribute(fieldNameAttribute.str()) {
+      fileId(fileId.str()) {
   valueInScopeCount.push(0);
   labelInScopeCount.push(0);
 }
@@ -1915,10 +1795,7 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
 
 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
                                     bool declareVariablesAtTop,
-                                    StringRef fileId, bool emitClass,
-                                    StringRef className,
-                                    StringRef fieldNameAttribute) {
-  CppEmitter emitter(os, declareVariablesAtTop, fileId, emitClass, className,
-                     fieldNameAttribute);
+                                    StringRef fileId) {
+  CppEmitter emitter(os, declareVariablesAtTop, fileId);
   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
-}
+}
\ No newline at end of file
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
deleted file mode 100644
index c34a1652abd3f..0000000000000
--- a/mlir/test/mlir-translate/emit-class-neg-external.mlir
+++ /dev/null
@@ -1,8 +0,0 @@
-/// An external function - has no body
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
-
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
-  emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
-}
-
-// CHECK: Warning: Cannot process external function 'extern_func'. This functionOp lacks a body so we will skip over it.
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
deleted file mode 100644
index 6d43fa953a946..0000000000000
--- a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-/// The function has no argument attributes
-// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
-
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
-  emitc.func @foo(%arg0 : i32) {
-    emitc.call_opaque "bar" (%arg0) : (i32) -> ()
-    emitc.return
-  }
-}
-
-// CHECK: class ArgAttrs final {
-// CHECK-NEXT: public: 
-// CHECK-NEXT:   int32_t v1;
-// CHECK-EMPTY: 
-// CHECK-NEXT:   std::map<std::string, char*> _buffer_map {
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
deleted file mode 100644
index 2779cb315ed41..0000000000000
--- a/mlir/test/mlir-translate/emit-class.mlir
+++ /dev/null
@@ -1,39 +0,0 @@
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s  
-
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
-  emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
-    %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-    %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-    %2 = load %1 : <f32>
-    %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-    %4 = load %3 : <f32>
-    %5 = add %2, %4 : (f32, f32) -> f32
-    %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-    assign %5 : f32 to %6 : <f32>
-    return
-  }
-}
-
-// CHECK: class MyAdder final {
-// CHECK-NEXT: public:
-// CHECK-NEXT:   float v1[1];
-// CHECK-NEXT:   float v2[1];
-// CHECK-NEXT:   float v3[1];
-// CHECK-EMPTY:
-// CHECK-NEXT:   std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, 
-// CHECK-SAME: { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
-// CHECK-NEXT:   char* getBufferForName(const std::string& name) const {
-// CHECK-NEXT:      auto it = _buffer_map.find(name);
-// CHECK-NEXT:      return (it == _buffer_map.end()) ? nullptr : it->second;
-// CHECK-NEXT:   }
-// CHECK-EMPTY:
-// CHECK-NEXT: void main() {
-// CHECK-NEXT:     size_t v4 = 0;
-// CHECK-NEXT:     float v5 = v2[v4];
-// CHECK-NEXT:     float v6 = v1[v4];
-// CHECK-NEXT:     float v7 = v5 + v6;
-// CHECK-NEXT:     v3[v4] = v7;
-// CHECK-NEXT:     return;
-// CHECK-NEXT:  }
-// CHECK-NEXT: };
-

>From d108bf22f602fcdb27425a36626670f4339e9652 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Sat, 14 Jun 2025 00:30:23 +0000
Subject: [PATCH 14/20] cleaning up comments

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td      | 3 +--
 mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp | 1 +
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 08235ca701e2a..466c628e6e837 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1579,10 +1579,9 @@ def EmitC_ClassOp
   let summary =
       "Represents a C++ class definition, encapsulating fields and methods.";
 
-  // FIX WORDING
   let description = [{
     The `emitc.class` operation defines a C++ class, acting as a container
-    for its data fields (`emitc.variable`) and methods (`emitc.func`).
+    for its data fields (`emitc.field`) and methods (`emitc.func`).
     It creates a distinct scope, isolating its contents from the surrounding
     MLIR region, similar to how C++ classes encapsulate their internals.
     All the class memebrs need to be default initalizable. 
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index a252924caeb62..87350ecdceaaa 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -10,6 +10,7 @@
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
 
 namespace mlir {
 namespace emitc {

>From 86a057e12369fb7563930f83e1d398e7fa5f0d7f Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 16 Jun 2025 23:49:42 +0000
Subject: [PATCH 15/20] Removed tf attributes and allowed for no attributes

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 26 ++++----
 .../EmitC/Transforms/WrapFuncInClass.cpp      | 59 +++++++++++--------
 .../EmitC/wrap_emitc_func_in_class.mlir       | 50 ++++++++--------
 .../EmitC/wrap_emitc_func_in_class_neg.mlir   |  8 ---
 .../wrap_emitc_func_in_class_noAttr.mlir      | 17 ++++++
 5 files changed, 94 insertions(+), 66 deletions(-)
 delete mode 100644 mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
 create mode 100644 mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 466c628e6e837..fc609fac72f3a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1584,14 +1584,13 @@ def EmitC_ClassOp
     for its data fields (`emitc.field`) and methods (`emitc.func`).
     It creates a distinct scope, isolating its contents from the surrounding
     MLIR region, similar to how C++ classes encapsulate their internals.
-    All the class memebrs need to be default initalizable. 
 
     Example:
     ```mlir
-    emitc.class @MymainClass {
-      emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
-      emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
-      emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
+    emitc.class @mainClass {
+      emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
+      emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
+      emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
 
       emitc.func @execute() {
         %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
@@ -1634,15 +1633,22 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
   let summary = "A field within a class";
   let description = [{
     The `emitc.field` operation declares a named field within an `emitc.class`
-    operation. The field's type must be an EmitC type. An optional initial value can be provided.
+    operation. The field's type must be an EmitC type. The initial value is optional. 
+    If the argument has attributes, these become the initial value, else we end up with no initial value.
 
     Example with initial values:
 
     ```mlir
-    emitc.class @MyModelClass {
-      emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
-      emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
-      emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
+    emitc.class @modelClass {
+      emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
+      emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
+      emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
+    }
+    ```
+    Example with no initial value:
+    ```mlir
+     emitc.class @modelClass {
+      emitc.field @another_feature : !emitc.array<1xf32>
     }
     ```
   }];
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index d87af1379d96a..2a2c65214b6fd 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -12,13 +12,17 @@
 #include "mlir/Dialect/EmitC/Transforms/Transforms.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/GraphWriter.h"
 #include "llvm/Support/LogicalResult.h"
+#include <string>
 
 namespace mlir {
 namespace emitc {
@@ -67,7 +71,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     if (funcOp->getParentOfType<emitc::ClassOp>()) {
       return failure();
     }
-    auto className = "My" + funcOp.getSymNameAttr().str() + "Class";
+    auto className = funcOp.getSymNameAttr().str() + "Class";
     mlir::emitc::ClassOp newClassOp =
         rewriter.create<emitc::ClassOp>(funcOp.getLoc(), className);
 
@@ -76,25 +80,33 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
 
     auto argAttrs = funcOp.getArgAttrs();
-    if (argAttrs) {
-      for (const auto &[arg, val] :
-           llvm::zip(*argAttrs, funcOp.getArguments())) {
-        if (auto namedAttr =
-                dyn_cast<mlir::DictionaryAttr>(arg).getNamed(attributeName)) {
-          Attribute nv = namedAttr->getValue();
-          StringAttr fieldName =
-              cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
-          TypeAttr typeAttr = TypeAttr::get(val.getType());
-          fields.push_back({fieldName, typeAttr});
-
-          rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
-                                          /* attributes*/ arg);
+    size_t idx = 0;
+
+    for (const BlockArgument &val : funcOp.getArguments()) {
+      StringAttr fieldName;
+      Attribute argAttr = nullptr;
+
+      if (argAttrs && idx < argAttrs->size()) {
+        if (DictionaryAttr dictAttr =
+                dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx])) {
+          if (auto namedAttr = dictAttr.getNamed(attributeName)) {
+            Attribute nv = namedAttr->getValue();
+            fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
+            argAttr = (*argAttrs)[idx];
+          }
         }
       }
-    } else {
-      funcOp->emitOpError("arguments should have attributes so we can "
-                          "initialize class fields.");
-      return failure();
+
+      if (!fieldName) {
+        fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
+      }
+
+      TypeAttr typeAttr = TypeAttr::get(val.getType());
+      fields.push_back({fieldName, typeAttr});
+      rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+                                      argAttr);
+
+      ++idx;
     }
 
     rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
@@ -112,7 +124,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
     std::vector<Value> newArguments;
     for (auto [fieldName, attr] : fields) {
-      auto arg =
+      GetFieldOp arg =
           rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
       newArguments.push_back(arg);
     }
@@ -122,14 +134,13 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
       rewriter.replaceAllUsesWith(oldArg, newArg);
     }
 
-    while (!newFuncOp.getArguments().empty()) {
-      if (failed(newFuncOp.eraseArgument(0))) {
-        break;
-      }
+    llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
+    if (failed(newFuncOp.eraseArguments(argsToErase))) {
+      newFuncOp->emitOpError("Failed to erase all arguments using BitVector.");
     }
 
     rewriter.replaceOp(funcOp, newClassOp);
-    return funcOp->use_empty() ? success() : failure();
+    return success();
   }
 };
 
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index e0fa78a3dc459..28636811b1f17 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s | FileCheck %s
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
 
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
-  emitc.func @Model(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+module attributes { } {
+  emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}, %arg1: !emitc.array<1xf32> {emitc.opaque = ["some_feature"]}, %arg2: !emitc.array<1xf32> {emitc.opaque = ["output_0"]}) attributes { } {
     %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
     %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
     %2 = load %1 : <f32>
@@ -14,24 +14,26 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
   }
 }
 
-// CHECK: module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
-// CHECK:   emitc.class @MyModelClass {
-// CHECK:     emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
-// CHECK:     emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
-// CHECK:     emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
-// CHECK:     emitc.func @execute() {
-// CHECK:       %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-// CHECK:       %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
-// CHECK:       %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
-// CHECK:       %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
-// CHECK:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK:       %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
-// CHECK:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK:       %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
-// CHECK:       %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
-// CHECK:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK:       assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
-// CHECK:       return
-// CHECK:     }
-// CHECK:   }
-// CHECK: }
+
+// CHECK: module {
+// CHECK-NEXT:   emitc.class @modelClass {
+// CHECK-NEXT:     emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
+// CHECK-NEXT:     emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
+// CHECK-NEXT:     emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
+// CHECK-NEXT:     emitc.func @execute() {
+// CHECK-NEXT:       %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+// CHECK-NEXT:       %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
+// CHECK-NEXT:       %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
+// CHECK-NEXT:       %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
+// CHECK-NEXT:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT:       %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
+// CHECK-NEXT:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT:       %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
+// CHECK-NEXT:       %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
+// CHECK-NEXT:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT:       assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
+// CHECK-NEXT:       return
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }
+
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
deleted file mode 100644
index 578e8f6c21ebf..0000000000000
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
+++ /dev/null
@@ -1,8 +0,0 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s 2>&1 | FileCheck %s
-
-emitc.func @foo(%arg0 : i32) {
-  emitc.call_opaque "bar" (%arg0) : (i32) -> ()
-  emitc.return
-}
-
-// CHECK: error: 'emitc.func' op arguments should have attributes so we can initialize class fields.
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
new file mode 100644
index 0000000000000..57155cac693bb
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
+
+emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
+  emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
+  emitc.return
+}
+
+// CHECK: module {
+// CHECK-NEXT:   emitc.class @fooClass {
+// CHECK-NEXT:     emitc.field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT:     emitc.func @execute() {
+// CHECK-NEXT:       %0 = get_field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT:       call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
+// CHECK-NEXT:       return
+// CHECK-NEXT:     }
+// CHECK-NEXT:   }
+// CHECK-NEXT: }

>From b064a9cdcf6cff9941b04587be6034a7f4ecca1c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
 =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
 =?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Thu, 22 May 2025 08:24:18 -0700
Subject: [PATCH 16/20] [flang][rt] Fix the use of kNoAsyncId -> kNoAsyncObject
 (#141079)


>From 4780ab3650038669cf9be2facf5527fa5770ee38 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 17 Jun 2025 15:04:18 +0000
Subject: [PATCH 17/20] Cleaning up descriptions

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 59 ++++++++-----------
 .../mlir/Dialect/EmitC/Transforms/Passes.td   |  6 +-
 .../EmitC/Transforms/WrapFuncInClass.cpp      |  8 +--
 .../EmitC/wrap_emitc_func_in_class.mlir       | 23 ++++----
 .../wrap_emitc_func_in_class_noAttr.mlir      |  2 +-
 5 files changed, 42 insertions(+), 56 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index fc609fac72f3a..926ff301f1807 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1587,29 +1587,21 @@ def EmitC_ClassOp
 
     Example:
     ```mlir
-    emitc.class @mainClass {
-      emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
-      emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
-      emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
-
+    emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = ["input_tensor"]}) attributes { } {
+      %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+      %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+      return
+    }
+    // becomes 
+    emitc.class @modelClass {
+      emitc.field @input_tensor : !emitc.array<1xf32> = {emitc.opaque = ["input_tensor"]}
       emitc.func @execute() {
         %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-
-        %1 = get_field @another_feature : !emitc.array<1xf32>
-        %2 = get_field @some_feature : !emitc.array<1xf32>
-        %3 = get_field @output_0 : !emitc.array<1xf32>
-
-        %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-        %5 = load %4 : <f32>
-        %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-        %7 = load %6 : <f32>
-        %8 = add %5, %7 : (f32, f32) -> f32
-        %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-        assign %8 : f32 to %9 : <f32>
+        %1 = get_field @input_tensor : !emitc.array<1xf32>
+        %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
         return
       }
-  }
-
+    }
     ```
   }];
 
@@ -1633,23 +1625,22 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
   let summary = "A field within a class";
   let description = [{
     The `emitc.field` operation declares a named field within an `emitc.class`
-    operation. The field's type must be an EmitC type. The initial value is optional. 
-    If the argument has attributes, these become the initial value, else we end up with no initial value.
-
-    Example with initial values:
+    operation. The field's type must be an EmitC type. 
+    If the corresponding function argument has attributes (accessed via `argAttrs`), 
+    these attributes are attached to the field operation. 
+    Otherwise, the field is created without additional attributes.
 
+    Example of func argument with attributes:
     ```mlir
-    emitc.class @modelClass {
-      emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
-      emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
-      emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
-    }
+    %arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]} 
+    // becomes
+    emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
     ```
-    Example with no initial value:
+    Example of func argument without attributes:
     ```mlir
-     emitc.class @modelClass {
-      emitc.field @another_feature : !emitc.array<1xf32>
-    }
+    %arg0 : !emitc.array<1xf32>
+    // becomes
+    emitc.field @fieldName0 : !emitc.array<1xf32>
     ```
   }];
 
@@ -1673,9 +1664,7 @@ def EmitC_GetFieldOp
      Example:
 
      ```mlir
-    %some_ptr = emitc.get_field @some_feature : !emitc.array<1xf32> 
-    %another_ptr = emitc.get_field @another_feature : !emitc.array<1xf32> 
-    %output_ptr = emitc.get_field @output_0 : !emitc.array<1xf32> 
+     %0 = get_field @fieldName0 : !emitc.array<1xf32>
      ```
   }];
 
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index d8ebf4a613bfd..09ad644c2a439 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -29,8 +29,10 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
   }];
   let dependentDialects = ["emitc::EmitCDialect"];
   let options = [Option<
-      "namedAttribute", "named-attribute", "std::string", "\"\"",
-      "Name of the attribute to look for field names on function arguments">];
+      "namedAttribute", "named-attribute", "std::string",
+      /*default=*/"",
+      "Attribute key used to extract field names from function argument's "
+      "dictionary attributes">];
 }
 
 #endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index 2a2c65214b6fd..31c013dda6f50 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -80,9 +80,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
 
     auto argAttrs = funcOp.getArgAttrs();
-    size_t idx = 0;
-
-    for (const BlockArgument &val : funcOp.getArguments()) {
+    for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
       StringAttr fieldName;
       Attribute argAttr = nullptr;
 
@@ -105,8 +103,6 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
       fields.push_back({fieldName, typeAttr});
       rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
                                       argAttr);
-
-      ++idx;
     }
 
     rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
@@ -123,7 +119,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
 
     rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
     std::vector<Value> newArguments;
-    for (auto [fieldName, attr] : fields) {
+    for (auto &[fieldName, attr] : fields) {
       GetFieldOp arg =
           rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
       newArguments.push_back(arg);
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index 28636811b1f17..480943116dbf1 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -21,19 +21,18 @@ module attributes { } {
 // CHECK-NEXT:     emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
 // CHECK-NEXT:     emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
 // CHECK-NEXT:     emitc.func @execute() {
-// CHECK-NEXT:       %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-// CHECK-NEXT:       %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
-// CHECK-NEXT:       %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
-// CHECK-NEXT:       %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
-// CHECK-NEXT:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK-NEXT:       %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
-// CHECK-NEXT:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK-NEXT:       %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
-// CHECK-NEXT:       %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
-// CHECK-NEXT:       %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK-NEXT:       assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
+// CHECK-NEXT:       "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+// CHECK-NEXT:       get_field @another_feature : !emitc.array<1xf32>
+// CHECK-NEXT:       get_field @some_feature : !emitc.array<1xf32>
+// CHECK-NEXT:       get_field @output_0 : !emitc.array<1xf32>
+// CHECK-NEXT:       subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT:       load {{.*}} : <f32>
+// CHECK-NEXT:       subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT:       load {{.*}} : <f32>
+// CHECK-NEXT:       add {{.*}}, {{.*}} : (f32, f32) -> f32
+// CHECK-NEXT:       subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT:       assign {{.*}} : f32 to {{.*}} : <f32>
 // CHECK-NEXT:       return
 // CHECK-NEXT:     }
 // CHECK-NEXT:   }
 // CHECK-NEXT: }
-
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
index 57155cac693bb..92ed20c4b14e3 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
+// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s
 
 emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
   emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()

>From 009f137492c47d4114725d12c1fd4a9bd4bcf419 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Wed, 18 Jun 2025 21:42:04 +0000
Subject: [PATCH 18/20] Avoid unnecessary checks

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   | 35 +++------
 .../mlir/Dialect/EmitC/Transforms/Passes.td   | 23 ++++++
 .../Dialect/EmitC/Transforms/Transforms.h     |  2 +-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 60 ++++-----------
 .../EmitC/Transforms/WrapFuncInClass.cpp      | 74 ++++++-------------
 .../EmitC/wrap_emitc_func_in_class.mlir       | 18 +++--
 6 files changed, 82 insertions(+), 130 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 926ff301f1807..3b7b1a44783da 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1586,18 +1586,13 @@ def EmitC_ClassOp
     MLIR region, similar to how C++ classes encapsulate their internals.
 
     Example:
+
     ```mlir
-    emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = ["input_tensor"]}) attributes { } {
-      %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-      %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-      return
-    }
-    // becomes 
     emitc.class @modelClass {
-      emitc.field @input_tensor : !emitc.array<1xf32> = {emitc.opaque = ["input_tensor"]}
+      emitc.field @fieldName0 : !emitc.array<1xf32> = {emitc.opaque = "input_tensor"}
       emitc.func @execute() {
         %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-        %1 = get_field @input_tensor : !emitc.array<1xf32>
+        %1 = get_field @fieldName0 : !emitc.array<1xf32>
         %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
         return
       }
@@ -1609,8 +1604,6 @@ def EmitC_ClassOp
 
   let regions = (region AnyRegion:$body);
 
-  let builders = [];
-
   let extraClassDeclaration = [{
     // Returns the body block containing class members and methods.
     Block &getBlock();
@@ -1626,29 +1619,21 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
   let description = [{
     The `emitc.field` operation declares a named field within an `emitc.class`
     operation. The field's type must be an EmitC type. 
-    If the corresponding function argument has attributes (accessed via `argAttrs`), 
-    these attributes are attached to the field operation. 
-    Otherwise, the field is created without additional attributes.
 
-    Example of func argument with attributes:
-    ```mlir
-    %arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]} 
-    // becomes
-    emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
-    ```
-    Example of func argument without attributes:
+    Example:
+
     ```mlir
-    %arg0 : !emitc.array<1xf32>
-    // becomes
+    // Example with an attribute:
+    emitc.field @fieldName0 : !emitc.array<1xf32>  {emitc.opaque = "another_feature"}
+    // Example with no attribute:
     emitc.field @fieldName0 : !emitc.array<1xf32>
     ```
   }];
 
   let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
-      OptionalAttr<AnyAttr>:$initial_value);
+      OptionalAttr<AnyAttr>:$attrs);
 
-  let assemblyFormat =
-      [{ $sym_name `:` $type (`=` $initial_value^)? attr-dict}];
+  let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
 
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 09ad644c2a439..74c49132b61f6 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -26,6 +26,29 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
     This pass transforms `emitc.func` operations into `emitc.class` operations.
     Function arguments become fields of the class, and the function body is moved
     to a new `execute` method within the class.
+    If the corresponding function argument has attributes (accessed via `argAttrs`), 
+    these attributes are attached to the field operation. 
+    Otherwise, the field is created without additional attributes.
+
+    Example:
+    
+    ```mlir
+    emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}) attributes { } {
+      %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+      %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+      return
+    }
+    // becomes 
+    emitc.class @modelClass {
+      emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}
+      emitc.func @execute() {
+        %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+        %1 = get_field @input_tensor : !emitc.array<1xf32>
+        %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+        return
+      }
+    }
+    ```
   }];
   let dependentDialects = ["emitc::EmitCDialect"];
   let options = [Option<
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 11a1ad2ad2ff2..a4e8fe10ff853 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -30,7 +30,7 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
 
 /// Populates 'patterns' with func-related patterns.
 void populateFuncPatterns(RewritePatternSet &patterns,
-                          const std::string &namedAttribute);
+                          StringRef namedAttribute);
 
 } // namespace emitc
 } // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 695ad3ee3bb0a..e99a524fdf8a3 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1404,39 +1404,19 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
 // FieldOp
 //===----------------------------------------------------------------------===//
 LogicalResult FieldOp::verify() {
-  if (!isSupportedEmitCType(getType())) {
+  if (!isSupportedEmitCType(getType()))
     return emitOpError("expected valid emitc type");
-  }
 
-  if (!getInitialValue()) {
-    return success();
-  }
+  Operation *parentOp = getOperation()->getParentOp();
+  if (!parentOp || !isa<emitc::ClassOp>(parentOp))
+    return emitOpError("field must be nested within an emitc.class operation");
 
-  Attribute initValue = *getInitialValue();
-  // Check that the type of the initial value is compatible with the type of
-  // the global variable.
-  if (ElementsAttr elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
-    Type initialValueType = elementsAttr.getType();
-    if (!initialValueType) {
-      return emitOpError("initial value attribute must have a type");
-    }
-    Type fieldType = getType();
-    if (initialValueType != fieldType) {
-      if (LValueType lvalueType = dyn_cast<LValueType>(fieldType)) {
-        Type innerFieldType = lvalueType.getValueType();
-        if (innerFieldType != initialValueType) {
-          return emitOpError("initial value type ")
-                 << initialValueType << " is not compatible with field type '"
-                 << fieldType << "' its inner type '" << innerFieldType << "'";
-        }
-
-      } else {
-        return emitOpError("initial value type '")
-               << initialValueType << "' is not compatible with field type '"
-               << fieldType << "'";
-      }
-    }
-  }
+  StringAttr symName = getSymNameAttr();
+  if (!symName || symName.getValue().empty())
+    return emitOpError("field must have a non-empty symbol name");
+
+  if (!getAttrs())
+    return success();
 
   return success();
 }
@@ -1446,27 +1426,19 @@ LogicalResult FieldOp::verify() {
 //===----------------------------------------------------------------------===//
 LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
-  if (!fieldNameAttr) {
-    return emitError("field name attribute is mandatory");
-  }
-
-  StringRef fieldName = fieldNameAttr.getValue();
-
   FieldOp fieldOp =
-      symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
-
-  if (!fieldOp) {
-    return emitOpError("field '") << fieldName << "' not found in the class '";
-  }
+      symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr);
+  if (!fieldOp)
+    return emitOpError("field '")
+           << fieldNameAttr << "' not found in the class";
 
   Type getFieldResultType = getResult().getType();
   Type fieldType = fieldOp.getType();
 
-  if (fieldType != getFieldResultType) {
+  if (fieldType != getFieldResultType)
     return emitOpError("result type ")
-           << getFieldResultType << " does not match field '" << fieldName
+           << getFieldResultType << " does not match field '" << fieldNameAttr
            << "' type " << fieldType;
-  }
 
   return success();
 }
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index 31c013dda6f50..4a02eca594c80 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -1,4 +1,4 @@
-//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
+//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,6 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir-c/Rewrite.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/EmitC/Transforms/Passes.h"
 #include "mlir/Dialect/EmitC/Transforms/Transforms.h"
@@ -14,66 +13,44 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/GraphWriter.h"
-#include "llvm/Support/LogicalResult.h"
-#include <string>
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace emitc;
 
 namespace mlir {
 namespace emitc {
-
 #define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
 #include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
 
 namespace {
-
 struct WrapFuncInClassPass
     : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
   using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
   void runOnOperation() override {
     Operation *rootOp = getOperation();
-    MLIRContext *context = rootOp->getContext();
 
-    RewritePatternSet patterns(context);
+    RewritePatternSet patterns(&getContext());
     populateFuncPatterns(patterns, namedAttribute);
 
-    if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
-      return signalPassFailure();
-  }
-  void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<emitc::EmitCDialect>();
+    walkAndApplyPatterns(rootOp, std::move(patterns));
   }
 };
 
 } // namespace
-
 } // namespace emitc
 } // namespace mlir
 
-using namespace mlir;
-using namespace mlir::emitc;
-
 class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
-private:
-  std::string attributeName;
-
 public:
-  WrapFuncInClass(MLIRContext *context, const std::string &attrName)
+  WrapFuncInClass(MLIRContext *context, StringRef attrName)
       : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
 
   LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
                                 PatternRewriter &rewriter) const override {
-    if (funcOp->getParentOfType<emitc::ClassOp>()) {
-      return failure();
-    }
+
     auto className = funcOp.getSymNameAttr().str() + "Class";
-    mlir::emitc::ClassOp newClassOp =
-        rewriter.create<emitc::ClassOp>(funcOp.getLoc(), className);
+    ClassOp newClassOp = rewriter.create<ClassOp>(funcOp.getLoc(), className);
 
     SmallVector<std::pair<StringAttr, TypeAttr>> fields;
     rewriter.createBlock(&newClassOp.getBody());
@@ -84,19 +61,11 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
       StringAttr fieldName;
       Attribute argAttr = nullptr;
 
+      fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
       if (argAttrs && idx < argAttrs->size()) {
         if (DictionaryAttr dictAttr =
-                dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx])) {
-          if (auto namedAttr = dictAttr.getNamed(attributeName)) {
-            Attribute nv = namedAttr->getValue();
-            fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
-            argAttr = (*argAttrs)[idx];
-          }
-        }
-      }
-
-      if (!fieldName) {
-        fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
+                dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx]))
+          argAttr = (*argAttrs)[idx];
       }
 
       TypeAttr typeAttr = TypeAttr::get(val.getType());
@@ -106,19 +75,17 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
     }
 
     rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
-    MLIRContext *funcContext = funcOp.getContext();
-    ArrayRef<Type> inputTypes = funcOp.getFunctionType().getInputs();
-    ArrayRef<Type> results = funcOp.getFunctionType().getResults();
-    FunctionType funcType = FunctionType::get(funcContext, inputTypes, results);
+    FunctionType funcType = funcOp.getFunctionType();
     Location loc = funcOp.getLoc();
-    FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
-        loc, rewriter.getStringAttr("execute"), funcType);
+    FuncOp newFuncOp =
+        rewriter.create<emitc::FuncOp>(loc, ("execute"), funcType);
 
     rewriter.createBlock(&newFuncOp.getBody());
     newFuncOp.getBody().takeBody(funcOp.getBody());
 
     rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
     std::vector<Value> newArguments;
+    newArguments.reserve(fields.size());
     for (auto &[fieldName, attr] : fields) {
       GetFieldOp arg =
           rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
@@ -132,15 +99,18 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
 
     llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
     if (failed(newFuncOp.eraseArguments(argsToErase))) {
-      newFuncOp->emitOpError("Failed to erase all arguments using BitVector.");
+      newFuncOp->emitOpError("failed to erase all arguments using BitVector");
     }
 
     rewriter.replaceOp(funcOp, newClassOp);
     return success();
   }
+
+private:
+  StringRef attributeName;
 };
 
 void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
-                                       const std::string &namedAttribute) {
+                                       StringRef namedAttribute) {
   patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
 }
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index 480943116dbf1..c67a0c197fcd9 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -1,7 +1,9 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s
 
 module attributes { } {
-  emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}, %arg1: !emitc.array<1xf32> {emitc.opaque = ["some_feature"]}, %arg2: !emitc.array<1xf32> {emitc.opaque = ["output_0"]}) attributes { } {
+  emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"}, 
+   %arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"},
+   %arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } {
     %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
     %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
     %2 = load %1 : <f32>
@@ -17,14 +19,14 @@ module attributes { } {
 
 // CHECK: module {
 // CHECK-NEXT:   emitc.class @modelClass {
-// CHECK-NEXT:     emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
-// CHECK-NEXT:     emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
-// CHECK-NEXT:     emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
+// CHECK-NEXT:     emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
+// CHECK-NEXT:     emitc.field @fieldName1 : !emitc.array<1xf32>  {emitc.name_hint = "some_feature"}
+// CHECK-NEXT:     emitc.field @fieldName2 : !emitc.array<1xf32>  {emitc.name_hint = "output_0"}
 // CHECK-NEXT:     emitc.func @execute() {
+// CHECK-NEXT:       get_field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT:       get_field @fieldName1 : !emitc.array<1xf32>
+// CHECK-NEXT:       get_field @fieldName2 : !emitc.array<1xf32>
 // CHECK-NEXT:       "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-// CHECK-NEXT:       get_field @another_feature : !emitc.array<1xf32>
-// CHECK-NEXT:       get_field @some_feature : !emitc.array<1xf32>
-// CHECK-NEXT:       get_field @output_0 : !emitc.array<1xf32>
 // CHECK-NEXT:       subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
 // CHECK-NEXT:       load {{.*}} : <f32>
 // CHECK-NEXT:       subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>

>From 3bb679968d3e64c71ed56d9140df568b4c56c704 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Wed, 18 Jun 2025 22:00:00 +0000
Subject: [PATCH 19/20] clean up how we use attributes

---
 mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp | 7 ++-----
 1 file changed, 2 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index 4a02eca594c80..ff117f2a91618 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -62,11 +62,8 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
       Attribute argAttr = nullptr;
 
       fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
-      if (argAttrs && idx < argAttrs->size()) {
-        if (DictionaryAttr dictAttr =
-                dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx]))
-          argAttr = (*argAttrs)[idx];
-      }
+      if (argAttrs && idx < argAttrs->size())
+        argAttr = (*argAttrs)[idx];
 
       TypeAttr typeAttr = TypeAttr::get(val.getType());
       fields.push_back({fieldName, typeAttr});

>From b5144f3741131b75f29e06ae9c951e6a23e29c90 Mon Sep 17 00:00:00 2001
From: Jaden Angella <141196890+Jaddyen at users.noreply.github.com>
Date: Wed, 18 Jun 2025 15:01:56 -0700
Subject: [PATCH 20/20] Update TranslateRegistration.cpp

---
 mlir/lib/Target/Cpp/TranslateRegistration.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index f1869ac9e8eda..2108ffd414c56 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -50,4 +50,4 @@ void registerToCppTranslation() {
       });
 }
 
-} // namespace mlir
\ No newline at end of file
+} // namespace mlir



More information about the Mlir-commits mailing list