[Mlir-commits] [mlir] [mlir][EmitC]Add a Reflection Map to a Class (PR #150572)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jul 28 09:14:04 PDT 2025
https://github.com/ajaden-codes updated https://github.com/llvm/llvm-project/pull/150572
>From 78f669ab757b8f59a9240e510f500e23465b2df0 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 18 Jul 2025 20:35:29 +0000
Subject: [PATCH 1/7] Modeling
---
.../mlir/Dialect/EmitC/Transforms/Passes.h | 1 +
.../mlir/Dialect/EmitC/Transforms/Passes.td | 35 ++++
.../EmitC/Transforms/AddReflectionMap.cpp | 161 ++++++++++++++++++
.../Dialect/EmitC/Transforms/CMakeLists.txt | 1 +
4 files changed, 198 insertions(+)
create mode 100644 mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
index 1af4aa06fa811..259d6c24cd5fc 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.h
@@ -16,6 +16,7 @@ namespace emitc {
#define GEN_PASS_DECL_FORMEXPRESSIONSPASS
#define GEN_PASS_DECL_WRAPFUNCINCLASSPASS
+#define GEN_PASS_DECL_ADDREFLECTIONMAPPASS
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 1893c101e735b..031b80362718e 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -53,4 +53,39 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
let dependentDialects = ["emitc::EmitCDialect"];
}
+def AddReflectionMapPass : Pass<"add-reflection-map"> {
+ let summary =
+ "Add a reflection map function to EmitC classes for runtime field lookup";
+ let description = [{
+ This pass adds a `getBufferForName` function to EmitC classes that enables
+ runtime lookup of field buffers by their string names.
+ This enables runtime introspection and dynamic access to class fields by name,
+ which is useful for interfacing with external systems that need to access
+ tensors/buffers by their semantic names.
+
+ Example transformation:
+ ```mlir
+ emitc.class @MyClass {
+ emitc.field @fieldName0 : !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}
+ emitc.field @fieldName1 : !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}
+ emitc.func @execute() { ... }
+ }
+ ```
+
+ Becomes:
+ ```mlir
+ emitc.class @MyClass {
+ emitc.field @fieldName0 : !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}
+ emitc.field @fieldName1 : !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}
+ emitc.func @getBufferForName(%name : !emitc.opaque<"std::string_view">) -> !emitc.opaque<"char*"> {
+ %map = "emitc.constant"(){value = #emitc.opaque<"{"another_feature", reinterpret_cast<char*>(&another_feature)}, {"some_feature", reinterpret_cast<char*>(&some_feature)}">} : () -> !emitc.opaque<"std::map<std::string, char*>">
+ return %null : !emitc.opaque<"char*">
+ }
+ emitc.func @execute() { ... }
+ }
+ ```
+ }];
+ let dependentDialects = ["mlir::emitc::EmitCDialect"];
+}
+
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
new file mode 100644
index 0000000000000..c74b03aefcefe
--- /dev/null
+++ b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
@@ -0,0 +1,161 @@
+/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace emitc;
+
+namespace mlir {
+namespace emitc {
+#define GEN_PASS_DEF_ADDREFLECTIONMAPPASS
+#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
+
+namespace {
+class AddReflectionMapPass
+ : public impl::AddReflectionMapPassBase<AddReflectionMapPass> {
+ using AddReflectionMapPassBase::AddReflectionMapPassBase;
+ void runOnOperation() override {
+ Operation *rootOp = getOperation();
+
+ RewritePatternSet patterns(&getContext());
+ populateAddReflectionMapPatterns(patterns);
+
+ walkAndApplyPatterns(rootOp, std::move(patterns));
+ }
+};
+
+} // namespace
+} // namespace emitc
+} // namespace mlir
+
+class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
+public:
+ AddReflectionMapClass(MLIRContext *context)
+ : OpRewritePattern<emitc::ClassOp>(context) {}
+
+ LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp,
+ PatternRewriter &rewriter) const override {
+ mlir::MLIRContext *context = rewriter.getContext();
+ emitc::OpaqueType stringViewType =
+ mlir::emitc::OpaqueType::get(rewriter.getContext(), "std::string_view");
+ emitc::OpaqueType charPtrType =
+ mlir::emitc::OpaqueType::get(rewriter.getContext(), "char");
+ emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
+ rewriter.getContext(), "const std::map<std::string, char*>");
+
+ FunctionType funcType =
+ rewriter.getFunctionType({stringViewType}, {charPtrType});
+ emitc::FuncOp executeFunc =
+ classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
+ rewriter.setInsertionPoint(executeFunc);
+
+ emitc::FuncOp getBufferFunc = rewriter.create<mlir::emitc::FuncOp>(
+ classOp.getLoc(), "getBufferForName", funcType);
+
+ Block *funcBody = getBufferFunc.addEntryBlock();
+ rewriter.setInsertionPointToStart(funcBody);
+
+ // Collect all field names
+ SmallVector<std::string> fieldNames;
+ classOp.walk([&](mlir::emitc::FieldOp fieldOp) {
+ if (mlir::Attribute attrsAttr =
+ fieldOp->getAttrDictionary().get("attrs")) {
+ if (DictionaryAttr innerDictAttr =
+ dyn_cast<mlir::DictionaryAttr>(attrsAttr)) {
+ auto indexPathAttr =
+ innerDictAttr.getNamed("tf_saved_model.index_path");
+ ArrayAttr arrayAttr =
+ dyn_cast<mlir::ArrayAttr>(indexPathAttr->getValue());
+ if (!arrayAttr.empty()) {
+ StringAttr stringAttr = dyn_cast<mlir::StringAttr>(arrayAttr[0]);
+ std::string indexPath = stringAttr.getValue().str();
+ fieldNames.push_back(indexPath);
+ }
+ if (arrayAttr.size() > 1) {
+ fieldOp.emitError() << "tf_saved_model.index_path attribute must "
+ "contain at most one value, but found "
+ << arrayAttr.size() << " values.";
+ return;
+ }
+ }
+ }
+ });
+
+ std::string mapInitializer = "{{";
+ for (size_t i = 0; i < fieldNames.size(); ++i) {
+ mapInitializer += "\"" + fieldNames[i] + "\", " +
+ "reinterpret_cast<char*>(&" + fieldNames[i] + ")",
+ mapInitializer += "}";
+ if (i < fieldNames.size() - 1)
+ mapInitializer += ", {";
+ }
+ mapInitializer += "}";
+
+ auto iteratorType = mlir::emitc::OpaqueType::get(
+ context, "std::map<std::string, char*>::const_iterator");
+ auto boolType = rewriter.getI1Type();
+ // 5. Create the constant map
+ auto bufferMap = rewriter.create<emitc::ConstantOp>(
+ classOp.getLoc(), mapType,
+ emitc::OpaqueAttr::get(context, mapInitializer));
+
+ // 6. Get the function argument
+ mlir::Value nameArg = getBufferFunc.getArgument(0);
+
+ // 7. Create the find call
+ auto it = rewriter.create<emitc::CallOpaqueOp>(
+ classOp.getLoc(), iteratorType, rewriter.getStringAttr("find"),
+ mlir::ValueRange{bufferMap.getResult(), nameArg});
+
+ // 8. Create the end call
+ auto endIt = rewriter.create<emitc::CallOpaqueOp>(
+ classOp.getLoc(), iteratorType, rewriter.getStringAttr("end"),
+ bufferMap.getResult());
+
+ // 9. Create the operator== call
+ auto isEnd = rewriter.create<emitc::CallOpaqueOp>(
+ classOp.getLoc(), boolType,
+ "operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
+
+ // 10. Create the nullptr constant
+ auto nullPtr = rewriter.create<emitc::ConstantOp>(
+ classOp.getLoc(), charPtrType,
+ emitc::OpaqueAttr::get(context, "nullptr"));
+
+ // 11. Create the second call
+ auto second = rewriter.create<emitc::CallOpaqueOp>(
+ classOp.getLoc(), charPtrType, "second", it.getResult(0));
+
+ // 12. Create the conditional
+ auto result = rewriter.create<emitc::ConditionalOp>(
+ classOp.getLoc(), charPtrType, isEnd.getResult(0), nullPtr.getResult(),
+ second.getResult(0));
+
+ // 13. Create return
+ rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());
+
+ return success();
+ }
+};
+
+void mlir::emitc::populateAddReflectionMapPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<AddReflectionMapClass>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
index baf67afc30072..dd8f014dc4737 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/EmitC/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIREmitCTransforms
FormExpressions.cpp
TypeConversions.cpp
WrapFuncInClass.cpp
+ AddReflectionMap.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC/Transforms
>From 9be23970d11a0f2ef3dd0afe1c2049386bf49dd3 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Tue, 22 Jul 2025 22:33:45 +0000
Subject: [PATCH 2/7] Add an argument
---
.../mlir/Dialect/EmitC/Transforms/Passes.td | 12 ++-
.../EmitC/Transforms/AddReflectionMap.cpp | 91 ++++++++-----------
.../Dialect/EmitC/add_reflection_map.mlir | 55 +++++++++++
3 files changed, 101 insertions(+), 57 deletions(-)
create mode 100644 mlir/test/Dialect/EmitC/add_reflection_map.mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 031b80362718e..5facdedbd1b4b 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -66,8 +66,8 @@ def AddReflectionMapPass : Pass<"add-reflection-map"> {
Example transformation:
```mlir
emitc.class @MyClass {
- emitc.field @fieldName0 : !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}
- emitc.field @fieldName1 : !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}
+ emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
+ emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
emitc.func @execute() { ... }
}
```
@@ -75,8 +75,8 @@ def AddReflectionMapPass : Pass<"add-reflection-map"> {
Becomes:
```mlir
emitc.class @MyClass {
- emitc.field @fieldName0 : !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}
- emitc.field @fieldName1 : !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}
+ emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
+ emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
emitc.func @getBufferForName(%name : !emitc.opaque<"std::string_view">) -> !emitc.opaque<"char*"> {
%map = "emitc.constant"(){value = #emitc.opaque<"{"another_feature", reinterpret_cast<char*>(&another_feature)}, {"some_feature", reinterpret_cast<char*>(&some_feature)}">} : () -> !emitc.opaque<"std::map<std::string, char*>">
return %null : !emitc.opaque<"char*">
@@ -86,6 +86,10 @@ def AddReflectionMapPass : Pass<"add-reflection-map"> {
```
}];
let dependentDialects = ["mlir::emitc::EmitCDialect"];
+ let options = [Option<"namedAttribute", "named-attribute", "std::string",
+ /*default=*/"",
+ "Attribute key used to extract field names from fields "
+ "dictionary attributes">];
}
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
index c74b03aefcefe..a9c4ae229d56f 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
@@ -1,15 +1,10 @@
-/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
+//===- AddReflectionMap.cpp - Add a reflection map to a class -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
@@ -35,7 +30,7 @@ class AddReflectionMapPass
Operation *rootOp = getOperation();
RewritePatternSet patterns(&getContext());
- populateAddReflectionMapPatterns(patterns);
+ populateAddReflectionMapPatterns(patterns, namedAttribute);
walkAndApplyPatterns(rootOp, std::move(patterns));
}
@@ -47,8 +42,8 @@ class AddReflectionMapPass
class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
public:
- AddReflectionMapClass(MLIRContext *context)
- : OpRewritePattern<emitc::ClassOp>(context) {}
+ AddReflectionMapClass(MLIRContext *context, StringRef attrName)
+ : OpRewritePattern<emitc::ClassOp>(context), attributeName(attrName) {}
LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp,
PatternRewriter &rewriter) const override {
@@ -73,23 +68,23 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
rewriter.setInsertionPointToStart(funcBody);
// Collect all field names
- SmallVector<std::string> fieldNames;
+ std::vector<std::pair<std::string, std::string>> fieldNames;
classOp.walk([&](mlir::emitc::FieldOp fieldOp) {
if (mlir::Attribute attrsAttr =
fieldOp->getAttrDictionary().get("attrs")) {
if (DictionaryAttr innerDictAttr =
dyn_cast<mlir::DictionaryAttr>(attrsAttr)) {
- auto indexPathAttr =
- innerDictAttr.getNamed("tf_saved_model.index_path");
+ auto indexPathAttr = innerDictAttr.getNamed(attributeName);
ArrayAttr arrayAttr =
dyn_cast<mlir::ArrayAttr>(indexPathAttr->getValue());
if (!arrayAttr.empty()) {
StringAttr stringAttr = dyn_cast<mlir::StringAttr>(arrayAttr[0]);
std::string indexPath = stringAttr.getValue().str();
- fieldNames.push_back(indexPath);
+ fieldNames.emplace_back(indexPath, fieldOp.getName().str());
}
if (arrayAttr.size() > 1) {
- fieldOp.emitError() << "tf_saved_model.index_path attribute must "
+ fieldOp.emitError() << attributeName
+ << " attribute must "
"contain at most one value, but found "
<< arrayAttr.size() << " values.";
return;
@@ -98,64 +93,54 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
}
});
- std::string mapInitializer = "{{";
+ std::string mapInitializer = "{ ";
for (size_t i = 0; i < fieldNames.size(); ++i) {
- mapInitializer += "\"" + fieldNames[i] + "\", " +
- "reinterpret_cast<char*>(&" + fieldNames[i] + ")",
- mapInitializer += "}";
+ mapInitializer += " { \"" + fieldNames[i].first + "\", " +
+ "reinterpret_cast<char*>(&" + fieldNames[i].second +
+ ")",
+ mapInitializer += " }";
if (i < fieldNames.size() - 1)
- mapInitializer += ", {";
+ mapInitializer += ", ";
}
- mapInitializer += "}";
+ mapInitializer += " }";
- auto iteratorType = mlir::emitc::OpaqueType::get(
+ emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get(
context, "std::map<std::string, char*>::const_iterator");
- auto boolType = rewriter.getI1Type();
- // 5. Create the constant map
- auto bufferMap = rewriter.create<emitc::ConstantOp>(
+
+ emitc::ConstantOp bufferMap = rewriter.create<emitc::ConstantOp>(
classOp.getLoc(), mapType,
emitc::OpaqueAttr::get(context, mapInitializer));
- // 6. Get the function argument
mlir::Value nameArg = getBufferFunc.getArgument(0);
-
- // 7. Create the find call
- auto it = rewriter.create<emitc::CallOpaqueOp>(
+ emitc::CallOpaqueOp it = rewriter.create<emitc::CallOpaqueOp>(
classOp.getLoc(), iteratorType, rewriter.getStringAttr("find"),
mlir::ValueRange{bufferMap.getResult(), nameArg});
-
- // 8. Create the end call
- auto endIt = rewriter.create<emitc::CallOpaqueOp>(
+ emitc::CallOpaqueOp endIt = rewriter.create<emitc::CallOpaqueOp>(
classOp.getLoc(), iteratorType, rewriter.getStringAttr("end"),
bufferMap.getResult());
-
- // 9. Create the operator== call
- auto isEnd = rewriter.create<emitc::CallOpaqueOp>(
- classOp.getLoc(), boolType,
+ emitc::CallOpaqueOp isEnd = rewriter.create<emitc::CallOpaqueOp>(
+ classOp.getLoc(), rewriter.getI1Type(),
"operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
-
- // 10. Create the nullptr constant
- auto nullPtr = rewriter.create<emitc::ConstantOp>(
+ emitc::ConstantOp nullPtr = rewriter.create<emitc::ConstantOp>(
classOp.getLoc(), charPtrType,
emitc::OpaqueAttr::get(context, "nullptr"));
-
- // 11. Create the second call
- auto second = rewriter.create<emitc::CallOpaqueOp>(
+ emitc::CallOpaqueOp second = rewriter.create<emitc::CallOpaqueOp>(
classOp.getLoc(), charPtrType, "second", it.getResult(0));
- // 12. Create the conditional
- auto result = rewriter.create<emitc::ConditionalOp>(
+ emitc::ConditionalOp result = rewriter.create<emitc::ConditionalOp>(
classOp.getLoc(), charPtrType, isEnd.getResult(0), nullPtr.getResult(),
second.getResult(0));
- // 13. Create return
rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());
return success();
}
+
+private:
+ StringRef attributeName;
};
-void mlir::emitc::populateAddReflectionMapPatterns(
- RewritePatternSet &patterns) {
- patterns.add<AddReflectionMapClass>(patterns.getContext());
+void mlir::emitc::populateAddReflectionMapPatterns(RewritePatternSet &patterns,
+ StringRef namedAttribute) {
+ patterns.add<AddReflectionMapClass>(patterns.getContext(), namedAttribute);
}
diff --git a/mlir/test/Dialect/EmitC/add_reflection_map.mlir b/mlir/test/Dialect/EmitC/add_reflection_map.mlir
new file mode 100644
index 0000000000000..f61ee639b22fc
--- /dev/null
+++ b/mlir/test/Dialect/EmitC/add_reflection_map.mlir
@@ -0,0 +1,55 @@
+// RUN: mlir-opt --add-reflection-map="named-attribute=emitc.field_ref" %s | FileCheck %s
+
+emitc.class @mainClass {
+ emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
+ emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
+ emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.field_ref = ["output_0"]}
+ emitc.func @execute() {
+ %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+ %1 = get_field @fieldName0 : !emitc.array<1xf32>
+ %2 = get_field @fieldName1 : !emitc.array<1xf32>
+ %3 = get_field @fieldName2 : !emitc.array<1xf32>
+ %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %5 = load %4 : <f32>
+ %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ %7 = load %6 : <f32>
+ %8 = add %5, %7 : (f32, f32) -> f32
+ %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+ assign %8 : f32 to %9 : <f32>
+ return
+ }
+}
+
+// CHECK: module {
+// CHECK-NEXT: emitc.class @mainClass {
+// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
+// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
+// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.field_ref = ["output_0"]}
+// CHECK-NEXT: emitc.func @getBufferForName(%arg0: !emitc.opaque<"std::string_view">) -> !emitc.opaque<"char"> {
+// CHECK-NEXT: %0 = "emitc.constant"() <{value = #emitc.opaque<"{ { \22another_feature\22, reinterpret_cast<char*>(&fieldName0) }, { \22some_feature\22, reinterpret_cast<char*>(&fieldName1) }, { \22output_0\22, reinterpret_cast<char*>(&fieldName2) } }">}> : () -> !emitc.opaque<"const std::map<std::string, char*>">
+// CHECK-NEXT: %1 = call_opaque "find"(%0, %arg0) : (!emitc.opaque<"const std::map<std::string, char*>">, !emitc.opaque<"std::string_view">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
+// CHECK-NEXT: %2 = call_opaque "end"(%0) : (!emitc.opaque<"const std::map<std::string, char*>">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
+// CHECK-NEXT: %3 = call_opaque "operator=="(%1, %2) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">, !emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> i1
+// CHECK-NEXT: %4 = "emitc.constant"() <{value = #emitc.opaque<"nullptr">}> : () -> !emitc.opaque<"char">
+// CHECK-NEXT: %5 = call_opaque "second"(%1) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> !emitc.opaque<"char">
+// CHECK-NEXT: %6 = conditional %3, %4, %5 : !emitc.opaque<"char">
+// CHECK-NEXT: return %6 : !emitc.opaque<"char">
+// CHECK-NEXT: }
+// CHECK-NEXT: emitc.func @execute() {
+// CHECK-NEXT: %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
+// CHECK-NEXT: %1 = get_field @fieldName0 : !emitc.array<1xf32>
+// CHECK-NEXT: %2 = get_field @fieldName1 : !emitc.array<1xf32>
+// CHECK-NEXT: %3 = get_field @fieldName2 : !emitc.array<1xf32>
+// CHECK-NEXT: %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %5 = load %4 : <f32>
+// CHECK-NEXT: %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT: %7 = load %6 : <f32>
+// CHECK-NEXT: %8 = add %5, %7 : (f32, f32) -> f32
+// CHECK-NEXT: %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
+// CHECK-NEXT: assign %8 : f32 to %9 : <f32>
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+// CHECK-NEXT: }
+
+
>From 1b07f074e588ab054fd4af6ebcafcdaf2b3c0ba1 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 25 Jul 2025 05:49:14 +0000
Subject: [PATCH 3/7] Specify the pass reqs
---
.../mlir/Dialect/EmitC/Transforms/Passes.td | 28 +++++----
.../EmitC/Transforms/AddReflectionMap.cpp | 58 +++++++++++++++----
.../Dialect/EmitC/add_reflection_map.mlir | 2 +
3 files changed, 66 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index 5facdedbd1b4b..b591fecbbe6c1 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -53,33 +53,41 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
let dependentDialects = ["emitc::EmitCDialect"];
}
-def AddReflectionMapPass : Pass<"add-reflection-map"> {
+def AddReflectionMapPass : Pass<"add-reflection-map", "ModuleOp"> {
let summary =
"Add a reflection map function to EmitC classes for runtime field lookup";
let description = [{
This pass adds a `getBufferForName` function to EmitC classes that enables
runtime lookup of field buffers by their string names.
- This enables runtime introspection and dynamic access to class fields by name,
- which is useful for interfacing with external systems that need to access
- tensors/buffers by their semantic names.
+ This would require that the class has fields with attributes and a function named `execute`.
+ The `fieldop` attribute is expected to be a dictionary where:
+ - The keys are `namedAttribute`.
+ - The values are arrays containing a single string attribute.
+
+
+ Example:
- Example transformation:
```mlir
emitc.class @MyClass {
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
emitc.func @execute() { ... }
}
- ```
- Becomes:
- ```mlir
+ // becomes:
+
emitc.class @MyClass {
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
emitc.func @getBufferForName(%name : !emitc.opaque<"std::string_view">) -> !emitc.opaque<"char*"> {
- %map = "emitc.constant"(){value = #emitc.opaque<"{"another_feature", reinterpret_cast<char*>(&another_feature)}, {"some_feature", reinterpret_cast<char*>(&some_feature)}">} : () -> !emitc.opaque<"std::map<std::string, char*>">
- return %null : !emitc.opaque<"char*">
+ %0 = "emitc.constant"() <{value = #emitc.opaque<"{ { \22another_feature\22, reinterpret_cast<char*>(&fieldName0) }, { \22some_feature\22, reinterpret_cast<char*>(&fieldName1) } }">}> : () -> !emitc.opaque<"const std::map<std::string, char*>">
+ %1 = call_opaque "find"(%0, %arg0) : (!emitc.opaque<"const std::map<std::string, char*>">, !emitc.opaque<"std::string_view">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
+ %2 = call_opaque "end"(%0) : (!emitc.opaque<"const std::map<std::string, char*>">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
+ %3 = call_opaque "operator=="(%1, %2) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">, !emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> i1
+ %4 = "emitc.constant"() <{value = #emitc.opaque<"nullptr">}> : () -> !emitc.opaque<"char">
+ %5 = call_opaque "second"(%1) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> !emitc.opaque<"char">
+ %6 = conditional %3, %4, %5 : !emitc.opaque<"char">
+ return %6 : !emitc.opaque<"char">
}
emitc.func @execute() { ... }
}
diff --git a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
index a9c4ae229d56f..854d9adb4adcc 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
@@ -23,16 +23,48 @@ namespace emitc {
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
namespace {
+constexpr const char *kMapLibraryHeader = "map";
+constexpr const char *kStringLibraryHeader = "string";
class AddReflectionMapPass
: public impl::AddReflectionMapPassBase<AddReflectionMapPass> {
using AddReflectionMapPassBase::AddReflectionMapPassBase;
void runOnOperation() override {
- Operation *rootOp = getOperation();
+ mlir::ModuleOp module = getOperation();
RewritePatternSet patterns(&getContext());
populateAddReflectionMapPatterns(patterns, namedAttribute);
- walkAndApplyPatterns(rootOp, std::move(patterns));
+ walkAndApplyPatterns(module, std::move(patterns));
+ bool hasMap = false;
+ bool hasString = false;
+ for (auto &op : *module.getBody()) {
+ emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
+ if (!includeOp)
+ continue;
+ if (includeOp.getIsStandardInclude()) {
+ if (includeOp.getInclude() == kMapLibraryHeader)
+ hasMap = true;
+ if (includeOp.getInclude() == kStringLibraryHeader)
+ hasString = true;
+ }
+ }
+
+ if (hasMap && hasString)
+ return;
+
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ if (!hasMap) {
+ StringAttr includeAttr = builder.getStringAttr(kMapLibraryHeader);
+ builder.create<mlir::emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+ }
+ if (!hasString) {
+ StringAttr includeAttr = builder.getStringAttr(kStringLibraryHeader);
+ builder.create<emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+ }
}
};
@@ -50,16 +82,20 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
mlir::MLIRContext *context = rewriter.getContext();
emitc::OpaqueType stringViewType =
mlir::emitc::OpaqueType::get(rewriter.getContext(), "std::string_view");
- emitc::OpaqueType charPtrType =
+ emitc::OpaqueType charType =
mlir::emitc::OpaqueType::get(rewriter.getContext(), "char");
emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
rewriter.getContext(), "const std::map<std::string, char*>");
FunctionType funcType =
- rewriter.getFunctionType({stringViewType}, {charPtrType});
+ rewriter.getFunctionType({stringViewType}, {charType});
emitc::FuncOp executeFunc =
classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
- rewriter.setInsertionPoint(executeFunc);
+ if (executeFunc)
+ rewriter.setInsertionPoint(executeFunc);
+ else
+ classOp.emitError() << "ClassOp must contain a function named 'execute' "
+ "to add reflection map";
emitc::FuncOp getBufferFunc = rewriter.create<mlir::emitc::FuncOp>(
classOp.getLoc(), "getBufferForName", funcType);
@@ -74,9 +110,8 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
fieldOp->getAttrDictionary().get("attrs")) {
if (DictionaryAttr innerDictAttr =
dyn_cast<mlir::DictionaryAttr>(attrsAttr)) {
- auto indexPathAttr = innerDictAttr.getNamed(attributeName);
- ArrayAttr arrayAttr =
- dyn_cast<mlir::ArrayAttr>(indexPathAttr->getValue());
+ ArrayAttr arrayAttr = dyn_cast<mlir::ArrayAttr>(
+ innerDictAttr.getNamed(attributeName)->getValue());
if (!arrayAttr.empty()) {
StringAttr stringAttr = dyn_cast<mlir::StringAttr>(arrayAttr[0]);
std::string indexPath = stringAttr.getValue().str();
@@ -122,13 +157,12 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
classOp.getLoc(), rewriter.getI1Type(),
"operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
emitc::ConstantOp nullPtr = rewriter.create<emitc::ConstantOp>(
- classOp.getLoc(), charPtrType,
- emitc::OpaqueAttr::get(context, "nullptr"));
+ classOp.getLoc(), charType, emitc::OpaqueAttr::get(context, "nullptr"));
emitc::CallOpaqueOp second = rewriter.create<emitc::CallOpaqueOp>(
- classOp.getLoc(), charPtrType, "second", it.getResult(0));
+ classOp.getLoc(), charType, "second", it.getResult(0));
emitc::ConditionalOp result = rewriter.create<emitc::ConditionalOp>(
- classOp.getLoc(), charPtrType, isEnd.getResult(0), nullPtr.getResult(),
+ classOp.getLoc(), charType, isEnd.getResult(0), nullPtr.getResult(),
second.getResult(0));
rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());
diff --git a/mlir/test/Dialect/EmitC/add_reflection_map.mlir b/mlir/test/Dialect/EmitC/add_reflection_map.mlir
index f61ee639b22fc..607de348e6f79 100644
--- a/mlir/test/Dialect/EmitC/add_reflection_map.mlir
+++ b/mlir/test/Dialect/EmitC/add_reflection_map.mlir
@@ -21,6 +21,8 @@ emitc.class @mainClass {
}
// CHECK: module {
+// CHECK-NEXT: emitc.include <"map">
+// CHECK-NEXT: emitc.include <"string">
// CHECK-NEXT: emitc.class @mainClass {
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
>From 608b19b557e8ba2a6b70c0cdd9be0dbc2f2fbae5 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Fri, 25 Jul 2025 21:00:45 +0000
Subject: [PATCH 4/7] small change
---
.../mlir/Dialect/EmitC/Transforms/Passes.td | 2 +-
.../EmitC/Transforms/AddReflectionMap.cpp | 43 ++++++++++---------
2 files changed, 23 insertions(+), 22 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
index b591fecbbe6c1..5815bdc471cb6 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td
@@ -59,7 +59,7 @@ def AddReflectionMapPass : Pass<"add-reflection-map", "ModuleOp"> {
let description = [{
This pass adds a `getBufferForName` function to EmitC classes that enables
runtime lookup of field buffers by their string names.
- This would require that the class has fields with attributes and a function named `execute`.
+ This requires that the class has fields with attributes and a function named `execute`.
The `fieldop` attribute is expected to be a dictionary where:
- The keys are `namedAttribute`.
- The values are arrays containing a single string attribute.
diff --git a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
index 854d9adb4adcc..b25cd6740e5ba 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
@@ -23,8 +23,16 @@ namespace emitc {
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
namespace {
-constexpr const char *kMapLibraryHeader = "map";
-constexpr const char *kStringLibraryHeader = "string";
+constexpr const char *mapLibraryHeader = "map";
+constexpr const char *stringLibraryHeader = "string";
+
+IncludeOp addHeader(OpBuilder &builder, ModuleOp module, StringRef headerName) {
+ StringAttr includeAttr = builder.getStringAttr(headerName);
+ return builder.create<emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+}
+
class AddReflectionMapPass
: public impl::AddReflectionMapPassBase<AddReflectionMapPass> {
using AddReflectionMapPassBase::AddReflectionMapPassBase;
@@ -35,35 +43,28 @@ class AddReflectionMapPass
populateAddReflectionMapPatterns(patterns, namedAttribute);
walkAndApplyPatterns(module, std::move(patterns));
- bool hasMap = false;
- bool hasString = false;
+ bool hasMapHdr = false;
+ bool hasStringHdr = false;
for (auto &op : *module.getBody()) {
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
if (!includeOp)
continue;
if (includeOp.getIsStandardInclude()) {
- if (includeOp.getInclude() == kMapLibraryHeader)
- hasMap = true;
- if (includeOp.getInclude() == kStringLibraryHeader)
- hasString = true;
+ if (includeOp.getInclude() == mapLibraryHeader)
+ hasMapHdr = true;
+ if (includeOp.getInclude() == stringLibraryHeader)
+ hasStringHdr = true;
}
+ if (hasMapHdr && hasStringHdr)
+ return;
}
- if (hasMap && hasString)
- return;
-
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
- if (!hasMap) {
- StringAttr includeAttr = builder.getStringAttr(kMapLibraryHeader);
- builder.create<mlir::emitc::IncludeOp>(
- module.getLoc(), includeAttr,
- /*is_standard_include=*/builder.getUnitAttr());
+ if (!hasMapHdr) {
+ addHeader(builder, module, mapLibraryHeader);
}
- if (!hasString) {
- StringAttr includeAttr = builder.getStringAttr(kStringLibraryHeader);
- builder.create<emitc::IncludeOp>(
- module.getLoc(), includeAttr,
- /*is_standard_include=*/builder.getUnitAttr());
+ if (!hasStringHdr) {
+ addHeader(builder, module, stringLibraryHeader);
}
}
};
>From fa627db0354147d169e2dfca745c47fe4efff831 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 28 Jul 2025 15:40:25 +0000
Subject: [PATCH 5/7] avoid re-initialization
---
.../Dialect/EmitC/Transforms/Transforms.h | 4 +
.../EmitC/Transforms/AddReflectionMap.cpp | 74 +++++++------------
2 files changed, 32 insertions(+), 46 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index bdf6d0985e6db..7abc430347dc3 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -31,6 +31,10 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
/// Populates 'patterns' with func-related patterns.
void populateFuncPatterns(RewritePatternSet &patterns);
+/// Populates `patterns` with patterns to add reflection map for EmitC classes.
+void populateAddReflectionMapPatterns(RewritePatternSet &patterns,
+ StringRef namedAttribute);
+
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
index b25cd6740e5ba..eaad5a6b352cb 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
@@ -80,29 +80,10 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp,
PatternRewriter &rewriter) const override {
- mlir::MLIRContext *context = rewriter.getContext();
- emitc::OpaqueType stringViewType =
- mlir::emitc::OpaqueType::get(rewriter.getContext(), "std::string_view");
- emitc::OpaqueType charType =
- mlir::emitc::OpaqueType::get(rewriter.getContext(), "char");
- emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
- rewriter.getContext(), "const std::map<std::string, char*>");
-
- FunctionType funcType =
- rewriter.getFunctionType({stringViewType}, {charType});
- emitc::FuncOp executeFunc =
- classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
- if (executeFunc)
- rewriter.setInsertionPoint(executeFunc);
- else
- classOp.emitError() << "ClassOp must contain a function named 'execute' "
- "to add reflection map";
+ MLIRContext *context = rewriter.getContext();
- emitc::FuncOp getBufferFunc = rewriter.create<mlir::emitc::FuncOp>(
- classOp.getLoc(), "getBufferForName", funcType);
-
- Block *funcBody = getBufferFunc.addEntryBlock();
- rewriter.setInsertionPointToStart(funcBody);
+ emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
+ context, "const std::map<std::string, char*>");
// Collect all field names
std::vector<std::pair<std::string, std::string>> fieldNames;
@@ -129,44 +110,45 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
}
});
+ // Construct the map initializer string
std::string mapInitializer = "{ ";
for (size_t i = 0; i < fieldNames.size(); ++i) {
mapInitializer += " { \"" + fieldNames[i].first + "\", " +
"reinterpret_cast<char*>(&" + fieldNames[i].second +
- ")",
- mapInitializer += " }";
+ ")";
+ mapInitializer += " }";
if (i < fieldNames.size() - 1)
mapInitializer += ", ";
}
mapInitializer += " }";
- emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get(
- context, "std::map<std::string, char*>::const_iterator");
+ emitc::OpaqueType returnType = mlir::emitc::OpaqueType::get(
+ context, "const std::map<std::string, char*>");
+
+ emitc::FuncOp executeFunc =
+ classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
+ if (executeFunc)
+ rewriter.setInsertionPoint(executeFunc);
+ else
+ classOp.emitError() << "ClassOp must contain a function named 'execute' "
+ "to add reflection map";
+
+ // Create the getFeatures function
+ emitc::FuncOp getFeaturesFunc = rewriter.create<mlir::emitc::FuncOp>(
+ classOp.getLoc(), "getFeatures",
+ rewriter.getFunctionType({}, {returnType}));
+
+ // Add the body of the getFeatures function
+ Block *funcBody = getFeaturesFunc.addEntryBlock();
+ rewriter.setInsertionPointToStart(funcBody);
+ // Create the constant map
emitc::ConstantOp bufferMap = rewriter.create<emitc::ConstantOp>(
classOp.getLoc(), mapType,
emitc::OpaqueAttr::get(context, mapInitializer));
- mlir::Value nameArg = getBufferFunc.getArgument(0);
- emitc::CallOpaqueOp it = rewriter.create<emitc::CallOpaqueOp>(
- classOp.getLoc(), iteratorType, rewriter.getStringAttr("find"),
- mlir::ValueRange{bufferMap.getResult(), nameArg});
- emitc::CallOpaqueOp endIt = rewriter.create<emitc::CallOpaqueOp>(
- classOp.getLoc(), iteratorType, rewriter.getStringAttr("end"),
- bufferMap.getResult());
- emitc::CallOpaqueOp isEnd = rewriter.create<emitc::CallOpaqueOp>(
- classOp.getLoc(), rewriter.getI1Type(),
- "operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
- emitc::ConstantOp nullPtr = rewriter.create<emitc::ConstantOp>(
- classOp.getLoc(), charType, emitc::OpaqueAttr::get(context, "nullptr"));
- emitc::CallOpaqueOp second = rewriter.create<emitc::CallOpaqueOp>(
- classOp.getLoc(), charType, "second", it.getResult(0));
-
- emitc::ConditionalOp result = rewriter.create<emitc::ConditionalOp>(
- classOp.getLoc(), charType, isEnd.getResult(0), nullPtr.getResult(),
- second.getResult(0));
-
- rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());
+ rewriter.create<mlir::emitc::ReturnOp>(classOp.getLoc(),
+ bufferMap.getResult());
return success();
}
>From 6f82da956d942b0914557886e3937793f2ca703b Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 28 Jul 2025 16:07:41 +0000
Subject: [PATCH 6/7] Revert "avoid re-initialization"
This reverts commit f2dee0d99bc1fb258b6cef57dc150cb637cc4ab3.
---
.../Dialect/EmitC/Transforms/Transforms.h | 4 -
.../EmitC/Transforms/AddReflectionMap.cpp | 74 ++++++++++++-------
2 files changed, 46 insertions(+), 32 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
index 7abc430347dc3..bdf6d0985e6db 100644
--- a/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h
@@ -31,10 +31,6 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
/// Populates 'patterns' with func-related patterns.
void populateFuncPatterns(RewritePatternSet &patterns);
-/// Populates `patterns` with patterns to add reflection map for EmitC classes.
-void populateAddReflectionMapPatterns(RewritePatternSet &patterns,
- StringRef namedAttribute);
-
} // namespace emitc
} // namespace mlir
diff --git a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
index eaad5a6b352cb..b25cd6740e5ba 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
@@ -80,10 +80,29 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp,
PatternRewriter &rewriter) const override {
- MLIRContext *context = rewriter.getContext();
-
+ mlir::MLIRContext *context = rewriter.getContext();
+ emitc::OpaqueType stringViewType =
+ mlir::emitc::OpaqueType::get(rewriter.getContext(), "std::string_view");
+ emitc::OpaqueType charType =
+ mlir::emitc::OpaqueType::get(rewriter.getContext(), "char");
emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
- context, "const std::map<std::string, char*>");
+ rewriter.getContext(), "const std::map<std::string, char*>");
+
+ FunctionType funcType =
+ rewriter.getFunctionType({stringViewType}, {charType});
+ emitc::FuncOp executeFunc =
+ classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
+ if (executeFunc)
+ rewriter.setInsertionPoint(executeFunc);
+ else
+ classOp.emitError() << "ClassOp must contain a function named 'execute' "
+ "to add reflection map";
+
+ emitc::FuncOp getBufferFunc = rewriter.create<mlir::emitc::FuncOp>(
+ classOp.getLoc(), "getBufferForName", funcType);
+
+ Block *funcBody = getBufferFunc.addEntryBlock();
+ rewriter.setInsertionPointToStart(funcBody);
// Collect all field names
std::vector<std::pair<std::string, std::string>> fieldNames;
@@ -110,45 +129,44 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
}
});
- // Construct the map initializer string
std::string mapInitializer = "{ ";
for (size_t i = 0; i < fieldNames.size(); ++i) {
mapInitializer += " { \"" + fieldNames[i].first + "\", " +
"reinterpret_cast<char*>(&" + fieldNames[i].second +
- ")";
- mapInitializer += " }";
+ ")",
+ mapInitializer += " }";
if (i < fieldNames.size() - 1)
mapInitializer += ", ";
}
mapInitializer += " }";
- emitc::OpaqueType returnType = mlir::emitc::OpaqueType::get(
- context, "const std::map<std::string, char*>");
-
- emitc::FuncOp executeFunc =
- classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
- if (executeFunc)
- rewriter.setInsertionPoint(executeFunc);
- else
- classOp.emitError() << "ClassOp must contain a function named 'execute' "
- "to add reflection map";
-
- // Create the getFeatures function
- emitc::FuncOp getFeaturesFunc = rewriter.create<mlir::emitc::FuncOp>(
- classOp.getLoc(), "getFeatures",
- rewriter.getFunctionType({}, {returnType}));
-
- // Add the body of the getFeatures function
- Block *funcBody = getFeaturesFunc.addEntryBlock();
- rewriter.setInsertionPointToStart(funcBody);
+ emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get(
+ context, "std::map<std::string, char*>::const_iterator");
- // Create the constant map
emitc::ConstantOp bufferMap = rewriter.create<emitc::ConstantOp>(
classOp.getLoc(), mapType,
emitc::OpaqueAttr::get(context, mapInitializer));
- rewriter.create<mlir::emitc::ReturnOp>(classOp.getLoc(),
- bufferMap.getResult());
+ mlir::Value nameArg = getBufferFunc.getArgument(0);
+ emitc::CallOpaqueOp it = rewriter.create<emitc::CallOpaqueOp>(
+ classOp.getLoc(), iteratorType, rewriter.getStringAttr("find"),
+ mlir::ValueRange{bufferMap.getResult(), nameArg});
+ emitc::CallOpaqueOp endIt = rewriter.create<emitc::CallOpaqueOp>(
+ classOp.getLoc(), iteratorType, rewriter.getStringAttr("end"),
+ bufferMap.getResult());
+ emitc::CallOpaqueOp isEnd = rewriter.create<emitc::CallOpaqueOp>(
+ classOp.getLoc(), rewriter.getI1Type(),
+ "operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
+ emitc::ConstantOp nullPtr = rewriter.create<emitc::ConstantOp>(
+ classOp.getLoc(), charType, emitc::OpaqueAttr::get(context, "nullptr"));
+ emitc::CallOpaqueOp second = rewriter.create<emitc::CallOpaqueOp>(
+ classOp.getLoc(), charType, "second", it.getResult(0));
+
+ emitc::ConditionalOp result = rewriter.create<emitc::ConditionalOp>(
+ classOp.getLoc(), charType, isEnd.getResult(0), nullPtr.getResult(),
+ second.getResult(0));
+
+ rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());
return success();
}
>From 94b0c34027a4f9a86a749f188ac591ede886ee44 Mon Sep 17 00:00:00 2001
From: Jaddyen <ajaden at google.com>
Date: Mon, 28 Jul 2025 16:08:35 +0000
Subject: [PATCH 7/7] Cleaning
---
mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
index b25cd6740e5ba..1bf0747ecf487 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp
@@ -138,7 +138,7 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
if (i < fieldNames.size() - 1)
mapInitializer += ", ";
}
- mapInitializer += " }";
+ mapInitializer += " }}";
emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get(
context, "std::map<std::string, char*>::const_iterator");
More information about the Mlir-commits
mailing list