[Mlir-commits] [mlir] d70185e - [mlir][IR] Support parsing hex float values in the DialectSymbolParser

River Riddle llvmlistbot at llvm.org
Wed Mar 17 14:01:13 PDT 2021


Author: River Riddle
Date: 2021-03-17T13:52:32-07:00
New Revision: d70185ec4821f7960c4fe10961479d97e816da68

URL: https://github.com/llvm/llvm-project/commit/d70185ec4821f7960c4fe10961479d97e816da68
DIFF: https://github.com/llvm/llvm-project/commit/d70185ec4821f7960c4fe10961479d97e816da68.diff

LOG: [mlir][IR] Support parsing hex float values in the DialectSymbolParser

This has been a TODO for a while, and prevents breakages for attributes/types that contain floats that can't roundtrip outside of the hex format.

Differential Revision: https://reviews.llvm.org/D98808

Added: 
    

Modified: 
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/DialectSymbolParser.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Parser/Parser.h
    mlir/test/Dialect/Quant/parse-uniform.mlir
    mlir/test/IR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index f71f2a21669a..a90c65ff1fb3 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -307,20 +307,6 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
   return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
 }
 
-/// Construct a float attribute bitwise equivalent to the integer literal.
-static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
-                                                      uint64_t value) {
-  if (type.isF64())
-    return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
-
-  APInt apInt(type.getWidth(), value);
-  if (apInt != value) {
-    p->emitError("hexadecimal float constant out of range for type");
-    return llvm::None;
-  }
-  return APFloat(type.getFloatSemantics(), apInt);
-}
-
 /// Construct an APint from a parsed value, a known attribute type and
 /// sign.
 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
@@ -369,10 +355,9 @@ static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
 /// Parse a decimal or a hexadecimal literal, which can be either an integer
 /// or a float attribute.
 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
-  // Remember if the literal is hexadecimal.
-  StringRef spelling = getToken().getSpelling();
-  auto loc = state.curToken.getLoc();
-  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+  Token tok = getToken();
+  StringRef spelling = tok.getSpelling();
+  llvm::SMLoc loc = tok.getLoc();
 
   consumeToken(Token::integer);
   if (!type) {
@@ -384,26 +369,12 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
   }
 
   if (auto floatType = type.dyn_cast<FloatType>()) {
-    if (isNegative)
-      return emitError(
-                 loc,
-                 "hexadecimal float literal should not have a leading minus"),
-             nullptr;
-    if (!isHex) {
-      emitError(loc, "unexpected decimal integer literal for a float attribute")
-              .attachNote()
-          << "add a trailing dot to make the literal a float";
-      return nullptr;
-    }
-
-    auto val = Token::getUInt64IntegerValue(spelling);
-    if (!val.hasValue())
-      return emitError("integer constant out of range for attribute"), nullptr;
-
-    // Construct a float attribute bitwise equivalent to the integer literal.
-    Optional<APFloat> apVal =
-        buildHexadecimalFloatLiteral(this, floatType, *val);
-    return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
+    Optional<APFloat> result;
+    if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
+                                            floatType.getFloatSemantics(),
+                                            floatType.getWidth())))
+      return Attribute();
+    return FloatAttr::get(floatType, *result);
   }
 
   if (!type.isa<IntegerType, IndexType>())
@@ -638,19 +609,13 @@ TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
 
     // Handle hexadecimal float literals.
     if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
-      if (isNegative) {
-        return p.emitError(token.getLoc())
-               << "hexadecimal float literal should not have a leading minus";
-      }
-      auto val = token.getUInt64IntegerValue();
-      if (!val.hasValue()) {
-        return p.emitError(
-            "hexadecimal float constant out of range for attribute");
-      }
-      Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
-      if (!apVal)
+      Optional<APFloat> result;
+      if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
+                                                eltTy.getFloatSemantics(),
+                                                eltTy.getWidth())))
         return failure();
-      floatValues.push_back(*apVal);
+
+      floatValues.push_back(*result);
       continue;
     }
 

diff  --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 46096a59f8ac..851f40bb0fe3 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -63,21 +63,34 @@ class CustomDialectAsmParser : public DialectAsmParser {
 
   /// Parse a floating point value from the stream.
   ParseResult parseFloat(double &result) override {
-    bool negative = parser.consumeIf(Token::minus);
+    bool isNegative = parser.consumeIf(Token::minus);
     Token curTok = parser.getToken();
+    llvm::SMLoc loc = curTok.getLoc();
 
     // Check for a floating point value.
     if (curTok.is(Token::floatliteral)) {
       auto val = curTok.getFloatingPointValue();
       if (!val.hasValue())
-        return emitError(curTok.getLoc(), "floating point value too large");
+        return emitError(loc, "floating point value too large");
       parser.consumeToken(Token::floatliteral);
-      result = negative ? -*val : *val;
+      result = isNegative ? -*val : *val;
       return success();
     }
 
-    // TODO: support hex floating point values.
-    return emitError(getCurrentLocation(), "expected floating point literal");
+    // Check for a hexadecimal float value.
+    if (curTok.is(Token::integer)) {
+      Optional<APFloat> apResult;
+      if (failed(parser.parseFloatFromIntegerLiteral(
+              apResult, curTok, isNegative, APFloat::IEEEdouble(),
+              /*typeSizeInBits=*/64)))
+        return failure();
+
+      parser.consumeToken(Token::integer);
+      result = apResult->convertToDouble();
+      return success();
+    }
+
+    return emitError(loc, "expected floating point literal");
   }
 
   /// Parse an optional integer value from the stream.

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 736522415b49..7b3cd158b2ad 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -112,6 +112,41 @@ OptionalParseResult Parser::parseOptionalInteger(uint64_t &result) {
   return success();
 }
 
+/// Parse a floating point value from an integer literal token.
+ParseResult Parser::parseFloatFromIntegerLiteral(
+    Optional<APFloat> &result, const Token &tok, bool isNegative,
+    const llvm::fltSemantics &semantics, size_t typeSizeInBits) {
+  llvm::SMLoc loc = tok.getLoc();
+  StringRef spelling = tok.getSpelling();
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+  if (!isHex) {
+    return emitError(loc, "unexpected decimal integer literal for a "
+                          "floating point value")
+               .attachNote()
+           << "add a trailing dot to make the literal a float";
+  }
+  if (isNegative) {
+    return emitError(loc, "hexadecimal float literal should not have a "
+                          "leading minus");
+  }
+
+  Optional<uint64_t> value = tok.getUInt64IntegerValue();
+  if (!value.hasValue())
+    return emitError(loc, "hexadecimal float constant out of range for type");
+
+  if (&semantics == &APFloat::IEEEdouble()) {
+    result = APFloat(semantics, APInt(typeSizeInBits, *value));
+    return success();
+  }
+
+  APInt apInt(typeSizeInBits, *value);
+  if (apInt != *value)
+    return emitError(loc, "hexadecimal float constant out of range for type");
+  result = APFloat(semantics, apInt);
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // OperationParser
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 0e9e4caff440..8d1910e73843 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -130,6 +130,12 @@ class Parser {
   /// Parse an optional integer value from the stream.
   OptionalParseResult parseOptionalInteger(uint64_t &result);
 
+  /// Parse a floating point value from an integer literal token.
+  ParseResult parseFloatFromIntegerLiteral(Optional<APFloat> &result,
+                                           const Token &tok, bool isNegative,
+                                           const llvm::fltSemantics &semantics,
+                                           size_t typeSizeInBits);
+
   //===--------------------------------------------------------------------===//
   // Type Parsing
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir
index 0e609a77d1fc..3f9dad898361 100644
--- a/mlir/test/Dialect/Quant/parse-uniform.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform.mlir
@@ -92,6 +92,15 @@ func @parse() -> !qalias {
   return %0 : !qalias
 }
 
+// -----
+// Expressed type: f32
+// CHECK: !quant.uniform<u8:f32, 0x41646ABBA0000000:128>
+!qalias = type !quant.uniform<u8:f32, 0x41646ABBA0000000:128>
+func @parse() -> !qalias {
+  %0 = "foo"() : () -> !qalias
+  return %0 : !qalias
+}
+
 // -----
 // Expressed type: f16
 // CHECK: !quant.uniform<u8:f16, 2.000000e+02>

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 4c4df915167a..2909416771fc 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -1191,7 +1191,7 @@ func @hexadecimal_float_literal_overflow() {
 // -----
 
 func @decimal_float_literal() {
-  // expected-error @+2 {{unexpected decimal integer literal for a float attribute}}
+  // expected-error @+2 {{unexpected decimal integer literal for a floating point value}}
   // expected-note @+1 {{add a trailing dot to make the literal a float}}
   "foo"() {value = 42 : f32} : () -> ()
 }
@@ -1244,7 +1244,7 @@ func @hexadecimal_float_too_wide_for_type_in_tensor() {
 
 // Check that we report an error when a value is too wide to be parsed.
 func @hexadecimal_float_too_wide_in_tensor() {
-  // expected-error @+1 {{hexadecimal float constant out of range for attribute}}
+  // expected-error @+1 {{hexadecimal float constant out of range for type}}
   "foo"() {bar = dense<0x7FFFFFF0000000000000> : tensor<2xf32>} : () -> ()
 }
 


        


More information about the Mlir-commits mailing list