[Mlir-commits] [mlir] 6bc767c - [mlir] Add a DialectAsmParser::getChecked method
River Riddle
llvmlistbot at llvm.org
Thu Mar 4 12:00:37 PST 2021
Author: River Riddle
Date: 2021-03-04T11:53:24-08:00
New Revision: 6bc767cd071ccdb41b5532f7d9cae22999e0fac4
URL: https://github.com/llvm/llvm-project/commit/6bc767cd071ccdb41b5532f7d9cae22999e0fac4
DIFF: https://github.com/llvm/llvm-project/commit/6bc767cd071ccdb41b5532f7d9cae22999e0fac4.diff
LOG: [mlir] Add a DialectAsmParser::getChecked method
This function simplifies calling the getChecked methods on Attributes and Types from within the parser, and removes any need to use `getEncodedSourceLocation` for these methods (by using an SMLoc instead). This is much more efficient than using an mlir::Location, as the encoding process to produce an mlir::Location is inefficient and undesirable for parsing (locations used during parsing should not persist afterwards unless otherwise necessary).
Differential Revision: https://reviews.llvm.org/D97900
Added:
Modified:
mlir/include/mlir/IR/DialectImplementation.h
mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
mlir/lib/Dialect/Quant/IR/TypeParser.cpp
mlir/lib/Parser/DialectSymbolParser.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index ba6f9df3f1de..f1a53df3c2ca 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -121,6 +121,10 @@ class DialectAsmParser {
virtual llvm::SMLoc getNameLoc() const = 0;
/// Re-encode the given source location as an MLIR location and return it.
+ /// Note: This method should only be used when a `Location` is necessary, as
+ /// the encoding process is not efficient. In other cases a more suitable
+ /// alternative should be used, such as the `getChecked` methods defined
+ /// below.
virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
/// Returns the full specification of the symbol being parsed. This allows for
@@ -163,6 +167,22 @@ class DialectAsmParser {
return success();
}
+ /// Invoke the `getChecked` method of the given Attribute or Type class, using
+ /// the provided location to emit errors in the case of failure. Note that
+ /// unlike `OpBuilder::getType`, this method does not implicitly insert a
+ /// context parameter.
+ template <typename T, typename... ParamsT>
+ T getChecked(llvm::SMLoc loc, ParamsT &&...params) {
+ return T::getChecked([&] { return emitError(loc); },
+ std::forward<ParamsT>(params)...);
+ }
+ /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
+ /// errors.
+ template <typename T, typename... ParamsT> T getChecked(ParamsT &&...params) {
+ return T::getChecked([&] { return emitError(getNameLoc()); },
+ std::forward<ParamsT>(params)...);
+ }
+
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index f32137a78479..921926692a99 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -178,7 +178,7 @@ static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
/// Parses an LLVM dialect function type.
/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
- Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+ llvm::SMLoc loc = parser.getCurrentLocation();
Type returnType;
if (parser.parseLess() || dispatchParse(parser, returnType) ||
parser.parseLParen())
@@ -187,8 +187,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
// Function type without arguments.
if (succeeded(parser.parseOptionalRParen())) {
if (succeeded(parser.parseGreater()))
- return LLVMFunctionType::getChecked(loc, returnType, llvm::None,
- /*isVarArg=*/false);
+ return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
+ /*isVarArg=*/false);
return LLVMFunctionType();
}
@@ -198,8 +198,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
if (succeeded(parser.parseOptionalEllipsis())) {
if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
return LLVMFunctionType();
- return LLVMFunctionType::getChecked(loc, returnType, argTypes,
- /*isVarArg=*/true);
+ return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
+ /*isVarArg=*/true);
}
Type arg;
@@ -210,14 +210,14 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
return LLVMFunctionType();
- return LLVMFunctionType::getChecked(loc, returnType, argTypes,
- /*isVarArg=*/false);
+ return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
+ /*isVarArg=*/false);
}
/// Parses an LLVM dialect pointer type.
/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
- Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+ llvm::SMLoc loc = parser.getCurrentLocation();
Type elementType;
if (parser.parseLess() || dispatchParse(parser, elementType))
return LLVMPointerType();
@@ -228,7 +228,7 @@ static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
return LLVMPointerType();
if (failed(parser.parseGreater()))
return LLVMPointerType();
- return LLVMPointerType::getChecked(loc, elementType, addressSpace);
+ return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
}
/// Parses an LLVM dialect vector type.
@@ -238,7 +238,7 @@ static Type parseVectorType(DialectAsmParser &parser) {
SmallVector<int64_t, 2> dims;
llvm::SMLoc dimPos, typePos;
Type elementType;
- Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+ llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
parser.getCurrentLocation(&typePos) ||
@@ -259,13 +259,13 @@ static Type parseVectorType(DialectAsmParser &parser) {
bool isScalable = dims.size() == 2;
if (isScalable)
- return LLVMScalableVectorType::getChecked(loc, elementType, dims[1]);
+ return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
if (elementType.isSignlessIntOrFloat()) {
parser.emitError(typePos)
<< "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
return Type();
}
- return LLVMFixedVectorType::getChecked(loc, elementType, dims[0]);
+ return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
}
/// Parses an LLVM dialect array type.
@@ -274,7 +274,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
SmallVector<int64_t, 1> dims;
llvm::SMLoc sizePos;
Type elementType;
- Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+ llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
dispatchParse(parser, elementType) || parser.parseGreater())
@@ -285,7 +285,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
return LLVMArrayType();
}
- return LLVMArrayType::getChecked(loc, elementType, dims[0]);
+ return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
}
/// Attempts to set the body of an identified structure type. Reports a parsing
diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 636cf7ddb96c..16fe1f0ebdee 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -117,7 +117,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
/// storage-range ::= integer-literal `:` integer-literal
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
-static Type parseAnyType(DialectAsmParser &parser, Location loc) {
+static Type parseAnyType(DialectAsmParser &parser) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
@@ -155,9 +155,8 @@ static Type parseAnyType(DialectAsmParser &parser, Location loc) {
return nullptr;
}
- return AnyQuantizedType::getChecked(loc, typeFlags, storageType,
- expressedType, storageTypeMin,
- storageTypeMax);
+ return parser.getChecked<AnyQuantizedType>(
+ typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
}
static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
@@ -192,7 +191,7 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
/// axis-spec ::= `:` integer-literal
/// scale-zero ::= float-literal `:` integer-literal
/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
-static Type parseUniformType(DialectAsmParser &parser, Location loc) {
+static Type parseUniformType(DialectAsmParser &parser) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
@@ -279,14 +278,14 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
if (isPerAxis) {
ArrayRef<double> scalesRef(scales.begin(), scales.end());
ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
- return UniformQuantizedPerAxisType::getChecked(
- loc, typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
+ return parser.getChecked<UniformQuantizedPerAxisType>(
+ typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
quantizedDimension, storageTypeMin, storageTypeMax);
}
- return UniformQuantizedType::getChecked(
- loc, typeFlags, storageType, expressedType, scales.front(),
- zeroPoints.front(), storageTypeMin, storageTypeMax);
+ return parser.getChecked<UniformQuantizedType>(
+ typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
+ storageTypeMin, storageTypeMax);
}
/// Parses an CalibratedQuantizedType.
@@ -295,7 +294,7 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
/// expressed-spec ::= expressed-type `<` calibrated-range `>`
/// expressed-type ::= `f` integer-literal
/// calibrated-range ::= float-literal `:` float-literal
-static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
+static Type parseCalibratedType(DialectAsmParser &parser) {
FloatType expressedType;
double min;
double max;
@@ -314,24 +313,22 @@ static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
return nullptr;
}
- return CalibratedQuantizedType::getChecked(loc, expressedType, min, max);
+ return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
}
/// Parse a type registered to this dialect.
Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
- Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
-
// All types start with an identifier that we switch on.
StringRef typeNameSpelling;
if (failed(parser.parseKeyword(&typeNameSpelling)))
return nullptr;
if (typeNameSpelling == "uniform")
- return parseUniformType(parser, loc);
+ return parseUniformType(parser);
if (typeNameSpelling == "any")
- return parseAnyType(parser, loc);
+ return parseAnyType(parser);
if (typeNameSpelling == "calibrated")
- return parseCalibratedType(parser, loc);
+ return parseCalibratedType(parser);
parser.emitError(parser.getNameLoc(),
"unknown quantized type " + typeNameSpelling);
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 6993b8eb543a..46096a59f8ac 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -524,7 +524,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
// Otherwise, form a new opaque attribute.
return OpaqueAttr::getChecked(
- getEncodedSourceLocation(loc),
+ [&] { return emitError(loc); },
Identifier::get(dialectName, state.context), symbolData,
attrType ? attrType : NoneType::get(state.context));
});
@@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {
// Otherwise, form a new opaque type.
return OpaqueType::getChecked(
- getEncodedSourceLocation(loc),
+ [&] { return emitError(loc); },
Identifier::get(dialectName, state.context), symbolData);
});
}
More information about the Mlir-commits
mailing list