[Mlir-commits] [mlir] 574e759 - [mlir][ods] Support using custom directives as first optional group element

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 30 20:21:03 PDT 2023


Author: Mogball
Date: 2023-08-31T03:20:54Z
New Revision: 574e7596e5149c3319dea255b2c45e670ca4711b

URL: https://github.com/llvm/llvm-project/commit/574e7596e5149c3319dea255b2c45e670ca4711b
DIFF: https://github.com/llvm/llvm-project/commit/574e7596e5149c3319dea255b2c45e670ca4711b.diff

LOG: [mlir][ods] Support using custom directives as first optional group element

This adds support for using a custom directive as the first optional
group element. The first optional group element guards the parsing of
the rest of the optional group. This can be done for custom directives
by expecting the parse function to return an `OptionalParseResult`
instead of a `ParseResult`.

Depends on D159243

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D159244

Added: 
    

Modified: 
    mlir/test/mlir-tblgen/attr-or-type-format.td
    mlir/test/mlir-tblgen/op-format-invalid.td
    mlir/test/mlir-tblgen/op-format.td
    mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
    mlir/tools/mlir-tblgen/FormatGen.cpp
    mlir/tools/mlir-tblgen/OpFormatGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td
index 2782f55bc966ed..b9041e45f856ad 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format.td
@@ -645,8 +645,26 @@ def TypeN : TestType<"TestP"> {
   let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`";
 }
 
+// TYPE-LABEL: TestQType::parse
+// TYPE: if (auto result = [&]() -> ::mlir::OptionalParseResult {
+// TYPE:     auto odsCustomResult = parseAB(odsParser
+// TYPE:     if (!odsCustomResult) return {};
+// TYPE:     if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();
+// TYPE:   return ::mlir::success();
+// TYPE: }(); result.has_value() && ::mlir::failed(*result)) {
+// TYPE:   return {};
+// TYPE: } else if (result.has_value()) {
+// TYPE:   // Parse literal 'y'
+// TYPE: } else {
+// TYPE:   // Parse literal 'x'
+def TypeO : TestType<"TestQ"> {
+  let parameters = (ins OptionalParameter<"int">:$a);
+  let mnemonic = "type_o";
+  let assemblyFormat = "(custom<AB>($a)^ `x`) : (`y`)?";
+}
+
 // DEFAULT_TYPE_PARSER: TestDialect::parseType(::mlir::DialectAsmParser &parser)
 // DEFAULT_TYPE_PARSER: auto parseResult = parseOptionalDynamicType(mnemonic, parser, genType);
 // DEFAULT_TYPE_PARSER: if (parseResult.has_value()) {
 // DEFAULT_TYPE_PARSER:   if (::mlir::succeeded(parseResult.value()))
-// DEFAULT_TYPE_PARSER:     return genType;
\ No newline at end of file
+// DEFAULT_TYPE_PARSER:     return genType;

diff  --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td
index 210241caef3163..ce91ceea34cee1 100644
--- a/mlir/test/mlir-tblgen/op-format-invalid.td
+++ b/mlir/test/mlir-tblgen/op-format-invalid.td
@@ -357,7 +357,7 @@ def OptionalInvalidB : TestFormat_Op<[{
 def OptionalInvalidC : TestFormat_Op<[{
   ($attr)? attr-dict
 }]>, Arguments<(ins OptionalAttr<I64Attr>:$attr)>;
-// CHECK: error: first parsable element of an optional group must be a literal or variable
+// CHECK: error: first parsable element of an optional group must be a literal, variable, or custom directive
 def OptionalInvalidD : TestFormat_Op<[{
   (type($operand) $operand^)? attr-dict
 }]>, Arguments<(ins Optional<I64>:$operand)>;

diff  --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td
index c098a52f15b3b0..dd756788098c90 100644
--- a/mlir/test/mlir-tblgen/op-format.td
+++ b/mlir/test/mlir-tblgen/op-format.td
@@ -84,9 +84,20 @@ def OptionalGroupC : TestFormat_Op<[{
   ($a^)? attr-dict
 }]>, Arguments<(ins DefaultValuedStrAttr<StrAttr, "default">:$a)>;
 
+// CHECK-LABEL: OptionalGroupD::parse
+// CHECK: if (auto result = [&]() -> ::mlir::OptionalParseResult {
+// CHECK:   auto odsResult = parseCustom(parser, aOperand, bOperand);
+// CHECK:   if (!odsResult) return {};
+// CHECK:   if (::mlir::failed(*odsResult)) return ::mlir::failure();
+// CHECK:   return ::mlir::success();
+// CHECK: }(); result.has_value() && ::mlir::failed(*result)) {
+// CHECK:   return ::mlir::failure();
+// CHECK: } else if (result.has_value()) {
+
 // CHECK-LABEL: OptionalGroupD::print
 // CHECK-NEXT: if (((getA()) || (getB()))) {
-// CHECK-NEXT:   odsPrinter << "("
+// CHECK-NEXT:   odsPrinter << ' ';
+// CHECK-NEXT:   printCustom
 def OptionalGroupD : TestFormat_Op<[{
-  (`(` custom<Custom>($a, $b)^ `)`)? attr-dict
+  (custom<Custom>($a, $b)^)? attr-dict
 }], [AttrSizedOperandSegments]>, Arguments<(ins Optional<I64>:$a, Optional<I64>:$b)>;

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index f14be81467d850..f8e0c83da3c8a6 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -201,7 +201,8 @@ class DefFormat {
   /// Generate the parser code for a `struct` directive.
   void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os);
   /// Generate the parser code for a `custom` directive.
-  void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os);
+  void genCustomParser(CustomDirective *el, FmtContext &ctx, MethodBody &os,
+                       bool isOptional = false);
   /// Generate the parser code for an optional group.
   void genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
                               MethodBody &os);
@@ -598,7 +599,7 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx,
 }
 
 void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
-                                MethodBody &os) {
+                                MethodBody &os, bool isOptional) {
   os << "{\n";
   os.indent();
 
@@ -620,7 +621,12 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
       os << tgfmt(cast<StringElement>(arg)->getValue(), &ctx);
   }
   os.unindent() << ");\n";
-  os << "if (::mlir::failed(odsCustomResult)) return {};\n";
+  if (isOptional) {
+    os << "if (!odsCustomResult) return {};\n";
+    os << "if (::mlir::failed(*odsCustomResult)) return ::mlir::failure();\n";
+  } else {
+    os << "if (::mlir::failed(odsCustomResult)) return {};\n";
+  }
   for (FormatElement *arg : el->getArguments()) {
     if (auto *param = dyn_cast<ParameterElement>(arg)) {
       if (param->isOptional())
@@ -629,7 +635,7 @@ void DefFormat::genCustomParser(CustomDirective *el, FmtContext &ctx,
       os.indent() << tgfmt("$_parser.emitError(odsCustomLoc, ", &ctx)
                   << "\"custom parser failed to parse parameter '"
                   << param->getName() << "'\");\n";
-      os << "return {};\n";
+      os << "return " << (isOptional ? "::mlir::failure()" : "{}") << ";\n";
       os.unindent() << "}\n";
     }
   }
@@ -663,6 +669,17 @@ void DefFormat::genOptionalGroupParser(OptionalElement *el, FmtContext &ctx,
   } else if (auto *params = dyn_cast<ParamsDirective>(first)) {
     genParamsParser(params, ctx, os);
     guardOn(params->getParams());
+  } else if (auto *custom = dyn_cast<CustomDirective>(first)) {
+    os << "if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
+    os.indent();
+    genCustomParser(custom, ctx, os, /*isOptional=*/true);
+    os << "return ::mlir::success();\n";
+    os.unindent();
+    os << "}(); result.has_value() && ::mlir::failed(*result)) {\n";
+    os.indent();
+    os << "return {};\n";
+    os.unindent();
+    os << "} else if (result.has_value()) {\n";
   } else {
     auto *strct = cast<StructDirective>(first);
     genStructParser(strct, ctx, os);

diff  --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index b4f71fb45b3765..d402748b96ad5f 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -383,9 +383,9 @@ FailureOr<FormatElement *> FormatParser::parseOptionalGroup(Context ctx) {
   unsigned thenParseStart = std::distance(thenElements.begin(), thenParseBegin);
   unsigned elseParseStart = std::distance(elseElements.begin(), elseParseBegin);
 
-  if (!isa<LiteralElement, VariableElement>(*thenParseBegin)) {
+  if (!isa<LiteralElement, VariableElement, CustomDirective>(*thenParseBegin)) {
     return emitError(loc, "first parsable element of an optional group must be "
-                          "a literal or variable");
+                          "a literal, variable, or custom directive");
   }
   return create<OptionalElement>(std::move(thenElements),
                                  std::move(elseElements), thenParseStart,

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 3d045044103c14..bdb97866a47fc9 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -953,7 +953,8 @@ static void genCustomParameterParser(FormatElement *param, MethodBody &body) {
 /// Generate the parser for a custom directive.
 static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
                                      bool useProperties,
-                                     StringRef opCppClassName) {
+                                     StringRef opCppClassName,
+                                     bool isOptional = false) {
   body << "  {\n";
 
   // Preprocess the directive variables.
@@ -1011,14 +1012,19 @@ static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body,
     }
   }
 
-  body << "    if (parse" << dir->getName() << "(parser";
+  body << "    auto odsResult = parse" << dir->getName() << "(parser";
   for (FormatElement *param : dir->getArguments()) {
     body << ", ";
     genCustomParameterParser(param, body);
   }
+  body << ");\n";
 
-  body << "))\n"
-       << "      return ::mlir::failure();\n";
+  if (isOptional) {
+    body << "    if (!odsResult) return {};\n"
+         << "    if (::mlir::failed(*odsResult)) return ::mlir::failure();\n";
+  } else {
+    body << "    if (odsResult) return ::mlir::failure();\n";
+  }
 
   // After parsing, add handling for any of the optional constructs.
   for (FormatElement *param : dir->getArguments()) {
@@ -1273,6 +1279,14 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
           body << llvm::formatv(regionEnsureSingleBlockParserCode,
                                 region->name);
       }
+    } else if (auto *custom = dyn_cast<CustomDirective>(firstElement)) {
+      body << "  if (auto result = [&]() -> ::mlir::OptionalParseResult {\n";
+      genCustomDirectiveParser(custom, body, useProperties, opCppClassName,
+                               /*isOptional=*/true);
+      body << "    return ::mlir::success();\n"
+           << "  }(); result.has_value() && ::mlir::failed(*result)) {\n"
+           << "    return ::mlir::failure();\n"
+           << "  } else if (result.has_value()) {\n";
     }
 
     genElementParsers(firstElement, thenElements.drop_front(),


        


More information about the Mlir-commits mailing list