[Mlir-commits] [mlir] 29bb0b5 - [mlir] Generate parser/printers for enums
River Riddle
llvmlistbot at llvm.org
Fri Oct 21 15:32:52 PDT 2022
Author: River Riddle
Date: 2022-10-21T15:32:36-07:00
New Revision: 29bb0b5e1d4d4563a2c7266d16496ffefba24473
URL: https://github.com/llvm/llvm-project/commit/29bb0b5e1d4d4563a2c7266d16496ffefba24473
DIFF: https://github.com/llvm/llvm-project/commit/29bb0b5e1d4d4563a2c7266d16496ffefba24473.diff
LOG: [mlir] Generate parser/printers for enums
This greatly simplifies composing enums in attribute/type printers,
which currently reimplement these functions as needed.
Differential Revision: https://reviews.llvm.org/D136407
Added:
Modified:
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/func.mlir
mlir/test/mlir-tblgen/enums-gen.td
mlir/tools/mlir-tblgen/EnumsGen.cpp
mlir/tools/mlir-tblgen/FormatGen.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index 95d4a90710f93..4c40060f700f0 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -21,7 +21,7 @@ def LinkageAttr : LLVM_Attr<"Linkage"> {
let parameters = (ins
"linkage::Linkage":$linkage
);
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "`<` $linkage `>`";
}
// Attribute definition for the LLVM Linkage enum.
@@ -30,7 +30,7 @@ def CConvAttr : LLVM_Attr<"CConv"> {
let parameters = (ins
"CConv":$CallingConv
);
- let hasCustomAssemblyFormat = 1;
+ let assemblyFormat = "`<` $CallingConv `>`";
}
def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 3ec212add15ac..980fc197e9f50 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2797,54 +2797,6 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
}
-void LinkageAttr::print(AsmPrinter &printer) const {
- printer << "<";
- if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
- printer << stringifyEnum(getLinkage());
- else
- printer << static_cast<uint64_t>(getLinkage());
- printer << ">";
-}
-
-Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
- StringRef elemName;
- if (parser.parseLess() || parser.parseKeyword(&elemName) ||
- parser.parseGreater())
- return {};
- auto elem = linkage::symbolizeLinkage(elemName);
- if (!elem) {
- parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName;
- return {};
- }
- Linkage linkage = *elem;
- return LinkageAttr::get(parser.getContext(), linkage);
-}
-
-void CConvAttr::print(AsmPrinter &printer) const {
- printer << "<";
- if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv())
- printer << stringifyEnum(getCallingConv());
- else
- printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv());
- printer << ">";
-}
-
-Attribute CConvAttr::parse(AsmParser &parser, Type type) {
- StringRef convName;
-
- if (parser.parseLess() || parser.parseKeyword(&convName) ||
- parser.parseGreater())
- return {};
- auto cconv = cconv::symbolizeCConv(convName);
- if (!cconv) {
- parser.emitError(parser.getNameLoc(), "unknown calling convention: ")
- << convName;
- return {};
- }
- CConv cconvVal = *cconv;
- return CConvAttr::get(parser.getContext(), cconvVal);
-}
-
LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
: options(attr.getOptions().begin(), attr.getOptions().end()) {}
diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir
index da46908782ace..17cc6bf564793 100644
--- a/mlir/test/Dialect/LLVMIR/func.mlir
+++ b/mlir/test/Dialect/LLVMIR/func.mlir
@@ -273,8 +273,9 @@ module {
// -----
module {
- // expected-error at +2 {{unknown calling convention: cc_12}}
"llvm.func"() ({
+ // expected-error @below {{invalid Calling Conventions specification: cc_12}}
+ // expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
}) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
}
diff --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td
index ebe126467be60..ed1b8f56c664c 100644
--- a/mlir/test/mlir-tblgen/enums-gen.td
+++ b/mlir/test/mlir-tblgen/enums-gen.td
@@ -28,6 +28,24 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
// DECL: std::string stringifyMyBitEnum(MyBitEnum);
// DECL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(::llvm::StringRef);
+// DECL: struct FieldParser<::MyBitEnum, ::MyBitEnum> {
+// DECL: template <typename ParserT>
+// DECL: static FailureOr<::MyBitEnum> parse(ParserT &parser) {
+// DECL: // Parse the keyword/string containing the enum.
+// DECL: std::string enumKeyword;
+// DECL: auto loc = parser.getCurrentLocation();
+// DECL: if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
+// DECL: return parser.emitError(loc, "expected keyword for An example bit enum");
+// DECL: // Symbolize the keyword.
+// DECL: if (::llvm::Optional<::MyBitEnum> attr = ::symbolizeEnum<::MyBitEnum>(enumKeyword))
+// DECL: return *attr;
+// DECL: return parser.emitError(loc, "invalid An example bit enum specification: ") << enumKeyword;
+// DECL: }
+
+// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) {
+// DECL: auto valueStr = stringifyEnum(value);
+// DECL: return p << valueStr;
+
// DEF-LABEL: std::string stringifyMyBitEnum
// DEF: auto val = static_cast<uint32_t>
// DEF: if (val == 0) return "None";
@@ -40,3 +58,34 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
// DEF: if (str == "None") return MyBitEnum::None;
// DEF: .Case("tagged", 1)
// DEF: .Case("Bit1", 2)
+
+// Test enum printer generation for non non-keyword enums.
+
+def NonKeywordBit: I32BitEnumAttrCaseBit<"Bit0", 0, "tag-ged">;
+def MyMixedNonKeywordBitEnum: I32BitEnumAttr<"MyMixedNonKeywordBitEnum", "An example bit enum", [
+ NonKeywordBit,
+ Bit1
+ ]> {
+ let genSpecializedAttr = 0;
+}
+
+def MyNonKeywordBitEnum: I32BitEnumAttr<"MyNonKeywordBitEnum", "An example bit enum", [
+ NonKeywordBit
+ ]> {
+ let genSpecializedAttr = 0;
+}
+
+// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyMixedNonKeywordBitEnum value) {
+// DECL: auto valueStr = stringifyEnum(value);
+// DECL: switch (value) {
+// DECL: case ::MyMixedNonKeywordBitEnum::Bit1:
+// DECL: break;
+// DECL: default:
+// DECL: return p << '"' << valueStr << '"';
+// DECL: }
+// DECL: return p << valueStr;
+// DECL: }
+
+// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonKeywordBitEnum value) {
+// DECL: auto valueStr = stringifyEnum(value);
+// DECL: return p << '"' << valueStr << '"';
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 60dde0677b290..c84995e863b8f 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -10,9 +10,11 @@
//
//===----------------------------------------------------------------------===//
+#include "FormatGen.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
@@ -65,10 +67,92 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName,
os << "};\n\n";
}
-static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
+static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
+ StringRef cppNamespace, raw_ostream &os) {
+ if (enumAttr.getUnderlyingType().empty() ||
+ enumAttr.getConstBuilderTemplate().empty())
+ return;
+ auto cases = enumAttr.getAllCases();
+
+ // Check which cases shouldn't be printed using a keyword.
+ llvm::BitVector nonKeywordCases(cases.size());
+ for (auto [index, caseVal] : llvm::enumerate(cases))
+ if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
+ nonKeywordCases.set(index);
+
+ // If this is a bit enum attribute, don't allow cases that may overlap with
+ // other cases. For simplicity sake, only allow cases with a single bit value.
+ if (enumAttr.isBitEnum()) {
+ for (auto [index, caseVal] : llvm::enumerate(cases)) {
+ int64_t value = caseVal.getValue();
+ if (value < 0 || (value != 0 && !llvm::isPowerOf2_64(value)))
+ nonKeywordCases.set(index);
+ }
+ }
+
+ // Generate the parser and the start of the printer for the enum.
+ const char *parsedAndPrinterStart = R"(
+namespace mlir {
+template <typename T, typename>
+struct FieldParser;
+
+template<>
+struct FieldParser<{0}, {0}> {{
+ template <typename ParserT>
+ static FailureOr<{0}> parse(ParserT &parser) {{
+ // Parse the keyword/string containing the enum.
+ std::string enumKeyword;
+ auto loc = parser.getCurrentLocation();
+ if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
+ return parser.emitError(loc, "expected keyword for {2}");
+
+ // Symbolize the keyword.
+ if (::llvm::Optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
+ return *attr;
+ return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
+ }
+};
+} // namespace mlir
+
+namespace llvm {
+inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
+ auto valueStr = stringifyEnum(value);
+)";
+ os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
+ enumAttr.getSummary());
+
+ // If all cases require a string, always wrap.
+ if (nonKeywordCases.all()) {
+ os << " return p << '\"' << valueStr << '\"';\n"
+ "}\n"
+ "} // namespace llvm\n";
+ return;
+ }
+
+ // If there are any cases that can't be used with a keyword, switch on the
+ // case value to determine when to print in the string form.
+ if (nonKeywordCases.any()) {
+ os << " switch (value) {\n";
+ for (auto &it : llvm::enumerate(cases)) {
+ if (nonKeywordCases.test(it.index()))
+ continue;
+ StringRef symbol = it.value().getSymbol();
+ os << llvm::formatv(" case {0}::{1}:\n", qualName,
+ llvm::isDigit(symbol.front()) ? ("_" + symbol)
+ : symbol);
+ }
+ os << " break;\n"
+ " default:\n"
+ " return p << '\"' << valueStr << '\"';\n"
+ " }\n";
+ }
+ os << " return p << valueStr;\n"
+ "}\n"
+ "} // namespace llvm\n";
+}
+
+static void emitDenseMapInfo(StringRef qualName, std::string underlyingType,
StringRef cppNamespace, raw_ostream &os) {
- std::string qualName =
- std::string(formatv("{0}::{1}", cppNamespace, enumName));
if (underlyingType.empty())
underlyingType =
std::string(formatv("std::underlying_type_t<{0}>", qualName));
@@ -529,8 +613,13 @@ class {1} : public ::mlir::{2} {
for (auto ns : llvm::reverse(namespaces))
os << "} // namespace " << ns << "\n";
+ // Generate a generic parser and printer for the enum.
+ std::string qualName =
+ std::string(formatv("{0}::{1}", cppNamespace, enumName));
+ emitParserPrinter(enumAttr, qualName, cppNamespace, os);
+
// Emit DenseMapInfo for this enum class
- emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
+ emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
}
static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index a756587803889..7d2e03ecfe278 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -444,6 +444,11 @@ bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
bool mlir::tblgen::canFormatStringAsKeyword(
StringRef value, function_ref<void(Twine)> emitError) {
+ if (value.empty()) {
+ if (emitError)
+ emitError("keywords cannot be empty");
+ return false;
+ }
if (!isalpha(value.front()) && value.front() != '_') {
if (emitError)
emitError("valid keyword starts with a letter or '_'");
More information about the Mlir-commits
mailing list