[Mlir-commits] [mlir] 761bc83 - [mlir][ods] Default-valued parameters in attribute or type defs
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 15 11:02:15 PST 2022
Author: Mogball
Date: 2022-02-15T19:02:11Z
New Revision: 761bc83af4ee70c6c5156f73ed88947a5d9f013f
URL: https://github.com/llvm/llvm-project/commit/761bc83af4ee70c6c5156f73ed88947a5d9f013f
DIFF: https://github.com/llvm/llvm-project/commit/761bc83af4ee70c6c5156f73ed88947a5d9f013f.diff
LOG: [mlir][ods] Default-valued parameters in attribute or type defs
Optional parameters with `defaultValue` set will be populated with that value if they aren't encountered during parsing. Moreover, parameters equal to their default values are elided when printing.
Depends on D118210
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D118544
Added:
Modified:
mlir/docs/Tutorials/DefiningAttributesAndTypes.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
mlir/test/mlir-tblgen/attr-or-type-format.td
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
index 1501b54af5279..3260ac10896f4 100644
--- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
+++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md
@@ -518,6 +518,37 @@ printed as `(5 : i32)`. If it is not present, it will be `x`. Directives that
are used inside optional groups are allowed only if all captured parameters are
also optional.
+#### Default-Valued Parameters
+
+Optional parameters can be given default values by setting `defaultValue`, a
+string of the C++ default value, or by using `DefaultValuedParameter`. If a
+value for the parameter was not encountered during parsing, it is set to this
+default value. If a parameter is equal to its default value, it is not printed.
+The `comparator` field of the parameter is used, but if one is not specified,
+the equality operator is used.
+
+For example:
+
+```
+let parameters = (ins DefaultValuedParameter<"Optional<int>", "5">:$a)
+let mnemonic = "default_valued";
+let assemblyFormat = "(`<` $a^ `>`)?";
+```
+
+Which will look like:
+
+```
+!test.default_valued // a = 5
+!test.default_valued<10> // a = 10
+```
+
+For optional `Attribute` or `Type` parameters, the current MLIR context is
+available through `$_ctx`. E.g.
+
+```
+DefaultValuedParameter<"IntegerType", "IntegerType::get($_ctx, 32)">
+```
+
### Assembly Format Directives
Attribute and type assembly formats have the following directives:
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index bb108f714a789..5ad8bd45b339e 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -3117,7 +3117,8 @@ class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
// Custom memory allocation code for storage constructor.
code allocator = ?;
- // Custom comparator used to compare two instances for equality.
+ // Comparator used to compare two instances for equality. By default, it uses
+ // the C++ equality operator.
code comparator = ?;
// The C++ type of this parameter.
string cppType = type;
@@ -3143,6 +3144,14 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
// must be default constructible and be contextually convertible to `bool`.
// Any `Optional<T>` and any attribute type satisfies these requirements.
bit isOptional = 0;
+ // Provide a default value for the parameter. Parameters with default values
+ // are considered optional. If a value was not parsed for the parameter, it
+ // will be set to the default value. Parameters equal to their default values
+ // are elided when printing. Equality is checked using the `comparator` field,
+ // which by default is the C++ equality operator. The current MLIR context is
+ // made available through `$_ctx`, e.g., for constructing default values for
+ // attributes and types.
+ string defaultValue = ?;
}
class AttrParameter<string type, string desc, string accessorType = "">
: AttrOrTypeParameter<type, desc, accessorType>;
@@ -3193,6 +3202,13 @@ class OptionalParameter<string type, string desc = ""> :
let isOptional = 1;
}
+// A parameter with a default value.
+class DefaultValuedParameter<string type, string value, string desc = ""> :
+ AttrOrTypeParameter<type, desc> {
+ let isOptional = 1;
+ let defaultValue = value;
+}
+
// This is a special parameter used for AttrDefs that represents a `mlir::Type`
// that is also used as the value `Type` of the attribute. Only one parameter
// of the attribute may be of this type.
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index dcfdc8ab28a63..557f60b2686eb 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -52,6 +52,9 @@ class AttrOrTypeParameter {
explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index)
: def(def), index(index) {}
+ /// Returns true if the parameter is anonymous (has no name).
+ bool isAnonymous() const;
+
/// Get the parameter name.
StringRef getName() const;
@@ -59,7 +62,7 @@ class AttrOrTypeParameter {
Optional<StringRef> getAllocator() const;
/// If specified, get the custom comparator code for this parameter.
- Optional<StringRef> getComparator() const;
+ StringRef getComparator() const;
/// Get the C++ type of this parameter.
StringRef getCppType() const;
@@ -85,6 +88,9 @@ class AttrOrTypeParameter {
/// Returns true if the parameter is optional.
bool isOptional() const;
+ /// Get the default value of the parameter if it has one.
+ Optional<StringRef> getDefaultValue() const;
+
/// Return the underlying def of this parameter.
llvm::Init *getDef() const;
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 2d92a5837c79f..3c0cfa029c700 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -187,6 +187,10 @@ auto AttrOrTypeParameter::getDefValue(StringRef name) const {
return result;
}
+bool AttrOrTypeParameter::isAnonymous() const {
+ return !def->getArgName(index);
+}
+
StringRef AttrOrTypeParameter::getName() const {
return def->getArgName(index)->getValue();
}
@@ -195,8 +199,9 @@ Optional<StringRef> AttrOrTypeParameter::getAllocator() const {
return getDefValue<llvm::StringInit>("allocator");
}
-Optional<StringRef> AttrOrTypeParameter::getComparator() const {
- return getDefValue<llvm::StringInit>("comparator");
+StringRef AttrOrTypeParameter::getComparator() const {
+ return getDefValue<llvm::StringInit>("comparator")
+ .getValueOr("$_lhs == $_rhs");
}
StringRef AttrOrTypeParameter::getCppType() const {
@@ -239,7 +244,13 @@ StringRef AttrOrTypeParameter::getSyntax() const {
}
bool AttrOrTypeParameter::isOptional() const {
- return getDefValue<llvm::BitInit>("isOptional").getValueOr(false);
+ // Parameters with default values are automatically optional.
+ return getDefValue<llvm::BitInit>("isOptional").getValueOr(false) ||
+ getDefaultValue().hasValue();
+}
+
+Optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {
+ return getDefValue<llvm::StringInit>("defaultValue");
}
llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); }
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index b8dce7ec50f93..de4969353ad01 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -311,7 +311,7 @@ def TestTypeOptionalGroup : Test_Type<"TestTypeOptionalGroup"> {
}
def TestTypeOptionalGroupParams : Test_Type<"TestTypeOptionalGroupParams"> {
- let parameters = (ins OptionalParameter<"mlir::Optional<int>">:$a,
+ let parameters = (ins DefaultValuedParameter<"mlir::Optional<int>", "10">:$a,
OptionalParameter<"mlir::Optional<int>">:$b);
let mnemonic = "optional_group_params";
let assemblyFormat = "`<` (`(` params^ `)`) : (`x`)? `>`";
@@ -330,4 +330,37 @@ def TestTypeSpaces : Test_Type<"TestTypeSpaceS"> {
let assemblyFormat = "`<` ` ` $a `\\n` `(` `)` `` `(` `)` $b `>`";
}
+class DefaultValuedAPFloat<string value>
+ : DefaultValuedParameter<"llvm::Optional<llvm::APFloat>",
+ "llvm::Optional<llvm::APFloat>(" # value # ")"> {
+ let comparator = "$_lhs->bitwiseIsEqual(*$_rhs)";
+ let parser = [{ [&]() -> mlir::FailureOr<llvm::Optional<llvm::APFloat>> {
+ mlir::FloatAttr attr;
+ auto result = $_parser.parseOptionalAttribute(attr);
+ if (result.hasValue() && mlir::succeeded(*result))
+ return {attr.getValue()};
+ if (!result.hasValue())
+ return llvm::Optional<llvm::APFloat>();
+ return mlir::failure();
+ }() }];
+ let printer = "$_printer << *$_self";
+}
+
+def TestTypeAPFloat : Test_Type<"TestTypeAPFloat"> {
+ let parameters = (ins
+ DefaultValuedAPFloat<"APFloat::getZero(APFloat::IEEEdouble())">:$a
+ );
+ let mnemonic = "ap_float";
+ let assemblyFormat = "`<` $a `>`";
+}
+
+def TestTypeDefaultValuedType : Test_Type<"TestTypeDefaultValuedType"> {
+ let parameters = (ins
+ DefaultValuedParameter<"mlir::IntegerType",
+ "mlir::IntegerType::get($_ctx, 32)">:$type
+ );
+ let mnemonic = "default_valued_type";
+ let assemblyFormat = "`<` (`(` $type^ `)`)? `>`";
+}
+
#endif // TEST_TYPEDEFS
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
index 47f25cb20406b..b626ca44f36a1 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-roundtrip.mlir
@@ -44,6 +44,11 @@ attributes {
// CHECK: !test.optional_group_struct<(a = 10, b = 5)>
// CHECK: !test.spaces< 5
// CHECK-NEXT: ()() 6>
+// CHECK: !test.ap_float<5.000000e+00>
+// CHECK: !test.ap_float<>
+// CHECK: !test.default_valued_type<(i64)>
+// CHECK: !test.default_valued_type<>
+
func private @test_roundtrip_default_parsers_struct(
!test.no_parser<255, [1, 2, 3, 4, 5], "foobar", 4>
) -> (
@@ -70,5 +75,9 @@ func private @test_roundtrip_default_parsers_struct(
!test.optional_group_struct<x>,
!test.optional_group_struct<(b = 5)>,
!test.optional_group_struct<(b = 5, a = 10)>,
- !test.spaces<5 ()() 6>
+ !test.spaces<5 ()() 6>,
+ !test.ap_float<5.0>,
+ !test.ap_float<>,
+ !test.default_valued_type<(i64)>,
+ !test.default_valued_type<>
)
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 4ed281d488db5..f0c1aac4af446 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -389,10 +389,13 @@ def TypeC : TestType<"TestE"> {
let assemblyFormat = "`{` struct($v0, $v2) `}` `{` struct($v1, $v3) `}`";
}
+// TYPE: ::mlir::Type TestFType::parse(::mlir::AsmParser &odsParser) {
+// TYPE: _result_a.getValueOr(int())
+
// TYPE: void TestFType::print(::mlir::AsmPrinter &odsPrinter) const {
// TYPE if (getA()) {
-// TYPE printer << ' ';
-// TYPE printer.printStrippedAttrOrType(getA());
+// TYPE odsPrinter << ' ';
+// TYPE odsPrinter.printStrippedAttrOrType(getA());
def TypeD : TestType<"TestF"> {
let parameters = (ins OptionalParameter<"int">:$a);
let mnemonic = "type_f";
@@ -406,7 +409,7 @@ def TypeD : TestType<"TestF"> {
// TYPE: if (odsParser.parseComma())
// TYPE: return {};
-// TYPE: if (getA())
+// TYPE: if ((getA()))
// TYPE: odsPrinter.printStrippedAttrOrType(getA());
// TYPE: odsPrinter << ", ";
// TYPE: odsPrinter.printStrippedAttrOrType(getB());
@@ -426,7 +429,7 @@ def TypeE : TestType<"TestG"> {
// TYPE: return {};
// TYPE: void TestHType::print(::mlir::AsmPrinter &odsPrinter) const {
-// TYPE: if (getA()) {
+// TYPE: if ((getA())) {
// TYPE: odsPrinter << "a = ";
// TYPE: odsPrinter.printStrippedAttrOrType(getA());
// TYPE: odsPrinter << ", ";
@@ -469,9 +472,9 @@ def TypeG : TestType<"TestI"> {
// TYPE: return {};
// TYPE: void TestJType::print(::mlir::AsmPrinter &odsPrinter) const {
-// TYPE: if (getB()) {
+// TYPE: if ((getB())) {
// TYPE: odsPrinter << "(";
-// TYPE: if (getB())
+// TYPE: if ((getB()))
// TYPE: odsPrinter.printStrippedAttrOrType(getB());
// TYPE: odsPrinter << ")";
// TYPE: } else {
@@ -484,3 +487,15 @@ def TypeH : TestType<"TestJ"> {
let mnemonic = "type_j";
let assemblyFormat = "(`(` $b^ `)`) : (`x`)? $a";
}
+
+// TYPE: ::mlir::Type TestKType::parse(::mlir::AsmParser &odsParser) {
+// TYPE: _result_a.getValueOr(10)
+
+// TYPE: void TestKType::print(::mlir::AsmPrinter &odsPrinter) const {
+// TYPE: if ((getA() && !(getA() == 10)))
+
+def TypeI : TestType<"TestK"> {
+ let parameters = (ins DefaultValuedParameter<"int", "10">:$a);
+ let mnemonic = "type_k";
+ let assemblyFormat = "$a";
+}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 7e2b147478e17..c6db1c0418846 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -179,6 +179,11 @@ DefGen::DefGen(const AttrOrTypeDef &def)
: def(def), params(def.getParameters()), defCls(def.getCppClassName()),
valueType(isa<AttrDef>(def) ? "Attribute" : "Type"),
defType(isa<AttrDef>(def) ? "Attr" : "Type") {
+ // Check that all parameters have names.
+ for (const AttrOrTypeParameter ¶m : def.getParameters())
+ if (param.isAnonymous())
+ llvm::PrintFatalError("all parameters must have a name");
+
// If a storage class is needed, create one.
if (def.getNumParameters() > 0)
storageCls.emplace(def.getStorageClassName(), /*isStruct=*/true);
@@ -535,8 +540,7 @@ void DefGen::emitEquals() {
? "getType()"
: it.value().getName()},
{"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}});
- Optional<StringRef> comparator = it.value().getComparator();
- body << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &ctx);
+ body << tgfmt(it.value().getComparator(), &ctx);
};
llvm::interleave(llvm::enumerate(params), body, eachFn, ") && (");
}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 77087b6b4f698..35ab19942ba13 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -56,6 +56,23 @@ class ParameterElement
/// Returns the name of the parameter.
StringRef getName() const { return param.getName(); }
+ /// Generate the code to check whether the parameter should be printed.
+ auto genPrintGuard(FmtContext &ctx) const {
+ return [&](raw_ostream &os) -> raw_ostream & {
+ std::string self = getParameterAccessorName(getName()) + "()";
+ ctx.withSelf(self);
+ os << tgfmt("($_self", &ctx);
+ if (llvm::Optional<StringRef> defaultValue =
+ getParam().getDefaultValue()) {
+ // Use the `comparator` field if it exists, else the equality operator.
+ std::string valueStr = tgfmt(*defaultValue, &ctx).str();
+ ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
+ os << " && !(" << tgfmt(getParam().getComparator(), &ctx) << ")";
+ }
+ return os << ")";
+ };
+ }
+
private:
bool shouldBeQualifiedFlag = false;
AttrOrTypeParameter param;
@@ -65,6 +82,12 @@ class ParameterElement
static bool paramIsOptional(ParameterElement *el) { return el->isOptional(); }
static bool paramNotOptional(ParameterElement *el) { return !el->isOptional(); }
+/// raw_ostream doesn't have an overload for stream functors. Declare one here.
+template <typename StreamFunctor>
+static raw_ostream &operator<<(raw_ostream &os, StreamFunctor &&fcn) {
+ return fcn(os);
+}
+
/// Base class for a directive that contains references to multiple variables.
template <DirectiveElement::Kind DirectiveKind>
class ParamsDirectiveBase : public DirectiveElementBase<DirectiveKind> {
@@ -231,6 +254,7 @@ class DefFormat {
void DefFormat::genParser(MethodBody &os) {
FmtContext ctx;
ctx.addSubst("_parser", "odsParser");
+ ctx.addSubst("_ctx", "odsParser.getContext()");
if (isa<AttrDef>(def))
ctx.addSubst("_type", "odsType");
os.indent();
@@ -274,11 +298,16 @@ void DefFormat::genParser(MethodBody &os) {
def.getCppClassName());
}
for (const AttrOrTypeParameter ¶m : params) {
- if (param.isOptional())
- os << formatv(",\n _result_{0}.getValueOr({1}())", param.getName(),
- param.getCppStorageType());
- else
+ if (param.isOptional()) {
+ os << formatv(",\n _result_{0}.getValueOr(", param.getName());
+ if (Optional<StringRef> defaultValue = param.getDefaultValue())
+ os << tgfmt(*defaultValue, &ctx);
+ else
+ os << param.getCppStorageType() << "()";
+ os << ")";
+ } else {
os << formatv(",\n *_result_{0}", param.getName());
+ }
}
os << ");";
}
@@ -596,6 +625,7 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
void DefFormat::genPrinter(MethodBody &os) {
FmtContext ctx;
ctx.addSubst("_printer", "odsPrinter");
+ ctx.addSubst("_ctx", "getContext()");
os.indent();
/// Generate printers.
@@ -642,9 +672,10 @@ void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
const AttrOrTypeParameter ¶m = el->getParam();
ctx.withSelf(getParameterAccessorName(param.getName()) + "()");
- // Guard the printer on the presence of optional parameters.
+ // Guard the printer on the presence of optional parameters and that they
+ // aren't equal to their default values (if they have one).
if (el->isOptional() && !skipGuard) {
- os << tgfmt("if ($_self) {\n", &ctx);
+ os << "if (" << el->genPrintGuard(ctx) << ") {\n";
os.indent();
}
@@ -665,23 +696,27 @@ void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx,
os.unindent() << "}\n";
}
+/// Generate code to guard printing on the presence of any optional parameters.
+template <typename ParameterRange>
+static void guardOnAny(FmtContext &ctx, MethodBody &os,
+ ParameterRange &¶ms) {
+ os << "if (";
+ llvm::interleave(
+ params, os,
+ [&](ParameterElement *param) { os << param->genPrintGuard(ctx); },
+ " || ");
+ os << ") {\n";
+ os.indent();
+}
+
void DefFormat::genCommaSeparatedPrinter(
ArrayRef<ParameterElement *> params, FmtContext &ctx, MethodBody &os,
function_ref<void(ParameterElement *)> extra) {
// Emit a space if necessary, but only if the struct is present.
if (shouldEmitSpace || !lastWasPunctuation) {
bool allOptional = llvm::all_of(params, paramIsOptional);
- if (allOptional) {
- os << "if (";
- llvm::interleave(
- params, os,
- [&](ParameterElement *param) {
- os << getParameterAccessorName(param->getName()) << "()";
- },
- " || ");
- os << ") {\n";
- os.indent();
- }
+ if (allOptional)
+ guardOnAny(ctx, os, params);
os << tgfmt("$_printer << ' ';\n", &ctx);
if (allOptional)
os.unindent() << "}\n";
@@ -692,8 +727,7 @@ void DefFormat::genCommaSeparatedPrinter(
os.indent() << "bool _firstPrinted = true;\n";
for (ParameterElement *param : params) {
if (param->isOptional()) {
- os << tgfmt("if ($_self()) {\n",
- &ctx.withSelf(getParameterAccessorName(param->getName())));
+ os << "if (" << param->genPrintGuard(ctx) << ") {\n";
os.indent();
}
os << tgfmt("if (!_firstPrinted) $_printer << \", \";\n", &ctx);
@@ -724,26 +758,14 @@ void DefFormat::genStructPrinter(StructDirective *el, FmtContext &ctx,
void DefFormat::genOptionalGroupPrinter(OptionalElement *el, FmtContext &ctx,
MethodBody &os) {
- // Emit the check on whether the group should be printed.
- const auto guardOn = [&](auto params) {
- os << "if (";
- llvm::interleave(
- params, os,
- [&](ParameterElement *el) {
- os << getParameterAccessorName(el->getName()) << "()";
- },
- " || ");
- os << ") {\n";
- os.indent();
- };
FormatElement *anchor = el->getAnchor();
if (auto *param = dyn_cast<ParameterElement>(anchor)) {
- guardOn(llvm::makeArrayRef(param));
+ guardOnAny(ctx, os, llvm::makeArrayRef(param));
} else if (auto *params = dyn_cast<ParamsDirective>(anchor)) {
- guardOn(params->getParams());
+ guardOnAny(ctx, os, params->getParams());
} else {
- auto *strct = dyn_cast<StructDirective>(anchor);
- guardOn(strct->getParams());
+ auto *strct = cast<StructDirective>(anchor);
+ guardOnAny(ctx, os, strct->getParams());
}
// Generate the printer for the contained elements.
{
More information about the Mlir-commits
mailing list