[PATCH] D72162: [mlir] Enhance classof() checks in StructsGen
Lei Zhang via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 3 12:21:50 PST 2020
This revision was automatically updated to reflect the committed changes.
Closed by commit rG5d5d5838ce07: [mlir] Enhance classof() checks in StructsGen (authored by antiagainst).
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D72162/new/
https://reviews.llvm.org/D72162
Files:
mlir/tools/mlir-tblgen/StructsGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp
mlir/unittests/TableGen/structs.td
Index: mlir/unittests/TableGen/structs.td
===================================================================
--- mlir/unittests/TableGen/structs.td
+++ mlir/unittests/TableGen/structs.td
@@ -15,6 +15,6 @@
def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [
StructFieldAttr<"sample_integer", I32Attr>,
StructFieldAttr<"sample_float", F32Attr>,
- StructFieldAttr<"sample_elements", ElementsAttr>] > {
+ StructFieldAttr<"sample_elements", I32ElementsAttr>] > {
let description = "Structure for test data";
}
Index: mlir/unittests/TableGen/StructsGenTest.cpp
===================================================================
--- mlir/unittests/TableGen/StructsGenTest.cpp
+++ mlir/unittests/TableGen/StructsGenTest.cpp
@@ -27,12 +27,12 @@
auto integerType = mlir::IntegerType::get(32, context);
auto integerAttr = mlir::IntegerAttr::get(integerType, 127);
- auto floatType = mlir::FloatType::getF16(context);
+ auto floatType = mlir::FloatType::getF32(context);
auto floatAttr = mlir::FloatAttr::get(floatType, 0.25);
auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType);
auto elementsAttr =
- mlir::DenseElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
+ mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, context);
}
@@ -88,6 +88,31 @@
ASSERT_FALSE(test::TestStruct::classof(badDictionary));
}
+// Validates that test::TestStruct::classof fails when a NamedAttribute has an
+// incorrect type.
+TEST(StructsGenTest, ClassofBadTypeFalse) {
+ mlir::MLIRContext context;
+ mlir::DictionaryAttr structAttr = getTestStruct(&context);
+ auto expectedValues = structAttr.getValue();
+ ASSERT_EQ(expectedValues.size(), 3u);
+
+ // Create a copy of all but the last NamedAttributes.
+ llvm::SmallVector<mlir::NamedAttribute, 4> newValues(
+ expectedValues.begin(), expectedValues.end() - 1);
+
+ // Add a copy of the last attribute with the wrong type.
+ auto i64Type = mlir::IntegerType::get(64, &context);
+ auto elementsType = mlir::RankedTensorType::get({3}, i64Type);
+ auto elementsAttr =
+ mlir::DenseIntElementsAttr::get(elementsType, ArrayRef<int64_t>{1, 2, 3});
+ mlir::Identifier id = expectedValues.back().first;
+ auto wrongAttr = mlir::NamedAttribute(id, elementsAttr);
+ newValues.push_back(wrongAttr);
+
+ auto badDictionary = mlir::DictionaryAttr::get(newValues, &context);
+ ASSERT_FALSE(test::TestStruct::classof(badDictionary));
+}
+
// Validates that test::TestStruct::classof fails when a NamedAttribute is
// missing.
TEST(StructsGenTest, ClassofMissingFalse) {
Index: mlir/tools/mlir-tblgen/StructsGen.cpp
===================================================================
--- mlir/tools/mlir-tblgen/StructsGen.cpp
+++ mlir/tools/mlir-tblgen/StructsGen.cpp
@@ -27,6 +27,7 @@
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
+using mlir::tblgen::FmtContext;
using mlir::tblgen::StructAttr;
static void
@@ -163,16 +164,18 @@
os << llvm::formatv(classofInfo, structName) << " {";
os << llvm::formatv(classofInfoHeader, fields.size());
+ FmtContext fctx;
const char *classofArgInfo = R"(
auto {0} = derived.get("{0}");
- if (!{0} || !{0}.isa<{1}>())
+ if (!{0} || !({1}))
return false;
)";
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
- auto storage = type.getStorageType();
- os << llvm::formatv(classofArgInfo, name, storage);
+ std::string condition =
+ tgfmt(type.getConditionTemplate(), &fctx.withSelf(name));
+ os << llvm::formatv(classofArgInfo, name, condition);
}
const char *classofEndInfo = R"(
-------------- next part --------------
A non-text attachment was scrubbed...
Name: D72162.236106.patch
Type: text/x-patch
Size: 3800 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200103/a3519144/attachment.bin>
More information about the llvm-commits
mailing list