[Mlir-commits] [mlir] [emitC]Pass in `mlir-opt` to wrap a func in class (PR #141158)
Jaden Angella
llvmlistbot at llvm.org
Wed Jun 18 15:00:36 PDT 2025
Valentin Clement =?utf-8?b?KOODkOODrOODsw=?=,Jaddyen <ajaden at google.com>,Jaddyen
<ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaddyen
<ajaden at google.com>,Jaddyen <ajaden at google.com>,Jaddyen <ajaden at google.com>,
Valentin Clement =?utf-8?b?KOODkOODrOODsw=?=,Jaddyen <ajaden at google.com>,Jaddyen
<ajaden at google.com>,Jaddyen <ajaden at google.com>
Message-ID:
In-Reply-To: <llvm.org/llvm/llvm-project/pull/141158 at github.com>
https://github.com/Jaddyen updated https://github.com/llvm/llvm-project/pull/141158
>From c3867879aaf0273e45793b8b1f0a4d8ca5d81221 Mon Sep 17 00:00:00 2001
From: Jaden Jaden <ajaden at google.com>
Date: Wed, 21 May 2025 23:10:51 +0000
Subject: [PATCH 01/19] Able to get class when you set a class-Name
---
mlir/include/mlir/Target/Cpp/CppEmitter.h | 2 +-
mlir/lib/Target/Cpp/TranslateRegistration.cpp | 7 +-
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 168 +++++++++++++++---
3 files changed, 147 insertions(+), 30 deletions(-)
diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index 7c5747a888261..f87fc67662f55 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,7 +28,7 @@ namespace emitc {
/// with matching id are emitted.
LogicalResult translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop = false,
- StringRef fileId = {});
+ StringRef fileId = {}, StringRef className = {});
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 2108ffd414c56..d556927d7f904 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,13 +33,18 @@ void registerToCppTranslation() {
"file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
llvm::cl::init(""));
+ static llvm::cl::opt<std::string> className(
+ "class-name", llvm::cl::desc("Specify the class name for the generated C++ code"),
+ llvm::cl::init(""));
+
TranslateFromMLIRRegistration reg(
"mlir-to-cpp", "translate from mlir to cpp",
[](Operation *op, raw_ostream &output) {
return emitc::translateToCpp(
op, output,
/*declareVariablesAtTop=*/declareVariablesAtTop,
- /*fileId=*/fileId);
+ /*fileId=*/fileId,
+ /*className=*/className);
},
[](DialectRegistry ®istry) {
// clang-format off
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 5abc112ab8c7a..1e94eebae50bf 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -68,6 +69,13 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
}
+template <typename Container, typename UnaryFunctor>
+inline LogicalResult interleaveWithNewLineWithError(const Container &c,
+ raw_ostream &os,
+ UnaryFunctor eachFn) {
+ return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << "\n"; });
+}
+
/// Return the precedence of a operator as an integer, higher values
/// imply higher precedence.
static FailureOr<int> getOperatorPrecedence(Operation *operation) {
@@ -116,7 +124,7 @@ namespace {
/// Emitter that uses dialect specific emitters to emit C++ code.
struct CppEmitter {
explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
- StringRef fileId);
+ StringRef fileId, StringRef className);
/// Emits attribute or returns failure.
LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -233,6 +241,9 @@ struct CppEmitter {
/// be declared at the beginning of a function.
bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
+ // Returns if we should spit out a C++ class
+ std::string shouldSpitClass() { return className; };
+
/// Returns whether this file op should be emitted
bool shouldEmitFile(FileOp file) {
return !fileId.empty() && file.getId() == fileId;
@@ -268,6 +279,9 @@ struct CppEmitter {
/// Only emit file ops whos id matches this value.
std::string fileId;
+ /// Name of the C++ class we will spit
+ std::string className;
+
/// Map from value to name of C++ variable that contain the name.
ValueMapper valueMapper;
@@ -1033,6 +1047,32 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
}));
}
+static LogicalResult printAttributes(CppEmitter &emitter,
+ Operation *functionOp,
+ ArrayRef<Type> arguments) {
+ raw_indented_ostream &os = emitter.ostream();
+ for (auto arg : arguments) {
+ os << " ";
+ if (failed(emitter.emitType(functionOp->getLoc(), arg)))
+ return failure();
+ os << " " << arg << ";\n";
+ }
+ return success();
+}
+
+static LogicalResult printAttributes(CppEmitter &emitter,
+ Operation *functionOp,
+ Region::BlockArgListType arguments) {
+ raw_indented_ostream &os = emitter.ostream();
+ for (auto arg : arguments) {
+ os << " ";
+ if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
+ return failure();
+ os << " " << emitter.getOrCreateName(arg) << ";\n";
+ }
+ return success();
+}
+
static LogicalResult printFunctionBody(CppEmitter &emitter,
Operation *functionOp,
Region::BlockListType &blocks) {
@@ -1146,35 +1186,107 @@ static LogicalResult printOperation(CppEmitter &emitter,
"with multiple blocks needs variables declared at top");
}
- CppEmitter::Scope scope(emitter);
+ CppEmitter::Scope classScope(emitter);
raw_indented_ostream &os = emitter.ostream();
- if (functionOp.getSpecifiers()) {
- for (Attribute specifier : functionOp.getSpecifiersAttr()) {
- os << cast<StringAttr>(specifier).str() << " ";
+ Operation *operation = functionOp.getOperation();
+ if (!emitter.shouldSpitClass().empty()) {
+ os << "class " << emitter.shouldSpitClass() << " final {\n";
+ os << "public: \n";
+
+ if (functionOp.isExternal()) {
+ if (failed(printAttributes(emitter, operation,
+ functionOp.getArgumentTypes())))
+ return failure();
+ return success();
}
- }
+ if (failed(printAttributes(emitter, operation, functionOp.getArguments())))
+ return failure();
+
+ os << "\n";
+
+ auto argAttrs = functionOp.getArgAttrs();
+ std::map<std::string, Value> fields;
+ if (argAttrs)
+ for (auto [a,v] : zip(*argAttrs, functionOp.getArguments())) {
+ if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
+ auto nv = da.getNamed("tf_saved_model.index_path")->getValue();
+ auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
+ fields[name] = v;
+ }
+ }
- if (failed(emitter.emitTypes(functionOp.getLoc(),
- functionOp.getFunctionType().getResults())))
- return failure();
- os << " " << functionOp.getName();
+ for (auto & r : functionOp->getRegions())
+ for (auto &b : r.getBlocks())
+ for (auto &opt : b.getOperations())
+ if (auto alloc = dyn_cast<memref::AllocOp>(opt)) {
+ auto name = emitter.getOrCreateName(alloc).str();
+ fields[name] = alloc;
+ if (failed(emitter.emitType(alloc.getLoc(), alloc.getType().getElementType())))
+ return failure();
+ os << " [" << alloc.getType().getNumElements() <<"] ";
+ os << " " << name << ";\n";
+ }
+ os << " std::map<std::string, char*> _buffer_map {";
+ for (auto &[n,v]:fields)
+ os << "{ \"" << n << "\"" << ", reinterpret_cast<char*>(" << emitter.getOrCreateName(v) << ") },";
+ os << " };\n";
+ os << " char* getBufferForName(const std::string& name) const {\n";
+ os << " auto it = _buffer_map.find(name);\n";
+ os << " return (it == _buffer_map.end()) ? nullptr : it->second;\n";
+ os << " }\n\n";
- os << "(";
- Operation *operation = functionOp.getOperation();
- if (functionOp.isExternal()) {
- if (failed(printFunctionArgs(emitter, operation,
- functionOp.getArgumentTypes())))
+ os.indent();
+
+ // Begin defining the main function where we have the actual execution
+ if (functionOp.getSpecifiers()) {
+ for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+ os << cast<StringAttr>(specifier).str() << " ";
+ }
+ }
+
+ if (failed(emitter.emitTypes(functionOp.getLoc(),
+ functionOp.getFunctionType().getResults())))
return failure();
- os << ");";
- return success();
- }
- if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
- return failure();
- os << ") {\n";
- if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
- return failure();
- os << "}\n";
+ os << " " << functionOp.getName();
+ //Defining the functionBody
+ os << "() { \n"; // Begin defining the function header (We need the function name and output type to remain without the rest of the function header)
+ os.indent();
+ if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+ return failure();
+ os.unindent();
+ os << "}\n";
+ os.unindent();
+ os << "};\n";
+ } else {
+ if (functionOp.getSpecifiers()) {
+ for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+ os << cast<StringAttr>(specifier).str() << " ";
+ }
+ }
+
+ if (failed(emitter.emitTypes(functionOp.getLoc(),
+ functionOp.getFunctionType().getResults())))
+ return failure();
+ os << " " << functionOp.getName();
+
+ os << "(";
+ Operation *operation = functionOp.getOperation();
+ if (functionOp.isExternal()) {
+ if (failed(printFunctionArgs(emitter, operation,
+ functionOp.getArgumentTypes())))
+ return failure();
+ os << ");";
+ return success();
+ }
+ if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+ return failure();
+ os << ") {\n";
+ if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+ return failure();
+ os << "}\n";
+ }
+
return success();
}
@@ -1210,9 +1322,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
}
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
- StringRef fileId)
+ StringRef fileId, StringRef className)
: os(os), declareVariablesAtTop(declareVariablesAtTop),
- fileId(fileId.str()) {
+ fileId(fileId.str()), className(className.str()) {
valueInScopeCount.push(0);
labelInScopeCount.push(0);
}
@@ -1795,7 +1907,7 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop,
- StringRef fileId) {
- CppEmitter emitter(os, declareVariablesAtTop, fileId);
+ StringRef fileId, StringRef className) {
+ CppEmitter emitter(os, declareVariablesAtTop, fileId, className);
return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
}
>From 7a97f17c3c1270f5371b262c5cd75b9e79e4b633 Mon Sep 17 00:00:00 2001
From: Jaden Jaden <ajaden at google.com>
Date: Thu, 22 May 2025 23:17:00 +0000
Subject: [PATCH 02/19] fixes
---
mlir/lib/Target/Cpp/TranslateRegistration.cpp | 4 +-
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 233 ++++++++++--------
2 files changed, 128 insertions(+), 109 deletions(-)
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index d556927d7f904..17fcc058eef04 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -34,7 +34,9 @@ void registerToCppTranslation() {
llvm::cl::init(""));
static llvm::cl::opt<std::string> className(
- "class-name", llvm::cl::desc("Specify the class name for the generated C++ code"),
+ "class-name",
+ llvm::cl::desc("Optional class name. If specified, the output will be a "
+ "class where the function(s) in the module are members."),
llvm::cl::init(""));
TranslateFromMLIRRegistration reg(
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 1e94eebae50bf..2e8dfca7479b2 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -9,7 +9,6 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -73,7 +72,8 @@ template <typename Container, typename UnaryFunctor>
inline LogicalResult interleaveWithNewLineWithError(const Container &c,
raw_ostream &os,
UnaryFunctor eachFn) {
- return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << "\n"; });
+ return interleaveWithError(c.begin(), c.end(), eachFn,
+ [&]() { os << ";\n"; });
}
/// Return the precedence of a operator as an integer, higher values
@@ -241,8 +241,11 @@ struct CppEmitter {
/// be declared at the beginning of a function.
bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
- // Returns if we should spit out a C++ class
- std::string shouldSpitClass() { return className; };
+ // Returns whether we should emit a C++ class
+ bool shouldPrintClass() { return !className.empty(); };
+
+ // Returns the class name to emit
+ std::string getClassName() { return className; };
/// Returns whether this file op should be emitted
bool shouldEmitFile(FileOp file) {
@@ -1047,30 +1050,25 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
}));
}
-static LogicalResult printAttributes(CppEmitter &emitter,
- Operation *functionOp,
- ArrayRef<Type> arguments) {
+static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
+ ArrayRef<Type> arguments) {
raw_indented_ostream &os = emitter.ostream();
- for (auto arg : arguments) {
- os << " ";
- if (failed(emitter.emitType(functionOp->getLoc(), arg)))
- return failure();
- os << " " << arg << ";\n";
- }
- return success();
+
+ return (interleaveWithNewLineWithError(
+ arguments, os, [&](Type arg) -> LogicalResult {
+ return emitter.emitType(functionOp->getLoc(), arg);
+ }));
}
-static LogicalResult printAttributes(CppEmitter &emitter,
- Operation *functionOp,
- Region::BlockArgListType arguments) {
+static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
+ Region::BlockArgListType arguments) {
raw_indented_ostream &os = emitter.ostream();
- for (auto arg : arguments) {
- os << " ";
- if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
- return failure();
- os << " " << emitter.getOrCreateName(arg) << ";\n";
- }
- return success();
+
+ return (interleaveWithNewLineWithError(
+ arguments, os, [&](BlockArgument arg) -> LogicalResult {
+ return emitter.emitVariableDeclaration(
+ functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
+ }));
}
static LogicalResult printFunctionBody(CppEmitter &emitter,
@@ -1177,116 +1175,135 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
-static LogicalResult printOperation(CppEmitter &emitter,
- emitc::FuncOp functionOp) {
- // We need to declare variables at top if the function has multiple blocks.
- if (!emitter.shouldDeclareVariablesAtTop() &&
- functionOp.getBlocks().size() > 1) {
- return functionOp.emitOpError(
- "with multiple blocks needs variables declared at top");
- }
-
- CppEmitter::Scope classScope(emitter);
+static LogicalResult printFunctionHeader(CppEmitter &emitter,
+ emitc::FuncOp functionOp) {
raw_indented_ostream &os = emitter.ostream();
Operation *operation = functionOp.getOperation();
- if (!emitter.shouldSpitClass().empty()) {
- os << "class " << emitter.shouldSpitClass() << " final {\n";
- os << "public: \n";
+ if (functionOp.getSpecifiers()) {
+ for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+ os << cast<StringAttr>(specifier).str() << " ";
+ }
+ }
- if (functionOp.isExternal()) {
- if (failed(printAttributes(emitter, operation,
- functionOp.getArgumentTypes())))
+ if (failed(emitter.emitTypes(functionOp.getLoc(),
+ functionOp.getFunctionType().getResults())))
+ return failure();
+ os << " " << functionOp.getName();
+ if (!emitter.shouldPrintClass()) {
+ os << "(";
+ if (functionOp.isExternal()) {
+ if (failed(printFunctionArgs(emitter, operation,
+ functionOp.getArgumentTypes())))
return failure();
+ os << ");";
return success();
}
- if (failed(printAttributes(emitter, operation, functionOp.getArguments())))
+ if (failed(
+ printFunctionArgs(emitter, operation, functionOp.getArguments())))
return failure();
-
- os << "\n";
-
- auto argAttrs = functionOp.getArgAttrs();
- std::map<std::string, Value> fields;
- if (argAttrs)
- for (auto [a,v] : zip(*argAttrs, functionOp.getArguments())) {
- if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
- auto nv = da.getNamed("tf_saved_model.index_path")->getValue();
- auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
- fields[name] = v;
- }
- }
+ os << ") {\n";
- for (auto & r : functionOp->getRegions())
- for (auto &b : r.getBlocks())
- for (auto &opt : b.getOperations())
- if (auto alloc = dyn_cast<memref::AllocOp>(opt)) {
- auto name = emitter.getOrCreateName(alloc).str();
- fields[name] = alloc;
- if (failed(emitter.emitType(alloc.getLoc(), alloc.getType().getElementType())))
- return failure();
- os << " [" << alloc.getType().getNumElements() <<"] ";
- os << " " << name << ";\n";
- }
- os << " std::map<std::string, char*> _buffer_map {";
- for (auto &[n,v]:fields)
- os << "{ \"" << n << "\"" << ", reinterpret_cast<char*>(" << emitter.getOrCreateName(v) << ") },";
- os << " };\n";
- os << " char* getBufferForName(const std::string& name) const {\n";
- os << " auto it = _buffer_map.find(name);\n";
- os << " return (it == _buffer_map.end()) ? nullptr : it->second;\n";
- os << " }\n\n";
+ } else {
+ os << "() { \n";
+ }
- os.indent();
+ return success();
+}
- // Begin defining the main function where we have the actual execution
- if (functionOp.getSpecifiers()) {
- for (Attribute specifier : functionOp.getSpecifiersAttr()) {
- os << cast<StringAttr>(specifier).str() << " ";
+static LogicalResult emitClassBody(CppEmitter &emitter,
+ emitc::FuncOp functionOp) {
+ raw_indented_ostream &os = emitter.ostream();
+ Operation *operation = functionOp.getOperation();
+ auto argAttrs = functionOp.getArgAttrs();
+ std::map<std::string, Value> fields;
+ os << "\nstd::map<std::string, char*> _buffer_map {";
+ if (argAttrs) // We can have no argattrs in the case that the function has no
+ // inputs nor outputs -> procedure
+ for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
+ if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
+ auto nv =
+ da.getNamed("tf_saved_model.index_path")
+ ->getValue(); // From what I've seen so far, this is the only
+ // way to have the argAttrs keys. If there is
+ // another way, I need to run the tests to see and
+ // see what cases trigger this change in format.
+ auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
+ fields[name] = v; // The only way to not have unique names is in the
+ // case that you have duplicate arguments in your
+ // tensorflow/python function. By python syntax rules,
+ // you're not allowed to have that(Current assumption)
+ os << "{ \"" << name << "\"" << ", reinterpret_cast<char*>("
+ << emitter.getOrCreateName(v) << ") },";
}
}
+ else
+ return failure();
+
+ os << " };\n";
+ os << "char* getBufferForName(const std::string& name) const {\n";
+ os.indent();
+ os.indent();
+ os << "auto it = _buffer_map.find(name);\n";
+ os << "return (it == _buffer_map.end()) ? nullptr : it->second;\n";
+ os.unindent();
+ os.unindent();
+ os << "}\n\n";
+
+ if (failed(printFunctionHeader(emitter, functionOp)))
+ return failure();
+
+ if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+ return failure();
+ os << "}\n";
+ os.unindent();
+ os << "};\n";
+
+ return success();
+}
+
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::FuncOp functionOp) {
+ // We need to declare variables at top if the function has multiple blocks.
+ if (!emitter.shouldDeclareVariablesAtTop() &&
+ functionOp.getBlocks().size() > 1) {
+ return functionOp.emitOpError(
+ "with multiple blocks needs variables declared at top");
+ }
- if (failed(emitter.emitTypes(functionOp.getLoc(),
- functionOp.getFunctionType().getResults())))
+ CppEmitter::Scope classScope(emitter);
+ raw_indented_ostream &os = emitter.ostream();
+ Operation *operation = functionOp.getOperation();
+ if (!emitter.shouldPrintClass()) {
+
+ if (failed(printFunctionHeader(emitter, functionOp)))
return failure();
- os << " " << functionOp.getName();
- //Defining the functionBody
- os << "() { \n"; // Begin defining the function header (We need the function name and output type to remain without the rest of the function header)
- os.indent();
- if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+ if (failed(printFunctionBody(
+ emitter, operation,
+ functionOp.getBlocks()))) // This is the only similarity between the
+ // function and the class
return failure();
- os.unindent();
os << "}\n";
- os.unindent();
- os << "};\n";
- } else {
- if (functionOp.getSpecifiers()) {
- for (Attribute specifier : functionOp.getSpecifiersAttr()) {
- os << cast<StringAttr>(specifier).str() << " ";
- }
- }
- if (failed(emitter.emitTypes(functionOp.getLoc(),
- functionOp.getFunctionType().getResults())))
- return failure();
- os << " " << functionOp.getName();
+ } else {
+ os << "class " << emitter.getClassName() << " final {\n";
+ os << "public: \n";
+ os.indent();
- os << "(";
- Operation *operation = functionOp.getOperation();
if (functionOp.isExternal()) {
- if (failed(printFunctionArgs(emitter, operation,
- functionOp.getArgumentTypes())))
+ if (failed(
+ printFields(emitter, operation, functionOp.getArgumentTypes())))
return failure();
- os << ");";
return success();
}
- if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+ if (failed(printFields(emitter, operation, functionOp.getArguments())))
return failure();
- os << ") {\n";
- if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+ os << ";\n";
+
+ if (failed(emitClassBody(emitter, functionOp)))
return failure();
- os << "}\n";
}
-
+
return success();
}
>From 23e84353689078ec1ae4ae509d952201113eb58b Mon Sep 17 00:00:00 2001
From: Jaden Jaden <ajaden at google.com>
Date: Tue, 27 May 2025 18:39:14 +0000
Subject: [PATCH 03/19] Added flags to increase readability
---
mlir/include/mlir/Target/Cpp/CppEmitter.h | 4 +-
mlir/lib/Target/Cpp/TranslateRegistration.cpp | 38 +++-
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 181 ++++++++----------
mlir/test/mlir-translate/emit-class.mlir | 49 +++++
4 files changed, 170 insertions(+), 102 deletions(-)
create mode 100644 mlir/test/mlir-translate/emit-class.mlir
diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index f87fc67662f55..d1a6c1dc12d4c 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,7 +28,9 @@ namespace emitc {
/// with matching id are emitted.
LogicalResult translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop = false,
- StringRef fileId = {}, StringRef className = {});
+ StringRef fileId = {}, bool emitClass = false,
+ StringRef className = {},
+ StringRef fieldNameAttribute = {});
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 17fcc058eef04..69e0ab01bb71d 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,20 +33,50 @@ void registerToCppTranslation() {
"file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
llvm::cl::init(""));
+ static llvm::cl::opt<bool> emitClass(
+ "emit-class",
+ llvm::cl::desc("If specified, the output will be a class where "
+ "the function(s) in the module are members. "
+ "Enables class-related options."),
+ llvm::cl::init(false));
+
static llvm::cl::opt<std::string> className(
"class-name",
- llvm::cl::desc("Optional class name. If specified, the output will be a "
- "class where the function(s) in the module are members."),
+ llvm::cl::desc("Mandatory class name if --emit-class is set."),
+ llvm::cl::init(""));
+
+ static llvm::cl::opt<std::string> fieldNameAttribute(
+ "field-name-attribute",
+ llvm::cl::desc("Mandatory name of the attribute to use as field name if "
+ "--emit-class is set."),
llvm::cl::init(""));
TranslateFromMLIRRegistration reg(
"mlir-to-cpp", "translate from mlir to cpp",
[](Operation *op, raw_ostream &output) {
+ if (emitClass) {
+ if (className.empty()) {
+ llvm::errs() << "Error: --class-name is mandatory when "
+ "--emit-class is set.\n";
+ return mlir::failure();
+ }
+ if (fieldNameAttribute.empty()) {
+ llvm::errs() << "Error: --field-name-attribute is mandatory when "
+ "--emit-class is set.\n";
+ return mlir::failure();
+ }
+ return emitc::translateToCpp(
+ op, output,
+ /*declareVariablesAtTop=*/declareVariablesAtTop,
+ /*fileId=*/fileId, /*emitClass=*/emitClass,
+ /*className=*/className,
+ /*fieldNameAttribute=*/fieldNameAttribute);
+ }
return emitc::translateToCpp(
op, output,
/*declareVariablesAtTop=*/declareVariablesAtTop,
- /*fileId=*/fileId,
- /*className=*/className);
+ /*fileId=*/fileId, /*emitClass=*/emitClass, /*className=*/className,
+ /*fieldNameAttribute=*/fieldNameAttribute);
},
[](DialectRegistry ®istry) {
// clang-format off
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 2e8dfca7479b2..ef6583caf2d46 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -22,6 +22,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <stack>
@@ -124,7 +125,8 @@ namespace {
/// Emitter that uses dialect specific emitters to emit C++ code.
struct CppEmitter {
explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
- StringRef fileId, StringRef className);
+ StringRef fileId, bool emitClass, StringRef className,
+ StringRef fieldNameAttribute);
/// Emits attribute or returns failure.
LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -242,11 +244,14 @@ struct CppEmitter {
bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
// Returns whether we should emit a C++ class
- bool shouldPrintClass() { return !className.empty(); };
+ bool shouldPrintClass() { return emitClass; };
// Returns the class name to emit
std::string getClassName() { return className; };
+ // Returns the field name to use in the map
+ std::string getfieldNameAttribute() { return fieldNameAttribute; };
+
/// Returns whether this file op should be emitted
bool shouldEmitFile(FileOp file) {
return !fileId.empty() && file.getId() == fileId;
@@ -282,9 +287,18 @@ struct CppEmitter {
/// Only emit file ops whos id matches this value.
std::string fileId;
- /// Name of the C++ class we will spit
+ /// Controls whether the output should be a C++ class.
+ /// If true, the generated C++ code will be encapsulated within a class,
+ /// and functions from the input module will become its member functions.
+ bool emitClass;
+
+ /// The specified name for the generated C++ class
std::string className;
+ /// Name of the MLIR attribute to use as a field name within the generated
+ /// class
+ std::string fieldNameAttribute;
+
/// Map from value to name of C++ variable that contain the name.
ValueMapper valueMapper;
@@ -1050,16 +1064,6 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
}));
}
-static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
- ArrayRef<Type> arguments) {
- raw_indented_ostream &os = emitter.ostream();
-
- return (interleaveWithNewLineWithError(
- arguments, os, [&](Type arg) -> LogicalResult {
- return emitter.emitType(functionOp->getLoc(), arg);
- }));
-}
-
static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
Region::BlockArgListType arguments) {
raw_indented_ostream &os = emitter.ostream();
@@ -1175,71 +1179,38 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
-static LogicalResult printFunctionHeader(CppEmitter &emitter,
- emitc::FuncOp functionOp) {
+static LogicalResult emitClassFields(CppEmitter &emitter,
+ emitc::FuncOp functionOp) {
raw_indented_ostream &os = emitter.ostream();
+ auto argAttrs = functionOp.getArgAttrs();
Operation *operation = functionOp.getOperation();
- if (functionOp.getSpecifiers()) {
- for (Attribute specifier : functionOp.getSpecifiersAttr()) {
- os << cast<StringAttr>(specifier).str() << " ";
- }
- }
-
- if (failed(emitter.emitTypes(functionOp.getLoc(),
- functionOp.getFunctionType().getResults())))
+ if (failed(printFields(emitter, operation, functionOp.getArguments())))
return failure();
- os << " " << functionOp.getName();
- if (!emitter.shouldPrintClass()) {
- os << "(";
- if (functionOp.isExternal()) {
- if (failed(printFunctionArgs(emitter, operation,
- functionOp.getArgumentTypes())))
- return failure();
- os << ");";
- return success();
- }
- if (failed(
- printFunctionArgs(emitter, operation, functionOp.getArguments())))
- return failure();
- os << ") {\n";
+ os << ";\n";
- } else {
- os << "() { \n";
- }
-
- return success();
-}
-
-static LogicalResult emitClassBody(CppEmitter &emitter,
- emitc::FuncOp functionOp) {
- raw_indented_ostream &os = emitter.ostream();
- Operation *operation = functionOp.getOperation();
- auto argAttrs = functionOp.getArgAttrs();
std::map<std::string, Value> fields;
- os << "\nstd::map<std::string, char*> _buffer_map {";
- if (argAttrs) // We can have no argattrs in the case that the function has no
- // inputs nor outputs -> procedure
+ os << "std::map<std::string, char*> _buffer_map {";
+ if (argAttrs) {
+ bool isFirst = true;
for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
- auto nv =
- da.getNamed("tf_saved_model.index_path")
- ->getValue(); // From what I've seen so far, this is the only
- // way to have the argAttrs keys. If there is
- // another way, I need to run the tests to see and
- // see what cases trigger this change in format.
+ auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
- fields[name] = v; // The only way to not have unique names is in the
- // case that you have duplicate arguments in your
- // tensorflow/python function. By python syntax rules,
- // you're not allowed to have that(Current assumption)
+ auto Ins = fields.insert({name, v});
+ if (!Ins.second)
+ return failure();
+ if (!isFirst) {
+ os << ",";
+ }
os << "{ \"" << name << "\"" << ", reinterpret_cast<char*>("
- << emitter.getOrCreateName(v) << ") },";
+ << emitter.getOrCreateName(v) << ") }";
+ isFirst = false;
}
}
- else
+ } else
return failure();
- os << " };\n";
+ os << "};";
os << "char* getBufferForName(const std::string& name) const {\n";
os.indent();
os.indent();
@@ -1249,15 +1220,6 @@ static LogicalResult emitClassBody(CppEmitter &emitter,
os.unindent();
os << "}\n\n";
- if (failed(printFunctionHeader(emitter, functionOp)))
- return failure();
-
- if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
- return failure();
- os << "}\n";
- os.unindent();
- os << "};\n";
-
return success();
}
@@ -1270,38 +1232,58 @@ static LogicalResult printOperation(CppEmitter &emitter,
"with multiple blocks needs variables declared at top");
}
- CppEmitter::Scope classScope(emitter);
+ CppEmitter::Scope scope(emitter);
raw_indented_ostream &os = emitter.ostream();
Operation *operation = functionOp.getOperation();
- if (!emitter.shouldPrintClass()) {
-
- if (failed(printFunctionHeader(emitter, functionOp)))
- return failure();
-
- if (failed(printFunctionBody(
- emitter, operation,
- functionOp.getBlocks()))) // This is the only similarity between the
- // function and the class
+ if (emitter.shouldPrintClass()) {
+ if (functionOp.isExternal())
return failure();
- os << "}\n";
-
- } else {
os << "class " << emitter.getClassName() << " final {\n";
os << "public: \n";
os.indent();
+ if (failed(emitClassFields(emitter, functionOp)))
+ return failure();
+ }
+
+ if (functionOp.getSpecifiers()) {
+ for (Attribute specifier : functionOp.getSpecifiersAttr()) {
+ os << cast<StringAttr>(specifier).str() << " ";
+ }
+ }
+
+ if (failed(emitter.emitTypes(functionOp.getLoc(),
+ functionOp.getFunctionType().getResults())))
+ return failure();
+ os << " " << functionOp.getName();
+
+ os << "(";
+
+ if (emitter.shouldPrintClass())
+ os << ") { \n";
+ else {
if (functionOp.isExternal()) {
- if (failed(
- printFields(emitter, operation, functionOp.getArgumentTypes())))
+ if (failed(printFunctionArgs(emitter, operation,
+ functionOp.getArgumentTypes())))
return failure();
+ os << ");";
return success();
}
- if (failed(printFields(emitter, operation, functionOp.getArguments())))
+ if (failed(
+ printFunctionArgs(emitter, operation, functionOp.getArguments())))
return failure();
- os << ";\n";
+ os << ") {\n";
+ }
- if (failed(emitClassBody(emitter, functionOp)))
- return failure();
+ if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
+ return failure();
+
+ if (emitter.shouldPrintClass()) {
+ os << "}\n";
+ os.unindent();
+ os << "};\n";
+ } else {
+ os << "}\n";
}
return success();
@@ -1339,9 +1321,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
}
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
- StringRef fileId, StringRef className)
+ StringRef fileId, bool emitClass, StringRef className,
+ StringRef fieldNameAttribute)
: os(os), declareVariablesAtTop(declareVariablesAtTop),
- fileId(fileId.str()), className(className.str()) {
+ fileId(fileId.str()), emitClass(emitClass), className(className.str()),
+ fieldNameAttribute(fieldNameAttribute.str()) {
valueInScopeCount.push(0);
labelInScopeCount.push(0);
}
@@ -1924,7 +1908,10 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop,
- StringRef fileId, StringRef className) {
- CppEmitter emitter(os, declareVariablesAtTop, fileId, className);
+ StringRef fileId, bool emitClass,
+ StringRef className,
+ StringRef fieldNameAttribute) {
+ CppEmitter emitter(os, declareVariablesAtTop, fileId, emitClass, className,
+ fieldNameAttribute);
return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
}
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
new file mode 100644
index 0000000000000..eddc2181a0b19
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path /tmp/model_emitc.mlir | FileCheck %s --check-prefix=ADDER_TEST
+
+// ADDER_TEST-LABEL: class MyAdder final {
+// ADDER_TEST-NEXT: public:
+// ADDER_TEST-DAG: float v1[1];
+// ADDER_TEST-DAG: float v2[1];
+// ADDER_TEST-DAG: float v3[1];
+// ADDER_TEST-NEXT: std::map<std::string, char*> _buffer_map {{ "another_feature", reinterpret_cast<char*>(v1) },{ "some_feature", reinterpret_cast<char*>(v2) },{ "output_0", reinterpret_cast<char*>(v3) }};
+// ADDER_TEST-NEXT: char* getBufferForName(const std::string& name) const {
+// ADDER_TEST-NEXT: auto it = _buffer_map.find(name);
+// ADDER_TEST-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second;
+// ADDER_TEST-NEXT: }
+// ADDER_TEST-NEXT: void main() {
+// ADDER_TEST-NEXT: size_t v4 = 0;
+// ADDER_TEST-NEXT: float v5 = v2[v4];
+// ADDER_TEST-NEXT: float v6 = v1[v4];
+// ADDER_TEST-NEXT: float v7 = v5 + v6;
+// ADDER_TEST-NEXT: v3[v4] = v7;
+// ADDER_TEST-NEXT: return;
+// ADDER_TEST-NEXT: }
+// ADDER_TEST-NEXT: };
+
+// ---
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyMultiOutput --field-name-attribute=tf_saved_model.index_path /tmp/model_multi_out_emitc.mlir | FileCheck %s --check-prefix=MULTI_OUT
+
+// MULTI_OUT-LABEL: class MyMultiOutput final {
+// MULTI_OUT-NEXT: public:
+// MULTI_OUT-DAG: float v1[1];
+// MULTI_OUT-DAG: float v2[1];
+// MULTI_OUT-DAG: float v3[1];
+// MULTI_OUT-DAG: float v4[1];
+// MULTI_OUT: std::map<std::string, char*> _buffer_map {{ "b", reinterpret_cast<char*>(v1) },{ "a", reinterpret_cast<char*>(v2) },{ "output_1", reinterpret_cast<char*>(v3) },{ "output_0", reinterpret_cast<char*>(v4) }, };
+// MULTI_OUT-NEXT: char* getBufferForName(const std::string& name) const {
+// MULTI_OUT-NEXT: auto it = _buffer_map.find(name);
+// MULTI_OUT-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second;
+// MULTI_OUT-NEXT: }
+// MULTI_OUT-NEXT: void main() {
+// MULTI_OUT-NEXT: size_t v5 = 0;
+// MULTI_OUT-NEXT: float v6 = v2[v5];
+// MULTI_OUT-NEXT: float v7 = v1[v5];
+// MULTI_OUT-NEXT: float v8 = v6 + v7;
+// MULTI_OUT-NEXT: v4[v5] = v8;
+// MULTI_OUT-NEXT: float v9 = v2[v5];
+// MULTI_OUT-NEXT: float v10 = v1[v5];
+// MULTI_OUT-NEXT: float v11 = v9 - v10;
+// MULTI_OUT-NEXT: v3[v5] = v11;
+// MULTI_OUT-NEXT: return;
+// MULTI_OUT-NEXT: }
+// MULTI_OUT-NEXT: };
>From 0c55dcf1f7cf314b21b6caf67d6d10000d986c1a Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Wed, 28 May 2025 21:33:44 +0000
Subject: [PATCH 04/19] Working tests to emit-class
---
mlir/lib/Target/Cpp/TranslateRegistration.cpp | 10 +--
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 35 ++++----
.../emit-class-neg-external.mlir | 8 ++
.../emit-class-neg-noArgAttrs.mlir | 15 ++++
mlir/test/mlir-translate/emit-class.mlir | 81 ++++++++-----------
5 files changed, 80 insertions(+), 69 deletions(-)
create mode 100644 mlir/test/mlir-translate/emit-class-neg-external.mlir
create mode 100644 mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 69e0ab01bb71d..9e1533d34f6ea 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -36,20 +36,20 @@ void registerToCppTranslation() {
static llvm::cl::opt<bool> emitClass(
"emit-class",
llvm::cl::desc("If specified, the output will be a class where "
- "the function(s) in the module are members. "
- "Enables class-related options."),
+ "the function(s) in the module are methods "
+ "Enables class-related options"),
llvm::cl::init(false));
static llvm::cl::opt<std::string> className(
"class-name",
- llvm::cl::desc("Mandatory class name if --emit-class is set."),
+ llvm::cl::desc("Mandatory class name if --emit-class is set"),
llvm::cl::init(""));
static llvm::cl::opt<std::string> fieldNameAttribute(
"field-name-attribute",
llvm::cl::desc("Mandatory name of the attribute to use as field name if "
- "--emit-class is set."),
- llvm::cl::init(""));
+ "--emit-class is set(default=tf_saved_model.index_path)"),
+ llvm::cl::init("tf_saved_model.index_path"));
TranslateFromMLIRRegistration reg(
"mlir-to-cpp", "translate from mlir to cpp",
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index ef6583caf2d46..a819e550ad385 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -22,7 +22,6 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <stack>
@@ -290,14 +289,14 @@ struct CppEmitter {
/// Controls whether the output should be a C++ class.
/// If true, the generated C++ code will be encapsulated within a class,
/// and functions from the input module will become its member functions.
- bool emitClass;
+ const bool emitClass;
/// The specified name for the generated C++ class
- std::string className;
+ const std::string className;
/// Name of the MLIR attribute to use as a field name within the generated
/// class
- std::string fieldNameAttribute;
+ const std::string fieldNameAttribute;
/// Map from value to name of C++ variable that contain the name.
ValueMapper valueMapper;
@@ -1189,9 +1188,8 @@ static LogicalResult emitClassFields(CppEmitter &emitter,
os << ";\n";
std::map<std::string, Value> fields;
- os << "std::map<std::string, char*> _buffer_map {";
+ os << "\nstd::map<std::string, char*> _buffer_map {";
if (argAttrs) {
- bool isFirst = true;
for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
@@ -1199,18 +1197,14 @@ static LogicalResult emitClassFields(CppEmitter &emitter,
auto Ins = fields.insert({name, v});
if (!Ins.second)
return failure();
- if (!isFirst) {
- os << ",";
- }
- os << "{ \"" << name << "\"" << ", reinterpret_cast<char*>("
- << emitter.getOrCreateName(v) << ") }";
- isFirst = false;
+ os << " { \"" << name << "\"" << ", reinterpret_cast<char*>("
+ << emitter.getOrCreateName(v) << ") }, ";
}
}
} else
return failure();
- os << "};";
+ os << "};\n";
os << "char* getBufferForName(const std::string& name) const {\n";
os.indent();
os.indent();
@@ -1236,8 +1230,15 @@ static LogicalResult printOperation(CppEmitter &emitter,
raw_indented_ostream &os = emitter.ostream();
Operation *operation = functionOp.getOperation();
if (emitter.shouldPrintClass()) {
- if (functionOp.isExternal())
+ if (functionOp.isExternal()) {
+ // TODO: Determine the best long-term strategy for external functions.
+ // Currently, we're stopping here to prevent downstream errors.
+ os << "Warning: Cannot process external function '"
+ << functionOp.getName() << "'. "
+ << "It lacks a body, and attempting to continue would lead to errors "
+ "due to missing argument details.\n";
return failure();
+ }
os << "class " << emitter.getClassName() << " final {\n";
os << "public: \n";
os.indent();
@@ -1259,9 +1260,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
os << "(";
- if (emitter.shouldPrintClass())
- os << ") { \n";
- else {
+ if (!emitter.shouldPrintClass()) {
if (functionOp.isExternal()) {
if (failed(printFunctionArgs(emitter, operation,
functionOp.getArgumentTypes())))
@@ -1272,8 +1271,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (failed(
printFunctionArgs(emitter, operation, functionOp.getArguments())))
return failure();
- os << ") {\n";
}
+ os << ") {\n";
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
new file mode 100644
index 0000000000000..347992c467759
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-external.mlir
@@ -0,0 +1,8 @@
+/// An external function - has no body
+// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s 2>&1 | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+ emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
+}
+
+// CHECK: Warning: Cannot process external function 'extern_func'. It lacks a body, and attempting to continue would lead to errors due to missing argument details.
\ No newline at end of file
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
new file mode 100644
index 0000000000000..77e89da2f0a4f
--- /dev/null
+++ b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
@@ -0,0 +1,15 @@
+/// The function has no argument attributes
+// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+ emitc.func @foo(%arg0 : i32) {
+ emitc.call_opaque "bar" (%arg0) : (i32) -> ()
+ emitc.return
+ }
+}
+
+// CHECK: class ArgAttrs final {
+// CHECK-NEXT: public:
+// CHECK-NEXT: int32_t v1;
+// CHECK-EMPTY:
+// CHECK-NEXT: std::map<std::string, char*> _buffer_map {
\ No newline at end of file
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
index eddc2181a0b19..9fa69f652df6c 100644
--- a/mlir/test/mlir-translate/emit-class.mlir
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -1,49 +1,38 @@
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path /tmp/model_emitc.mlir | FileCheck %s --check-prefix=ADDER_TEST
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
-// ADDER_TEST-LABEL: class MyAdder final {
-// ADDER_TEST-NEXT: public:
-// ADDER_TEST-DAG: float v1[1];
-// ADDER_TEST-DAG: float v2[1];
-// ADDER_TEST-DAG: float v3[1];
-// ADDER_TEST-NEXT: std::map<std::string, char*> _buffer_map {{ "another_feature", reinterpret_cast<char*>(v1) },{ "some_feature", reinterpret_cast<char*>(v2) },{ "output_0", reinterpret_cast<char*>(v3) }};
-// ADDER_TEST-NEXT: char* getBufferForName(const std::string& name) const {
-// ADDER_TEST-NEXT: auto it = _buffer_map.find(name);
-// ADDER_TEST-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second;
-// ADDER_TEST-NEXT: }
-// ADDER_TEST-NEXT: void main() {
-// ADDER_TEST-NEXT: size_t v4 = 0;
-// ADDER_TEST-NEXT: float v5 = v2[v4];
-// ADDER_TEST-NEXT: float v6 = v1[v4];
-// ADDER_TEST-NEXT: float v7 = v5 + v6;
-// ADDER_TEST-NEXT: v3[v4] = v7;
-// ADDER_TEST-NEXT: return;
-// ADDER_TEST-NEXT: }
-// ADDER_TEST-NEXT: };
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+ emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %2 = load %1 : <f32>
+ %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %4 = load %3 : <f32>
+ %5 = add %2, %4 : (f32, f32) -> f32
+ %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ assign %5 : f32 to %6 : <f32>
+ return
+ }
+}
-// ---
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyMultiOutput --field-name-attribute=tf_saved_model.index_path /tmp/model_multi_out_emitc.mlir | FileCheck %s --check-prefix=MULTI_OUT
+// CHECK: class MyAdder final {
+// CHECK-NEXT: public:
+// CHECK-NEXT: float v1[1];
+// CHECK-NEXT: float v2[1];
+// CHECK-NEXT: float v3[1];
+// CHECK-EMPTY:
+// CHECK-NEXT: std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
+// CHECK-NEXT: char* getBufferForName(const std::string& name) const {
+// CHECK-NEXT: auto it = _buffer_map.find(name);
+// CHECK-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second;
+// CHECK-NEXT: }
+// CHECK-EMPTY:
+// CHECK-NEXT: void main() {
+// CHECK-NEXT: size_t v4 = 0;
+// CHECK-NEXT: float v5 = v2[v4];
+// CHECK-NEXT: float v6 = v1[v4];
+// CHECK-NEXT: float v7 = v5 + v6;
+// CHECK-NEXT: v3[v4] = v7;
+// CHECK-NEXT: return;
+// CHECK-NEXT: }
+// CHECK-NEXT: };
-// MULTI_OUT-LABEL: class MyMultiOutput final {
-// MULTI_OUT-NEXT: public:
-// MULTI_OUT-DAG: float v1[1];
-// MULTI_OUT-DAG: float v2[1];
-// MULTI_OUT-DAG: float v3[1];
-// MULTI_OUT-DAG: float v4[1];
-// MULTI_OUT: std::map<std::string, char*> _buffer_map {{ "b", reinterpret_cast<char*>(v1) },{ "a", reinterpret_cast<char*>(v2) },{ "output_1", reinterpret_cast<char*>(v3) },{ "output_0", reinterpret_cast<char*>(v4) }, };
-// MULTI_OUT-NEXT: char* getBufferForName(const std::string& name) const {
-// MULTI_OUT-NEXT: auto it = _buffer_map.find(name);
-// MULTI_OUT-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second;
-// MULTI_OUT-NEXT: }
-// MULTI_OUT-NEXT: void main() {
-// MULTI_OUT-NEXT: size_t v5 = 0;
-// MULTI_OUT-NEXT: float v6 = v2[v5];
-// MULTI_OUT-NEXT: float v7 = v1[v5];
-// MULTI_OUT-NEXT: float v8 = v6 + v7;
-// MULTI_OUT-NEXT: v4[v5] = v8;
-// MULTI_OUT-NEXT: float v9 = v2[v5];
-// MULTI_OUT-NEXT: float v10 = v1[v5];
-// MULTI_OUT-NEXT: float v11 = v9 - v10;
-// MULTI_OUT-NEXT: v3[v5] = v11;
-// MULTI_OUT-NEXT: return;
-// MULTI_OUT-NEXT: }
-// MULTI_OUT-NEXT: };
>From ff237e575769e9235b236b3a4489cac07266d82b Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 29 May 2025 16:09:06 +0000
Subject: [PATCH 05/19] A lil' cleaning up
---
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 7 +++----
mlir/test/mlir-translate/emit-class-neg-external.mlir | 6 +++---
mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir | 4 ++--
mlir/test/mlir-translate/emit-class.mlir | 5 +++--
4 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index a819e550ad385..e09ed0a142725 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1232,12 +1232,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (emitter.shouldPrintClass()) {
if (functionOp.isExternal()) {
// TODO: Determine the best long-term strategy for external functions.
- // Currently, we're stopping here to prevent downstream errors.
+ // Currently, we're skipping over this functionOp.
os << "Warning: Cannot process external function '"
<< functionOp.getName() << "'. "
- << "It lacks a body, and attempting to continue would lead to errors "
- "due to missing argument details.\n";
- return failure();
+ << "This functionOp lacks a body so we will skip over it.";
+ return success();
}
os << "class " << emitter.getClassName() << " final {\n";
os << "public: \n";
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
index 347992c467759..c34a1652abd3f 100644
--- a/mlir/test/mlir-translate/emit-class-neg-external.mlir
+++ b/mlir/test/mlir-translate/emit-class-neg-external.mlir
@@ -1,8 +1,8 @@
/// An external function - has no body
-// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s 2>&1 | FileCheck %s
+// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
}
-// CHECK: Warning: Cannot process external function 'extern_func'. It lacks a body, and attempting to continue would lead to errors due to missing argument details.
\ No newline at end of file
+// CHECK: Warning: Cannot process external function 'extern_func'. This functionOp lacks a body so we will skip over it.
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
index 77e89da2f0a4f..6d43fa953a946 100644
--- a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
+++ b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
@@ -1,7 +1,7 @@
/// The function has no argument attributes
// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
emitc.func @foo(%arg0 : i32) {
emitc.call_opaque "bar" (%arg0) : (i32) -> ()
emitc.return
@@ -12,4 +12,4 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
// CHECK-NEXT: public:
// CHECK-NEXT: int32_t v1;
// CHECK-EMPTY:
-// CHECK-NEXT: std::map<std::string, char*> _buffer_map {
\ No newline at end of file
+// CHECK-NEXT: std::map<std::string, char*> _buffer_map {
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
index 9fa69f652df6c..2779cb315ed41 100644
--- a/mlir/test/mlir-translate/emit-class.mlir
+++ b/mlir/test/mlir-translate/emit-class.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.metadata = {CONVERSION_METADATA = "\10\00\00\00\00\00\00\00\08\00\0E\00\08\00\04\00\08\00\00\00\10\00\00\00$\00\00\00\00\00\06\00\08\00\04\00\06\00\00\00\04\00\00\00\00\00\00\00\0C\00\18\00\14\00\10\00\0C\00\04\00\0C\00\00\00\A6\03|\7Frm\F2\17\01\00\00\00\02\00\00\00\04\00\00\00\06\00\00\002.19.0\00\00", min_runtime_version = "1.5.0\00\00\00\00\00\00\00\00\00\00\00"}, tfl.schema_version = 3 : i32} {
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
%1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
@@ -20,7 +20,8 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
// CHECK-NEXT: float v2[1];
// CHECK-NEXT: float v3[1];
// CHECK-EMPTY:
-// CHECK-NEXT: std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) }, { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
+// CHECK-NEXT: std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) },
+// CHECK-SAME: { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
// CHECK-NEXT: char* getBufferForName(const std::string& name) const {
// CHECK-NEXT: auto it = _buffer_map.find(name);
// CHECK-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second;
>From 2dc314891040dfa94182f21fc874ebbbee186d1b Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 29 May 2025 17:51:59 +0000
Subject: [PATCH 06/19] Clarifying TODO messages
---
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index e09ed0a142725..aa85a03b84885 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1233,6 +1233,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (functionOp.isExternal()) {
// TODO: Determine the best long-term strategy for external functions.
// Currently, we're skipping over this functionOp.
+ // We have considered using emitWarning() which would return
+ // InFlightDiagnostic which seems can be automatically converted to LogicalResult since
+ // this is done in emitAttributes where emitError is converted to LogicalResult. However, it requires that we pass in a
+ // location which at first glance we don't have in this scope. Open to
+ // further discussion on this.
os << "Warning: Cannot process external function '"
<< functionOp.getName() << "'. "
<< "This functionOp lacks a body so we will skip over it.";
@@ -1255,6 +1260,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (failed(emitter.emitTypes(functionOp.getLoc(),
functionOp.getFunctionType().getResults())))
return failure();
+ // TODO: We may wanna consider having the name of the function be execute in
+ // the case that we want to emit a class instead of main. Leaving as is for
+ // now to make the change smaller.
os << " " << functionOp.getName();
os << "(";
>From bf4f1cd010631adc295210bbb98a32178bad8f34 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 29 May 2025 19:01:55 +0000
Subject: [PATCH 07/19] Formatting issues
---
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index aa85a03b84885..fc550d24113e3 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -70,8 +70,8 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
template <typename Container, typename UnaryFunctor>
inline LogicalResult interleaveWithNewLineWithError(const Container &c,
- raw_ostream &os,
- UnaryFunctor eachFn) {
+ raw_ostream &os,
+ UnaryFunctor eachFn) {
return interleaveWithError(c.begin(), c.end(), eachFn,
[&]() { os << ";\n"; });
}
@@ -1234,8 +1234,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
// TODO: Determine the best long-term strategy for external functions.
// Currently, we're skipping over this functionOp.
// We have considered using emitWarning() which would return
- // InFlightDiagnostic which seems can be automatically converted to LogicalResult since
- // this is done in emitAttributes where emitError is converted to LogicalResult. However, it requires that we pass in a
+ // InFlightDiagnostic which seems can be automatically converted to
+ // LogicalResult since this is done in emitAttributes where emitError is
+ // converted to LogicalResult. However, it requires that we pass in a
// location which at first glance we don't have in this scope. Open to
// further discussion on this.
os << "Warning: Cannot process external function '"
>From e7f208437bc44f211a50c152809499b81ae40c98 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
=?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
=?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Thu, 22 May 2025 08:24:18 -0700
Subject: [PATCH 08/19] [flang][rt] Fix the use of kNoAsyncId -> kNoAsyncObject
(#141079)
>From 20d9f7a4e416f9940430ef1aa99f0c6e3995543a Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 6 Jun 2025 16:12:33 +0000
Subject: [PATCH 09/19] Adding ClassOp, FieldOp, GetFieldOp to allow for a
transfrom from func to class
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 116 ++++++++++++++++++
.../mlir/Dialect/EmitC/Transforms/Passes.h | 1 +
.../mlir/Dialect/EmitC/Transforms/Passes.td | 11 ++
.../Dialect/EmitC/Transforms/Transforms.h | 6 +
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 73 +++++++++++
.../Dialect/EmitC/Transforms/CMakeLists.txt | 1 +
.../EmitC/Transforms/ConvertFuncToClass.cpp | 70 +++++++++++
.../Dialect/EmitC/Transforms/Transforms.cpp | 77 ++++++++++++
.../Dialect/EmitC/convert_func_to_class.mlir | 15 +++
9 files changed, 370 insertions(+)
create mode 100644 mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
create mode 100644 mlir/test/Dialect/EmitC/convert_func_to_class.mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index e53d3e45875d5..ea6af41ee2901 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1572,4 +1572,120 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
let hasVerifier = 1;
}
+def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
+ IsolatedFromAbove, OpAsmOpInterface]> {
+ let summary =
+ "Represents a C++ class definition, encapsulating fields and methods.";
+
+ let description = [{
+ The `emitc.class` operation defines a C++ class, acting as a container
+ for its data fields (`emitc.variable`) and methods (`emitc.func`).
+ It creates a distinct scope, isolating its contents from the surrounding
+ MLIR region, similar to how C++ classes encapsulate their internals.
+
+ Example:
+ ```mlir
+ emitc.class @MyModelClass {
+ emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
+ emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>>
+ emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
+
+ emitc.func @main() attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ %c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+
+ %some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
+ %another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
+ %output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
+
+ %v1 = subscript %some_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %v1_val = load %v1 : !emitc.lvalue<f32> -> f32
+
+ %v2 = subscript %another_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %v2_val = load %v2 : !emitc.lvalue<f32> -> f32
+
+ %v3_val = add %v1_val, %v2_val : (f32, f32) -> f32
+
+ %output_lvalue = subscript %output_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ assign %v3_val, %output_lvalue : (f32, !emitc.lvalue<f32>) -> ()
+
+ return
+ }
+ }
+ }
+
+ ```
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name);
+
+ let regions = (region AnyRegion:$body);
+
+ let builders = [];
+
+ let extraClassDeclaration = [{
+ // Returns the body block containing class members and methods.
+ Block &getBlock();
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+
+ let assemblyFormat = "`class` $sym_name attr-dict-with-keyword $body";
+}
+
+def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
+ let summary = "A field within a class";
+ let description = [{
+ The `emitc.field` operation declares a named field within an `emitc.class`
+ operation. The field's type must be an EmitC type. An optional initial value can be provided.
+
+ Example with initial values:
+
+ ```mlir
+ emitc.class @MyModelClass {
+ emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<0.0> : !emitc.f32
+ emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<1.0> : !emitc.f32
+ emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
+ }
+ ```
+ Example without initial value:
+ ```mlir
+ emitc.class @MyModelClass {
+ emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
+ }
+ ```
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
+ OptionalAttr<AnyAttr>:$initial_value);
+
+ let assemblyFormat = "$sym_name `:` $type (`=` $initial_value^)? attr-dict";
+
+ let hasVerifier = 1;
+}
+
+def EmitC_GetFieldOp
+ : EmitC_Op<"get_field", [Pure, DeclareOpInterfaceMethods<
+ SymbolUserOpInterface>]> {
+ let summary = "Obtain access to a field within a class instance";
+ let description = [{
+ The `emitc.get_field` operation retrieves the lvalue of a
+ named field from a given class instance.
+
+ Example:
+
+ ```mlir
+ %some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
+ %another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
+ %output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
+ ```
+ }];
+
+ let arguments = (ins AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$base,
+ FlatSymbolRefAttr:$class_name, FlatSymbolRefAttr:$field_name);
+
+ let results = (outs AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$result);
+ let assemblyFormat = "$base `:` type($base) $class_name `,` $field_name `->` "
+ "type($result) attr-dict";
+}
+
#endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
index 5a103f181c76b..ad516f93808f8 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
@@ -15,6 +15,7 @@ namespace mlir {
namespace emitc {
#define GEN_PASS_DECL_FORMEXPRESSIONSPASS
+#define GEN_PASS_DECL_CONVERTFUNCTOCLASSPASS
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index f46b705ca2dfe..d84c3184da777 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,4 +20,15 @@ def FormExpressionsPass : Pass<"form-expressions"> {
let dependentDialects = ["emitc::EmitCDialect"];
}
+def ConvertFuncToClassPass
+ : Pass<"convert-emitc-func-to-class", "mlir::emitc::FuncOp"> {
+ let summary = "Convert functions to classes, using arguments as fields.";
+ let description = [{
+ This pass transforms `emitc.func` operations into `emitc.class` operations.
+ Function arguments become fields of the class, and the function body is moved
+ to a new `execute` method within the class.
+ }];
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 2574acd7d48e0..49dec99938a44 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -28,6 +28,12 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
/// Populates `patterns` with expression-related patterns.
void populateExpressionPatterns(RewritePatternSet &patterns);
+//===----------------------------------------------------------------------===//
+// Convert Func to Class Transform
+//===----------------------------------------------------------------------===//
+
+ClassOp createClass(FuncOp funcOp, OpBuilder &builder);
+
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index f82b20712b8c6..3af0f16be7515 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1400,6 +1400,79 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
builder.getNamedAttr("id", builder.getStringAttr(id)));
}
+//===----------------------------------------------------------------------===//
+// FieldOp
+//===----------------------------------------------------------------------===//
+LogicalResult FieldOp::verify() {
+ if (!isSupportedEmitCType(getType())) {
+ return emitOpError("expected valid emitc type");
+ }
+
+ if (!getInitialValue().has_value()) {
+ return success();
+ }
+
+ Attribute initValue = getInitialValue().value();
+ // Check that the type of the initial value is compatible with the type of
+ // the global variable.
+ if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
+ auto initialValueType = elementsAttr.getType();
+ if (!initialValueType) {
+ return emitOpError("initial value attribute must have a type");
+ }
+ auto fieldType = getType();
+ if (initialValueType != fieldType) {
+ if (auto lvalueType = dyn_cast<LValueType>(fieldType)) {
+ auto innerFieldType = lvalueType.getValueType();
+ if (innerFieldType != initialValueType) {
+ return emitOpError("initial value type ")
+ << initialValueType << " is not compatible with field type "
+ << fieldType << " its inner type " << innerFieldType;
+ }
+
+ } else {
+ return emitOpError("initial value type ")
+ << initialValueType << " is not compatible with field type "
+ << fieldType;
+ }
+ }
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetFieldOp
+//===----------------------------------------------------------------------===//
+LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto classNameAttr = getClassNameAttr();
+ auto fieldNameAttr = getFieldNameAttr();
+ if (!classNameAttr || !fieldNameAttr) {
+ return emitError("class and field name attributes are mandatory");
+ }
+ StringRef className = classNameAttr.getValue();
+ StringRef fieldName = fieldNameAttr.getValue();
+
+ auto fieldOp =
+ symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
+
+ if (!fieldOp) {
+ return emitOpError("field '")
+ << fieldName << "' not found in class '" << className << "'";
+ }
+
+ Type getFieldResultType = getResult().getType();
+ Type fieldType = fieldOp.getType();
+
+ if (fieldType != getFieldResultType) {
+ return emitOpError("result type ")
+ << getFieldResultType << " does not match field '" << fieldName
+ << "' type " << fieldType;
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
index 19b80b22bd84b..4c2525c2a5b00 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
Transforms.cpp
FormExpressions.cpp
TypeConversions.cpp
+ ConvertFuncToClass.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
diff --git a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
new file mode 100644
index 0000000000000..5beb62f675414
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
@@ -0,0 +1,70 @@
+//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+namespace emitc {
+
+#define GEN_PASS_DEF_CONVERTFUNCTOCLASSPASS
+#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
+
+namespace {
+
+struct ConvertFuncToClassPass
+ : public impl::ConvertFuncToClassPassBase<ConvertFuncToClassPass> {
+ void runOnOperation() override {
+ emitc::FuncOp funcOp = getOperation();
+ MLIRContext *context = funcOp->getContext();
+
+ // Wrap each C operator op with an expression op.
+ OpBuilder builder(context);
+ createClass(funcOp, builder);
+
+ // // Create the new function inside the class
+ // auto funcType = FunctionType::get(funcOp.getContext(),
+ // funcOp.getFunctionType().getInputs(),
+ // funcOp.getFunctionType().getResults()); auto newFuncOp =
+ // builder.create<emitc::FuncOp>(
+ // funcOp.getLoc(),builder.getStringAttr("execute"), funcType );
+
+ // builder.createBlock(&newFuncOp.getBody());
+ // builder.setInsertionPointToStart(&newFuncOp.getBody().front());
+
+ // // 7. Remap original arguments to field pointers
+ // IRMapping mapper;
+
+ // // 8. move or clone operations from original function
+ // for (Operation &opToClone :
+ // llvm::make_early_inc_range(funcOp.getBody().front())) {
+ // if (isa<emitc::ConstantOp>(opToClone) ||
+ // isa<emitc::SubscriptOp>(opToClone) ||
+ // isa<emitc::LoadOp>(opToClone) ||
+ // isa<emitc::AddOp>(opToClone) ||
+ // isa<emitc::AssignOp>(opToClone) ||
+ // isa<emitc::ReturnOp>(opToClone )) {
+ // builder.clone(opToClone, mapper);
+ // } else {
+ // opToClone.emitOpError("Unsupported operation found");
+ // }
+ // }
+ // if (funcOp->use_empty()) funcOp->erase();
+ }
+};
+
+} // namespace
+
+} // namespace emitc
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 87350ecdceaaa..a9687e4d1d187 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -43,6 +43,83 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
return expressionOp;
}
+ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
+ builder.setInsertionPoint(funcOp);
+
+ // 2. Create the class
+ auto classOp = builder.create<emitc::ClassOp>(
+ funcOp.getLoc(), builder.getStringAttr("MyModelClass"));
+
+ // Create a block inside the class body and set insertion point
+ builder.createBlock(&classOp.getBody());
+ builder.setInsertionPointToStart(&classOp.getBody().front());
+
+ // 3. Extract input/output names from function arguments
+ SmallVector<std::pair<StringRef, Type>> fields;
+ llvm::SmallDenseMap<Value, Value> argToFieldMap;
+
+ auto argAttrs = funcOp.getArgAttrs();
+ if (argAttrs) {
+ for (const auto [arg, val] : zip(*argAttrs, funcOp.getArguments())) {
+ if (auto da = dyn_cast<mlir::DictionaryAttr>(arg)) {
+ auto nv = da.getNamed("tf_saved_model.index_path")->getValue();
+ auto fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
+ auto fieldType = emitc::LValueType::get(emitc::PointerType::get(
+ dyn_cast_or_null<emitc::ArrayType>(val.getType())
+ .getElementType()));
+ fields.push_back({fieldName.str(), fieldType});
+
+ // 4.Create the class fields
+ auto typeAttr = TypeAttr::get(val.getType());
+ mlir::Attribute emptyAttr = builder.getAttr<mlir::UnitAttr>();
+ auto dictAttr = DictionaryAttr::get(
+ builder.getContext(),
+ {builder.getNamedAttr(fieldName.str(), emptyAttr)});
+ builder.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+ /* attributes*/ dictAttr);
+ // 5. Get the pointers to the class fields
+ auto pointer = emitc::PointerType::get(
+ dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
+ auto ptr = builder.create<emitc::GetFieldOp>(
+ funcOp.getLoc(), pointer, val, "MyModelClass", fieldName);
+ argToFieldMap[val] = ptr;
+ }
+ }
+ }
+
+ // Create the new function inside the class
+ auto funcContext = funcOp.getContext();
+ auto inputTypes = funcOp.getFunctionType().getInputs();
+ auto results = funcOp.getFunctionType().getResults();
+ auto funcType = FunctionType::get(funcContext, inputTypes, results);
+ auto loc = funcOp.getLoc();
+ auto newFuncOp = builder.create<emitc::FuncOp>(
+ loc, builder.getStringAttr("execute"), funcType);
+
+ builder.createBlock(&newFuncOp.getBody());
+ builder.setInsertionPointToStart(&newFuncOp.getBody().front());
+
+ // 7. Remap original arguments to field pointers
+ IRMapping mapper;
+
+ // 8. move or clone operations from original function
+ auto body = llvm::make_early_inc_range(funcOp.getBody().front());
+ for (Operation &opToClone : body) {
+ if (isa<emitc::ConstantOp>(opToClone) ||
+ isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
+ isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
+ isa<emitc::ReturnOp>(opToClone)) {
+ builder.clone(opToClone, mapper);
+ } else {
+ opToClone.emitOpError("Unsupported operation found");
+ }
+ }
+
+ // if (funcOp->use_empty()) funcOp->erase();
+
+ return classOp;
+}
+
} // namespace emitc
} // namespace mlir
diff --git a/mlir/test/Dialect/EmitC/convert_func_to_class.mlir b/mlir/test/Dialect/EmitC/convert_func_to_class.mlir
new file mode 100644
index 0000000000000..f9e276f059130
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/convert_func_to_class.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s --emitc-convert-func-to-class
+
+module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+ emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %2 = load %1 : <f32>
+ %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %4 = load %3 : <f32>
+ %5 = add %2, %4 : (f32, f32) -> f32
+ %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ assign %5 : f32 to %6 : <f32>
+ return
+ }
+}
>From 8ba35d7d97c605f4d98842a667502f24df0220f0 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 6 Jun 2025 17:22:04 +0000
Subject: [PATCH 10/19] Removed unnecessary comments
---
.../EmitC/Transforms/ConvertFuncToClass.cpp | 32 +------------------
.../Dialect/EmitC/Transforms/Transforms.cpp | 25 +++++++--------
2 files changed, 12 insertions(+), 45 deletions(-)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
index 5beb62f675414..fe8a05d39e1df 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
@@ -29,42 +29,12 @@ struct ConvertFuncToClassPass
emitc::FuncOp funcOp = getOperation();
MLIRContext *context = funcOp->getContext();
- // Wrap each C operator op with an expression op.
OpBuilder builder(context);
createClass(funcOp, builder);
-
- // // Create the new function inside the class
- // auto funcType = FunctionType::get(funcOp.getContext(),
- // funcOp.getFunctionType().getInputs(),
- // funcOp.getFunctionType().getResults()); auto newFuncOp =
- // builder.create<emitc::FuncOp>(
- // funcOp.getLoc(),builder.getStringAttr("execute"), funcType );
-
- // builder.createBlock(&newFuncOp.getBody());
- // builder.setInsertionPointToStart(&newFuncOp.getBody().front());
-
- // // 7. Remap original arguments to field pointers
- // IRMapping mapper;
-
- // // 8. move or clone operations from original function
- // for (Operation &opToClone :
- // llvm::make_early_inc_range(funcOp.getBody().front())) {
- // if (isa<emitc::ConstantOp>(opToClone) ||
- // isa<emitc::SubscriptOp>(opToClone) ||
- // isa<emitc::LoadOp>(opToClone) ||
- // isa<emitc::AddOp>(opToClone) ||
- // isa<emitc::AssignOp>(opToClone) ||
- // isa<emitc::ReturnOp>(opToClone )) {
- // builder.clone(opToClone, mapper);
- // } else {
- // opToClone.emitOpError("Unsupported operation found");
- // }
- // }
- // if (funcOp->use_empty()) funcOp->erase();
}
};
} // namespace
} // namespace emitc
-} // namespace mlir
\ No newline at end of file
+} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index a9687e4d1d187..8471631c6e60b 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -46,15 +46,12 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
builder.setInsertionPoint(funcOp);
- // 2. Create the class
auto classOp = builder.create<emitc::ClassOp>(
funcOp.getLoc(), builder.getStringAttr("MyModelClass"));
- // Create a block inside the class body and set insertion point
builder.createBlock(&classOp.getBody());
builder.setInsertionPointToStart(&classOp.getBody().front());
- // 3. Extract input/output names from function arguments
SmallVector<std::pair<StringRef, Type>> fields;
llvm::SmallDenseMap<Value, Value> argToFieldMap;
@@ -69,7 +66,6 @@ ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
.getElementType()));
fields.push_back({fieldName.str(), fieldType});
- // 4.Create the class fields
auto typeAttr = TypeAttr::get(val.getType());
mlir::Attribute emptyAttr = builder.getAttr<mlir::UnitAttr>();
auto dictAttr = DictionaryAttr::get(
@@ -77,17 +73,19 @@ ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
{builder.getNamedAttr(fieldName.str(), emptyAttr)});
builder.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
/* attributes*/ dictAttr);
- // 5. Get the pointers to the class fields
- auto pointer = emitc::PointerType::get(
- dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
- auto ptr = builder.create<emitc::GetFieldOp>(
- funcOp.getLoc(), pointer, val, "MyModelClass", fieldName);
- argToFieldMap[val] = ptr;
+
+ // TODO: From my current understanding, we need to instantiate a class
+ // so we can get the pointers from .field but we can't do that in here
+ // so I'm unsure how I can rewrite the following line to ensure
+ // GetFieldOp works correctly. auto pointer =
+ // emitc::PointerType::get(dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
+ // auto ptr = builder.create<emitc::GetFieldOp>(funcOp.getLoc(),
+ // pointer, val, "MyModelClass", fieldName);
+ argToFieldMap[val] = nullptr;
}
}
}
- // Create the new function inside the class
auto funcContext = funcOp.getContext();
auto inputTypes = funcOp.getFunctionType().getInputs();
auto results = funcOp.getFunctionType().getResults();
@@ -99,10 +97,8 @@ ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
builder.createBlock(&newFuncOp.getBody());
builder.setInsertionPointToStart(&newFuncOp.getBody().front());
- // 7. Remap original arguments to field pointers
IRMapping mapper;
- // 8. move or clone operations from original function
auto body = llvm::make_early_inc_range(funcOp.getBody().front());
for (Operation &opToClone : body) {
if (isa<emitc::ConstantOp>(opToClone) ||
@@ -115,7 +111,8 @@ ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
}
}
- // if (funcOp->use_empty()) funcOp->erase();
+ // TODO: Need to erase the funcOp after all this. Using funcOp->erase raises
+ // errors:
return classOp;
}
>From f3acc4faf11d2109b1e9b5f9ae50705b2e1ffb8d Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Thu, 12 Jun 2025 21:36:24 +0000
Subject: [PATCH 11/19] rewritten
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 87 +++++-----
.../mlir/Dialect/EmitC/Transforms/Passes.h | 2 +-
.../mlir/Dialect/EmitC/Transforms/Passes.td | 5 +-
.../Dialect/EmitC/Transforms/Transforms.h | 7 +-
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 38 +++--
.../Dialect/EmitC/Transforms/CMakeLists.txt | 2 +-
.../EmitC/Transforms/ConvertFuncToClass.cpp | 40 -----
.../Dialect/EmitC/Transforms/Transforms.cpp | 75 ---------
.../EmitC/Transforms/WrapFuncInClass.cpp | 150 ++++++++++++++++++
...ass.mlir => wrap_emitc_func_in_class.mlir} | 2 +-
10 files changed, 214 insertions(+), 194 deletions(-)
delete mode 100644 mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
create mode 100644 mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
rename mlir/test/Dialect/EmitC/{convert_func_to_class.mlir => wrap_emitc_func_in_class.mlir} (95%)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index ea6af41ee2901..08235ca701e2a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1572,46 +1572,45 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
let hasVerifier = 1;
}
-def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
- IsolatedFromAbove, OpAsmOpInterface]> {
+def EmitC_ClassOp
+ : EmitC_Op<"class", [AutomaticAllocationScope, IsolatedFromAbove,
+ OpAsmOpInterface, SymbolTable,
+ Symbol]#GraphRegionNoTerminator.traits> {
let summary =
"Represents a C++ class definition, encapsulating fields and methods.";
+ // FIX WORDING
let description = [{
The `emitc.class` operation defines a C++ class, acting as a container
for its data fields (`emitc.variable`) and methods (`emitc.func`).
It creates a distinct scope, isolating its contents from the surrounding
MLIR region, similar to how C++ classes encapsulate their internals.
+ All the class memebrs need to be default initalizable.
Example:
```mlir
- emitc.class @MyModelClass {
- emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
- emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>>
- emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
-
- emitc.func @main() attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
- %c0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-
- %some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
- %another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
- %output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
-
- %v1 = subscript %some_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
- %v1_val = load %v1 : !emitc.lvalue<f32> -> f32
-
- %v2 = subscript %another_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
- %v2_val = load %v2 : !emitc.lvalue<f32> -> f32
-
- %v3_val = add %v1_val, %v2_val : (f32, f32) -> f32
-
- %output_lvalue = subscript %output_ptr[%c0] : (!emitc.ptr<f32>, !emitc.size_t) -> !emitc.lvalue<f32>
- assign %v3_val, %output_lvalue : (f32, !emitc.lvalue<f32>) -> ()
-
- return
- }
+ emitc.class @MymainClass {
+ emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
+ emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
+ emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
+
+ emitc.func @execute() {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+
+ %1 = get_field @another_feature : !emitc.array<1xf32>
+ %2 = get_field @some_feature : !emitc.array<1xf32>
+ %3 = get_field @output_0 : !emitc.array<1xf32>
+
+ %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %5 = load %4 : <f32>
+ %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %7 = load %6 : <f32>
+ %8 = add %5, %7 : (f32, f32) -> f32
+ %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ assign %8 : f32 to %9 : <f32>
+ return
}
- }
+ }
```
}];
@@ -1629,7 +1628,7 @@ def EmitC_ClassOp : EmitC_Op<"class", [AutomaticAllocationScope,
let hasCustomAssemblyFormat = 1;
- let assemblyFormat = "`class` $sym_name attr-dict-with-keyword $body";
+ let assemblyFormat = [{ $sym_name attr-dict-with-keyword $body }];
}
def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
@@ -1642,15 +1641,9 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
```mlir
emitc.class @MyModelClass {
- emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<0.0> : !emitc.f32
- emitc.field @some_feature : !emitc.lvalue<!emitc.ptr<f32>> = #emitc.value<1.0> : !emitc.f32
- emitc.field @output_0 : !emitc.lvalue<!emitc.ptr<f32>>
- }
- ```
- Example without initial value:
- ```mlir
- emitc.class @MyModelClass {
- emitc.field @another_feature : !emitc.lvalue<!emitc.ptr<f32>>
+ emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
+ emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
+ emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
}
```
}];
@@ -1658,7 +1651,8 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
OptionalAttr<AnyAttr>:$initial_value);
- let assemblyFormat = "$sym_name `:` $type (`=` $initial_value^)? attr-dict";
+ let assemblyFormat =
+ [{ $sym_name `:` $type (`=` $initial_value^)? attr-dict}];
let hasVerifier = 1;
}
@@ -1674,18 +1668,15 @@ def EmitC_GetFieldOp
Example:
```mlir
- %some_ptr = emitc.get_field %self : @MyModelClass, @some_feature -> !emitc.ptr<f32>
- %another_ptr = emitc.get_field %self : @MyModelClass, @another_feature -> !emitc.ptr<f32>
- %output_ptr = emitc.get_field %self : @MyModelClass, @output_0 -> !emitc.ptr<f32>
+ %some_ptr = emitc.get_field @some_feature : !emitc.array<1xf32>
+ %another_ptr = emitc.get_field @another_feature : !emitc.array<1xf32>
+ %output_ptr = emitc.get_field @output_0 : !emitc.array<1xf32>
```
}];
- let arguments = (ins AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$base,
- FlatSymbolRefAttr:$class_name, FlatSymbolRefAttr:$field_name);
-
- let results = (outs AnyTypeOf<[EmitC_LValueType, EmitC_PointerType]>:$result);
- let assemblyFormat = "$base `:` type($base) $class_name `,` $field_name `->` "
- "type($result) attr-dict";
+ let arguments = (ins FlatSymbolRefAttr:$field_name);
+ let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
+ let assemblyFormat = "$field_name `:` type($result) attr-dict";
}
#endif // MLIR_DIALECT_EMITC_IR_EMITC
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
index ad516f93808f8..1af4aa06fa811 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
@@ -15,7 +15,7 @@ namespace mlir {
namespace emitc {
#define GEN_PASS_DECL_FORMEXPRESSIONSPASS
-#define GEN_PASS_DECL_CONVERTFUNCTOCLASSPASS
+#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index d84c3184da777..1aa95b32217c1 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -20,9 +20,8 @@ def FormExpressionsPass : Pass<"form-expressions"> {
let dependentDialects = ["emitc::EmitCDialect"];
}
-def ConvertFuncToClassPass
- : Pass<"convert-emitc-func-to-class", "mlir::emitc::FuncOp"> {
- let summary = "Convert functions to classes, using arguments as fields.";
+def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
+ let summary = "Wrap functions in classes, using arguments as fields.";
let description = [{
This pass transforms `emitc.func` operations into `emitc.class` operations.
Function arguments become fields of the class, and the function body is moved
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 49dec99938a44..bdf6d0985e6db 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -28,11 +28,8 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
/// Populates `patterns` with expression-related patterns.
void populateExpressionPatterns(RewritePatternSet &patterns);
-//===----------------------------------------------------------------------===//
-// Convert Func to Class Transform
-//===----------------------------------------------------------------------===//
-
-ClassOp createClass(FuncOp funcOp, OpBuilder &builder);
+/// Populates 'patterns' with func-related patterns.
+void populateFuncPatterns(RewritePatternSet &patterns);
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 3af0f16be7515..695ad3ee3bb0a 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1408,32 +1408,32 @@ LogicalResult FieldOp::verify() {
return emitOpError("expected valid emitc type");
}
- if (!getInitialValue().has_value()) {
+ if (!getInitialValue()) {
return success();
}
- Attribute initValue = getInitialValue().value();
+ Attribute initValue = *getInitialValue();
// Check that the type of the initial value is compatible with the type of
// the global variable.
- if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
- auto initialValueType = elementsAttr.getType();
+ if (ElementsAttr elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
+ Type initialValueType = elementsAttr.getType();
if (!initialValueType) {
return emitOpError("initial value attribute must have a type");
}
- auto fieldType = getType();
+ Type fieldType = getType();
if (initialValueType != fieldType) {
- if (auto lvalueType = dyn_cast<LValueType>(fieldType)) {
- auto innerFieldType = lvalueType.getValueType();
+ if (LValueType lvalueType = dyn_cast<LValueType>(fieldType)) {
+ Type innerFieldType = lvalueType.getValueType();
if (innerFieldType != initialValueType) {
return emitOpError("initial value type ")
- << initialValueType << " is not compatible with field type "
- << fieldType << " its inner type " << innerFieldType;
+ << initialValueType << " is not compatible with field type '"
+ << fieldType << "' its inner type '" << innerFieldType << "'";
}
} else {
- return emitOpError("initial value type ")
- << initialValueType << " is not compatible with field type "
- << fieldType;
+ return emitOpError("initial value type '")
+ << initialValueType << "' is not compatible with field type '"
+ << fieldType << "'";
}
}
}
@@ -1445,20 +1445,18 @@ LogicalResult FieldOp::verify() {
// GetFieldOp
//===----------------------------------------------------------------------===//
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
- auto classNameAttr = getClassNameAttr();
- auto fieldNameAttr = getFieldNameAttr();
- if (!classNameAttr || !fieldNameAttr) {
- return emitError("class and field name attributes are mandatory");
+ mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
+ if (!fieldNameAttr) {
+ return emitError("field name attribute is mandatory");
}
- StringRef className = classNameAttr.getValue();
+
StringRef fieldName = fieldNameAttr.getValue();
- auto fieldOp =
+ FieldOp fieldOp =
symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
if (!fieldOp) {
- return emitOpError("field '")
- << fieldName << "' not found in class '" << className << "'";
+ return emitOpError("field '") << fieldName << "' not found in the class '";
}
Type getFieldResultType = getResult().getType();
diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
index 4c2525c2a5b00..baf67afc30072 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -2,7 +2,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
Transforms.cpp
FormExpressions.cpp
TypeConversions.cpp
- ConvertFuncToClass.cpp
+ WrapFuncInClass.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
diff --git a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
deleted file mode 100644
index fe8a05d39e1df..0000000000000
--- a/mlir/lib/Dialect/EmitC/Transforms/ConvertFuncToClass.cpp
+++ /dev/null
@@ -1,40 +0,0 @@
-//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/EmitC/IR/EmitC.h"
-#include "mlir/Dialect/EmitC/Transforms/Passes.h"
-#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
-#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-
-namespace mlir {
-namespace emitc {
-
-#define GEN_PASS_DEF_CONVERTFUNCTOCLASSPASS
-#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
-
-namespace {
-
-struct ConvertFuncToClassPass
- : public impl::ConvertFuncToClassPassBase<ConvertFuncToClassPass> {
- void runOnOperation() override {
- emitc::FuncOp funcOp = getOperation();
- MLIRContext *context = funcOp->getContext();
-
- OpBuilder builder(context);
- createClass(funcOp, builder);
- }
-};
-
-} // namespace
-
-} // namespace emitc
-} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index 8471631c6e60b..a252924caeb62 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -10,7 +10,6 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
-#include "llvm/Support/Debug.h"
namespace mlir {
namespace emitc {
@@ -43,80 +42,6 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
return expressionOp;
}
-ClassOp createClass(FuncOp funcOp, OpBuilder &builder) {
- builder.setInsertionPoint(funcOp);
-
- auto classOp = builder.create<emitc::ClassOp>(
- funcOp.getLoc(), builder.getStringAttr("MyModelClass"));
-
- builder.createBlock(&classOp.getBody());
- builder.setInsertionPointToStart(&classOp.getBody().front());
-
- SmallVector<std::pair<StringRef, Type>> fields;
- llvm::SmallDenseMap<Value, Value> argToFieldMap;
-
- auto argAttrs = funcOp.getArgAttrs();
- if (argAttrs) {
- for (const auto [arg, val] : zip(*argAttrs, funcOp.getArguments())) {
- if (auto da = dyn_cast<mlir::DictionaryAttr>(arg)) {
- auto nv = da.getNamed("tf_saved_model.index_path")->getValue();
- auto fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
- auto fieldType = emitc::LValueType::get(emitc::PointerType::get(
- dyn_cast_or_null<emitc::ArrayType>(val.getType())
- .getElementType()));
- fields.push_back({fieldName.str(), fieldType});
-
- auto typeAttr = TypeAttr::get(val.getType());
- mlir::Attribute emptyAttr = builder.getAttr<mlir::UnitAttr>();
- auto dictAttr = DictionaryAttr::get(
- builder.getContext(),
- {builder.getNamedAttr(fieldName.str(), emptyAttr)});
- builder.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
- /* attributes*/ dictAttr);
-
- // TODO: From my current understanding, we need to instantiate a class
- // so we can get the pointers from .field but we can't do that in here
- // so I'm unsure how I can rewrite the following line to ensure
- // GetFieldOp works correctly. auto pointer =
- // emitc::PointerType::get(dyn_cast_or_null<emitc::ArrayType>(val.getType()).getElementType());
- // auto ptr = builder.create<emitc::GetFieldOp>(funcOp.getLoc(),
- // pointer, val, "MyModelClass", fieldName);
- argToFieldMap[val] = nullptr;
- }
- }
- }
-
- auto funcContext = funcOp.getContext();
- auto inputTypes = funcOp.getFunctionType().getInputs();
- auto results = funcOp.getFunctionType().getResults();
- auto funcType = FunctionType::get(funcContext, inputTypes, results);
- auto loc = funcOp.getLoc();
- auto newFuncOp = builder.create<emitc::FuncOp>(
- loc, builder.getStringAttr("execute"), funcType);
-
- builder.createBlock(&newFuncOp.getBody());
- builder.setInsertionPointToStart(&newFuncOp.getBody().front());
-
- IRMapping mapper;
-
- auto body = llvm::make_early_inc_range(funcOp.getBody().front());
- for (Operation &opToClone : body) {
- if (isa<emitc::ConstantOp>(opToClone) ||
- isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
- isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
- isa<emitc::ReturnOp>(opToClone)) {
- builder.clone(opToClone, mapper);
- } else {
- opToClone.emitOpError("Unsupported operation found");
- }
- }
-
- // TODO: Need to erase the funcOp after all this. Using funcOp->erase raises
- // errors:
-
- return classOp;
-}
-
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
new file mode 100644
index 0000000000000..a6f333b074a02
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -0,0 +1,150 @@
+//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Rewrite.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/GraphWriter.h"
+
+namespace mlir {
+namespace emitc {
+
+#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
+#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
+
+namespace {
+
+struct WrapFuncInClassPass
+ : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
+ void runOnOperation() override {
+ Operation *rootOp = getOperation();
+ MLIRContext *context = rootOp->getContext();
+
+ RewritePatternSet patterns(context);
+ populateFuncPatterns(patterns);
+
+ if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
+ return signalPassFailure();
+ }
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<emitc::EmitCDialect>();
+ }
+};
+
+} // namespace
+
+} // namespace emitc
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::emitc;
+
+static bool validOp(Operation &opToClone) {
+ return isa<emitc::ConstantOp>(opToClone) ||
+ isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
+ isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
+ isa<emitc::ReturnOp>(opToClone);
+}
+
+class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
+public:
+ using OpRewritePattern<emitc::FuncOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
+ PatternRewriter &rewriter) const override {
+ if (funcOp->getParentOfType<emitc::ClassOp>()) {
+ return failure();
+ }
+ auto className = "My" + funcOp.getSymNameAttr().str() + "Class";
+ mlir::emitc::ClassOp newClassOp =
+ rewriter.create<emitc::ClassOp>(funcOp.getLoc(), className);
+
+ SmallVector<std::pair<StringAttr, TypeAttr>> fields;
+ rewriter.createBlock(&newClassOp.getBody());
+ rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
+
+ auto argAttrs = funcOp.getArgAttrs();
+
+ for (const auto &[arg, val] : (zip(*argAttrs, funcOp.getArguments()))) {
+ // FIXME:How can we avoid hardcoding this name?
+ // Should we loop through the dictionary and check for each named
+ // attribute if attr.getName().getValue().contains("tf_saved_model")
+ if (auto namedAttr = dyn_cast<mlir::DictionaryAttr>(arg).getNamed(
+ "tf_saved_model.index_path")) {
+ Attribute nv = namedAttr->getValue();
+ StringAttr fieldName =
+ cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
+ TypeAttr typeAttr = TypeAttr::get(val.getType());
+ fields.push_back({fieldName, typeAttr});
+
+ rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+ /* attributes*/ arg);
+ } else
+ funcOp->emitOpError("Only Covers TF models");
+ }
+
+ rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
+ MLIRContext *funcContext = funcOp.getContext();
+ ArrayRef<Type> inputTypes = funcOp.getFunctionType().getInputs();
+ ArrayRef<Type> results = funcOp.getFunctionType().getResults();
+ FunctionType funcType = FunctionType::get(funcContext, inputTypes, results);
+ Location loc = funcOp.getLoc();
+ FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
+ loc, rewriter.getStringAttr("execute"), funcType);
+
+ rewriter.setInsertionPointToStart(newFuncOp.addEntryBlock());
+
+ std::vector<Value> newArguments;
+ for (auto [fieldName, attr] : fields) {
+ auto arg =
+ rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
+ newArguments.push_back(arg);
+ }
+
+ IRMapping mapper;
+ for (auto [oldArg, newArg] :
+ llvm::zip(funcOp.getArguments(), newArguments)) {
+ mapper.map(oldArg, newArg);
+ }
+
+ while (!newFuncOp.getArguments().empty()) {
+ if (failed(newFuncOp.eraseArgument(0))) {
+ break;
+ }
+ }
+
+ // TODO: The mapper is easier to use but cloning is more expensive than
+ // moving the body. Working on changing this portion to move the body
+ // instead
+ auto body = llvm::make_early_inc_range(funcOp.getBody().front());
+ for (Operation &opToClone : body) {
+ if (validOp(opToClone)) {
+ rewriter.clone(opToClone, mapper);
+ } else {
+ opToClone.emitOpError("Unsupported operation found");
+ }
+ }
+
+ rewriter.replaceOp(funcOp, newClassOp);
+ return funcOp->use_empty() ? success() : failure();
+ }
+};
+
+void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
+ patterns.add<WrapFuncInClass>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/EmitC/convert_func_to_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
similarity index 95%
rename from mlir/test/Dialect/EmitC/convert_func_to_class.mlir
rename to mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index f9e276f059130..3a775f969ab17 100644
--- a/mlir/test/Dialect/EmitC/convert_func_to_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --emitc-convert-func-to-class
+// RUN: mlir-opt %s --wrap-emitc-func-in-class
module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
>From f47f2111c669c77c8160abb8243d940b20f12b94 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 13 Jun 2025 23:48:08 +0000
Subject: [PATCH 12/19] Added tests for the wrap-emitc-func-in-class
---
.../mlir/Dialect/EmitC/Transforms/Passes.td | 3 +
.../Dialect/EmitC/Transforms/Transforms.h | 3 +-
.../EmitC/Transforms/WrapFuncInClass.cpp | 81 ++++++++-----------
.../EmitC/wrap_emitc_func_in_class.mlir | 26 +++++-
.../EmitC/wrap_emitc_func_in_class_neg.mlir | 8 ++
5 files changed, 72 insertions(+), 49 deletions(-)
create mode 100644 mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 1aa95b32217c1..d8ebf4a613bfd 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -28,6 +28,9 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
to a new `execute` method within the class.
}];
let dependentDialects = ["emitc::EmitCDialect"];
+ let options = [Option<
+ "namedAttribute", "named-attribute", "std::string", "\"\"",
+ "Name of the attribute to look for field names on function arguments">];
}
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index bdf6d0985e6db..11a1ad2ad2ff2 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -29,7 +29,8 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
void populateExpressionPatterns(RewritePatternSet &patterns);
/// Populates 'patterns' with func-related patterns.
-void populateFuncPatterns(RewritePatternSet &patterns);
+void populateFuncPatterns(RewritePatternSet &patterns,
+ const std::string &namedAttribute);
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index a6f333b074a02..d87af1379d96a 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -12,14 +12,13 @@
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/GraphWriter.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
namespace emitc {
@@ -31,12 +30,13 @@ namespace {
struct WrapFuncInClassPass
: public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
+ using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
void runOnOperation() override {
Operation *rootOp = getOperation();
MLIRContext *context = rootOp->getContext();
RewritePatternSet patterns(context);
- populateFuncPatterns(patterns);
+ populateFuncPatterns(patterns, namedAttribute);
if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
return signalPassFailure();
@@ -54,16 +54,13 @@ struct WrapFuncInClassPass
using namespace mlir;
using namespace mlir::emitc;
-static bool validOp(Operation &opToClone) {
- return isa<emitc::ConstantOp>(opToClone) ||
- isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
- isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
- isa<emitc::ReturnOp>(opToClone);
-}
-
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
+private:
+ std::string attributeName;
+
public:
- using OpRewritePattern<emitc::FuncOp>::OpRewritePattern;
+ WrapFuncInClass(MLIRContext *context, const std::string &attrName)
+ : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
PatternRewriter &rewriter) const override {
@@ -79,23 +76,25 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
auto argAttrs = funcOp.getArgAttrs();
-
- for (const auto &[arg, val] : (zip(*argAttrs, funcOp.getArguments()))) {
- // FIXME:How can we avoid hardcoding this name?
- // Should we loop through the dictionary and check for each named
- // attribute if attr.getName().getValue().contains("tf_saved_model")
- if (auto namedAttr = dyn_cast<mlir::DictionaryAttr>(arg).getNamed(
- "tf_saved_model.index_path")) {
- Attribute nv = namedAttr->getValue();
- StringAttr fieldName =
- cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
- TypeAttr typeAttr = TypeAttr::get(val.getType());
- fields.push_back({fieldName, typeAttr});
-
- rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
- /* attributes*/ arg);
- } else
- funcOp->emitOpError("Only Covers TF models");
+ if (argAttrs) {
+ for (const auto &[arg, val] :
+ llvm::zip(*argAttrs, funcOp.getArguments())) {
+ if (auto namedAttr =
+ dyn_cast<mlir::DictionaryAttr>(arg).getNamed(attributeName)) {
+ Attribute nv = namedAttr->getValue();
+ StringAttr fieldName =
+ cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
+ TypeAttr typeAttr = TypeAttr::get(val.getType());
+ fields.push_back({fieldName, typeAttr});
+
+ rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+ /* attributes*/ arg);
+ }
+ }
+ } else {
+ funcOp->emitOpError("arguments should have attributes so we can "
+ "initialize class fields.");
+ return failure();
}
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
@@ -107,8 +106,10 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
loc, rewriter.getStringAttr("execute"), funcType);
- rewriter.setInsertionPointToStart(newFuncOp.addEntryBlock());
+ rewriter.createBlock(&newFuncOp.getBody());
+ newFuncOp.getBody().takeBody(funcOp.getBody());
+ rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
std::vector<Value> newArguments;
for (auto [fieldName, attr] : fields) {
auto arg =
@@ -116,10 +117,9 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
newArguments.push_back(arg);
}
- IRMapping mapper;
for (auto [oldArg, newArg] :
- llvm::zip(funcOp.getArguments(), newArguments)) {
- mapper.map(oldArg, newArg);
+ llvm::zip(newFuncOp.getArguments(), newArguments)) {
+ rewriter.replaceAllUsesWith(oldArg, newArg);
}
while (!newFuncOp.getArguments().empty()) {
@@ -128,23 +128,12 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
}
}
- // TODO: The mapper is easier to use but cloning is more expensive than
- // moving the body. Working on changing this portion to move the body
- // instead
- auto body = llvm::make_early_inc_range(funcOp.getBody().front());
- for (Operation &opToClone : body) {
- if (validOp(opToClone)) {
- rewriter.clone(opToClone, mapper);
- } else {
- opToClone.emitOpError("Unsupported operation found");
- }
- }
-
rewriter.replaceOp(funcOp, newClassOp);
return funcOp->use_empty() ? success() : failure();
}
};
-void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
- patterns.add<WrapFuncInClass>(patterns.getContext());
+void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
+ const std::string &namedAttribute) {
+ patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
}
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index 3a775f969ab17..e0fa78a3dc459 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt %s --wrap-emitc-func-in-class
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s | FileCheck %s
module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
- emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+ emitc.func @Model(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
%1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
%2 = load %1 : <f32>
@@ -13,3 +13,25 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
return
}
}
+
+// CHECK: module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
+// CHECK: emitc.class @MyModelClass {
+// CHECK: emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
+// CHECK: emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
+// CHECK: emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
+// CHECK: emitc.func @execute() {
+// CHECK: %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+// CHECK: %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
+// CHECK: %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
+// CHECK: %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
+// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
+// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
+// CHECK: %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
+// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK: assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
+// CHECK: return
+// CHECK: }
+// CHECK: }
+// CHECK: }
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
new file mode 100644
index 0000000000000..578e8f6c21ebf
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s 2>&1 | FileCheck %s
+
+emitc.func @foo(%arg0 : i32) {
+ emitc.call_opaque "bar" (%arg0) : (i32) -> ()
+ emitc.return
+}
+
+// CHECK: error: 'emitc.func' op arguments should have attributes so we can initialize class fields.
>From 49f202cd3d91767f8a1d25481f1a48ba8050fcc1 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Sat, 14 Jun 2025 00:20:52 +0000
Subject: [PATCH 13/19] Remove class-name changes
---
mlir/include/mlir/Target/Cpp/CppEmitter.h | 4 +-
mlir/lib/Target/Cpp/TranslateRegistration.cpp | 41 +----
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 153 ++----------------
.../emit-class-neg-external.mlir | 8 -
.../emit-class-neg-noArgAttrs.mlir | 15 --
mlir/test/mlir-translate/emit-class.mlir | 39 -----
6 files changed, 18 insertions(+), 242 deletions(-)
delete mode 100644 mlir/test/mlir-translate/emit-class-neg-external.mlir
delete mode 100644 mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
delete mode 100644 mlir/test/mlir-translate/emit-class.mlir
diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h
index d1a6c1dc12d4c..7c5747a888261 100644
--- a/mlir/include/mlir/Target/Cpp/CppEmitter.h
+++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h
@@ -28,9 +28,7 @@ namespace emitc {
/// with matching id are emitted.
LogicalResult translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop = false,
- StringRef fileId = {}, bool emitClass = false,
- StringRef className = {},
- StringRef fieldNameAttribute = {});
+ StringRef fileId = {});
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
index 9e1533d34f6ea..f1869ac9e8eda 100644
--- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp
+++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp
@@ -33,50 +33,13 @@ void registerToCppTranslation() {
"file-id", llvm::cl::desc("Emit emitc.file ops with matching id"),
llvm::cl::init(""));
- static llvm::cl::opt<bool> emitClass(
- "emit-class",
- llvm::cl::desc("If specified, the output will be a class where "
- "the function(s) in the module are methods "
- "Enables class-related options"),
- llvm::cl::init(false));
-
- static llvm::cl::opt<std::string> className(
- "class-name",
- llvm::cl::desc("Mandatory class name if --emit-class is set"),
- llvm::cl::init(""));
-
- static llvm::cl::opt<std::string> fieldNameAttribute(
- "field-name-attribute",
- llvm::cl::desc("Mandatory name of the attribute to use as field name if "
- "--emit-class is set(default=tf_saved_model.index_path)"),
- llvm::cl::init("tf_saved_model.index_path"));
-
TranslateFromMLIRRegistration reg(
"mlir-to-cpp", "translate from mlir to cpp",
[](Operation *op, raw_ostream &output) {
- if (emitClass) {
- if (className.empty()) {
- llvm::errs() << "Error: --class-name is mandatory when "
- "--emit-class is set.\n";
- return mlir::failure();
- }
- if (fieldNameAttribute.empty()) {
- llvm::errs() << "Error: --field-name-attribute is mandatory when "
- "--emit-class is set.\n";
- return mlir::failure();
- }
- return emitc::translateToCpp(
- op, output,
- /*declareVariablesAtTop=*/declareVariablesAtTop,
- /*fileId=*/fileId, /*emitClass=*/emitClass,
- /*className=*/className,
- /*fieldNameAttribute=*/fieldNameAttribute);
- }
return emitc::translateToCpp(
op, output,
/*declareVariablesAtTop=*/declareVariablesAtTop,
- /*fileId=*/fileId, /*emitClass=*/emitClass, /*className=*/className,
- /*fieldNameAttribute=*/fieldNameAttribute);
+ /*fileId=*/fileId);
},
[](DialectRegistry ®istry) {
// clang-format off
@@ -87,4 +50,4 @@ void registerToCppTranslation() {
});
}
-} // namespace mlir
+} // namespace mlir
\ No newline at end of file
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index fc550d24113e3..52bfd4d774688 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -68,14 +68,6 @@ inline LogicalResult interleaveCommaWithError(const Container &c,
return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
}
-template <typename Container, typename UnaryFunctor>
-inline LogicalResult interleaveWithNewLineWithError(const Container &c,
- raw_ostream &os,
- UnaryFunctor eachFn) {
- return interleaveWithError(c.begin(), c.end(), eachFn,
- [&]() { os << ";\n"; });
-}
-
/// Return the precedence of a operator as an integer, higher values
/// imply higher precedence.
static FailureOr<int> getOperatorPrecedence(Operation *operation) {
@@ -124,8 +116,7 @@ namespace {
/// Emitter that uses dialect specific emitters to emit C++ code.
struct CppEmitter {
explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
- StringRef fileId, bool emitClass, StringRef className,
- StringRef fieldNameAttribute);
+ StringRef fileId);
/// Emits attribute or returns failure.
LogicalResult emitAttribute(Location loc, Attribute attr);
@@ -242,15 +233,6 @@ struct CppEmitter {
/// be declared at the beginning of a function.
bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
- // Returns whether we should emit a C++ class
- bool shouldPrintClass() { return emitClass; };
-
- // Returns the class name to emit
- std::string getClassName() { return className; };
-
- // Returns the field name to use in the map
- std::string getfieldNameAttribute() { return fieldNameAttribute; };
-
/// Returns whether this file op should be emitted
bool shouldEmitFile(FileOp file) {
return !fileId.empty() && file.getId() == fileId;
@@ -286,18 +268,6 @@ struct CppEmitter {
/// Only emit file ops whos id matches this value.
std::string fileId;
- /// Controls whether the output should be a C++ class.
- /// If true, the generated C++ code will be encapsulated within a class,
- /// and functions from the input module will become its member functions.
- const bool emitClass;
-
- /// The specified name for the generated C++ class
- const std::string className;
-
- /// Name of the MLIR attribute to use as a field name within the generated
- /// class
- const std::string fieldNameAttribute;
-
/// Map from value to name of C++ variable that contain the name.
ValueMapper valueMapper;
@@ -1063,17 +1033,6 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
}));
}
-static LogicalResult printFields(CppEmitter &emitter, Operation *functionOp,
- Region::BlockArgListType arguments) {
- raw_indented_ostream &os = emitter.ostream();
-
- return (interleaveWithNewLineWithError(
- arguments, os, [&](BlockArgument arg) -> LogicalResult {
- return emitter.emitVariableDeclaration(
- functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
- }));
-}
-
static LogicalResult printFunctionBody(CppEmitter &emitter,
Operation *functionOp,
Region::BlockListType &blocks) {
@@ -1178,45 +1137,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
-static LogicalResult emitClassFields(CppEmitter &emitter,
- emitc::FuncOp functionOp) {
- raw_indented_ostream &os = emitter.ostream();
- auto argAttrs = functionOp.getArgAttrs();
- Operation *operation = functionOp.getOperation();
- if (failed(printFields(emitter, operation, functionOp.getArguments())))
- return failure();
- os << ";\n";
-
- std::map<std::string, Value> fields;
- os << "\nstd::map<std::string, char*> _buffer_map {";
- if (argAttrs) {
- for (const auto [a, v] : zip(*argAttrs, functionOp.getArguments())) {
- if (auto da = dyn_cast<mlir::DictionaryAttr>(a)) {
- auto nv = da.getNamed(emitter.getfieldNameAttribute())->getValue();
- auto name = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]).str();
- auto Ins = fields.insert({name, v});
- if (!Ins.second)
- return failure();
- os << " { \"" << name << "\"" << ", reinterpret_cast<char*>("
- << emitter.getOrCreateName(v) << ") }, ";
- }
- }
- } else
- return failure();
-
- os << "};\n";
- os << "char* getBufferForName(const std::string& name) const {\n";
- os.indent();
- os.indent();
- os << "auto it = _buffer_map.find(name);\n";
- os << "return (it == _buffer_map.end()) ? nullptr : it->second;\n";
- os.unindent();
- os.unindent();
- os << "}\n\n";
-
- return success();
-}
-
static LogicalResult printOperation(CppEmitter &emitter,
emitc::FuncOp functionOp) {
// We need to declare variables at top if the function has multiple blocks.
@@ -1228,30 +1148,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
CppEmitter::Scope scope(emitter);
raw_indented_ostream &os = emitter.ostream();
- Operation *operation = functionOp.getOperation();
- if (emitter.shouldPrintClass()) {
- if (functionOp.isExternal()) {
- // TODO: Determine the best long-term strategy for external functions.
- // Currently, we're skipping over this functionOp.
- // We have considered using emitWarning() which would return
- // InFlightDiagnostic which seems can be automatically converted to
- // LogicalResult since this is done in emitAttributes where emitError is
- // converted to LogicalResult. However, it requires that we pass in a
- // location which at first glance we don't have in this scope. Open to
- // further discussion on this.
- os << "Warning: Cannot process external function '"
- << functionOp.getName() << "'. "
- << "This functionOp lacks a body so we will skip over it.";
- return success();
- }
- os << "class " << emitter.getClassName() << " final {\n";
- os << "public: \n";
- os.indent();
-
- if (failed(emitClassFields(emitter, functionOp)))
- return failure();
- }
-
if (functionOp.getSpecifiers()) {
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
os << cast<StringAttr>(specifier).str() << " ";
@@ -1261,37 +1157,23 @@ static LogicalResult printOperation(CppEmitter &emitter,
if (failed(emitter.emitTypes(functionOp.getLoc(),
functionOp.getFunctionType().getResults())))
return failure();
- // TODO: We may wanna consider having the name of the function be execute in
- // the case that we want to emit a class instead of main. Leaving as is for
- // now to make the change smaller.
os << " " << functionOp.getName();
os << "(";
-
- if (!emitter.shouldPrintClass()) {
- if (functionOp.isExternal()) {
- if (failed(printFunctionArgs(emitter, operation,
- functionOp.getArgumentTypes())))
- return failure();
- os << ");";
- return success();
- }
- if (failed(
- printFunctionArgs(emitter, operation, functionOp.getArguments())))
+ Operation *operation = functionOp.getOperation();
+ if (functionOp.isExternal()) {
+ if (failed(printFunctionArgs(emitter, operation,
+ functionOp.getArgumentTypes())))
return failure();
+ os << ");";
+ return success();
}
+ if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
+ return failure();
os << ") {\n";
-
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
-
- if (emitter.shouldPrintClass()) {
- os << "}\n";
- os.unindent();
- os << "};\n";
- } else {
- os << "}\n";
- }
+ os << "}\n";
return success();
}
@@ -1328,11 +1210,9 @@ static LogicalResult printOperation(CppEmitter &emitter,
}
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
- StringRef fileId, bool emitClass, StringRef className,
- StringRef fieldNameAttribute)
+ StringRef fileId)
: os(os), declareVariablesAtTop(declareVariablesAtTop),
- fileId(fileId.str()), emitClass(emitClass), className(className.str()),
- fieldNameAttribute(fieldNameAttribute.str()) {
+ fileId(fileId.str()) {
valueInScopeCount.push(0);
labelInScopeCount.push(0);
}
@@ -1915,10 +1795,7 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop,
- StringRef fileId, bool emitClass,
- StringRef className,
- StringRef fieldNameAttribute) {
- CppEmitter emitter(os, declareVariablesAtTop, fileId, emitClass, className,
- fieldNameAttribute);
+ StringRef fileId) {
+ CppEmitter emitter(os, declareVariablesAtTop, fileId);
return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
-}
+}
\ No newline at end of file
diff --git a/mlir/test/mlir-translate/emit-class-neg-external.mlir b/mlir/test/mlir-translate/emit-class-neg-external.mlir
deleted file mode 100644
index c34a1652abd3f..0000000000000
--- a/mlir/test/mlir-translate/emit-class-neg-external.mlir
+++ /dev/null
@@ -1,8 +0,0 @@
-/// An external function - has no body
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
-
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
- emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
-}
-
-// CHECK: Warning: Cannot process external function 'extern_func'. This functionOp lacks a body so we will skip over it.
diff --git a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir b/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
deleted file mode 100644
index 6d43fa953a946..0000000000000
--- a/mlir/test/mlir-translate/emit-class-neg-noArgAttrs.mlir
+++ /dev/null
@@ -1,15 +0,0 @@
-/// The function has no argument attributes
-// RUN: not mlir-translate --mlir-to-cpp --emit-class=true --class-name=ArgAttrs --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
-
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
- emitc.func @foo(%arg0 : i32) {
- emitc.call_opaque "bar" (%arg0) : (i32) -> ()
- emitc.return
- }
-}
-
-// CHECK: class ArgAttrs final {
-// CHECK-NEXT: public:
-// CHECK-NEXT: int32_t v1;
-// CHECK-EMPTY:
-// CHECK-NEXT: std::map<std::string, char*> _buffer_map {
diff --git a/mlir/test/mlir-translate/emit-class.mlir b/mlir/test/mlir-translate/emit-class.mlir
deleted file mode 100644
index 2779cb315ed41..0000000000000
--- a/mlir/test/mlir-translate/emit-class.mlir
+++ /dev/null
@@ -1,39 +0,0 @@
-// RUN: mlir-translate --mlir-to-cpp --emit-class=true --class-name=MyAdder --field-name-attribute=tf_saved_model.index_path %s | FileCheck %s
-
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
- emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
- %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
- %1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- %2 = load %1 : <f32>
- %3 = subscript %arg0[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- %4 = load %3 : <f32>
- %5 = add %2, %4 : (f32, f32) -> f32
- %6 = subscript %arg2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- assign %5 : f32 to %6 : <f32>
- return
- }
-}
-
-// CHECK: class MyAdder final {
-// CHECK-NEXT: public:
-// CHECK-NEXT: float v1[1];
-// CHECK-NEXT: float v2[1];
-// CHECK-NEXT: float v3[1];
-// CHECK-EMPTY:
-// CHECK-NEXT: std::map<std::string, char*> _buffer_map { { "another_feature", reinterpret_cast<char*>(v1) },
-// CHECK-SAME: { "some_feature", reinterpret_cast<char*>(v2) }, { "output_0", reinterpret_cast<char*>(v3) }, };
-// CHECK-NEXT: char* getBufferForName(const std::string& name) const {
-// CHECK-NEXT: auto it = _buffer_map.find(name);
-// CHECK-NEXT: return (it == _buffer_map.end()) ? nullptr : it->second;
-// CHECK-NEXT: }
-// CHECK-EMPTY:
-// CHECK-NEXT: void main() {
-// CHECK-NEXT: size_t v4 = 0;
-// CHECK-NEXT: float v5 = v2[v4];
-// CHECK-NEXT: float v6 = v1[v4];
-// CHECK-NEXT: float v7 = v5 + v6;
-// CHECK-NEXT: v3[v4] = v7;
-// CHECK-NEXT: return;
-// CHECK-NEXT: }
-// CHECK-NEXT: };
-
>From d108bf22f602fcdb27425a36626670f4339e9652 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Sat, 14 Jun 2025 00:30:23 +0000
Subject: [PATCH 14/19] cleaning up comments
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 3 +--
mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp | 1 +
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 08235ca701e2a..466c628e6e837 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1579,10 +1579,9 @@ def EmitC_ClassOp
let summary =
"Represents a C++ class definition, encapsulating fields and methods.";
- // FIX WORDING
let description = [{
The `emitc.class` operation defines a C++ class, acting as a container
- for its data fields (`emitc.variable`) and methods (`emitc.func`).
+ for its data fields (`emitc.field`) and methods (`emitc.func`).
It creates a distinct scope, isolating its contents from the surrounding
MLIR region, similar to how C++ classes encapsulate their internals.
All the class memebrs need to be default initalizable.
diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
index a252924caeb62..87350ecdceaaa 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
+#include "llvm/Support/Debug.h"
namespace mlir {
namespace emitc {
>From 86a057e12369fb7563930f83e1d398e7fa5f0d7f Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 16 Jun 2025 23:49:42 +0000
Subject: [PATCH 15/19] Removed tf attributes and allowed for no attributes
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 26 ++++----
.../EmitC/Transforms/WrapFuncInClass.cpp | 59 +++++++++++--------
.../EmitC/wrap_emitc_func_in_class.mlir | 50 ++++++++--------
.../EmitC/wrap_emitc_func_in_class_neg.mlir | 8 ---
.../wrap_emitc_func_in_class_noAttr.mlir | 17 ++++++
5 files changed, 94 insertions(+), 66 deletions(-)
delete mode 100644 mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
create mode 100644 mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 466c628e6e837..fc609fac72f3a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1584,14 +1584,13 @@ def EmitC_ClassOp
for its data fields (`emitc.field`) and methods (`emitc.func`).
It creates a distinct scope, isolating its contents from the surrounding
MLIR region, similar to how C++ classes encapsulate their internals.
- All the class memebrs need to be default initalizable.
Example:
```mlir
- emitc.class @MymainClass {
- emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
- emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
- emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
+ emitc.class @mainClass {
+ emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
+ emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
+ emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
emitc.func @execute() {
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
@@ -1634,15 +1633,22 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
let summary = "A field within a class";
let description = [{
The `emitc.field` operation declares a named field within an `emitc.class`
- operation. The field's type must be an EmitC type. An optional initial value can be provided.
+ operation. The field's type must be an EmitC type. The initial value is optional.
+ If the argument has attributes, these become the initial value, else we end up with no initial value.
Example with initial values:
```mlir
- emitc.class @MyModelClass {
- emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
- emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
- emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
+ emitc.class @modelClass {
+ emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
+ emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
+ emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
+ }
+ ```
+ Example with no initial value:
+ ```mlir
+ emitc.class @modelClass {
+ emitc.field @another_feature : !emitc.array<1xf32>
}
```
}];
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index d87af1379d96a..2a2c65214b6fd 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -12,13 +12,17 @@
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/StringRef.h"
#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/LogicalResult.h"
+#include <string>
namespace mlir {
namespace emitc {
@@ -67,7 +71,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
if (funcOp->getParentOfType<emitc::ClassOp>()) {
return failure();
}
- auto className = "My" + funcOp.getSymNameAttr().str() + "Class";
+ auto className = funcOp.getSymNameAttr().str() + "Class";
mlir::emitc::ClassOp newClassOp =
rewriter.create<emitc::ClassOp>(funcOp.getLoc(), className);
@@ -76,25 +80,33 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
auto argAttrs = funcOp.getArgAttrs();
- if (argAttrs) {
- for (const auto &[arg, val] :
- llvm::zip(*argAttrs, funcOp.getArguments())) {
- if (auto namedAttr =
- dyn_cast<mlir::DictionaryAttr>(arg).getNamed(attributeName)) {
- Attribute nv = namedAttr->getValue();
- StringAttr fieldName =
- cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
- TypeAttr typeAttr = TypeAttr::get(val.getType());
- fields.push_back({fieldName, typeAttr});
-
- rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
- /* attributes*/ arg);
+ size_t idx = 0;
+
+ for (const BlockArgument &val : funcOp.getArguments()) {
+ StringAttr fieldName;
+ Attribute argAttr = nullptr;
+
+ if (argAttrs && idx < argAttrs->size()) {
+ if (DictionaryAttr dictAttr =
+ dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx])) {
+ if (auto namedAttr = dictAttr.getNamed(attributeName)) {
+ Attribute nv = namedAttr->getValue();
+ fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
+ argAttr = (*argAttrs)[idx];
+ }
}
}
- } else {
- funcOp->emitOpError("arguments should have attributes so we can "
- "initialize class fields.");
- return failure();
+
+ if (!fieldName) {
+ fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
+ }
+
+ TypeAttr typeAttr = TypeAttr::get(val.getType());
+ fields.push_back({fieldName, typeAttr});
+ rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
+ argAttr);
+
+ ++idx;
}
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
@@ -112,7 +124,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
std::vector<Value> newArguments;
for (auto [fieldName, attr] : fields) {
- auto arg =
+ GetFieldOp arg =
rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
newArguments.push_back(arg);
}
@@ -122,14 +134,13 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
rewriter.replaceAllUsesWith(oldArg, newArg);
}
- while (!newFuncOp.getArguments().empty()) {
- if (failed(newFuncOp.eraseArgument(0))) {
- break;
- }
+ llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
+ if (failed(newFuncOp.eraseArguments(argsToErase))) {
+ newFuncOp->emitOpError("Failed to erase all arguments using BitVector.");
}
rewriter.replaceOp(funcOp, newClassOp);
- return funcOp->use_empty() ? success() : failure();
+ return success();
}
};
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index e0fa78a3dc459..28636811b1f17 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s | FileCheck %s
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
-module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
- emitc.func @Model(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
+module attributes { } {
+ emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}, %arg1: !emitc.array<1xf32> {emitc.opaque = ["some_feature"]}, %arg2: !emitc.array<1xf32> {emitc.opaque = ["output_0"]}) attributes { } {
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
%1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
%2 = load %1 : <f32>
@@ -14,24 +14,26 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
}
}
-// CHECK: module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
-// CHECK: emitc.class @MyModelClass {
-// CHECK: emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
-// CHECK: emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
-// CHECK: emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
-// CHECK: emitc.func @execute() {
-// CHECK: %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-// CHECK: %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
-// CHECK: %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
-// CHECK: %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
-// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
-// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
-// CHECK: %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
-// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK: assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
-// CHECK: return
-// CHECK: }
-// CHECK: }
-// CHECK: }
+
+// CHECK: module {
+// CHECK-NEXT: emitc.class @modelClass {
+// CHECK-NEXT: emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
+// CHECK-NEXT: emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
+// CHECK-NEXT: emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
+// CHECK-NEXT: emitc.func @execute() {
+// CHECK-NEXT: %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+// CHECK-NEXT: %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
+// CHECK-NEXT: %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
+// CHECK-NEXT: %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
+// CHECK-NEXT: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
+// CHECK-NEXT: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
+// CHECK-NEXT: %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
+// CHECK-NEXT: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT: assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
deleted file mode 100644
index 578e8f6c21ebf..0000000000000
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_neg.mlir
+++ /dev/null
@@ -1,8 +0,0 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s 2>&1 | FileCheck %s
-
-emitc.func @foo(%arg0 : i32) {
- emitc.call_opaque "bar" (%arg0) : (i32) -> ()
- emitc.return
-}
-
-// CHECK: error: 'emitc.func' op arguments should have attributes so we can initialize class fields.
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
new file mode 100644
index 0000000000000..57155cac693bb
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
+
+emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
+ emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
+ emitc.return
+}
+
+// CHECK: module {
+// CHECK-NEXT: emitc.class @fooClass {
+// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT: emitc.func @execute() {
+// CHECK-NEXT: %0 = get_field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT: call_opaque "bar"(%0) : (!emitc.array<1xf32>) -> ()
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
>From b064a9cdcf6cff9941b04587be6034a7f4ecca1c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
=?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
=?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Thu, 22 May 2025 08:24:18 -0700
Subject: [PATCH 16/19] [flang][rt] Fix the use of kNoAsyncId -> kNoAsyncObject
(#141079)
>From 4780ab3650038669cf9be2facf5527fa5770ee38 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 17 Jun 2025 15:04:18 +0000
Subject: [PATCH 17/19] Cleaning up descriptions
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 59 ++++++++-----------
.../mlir/Dialect/EmitC/Transforms/Passes.td | 6 +-
.../EmitC/Transforms/WrapFuncInClass.cpp | 8 +--
.../EmitC/wrap_emitc_func_in_class.mlir | 23 ++++----
.../wrap_emitc_func_in_class_noAttr.mlir | 2 +-
5 files changed, 42 insertions(+), 56 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index fc609fac72f3a..926ff301f1807 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1587,29 +1587,21 @@ def EmitC_ClassOp
Example:
```mlir
- emitc.class @mainClass {
- emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
- emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
- emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
-
+ emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = ["input_tensor"]}) attributes { } {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ return
+ }
+ // becomes
+ emitc.class @modelClass {
+ emitc.field @input_tensor : !emitc.array<1xf32> = {emitc.opaque = ["input_tensor"]}
emitc.func @execute() {
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-
- %1 = get_field @another_feature : !emitc.array<1xf32>
- %2 = get_field @some_feature : !emitc.array<1xf32>
- %3 = get_field @output_0 : !emitc.array<1xf32>
-
- %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- %5 = load %4 : <f32>
- %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- %7 = load %6 : <f32>
- %8 = add %5, %7 : (f32, f32) -> f32
- %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- assign %8 : f32 to %9 : <f32>
+ %1 = get_field @input_tensor : !emitc.array<1xf32>
+ %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
return
}
- }
-
+ }
```
}];
@@ -1633,23 +1625,22 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
let summary = "A field within a class";
let description = [{
The `emitc.field` operation declares a named field within an `emitc.class`
- operation. The field's type must be an EmitC type. The initial value is optional.
- If the argument has attributes, these become the initial value, else we end up with no initial value.
-
- Example with initial values:
+ operation. The field's type must be an EmitC type.
+ If the corresponding function argument has attributes (accessed via `argAttrs`),
+ these attributes are attached to the field operation.
+ Otherwise, the field is created without additional attributes.
+ Example of func argument with attributes:
```mlir
- emitc.class @modelClass {
- emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
- emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
- emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
- }
+ %arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}
+ // becomes
+ emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
```
- Example with no initial value:
+ Example of func argument without attributes:
```mlir
- emitc.class @modelClass {
- emitc.field @another_feature : !emitc.array<1xf32>
- }
+ %arg0 : !emitc.array<1xf32>
+ // becomes
+ emitc.field @fieldName0 : !emitc.array<1xf32>
```
}];
@@ -1673,9 +1664,7 @@ def EmitC_GetFieldOp
Example:
```mlir
- %some_ptr = emitc.get_field @some_feature : !emitc.array<1xf32>
- %another_ptr = emitc.get_field @another_feature : !emitc.array<1xf32>
- %output_ptr = emitc.get_field @output_0 : !emitc.array<1xf32>
+ %0 = get_field @fieldName0 : !emitc.array<1xf32>
```
}];
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index d8ebf4a613bfd..09ad644c2a439 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -29,8 +29,10 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
}];
let dependentDialects = ["emitc::EmitCDialect"];
let options = [Option<
- "namedAttribute", "named-attribute", "std::string", "\"\"",
- "Name of the attribute to look for field names on function arguments">];
+ "namedAttribute", "named-attribute", "std::string",
+ /*default=*/"",
+ "Attribute key used to extract field names from function argument's "
+ "dictionary attributes">];
}
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index 2a2c65214b6fd..31c013dda6f50 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -80,9 +80,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
auto argAttrs = funcOp.getArgAttrs();
- size_t idx = 0;
-
- for (const BlockArgument &val : funcOp.getArguments()) {
+ for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
StringAttr fieldName;
Attribute argAttr = nullptr;
@@ -105,8 +103,6 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
fields.push_back({fieldName, typeAttr});
rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
argAttr);
-
- ++idx;
}
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
@@ -123,7 +119,7 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
std::vector<Value> newArguments;
- for (auto [fieldName, attr] : fields) {
+ for (auto &[fieldName, attr] : fields) {
GetFieldOp arg =
rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
newArguments.push_back(arg);
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index 28636811b1f17..480943116dbf1 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -21,19 +21,18 @@ module attributes { } {
// CHECK-NEXT: emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
// CHECK-NEXT: emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
// CHECK-NEXT: emitc.func @execute() {
-// CHECK-NEXT: %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-// CHECK-NEXT: %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
-// CHECK-NEXT: %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
-// CHECK-NEXT: %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
-// CHECK-NEXT: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
-// CHECK-NEXT: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK-NEXT: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
-// CHECK-NEXT: %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
-// CHECK-NEXT: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
-// CHECK-NEXT: assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
+// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+// CHECK-NEXT: get_field @another_feature : !emitc.array<1xf32>
+// CHECK-NEXT: get_field @some_feature : !emitc.array<1xf32>
+// CHECK-NEXT: get_field @output_0 : !emitc.array<1xf32>
+// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT: load {{.*}} : <f32>
+// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT: load {{.*}} : <f32>
+// CHECK-NEXT: add {{.*}}, {{.*}} : (f32, f32) -> f32
+// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT: assign {{.*}} : f32 to {{.*}} : <f32>
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
-
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
index 57155cac693bb..92ed20c4b14e3 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class_noAttr.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
+// RUN: mlir-opt --wrap-emitc-func-in-class %s | FileCheck %s
emitc.func @foo(%arg0 : !emitc.array<1xf32>) {
emitc.call_opaque "bar" (%arg0) : (!emitc.array<1xf32>) -> ()
>From 009f137492c47d4114725d12c1fd4a9bd4bcf419 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Wed, 18 Jun 2025 21:42:04 +0000
Subject: [PATCH 18/19] Avoid unnecessary checks
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 35 +++------
.../mlir/Dialect/EmitC/Transforms/Passes.td | 23 ++++++
.../Dialect/EmitC/Transforms/Transforms.h | 2 +-
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 60 ++++-----------
.../EmitC/Transforms/WrapFuncInClass.cpp | 74 ++++++-------------
.../EmitC/wrap_emitc_func_in_class.mlir | 18 +++--
6 files changed, 82 insertions(+), 130 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 926ff301f1807..3b7b1a44783da 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1586,18 +1586,13 @@ def EmitC_ClassOp
MLIR region, similar to how C++ classes encapsulate their internals.
Example:
+
```mlir
- emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = ["input_tensor"]}) attributes { } {
- %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
- %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
- return
- }
- // becomes
emitc.class @modelClass {
- emitc.field @input_tensor : !emitc.array<1xf32> = {emitc.opaque = ["input_tensor"]}
+ emitc.field @fieldName0 : !emitc.array<1xf32> = {emitc.opaque = "input_tensor"}
emitc.func @execute() {
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
- %1 = get_field @input_tensor : !emitc.array<1xf32>
+ %1 = get_field @fieldName0 : !emitc.array<1xf32>
%2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
return
}
@@ -1609,8 +1604,6 @@ def EmitC_ClassOp
let regions = (region AnyRegion:$body);
- let builders = [];
-
let extraClassDeclaration = [{
// Returns the body block containing class members and methods.
Block &getBlock();
@@ -1626,29 +1619,21 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
let description = [{
The `emitc.field` operation declares a named field within an `emitc.class`
operation. The field's type must be an EmitC type.
- If the corresponding function argument has attributes (accessed via `argAttrs`),
- these attributes are attached to the field operation.
- Otherwise, the field is created without additional attributes.
- Example of func argument with attributes:
- ```mlir
- %arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}
- // becomes
- emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
- ```
- Example of func argument without attributes:
+ Example:
+
```mlir
- %arg0 : !emitc.array<1xf32>
- // becomes
+ // Example with an attribute:
+ emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.opaque = "another_feature"}
+ // Example with no attribute:
emitc.field @fieldName0 : !emitc.array<1xf32>
```
}];
let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
- OptionalAttr<AnyAttr>:$initial_value);
+ OptionalAttr<AnyAttr>:$attrs);
- let assemblyFormat =
- [{ $sym_name `:` $type (`=` $initial_value^)? attr-dict}];
+ let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
let hasVerifier = 1;
}
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 09ad644c2a439..74c49132b61f6 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -26,6 +26,29 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
This pass transforms `emitc.func` operations into `emitc.class` operations.
Function arguments become fields of the class, and the function body is moved
to a new `execute` method within the class.
+ If the corresponding function argument has attributes (accessed via `argAttrs`),
+ these attributes are attached to the field operation.
+ Otherwise, the field is created without additional attributes.
+
+ Example:
+
+ ```mlir
+ emitc.func @model(%input_data : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}) attributes { } {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = subscript %input_data[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ return
+ }
+ // becomes
+ emitc.class @modelClass {
+ emitc.field @input_tensor : !emitc.array<1xf32> {emitc.opaque = "input_tensor"}
+ emitc.func @execute() {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = get_field @input_tensor : !emitc.array<1xf32>
+ %2 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ return
+ }
+ }
+ ```
}];
let dependentDialects = ["emitc::EmitCDialect"];
let options = [Option<
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 11a1ad2ad2ff2..a4e8fe10ff853 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -30,7 +30,7 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
/// Populates 'patterns' with func-related patterns.
void populateFuncPatterns(RewritePatternSet &patterns,
- const std::string &namedAttribute);
+ StringRef namedAttribute);
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 695ad3ee3bb0a..e99a524fdf8a3 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1404,39 +1404,19 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
// FieldOp
//===----------------------------------------------------------------------===//
LogicalResult FieldOp::verify() {
- if (!isSupportedEmitCType(getType())) {
+ if (!isSupportedEmitCType(getType()))
return emitOpError("expected valid emitc type");
- }
- if (!getInitialValue()) {
- return success();
- }
+ Operation *parentOp = getOperation()->getParentOp();
+ if (!parentOp || !isa<emitc::ClassOp>(parentOp))
+ return emitOpError("field must be nested within an emitc.class operation");
- Attribute initValue = *getInitialValue();
- // Check that the type of the initial value is compatible with the type of
- // the global variable.
- if (ElementsAttr elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
- Type initialValueType = elementsAttr.getType();
- if (!initialValueType) {
- return emitOpError("initial value attribute must have a type");
- }
- Type fieldType = getType();
- if (initialValueType != fieldType) {
- if (LValueType lvalueType = dyn_cast<LValueType>(fieldType)) {
- Type innerFieldType = lvalueType.getValueType();
- if (innerFieldType != initialValueType) {
- return emitOpError("initial value type ")
- << initialValueType << " is not compatible with field type '"
- << fieldType << "' its inner type '" << innerFieldType << "'";
- }
-
- } else {
- return emitOpError("initial value type '")
- << initialValueType << "' is not compatible with field type '"
- << fieldType << "'";
- }
- }
- }
+ StringAttr symName = getSymNameAttr();
+ if (!symName || symName.getValue().empty())
+ return emitOpError("field must have a non-empty symbol name");
+
+ if (!getAttrs())
+ return success();
return success();
}
@@ -1446,27 +1426,19 @@ LogicalResult FieldOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
- if (!fieldNameAttr) {
- return emitError("field name attribute is mandatory");
- }
-
- StringRef fieldName = fieldNameAttr.getValue();
-
FieldOp fieldOp =
- symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, getFieldNameAttr());
-
- if (!fieldOp) {
- return emitOpError("field '") << fieldName << "' not found in the class '";
- }
+ symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr);
+ if (!fieldOp)
+ return emitOpError("field '")
+ << fieldNameAttr << "' not found in the class";
Type getFieldResultType = getResult().getType();
Type fieldType = fieldOp.getType();
- if (fieldType != getFieldResultType) {
+ if (fieldType != getFieldResultType)
return emitOpError("result type ")
- << getFieldResultType << " does not match field '" << fieldName
+ << getFieldResultType << " does not match field '" << fieldNameAttr
<< "' type " << fieldType;
- }
return success();
}
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index 31c013dda6f50..4a02eca594c80 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -1,4 +1,4 @@
-//===- ConvertFuncToClass.cpp - Convert functions to classes -------------===//
+//===- WrapFuncInClass.cpp - Wrap Emitc Funcs in classes -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir-c/Rewrite.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
@@ -14,66 +13,44 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/IR/TypeRange.h"
-#include "mlir/IR/Value.h"
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/GraphWriter.h"
-#include "llvm/Support/LogicalResult.h"
-#include <string>
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace emitc;
namespace mlir {
namespace emitc {
-
#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
namespace {
-
struct WrapFuncInClassPass
: public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
void runOnOperation() override {
Operation *rootOp = getOperation();
- MLIRContext *context = rootOp->getContext();
- RewritePatternSet patterns(context);
+ RewritePatternSet patterns(&getContext());
populateFuncPatterns(patterns, namedAttribute);
- if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
- return signalPassFailure();
- }
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<emitc::EmitCDialect>();
+ walkAndApplyPatterns(rootOp, std::move(patterns));
}
};
} // namespace
-
} // namespace emitc
} // namespace mlir
-using namespace mlir;
-using namespace mlir::emitc;
-
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
-private:
- std::string attributeName;
-
public:
- WrapFuncInClass(MLIRContext *context, const std::string &attrName)
+ WrapFuncInClass(MLIRContext *context, StringRef attrName)
: OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
PatternRewriter &rewriter) const override {
- if (funcOp->getParentOfType<emitc::ClassOp>()) {
- return failure();
- }
+
auto className = funcOp.getSymNameAttr().str() + "Class";
- mlir::emitc::ClassOp newClassOp =
- rewriter.create<emitc::ClassOp>(funcOp.getLoc(), className);
+ ClassOp newClassOp = rewriter.create<ClassOp>(funcOp.getLoc(), className);
SmallVector<std::pair<StringAttr, TypeAttr>> fields;
rewriter.createBlock(&newClassOp.getBody());
@@ -84,19 +61,11 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
StringAttr fieldName;
Attribute argAttr = nullptr;
+ fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
if (argAttrs && idx < argAttrs->size()) {
if (DictionaryAttr dictAttr =
- dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx])) {
- if (auto namedAttr = dictAttr.getNamed(attributeName)) {
- Attribute nv = namedAttr->getValue();
- fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
- argAttr = (*argAttrs)[idx];
- }
- }
- }
-
- if (!fieldName) {
- fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
+ dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx]))
+ argAttr = (*argAttrs)[idx];
}
TypeAttr typeAttr = TypeAttr::get(val.getType());
@@ -106,19 +75,17 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
}
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
- MLIRContext *funcContext = funcOp.getContext();
- ArrayRef<Type> inputTypes = funcOp.getFunctionType().getInputs();
- ArrayRef<Type> results = funcOp.getFunctionType().getResults();
- FunctionType funcType = FunctionType::get(funcContext, inputTypes, results);
+ FunctionType funcType = funcOp.getFunctionType();
Location loc = funcOp.getLoc();
- FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
- loc, rewriter.getStringAttr("execute"), funcType);
+ FuncOp newFuncOp =
+ rewriter.create<emitc::FuncOp>(loc, ("execute"), funcType);
rewriter.createBlock(&newFuncOp.getBody());
newFuncOp.getBody().takeBody(funcOp.getBody());
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
std::vector<Value> newArguments;
+ newArguments.reserve(fields.size());
for (auto &[fieldName, attr] : fields) {
GetFieldOp arg =
rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
@@ -132,15 +99,18 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
llvm::BitVector argsToErase(newFuncOp.getNumArguments(), true);
if (failed(newFuncOp.eraseArguments(argsToErase))) {
- newFuncOp->emitOpError("Failed to erase all arguments using BitVector.");
+ newFuncOp->emitOpError("failed to erase all arguments using BitVector");
}
rewriter.replaceOp(funcOp, newClassOp);
return success();
}
+
+private:
+ StringRef attributeName;
};
void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
- const std::string &namedAttribute) {
+ StringRef namedAttribute) {
patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
}
diff --git a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
index 480943116dbf1..c67a0c197fcd9 100644
--- a/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
+++ b/mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir
@@ -1,7 +1,9 @@
-// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.opaque' %s | FileCheck %s
+// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=emitc.name_hint' %s | FileCheck %s
module attributes { } {
- emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.opaque = ["another_feature"]}, %arg1: !emitc.array<1xf32> {emitc.opaque = ["some_feature"]}, %arg2: !emitc.array<1xf32> {emitc.opaque = ["output_0"]}) attributes { } {
+ emitc.func @model(%arg0: !emitc.array<1xf32> {emitc.name_hint = "another_feature"},
+ %arg1: !emitc.array<1xf32> {emitc.name_hint = "some_feature"},
+ %arg2: !emitc.array<1xf32> {emitc.name_hint = "output_0"}) attributes { } {
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
%1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
%2 = load %1 : <f32>
@@ -17,14 +19,14 @@ module attributes { } {
// CHECK: module {
// CHECK-NEXT: emitc.class @modelClass {
-// CHECK-NEXT: emitc.field @another_feature : !emitc.array<1xf32> = {emitc.opaque = ["another_feature"]}
-// CHECK-NEXT: emitc.field @some_feature : !emitc.array<1xf32> = {emitc.opaque = ["some_feature"]}
-// CHECK-NEXT: emitc.field @output_0 : !emitc.array<1xf32> = {emitc.opaque = ["output_0"]}
+// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.name_hint = "another_feature"}
+// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.name_hint = "some_feature"}
+// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.name_hint = "output_0"}
// CHECK-NEXT: emitc.func @execute() {
+// CHECK-NEXT: get_field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT: get_field @fieldName1 : !emitc.array<1xf32>
+// CHECK-NEXT: get_field @fieldName2 : !emitc.array<1xf32>
// CHECK-NEXT: "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
-// CHECK-NEXT: get_field @another_feature : !emitc.array<1xf32>
-// CHECK-NEXT: get_field @some_feature : !emitc.array<1xf32>
-// CHECK-NEXT: get_field @output_0 : !emitc.array<1xf32>
// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
// CHECK-NEXT: load {{.*}} : <f32>
// CHECK-NEXT: subscript {{.*}}[{{.*}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
>From 3bb679968d3e64c71ed56d9140df568b4c56c704 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Wed, 18 Jun 2025 22:00:00 +0000
Subject: [PATCH 19/19] clean up how we use attributes
---
mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp | 7 ++-----
1 file changed, 2 insertions(+), 5 deletions(-)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index 4a02eca594c80..ff117f2a91618 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -62,11 +62,8 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
Attribute argAttr = nullptr;
fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
- if (argAttrs && idx < argAttrs->size()) {
- if (DictionaryAttr dictAttr =
- dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx]))
- argAttr = (*argAttrs)[idx];
- }
+ if (argAttrs && idx < argAttrs->size())
+ argAttr = (*argAttrs)[idx];
TypeAttr typeAttr = TypeAttr::get(val.getType());
fields.push_back({fieldName, typeAttr});
More information about the Mlir-commits
mailing list