[Mlir-commits] [mlir] 7939b76 - [mlir] Support default valued attribute in StructsGen
Lei Zhang
llvmlistbot at llvm.org
Thu Sep 3 06:50:05 PDT 2020
Author: Lei Zhang
Date: 2020-09-03T09:46:44-04:00
New Revision: 7939b76e2a7b1fbc288f6d700bdbe53c581b58a6
URL: https://github.com/llvm/llvm-project/commit/7939b76e2a7b1fbc288f6d700bdbe53c581b58a6
DIFF: https://github.com/llvm/llvm-project/commit/7939b76e2a7b1fbc288f6d700bdbe53c581b58a6.diff
LOG: [mlir] Support default valued attribute in StructsGen
Its handling is similar to optional attributes, except for the
getter method.
Reviewed By: rsuderman
Differential Revision: https://reviews.llvm.org/D87055
Added:
Modified:
mlir/tools/mlir-tblgen/StructsGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp
mlir/unittests/TableGen/structs.td
Removed:
################################################################################
diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp
index cccacc0cad85..2606dfe3696b 100644
--- a/mlir/tools/mlir-tblgen/StructsGen.cpp
+++ b/mlir/tools/mlir-tblgen/StructsGen.cpp
@@ -143,7 +143,7 @@ static void emitFactoryDef(llvm::StringRef structName,
)";
for (auto field : fields) {
- if (field.getType().isOptional())
+ if (field.getType().isOptional() || field.getType().hasDefaultValue())
os << llvm::formatv(getFieldInfoOptional, field.getName());
else
os << llvm::formatv(getFieldInfo, field.getName());
@@ -169,7 +169,7 @@ bool {0}::classof(::mlir::Attribute attr))";
auto derived = attr.dyn_cast<::mlir::DictionaryAttr>();
if (!derived)
return false;
- int empty_optionals = 0;
+ int num_absent_attrs = 0;
)";
os << llvm::formatv(classofInfo, structName) << " {";
@@ -184,7 +184,7 @@ bool {0}::classof(::mlir::Attribute attr))";
const char *classofArgInfoOptional = R"(
auto {0} = derived.get("{0}");
if (!{0})
- ++empty_optionals;
+ ++num_absent_attrs;
else if (!({1}))
return false;
)";
@@ -193,14 +193,14 @@ bool {0}::classof(::mlir::Attribute attr))";
auto type = field.getType();
std::string condition =
std::string(tgfmt(type.getConditionTemplate(), &fctx.withSelf(name)));
- if (type.isOptional())
+ if (type.isOptional() || type.hasDefaultValue())
os << llvm::formatv(classofArgInfoOptional, name, condition);
else
os << llvm::formatv(classofArgInfo, name, condition);
}
const char *classofEndInfo = R"(
- return derived.size() + empty_optionals == {0};
+ return derived.size() + num_absent_attrs == {0};
}
)";
os << llvm::formatv(classofEndInfo, fields.size());
@@ -229,14 +229,35 @@ emitAccessorDef(llvm::StringRef structName,
return {1}.cast<{0}>();
}
)";
+ const char *fieldInfoDefaultValued = R"(
+{0} {2}::{1}() const {
+ auto derived = this->cast<::mlir::DictionaryAttr>();
+ auto {1} = derived.get("{1}");
+ if (!{1}) {
+ ::mlir::Builder builder(getContext());
+ return {3};
+ }
+ assert({1}.isa<{0}>() && "incorrect Attribute type found.");
+ return {1}.cast<{0}>();
+}
+)";
+ FmtContext fmtCtx;
+ fmtCtx.withBuilder("builder");
+
for (auto field : fields) {
auto name = field.getName();
auto type = field.getType();
auto storage = type.getStorageType();
- if (type.isOptional())
+ if (type.isOptional()) {
os << llvm::formatv(fieldInfoOptional, storage, name, structName);
- else
+ } else if (type.hasDefaultValue()) {
+ std::string defaultValue = tgfmt(type.getConstBuilderTemplate(), &fmtCtx,
+ type.getDefaultValue());
+ os << llvm::formatv(fieldInfoDefaultValued, storage, name, structName,
+ defaultValue);
+ } else {
os << llvm::formatv(fieldInfo, storage, name, structName);
+ }
}
}
diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index 14b0abc675bf..d2acb28ebfb1 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/DenseMap.h"
@@ -34,9 +35,10 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
auto elementsAttr =
mlir::DenseIntElementsAttr::get(elementsType, {1, 2, 3, 4, 5, 6});
auto optionalAttr = nullptr;
+ auto defaultValuedAttr = nullptr;
return test::TestStruct::get(integerAttr, floatAttr, elementsAttr,
- optionalAttr, context);
+ optionalAttr, defaultValuedAttr, context);
}
/// Validates that test::TestStruct::classof correctly identifies a valid
@@ -167,4 +169,12 @@ TEST(StructsGenTest, EmptyOptional) {
EXPECT_EQ(structAttr.sample_optional_integer(), nullptr);
}
+TEST(StructsGenTest, GetDefaultValuedAttr) {
+ mlir::MLIRContext context;
+ mlir::Builder builder(&context);
+ auto structAttr = getTestStruct(&context);
+ EXPECT_EQ(structAttr.sample_default_valued_integer(),
+ builder.getI32IntegerAttr(42));
+}
+
} // namespace mlir
diff --git a/mlir/unittests/TableGen/structs.td b/mlir/unittests/TableGen/structs.td
index cf5e4f5448f0..06a15e181484 100644
--- a/mlir/unittests/TableGen/structs.td
+++ b/mlir/unittests/TableGen/structs.td
@@ -17,6 +17,8 @@ def Test_Struct : StructAttr<"TestStruct", Test_Dialect, [
StructFieldAttr<"sample_float", F32Attr>,
StructFieldAttr<"sample_elements", I32ElementsAttr>,
StructFieldAttr<"sample_optional_integer",
- OptionalAttr<I32Attr>>] > {
+ OptionalAttr<I32Attr>>,
+ StructFieldAttr<"sample_default_valued_integer",
+ DefaultValuedAttr<I32Attr, "42">>] > {
let description = "Structure for test data";
}
More information about the Mlir-commits
mailing list