[Mlir-commits] [mlir] 149d4b5 - [mlir][EmitC]Allow Fields to have initial values (#151437)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 1 11:20:16 PDT 2025


Author: Jaden Angella
Date: 2025-08-01T11:20:13-07:00
New Revision: 149d4b503391b4643f3085bd82e19eae69e5e3fb

URL: https://github.com/llvm/llvm-project/commit/149d4b503391b4643f3085bd82e19eae69e5e3fb
DIFF: https://github.com/llvm/llvm-project/commit/149d4b503391b4643f3085bd82e19eae69e5e3fb.diff

LOG: [mlir][EmitC]Allow Fields to have initial values (#151437)

This will ensure that:
- The `field` of a class can have an initial value
- The `field` op is emitted correctly
- The `getfield` op is emitted correctly

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
    mlir/lib/Dialect/EmitC/IR/EmitC.cpp
    mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
    mlir/lib/Target/Cpp/TranslateToCpp.cpp
    mlir/test/mlir-translate/emitc_classops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 7fe2da8f7e044..937b34a625628 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1659,13 +1659,22 @@ def EmitC_FieldOp : EmitC_Op<"field", [Symbol]> {
     emitc.field @fieldName0 : !emitc.array<1xf32>  {emitc.opaque = "another_feature"}
     // Example with no attribute:
     emitc.field @fieldName0 : !emitc.array<1xf32>
+    // Example with an initial value:
+    emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0>
+    // Example with an initial value and attributes:
+    emitc.field @fieldName0 : !emitc.array<1xf32> = dense<0.0> {
+      emitc.opaque = "input_tensor"}
     ```
   }];
 
   let arguments = (ins SymbolNameAttr:$sym_name, TypeAttr:$type,
-      OptionalAttr<AnyAttr>:$attrs);
+      OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value);
 
-  let assemblyFormat = [{ $sym_name `:` $type ($attrs^)? attr-dict}];
+  let assemblyFormat = [{
+       $sym_name
+       `:` custom<EmitCFieldOpTypeAndInitialValue>($type, $initial_value)
+       attr-dict
+  }];
 
   let hasVerifier = 1;
 }
@@ -1686,7 +1695,7 @@ def EmitC_GetFieldOp
   }];
 
   let arguments = (ins FlatSymbolRefAttr:$field_name);
-  let results = (outs AnyTypeOf<[EmitC_ArrayType, EmitC_LValueType]>:$result);
+  let results = (outs EmitCType:$result);
   let assemblyFormat = "$field_name `:` type($result) attr-dict";
 }
 

diff  --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4c0902293cbf9..e6a3154721faa 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1398,6 +1398,45 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
 //===----------------------------------------------------------------------===//
 // FieldOp
 //===----------------------------------------------------------------------===//
+static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
+                                                 TypeAttr type,
+                                                 Attribute initialValue) {
+  p << type;
+  if (initialValue) {
+    p << " = ";
+    p.printAttributeWithoutType(initialValue);
+  }
+}
+
+static Type getInitializerTypeForField(Type type) {
+  if (auto array = llvm::dyn_cast<ArrayType>(type))
+    return RankedTensorType::get(array.getShape(), array.getElementType());
+  return type;
+}
+
+static ParseResult
+parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+                                     Attribute &initialValue) {
+  Type type;
+  if (parser.parseType(type))
+    return failure();
+
+  typeAttr = TypeAttr::get(type);
+
+  if (parser.parseOptionalEqual())
+    return success();
+
+  if (parser.parseAttribute(initialValue, getInitializerTypeForField(type)))
+    return failure();
+
+  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
+          initialValue))
+    return parser.emitError(parser.getNameLoc())
+           << "initial value should be a integer, float, elements or opaque "
+              "attribute";
+  return success();
+}
+
 LogicalResult FieldOp::verify() {
   if (!isSupportedEmitCType(getType()))
     return emitOpError("expected valid emitc type");
@@ -1410,9 +1449,6 @@ LogicalResult FieldOp::verify() {
   if (!symName || symName.getValue().empty())
     return emitOpError("field must have a non-empty symbol name");
 
-  if (!getAttrs())
-    return success();
-
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index fa05ad8063b99..c55e26e722f33 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -58,17 +58,18 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
 
     auto argAttrs = funcOp.getArgAttrs();
     for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
-      StringAttr fieldName;
-      Attribute argAttr = nullptr;
-
-      fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
-      if (argAttrs && idx < argAttrs->size())
-        argAttr = (*argAttrs)[idx];
+      StringAttr fieldName =
+          rewriter.getStringAttr("fieldName" + std::to_string(idx));
 
       TypeAttr typeAttr = TypeAttr::get(val.getType());
       fields.push_back({fieldName, typeAttr});
-      emitc::FieldOp::create(rewriter, funcOp.getLoc(), fieldName, typeAttr,
-                             argAttr);
+
+      FieldOp fieldop = rewriter.create<emitc::FieldOp>(
+          funcOp->getLoc(), fieldName, typeAttr, nullptr);
+
+      if (argAttrs && idx < argAttrs->size()) {
+        fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
+      }
     }
 
     rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());

diff  --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index dcd2e11e83c6a..8e83e455d1a7f 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -333,7 +333,8 @@ struct CppEmitter {
 /// Determine whether expression \p op should be emitted in a deferred way.
 static bool hasDeferredEmission(Operation *op) {
   return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
-                         emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
+                         emitc::MemberOfPtrOp, emitc::SubscriptOp,
+                         emitc::GetFieldOp>(op);
 }
 
 /// Determine whether expression \p expressionOp should be emitted inline, i.e.
@@ -1049,25 +1050,17 @@ static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) {
 
 static LogicalResult printOperation(CppEmitter &emitter, FieldOp fieldOp) {
   raw_ostream &os = emitter.ostream();
-  if (failed(emitter.emitType(fieldOp->getLoc(), fieldOp.getType())))
+  if (failed(emitter.emitVariableDeclaration(
+          fieldOp->getLoc(), fieldOp.getType(), fieldOp.getSymName())))
     return failure();
-  os << " " << fieldOp.getSymName() << ";";
-  return success();
-}
-
-static LogicalResult printOperation(CppEmitter &emitter,
-                                    GetFieldOp getFieldOp) {
-  raw_indented_ostream &os = emitter.ostream();
-
-  Value result = getFieldOp.getResult();
-  if (failed(emitter.emitType(getFieldOp->getLoc(), result.getType())))
-    return failure();
-  os << " ";
-  if (failed(emitter.emitOperand(result)))
-    return failure();
-  os << " = ";
+  std::optional<Attribute> initialValue = fieldOp.getInitialValue();
+  if (initialValue) {
+    os << " = ";
+    if (failed(emitter.emitAttribute(fieldOp->getLoc(), *initialValue)))
+      return failure();
+  }
 
-  os << getFieldOp.getFieldName().str();
+  os << ";";
   return success();
 }
 
@@ -1204,7 +1197,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
   os << ") {\n";
   if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
     return failure();
-  os << "}\n";
+  os << "}";
 
   return success();
 }
@@ -1245,7 +1238,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
   os << ") {\n";
   if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
     return failure();
-  os << "}\n";
+  os << "}";
 
   return success();
 }
@@ -1700,12 +1693,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
                 emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp,
                 emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp,
-                emitc::GetFieldOp, emitc::GlobalOp, emitc::IfOp,
-                emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
-                emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
-                emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
-                emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
-                emitc::VerbatimOp>(
+                emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
+                emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
+                emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+                emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
+                emitc::VariableOp, emitc::VerbatimOp>(
 
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
@@ -1715,6 +1707,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
             cacheDeferredOpResult(op.getResult(), op.getName());
             return success();
           })
+          .Case<emitc::GetFieldOp>([&](auto op) {
+            cacheDeferredOpResult(op.getResult(), op.getFieldName());
+            return success();
+          })
           .Case<emitc::LiteralOp>([&](auto op) {
             cacheDeferredOpResult(op.getResult(), op.getValue());
             return success();

diff  --git a/mlir/test/mlir-translate/emitc_classops.mlir b/mlir/test/mlir-translate/emitc_classops.mlir
index 4b7ddf4630d55..d880f9b16dfc6 100644
--- a/mlir/test/mlir-translate/emitc_classops.mlir
+++ b/mlir/test/mlir-translate/emitc_classops.mlir
@@ -14,15 +14,12 @@ emitc.class @modelClass {
 
 // CHECK-LABEL: class modelClass {
 // CHECK-NEXT: public:
-// CHECK-NEXT:  float[1] fieldName0;
-// CHECK-NEXT:  float[1] fieldName1;
+// CHECK-NEXT:  float fieldName0[1];
+// CHECK-NEXT:  float fieldName1[1];
 // CHECK-NEXT:  void execute() {
 // CHECK-NEXT:    size_t v1 = 0;
-// CHECK-NEXT:    float[1] v2 = fieldName0;
-// CHECK-NEXT:    float[1] v3 = fieldName1;
 // CHECK-NEXT:    return;
 // CHECK-NEXT:  }
-// CHECK-EMPTY:
 // CHECK-NEXT: };
 
 emitc.class final @finalClass {
@@ -39,13 +36,43 @@ emitc.class final @finalClass {
 
 // CHECK-LABEL: class finalClass final {
 // CHECK-NEXT: public:
-// CHECK-NEXT:  float[1] fieldName0;
-// CHECK-NEXT:  float[1] fieldName1;
+// CHECK-NEXT:  float fieldName0[1];
+// CHECK-NEXT:  float fieldName1[1];
 // CHECK-NEXT:  void execute() {
 // CHECK-NEXT:    size_t v1 = 0;
-// CHECK-NEXT:    float[1] v2 = fieldName0;
-// CHECK-NEXT:    float[1] v3 = fieldName1;
 // CHECK-NEXT:    return;
 // CHECK-NEXT:  }
-// CHECK-EMPTY:
 // CHECK-NEXT: };
+
+emitc.class @mainClass {
+  emitc.field @fieldName0 : !emitc.array<2xf32> = dense<0.0> {attrs = {emitc.name_hint = "another_feature"}}
+  emitc.func @get_fieldName0() {
+    %0 = emitc.get_field @fieldName0 : !emitc.array<2xf32>
+    return 
+  }
+}
+
+// CHECK-LABEL: class mainClass {
+// CHECK-NEXT: public:
+// CHECK-NEXT:  float fieldName0[2] = {0.0e+00f, 0.0e+00f};
+// CHECK-NEXT:  void get_fieldName0() {
+// CHECK-NEXT:    return;
+// CHECK-NEXT:  }
+// CHECK-NEXT: };
+
+emitc.class @reflectionClass {
+  emitc.field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>"> = #emitc.opaque<"{ { \22another_feature\22, \22fieldName0\22 } }"> 
+  emitc.func @get_reflectionMap() {
+    %0 = emitc.get_field @reflectionMap : !emitc.opaque<"const std::map<std::string, std::string>">
+    return 
+  }
+}
+
+// CHECK-LABEL: class reflectionClass {
+// CHECK-NEXT: public:
+// CHECK-NEXT:  const std::map<std::string, std::string> reflectionMap = { { "another_feature", "fieldName0" } };
+// CHECK-NEXT:  void get_reflectionMap() {
+// CHECK-NEXT:    return;
+// CHECK-NEXT:  }
+// CHECK-NEXT: };
+


        


More information about the Mlir-commits mailing list