[Mlir-commits] [mlir] a266a21 - [mlir][ods] Extend the EnumAttr tablegen class to support BitEnum attributes

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 25 12:00:06 PDT 2022


Author: Jeremy Furtek
Date: 2022-04-25T19:00:00Z
New Revision: a266a21000123b66718df2a09fd7ceeea67843c5

URL: https://github.com/llvm/llvm-project/commit/a266a21000123b66718df2a09fd7ceeea67843c5
DIFF: https://github.com/llvm/llvm-project/commit/a266a21000123b66718df2a09fd7ceeea67843c5.diff

LOG: [mlir][ods] Extend the EnumAttr tablegen class to support BitEnum attributes

This diff allows the EnumAttr class to be used for bit enum attributes (in
addition to previously supported integer enum attributes). While integer
and bit enum attributes share many common implementation aspects, parsing
bit enum values requires a separate implementation. This is accomplished
by creating empty parser and printer strings in the EnumAttrInfo record,
and having derived classes (specific to bit and integer enums) override with
an appropriate parser/printer string.

To support existing bit enums that may use a vertical bar separator, the
parser is modified to support the | token.

Tests were added for bit enums alongside integer enums.

Future diffs for fastmath attributes in the arithmetic dialect will use these
changes.

(resubmission of earlier abaondoned diff, updated to reflect subsequent changes
in the repository)

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D123880

Added: 
    

Modified: 
    mlir/include/mlir/IR/EnumAttr.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Parser/AsmParserImpl.h
    mlir/lib/Parser/Lexer.cpp
    mlir/lib/Parser/TokenKinds.def
    mlir/test/IR/attribute.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td
index 929283e4d48b6..66a557ce41ab5 100644
--- a/mlir/include/mlir/IR/EnumAttr.td
+++ b/mlir/include/mlir/IR/EnumAttr.td
@@ -53,7 +53,7 @@ class BitEnumAttrCaseBase<I intType, string sym, int val, string str = sym> :
 
 // A bit enum case stored with a 32-bit IntegerAttr. `val` here is *not* the
 // ordinal number of a bit that is set. It is a 32-bit integer value with bits
-// set to match the case. 
+// set to match the case.
 class I32BitEnumAttrCase<string sym, int val, string str = sym>
     : BitEnumAttrCaseBase<I32, sym, val, str>;
 
@@ -182,6 +182,14 @@ class EnumAttrInfo<
     cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)",
     baseAttrClass.constBuilderCall);
   let valueType = baseAttrClass.valueType;
+
+  // C++ type wrapped by attribute
+  string cppType = cppNamespace # "::" # className;
+
+  // Parser and printer code used by the EnumParameter class, to be provided by
+  // derived classes
+  string parameterParser = ?;
+  string parameterPrinter = ?;
 }
 
 // An enum attribute backed by IntegerAttr.
@@ -202,7 +210,25 @@ class IntEnumAttr<I intType, string name, string summary,
     IntEnumAttrBase<intType, cases,
       !if(!empty(summary), "allowed " # intType.summary # " cases: " #
           !interleave(!foreach(case, cases, case.value), ", "),
-          summary)>>;
+          summary)>> {
+  // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
+  // symbol is not valid.
+  let parameterParser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
+    auto loc = $_parser.getCurrentLocation();
+    ::llvm::StringRef enumKeyword;
+    if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
+      return ::mlir::failure();
+    auto maybeEnum = }] # cppNamespace # "::" #
+                          stringToSymbolFnName # [{(enumKeyword);
+    if (maybeEnum)
+      return *maybeEnum;
+    return {(::mlir::LogicalResult)$_parser.emitError(loc, "expected }] #
+    cppType # [{ to be one of: }] #
+    !interleave(!foreach(enum, enumerants, enum.str), ", ") # [{")};
+  }()}];
+  // Print the enum by calling `symbolToString`.
+  let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)";
+}
 
 class I32EnumAttr<string name, string summary, list<I32EnumAttrCase> cases> :
     IntEnumAttr<I32, name, summary, cases> {
@@ -244,6 +270,35 @@ class BitEnumAttr<I intType, string name, string summary,
   // The delimiter used to separate bit enum cases in strings.
   string separator = "|";
 
+  // Parsing function that corresponds to the enum separator. Only
+  // "," and "|" are supported by this definition.
+  string parseSeparatorFn = !if(!eq(separator,"|"),"parseOptionalVerticalBar",
+                                                   "parseOptionalComma");
+
+  // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
+  // symbol is not valid.
+  let parameterParser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
+    }] # cppType # [{ flags = {};
+    auto loc = $_parser.getCurrentLocation();
+    ::llvm::StringRef enumKeyword;
+    do {
+      if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
+        return ::mlir::failure();
+      auto maybeEnum = }] # cppNamespace # "::" #
+                            stringToSymbolFnName # [{(enumKeyword);
+      if (!maybeEnum) {
+          return {(::mlir::LogicalResult)$_parser.emitError(loc, "expected }] #
+                  cppType # [{ to be one of: }] #
+                 !interleave(!foreach(enum, enumerants, enum.str),
+                             ", ") # [{")};
+      }
+      flags = flags | *maybeEnum;
+    } while(::mlir::succeeded($_parser.}] # parseSeparatorFn # [{()));
+    return flags;
+  }()}];
+  // Print the enum by calling `symbolToString`.
+  let parameterPrinter = "$_printer << " # symbolToStringFnName # "($_self)";
+
   // Print the "primary group" only for bits that are members of case groups
   // that have all bits present. When the value is 0, printing will display both
   // both individual bit case names AND the names for all groups that the bit is
@@ -272,23 +327,8 @@ class I64BitEnumAttr<string name, string summary,
 class EnumParameter<EnumAttrInfo enumInfo>
     : AttrParameter<enumInfo.cppNamespace # "::" # enumInfo.className,
                     "an enum of type " # enumInfo.className> {
-  // Parse a keyword and pass it to `stringToSymbol`. Emit an error if a the
-  // symbol is not valid.
-  let parser = [{[&]() -> ::mlir::FailureOr<}] # cppType # [{> {
-    auto loc = $_parser.getCurrentLocation();
-    ::llvm::StringRef enumKeyword;
-    if (::mlir::failed($_parser.parseKeyword(&enumKeyword)))
-      return ::mlir::failure();
-    auto maybeEnum = }] # enumInfo.cppNamespace # "::" #
-                          enumInfo.stringToSymbolFnName # [{(enumKeyword);
-    if (maybeEnum)
-      return *maybeEnum;
-    return {(::mlir::LogicalResult)$_parser.emitError(loc, "expected }] #
-    cppType # [{ to be one of: }] #
-    !interleave(!foreach(enum, enumInfo.enumerants, enum.str), ", ") # [{")};
-  }()}];
-  // Print the enum by calling `symbolToString`.
-  let printer = "$_printer << " # enumInfo.symbolToStringFnName # "($_self)";
+  let parser = enumInfo.parameterParser;
+  let printer = enumInfo.parameterPrinter;
 }
 
 // An attribute backed by a C++ enum. The attribute contains a single

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 2d0a5a8aed0d6..8bdab5739f30b 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -464,6 +464,12 @@ class AsmParser {
   /// Parse a '*' token if present.
   virtual ParseResult parseOptionalStar() = 0;
 
+  /// Parse a '|' token.
+  virtual ParseResult parseVerticalBar() = 0;
+
+  /// Parse a '|' token if present.
+  virtual ParseResult parseOptionalVerticalBar() = 0;
+
   /// Parse a quoted string token.
   ParseResult parseString(std::string *string) {
     auto loc = getCurrentLocation();

diff  --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h
index 81ed270395a08..7e2acd01634f4 100644
--- a/mlir/lib/Parser/AsmParserImpl.h
+++ b/mlir/lib/Parser/AsmParserImpl.h
@@ -221,6 +221,16 @@ class AsmParserImpl : public BaseT {
     return success(parser.consumeIf(Token::plus));
   }
 
+  /// Parse a '|' token.
+  virtual ParseResult parseVerticalBar() override {
+    return parser.parseToken(Token::vertical_bar, "expected '|'");
+  }
+
+  /// Parse a '|' token if present.
+  virtual ParseResult parseOptionalVerticalBar() override {
+    return success(parser.consumeIf(Token::vertical_bar));
+  }
+
   /// Parses a quoted string token if present.
   ParseResult parseOptionalString(std::string *string) override {
     if (!parser.getToken().is(Token::string))

diff  --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp
index 8d48a60948a85..b9c2a14ebf730 100644
--- a/mlir/lib/Parser/Lexer.cpp
+++ b/mlir/lib/Parser/Lexer.cpp
@@ -127,6 +127,9 @@ Token Lexer::lexToken() {
     case '?':
       return formToken(Token::question, tokStart);
 
+    case '|':
+      return formToken(Token::vertical_bar, tokStart);
+
     case '/':
       if (*curPtr == '/') {
         skipComment();

diff  --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def
index d6c403553f1f6..033f002d8a05f 100644
--- a/mlir/lib/Parser/TokenKinds.def
+++ b/mlir/lib/Parser/TokenKinds.def
@@ -70,6 +70,7 @@ TOK_PUNCTUATION(r_brace, "}")
 TOK_PUNCTUATION(r_paren, ")")
 TOK_PUNCTUATION(r_square, "]")
 TOK_PUNCTUATION(star, "*")
+TOK_PUNCTUATION(vertical_bar, "|")
 
 // Keywords.  These turn "foo" into Token::kw_foo enums.
 

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 76ddf6129ae35..555b1f4744c7f 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -407,6 +407,42 @@ func.func @disallowed_case7_fail() {
 
 // -----
 
+//===----------------------------------------------------------------------===//
+// Test BitEnumAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @allowed_cases_pass
+func @allowed_cases_pass() {
+  // CHECK: test.op_with_bit_enum <read,write>
+  "test.op_with_bit_enum"() {value = #test.bit_enum<read, write>} : () -> ()
+  // CHECK: test.op_with_bit_enum <read,execute>
+  test.op_with_bit_enum <read,execute>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @allowed_cases_pass
+func @allowed_cases_pass() {
+  // CHECK: test.op_with_bit_enum_vbar <user|group>
+  "test.op_with_bit_enum_vbar"() {
+    value = #test.bit_enum_vbar<user|group>
+  } : () -> ()
+  // CHECK: test.op_with_bit_enum_vbar <user|group|other>
+  test.op_with_bit_enum_vbar <user | group | other>
+  return
+}
+
+// -----
+
+func @disallowed_case_sticky_fail() {
+  // expected-error at +2 {{expected test::TestBitEnum to be one of: read, write, execute}}
+  // expected-error at +1 {{failed to parse TestBitEnumAttr}}
+  "test.op_with_bit_enum"() {value = #test.bit_enum<sticky>} : () -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Test FloatElementsAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 68668596d7cba..eaa56499c0011 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -22,6 +22,7 @@
 #include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Transforms/FoldUtils.h"
 #include "mlir/Transforms/InliningUtils.h"
+#include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSwitch.h"
 
 // Include this before the using namespace lines below to

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bf41f5f34f68d..8fd6fd6035a6d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -311,6 +311,57 @@ def : Pat<(OpWithEnum ConstantAttr<TestEnumAttr,
                                    "::test::TestEnum::Second">,
                       ConstantAttr<I32Attr, "1">)>;
 
+//===----------------------------------------------------------------------===//
+// Test Bit Enum Attributes
+//===----------------------------------------------------------------------===//
+
+// Define the C++ enum.
+def TestBitEnum
+    : I32BitEnumAttr<"TestBitEnum", "a test bit enum", [
+        I32BitEnumAttrCaseBit<"Read", 0, "read">,
+        I32BitEnumAttrCaseBit<"Write", 1, "write">,
+        I32BitEnumAttrCaseBit<"Execute", 2, "execute">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test";
+  let separator = ",";
+}
+
+// Define the enum attribute.
+def TestBitEnumAttr : EnumAttr<Test_Dialect, TestBitEnum, "bit_enum"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+// Define an op that contains the enum attribute.
+def OpWithBitEnum : TEST_Op<"op_with_bit_enum"> {
+  let arguments = (ins TestBitEnumAttr:$value, OptionalAttr<AnyAttr>:$tag);
+  let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
+}
+
+// Define an enum with a 
diff erent separator
+def TestBitEnumVerticalBar
+    : I32BitEnumAttr<"TestBitEnumVerticalBar", "another test bit enum", [
+        I32BitEnumAttrCaseBit<"User", 0, "user">,
+        I32BitEnumAttrCaseBit<"Group", 1, "group">,
+        I32BitEnumAttrCaseBit<"Other", 2, "other">,
+      ]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "test";
+  let separator = "|";
+}
+
+def TestBitEnumVerticalBarAttr
+    : EnumAttr<Test_Dialect, TestBitEnumVerticalBar, "bit_enum_vbar"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+// Define an op that contains the enum attribute.
+def OpWithBitEnumVerticalBar : TEST_Op<"op_with_bit_enum_vbar"> {
+  let arguments = (ins TestBitEnumVerticalBarAttr:$value,
+                   OptionalAttr<AnyAttr>:$tag);
+  let assemblyFormat = "$value (`tag` $tag^)? attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // Test Attribute Constraints
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list