[Mlir-commits] [mlir] [mlir][tblgen] add concrete create methods (PR #147168)

Maksim Levental llvmlistbot at llvm.org
Sat Jul 5 17:34:47 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/147168

>From 29ed7839891f50f83bef92f0a43ff27ff1e752ab Mon Sep 17 00:00:00 2001
From: Maksim Levental <maksim.levental at gmail.com>
Date: Sat, 5 Jul 2025 20:26:55 -0400
Subject: [PATCH] [mlir][tblgen] add concrete create methods

---
 mlir/include/mlir/TableGen/Class.h          |  2 +
 mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 51 ++++++++++++++++++---
 2 files changed, 47 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index f750a34a3b2ba..69cefbbc43e0a 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -71,6 +71,8 @@ class MethodParameter {
   StringRef getName() const { return name; }
   /// Returns true if the parameter has a default value.
   bool hasDefaultValue() const { return !defaultValue.empty(); }
+  StringRef getDefaultValue() const { return defaultValue; }
+  bool isOptional() const { return optional; }
 
 private:
   /// The C++ type.
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6008ed4673d1b..d90164a8e7377 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -230,6 +230,14 @@ static const char *const opCommentHeader = R"(
 
 )";
 
+static const char *const inlineCreateBody = R"(
+  OperationState state(location, getOperationName());
+  build(builder, state, {0});
+  auto __res__ = dyn_cast<{1}>(builder.create(state));
+  assert(__res__ && "builder didn't return the right type");
+  return __res__;
+)";
+
 //===----------------------------------------------------------------------===//
 // Utility structs and functions
 //===----------------------------------------------------------------------===//
@@ -665,6 +673,7 @@ class OpEmitter {
   // Generates the build() method that takes each operand/attribute
   // as a stand-alone parameter.
   void genSeparateArgParamBuilder();
+  void genInlineCreateBody(const SmallVector<MethodParameter> &paramList);
 
   // Generates the build() method that takes each operand/attribute as a
   // stand-alone parameter. The generated build() method uses first operand's
@@ -2557,6 +2566,29 @@ static bool canInferType(const Operator &op) {
   return op.getTrait("::mlir::InferTypeOpInterface::Trait");
 }
 
+void OpEmitter::genInlineCreateBody(
+    const SmallVector<MethodParameter> &paramList) {
+  SmallVector<MethodParameter> createParamList;
+  SmallVector<llvm::StringRef, 4> nonBuilderStateArgsList;
+  createParamList.emplace_back("::mlir::OpBuilder &", "builder");
+  createParamList.emplace_back("::mlir::Location", "location");
+  for (auto &param : paramList) {
+    if (param.getType() == "::mlir::OpBuilder &" or
+        param.getType() == "::mlir::OperationState &")
+      continue;
+    createParamList.emplace_back(param.getType(), param.getName(),
+                                 param.getDefaultValue(), param.isOptional());
+    nonBuilderStateArgsList.push_back(param.getName());
+  }
+  auto *c = opClass.addStaticMethod(opClass.getClassName(), "create",
+                                    createParamList);
+  std::string nonBuilderStateArgs = "";
+  llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs);
+  interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS);
+  c->body() << llvm::formatv(inlineCreateBody, nonBuilderStateArgs,
+                             opClass.getClassName());
+}
+
 void OpEmitter::genSeparateArgParamBuilder() {
   SmallVector<AttrParamKind, 2> attrBuilderType;
   attrBuilderType.push_back(AttrParamKind::WrappedAttr);
@@ -2573,10 +2605,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
     buildParamList(paramList, inferredAttributes, resultNames, paramKind,
                    attrType);
 
-    auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+    auto *m = opClass.addStaticMethod("void", "build", paramList);
     // If the builder is redundant, skip generating the method.
     if (!m)
       return;
+    genInlineCreateBody(paramList);
+
     auto &body = m->body();
     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
                                            /*isRawValueAttr=*/attrType ==
@@ -2701,10 +2735,11 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
   if (op.getNumVariadicRegions())
     paramList.emplace_back("unsigned", "numRegions");
 
-  auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+  auto *m = opClass.addStaticMethod("void", "build", paramList);
   // If the builder is redundant, skip generating the method
   if (!m)
     return;
+  genInlineCreateBody(paramList);
   auto &body = m->body();
 
   // Operands
@@ -2815,10 +2850,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder(
   if (op.getNumVariadicRegions())
     paramList.emplace_back("unsigned", "numRegions");
 
-  auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+  auto *m = opClass.addStaticMethod("void", "build", paramList);
   // If the builder is redundant, skip generating the method
   if (!m)
     return;
+  genInlineCreateBody(paramList);
   auto &body = m->body();
 
   int numResults = op.getNumResults();
@@ -2895,10 +2931,11 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
     buildParamList(paramList, inferredAttributes, resultNames,
                    TypeParamKind::None, attrType);
 
-    auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+    auto *m = opClass.addStaticMethod("void", "build", paramList);
     // If the builder is redundant, skip generating the method
     if (!m)
       return;
+    genInlineCreateBody(paramList);
     auto &body = m->body();
     genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
                                            /*isRawValueAttr=*/attrType ==
@@ -2937,10 +2974,11 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
                                  : "attributes";
   paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
                          attributesName, "{}");
-  auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+  auto *m = opClass.addStaticMethod("void", "build", paramList);
   // If the builder is redundant, skip generating the method
   if (!m)
     return;
+  genInlineCreateBody(paramList);
 
   auto &body = m->body();
 
@@ -3103,10 +3141,11 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
   if (op.getNumVariadicRegions())
     paramList.emplace_back("unsigned", "numRegions");
 
-  auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+  auto *m = opClass.addStaticMethod("void", "build", paramList);
   // If the builder is redundant, skip generating the method
   if (!m)
     return;
+  genInlineCreateBody(paramList);
   auto &body = m->body();
 
   // Operands



More information about the Mlir-commits mailing list