[Mlir-commits] [mlir] 715b025 - [mlir][ods] Simplify signature of `custom` printers and parsers of Attributes and Types in presence of default constructible parameters
Markus Böck
llvmlistbot at llvm.org
Sun Jan 22 07:18:41 PST 2023
Author: Markus Böck
Date: 2023-01-22T16:18:44+01:00
New Revision: 715b0258522ff3c99aa57801d1f4d2b1b7a90ee1
URL: https://github.com/llvm/llvm-project/commit/715b0258522ff3c99aa57801d1f4d2b1b7a90ee1
DIFF: https://github.com/llvm/llvm-project/commit/715b0258522ff3c99aa57801d1f4d2b1b7a90ee1.diff
LOG: [mlir][ods] Simplify signature of `custom` printers and parsers of Attributes and Types in presence of default constructible parameters
The vast majority of parameters of C++ types used as parameters for Attributes and Types are likely to be default constructible. Nevertheless, TableGen conservatively generates code for the custom directive, expecting signatures using FailureOr<T> for all parameter types T to accomodate them possibly not being default constructible. This however reduces the ergonomics of the likely case of default constructible parameters.
This patch fixes that issue, while barely changing the generated TableGen code, by using a helper function that is used to pass any parameters into custom parser methods. If the type is default constructible, as deemed by the C++ compiler, a default constructible instance is created and passed into the parser method by reference. In all other cases it is a Noop and a FailureOr is passed as before.
Documentation was also updated to document the new behaviour.
Fixes https://github.com/llvm/llvm-project/issues/60178
Differential Revision: https://reviews.llvm.org/D142301
Added:
Modified:
mlir/docs/DefiningDialects/AttributesAndTypes.md
mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
mlir/include/mlir/IR/AttributeSupport.h
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
mlir/test/lib/Dialect/Test/TestAttributes.cpp
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/mlir-tblgen/attr-or-type-format.td
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DefiningDialects/AttributesAndTypes.md b/mlir/docs/DefiningDialects/AttributesAndTypes.md
index afceda9f3b501..e9f13e7cecb15 100644
--- a/mlir/docs/DefiningDialects/AttributesAndTypes.md
+++ b/mlir/docs/DefiningDialects/AttributesAndTypes.md
@@ -866,17 +866,33 @@ The `custom` directive `custom<Foo>($foo)` will in the parser and printer
respectively generate calls to:
```c++
-LogicalResult parseFoo(AsmParser &parser, FailureOr<int> &foo);
+LogicalResult parseFoo(AsmParser &parser, int &foo);
void printFoo(AsmPrinter &printer, int foo);
```
+As you can see, by default parameters are passed into the parse function by
+reference. This is only possible if the C++ type is default constructible.
+If the C++ type is not default constructible, the parameter is wrapped in a
+`FailureOr`. Therefore, given the following definition:
+
+```tablegen
+let parameters = (ins "NotDefaultConstructible":$foobar);
+let assemblyFormat = "custom<Fizz>($foobar)";
+```
+
+It will generate calls expecting the following signature for `parseFizz`:
+
+```c++
+LogicalResult parseFizz(AsmParser &parser, FailureOr<NotDefaultConstructible> &foobar);
+```
+
A previously bound variable can be passed as a parameter to a `custom` directive
by wrapping it in a `ref` directive. In the previous example, `$foo` is bound by
the first directive. The second directive references it and expects the
following printer and parser signatures:
```c++
-LogicalResult parseBar(AsmParser &parser, FailureOr<int> &bar, int foo);
+LogicalResult parseBar(AsmParser &parser, int &bar, int foo);
void printBar(AsmPrinter &printer, int bar, int foo);
```
@@ -885,8 +901,7 @@ is that the parameter for the parser must use the storage type of the parameter.
For example, `StringRefParameter` expects the parser and printer signatures as:
```c++
-LogicalResult parseStringParam(AsmParser &parser,
- FailureOr<std::string> &value);
+LogicalResult parseStringParam(AsmParser &parser, std::string &value);
void printStringParam(AsmPrinter &printer, StringRef value);
```
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 3251f2c5ec4f5..fcc35bdb856f9 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -219,7 +219,7 @@ void printType(Type type, AsmPrinter &printer);
} // namespace detail
/// Parse any MLIR type or a concise syntax for LLVM types.
-ParseResult parsePrettyLLVMType(AsmParser &p, FailureOr<Type> &type);
+ParseResult parsePrettyLLVMType(AsmParser &p, Type &type);
/// Print any MLIR type or a concise syntax for LLVM types.
void printPrettyLLVMType(AsmPrinter &p, Type type);
diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h
index 691055766ed24..73ede2e4d4818 100644
--- a/mlir/include/mlir/IR/AttributeSupport.h
+++ b/mlir/include/mlir/IR/AttributeSupport.h
@@ -264,6 +264,19 @@ class AttributeUniquer {
static void initializeAttributeStorage(AttributeStorage *storage,
MLIRContext *ctx, TypeID attrID);
};
+
+// Internal function called by ODS generated code.
+// Default initializes the type within a FailureOr<T> if T is default
+// constructible and returns a reference to the instance.
+// Otherwise, returns a reference to the FailureOr<T>.
+template <class T>
+decltype(auto) unwrapForCustomParse(FailureOr<T> &failureOr) {
+ if constexpr (std::is_default_constructible_v<T>)
+ return failureOr.emplace();
+ else
+ return failureOr;
+}
+
} // namespace detail
} // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index 6054dc426c47b..0c8b97b4de124 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -356,9 +356,8 @@ Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) {
return type;
}
-ParseResult LLVM::parsePrettyLLVMType(AsmParser &p, FailureOr<Type> &type) {
- type.emplace();
- return dispatchParse(p, *type);
+ParseResult LLVM::parsePrettyLLVMType(AsmParser &p, Type &type) {
+ return dispatchParse(p, type);
}
void LLVM::printPrettyLLVMType(AsmPrinter &p, Type type) {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index 8c27ee974df3c..ce28aa09b8214 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -33,10 +33,8 @@ constexpr const static unsigned kBitsInByte = 8;
// custom<FunctionTypes>
//===----------------------------------------------------------------------===//
-static ParseResult parseFunctionTypes(AsmParser &p,
- FailureOr<SmallVector<Type>> ¶ms,
- FailureOr<bool> &isVarArg) {
- params.emplace();
+static ParseResult parseFunctionTypes(AsmParser &p, SmallVector<Type> ¶ms,
+ bool &isVarArg) {
isVarArg = false;
// `(` `)`
if (succeeded(p.parseOptionalRParen()))
@@ -49,10 +47,10 @@ static ParseResult parseFunctionTypes(AsmParser &p,
}
// type (`,` type)* (`,` `...`)?
- FailureOr<Type> type;
+ Type type;
if (parsePrettyLLVMType(p, type))
return failure();
- params->push_back(*type);
+ params.push_back(type);
while (succeeded(p.parseOptionalComma())) {
if (succeeded(p.parseOptionalEllipsis())) {
isVarArg = true;
@@ -60,7 +58,7 @@ static ParseResult parseFunctionTypes(AsmParser &p,
}
if (parsePrettyLLVMType(p, type))
return failure();
- params->push_back(*type);
+ params.push_back(type);
}
return p.parseRParen();
}
@@ -81,11 +79,10 @@ static void printFunctionTypes(AsmPrinter &p, ArrayRef<Type> params,
// custom<Pointer>
//===----------------------------------------------------------------------===//
-static ParseResult parsePointer(AsmParser &p, FailureOr<Type> &elementType,
- FailureOr<unsigned> &addressSpace) {
- addressSpace = 0;
+static ParseResult parsePointer(AsmParser &p, Type &elementType,
+ unsigned &addressSpace) {
// `<` addressSpace `>`
- OptionalParseResult result = p.parseOptionalInteger(*addressSpace);
+ OptionalParseResult result = p.parseOptionalInteger(addressSpace);
if (result.has_value()) {
if (failed(result.value()))
return failure();
@@ -96,7 +93,7 @@ static ParseResult parsePointer(AsmParser &p, FailureOr<Type> &elementType,
if (parsePrettyLLVMType(p, elementType))
return failure();
if (succeeded(p.parseOptionalComma()))
- return p.parseInteger(*addressSpace);
+ return p.parseInteger(addressSpace);
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 3123c11607eb2..24955ffb1713f 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -164,12 +164,11 @@ ArrayRef<uint64_t> TestExtern1DI64ElementsAttr::getElements() const {
// TestCustomAnchorAttr
//===----------------------------------------------------------------------===//
-static ParseResult parseTrueFalse(AsmParser &p,
- FailureOr<std::optional<int>> &result) {
+static ParseResult parseTrueFalse(AsmParser &p, std::optional<int> &result) {
bool b;
if (p.parseInteger(b))
return failure();
- result = std::optional<int>(b);
+ result = b;
return success();
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 4d5c9b89957f2..231f69f629ce2 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -89,23 +89,21 @@ static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
// TestCustomType
//===----------------------------------------------------------------------===//
-static LogicalResult parseCustomTypeA(AsmParser &parser,
- FailureOr<int> &aResult) {
- aResult.emplace();
- return parser.parseInteger(*aResult);
+static LogicalResult parseCustomTypeA(AsmParser &parser, int &aResult) {
+ return parser.parseInteger(aResult);
}
static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
- FailureOr<std::optional<int>> &bResult) {
+ std::optional<int> &bResult) {
if (a < 0)
return success();
for (int i : llvm::seq(0, a))
if (failed(parser.parseInteger(i)))
return failure();
bResult.emplace(0);
- return parser.parseInteger(**bResult);
+ return parser.parseInteger(*bResult);
}
static void printCustomTypeB(AsmPrinter &printer, int a, std::optional<int> b) {
@@ -117,8 +115,7 @@ static void printCustomTypeB(AsmPrinter &printer, int a, std::optional<int> b) {
printer << *b;
}
-static LogicalResult parseFooString(AsmParser &parser,
- FailureOr<std::string> &foo) {
+static LogicalResult parseFooString(AsmParser &parser, std::string &foo) {
std::string result;
if (parser.parseString(&result))
return failure();
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index d4794c2438620..8ab622e6087ab 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -593,7 +593,7 @@ def TypeK : TestType<"TestM"> {
// TYPE-LABEL: ::mlir::Type TestNType::parse
// TYPE: parseFoo(
-// TYPE-NEXT: _result_a,
+// TYPE-NEXT: ::mlir::detail::unwrapForCustomParse(_result_a),
// TYPE-NEXT: 1);
// TYPE-LABEL: void TestNType::print
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 5a4617ba884df..44b78faf9b4b8 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -612,7 +612,8 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
for (FormatElement *arg : el->getArguments()) {
os << ",\n";
if (auto *param = dyn_cast<ParameterElement>(arg))
- os << "_result_" << param->getName();
+ os << "::mlir::detail::unwrapForCustomParse(_result_" << param->getName()
+ << ")";
else if (auto *ref = dyn_cast<RefDirective>(arg))
os << "*_result_" << cast<ParameterElement>(ref->getArg())->getName();
else
More information about the Mlir-commits
mailing list