[Mlir-commits] [mlir] [mlir][Parser] Deduplicate floating-point parsing functionality (PR #116172)
Matthias Springer
llvmlistbot at llvm.org
Mon Nov 18 00:21:20 PST 2024
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116172
>From 093da29f4a33c9baca81c03cb7482d34a5283f69 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 14 Nov 2024 06:55:21 +0100
Subject: [PATCH 1/4] [mlir][Parser][NFC] Make `parseFloatFromIntegerLiteral` a
standalone function
---
mlir/lib/AsmParser/AsmParserImpl.h | 13 +++---
mlir/lib/AsmParser/AttributeParser.cpp | 24 +++++-----
mlir/lib/AsmParser/Parser.cpp | 63 +++++++++++++-------------
mlir/lib/AsmParser/Parser.h | 12 ++---
4 files changed, 57 insertions(+), 55 deletions(-)
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 04250f63dcd253..1e6cbc0ec51beb 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -287,13 +287,13 @@ class AsmParserImpl : public BaseT {
APFloat &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
- SMLoc loc = curTok.getLoc();
+ auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); };
// Check for a floating point value.
if (curTok.is(Token::floatliteral)) {
auto val = curTok.getFloatingPointValue();
if (!val)
- return emitError(loc, "floating point value too large");
+ return emitErrorAtTok() << "floating point value too large";
parser.consumeToken(Token::floatliteral);
result = APFloat(isNegative ? -*val : *val);
bool losesInfo;
@@ -303,10 +303,9 @@ class AsmParserImpl : public BaseT {
// Check for a hexadecimal float value.
if (curTok.is(Token::integer)) {
- std::optional<APFloat> apResult;
- if (failed(parser.parseFloatFromIntegerLiteral(
- apResult, curTok, isNegative, semantics,
- APFloat::semanticsSizeInBits(semantics))))
+ FailureOr<APFloat> apResult = parseFloatFromIntegerLiteral(
+ emitErrorAtTok, curTok, isNegative, semantics);
+ if (failed(apResult))
return failure();
result = *apResult;
@@ -314,7 +313,7 @@ class AsmParserImpl : public BaseT {
return success();
}
- return emitError(loc, "expected floating point literal");
+ return emitErrorAtTok() << "expected floating point literal";
}
/// Parse a floating point value from the stream.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..ba9be3b030453a 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -422,10 +422,10 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
}
if (auto floatType = dyn_cast<FloatType>(type)) {
- std::optional<APFloat> result;
- if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
- floatType.getFloatSemantics(),
- floatType.getWidth())))
+ auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
+ FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
+ emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+ if (failed(result))
return Attribute();
return FloatAttr::get(floatType, *result);
}
@@ -661,10 +661,10 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
// Handle hexadecimal float literals.
if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
- std::optional<APFloat> result;
- if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
- eltTy.getFloatSemantics(),
- eltTy.getWidth())))
+ auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+ FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
+ emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
+ if (failed(result))
return failure();
floatValues.push_back(*result);
@@ -911,10 +911,12 @@ ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
auto floatType = cast<FloatType>(type);
if (p.consumeIf(Token::integer)) {
// Parse an integer literal as a float.
- if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
- floatType.getFloatSemantics(),
- floatType.getWidth()))
+ auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+ FailureOr<APFloat> fromIntLit = parseFloatFromIntegerLiteral(
+ emitErrorAtTok, token, isNegative, floatType.getFloatSemantics());
+ if (failed(fromIntLit))
return failure();
+ result = *fromIntLit;
} else if (p.consumeIf(Token::floatliteral)) {
// Parse a floating point literal.
std::optional<double> val = token.getFloatingPointValue();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8f19487d80fa39..ac7eec931b1250 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -67,6 +67,38 @@
using namespace mlir;
using namespace mlir::detail;
+/// Parse a floating point value from an integer literal token.
+FailureOr<APFloat> detail::parseFloatFromIntegerLiteral(
+ function_ref<InFlightDiagnostic()> emitError, const Token &tok,
+ bool isNegative, const llvm::fltSemantics &semantics) {
+ StringRef spelling = tok.getSpelling();
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+ if (!isHex) {
+ auto error = emitError();
+ error << "unexpected decimal integer literal for a "
+ "floating point value";
+ error.attachNote() << "add a trailing dot to make the literal a float";
+ return failure();
+ }
+ if (isNegative) {
+ emitError() << "hexadecimal float literal should not have a "
+ "leading minus";
+ return failure();
+ }
+
+ APInt intValue;
+ tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
+ auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
+ if (intValue.getActiveBits() > typeSizeInBits) {
+ return emitError() << "hexadecimal float constant out of range for type";
+ return failure();
+ }
+
+ APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
+ intValue.getRawData());
+ return APFloat(semantics, truncatedValue);
+}
+
//===----------------------------------------------------------------------===//
// CodeComplete
//===----------------------------------------------------------------------===//
@@ -347,37 +379,6 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
return success();
}
-/// Parse a floating point value from an integer literal token.
-ParseResult Parser::parseFloatFromIntegerLiteral(
- std::optional<APFloat> &result, const Token &tok, bool isNegative,
- const llvm::fltSemantics &semantics, size_t typeSizeInBits) {
- SMLoc loc = tok.getLoc();
- StringRef spelling = tok.getSpelling();
- bool isHex = spelling.size() > 1 && spelling[1] == 'x';
- if (!isHex) {
- return emitError(loc, "unexpected decimal integer literal for a "
- "floating point value")
- .attachNote()
- << "add a trailing dot to make the literal a float";
- }
- if (isNegative) {
- return emitError(loc, "hexadecimal float literal should not have a "
- "leading minus");
- }
-
- APInt intValue;
- tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
- if (intValue.getActiveBits() > typeSizeInBits)
- return emitError(loc, "hexadecimal float constant out of range for type");
-
- APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
- intValue.getRawData());
-
- result.emplace(semantics, truncatedValue);
-
- return success();
-}
-
ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index bf91831798056b..fa29264ffe506a 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -16,6 +16,12 @@
namespace mlir {
namespace detail {
+/// Parse a floating point value from an integer literal token.
+FailureOr<APFloat>
+parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError,
+ const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics);
+
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
@@ -151,12 +157,6 @@ class Parser {
/// Parse an optional integer value only in decimal format from the stream.
OptionalParseResult parseOptionalDecimalInteger(APInt &result);
- /// Parse a floating point value from an integer literal token.
- ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
- const Token &tok, bool isNegative,
- const llvm::fltSemantics &semantics,
- size_t typeSizeInBits);
-
/// Returns true if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
return getToken().isAny(Token::bare_identifier, Token::inttype) ||
>From 94bcf6f9dd075c314dfc6d2bb5dbde00366f52e0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 14 Nov 2024 07:43:08 +0100
Subject: [PATCH 2/4] [mlir][Parser] Deduplicate fp parsing functionality
---
mlir/lib/AsmParser/AsmParserImpl.h | 33 ++-------
mlir/lib/AsmParser/AttributeParser.cpp | 71 ++++----------------
mlir/lib/AsmParser/Parser.cpp | 23 +++++++
mlir/lib/AsmParser/Parser.h | 6 ++
mlir/test/IR/invalid-builtin-attributes.mlir | 10 +--
5 files changed, 56 insertions(+), 87 deletions(-)
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 1e6cbc0ec51beb..bbd70d5980f8fe 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -288,32 +288,13 @@ class AsmParserImpl : public BaseT {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); };
-
- // Check for a floating point value.
- if (curTok.is(Token::floatliteral)) {
- auto val = curTok.getFloatingPointValue();
- if (!val)
- return emitErrorAtTok() << "floating point value too large";
- parser.consumeToken(Token::floatliteral);
- result = APFloat(isNegative ? -*val : *val);
- bool losesInfo;
- result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
- return success();
- }
-
- // Check for a hexadecimal float value.
- if (curTok.is(Token::integer)) {
- FailureOr<APFloat> apResult = parseFloatFromIntegerLiteral(
- emitErrorAtTok, curTok, isNegative, semantics);
- if (failed(apResult))
- return failure();
-
- result = *apResult;
- parser.consumeToken(Token::integer);
- return success();
- }
-
- return emitErrorAtTok() << "expected floating point literal";
+ FailureOr<APFloat> apResult =
+ parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics);
+ if (failed(apResult))
+ return failure();
+ parser.consumeToken();
+ result = *apResult;
+ return success();
}
/// Parse a floating point value from the stream.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index ba9be3b030453a..9ebada076cd042 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -658,36 +658,12 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
-
- // Handle hexadecimal float literals.
- if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
- auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
- FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
- emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
- if (failed(result))
- return failure();
-
- floatValues.push_back(*result);
- continue;
- }
-
- // Check to see if any decimal integers or booleans were parsed.
- if (!token.is(Token::floatliteral))
- return p.emitError()
- << "expected floating-point elements, but parsed integer";
-
- // Build the float values from tokens.
- auto val = token.getFloatingPointValue();
- if (!val)
- return p.emitError("floating point value too large for attribute");
-
- APFloat apVal(isNegative ? -*val : *val);
- if (!eltTy.isF64()) {
- bool unused;
- apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
- &unused);
- }
- floatValues.push_back(apVal);
+ auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+ FailureOr<APFloat> result = parseFloatFromLiteral(
+ emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
+ if (failed(result))
+ return failure();
+ floatValues.push_back(*result);
}
return success();
}
@@ -905,34 +881,15 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
bool isNegative = p.consumeIf(Token::minus);
-
Token token = p.getToken();
- std::optional<APFloat> result;
- auto floatType = cast<FloatType>(type);
- if (p.consumeIf(Token::integer)) {
- // Parse an integer literal as a float.
- auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
- FailureOr<APFloat> fromIntLit = parseFloatFromIntegerLiteral(
- emitErrorAtTok, token, isNegative, floatType.getFloatSemantics());
- if (failed(fromIntLit))
- return failure();
- result = *fromIntLit;
- } else if (p.consumeIf(Token::floatliteral)) {
- // Parse a floating point literal.
- std::optional<double> val = token.getFloatingPointValue();
- if (!val)
- return failure();
- result = APFloat(isNegative ? -*val : *val);
- if (!type.isF64()) {
- bool unused;
- result->convert(floatType.getFloatSemantics(),
- APFloat::rmNearestTiesToEven, &unused);
- }
- } else {
- return p.emitError("expected integer or floating point literal");
- }
-
- append(result->bitcastToAPInt());
+ auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+ FailureOr<APFloat> fromIntLit =
+ parseFloatFromLiteral(emitErrorAtTok, token, isNegative,
+ cast<FloatType>(type).getFloatSemantics());
+ if (failed(fromIntLit))
+ return failure();
+ p.consumeToken();
+ append(fromIntLit->bitcastToAPInt());
return success();
}
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index ac7eec931b1250..15f3dd7a66c358 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -99,6 +99,29 @@ FailureOr<APFloat> detail::parseFloatFromIntegerLiteral(
return APFloat(semantics, truncatedValue);
}
+FailureOr<APFloat>
+detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+ const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics) {
+ // Check for a floating point value.
+ if (tok.is(Token::floatliteral)) {
+ auto val = tok.getFloatingPointValue();
+ if (!val)
+ return emitError() << "floating point value too large";
+
+ APFloat result(isNegative ? -*val : *val);
+ bool unused;
+ result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
+ return result;
+ }
+
+ // Check for a hexadecimal float value.
+ if (tok.is(Token::integer))
+ return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics);
+
+ return emitError() << "expected floating point literal";
+}
+
//===----------------------------------------------------------------------===//
// CodeComplete
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index fa29264ffe506a..ab445476a91923 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -22,6 +22,12 @@ parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics);
+/// Parse a floating point value from a literal.
+FailureOr<APFloat>
+parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+ const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics);
+
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 431c7b12b8f5fe..5098fe751fd01f 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () {
// -----
func.func @elementsattr_floattype2() -> () {
- // expected-error at +1 {{expected floating-point elements, but parsed integer}}
+ // expected-error at below {{unexpected decimal integer literal for a floating point value}}
+ // expected-note at below {{add a trailing dot to make the literal a float}}
"foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> ()
}
@@ -138,21 +139,22 @@ func.func @float_in_int_tensor() {
// -----
func.func @float_in_bool_tensor() {
- // expected-error @+1 {{expected integer elements, but parsed floating-point}}
+ // expected-error at below {{expected integer elements, but parsed floating-point}}
"foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> ()
}
// -----
func.func @decimal_int_in_float_tensor() {
- // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+ // expected-error at below {{unexpected decimal integer literal for a floating point value}}
+ // expected-note at below {{add a trailing dot to make the literal a float}}
"foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> ()
}
// -----
func.func @bool_in_float_tensor() {
- // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+ // expected-error @+1 {{expected floating point literal}}
"foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> ()
}
>From 6a59007385b335a54739d4b060103384b6c54d20 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 18 Nov 2024 09:06:22 +0100
Subject: [PATCH 3/4] address comments
---
mlir/lib/AsmParser/AsmParserImpl.h | 3 +-
mlir/lib/AsmParser/AttributeParser.cpp | 14 ++--
mlir/lib/AsmParser/Parser.cpp | 110 ++++++++++++-------------
mlir/lib/AsmParser/Parser.h | 20 ++---
4 files changed, 70 insertions(+), 77 deletions(-)
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index bbd70d5980f8fe..d9d49d53a407d0 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -287,9 +287,8 @@ class AsmParserImpl : public BaseT {
APFloat &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
- auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); };
FailureOr<APFloat> apResult =
- parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics);
+ parser.parseFloatFromLiteral(curTok, isNegative, semantics);
if (failed(apResult))
return failure();
parser.consumeToken();
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 9ebada076cd042..0df3d492f411ec 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -422,9 +422,8 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
}
if (auto floatType = dyn_cast<FloatType>(type)) {
- auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
- emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+ tok, isNegative, floatType.getFloatSemantics());
if (failed(result))
return Attribute();
return FloatAttr::get(floatType, *result);
@@ -658,9 +657,8 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
- auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
- FailureOr<APFloat> result = parseFloatFromLiteral(
- emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
+ FailureOr<APFloat> result =
+ p.parseFloatFromLiteral(token, isNegative, eltTy.getFloatSemantics());
if (failed(result))
return failure();
floatValues.push_back(*result);
@@ -882,10 +880,8 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
bool isNegative = p.consumeIf(Token::minus);
Token token = p.getToken();
- auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
- FailureOr<APFloat> fromIntLit =
- parseFloatFromLiteral(emitErrorAtTok, token, isNegative,
- cast<FloatType>(type).getFloatSemantics());
+ FailureOr<APFloat> fromIntLit = p.parseFloatFromLiteral(
+ token, isNegative, cast<FloatType>(type).getFloatSemantics());
if (failed(fromIntLit))
return failure();
p.consumeToken();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 15f3dd7a66c358..e10b87ad43b139 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -67,61 +67,6 @@
using namespace mlir;
using namespace mlir::detail;
-/// Parse a floating point value from an integer literal token.
-FailureOr<APFloat> detail::parseFloatFromIntegerLiteral(
- function_ref<InFlightDiagnostic()> emitError, const Token &tok,
- bool isNegative, const llvm::fltSemantics &semantics) {
- StringRef spelling = tok.getSpelling();
- bool isHex = spelling.size() > 1 && spelling[1] == 'x';
- if (!isHex) {
- auto error = emitError();
- error << "unexpected decimal integer literal for a "
- "floating point value";
- error.attachNote() << "add a trailing dot to make the literal a float";
- return failure();
- }
- if (isNegative) {
- emitError() << "hexadecimal float literal should not have a "
- "leading minus";
- return failure();
- }
-
- APInt intValue;
- tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
- auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
- if (intValue.getActiveBits() > typeSizeInBits) {
- return emitError() << "hexadecimal float constant out of range for type";
- return failure();
- }
-
- APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
- intValue.getRawData());
- return APFloat(semantics, truncatedValue);
-}
-
-FailureOr<APFloat>
-detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
- const Token &tok, bool isNegative,
- const llvm::fltSemantics &semantics) {
- // Check for a floating point value.
- if (tok.is(Token::floatliteral)) {
- auto val = tok.getFloatingPointValue();
- if (!val)
- return emitError() << "floating point value too large";
-
- APFloat result(isNegative ? -*val : *val);
- bool unused;
- result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
- return result;
- }
-
- // Check for a hexadecimal float value.
- if (tok.is(Token::integer))
- return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics);
-
- return emitError() << "expected floating point literal";
-}
-
//===----------------------------------------------------------------------===//
// CodeComplete
//===----------------------------------------------------------------------===//
@@ -402,6 +347,61 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
return success();
}
+FailureOr<APFloat>
+Parser::parseFloatFromLiteral(const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics) {
+ // Check for a floating point value.
+ if (tok.is(Token::floatliteral)) {
+ auto val = tok.getFloatingPointValue();
+ if (!val)
+ return emitError(tok.getLoc()) << "floating point value too large";
+
+ APFloat result(isNegative ? -*val : *val);
+ bool unused;
+ result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
+ return result;
+ }
+
+ // Check for a hexadecimal float value.
+ if (tok.is(Token::integer))
+ return parseFloatFromIntegerLiteral(tok, isNegative, semantics);
+
+ return emitError(tok.getLoc()) << "expected floating point literal";
+}
+
+/// Parse a floating point value from an integer literal token.
+FailureOr<APFloat>
+Parser::parseFloatFromIntegerLiteral(const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics) {
+ StringRef spelling = tok.getSpelling();
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+ if (!isHex) {
+ auto error = emitError(tok.getLoc());
+ error << "unexpected decimal integer literal for a "
+ "floating point value";
+ error.attachNote() << "add a trailing dot to make the literal a float";
+ return failure();
+ }
+ if (isNegative) {
+ emitError(tok.getLoc()) << "hexadecimal float literal should not have a "
+ "leading minus";
+ return failure();
+ }
+
+ APInt intValue;
+ tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
+ auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
+ if (intValue.getActiveBits() > typeSizeInBits) {
+ return emitError(tok.getLoc())
+ << "hexadecimal float constant out of range for type";
+ return failure();
+ }
+
+ APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
+ intValue.getRawData());
+ return APFloat(semantics, truncatedValue);
+}
+
ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
// Check that the current token is a keyword.
if (!isCurrentTokenAKeyword())
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index ab445476a91923..15c4990de58344 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -16,17 +16,6 @@
namespace mlir {
namespace detail {
-/// Parse a floating point value from an integer literal token.
-FailureOr<APFloat>
-parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError,
- const Token &tok, bool isNegative,
- const llvm::fltSemantics &semantics);
-
-/// Parse a floating point value from a literal.
-FailureOr<APFloat>
-parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
- const Token &tok, bool isNegative,
- const llvm::fltSemantics &semantics);
//===----------------------------------------------------------------------===//
// Parser
@@ -163,6 +152,15 @@ class Parser {
/// Parse an optional integer value only in decimal format from the stream.
OptionalParseResult parseOptionalDecimalInteger(APInt &result);
+ /// Parse a floating point value from a literal.
+ FailureOr<APFloat> parseFloatFromLiteral(const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics);
+
+ /// Parse a floating point value from an integer literal token.
+ FailureOr<APFloat>
+ parseFloatFromIntegerLiteral(const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics);
+
/// Returns true if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
return getToken().isAny(Token::bare_identifier, Token::inttype) ||
>From 1f81180b91597f0813d46cb2f5b422d7678a2a3c Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 18 Nov 2024 09:16:19 +0100
Subject: [PATCH 4/4] address comments 2
---
mlir/lib/AsmParser/AsmParserImpl.h | 6 ++--
mlir/lib/AsmParser/AttributeParser.cpp | 19 ++++++------
mlir/lib/AsmParser/Parser.cpp | 42 +++++++++++++-------------
mlir/lib/AsmParser/Parser.h | 11 ++++---
4 files changed, 40 insertions(+), 38 deletions(-)
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index d9d49d53a407d0..d5b72d63813a4e 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -287,9 +287,9 @@ class AsmParserImpl : public BaseT {
APFloat &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
- FailureOr<APFloat> apResult =
- parser.parseFloatFromLiteral(curTok, isNegative, semantics);
- if (failed(apResult))
+ std::optional<APFloat> apResult;
+ if (failed(parser.parseFloatFromLiteral(apResult, curTok, isNegative,
+ semantics)))
return failure();
parser.consumeToken();
result = *apResult;
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 0df3d492f411ec..ff616dac9625b4 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -422,9 +422,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
}
if (auto floatType = dyn_cast<FloatType>(type)) {
- FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
- tok, isNegative, floatType.getFloatSemantics());
- if (failed(result))
+ std::optional<APFloat> result;
+ if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
+ floatType.getFloatSemantics())))
return Attribute();
return FloatAttr::get(floatType, *result);
}
@@ -657,9 +657,9 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;
- FailureOr<APFloat> result =
- p.parseFloatFromLiteral(token, isNegative, eltTy.getFloatSemantics());
- if (failed(result))
+ std::optional<APFloat> result;
+ if (failed(p.parseFloatFromLiteral(result, token, isNegative,
+ eltTy.getFloatSemantics())))
return failure();
floatValues.push_back(*result);
}
@@ -880,9 +880,10 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
bool isNegative = p.consumeIf(Token::minus);
Token token = p.getToken();
- FailureOr<APFloat> fromIntLit = p.parseFloatFromLiteral(
- token, isNegative, cast<FloatType>(type).getFloatSemantics());
- if (failed(fromIntLit))
+ std::optional<APFloat> fromIntLit;
+ if (failed(
+ p.parseFloatFromLiteral(fromIntLit, token, isNegative,
+ cast<FloatType>(type).getFloatSemantics())))
return failure();
p.consumeToken();
append(fromIntLit->bitcastToAPInt());
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e10b87ad43b139..e3db248164672c 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -347,59 +347,59 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
return success();
}
-FailureOr<APFloat>
-Parser::parseFloatFromLiteral(const Token &tok, bool isNegative,
- const llvm::fltSemantics &semantics) {
+ParseResult Parser::parseFloatFromLiteral(std::optional<APFloat> &result,
+ const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics) {
// Check for a floating point value.
if (tok.is(Token::floatliteral)) {
auto val = tok.getFloatingPointValue();
if (!val)
return emitError(tok.getLoc()) << "floating point value too large";
- APFloat result(isNegative ? -*val : *val);
+ result.emplace(isNegative ? -*val : *val);
bool unused;
- result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
- return result;
+ result->convert(semantics, APFloat::rmNearestTiesToEven, &unused);
+ return success();
}
// Check for a hexadecimal float value.
if (tok.is(Token::integer))
- return parseFloatFromIntegerLiteral(tok, isNegative, semantics);
+ return parseFloatFromIntegerLiteral(result, tok, isNegative, semantics);
return emitError(tok.getLoc()) << "expected floating point literal";
}
/// Parse a floating point value from an integer literal token.
-FailureOr<APFloat>
-Parser::parseFloatFromIntegerLiteral(const Token &tok, bool isNegative,
+ParseResult
+Parser::parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
+ const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics) {
StringRef spelling = tok.getSpelling();
bool isHex = spelling.size() > 1 && spelling[1] == 'x';
if (!isHex) {
- auto error = emitError(tok.getLoc());
- error << "unexpected decimal integer literal for a "
- "floating point value";
- error.attachNote() << "add a trailing dot to make the literal a float";
- return failure();
+ return emitError(tok.getLoc(), "unexpected decimal integer literal for a "
+ "floating point value")
+ .attachNote()
+ << "add a trailing dot to make the literal a float";
}
if (isNegative) {
- emitError(tok.getLoc()) << "hexadecimal float literal should not have a "
- "leading minus";
- return failure();
+ return emitError(tok.getLoc(),
+ "hexadecimal float literal should not have a "
+ "leading minus");
}
APInt intValue;
tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
if (intValue.getActiveBits() > typeSizeInBits) {
- return emitError(tok.getLoc())
- << "hexadecimal float constant out of range for type";
- return failure();
+ return emitError(tok.getLoc(),
+ "hexadecimal float constant out of range for type");
}
APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
intValue.getRawData());
- return APFloat(semantics, truncatedValue);
+ result.emplace(semantics, truncatedValue);
+ return success();
}
ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 15c4990de58344..4979cfc6e69e41 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -153,13 +153,14 @@ class Parser {
OptionalParseResult parseOptionalDecimalInteger(APInt &result);
/// Parse a floating point value from a literal.
- FailureOr<APFloat> parseFloatFromLiteral(const Token &tok, bool isNegative,
- const llvm::fltSemantics &semantics);
+ ParseResult parseFloatFromLiteral(std::optional<APFloat> &result,
+ const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics);
/// Parse a floating point value from an integer literal token.
- FailureOr<APFloat>
- parseFloatFromIntegerLiteral(const Token &tok, bool isNegative,
- const llvm::fltSemantics &semantics);
+ ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
+ const Token &tok, bool isNegative,
+ const llvm::fltSemantics &semantics);
/// Returns true if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
More information about the Mlir-commits
mailing list