[Mlir-commits] [mlir] 4e5dee2 - [mlir][ods] Add tablegen field for concise printing of BitEnum attributes
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 25 11:48:41 PDT 2022
Author: jfurtek
Date: 2022-04-25T18:48:35Z
New Revision: 4e5dee2f30dd5b3548a268271e844b54a35947c2
URL: https://github.com/llvm/llvm-project/commit/4e5dee2f30dd5b3548a268271e844b54a35947c2
DIFF: https://github.com/llvm/llvm-project/commit/4e5dee2f30dd5b3548a268271e844b54a35947c2.diff
LOG: [mlir][ods] Add tablegen field for concise printing of BitEnum attributes
This diff introduces a tablegen field for bit enum attributes
(`printBitEnumPrimaryGroups`) to control printing when the enum uses "group"
cases. An example would be an implementation that uses a `fastmath` enum value
as an alias for individual fastmath flags. The proposed field would allow
printing of simply `fast` for the enum value, instead of the more verbose list
that would include `fast` as well as the individual flags (e.g. `reassoc,nnan,
ninf,nsz,arcp,contract,afn,fast`).
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D123871
Added:
Modified:
mlir/include/mlir/IR/EnumAttr.td
mlir/include/mlir/TableGen/Attribute.h
mlir/lib/TableGen/Attribute.cpp
mlir/tools/mlir-tblgen/EnumsGen.cpp
mlir/unittests/TableGen/EnumsGenTest.cpp
mlir/unittests/TableGen/enums.td
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index d5bc51a3aab94..929283e4d48b6 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -243,6 +243,15 @@ class BitEnumAttr<I intType, string name, string summary,
// The delimiter used to separate bit enum cases in strings.
string separator = "|";
+
+ // Print the "primary group" only for bits that are members of case groups
+ // that have all bits present. When the value is 0, printing will display both
+ // both individual bit case names AND the names for all groups that the bit is
+ // contained in. When the value is 1, for each bit that is set AND is a member
+ // of a group with all bits set, only the "primary group" (i.e. the first
+ // group with all bits set in reverse declaration order) will be printed (for
+ // conciseness).
+ bit printBitEnumPrimaryGroups = 0;
}
class I32BitEnumAttr<string name, string summary,
diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index b74c66461652d..2c9732ea88f12 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -206,6 +206,7 @@ class EnumAttr : public Attribute {
bool genSpecializedAttr() const;
llvm::Record *getBaseAttrClass() const;
StringRef getSpecializedAttrClassName() const;
+ bool printBitEnumPrimaryGroups() const;
};
class StructFieldAttr {
diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 1d2b8d3ecd142..9254e4e983c57 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -239,6 +239,10 @@ StringRef EnumAttr::getSpecializedAttrClassName() const {
return def->getValueAsString("specializedAttrClassName");
}
+bool EnumAttr::printBitEnumPrimaryGroups() const {
+ return def->getValueAsBit("printBitEnumPrimaryGroups");
+}
+
StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
assert(def->isSubClassOf("StructFieldAttr") &&
"must be subclass of TableGen 'StructFieldAttr' class");
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 15b17199735ac..1e71cdb16fe03 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -204,12 +204,47 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
allBitsUnsetCase->getSymbol());
}
os << " ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
- for (const auto &enumerant : enumerants) {
- // Skip the special enumerant for None.
- if (int64_t val = enumerant.getValue())
- os << formatv(
- " if ({0}u == ({0}u & val)) {{ strs.push_back(\"{1}\"); }\n ", val,
- enumerant.getStr());
+
+ // Add case string if the value has all case bits, and remove them to avoid
+ // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1.
+ const char *const formatCompareRemove = R"(
+ if ({0}u == ({0}u & val)) {{
+ strs.push_back("{1}");
+ val &= ~static_cast<{2}>({0});
+ }
+)";
+ // Add case string if the value has all case bits. Used for individual bit
+ // cases, and for groups when printBitEnumPrimaryGroups is 0.
+ const char *const formatCompare = R"(
+ if ({0}u == ({0}u & val))
+ strs.push_back("{1}");
+)";
+ // Optionally elide bits that are members of groups that will also be printed
+ // for more concise output.
+ if (enumAttr.printBitEnumPrimaryGroups()) {
+ os << " // Print bit enum groups before individual bits\n";
+ // Emit comparisons for group bit cases in reverse tablegen declaration
+ // order, removing bits for groups with all bits present.
+ for (const auto &enumerant : llvm::reverse(enumerants)) {
+ if ((enumerant.getValue() != 0) &&
+ enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) {
+ os << formatv(formatCompareRemove, enumerant.getValue(),
+ enumerant.getStr(), enumAttr.getUnderlyingType());
+ }
+ }
+ // Emit comparisons for individual bit cases in tablegen declaration order.
+ for (const auto &enumerant : enumerants) {
+ if ((enumerant.getValue() != 0) &&
+ enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit"))
+ os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
+ }
+ } else {
+ // Emit comparisons for ALL nonzero cases (individual bits and groups) in
+ // tablegen declaration order.
+ for (const auto &enumerant : enumerants) {
+ if (enumerant.getValue() != 0)
+ os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
+ }
}
os << formatv(" return ::llvm::join(strs, \"{0}\");\n", separator);
diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index a5819c5a857ca..1b6f23932249f 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -70,6 +70,9 @@ TEST(EnumsGenTest, GeneratedBitEnumDefinition) {
EXPECT_EQ(0u, static_cast<uint32_t>(BitEnumWithNone::None));
EXPECT_EQ(1u, static_cast<uint32_t>(BitEnumWithNone::Bit0));
EXPECT_EQ(8u, static_cast<uint32_t>(BitEnumWithNone::Bit3));
+
+ EXPECT_EQ(2u, static_cast<uint64_t>(BitEnum64_Test::Bit1));
+ EXPECT_EQ(144115188075855872u, static_cast<uint64_t>(BitEnum64_Test::Bit57));
}
TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
@@ -79,8 +82,11 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
EXPECT_EQ(
stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3),
"Bit0|Bit3");
- EXPECT_EQ(2u, static_cast<uint64_t>(BitEnum64_Test::Bit1));
- EXPECT_EQ(144115188075855872u, static_cast<uint64_t>(BitEnum64_Test::Bit57));
+
+ EXPECT_EQ(stringifyBitEnum64_Test(BitEnum64_Test::Bit1), "Bit1");
+ EXPECT_EQ(
+ stringifyBitEnum64_Test(BitEnum64_Test::Bit1 | BitEnum64_Test::Bit57),
+ "Bit1|Bit57");
}
TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
@@ -116,6 +122,26 @@ TEST(EnumsGenTest, GeneratedStringToSymbolForGroupedBitEnum) {
BitEnumWithGroup::Bit3 | BitEnumWithGroup::Bit0);
}
+TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
+ EXPECT_EQ(stringifyBitEnumPrimaryGroup(
+ BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 |
+ BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3),
+ "Bits0To3");
+ EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
+ BitEnumPrimaryGroup::Bit2 |
+ BitEnumPrimaryGroup::Bit3),
+ "Bit0,Bit2,Bit3");
+ EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
+ BitEnumPrimaryGroup::Bit4 |
+ BitEnumPrimaryGroup::Bit5),
+ "Bits4And5,Bit0");
+ EXPECT_EQ(stringifyBitEnumPrimaryGroup(
+ BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 |
+ BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3 |
+ BitEnumPrimaryGroup::Bit4 | BitEnumPrimaryGroup::Bit5),
+ "Bits0To5");
+}
+
TEST(EnumsGenTest, GeneratedOperator) {
EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
BitEnumWithNone::Bit0));
diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td
index 2baaeb0a50248..5c48b2c770907 100644
--- a/mlir/unittests/TableGen/enums.td
+++ b/mlir/unittests/TableGen/enums.td
@@ -40,10 +40,21 @@ def BitEnumWithoutNone : I32BitEnumAttr<"BitEnumWithoutNone", "A test enum",
def Bits0To3 : I32BitEnumAttrCaseGroup<"Bits0To3",
[Bit0, Bit1, Bit2, Bit3]>;
+def Bits4And5 : I32BitEnumAttrCaseGroup<"Bits4And5",
+ [Bit4, Bit5]>;
+def Bits0To5 : I32BitEnumAttrCaseGroup<"Bits0To5",
+ [Bits0To3, Bits4And5]>;
def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum",
[Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>;
+def BitEnumPrimaryGroup : I32BitEnumAttr<"BitEnumPrimaryGroup", "test enum",
+ [Bit0, Bit1, Bit2, Bit3, Bit4, Bit5,
+ Bits0To3, Bits4And5, Bits0To5]> {
+ let separator = ",";
+ let printBitEnumPrimaryGroups = 1;
+}
+
def BitEnum64_None : I64BitEnumAttrCaseNone<"None">;
def BitEnum64_57 : I64BitEnumAttrCaseBit<"Bit57", 57>;
def BitEnum64_1 : I64BitEnumAttrCaseBit<"Bit1", 1>;
More information about the Mlir-commits
mailing list