[Mlir-commits] [mlir] 4792f2a - [mlir][ods] Generalize default/optional parameters

Jeff Niu llvmlistbot at llvm.org
Tue Sep 20 11:08:10 PDT 2022


Author: Jeff Niu
Date: 2022-09-20T11:07:53-07:00
New Revision: 4792f2ab214e2df7875d17d4258bd5eae733e825

URL: https://github.com/llvm/llvm-project/commit/4792f2ab214e2df7875d17d4258bd5eae733e825
DIFF: https://github.com/llvm/llvm-project/commit/4792f2ab214e2df7875d17d4258bd5eae733e825.diff

LOG: [mlir][ods] Generalize default/optional parameters

This patch consolidates the notions of an optional parameter and a
default parameter. An optional parameter is a parameter equal to its
default value, which for a "purely optional" parameter is its "null"
value.

This allows the existing `comparator` and `defaultValue` fields to be
used enabled more complex "optional" parameters, such as empty arrays.

Depends on D133812

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D133816

Added: 
    

Modified: 
    mlir/docs/AttributesAndTypes.md
    mlir/include/mlir/IR/AttrTypeBase.td
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md
index 4821e63f515f4..ec168dfc61973 100644
--- a/mlir/docs/AttributesAndTypes.md
+++ b/mlir/docs/AttributesAndTypes.md
@@ -640,23 +640,21 @@ To add a custom conversion between the `cppStorageType` and the C++ type of the
 parameter, parameters can override `convertFromStorage`, which by default is
 `"$_self"` (i.e., it attempts an implicit conversion from `cppStorageType`).
 
-###### Optional Parameters
+###### Optional and Default-Valued Parameters
 
+An optional parameter can be omitted from the assembly format of an attribute or
+a type. An optional parameter is omitted when it is equal to its default value.
 Optional parameters in the assembly format can be indicated by setting
-`isOptional`. The C++ type of an optional parameter is required to satisfy the
-following requirements:
+`defaultValue`, a string of the C++ default value. If a value for the parameter
+was not encountered during parsing, it is set to this default value. If a
+parameter is equal to its default value, it is not printed. The `comparator`
+field of the parameter is used, but if one is not specified, the equality
+operator is used.
 
-- is default-constructible
-- is contextually convertible to `bool`
-- only the default-constructed value is `false`
-
-The parameter parser should return the default-constructed value to indicate "no
-value present". The printer will guard on the presence of a value to print the
-parameter.
-
-If a value was not parsed for an optional parameter, then the parameter will be
-set to its default-constructed C++ value. For example, `Optional<int>` will be
-set to `llvm::None` and `Attribute` will be set to `nullptr`.
+When using `OptionalParameter`, the default value is set to the C++
+default-constructed value for the C++ storage type. For example, `Optional<int>`
+will be set to `llvm::None` and `Attribute` will be set to `nullptr`. The
+presence of these parameters is tested by comparing them to their "null" values.
 
 Only optional parameters or directives that only capture optional parameters can
 be used in optional groups. An optional group is a set of elements optionally
@@ -673,16 +671,9 @@ printed as `(5 : i32)`. If it is not present, it will be `x`. Directives that
 are used inside optional groups are allowed only if all captured parameters are
 also optional.
 
-###### Default-Valued Parameters
-
-Optional parameters can be given default values by setting `defaultValue`, a
-string of the C++ default value, or by using `DefaultValuedParameter`. If a
-value for the parameter was not encountered during parsing, it is set to this
-default value. If a parameter is equal to its default value, it is not printed.
-The `comparator` field of the parameter is used, but if one is not specified,
-the equality operator is used.
-
-For example:
+An optional parameter can also be specified with `DefaultValuedParameter`, which
+specifies that a parameter should be omitted when it is equal to some given
+value.
 
 ```tablegen
 let parameters = (ins DefaultValuedParameter<"Optional<int>", "5">:$a)

diff  --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 8f168064a362f..429624d70ee17 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -299,7 +299,7 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
   string cppAccessorType = !if(!empty(accessorType), type, accessorType);
   // The C++ storage type of of this parameter if it is a reference, e.g.
   // `std::string` for `StringRef` or `SmallVector` for `ArrayRef`.
-  string cppStorageType = ?;
+  string cppStorageType = cppType;
   // The C++ code to convert from the storage type to the parameter type.
   string convertFromStorage = "$_self";
   // One-line human-readable description of the argument.
@@ -315,10 +315,6 @@ class AttrOrTypeParameter<string type, string desc, string accessorType = ""> {
   // operator of `AsmPrinter` as necessary to print your type. Or you can
   // provide a custom printer.
   string printer = ?;
-  // Mark a parameter as optional. The C++ type of parameters marked as optional
-  // must be default constructible and be contextually convertible to `bool`.
-  // Any `Optional<T>` and any attribute type satisfies these requirements.
-  bit isOptional = 0;
   // Provide a default value for the parameter. Parameters with default values
   // are considered optional. If a value was not parsed for the parameter, it
   // will be set to the default value. Parameters equal to their default values
@@ -374,13 +370,12 @@ class ArrayRefOfSelfAllocationParameter<string arrayOf, string desc> :
 // An optional parameter.
 class OptionalParameter<string type, string desc = ""> :
     AttrOrTypeParameter<type, desc> {
-  let isOptional = 1;
+  let defaultValue = cppStorageType # "()";
 }
 
 // A parameter with a default value.
 class DefaultValuedParameter<string type, string value, string desc = ""> :
     AttrOrTypeParameter<type, desc> {
-  let isOptional = 1;
   let defaultValue = value;
 }
 

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index bf86218c360b8..38f1823b1437e 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -277,9 +277,7 @@ StringRef AttrOrTypeParameter::getSyntax() const {
 }
 
 bool AttrOrTypeParameter::isOptional() const {
-  // Parameters with default values are automatically optional.
-  return getDefValue<llvm::BitInit>("isOptional").value_or(false) ||
-         getDefaultValue();
+  return getDefaultValue().has_value();
 }
 
 Optional<StringRef> AttrOrTypeParameter::getDefaultValue() const {

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 8272efae32338..81fd154d341e6 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -286,19 +286,17 @@ def TestTypeSpaces : Test_Type<"TestTypeSpaceS"> {
 }
 
 class DefaultValuedAPFloat<string value>
-    : DefaultValuedParameter<"llvm::Optional<llvm::APFloat>",
-                             "llvm::Optional<llvm::APFloat>(" # value # ")"> {
-  let comparator = "$_lhs->bitwiseIsEqual(*$_rhs)";
-  let parser = [{ [&]() -> mlir::FailureOr<llvm::Optional<llvm::APFloat>> {
+    : DefaultValuedParameter<"llvm::APFloat", "llvm::APFloat(" # value # ")"> {
+  let comparator = "$_lhs.bitwiseIsEqual($_rhs)";
+  let parser = [{ [&]() -> mlir::FailureOr<llvm::APFloat> {
     mlir::FloatAttr attr;
     auto result = $_parser.parseOptionalAttribute(attr);
     if (result.has_value() && mlir::succeeded(*result))
-      return {attr.getValue()};
+      return attr.getValue();
     if (!result.has_value())
-      return llvm::Optional<llvm::APFloat>();
+      return llvm::APFloat(}] # value # [{);
     return mlir::failure();
   }() }];
-  let printer = "$_printer << *$_self";
 }
 
 def TestTypeAPFloat : Test_Type<"TestTypeAPFloat"> {

diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 1ecb688ccd29c..ed537c1c8a497 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -421,11 +421,11 @@ def TypeD : TestType<"TestF"> {
 // TYPE: ::mlir::Type TestGType::parse(::mlir::AsmParser &odsParser) {
 // TYPE:   if (::mlir::failed(_result_a))
 // TYPE:     return {};
-// TYPE:   if (::mlir::succeeded(_result_a) && *_result_a)
+// TYPE:   if (::mlir::succeeded(_result_a) && !((*_result_a) == int()))
 // TYPE:     if (odsParser.parseComma())
 // TYPE:       return {};
 
-// TYPE: if ((getA()))
+// TYPE: if (!(getA() == int()))
 // TYPE:   odsPrinter.printStrippedAttrOrType(getA());
 // TYPE: odsPrinter << ", ";
 // TYPE: odsPrinter.printStrippedAttrOrType(getB());
@@ -445,7 +445,7 @@ def TypeE : TestType<"TestG"> {
 // TYPE:     return {};
 
 // TYPE: void TestHType::print(::mlir::AsmPrinter &odsPrinter) const {
-// TYPE:   if ((getA())) {
+// TYPE:   if (!(getA() == int())) {
 // TYPE:     odsPrinter << "a = ";
 // TYPE:     odsPrinter.printStrippedAttrOrType(getA());
 // TYPE:     odsPrinter << ", ";
@@ -488,9 +488,9 @@ def TypeG : TestType<"TestI"> {
 // TYPE:     return {};
 
 // TYPE: void TestJType::print(::mlir::AsmPrinter &odsPrinter) const {
-// TYPE:   if ((getB())) {
+// TYPE:   if (!(getB() == int())) {
 // TYPE:     odsPrinter << "(";
-// TYPE:     if ((getB()))
+// TYPE:     if (!(getB() == int()))
 // TYPE:       odsPrinter.printStrippedAttrOrType(getB());
 // TYPE:     odsPrinter << ")";
 // TYPE:   } else {
@@ -508,7 +508,7 @@ def TypeH : TestType<"TestJ"> {
 // TYPE:   _result_a.value_or(10)
 
 // TYPE: void TestKType::print(::mlir::AsmPrinter &odsPrinter) const {
-// TYPE:   if ((getA() && !(getA() == 10)))
+// TYPE:   if (!(getA() == 10))
 
 def TypeI : TestType<"TestK"> {
   let parameters = (ins DefaultValuedParameter<"int", "10">:$a);
@@ -578,7 +578,7 @@ def TypeL : TestType<"TestN"> {
 // TYPE: else
 
 // TYPE-LABEL: void TestOType::print
-// TYPE: if (!((getA())))
+// TYPE: if (!(!(getA() == int())))
 // TYPE: odsPrinter << ' ' << "?"
 // TYPE: else
 // TYPE: odsPrinter.printStrippedAttrOrType(getA())
@@ -598,7 +598,7 @@ def TypeM : TestType<"TestO"> {
 // TYPE-NEXT: }
 
 // TYPE-LABEL: void TestPType::print
-// TYPE: if (!((getA()) || (getB())))
+// TYPE: if (!(!(getA() == int()) || !(getB() == int())))
 // TYPE-NEXT: odsPrinter << "?"
 
 def TypeN : TestType<"TestP"> {

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 655f40dbe1c47..c8b75b012f33a 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -56,18 +56,19 @@ class ParameterElement
   /// Returns the name of the parameter.
   StringRef getName() const { return param.getName(); }
 
+  /// Return the code to check whether the parameter is present.
+  auto genIsPresent(FmtContext &ctx, const Twine &self) const {
+    assert(isOptional() && "cannot guard on a mandatory parameter");
+    std::string valueStr = tgfmt(*param.getDefaultValue(), &ctx).str();
+    ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
+    return tgfmt(getParam().getComparator(), &ctx);
+  }
+
   /// Generate the code to check whether the parameter should be printed.
   MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const {
+    assert(isOptional() && "cannot guard on a mandatory parameter");
     std::string self = param.getAccessorName() + "()";
-    ctx.withSelf(self);
-    os << tgfmt("($_self", &ctx);
-    if (llvm::Optional<StringRef> defaultValue = getParam().getDefaultValue()) {
-      // Use the `comparator` field if it exists, else the equality operator.
-      std::string valueStr = tgfmt(*defaultValue, &ctx).str();
-      ctx.addSubst("_lhs", self).addSubst("_rhs", valueStr);
-      os << " && !(" << tgfmt(getParam().getComparator(), &ctx) << ")";
-    }
-    return os << ")";
+    return os << "!(" << genIsPresent(ctx, self) << ")";
   }
 
 private:
@@ -332,13 +333,9 @@ void DefFormat::genParser(MethodBody &os) {
     os << ",\n    ";
     std::string paramSelfStr;
     llvm::raw_string_ostream selfOs(paramSelfStr);
-    if (param.isOptional()) {
-      selfOs << formatv("(_result_{0}.value_or(", param.getName());
-      if (Optional<StringRef> defaultValue = param.getDefaultValue())
-        selfOs << tgfmt(*defaultValue, &ctx);
-      else
-        selfOs << param.getCppStorageType() << "()";
-      selfOs << "))";
+    if (Optional<StringRef> defaultValue = param.getDefaultValue()) {
+      selfOs << formatv("(_result_{0}.value_or(", param.getName())
+             << tgfmt(*defaultValue, &ctx) << "))";
     } else {
       selfOs << formatv("(*_result_{0})", param.getName());
     }
@@ -447,8 +444,9 @@ void DefFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx,
     ParameterElement *el = *std::prev(it);
     // Parse a comma if the last optional parameter had a value.
     if (el->isOptional()) {
-      os << formatv("if (::mlir::succeeded(_result_{0}) && *_result_{0}) {{\n",
-                    el->getName());
+      os << formatv("if (::mlir::succeeded(_result_{0}) && !({1})) {{\n",
+                    el->getName(),
+                    el->genIsPresent(ctx, "(*_result_" + el->getName() + ")"));
       os.indent();
     }
     if (it <= lastReqIt) {
@@ -522,18 +520,6 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
     }
 )";
 
-  // Optional parameters in a struct must be parsed successfully if the
-  // keyword is present.
-  //
-  // {0}: Name of the parameter.
-  // {1}: Emit error string
-  const char *const checkOptionalParam = R"(
-    if (::mlir::succeeded(_result_{0}) && !*_result_{0}) {{
-      {1}"expected a value for parameter '{0}'");
-      return {{};
-    }
-)";
-
   // First iteration of the loop parsing an optional struct.
   const char *const optionalStructFirst = R"(
   ::llvm::StringRef _paramKey;
@@ -558,11 +544,6 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
                   "  _seen_{0} = true;\n",
                   param->getName());
     genVariableParser(param, ctx, os.indent());
-    if (param->isOptional()) {
-      os.getStream().printReindented(strfmt(checkOptionalParam,
-                                            param->getName(),
-                                            tgfmt(parserErrorStr, &ctx).str()));
-    }
     os.unindent() << "} else ";
     // Print the check for duplicate or unknown parameter.
   }


        


More information about the Mlir-commits mailing list