[Mlir-commits] [mlir] 60e34f8 - [mlir][ods] Remove StrEnumAttr

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 13 10:49:06 PDT 2022


Author: Mogball
Date: 2022-04-13T17:49:02Z
New Revision: 60e34f8dddb4a3ae5b82e8d55728c021126c4af8

URL: https://github.com/llvm/llvm-project/commit/60e34f8dddb4a3ae5b82e8d55728c021126c4af8
DIFF: https://github.com/llvm/llvm-project/commit/60e34f8dddb4a3ae5b82e8d55728c021126c4af8.diff

LOG: [mlir][ods] Remove StrEnumAttr

StrEnumAttr has been deprecated in favour of EnumAttr, a solution based on AttrDef (https://reviews.llvm.org/D115181). This patch removes StrEnumAttr, along with all the custom ODS logic required to handle it.

See https://discourse.llvm.org/t/psa-stop-using-strenumattr-do-use-enumattr/5710 on how to transition to EnumAttr. In short,

```
// Before
def MyEnumAttr : StrEnumAttr<"MyEnum", "", [
  StrEnumAttrCase<"A">,
  StrEnumAttrCase<"B">
]>;

// After (pick an integer enum of your choice)
def MyEnum : I32EnumAttr<"MyEnum", "", [
  I32EnumAttrCase<"A", 0>,
  I32EnumAttrCase<"B", 1>
]> {
  // Don't generate a C++ class! We want to use the AttrDef
  let genSpecializedAttr = 0;
}
// Define the AttrDef
def MyEnum : EnumAttr<MyDialect, MyEnum, "my_enum">;
```

Reviewed By: rriddle, jpienaar

Differential Revision: https://reviews.llvm.org/D120834

Added: 
    

Modified: 
    mlir/docs/OpDefinitions.md
    mlir/include/mlir/IR/OpBase.td
    mlir/include/mlir/TableGen/Attribute.h
    mlir/lib/TableGen/Attribute.cpp
    mlir/test/IR/attribute.mlir
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/mlir-tblgen/pattern.mlir
    mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp
    mlir/tools/mlir-tblgen/EnumsGen.cpp
    mlir/tools/mlir-tblgen/OpFormatGen.cpp
    mlir/tools/mlir-tblgen/RewriterGen.cpp
    mlir/unittests/TableGen/EnumsGenTest.cpp
    mlir/unittests/TableGen/enums.td

Removed: 
    


################################################################################
diff  --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md
index 1f05bf60f7078..b3aadaa3e8ac5 100644
--- a/mlir/docs/OpDefinitions.md
+++ b/mlir/docs/OpDefinitions.md
@@ -1283,10 +1283,8 @@ optionality, default values, etc.:
 
 Some attributes can only take values from a predefined enum, e.g., the
 comparison kind of a comparison op. To define such attributes, ODS provides
-several mechanisms: `StrEnumAttr`, `IntEnumAttr`, and `BitEnumAttr`.
+several mechanisms: `IntEnumAttr`, and `BitEnumAttr`.
 
-*   `StrEnumAttr`: each enum case is a string, the attribute is stored as a
-    [`StringAttr`][StringAttr] in the op.
 *   `IntEnumAttr`: each enum case is an integer, the attribute is stored as a
     [`IntegerAttr`][IntegerAttr] in the op.
 *   `BitEnumAttr`: each enum case is a either the empty case, a single bit,

diff  --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 34798030ad72a..d38d16a00710c 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -1230,13 +1230,6 @@ class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
   string str = strVal;
 }
 
-// An enum attribute case stored with StringAttr.
-class StrEnumAttrCase<string sym, int val = -1, string str = sym> :
-    EnumAttrCaseInfo<sym, val, str>,
-    StringBasedAttr<
-      CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">,
-      "case " # str>;
-
 // An enum attribute case stored with IntegerAttr, which has an integer value,
 // its representation as a string and a C++ symbol name which may be 
diff erent.
 class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
@@ -1393,22 +1386,6 @@ class EnumAttrInfo<
   let valueType = baseAttrClass.valueType;
 }
 
-// An enum attribute backed by StringAttr.
-//
-// Op attributes of this kind are stored as StringAttr. Extra verification will
-// be generated on the string though: only the symbols of the allowed cases are
-// permitted as the string value.
-class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases> :
-  EnumAttrInfo<name, cases,
-    StringBasedAttr<
-      And<[StrAttr.predicate, Or<!foreach(case, cases, case.predicate)>]>,
-      !if(!empty(summary), "allowed string cases: " #
-          !interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "),
-          summary)>> {
-  // Disable specialized Attribute class for `StringAttr` backend by default.
-  let genSpecializedAttr = 0;
-}
-
 // An enum attribute backed by IntegerAttr.
 //
 // Op attributes of this kind are stored as IntegerAttr. Extra verification will

diff  --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h
index 9e6165a33178a..b74c66461652d 100644
--- a/mlir/include/mlir/TableGen/Attribute.h
+++ b/mlir/include/mlir/TableGen/Attribute.h
@@ -144,9 +144,6 @@ class EnumAttrCase : public Attribute {
   explicit EnumAttrCase(const llvm::Record *record);
   explicit EnumAttrCase(const llvm::DefInit *init);
 
-  // Returns true if this EnumAttrCase is backed by a StringAttr.
-  bool isStrCase() const;
-
   // Returns the symbol of this enum attribute case.
   StringRef getSymbol() const;
 

diff  --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp
index ae9183fb27c8b..1d2b8d3ecd142 100644
--- a/mlir/lib/TableGen/Attribute.cpp
+++ b/mlir/lib/TableGen/Attribute.cpp
@@ -157,8 +157,6 @@ EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
 EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
     : EnumAttrCase(init->getDef()) {}
 
-bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
-
 StringRef EnumAttrCase::getSymbol() const {
   return def->getValueAsString("symbol");
 }

diff  --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index 29235df349d25..318168dfaa87e 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -345,29 +345,6 @@ func @string_attr_custom_type() {
 
 // -----
 
-//===----------------------------------------------------------------------===//
-// Test StrEnumAttr
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func @allowed_cases_pass
-func @allowed_cases_pass() {
-  // CHECK: test.str_enum_attr
-  %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32
-  // CHECK: test.str_enum_attr
-  %1 = "test.str_enum_attr"() {attr = "B"} : () -> i32
-  return
-}
-
-// -----
-
-func @disallowed_case_fail() {
-  // expected-error @+1 {{allowed string cases: 'A', 'B'}}
-  %0 = "test.str_enum_attr"() {attr = 7: i32} : () -> i32
-  return
-}
-
-// -----
-
 //===----------------------------------------------------------------------===//
 // Test I32EnumAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index bccca927725e0..85dc1b28cce2f 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -191,17 +191,6 @@ def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
   let assemblyFormat = "$attr attr-dict";
 }
 
-def StrCaseA: StrEnumAttrCase<"A">;
-def StrCaseB: StrEnumAttrCase<"B">;
-
-def SomeStrEnum: StrEnumAttr<
-  "SomeStrEnum", "", [StrCaseA, StrCaseB]>;
-
-def StrEnumAttrOp : TEST_Op<"str_enum_attr"> {
-  let arguments = (ins SomeStrEnum:$attr);
-  let results = (outs I32:$val);
-}
-
 def I32Case5:  I32EnumAttrCase<"case5", 5>;
 def I32Case10: I32EnumAttrCase<"case10", 10>;
 
@@ -1260,8 +1249,6 @@ def : Pat<(OpAttrMatch3 $attr), (OpAttrMatch4 ConstUnitAttr, $attr)>;
 def OpC : TEST_Op<"op_c">, Arguments<(ins I32)>, Results<(outs I32)>;
 def : Pat<(OpC $input), (OpB $input, ConstantAttr<I32Attr, "17">:$attr)>;
 
-// Test string enum attribute in rewrites.
-def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>;
 // Test integer enum attribute in rewrites.
 def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
 def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
@@ -1568,11 +1555,8 @@ def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "66">:$attr),
 // Test Legalization
 //===----------------------------------------------------------------------===//
 
-def Test_LegalizerEnum_Success : StrEnumAttrCase<"Success">;
-def Test_LegalizerEnum_Failure : StrEnumAttrCase<"Failure">;
-
-def Test_LegalizerEnum : StrEnumAttr<"Success", "Failure",
-  [Test_LegalizerEnum_Success, Test_LegalizerEnum_Failure]>;
+def Test_LegalizerEnum_Success : ConstantStrAttr<StrAttr, "Success">;
+def Test_LegalizerEnum_Failure : ConstantStrAttr<StrAttr, "Failure">;
 
 def ILLegalOpA : TEST_Op<"illegal_op_a">, Results<(outs I32)>;
 def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32)>;
@@ -1582,7 +1566,7 @@ def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>;
 def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
 def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>;
 def LegalOpA : TEST_Op<"legal_op_a">,
-  Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
+  Arguments<(ins StrAttr:$status)>, Results<(outs I32)>;
 def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
 def LegalOpC : TEST_Op<"legal_op_c">,
   Arguments<(ins I32)>, Results<(outs I32)>;

diff  --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir
index 9d64865014308..a6525ec3cae6b 100644
--- a/mlir/test/mlir-tblgen/pattern.mlir
+++ b/mlir/test/mlir-tblgen/pattern.mlir
@@ -356,13 +356,6 @@ func @testConstOpMatchNonConst(%arg0 : i32) -> (i32) {
 // Test Enum Attributes
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: verifyStrEnumAttr
-func @verifyStrEnumAttr() -> i32 {
-  // CHECK: "test.str_enum_attr"() {attr = "B"}
-  %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32
-  return %0 : i32
-}
-
 // CHECK-LABEL: verifyI32EnumAttr
 func @verifyI32EnumAttr() -> i32 {
   // CHECK: "test.i32_enum_attr"() {attr = 10 : i32}

diff  --git a/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp b/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp
index a5c89344b2d61..337b6a5e5d5bd 100644
--- a/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp
+++ b/mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp
@@ -35,7 +35,7 @@ using llvm::RecordKeeper;
 // declarations, functions etc.
 //
 // Some OpenMP/OpenACC clauses accept only a fixed set of values as inputs.
-// These can be represented as a String Enum Attribute (StrEnumAttr) in MLIR
+// These can be represented as a Enum Attributes (EnumAttrDef) in MLIR
 // ODS. The emitDecls function below currently generates these enumerations. The
 // name of the enumeration is specified in the enumClauseValue field of
 // Clause record in OMP.td. This name can be used to specify the type of the

diff  --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp
index 3365ff02b0df0..15b17199735ac 100644
--- a/mlir/tools/mlir-tblgen/EnumsGen.cpp
+++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp
@@ -314,8 +314,6 @@ static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
 static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
   EnumAttr enumAttr(enumDef);
   StringRef enumName = enumAttr.getEnumClassName();
-  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
-  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
   StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
   llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
   Attribute baseAttr(baseAttrDef);
@@ -341,28 +339,22 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
   os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
                 attrClassName, enumName);
 
-  if (enumAttr.isSubClassOf("StrEnumAttr")) {
-    os << formatv("  ::mlir::StringAttr baseAttr = "
-                  "::mlir::StringAttr::get(context, {0}(val));\n",
-                  symToStrFnName);
-  } else {
-    StringRef underlyingType = enumAttr.getUnderlyingType();
-
-    // Assuming that it is IntegerAttr constraint
-    int64_t bitwidth = 64;
-    if (baseAttrDef->getValue("valueType")) {
-      auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
-      if (valueTypeDef->getValue("bitwidth"))
-        bitwidth = valueTypeDef->getValueAsInt("bitwidth");
-    }
+  StringRef underlyingType = enumAttr.getUnderlyingType();
 
-    os << formatv("  ::mlir::IntegerType intType = "
-                  "::mlir::IntegerType::get(context, {0});\n",
-                  bitwidth);
-    os << formatv("  ::mlir::IntegerAttr baseAttr = "
-                  "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
-                  underlyingType);
+  // Assuming that it is IntegerAttr constraint
+  int64_t bitwidth = 64;
+  if (baseAttrDef->getValue("valueType")) {
+    auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
+    if (valueTypeDef->getValue("bitwidth"))
+      bitwidth = valueTypeDef->getValueAsInt("bitwidth");
   }
+
+  os << formatv("  ::mlir::IntegerType intType = "
+                "::mlir::IntegerType::get(context, {0});\n",
+                bitwidth);
+  os << formatv("  ::mlir::IntegerAttr baseAttr = "
+                "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
+                underlyingType);
   os << formatv("  return baseAttr.cast<{0}>();\n", attrClassName);
 
   os << "}\n";
@@ -371,14 +363,8 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
 
   os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
 
-  if (enumAttr.isSubClassOf("StrEnumAttr")) {
-    os << formatv("  const auto res = {0}(::mlir::StringAttr::getValue());\n",
-                  strToSymFnName);
-    os << "  return res.getValue();\n";
-  } else {
-    os << formatv("  return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
-                  enumName);
-  }
+  os << formatv("  return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
+                enumName);
 
   os << "}\n";
 }
@@ -483,8 +469,7 @@ class {1} : public ::mlir::{2} {
 )";
   if (enumAttr.genSpecializedAttr()) {
     StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
-    StringRef baseAttrClassName =
-        enumAttr.isSubClassOf("StrEnumAttr") ? "StringAttr" : "IntegerAttr";
+    StringRef baseAttrClassName = "IntegerAttr";
     os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
   }
 

diff  --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index fb54dcb3d85f8..d6976e0cef0f6 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -1797,30 +1797,11 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
   // Get a string containing all of the cases that can't be represented with a
   // keyword.
   BitVector nonKeywordCases(cases.size());
-  bool hasStrCase = false;
   for (auto &it : llvm::enumerate(cases)) {
-    hasStrCase = it.value().isStrCase();
     if (!canFormatStringAsKeyword(it.value().getStr()))
       nonKeywordCases.set(it.index());
   }
 
-  // If this is a string enum, use the case string to determine which cases
-  // need to use the string form.
-  if (hasStrCase) {
-    if (nonKeywordCases.any()) {
-      body << "    if (llvm::is_contained(llvm::ArrayRef<llvm::StringRef>(";
-      llvm::interleaveComma(nonKeywordCases.set_bits(), body, [&](unsigned it) {
-        body << '"' << cases[it].getStr() << '"';
-      });
-      body << ")))\n"
-              "      _odsPrinter << '\"' << caseValueStr << '\"';\n"
-              "    else\n  ";
-    }
-    body << "    _odsPrinter << caseValueStr;\n"
-            "  }\n";
-    return;
-  }
-
   // Otherwise if this is a bit enum attribute, don't allow cases that may
   // overlap with other cases. For simplicity sake, only allow cases with a
   // single bit value.

diff  --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 0159328d6453a..d7dfa7315684d 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1221,8 +1221,6 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
   }
   if (leaf.isEnumAttrCase()) {
     auto enumCase = leaf.getAsEnumAttrCase();
-    if (enumCase.isStrCase())
-      return handleConstantAttr(enumCase, "\"" + enumCase.getSymbol() + "\"");
     // This is an enum case backed by an IntegerAttr. We need to get its value
     // to build the constant.
     std::string val = std::to_string(enumCase.getValue());

diff  --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp
index 82dbe119cb846..a5819c5a857ca 100644
--- a/mlir/unittests/TableGen/EnumsGenTest.cpp
+++ b/mlir/unittests/TableGen/EnumsGenTest.cpp
@@ -27,12 +27,12 @@
 /// Test namespaces and enum class/utility names.
 using Outer::Inner::ConvertToEnum;
 using Outer::Inner::ConvertToString;
-using Outer::Inner::StrEnum;
-using Outer::Inner::StrEnumAttr;
+using Outer::Inner::FooEnum;
+using Outer::Inner::FooEnumAttr;
 
 TEST(EnumsGenTest, GeneratedStrEnumDefinition) {
-  EXPECT_EQ(0u, static_cast<uint64_t>(StrEnum::CaseA));
-  EXPECT_EQ(10u, static_cast<uint64_t>(StrEnum::CaseB));
+  EXPECT_EQ(0u, static_cast<uint64_t>(FooEnum::CaseA));
+  EXPECT_EQ(1u, static_cast<uint64_t>(FooEnum::CaseB));
 }
 
 TEST(EnumsGenTest, GeneratedI32EnumDefinition) {
@@ -41,23 +41,23 @@ TEST(EnumsGenTest, GeneratedI32EnumDefinition) {
 }
 
 TEST(EnumsGenTest, GeneratedDenseMapInfo) {
-  llvm::DenseMap<StrEnum, std::string> myMap;
+  llvm::DenseMap<FooEnum, std::string> myMap;
 
-  myMap[StrEnum::CaseA] = "zero";
-  myMap[StrEnum::CaseB] = "one";
+  myMap[FooEnum::CaseA] = "zero";
+  myMap[FooEnum::CaseB] = "one";
 
-  EXPECT_EQ(myMap[StrEnum::CaseA], "zero");
-  EXPECT_EQ(myMap[StrEnum::CaseB], "one");
+  EXPECT_EQ(myMap[FooEnum::CaseA], "zero");
+  EXPECT_EQ(myMap[FooEnum::CaseB], "one");
 }
 
 TEST(EnumsGenTest, GeneratedSymbolToStringFn) {
-  EXPECT_EQ(ConvertToString(StrEnum::CaseA), "CaseA");
-  EXPECT_EQ(ConvertToString(StrEnum::CaseB), "CaseB");
+  EXPECT_EQ(ConvertToString(FooEnum::CaseA), "CaseA");
+  EXPECT_EQ(ConvertToString(FooEnum::CaseB), "CaseB");
 }
 
 TEST(EnumsGenTest, GeneratedStringToSymbolFn) {
-  EXPECT_EQ(llvm::Optional<StrEnum>(StrEnum::CaseA), ConvertToEnum("CaseA"));
-  EXPECT_EQ(llvm::Optional<StrEnum>(StrEnum::CaseB), ConvertToEnum("CaseB"));
+  EXPECT_EQ(llvm::Optional<FooEnum>(FooEnum::CaseA), ConvertToEnum("CaseA"));
+  EXPECT_EQ(llvm::Optional<FooEnum>(FooEnum::CaseB), ConvertToEnum("CaseB"));
   EXPECT_EQ(llvm::None, ConvertToEnum("X"));
 }
 
@@ -155,19 +155,6 @@ TEST(EnumsGenTest, GeneratedIntAttributeClass) {
   EXPECT_EQ(intAttr, enumAttr);
 }
 
-TEST(EnumsGenTest, GeneratedStringAttributeClass) {
-  mlir::MLIRContext ctx;
-  StrEnum rawVal = StrEnum::CaseA;
-
-  StrEnumAttr enumAttr = StrEnumAttr::get(&ctx, rawVal);
-  EXPECT_NE(enumAttr, nullptr);
-  EXPECT_EQ(enumAttr.getValue(), rawVal);
-
-  mlir::Attribute strAttr = mlir::StringAttr::get(&ctx, "CaseA");
-  EXPECT_TRUE(strAttr.isa<StrEnumAttr>());
-  EXPECT_EQ(strAttr, enumAttr);
-}
-
 TEST(EnumsGenTest, GeneratedBitAttributeClass) {
   mlir::MLIRContext ctx;
 

diff  --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td
index 142f41403ce9f..dcc2313e93fef 100644
--- a/mlir/unittests/TableGen/enums.td
+++ b/mlir/unittests/TableGen/enums.td
@@ -8,10 +8,10 @@
 
 include "mlir/IR/OpBase.td"
 
-def CaseA: StrEnumAttrCase<"CaseA">;
-def CaseB: StrEnumAttrCase<"CaseB", 10>;
+def CaseA: I32EnumAttrCase<"CaseA", 0>;
+def CaseB: I32EnumAttrCase<"CaseB", 1>;
 
-def StrEnum: StrEnumAttr<"StrEnum", "A test enum", [CaseA, CaseB]> {
+def FooEnum: I32EnumAttr<"FooEnum", "A test enum", [CaseA, CaseB]> {
   let cppNamespace = "Outer::Inner";
   let stringToSymbolFnName = "ConvertToEnum";
   let symbolToStringFnName = "ConvertToString";


        


More information about the Mlir-commits mailing list