[Mlir-commits] [mlir] 839b436 - [mlir] Improve BitEnumAttr, update documentation

Krzysztof Drewniak llvmlistbot at llvm.org
Tue Sep 6 14:36:42 PDT 2022


Author: Krzysztof Drewniak
Date: 2022-09-06T21:36:34Z
New Revision: 839b436c93604e042f74050cf2adadd75f30e898

URL: https://github.com/llvm/llvm-project/commit/839b436c93604e042f74050cf2adadd75f30e898
DIFF: https://github.com/llvm/llvm-project/commit/839b436c93604e042f74050cf2adadd75f30e898.diff

LOG: [mlir] Improve BitEnumAttr, update documentation

- Add new operators to BitEnumAttr, mainly not (which only inverts
bits that can be valid bits for the attribute) and xor
- Add new bit enum utility functions: bitEnumClear(bits, bit) and
bitEnumSet(bits, bit, value=true) as they've come up in code I've been
writing that makes use of such enums
- Add rudimentary tests for the enum generator
- Update the OpDefinition documentation to make it contain a correct
example and to have it account for the changes mentioned above.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D133374

Added: 
    mlir/test/mlir-tblgen/enums-gen.td

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/tools/mlir-tblgen/EnumsGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 4415f067f4295..7dd583426f7ab 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1403,11 +1403,11 @@ llvm::Optional<MyIntEnum> symbolizeMyIntEnum(uint32_t value) {
 Similarly for the following `BitEnumAttr` definition:
 
 ```tablegen
-def None: BitEnumAttrCaseNone<"None">;
-def Bit0: BitEnumAttrCaseBit<"Bit0", 0>;
-def Bit1: BitEnumAttrCaseBit<"Bit1", 1>;
-def Bit2: BitEnumAttrCaseBit<"Bit2", 2>;
-def Bit3: BitEnumAttrCaseBit<"Bit3", 3>;
+def None: I32BitEnumAttrCaseNone<"None">;
+def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
+def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
+def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
+def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
 
 def MyBitEnum: BitEnumAttr<"MyBitEnum", "An example bit enum",
                            [None, Bit0, Bit1, Bit2, Bit3]>;
@@ -1428,14 +1428,37 @@ enum class MyBitEnum : uint32_t {
 llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t);
 std::string stringifyMyBitEnum(MyBitEnum);
 llvm::Optional<MyBitEnum> symbolizeMyBitEnum(llvm::StringRef);
-inline MyBitEnum operator|(MyBitEnum lhs, MyBitEnum rhs) {
-  return static_cast<MyBitEnum>(static_cast<uint32_t>(lhs) | static_cast<uint32_t>(rhs));
+
+inline constexpr MyBitEnum operator|(MyBitEnum a, MyBitEnum b) {
+  return static_cast<MyBitEnum>(static_cast<uint32_t>(a) | static_cast<uint32_t>(b));
+}
+inline constexpr MyBitEnum operator&(MyBitEnum a, MyBitEnum b) {
+  return static_cast<MyBitEnum>(static_cast<uint32_t>(a) & static_cast<uint32_t>(b));
+}
+inline constexpr MyBitEnum operator^(MyBitEnum a, MyBitEnum b) {
+  return static_cast<MyBitEnum>(static_cast<uint32_t>(a) ^ static_cast<uint32_t>(b));
+}
+inline constexpr MyBitEnum operator~(MyBitEnum bits) {
+  // Ensure only bits that can be present in the enum are set
+  return static_cast<MyBitEnum>(~static_cast<uint32_t>(bits) & static_cast<uint32_t>(15u));
+}
+inline constexpr bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
+  return (bits & bit) == bit;
 }
-inline MyBitEnum operator&(MyBitEnum lhs, MyBitEnum rhs) {
-  return static_cast<MyBitEnum>(static_cast<uint32_t>(lhs) & static_cast<uint32_t>(rhs));
+inline constexpr MyBitEnum bitEnumClear(MyBitEnum bits, MyBitEnum bit) {
+  return bits & ~bit;
 }
-inline bool bitEnumContains(MyBitEnum bits, MyBitEnum bit) {
-  return (static_cast<uint32_t>(bits) & static_cast<uint32_t>(bit)) != 0;
+
+inline std::string stringifyEnum(MyBitEnum enumValue) {
+  return stringifyMyBitEnum(enumValue);
+}
+
+template <typename EnumType>
+::llvm::Optional<EnumType> symbolizeEnum(::llvm::StringRef);
+
+template <>
+inline ::llvm::Optional<MyBitEnum> symbolizeEnum<MyBitEnum>(::llvm::StringRef str) {
+  return symbolizeMyBitEnum(str);
 }
 
 namespace llvm {
@@ -1467,7 +1490,7 @@ std::string stringifyMyBitEnum(MyBitEnum symbol) {
   // Special case for all bits unset.
   if (val == 0) return "None";
   llvm::SmallVector<llvm::StringRef, 2> strs;
-  if (1u == (1u & val)) { strs.push_back("Bit0"); }
+  if (1u == (1u & val)) { strs.push_back("tagged"); }
   if (2u == (2u & val)) { strs.push_back("Bit1"); }
   if (4u == (4u & val)) { strs.push_back("Bit2"); }
   if (8u == (8u & val)) { strs.push_back("Bit3"); }
@@ -1485,7 +1508,7 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(llvm::StringRef str) {
   uint32_t val = 0;
   for (auto symbol : symbols) {
     auto bit = llvm::StringSwitch<llvm::Optional<uint32_t>>(symbol)
-      .Case("Bit0", 1)
+      .Case("tagged", 1)
       .Case("Bit1", 2)
       .Case("Bit2", 4)
       .Case("Bit3", 8)
@@ -1499,7 +1522,7 @@ llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t value) {
   // Special case for all bits unset.
   if (value == 0) return MyBitEnum::None;
 
-  if (value & ~(1u | 2u | 4u | 8u)) return llvm::None;
+  if (value & ~static_cast<uint32_t>(15u)) return llvm::None;
   return static_cast<MyBitEnum>(value);
 }
 ```

diff  --git a/mlir/test/mlir-tblgen/enums-gen.td b/mlir/test/mlir-tblgen/enums-gen.td
new file mode 100644
index 0000000000000..ebe126467be60
--- /dev/null
+++ b/mlir/test/mlir-tblgen/enums-gen.td
@@ -0,0 +1,42 @@
+// RUN: mlir-tblgen -gen-enum-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
+// RUN: mlir-tblgen -gen-enum-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpBase.td"
+
+// Test bit enums
+def None: I32BitEnumAttrCaseNone<"None">;
+def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
+def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
+def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
+def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
+
+def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
+                           [None, Bit0, Bit1, Bit2, Bit3]> {
+  let genSpecializedAttr = 0;
+}
+
+// DECL-LABEL: enum class MyBitEnum : uint32_t
+// DECL: None = 0,
+// DECL: Bit0 = 1,
+// DECL: Bit1 = 2,
+// DECL: Bit2 = 4,
+// DECL: Bit3 = 8,
+// DECL: }
+
+// DECL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(uint32_t);
+// DECL: std::string stringifyMyBitEnum(MyBitEnum);
+// DECL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(::llvm::StringRef);
+
+// DEF-LABEL: std::string stringifyMyBitEnum
+// DEF: auto val = static_cast<uint32_t>
+// DEF: if (val == 0) return "None";
+// DEF: if (1u == (1u & val))
+// DEF-NEXT: push_back("tagged")
+// DEF: if (2u == (2u & val))
+// DEF-NEXT: push_back("Bit1")
+
+// DEF-LABEL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(::llvm::StringRef str)
+// DEF: if (str == "None") return MyBitEnum::None;
+// DEF: .Case("tagged", 1)
+// DEF: .Case("Bit1", 2)

diff  --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 44a8035c0c021..19dcd31932dab 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -136,28 +136,42 @@ getAllBitsUnsetCase(llvm::ArrayRef<EnumAttrCase> cases) {
 //
 // inline constexpr <enum-type> operator|(<enum-type> a, <enum-type> b);
 // inline constexpr <enum-type> operator&(<enum-type> a, <enum-type> b);
-// inline constexpr bool bitEnumContains(<enum-type> a, <enum-type> b);
+// inline constexpr <enum-type> operator^(<enum-type> a, <enum-type> b);
+// inline constexpr <enum-type> operator~(<enum-type> bits);
+// inline constexpr bool bitEnumContains(<enum-type> bits, <enum-type> bit);
+// inline constexpr <enum-type> bitEnumClear(<enum-type> bits, <enum-type> bit);
+// inline constexpr <enum-type> bitEnumSet(<enum-type> bits, <enum-type> bit,
+// bool value=true);
 static void emitOperators(const Record &enumDef, raw_ostream &os) {
   EnumAttr enumAttr(enumDef);
   StringRef enumName = enumAttr.getEnumClassName();
   std::string underlyingType = std::string(enumAttr.getUnderlyingType());
-  os << formatv("inline constexpr {0} operator|({0} lhs, {0} rhs) {{\n",
-                enumName)
-     << formatv("  return static_cast<{0}>("
-                "static_cast<{1}>(lhs) | static_cast<{1}>(rhs));\n",
-                enumName, underlyingType)
-     << "}\n";
-  os << formatv("inline constexpr {0} operator&({0} lhs, {0} rhs) {{\n",
-                enumName)
-     << formatv("  return static_cast<{0}>("
-                "static_cast<{1}>(lhs) & static_cast<{1}>(rhs));\n",
-                enumName, underlyingType)
-     << "}\n";
-  os << formatv(
-            "inline constexpr bool bitEnumContains({0} bits, {0} bit) {{\n"
-            "  return (static_cast<{1}>(bits) & static_cast<{1}>(bit)) != 0;\n",
-            enumName, underlyingType)
-     << "}\n";
+  int64_t validBits = enumDef.getValueAsInt("validBits");
+  const char *const operators = R"(
+inline constexpr {0} operator|({0} a, {0} b) {{
+  return static_cast<{0}>(static_cast<{1}>(a) | static_cast<{1}>(b));
+}
+inline constexpr {0} operator&({0} a, {0} b) {{
+  return static_cast<{0}>(static_cast<{1}>(a) & static_cast<{1}>(b));
+}
+inline constexpr {0} operator^({0} a, {0} b) {{
+  return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b));
+}
+inline constexpr {0} operator~({0} bits) {{
+  // Ensure only bits that can be present in the enum are set
+  return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));
+}
+inline constexpr bool bitEnumContains({0} bits, {0} bit) {{
+  return (bits & bit) == bit;
+}
+inline constexpr {0} bitEnumClear({0} bits, {0} bit) {{
+  return bits & ~bit;
+}
+inline constexpr {0} bitEnumSet({0} bits, {0} bit, /*optional*/bool value=true) {{
+  return value ? (bits | bit) : bitEnumClear(bits, bit);
+}
+  )";
+  os << formatv(operators, enumName, underlyingType, validBits);
 }
 
 static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
@@ -424,13 +438,9 @@ static void emitUnderlyingToSymFnForBitEnum(const Record &enumDef,
     os << formatv("  if (value == 0) return {0}::{1};\n\n", enumName,
                   makeIdentifier(allBitsUnsetCase->getSymbol()));
   }
-  llvm::SmallVector<std::string, 8> values;
-  for (const auto &enumerant : enumerants) {
-    if (auto val = enumerant.getValue())
-      values.push_back(std::string(formatv("{0}u", val)));
-  }
-  os << formatv("  if (value & ~static_cast<{0}>({1})) return llvm::None;\n",
-                underlyingType, llvm::join(values, " | "));
+  int64_t validBits = enumDef.getValueAsInt("validBits");
+  os << formatv("  if (value & ~static_cast<{0}>({1}u)) return llvm::None;\n",
+                underlyingType, validBits);
   os << formatv("  return static_cast<{0}>(value);\n", enumName);
   os << "}\n";
 }


        


More information about the Mlir-commits mailing list