[Mlir-commits] [mlir] 23e3cbe - [mlir] Refactor how parser/printers are specified for AttrDef/TypeDef

River Riddle llvmlistbot at llvm.org
Tue Mar 15 00:51:08 PDT 2022


Author: River Riddle
Date: 2022-03-15T00:42:31-07:00
New Revision: 23e3cbe24a51604a03a379664f67ed79e5fef897

URL: https://github.com/llvm/llvm-project/commit/23e3cbe24a51604a03a379664f67ed79e5fef897
DIFF: https://github.com/llvm/llvm-project/commit/23e3cbe24a51604a03a379664f67ed79e5fef897.diff

LOG: [mlir] Refactor how parser/printers are specified for AttrDef/TypeDef

There is currently an awkwardly complex set of rules for how a
parser/printer is generated for AttrDef/TypeDef. It can change depending on if a
mnemonic was specified, if there are parameters, if using the assemblyFormat, if
individual parser/printer code blocks were specified, etc. This commit refactors
this to make what the attribute/type wants more explicit, and to better align
with how formats are specified for operations.

Firstly, the parser/printer code blocks are removed in favor of a
`hasCustomAssemblyFormat` bit field. This aligns with the operation format
specification (and is nice to remove code blocks from ODS).

This commit also adds a requirement to explicitly set `assemblyFormat` or
`hasCustomAssemblyFormat` when the mnemonic is set and the attr/type
has no parameters. This removes the weird implicit matrix of behavior,
and also encourages the author to make a conscious choice of either C++
or declarative format instead of implicitly opting them into the C++
format (we should be pushing towards declarative when possible).

Differential Revision: https://reviews.llvm.org/D121505

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/FIRTypes.td
    mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
    mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
    mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
    mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
    mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/IR/AttrTypeBase.td
    mlir/include/mlir/TableGen/AttrOrTypeDef.h
    mlir/lib/TableGen/AttrOrTypeDef.cpp
    mlir/test/lib/Dialect/Test/TestAttrDefs.td
    mlir/test/lib/Dialect/Test/TestTypeDefs.td
    mlir/test/lib/Dialect/Test/TestTypes.cpp
    mlir/test/mlir-tblgen/attrdefs.td
    mlir/test/mlir-tblgen/testdialect-typedefs.mlir
    mlir/test/mlir-tblgen/typedefs.td
    mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
    mlir/tools/mlir-tblgen/OpDocGen.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/FIRTypes.td b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
index 0875f4abb8332..4db44bfb9262e 100644
--- a/flang/include/flang/Optimizer/Dialect/FIRTypes.td
+++ b/flang/include/flang/Optimizer/Dialect/FIRTypes.td
@@ -39,6 +39,7 @@ def fir_BoxCharType : FIR_Type<"BoxChar", "boxchar"> {
   let parameters = (ins "KindTy":$kind);
 
   let genAccessors = 1;
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     using KindTy = unsigned;
@@ -62,6 +63,7 @@ def fir_BoxProcType : FIR_Type<"BoxProc", "boxproc"> {
   let parameters = (ins "mlir::Type":$eleTy);
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 }
 
 def fir_BoxType : FIR_Type<"Box", "box"> {
@@ -91,6 +93,7 @@ def fir_BoxType : FIR_Type<"Box", "box"> {
   }];
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 }
 
 def fir_CharacterType : FIR_Type<"Character", "char"> {
@@ -104,6 +107,7 @@ def fir_CharacterType : FIR_Type<"Character", "char"> {
   }];
 
   let parameters = (ins "KindTy":$FKind, "CharacterType::LenType":$len);
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     using KindTy = unsigned;
@@ -143,6 +147,7 @@ def fir_ComplexType : FIR_Type<"Complex", "complex"> {
   }];
 
   let parameters = (ins "KindTy":$fKind);
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     using KindTy = unsigned;
@@ -174,6 +179,7 @@ def fir_HeapType : FIR_Type<"Heap", "heap"> {
   let parameters = (ins "mlir::Type":$eleTy);
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 
   let skipDefaultBuilders = 1;
 
@@ -194,6 +200,7 @@ def fir_IntegerType : FIR_Type<"Integer", "int"> {
   }];
 
   let parameters = (ins "KindTy":$fKind);
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     using KindTy = unsigned;
@@ -219,6 +226,7 @@ def fir_LogicalType : FIR_Type<"Logical", "logical"> {
   }];
 
   let parameters = (ins "KindTy":$fKind);
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     using KindTy = unsigned;
@@ -259,6 +267,7 @@ def fir_PointerType : FIR_Type<"Pointer", "ptr"> {
   let parameters = (ins "mlir::Type":$eleTy);
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 
   let skipDefaultBuilders = 1;
 
@@ -283,6 +292,7 @@ def fir_RealType : FIR_Type<"Real", "real"> {
   }];
 
   let parameters = (ins "KindTy":$fKind);
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     using KindTy = unsigned;
@@ -304,6 +314,7 @@ def fir_RecordType : FIR_Type<"Record", "type"> {
 
   let genVerifyDecl = 1;
   let genStorageClass = 0;
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     using TypePair = std::pair<std::string, mlir::Type>;
@@ -351,6 +362,7 @@ def fir_ReferenceType : FIR_Type<"Reference", "ref"> {
   }];
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 }
 
 def fir_ShapeType : FIR_Type<"Shape", "shape"> {
@@ -363,6 +375,7 @@ def fir_ShapeType : FIR_Type<"Shape", "shape"> {
   }];
 
   let parameters = (ins "unsigned":$rank);
+  let hasCustomAssemblyFormat = 1;
 }
 
 def fir_ShapeShiftType : FIR_Type<"ShapeShift", "shapeshift"> {
@@ -376,6 +389,7 @@ def fir_ShapeShiftType : FIR_Type<"ShapeShift", "shapeshift"> {
   }];
 
   let parameters = (ins "unsigned":$rank);
+  let hasCustomAssemblyFormat = 1;
 }
 
 def fir_ShiftType : FIR_Type<"Shift", "shift"> {
@@ -388,6 +402,7 @@ def fir_ShiftType : FIR_Type<"Shift", "shift"> {
   }];
 
   let parameters = (ins "unsigned":$rank);
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     using KindTy = unsigned;
@@ -417,6 +432,7 @@ def fir_SequenceType : FIR_Type<"Sequence", "array"> {
   );
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 
   let builders = [
     TypeBuilderWithInferredContext<(ins
@@ -470,6 +486,7 @@ def fir_SliceType : FIR_Type<"Slice", "slice"> {
   }];
 
   let parameters = (ins "unsigned":$rank);
+  let hasCustomAssemblyFormat = 1;
 }
 
 def fir_TypeDescType : FIR_Type<"TypeDesc", "tdesc"> {
@@ -483,6 +500,7 @@ def fir_TypeDescType : FIR_Type<"TypeDesc", "tdesc"> {
   let parameters = (ins "mlir::Type":$ofTy);
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 
   let skipDefaultBuilders = 1;
 
@@ -505,6 +523,7 @@ def fir_VectorType : FIR_Type<"Vector", "vector"> {
   let parameters = (ins "uint64_t":$len, "mlir::Type":$eleTy);
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     static bool isValidElementType(mlir::Type t);

diff  --git a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
index 9d3adf28c8a79..a9dced68aae52 100644
--- a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
+++ b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td
@@ -47,6 +47,7 @@ def Async_ValueType : Async_Type<"Value", "value"> {
       return $_get(valueType.getContext(), valueType);
     }]>
   ];
+  let hasCustomAssemblyFormat = 1;
   let skipDefaultBuilders = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
index bbe6dc51a2369..d9f724ae5b06a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCAttributes.td
@@ -41,6 +41,8 @@ def EmitC_OpaqueAttr : EmitC_Attr<"Opaque", "opaque"> {
   }];
 
   let parameters = (ins StringRefParameter<"the opaque value">:$value);
+  
+  let hasCustomAssemblyFormat = 1;
 }
 
 #endif // MLIR_DIALECT_EMITC_IR_EMITCATTRIBUTES

diff  --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index 1604f7dd37808..9a73f6ff3b1dc 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -42,6 +42,7 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
   }];
 
   let parameters = (ins StringRefParameter<"the opaque value">:$value);
+  let hasCustomAssemblyFormat = 1;
 }
 
 def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {

diff  --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
index f29ea92db3b69..f5b97002b5a15 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
@@ -23,6 +23,7 @@ def FastmathFlagsAttr : LLVM_Attr<"FMF"> {
   let parameters = (ins
     "FastmathFlags":$flags
   );
+  let hasCustomAssemblyFormat = 1;
 }
 
 // Attribute definition for the LLVM Linkage enum.
@@ -31,6 +32,7 @@ def LinkageAttr : LLVM_Attr<"Linkage"> {
   let parameters = (ins
     "linkage::Linkage":$linkage
   );
+  let hasCustomAssemblyFormat = 1;
 }
 
 def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
@@ -63,6 +65,7 @@ def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
     AttrBuilder<(ins "ArrayRef<std::pair<LoopOptionCase, int64_t>>":$sortedOptions)>,
     AttrBuilder<(ins "LoopOptionsAttrBuilder &":$optionBuilders)>
   ];
+  let hasCustomAssemblyFormat = 1;
   let skipDefaultBuilders = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
index cfe3dd0b50edb..d690fca937a9f 100644
--- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
+++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td
@@ -67,6 +67,7 @@ def PDL_Range : PDL_Type<"Range", "range"> {
     }]>,
   ];
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
   let skipDefaultBuilders = 1;
 }
 

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 94ca03808e0ba..2db552f5fe039 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -80,6 +80,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
   );
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 
   let extraClassDeclaration = [{
     // Dimension level types that define sparse tensors:

diff  --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 8e263b646697b..ead971f01bbc1 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -188,20 +188,20 @@ class AttrOrTypeDef<string valueType, string name, list<Trait> defTraits,
 
   // Use the lowercased name as the keyword for parsing/printing. Specify only
   // if you want tblgen to generate declarations and/or definitions of
-  // the printer/parser.
+  // the printer/parser. If specified and the Attribute or Type contains
+  // parameters, `assemblyFormat` or `hasCustomAssemblyFormat` must also be
+  // specified.
   string mnemonic = ?;
 
-  // If 'mnemonic' specified,
-  //   If null, generate just the declarations.
-  //   If a non-empty code block, just use that code as the definition code.
-  //   Error if an empty code block.
-  code printer = ?;
-  code parser = ?;
-
   // Custom assembly format. Requires 'mnemonic' to be specified. Cannot be
-  // specified at the same time as either 'printer' or 'parser'. The generated
+  // specified at the same time as 'hasCustomAssemblyFormat'. The generated
   // printer requires 'genAccessors' to be true.
   string assemblyFormat = ?;
+  /// This field indicates that the attribute or type has a custom assembly format
+  /// implemented in C++. When set to `1` a `parse` and `print` method are generated
+  /// on the generated class. The attribute or type should implement these methods to
+  /// support the custom format.
+  bit hasCustomAssemblyFormat = 0;
 
   // If set, generate accessors for each parameter.
   bit genAccessors = 1;

diff  --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 557f60b2686eb..c2aafca0831b0 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -175,30 +175,13 @@ class AttrOrTypeDef {
   /// supposed to auto-generate them.
   Optional<StringRef> getMnemonic() const;
 
-  /// Returns the code to use as the types printer method. If not specified,
-  /// return a non-value. Otherwise, return the contents of that code block.
-  Optional<StringRef> getPrinterCode() const;
-
-  /// Returns the code to use as the parser method. If not specified, returns
-  /// None. Otherwise, returns the contents of that code block.
-  Optional<StringRef> getParserCode() const;
+  /// Returns if the attribute or type has a custom assembly format implemented
+  /// in C++. Corresponds to the `hasCustomAssemblyFormat` field.
+  bool hasCustomAssemblyFormat() const;
 
   /// Returns the custom assembly format, if one was specified.
   Optional<StringRef> getAssemblyFormat() const;
 
-  /// An attribute or type with parameters needs a parser.
-  bool needsParserPrinter() const { return getNumParameters() != 0; }
-
-  /// Returns true if this attribute or type has a generated parser.
-  bool hasGeneratedParser() const {
-    return getParserCode() || getAssemblyFormat();
-  }
-
-  /// Returns true if this attribute or type has a generated printer.
-  bool hasGeneratedPrinter() const {
-    return getPrinterCode() || getAssemblyFormat();
-  }
-
   /// Returns true if the accessors based on the parameters should be generated.
   bool genAccessors() const;
 

diff  --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp
index 3c0cfa029c700..444db742bd32e 100644
--- a/mlir/lib/TableGen/AttrOrTypeDef.cpp
+++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp
@@ -62,6 +62,30 @@ AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) {
     for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i)
       parameters.push_back(AttrOrTypeParameter(parametersDag, i));
   }
+
+  // Verify the use of the mnemonic field.
+  bool hasCppFormat = hasCustomAssemblyFormat();
+  bool hasDeclarativeFormat = getAssemblyFormat().hasValue();
+  if (getMnemonic()) {
+    if (hasCppFormat && hasDeclarativeFormat) {
+      PrintFatalError(getLoc(), "cannot specify both 'assemblyFormat' "
+                                "and 'hasCustomAssemblyFormat'");
+    }
+    if (!parameters.empty() && !hasCppFormat && !hasDeclarativeFormat) {
+      PrintFatalError(getLoc(),
+                      "must specify either 'assemblyFormat' or "
+                      "'hasCustomAssemblyFormat' when 'mnemonic' is set");
+    }
+  } else if (hasCppFormat || hasDeclarativeFormat) {
+    PrintFatalError(getLoc(),
+                    "'assemblyFormat' or 'hasCustomAssemblyFormat' can only be "
+                    "used when 'mnemonic' is set");
+  }
+  // Assembly format requires accessors to be generated.
+  if (hasDeclarativeFormat && !genAccessors()) {
+    PrintFatalError(getLoc(),
+                    "'assemblyFormat' requires 'genAccessors' to be true");
+  }
 }
 
 Dialect AttrOrTypeDef::getDialect() const {
@@ -122,12 +146,8 @@ Optional<StringRef> AttrOrTypeDef::getMnemonic() const {
   return def->getValueAsOptionalString("mnemonic");
 }
 
-Optional<StringRef> AttrOrTypeDef::getPrinterCode() const {
-  return def->getValueAsOptionalString("printer");
-}
-
-Optional<StringRef> AttrOrTypeDef::getParserCode() const {
-  return def->getValueAsOptionalString("parser");
+bool AttrOrTypeDef::hasCustomAssemblyFormat() const {
+  return def->getValueAsBit("hasCustomAssemblyFormat");
 }
 
 Optional<StringRef> AttrOrTypeDef::getAssemblyFormat() const {

diff  --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 99ad93533c4ca..fc2ce63a22b67 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -43,6 +43,7 @@ def CompoundAttrA : Test_Attr<"CompoundA"> {
       "An example of an array of ints" // Parameter description.
       >: $arrayOfInts
   );
+  let hasCustomAssemblyFormat = 1;
 }
 def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> {
   let mnemonic = "cmpnd_nested";
@@ -54,6 +55,7 @@ def CompoundAttrNested : Test_Attr<"CompoundAttrNested"> {
 def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> {
   let mnemonic = "attr_with_self_type_param";
   let parameters = (ins AttributeSelfTypeParameter<"">:$type);
+  let hasCustomAssemblyFormat = 1;
 }
 
 // An attribute testing AttributeSelfTypeParameter.
@@ -61,6 +63,7 @@ def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> {
   let mnemonic = "attr_with_type_builder";
   let parameters = (ins "::mlir::IntegerAttr":$attr);
   let typeBuilder = "$_attr.getType()";
+  let hasCustomAssemblyFormat = 1;
 }
 
 def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
@@ -68,7 +71,6 @@ def TestAttrTrait : NativeAttrTrait<"TestAttrTrait">;
 // The definition of a singleton attribute that has a trait.
 def AttrWithTrait : Test_Attr<"AttrWithTrait", [TestAttrTrait]> {
   let mnemonic = "attr_with_trait";
-  let parameters = (ins );
 }
 
 // Test support for ElementsAttrInterface.
@@ -106,6 +108,7 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
     }
   }];
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 }
 
 def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
@@ -120,6 +123,7 @@ def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
     "::mlir::Attribute":$second,
     "::mlir::Attribute":$third
   );
+  let hasCustomAssemblyFormat = 1;
 }
 
 // A more complex parameterized attribute with multiple level of nesting.

diff  --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 75cccdaf8fd45..fd50dd3e0e90b 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -48,6 +48,7 @@ def CompoundTypeA : Test_Type<"CompoundA"> {
   let extraClassDeclaration = [{
     struct SomeCppStruct {};
   }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 // A more complex and nested parameterized type.
@@ -92,12 +93,8 @@ def IntegerType : Test_Type<"TestInteger"> {
     "::test::TestIntegerType::SignednessSemantics":$signedness
   );
 
-  // We define the printer inline.
-  let printer = [{
-    $_printer << "<";
-    printSignedness($_printer, getImpl()->signedness);
-    $_printer << ", " << getImpl()->width << ">";
-  }];
+  // Indicate we use a custom format.
+  let hasCustomAssemblyFormat = 1;
 
   // Define custom builder methods.
   let builders = [
@@ -108,19 +105,6 @@ def IntegerType : Test_Type<"TestInteger"> {
   ];
   let skipDefaultBuilders = 1;
 
-  // The parser is defined here also.
-  let parser = [{
-    if ($_parser.parseLess()) return Type();
-    SignednessSemantics signedness;
-    if (parseSignedness($_parser, signedness)) return Type();
-    if ($_parser.parseComma()) return Type();
-    int width;
-    if ($_parser.parseInteger(width)) return Type();
-    if ($_parser.parseGreater()) return Type();
-    Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
-    return getChecked(loc, loc.getContext(), width, signedness);
-  }];
-
   // Any extra code one wants in the type's class declaration.
   let extraClassDeclaration = [{
     /// Signedness semantics.
@@ -150,37 +134,7 @@ class FieldInfo_Type<string name> : Test_Type<name> {
       "::test::FieldInfo", // FieldInfo is defined/declared in TestTypes.h.
       "Models struct fields">: $fields
   );
-
-  // Prints the type in this format:
-  //   struct<[{field1Name, field1Type}, {field2Name, field2Type}]
-  let printer = [{
-    $_printer << "<";
-    for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) {
-      const auto& field = getImpl()->fields[i];
-      $_printer << "{" << field.name << "," << field.type << "}";
-      if (i < getImpl()->fields.size() - 1)
-          $_printer << ",";
-    }
-    $_printer << ">";
-  }];
-
-  // Parses the above format
-  let parser = [{
-    llvm::SmallVector<FieldInfo, 4> parameters;
-    if ($_parser.parseLess()) return Type();
-    while (mlir::succeeded($_parser.parseOptionalLBrace())) {
-      llvm::StringRef name;
-      if ($_parser.parseKeyword(&name)) return Type();
-      if ($_parser.parseComma()) return Type();
-      Type type;
-      if ($_parser.parseType(type)) return Type();
-      if ($_parser.parseRBrace()) return Type();
-      parameters.push_back(FieldInfo {name, type});
-      if ($_parser.parseOptionalComma()) break;
-    }
-    if ($_parser.parseGreater()) return Type();
-    return get($_ctxt, parameters);
-  }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 def StructType : FieldInfo_Type<"Struct"> {
@@ -208,6 +162,7 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
 
   public:
   }];
+  let hasCustomAssemblyFormat = 1;
 }
 
 def TestMemRefElementType : Test_Type<"TestMemRefElementType",

diff  --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 5bcf62a3fa272..833e90de2162a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -83,6 +83,65 @@ static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT
   return llvm::hash_combine(fi.name, fi.type);
 }
 
+//===----------------------------------------------------------------------===//
+// TestCustomType
+//===----------------------------------------------------------------------===//
+
+static LogicalResult parseCustomTypeA(AsmParser &parser,
+                                      FailureOr<int> &a_result) {
+  a_result.emplace();
+  return parser.parseInteger(*a_result);
+}
+
+static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
+
+static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
+                                      FailureOr<Optional<int>> &b_result) {
+  if (a < 0)
+    return success();
+  for (int i : llvm::seq(0, a))
+    if (failed(parser.parseInteger(i)))
+      return failure();
+  b_result.emplace(0);
+  return parser.parseInteger(**b_result);
+}
+
+static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
+  if (a < 0)
+    return;
+  printer << ' ';
+  for (int i : llvm::seq(0, a))
+    printer << i << ' ';
+  printer << *b;
+}
+
+static LogicalResult parseFooString(AsmParser &parser,
+                                    FailureOr<std::string> &foo) {
+  std::string result;
+  if (parser.parseString(&result))
+    return failure();
+  foo = std::move(result);
+  return success();
+}
+
+static void printFooString(AsmPrinter &printer, StringRef foo) {
+  printer << '"' << foo << '"';
+}
+
+static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
+  return parser.parseKeyword(foo);
+}
+
+static void printBarString(AsmPrinter &printer, StringRef foo) {
+  printer << ' ' << foo;
+}
+//===----------------------------------------------------------------------===//
+// Tablegen Generated Definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_TYPEDEF_CLASSES
+#include "TestTypeDefs.cpp.inc"
+
 //===----------------------------------------------------------------------===//
 // CompoundAType
 //===----------------------------------------------------------------------===//
@@ -129,6 +188,54 @@ TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
   return success();
 }
 
+Type TestIntegerType::parse(AsmParser &parser) {
+  SignednessSemantics signedness;
+  int width;
+  if (parser.parseLess() || parseSignedness(parser, signedness) ||
+      parser.parseComma() || parser.parseInteger(width) ||
+      parser.parseGreater())
+    return Type();
+  Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
+  return getChecked(loc, loc.getContext(), width, signedness);
+}
+
+void TestIntegerType::print(AsmPrinter &p) const {
+  p << "<";
+  printSignedness(p, getSignedness());
+  p << ", " << getWidth() << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// TestStructType
+//===----------------------------------------------------------------------===//
+
+Type StructType::parse(AsmParser &p) {
+  SmallVector<FieldInfo, 4> parameters;
+  if (p.parseLess())
+    return Type();
+  while (succeeded(p.parseOptionalLBrace())) {
+    Type type;
+    StringRef name;
+    if (p.parseKeyword(&name) || p.parseComma() || p.parseType(type) ||
+        p.parseRBrace())
+      return Type();
+    parameters.push_back(FieldInfo{name, type});
+    if (p.parseOptionalComma())
+      break;
+  }
+  if (p.parseGreater())
+    return Type();
+  return get(p.getContext(), parameters);
+}
+
+void StructType::print(AsmPrinter &p) const {
+  p << "<";
+  llvm::interleaveComma(getFields(), p, [&](const FieldInfo &field) {
+    p << "{" << field.name << "," << field.type << "}";
+  });
+  p << ">";
+}
+
 //===----------------------------------------------------------------------===//
 // TestType
 //===----------------------------------------------------------------------===//
@@ -208,66 +315,6 @@ unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params,
   return 1;
 }
 
-//===----------------------------------------------------------------------===//
-// TestCustomType
-//===----------------------------------------------------------------------===//
-
-static LogicalResult parseCustomTypeA(AsmParser &parser,
-                                      FailureOr<int> &a_result) {
-  a_result.emplace();
-  return parser.parseInteger(*a_result);
-}
-
-static void printCustomTypeA(AsmPrinter &printer, int a) { printer << a; }
-
-static LogicalResult parseCustomTypeB(AsmParser &parser, int a,
-                                      FailureOr<Optional<int>> &b_result) {
-  if (a < 0)
-    return success();
-  for (int i : llvm::seq(0, a))
-    if (failed(parser.parseInteger(i)))
-      return failure();
-  b_result.emplace(0);
-  return parser.parseInteger(**b_result);
-}
-
-static void printCustomTypeB(AsmPrinter &printer, int a, Optional<int> b) {
-  if (a < 0)
-    return;
-  printer << ' ';
-  for (int i : llvm::seq(0, a))
-    printer << i << ' ';
-  printer << *b;
-}
-
-static LogicalResult parseFooString(AsmParser &parser,
-                                    FailureOr<std::string> &foo) {
-  std::string result;
-  if (parser.parseString(&result))
-    return failure();
-  foo = std::move(result);
-  return success();
-}
-
-static void printFooString(AsmPrinter &printer, StringRef foo) {
-  printer << '"' << foo << '"';
-}
-
-static LogicalResult parseBarString(AsmParser &parser, StringRef foo) {
-  return parser.parseKeyword(foo);
-}
-
-static void printBarString(AsmPrinter &printer, StringRef foo) {
-  printer << ' ' << foo;
-}
-
-//===----------------------------------------------------------------------===//
-// Tablegen Generated Definitions
-//===----------------------------------------------------------------------===//
-
-#define GET_TYPEDEF_CLASSES
-#include "TestTypeDefs.cpp.inc"
-
 //===----------------------------------------------------------------------===//
 // TestDialect
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td
index 70a021ebb385b..3737ce77948e2 100644
--- a/mlir/test/mlir-tblgen/attrdefs.td
+++ b/mlir/test/mlir-tblgen/attrdefs.td
@@ -60,6 +60,7 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
   );
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 
 // DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute
 // DECL: static CompoundAAttr getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
@@ -102,6 +103,7 @@ def C_IndexAttr : TestAttr<"Index"> {
       ins
       StringRefParameter<"Label for index">:$label
     );
+  let hasCustomAssemblyFormat = 1;
 
 // DECL-LABEL: class IndexAttr : public ::mlir::Attribute
 // DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
@@ -127,6 +129,7 @@ def E_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> {
   let mnemonic = "attr_with_type_builder";
   let parameters = (ins "::mlir::IntegerAttr":$attr);
   let typeBuilder = "$_attr.getType()";
+  let hasCustomAssemblyFormat = 1;
 }
 
 // DEF-LABEL: struct AttrWithTypeBuilderAttrStorage

diff  --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
index dc8aea52559b4..1b568f197f1da 100644
--- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
@@ -38,7 +38,7 @@ func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<
   return
 }
 
-// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<none, 3>}>)
+// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla}, {field2,!test.int<none, 3>}>)
 func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int<none, 3>} > ) {
   return
 }

diff  --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td
index af8b70a6dfec4..833195530f9b0 100644
--- a/mlir/test/mlir-tblgen/typedefs.td
+++ b/mlir/test/mlir-tblgen/typedefs.td
@@ -35,8 +35,8 @@ include "mlir/IR/OpBase.td"
 
 def Test_Dialect: Dialect {
 // DECL-NOT: TestDialect
-    let name = "TestDialect";
-    let cppNamespace = "::test";
+  let name = "TestDialect";
+  let cppNamespace = "::test";
 }
 
 class TestType<string name> : TypeDef<Test_Dialect, name> { }
@@ -54,16 +54,16 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
   let summary = "A more complex parameterized type";
   let description = "This type is to test a reasonably complex type";
   let mnemonic = "cmpnd_a";
-  let parameters = (
-      ins
-      "int":$widthOfSomething,
-      "::test::SimpleTypeA": $exampleTdType,
-      "SomeCppStruct": $exampleCppType,
-      ArrayRefParameter<"int", "Matrix dimensions">:$dims,
-      RTLValueType:$inner
+  let parameters = (ins
+    "int":$widthOfSomething,
+    "::test::SimpleTypeA": $exampleTdType,
+    "SomeCppStruct": $exampleCppType,
+    ArrayRefParameter<"int", "Matrix dimensions">:$dims,
+    RTLValueType:$inner
   );
 
   let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
 
 // DECL-LABEL: class CompoundAType : public ::mlir::Type
 // DECL: static CompoundAType getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims, ::mlir::Type inner);
@@ -79,12 +79,12 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
 }
 
 def C_IndexType : TestType<"Index"> {
-    let mnemonic = "index";
+  let mnemonic = "index";
 
-    let parameters = (
-      ins
-      StringRefParameter<"Label for index">:$label
-    );
+  let parameters = (ins
+    StringRefParameter<"Label for index">:$label
+  );
+  let hasCustomAssemblyFormat = 1;
 
 // DECL-LABEL: class IndexType : public ::mlir::Type
 // DECL: static constexpr ::llvm::StringLiteral getMnemonic() {
@@ -95,8 +95,7 @@ def C_IndexType : TestType<"Index"> {
 }
 
 def D_SingleParameterType : TestType<"SingleParameter"> {
-  let parameters = (
-    ins
+  let parameters = (ins
     "int": $num
   );
 // DECL-LABEL: struct SingleParameterTypeStorage;
@@ -105,17 +104,17 @@ def D_SingleParameterType : TestType<"SingleParameter"> {
 }
 
 def E_IntegerType : TestType<"Integer"> {
-    let mnemonic = "int";
-    let genVerifyDecl = 1;
-    let parameters = (
-        ins
-        "SignednessSemantics":$signedness,
-        TypeParameter<"unsigned", "Bitwidth of integer">:$width
-    );
+  let mnemonic = "int";
+  let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
+  let parameters = (ins
+      "SignednessSemantics":$signedness,
+      TypeParameter<"unsigned", "Bitwidth of integer">:$width
+  );
 
 // DECL-LABEL: IntegerType : public ::mlir::Type
 
-    let extraClassDeclaration = [{
+  let extraClassDeclaration = [{
   /// Signedness semantics.
   enum SignednessSemantics {
     Signless, /// No signedness semantics
@@ -132,7 +131,7 @@ def E_IntegerType : TestType<"Integer"> {
   bool isSigned() const { return getSignedness() == Signed; }
   /// Return true if this is an unsigned integer type.
   bool isUnsigned() const { return getSignedness() == Unsigned; }
-    }];
+  }];
 
 // DECL: /// Signedness semantics.
 // DECL-NEXT: enum SignednessSemantics {

diff  --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
index 5153c6e507851..c9a512ae18e58 100644
--- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
+++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
@@ -121,10 +121,6 @@ class DefGen {
   /// Emit a checked custom builder.
   void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder);
 
-  //===--------------------------------------------------------------------===//
-  // Parser and Printer Emission
-  void emitParserPrinterBody(MethodBody &parser, MethodBody &printer);
-
   //===--------------------------------------------------------------------===//
   // Interface Method Emission
 
@@ -264,9 +260,10 @@ void DefGen::emitParserPrinter() {
   auto *mnemonic = defCls.addStaticMethod<Method::Constexpr>(
       "::llvm::StringLiteral", "getMnemonic");
   mnemonic->body().indent() << strfmt("return {\"{0}\"};", *def.getMnemonic());
+
   // Declare the parser and printer, if needed.
-  if (!def.needsParserPrinter() && !def.hasGeneratedParser() &&
-      !def.hasGeneratedPrinter())
+  bool hasAssemblyFormat = def.getAssemblyFormat().hasValue();
+  if (!def.hasCustomAssemblyFormat() && !hasAssemblyFormat)
     return;
 
   // Declare the parser.
@@ -274,18 +271,18 @@ void DefGen::emitParserPrinter() {
   parserParams.emplace_back("::mlir::AsmParser &", "odsParser");
   if (isa<AttrDef>(&def))
     parserParams.emplace_back("::mlir::Type", "odsType");
-  auto *parser = defCls.addMethod(
-      strfmt("::mlir::{0}", valueType), "parse",
-      def.hasGeneratedParser() ? Method::Static : Method::StaticDeclaration,
-      std::move(parserParams));
+  auto *parser = defCls.addMethod(strfmt("::mlir::{0}", valueType), "parse",
+                                  hasAssemblyFormat ? Method::Static
+                                                    : Method::StaticDeclaration,
+                                  std::move(parserParams));
   // Declare the printer.
-  auto props =
-      def.hasGeneratedPrinter() ? Method::Const : Method::ConstDeclaration;
+  auto props = hasAssemblyFormat ? Method::Const : Method::ConstDeclaration;
   Method *printer =
       defCls.addMethod("void", "print", props,
                        MethodParameter("::mlir::AsmPrinter &", "odsPrinter"));
-  // Emit the bodies.
-  emitParserPrinterBody(parser->body(), printer->body());
+  // Emit the bodies if we are using the declarative format.
+  if (hasAssemblyFormat)
+    return generateAttrOrTypeFormat(def, parser->body(), printer->body());
 }
 
 void DefGen::emitAccessors() {
@@ -406,50 +403,6 @@ void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) {
   m->body().indent().getStream().printReindented(bodyStr);
 }
 
-//===----------------------------------------------------------------------===//
-// Parser and Printer Emission
-
-void DefGen::emitParserPrinterBody(MethodBody &parser, MethodBody &printer) {
-  Optional<StringRef> parserCode = def.getParserCode();
-  Optional<StringRef> printerCode = def.getPrinterCode();
-  Optional<StringRef> asmFormat = def.getAssemblyFormat();
-  // Verify the parser-printer specification first.
-  if (asmFormat && (parserCode || printerCode)) {
-    PrintFatalError(def.getLoc(),
-                    def.getName() + ": assembly format cannot be specified at "
-                                    "the same time as printer or parser code");
-  }
-  // Specified code cannot be empty.
-  if (parserCode && parserCode->empty())
-    PrintFatalError(def.getLoc(), def.getName() + ": parser cannot be empty");
-  if (printerCode && printerCode->empty())
-    PrintFatalError(def.getLoc(), def.getName() + ": printer cannot be empty");
-  // Assembly format requires accessors to be generated.
-  if (asmFormat && !def.genAccessors()) {
-    PrintFatalError(def.getLoc(),
-                    def.getName() +
-                        ": the generated printer from 'assemblyFormat' "
-                        "requires 'genAccessors' to be true");
-  }
-
-  // Generate the parser and printer bodies.
-  if (asmFormat)
-    return generateAttrOrTypeFormat(def, parser, printer);
-
-  FmtContext ctx = FmtContext({{"_parser", "odsParser"},
-                               {"_printer", "odsPrinter"},
-                               {"_type", "odsType"}});
-  if (parserCode) {
-    ctx.addSubst("_ctxt", "odsParser.getContext()");
-    parser.indent().getStream().printReindented(tgfmt(*parserCode, &ctx).str());
-  }
-  if (printerCode) {
-    ctx.addSubst("_ctxt", "odsPrinter.getContext()");
-    printer.indent().getStream().printReindented(
-        tgfmt(*printerCode, &ctx).str());
-  }
-}
-
 //===----------------------------------------------------------------------===//
 // Interface Method Emission
 
@@ -829,18 +782,21 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
   for (auto &def : defs) {
     if (!def.getMnemonic())
       continue;
+    bool hasParserPrinterDecl =
+        def.hasCustomAssemblyFormat() || def.getAssemblyFormat();
     std::string defClass = strfmt(
         "{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName());
+
     // If the def has no parameters or parser code, invoke a normal `get`.
     std::string parseOrGet =
-        def.needsParserPrinter() || def.hasGeneratedParser()
+        hasParserPrinterDecl
             ? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "")
             : "get(parser.getContext())";
     parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet);
 
     // If the def has no parameters and no printer, just print the mnemonic.
     StringRef printDef = "";
-    if (def.needsParserPrinter() || def.hasGeneratedPrinter())
+    if (hasParserPrinterDecl)
       printDef = "\nt.print(printer);";
     printer.body() << llvm::formatv(printValue, defClass, printDef);
   }

diff  --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp
index b3be48176d454..9b9c64884a360 100644
--- a/mlir/tools/mlir-tblgen/OpDocGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp
@@ -253,8 +253,7 @@ static void emitAttrOrTypeDefDoc(const AttrOrTypeDef &def, raw_ostream &os) {
     os << "\n" << def.getSummary() << "\n";
 
   // Emit the syntax if present.
-  if (def.getMnemonic() && def.getPrinterCode() == StringRef() &&
-      def.getParserCode() == StringRef())
+  if (def.getMnemonic() && !def.hasCustomAssemblyFormat())
     emitAttrOrTypeDefAssemblyFormat(def, os);
 
   // Emit the description if present.


        


More information about the Mlir-commits mailing list