[Mlir-commits] [mlir] [mlir][ODS] Deduplicate `ref` and `qualified` handling (PR #91080)
Markus Böck
llvmlistbot at llvm.org
Sat May 4 11:52:03 PDT 2024
https://github.com/zero9178 created https://github.com/llvm/llvm-project/pull/91080
Both the attribute and type format generator and the op format generator independently implemented the parsing and verification of the `ref` and `qualified` directives with little to no differences.
This PR moves the implementation of these into the common `FormatParser` class to deduplicate the implementations.
>From 56d096bb46f42bc7c82503b300faa94a8d00254b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Sat, 4 May 2024 19:51:04 +0100
Subject: [PATCH] [mlir][ODS] Deduplicate `ref` and `qualified` handling
Both the attribute and type format generator and the op format generator independently implemented the parsing and verification of the `ref` and `qualified` directives with little to no differences.
This PR moves the implementation of these into the common `FormatParser` class to deduplicate the implementations.
---
.../attr-or-type-format-invalid.td | 2 +-
.../tools/mlir-tblgen/AttrOrTypeFormatGen.cpp | 52 ++++---------------
mlir/tools/mlir-tblgen/FormatGen.cpp | 36 +++++++++++++
mlir/tools/mlir-tblgen/FormatGen.h | 10 +++-
mlir/tools/mlir-tblgen/OpFormatGen.cpp | 40 ++------------
5 files changed, 61 insertions(+), 79 deletions(-)
diff --git a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
index d3be4d8b8022a0..3a57cbca4d7bb7 100644
--- a/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
+++ b/mlir/test/mlir-tblgen/attr-or-type-format-invalid.td
@@ -111,7 +111,7 @@ def InvalidTypeN : InvalidType<"InvalidTypeN", "invalid_n"> {
def InvalidTypeO : InvalidType<"InvalidTypeO", "invalid_o"> {
let parameters = (ins "int":$a);
- // CHECK: `ref` is only allowed inside custom directives
+ // CHECK: 'ref' is only valid within a `custom` directive
let assemblyFormat = "$a ref($a)";
}
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
index 6098808c646f76..abd1fbdaf8c649 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp
@@ -940,6 +940,8 @@ class DefFormatParser : public FormatParser {
ArrayRef<FormatElement *> elements,
FormatElement *anchor) override;
+ LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
+
/// Parse an attribute or type variable.
FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
Context ctx) override;
@@ -950,12 +952,8 @@ class DefFormatParser : public FormatParser {
private:
/// Parse a `params` directive.
FailureOr<FormatElement *> parseParamsDirective(SMLoc loc, Context ctx);
- /// Parse a `qualified` directive.
- FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
/// Parse a `struct` directive.
FailureOr<FormatElement *> parseStructDirective(SMLoc loc, Context ctx);
- /// Parse a `ref` directive.
- FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context ctx);
/// Attribute or type tablegen def.
const AttrOrTypeDef &def;
@@ -1060,6 +1058,14 @@ DefFormatParser::verifyOptionalGroupElements(llvm::SMLoc loc,
return success();
}
+LogicalResult DefFormatParser::markQualified(SMLoc loc,
+ FormatElement *element) {
+ if (!isa<ParameterElement>(element))
+ return emitError(loc, "`qualified` argument list expected a variable");
+ cast<ParameterElement>(element)->setShouldBeQualified();
+ return success();
+}
+
FailureOr<DefFormat> DefFormatParser::parse() {
FailureOr<std::vector<FormatElement *>> elements = FormatParser::parse();
if (failed(elements))
@@ -1107,33 +1113,11 @@ DefFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
return parseParamsDirective(loc, ctx);
case FormatToken::kw_struct:
return parseStructDirective(loc, ctx);
- case FormatToken::kw_ref:
- return parseRefDirective(loc, ctx);
- case FormatToken::kw_custom:
- return parseCustomDirective(loc, ctx);
-
default:
return emitError(loc, "unsupported directive kind");
}
}
-FailureOr<FormatElement *>
-DefFormatParser::parseQualifiedDirective(SMLoc loc, Context ctx) {
- if (failed(parseToken(FormatToken::l_paren,
- "expected '(' before argument list")))
- return failure();
- FailureOr<FormatElement *> var = parseElement(ctx);
- if (failed(var))
- return var;
- if (!isa<ParameterElement>(*var))
- return emitError(loc, "`qualified` argument list expected a variable");
- cast<ParameterElement>(*var)->setShouldBeQualified();
- if (failed(
- parseToken(FormatToken::r_paren, "expected ')' after argument list")))
- return failure();
- return var;
-}
-
FailureOr<FormatElement *> DefFormatParser::parseParamsDirective(SMLoc loc,
Context ctx) {
// It doesn't make sense to allow references to all parameters in a custom
@@ -1201,22 +1185,6 @@ FailureOr<FormatElement *> DefFormatParser::parseStructDirective(SMLoc loc,
return create<StructDirective>(std::move(vars));
}
-FailureOr<FormatElement *> DefFormatParser::parseRefDirective(SMLoc loc,
- Context ctx) {
- if (ctx != CustomDirectiveContext)
- return emitError(loc, "`ref` is only allowed inside custom directives");
-
- // Parse the child parameter element.
- FailureOr<FormatElement *> child;
- if (failed(parseToken(FormatToken::l_paren, "expected '('")) ||
- failed(child = parseElement(RefDirectiveContext)) ||
- failed(parseToken(FormatToken::r_paren, "expeced ')'")))
- return failure();
-
- // Only parameter elements are allowed to be parsed under a `ref` directive.
- return create<RefDirective>(*child);
-}
-
//===----------------------------------------------------------------------===//
// Interface
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index d402748b96ad5f..7540e584b8fac5 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -308,6 +308,10 @@ FailureOr<FormatElement *> FormatParser::parseDirective(Context ctx) {
if (tok.is(FormatToken::kw_custom))
return parseCustomDirective(loc, ctx);
+ if (tok.is(FormatToken::kw_ref))
+ return parseRefDirective(loc, ctx);
+ if (tok.is(FormatToken::kw_qualified))
+ return parseQualifiedDirective(loc, ctx);
return parseDirectiveImpl(loc, tok.getKind(), ctx);
}
@@ -430,6 +434,38 @@ FailureOr<FormatElement *> FormatParser::parseCustomDirective(SMLoc loc,
return create<CustomDirective>(nameTok->getSpelling(), std::move(arguments));
}
+FailureOr<FormatElement *> FormatParser::parseRefDirective(SMLoc loc,
+ Context context) {
+ if (context != CustomDirectiveContext)
+ return emitError(loc, "'ref' is only valid within a `custom` directive");
+
+ FailureOr<FormatElement *> arg;
+ if (failed(parseToken(FormatToken::l_paren,
+ "expected '(' before argument list")) ||
+ failed(arg = parseElement(RefDirectiveContext)) ||
+ failed(
+ parseToken(FormatToken::r_paren, "expected ')' after argument list")))
+ return failure();
+
+ return create<RefDirective>(*arg);
+}
+
+FailureOr<FormatElement *> FormatParser::parseQualifiedDirective(SMLoc loc,
+ Context ctx) {
+ if (failed(parseToken(FormatToken::l_paren,
+ "expected '(' before argument list")))
+ return failure();
+ FailureOr<FormatElement *> var = parseElement(ctx);
+ if (failed(var))
+ return var;
+ if (failed(markQualified(loc, *var)))
+ return failure();
+ if (failed(
+ parseToken(FormatToken::r_paren, "expected ')' after argument list")))
+ return failure();
+ return var;
+}
+
//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 18a410277fc108..b061d4d8ea7f03 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -495,9 +495,12 @@ class FormatParser {
FailureOr<FormatElement *> parseDirective(Context ctx);
/// Parse an optional group.
FailureOr<FormatElement *> parseOptionalGroup(Context ctx);
-
/// Parse a custom directive.
FailureOr<FormatElement *> parseCustomDirective(llvm::SMLoc loc, Context ctx);
+ /// Parse a ref directive.
+ FailureOr<FormatElement *> parseRefDirective(SMLoc loc, Context context);
+ /// Parse a qualified directive.
+ FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc, Context ctx);
/// Parse a format-specific variable kind.
virtual FailureOr<FormatElement *>
@@ -522,6 +525,11 @@ class FormatParser {
ArrayRef<FormatElement *> elements,
FormatElement *anchor) = 0;
+ /// Mark 'element' as qualified. If 'element' cannot be qualified an error
+ /// should be emitted and failure returned.
+ virtual LogicalResult markQualified(llvm::SMLoc loc,
+ FormatElement *element) = 0;
+
//===--------------------------------------------------------------------===//
// Lexer Utilities
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index 806991035e6685..f7cc0a292b8c53 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -2547,6 +2547,8 @@ class OpFormatParser : public FormatParser {
LogicalResult verifyOptionalGroupElement(SMLoc loc, FormatElement *element,
bool isAnchor);
+ LogicalResult markQualified(SMLoc loc, FormatElement *element) override;
+
/// Parse an operation variable.
FailureOr<FormatElement *> parseVariableImpl(SMLoc loc, StringRef name,
Context ctx) override;
@@ -2622,10 +2624,6 @@ class OpFormatParser : public FormatParser {
FailureOr<FormatElement *> parseOIListDirective(SMLoc loc, Context context);
LogicalResult verifyOIListParsingElement(FormatElement *element, SMLoc loc);
FailureOr<FormatElement *> parseOperandsDirective(SMLoc loc, Context context);
- FailureOr<FormatElement *> parseQualifiedDirective(SMLoc loc,
- Context context);
- FailureOr<FormatElement *> parseReferenceDirective(SMLoc loc,
- Context context);
FailureOr<FormatElement *> parseRegionsDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseResultsDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseSuccessorsDirective(SMLoc loc,
@@ -3224,16 +3222,12 @@ OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
return parseFunctionalTypeDirective(loc, ctx);
case FormatToken::kw_operands:
return parseOperandsDirective(loc, ctx);
- case FormatToken::kw_qualified:
- return parseQualifiedDirective(loc, ctx);
case FormatToken::kw_regions:
return parseRegionsDirective(loc, ctx);
case FormatToken::kw_results:
return parseResultsDirective(loc, ctx);
case FormatToken::kw_successors:
return parseSuccessorsDirective(loc, ctx);
- case FormatToken::kw_ref:
- return parseReferenceDirective(loc, ctx);
case FormatToken::kw_type:
return parseTypeDirective(loc, ctx);
case FormatToken::kw_oilist:
@@ -3338,22 +3332,6 @@ OpFormatParser::parseOperandsDirective(SMLoc loc, Context context) {
return create<OperandsDirective>();
}
-FailureOr<FormatElement *>
-OpFormatParser::parseReferenceDirective(SMLoc loc, Context context) {
- if (context != CustomDirectiveContext)
- return emitError(loc, "'ref' is only valid within a `custom` directive");
-
- FailureOr<FormatElement *> arg;
- if (failed(parseToken(FormatToken::l_paren,
- "expected '(' before argument list")) ||
- failed(arg = parseElement(RefDirectiveContext)) ||
- failed(
- parseToken(FormatToken::r_paren, "expected ')' after argument list")))
- return failure();
-
- return create<RefDirective>(*arg);
-}
-
FailureOr<FormatElement *>
OpFormatParser::parseRegionsDirective(SMLoc loc, Context context) {
if (context == TypeDirectiveContext)
@@ -3495,19 +3473,11 @@ FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
return create<TypeDirective>(*operand);
}
-FailureOr<FormatElement *>
-OpFormatParser::parseQualifiedDirective(SMLoc loc, Context context) {
- FailureOr<FormatElement *> element;
- if (failed(parseToken(FormatToken::l_paren,
- "expected '(' before argument list")) ||
- failed(element = parseElement(context)) ||
- failed(
- parseToken(FormatToken::r_paren, "expected ')' after argument list")))
- return failure();
- return TypeSwitch<FormatElement *, FailureOr<FormatElement *>>(*element)
+LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) {
+ return TypeSwitch<FormatElement *, LogicalResult>(element)
.Case<AttributeVariable, TypeDirective>([](auto *element) {
element->setShouldBeQualified();
- return element;
+ return success();
})
.Default([&](auto *element) {
return this->emitError(
More information about the Mlir-commits
mailing list