[Mlir-commits] [mlir] 7939b76 - [mlir] Support default valued attribute in StructsGen

Lei Zhang llvmlistbot at llvm.org
Thu Sep 3 06:50:05 PDT 2020


Author: Lei Zhang
Date: 2020-09-03T09:46:44-04:00
New Revision: 7939b76e2a7b1fbc288f6d700bdbe53c581b58a6

URL: https://github.com/llvm/llvm-project/commit/7939b76e2a7b1fbc288f6d700bdbe53c581b58a6
DIFF: https://github.com/llvm/llvm-project/commit/7939b76e2a7b1fbc288f6d700bdbe53c581b58a6.diff

LOG: [mlir] Support default valued attribute in StructsGen

Its handling is similar to optional attributes, except for the
getter method.

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D87055

Added: 
    

Modified: 
    mlir/tools/mlir-tblgen/StructsGen.cpp
    mlir/unittests/TableGen/StructsGenTest.cpp
    mlir/unittests/TableGen/structs.td

Removed: 
    


################################################################################
diff  --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp
index cccacc0cad85..2606dfe3696b 100644
--- a/mlir/tools/mlir-tblgen/StructsGen.cpp
+++ b/mlir/tools/mlir-tblgen/StructsGen.cpp
@@ -143,7 +143,7 @@ static void emitFactoryDef(llvm::StringRef structName,
 )";
 
   for (auto field : fields) {
-    if (field.getType().isOptional())
+    if (field.getType().isOptional() || field.getType().hasDefaultValue())
       os << llvm::formatv(getFieldInfoOptional, field.getName());
     else
       os << llvm::formatv(getFieldInfo, field.getName());
@@ -169,7 +169,7 @@ bool {0}::classof(::mlir::Attribute attr))";
   auto derived = attr.dyn_cast<::mlir::DictionaryAttr>();
   if (!derived)
     return false;
-  int empty_optionals = 0;
+  int num_absent_attrs = 0;
 )";
 
   os << llvm::formatv(classofInfo, structName) << " {";
@@ -184,7 +184,7 @@ bool {0}::classof(::mlir::Attribute attr))";
   const char *classofArgInfoOptional = R"(
   auto {0} = derived.get("{0}");
   if (!{0})
-    ++empty_optionals;
+    ++num_absent_attrs;
   else if (!({1}))
     return false;
 )";
@@ -193,14 +193,14 @@ bool {0}::classof(::mlir::Attribute attr))";
     auto type = field.getType();
     std::string condition =
         std::string(tgfmt(type.getConditionTemplate(), &fctx.withSelf(name)));
-    if (type.isOptional())
+    if (type.isOptional() || type.hasDefaultValue())
       os << llvm::formatv(classofArgInfoOptional, name, condition);
     else
       os << llvm::formatv(classofArgInfo, name, condition);
   }
 
   const char *classofEndInfo = R"(
-  return derived.size() + empty_optionals == {0};
+  return derived.size() + num_absent_attrs == {0};
 }
 )";
   os << llvm::formatv(classofEndInfo, fields.size());
@@ -229,14 +229,35 @@ emitAccessorDef(llvm::StringRef structName,
   return {1}.cast<{0}>();
 }
 )";
+  const char *fieldInfoDefaultValued = R"(
+{0} {2}::{1}() const {
+  auto derived = this->cast<::mlir::DictionaryAttr>();
+  auto {1} = derived.get("{1}");
+  if (!{1}) {
+    ::mlir::Builder builder(getContext());
+    return {3};
+  }
+  assert({1}.isa<{0}>() && "incorrect Attribute type found.");
+  return {1}.cast<{0}>();
+}
+)";
+  FmtContext fmtCtx;
+  fmtCtx.withBuilder("builder");
+
   for (auto field : fields) {
     auto name = field.getName();
     auto type = field.getType();
     auto storage = type.getStorageType();
-    if (type.isOptional())
+    if (type.isOptional()) {
       os << llvm::formatv(fieldInfoOptional, storage, name, structName);
-    else
+    } else if (type.hasDefaultValue()) {
+      std::string defaultValue = tgfmt(type.getConstBuilderTemplate(), &fmtCtx,
+                                       type.getDefaultValue());
+      os << llvm::formatv(fieldInfoDefaultValued, storage, name, structName,
+                          defaultValue);
+    } else {
       os << llvm::formatv(fieldInfo, storage, name, structName);
+    }
   }
 }
 

diff  --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index 14b0abc675bf..d2acb28ebfb1 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/StandardTypes.h"
 #include "llvm/ADT/DenseMap.h"
@@ -34,9 +35,10 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
   auto elementsAttr =
       mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
   auto optionalAttr = nullptr;
+  auto defaultValuedAttr = nullptr;
 
   return test::TestStruct::get(integerAttr, floatAttr, elementsAttr,
-                               optionalAttr, context);
+                               optionalAttr, defaultValuedAttr, context);
 }
 
 /// Validates that test::TestStruct::classof correctly identifies a valid
@@ -167,4 +169,12 @@ TEST(StructsGenTest, EmptyOptional) {
   EXPECT_EQ(structAttr.sample_optional_integer(), nullptr);
 }
 
+TEST(StructsGenTest, GetDefaultValuedAttr) {
+  mlir::MLIRContext context;
+  mlir::Builder builder(&context);
+  auto structAttr = getTestStruct(&context);
+  EXPECT_EQ(structAttr.sample_default_valued_integer(),
+            builder.getI32IntegerAttr(42));
+}
+
 } // namespace mlir

diff  --git a/mlir/unittests/TableGen/structs.td b/mlir/unittests/TableGen/structs.td
index cf5e4f5448f0..06a15e181484 100644
--- a/mlir/unittests/TableGen/structs.td
+++ b/mlir/unittests/TableGen/structs.td
@@ -17,6 +17,8 @@ def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [
                 StructFieldAttr<"sample_float", F32Attr>,
                 StructFieldAttr<"sample_elements", I32ElementsAttr>,
                 StructFieldAttr<"sample_optional_integer",
-                                OptionalAttr<I32Attr>>] > {
+                                OptionalAttr<I32Attr>>,
+                StructFieldAttr<"sample_default_valued_integer",
+                                DefaultValuedAttr<I32Attr, "42">>] > {
   let description = "Structure for test data";
 }


        


More information about the Mlir-commits mailing list