[PATCH] D72162: [mlir] Enhance classof() checks in StructsGen

Lei Zhang via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 3 08:40:17 PST 2020


antiagainst created this revision.
antiagainst added reviewers: rriddle, jpienaar.
Herald added subscribers: llvm-commits, lucyrfox, mgester, arpith-jacob, nicolasvasilache, shauheen, burmako, mehdi_amini.
Herald added a project: LLVM.

Previously we only check that each field is of the correct
mlir::Attribute subclass. This commit enhances to also consider
the attribute's types, by leveraging the constraints already
encoded in TableGen attribute definitions.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D72162

Files:
  mlir/tools/mlir-tblgen/StructsGen.cpp
  mlir/unittests/TableGen/StructsGenTest.cpp
  mlir/unittests/TableGen/structs.td


Index: mlir/unittests/TableGen/structs.td
===================================================================
--- mlir/unittests/TableGen/structs.td
+++ mlir/unittests/TableGen/structs.td
@@ -15,6 +15,6 @@
 def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [
                 StructFieldAttr<"sample_integer", I32Attr>,
                 StructFieldAttr<"sample_float", F32Attr>,
-                StructFieldAttr<"sample_elements", ElementsAttr>] > {
+                StructFieldAttr<"sample_elements", I32ElementsAttr>] > {
   let description = "Structure for test data";
 }
Index: mlir/unittests/TableGen/StructsGenTest.cpp
===================================================================
--- mlir/unittests/TableGen/StructsGenTest.cpp
+++ mlir/unittests/TableGen/StructsGenTest.cpp
@@ -27,12 +27,12 @@
   auto integerType = mlir::IntegerType::get(32, context);
   auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
 
-  auto floatType = mlir::FloatType::getF16(context);
+  auto floatType = mlir::FloatType::getF32(context);
   auto floatAttr = mlir::FloatAttr::get(floatType, 0.25);
 
   auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType);
   auto elementsAttr =
-      mlir::DenseElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
+      mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
 
   return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, context);
 }
@@ -88,6 +88,31 @@
   ASSERT_FALSE(test::TestStruct::classof(badDictionary));
 }
 
+// Validates that test::TestStruct::classof fails when a NamedAttribute has an
+// incorrect type.
+TEST(StructsGenTest, ClassofBadTypeFalse) {
+  mlir::MLIRContext context;
+  mlir::DictionaryAttr structAttr = getTestStruct(&context);
+  auto expectedValues = structAttr.getValue();
+  ASSERT_EQ(expectedValues.size(), 3u);
+
+  // Create a copy of all but the last NamedAttributes.
+  llvm::SmallVector<mlir::NamedAttribute, 4> newValues(
+      expectedValues.begin(), expectedValues.end() - 1);
+
+  // Add a copy of the last attribute with the wrong type.
+  auto i64Type = mlir::IntegerType::get(64, &context);
+  auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
+  auto elementsAttr =
+      mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});
+  mlir::Identifier id = expectedValues.back().first;
+  auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
+  newValues.push_back(wrongAttr);
+
+  auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+  ASSERT_FALSE(test::TestStruct::classof(badDictionary));
+}
+
 // Validates that test::TestStruct::classof fails when a NamedAttribute is
 // missing.
 TEST(StructsGenTest, ClassofMissingFalse) {
Index: mlir/tools/mlir-tblgen/StructsGen.cpp
===================================================================
--- mlir/tools/mlir-tblgen/StructsGen.cpp
+++ mlir/tools/mlir-tblgen/StructsGen.cpp
@@ -27,6 +27,7 @@
 using llvm::Record;
 using llvm::RecordKeeper;
 using llvm::StringRef;
+using mlir::tblgen::FmtContext;
 using mlir::tblgen::StructAttr;
 
 static void
@@ -163,16 +164,18 @@
   os << llvm::formatv(classofInfo, structName) << " {";
   os << llvm::formatv(classofInfoHeader, fields.size());
 
+  FmtContext fctx;
   const char *classofArgInfo = R"(
   auto {0} = derived.get("{0}");
-  if (!{0} || !{0}.isa<{1}>())
+  if (!{0} || !({1}))
     return false;
 )";
   for (auto field : fields) {
     auto name = field.getName();
     auto type = field.getType();
-    auto storage = type.getStorageType();
-    os << llvm::formatv(classofArgInfo, name, storage);
+    std::string condition =
+        tgfmt(type.getConditionTemplate(), &fctx.withSelf(name));
+    os << llvm::formatv(classofArgInfo, name, condition);
   }
 
   const char *classofEndInfo = R"(


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D72162.236072.patch
Type: text/x-patch
Size: 3800 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200103/748a4ee2/attachment.bin>


More information about the llvm-commits mailing list