[Mlir-commits] [mlir] WIP: Add support for regions in irdl-to-cpp (PR #158540)
Jeremy Kun
llvmlistbot at llvm.org
Sat Sep 27 12:56:07 PDT 2025
https://github.com/j2kun updated https://github.com/llvm/llvm-project/pull/158540
>From 397e1945b4f31a175acd5ea12b3451f1c0a31ac1 Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Sat, 13 Sep 2025 22:46:18 -0700
Subject: [PATCH 1/2] Add support for regions in irdl-to-cpp
---
mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp | 124 +++++++++++++++++-
.../IRDLToCpp/Templates/PerOperationDecl.txt | 51 +++----
.../IRDLToCpp/Templates/PerOperationDef.txt | 12 +-
...to_cpp_invalid_unsupported_types.irdl.mlir | 27 +---
4 files changed, 162 insertions(+), 52 deletions(-)
diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index d6b8a8a1df426..92a9c59ec9ead 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -54,6 +54,7 @@ struct OpStrings {
std::string opCppName;
SmallVector<std::string> opResultNames;
SmallVector<std::string> opOperandNames;
+ SmallVector<std::string> opRegionNames;
};
static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
/// Generates OpStrings from an OperatioOp
static OpStrings getStrings(irdl::OperationOp op) {
auto operandOp = op.getOp<irdl::OperandsOp>();
-
auto resultOp = op.getOp<irdl::ResultsOp>();
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
OpStrings strings;
strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
}));
}
+ if (regionsOp) {
+ strings.opRegionNames = SmallVector<std::string>(
+ llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
+ return llvm::formatv("{0}", cast<StringAttr>(attr));
+ }));
+ }
+
return strings;
}
@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
const auto operandCount = strings.opOperandNames.size();
const auto resultCount = strings.opResultNames.size();
+ const auto regionCount = strings.opRegionNames.size();
dict["OP_NAME"] = strings.opName;
dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,14 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
dict["OP_RESULT_INITIALIZER_LIST"] =
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
+ dict["OP_REGION_COUNT"] = std::to_string(regionCount);
+ dict["OP_ADD_REGIONS"] =
+ regionCount ? std::string(llvm::formatv(
+ R"(for (unsigned i = 0; i != {0}; ++i)
+ (void)odsState.addRegion();
+)",
+ regionCount))
+ : "";
}
/// Fills a dictionary with values from DialectStrings
@@ -179,6 +196,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
const OpStrings &opStrings) {
auto opGetters = std::string{};
auto resGetters = std::string{};
+ auto regionGetters = std::string{};
+ auto regionAdaptorGetters = std::string{};
for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
const auto op =
@@ -196,8 +215,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
op, i);
}
+ for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
+ const auto op =
+ llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
+ regionAdaptorGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return getRegions()[{1}]; }
+ )",
+ op, i);
+ regionGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
+ )",
+ op, i);
+ }
+
dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
+ dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
+ dict["OP_REGION_GETTER_DECLS"] = regionGetters;
}
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -238,6 +272,21 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
dict["OP_BUILD_DECLS"] = buildDecls;
}
+// add traits to the dictionary, return true if any were added
+static std::vector<std::string> generateTraits(irdl::detail::dictionary &dict,
+ irdl::OperationOp op,
+ const OpStrings &strings) {
+ std::vector<std::string> cppTraitNames;
+ if (!strings.opRegionNames.empty()) {
+ cppTraitNames.push_back(
+ llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
+ strings.opRegionNames.size())
+ .str());
+ cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
+ }
+ return cppTraitNames;
+}
+
static LogicalResult generateOperationInclude(irdl::OperationOp op,
raw_ostream &output,
irdl::detail::dictionary &dict) {
@@ -247,6 +296,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
const auto opStrings = getStrings(op);
fillDict(dict, opStrings);
+ std::vector<std::string> traitNames = generateTraits(dict, op, opStrings);
+ if (traitNames.empty())
+ dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
+ else
+ dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
+ llvm::join(traitNames, ", "));
+
generateOpGetterDeclarations(dict, opStrings);
generateOpBuilderDeclarations(dict, opStrings);
@@ -301,6 +357,66 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
return success();
}
+static void generateVerifiers(irdl::detail::dictionary &dict,
+ irdl::OperationOp op, const OpStrings &strings) {
+ std::vector<std::string> verifierHelpers;
+ std::vector<std::string> verifierCalls;
+ if (strings.opRegionNames.empty()) {
+ dict["OP_VERIFIER_HELPERS"] = "";
+ dict["OP_VERIFIER"] = "";
+ // Currently IRDL regions are the only reason to generate a verifier,
+ // though this will likely change as
+ // https://github.com/llvm/llvm-project/issues/158040 is implemented
+ return;
+ }
+
+ for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
+ std::string regionName = strings.opRegionNames[i];
+ std::string helperFnName =
+ llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
+ strings.opCppName, regionName)
+ .str();
+ std::string condition = "true"; // FIXME: get from irdl region op
+ std::string textualConditionName =
+ "any region"; // FIXME: can use textual irdl for this region?
+
+ verifierHelpers.push_back(llvm::formatv(
+ R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{
+ if (!{1}) {{
+ return op->emitOpError("region #") << regionIndex
+ << (regionName.empty() ? " " : " ('" + regionName + "')
+ << "failed to verify constraint: {2}";
+ }}
+ return ::mlir::success();
+ }})",
+ helperFnName, condition, textualConditionName));
+
+ verifierCalls.push_back(llvm::formatv(R"(
+ if (::mlir::failed({0}(*this, (*this)->getRegion({1}), {2}, {1})))
+ return ::mlir::failure();)",
+ helperFnName, i, regionName)
+ .str());
+ }
+
+ // Add an overall verifier that sequences the helper calls
+ std::string verifierDef =
+ llvm::formatv(R"(
+::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
+ {1}
+ return ::mlir::success();
+ }}
+
+::llvm::LogicalResult {0}::verifyInvariants() {{
+ if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify()))
+ return ::mlir::success();
+ return ::mlir::failure();
+}})",
+ strings.opCppName, llvm::join(verifierCalls, "\n"));
+
+ dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
+ dict["OP_VERIFIER"] = verifierDef;
+}
+
static std::string generateOpDefinition(irdl::detail::dictionary &dict,
irdl::OperationOp op) {
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +486,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
dict["OP_BUILD_DEFS"] = buildDefinition;
+ generateVerifiers(dict, op, opStrings);
+
std::string str;
llvm::raw_string_ostream stream{str};
perOpDefTemplate.render(stream, dict);
@@ -427,7 +545,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
dict["TYPE_PARSER"] = llvm::formatv(
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
- {0}
+ {0}
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
*mnemonic = keyword;
return std::nullopt;
@@ -520,6 +638,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
"IRDL C++ translation does not yet support variadic results");
}))
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
+ .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
+ .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
.Default([](mlir::Operation *op) -> LogicalResult {
return op->emitError("IRDL C++ translation does not yet support "
"translation of ")
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
index e9068e9488f99..ae900bb818dd3 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
@@ -12,15 +12,15 @@ public:
struct Properties {
};
public:
- __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
- : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
- odsRegions(op->getRegions())
+ __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
+ : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
+ odsRegions(op->getRegions())
{}
/// Return the unstructured operand index of a structured operand along with
// the amount of unstructured operands it contains.
std::pair<unsigned, unsigned>
- getStructuredOperandIndexAndLength (unsigned index,
+ getStructuredOperandIndexAndLength (unsigned index,
unsigned odsOperandsSize) {
return {index, 1};
}
@@ -32,6 +32,12 @@ public:
::mlir::DictionaryAttr getAttributes() {
return odsAttrs;
}
+
+ __OP_REGION_ADAPTER_GETTER_DECLS__
+
+ ::mlir::RegionRange getRegions() {
+ return odsRegions;
+ }
protected:
::mlir::DictionaryAttr odsAttrs;
::std::optional<::mlir::OperationName> odsOpName;
@@ -42,28 +48,28 @@ protected:
} // namespace detail
template <typename RangeT>
-class __OP_CPP_NAME__GenericAdaptor
+class __OP_CPP_NAME__GenericAdaptor
: public detail::__OP_CPP_NAME__GenericAdaptorBase {
using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
public:
__OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
- ::mlir::OpaqueProperties properties,
- ::mlir::RegionRange regions = {})
- : __OP_CPP_NAME__GenericAdaptor(values, attrs,
- (properties ? *properties.as<::mlir::EmptyProperties *>()
+ ::mlir::OpaqueProperties properties,
+ ::mlir::RegionRange regions = {})
+ : __OP_CPP_NAME__GenericAdaptor(values, attrs,
+ (properties ? *properties.as<::mlir::EmptyProperties *>()
: ::mlir::EmptyProperties{}), regions) {}
- __OP_CPP_NAME__GenericAdaptor(RangeT values,
+ __OP_CPP_NAME__GenericAdaptor(RangeT values,
const __OP_CPP_NAME__GenericAdaptorBase &base)
: Base(base), odsOperands(values) {}
- // This template parameter allows using __OP_CPP_NAME__ which is declared
+ // This template parameter allows using __OP_CPP_NAME__ which is declared
// later.
template <typename LateInst = __OP_CPP_NAME__,
typename = std::enable_if_t<
std::is_same_v<LateInst, __OP_CPP_NAME__>>>
- __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
+ __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
: Base(op), odsOperands(values) {}
/// Return the unstructured operand index of a structured operand along with
@@ -77,7 +83,7 @@ public:
RangeT getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(odsOperands.begin(), valueRange.first),
- std::next(odsOperands.begin(),
+ std::next(odsOperands.begin(),
valueRange.first + valueRange.second)};
}
@@ -91,7 +97,7 @@ private:
RangeT odsOperands;
};
-class __OP_CPP_NAME__Adaptor
+class __OP_CPP_NAME__Adaptor
: public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
public:
using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
@@ -100,7 +106,7 @@ public:
::llvm::LogicalResult verify(::mlir::Location loc);
};
-class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
+class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
public:
using Op::Op;
using Op::print;
@@ -147,7 +153,7 @@ public:
::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(getOperation()->operand_begin(), valueRange.first),
- std::next(getOperation()->operand_begin(),
+ std::next(getOperation()->operand_begin(),
valueRange.first + valueRange.second)};
}
@@ -162,18 +168,19 @@ public:
::mlir::Operation::result_range getStructuredResults(unsigned index) {
auto valueRange = getStructuredResultIndexAndLength(index);
return {std::next(getOperation()->result_begin(), valueRange.first),
- std::next(getOperation()->result_begin(),
+ std::next(getOperation()->result_begin(),
valueRange.first + valueRange.second)};
}
__OP_OPERAND_GETTER_DECLS__
__OP_RESULT_GETTER_DECLS__
-
+ __OP_REGION_GETTER_DECLS__
+
__OP_BUILD_DECLS__
- static void build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+ static void build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
index 30ca420d77448..840e497560e6b 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
@@ -6,12 +6,14 @@ R"(
__NAMESPACE_OPEN__
+__OP_VERIFIER_HELPERS__
+
__OP_BUILD_DEFS__
-void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
{
assert(operands.size() == __OP_OPERAND_COUNT__);
@@ -19,6 +21,7 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
odsState.addOperands(operands);
odsState.addAttributes(attributes);
odsState.addTypes(resultTypes);
+ __OP_ADD_REGIONS__
}
__OP_CPP_NAME__
@@ -44,6 +47,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder,
return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes);
}
+__OP_VERIFIER__
__NAMESPACE_CLOSE__
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
index 403b49235467c..cc2745643db7e 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
@@ -7,7 +7,7 @@ irdl.dialect @test_irdl_to_cpp {
irdl.results(res: %1)
}
}
-// -----
+// -----
irdl.dialect @test_irdl_to_cpp {
irdl.operation @operands_no_any_of {
@@ -42,7 +42,7 @@ irdl.dialect @test_irdl_to_cpp {
irdl.dialect @test_irdl_to_cpp {
irdl.type @ty {
- %0 = irdl.any
+ %0 = irdl.any
// expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.parameters operation}}
irdl.parameters(ty: %0)
}
@@ -50,30 +50,9 @@ irdl.dialect @test_irdl_to_cpp {
// -----
-irdl.dialect @test_irdl_to_cpp {
- irdl.operation @test_op {
- // expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.region operation}}
- %0 = irdl.region()
- irdl.regions(reg: %0)
- }
-
-}
-
-// -----
-
-irdl.dialect @test_irdl_to_cpp {
- irdl.operation @test_op {
- // expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.regions operation}}
- irdl.regions()
- }
-
-}
-
-// -----
-
irdl.dialect @test_irdl_to_cpp {
irdl.type @test_derived {
// expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.base operation}}
%0 = irdl.base "!builtin.integer"
- }
+ }
}
>From 10d6418b13781ef70f41ebea0427a83fcbb4a5ad Mon Sep 17 00:00:00 2001
From: Jeremy Kun <j2kun at users.noreply.github.com>
Date: Fri, 26 Sep 2025 22:55:52 -0700
Subject: [PATCH 2/2] Get region constraints from IRDL ops
---
mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp | 73 ++++++++++++++++++++-----
1 file changed, 58 insertions(+), 15 deletions(-)
diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index 92a9c59ec9ead..628c15cd11eb8 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -141,13 +141,13 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
dict["OP_RESULT_INITIALIZER_LIST"] =
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
dict["OP_REGION_COUNT"] = std::to_string(regionCount);
- dict["OP_ADD_REGIONS"] =
- regionCount ? std::string(llvm::formatv(
- R"(for (unsigned i = 0; i != {0}; ++i)
+ dict["OP_ADD_REGIONS"] = regionCount
+ ? std::string(llvm::formatv(
+ R"(for (unsigned i = 0; i != {0}; ++i)
(void)odsState.addRegion();
)",
- regionCount))
- : "";
+ regionCount))
+ : "";
}
/// Fills a dictionary with values from DialectStrings
@@ -273,10 +273,10 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
}
// add traits to the dictionary, return true if any were added
-static std::vector<std::string> generateTraits(irdl::detail::dictionary &dict,
+static SmallVector<std::string> generateTraits(irdl::detail::dictionary &dict,
irdl::OperationOp op,
const OpStrings &strings) {
- std::vector<std::string> cppTraitNames;
+ SmallVector<std::string> cppTraitNames;
if (!strings.opRegionNames.empty()) {
cppTraitNames.push_back(
llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
@@ -296,7 +296,7 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
const auto opStrings = getStrings(op);
fillDict(dict, opStrings);
- std::vector<std::string> traitNames = generateTraits(dict, op, opStrings);
+ SmallVector<std::string> traitNames = generateTraits(dict, op, opStrings);
if (traitNames.empty())
dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
else
@@ -359,9 +359,10 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
static void generateVerifiers(irdl::detail::dictionary &dict,
irdl::OperationOp op, const OpStrings &strings) {
- std::vector<std::string> verifierHelpers;
- std::vector<std::string> verifierCalls;
- if (strings.opRegionNames.empty()) {
+ SmallVector<std::string> verifierHelpers;
+ SmallVector<std::string> verifierCalls;
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
+ if (strings.opRegionNames.empty() || !regionsOp) {
dict["OP_VERIFIER_HELPERS"] = "";
dict["OP_VERIFIER"] = "";
// Currently IRDL regions are the only reason to generate a verifier,
@@ -376,13 +377,55 @@ static void generateVerifiers(irdl::detail::dictionary &dict,
llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
strings.opCppName, regionName)
.str();
- std::string condition = "true"; // FIXME: get from irdl region op
- std::string textualConditionName =
- "any region"; // FIXME: can use textual irdl for this region?
+
+ // Extract the actual region constraint from the IRDL RegionOp
+ std::string condition = "true";
+ std::string textualConditionName = "any region";
+
+ if (auto regionDefOp =
+ dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
+ // Generate constraint condition based on RegionOp attributes
+ SmallVector<std::string> conditionParts;
+ SmallVector<std::string> descriptionParts;
+
+ // Check number of blocks constraint
+ if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
+ conditionParts.push_back(
+ llvm::formatv("region.getBlocks().size() == {0}",
+ blockCount.value())
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
+ }
+
+ // Check entry block arguments constraint
+ if (regionDefOp.getConstrainedArguments()) {
+ size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
+ conditionParts.push_back(
+ llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("{0} entry block argument(s)", expectedArgCount)
+ .str());
+ }
+
+ // Combine conditions
+ if (!conditionParts.empty()) {
+ condition = llvm::join(conditionParts, " && ");
+ }
+
+ // Generate descriptive error message
+ if (!descriptionParts.empty()) {
+ textualConditionName =
+ llvm::formatv("region with {0}",
+ llvm::join(descriptionParts, " and "))
+ .str();
+ }
+ }
verifierHelpers.push_back(llvm::formatv(
R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{
- if (!{1}) {{
+ if (!({1})) {{
return op->emitOpError("region #") << regionIndex
<< (regionName.empty() ? " " : " ('" + regionName + "')
<< "failed to verify constraint: {2}";
More information about the Mlir-commits
mailing list