[Mlir-commits] [mlir] WIP: Add support for regions in irdl-to-cpp (PR #158540)

Jeremy Kun llvmlistbot at llvm.org
Sat Sep 27 12:56:07 PDT 2025


https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/158540

>From 397e1945b4f31a175acd5ea12b3451f1c0a31ac1 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 13 Sep 2025 22:46:18 -0700
Subject: [PATCH 1/2] Add support for regions in irdl-to-cpp

---
 mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp       | 124 +++++++++++++++++-
 .../IRDLToCpp/Templates/PerOperationDecl.txt  |  51 +++----
 .../IRDLToCpp/Templates/PerOperationDef.txt   |  12 +-
 ...to_cpp_invalid_unsupported_types.irdl.mlir |  27 +---
 4 files changed, 162 insertions(+), 52 deletions(-)

diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index d6b8a8a1df426..92a9c59ec9ead 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -54,6 +54,7 @@ struct OpStrings {
   std::string opCppName;
   SmallVector<std::string> opResultNames;
   SmallVector<std::string> opOperandNames;
+  SmallVector<std::string> opRegionNames;
 };
 
 static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
 /// Generates OpStrings from an OperatioOp
 static OpStrings getStrings(irdl::OperationOp op) {
   auto operandOp = op.getOp<irdl::OperandsOp>();
-
   auto resultOp = op.getOp<irdl::ResultsOp>();
+  auto regionsOp = op.getOp<irdl::RegionsOp>();
 
   OpStrings strings;
   strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
         }));
   }
 
+  if (regionsOp) {
+    strings.opRegionNames = SmallVector<std::string>(
+        llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
+          return llvm::formatv("{0}", cast<StringAttr>(attr));
+        }));
+  }
+
   return strings;
 }
 
@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
 static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
   const auto operandCount = strings.opOperandNames.size();
   const auto resultCount = strings.opResultNames.size();
+  const auto regionCount = strings.opRegionNames.size();
 
   dict["OP_NAME"] = strings.opName;
   dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,14 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
       operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
   dict["OP_RESULT_INITIALIZER_LIST"] =
       resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
+  dict["OP_REGION_COUNT"] = std::to_string(regionCount);
+  dict["OP_ADD_REGIONS"] =
+      regionCount ? std::string(llvm::formatv(
+                        R"(for (unsigned i = 0; i != {0}; ++i)
+  (void)odsState.addRegion();
+)",
+                        regionCount))
+                  : "";
 }
 
 /// Fills a dictionary with values from DialectStrings
@@ -179,6 +196,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
                                          const OpStrings &opStrings) {
   auto opGetters = std::string{};
   auto resGetters = std::string{};
+  auto regionGetters = std::string{};
+  auto regionAdaptorGetters = std::string{};
 
   for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
     const auto op =
@@ -196,8 +215,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
         op, i);
   }
 
+  for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
+    const auto op =
+        llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
+    regionAdaptorGetters += llvm::formatv(
+        R"(::mlir::Region &get{0}() { return getRegions()[{1}]; }
+  )",
+        op, i);
+    regionGetters += llvm::formatv(
+        R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
+  )",
+        op, i);
+  }
+
   dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
   dict["OP_RESULT_GETTER_DECLS"] = resGetters;
+  dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
+  dict["OP_REGION_GETTER_DECLS"] = regionGetters;
 }
 
 static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -238,6 +272,21 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
   dict["OP_BUILD_DECLS"] = buildDecls;
 }
 
+// add traits to the dictionary, return true if any were added
+static std::vector<std::string> generateTraits(irdl::detail::dictionary &dict,
+                                               irdl::OperationOp op,
+                                               const OpStrings &strings) {
+  std::vector<std::string> cppTraitNames;
+  if (!strings.opRegionNames.empty()) {
+    cppTraitNames.push_back(
+        llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
+                      strings.opRegionNames.size())
+            .str());
+    cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
+  }
+  return cppTraitNames;
+}
+
 static LogicalResult generateOperationInclude(irdl::OperationOp op,
                                               raw_ostream &output,
                                               irdl::detail::dictionary &dict) {
@@ -247,6 +296,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
   const auto opStrings = getStrings(op);
   fillDict(dict, opStrings);
 
+  std::vector<std::string> traitNames = generateTraits(dict, op, opStrings);
+  if (traitNames.empty())
+    dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
+  else
+    dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
+                                             llvm::join(traitNames, ", "));
+
   generateOpGetterDeclarations(dict, opStrings);
   generateOpBuilderDeclarations(dict, opStrings);
 
@@ -301,6 +357,66 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
   return success();
 }
 
+static void generateVerifiers(irdl::detail::dictionary &dict,
+                              irdl::OperationOp op, const OpStrings &strings) {
+  std::vector<std::string> verifierHelpers;
+  std::vector<std::string> verifierCalls;
+  if (strings.opRegionNames.empty()) {
+    dict["OP_VERIFIER_HELPERS"] = "";
+    dict["OP_VERIFIER"] = "";
+    // Currently IRDL regions are the only reason to generate a verifier,
+    // though this will likely change as
+    // https://github.com/llvm/llvm-project/issues/158040 is implemented
+    return;
+  }
+
+  for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
+    std::string regionName = strings.opRegionNames[i];
+    std::string helperFnName =
+        llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
+                      strings.opCppName, regionName)
+            .str();
+    std::string condition = "true"; // FIXME: get from irdl region op
+    std::string textualConditionName =
+        "any region"; // FIXME: can use textual irdl for this region?
+
+    verifierHelpers.push_back(llvm::formatv(
+        R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
+    if (!{1}) {{
+      return op->emitOpError("region #") << regionIndex
+          << (regionName.empty() ? " " : " ('" + regionName + "')
+          << "failed to verify constraint: {2}";
+        }}
+    return ::mlir::success();
+        }})",
+        helperFnName, condition, textualConditionName));
+
+    verifierCalls.push_back(llvm::formatv(R"(
+  if (::mlir::failed({0}(*this, (*this)->getRegion({1}), {2}, {1})))
+    return ::mlir::failure();)",
+                                          helperFnName, i, regionName)
+                                .str());
+  }
+
+  // Add an overall verifier that sequences the helper calls
+  std::string verifierDef =
+      llvm::formatv(R"(
+::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
+    {1}
+    return ::mlir::success();
+  }}
+
+::llvm::LogicalResult {0}::verifyInvariants() {{
+  if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify()))
+    return ::mlir::success();
+  return ::mlir::failure();
+}})",
+                    strings.opCppName, llvm::join(verifierCalls, "\n"));
+
+  dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
+  dict["OP_VERIFIER"] = verifierDef;
+}
+
 static std::string generateOpDefinition(irdl::detail::dictionary &dict,
                                         irdl::OperationOp op) {
   static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +486,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
 
   dict["OP_BUILD_DEFS"] = buildDefinition;
 
+  generateVerifiers(dict, op, opStrings);
+
   std::string str;
   llvm::raw_string_ostream stream{str};
   perOpDefTemplate.render(stream, dict);
@@ -427,7 +545,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
   dict["TYPE_PARSER"] = llvm::formatv(
       R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
   return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
-    {0}    
+    {0}
     .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
       *mnemonic = keyword;
       return std::nullopt;
@@ -520,6 +638,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
                   "IRDL C++ translation does not yet support variadic results");
             }))
             .Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
+            .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
+            .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
             .Default([](mlir::Operation *op) -> LogicalResult {
               return op->emitError("IRDL C++ translation does not yet support "
                                    "translation of ")
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
index e9068e9488f99..ae900bb818dd3 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
@@ -12,15 +12,15 @@ public:
   struct Properties {
   };
 public:
-  __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) 
-    : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), 
-               odsRegions(op->getRegions()) 
+  __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
+    : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
+               odsRegions(op->getRegions())
   {}
 
   /// Return the unstructured operand index of a structured operand along with
   // the amount of unstructured operands it contains.
   std::pair<unsigned, unsigned>
-  getStructuredOperandIndexAndLength (unsigned index, 
+  getStructuredOperandIndexAndLength (unsigned index,
                                       unsigned odsOperandsSize) {
     return {index, 1};
   }
@@ -32,6 +32,12 @@ public:
   ::mlir::DictionaryAttr getAttributes() {
     return odsAttrs;
   }
+
+  __OP_REGION_ADAPTER_GETTER_DECLS__
+
+  ::mlir::RegionRange getRegions() {
+    return odsRegions;
+  }
 protected:
   ::mlir::DictionaryAttr odsAttrs;
   ::std::optional<::mlir::OperationName> odsOpName;
@@ -42,28 +48,28 @@ protected:
 } // namespace detail
 
 template <typename RangeT>
-class __OP_CPP_NAME__GenericAdaptor 
+class __OP_CPP_NAME__GenericAdaptor
   : public detail::__OP_CPP_NAME__GenericAdaptorBase {
   using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
   using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
 public:
   __OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
-                                ::mlir::OpaqueProperties properties, 
-                                ::mlir::RegionRange regions = {}) 
-    : __OP_CPP_NAME__GenericAdaptor(values, attrs, 
-      (properties ? *properties.as<::mlir::EmptyProperties *>() 
+                                ::mlir::OpaqueProperties properties,
+                                ::mlir::RegionRange regions = {})
+    : __OP_CPP_NAME__GenericAdaptor(values, attrs,
+      (properties ? *properties.as<::mlir::EmptyProperties *>()
       : ::mlir::EmptyProperties{}), regions) {}
 
-  __OP_CPP_NAME__GenericAdaptor(RangeT values, 
+  __OP_CPP_NAME__GenericAdaptor(RangeT values,
                                 const __OP_CPP_NAME__GenericAdaptorBase &base)
     : Base(base), odsOperands(values) {}
 
-  // This template parameter allows using __OP_CPP_NAME__ which is declared 
+  // This template parameter allows using __OP_CPP_NAME__ which is declared
   // later.
   template <typename LateInst = __OP_CPP_NAME__,
             typename = std::enable_if_t<
                          std::is_same_v<LateInst, __OP_CPP_NAME__>>>
-  __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) 
+  __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
     : Base(op), odsOperands(values) {}
 
   /// Return the unstructured operand index of a structured operand along with
@@ -77,7 +83,7 @@ public:
   RangeT getStructuredOperands(unsigned index) {
     auto valueRange = getStructuredOperandIndexAndLength(index);
     return {std::next(odsOperands.begin(), valueRange.first),
-             std::next(odsOperands.begin(), 
+             std::next(odsOperands.begin(),
                        valueRange.first + valueRange.second)};
   }
 
@@ -91,7 +97,7 @@ private:
   RangeT odsOperands;
 };
 
-class __OP_CPP_NAME__Adaptor 
+class __OP_CPP_NAME__Adaptor
   : public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
 public:
   using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
@@ -100,7 +106,7 @@ public:
   ::llvm::LogicalResult verify(::mlir::Location loc);
 };
 
-class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
+class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
 public:
   using Op::Op;
   using Op::print;
@@ -147,7 +153,7 @@ public:
   ::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
     auto valueRange = getStructuredOperandIndexAndLength(index);
     return {std::next(getOperation()->operand_begin(), valueRange.first),
-             std::next(getOperation()->operand_begin(), 
+             std::next(getOperation()->operand_begin(),
                        valueRange.first + valueRange.second)};
   }
 
@@ -162,18 +168,19 @@ public:
   ::mlir::Operation::result_range getStructuredResults(unsigned index) {
     auto valueRange = getStructuredResultIndexAndLength(index);
     return {std::next(getOperation()->result_begin(), valueRange.first),
-             std::next(getOperation()->result_begin(), 
+             std::next(getOperation()->result_begin(),
                        valueRange.first + valueRange.second)};
   }
 
   __OP_OPERAND_GETTER_DECLS__
   __OP_RESULT_GETTER_DECLS__
-  
+  __OP_REGION_GETTER_DECLS__
+
   __OP_BUILD_DECLS__
-  static void build(::mlir::OpBuilder &odsBuilder, 
-                    ::mlir::OperationState &odsState, 
-                    ::mlir::TypeRange resultTypes, 
-                    ::mlir::ValueRange operands, 
+  static void build(::mlir::OpBuilder &odsBuilder,
+                    ::mlir::OperationState &odsState,
+                    ::mlir::TypeRange resultTypes,
+                    ::mlir::ValueRange operands,
                     ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
 
   static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
index 30ca420d77448..840e497560e6b 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
@@ -6,12 +6,14 @@ R"(
 
 __NAMESPACE_OPEN__
 
+__OP_VERIFIER_HELPERS__
+
 __OP_BUILD_DEFS__
 
-void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, 
-                            ::mlir::OperationState &odsState, 
-                            ::mlir::TypeRange resultTypes, 
-                            ::mlir::ValueRange operands, 
+void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
+                            ::mlir::OperationState &odsState,
+                            ::mlir::TypeRange resultTypes,
+                            ::mlir::ValueRange operands,
                             ::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
 {
   assert(operands.size() == __OP_OPERAND_COUNT__);
@@ -19,6 +21,7 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
   odsState.addOperands(operands);
   odsState.addAttributes(attributes);
   odsState.addTypes(resultTypes);
+  __OP_ADD_REGIONS__
 }
 
 __OP_CPP_NAME__
@@ -44,6 +47,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder,
   return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes);
 }
 
+__OP_VERIFIER__
 
 __NAMESPACE_CLOSE__
 
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
index 403b49235467c..cc2745643db7e 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
@@ -7,7 +7,7 @@ irdl.dialect @test_irdl_to_cpp {
     irdl.results(res: %1)
   }
 }
-// ----- 
+// -----
 
 irdl.dialect @test_irdl_to_cpp {
   irdl.operation @operands_no_any_of {
@@ -42,7 +42,7 @@ irdl.dialect @test_irdl_to_cpp {
 
 irdl.dialect @test_irdl_to_cpp {
   irdl.type @ty {
-    %0 = irdl.any 
+    %0 = irdl.any
     // expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.parameters operation}}
     irdl.parameters(ty: %0)
   }
@@ -50,30 +50,9 @@ irdl.dialect @test_irdl_to_cpp {
 
 // -----
 
-irdl.dialect @test_irdl_to_cpp {
-  irdl.operation @test_op {
-    // expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.region operation}}
-    %0 = irdl.region()
-    irdl.regions(reg: %0)
-  }
-  
-}
-
-// -----
-
-irdl.dialect @test_irdl_to_cpp {
-  irdl.operation @test_op {
-    // expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.regions operation}}
-    irdl.regions()
-  }
-  
-}
-
-// -----
-
 irdl.dialect @test_irdl_to_cpp {
   irdl.type @test_derived {
     // expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.base operation}}
     %0 = irdl.base "!builtin.integer"
-  }    
+  }
 }

>From 10d6418b13781ef70f41ebea0427a83fcbb4a5ad Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 26 Sep 2025 22:55:52 -0700
Subject: [PATCH 2/2] Get region constraints from IRDL ops

---
 mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp | 73 ++++++++++++++++++++-----
 1 file changed, 58 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index 92a9c59ec9ead..628c15cd11eb8 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -141,13 +141,13 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
   dict["OP_RESULT_INITIALIZER_LIST"] =
       resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
   dict["OP_REGION_COUNT"] = std::to_string(regionCount);
-  dict["OP_ADD_REGIONS"] =
-      regionCount ? std::string(llvm::formatv(
-                        R"(for (unsigned i = 0; i != {0}; ++i)
+  dict["OP_ADD_REGIONS"] = regionCount
+                               ? std::string(llvm::formatv(
+                                     R"(for (unsigned i = 0; i != {0}; ++i)
   (void)odsState.addRegion();
 )",
-                        regionCount))
-                  : "";
+                                     regionCount))
+                               : "";
 }
 
 /// Fills a dictionary with values from DialectStrings
@@ -273,10 +273,10 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
 }
 
 // add traits to the dictionary, return true if any were added
-static std::vector<std::string> generateTraits(irdl::detail::dictionary &dict,
+static SmallVector<std::string> generateTraits(irdl::detail::dictionary &dict,
                                                irdl::OperationOp op,
                                                const OpStrings &strings) {
-  std::vector<std::string> cppTraitNames;
+  SmallVector<std::string> cppTraitNames;
   if (!strings.opRegionNames.empty()) {
     cppTraitNames.push_back(
         llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
@@ -296,7 +296,7 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
   const auto opStrings = getStrings(op);
   fillDict(dict, opStrings);
 
-  std::vector<std::string> traitNames = generateTraits(dict, op, opStrings);
+  SmallVector<std::string> traitNames = generateTraits(dict, op, opStrings);
   if (traitNames.empty())
     dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
   else
@@ -359,9 +359,10 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
 
 static void generateVerifiers(irdl::detail::dictionary &dict,
                               irdl::OperationOp op, const OpStrings &strings) {
-  std::vector<std::string> verifierHelpers;
-  std::vector<std::string> verifierCalls;
-  if (strings.opRegionNames.empty()) {
+  SmallVector<std::string> verifierHelpers;
+  SmallVector<std::string> verifierCalls;
+  auto regionsOp = op.getOp<irdl::RegionsOp>();
+  if (strings.opRegionNames.empty() || !regionsOp) {
     dict["OP_VERIFIER_HELPERS"] = "";
     dict["OP_VERIFIER"] = "";
     // Currently IRDL regions are the only reason to generate a verifier,
@@ -376,13 +377,55 @@ static void generateVerifiers(irdl::detail::dictionary &dict,
         llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
                       strings.opCppName, regionName)
             .str();
-    std::string condition = "true"; // FIXME: get from irdl region op
-    std::string textualConditionName =
-        "any region"; // FIXME: can use textual irdl for this region?
+
+    // Extract the actual region constraint from the IRDL RegionOp
+    std::string condition = "true";
+    std::string textualConditionName = "any region";
+
+    if (auto regionDefOp =
+            dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
+      // Generate constraint condition based on RegionOp attributes
+      SmallVector<std::string> conditionParts;
+      SmallVector<std::string> descriptionParts;
+
+      // Check number of blocks constraint
+      if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
+        conditionParts.push_back(
+            llvm::formatv("region.getBlocks().size() == {0}",
+                          blockCount.value())
+                .str());
+        descriptionParts.push_back(
+            llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
+      }
+
+      // Check entry block arguments constraint
+      if (regionDefOp.getConstrainedArguments()) {
+        size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
+        conditionParts.push_back(
+            llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
+                .str());
+        descriptionParts.push_back(
+            llvm::formatv("{0} entry block argument(s)", expectedArgCount)
+                .str());
+      }
+
+      // Combine conditions
+      if (!conditionParts.empty()) {
+        condition = llvm::join(conditionParts, " && ");
+      }
+
+      // Generate descriptive error message
+      if (!descriptionParts.empty()) {
+        textualConditionName =
+            llvm::formatv("region with {0}",
+                          llvm::join(descriptionParts, " and "))
+                .str();
+      }
+    }
 
     verifierHelpers.push_back(llvm::formatv(
         R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
-    if (!{1}) {{
+    if (!({1})) {{
       return op->emitOpError("region #") << regionIndex
           << (regionName.empty() ? " " : " ('" + regionName + "')
           << "failed to verify constraint: {2}";



More information about the Mlir-commits mailing list