[llvm-branch-commits] [mlir] [mlir][Parser] Deduplicate fp parsing functionality (PR #116172)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Nov 13 22:45:59 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/116172

The following functionality is duplicated in multiple places: trying to parse an APFloat from a floating point literal or an integer in hexadecimal representation (bit pattern). Move it to a common helper function.

NFC apart from the slightly changed error messages.

Depends on #116171.

>From 51530aeea8c18804034881c87236d1ab5ceb274f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 14 Nov 2024 07:43:08 +0100
Subject: [PATCH] [mlir][Parser] Deduplicate fp parsing functionality

---
 mlir/lib/AsmParser/AsmParserImpl.h           | 33 ++-------
 mlir/lib/AsmParser/AttributeParser.cpp       | 71 ++++----------------
 mlir/lib/AsmParser/Parser.cpp                | 23 +++++++
 mlir/lib/AsmParser/Parser.h                  |  6 ++
 mlir/test/IR/invalid-builtin-attributes.mlir | 10 +--
 5 files changed, 56 insertions(+), 87 deletions(-)

diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 1e6cbc0ec51beb..bbd70d5980f8fe 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -288,32 +288,13 @@ class AsmParserImpl : public BaseT {
     bool isNegative = parser.consumeIf(Token::minus);
     Token curTok = parser.getToken();
     auto emitErrorAtTok = [&]() { return emitError(curTok.getLoc(), ""); };
-
-    // Check for a floating point value.
-    if (curTok.is(Token::floatliteral)) {
-      auto val = curTok.getFloatingPointValue();
-      if (!val)
-        return emitErrorAtTok() << "floating point value too large";
-      parser.consumeToken(Token::floatliteral);
-      result = APFloat(isNegative ? -*val : *val);
-      bool losesInfo;
-      result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
-      return success();
-    }
-
-    // Check for a hexadecimal float value.
-    if (curTok.is(Token::integer)) {
-      FailureOr<APFloat> apResult = parseFloatFromIntegerLiteral(
-          emitErrorAtTok, curTok, isNegative, semantics);
-      if (failed(apResult))
-        return failure();
-
-      result = *apResult;
-      parser.consumeToken(Token::integer);
-      return success();
-    }
-
-    return emitErrorAtTok() << "expected floating point literal";
+    FailureOr<APFloat> apResult =
+        parseFloatFromLiteral(emitErrorAtTok, curTok, isNegative, semantics);
+    if (failed(apResult))
+      return failure();
+    parser.consumeToken();
+    result = *apResult;
+    return success();
   }
 
   /// Parse a floating point value from the stream.
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index ba9be3b030453a..9ebada076cd042 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -658,36 +658,12 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
   for (const auto &signAndToken : storage) {
     bool isNegative = signAndToken.first;
     const Token &token = signAndToken.second;
-
-    // Handle hexadecimal float literals.
-    if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
-      auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
-      FailureOr<APFloat> result = parseFloatFromIntegerLiteral(
-          emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
-      if (failed(result))
-        return failure();
-
-      floatValues.push_back(*result);
-      continue;
-    }
-
-    // Check to see if any decimal integers or booleans were parsed.
-    if (!token.is(Token::floatliteral))
-      return p.emitError()
-             << "expected floating-point elements, but parsed integer";
-
-    // Build the float values from tokens.
-    auto val = token.getFloatingPointValue();
-    if (!val)
-      return p.emitError("floating point value too large for attribute");
-
-    APFloat apVal(isNegative ? -*val : *val);
-    if (!eltTy.isF64()) {
-      bool unused;
-      apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
-                    &unused);
-    }
-    floatValues.push_back(apVal);
+    auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+    FailureOr<APFloat> result = parseFloatFromLiteral(
+        emitErrorAtTok, token, isNegative, eltTy.getFloatSemantics());
+    if (failed(result))
+      return failure();
+    floatValues.push_back(*result);
   }
   return success();
 }
@@ -905,34 +881,15 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
 
 ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
   bool isNegative = p.consumeIf(Token::minus);
-
   Token token = p.getToken();
-  std::optional<APFloat> result;
-  auto floatType = cast<FloatType>(type);
-  if (p.consumeIf(Token::integer)) {
-    // Parse an integer literal as a float.
-    auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
-    FailureOr<APFloat> fromIntLit = parseFloatFromIntegerLiteral(
-        emitErrorAtTok, token, isNegative, floatType.getFloatSemantics());
-    if (failed(fromIntLit))
-      return failure();
-    result = *fromIntLit;
-  } else if (p.consumeIf(Token::floatliteral)) {
-    // Parse a floating point literal.
-    std::optional<double> val = token.getFloatingPointValue();
-    if (!val)
-      return failure();
-    result = APFloat(isNegative ? -*val : *val);
-    if (!type.isF64()) {
-      bool unused;
-      result->convert(floatType.getFloatSemantics(),
-                      APFloat::rmNearestTiesToEven, &unused);
-    }
-  } else {
-    return p.emitError("expected integer or floating point literal");
-  }
-
-  append(result->bitcastToAPInt());
+  auto emitErrorAtTok = [&]() { return p.emitError(token.getLoc()); };
+  FailureOr<APFloat> fromIntLit =
+      parseFloatFromLiteral(emitErrorAtTok, token, isNegative,
+                            cast<FloatType>(type).getFloatSemantics());
+  if (failed(fromIntLit))
+    return failure();
+  p.consumeToken();
+  append(fromIntLit->bitcastToAPInt());
   return success();
 }
 
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index ac7eec931b1250..15f3dd7a66c358 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -99,6 +99,29 @@ FailureOr<APFloat> detail::parseFloatFromIntegerLiteral(
   return APFloat(semantics, truncatedValue);
 }
 
+FailureOr<APFloat>
+detail::parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+                              const Token &tok, bool isNegative,
+                              const llvm::fltSemantics &semantics) {
+  // Check for a floating point value.
+  if (tok.is(Token::floatliteral)) {
+    auto val = tok.getFloatingPointValue();
+    if (!val)
+      return emitError() << "floating point value too large";
+
+    APFloat result(isNegative ? -*val : *val);
+    bool unused;
+    result.convert(semantics, APFloat::rmNearestTiesToEven, &unused);
+    return result;
+  }
+
+  // Check for a hexadecimal float value.
+  if (tok.is(Token::integer))
+    return parseFloatFromIntegerLiteral(emitError, tok, isNegative, semantics);
+
+  return emitError() << "expected floating point literal";
+}
+
 //===----------------------------------------------------------------------===//
 // CodeComplete
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index fa29264ffe506a..ab445476a91923 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -22,6 +22,12 @@ parseFloatFromIntegerLiteral(function_ref<InFlightDiagnostic()> emitError,
                              const Token &tok, bool isNegative,
                              const llvm::fltSemantics &semantics);
 
+/// Parse a floating point value from a literal.
+FailureOr<APFloat>
+parseFloatFromLiteral(function_ref<InFlightDiagnostic()> emitError,
+                      const Token &tok, bool isNegative,
+                      const llvm::fltSemantics &semantics);
+
 //===----------------------------------------------------------------------===//
 // Parser
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index 431c7b12b8f5fe..5098fe751fd01f 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () {
 // -----
 
 func.func @elementsattr_floattype2() -> () {
-  // expected-error at +1 {{expected floating-point elements, but parsed integer}}
+  // expected-error at below {{unexpected decimal integer literal for a floating point value}}
+  // expected-note at below {{add a trailing dot to make the literal a float}}
   "foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> ()
 }
 
@@ -138,21 +139,22 @@ func.func @float_in_int_tensor() {
 // -----
 
 func.func @float_in_bool_tensor() {
-  // expected-error @+1 {{expected integer elements, but parsed floating-point}}
+  // expected-error at below {{expected integer elements, but parsed floating-point}}
   "foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> ()
 }
 
 // -----
 
 func.func @decimal_int_in_float_tensor() {
-  // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+  // expected-error at below {{unexpected decimal integer literal for a floating point value}}
+  // expected-note at below {{add a trailing dot to make the literal a float}}
   "foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> ()
 }
 
 // -----
 
 func.func @bool_in_float_tensor() {
-  // expected-error @+1 {{expected floating-point elements, but parsed integer}}
+  // expected-error @+1 {{expected floating point literal}}
   "foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> ()
 }
 



More information about the llvm-branch-commits mailing list