[Mlir-commits] [mlir] [mlir][ODS] Deduplicate `ref` and `qualified` handling (PR #91080)

Markus Böck llvmlistbot at llvm.org
Sat May 4 11:52:03 PDT 2024


https://github.com/zero9178 created https://github.com/llvm/llvm-project/pull/91080

Both the attribute and type format generator and the op format generator independently implemented the parsing and verification of the `ref` and `qualified` directives with little to no differences.

This PR moves the implementation of these into the common `FormatParser` class to deduplicate the implementations.

>From 56d096bb46f42bc7c82503b300faa94a8d00254b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Sat, 4 May 2024 19:51:04 +0100
Subject: [PATCH] [mlir][ODS] Deduplicate `ref` and `qualified` handling

Both the attribute and type format generator and the op format generator independently implemented the parsing and verification of the `ref` and `qualified` directives with little to no differences.

This PR moves the implementation of these into the common `FormatParser` class to deduplicate the implementations.
---
 .../attr-or-type-format-invalid.td            |  2 +-
 .../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 52 ++++---------------
 mlir/tools/mlir-tblgen/FormatGen.cpp          | 36 +++++++++++++
 mlir/tools/mlir-tblgen/FormatGen.h            | 10 +++-
 mlir/tools/mlir-tblgen/OpFormatGen.cpp        | 40 ++------------
 5 files changed, 61 insertions(+), 79 deletions(-)

diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index d3be4d8b8022a0..3a57cbca4d7bb7 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -111,7 +111,7 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
 
 def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> {
   let parameters = (ins "int":$a);
-  // CHECK: `ref` is only allowed inside custom directives
+  // CHECK: 'ref' is only valid within a `custom` directive
   let assemblyFormat = "$a ref($a)";
 }
 
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 6098808c646f76..abd1fbdaf8c649 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -940,6 +940,8 @@ class DefFormatParser : public FormatParser {
                                             ArrayRef<FormatElement *> elements,
                                             FormatElement *anchor) override;
 
+  LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
+
   /// Parse an attribute or type variable.
   FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
                                                Context ctx) override;
@@ -950,12 +952,8 @@ class DefFormatParser : public FormatParser {
 private:
   /// Parse a `params` directive.
   FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
-  /// Parse a `qualified` directive.
-  FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
   /// Parse a `struct` directive.
   FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
-  /// Parse a `ref` directive.
-  FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
 
   /// Attribute or type tablegen def.
   const AttrOrTypeDef &def;
@@ -1060,6 +1058,14 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
   return success();
 }
 
+LogicalResult DefFormatParser::markQualified(SMLoc loc,
+                                             FormatElement *element) {
+  if (!isa<ParameterElement>(element))
+    return emitError(loc, "`qualified` argument list expected a variable");
+  cast<ParameterElement>(element)->setShouldBeQualified();
+  return success();
+}
+
 FailureOr<DefFormat> DefFormatParser::parse() {
   FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
   if (failed(elements))
@@ -1107,33 +1113,11 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
     return parseParamsDirective(loc, ctx);
   case FormatToken::kw_struct:
     return parseStructDirective(loc, ctx);
-  case FormatToken::kw_ref:
-    return parseRefDirective(loc, ctx);
-  case FormatToken::kw_custom:
-    return parseCustomDirective(loc, ctx);
-
   default:
     return emitError(loc, "unsupported directive kind");
   }
 }
 
-FailureOr<FormatElement *>
-DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
-  if (failed(parseToken(FormatToken::l_paren,
-                        "expected '(' before argument list")))
-    return failure();
-  FailureOr<FormatElement *> var = parseElement(ctx);
-  if (failed(var))
-    return var;
-  if (!isa<ParameterElement>(*var))
-    return emitError(loc, "`qualified` argument list expected a variable");
-  cast<ParameterElement>(*var)->setShouldBeQualified();
-  if (failed(
-          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
-    return failure();
-  return var;
-}
-
 FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
                                                                  Context ctx) {
   // It doesn't make sense to allow references to all parameters in a custom
@@ -1201,22 +1185,6 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
   return create<StructDirective>(std::move(vars));
 }
 
-FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
-                                                              Context ctx) {
-  if (ctx != CustomDirectiveContext)
-    return emitError(loc, "`ref` is only allowed inside custom directives");
-
-  // Parse the child parameter element.
-  FailureOr<FormatElement *> child;
-  if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
-      failed(child = parseElement(RefDirectiveContext)) ||
-      failed(parseToken(FormatToken::r_paren, "expeced ')'")))
-    return failure();
-
-  // Only parameter elements are allowed to be parsed under a `ref` directive.
-  return create<RefDirective>(*child);
-}
-
 //===----------------------------------------------------------------------===//
 // Interface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index d402748b96ad5f..7540e584b8fac5 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -308,6 +308,10 @@ FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
 
   if (tok.is(FormatToken::kw_custom))
     return parseCustomDirective(loc, ctx);
+  if (tok.is(FormatToken::kw_ref))
+    return parseRefDirective(loc, ctx);
+  if (tok.is(FormatToken::kw_qualified))
+    return parseQualifiedDirective(loc, ctx);
   return parseDirectiveImpl(loc, tok.getKind(), ctx);
 }
 
@@ -430,6 +434,38 @@ FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
   return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
 }
 
+FailureOr<FormatElement *> FormatParser::parseRefDirective(SMLoc loc,
+                                                           Context context) {
+  if (context != CustomDirectiveContext)
+    return emitError(loc, "'ref' is only valid within a `custom` directive");
+
+  FailureOr<FormatElement *> arg;
+  if (failed(parseToken(FormatToken::l_paren,
+                        "expected '(' before argument list")) ||
+      failed(arg = parseElement(RefDirectiveContext)) ||
+      failed(
+          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
+    return failure();
+
+  return create<RefDirective>(*arg);
+}
+
+FailureOr<FormatElement *> FormatParser::parseQualifiedDirective(SMLoc loc,
+                                                                 Context ctx) {
+  if (failed(parseToken(FormatToken::l_paren,
+                        "expected '(' before argument list")))
+    return failure();
+  FailureOr<FormatElement *> var = parseElement(ctx);
+  if (failed(var))
+    return var;
+  if (failed(markQualified(loc, *var)))
+    return failure();
+  if (failed(
+          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
+    return failure();
+  return var;
+}
+
 //===----------------------------------------------------------------------===//
 // Utility Functions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 18a410277fc108..b061d4d8ea7f03 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -495,9 +495,12 @@ class FormatParser {
   FailureOr<FormatElement *> parseDirective(Context ctx);
   /// Parse an optional group.
   FailureOr<FormatElement *> parseOptionalGroup(Context ctx);
-
   /// Parse a custom directive.
   FailureOr<FormatElement *> parseCustomDirective(llvm::SMLoc loc, Context ctx);
+  /// Parse a ref directive.
+  FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context context);
+  /// Parse a qualified directive.
+  FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
 
   /// Parse a format-specific variable kind.
   virtual FailureOr<FormatElement *>
@@ -522,6 +525,11 @@ class FormatParser {
                               ArrayRef<FormatElement *> elements,
                               FormatElement *anchor) = 0;
 
+  /// Mark 'element' as qualified. If 'element' cannot be qualified an error
+  /// should be emitted and failure returned.
+  virtual LogicalResult markQualified(llvm::SMLoc loc,
+                                      FormatElement *element) = 0;
+
   //===--------------------------------------------------------------------===//
   // Lexer Utilities
 
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 806991035e6685..f7cc0a292b8c53 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -2547,6 +2547,8 @@ class OpFormatParser : public FormatParser {
   LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
                                            bool isAnchor);
 
+  LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
+
   /// Parse an operation variable.
   FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
                                                Context ctx) override;
@@ -2622,10 +2624,6 @@ class OpFormatParser : public FormatParser {
   FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
   LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
   FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
-  FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
-                                                     Context context);
-  FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
-                                                     Context context);
   FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
   FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
   FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
@@ -3224,16 +3222,12 @@ OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
     return parseFunctionalTypeDirective(loc, ctx);
   case FormatToken::kw_operands:
     return parseOperandsDirective(loc, ctx);
-  case FormatToken::kw_qualified:
-    return parseQualifiedDirective(loc, ctx);
   case FormatToken::kw_regions:
     return parseRegionsDirective(loc, ctx);
   case FormatToken::kw_results:
     return parseResultsDirective(loc, ctx);
   case FormatToken::kw_successors:
     return parseSuccessorsDirective(loc, ctx);
-  case FormatToken::kw_ref:
-    return parseReferenceDirective(loc, ctx);
   case FormatToken::kw_type:
     return parseTypeDirective(loc, ctx);
   case FormatToken::kw_oilist:
@@ -3338,22 +3332,6 @@ OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
   return create<OperandsDirective>();
 }
 
-FailureOr<FormatElement *>
-OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
-  if (context != CustomDirectiveContext)
-    return emitError(loc, "'ref' is only valid within a `custom` directive");
-
-  FailureOr<FormatElement *> arg;
-  if (failed(parseToken(FormatToken::l_paren,
-                        "expected '(' before argument list")) ||
-      failed(arg = parseElement(RefDirectiveContext)) ||
-      failed(
-          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
-    return failure();
-
-  return create<RefDirective>(*arg);
-}
-
 FailureOr<FormatElement *>
 OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
   if (context == TypeDirectiveContext)
@@ -3495,19 +3473,11 @@ FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
   return create<TypeDirective>(*operand);
 }
 
-FailureOr<FormatElement *>
-OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
-  FailureOr<FormatElement *> element;
-  if (failed(parseToken(FormatToken::l_paren,
-                        "expected '(' before argument list")) ||
-      failed(element = parseElement(context)) ||
-      failed(
-          parseToken(FormatToken::r_paren, "expected ')' after argument list")))
-    return failure();
-  return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
+LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) {
+  return TypeSwitch<FormatElement *, LogicalResult>(element)
       .Case<AttributeVariable, TypeDirective>([](auto *element) {
         element->setShouldBeQualified();
-        return element;
+        return success();
       })
       .Default([&](auto *element) {
         return this->emitError(



More information about the Mlir-commits mailing list