[Mlir-commits] [mlir] 066a76a - Support OptionalAttr inside a StructAttr

Tamas Berghammer llvmlistbot at llvm.org
Wed Feb 19 04:52:42 PST 2020


Author: Tamas Berghammer
Date: 2020-02-19T12:47:04Z
New Revision: 066a76a234d82562de9bfc0503cac5da3ac0a121

URL: https://github.com/llvm/llvm-project/commit/066a76a234d82562de9bfc0503cac5da3ac0a121
DIFF: https://github.com/llvm/llvm-project/commit/066a76a234d82562de9bfc0503cac5da3ac0a121.diff

LOG: Support OptionalAttr inside a StructAttr

Differential revision: https://reviews.llvm.org/D74768

Added: 
    

Modified: 
    mlir/tools/mlir-tblgen/StructsGen.cpp
    mlir/unittests/TableGen/StructsGenTest.cpp
    mlir/unittests/TableGen/structs.td

Removed: 
    


################################################################################
diff  --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp
index 133f2e9d79b4..005757a62f37 100644
--- a/mlir/tools/mlir-tblgen/StructsGen.cpp
+++ b/mlir/tools/mlir-tblgen/StructsGen.cpp
@@ -135,8 +135,18 @@ static void emitFactoryDef(llvm::StringRef structName,
   fields.emplace_back({0}_id, {0});
 )";
 
+  const char *getFieldInfoOptional = R"(
+  if ({0}) {
+    auto {0}_id = mlir::Identifier::get("{0}", context);
+    fields.emplace_back({0}_id, {0});
+  }
+)";
+
   for (auto field : fields) {
-    os << llvm::formatv(getFieldInfo, field.getName());
+    if (field.getType().isOptional())
+      os << llvm::formatv(getFieldInfoOptional, field.getName());
+    else
+      os << llvm::formatv(getFieldInfo, field.getName());
   }
 
   const char *getEndInfo = R"(
@@ -154,35 +164,46 @@ static void emitClassofDef(llvm::StringRef structName,
 bool {0}::classof(mlir::Attribute attr))";
 
   const char *classofInfoHeader = R"(
-   auto derived = attr.dyn_cast<mlir::DictionaryAttr>();
-   if (!derived)
-     return false;
-   if (derived.size() != {0})
-     return false;
+  if (!attr)
+    return false;
+  auto derived = attr.dyn_cast<mlir::DictionaryAttr>();
+  if (!derived)
+    return false;
+  int empty_optionals = 0;
 )";
 
   os << llvm::formatv(classofInfo, structName) << " {";
-  os << llvm::formatv(classofInfoHeader, fields.size());
+  os << llvm::formatv(classofInfoHeader);
 
   FmtContext fctx;
   const char *classofArgInfo = R"(
   auto {0} = derived.get("{0}");
   if (!{0} || !({1}))
     return false;
+)";
+  const char *classofArgInfoOptional = R"(
+  auto {0} = derived.get("{0}");
+  if (!{0})
+    ++empty_optionals;
+  else if (!({1}))
+    return false;
 )";
   for (auto field : fields) {
     auto name = field.getName();
     auto type = field.getType();
     std::string condition =
         std::string(tgfmt(type.getConditionTemplate(), &fctx.withSelf(name)));
-    os << llvm::formatv(classofArgInfo, name, condition);
+    if (type.isOptional())
+      os << llvm::formatv(classofArgInfoOptional, name, condition);
+    else
+      os << llvm::formatv(classofArgInfo, name, condition);
   }
 
   const char *classofEndInfo = R"(
-  return true;
+  return derived.size() + empty_optionals == {0};
 }
 )";
-  os << classofEndInfo;
+  os << llvm::formatv(classofEndInfo, fields.size());
 }
 
 static void
@@ -197,12 +218,25 @@ emitAccessorDef(llvm::StringRef structName,
   assert({1}.isa<{0}>() && "incorrect Attribute type found.");
   return {1}.cast<{0}>();
 }
+)";
+  const char *fieldInfoOptional = R"(
+{0} {2}::{1}() const {
+  auto derived = this->cast<mlir::DictionaryAttr>();
+  auto {1} = derived.get("{1}");
+  if (!{1})
+    return nullptr;
+  assert({1}.isa<{0}>() && "incorrect Attribute type found.");
+  return {1}.cast<{0}>();
+}
 )";
   for (auto field : fields) {
     auto name = field.getName();
     auto type = field.getType();
     auto storage = type.getStorageType();
-    os << llvm::formatv(fieldInfo, storage, name, structName);
+    if (type.isOptional())
+      os << llvm::formatv(fieldInfoOptional, storage, name, structName);
+    else
+      os << llvm::formatv(fieldInfo, storage, name, structName);
   }
 }
 

diff  --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index 5fccb18d1728..19aff1c83b0f 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -33,8 +33,10 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
   auto elementsType = mlir::RankedTensorType::get({2, 3}, integerType);
   auto elementsAttr =
       mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
+  auto optionalAttr = nullptr;
 
-  return test::TestStruct::get(integerAttr, floatAttr, elementsAttr, context);
+  return test::TestStruct::get(integerAttr, floatAttr, elementsAttr,
+                               optionalAttr, context);
 }
 
 // Validates that test::TestStruct::classof correctly identifies a valid
@@ -159,4 +161,10 @@ TEST(StructsGenTest, GetElements) {
   }
 }
 
+TEST(StructsGenTest, EmptyOptional) {
+  mlir::MLIRContext context;
+  auto structAttr = getTestStruct(&context);
+  EXPECT_EQ(structAttr.sample_optional_integer(), nullptr);
+}
+
 } // namespace mlir

diff  --git a/mlir/unittests/TableGen/structs.td b/mlir/unittests/TableGen/structs.td
index 9f86981f3dd4..cf5e4f5448f0 100644
--- a/mlir/unittests/TableGen/structs.td
+++ b/mlir/unittests/TableGen/structs.td
@@ -15,6 +15,8 @@ def Test_Dialect : Dialect {
 def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [
                 StructFieldAttr<"sample_integer", I32Attr>,
                 StructFieldAttr<"sample_float", F32Attr>,
-                StructFieldAttr<"sample_elements", I32ElementsAttr>] > {
+                StructFieldAttr<"sample_elements", I32ElementsAttr>,
+                StructFieldAttr<"sample_optional_integer",
+                                OptionalAttr<I32Attr>>] > {
   let description = "Structure for test data";
 }


        


More information about the Mlir-commits mailing list