[Mlir-commits] [mlir] 4e5dee2 - [mlir][ods] Add tablegen field for concise printing of BitEnum attributes

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 25 11:48:41 PDT 2022


Author: jfurtek
Date: 2022-04-25T18:48:35Z
New Revision: 4e5dee2f30dd5b3548a268271e844b54a35947c2

URL: https://github.com/llvm/llvm-project/commit/4e5dee2f30dd5b3548a268271e844b54a35947c2
DIFF: https://github.com/llvm/llvm-project/commit/4e5dee2f30dd5b3548a268271e844b54a35947c2.diff

LOG: [mlir][ods] Add tablegen field for concise printing of BitEnum attributes

This diff introduces a tablegen field for bit enum attributes
(`printBitEnumPrimaryGroups`) to control printing when the enum uses "group"
cases. An example would be an implementation that uses a `fastmath` enum value
as an alias for individual fastmath flags. The proposed field would allow
printing of simply `fast` for the enum value, instead of the more verbose list
that would include `fast` as well as the individual flags (e.g. `reassoc,nnan,
ninf,nsz,arcp,contract,afn,fast`).

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D123871

Added: 
    

Modified: 
    mlir/include/mlir/IR/EnumAttr.td
    mlir/include/mlir/TableGen/Attribute.h
    mlir/lib/TableGen/Attribute.cpp
    mlir/tools/mlir-tblgen/EnumsGen.cpp
    mlir/unittests/TableGen/EnumsGenTest.cpp
    mlir/unittests/TableGen/enums.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index d5bc51a3aab94..929283e4d48b6 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -243,6 +243,15 @@ class BitEnumAttr<I intType, string name, string summary,
 
   // The delimiter used to separate bit enum cases in strings.
   string separator = "|";
+
+  // Print the "primary group" only for bits that are members of case groups
+  // that have all bits present. When the value is 0, printing will display both
+  // both individual bit case names AND the names for all groups that the bit is
+  // contained in. When the value is 1, for each bit that is set AND is a member
+  // of a group with all bits set, only the "primary group" (i.e. the first
+  // group with all bits set in reverse declaration order) will be printed (for
+  // conciseness).
+  bit printBitEnumPrimaryGroups = 0;
 }
 
 class I32BitEnumAttr<string name, string summary,

diff  --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index b74c66461652d..2c9732ea88f12 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -206,6 +206,7 @@ class EnumAttr : public Attribute {
   bool genSpecializedAttr() const;
   llvm::Record *getBaseAttrClass() const;
   StringRef getSpecializedAttrClassName() const;
+  bool printBitEnumPrimaryGroups() const;
 };
 
 class StructFieldAttr {

diff  --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index 1d2b8d3ecd142..9254e4e983c57 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -239,6 +239,10 @@ StringRef EnumAttr::getSpecializedAttrClassName() const {
   return def->getValueAsString("specializedAttrClassName");
 }
 
+bool EnumAttr::printBitEnumPrimaryGroups() const {
+  return def->getValueAsBit("printBitEnumPrimaryGroups");
+}
+
 StructFieldAttr::StructFieldAttr(const llvm::Record *record) : def(record) {
   assert(def->isSubClassOf("StructFieldAttr") &&
          "must be subclass of TableGen 'StructFieldAttr' class");

diff  --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 15b17199735ac..1e71cdb16fe03 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -204,12 +204,47 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
                   allBitsUnsetCase->getSymbol());
   }
   os << "  ::llvm::SmallVector<::llvm::StringRef, 2> strs;\n";
-  for (const auto &enumerant : enumerants) {
-    // Skip the special enumerant for None.
-    if (int64_t val = enumerant.getValue())
-      os << formatv(
-          "  if ({0}u == ({0}u & val)) {{ strs.push_back(\"{1}\"); }\n ", val,
-          enumerant.getStr());
+
+  // Add case string if the value has all case bits, and remove them to avoid
+  // printing again. Used only for groups, when printBitEnumPrimaryGroups is 1.
+  const char *const formatCompareRemove = R"(
+  if ({0}u == ({0}u & val)) {{
+    strs.push_back("{1}");
+    val &= ~static_cast<{2}>({0});
+  }
+)";
+  // Add case string if the value has all case bits. Used for individual bit
+  // cases, and for groups when printBitEnumPrimaryGroups is 0.
+  const char *const formatCompare = R"(
+  if ({0}u == ({0}u & val))
+    strs.push_back("{1}");
+)";
+  // Optionally elide bits that are members of groups that will also be printed
+  // for more concise output.
+  if (enumAttr.printBitEnumPrimaryGroups()) {
+    os << "  // Print bit enum groups before individual bits\n";
+    // Emit comparisons for group bit cases in reverse tablegen declaration
+    // order, removing bits for groups with all bits present.
+    for (const auto &enumerant : llvm::reverse(enumerants)) {
+      if ((enumerant.getValue() != 0) &&
+          enumerant.getDef().isSubClassOf("BitEnumAttrCaseGroup")) {
+        os << formatv(formatCompareRemove, enumerant.getValue(),
+                      enumerant.getStr(), enumAttr.getUnderlyingType());
+      }
+    }
+    // Emit comparisons for individual bit cases in tablegen declaration order.
+    for (const auto &enumerant : enumerants) {
+      if ((enumerant.getValue() != 0) &&
+          enumerant.getDef().isSubClassOf("BitEnumAttrCaseBit"))
+        os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
+    }
+  } else {
+    // Emit comparisons for ALL nonzero cases (individual bits and groups) in
+    // tablegen declaration order.
+    for (const auto &enumerant : enumerants) {
+      if (enumerant.getValue() != 0)
+        os << formatv(formatCompare, enumerant.getValue(), enumerant.getStr());
+    }
   }
   os << formatv("  return ::llvm::join(strs, \"{0}\");\n", separator);
 

diff  --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index a5819c5a857ca..1b6f23932249f 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -70,6 +70,9 @@ TEST(EnumsGenTest, GeneratedBitEnumDefinition) {
   EXPECT_EQ(0u, static_cast<uint32_t>(BitEnumWithNone::None));
   EXPECT_EQ(1u, static_cast<uint32_t>(BitEnumWithNone::Bit0));
   EXPECT_EQ(8u, static_cast<uint32_t>(BitEnumWithNone::Bit3));
+
+  EXPECT_EQ(2u, static_cast<uint64_t>(BitEnum64_Test::Bit1));
+  EXPECT_EQ(144115188075855872u, static_cast<uint64_t>(BitEnum64_Test::Bit57));
 }
 
 TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
@@ -79,8 +82,11 @@ TEST(EnumsGenTest, GeneratedSymbolToStringFnForBitEnum) {
   EXPECT_EQ(
       stringifyBitEnumWithNone(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3),
       "Bit0|Bit3");
-  EXPECT_EQ(2u, static_cast<uint64_t>(BitEnum64_Test::Bit1));
-  EXPECT_EQ(144115188075855872u, static_cast<uint64_t>(BitEnum64_Test::Bit57));
+
+  EXPECT_EQ(stringifyBitEnum64_Test(BitEnum64_Test::Bit1), "Bit1");
+  EXPECT_EQ(
+      stringifyBitEnum64_Test(BitEnum64_Test::Bit1 | BitEnum64_Test::Bit57),
+      "Bit1|Bit57");
 }
 
 TEST(EnumsGenTest, GeneratedStringToSymbolForBitEnum) {
@@ -116,6 +122,26 @@ TEST(EnumsGenTest, GeneratedStringToSymbolForGroupedBitEnum) {
             BitEnumWithGroup::Bit3 | BitEnumWithGroup::Bit0);
 }
 
+TEST(EnumsGenTest, GeneratedSymbolToStringFnForPrimaryGroupBitEnum) {
+  EXPECT_EQ(stringifyBitEnumPrimaryGroup(
+                BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 |
+                BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3),
+            "Bits0To3");
+  EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
+                                         BitEnumPrimaryGroup::Bit2 |
+                                         BitEnumPrimaryGroup::Bit3),
+            "Bit0,Bit2,Bit3");
+  EXPECT_EQ(stringifyBitEnumPrimaryGroup(BitEnumPrimaryGroup::Bit0 |
+                                         BitEnumPrimaryGroup::Bit4 |
+                                         BitEnumPrimaryGroup::Bit5),
+            "Bits4And5,Bit0");
+  EXPECT_EQ(stringifyBitEnumPrimaryGroup(
+                BitEnumPrimaryGroup::Bit0 | BitEnumPrimaryGroup::Bit1 |
+                BitEnumPrimaryGroup::Bit2 | BitEnumPrimaryGroup::Bit3 |
+                BitEnumPrimaryGroup::Bit4 | BitEnumPrimaryGroup::Bit5),
+            "Bits0To5");
+}
+
 TEST(EnumsGenTest, GeneratedOperator) {
   EXPECT_TRUE(bitEnumContains(BitEnumWithNone::Bit0 | BitEnumWithNone::Bit3,
                               BitEnumWithNone::Bit0));

diff  --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td
index 2baaeb0a50248..5c48b2c770907 100644
--- a/mlir/unittests/TableGen/enums.td
+++ b/mlir/unittests/TableGen/enums.td
@@ -40,10 +40,21 @@ def BitEnumWithoutNone : I32BitEnumAttr<"BitEnumWithoutNone", "A test enum",
 
 def Bits0To3 : I32BitEnumAttrCaseGroup<"Bits0To3",
                                        [Bit0, Bit1, Bit2, Bit3]>;
+def Bits4And5 : I32BitEnumAttrCaseGroup<"Bits4And5",
+                                       [Bit4, Bit5]>;
+def Bits0To5 : I32BitEnumAttrCaseGroup<"Bits0To5",
+                                       [Bits0To3, Bits4And5]>;
 
 def BitEnumWithGroup : I32BitEnumAttr<"BitEnumWithGroup", "A test enum",
                                       [Bit0, Bit1, Bit2, Bit3, Bit4, Bits0To3]>;
 
+def BitEnumPrimaryGroup : I32BitEnumAttr<"BitEnumPrimaryGroup", "test enum",
+                                        [Bit0, Bit1, Bit2, Bit3, Bit4, Bit5,
+                                         Bits0To3, Bits4And5, Bits0To5]> {
+  let separator = ",";
+  let printBitEnumPrimaryGroups = 1;
+}
+
 def BitEnum64_None : I64BitEnumAttrCaseNone<"None">;
 def BitEnum64_57   : I64BitEnumAttrCaseBit<"Bit57", 57>;
 def BitEnum64_1    : I64BitEnumAttrCaseBit<"Bit1", 1>;


        


More information about the Mlir-commits mailing list