[Mlir-commits] [mlir] [MLIR] Extend floating point parsing support (PR #90442)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 1 00:02:15 PDT 2024


https://github.com/orbiri updated https://github.com/llvm/llvm-project/pull/90442

>From 75938303a06972023aac87021c9ff320afc917fd Mon Sep 17 00:00:00 2001
From: Or Biri <orzivh at gmail.com>
Date: Sun, 28 Apr 2024 15:27:43 +0300
Subject: [PATCH] [MLIR] Extend floating point parsing support

Parsing support for floating point types was missing a few features:
1. Parsing floating point attributes from integer literals was supported only
   for types with bitwidth smaller or equal to 64.
2. Downstream users could not use `AsmParser::parseFloat` to parse float types
   which are printed as integer literals.

This commit addresses both these points. It extends
`Parser::parseFloatFromIntegerLiteral` to support arbitrary bitwidth, and
exposes a new API to parse arbitrary floating point given an fltSemantics as
input. The usage of this new API is introduced in the Test Dialect.
---
 mlir/include/mlir/IR/OpImplementation.h       |  4 ++
 mlir/lib/AsmParser/AsmParserImpl.h            | 28 +++++++--
 mlir/lib/AsmParser/Parser.cpp                 | 16 ++----
 mlir/test/IR/custom-float-attr-roundtrip.mlir | 57 +++++++++++++++++++
 mlir/test/IR/parser.mlir                      | 24 ++++++++
 mlir/test/lib/Dialect/Test/TestAttrDefs.td    | 11 ++++
 mlir/test/lib/Dialect/Test/TestAttributes.cpp | 41 +++++++++++++
 7 files changed, 165 insertions(+), 16 deletions(-)
 create mode 100644 mlir/test/IR/custom-float-attr-roundtrip.mlir

diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5ca..fa435cb3155ed4 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -700,6 +700,10 @@ class AsmParser {
   /// Parse a floating point value from the stream.
   virtual ParseResult parseFloat(double &result) = 0;
 
+  /// Parse a floating point value into APFloat from the stream.
+  virtual ParseResult parseFloat(const llvm::fltSemantics &semantics,
+                                 APFloat &result) = 0;
+
   /// Parse an integer value from the stream.
   template <typename IntT>
   ParseResult parseInteger(IntT &result) {
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 30c0079cda0861..8f22be80865bf8 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -269,8 +269,12 @@ class AsmParserImpl : public BaseT {
     return success();
   }
 
-  /// Parse a floating point value from the stream.
-  ParseResult parseFloat(double &result) override {
+  /// Parse a floating point value with given semantics from the stream. Since
+  /// this implementation parses the string as double precision and only
+  /// afterwards converts the value to the requested semantic, precision may be
+  /// lost.
+  ParseResult parseFloat(const llvm::fltSemantics &semantics,
+                         APFloat &result) override {
     bool isNegative = parser.consumeIf(Token::minus);
     Token curTok = parser.getToken();
     SMLoc loc = curTok.getLoc();
@@ -281,7 +285,9 @@ class AsmParserImpl : public BaseT {
       if (!val)
         return emitError(loc, "floating point value too large");
       parser.consumeToken(Token::floatliteral);
-      result = isNegative ? -*val : *val;
+      result = APFloat(isNegative ? -*val : *val);
+      bool losesInfo;
+      result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
       return success();
     }
 
@@ -289,18 +295,28 @@ class AsmParserImpl : public BaseT {
     if (curTok.is(Token::integer)) {
       std::optional<APFloat> apResult;
       if (failed(parser.parseFloatFromIntegerLiteral(
-              apResult, curTok, isNegative, APFloat::IEEEdouble(),
-              /*typeSizeInBits=*/64)))
+              apResult, curTok, isNegative, semantics,
+              APFloat::semanticsSizeInBits(semantics))))
         return failure();
 
+      result = *apResult;
       parser.consumeToken(Token::integer);
-      result = apResult->convertToDouble();
       return success();
     }
 
     return emitError(loc, "expected floating point literal");
   }
 
+  /// Parse a floating point value from the stream.
+  ParseResult parseFloat(double &result) override {
+    llvm::APFloat apResult(0.0);
+    if (parseFloat(APFloat::IEEEdouble(), apResult))
+      return failure();
+
+    result = apResult.convertToDouble();
+    return success();
+  }
+
   /// Parse an optional integer value from the stream.
   OptionalParseResult parseOptionalInteger(APInt &result) override {
     return parser.parseOptionalInteger(result);
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f..1b8b4bac1821e9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -326,19 +326,15 @@ ParseResult Parser::parseFloatFromIntegerLiteral(
                           "leading minus");
   }
 
-  std::optional<uint64_t> value = tok.getUInt64IntegerValue();
-  if (!value)
+  APInt intValue;
+  tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
+  if (intValue.getActiveBits() > typeSizeInBits)
     return emitError(loc, "hexadecimal float constant out of range for type");
 
-  if (&semantics == &APFloat::IEEEdouble()) {
-    result = APFloat(semantics, APInt(typeSizeInBits, *value));
-    return success();
-  }
+  APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
+                       intValue.getRawData());
 
-  APInt apInt(typeSizeInBits, *value);
-  if (apInt != *value)
-    return emitError(loc, "hexadecimal float constant out of range for type");
-  result = APFloat(semantics, apInt);
+  result.emplace(semantics, truncatedValue);
 
   return success();
 }
diff --git a/mlir/test/IR/custom-float-attr-roundtrip.mlir b/mlir/test/IR/custom-float-attr-roundtrip.mlir
new file mode 100644
index 00000000000000..a8da89ba7372d0
--- /dev/null
+++ b/mlir/test/IR/custom-float-attr-roundtrip.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_enum_attr_roundtrip
+func.func @test_enum_attr_roundtrip() -> () {
+  // CHECK: attr = #test.custom_float<"float" : 2.000000e+00>
+  "test.op"() {attr = #test.custom_float<"float" : 2.>} : () -> ()
+  // CHECK: attr = #test.custom_float<"double" : 2.000000e+00>
+  "test.op"() {attr = #test.custom_float<"double" : 2.>} : () -> ()
+   // CHECK: attr = #test.custom_float<"fp80" : 2.000000e+00>
+  "test.op"() {attr = #test.custom_float<"fp80" : 2.>} : () -> ()
+  // CHECK: attr = #test.custom_float<"float" : 0x7FC00000>
+  "test.op"() {attr = #test.custom_float<"float" : 0x7FC00000>} : () -> ()
+  // CHECK: attr = #test.custom_float<"double" : 0x7FF0000001000000>
+  "test.op"() {attr = #test.custom_float<"double" : 0x7FF0000001000000>} : () -> ()
+  // CHECK: attr = #test.custom_float<"fp80" : 0x7FFFC000000000100000>
+  "test.op"() {attr = #test.custom_float<"fp80" : 0x7FFFC000000000100000>} : () -> ()
+  return
+}
+
+// -----
+
+// Verify literal must be hex or float
+
+// expected-error @below {{unexpected decimal integer literal for a floating point value}}
+// expected-note @below {{add a trailing dot to make the literal a float}}
+"test.op"() {attr = #test.custom_float<"float" : 42>} : () -> ()
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr = #test.custom_float<"float" : 0x7FC000000>} : () -> ()
+
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr = #test.custom_float<"double" : 0x7FC000007FC0000000>} : () -> ()
+
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr = #test.custom_float<"fp80" : 0x7FC0000007FC0000007FC000000>} : () -> ()
+
+// -----
+
+// Value must be a floating point literal or integer literal
+
+// expected-error @below {{expected floating point literal}}
+"test.op"() {attr = #test.custom_float<"float" : "blabla">} : () -> ()
+
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index bebbb876391d07..020942e7f4c11b 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1105,6 +1105,30 @@ func.func @bfloat16_special_values() {
   return
 }
 
+// CHECK-LABEL: @f80_special_values
+func.func @f80_special_values() {
+  // F80 signaling NaNs.
+  // CHECK: arith.constant 0x7FFFE000000000000001 : f80
+  %0 = arith.constant 0x7FFFE000000000000001 : f80
+  // CHECK: arith.constant 0x7FFFB000000000000011 : f80
+  %1 = arith.constant 0x7FFFB000000000000011 : f80
+
+  // F80 quiet NaNs.
+  // CHECK: arith.constant 0x7FFFC000000000100000 : f80
+  %2 = arith.constant 0x7FFFC000000000100000 : f80
+  // CHECK: arith.constant 0x7FFFE000000001000000 : f80
+  %3 = arith.constant 0x7FFFE000000001000000 : f80
+
+  // F80 positive infinity.
+  // CHECK: arith.constant 0x7FFF8000000000000000 : f80
+  %4 = arith.constant 0x7FFF8000000000000000 : f80
+  // F80 negative infinity.
+  // CHECK: arith.constant 0xFFFF8000000000000000 : f80
+  %5 = arith.constant 0xFFFF8000000000000000 : f80
+
+  return
+}
+
 // We want to print floats in exponential notation with 6 significant digits,
 // but it may lead to precision loss when parsing back, in which case we print
 // the decimal form instead.
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 40f035a3e3a4e5..12635e107bd42c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -340,4 +340,15 @@ def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
   }];
 }
 
+// Test AsmParser::parseFloat(const fltSemnatics&, APFloat&) API through the
+// custom parser and printer.
+def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
+  let mnemonic = "custom_float";
+  let parameters = (ins "mlir::StringAttr":$type_str, APFloatParameter<"">:$value);
+
+  let assemblyFormat = [{
+    `<` custom<CustomFloatAttr>($type_str, $value) `>`
+  }];
+}
+
 #endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 2cc051e664beec..d7e40d35238d91 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/Types.h"
 #include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -240,6 +241,46 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
   p.printKeywordOrString(value);
 }
 
+//===----------------------------------------------------------------------===//
+// Custom Float Attribute
+//===----------------------------------------------------------------------===//
+
+static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr,
+                                 APFloat value) {
+  p << typeStrAttr << " : " << value;
+}
+
+static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
+                                        FailureOr<APFloat> &value) {
+
+  std::string str;
+  if (p.parseString(&str))
+    return failure();
+
+  typeStrAttr = StringAttr::get(p.getContext(), str);
+
+  if (p.parseColon())
+    return failure();
+
+  const llvm::fltSemantics *semantics;
+  if (str == "float")
+    semantics = &llvm::APFloat::IEEEsingle();
+  else if (str == "double")
+    semantics = &llvm::APFloat::IEEEdouble();
+  else if (str == "fp80")
+    semantics = &llvm::APFloat::x87DoubleExtended();
+  else
+    return p.emitError(p.getCurrentLocation(), "unknown float type, expected "
+                                               "'float', 'double' or 'fp80'");
+
+  APFloat parsedValue(0.0);
+  if (p.parseFloat(*semantics, parsedValue))
+    return failure();
+
+  value.emplace(parsedValue);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list