[Mlir-commits] [mlir] 9b07512 - [mlir][Parser][NFC] Remove several usages of getEncodedSourceLocation

River Riddle llvmlistbot at llvm.org
Sat Feb 15 23:52:50 PST 2020


Author: River Riddle
Date: 2020-02-15T23:52:23-08:00
New Revision: 9b07512fd3cd4541872dd2e96a697172f3f7a243

URL: https://github.com/llvm/llvm-project/commit/9b07512fd3cd4541872dd2e96a697172f3f7a243
DIFF: https://github.com/llvm/llvm-project/commit/9b07512fd3cd4541872dd2e96a697172f3f7a243.diff

LOG: [mlir][Parser][NFC] Remove several usages of getEncodedSourceLocation

Summary: getEncodedSourceLocation can be very costly to compute, especially if the input line becomes very long. This revision inlines some of the verification of a few `getChecked` methods to avoid the materialization of an encoded source location.

Differential Revision: https://reviews.llvm.org/D74587

Added: 
    

Modified: 
    mlir/lib/Parser/Parser.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index bff9ab6b81c0..8bd57a11888c 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -970,13 +970,16 @@ Type Parser::parseComplexType() {
   if (parseToken(Token::less, "expected '<' in complex type"))
     return nullptr;
 
-  auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
+  llvm::SMLoc elementTypeLoc = getToken().getLoc();
   auto elementType = parseType();
   if (!elementType ||
       parseToken(Token::greater, "expected '>' in complex type"))
     return nullptr;
+  if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
+    return emitError(elementTypeLoc, "invalid element type for complex"),
+           nullptr;
 
-  return ComplexType::getChecked(elementType, typeLocation);
+  return ComplexType::get(elementType);
 }
 
 /// Parse an extended type.
@@ -1097,69 +1100,79 @@ Type Parser::parseMemRefType() {
   if (!elementType)
     return nullptr;
 
+  // Check that memref is formed from allowed types.
+  if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
+      !elementType.isa<ComplexType>())
+    return emitError(typeLoc, "invalid memref element type"), nullptr;
+
   // Parse semi-affine-map-composition.
   SmallVector<AffineMap, 2> affineMapComposition;
-  unsigned memorySpace = 0;
-  bool parsedMemorySpace = false;
+  Optional<unsigned> memorySpace;
+  unsigned numDims = dimensions.size();
 
   auto parseElt = [&]() -> ParseResult {
+    // Check for the memory space.
     if (getToken().is(Token::integer)) {
-      // Parse memory space.
-      if (parsedMemorySpace)
+      if (memorySpace)
         return emitError("multiple memory spaces specified in memref type");
-      auto v = getToken().getUnsignedIntegerValue();
-      if (!v.hasValue())
+      memorySpace = getToken().getUnsignedIntegerValue();
+      if (!memorySpace.hasValue())
         return emitError("invalid memory space in memref type");
-      memorySpace = v.getValue();
       consumeToken(Token::integer);
-      parsedMemorySpace = true;
+      return success();
+    }
+    if (isUnranked)
+      return emitError("cannot have affine map for unranked memref type");
+    if (memorySpace)
+      return emitError("expected memory space to be last in memref type");
+
+    AffineMap map;
+    llvm::SMLoc mapLoc = getToken().getLoc();
+    if (getToken().is(Token::kw_offset)) {
+      int64_t offset;
+      SmallVector<int64_t, 4> strides;
+      if (failed(parseStridedLayout(offset, strides)))
+        return failure();
+      // Construct strided affine map.
+      map = makeStridedLinearLayoutMap(strides, offset, state.context);
     } else {
-      if (isUnranked)
-        return emitError("cannot have affine map for unranked memref type");
-      if (parsedMemorySpace)
-        return emitError("expected memory space to be last in memref type");
-      if (getToken().is(Token::kw_offset)) {
-        int64_t offset;
-        SmallVector<int64_t, 4> strides;
-        if (failed(parseStridedLayout(offset, strides)))
-          return failure();
-        // Construct strided affine map.
-        auto map = makeStridedLinearLayoutMap(strides, offset,
-                                              elementType.getContext());
-        affineMapComposition.push_back(map);
-      } else {
-        // Parse affine map.
-        auto affineMap = parseAttribute();
-        if (!affineMap)
-          return failure();
-        // Verify that the parsed attribute is an affine map.
-        if (auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>())
-          affineMapComposition.push_back(affineMapAttr.getValue());
-        else
-          return emitError("expected affine map in memref type");
-      }
+      // Parse an affine map attribute.
+      auto affineMap = parseAttribute();
+      if (!affineMap)
+        return failure();
+      auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>();
+      if (!affineMapAttr)
+        return emitError("expected affine map in memref type");
+      map = affineMapAttr.getValue();
+    }
+
+    if (map.getNumDims() != numDims) {
+      size_t i = affineMapComposition.size();
+      return emitError(mapLoc, "memref affine map dimension mismatch between ")
+             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
+             << " and affine map" << i + 1 << ": " << numDims
+             << " != " << map.getNumDims();
     }
+    numDims = map.getNumResults();
+    affineMapComposition.push_back(map);
     return success();
   };
 
   // Parse a list of mappings and address space if present.
-  if (consumeIf(Token::comma)) {
+  if (!consumeIf(Token::greater)) {
     // Parse comma separated list of affine maps, followed by memory space.
-    if (parseCommaSeparatedListUntil(Token::greater, parseElt,
+    if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
+        parseCommaSeparatedListUntil(Token::greater, parseElt,
                                      /*allowEmptyList=*/false)) {
       return nullptr;
     }
-  } else {
-    if (parseToken(Token::greater, "expected ',' or '>' in memref type"))
-      return nullptr;
   }
 
   if (isUnranked)
-    return UnrankedMemRefType::getChecked(elementType, memorySpace,
-                                          getEncodedSourceLocation(typeLoc));
+    return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0));
 
-  return MemRefType::getChecked(dimensions, elementType, affineMapComposition,
-                                memorySpace, getEncodedSourceLocation(typeLoc));
+  return MemRefType::get(dimensions, elementType, affineMapComposition,
+                         memorySpace.getValueOr(0));
 }
 
 /// Parse any type except the function type.
@@ -1198,9 +1211,14 @@ Type Parser::parseNonFunctionType() {
     auto width = getToken().getIntTypeBitwidth();
     if (!width.hasValue())
       return (emitError("invalid integer width"), nullptr);
-    auto loc = getEncodedSourceLocation(getToken().getLoc());
+    if (width.getValue() > IntegerType::kMaxWidth) {
+      emitError(getToken().getLoc(), "integer bitwidth is limited to ")
+          << IntegerType::kMaxWidth << " bits";
+      return nullptr;
+    }
+
     consumeToken(Token::inttype);
-    return IntegerType::getChecked(width.getValue(), builder.getContext(), loc);
+    return IntegerType::get(width.getValue(), builder.getContext());
   }
 
   // float-type
@@ -1261,14 +1279,16 @@ Type Parser::parseTensorType() {
   }
 
   // Parse the element type.
-  auto typeLocation = getEncodedSourceLocation(getToken().getLoc());
+  auto elementTypeLoc = getToken().getLoc();
   auto elementType = parseType();
   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
     return nullptr;
+  if (!TensorType::isValidElementType(elementType))
+    return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
 
   if (isUnranked)
-    return UnrankedTensorType::getChecked(elementType, typeLocation);
-  return RankedTensorType::getChecked(dimensions, elementType, typeLocation);
+    return UnrankedTensorType::get(elementType);
+  return RankedTensorType::get(dimensions, elementType);
 }
 
 /// Parse a tuple type.
@@ -1313,15 +1333,21 @@ VectorType Parser::parseVectorType() {
     return nullptr;
   if (dimensions.empty())
     return (emitError("expected dimension size in vector type"), nullptr);
+  if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
+    return emitError(getToken().getLoc(),
+                     "vector types must have positive constant sizes"),
+           nullptr;
 
   // Parse the element type.
   auto typeLoc = getToken().getLoc();
   auto elementType = parseType();
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
+  if (!VectorType::isValidElementType(elementType))
+    return emitError(typeLoc, "vector elements must be int or float type"),
+           nullptr;
 
-  return VectorType::getChecked(dimensions, elementType,
-                                getEncodedSourceLocation(typeLoc));
+  return VectorType::get(dimensions, elementType);
 }
 
 /// Parse a dimension list of a tensor or memref type.  This populates the


        


More information about the Mlir-commits mailing list