[llvm-branch-commits] [mlir] [mlir] Start moving some builtin type formats to the dialect (PR #80421)
Markus Böck via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Feb 2 04:15:02 PST 2024
https://github.com/zero9178 updated https://github.com/llvm/llvm-project/pull/80421
>From e67e980cd077de77bb1683f4a9ad948f13ad4289 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Markus=20B=C3=B6ck?= <markus.boeck02 at gmail.com>
Date: Fri, 2 Feb 2024 12:47:05 +0100
Subject: [PATCH] [mlir] Start moving some builtin type formats to the dialect
Most types and attributes in the builtin dialect are parsed and printed using special-purpose printers and parsers for that type. They also use the low-level `Printer` rather than the `AsmPrinter`, making the implementations inconsistent compared to all other dialects in MLIR.
This PR starts moving some builtin types to be parsed using the usual `print` and `parse` methods like all other MLIR dialects. This has the following advantages:
* The implementation now looks like any other dialect's types
* It is now possible to use `assemblyFormat` for builtin types and attributes
* The code can be easily moved to other dialects if desired
* Arguably better layering and less code
* As a side-effect, it is now also possible to write `!builtin.<type>` for any types moved
A future benefit would include being able to print types and attributes in stripped format as well (e.g. `<f32>` vs `complex<f32>`), just like all other dialect types and attributes. This is currently explicitly disabled as it causes a LOT of changes in IR syntax and I believe some ambiguities in the parser.
For the purpose of reviewing and incremental development, this PR only moves `tuple`, `tensor`, `none`, `memref` and `complex`. The plan is to eventually move all attributes and types where the current syntax can be implemented within the dialect.
For backwards compatibility with the existing syntax, the builtin dialect is special-cased in the printer where the `builtin.` prefix is omitted.
---
mlir/include/mlir/IR/BuiltinDialect.td | 2 +-
mlir/include/mlir/IR/BuiltinTypes.td | 25 ++-
mlir/include/mlir/IR/OpImplementation.h | 5 +
mlir/lib/AsmParser/DialectSymbolParser.cpp | 24 +++
mlir/lib/AsmParser/Parser.h | 24 ++-
mlir/lib/AsmParser/TypeParser.cpp | 211 +--------------------
mlir/lib/IR/AsmPrinter.cpp | 72 ++-----
mlir/lib/IR/BuiltinTypes.cpp | 150 +++++++++++++++
mlir/test/IR/invalid-builtin-types.mlir | 10 +-
mlir/test/IR/invalid.mlir | 4 +-
mlir/test/IR/qualified-builtin.mlir | 11 ++
11 files changed, 249 insertions(+), 289 deletions(-)
create mode 100644 mlir/test/IR/qualified-builtin.mlir
diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td
index c131107634b44..a8627170288c9 100644
--- a/mlir/include/mlir/IR/BuiltinDialect.td
+++ b/mlir/include/mlir/IR/BuiltinDialect.td
@@ -22,7 +22,7 @@ def Builtin_Dialect : Dialect {
let name = "builtin";
let cppNamespace = "::mlir";
let useDefaultAttributePrinterParser = 0;
- let useDefaultTypePrinterParser = 0;
+ let useDefaultTypePrinterParser = 1;
let extraClassDeclaration = [{
private:
// Register the builtin Attributes.
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c32..f3a51d2155040 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -25,7 +25,8 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// Base class for Builtin dialect types.
class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
- : TypeDef<Builtin_Dialect, name, traits, baseCppClass> {
+ : TypeDef<Builtin_Dialect, name, !listconcat(traits, [PrintTypeQualified]),
+ baseCppClass> {
let mnemonic = ?;
let typeName = "builtin." # typeMnemonic;
}
@@ -62,6 +63,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "complex";
+ let assemblyFormat = "`<` $elementType `>`";
}
//===----------------------------------------------------------------------===//
@@ -668,6 +672,9 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "memref";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -698,6 +705,8 @@ def Builtin_None : Builtin_Type<"None", "none"> {
let extraClassDeclaration = [{
static NoneType get(MLIRContext *context);
}];
+
+ let mnemonic = "none";
}
//===----------------------------------------------------------------------===//
@@ -849,6 +858,9 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "tensor";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -884,7 +896,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
tuple<i32, f32, tensor<i1>, i5>
```
}];
- let parameters = (ins "ArrayRef<Type>":$types);
+ let parameters = (ins OptionalArrayRefParameter<"Type">:$types);
let builders = [
TypeBuilder<(ins "TypeRange":$elementTypes), [{
return $_get($_ctxt, elementTypes);
@@ -916,6 +928,9 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
return getTypes()[index];
}
}];
+
+ let mnemonic = "tuple";
+ let assemblyFormat = "`<` (`>`) : ($types^ `>`)?";
}
//===----------------------------------------------------------------------===//
@@ -994,6 +1009,9 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "memref";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -1043,6 +1061,9 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "tensor";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 50e6cc59ca459..2a5587d43901a 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -187,6 +187,11 @@ class AsmPrinter {
/// provide a valid type for the attribute.
virtual void printAttributeWithoutType(Attribute attr);
+ /// Print the given attribute without its type if and only if the type is the
+ /// default type for the given attribute.
+ /// E.g. '1 : i64' is printed as just '1'.
+ virtual void printAttributeWithoutDefaultType(Attribute attr);
+
/// Print the alias for the given attribute, return failure if no alias could
/// be printed.
virtual LogicalResult printAlias(Attribute attr);
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 80cce7e6ae43f..400d26398afc3 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -322,6 +322,30 @@ Type Parser::parseExtendedType() {
});
}
+Type Parser::parseExtendedBuiltinType() {
+ // Initially set to just the mnemonic of the type.
+ llvm::StringRef symbolData = getToken().getSpelling();
+ const char *startOfTypePos = symbolData.data();
+ consumeToken();
+ // Extend 'symbolData' to include the body if it is not a singleton type.
+ // Note that all types in the builtin type always use the pretty dialect form
+ // aka 'dialect.mnemonic<body>'.
+ if (getToken().is(Token::less))
+ if (failed(parseDialectSymbolBody(symbolData)))
+ return nullptr;
+
+ const char *endOfTypePos = getToken().getLoc().getPointer();
+
+ // With the body of the type captured, hand it off to the dialect parser.
+ resetToken(startOfTypePos);
+ CustomDialectAsmParser customParser(symbolData, *this);
+ Type type = builtinDialect->parseType(customParser);
+
+ // Move the lexer past the type.
+ resetToken(endOfTypePos);
+ return type;
+}
+
//===----------------------------------------------------------------------===//
// mlir::parseAttribute/parseType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index b959e67b8e258..73080c88ff6b0 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -11,6 +11,7 @@
#include "ParserState.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/OpImplementation.h"
#include <optional>
@@ -28,9 +29,14 @@ class Parser {
using Delimiter = OpAsmParser::Delimiter;
Builder builder;
+ /// Cached instance of the builtin dialect for parsing builtins.
+ Dialect *builtinDialect;
Parser(ParserState &state)
- : builder(state.config.getContext()), state(state) {}
+ : builder(state.config.getContext()),
+ builtinDialect(
+ builder.getContext()->getLoadedDialect<BuiltinDialect>()),
+ state(state) {}
// Helper methods to get stuff from the parser-global state.
ParserState &getState() const { return state; }
@@ -192,27 +198,19 @@ class Parser {
/// Parse an arbitrary type.
Type parseType();
- /// Parse a complex type.
- Type parseComplexType();
-
/// Parse an extended type.
Type parseExtendedType();
+ /// Parse an extended type from the builtin dialect where the '!builtin'
+ /// prefix is missing.
+ Type parseExtendedBuiltinType();
+
/// Parse a function type.
Type parseFunctionType();
- /// Parse a memref type.
- Type parseMemRefType();
-
/// Parse a non function type.
Type parseNonFunctionType();
- /// Parse a tensor type.
- Type parseTensorType();
-
- /// Parse a tuple type.
- Type parseTupleType();
-
/// Parse a vector type.
VectorType parseVectorType();
ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 5da931b77b3be..95df69b899b8a 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -11,12 +11,9 @@
//===----------------------------------------------------------------------===//
#include "Parser.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -123,29 +120,6 @@ ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
return success();
}
-/// Parse a complex type.
-///
-/// complex-type ::= `complex` `<` type `>`
-///
-Type Parser::parseComplexType() {
- consumeToken(Token::kw_complex);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in complex type"))
- return nullptr;
-
- SMLoc elementTypeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType ||
- parseToken(Token::greater, "expected '>' in complex type"))
- return nullptr;
- if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
- return emitError(elementTypeLoc, "invalid element type for complex"),
- nullptr;
-
- return ComplexType::get(elementType);
-}
-
/// Parse a function type.
///
/// function-type ::= type-list-parens `->` function-result-type
@@ -162,95 +136,6 @@ Type Parser::parseFunctionType() {
return builder.getFunctionType(arguments, results);
}
-/// Parse a memref type.
-///
-/// memref-type ::= ranked-memref-type | unranked-memref-type
-///
-/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
-/// (`,` layout-specification)? (`,` memory-space)? `>`
-///
-/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
-///
-/// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
-/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
-/// layout-specification ::= semi-affine-map | strided-layout | attribute
-/// memory-space ::= integer-literal | attribute
-///
-Type Parser::parseMemRefType() {
- SMLoc loc = getToken().getLoc();
- consumeToken(Token::kw_memref);
-
- if (parseToken(Token::less, "expected '<' in memref type"))
- return nullptr;
-
- bool isUnranked;
- SmallVector<int64_t, 4> dimensions;
-
- if (consumeIf(Token::star)) {
- // This is an unranked memref type.
- isUnranked = true;
- if (parseXInDimensionList())
- return nullptr;
-
- } else {
- isUnranked = false;
- if (parseDimensionListRanked(dimensions))
- return nullptr;
- }
-
- // Parse the element type.
- auto typeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType)
- return nullptr;
-
- // Check that memref is formed from allowed types.
- if (!BaseMemRefType::isValidElementType(elementType))
- return emitError(typeLoc, "invalid memref element type"), nullptr;
-
- MemRefLayoutAttrInterface layout;
- Attribute memorySpace;
-
- auto parseElt = [&]() -> ParseResult {
- // Either it is MemRefLayoutAttrInterface or memory space attribute.
- Attribute attr = parseAttribute();
- if (!attr)
- return failure();
-
- if (isa<MemRefLayoutAttrInterface>(attr)) {
- layout = cast<MemRefLayoutAttrInterface>(attr);
- } else if (memorySpace) {
- return emitError("multiple memory spaces specified in memref type");
- } else {
- memorySpace = attr;
- return success();
- }
-
- if (isUnranked)
- return emitError("cannot have affine map for unranked memref type");
- if (memorySpace)
- return emitError("expected memory space to be last in memref type");
-
- return success();
- };
-
- // Parse a list of mappings and address space if present.
- if (!consumeIf(Token::greater)) {
- // Parse comma separated list of affine maps, followed by memory space.
- if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
- parseCommaSeparatedListUntil(Token::greater, parseElt,
- /*allowEmptyList=*/false)) {
- return nullptr;
- }
- }
-
- if (isUnranked)
- return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
-
- return getChecked<MemRefType>(loc, dimensions, elementType, layout,
- memorySpace);
-}
-
/// Parse any type except the function type.
///
/// non-function-type ::= integer-type
@@ -272,14 +157,12 @@ Type Parser::parseNonFunctionType() {
switch (getToken().getKind()) {
default:
return (emitWrongTokenError("expected non-function type"), nullptr);
- case Token::kw_memref:
- return parseMemRefType();
case Token::kw_tensor:
- return parseTensorType();
+ case Token::kw_memref:
case Token::kw_complex:
- return parseComplexType();
case Token::kw_tuple:
- return parseTupleType();
+ case Token::kw_none:
+ return parseExtendedBuiltinType();
case Token::kw_vector:
return parseVectorType();
// integer-type
@@ -344,11 +227,6 @@ Type Parser::parseNonFunctionType() {
consumeToken(Token::kw_index);
return builder.getIndexType();
- // none-type
- case Token::kw_none:
- consumeToken(Token::kw_none);
- return builder.getNoneType();
-
// extended type
case Token::exclamation_identifier:
return parseExtendedType();
@@ -361,89 +239,6 @@ Type Parser::parseNonFunctionType() {
}
}
-/// Parse a tensor type.
-///
-/// tensor-type ::= `tensor` `<` dimension-list type `>`
-/// dimension-list ::= dimension-list-ranked | `*x`
-///
-Type Parser::parseTensorType() {
- consumeToken(Token::kw_tensor);
-
- if (parseToken(Token::less, "expected '<' in tensor type"))
- return nullptr;
-
- bool isUnranked;
- SmallVector<int64_t, 4> dimensions;
-
- if (consumeIf(Token::star)) {
- // This is an unranked tensor type.
- isUnranked = true;
-
- if (parseXInDimensionList())
- return nullptr;
-
- } else {
- isUnranked = false;
- if (parseDimensionListRanked(dimensions))
- return nullptr;
- }
-
- // Parse the element type.
- auto elementTypeLoc = getToken().getLoc();
- auto elementType = parseType();
-
- // Parse an optional encoding attribute.
- Attribute encoding;
- if (consumeIf(Token::comma)) {
- auto parseResult = parseOptionalAttribute(encoding);
- if (parseResult.has_value()) {
- if (failed(parseResult.value()))
- return nullptr;
- if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
- if (failed(v.verifyEncoding(dimensions, elementType,
- [&] { return emitError(); })))
- return nullptr;
- }
- }
- }
-
- if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
- return nullptr;
- if (!TensorType::isValidElementType(elementType))
- return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
-
- if (isUnranked) {
- if (encoding)
- return emitError("cannot apply encoding to unranked tensor"), nullptr;
- return UnrankedTensorType::get(elementType);
- }
- return RankedTensorType::get(dimensions, elementType, encoding);
-}
-
-/// Parse a tuple type.
-///
-/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
-///
-Type Parser::parseTupleType() {
- consumeToken(Token::kw_tuple);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in tuple type"))
- return nullptr;
-
- // Check for an empty tuple by directly parsing '>'.
- if (consumeIf(Token::greater))
- return TupleType::get(getContext());
-
- // Parse the element types and the '>'.
- SmallVector<Type, 4> types;
- if (parseTypeListNoParens(types) ||
- parseToken(Token::greater, "expected '>' in tuple type"))
- return nullptr;
-
- return TupleType::get(getContext(), types);
-}
-
/// Parse a vector type.
///
/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f8..0679d4135048a 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2132,6 +2132,13 @@ static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
/// Print the given dialect symbol to the stream.
static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
StringRef dialectName, StringRef symString) {
+ // Treat the builtin dialect special by eliding the '<symPrefix>builtin'
+ // prefix.
+ if (dialectName == "builtin") {
+ os << symString;
+ return;
+ }
+
os << symPrefix << dialectName;
// If this symbol name is simple enough, print it directly in pretty form,
@@ -2599,64 +2606,6 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
printType(vectorTy.getElementType());
os << '>';
})
- .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
- os << "tensor<";
- printDimensionList(tensorTy.getShape());
- if (!tensorTy.getShape().empty())
- os << 'x';
- printType(tensorTy.getElementType());
- // Only print the encoding attribute value if set.
- if (tensorTy.getEncoding()) {
- os << ", ";
- printAttribute(tensorTy.getEncoding());
- }
- os << '>';
- })
- .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
- os << "tensor<*x";
- printType(tensorTy.getElementType());
- os << '>';
- })
- .Case<MemRefType>([&](MemRefType memrefTy) {
- os << "memref<";
- printDimensionList(memrefTy.getShape());
- if (!memrefTy.getShape().empty())
- os << 'x';
- printType(memrefTy.getElementType());
- MemRefLayoutAttrInterface layout = memrefTy.getLayout();
- if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
- os << ", ";
- printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
- }
- // Only print the memory space if it is the non-default one.
- if (memrefTy.getMemorySpace()) {
- os << ", ";
- printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
- }
- os << '>';
- })
- .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
- os << "memref<*x";
- printType(memrefTy.getElementType());
- // Only print the memory space if it is the non-default one.
- if (memrefTy.getMemorySpace()) {
- os << ", ";
- printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
- }
- os << '>';
- })
- .Case<ComplexType>([&](ComplexType complexTy) {
- os << "complex<";
- printType(complexTy.getElementType());
- os << '>';
- })
- .Case<TupleType>([&](TupleType tupleTy) {
- os << "tuple<";
- interleaveComma(tupleTy.getTypes(),
- [&](Type type) { printType(type); });
- os << '>';
- })
- .Case<NoneType>([&](Type) { os << "none"; })
.Default([&](Type type) { return printDialectType(type); });
}
@@ -2799,6 +2748,13 @@ void AsmPrinter::printAttributeWithoutType(Attribute attr) {
impl->printAttribute(attr, Impl::AttrTypeElision::Must);
}
+void AsmPrinter::printAttributeWithoutDefaultType(Attribute attr) {
+ assert(
+ impl &&
+ "expected AsmPrinter::printAttributeWithoutDefaultType to be overriden");
+ impl->printAttribute(attr, Impl::AttrTypeElision::May);
+}
+
void AsmPrinter::printKeywordOrString(StringRef keyword) {
assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
::printKeywordOrString(keyword, impl->getStream());
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b8ee3d452803..e160c0ff4c33d 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -10,10 +10,13 @@
#include "TypeDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/APFloat.h"
@@ -25,6 +28,52 @@
using namespace mlir;
using namespace mlir::detail;
+//===----------------------------------------------------------------------===//
+// Custom printing and parsing
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseMemRefDimension(AsmParser &parser,
+ SmallVectorImpl<int64_t> &dimension,
+ bool &isUnranked) {
+ if (succeeded(parser.parseOptionalStar())) {
+ isUnranked = true;
+ return parser.parseXInDimensionList();
+ }
+
+ isUnranked = false;
+ return parser.parseDimensionList(dimension);
+}
+
+static ParseResult parseMemRefSpaceAndLayout(AsmParser &parser,
+ MemRefLayoutAttrInterface &layout,
+ Attribute &memorySpace,
+ bool isUnranked) {
+ while (succeeded(parser.parseOptionalComma())) {
+ SMLoc loc = parser.getCurrentLocation();
+ Attribute attr;
+ if (parser.parseAttribute(attr))
+ return failure();
+
+ if (auto memRefLayout = dyn_cast<MemRefLayoutAttrInterface>(attr)) {
+ layout = memRefLayout;
+ } else if (memorySpace) {
+ return parser.emitError(
+ loc, "multiple memory spaces specified in memref type");
+ } else {
+ memorySpace = attr;
+ continue;
+ }
+
+ if (isUnranked)
+ return parser.emitError(
+ loc, "cannot have affine map for unranked memref type");
+ if (memorySpace)
+ return parser.emitError(
+ loc, "expected memory space to be last in memref type");
+ }
+ return success();
+}
+
//===----------------------------------------------------------------------===//
/// Tablegen Type Definitions
//===----------------------------------------------------------------------===//
@@ -340,6 +389,46 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
return checkTensorElementType(emitError, elementType);
}
+Type RankedTensorType::parse(AsmParser &parser) {
+ SmallVector<int64_t> dimension;
+ Type elementType;
+ bool isUnranked;
+ if (parser.parseLess() ||
+ parseMemRefDimension(parser, dimension, isUnranked) ||
+ parser.parseType(elementType))
+ return nullptr;
+
+ Attribute encoding;
+ if (succeeded(parser.parseOptionalComma())) {
+ SMLoc loc = parser.getCurrentLocation();
+ if (parser.parseAttribute(encoding))
+ return nullptr;
+
+ if (isUnranked) {
+ parser.emitError(loc, "cannot apply encoding to unranked tensor");
+ return nullptr;
+ }
+ }
+
+ if (failed(parser.parseGreater()))
+ return nullptr;
+
+ if (isUnranked)
+ return parser.getChecked<UnrankedTensorType>(elementType);
+ return parser.getChecked<RankedTensorType>(dimension, elementType, encoding);
+}
+
+void RankedTensorType::print(AsmPrinter &printer) const {
+ printer << '<';
+ printer.printDimensionList(getShape());
+ if (!getShape().empty())
+ printer << 'x';
+ printer << getElementType();
+ if (getEncoding())
+ printer << ", " << getEncoding();
+ printer << '>';
+}
+
//===----------------------------------------------------------------------===//
// UnrankedTensorType
//===----------------------------------------------------------------------===//
@@ -350,6 +439,14 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
return checkTensorElementType(emitError, elementType);
}
+Type UnrankedTensorType::parse(AsmParser &parser) {
+ return RankedTensorType::parse(parser);
+}
+
+void UnrankedTensorType::print(AsmPrinter &printer) const {
+ printer << "<*x" << getElementType() << ">";
+}
+
//===----------------------------------------------------------------------===//
// BaseMemRefType
//===----------------------------------------------------------------------===//
@@ -652,6 +749,44 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+Type MemRefType::parse(AsmParser &parser) {
+ SmallVector<int64_t> dimension;
+ Type elementType;
+ MemRefLayoutAttrInterface layout;
+ Attribute memorySpace;
+ bool isUnranked;
+ if (parser.parseLess() ||
+ parseMemRefDimension(parser, dimension, isUnranked) ||
+ parser.parseType(elementType) ||
+ parseMemRefSpaceAndLayout(parser, layout, memorySpace, isUnranked) ||
+ parser.parseGreater())
+ return nullptr;
+
+ if (isUnranked)
+ return parser.getChecked<UnrankedMemRefType>(elementType, memorySpace);
+ return parser.getChecked<MemRefType>(dimension, elementType, layout,
+ memorySpace);
+}
+
+void MemRefType::print(AsmPrinter &printer) const {
+ printer << '<';
+ printer.printDimensionList(getShape());
+ if (!getShape().empty())
+ printer << 'x';
+ printer << getElementType();
+ MemRefLayoutAttrInterface layout = getLayout();
+ if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
+ printer << ", ";
+ printer.printAttributeWithoutDefaultType(getLayout());
+ }
+ // Only print the memory space if it is the non-default one.
+ if (getMemorySpace()) {
+ printer << ", ";
+ printer.printAttributeWithoutDefaultType(getMemorySpace());
+ }
+ printer << '>';
+}
+
//===----------------------------------------------------------------------===//
// UnrankedMemRefType
//===----------------------------------------------------------------------===//
@@ -672,6 +807,21 @@ UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+Type UnrankedMemRefType::parse(AsmParser &parser) {
+ return MemRefType::parse(parser);
+}
+
+void UnrankedMemRefType::print(AsmPrinter &printer) const {
+ printer << "<*x";
+ printer << getElementType();
+ // Only print the memory space if it is the non-default one.
+ if (getMemorySpace()) {
+ printer << ", ";
+ printer.printAttributeWithoutDefaultType(getMemorySpace());
+ }
+ printer << '>';
+}
+
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
// i.e. single term). Accumulate the AffineExpr into the existing one.
static void extractStridesFromTerm(AffineExpr e,
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 9884212e916c1..04995bf7338aa 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -27,7 +27,7 @@ func.func @illegalunrankedmemrefelementtype(memref<*xtensor<i8>>) -> () // expec
// -----
// Test no map in memref type.
-func.func @memrefs(memref<2x4xi8, >) // expected-error {{expected list element}}
+func.func @memrefs(memref<2x4xi8, >) // expected-error {{expected attribute}}
// -----
// Test non-existent map in memref type.
@@ -74,7 +74,7 @@ func.func private @memref_unfinished_strided() -> memref<?x?xf32, strided<>>
// -----
-// expected-error @below {{expected a 64-bit signed integer or '?'}}
+// expected-error @below {{unbalanced '[' character in pretty dialect name}}
func.func private @memref_unfinished_stride_list() -> memref<?x?xf32, strided<[>>
// -----
@@ -94,7 +94,7 @@ func.func private @memref_missing_offset_value() -> memref<?x?xf32, strided<[],
// -----
-// expected-error @below {{expected '>'}}
+// expected-error @below {{unbalanced '<' character in pretty dialect name}}
func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<[], offset: 32)>
// -----
@@ -170,12 +170,12 @@ func.func @bad_complex(complex<memref<2x4xi8>>)
// -----
-// expected-error @+1 {{expected '<' in complex type}}
+// expected-error @+1 {{expected '<'}}
func.func @bad_complex(complex memref<2x4xi8>>)
// -----
-// expected-error @+1 {{expected '>' in complex type}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
func.func @bad_complex(complex<i32)
// -----
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 861f4ef6c020d..1e01b477b1adb 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -419,12 +419,12 @@ func.func @invalid_unknown_type_dialect_name() -> !invalid.dialect<!x@#]!@#>
// -----
-// expected-error @+1 {{expected '<' in tuple type}}
+// expected-error @+1 {{expected '<'}}
func.func @invalid_tuple_missing_less(tuple i32>)
// -----
-// expected-error @+1 {{expected '>' in tuple type}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
func.func @invalid_tuple_missing_greater(tuple<i32)
// -----
diff --git a/mlir/test/IR/qualified-builtin.mlir b/mlir/test/IR/qualified-builtin.mlir
new file mode 100644
index 0000000000000..a2f9e63ea66ba
--- /dev/null
+++ b/mlir/test/IR/qualified-builtin.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: @test1
+// CHECK: -> tuple<>
+func.func private @test1() -> !builtin.tuple<>
+
+// CHECK-LABEL: @test2
+// CHECK: -> none
+func.func private @test2() -> !builtin.none
+
+
More information about the llvm-branch-commits
mailing list