[Mlir-commits] [mlir] d70185e - [mlir][IR] Support parsing hex float values in the DialectSymbolParser
River Riddle
llvmlistbot at llvm.org
Wed Mar 17 14:01:13 PDT 2021
Author: River Riddle
Date: 2021-03-17T13:52:32-07:00
New Revision: d70185ec4821f7960c4fe10961479d97e816da68
URL: https://github.com/llvm/llvm-project/commit/d70185ec4821f7960c4fe10961479d97e816da68
DIFF: https://github.com/llvm/llvm-project/commit/d70185ec4821f7960c4fe10961479d97e816da68.diff
LOG: [mlir][IR] Support parsing hex float values in the DialectSymbolParser
This has been a TODO for a while, and prevents breakages for attributes/types that contain floats that can't roundtrip outside of the hex format.
Differential Revision: https://reviews.llvm.org/D98808
Added:
Modified:
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/DialectSymbolParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/test/Dialect/Quant/parse-uniform.mlir
mlir/test/IR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index f71f2a21669a..a90c65ff1fb3 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -307,20 +307,6 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
}
-/// Construct a float attribute bitwise equivalent to the integer literal.
-static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
- uint64_t value) {
- if (type.isF64())
- return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
-
- APInt apInt(type.getWidth(), value);
- if (apInt != value) {
- p->emitError("hexadecimal float constant out of range for type");
- return llvm::None;
- }
- return APFloat(type.getFloatSemantics(), apInt);
-}
-
/// Construct an APint from a parsed value, a known attribute type and
/// sign.
static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
@@ -369,10 +355,9 @@ static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
/// Parse a decimal or a hexadecimal literal, which can be either an integer
/// or a float attribute.
Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
- // Remember if the literal is hexadecimal.
- StringRef spelling = getToken().getSpelling();
- auto loc = state.curToken.getLoc();
- bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+ Token tok = getToken();
+ StringRef spelling = tok.getSpelling();
+ llvm::SMLoc loc = tok.getLoc();
consumeToken(Token::integer);
if (!type) {
@@ -384,26 +369,12 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
}
if (auto floatType = type.dyn_cast<FloatType>()) {
- if (isNegative)
- return emitError(
- loc,
- "hexadecimal float literal should not have a leading minus"),
- nullptr;
- if (!isHex) {
- emitError(loc, "unexpected decimal integer literal for a float attribute")
- .attachNote()
- << "add a trailing dot to make the literal a float";
- return nullptr;
- }
-
- auto val = Token::getUInt64IntegerValue(spelling);
- if (!val.hasValue())
- return emitError("integer constant out of range for attribute"), nullptr;
-
- // Construct a float attribute bitwise equivalent to the integer literal.
- Optional<APFloat> apVal =
- buildHexadecimalFloatLiteral(this, floatType, *val);
- return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
+ Optional<APFloat> result;
+ if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
+ floatType.getFloatSemantics(),
+ floatType.getWidth())))
+ return Attribute();
+ return FloatAttr::get(floatType, *result);
}
if (!type.isa<IntegerType, IndexType>())
@@ -638,19 +609,13 @@ TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
// Handle hexadecimal float literals.
if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
- if (isNegative) {
- return p.emitError(token.getLoc())
- << "hexadecimal float literal should not have a leading minus";
- }
- auto val = token.getUInt64IntegerValue();
- if (!val.hasValue()) {
- return p.emitError(
- "hexadecimal float constant out of range for attribute");
- }
- Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
- if (!apVal)
+ Optional<APFloat> result;
+ if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
+ eltTy.getFloatSemantics(),
+ eltTy.getWidth())))
return failure();
- floatValues.push_back(*apVal);
+
+ floatValues.push_back(*result);
continue;
}
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index 46096a59f8ac..851f40bb0fe3 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -63,21 +63,34 @@ class CustomDialectAsmParser : public DialectAsmParser {
/// Parse a floating point value from the stream.
ParseResult parseFloat(double &result) override {
- bool negative = parser.consumeIf(Token::minus);
+ bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
+ llvm::SMLoc loc = curTok.getLoc();
// Check for a floating point value.
if (curTok.is(Token::floatliteral)) {
auto val = curTok.getFloatingPointValue();
if (!val.hasValue())
- return emitError(curTok.getLoc(), "floating point value too large");
+ return emitError(loc, "floating point value too large");
parser.consumeToken(Token::floatliteral);
- result = negative ? -*val : *val;
+ result = isNegative ? -*val : *val;
return success();
}
- // TODO: support hex floating point values.
- return emitError(getCurrentLocation(), "expected floating point literal");
+ // Check for a hexadecimal float value.
+ if (curTok.is(Token::integer)) {
+ Optional<APFloat> apResult;
+ if (failed(parser.parseFloatFromIntegerLiteral(
+ apResult, curTok, isNegative, APFloat::IEEEdouble(),
+ /*typeSizeInBits=*/64)))
+ return failure();
+
+ parser.consumeToken(Token::integer);
+ result = apResult->convertToDouble();
+ return success();
+ }
+
+ return emitError(loc, "expected floating point literal");
}
/// Parse an optional integer value from the stream.
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 736522415b49..7b3cd158b2ad 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -112,6 +112,41 @@ OptionalParseResult Parser::parseOptionalInteger(uint64_t &result) {
return success();
}
+/// Parse a floating point value from an integer literal token.
+ParseResult Parser::parseFloatFromIntegerLiteral(
+ Optional<APFloat> &result, const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics, size_t typeSizeInBits) {
+ llvm::SMLoc loc = tok.getLoc();
+ StringRef spelling = tok.getSpelling();
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+ if (!isHex) {
+ return emitError(loc, "unexpected decimal integer literal for a "
+ "floating point value")
+ .attachNote()
+ << "add a trailing dot to make the literal a float";
+ }
+ if (isNegative) {
+ return emitError(loc, "hexadecimal float literal should not have a "
+ "leading minus");
+ }
+
+ Optional<uint64_t> value = tok.getUInt64IntegerValue();
+ if (!value.hasValue())
+ return emitError(loc, "hexadecimal float constant out of range for type");
+
+ if (&semantics == &APFloat::IEEEdouble()) {
+ result = APFloat(semantics, APInt(typeSizeInBits, *value));
+ return success();
+ }
+
+ APInt apInt(typeSizeInBits, *value);
+ if (apInt != *value)
+ return emitError(loc, "hexadecimal float constant out of range for type");
+ result = APFloat(semantics, apInt);
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// OperationParser
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index 0e9e4caff440..8d1910e73843 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -130,6 +130,12 @@ class Parser {
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(uint64_t &result);
+ /// Parse a floating point value from an integer literal token.
+ ParseResult parseFloatFromIntegerLiteral(Optional<APFloat> &result,
+ const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics,
+ size_t typeSizeInBits);
+
//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Quant/parse-uniform.mlir b/mlir/test/Dialect/Quant/parse-uniform.mlir
index 0e609a77d1fc..3f9dad898361 100644
--- a/mlir/test/Dialect/Quant/parse-uniform.mlir
+++ b/mlir/test/Dialect/Quant/parse-uniform.mlir
@@ -92,6 +92,15 @@ func @parse() -> !qalias {
return %0 : !qalias
}
+// -----
+// Expressed type: f32
+// CHECK: !quant.uniform<u8:f32, 0x41646ABBA0000000:128>
+!qalias = type !quant.uniform<u8:f32, 0x41646ABBA0000000:128>
+func @parse() -> !qalias {
+ %0 = "foo"() : () -> !qalias
+ return %0 : !qalias
+}
+
// -----
// Expressed type: f16
// CHECK: !quant.uniform<u8:f16, 2.000000e+02>
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 4c4df915167a..2909416771fc 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -1191,7 +1191,7 @@ func @hexadecimal_float_literal_overflow() {
// -----
func @decimal_float_literal() {
- // expected-error @+2 {{unexpected decimal integer literal for a float attribute}}
+ // expected-error @+2 {{unexpected decimal integer literal for a floating point value}}
// expected-note @+1 {{add a trailing dot to make the literal a float}}
"foo"() {value = 42 : f32} : () -> ()
}
@@ -1244,7 +1244,7 @@ func @hexadecimal_float_too_wide_for_type_in_tensor() {
// Check that we report an error when a value is too wide to be parsed.
func @hexadecimal_float_too_wide_in_tensor() {
- // expected-error @+1 {{hexadecimal float constant out of range for attribute}}
+ // expected-error @+1 {{hexadecimal float constant out of range for type}}
"foo"() {bar = dense<0x7FFFFFF0000000000000> : tensor<2xf32>} : () -> ()
}
More information about the Mlir-commits
mailing list