[Mlir-commits] [mlir] 839b436 - [mlir] Improve BitEnumAttr, update documentation
Krzysztof Drewniak
llvmlistbot at llvm.org
Tue Sep 6 14:36:42 PDT 2022
Author: Krzysztof Drewniak
Date: 2022-09-06T21:36:34Z
New Revision: 839b436c93604e042f74050cf2adadd75f30e898
URL: https://github.com/llvm/llvm-project/commit/839b436c93604e042f74050cf2adadd75f30e898
DIFF: https://github.com/llvm/llvm-project/commit/839b436c93604e042f74050cf2adadd75f30e898.diff
LOG: [mlir] Improve BitEnumAttr, update documentation
- Add new operators to BitEnumAttr, mainly not (which only inverts
bits that can be valid bits for the attribute) and xor
- Add new bit enum utility functions: bitEnumClear(bits, bit) and
bitEnumSet(bits, bit, value=true) as they've come up in code I've been
writing that makes use of such enums
- Add rudimentary tests for the enum generator
- Update the OpDefinition documentation to make it contain a correct
example and to have it account for the changes mentioned above.
Reviewed By: antiagainst
Differential Revision: https://reviews.llvm.org/D133374
Added:
mlir/test/mlir-tblgen/enums-gen.td
Modified:
mlir/docs/OpDefinitions.md
mlir/tools/mlir-tblgen/EnumsGen.cpp
Removed:
################################################################################
diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 4415f067f4295..7dd583426f7ab 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1403,11 +1403,11 @@ llvm::Optional<MyIntEnum> symbolizeMyIntEnum(uint32_t value) {
Similarly for the following `BitEnumAttr` definition:
```tablegen
-def None: BitEnumAttrCaseNone<"None">;
-def Bit0: BitEnumAttrCaseBit<"Bit0", 0>;
-def Bit1: BitEnumAttrCaseBit<"Bit1", 1>;
-def Bit2: BitEnumAttrCaseBit<"Bit2", 2>;
-def Bit3: BitEnumAttrCaseBit<"Bit3", 3>;
+def None: I32BitEnumAttrCaseNone<"None">;
+def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
+def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
+def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
+def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum",
[None, Bit0, Bit1, Bit2, Bit3]>;
@@ -1428,14 +1428,37 @@ enum class MyBitEnum : uint32_t {
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t);
std::string stringifyMyBitEnum(MyBitEnum);
llvm::Optional<MyBitEnum> symbolizeMyBitEnum(llvm::StringRef);
-inline MyBitEnum operator|(MyBitEnum lhs, MyBitEnum rhs) {
- return static_cast<MyBitEnum>(static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs));
+
+inline constexpr MyBitEnum operator|(MyBitEnum a, MyBitEnum b) {
+ return static_cast<MyBitEnum>(static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
+}
+inline constexpr MyBitEnum operator&(MyBitEnum a, MyBitEnum b) {
+ return static_cast<MyBitEnum>(static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
+}
+inline constexpr MyBitEnum operator^(MyBitEnum a, MyBitEnum b) {
+ return static_cast<MyBitEnum>(static_cast<uint32_t>(a) ^ static_cast<uint32_t>(b));
+}
+inline constexpr MyBitEnum operator~(MyBitEnum bits) {
+ // Ensure only bits that can be present in the enum are set
+ return static_cast<MyBitEnum>(~static_cast<uint32_t>(bits) & static_cast<uint32_t>(15u));
+}
+inline constexpr bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
+ return (bits & bit) == bit;
}
-inline MyBitEnum operator&(MyBitEnum lhs, MyBitEnum rhs) {
- return static_cast<MyBitEnum>(static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs));
+inline constexpr MyBitEnum bitEnumClear(MyBitEnum bits, MyBitEnum bit) {
+ return bits & ~bit;
}
-inline bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
- return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
+
+inline std::string stringifyEnum(MyBitEnum enumValue) {
+ return stringifyMyBitEnum(enumValue);
+}
+
+template <typename EnumType>
+::llvm::Optional<EnumType> symbolizeEnum(::llvm::StringRef);
+
+template <>
+inline ::llvm::Optional<MyBitEnum> symbolizeEnum<MyBitEnum>(::llvm::StringRef str) {
+ return symbolizeMyBitEnum(str);
}
namespace llvm {
@@ -1467,7 +1490,7 @@ std::string stringifyMyBitEnum(MyBitEnum symbol) {
// Special case for all bits unset.
if (val == 0) return "None";
llvm::SmallVector<llvm::StringRef, 2> strs;
- if (1u == (1u & val)) { strs.push_back("Bit0"); }
+ if (1u == (1u & val)) { strs.push_back("tagged"); }
if (2u == (2u & val)) { strs.push_back("Bit1"); }
if (4u == (4u & val)) { strs.push_back("Bit2"); }
if (8u == (8u & val)) { strs.push_back("Bit3"); }
@@ -1485,7 +1508,7 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(llvm::StringRef str) {
uint32_t val = 0;
for (auto symbol : symbols) {
auto bit = llvm::StringSwitch<llvm::Optional<uint32_t>>(symbol)
- .Case("Bit0", 1)
+ .Case("tagged", 1)
.Case("Bit1", 2)
.Case("Bit2", 4)
.Case("Bit3", 8)
@@ -1499,7 +1522,7 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
// Special case for all bits unset.
if (value == 0) return MyBitEnum::None;
- if (value & ~(1u | 2u | 4u | 8u)) return llvm::None;
+ if (value & ~static_cast<uint32_t>(15u)) return llvm::None;
return static_cast<MyBitEnum>(value);
}
```
diff --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td
new file mode 100644
index 0000000000000..ebe126467be60
--- /dev/null
+++ b/mlir/test/mlir-tblgen/enums-gen.td
@@ -0,0 +1,42 @@
+// RUN: mlir-tblgen -gen-enum-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-enum-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpBase.td"
+
+// Test bit enums
+def None: I32BitEnumAttrCaseNone<"None">;
+def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
+def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
+def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
+def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
+
+def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
+ [None, Bit0, Bit1, Bit2, Bit3]> {
+ let genSpecializedAttr = 0;
+}
+
+// DECL-LABEL: enum class MyBitEnum : uint32_t
+// DECL: None = 0,
+// DECL: Bit0 = 1,
+// DECL: Bit1 = 2,
+// DECL: Bit2 = 4,
+// DECL: Bit3 = 8,
+// DECL: }
+
+// DECL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t);
+// DECL: std::string stringifyMyBitEnum(MyBitEnum);
+// DECL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(::llvm::StringRef);
+
+// DEF-LABEL: std::string stringifyMyBitEnum
+// DEF: auto val = static_cast<uint32_t>
+// DEF: if (val == 0) return "None";
+// DEF: if (1u == (1u & val))
+// DEF-NEXT: push_back("tagged")
+// DEF: if (2u == (2u & val))
+// DEF-NEXT: push_back("Bit1")
+
+// DEF-LABEL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(::llvm::StringRef str)
+// DEF: if (str == "None") return MyBitEnum::None;
+// DEF: .Case("tagged", 1)
+// DEF: .Case("Bit1", 2)
diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 44a8035c0c021..19dcd31932dab 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -136,28 +136,42 @@ getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
//
// inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b);
// inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
-// inline constexpr bool bitEnumContains(<enum-type> a, <enum-type> b);
+// inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
+// inline constexpr <enum-type> operator~(<enum-type> bits);
+// inline constexpr bool bitEnumContains(<enum-type> bits, <enum-type> bit);
+// inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
+// inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
+// bool value=true);
static void emitOperators(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
std::string underlyingType = std::string(enumAttr.getUnderlyingType());
- os << formatv("inline constexpr {0} operator|({0} lhs, {0} rhs) {{\n",
- enumName)
- << formatv(" return static_cast<{0}>("
- "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
- enumName, underlyingType)
- << "}\n";
- os << formatv("inline constexpr {0} operator&({0} lhs, {0} rhs) {{\n",
- enumName)
- << formatv(" return static_cast<{0}>("
- "static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
- enumName, underlyingType)
- << "}\n";
- os << formatv(
- "inline constexpr bool bitEnumContains({0} bits, {0} bit) {{\n"
- " return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
- enumName, underlyingType)
- << "}\n";
+ int64_t validBits = enumDef.getValueAsInt("validBits");
+ const char *const operators = R"(
+inline constexpr {0} operator|({0} a, {0} b) {{
+ return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b));
+}
+inline constexpr {0} operator&({0} a, {0} b) {{
+ return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b));
+}
+inline constexpr {0} operator^({0} a, {0} b) {{
+ return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b));
+}
+inline constexpr {0} operator~({0} bits) {{
+ // Ensure only bits that can be present in the enum are set
+ return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
+}
+inline constexpr bool bitEnumContains({0} bits, {0} bit) {{
+ return (bits & bit) == bit;
+}
+inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
+ return bits & ~bit;
+}
+inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{
+ return value ? (bits | bit) : bitEnumClear(bits, bit);
+}
+ )";
+ os << formatv(operators, enumName, underlyingType, validBits);
}
static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
@@ -424,13 +438,9 @@ static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
os << formatv(" if (value == 0) return {0}::{1};\n\n", enumName,
makeIdentifier(allBitsUnsetCase->getSymbol()));
}
- llvm::SmallVector<std::string, 8> values;
- for (const auto &enumerant : enumerants) {
- if (auto val = enumerant.getValue())
- values.push_back(std::string(formatv("{0}u", val)));
- }
- os << formatv(" if (value & ~static_cast<{0}>({1})) return llvm::None;\n",
- underlyingType, llvm::join(values, " | "));
+ int64_t validBits = enumDef.getValueAsInt("validBits");
+ os << formatv(" if (value & ~static_cast<{0}>({1}u)) return llvm::None;\n",
+ underlyingType, validBits);
os << formatv(" return static_cast<{0}>(value);\n", enumName);
os << "}\n";
}
More information about the Mlir-commits
mailing list