[Mlir-commits] [mlir] 6bc767c - [mlir] Add a DialectAsmParser::getChecked method

River Riddle llvmlistbot at llvm.org
Thu Mar 4 12:00:37 PST 2021


Author: River Riddle
Date: 2021-03-04T11:53:24-08:00
New Revision: 6bc767cd071ccdb41b5532f7d9cae22999e0fac4

URL: https://github.com/llvm/llvm-project/commit/6bc767cd071ccdb41b5532f7d9cae22999e0fac4
DIFF: https://github.com/llvm/llvm-project/commit/6bc767cd071ccdb41b5532f7d9cae22999e0fac4.diff

LOG: [mlir] Add a DialectAsmParser::getChecked method

This function simplifies calling the getChecked methods on Attributes and Types from within the parser, and removes any need to use `getEncodedSourceLocation` for these methods (by using an SMLoc instead). This is much more efficient than using an mlir::Location, as the encoding process to produce an mlir::Location is inefficient and undesirable for parsing (locations used during parsing should not persist afterwards unless otherwise necessary).

Differential Revision: https://reviews.llvm.org/D97900

Added: 
    

Modified: 
    mlir/include/mlir/IR/DialectImplementation.h
    mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
    mlir/lib/Dialect/Quant/IR/TypeParser.cpp
    mlir/lib/Parser/DialectSymbolParser.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index ba6f9df3f1de..f1a53df3c2ca 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -121,6 +121,10 @@ class DialectAsmParser {
   virtual llvm::SMLoc getNameLoc() const = 0;
 
   /// Re-encode the given source location as an MLIR location and return it.
+  /// Note: This method should only be used when a `Location` is necessary, as
+  /// the encoding process is not efficient. In other cases a more suitable
+  /// alternative should be used, such as the `getChecked` methods defined
+  /// below.
   virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
 
   /// Returns the full specification of the symbol being parsed. This allows for
@@ -163,6 +167,22 @@ class DialectAsmParser {
     return success();
   }
 
+  /// Invoke the `getChecked` method of the given Attribute or Type class, using
+  /// the provided location to emit errors in the case of failure. Note that
+  /// unlike `OpBuilder::getType`, this method does not implicitly insert a
+  /// context parameter.
+  template <typename T, typename... ParamsT>
+  T getChecked(llvm::SMLoc loc, ParamsT &&...params) {
+    return T::getChecked([&] { return emitError(loc); },
+                         std::forward<ParamsT>(params)...);
+  }
+  /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
+  /// errors.
+  template <typename T, typename... ParamsT> T getChecked(ParamsT &&...params) {
+    return T::getChecked([&] { return emitError(getNameLoc()); },
+                         std::forward<ParamsT>(params)...);
+  }
+
   //===--------------------------------------------------------------------===//
   // Token Parsing
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
index f32137a78479..921926692a99 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp
@@ -178,7 +178,7 @@ static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
 /// Parses an LLVM dialect function type.
 ///   llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
 static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
-  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+  llvm::SMLoc loc = parser.getCurrentLocation();
   Type returnType;
   if (parser.parseLess() || dispatchParse(parser, returnType) ||
       parser.parseLParen())
@@ -187,8 +187,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
   // Function type without arguments.
   if (succeeded(parser.parseOptionalRParen())) {
     if (succeeded(parser.parseGreater()))
-      return LLVMFunctionType::getChecked(loc, returnType, llvm::None,
-                                          /*isVarArg=*/false);
+      return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
+                                                 /*isVarArg=*/false);
     return LLVMFunctionType();
   }
 
@@ -198,8 +198,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
     if (succeeded(parser.parseOptionalEllipsis())) {
       if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
         return LLVMFunctionType();
-      return LLVMFunctionType::getChecked(loc, returnType, argTypes,
-                                          /*isVarArg=*/true);
+      return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
+                                                 /*isVarArg=*/true);
     }
 
     Type arg;
@@ -210,14 +210,14 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
 
   if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
     return LLVMFunctionType();
-  return LLVMFunctionType::getChecked(loc, returnType, argTypes,
-                                      /*isVarArg=*/false);
+  return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
+                                             /*isVarArg=*/false);
 }
 
 /// Parses an LLVM dialect pointer type.
 ///   llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
 static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
-  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+  llvm::SMLoc loc = parser.getCurrentLocation();
   Type elementType;
   if (parser.parseLess() || dispatchParse(parser, elementType))
     return LLVMPointerType();
@@ -228,7 +228,7 @@ static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
     return LLVMPointerType();
   if (failed(parser.parseGreater()))
     return LLVMPointerType();
-  return LLVMPointerType::getChecked(loc, elementType, addressSpace);
+  return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
 }
 
 /// Parses an LLVM dialect vector type.
@@ -238,7 +238,7 @@ static Type parseVectorType(DialectAsmParser &parser) {
   SmallVector<int64_t, 2> dims;
   llvm::SMLoc dimPos, typePos;
   Type elementType;
-  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+  llvm::SMLoc loc = parser.getCurrentLocation();
   if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
       parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
       parser.getCurrentLocation(&typePos) ||
@@ -259,13 +259,13 @@ static Type parseVectorType(DialectAsmParser &parser) {
 
   bool isScalable = dims.size() == 2;
   if (isScalable)
-    return LLVMScalableVectorType::getChecked(loc, elementType, dims[1]);
+    return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
   if (elementType.isSignlessIntOrFloat()) {
     parser.emitError(typePos)
         << "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
     return Type();
   }
-  return LLVMFixedVectorType::getChecked(loc, elementType, dims[0]);
+  return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
 }
 
 /// Parses an LLVM dialect array type.
@@ -274,7 +274,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
   SmallVector<int64_t, 1> dims;
   llvm::SMLoc sizePos;
   Type elementType;
-  Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
+  llvm::SMLoc loc = parser.getCurrentLocation();
   if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
       parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
       dispatchParse(parser, elementType) || parser.parseGreater())
@@ -285,7 +285,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
     return LLVMArrayType();
   }
 
-  return LLVMArrayType::getChecked(loc, elementType, dims[0]);
+  return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
 }
 
 /// Attempts to set the body of an identified structure type. Reports a parsing

diff  --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
index 636cf7ddb96c..16fe1f0ebdee 100644
--- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
+++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
@@ -117,7 +117,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
 ///   storage-range ::= integer-literal `:` integer-literal
 ///   storage-type ::= (`i` | `u`) integer-literal
 ///   expressed-type-spec ::= `:` `f` integer-literal
-static Type parseAnyType(DialectAsmParser &parser, Location loc) {
+static Type parseAnyType(DialectAsmParser &parser) {
   IntegerType storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
@@ -155,9 +155,8 @@ static Type parseAnyType(DialectAsmParser &parser, Location loc) {
     return nullptr;
   }
 
-  return AnyQuantizedType::getChecked(loc, typeFlags, storageType,
-                                      expressedType, storageTypeMin,
-                                      storageTypeMax);
+  return parser.getChecked<AnyQuantizedType>(
+      typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
 }
 
 static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
@@ -192,7 +191,7 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
 ///   axis-spec ::= `:` integer-literal
 ///   scale-zero ::= float-literal `:` integer-literal
 ///   scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
-static Type parseUniformType(DialectAsmParser &parser, Location loc) {
+static Type parseUniformType(DialectAsmParser &parser) {
   IntegerType storageType;
   FloatType expressedType;
   unsigned typeFlags = 0;
@@ -279,14 +278,14 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
   if (isPerAxis) {
     ArrayRef<double> scalesRef(scales.begin(), scales.end());
     ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
-    return UniformQuantizedPerAxisType::getChecked(
-        loc, typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
+    return parser.getChecked<UniformQuantizedPerAxisType>(
+        typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
         quantizedDimension, storageTypeMin, storageTypeMax);
   }
 
-  return UniformQuantizedType::getChecked(
-      loc, typeFlags, storageType, expressedType, scales.front(),
-      zeroPoints.front(), storageTypeMin, storageTypeMax);
+  return parser.getChecked<UniformQuantizedType>(
+      typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
+      storageTypeMin, storageTypeMax);
 }
 
 /// Parses an CalibratedQuantizedType.
@@ -295,7 +294,7 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
 ///   expressed-spec ::= expressed-type `<` calibrated-range `>`
 ///   expressed-type ::= `f` integer-literal
 ///   calibrated-range ::= float-literal `:` float-literal
-static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
+static Type parseCalibratedType(DialectAsmParser &parser) {
   FloatType expressedType;
   double min;
   double max;
@@ -314,24 +313,22 @@ static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
     return nullptr;
   }
 
-  return CalibratedQuantizedType::getChecked(loc, expressedType, min, max);
+  return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
 }
 
 /// Parse a type registered to this dialect.
 Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
-  Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
-
   // All types start with an identifier that we switch on.
   StringRef typeNameSpelling;
   if (failed(parser.parseKeyword(&typeNameSpelling)))
     return nullptr;
 
   if (typeNameSpelling == "uniform")
-    return parseUniformType(parser, loc);
+    return parseUniformType(parser);
   if (typeNameSpelling == "any")
-    return parseAnyType(parser, loc);
+    return parseAnyType(parser);
   if (typeNameSpelling == "calibrated")
-    return parseCalibratedType(parser, loc);
+    return parseCalibratedType(parser);
 
   parser.emitError(parser.getNameLoc(),
                    "unknown quantized type " + typeNameSpelling);

diff  --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 6993b8eb543a..46096a59f8ac 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -524,7 +524,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
 
         // Otherwise, form a new opaque attribute.
         return OpaqueAttr::getChecked(
-            getEncodedSourceLocation(loc),
+            [&] { return emitError(loc); },
             Identifier::get(dialectName, state.context), symbolData,
             attrType ? attrType : NoneType::get(state.context));
       });
@@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {
 
         // Otherwise, form a new opaque type.
         return OpaqueType::getChecked(
-            getEncodedSourceLocation(loc),
+            [&] { return emitError(loc); },
             Identifier::get(dialectName, state.context), symbolData);
       });
 }


        


More information about the Mlir-commits mailing list