[Mlir-commits] [mlir] Adding mlir models (PR #141158)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 29 10:55:07 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-emitc

Author: None (Jaddyen)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/141158.diff


6 Files Affected:

- (modified) mlir/include/mlir/Target/Cpp/CppEmitter.h (+3-1) 
- (modified) mlir/lib/Target/Cpp/TranslateRegistration.cpp (+38-1) 
- (modified) mlir/lib/Target/Cpp/TranslateToCpp.cpp (+136-14) 
- (added) mlir/test/mlir-translate/emit-class-neg-external.mlir (+8) 
- (added) mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir (+15) 
- (added) mlir/test/mlir-translate/emit-class.mlir (+39) 


``````````diff
diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index 7c5747a888261..d1a6c1dc12d4c 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,7 +28,9 @@ namespace emitc {
 /// with matching id are emitted.
 LogicalResult translateToCpp(Operation *op, raw_ostream &os,
                              bool declareVariablesAtTop = false,
-                             StringRef fileId = {});
+                             StringRef fileId = {}, bool emitClass = false,
+                             StringRef className = {},
+                             StringRef fieldNameAttribute = {});
 } // namespace emitc
 } // namespace mlir
 
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 2108ffd414c56..9e1533d34f6ea 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,13 +33,50 @@ void registerToCppTranslation() {
       "file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
       llvm::cl::init(""));
 
+  static llvm::cl::opt<bool> emitClass(
+      "emit-class",
+      llvm::cl::desc("If specified, the output will be a class where "
+                     "the function(s) in the module are methods "
+                     "Enables class-related options"),
+      llvm::cl::init(false));
+
+  static llvm::cl::opt<std::string> className(
+      "class-name",
+      llvm::cl::desc("Mandatory class name if --emit-class is set"),
+      llvm::cl::init(""));
+
+  static llvm::cl::opt<std::string> fieldNameAttribute(
+      "field-name-attribute",
+      llvm::cl::desc("Mandatory name of the attribute to use as field name if "
+                     "--emit-class is set(default=tf_saved_model.index_path)"),
+      llvm::cl::init("tf_saved_model.index_path"));
+
   TranslateFromMLIRRegistration reg(
       "mlir-to-cpp", "translate from mlir to cpp",
       [](Operation *op, raw_ostream &output) {
+        if (emitClass) {
+          if (className.empty()) {
+            llvm::errs() << "Error: --class-name is mandatory when "
+                            "--emit-class is set.\n";
+            return mlir::failure();
+          }
+          if (fieldNameAttribute.empty()) {
+            llvm::errs() << "Error: --field-name-attribute is mandatory when "
+                            "--emit-class is set.\n";
+            return mlir::failure();
+          }
+          return emitc::translateToCpp(
+              op, output,
+              /*declareVariablesAtTop=*/declareVariablesAtTop,
+              /*fileId=*/fileId, /*emitClass=*/emitClass,
+              /*className=*/className,
+              /*fieldNameAttribute=*/fieldNameAttribute);
+        }
         return emitc::translateToCpp(
             op, output,
             /*declareVariablesAtTop=*/declareVariablesAtTop,
-            /*fileId=*/fileId);
+            /*fileId=*/fileId, /*emitClass=*/emitClass, /*className=*/className,
+            /*fieldNameAttribute=*/fieldNameAttribute);
       },
       [](DialectRegistry &registry) {
         // clang-format off
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 0c4975a13d301..46891d0aca556 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -68,6 +68,14 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
   return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
 }
 
+template <typename Container, typename UnaryFunctor>
+inline LogicalResult interleaveWithNewLineWithError(const Container &c,
+                                              raw_ostream &os,
+                                              UnaryFunctor eachFn) {
+  return interleaveWithError(c.begin(), c.end(), eachFn,
+                             [&]() { os << ";\n"; });
+}
+
 /// Return the precedence of a operator as an integer, higher values
 /// imply higher precedence.
 static FailureOr<int> getOperatorPrecedence(Operation *operation) {
@@ -116,7 +124,8 @@ namespace {
 /// Emitter that uses dialect specific emitters to emit C++ code.
 struct CppEmitter {
   explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                      StringRef fileId);
+                      StringRef fileId, bool emitClass, StringRef className,
+                      StringRef fieldNameAttribute);
 
   /// Emits attribute or returns failure.
   LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -233,6 +242,15 @@ struct CppEmitter {
   /// be declared at the beginning of a function.
   bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
 
+  // Returns whether we should emit a C++ class
+  bool shouldPrintClass() { return emitClass; };
+
+  // Returns the class name to emit
+  std::string getClassName() { return className; };
+
+  // Returns the field name to use in the map
+  std::string getfieldNameAttribute() { return fieldNameAttribute; };
+
   /// Returns whether this file op should be emitted
   bool shouldEmitFile(FileOp file) {
     return !fileId.empty() && file.getId() == fileId;
@@ -268,6 +286,18 @@ struct CppEmitter {
   /// Only emit file ops whos id matches this value.
   std::string fileId;
 
+  /// Controls whether the output should be a C++ class.
+  /// If true, the generated C++ code will be encapsulated within a class,
+  /// and functions from the input module will become its member functions.
+  const bool emitClass;
+
+  /// The specified name for the generated C++ class
+  const std::string className;
+
+  /// Name of the MLIR attribute to use as a field name within the generated
+  /// class
+  const std::string fieldNameAttribute;
+
   /// Map from value to name of C++ variable that contain the name.
   ValueMapper valueMapper;
 
@@ -1025,6 +1055,17 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
       }));
 }
 
+static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
+                                 Region::BlockArgListType arguments) {
+  raw_indented_ostream &os = emitter.ostream();
+
+  return (interleaveWithNewLineWithError(
+      arguments, os, [&](BlockArgument arg) -> LogicalResult {
+        return emitter.emitVariableDeclaration(
+            functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
+      }));
+}
+
 static LogicalResult printFunctionBody(CppEmitter &emitter,
                                        Operation *functionOp,
                                        Region::BlockListType &blocks) {
@@ -1129,6 +1170,45 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return success();
 }
 
+static LogicalResult emitClassFields(CppEmitter &emitter,
+                                     emitc::FuncOp functionOp) {
+  raw_indented_ostream &os = emitter.ostream();
+  auto argAttrs = functionOp.getArgAttrs();
+  Operation *operation = functionOp.getOperation();
+  if (failed(printFields(emitter, operation, functionOp.getArguments())))
+    return failure();
+  os << ";\n";
+
+  std::map<std::string, Value> fields;
+  os << "\nstd::map<std::string, char*> _buffer_map {";
+  if (argAttrs) {
+    for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
+      if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
+        auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
+        auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
+        auto Ins = fields.insert({name, v});
+        if (!Ins.second)
+          return failure();
+        os << " { \"" << name << "\"" << ", reinterpret_cast<char*>("
+           << emitter.getOrCreateName(v) << ") }, ";
+      }
+    }
+  } else
+    return failure();
+
+  os << "};\n";
+  os << "char* getBufferForName(const std::string& name) const {\n";
+  os.indent();
+  os.indent();
+  os << "auto it = _buffer_map.find(name);\n";
+  os << "return (it == _buffer_map.end()) ? nullptr : it->second;\n";
+  os.unindent();
+  os.unindent();
+  os << "}\n\n";
+
+  return success();
+}
+
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::FuncOp functionOp) {
   // We need to declare variables at top if the function has multiple blocks.
@@ -1140,6 +1220,29 @@ static LogicalResult printOperation(CppEmitter &emitter,
 
   CppEmitter::Scope scope(emitter);
   raw_indented_ostream &os = emitter.ostream();
+  Operation *operation = functionOp.getOperation();
+  if (emitter.shouldPrintClass()) {
+    if (functionOp.isExternal()) {
+      // TODO: Determine the best long-term strategy for external functions.
+      // Currently, we're skipping over this functionOp.
+      // We have considered using emitWarning() which would return
+      // InFlightDiagnostic which seems can be automatically converted to LogicalResult since
+      // this is done in emitAttributes where emitError is converted to LogicalResult. However, it requires that we pass in a
+      // location which at first glance we don't have in this scope. Open to
+      // further discussion on this.
+      os << "Warning: Cannot process external function '"
+         << functionOp.getName() << "'. "
+         << "This functionOp lacks a body so we will skip over it.";
+      return success();
+    }
+    os << "class " << emitter.getClassName() << " final {\n";
+    os << "public: \n";
+    os.indent();
+
+    if (failed(emitClassFields(emitter, functionOp)))
+      return failure();
+  }
+
   if (functionOp.getSpecifiers()) {
     for (Attribute specifier : functionOp.getSpecifiersAttr()) {
       os << cast<StringAttr>(specifier).str() << " ";
@@ -1149,23 +1252,37 @@ static LogicalResult printOperation(CppEmitter &emitter,
   if (failed(emitter.emitTypes(functionOp.getLoc(),
                                functionOp.getFunctionType().getResults())))
     return failure();
+  // TODO: We may wanna consider having the name of the function be execute in
+  // the case that we want to emit a class instead of main. Leaving as is for
+  // now to make the change smaller.
   os << " " << functionOp.getName();
 
   os << "(";
-  Operation *operation = functionOp.getOperation();
-  if (functionOp.isExternal()) {
-    if (failed(printFunctionArgs(emitter, operation,
-                                 functionOp.getArgumentTypes())))
+
+  if (!emitter.shouldPrintClass()) {
+    if (functionOp.isExternal()) {
+      if (failed(printFunctionArgs(emitter, operation,
+                                   functionOp.getArgumentTypes())))
+        return failure();
+      os << ");";
+      return success();
+    }
+    if (failed(
+            printFunctionArgs(emitter, operation, functionOp.getArguments())))
       return failure();
-    os << ");";
-    return success();
   }
-  if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
-    return failure();
   os << ") {\n";
+
   if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
     return failure();
-  os << "}\n";
+
+  if (emitter.shouldPrintClass()) {
+    os << "}\n";
+    os.unindent();
+    os << "};\n";
+  } else {
+    os << "}\n";
+  }
 
   return success();
 }
@@ -1202,9 +1319,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
 }
 
 CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
-                       StringRef fileId)
+                       StringRef fileId, bool emitClass, StringRef className,
+                       StringRef fieldNameAttribute)
     : os(os), declareVariablesAtTop(declareVariablesAtTop),
-      fileId(fileId.str()) {
+      fileId(fileId.str()), emitClass(emitClass), className(className.str()),
+      fieldNameAttribute(fieldNameAttribute.str()) {
   valueInScopeCount.push(0);
   labelInScopeCount.push(0);
 }
@@ -1787,7 +1906,10 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
 
 LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
                                     bool declareVariablesAtTop,
-                                    StringRef fileId) {
-  CppEmitter emitter(os, declareVariablesAtTop, fileId);
+                                    StringRef fileId, bool emitClass,
+                                    StringRef className,
+                                    StringRef fieldNameAttribute) {
+  CppEmitter emitter(os, declareVariablesAtTop, fileId, emitClass, className,
+                     fieldNameAttribute);
   return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
 }
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
new file mode 100644
index 0000000000000..c34a1652abd3f
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-external.mlir
@@ -0,0 +1,8 @@
+/// An external function - has no body
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+  emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
+}
+
+// CHECK: Warning: Cannot process external function 'extern_func'. This functionOp lacks a body so we will skip over it.
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
new file mode 100644
index 0000000000000..6d43fa953a946
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
@@ -0,0 +1,15 @@
+/// The function has no argument attributes
+// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+  emitc.func @foo(%arg0 : i32) {
+    emitc.call_opaque "bar" (%arg0) : (i32) -> ()
+    emitc.return
+  }
+}
+
+// CHECK: class ArgAttrs final {
+// CHECK-NEXT: public: 
+// CHECK-NEXT:   int32_t v1;
+// CHECK-EMPTY: 
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map {
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
new file mode 100644
index 0000000000000..2779cb315ed41
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -0,0 +1,39 @@
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s  
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+  emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+    %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+    %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %2 = load %1 : <f32>
+    %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    %4 = load %3 : <f32>
+    %5 = add %2, %4 : (f32, f32) -> f32
+    %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+    assign %5 : f32 to %6 : <f32>
+    return
+  }
+}
+
+// CHECK: class MyAdder final {
+// CHECK-NEXT: public:
+// CHECK-NEXT:   float v1[1];
+// CHECK-NEXT:   float v2[1];
+// CHECK-NEXT:   float v3[1];
+// CHECK-EMPTY:
+// CHECK-NEXT:   std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, 
+// CHECK-SAME: { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
+// CHECK-NEXT:   char* getBufferForName(const std::string& name) const {
+// CHECK-NEXT:      auto it = _buffer_map.find(name);
+// CHECK-NEXT:      return (it == _buffer_map.end()) ? nullptr : it->second;
+// CHECK-NEXT:   }
+// CHECK-EMPTY:
+// CHECK-NEXT: void main() {
+// CHECK-NEXT:     size_t v4 = 0;
+// CHECK-NEXT:     float v5 = v2[v4];
+// CHECK-NEXT:     float v6 = v1[v4];
+// CHECK-NEXT:     float v7 = v5 + v6;
+// CHECK-NEXT:     v3[v4] = v7;
+// CHECK-NEXT:     return;
+// CHECK-NEXT:  }
+// CHECK-NEXT: };
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/141158


More information about the Mlir-commits mailing list