[Mlir-commits] [mlir] 2b09a89 - [OpAsmParser] Refactor parseOptionalInteger to support wide integers, NFC.

Chris Lattner llvmlistbot at llvm.org
Mon May 10 22:47:05 PDT 2021


Author: Chris Lattner
Date: 2021-05-10T22:35:42-07:00
New Revision: 2b09a89daf956795d82076d983c3d78b96e1af4b

URL: https://github.com/llvm/llvm-project/commit/2b09a89daf956795d82076d983c3d78b96e1af4b
DIFF: https://github.com/llvm/llvm-project/commit/2b09a89daf956795d82076d983c3d78b96e1af4b.diff

LOG: [OpAsmParser] Refactor parseOptionalInteger to support wide integers, NFC.

OpAsmParser (and DialectAsmParser) supports a pair of
parseInteger/parseOptionalInteger methods, which allow parsing a bare
integer into a C type of your choice (e.g. int8_t) using templates.  It
was implemented in terms of a virtual method call that is hard coded to
int64_t because "that should be big enough".

Change the virtual method hook to return an APInt instead.  This allows
asmparsers for custom ops to parse large integers if they want to, without
changing any of the clients of the fixed size C API.

Differential Revision: https://reviews.llvm.org/D102120

Added: 
    

Modified: 
    mlir/include/mlir/IR/DialectImplementation.h
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Parser/DialectSymbolParser.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Parser/Parser.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h
index 2815ccdaa6f08..f99d695a50017 100644
--- a/mlir/include/mlir/IR/DialectImplementation.h
+++ b/mlir/include/mlir/IR/DialectImplementation.h
@@ -139,7 +139,8 @@ class DialectAsmParser {
   virtual ParseResult parseFloat(double &result) = 0;
 
   /// Parse an integer value from the stream.
-  template <typename IntT> ParseResult parseInteger(IntT &result) {
+  template <typename IntT>
+  ParseResult parseInteger(IntT &result) {
     auto loc = getCurrentLocation();
     OptionalParseResult parseResult = parseOptionalInteger(result);
     if (!parseResult.hasValue())
@@ -148,21 +149,24 @@ class DialectAsmParser {
   }
 
   /// Parse an optional integer value from the stream.
-  virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0;
+  virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
 
   template <typename IntT>
   OptionalParseResult parseOptionalInteger(IntT &result) {
     auto loc = getCurrentLocation();
 
     // Parse the unsigned variant.
-    uint64_t uintResult;
+    APInt uintResult;
     OptionalParseResult parseResult = parseOptionalInteger(uintResult);
     if (!parseResult.hasValue() || failed(*parseResult))
       return parseResult;
 
-    // Try to convert to the provided integer type.
-    result = IntT(uintResult);
-    if (uint64_t(result) != uintResult)
+    // Try to convert to the provided integer type.  sextOrTrunc is correct even
+    // for unsigned types because parseOptionalInteger ensures the sign bit is
+    // zero for non-negated integers.
+    result =
+        (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
+    if (APInt(uintResult.getBitWidth(), result) != uintResult)
       return emitError(loc, "integer value too large");
     return success();
   }
@@ -172,13 +176,14 @@ class DialectAsmParser {
   /// unlike `OpBuilder::getType`, this method does not implicitly insert a
   /// context parameter.
   template <typename T, typename... ParamsT>
-  T getChecked(llvm::SMLoc loc, ParamsT &&...params) {
+  T getChecked(llvm::SMLoc loc, ParamsT &&... params) {
     return T::getChecked([&] { return emitError(loc); },
                          std::forward<ParamsT>(params)...);
   }
   /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
   /// errors.
-  template <typename T, typename... ParamsT> T getChecked(ParamsT &&...params) {
+  template <typename T, typename... ParamsT>
+  T getChecked(ParamsT &&... params) {
     return T::getChecked([&] { return emitError(getNameLoc()); },
                          std::forward<ParamsT>(params)...);
   }
@@ -331,7 +336,8 @@ class DialectAsmParser {
   virtual ParseResult parseType(Type &result) = 0;
 
   /// Parse a type of a specific kind, e.g. a FunctionType.
-  template <typename TypeType> ParseResult parseType(TypeType &result) {
+  template <typename TypeType>
+  ParseResult parseType(TypeType &result) {
     llvm::SMLoc loc = getCurrentLocation();
 
     // Parse any kind of type.

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 22274795fa47e..fb4f1a85077c0 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -435,21 +435,24 @@ class OpAsmParser {
   }
 
   /// Parse an optional integer value from the stream.
-  virtual OptionalParseResult parseOptionalInteger(uint64_t &result) = 0;
+  virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
 
   template <typename IntT>
   OptionalParseResult parseOptionalInteger(IntT &result) {
     auto loc = getCurrentLocation();
 
     // Parse the unsigned variant.
-    uint64_t uintResult;
+    APInt uintResult;
     OptionalParseResult parseResult = parseOptionalInteger(uintResult);
     if (!parseResult.hasValue() || failed(*parseResult))
       return parseResult;
 
-    // Try to convert to the provided integer type.
-    result = IntT(uintResult);
-    if (uint64_t(result) != uintResult)
+    // Try to convert to the provided integer type.  sextOrTrunc is correct even
+    // for unsigned types because parseOptionalInteger ensures the sign bit is
+    // zero for non-negated integers.
+    result =
+        (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
+    if (APInt(uintResult.getBitWidth(), result) != uintResult)
       return emitError(loc, "integer value too large");
     return success();
   }

diff  --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index e9d79c6913095..62a03795b8415 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -94,7 +94,7 @@ class CustomDialectAsmParser : public DialectAsmParser {
   }
 
   /// Parse an optional integer value from the stream.
-  OptionalParseResult parseOptionalInteger(uint64_t &result) override {
+  OptionalParseResult parseOptionalInteger(APInt &result) override {
     return parser.parseOptionalInteger(result);
   }
 

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 5c78cefa2cc46..b3c66b2fa1916 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -96,7 +96,7 @@ ParseResult Parser::parseToken(Token::Kind expectedToken,
 }
 
 /// Parse an optional integer value from the stream.
-OptionalParseResult Parser::parseOptionalInteger(uint64_t &result) {
+OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
   Token curToken = getToken();
   if (curToken.isNot(Token::integer, Token::minus))
     return llvm::None;
@@ -106,10 +106,19 @@ OptionalParseResult Parser::parseOptionalInteger(uint64_t &result) {
   if (parseToken(Token::integer, "expected integer value"))
     return failure();
 
-  auto val = curTok.getUInt64IntegerValue();
-  if (!val)
+  StringRef spelling = curTok.getSpelling();
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+  if (spelling.getAsInteger(isHex ? 0 : 10, result))
     return emitError(curTok.getLoc(), "integer value too large");
-  result = negative ? -*val : *val;
+
+  // Make sure we have a zero at the top so we return the right signedness.
+  if (result.isNegative())
+    result = result.zext(result.getBitWidth() + 1);
+
+  // Process the negative sign if present.
+  if (negative)
+    result.negate();
+
   return success();
 }
 
@@ -1217,7 +1226,7 @@ class CustomOpAsmParser : public OpAsmParser {
   }
 
   /// Parse an optional integer value from the stream.
-  OptionalParseResult parseOptionalInteger(uint64_t &result) override {
+  OptionalParseResult parseOptionalInteger(APInt &result) override {
     return parser.parseOptionalInteger(result);
   }
 

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 6fe07e9246a4f..2f41f99a51859 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -128,7 +128,7 @@ class Parser {
   ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
 
   /// Parse an optional integer value from the stream.
-  OptionalParseResult parseOptionalInteger(uint64_t &result);
+  OptionalParseResult parseOptionalInteger(APInt &result);
 
   /// Parse a floating point value from an integer literal token.
   ParseResult parseFloatFromIntegerLiteral(Optional<APFloat> &result,


        


More information about the Mlir-commits mailing list