[Mlir-commits] [mlir] 99d8590 - [mlir] [irdl] Add support for regions in irdl-to-cpp (#158540)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 2 10:56:44 PDT 2025
Author: Jeremy Kun
Date: 2025-10-02T10:56:40-07:00
New Revision: 99d85906c542c3801a24137ba6d6f2c367308563
URL: https://github.com/llvm/llvm-project/commit/99d85906c542c3801a24137ba6d6f2c367308563
DIFF: https://github.com/llvm/llvm-project/commit/99d85906c542c3801a24137ba6d6f2c367308563.diff
LOG: [mlir] [irdl] Add support for regions in irdl-to-cpp (#158540)
Fixes https://github.com/llvm/llvm-project/issues/158034
For the input
```mlir
irdl.dialect @conditional_dialect {
// A conditional operation with regions
irdl.operation @conditional {
// Create region constraints
%r0 = irdl.region // Unconstrained region
%r1 = irdl.region() // Region with no entry block arguments
%v0 = irdl.any
%r2 = irdl.region(%v0) // Region with one i1 entry block argument
irdl.regions(cond: %r2, then: %r0, else: %r1)
}
}
```
This produces the following cpp:
https://gist.github.com/j2kun/d2095f108efbd8d403576d5c460e0c00
Summary of changes:
- The op class and adaptor get named accessors to the regions `Region
&get<RegionName>()` and `getRegions()`
- The op now gets `OpTrait::NRegions<3>` and `OpInvariants` to trigger
the region verification
- Support for region block argument constraints is added, but not
working for all constraints until codegen for `irdl.is` is added (filed
https://github.com/llvm/llvm-project/issues/161018 and left a TODO).
- Helper functions for the individual verification steps are added,
following mlir-tblgen's format (in the above gist,
`__mlir_irdl_local_region_constraint_ConditionalOp_cond` and similar),
and `verifyInvariantsImpl` that calls them.
- Regions are added in the builder
## Questions for the reviewer
### What is the "correct" interface for verification?
I used `mlir-tblgen` on an analogous version of the example
`ConditionalOp` in this PR, and I see an `::mlir::OpTrait::OpInvariants`
trait as well as
```cpp
::llvm::LogicalResult ConditionalOp::verifyInvariantsImpl() {
{
unsigned index = 0; (void)index;
for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(0)))
if (::mlir::failed(__mlir_ods_local_region_constraint_test1(*this, region, "cond", index++)))
return ::mlir::failure();
for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(1)))
if (::mlir::failed(__mlir_ods_local_region_constraint_test1(*this, region, "then", index++)))
return ::mlir::failure();
for (auto ®ion : ::llvm::MutableArrayRef((*this)->getRegion(2)))
if (::mlir::failed(__mlir_ods_local_region_constraint_test1(*this, region, "else", index++)))
return ::mlir::failure();
}
return ::mlir::success();
}
::llvm::LogicalResult ConditionalOp::verifyInvariants() {
if(::mlir::succeeded(verifyInvariantsImpl()) && ::mlir::succeeded(verify()))
return ::mlir::success();
return ::mlir::failure();
}
```
However, `OpInvariants` only seems to need `verifyInvariantsImpl`, so
it's not clear to me what is the purpose of the `verifyInvariants`
function, or, if I leave out `verifyInvariants`, whether I need to call
`verify()` in my implementation of `verifyInvariantsImpl`. In this PR, I
omitted `verifyInvariants` and generated `verifyInvariantsImpl`.
### Is testing sufficient?
I am not certain I implemented the builders properly, and it's unclear
to me to what extent the existing tests check this (which look like they
compile the generated cpp, but don't actually use it). Did I omit some
standard function or overload?
---------
Co-authored-by: Jeremy Kun <j2kun at users.noreply.github.com>
Added:
Modified:
mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt
mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir
mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir
mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index d6b8a8a1df426..e3f075fcc1294 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -54,6 +54,7 @@ struct OpStrings {
std::string opCppName;
SmallVector<std::string> opResultNames;
SmallVector<std::string> opOperandNames;
+ SmallVector<std::string> opRegionNames;
};
static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
/// Generates OpStrings from an OperatioOp
static OpStrings getStrings(irdl::OperationOp op) {
auto operandOp = op.getOp<irdl::OperandsOp>();
-
auto resultOp = op.getOp<irdl::ResultsOp>();
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
OpStrings strings;
strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
}));
}
+ if (regionsOp) {
+ strings.opRegionNames = SmallVector<std::string>(
+ llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
+ return llvm::formatv("{0}", cast<StringAttr>(attr));
+ }));
+ }
+
return strings;
}
@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
const auto operandCount = strings.opOperandNames.size();
const auto resultCount = strings.opResultNames.size();
+ const auto regionCount = strings.opRegionNames.size();
dict["OP_NAME"] = strings.opName;
dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
dict["OP_RESULT_INITIALIZER_LIST"] =
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
+ dict["OP_REGION_COUNT"] = std::to_string(regionCount);
}
/// Fills a dictionary with values from DialectStrings
@@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
const OpStrings &opStrings) {
auto opGetters = std::string{};
auto resGetters = std::string{};
+ auto regionGetters = std::string{};
+ auto regionAdaptorGetters = std::string{};
for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
const auto op =
@@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
op, i);
}
+ for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
+ const auto op =
+ llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
+ regionAdaptorGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
+ )",
+ op, i);
+ regionGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
+ )",
+ op, i);
+ }
+
dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
+ dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
+ dict["OP_REGION_GETTER_DECLS"] = regionGetters;
}
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
dict["OP_BUILD_DECLS"] = buildDecls;
}
+// add traits to the dictionary, return true if any were added
+static SmallVector<std::string> generateTraits(irdl::OperationOp op,
+ const OpStrings &strings) {
+ SmallVector<std::string> cppTraitNames;
+ if (!strings.opRegionNames.empty()) {
+ cppTraitNames.push_back(
+ llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
+ strings.opRegionNames.size())
+ .str());
+
+ // Requires verifyInvariantsImpl is implemented on the op
+ cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
+ }
+ return cppTraitNames;
+}
+
static LogicalResult generateOperationInclude(irdl::OperationOp op,
raw_ostream &output,
irdl::detail::dictionary &dict) {
@@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
const auto opStrings = getStrings(op);
fillDict(dict, opStrings);
+ SmallVector<std::string> traitNames = generateTraits(op, opStrings);
+ if (traitNames.empty())
+ dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
+ else
+ dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
+ llvm::join(traitNames, ", "));
+
generateOpGetterDeclarations(dict, opStrings);
generateOpBuilderDeclarations(dict, opStrings);
@@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
return success();
}
+static void generateRegionConstraintVerifiers(
+ irdl::detail::dictionary &dict, irdl::OperationOp op,
+ const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers,
+ SmallVectorImpl<std::string> &verifierCalls) {
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
+ if (strings.opRegionNames.empty() || !regionsOp)
+ return;
+
+ for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
+ std::string regionName = strings.opRegionNames[i];
+ std::string helperFnName =
+ llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
+ strings.opCppName, regionName)
+ .str();
+
+ // Extract the actual region constraint from the IRDL RegionOp
+ std::string condition = "true";
+ std::string textualConditionName = "any region";
+
+ if (auto regionDefOp =
+ dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
+ // Generate constraint condition based on RegionOp attributes
+ SmallVector<std::string> conditionParts;
+ SmallVector<std::string> descriptionParts;
+
+ // Check number of blocks constraint
+ if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
+ conditionParts.push_back(
+ llvm::formatv("region.getBlocks().size() == {0}",
+ blockCount.value())
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
+ }
+
+ // Check entry block arguments constraint
+ if (regionDefOp.getConstrainedArguments()) {
+ size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
+ conditionParts.push_back(
+ llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("{0} entry block argument(s)", expectedArgCount)
+ .str());
+ }
+
+ // Combine conditions
+ if (!conditionParts.empty()) {
+ condition = llvm::join(conditionParts, " && ");
+ }
+
+ // Generate descriptive error message
+ if (!descriptionParts.empty()) {
+ textualConditionName =
+ llvm::formatv("region with {0}",
+ llvm::join(descriptionParts, " and "))
+ .str();
+ }
+ }
+
+ verifierHelpers.push_back(llvm::formatv(
+ R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{
+ if (!({1})) {{
+ return op->emitOpError("region #") << regionIndex
+ << (regionName.empty() ? " " : " ('" + regionName + "') ")
+ << "failed to verify constraint: {2}";
+ }
+ return ::mlir::success();
+})",
+ helperFnName, condition, textualConditionName));
+
+ verifierCalls.push_back(llvm::formatv(R"(
+ if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
+ return ::mlir::failure();)",
+ helperFnName, i, regionName)
+ .str());
+ }
+}
+
+static void generateVerifiers(irdl::detail::dictionary &dict,
+ irdl::OperationOp op, const OpStrings &strings) {
+ SmallVector<std::string> verifierHelpers;
+ SmallVector<std::string> verifierCalls;
+
+ generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers,
+ verifierCalls);
+
+ // Add an overall verifier that sequences the helper calls
+ std::string verifierDef =
+ llvm::formatv(R"(
+::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
+ if(::mlir::failed(verify()))
+ return ::mlir::failure();
+
+ {1}
+
+ return ::mlir::success();
+})",
+ strings.opCppName, llvm::join(verifierCalls, "\n"));
+
+ dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
+ dict["OP_VERIFIER"] = verifierDef;
+}
+
static std::string generateOpDefinition(irdl::detail::dictionary &dict,
irdl::OperationOp op) {
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
dict["OP_BUILD_DEFS"] = buildDefinition;
+ generateVerifiers(dict, op, opStrings);
+
std::string str;
llvm::raw_string_ostream stream{str};
perOpDefTemplate.render(stream, dict);
@@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
dict["TYPE_PARSER"] = llvm::formatv(
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
- {0}
+ {0}
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
*mnemonic = keyword;
return std::nullopt;
@@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
"IRDL C++ translation does not yet support variadic results");
}))
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
+ .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
+ .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
.Default([](mlir::Operation *op) -> LogicalResult {
return op->emitError("IRDL C++ translation does not yet support "
"translation of ")
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
index e9068e9488f99..93ce0bef1f269 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
@@ -12,15 +12,15 @@ public:
struct Properties {
};
public:
- __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
- : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
- odsRegions(op->getRegions())
+ __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
+ : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
+ odsRegions(op->getRegions())
{}
/// Return the unstructured operand index of a structured operand along with
// the amount of unstructured operands it contains.
std::pair<unsigned, unsigned>
- getStructuredOperandIndexAndLength (unsigned index,
+ getStructuredOperandIndexAndLength (unsigned index,
unsigned odsOperandsSize) {
return {index, 1};
}
@@ -32,6 +32,12 @@ public:
::mlir::DictionaryAttr getAttributes() {
return odsAttrs;
}
+
+ __OP_REGION_ADAPTER_GETTER_DECLS__
+
+ ::mlir::RegionRange getRegions() {
+ return odsRegions;
+ }
protected:
::mlir::DictionaryAttr odsAttrs;
::std::optional<::mlir::OperationName> odsOpName;
@@ -42,28 +48,28 @@ protected:
} // namespace detail
template <typename RangeT>
-class __OP_CPP_NAME__GenericAdaptor
+class __OP_CPP_NAME__GenericAdaptor
: public detail::__OP_CPP_NAME__GenericAdaptorBase {
using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
public:
__OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
- ::mlir::OpaqueProperties properties,
- ::mlir::RegionRange regions = {})
- : __OP_CPP_NAME__GenericAdaptor(values, attrs,
- (properties ? *properties.as<::mlir::EmptyProperties *>()
+ ::mlir::OpaqueProperties properties,
+ ::mlir::RegionRange regions = {})
+ : __OP_CPP_NAME__GenericAdaptor(values, attrs,
+ (properties ? *properties.as<::mlir::EmptyProperties *>()
: ::mlir::EmptyProperties{}), regions) {}
- __OP_CPP_NAME__GenericAdaptor(RangeT values,
+ __OP_CPP_NAME__GenericAdaptor(RangeT values,
const __OP_CPP_NAME__GenericAdaptorBase &base)
: Base(base), odsOperands(values) {}
- // This template parameter allows using __OP_CPP_NAME__ which is declared
+ // This template parameter allows using __OP_CPP_NAME__ which is declared
// later.
template <typename LateInst = __OP_CPP_NAME__,
typename = std::enable_if_t<
std::is_same_v<LateInst, __OP_CPP_NAME__>>>
- __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
+ __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
: Base(op), odsOperands(values) {}
/// Return the unstructured operand index of a structured operand along with
@@ -77,7 +83,7 @@ public:
RangeT getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(odsOperands.begin(), valueRange.first),
- std::next(odsOperands.begin(),
+ std::next(odsOperands.begin(),
valueRange.first + valueRange.second)};
}
@@ -91,7 +97,7 @@ private:
RangeT odsOperands;
};
-class __OP_CPP_NAME__Adaptor
+class __OP_CPP_NAME__Adaptor
: public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
public:
using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
@@ -100,7 +106,7 @@ public:
::llvm::LogicalResult verify(::mlir::Location loc);
};
-class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
+class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
public:
using Op::Op;
using Op::print;
@@ -112,6 +118,8 @@ public:
return {};
}
+ ::llvm::LogicalResult verifyInvariantsImpl();
+
static constexpr ::llvm::StringLiteral getOperationName() {
return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__");
}
@@ -147,7 +155,7 @@ public:
::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(getOperation()->operand_begin(), valueRange.first),
- std::next(getOperation()->operand_begin(),
+ std::next(getOperation()->operand_begin(),
valueRange.first + valueRange.second)};
}
@@ -162,18 +170,19 @@ public:
::mlir::Operation::result_range getStructuredResults(unsigned index) {
auto valueRange = getStructuredResultIndexAndLength(index);
return {std::next(getOperation()->result_begin(), valueRange.first),
- std::next(getOperation()->result_begin(),
+ std::next(getOperation()->result_begin(),
valueRange.first + valueRange.second)};
}
__OP_OPERAND_GETTER_DECLS__
__OP_RESULT_GETTER_DECLS__
-
+ __OP_REGION_GETTER_DECLS__
+
__OP_BUILD_DECLS__
- static void build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+ static void build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
index 30ca420d77448..f4a1b7a996263 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
@@ -6,12 +6,14 @@ R"(
__NAMESPACE_OPEN__
+__OP_VERIFIER_HELPERS__
+
__OP_BUILD_DEFS__
-void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
{
assert(operands.size() == __OP_OPERAND_COUNT__);
@@ -19,6 +21,9 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
odsState.addOperands(operands);
odsState.addAttributes(attributes);
odsState.addTypes(resultTypes);
+ for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i) {
+ (void)odsState.addRegion();
+ }
}
__OP_CPP_NAME__
@@ -44,6 +49,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder,
return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes);
}
+__OP_VERIFIER__
__NAMESPACE_CLOSE__
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt b/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt
index 103bc94d86920..7d325778f09cb 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/CMakeLists.txt
@@ -12,5 +12,7 @@ add_mlir_library(MLIRTestIRDLToCppDialect
mlir_target_link_libraries(MLIRTestIRDLToCppDialect PUBLIC
MLIRIR
MLIRPass
+ MLIRSCFDialect
MLIRTransforms
+ MLIRTestDialect
)
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
index 9550e4c96e547..421db7e4c0094 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/TestIRDLToCppDialect.cpp
@@ -13,6 +13,7 @@
// #include "mlir/IR/Dialect.h"
#include "mlir/IR/Region.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
@@ -54,16 +55,34 @@ struct TestOpConversion : public OpConversionPattern<test_irdl_to_cpp::BeefOp> {
}
};
+struct TestRegionConversion
+ : public OpConversionPattern<test_irdl_to_cpp::ConditionalOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(mlir::test_irdl_to_cpp::ConditionalOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Just exercising the C++ API even though these are not enforced in the
+ // dialect definition
+ assert(op.getThen().getBlocks().size() == 1);
+ assert(adaptor.getElse().getBlocks().size() == 1);
+ auto ifOp = scf::IfOp::create(rewriter, op.getLoc(), op.getInput());
+ rewriter.replaceOp(op, ifOp);
+ return success();
+ }
+};
+
struct ConvertTestDialectToSomethingPass
: PassWrapper<ConvertTestDialectToSomethingPass, OperationPass<ModuleOp>> {
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
- patterns.add<TestOpConversion>(ctx);
+ patterns.add<TestOpConversion, TestRegionConversion>(ctx);
ConversionTarget target(getContext());
- target.addIllegalOp<test_irdl_to_cpp::BeefOp>();
- target.addLegalOp<test_irdl_to_cpp::BarOp>();
- target.addLegalOp<test_irdl_to_cpp::HashOp>();
+ target.addIllegalOp<test_irdl_to_cpp::BeefOp,
+ test_irdl_to_cpp::ConditionalOp>();
+ target.addLegalOp<test_irdl_to_cpp::BarOp, test_irdl_to_cpp::HashOp,
+ scf::IfOp, scf::YieldOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -73,6 +92,10 @@ struct ConvertTestDialectToSomethingPass
StringRef getDescription() const final {
return "Checks the convertability of an irdl dialect";
}
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<scf::SCFDialect>();
+ }
};
void registerIrdlTestDialect(mlir::DialectRegistry ®istry) {
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir
index f6233ee18190a..1915324ccb459 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_conversion.testd.mlir
@@ -1,15 +1,29 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-irdl-conversion-check)" | FileCheck %s
// CHECK-LABEL: module {
module {
- // CHECK: func.func @test() {
+ // CHECK: func.func @test(%[[test_arg:[^ ]*]]: i1) {
// CHECK: %[[v0:[^ ]*]] = "test_irdl_to_cpp.bar"() : () -> i32
// CHECK: %[[v1:[^ ]*]] = "test_irdl_to_cpp.bar"() : () -> i32
// CHECK: %[[v2:[^ ]*]] = "test_irdl_to_cpp.hash"(%[[v0]], %[[v0]]) : (i32, i32) -> i32
+ // CHECK: scf.if %[[test_arg]]
// CHECK: return
// CHECK: }
- func.func @test() {
+ func.func @test(%test_arg: i1) {
%0 = "test_irdl_to_cpp.bar"() : () -> i32
%1 = "test_irdl_to_cpp.beef"(%0, %0) : (i32, i32) -> i32
+ "test_irdl_to_cpp.conditional"(%test_arg) ({
+ ^cond(%test: i1):
+ %3 = "test_irdl_to_cpp.bar"() : () -> i32
+ "test.terminator"() : ()->()
+ }, {
+ ^then(%what: i1, %ever: i32):
+ %4 = "test_irdl_to_cpp.bar"() : () -> i32
+ "test.terminator"() : ()->()
+ }, {
+ ^else():
+ %5 = "test_irdl_to_cpp.bar"() : () -> i32
+ "test.terminator"() : ()->()
+ }) : (i1) -> ()
return
}
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir
index 42e713e0adecd..85fb8cb15acef 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp.irdl.mlir
@@ -2,7 +2,7 @@
// CHECK: class TestIrdlToCpp
irdl.dialect @test_irdl_to_cpp {
-
+
// CHECK: class FooType
irdl.type @foo
@@ -32,4 +32,53 @@ irdl.dialect @test_irdl_to_cpp {
irdl.operands(lhs: %0, rhs: %0)
irdl.results(res: %0)
}
+
+ // CHECK: ConditionalOp declarations
+ // CHECK: ConditionalOpGenericAdaptorBase
+ // CHECK: ::mlir::Region &getCond() { return *getRegions()[0]; }
+ // CHECK: ::mlir::Region &getThen() { return *getRegions()[1]; }
+ // CHECK: ::mlir::Region &getElse() { return *getRegions()[2]; }
+ //
+ // CHECK: class ConditionalOp : public ::mlir::Op<ConditionalOp, ::mlir::OpTrait::NRegions<3>::Impl, ::mlir::OpTrait::OpInvariants>
+ // CHECK: ::mlir::Region &getCond() { return (*this)->getRegion(0); }
+ // CHECK: ::mlir::Region &getThen() { return (*this)->getRegion(1); }
+ // CHECK: ::mlir::Region &getElse() { return (*this)->getRegion(2); }
+
+ // CHECK: ConditionalOp definitions
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_cond
+ // CHECK: if (!(region.getNumArguments() == 1)) {
+ // CHECK: failed to verify constraint: region with 1 entry block argument(s)
+
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_then
+ // CHECK: if (!(true)) {
+
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_else
+ // CHECK: if (!(region.getNumArguments() == 0)) {
+ // CHECK: failed to verify constraint: region with 0 entry block argument(s)
+
+ // CHECK: ConditionalOp::build
+ // CHECK: for (unsigned i = 0; i != 3; ++i)
+ // CHECK-NEXT: (void)odsState.addRegion();
+
+ // CHECK: ConditionalOp::verifyInvariantsImpl
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_cond
+ // CHECK: failure
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_then
+ // CHECK: failure
+ // CHECK: __mlir_irdl_local_region_constraint_ConditionalOp_else
+ // CHECK: failure
+ // CHECK: success
+ irdl.operation @conditional {
+ %r0 = irdl.region // Unconstrained region
+ %r1 = irdl.region() // Region with no entry block arguments
+
+ // TODO(#161018): support irdl.is in irdl-to-cpp
+ // %v0 = irdl.is i1 // Type constraint: i1 (boolean)
+ %v0 = irdl.any
+ %r2 = irdl.region(%v0) // Region with one i1 entry block argument
+ irdl.regions(cond: %r2, then: %r0, else: %r1)
+
+ %0 = irdl.any
+ irdl.operands(input: %0)
+ }
}
diff --git a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
index 403b49235467c..cc2745643db7e 100644
--- a/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
+++ b/mlir/test/lib/Dialect/TestIRDLToCpp/test_irdl_to_cpp_invalid_unsupported_types.irdl.mlir
@@ -7,7 +7,7 @@ irdl.dialect @test_irdl_to_cpp {
irdl.results(res: %1)
}
}
-// -----
+// -----
irdl.dialect @test_irdl_to_cpp {
irdl.operation @operands_no_any_of {
@@ -42,7 +42,7 @@ irdl.dialect @test_irdl_to_cpp {
irdl.dialect @test_irdl_to_cpp {
irdl.type @ty {
- %0 = irdl.any
+ %0 = irdl.any
// expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.parameters operation}}
irdl.parameters(ty: %0)
}
@@ -50,30 +50,9 @@ irdl.dialect @test_irdl_to_cpp {
// -----
-irdl.dialect @test_irdl_to_cpp {
- irdl.operation @test_op {
- // expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.region operation}}
- %0 = irdl.region()
- irdl.regions(reg: %0)
- }
-
-}
-
-// -----
-
-irdl.dialect @test_irdl_to_cpp {
- irdl.operation @test_op {
- // expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.regions operation}}
- irdl.regions()
- }
-
-}
-
-// -----
-
irdl.dialect @test_irdl_to_cpp {
irdl.type @test_derived {
// expected-error at +1 {{IRDL C++ translation does not yet support translation of irdl.base operation}}
%0 = irdl.base "!builtin.integer"
- }
+ }
}
More information about the Mlir-commits
mailing list