[Mlir-commits] [mlir] [mlir][Parser] Deduplicate floating-point parsing functionality (PR #116172)

Matthias Springer llvmlistbot at llvm.org
Mon Nov 18 00:21:20 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116172

>From 093da29f4a33c9baca81c03cb7482d34a5283f69 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 14 Nov 2024 06:55:21 +0100
Subject: [PATCH 1/4] [mlir][Parser][NFC] Make `parseFloatFromIntegerLiteral` a
 standalone function

---
 mlir/lib/AsmParser/AsmParserImpl.h     | 13 +++---
 mlir/lib/AsmParser/AttributeParser.cpp | 24 +++++-----
 mlir/lib/AsmParser/Parser.cpp          | 63 +++++++++++++-------------
 mlir/lib/AsmParser/Parser.h            | 12 ++---
 4 files changed, 57 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 04250f63dcd253..1e6cbc0ec51beb 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -287,13 +287,13 @@ class AsmParserImpl : public BaseT {
                          APFloat &result) override {
     bool isNegative = parser.consumeIf(Token::minus);
     Token curTok = parser.getToken();
-    SMLoc loc = curTok.getLoc();
+    auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); };
 
     // Check for a floating point value.
     if (curTok.is(Token::floatliteral)) {
       auto val = curTok.getFloatingPointValue();
       if (!val)
-        return emitError(loc, "floating point value too large");
+        return emitErrorAtTok() << "floating point value too large";
       parser.consumeToken(Token::floatliteral);
       result = APFloat(isNegative ? -*val : *val);
       bool losesInfo;
@@ -303,10 +303,9 @@ class AsmParserImpl : public BaseT {
 
     // Check for a hexadecimal float value.
     if (curTok.is(Token::integer)) {
-      std::optional<APFloat> apResult;
-      if (failed(parser.parseFloatFromIntegerLiteral(
-              apResult, curTok, isNegative, semantics,
-              APFloat::semanticsSizeInBits(semantics))))
+      FailureOr<APFloat> apResult = parseFloatFromIntegerLiteral(
+          emitErrorAtTok, curTok, isNegative, semantics);
+      if (failed(apResult))
         return failure();
 
       result = *apResult;
@@ -314,7 +313,7 @@ class AsmParserImpl : public BaseT {
       return success();
     }
 
-    return emitError(loc, "expected floating point literal");
+    return emitErrorAtTok() << "expected floating point literal";
   }
 
   /// Parse a floating point value from the stream.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index efa65e49abc33b..ba9be3b030453a 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -422,10 +422,10 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
   }
 
   if (auto floatType = dyn_cast<FloatType>(type)) {
-    std::optional<APFloat> result;
-    if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
-                                            floatType.getFloatSemantics(),
-                                            floatType.getWidth())))
+    auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
+    FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
+        emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+    if (failed(result))
       return Attribute();
     return FloatAttr::get(floatType, *result);
   }
@@ -661,10 +661,10 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
 
     // Handle hexadecimal float literals.
     if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
-      std::optional<APFloat> result;
-      if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
-                                                eltTy.getFloatSemantics(),
-                                                eltTy.getWidth())))
+      auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+      FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
+          emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
+      if (failed(result))
         return failure();
 
       floatValues.push_back(*result);
@@ -911,10 +911,12 @@ ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
   auto floatType = cast<FloatType>(type);
   if (p.consumeIf(Token::integer)) {
     // Parse an integer literal as a float.
-    if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
-                                       floatType.getFloatSemantics(),
-                                       floatType.getWidth()))
+    auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+    FailureOr<APFloat> fromIntLit = parseFloatFromIntegerLiteral(
+        emitErrorAtTok, token, isNegative, floatType.getFloatSemantics());
+    if (failed(fromIntLit))
       return failure();
+    result = *fromIntLit;
   } else if (p.consumeIf(Token::floatliteral)) {
     // Parse a floating point literal.
     std::optional<double> val = token.getFloatingPointValue();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 8f19487d80fa39..ac7eec931b1250 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -67,6 +67,38 @@
 using namespace mlir;
 using namespace mlir::detail;
 
+/// Parse a floating point value from an integer literal token.
+FailureOr<APFloat> detail::parseFloatFromIntegerLiteral(
+    function_ref<InFlightDiagnostic()> emitError, const Token &tok,
+    bool isNegative, const llvm::fltSemantics &semantics) {
+  StringRef spelling = tok.getSpelling();
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+  if (!isHex) {
+    auto error = emitError();
+    error << "unexpected decimal integer literal for a "
+             "floating point value";
+    error.attachNote() << "add a trailing dot to make the literal a float";
+    return failure();
+  }
+  if (isNegative) {
+    emitError() << "hexadecimal float literal should not have a "
+                   "leading minus";
+    return failure();
+  }
+
+  APInt intValue;
+  tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
+  auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
+  if (intValue.getActiveBits() > typeSizeInBits) {
+    return emitError() << "hexadecimal float constant out of range for type";
+    return failure();
+  }
+
+  APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
+                       intValue.getRawData());
+  return APFloat(semantics, truncatedValue);
+}
+
 //===----------------------------------------------------------------------===//
 // CodeComplete
 //===----------------------------------------------------------------------===//
@@ -347,37 +379,6 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
   return success();
 }
 
-/// Parse a floating point value from an integer literal token.
-ParseResult Parser::parseFloatFromIntegerLiteral(
-    std::optional<APFloat> &result, const Token &tok, bool isNegative,
-    const llvm::fltSemantics &semantics, size_t typeSizeInBits) {
-  SMLoc loc = tok.getLoc();
-  StringRef spelling = tok.getSpelling();
-  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
-  if (!isHex) {
-    return emitError(loc, "unexpected decimal integer literal for a "
-                          "floating point value")
-               .attachNote()
-           << "add a trailing dot to make the literal a float";
-  }
-  if (isNegative) {
-    return emitError(loc, "hexadecimal float literal should not have a "
-                          "leading minus");
-  }
-
-  APInt intValue;
-  tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
-  if (intValue.getActiveBits() > typeSizeInBits)
-    return emitError(loc, "hexadecimal float constant out of range for type");
-
-  APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
-                       intValue.getRawData());
-
-  result.emplace(semantics, truncatedValue);
-
-  return success();
-}
-
 ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
   // Check that the current token is a keyword.
   if (!isCurrentTokenAKeyword())
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index bf91831798056b..fa29264ffe506a 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -16,6 +16,12 @@
 
 namespace mlir {
 namespace detail {
+/// Parse a floating point value from an integer literal token.
+FailureOr<APFloat>
+parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError,
+                             const Token &tok, bool isNegative,
+                             const llvm::fltSemantics &semantics);
+
 //===----------------------------------------------------------------------===//
 // Parser
 //===----------------------------------------------------------------------===//
@@ -151,12 +157,6 @@ class Parser {
   /// Parse an optional integer value only in decimal format from the stream.
   OptionalParseResult parseOptionalDecimalInteger(APInt &result);
 
-  /// Parse a floating point value from an integer literal token.
-  ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
-                                           const Token &tok, bool isNegative,
-                                           const llvm::fltSemantics &semantics,
-                                           size_t typeSizeInBits);
-
   /// Returns true if the current token corresponds to a keyword.
   bool isCurrentTokenAKeyword() const {
     return getToken().isAny(Token::bare_identifier, Token::inttype) ||

>From 94bcf6f9dd075c314dfc6d2bb5dbde00366f52e0 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 14 Nov 2024 07:43:08 +0100
Subject: [PATCH 2/4] [mlir][Parser] Deduplicate fp parsing functionality

---
 mlir/lib/AsmParser/AsmParserImpl.h           | 33 ++-------
 mlir/lib/AsmParser/AttributeParser.cpp       | 71 ++++----------------
 mlir/lib/AsmParser/Parser.cpp                | 23 +++++++
 mlir/lib/AsmParser/Parser.h                  |  6 ++
 mlir/test/IR/invalid-builtin-attributes.mlir | 10 +--
 5 files changed, 56 insertions(+), 87 deletions(-)

diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 1e6cbc0ec51beb..bbd70d5980f8fe 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -288,32 +288,13 @@ class AsmParserImpl : public BaseT {
     bool isNegative = parser.consumeIf(Token::minus);
     Token curTok = parser.getToken();
     auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); };
-
-    // Check for a floating point value.
-    if (curTok.is(Token::floatliteral)) {
-      auto val = curTok.getFloatingPointValue();
-      if (!val)
-        return emitErrorAtTok() << "floating point value too large";
-      parser.consumeToken(Token::floatliteral);
-      result = APFloat(isNegative ? -*val : *val);
-      bool losesInfo;
-      result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
-      return success();
-    }
-
-    // Check for a hexadecimal float value.
-    if (curTok.is(Token::integer)) {
-      FailureOr<APFloat> apResult = parseFloatFromIntegerLiteral(
-          emitErrorAtTok, curTok, isNegative, semantics);
-      if (failed(apResult))
-        return failure();
-
-      result = *apResult;
-      parser.consumeToken(Token::integer);
-      return success();
-    }
-
-    return emitErrorAtTok() << "expected floating point literal";
+    FailureOr<APFloat> apResult =
+        parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics);
+    if (failed(apResult))
+      return failure();
+    parser.consumeToken();
+    result = *apResult;
+    return success();
   }
 
   /// Parse a floating point value from the stream.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index ba9be3b030453a..9ebada076cd042 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -658,36 +658,12 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
   for (const auto &signAndToken : storage) {
     bool isNegative = signAndToken.first;
     const Token &token = signAndToken.second;
-
-    // Handle hexadecimal float literals.
-    if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
-      auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
-      FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
-          emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
-      if (failed(result))
-        return failure();
-
-      floatValues.push_back(*result);
-      continue;
-    }
-
-    // Check to see if any decimal integers or booleans were parsed.
-    if (!token.is(Token::floatliteral))
-      return p.emitError()
-             << "expected floating-point elements, but parsed integer";
-
-    // Build the float values from tokens.
-    auto val = token.getFloatingPointValue();
-    if (!val)
-      return p.emitError("floating point value too large for attribute");
-
-    APFloat apVal(isNegative ? -*val : *val);
-    if (!eltTy.isF64()) {
-      bool unused;
-      apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
-                    &unused);
-    }
-    floatValues.push_back(apVal);
+    auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+    FailureOr<APFloat> result = parseFloatFromLiteral(
+        emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
+    if (failed(result))
+      return failure();
+    floatValues.push_back(*result);
   }
   return success();
 }
@@ -905,34 +881,15 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
 
 ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
   bool isNegative = p.consumeIf(Token::minus);
-
   Token token = p.getToken();
-  std::optional<APFloat> result;
-  auto floatType = cast<FloatType>(type);
-  if (p.consumeIf(Token::integer)) {
-    // Parse an integer literal as a float.
-    auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
-    FailureOr<APFloat> fromIntLit = parseFloatFromIntegerLiteral(
-        emitErrorAtTok, token, isNegative, floatType.getFloatSemantics());
-    if (failed(fromIntLit))
-      return failure();
-    result = *fromIntLit;
-  } else if (p.consumeIf(Token::floatliteral)) {
-    // Parse a floating point literal.
-    std::optional<double> val = token.getFloatingPointValue();
-    if (!val)
-      return failure();
-    result = APFloat(isNegative ? -*val : *val);
-    if (!type.isF64()) {
-      bool unused;
-      result->convert(floatType.getFloatSemantics(),
-                      APFloat::rmNearestTiesToEven, &unused);
-    }
-  } else {
-    return p.emitError("expected integer or floating point literal");
-  }
-
-  append(result->bitcastToAPInt());
+  auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+  FailureOr<APFloat> fromIntLit =
+      parseFloatFromLiteral(emitErrorAtTok, token, isNegative,
+                            cast<FloatType>(type).getFloatSemantics());
+  if (failed(fromIntLit))
+    return failure();
+  p.consumeToken();
+  append(fromIntLit->bitcastToAPInt());
   return success();
 }
 
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index ac7eec931b1250..15f3dd7a66c358 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -99,6 +99,29 @@ FailureOr<APFloat> detail::parseFloatFromIntegerLiteral(
   return APFloat(semantics, truncatedValue);
 }
 
+FailureOr<APFloat>
+detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+                              const Token &tok, bool isNegative,
+                              const llvm::fltSemantics &semantics) {
+  // Check for a floating point value.
+  if (tok.is(Token::floatliteral)) {
+    auto val = tok.getFloatingPointValue();
+    if (!val)
+      return emitError() << "floating point value too large";
+
+    APFloat result(isNegative ? -*val : *val);
+    bool unused;
+    result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
+    return result;
+  }
+
+  // Check for a hexadecimal float value.
+  if (tok.is(Token::integer))
+    return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics);
+
+  return emitError() << "expected floating point literal";
+}
+
 //===----------------------------------------------------------------------===//
 // CodeComplete
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index fa29264ffe506a..ab445476a91923 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -22,6 +22,12 @@ parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError,
                              const Token &tok, bool isNegative,
                              const llvm::fltSemantics &semantics);
 
+/// Parse a floating point value from a literal.
+FailureOr<APFloat>
+parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+                      const Token &tok, bool isNegative,
+                      const llvm::fltSemantics &semantics);
+
 //===----------------------------------------------------------------------===//
 // Parser
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 431c7b12b8f5fe..5098fe751fd01f 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () {
 // -----
 
 func.func @elementsattr_floattype2() -> () {
-  // expected-error at +1 {{expected floating-point elements, but parsed integer}}
+  // expected-error at below {{unexpected decimal integer literal for a floating point value}}
+  // expected-note at below {{add a trailing dot to make the literal a float}}
   "foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> ()
 }
 
@@ -138,21 +139,22 @@ func.func @float_in_int_tensor() {
 // -----
 
 func.func @float_in_bool_tensor() {
-  // expected-error @+1 {{expected integer elements, but parsed floating-point}}
+  // expected-error at below {{expected integer elements, but parsed floating-point}}
   "foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> ()
 }
 
 // -----
 
 func.func @decimal_int_in_float_tensor() {
-  // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+  // expected-error at below {{unexpected decimal integer literal for a floating point value}}
+  // expected-note at below {{add a trailing dot to make the literal a float}}
   "foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> ()
 }
 
 // -----
 
 func.func @bool_in_float_tensor() {
-  // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+  // expected-error @+1 {{expected floating point literal}}
   "foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> ()
 }
 

>From 6a59007385b335a54739d4b060103384b6c54d20 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 18 Nov 2024 09:06:22 +0100
Subject: [PATCH 3/4] address comments

---
 mlir/lib/AsmParser/AsmParserImpl.h     |   3 +-
 mlir/lib/AsmParser/AttributeParser.cpp |  14 ++--
 mlir/lib/AsmParser/Parser.cpp          | 110 ++++++++++++-------------
 mlir/lib/AsmParser/Parser.h            |  20 ++---
 4 files changed, 70 insertions(+), 77 deletions(-)

diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index bbd70d5980f8fe..d9d49d53a407d0 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -287,9 +287,8 @@ class AsmParserImpl : public BaseT {
                          APFloat &result) override {
     bool isNegative = parser.consumeIf(Token::minus);
     Token curTok = parser.getToken();
-    auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); };
     FailureOr<APFloat> apResult =
-        parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics);
+        parser.parseFloatFromLiteral(curTok, isNegative, semantics);
     if (failed(apResult))
       return failure();
     parser.consumeToken();
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 9ebada076cd042..0df3d492f411ec 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -422,9 +422,8 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
   }
 
   if (auto floatType = dyn_cast<FloatType>(type)) {
-    auto emitErrorAtTok = [&]() { return emitError(tok.getLoc()); };
     FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
-        emitErrorAtTok, tok, isNegative, floatType.getFloatSemantics());
+        tok, isNegative, floatType.getFloatSemantics());
     if (failed(result))
       return Attribute();
     return FloatAttr::get(floatType, *result);
@@ -658,9 +657,8 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
   for (const auto &signAndToken : storage) {
     bool isNegative = signAndToken.first;
     const Token &token = signAndToken.second;
-    auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
-    FailureOr<APFloat> result = parseFloatFromLiteral(
-        emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
+    FailureOr<APFloat> result =
+        p.parseFloatFromLiteral(token, isNegative, eltTy.getFloatSemantics());
     if (failed(result))
       return failure();
     floatValues.push_back(*result);
@@ -882,10 +880,8 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
 ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
   bool isNegative = p.consumeIf(Token::minus);
   Token token = p.getToken();
-  auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
-  FailureOr<APFloat> fromIntLit =
-      parseFloatFromLiteral(emitErrorAtTok, token, isNegative,
-                            cast<FloatType>(type).getFloatSemantics());
+  FailureOr<APFloat> fromIntLit = p.parseFloatFromLiteral(
+      token, isNegative, cast<FloatType>(type).getFloatSemantics());
   if (failed(fromIntLit))
     return failure();
   p.consumeToken();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 15f3dd7a66c358..e10b87ad43b139 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -67,61 +67,6 @@
 using namespace mlir;
 using namespace mlir::detail;
 
-/// Parse a floating point value from an integer literal token.
-FailureOr<APFloat> detail::parseFloatFromIntegerLiteral(
-    function_ref<InFlightDiagnostic()> emitError, const Token &tok,
-    bool isNegative, const llvm::fltSemantics &semantics) {
-  StringRef spelling = tok.getSpelling();
-  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
-  if (!isHex) {
-    auto error = emitError();
-    error << "unexpected decimal integer literal for a "
-             "floating point value";
-    error.attachNote() << "add a trailing dot to make the literal a float";
-    return failure();
-  }
-  if (isNegative) {
-    emitError() << "hexadecimal float literal should not have a "
-                   "leading minus";
-    return failure();
-  }
-
-  APInt intValue;
-  tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
-  auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
-  if (intValue.getActiveBits() > typeSizeInBits) {
-    return emitError() << "hexadecimal float constant out of range for type";
-    return failure();
-  }
-
-  APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
-                       intValue.getRawData());
-  return APFloat(semantics, truncatedValue);
-}
-
-FailureOr<APFloat>
-detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
-                              const Token &tok, bool isNegative,
-                              const llvm::fltSemantics &semantics) {
-  // Check for a floating point value.
-  if (tok.is(Token::floatliteral)) {
-    auto val = tok.getFloatingPointValue();
-    if (!val)
-      return emitError() << "floating point value too large";
-
-    APFloat result(isNegative ? -*val : *val);
-    bool unused;
-    result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
-    return result;
-  }
-
-  // Check for a hexadecimal float value.
-  if (tok.is(Token::integer))
-    return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics);
-
-  return emitError() << "expected floating point literal";
-}
-
 //===----------------------------------------------------------------------===//
 // CodeComplete
 //===----------------------------------------------------------------------===//
@@ -402,6 +347,61 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
   return success();
 }
 
+FailureOr<APFloat>
+Parser::parseFloatFromLiteral(const Token &tok, bool isNegative,
+                              const llvm::fltSemantics &semantics) {
+  // Check for a floating point value.
+  if (tok.is(Token::floatliteral)) {
+    auto val = tok.getFloatingPointValue();
+    if (!val)
+      return emitError(tok.getLoc()) << "floating point value too large";
+
+    APFloat result(isNegative ? -*val : *val);
+    bool unused;
+    result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
+    return result;
+  }
+
+  // Check for a hexadecimal float value.
+  if (tok.is(Token::integer))
+    return parseFloatFromIntegerLiteral(tok, isNegative, semantics);
+
+  return emitError(tok.getLoc()) << "expected floating point literal";
+}
+
+/// Parse a floating point value from an integer literal token.
+FailureOr<APFloat>
+Parser::parseFloatFromIntegerLiteral(const Token &tok, bool isNegative,
+                                     const llvm::fltSemantics &semantics) {
+  StringRef spelling = tok.getSpelling();
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+  if (!isHex) {
+    auto error = emitError(tok.getLoc());
+    error << "unexpected decimal integer literal for a "
+             "floating point value";
+    error.attachNote() << "add a trailing dot to make the literal a float";
+    return failure();
+  }
+  if (isNegative) {
+    emitError(tok.getLoc()) << "hexadecimal float literal should not have a "
+                               "leading minus";
+    return failure();
+  }
+
+  APInt intValue;
+  tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
+  auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
+  if (intValue.getActiveBits() > typeSizeInBits) {
+    return emitError(tok.getLoc())
+           << "hexadecimal float constant out of range for type";
+    return failure();
+  }
+
+  APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
+                       intValue.getRawData());
+  return APFloat(semantics, truncatedValue);
+}
+
 ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
   // Check that the current token is a keyword.
   if (!isCurrentTokenAKeyword())
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index ab445476a91923..15c4990de58344 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -16,17 +16,6 @@
 
 namespace mlir {
 namespace detail {
-/// Parse a floating point value from an integer literal token.
-FailureOr<APFloat>
-parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError,
-                             const Token &tok, bool isNegative,
-                             const llvm::fltSemantics &semantics);
-
-/// Parse a floating point value from a literal.
-FailureOr<APFloat>
-parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
-                      const Token &tok, bool isNegative,
-                      const llvm::fltSemantics &semantics);
 
 //===----------------------------------------------------------------------===//
 // Parser
@@ -163,6 +152,15 @@ class Parser {
   /// Parse an optional integer value only in decimal format from the stream.
   OptionalParseResult parseOptionalDecimalInteger(APInt &result);
 
+  /// Parse a floating point value from a literal.
+  FailureOr<APFloat> parseFloatFromLiteral(const Token &tok, bool isNegative,
+                                           const llvm::fltSemantics &semantics);
+
+  /// Parse a floating point value from an integer literal token.
+  FailureOr<APFloat>
+  parseFloatFromIntegerLiteral(const Token &tok, bool isNegative,
+                               const llvm::fltSemantics &semantics);
+
   /// Returns true if the current token corresponds to a keyword.
   bool isCurrentTokenAKeyword() const {
     return getToken().isAny(Token::bare_identifier, Token::inttype) ||

>From 1f81180b91597f0813d46cb2f5b422d7678a2a3c Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Mon, 18 Nov 2024 09:16:19 +0100
Subject: [PATCH 4/4] address comments 2

---
 mlir/lib/AsmParser/AsmParserImpl.h     |  6 ++--
 mlir/lib/AsmParser/AttributeParser.cpp | 19 ++++++------
 mlir/lib/AsmParser/Parser.cpp          | 42 +++++++++++++-------------
 mlir/lib/AsmParser/Parser.h            | 11 ++++---
 4 files changed, 40 insertions(+), 38 deletions(-)

diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index d9d49d53a407d0..d5b72d63813a4e 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -287,9 +287,9 @@ class AsmParserImpl : public BaseT {
                          APFloat &result) override {
     bool isNegative = parser.consumeIf(Token::minus);
     Token curTok = parser.getToken();
-    FailureOr<APFloat> apResult =
-        parser.parseFloatFromLiteral(curTok, isNegative, semantics);
-    if (failed(apResult))
+    std::optional<APFloat> apResult;
+    if (failed(parser.parseFloatFromLiteral(apResult, curTok, isNegative,
+                                            semantics)))
       return failure();
     parser.consumeToken();
     result = *apResult;
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index 0df3d492f411ec..ff616dac9625b4 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -422,9 +422,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
   }
 
   if (auto floatType = dyn_cast<FloatType>(type)) {
-    FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
-        tok, isNegative, floatType.getFloatSemantics());
-    if (failed(result))
+    std::optional<APFloat> result;
+    if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
+                                            floatType.getFloatSemantics())))
       return Attribute();
     return FloatAttr::get(floatType, *result);
   }
@@ -657,9 +657,9 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
   for (const auto &signAndToken : storage) {
     bool isNegative = signAndToken.first;
     const Token &token = signAndToken.second;
-    FailureOr<APFloat> result =
-        p.parseFloatFromLiteral(token, isNegative, eltTy.getFloatSemantics());
-    if (failed(result))
+    std::optional<APFloat> result;
+    if (failed(p.parseFloatFromLiteral(result, token, isNegative,
+                                       eltTy.getFloatSemantics())))
       return failure();
     floatValues.push_back(*result);
   }
@@ -880,9 +880,10 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
 ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
   bool isNegative = p.consumeIf(Token::minus);
   Token token = p.getToken();
-  FailureOr<APFloat> fromIntLit = p.parseFloatFromLiteral(
-      token, isNegative, cast<FloatType>(type).getFloatSemantics());
-  if (failed(fromIntLit))
+  std::optional<APFloat> fromIntLit;
+  if (failed(
+          p.parseFloatFromLiteral(fromIntLit, token, isNegative,
+                                  cast<FloatType>(type).getFloatSemantics())))
     return failure();
   p.consumeToken();
   append(fromIntLit->bitcastToAPInt());
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e10b87ad43b139..e3db248164672c 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -347,59 +347,59 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
   return success();
 }
 
-FailureOr<APFloat>
-Parser::parseFloatFromLiteral(const Token &tok, bool isNegative,
-                              const llvm::fltSemantics &semantics) {
+ParseResult Parser::parseFloatFromLiteral(std::optional<APFloat> &result,
+                                          const Token &tok, bool isNegative,
+                                          const llvm::fltSemantics &semantics) {
   // Check for a floating point value.
   if (tok.is(Token::floatliteral)) {
     auto val = tok.getFloatingPointValue();
     if (!val)
       return emitError(tok.getLoc()) << "floating point value too large";
 
-    APFloat result(isNegative ? -*val : *val);
+    result.emplace(isNegative ? -*val : *val);
     bool unused;
-    result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
-    return result;
+    result->convert(semantics, APFloat::rmNearestTiesToEven, &unused);
+    return success();
   }
 
   // Check for a hexadecimal float value.
   if (tok.is(Token::integer))
-    return parseFloatFromIntegerLiteral(tok, isNegative, semantics);
+    return parseFloatFromIntegerLiteral(result, tok, isNegative, semantics);
 
   return emitError(tok.getLoc()) << "expected floating point literal";
 }
 
 /// Parse a floating point value from an integer literal token.
-FailureOr<APFloat>
-Parser::parseFloatFromIntegerLiteral(const Token &tok, bool isNegative,
+ParseResult
+Parser::parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
+                                     const Token &tok, bool isNegative,
                                      const llvm::fltSemantics &semantics) {
   StringRef spelling = tok.getSpelling();
   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
   if (!isHex) {
-    auto error = emitError(tok.getLoc());
-    error << "unexpected decimal integer literal for a "
-             "floating point value";
-    error.attachNote() << "add a trailing dot to make the literal a float";
-    return failure();
+    return emitError(tok.getLoc(), "unexpected decimal integer literal for a "
+                                   "floating point value")
+               .attachNote()
+           << "add a trailing dot to make the literal a float";
   }
   if (isNegative) {
-    emitError(tok.getLoc()) << "hexadecimal float literal should not have a "
-                               "leading minus";
-    return failure();
+    return emitError(tok.getLoc(),
+                     "hexadecimal float literal should not have a "
+                     "leading minus");
   }
 
   APInt intValue;
   tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
   auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
   if (intValue.getActiveBits() > typeSizeInBits) {
-    return emitError(tok.getLoc())
-           << "hexadecimal float constant out of range for type";
-    return failure();
+    return emitError(tok.getLoc(),
+                     "hexadecimal float constant out of range for type");
   }
 
   APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
                        intValue.getRawData());
-  return APFloat(semantics, truncatedValue);
+  result.emplace(semantics, truncatedValue);
+  return success();
 }
 
 ParseResult Parser::parseOptionalKeyword(StringRef *keyword) {
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 15c4990de58344..4979cfc6e69e41 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -153,13 +153,14 @@ class Parser {
   OptionalParseResult parseOptionalDecimalInteger(APInt &result);
 
   /// Parse a floating point value from a literal.
-  FailureOr<APFloat> parseFloatFromLiteral(const Token &tok, bool isNegative,
-                                           const llvm::fltSemantics &semantics);
+  ParseResult parseFloatFromLiteral(std::optional<APFloat> &result,
+                                    const Token &tok, bool isNegative,
+                                    const llvm::fltSemantics &semantics);
 
   /// Parse a floating point value from an integer literal token.
-  FailureOr<APFloat>
-  parseFloatFromIntegerLiteral(const Token &tok, bool isNegative,
-                               const llvm::fltSemantics &semantics);
+  ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
+                                           const Token &tok, bool isNegative,
+                                           const llvm::fltSemantics &semantics);
 
   /// Returns true if the current token corresponds to a keyword.
   bool isCurrentTokenAKeyword() const {



More information about the Mlir-commits mailing list