[Mlir-commits] [mlir] Adding mlir models (PR #141158)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 29 09:09:58 PDT 2025


https://github.com/ajaden-codes updated https://github.com/llvm/llvm-project/pull/141158

>From 7f603ca713f7177cf3d2a60c9958ed1c9bfa569a Mon Sep 17 00:00:00 2001
From: Jaden Jaden <ajaden at google.com>
Date: Wed, 21 May 2025 23:10:51 +0000
Subject: [PATCH 1/5] Able to get class when you set a class-Name

---
 mlir/include/mlir/Target/Cpp/CppEmitter.h     |   2 +-
 mlir/lib/Target/Cpp/TranslateRegistration.cpp |   7 +-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 168 +++++++++++++++---
 3 files changed, 147 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index 7c5747a888261..f87fc67662f55 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,7 +28,7 @@ namespace emitc {
 /// with matching id are emitted.
 LogicalResult translateToCpp(Operation *op, raw_ostream &os,
                              bool declareVariablesAtTop = false,
-                             StringRef fileId = {});
+                             StringRef fileId = {}, StringRef className = {});
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 2108ffd414c56..d556927d7f904 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,13 +33,18 @@ void registerToCppTranslation() {
       "file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
       llvm::cl::init(""));
 
+  static llvm::cl::opt<std::string> className(
+      "class-name", llvm::cl::desc("Specify the class name for the generated C++ code"),
+      llvm::cl::init(""));
+
   TranslateFromMLIRRegistration reg(
       "mlir-to-cpp", "translate from mlir to cpp",
       [](Operation *op, raw_ostream &output) {
         return emitc::translateToCpp(
             op, output,
             /*declareVariablesAtTop=*/declareVariablesAtTop,
-            /*fileId=*/fileId);
+            /*fileId=*/fileId,
+            /*className=*/className);
       },
       [](DialectRegistry &registry) {
         // clang-format off
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 0c4975a13d301..4b67233aa7f89 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -68,6 +69,13 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
 }
 
+template <typename Container, typename UnaryFunctor>
+inline LogicalResult interleaveWithNewLineWithError(const Container &c,
+                                              raw_ostream &os,
+                                              UnaryFunctor eachFn) {
+  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << "\n"; });
+}
+
 /// Return the precedence of a operator as an integer, higher values
 /// imply higher precedence.
 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
@@ -116,7 +124,7 @@ namespace {
 /// Emitter that uses dialect specific emitters to emit C++ code.
 struct CppEmitter {
   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                      StringRef fileId);
+                      StringRef fileId, StringRef className);
 
   /// Emits attribute or returns failure.
   LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -233,6 +241,9 @@ struct CppEmitter {
   /// be declared at the beginning of a function.
   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
 
+  // Returns if we should spit out a C++ class
+  std::string shouldSpitClass() { return className; };
+
   /// Returns whether this file op should be emitted
   bool shouldEmitFile(FileOp file) {
     return !fileId.empty() && file.getId() == fileId;
@@ -268,6 +279,9 @@ struct CppEmitter {
   /// Only emit file ops whos id matches this value.
   std::string fileId;
 
+  /// Name of the C++ class we will spit
+  std::string className;
+
   /// Map from value to name of C++ variable that contain the name.
   ValueMapper valueMapper;
 
@@ -1025,6 +1039,32 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
       }));
 }
 
+static LogicalResult printAttributes(CppEmitter &emitter,
+                                       Operation *functionOp,
+                                       ArrayRef<Type> arguments) {
+  raw_indented_ostream &os = emitter.ostream();
+  for (auto arg : arguments) {
+    os << "  ";
+    if (failed(emitter.emitType(functionOp->getLoc(), arg)))
+      return failure();
+    os << " " << arg << ";\n";
+  }
+  return success();
+}
+
+static LogicalResult printAttributes(CppEmitter &emitter,
+                                       Operation *functionOp,
+                                       Region::BlockArgListType arguments) {
+  raw_indented_ostream &os = emitter.ostream();
+  for (auto arg : arguments) {
+    os << "  ";
+    if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
+      return failure();
+    os << " " << emitter.getOrCreateName(arg) << ";\n";
+  }
+  return success();
+}
+
 static LogicalResult printFunctionBody(CppEmitter &emitter,
                                        Operation *functionOp,
                                        Region::BlockListType &blocks) {
@@ -1138,35 +1178,107 @@ static LogicalResult printOperation(CppEmitter &emitter,
         "with multiple blocks needs variables declared at top");
   }
 
-  CppEmitter::Scope scope(emitter);
+  CppEmitter::Scope classScope(emitter);
   raw_indented_ostream &os = emitter.ostream();
-  if (functionOp.getSpecifiers()) {
-    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
-      os << cast<StringAttr>(specifier).str() << " ";
+  Operation *operation = functionOp.getOperation();
+  if (!emitter.shouldSpitClass().empty()) {
+    os << "class " << emitter.shouldSpitClass() << " final {\n";
+    os << "public: \n";
+
+    if (functionOp.isExternal()) { 
+      if (failed(printAttributes(emitter, operation,
+                                  functionOp.getArgumentTypes())))
+        return failure();
+      return success();
     }
-  }
+    if (failed(printAttributes(emitter, operation, functionOp.getArguments())))
+      return failure();
+    
+    os << "\n";
+    
+    auto argAttrs = functionOp.getArgAttrs();
+    std::map<std::string, Value> fields; 
+    if (argAttrs)
+      for (auto [a,v] : zip(*argAttrs, functionOp.getArguments())) { 
+        if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
+          auto nv = da.getNamed("tf_saved_model.index_path")->getValue(); 
+          auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str(); 
+          fields[name] = v;
+        }
+      }
 
-  if (failed(emitter.emitTypes(functionOp.getLoc(),
-                               functionOp.getFunctionType().getResults())))
-    return failure();
-  os << " " << functionOp.getName();
+    for (auto & r : functionOp->getRegions()) 
+      for (auto &b : r.getBlocks())
+        for (auto &opt : b.getOperations())  
+          if (auto alloc = dyn_cast<memref::AllocOp>(opt)) { 
+            auto name = emitter.getOrCreateName(alloc).str(); 
+            fields[name] = alloc;
+            if (failed(emitter.emitType(alloc.getLoc(), alloc.getType().getElementType())))
+              return failure();
+            os << " [" << alloc.getType().getNumElements() <<"] "; 
+            os << " " << name << ";\n";
+          }
+    os << "  std::map<std::string, char*> _buffer_map {"; 
+    for (auto &[n,v]:fields)
+      os << "{ \"" << n << "\"" << ", reinterpret_cast<char*>(" << emitter.getOrCreateName(v) << ") },";
+    os << "  };\n";
+    os << "  char* getBufferForName(const std::string& name) const {\n";
+    os << "    auto it = _buffer_map.find(name);\n";
+    os << "    return (it == _buffer_map.end()) ? nullptr : it->second;\n";
+    os << "  }\n\n";
 
-  os << "(";
-  Operation *operation = functionOp.getOperation();
-  if (functionOp.isExternal()) {
-    if (failed(printFunctionArgs(emitter, operation,
-                                 functionOp.getArgumentTypes())))
+    os.indent();
+
+    // Begin defining the main function where we have the actual execution
+    if (functionOp.getSpecifiers()) {
+      for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+        os << cast<StringAttr>(specifier).str() << " ";
+      }
+    }
+
+    if (failed(emitter.emitTypes(functionOp.getLoc(),
+                                functionOp.getFunctionType().getResults())))
       return failure();
-    os << ");";
-    return success();
-  }
-  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
-    return failure();
-  os << ") {\n";
-  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
-    return failure();
-  os << "}\n";
+    os << " " << functionOp.getName();
 
+    //Defining the functionBody
+    os << "() { \n"; // Begin defining the function header (We need the function name and output type to remain without the rest of the function header)
+    os.indent();
+    if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+      return failure();
+    os.unindent();
+    os << "}\n";
+    os.unindent();
+    os << "};\n";
+  } else {
+    if (functionOp.getSpecifiers()) {
+      for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+        os << cast<StringAttr>(specifier).str() << " ";
+      }
+    }
+
+    if (failed(emitter.emitTypes(functionOp.getLoc(),
+                                functionOp.getFunctionType().getResults())))
+      return failure();
+    os << " " << functionOp.getName();
+
+    os << "(";
+    Operation *operation = functionOp.getOperation();
+    if (functionOp.isExternal()) {
+      if (failed(printFunctionArgs(emitter, operation,
+                                  functionOp.getArgumentTypes())))
+        return failure();
+      os << ");";
+      return success();
+    }
+    if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+      return failure();
+    os << ") {\n";
+    if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+      return failure();
+    os << "}\n";
+  }
+  
   return success();
 }
 
@@ -1202,9 +1314,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
 }
 
 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                       StringRef fileId)
+                       StringRef fileId, StringRef className)
     : os(os), declareVariablesAtTop(declareVariablesAtTop),
-      fileId(fileId.str()) {
+      fileId(fileId.str()), className(className.str()) {
   valueInScopeCount.push(0);
   labelInScopeCount.push(0);
 }
@@ -1787,7 +1899,7 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
 
 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
                                     bool declareVariablesAtTop,
-                                    StringRef fileId) {
-  CppEmitter emitter(os, declareVariablesAtTop, fileId);
+                                    StringRef fileId, StringRef className) {
+  CppEmitter emitter(os, declareVariablesAtTop, fileId, className);
   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
 }

>From 9d257ce53bdfe3a504e6d28cbed87eeef5a4531a Mon Sep 17 00:00:00 2001
From: Jaden Jaden <ajaden at google.com>
Date: Thu, 22 May 2025 23:17:00 +0000
Subject: [PATCH 2/5] fixes

---
 mlir/lib/Target/Cpp/TranslateRegistration.cpp |   4 +-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 233 ++++++++++--------
 2 files changed, 128 insertions(+), 109 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index d556927d7f904..17fcc058eef04 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -34,7 +34,9 @@ void registerToCppTranslation() {
       llvm::cl::init(""));
 
   static llvm::cl::opt<std::string> className(
-      "class-name", llvm::cl::desc("Specify the class name for the generated C++ code"),
+      "class-name",
+      llvm::cl::desc("Optional class name. If specified, the output will be a "
+                     "class where the function(s) in the module are members."),
       llvm::cl::init(""));
 
   TranslateFromMLIRRegistration reg(
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 4b67233aa7f89..08303428b91af 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -73,7 +72,8 @@ template <typename Container, typename UnaryFunctor>
 inline LogicalResult interleaveWithNewLineWithError(const Container &c,
                                               raw_ostream &os,
                                               UnaryFunctor eachFn) {
-  return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << "\n"; });
+  return interleaveWithError(c.begin(), c.end(), eachFn,
+                             [&]() { os << ";\n"; });
 }
 
 /// Return the precedence of a operator as an integer, higher values
@@ -241,8 +241,11 @@ struct CppEmitter {
   /// be declared at the beginning of a function.
   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
 
-  // Returns if we should spit out a C++ class
-  std::string shouldSpitClass() { return className; };
+  // Returns whether we should emit a C++ class
+  bool shouldPrintClass() { return !className.empty(); };
+
+  // Returns the class name to emit
+  std::string getClassName() { return className; };
 
   /// Returns whether this file op should be emitted
   bool shouldEmitFile(FileOp file) {
@@ -1039,30 +1042,25 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
       }));
 }
 
-static LogicalResult printAttributes(CppEmitter &emitter,
-                                       Operation *functionOp,
-                                       ArrayRef<Type> arguments) {
+static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
+                                 ArrayRef<Type> arguments) {
   raw_indented_ostream &os = emitter.ostream();
-  for (auto arg : arguments) {
-    os << "  ";
-    if (failed(emitter.emitType(functionOp->getLoc(), arg)))
-      return failure();
-    os << " " << arg << ";\n";
-  }
-  return success();
+
+  return (interleaveWithNewLineWithError(
+      arguments, os, [&](Type arg) -> LogicalResult {
+        return emitter.emitType(functionOp->getLoc(), arg);
+      }));
 }
 
-static LogicalResult printAttributes(CppEmitter &emitter,
-                                       Operation *functionOp,
-                                       Region::BlockArgListType arguments) {
+static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
+                                 Region::BlockArgListType arguments) {
   raw_indented_ostream &os = emitter.ostream();
-  for (auto arg : arguments) {
-    os << "  ";
-    if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
-      return failure();
-    os << " " << emitter.getOrCreateName(arg) << ";\n";
-  }
-  return success();
+
+  return (interleaveWithNewLineWithError(
+      arguments, os, [&](BlockArgument arg) -> LogicalResult {
+        return emitter.emitVariableDeclaration(
+            functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
+      }));
 }
 
 static LogicalResult printFunctionBody(CppEmitter &emitter,
@@ -1169,116 +1167,135 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
-static LogicalResult printOperation(CppEmitter &emitter,
-                                    emitc::FuncOp functionOp) {
-  // We need to declare variables at top if the function has multiple blocks.
-  if (!emitter.shouldDeclareVariablesAtTop() &&
-      functionOp.getBlocks().size() > 1) {
-    return functionOp.emitOpError(
-        "with multiple blocks needs variables declared at top");
-  }
-
-  CppEmitter::Scope classScope(emitter);
+static LogicalResult printFunctionHeader(CppEmitter &emitter,
+                                         emitc::FuncOp functionOp) {
   raw_indented_ostream &os = emitter.ostream();
   Operation *operation = functionOp.getOperation();
-  if (!emitter.shouldSpitClass().empty()) {
-    os << "class " << emitter.shouldSpitClass() << " final {\n";
-    os << "public: \n";
+  if (functionOp.getSpecifiers()) {
+    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+      os << cast<StringAttr>(specifier).str() << " ";
+    }
+  }
 
-    if (functionOp.isExternal()) { 
-      if (failed(printAttributes(emitter, operation,
-                                  functionOp.getArgumentTypes())))
+  if (failed(emitter.emitTypes(functionOp.getLoc(),
+                               functionOp.getFunctionType().getResults())))
+    return failure();
+  os << " " << functionOp.getName();
+  if (!emitter.shouldPrintClass()) {
+    os << "(";
+    if (functionOp.isExternal()) {
+      if (failed(printFunctionArgs(emitter, operation,
+                                   functionOp.getArgumentTypes())))
         return failure();
+      os << ");";
       return success();
     }
-    if (failed(printAttributes(emitter, operation, functionOp.getArguments())))
+    if (failed(
+            printFunctionArgs(emitter, operation, functionOp.getArguments())))
       return failure();
-    
-    os << "\n";
-    
-    auto argAttrs = functionOp.getArgAttrs();
-    std::map<std::string, Value> fields; 
-    if (argAttrs)
-      for (auto [a,v] : zip(*argAttrs, functionOp.getArguments())) { 
-        if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
-          auto nv = da.getNamed("tf_saved_model.index_path")->getValue(); 
-          auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str(); 
-          fields[name] = v;
-        }
-      }
+    os << ") {\n";
 
-    for (auto & r : functionOp->getRegions()) 
-      for (auto &b : r.getBlocks())
-        for (auto &opt : b.getOperations())  
-          if (auto alloc = dyn_cast<memref::AllocOp>(opt)) { 
-            auto name = emitter.getOrCreateName(alloc).str(); 
-            fields[name] = alloc;
-            if (failed(emitter.emitType(alloc.getLoc(), alloc.getType().getElementType())))
-              return failure();
-            os << " [" << alloc.getType().getNumElements() <<"] "; 
-            os << " " << name << ";\n";
-          }
-    os << "  std::map<std::string, char*> _buffer_map {"; 
-    for (auto &[n,v]:fields)
-      os << "{ \"" << n << "\"" << ", reinterpret_cast<char*>(" << emitter.getOrCreateName(v) << ") },";
-    os << "  };\n";
-    os << "  char* getBufferForName(const std::string& name) const {\n";
-    os << "    auto it = _buffer_map.find(name);\n";
-    os << "    return (it == _buffer_map.end()) ? nullptr : it->second;\n";
-    os << "  }\n\n";
+  } else {
+    os << "() { \n";
+  }
 
-    os.indent();
+  return success();
+}
 
-    // Begin defining the main function where we have the actual execution
-    if (functionOp.getSpecifiers()) {
-      for (Attribute specifier : functionOp.getSpecifiersAttr()) {
-        os << cast<StringAttr>(specifier).str() << " ";
+static LogicalResult emitClassBody(CppEmitter &emitter,
+                                   emitc::FuncOp functionOp) {
+  raw_indented_ostream &os = emitter.ostream();
+  Operation *operation = functionOp.getOperation();
+  auto argAttrs = functionOp.getArgAttrs();
+  std::map<std::string, Value> fields;
+  os << "\nstd::map<std::string, char*> _buffer_map {";
+  if (argAttrs) // We can have no argattrs in the case that the function has no
+                // inputs nor outputs -> procedure
+    for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
+      if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
+        auto nv =
+            da.getNamed("tf_saved_model.index_path")
+                ->getValue(); // From what I've seen so far, this is the only
+                              // way to have the argAttrs keys. If there is
+                              // another way, I need to run the tests to see and
+                              // see what cases trigger this change in format.
+        auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
+        fields[name] = v; // The only way to not have unique names is in the
+                          // case that you have duplicate arguments in your
+                          // tensorflow/python function. By python syntax rules,
+                          // you're not allowed to have that(Current assumption)
+        os << "{ \"" << name << "\"" << ", reinterpret_cast<char*>("
+           << emitter.getOrCreateName(v) << ") },";
       }
     }
+  else
+    return failure();
+
+  os << " };\n";
+  os << "char* getBufferForName(const std::string& name) const {\n";
+  os.indent();
+  os.indent();
+  os << "auto it = _buffer_map.find(name);\n";
+  os << "return (it == _buffer_map.end()) ? nullptr : it->second;\n";
+  os.unindent();
+  os.unindent();
+  os << "}\n\n";
+
+  if (failed(printFunctionHeader(emitter, functionOp)))
+    return failure();
+
+  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    return failure();
+  os << "}\n";
+  os.unindent();
+  os << "};\n";
+
+  return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::FuncOp functionOp) {
+  // We need to declare variables at top if the function has multiple blocks.
+  if (!emitter.shouldDeclareVariablesAtTop() &&
+      functionOp.getBlocks().size() > 1) {
+    return functionOp.emitOpError(
+        "with multiple blocks needs variables declared at top");
+  }
 
-    if (failed(emitter.emitTypes(functionOp.getLoc(),
-                                functionOp.getFunctionType().getResults())))
+  CppEmitter::Scope classScope(emitter);
+  raw_indented_ostream &os = emitter.ostream();
+  Operation *operation = functionOp.getOperation();
+  if (!emitter.shouldPrintClass()) {
+
+    if (failed(printFunctionHeader(emitter, functionOp)))
       return failure();
-    os << " " << functionOp.getName();
 
-    //Defining the functionBody
-    os << "() { \n"; // Begin defining the function header (We need the function name and output type to remain without the rest of the function header)
-    os.indent();
-    if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    if (failed(printFunctionBody(
+            emitter, operation,
+            functionOp.getBlocks()))) // This is the only similarity between the
+                                      // function and the class
       return failure();
-    os.unindent();
     os << "}\n";
-    os.unindent();
-    os << "};\n";
-  } else {
-    if (functionOp.getSpecifiers()) {
-      for (Attribute specifier : functionOp.getSpecifiersAttr()) {
-        os << cast<StringAttr>(specifier).str() << " ";
-      }
-    }
 
-    if (failed(emitter.emitTypes(functionOp.getLoc(),
-                                functionOp.getFunctionType().getResults())))
-      return failure();
-    os << " " << functionOp.getName();
+  } else {
+    os << "class " << emitter.getClassName() << " final {\n";
+    os << "public: \n";
+    os.indent();
 
-    os << "(";
-    Operation *operation = functionOp.getOperation();
     if (functionOp.isExternal()) {
-      if (failed(printFunctionArgs(emitter, operation,
-                                  functionOp.getArgumentTypes())))
+      if (failed(
+              printFields(emitter, operation, functionOp.getArgumentTypes())))
         return failure();
-      os << ");";
       return success();
     }
-    if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+    if (failed(printFields(emitter, operation, functionOp.getArguments())))
       return failure();
-    os << ") {\n";
-    if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    os << ";\n";
+
+    if (failed(emitClassBody(emitter, functionOp)))
       return failure();
-    os << "}\n";
   }
-  
+
   return success();
 }
 

>From 2917dc4191ac1d5c96cbc544a73d14f4b8ecc1f4 Mon Sep 17 00:00:00 2001
From: Jaden Jaden <ajaden at google.com>
Date: Tue, 27 May 2025 18:39:14 +0000
Subject: [PATCH 3/5] Added flags to increase readability

---
 mlir/include/mlir/Target/Cpp/CppEmitter.h     |   4 +-
 mlir/lib/Target/Cpp/TranslateRegistration.cpp |  38 +++-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 181 ++++++++----------
 mlir/test/mlir-translate/emit-class.mlir      |  49 +++++
 4 files changed, 170 insertions(+), 102 deletions(-)
 create mode 100644 mlir/test/mlir-translate/emit-class.mlir

diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index f87fc67662f55..d1a6c1dc12d4c 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,7 +28,9 @@ namespace emitc {
 /// with matching id are emitted.
 LogicalResult translateToCpp(Operation *op, raw_ostream &os,
                              bool declareVariablesAtTop = false,
-                             StringRef fileId = {}, StringRef className = {});
+                             StringRef fileId = {}, bool emitClass = false,
+                             StringRef className = {},
+                             StringRef fieldNameAttribute = {});
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 17fcc058eef04..69e0ab01bb71d 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,20 +33,50 @@ void registerToCppTranslation() {
       "file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
       llvm::cl::init(""));
 
+  static llvm::cl::opt<bool> emitClass(
+      "emit-class",
+      llvm::cl::desc("If specified, the output will be a class where "
+                     "the function(s) in the module are members. "
+                     "Enables class-related options."),
+      llvm::cl::init(false));
+
   static llvm::cl::opt<std::string> className(
       "class-name",
-      llvm::cl::desc("Optional class name. If specified, the output will be a "
-                     "class where the function(s) in the module are members."),
+      llvm::cl::desc("Mandatory class name if --emit-class is set."),
+      llvm::cl::init(""));
+
+  static llvm::cl::opt<std::string> fieldNameAttribute(
+      "field-name-attribute",
+      llvm::cl::desc("Mandatory name of the attribute to use as field name if "
+                     "--emit-class is set."),
       llvm::cl::init(""));
 
   TranslateFromMLIRRegistration reg(
       "mlir-to-cpp", "translate from mlir to cpp",
       [](Operation *op, raw_ostream &output) {
+        if (emitClass) {
+          if (className.empty()) {
+            llvm::errs() << "Error: --class-name is mandatory when "
+                            "--emit-class is set.\n";
+            return mlir::failure();
+          }
+          if (fieldNameAttribute.empty()) {
+            llvm::errs() << "Error: --field-name-attribute is mandatory when "
+                            "--emit-class is set.\n";
+            return mlir::failure();
+          }
+          return emitc::translateToCpp(
+              op, output,
+              /*declareVariablesAtTop=*/declareVariablesAtTop,
+              /*fileId=*/fileId, /*emitClass=*/emitClass,
+              /*className=*/className,
+              /*fieldNameAttribute=*/fieldNameAttribute);
+        }
         return emitc::translateToCpp(
             op, output,
             /*declareVariablesAtTop=*/declareVariablesAtTop,
-            /*fileId=*/fileId,
-            /*className=*/className);
+            /*fileId=*/fileId, /*emitClass=*/emitClass, /*className=*/className,
+            /*fieldNameAttribute=*/fieldNameAttribute);
       },
       [](DialectRegistry &registry) {
         // clang-format off
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 08303428b91af..4f12e9b5669ad 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -22,6 +22,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <stack>
@@ -124,7 +125,8 @@ namespace {
 /// Emitter that uses dialect specific emitters to emit C++ code.
 struct CppEmitter {
   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                      StringRef fileId, StringRef className);
+                      StringRef fileId, bool emitClass, StringRef className,
+                      StringRef fieldNameAttribute);
 
   /// Emits attribute or returns failure.
   LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -242,11 +244,14 @@ struct CppEmitter {
   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
 
   // Returns whether we should emit a C++ class
-  bool shouldPrintClass() { return !className.empty(); };
+  bool shouldPrintClass() { return emitClass; };
 
   // Returns the class name to emit
   std::string getClassName() { return className; };
 
+  // Returns the field name to use in the map
+  std::string getfieldNameAttribute() { return fieldNameAttribute; };
+
   /// Returns whether this file op should be emitted
   bool shouldEmitFile(FileOp file) {
     return !fileId.empty() && file.getId() == fileId;
@@ -282,9 +287,18 @@ struct CppEmitter {
   /// Only emit file ops whos id matches this value.
   std::string fileId;
 
-  /// Name of the C++ class we will spit
+  /// Controls whether the output should be a C++ class.
+  /// If true, the generated C++ code will be encapsulated within a class,
+  /// and functions from the input module will become its member functions.
+  bool emitClass;
+
+  /// The specified name for the generated C++ class
   std::string className;
 
+  /// Name of the MLIR attribute to use as a field name within the generated
+  /// class
+  std::string fieldNameAttribute;
+
   /// Map from value to name of C++ variable that contain the name.
   ValueMapper valueMapper;
 
@@ -1042,16 +1056,6 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
       }));
 }
 
-static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
-                                 ArrayRef<Type> arguments) {
-  raw_indented_ostream &os = emitter.ostream();
-
-  return (interleaveWithNewLineWithError(
-      arguments, os, [&](Type arg) -> LogicalResult {
-        return emitter.emitType(functionOp->getLoc(), arg);
-      }));
-}
-
 static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
                                  Region::BlockArgListType arguments) {
   raw_indented_ostream &os = emitter.ostream();
@@ -1167,71 +1171,38 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
-static LogicalResult printFunctionHeader(CppEmitter &emitter,
-                                         emitc::FuncOp functionOp) {
+static LogicalResult emitClassFields(CppEmitter &emitter,
+                                     emitc::FuncOp functionOp) {
   raw_indented_ostream &os = emitter.ostream();
+  auto argAttrs = functionOp.getArgAttrs();
   Operation *operation = functionOp.getOperation();
-  if (functionOp.getSpecifiers()) {
-    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
-      os << cast<StringAttr>(specifier).str() << " ";
-    }
-  }
-
-  if (failed(emitter.emitTypes(functionOp.getLoc(),
-                               functionOp.getFunctionType().getResults())))
+  if (failed(printFields(emitter, operation, functionOp.getArguments())))
     return failure();
-  os << " " << functionOp.getName();
-  if (!emitter.shouldPrintClass()) {
-    os << "(";
-    if (functionOp.isExternal()) {
-      if (failed(printFunctionArgs(emitter, operation,
-                                   functionOp.getArgumentTypes())))
-        return failure();
-      os << ");";
-      return success();
-    }
-    if (failed(
-            printFunctionArgs(emitter, operation, functionOp.getArguments())))
-      return failure();
-    os << ") {\n";
+  os << ";\n";
 
-  } else {
-    os << "() { \n";
-  }
-
-  return success();
-}
-
-static LogicalResult emitClassBody(CppEmitter &emitter,
-                                   emitc::FuncOp functionOp) {
-  raw_indented_ostream &os = emitter.ostream();
-  Operation *operation = functionOp.getOperation();
-  auto argAttrs = functionOp.getArgAttrs();
   std::map<std::string, Value> fields;
-  os << "\nstd::map<std::string, char*> _buffer_map {";
-  if (argAttrs) // We can have no argattrs in the case that the function has no
-                // inputs nor outputs -> procedure
+  os << "std::map<std::string, char*> _buffer_map {";
+  if (argAttrs) {
+    bool isFirst = true;
     for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
       if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
-        auto nv =
-            da.getNamed("tf_saved_model.index_path")
-                ->getValue(); // From what I've seen so far, this is the only
-                              // way to have the argAttrs keys. If there is
-                              // another way, I need to run the tests to see and
-                              // see what cases trigger this change in format.
+        auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
         auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
-        fields[name] = v; // The only way to not have unique names is in the
-                          // case that you have duplicate arguments in your
-                          // tensorflow/python function. By python syntax rules,
-                          // you're not allowed to have that(Current assumption)
+        auto Ins = fields.insert({name, v});
+        if (!Ins.second)
+          return failure();
+        if (!isFirst) {
+          os << ",";
+        }
         os << "{ \"" << name << "\"" << ", reinterpret_cast<char*>("
-           << emitter.getOrCreateName(v) << ") },";
+           << emitter.getOrCreateName(v) << ") }";
+        isFirst = false;
       }
     }
-  else
+  } else
     return failure();
 
-  os << " };\n";
+  os << "};";
   os << "char* getBufferForName(const std::string& name) const {\n";
   os.indent();
   os.indent();
@@ -1241,15 +1212,6 @@ static LogicalResult emitClassBody(CppEmitter &emitter,
   os.unindent();
   os << "}\n\n";
 
-  if (failed(printFunctionHeader(emitter, functionOp)))
-    return failure();
-
-  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
-    return failure();
-  os << "}\n";
-  os.unindent();
-  os << "};\n";
-
   return success();
 }
 
@@ -1262,38 +1224,58 @@ static LogicalResult printOperation(CppEmitter &emitter,
         "with multiple blocks needs variables declared at top");
   }
 
-  CppEmitter::Scope classScope(emitter);
+  CppEmitter::Scope scope(emitter);
   raw_indented_ostream &os = emitter.ostream();
   Operation *operation = functionOp.getOperation();
-  if (!emitter.shouldPrintClass()) {
-
-    if (failed(printFunctionHeader(emitter, functionOp)))
-      return failure();
-
-    if (failed(printFunctionBody(
-            emitter, operation,
-            functionOp.getBlocks()))) // This is the only similarity between the
-                                      // function and the class
+  if (emitter.shouldPrintClass()) {
+    if (functionOp.isExternal())
       return failure();
-    os << "}\n";
-
-  } else {
     os << "class " << emitter.getClassName() << " final {\n";
     os << "public: \n";
     os.indent();
 
+    if (failed(emitClassFields(emitter, functionOp)))
+      return failure();
+  }
+
+  if (functionOp.getSpecifiers()) {
+    for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+      os << cast<StringAttr>(specifier).str() << " ";
+    }
+  }
+
+  if (failed(emitter.emitTypes(functionOp.getLoc(),
+                               functionOp.getFunctionType().getResults())))
+    return failure();
+  os << " " << functionOp.getName();
+
+  os << "(";
+
+  if (emitter.shouldPrintClass())
+    os << ") { \n";
+  else {
     if (functionOp.isExternal()) {
-      if (failed(
-              printFields(emitter, operation, functionOp.getArgumentTypes())))
+      if (failed(printFunctionArgs(emitter, operation,
+                                   functionOp.getArgumentTypes())))
         return failure();
+      os << ");";
       return success();
     }
-    if (failed(printFields(emitter, operation, functionOp.getArguments())))
+    if (failed(
+            printFunctionArgs(emitter, operation, functionOp.getArguments())))
       return failure();
-    os << ";\n";
+    os << ") {\n";
+  }
 
-    if (failed(emitClassBody(emitter, functionOp)))
-      return failure();
+  if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+    return failure();
+
+  if (emitter.shouldPrintClass()) {
+    os << "}\n";
+    os.unindent();
+    os << "};\n";
+  } else {
+    os << "}\n";
   }
 
   return success();
@@ -1331,9 +1313,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
 }
 
 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                       StringRef fileId, StringRef className)
+                       StringRef fileId, bool emitClass, StringRef className,
+                       StringRef fieldNameAttribute)
     : os(os), declareVariablesAtTop(declareVariablesAtTop),
-      fileId(fileId.str()), className(className.str()) {
+      fileId(fileId.str()), emitClass(emitClass), className(className.str()),
+      fieldNameAttribute(fieldNameAttribute.str()) {
   valueInScopeCount.push(0);
   labelInScopeCount.push(0);
 }
@@ -1916,7 +1900,10 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
 
 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
                                     bool declareVariablesAtTop,
-                                    StringRef fileId, StringRef className) {
-  CppEmitter emitter(os, declareVariablesAtTop, fileId, className);
+                                    StringRef fileId, bool emitClass,
+                                    StringRef className,
+                                    StringRef fieldNameAttribute) {
+  CppEmitter emitter(os, declareVariablesAtTop, fileId, emitClass, className,
+                     fieldNameAttribute);
   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
 }
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
new file mode 100644
index 0000000000000..eddc2181a0b19
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path /tmp/model_emitc.mlir | FileCheck %s --check-prefix=ADDER_TEST
+
+// ADDER_TEST-LABEL: class MyAdder final {
+// ADDER_TEST-NEXT: public:
+// ADDER_TEST-DAG: float v1[1];
+// ADDER_TEST-DAG: float v2[1];
+// ADDER_TEST-DAG: float v3[1];      
+// ADDER_TEST-NEXT: std::map<std::string, char*> _buffer_map {{ "another_feature", reinterpret_cast<char*>(v1) },{ "some_feature", reinterpret_cast<char*>(v2) },{ "output_0", reinterpret_cast<char*>(v3) }};
+// ADDER_TEST-NEXT: char* getBufferForName(const std::string& name) const {
+// ADDER_TEST-NEXT:     auto it = _buffer_map.find(name);
+// ADDER_TEST-NEXT:     return (it == _buffer_map.end()) ? nullptr : it->second;
+// ADDER_TEST-NEXT: }
+// ADDER_TEST-NEXT: void main() {
+// ADDER_TEST-NEXT:     size_t v4 = 0;
+// ADDER_TEST-NEXT:     float v5 = v2[v4];
+// ADDER_TEST-NEXT:     float v6 = v1[v4];
+// ADDER_TEST-NEXT:     float v7 = v5 + v6;
+// ADDER_TEST-NEXT:     v3[v4] = v7;
+// ADDER_TEST-NEXT:     return;
+// ADDER_TEST-NEXT: }
+// ADDER_TEST-NEXT: };
+
+// ---
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyMultiOutput --field-name-attribute=tf_saved_model.index_path /tmp/model_multi_out_emitc.mlir | FileCheck %s --check-prefix=MULTI_OUT
+
+// MULTI_OUT-LABEL: class MyMultiOutput final {
+// MULTI_OUT-NEXT: public:
+// MULTI_OUT-DAG: float v1[1];
+// MULTI_OUT-DAG: float v2[1];
+// MULTI_OUT-DAG: float v3[1];
+// MULTI_OUT-DAG: float v4[1];
+// MULTI_OUT: std::map<std::string, char*> _buffer_map {{ "b", reinterpret_cast<char*>(v1) },{ "a", reinterpret_cast<char*>(v2) },{ "output_1", reinterpret_cast<char*>(v3) },{ "output_0", reinterpret_cast<char*>(v4) }, };
+// MULTI_OUT-NEXT: char* getBufferForName(const std::string& name) const {
+// MULTI_OUT-NEXT:     auto it = _buffer_map.find(name);
+// MULTI_OUT-NEXT:     return (it == _buffer_map.end()) ? nullptr : it->second;
+// MULTI_OUT-NEXT: }
+// MULTI_OUT-NEXT: void main() {
+// MULTI_OUT-NEXT:     size_t v5 = 0;
+// MULTI_OUT-NEXT:     float v6 = v2[v5];
+// MULTI_OUT-NEXT:     float v7 = v1[v5];
+// MULTI_OUT-NEXT:     float v8 = v6 + v7;
+// MULTI_OUT-NEXT:     v4[v5] = v8;
+// MULTI_OUT-NEXT:     float v9 = v2[v5];
+// MULTI_OUT-NEXT:     float v10 = v1[v5];
+// MULTI_OUT-NEXT:     float v11 = v9 - v10;
+// MULTI_OUT-NEXT:     v3[v5] = v11;
+// MULTI_OUT-NEXT:     return;
+// MULTI_OUT-NEXT: }
+// MULTI_OUT-NEXT: };

>From bd3b5bd498a4b98246fe9f2d44fdf217c9f8fe58 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Wed, 28 May 2025 21:33:44 +0000
Subject: [PATCH 4/5] Working tests to emit-class

---
 mlir/lib/Target/Cpp/TranslateRegistration.cpp | 10 +--
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 35 ++++----
 .../emit-class-neg-external.mlir              |  8 ++
 .../emit-class-neg-noArgAttrs.mlir            | 15 ++++
 mlir/test/mlir-translate/emit-class.mlir      | 81 ++++++++-----------
 5 files changed, 80 insertions(+), 69 deletions(-)
 create mode 100644 mlir/test/mlir-translate/emit-class-neg-external.mlir
 create mode 100644 mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir

diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 69e0ab01bb71d..9e1533d34f6ea 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -36,20 +36,20 @@ void registerToCppTranslation() {
   static llvm::cl::opt<bool> emitClass(
       "emit-class",
       llvm::cl::desc("If specified, the output will be a class where "
-                     "the function(s) in the module are members. "
-                     "Enables class-related options."),
+                     "the function(s) in the module are methods "
+                     "Enables class-related options"),
       llvm::cl::init(false));
 
   static llvm::cl::opt<std::string> className(
       "class-name",
-      llvm::cl::desc("Mandatory class name if --emit-class is set."),
+      llvm::cl::desc("Mandatory class name if --emit-class is set"),
       llvm::cl::init(""));
 
   static llvm::cl::opt<std::string> fieldNameAttribute(
       "field-name-attribute",
       llvm::cl::desc("Mandatory name of the attribute to use as field name if "
-                     "--emit-class is set."),
-      llvm::cl::init(""));
+                     "--emit-class is set(default=tf_saved_model.index_path)"),
+      llvm::cl::init("tf_saved_model.index_path"));
 
   TranslateFromMLIRRegistration reg(
       "mlir-to-cpp", "translate from mlir to cpp",
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 4f12e9b5669ad..5b2fa7a624d37 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -22,7 +22,6 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FormatVariadic.h"
 #include <stack>
@@ -290,14 +289,14 @@ struct CppEmitter {
   /// Controls whether the output should be a C++ class.
   /// If true, the generated C++ code will be encapsulated within a class,
   /// and functions from the input module will become its member functions.
-  bool emitClass;
+  const bool emitClass;
 
   /// The specified name for the generated C++ class
-  std::string className;
+  const std::string className;
 
   /// Name of the MLIR attribute to use as a field name within the generated
   /// class
-  std::string fieldNameAttribute;
+  const std::string fieldNameAttribute;
 
   /// Map from value to name of C++ variable that contain the name.
   ValueMapper valueMapper;
@@ -1181,9 +1180,8 @@ static LogicalResult emitClassFields(CppEmitter &emitter,
   os << ";\n";
 
   std::map<std::string, Value> fields;
-  os << "std::map<std::string, char*> _buffer_map {";
+  os << "\nstd::map<std::string, char*> _buffer_map {";
   if (argAttrs) {
-    bool isFirst = true;
     for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
       if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
         auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
@@ -1191,18 +1189,14 @@ static LogicalResult emitClassFields(CppEmitter &emitter,
         auto Ins = fields.insert({name, v});
         if (!Ins.second)
           return failure();
-        if (!isFirst) {
-          os << ",";
-        }
-        os << "{ \"" << name << "\"" << ", reinterpret_cast<char*>("
-           << emitter.getOrCreateName(v) << ") }";
-        isFirst = false;
+        os << " { \"" << name << "\"" << ", reinterpret_cast<char*>("
+           << emitter.getOrCreateName(v) << ") }, ";
       }
     }
   } else
     return failure();
 
-  os << "};";
+  os << "};\n";
   os << "char* getBufferForName(const std::string& name) const {\n";
   os.indent();
   os.indent();
@@ -1228,8 +1222,15 @@ static LogicalResult printOperation(CppEmitter &emitter,
   raw_indented_ostream &os = emitter.ostream();
   Operation *operation = functionOp.getOperation();
   if (emitter.shouldPrintClass()) {
-    if (functionOp.isExternal())
+    if (functionOp.isExternal()) {
+      // TODO: Determine the best long-term strategy for external functions.
+      // Currently, we're stopping here to prevent downstream errors.
+      os << "Warning: Cannot process external function '"
+         << functionOp.getName() << "'. "
+         << "It lacks a body, and attempting to continue would lead to errors "
+            "due to missing argument details.\n";
       return failure();
+    }
     os << "class " << emitter.getClassName() << " final {\n";
     os << "public: \n";
     os.indent();
@@ -1251,9 +1252,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
 
   os << "(";
 
-  if (emitter.shouldPrintClass())
-    os << ") { \n";
-  else {
+  if (!emitter.shouldPrintClass()) {
     if (functionOp.isExternal()) {
       if (failed(printFunctionArgs(emitter, operation,
                                    functionOp.getArgumentTypes())))
@@ -1264,8 +1263,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
     if (failed(
             printFunctionArgs(emitter, operation, functionOp.getArguments())))
       return failure();
-    os << ") {\n";
   }
+  os << ") {\n";
 
   if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
     return failure();
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
new file mode 100644
index 0000000000000..347992c467759
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-external.mlir
@@ -0,0 +1,8 @@
+/// An external function - has no body
+// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s 2>&1 | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+  emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
+}
+
+// CHECK: Warning: Cannot process external function 'extern_func'. It lacks a body, and attempting to continue would lead to errors due to missing argument details.
\ No newline at end of file
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
new file mode 100644
index 0000000000000..77e89da2f0a4f
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
@@ -0,0 +1,15 @@
+/// The function has no argument attributes
+// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+  emitc.func @foo(%arg0 : i32) {
+    emitc.call_opaque "bar" (%arg0) : (i32) -> ()
+    emitc.return
+  }
+}
+
+// CHECK: class ArgAttrs final {
+// CHECK-NEXT: public: 
+// CHECK-NEXT:   int32_t v1;
+// CHECK-EMPTY: 
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map {
\ No newline at end of file
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
index eddc2181a0b19..9fa69f652df6c 100644
--- a/mlir/test/mlir-translate/emit-class.mlir
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -1,49 +1,38 @@
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path /tmp/model_emitc.mlir | FileCheck %s --check-prefix=ADDER_TEST
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s  
 
-// ADDER_TEST-LABEL: class MyAdder final {
-// ADDER_TEST-NEXT: public:
-// ADDER_TEST-DAG: float v1[1];
-// ADDER_TEST-DAG: float v2[1];
-// ADDER_TEST-DAG: float v3[1];      
-// ADDER_TEST-NEXT: std::map<std::string, char*> _buffer_map {{ "another_feature", reinterpret_cast<char*>(v1) },{ "some_feature", reinterpret_cast<char*>(v2) },{ "output_0", reinterpret_cast<char*>(v3) }};
-// ADDER_TEST-NEXT: char* getBufferForName(const std::string& name) const {
-// ADDER_TEST-NEXT:     auto it = _buffer_map.find(name);
-// ADDER_TEST-NEXT:     return (it == _buffer_map.end()) ? nullptr : it->second;
-// ADDER_TEST-NEXT: }
-// ADDER_TEST-NEXT: void main() {
-// ADDER_TEST-NEXT:     size_t v4 = 0;
-// ADDER_TEST-NEXT:     float v5 = v2[v4];
-// ADDER_TEST-NEXT:     float v6 = v1[v4];
-// ADDER_TEST-NEXT:     float v7 = v5 + v6;
-// ADDER_TEST-NEXT:     v3[v4] = v7;
-// ADDER_TEST-NEXT:     return;
-// ADDER_TEST-NEXT: }
-// ADDER_TEST-NEXT: };
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+  emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+    %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+    %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %2 = load %1 : <f32>
+    %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %4 = load %3 : <f32>
+    %5 = add %2, %4 : (f32, f32) -> f32
+    %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    assign %5 : f32 to %6 : <f32>
+    return
+  }
+}
 
-// ---
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyMultiOutput --field-name-attribute=tf_saved_model.index_path /tmp/model_multi_out_emitc.mlir | FileCheck %s --check-prefix=MULTI_OUT
+// CHECK: class MyAdder final {
+// CHECK-NEXT: public:
+// CHECK-NEXT:   float v1[1];
+// CHECK-NEXT:   float v2[1];
+// CHECK-NEXT:   float v3[1];
+// CHECK-EMPTY:
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
+// CHECK-NEXT:   char* getBufferForName(const std::string& name) const {
+// CHECK-NEXT:      auto it = _buffer_map.find(name);
+// CHECK-NEXT:      return (it == _buffer_map.end()) ? nullptr : it->second;
+// CHECK-NEXT:   }
+// CHECK-EMPTY:
+// CHECK-NEXT: void main() {
+// CHECK-NEXT:     size_t v4 = 0;
+// CHECK-NEXT:     float v5 = v2[v4];
+// CHECK-NEXT:     float v6 = v1[v4];
+// CHECK-NEXT:     float v7 = v5 + v6;
+// CHECK-NEXT:     v3[v4] = v7;
+// CHECK-NEXT:     return;
+// CHECK-NEXT:  }
+// CHECK-NEXT: };
 
-// MULTI_OUT-LABEL: class MyMultiOutput final {
-// MULTI_OUT-NEXT: public:
-// MULTI_OUT-DAG: float v1[1];
-// MULTI_OUT-DAG: float v2[1];
-// MULTI_OUT-DAG: float v3[1];
-// MULTI_OUT-DAG: float v4[1];
-// MULTI_OUT: std::map<std::string, char*> _buffer_map {{ "b", reinterpret_cast<char*>(v1) },{ "a", reinterpret_cast<char*>(v2) },{ "output_1", reinterpret_cast<char*>(v3) },{ "output_0", reinterpret_cast<char*>(v4) }, };
-// MULTI_OUT-NEXT: char* getBufferForName(const std::string& name) const {
-// MULTI_OUT-NEXT:     auto it = _buffer_map.find(name);
-// MULTI_OUT-NEXT:     return (it == _buffer_map.end()) ? nullptr : it->second;
-// MULTI_OUT-NEXT: }
-// MULTI_OUT-NEXT: void main() {
-// MULTI_OUT-NEXT:     size_t v5 = 0;
-// MULTI_OUT-NEXT:     float v6 = v2[v5];
-// MULTI_OUT-NEXT:     float v7 = v1[v5];
-// MULTI_OUT-NEXT:     float v8 = v6 + v7;
-// MULTI_OUT-NEXT:     v4[v5] = v8;
-// MULTI_OUT-NEXT:     float v9 = v2[v5];
-// MULTI_OUT-NEXT:     float v10 = v1[v5];
-// MULTI_OUT-NEXT:     float v11 = v9 - v10;
-// MULTI_OUT-NEXT:     v3[v5] = v11;
-// MULTI_OUT-NEXT:     return;
-// MULTI_OUT-NEXT: }
-// MULTI_OUT-NEXT: };

>From b1d039298b66119f0b63a93a3f104684c53c5930 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 29 May 2025 16:09:06 +0000
Subject: [PATCH 5/5] A lil' cleaning up

---
 mlir/lib/Target/Cpp/TranslateToCpp.cpp                  | 7 +++----
 mlir/test/mlir-translate/emit-class-neg-external.mlir   | 6 +++---
 mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir | 4 ++--
 mlir/test/mlir-translate/emit-class.mlir                | 5 +++--
 4 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 5b2fa7a624d37..11abe5a38119d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1224,12 +1224,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
   if (emitter.shouldPrintClass()) {
     if (functionOp.isExternal()) {
       // TODO: Determine the best long-term strategy for external functions.
-      // Currently, we're stopping here to prevent downstream errors.
+      // Currently, we're skipping over this functionOp.
       os << "Warning: Cannot process external function '"
          << functionOp.getName() << "'. "
-         << "It lacks a body, and attempting to continue would lead to errors "
-            "due to missing argument details.\n";
-      return failure();
+         << "This functionOp lacks a body so we will skip over it.";
+      return success();
     }
     os << "class " << emitter.getClassName() << " final {\n";
     os << "public: \n";
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
index 347992c467759..c34a1652abd3f 100644
--- a/mlir/test/mlir-translate/emit-class-neg-external.mlir
+++ b/mlir/test/mlir-translate/emit-class-neg-external.mlir
@@ -1,8 +1,8 @@
 /// An external function - has no body
-// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s 2>&1 | FileCheck %s
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
 
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
   emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
 }
 
-// CHECK: Warning: Cannot process external function 'extern_func'. It lacks a body, and attempting to continue would lead to errors due to missing argument details.
\ No newline at end of file
+// CHECK: Warning: Cannot process external function 'extern_func'. This functionOp lacks a body so we will skip over it.
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
index 77e89da2f0a4f..6d43fa953a946 100644
--- a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
+++ b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
@@ -1,7 +1,7 @@
 /// The function has no argument attributes
 // RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
 
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
   emitc.func @foo(%arg0 : i32) {
     emitc.call_opaque "bar" (%arg0) : (i32) -> ()
     emitc.return
@@ -12,4 +12,4 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
 // CHECK-NEXT: public: 
 // CHECK-NEXT:   int32_t v1;
 // CHECK-EMPTY: 
-// CHECK-NEXT:   std::map<std::string, char*> _buffer_map {
\ No newline at end of file
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map {
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
index 9fa69f652df6c..2779cb315ed41 100644
--- a/mlir/test/mlir-translate/emit-class.mlir
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -1,6 +1,6 @@
 // RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s  
 
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
   emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
     %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
     %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
@@ -20,7 +20,8 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
 // CHECK-NEXT:   float v2[1];
 // CHECK-NEXT:   float v3[1];
 // CHECK-EMPTY:
-// CHECK-NEXT:   std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, 
+// CHECK-SAME: { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
 // CHECK-NEXT:   char* getBufferForName(const std::string& name) const {
 // CHECK-NEXT:      auto it = _buffer_map.find(name);
 // CHECK-NEXT:      return (it == _buffer_map.end()) ? nullptr : it->second;



More information about the Mlir-commits mailing list