[Mlir-commits] [mlir] 23e3cbe - [mlir] Refactor how parser/printers are specified for AttrDef/TypeDef
River Riddle
llvmlistbot at llvm.org
Tue Mar 15 00:51:08 PDT 2022
Author: River Riddle
Date: 2022-03-15T00:42:31-07:00
New Revision: 23e3cbe24a51604a03a379664f67ed79e5fef897
URL: https://github.com/llvm/llvm-project/commit/23e3cbe24a51604a03a379664f67ed79e5fef897
DIFF: https://github.com/llvm/llvm-project/commit/23e3cbe24a51604a03a379664f67ed79e5fef897.diff
LOG: [mlir] Refactor how parser/printers are specified for AttrDef/TypeDef
There is currently an awkwardly complex set of rules for how a
parser/printer is generated for AttrDef/TypeDef. It can change depending on if a
mnemonic was specified, if there are parameters, if using the assemblyFormat, if
individual parser/printer code blocks were specified, etc. This commit refactors
this to make what the attribute/type wants more explicit, and to better align
with how formats are specified for operations.
Firstly, the parser/printer code blocks are removed in favor of a
`hasCustomAssemblyFormat` bit field. This aligns with the operation format
specification (and is nice to remove code blocks from ODS).
This commit also adds a requirement to explicitly set `assemblyFormat` or
`hasCustomAssemblyFormat` when the mnemonic is set and the attr/type
has no parameters. This removes the weird implicit matrix of behavior,
and also encourages the author to make a conscious choice of either C++
or declarative format instead of implicitly opting them into the C++
format (we should be pushing towards declarative when possible).
Differential Revision: https://reviews.llvm.org/D121505
Added:
Modified:
flang/include/flang/Optimizer/Dialect/FIRTypes.td
mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/IR/AttrTypeBase.td
mlir/include/mlir/TableGen/AttrOrTypeDef.h
mlir/lib/TableGen/AttrOrTypeDef.cpp
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestTypeDefs.td
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/mlir-tblgen/attrdefs.td
mlir/test/mlir-tblgen/testdialect-typedefs.mlir
mlir/test/mlir-tblgen/typedefs.td
mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
mlir/tools/mlir-tblgen/OpDocGen.cpp
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 0875f4abb8332..4db44bfb9262e 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -39,6 +39,7 @@ def fir_BoxCharType : FIR_Type<"BoxChar", "boxchar"> {
let parameters = (ins "KindTy":$kind);
let genAccessors = 1;
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
using KindTy = unsigned;
@@ -62,6 +63,7 @@ def fir_BoxProcType : FIR_Type<"BoxProc", "boxproc"> {
let parameters = (ins "mlir::Type":$eleTy);
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
}
def fir_BoxType : FIR_Type<"Box", "box"> {
@@ -91,6 +93,7 @@ def fir_BoxType : FIR_Type<"Box", "box"> {
}];
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
}
def fir_CharacterType : FIR_Type<"Character", "char"> {
@@ -104,6 +107,7 @@ def fir_CharacterType : FIR_Type<"Character", "char"> {
}];
let parameters = (ins "KindTy":$FKind, "CharacterType::LenType":$len);
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
using KindTy = unsigned;
@@ -143,6 +147,7 @@ def fir_ComplexType : FIR_Type<"Complex", "complex"> {
}];
let parameters = (ins "KindTy":$fKind);
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
using KindTy = unsigned;
@@ -174,6 +179,7 @@ def fir_HeapType : FIR_Type<"Heap", "heap"> {
let parameters = (ins "mlir::Type":$eleTy);
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
@@ -194,6 +200,7 @@ def fir_IntegerType : FIR_Type<"Integer", "int"> {
}];
let parameters = (ins "KindTy":$fKind);
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
using KindTy = unsigned;
@@ -219,6 +226,7 @@ def fir_LogicalType : FIR_Type<"Logical", "logical"> {
}];
let parameters = (ins "KindTy":$fKind);
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
using KindTy = unsigned;
@@ -259,6 +267,7 @@ def fir_PointerType : FIR_Type<"Pointer", "ptr"> {
let parameters = (ins "mlir::Type":$eleTy);
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
@@ -283,6 +292,7 @@ def fir_RealType : FIR_Type<"Real", "real"> {
}];
let parameters = (ins "KindTy":$fKind);
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
using KindTy = unsigned;
@@ -304,6 +314,7 @@ def fir_RecordType : FIR_Type<"Record", "type"> {
let genVerifyDecl = 1;
let genStorageClass = 0;
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
using TypePair = std::pair<std::string, mlir::Type>;
@@ -351,6 +362,7 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
}];
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
}
def fir_ShapeType : FIR_Type<"Shape", "shape"> {
@@ -363,6 +375,7 @@ def fir_ShapeType : FIR_Type<"Shape", "shape"> {
}];
let parameters = (ins "unsigned":$rank);
+ let hasCustomAssemblyFormat = 1;
}
def fir_ShapeShiftType : FIR_Type<"ShapeShift", "shapeshift"> {
@@ -376,6 +389,7 @@ def fir_ShapeShiftType : FIR_Type<"ShapeShift", "shapeshift"> {
}];
let parameters = (ins "unsigned":$rank);
+ let hasCustomAssemblyFormat = 1;
}
def fir_ShiftType : FIR_Type<"Shift", "shift"> {
@@ -388,6 +402,7 @@ def fir_ShiftType : FIR_Type<"Shift", "shift"> {
}];
let parameters = (ins "unsigned":$rank);
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
using KindTy = unsigned;
@@ -417,6 +432,7 @@ def fir_SequenceType : FIR_Type<"Sequence", "array"> {
);
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
let builders = [
TypeBuilderWithInferredContext<(ins
@@ -470,6 +486,7 @@ def fir_SliceType : FIR_Type<"Slice", "slice"> {
}];
let parameters = (ins "unsigned":$rank);
+ let hasCustomAssemblyFormat = 1;
}
def fir_TypeDescType : FIR_Type<"TypeDesc", "tdesc"> {
@@ -483,6 +500,7 @@ def fir_TypeDescType : FIR_Type<"TypeDesc", "tdesc"> {
let parameters = (ins "mlir::Type":$ofTy);
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
@@ -505,6 +523,7 @@ def fir_VectorType : FIR_Type<"Vector", "vector"> {
let parameters = (ins "uint64_t":$len, "mlir::Type":$eleTy);
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
static bool isValidElementType(mlir::Type t);
diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
index 9d3adf28c8a79..a9dced68aae52 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
@@ -47,6 +47,7 @@ def Async_ValueType : Async_Type<"Value", "value"> {
return $_get(valueType.getContext(), valueType);
}]>
];
+ let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
index bbe6dc51a2369..d9f724ae5b06a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
@@ -41,6 +41,8 @@ def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
}];
let parameters = (ins StringRefParameter<"the opaque value">:$value);
+
+ let hasCustomAssemblyFormat = 1;
}
#endif // MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index 1604f7dd37808..9a73f6ff3b1dc 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -42,6 +42,7 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
}];
let parameters = (ins StringRefParameter<"the opaque value">:$value);
+ let hasCustomAssemblyFormat = 1;
}
def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index f29ea92db3b69..f5b97002b5a15 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -23,6 +23,7 @@ def FastmathFlagsAttr : LLVM_Attr<"FMF"> {
let parameters = (ins
"FastmathFlags":$flags
);
+ let hasCustomAssemblyFormat = 1;
}
// Attribute definition for the LLVM Linkage enum.
@@ -31,6 +32,7 @@ def LinkageAttr : LLVM_Attr<"Linkage"> {
let parameters = (ins
"linkage::Linkage":$linkage
);
+ let hasCustomAssemblyFormat = 1;
}
def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
@@ -63,6 +65,7 @@ def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
AttrBuilder<(ins "ArrayRef<std::pair<LoopOptionCase, int64_t>>":$sortedOptions)>,
AttrBuilder<(ins "LoopOptionsAttrBuilder &":$optionBuilders)>
];
+ let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
index cfe3dd0b50edb..d690fca937a9f 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
@@ -67,6 +67,7 @@ def PDL_Range : PDL_Type<"Range", "range"> {
}]>,
];
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 94ca03808e0ba..2db552f5fe039 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -80,6 +80,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
);
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
let extraClassDeclaration = [{
// Dimension level types that define sparse tensors:
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 8e263b646697b..ead971f01bbc1 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -188,20 +188,20 @@ class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
// Use the lowercased name as the keyword for parsing/printing. Specify only
// if you want tblgen to generate declarations and/or definitions of
- // the printer/parser.
+ // the printer/parser. If specified and the Attribute or Type contains
+ // parameters, `assemblyFormat` or `hasCustomAssemblyFormat` must also be
+ // specified.
string mnemonic = ?;
- // If 'mnemonic' specified,
- // If null, generate just the declarations.
- // If a non-empty code block, just use that code as the definition code.
- // Error if an empty code block.
- code printer = ?;
- code parser = ?;
-
// Custom assembly format. Requires 'mnemonic' to be specified. Cannot be
- // specified at the same time as either 'printer' or 'parser'. The generated
+ // specified at the same time as 'hasCustomAssemblyFormat'. The generated
// printer requires 'genAccessors' to be true.
string assemblyFormat = ?;
+ /// This field indicates that the attribute or type has a custom assembly format
+ /// implemented in C++. When set to `1` a `parse` and `print` method are generated
+ /// on the generated class. The attribute or type should implement these methods to
+ /// support the custom format.
+ bit hasCustomAssemblyFormat = 0;
// If set, generate accessors for each parameter.
bit genAccessors = 1;
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 557f60b2686eb..c2aafca0831b0 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -175,30 +175,13 @@ class AttrOrTypeDef {
/// supposed to auto-generate them.
Optional<StringRef> getMnemonic() const;
- /// Returns the code to use as the types printer method. If not specified,
- /// return a non-value. Otherwise, return the contents of that code block.
- Optional<StringRef> getPrinterCode() const;
-
- /// Returns the code to use as the parser method. If not specified, returns
- /// None. Otherwise, returns the contents of that code block.
- Optional<StringRef> getParserCode() const;
+ /// Returns if the attribute or type has a custom assembly format implemented
+ /// in C++. Corresponds to the `hasCustomAssemblyFormat` field.
+ bool hasCustomAssemblyFormat() const;
/// Returns the custom assembly format, if one was specified.
Optional<StringRef> getAssemblyFormat() const;
- /// An attribute or type with parameters needs a parser.
- bool needsParserPrinter() const { return getNumParameters() != 0; }
-
- /// Returns true if this attribute or type has a generated parser.
- bool hasGeneratedParser() const {
- return getParserCode() || getAssemblyFormat();
- }
-
- /// Returns true if this attribute or type has a generated printer.
- bool hasGeneratedPrinter() const {
- return getPrinterCode() || getAssemblyFormat();
- }
-
/// Returns true if the accessors based on the parameters should be generated.
bool genAccessors() const;
diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 3c0cfa029c700..444db742bd32e 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -62,6 +62,30 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i)
parameters.push_back(AttrOrTypeParameter(parametersDag, i));
}
+
+ // Verify the use of the mnemonic field.
+ bool hasCppFormat = hasCustomAssemblyFormat();
+ bool hasDeclarativeFormat = getAssemblyFormat().hasValue();
+ if (getMnemonic()) {
+ if (hasCppFormat && hasDeclarativeFormat) {
+ PrintFatalError(getLoc(), "cannot specify both 'assemblyFormat' "
+ "and 'hasCustomAssemblyFormat'");
+ }
+ if (!parameters.empty() && !hasCppFormat && !hasDeclarativeFormat) {
+ PrintFatalError(getLoc(),
+ "must specify either 'assemblyFormat' or "
+ "'hasCustomAssemblyFormat' when 'mnemonic' is set");
+ }
+ } else if (hasCppFormat || hasDeclarativeFormat) {
+ PrintFatalError(getLoc(),
+ "'assemblyFormat' or 'hasCustomAssemblyFormat' can only be "
+ "used when 'mnemonic' is set");
+ }
+ // Assembly format requires accessors to be generated.
+ if (hasDeclarativeFormat && !genAccessors()) {
+ PrintFatalError(getLoc(),
+ "'assemblyFormat' requires 'genAccessors' to be true");
+ }
}
Dialect AttrOrTypeDef::getDialect() const {
@@ -122,12 +146,8 @@ Optional<StringRef> AttrOrTypeDef::getMnemonic() const {
return def->getValueAsOptionalString("mnemonic");
}
-Optional<StringRef> AttrOrTypeDef::getPrinterCode() const {
- return def->getValueAsOptionalString("printer");
-}
-
-Optional<StringRef> AttrOrTypeDef::getParserCode() const {
- return def->getValueAsOptionalString("parser");
+bool AttrOrTypeDef::hasCustomAssemblyFormat() const {
+ return def->getValueAsBit("hasCustomAssemblyFormat");
}
Optional<StringRef> AttrOrTypeDef::getAssemblyFormat() const {
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 99ad93533c4ca..fc2ce63a22b67 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -43,6 +43,7 @@ def CompoundAttrA : Test_Attr<"CompoundA"> {
"An example of an array of ints" // Parameter description.
>: $arrayOfInts
);
+ let hasCustomAssemblyFormat = 1;
}
def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> {
let mnemonic = "cmpnd_nested";
@@ -54,6 +55,7 @@ def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> {
def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> {
let mnemonic = "attr_with_self_type_param";
let parameters = (ins AttributeSelfTypeParameter<"">:$type);
+ let hasCustomAssemblyFormat = 1;
}
// An attribute testing AttributeSelfTypeParameter.
@@ -61,6 +63,7 @@ def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> {
let mnemonic = "attr_with_type_builder";
let parameters = (ins "::mlir::IntegerAttr":$attr);
let typeBuilder = "$_attr.getType()";
+ let hasCustomAssemblyFormat = 1;
}
def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
@@ -68,7 +71,6 @@ def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
// The definition of a singleton attribute that has a trait.
def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
let mnemonic = "attr_with_trait";
- let parameters = (ins );
}
// Test support for ElementsAttrInterface.
@@ -106,6 +108,7 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
}
}];
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
}
def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
@@ -120,6 +123,7 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
"::mlir::Attribute":$second,
"::mlir::Attribute":$third
);
+ let hasCustomAssemblyFormat = 1;
}
// A more complex parameterized attribute with multiple level of nesting.
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 75cccdaf8fd45..fd50dd3e0e90b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -48,6 +48,7 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
let extraClassDeclaration = [{
struct SomeCppStruct {};
}];
+ let hasCustomAssemblyFormat = 1;
}
// A more complex and nested parameterized type.
@@ -92,12 +93,8 @@ def IntegerType : Test_Type<"TestInteger"> {
"::test::TestIntegerType::SignednessSemantics":$signedness
);
- // We define the printer inline.
- let printer = [{
- $_printer << "<";
- printSignedness($_printer, getImpl()->signedness);
- $_printer << ", " << getImpl()->width << ">";
- }];
+ // Indicate we use a custom format.
+ let hasCustomAssemblyFormat = 1;
// Define custom builder methods.
let builders = [
@@ -108,19 +105,6 @@ def IntegerType : Test_Type<"TestInteger"> {
];
let skipDefaultBuilders = 1;
- // The parser is defined here also.
- let parser = [{
- if ($_parser.parseLess()) return Type();
- SignednessSemantics signedness;
- if (parseSignedness($_parser, signedness)) return Type();
- if ($_parser.parseComma()) return Type();
- int width;
- if ($_parser.parseInteger(width)) return Type();
- if ($_parser.parseGreater()) return Type();
- Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
- return getChecked(loc, loc.getContext(), width, signedness);
- }];
-
// Any extra code one wants in the type's class declaration.
let extraClassDeclaration = [{
/// Signedness semantics.
@@ -150,37 +134,7 @@ class FieldInfo_Type<string name> : Test_Type<name> {
"::test::FieldInfo", // FieldInfo is defined/declared in TestTypes.h.
"Models struct fields">: $fields
);
-
- // Prints the type in this format:
- // struct<[{field1Name, field1Type}, {field2Name, field2Type}]
- let printer = [{
- $_printer << "<";
- for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) {
- const auto& field = getImpl()->fields[i];
- $_printer << "{" << field.name << "," << field.type << "}";
- if (i < getImpl()->fields.size() - 1)
- $_printer << ",";
- }
- $_printer << ">";
- }];
-
- // Parses the above format
- let parser = [{
- llvm::SmallVector<FieldInfo, 4> parameters;
- if ($_parser.parseLess()) return Type();
- while (mlir::succeeded($_parser.parseOptionalLBrace())) {
- llvm::StringRef name;
- if ($_parser.parseKeyword(&name)) return Type();
- if ($_parser.parseComma()) return Type();
- Type type;
- if ($_parser.parseType(type)) return Type();
- if ($_parser.parseRBrace()) return Type();
- parameters.push_back(FieldInfo {name, type});
- if ($_parser.parseOptionalComma()) break;
- }
- if ($_parser.parseGreater()) return Type();
- return get($_ctxt, parameters);
- }];
+ let hasCustomAssemblyFormat = 1;
}
def StructType : FieldInfo_Type<"Struct"> {
@@ -208,6 +162,7 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
public:
}];
+ let hasCustomAssemblyFormat = 1;
}
def TestMemRefElementType : Test_Type<"TestMemRefElementType",
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 5bcf62a3fa272..833e90de2162a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -83,6 +83,65 @@ static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
return llvm::hash_combine(fi.name, fi.type);
}
+//===----------------------------------------------------------------------===//
+// TestCustomType
+//===----------------------------------------------------------------------===//
+
+static LogicalResult parseCustomTypeA(AsmParser &parser,
+ FailureOr<int> &a_result) {
+ a_result.emplace();
+ return parser.parseInteger(*a_result);
+}
+
+static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
+
+static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
+ FailureOr<Optional<int>> &b_result) {
+ if (a < 0)
+ return success();
+ for (int i : llvm::seq(0, a))
+ if (failed(parser.parseInteger(i)))
+ return failure();
+ b_result.emplace(0);
+ return parser.parseInteger(**b_result);
+}
+
+static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
+ if (a < 0)
+ return;
+ printer << ' ';
+ for (int i : llvm::seq(0, a))
+ printer << i << ' ';
+ printer << *b;
+}
+
+static LogicalResult parseFooString(AsmParser &parser,
+ FailureOr<std::string> &foo) {
+ std::string result;
+ if (parser.parseString(&result))
+ return failure();
+ foo = std::move(result);
+ return success();
+}
+
+static void printFooString(AsmPrinter &printer, StringRef foo) {
+ printer << '"' << foo << '"';
+}
+
+static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
+ return parser.parseKeyword(foo);
+}
+
+static void printBarString(AsmPrinter &printer, StringRef foo) {
+ printer << ' ' << foo;
+}
+//===----------------------------------------------------------------------===//
+// Tablegen Generated Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.cpp.inc"
+
//===----------------------------------------------------------------------===//
// CompoundAType
//===----------------------------------------------------------------------===//
@@ -129,6 +188,54 @@ TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+Type TestIntegerType::parse(AsmParser &parser) {
+ SignednessSemantics signedness;
+ int width;
+ if (parser.parseLess() || parseSignedness(parser, signedness) ||
+ parser.parseComma() || parser.parseInteger(width) ||
+ parser.parseGreater())
+ return Type();
+ Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
+ return getChecked(loc, loc.getContext(), width, signedness);
+}
+
+void TestIntegerType::print(AsmPrinter &p) const {
+ p << "<";
+ printSignedness(p, getSignedness());
+ p << ", " << getWidth() << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// TestStructType
+//===----------------------------------------------------------------------===//
+
+Type StructType::parse(AsmParser &p) {
+ SmallVector<FieldInfo, 4> parameters;
+ if (p.parseLess())
+ return Type();
+ while (succeeded(p.parseOptionalLBrace())) {
+ Type type;
+ StringRef name;
+ if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) ||
+ p.parseRBrace())
+ return Type();
+ parameters.push_back(FieldInfo{name, type});
+ if (p.parseOptionalComma())
+ break;
+ }
+ if (p.parseGreater())
+ return Type();
+ return get(p.getContext(), parameters);
+}
+
+void StructType::print(AsmPrinter &p) const {
+ p << "<";
+ llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) {
+ p << "{" << field.name << "," << field.type << "}";
+ });
+ p << ">";
+}
+
//===----------------------------------------------------------------------===//
// TestType
//===----------------------------------------------------------------------===//
@@ -208,66 +315,6 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
return 1;
}
-//===----------------------------------------------------------------------===//
-// TestCustomType
-//===----------------------------------------------------------------------===//
-
-static LogicalResult parseCustomTypeA(AsmParser &parser,
- FailureOr<int> &a_result) {
- a_result.emplace();
- return parser.parseInteger(*a_result);
-}
-
-static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
-
-static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
- FailureOr<Optional<int>> &b_result) {
- if (a < 0)
- return success();
- for (int i : llvm::seq(0, a))
- if (failed(parser.parseInteger(i)))
- return failure();
- b_result.emplace(0);
- return parser.parseInteger(**b_result);
-}
-
-static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
- if (a < 0)
- return;
- printer << ' ';
- for (int i : llvm::seq(0, a))
- printer << i << ' ';
- printer << *b;
-}
-
-static LogicalResult parseFooString(AsmParser &parser,
- FailureOr<std::string> &foo) {
- std::string result;
- if (parser.parseString(&result))
- return failure();
- foo = std::move(result);
- return success();
-}
-
-static void printFooString(AsmPrinter &printer, StringRef foo) {
- printer << '"' << foo << '"';
-}
-
-static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
- return parser.parseKeyword(foo);
-}
-
-static void printBarString(AsmPrinter &printer, StringRef foo) {
- printer << ' ' << foo;
-}
-
-//===----------------------------------------------------------------------===//
-// Tablegen Generated Definitions
-//===----------------------------------------------------------------------===//
-
-#define GET_TYPEDEF_CLASSES
-#include "TestTypeDefs.cpp.inc"
-
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 70a021ebb385b..3737ce77948e2 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -60,6 +60,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
);
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
// DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute
// DECL: static CompoundAAttr getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
@@ -102,6 +103,7 @@ def C_IndexAttr : TestAttr<"Index"> {
ins
StringRefParameter<"Label for index">:$label
);
+ let hasCustomAssemblyFormat = 1;
// DECL-LABEL: class IndexAttr : public ::mlir::Attribute
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
@@ -127,6 +129,7 @@ def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> {
let mnemonic = "attr_with_type_builder";
let parameters = (ins "::mlir::IntegerAttr":$attr);
let typeBuilder = "$_attr.getType()";
+ let hasCustomAssemblyFormat = 1;
}
// DEF-LABEL: struct AttrWithTypeBuilderAttrStorage
diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
index dc8aea52559b4..1b568f197f1da 100644
--- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
@@ -38,7 +38,7 @@ func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<
return
}
-// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<none, 3>}>)
+// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla}, {field2,!test.int<none, 3>}>)
func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int<none, 3>} > ) {
return
}
diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index af8b70a6dfec4..833195530f9b0 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -35,8 +35,8 @@ include "mlir/IR/OpBase.td"
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect
- let name = "TestDialect";
- let cppNamespace = "::test";
+ let name = "TestDialect";
+ let cppNamespace = "::test";
}
class TestType<string name> : TypeDef<Test_Dialect, name> { }
@@ -54,16 +54,16 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
let summary = "A more complex parameterized type";
let description = "This type is to test a reasonably complex type";
let mnemonic = "cmpnd_a";
- let parameters = (
- ins
- "int":$widthOfSomething,
- "::test::SimpleTypeA": $exampleTdType,
- "SomeCppStruct": $exampleCppType,
- ArrayRefParameter<"int", "Matrix dimensions">:$dims,
- RTLValueType:$inner
+ let parameters = (ins
+ "int":$widthOfSomething,
+ "::test::SimpleTypeA": $exampleTdType,
+ "SomeCppStruct": $exampleCppType,
+ ArrayRefParameter<"int", "Matrix dimensions">:$dims,
+ RTLValueType:$inner
);
let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
// DECL-LABEL: class CompoundAType : public ::mlir::Type
// DECL: static CompoundAType getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
@@ -79,12 +79,12 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
}
def C_IndexType : TestType<"Index"> {
- let mnemonic = "index";
+ let mnemonic = "index";
- let parameters = (
- ins
- StringRefParameter<"Label for index">:$label
- );
+ let parameters = (ins
+ StringRefParameter<"Label for index">:$label
+ );
+ let hasCustomAssemblyFormat = 1;
// DECL-LABEL: class IndexType : public ::mlir::Type
// DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
@@ -95,8 +95,7 @@ def C_IndexType : TestType<"Index"> {
}
def D_SingleParameterType : TestType<"SingleParameter"> {
- let parameters = (
- ins
+ let parameters = (ins
"int": $num
);
// DECL-LABEL: struct SingleParameterTypeStorage;
@@ -105,17 +104,17 @@ def D_SingleParameterType : TestType<"SingleParameter"> {
}
def E_IntegerType : TestType<"Integer"> {
- let mnemonic = "int";
- let genVerifyDecl = 1;
- let parameters = (
- ins
- "SignednessSemantics":$signedness,
- TypeParameter<"unsigned", "Bitwidth of integer">:$width
- );
+ let mnemonic = "int";
+ let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
+ let parameters = (ins
+ "SignednessSemantics":$signedness,
+ TypeParameter<"unsigned", "Bitwidth of integer">:$width
+ );
// DECL-LABEL: IntegerType : public ::mlir::Type
- let extraClassDeclaration = [{
+ let extraClassDeclaration = [{
/// Signedness semantics.
enum SignednessSemantics {
Signless, /// No signedness semantics
@@ -132,7 +131,7 @@ def E_IntegerType : TestType<"Integer"> {
bool isSigned() const { return getSignedness() == Signed; }
/// Return true if this is an unsigned integer type.
bool isUnsigned() const { return getSignedness() == Unsigned; }
- }];
+ }];
// DECL: /// Signedness semantics.
// DECL-NEXT: enum SignednessSemantics {
diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 5153c6e507851..c9a512ae18e58 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -121,10 +121,6 @@ class DefGen {
/// Emit a checked custom builder.
void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
- //===--------------------------------------------------------------------===//
- // Parser and Printer Emission
- void emitParserPrinterBody(MethodBody &parser, MethodBody &printer);
-
//===--------------------------------------------------------------------===//
// Interface Method Emission
@@ -264,9 +260,10 @@ void DefGen::emitParserPrinter() {
auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
"::llvm::StringLiteral", "getMnemonic");
mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic());
+
// Declare the parser and printer, if needed.
- if (!def.needsParserPrinter() && !def.hasGeneratedParser() &&
- !def.hasGeneratedPrinter())
+ bool hasAssemblyFormat = def.getAssemblyFormat().hasValue();
+ if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
return;
// Declare the parser.
@@ -274,18 +271,18 @@ void DefGen::emitParserPrinter() {
parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
if (isa<AttrDef>(&def))
parserParams.emplace_back("::mlir::Type", "odsType");
- auto *parser = defCls.addMethod(
- strfmt("::mlir::{0}", valueType), "parse",
- def.hasGeneratedParser() ? Method::Static : Method::StaticDeclaration,
- std::move(parserParams));
+ auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse",
+ hasAssemblyFormat ? Method::Static
+ : Method::StaticDeclaration,
+ std::move(parserParams));
// Declare the printer.
- auto props =
- def.hasGeneratedPrinter() ? Method::Const : Method::ConstDeclaration;
+ auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
Method *printer =
defCls.addMethod("void", "print", props,
MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
- // Emit the bodies.
- emitParserPrinterBody(parser->body(), printer->body());
+ // Emit the bodies if we are using the declarative format.
+ if (hasAssemblyFormat)
+ return generateAttrOrTypeFormat(def, parser->body(), printer->body());
}
void DefGen::emitAccessors() {
@@ -406,50 +403,6 @@ void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
m->body().indent().getStream().printReindented(bodyStr);
}
-//===----------------------------------------------------------------------===//
-// Parser and Printer Emission
-
-void DefGen::emitParserPrinterBody(MethodBody &parser, MethodBody &printer) {
- Optional<StringRef> parserCode = def.getParserCode();
- Optional<StringRef> printerCode = def.getPrinterCode();
- Optional<StringRef> asmFormat = def.getAssemblyFormat();
- // Verify the parser-printer specification first.
- if (asmFormat && (parserCode || printerCode)) {
- PrintFatalError(def.getLoc(),
- def.getName() + ": assembly format cannot be specified at "
- "the same time as printer or parser code");
- }
- // Specified code cannot be empty.
- if (parserCode && parserCode->empty())
- PrintFatalError(def.getLoc(), def.getName() + ": parser cannot be empty");
- if (printerCode && printerCode->empty())
- PrintFatalError(def.getLoc(), def.getName() + ": printer cannot be empty");
- // Assembly format requires accessors to be generated.
- if (asmFormat && !def.genAccessors()) {
- PrintFatalError(def.getLoc(),
- def.getName() +
- ": the generated printer from 'assemblyFormat' "
- "requires 'genAccessors' to be true");
- }
-
- // Generate the parser and printer bodies.
- if (asmFormat)
- return generateAttrOrTypeFormat(def, parser, printer);
-
- FmtContext ctx = FmtContext({{"_parser", "odsParser"},
- {"_printer", "odsPrinter"},
- {"_type", "odsType"}});
- if (parserCode) {
- ctx.addSubst("_ctxt", "odsParser.getContext()");
- parser.indent().getStream().printReindented(tgfmt(*parserCode, &ctx).str());
- }
- if (printerCode) {
- ctx.addSubst("_ctxt", "odsPrinter.getContext()");
- printer.indent().getStream().printReindented(
- tgfmt(*printerCode, &ctx).str());
- }
-}
-
//===----------------------------------------------------------------------===//
// Interface Method Emission
@@ -829,18 +782,21 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
for (auto &def : defs) {
if (!def.getMnemonic())
continue;
+ bool hasParserPrinterDecl =
+ def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
std::string defClass = strfmt(
"{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName());
+
// If the def has no parameters or parser code, invoke a normal `get`.
std::string parseOrGet =
- def.needsParserPrinter() || def.hasGeneratedParser()
+ hasParserPrinterDecl
? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "")
: "get(parser.getContext())";
parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet);
// If the def has no parameters and no printer, just print the mnemonic.
StringRef printDef = "";
- if (def.needsParserPrinter() || def.hasGeneratedPrinter())
+ if (hasParserPrinterDecl)
printDef = "\nt.print(printer);";
printer.body() << llvm::formatv(printValue, defClass, printDef);
}
diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index b3be48176d454..9b9c64884a360 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -253,8 +253,7 @@ static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
os << "\n" << def.getSummary() << "\n";
// Emit the syntax if present.
- if (def.getMnemonic() && def.getPrinterCode() == StringRef() &&
- def.getParserCode() == StringRef())
+ if (def.getMnemonic() && !def.hasCustomAssemblyFormat())
emitAttrOrTypeDefAssemblyFormat(def, os);
// Emit the description if present.
More information about the Mlir-commits
mailing list