[Mlir-commits] [mlir] 82c6f10 - [mlir] Better handling for bit groups in enum parser/printer
River Riddle
llvmlistbot at llvm.org
Mon Oct 24 00:00:10 PDT 2022
Author: River Riddle
Date: 2022-10-23T23:59:55-07:00
New Revision: 82c6f10052ac57b8a51b97898f32b40ec806da03
URL: https://github.com/llvm/llvm-project/commit/82c6f10052ac57b8a51b97898f32b40ec806da03
DIFF: https://github.com/llvm/llvm-project/commit/82c6f10052ac57b8a51b97898f32b40ec806da03.diff
LOG: [mlir] Better handling for bit groups in enum parser/printer
We currently wrap all multi-bit cases with a string, but this is
overly restrictive. This commit refactors to use keywords when
we know they are valid, and only degrade to string when the validity
of the bitgroup is unknown.
Differential Revision: https://reviews.llvm.org/D136540
Added:
Modified:
mlir/test/mlir-tblgen/enums-gen.td
mlir/tools/mlir-tblgen/EnumsGen.cpp
Removed:
################################################################################
diff --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td
index ed1b8f56c664c..977647ccf7b6e 100644
--- a/mlir/test/mlir-tblgen/enums-gen.td
+++ b/mlir/test/mlir-tblgen/enums-gen.td
@@ -10,9 +10,12 @@ def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
+def BitGroup: I32BitEnumAttrCaseGroup<"BitGroup", [
+ Bit0, Bit1
+]>;
def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
- [None, Bit0, Bit1, Bit2, Bit3]> {
+ [None, Bit0, Bit1, Bit2, Bit3, BitGroup]> {
let genSpecializedAttr = 0;
}
@@ -44,6 +47,15 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) {
// DECL: auto valueStr = stringifyEnum(value);
+// DECL: switch (value) {
+// DECL: case ::MyBitEnum::BitGroup:
+// DECL: return p << valueStr;
+// DECL: default:
+// DECL: break;
+// DECL: }
+// DECL: auto underlyingValue = static_cast<std::make_unsigned_t<::MyBitEnum>>(value);
+// DECL: if (underlyingValue && !llvm::has_single_bit(underlyingValue))
+// DECL: return p << '"' << valueStr << '"';
// DECL: return p << valueStr;
// DEF-LABEL: std::string stringifyMyBitEnum
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index c84995e863b8f..e95d5d61a693a 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -80,16 +80,6 @@ static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
nonKeywordCases.set(index);
- // If this is a bit enum attribute, don't allow cases that may overlap with
- // other cases. For simplicity sake, only allow cases with a single bit value.
- if (enumAttr.isBitEnum()) {
- for (auto [index, caseVal] : llvm::enumerate(cases)) {
- int64_t value = caseVal.getValue();
- if (value < 0 || (value != 0 && !llvm::isPowerOf2_64(value)))
- nonKeywordCases.set(index);
- }
- }
-
// Generate the parser and the start of the printer for the enum.
const char *parsedAndPrinterStart = R"(
namespace mlir {
@@ -137,7 +127,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
if (nonKeywordCases.test(it.index()))
continue;
StringRef symbol = it.value().getSymbol();
- os << llvm::formatv(" case {0}::{1}:\n", qualName,
+ os << llvm::formatv(" case {0}::{1}:\n", qualName,
llvm::isDigit(symbol.front()) ? ("_" + symbol)
: symbol);
}
@@ -145,6 +135,37 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
" default:\n"
" return p << '\"' << valueStr << '\"';\n"
" }\n";
+
+ // If this is a bit enum, conservatively print the string form if the value
+ // is not a power of two (i.e. not a single bit case) and not a known case.
+ } else if (enumAttr.isBitEnum()) {
+ // Process the known multi-bit cases that use valid keywords.
+ llvm::SmallVector<EnumAttrCase *> validMultiBitCases;
+ for (auto [index, caseVal] : llvm::enumerate(cases)) {
+ uint64_t value = caseVal.getValue();
+ if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
+ validMultiBitCases.push_back(&caseVal);
+ }
+ if (!validMultiBitCases.empty()) {
+ os << " switch (value) {\n";
+ for (EnumAttrCase *caseVal : validMultiBitCases) {
+ StringRef symbol = caseVal->getSymbol();
+ os << llvm::formatv(" case {0}::{1}:\n", qualName,
+ llvm::isDigit(symbol.front()) ? ("_" + symbol)
+ : symbol);
+ }
+ os << " return p << valueStr;\n"
+ " default:\n"
+ " break;\n"
+ " }\n";
+ }
+
+ // All other multi-bit cases should be printed as strings.
+ os << formatv(" auto underlyingValue = "
+ "static_cast<std::make_unsigned_t<{0}>>(value);\n",
+ qualName);
+ os << " if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n"
+ " return p << '\"' << valueStr << '\"';\n";
}
os << " return p << valueStr;\n"
"}\n"
More information about the Mlir-commits
mailing list