[Mlir-commits] [mlir] [MLIR] Extend floating point parsing support (PR #90442)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 1 00:02:15 PDT 2024
https://github.com/orbiri updated https://github.com/llvm/llvm-project/pull/90442
>From 75938303a06972023aac87021c9ff320afc917fd Mon Sep 17 00:00:00 2001
From: Or Biri <orzivh at gmail.com>
Date: Sun, 28 Apr 2024 15:27:43 +0300
Subject: [PATCH] [MLIR] Extend floating point parsing support
Parsing support for floating point types was missing a few features:
1. Parsing floating point attributes from integer literals was supported only
for types with bitwidth smaller or equal to 64.
2. Downstream users could not use `AsmParser::parseFloat` to parse float types
which are printed as integer literals.
This commit addresses both these points. It extends
`Parser::parseFloatFromIntegerLiteral` to support arbitrary bitwidth, and
exposes a new API to parse arbitrary floating point given an fltSemantics as
input. The usage of this new API is introduced in the Test Dialect.
---
mlir/include/mlir/IR/OpImplementation.h | 4 ++
mlir/lib/AsmParser/AsmParserImpl.h | 28 +++++++--
mlir/lib/AsmParser/Parser.cpp | 16 ++----
mlir/test/IR/custom-float-attr-roundtrip.mlir | 57 +++++++++++++++++++
mlir/test/IR/parser.mlir | 24 ++++++++
mlir/test/lib/Dialect/Test/TestAttrDefs.td | 11 ++++
mlir/test/lib/Dialect/Test/TestAttributes.cpp | 41 +++++++++++++
7 files changed, 165 insertions(+), 16 deletions(-)
create mode 100644 mlir/test/IR/custom-float-attr-roundtrip.mlir
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5ca..fa435cb3155ed4 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -700,6 +700,10 @@ class AsmParser {
/// Parse a floating point value from the stream.
virtual ParseResult parseFloat(double &result) = 0;
+ /// Parse a floating point value into APFloat from the stream.
+ virtual ParseResult parseFloat(const llvm::fltSemantics &semantics,
+ APFloat &result) = 0;
+
/// Parse an integer value from the stream.
template <typename IntT>
ParseResult parseInteger(IntT &result) {
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index 30c0079cda0861..8f22be80865bf8 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -269,8 +269,12 @@ class AsmParserImpl : public BaseT {
return success();
}
- /// Parse a floating point value from the stream.
- ParseResult parseFloat(double &result) override {
+ /// Parse a floating point value with given semantics from the stream. Since
+ /// this implementation parses the string as double precision and only
+ /// afterwards converts the value to the requested semantic, precision may be
+ /// lost.
+ ParseResult parseFloat(const llvm::fltSemantics &semantics,
+ APFloat &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
SMLoc loc = curTok.getLoc();
@@ -281,7 +285,9 @@ class AsmParserImpl : public BaseT {
if (!val)
return emitError(loc, "floating point value too large");
parser.consumeToken(Token::floatliteral);
- result = isNegative ? -*val : *val;
+ result = APFloat(isNegative ? -*val : *val);
+ bool losesInfo;
+ result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
return success();
}
@@ -289,18 +295,28 @@ class AsmParserImpl : public BaseT {
if (curTok.is(Token::integer)) {
std::optional<APFloat> apResult;
if (failed(parser.parseFloatFromIntegerLiteral(
- apResult, curTok, isNegative, APFloat::IEEEdouble(),
- /*typeSizeInBits=*/64)))
+ apResult, curTok, isNegative, semantics,
+ APFloat::semanticsSizeInBits(semantics))))
return failure();
+ result = *apResult;
parser.consumeToken(Token::integer);
- result = apResult->convertToDouble();
return success();
}
return emitError(loc, "expected floating point literal");
}
+ /// Parse a floating point value from the stream.
+ ParseResult parseFloat(double &result) override {
+ llvm::APFloat apResult(0.0);
+ if (parseFloat(APFloat::IEEEdouble(), apResult))
+ return failure();
+
+ result = apResult.convertToDouble();
+ return success();
+ }
+
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(APInt &result) override {
return parser.parseOptionalInteger(result);
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 00f2b0c0c2f12f..1b8b4bac1821e9 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -326,19 +326,15 @@ ParseResult Parser::parseFloatFromIntegerLiteral(
"leading minus");
}
- std::optional<uint64_t> value = tok.getUInt64IntegerValue();
- if (!value)
+ APInt intValue;
+ tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
+ if (intValue.getActiveBits() > typeSizeInBits)
return emitError(loc, "hexadecimal float constant out of range for type");
- if (&semantics == &APFloat::IEEEdouble()) {
- result = APFloat(semantics, APInt(typeSizeInBits, *value));
- return success();
- }
+ APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
+ intValue.getRawData());
- APInt apInt(typeSizeInBits, *value);
- if (apInt != *value)
- return emitError(loc, "hexadecimal float constant out of range for type");
- result = APFloat(semantics, apInt);
+ result.emplace(semantics, truncatedValue);
return success();
}
diff --git a/mlir/test/IR/custom-float-attr-roundtrip.mlir b/mlir/test/IR/custom-float-attr-roundtrip.mlir
new file mode 100644
index 00000000000000..a8da89ba7372d0
--- /dev/null
+++ b/mlir/test/IR/custom-float-attr-roundtrip.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
+
+// CHECK-LABEL: @test_enum_attr_roundtrip
+func.func @test_enum_attr_roundtrip() -> () {
+ // CHECK: attr = #test.custom_float<"float" : 2.000000e+00>
+ "test.op"() {attr = #test.custom_float<"float" : 2.>} : () -> ()
+ // CHECK: attr = #test.custom_float<"double" : 2.000000e+00>
+ "test.op"() {attr = #test.custom_float<"double" : 2.>} : () -> ()
+ // CHECK: attr = #test.custom_float<"fp80" : 2.000000e+00>
+ "test.op"() {attr = #test.custom_float<"fp80" : 2.>} : () -> ()
+ // CHECK: attr = #test.custom_float<"float" : 0x7FC00000>
+ "test.op"() {attr = #test.custom_float<"float" : 0x7FC00000>} : () -> ()
+ // CHECK: attr = #test.custom_float<"double" : 0x7FF0000001000000>
+ "test.op"() {attr = #test.custom_float<"double" : 0x7FF0000001000000>} : () -> ()
+ // CHECK: attr = #test.custom_float<"fp80" : 0x7FFFC000000000100000>
+ "test.op"() {attr = #test.custom_float<"fp80" : 0x7FFFC000000000100000>} : () -> ()
+ return
+}
+
+// -----
+
+// Verify literal must be hex or float
+
+// expected-error @below {{unexpected decimal integer literal for a floating point value}}
+// expected-note @below {{add a trailing dot to make the literal a float}}
+"test.op"() {attr = #test.custom_float<"float" : 42>} : () -> ()
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr = #test.custom_float<"float" : 0x7FC000000>} : () -> ()
+
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr = #test.custom_float<"double" : 0x7FC000007FC0000000>} : () -> ()
+
+
+// -----
+
+// Integer value must be in the width of the floating point type
+
+// expected-error @below {{hexadecimal float constant out of range for type}}
+"test.op"() {attr = #test.custom_float<"fp80" : 0x7FC0000007FC0000007FC000000>} : () -> ()
+
+// -----
+
+// Value must be a floating point literal or integer literal
+
+// expected-error @below {{expected floating point literal}}
+"test.op"() {attr = #test.custom_float<"float" : "blabla">} : () -> ()
+
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index bebbb876391d07..020942e7f4c11b 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1105,6 +1105,30 @@ func.func @bfloat16_special_values() {
return
}
+// CHECK-LABEL: @f80_special_values
+func.func @f80_special_values() {
+ // F80 signaling NaNs.
+ // CHECK: arith.constant 0x7FFFE000000000000001 : f80
+ %0 = arith.constant 0x7FFFE000000000000001 : f80
+ // CHECK: arith.constant 0x7FFFB000000000000011 : f80
+ %1 = arith.constant 0x7FFFB000000000000011 : f80
+
+ // F80 quiet NaNs.
+ // CHECK: arith.constant 0x7FFFC000000000100000 : f80
+ %2 = arith.constant 0x7FFFC000000000100000 : f80
+ // CHECK: arith.constant 0x7FFFE000000001000000 : f80
+ %3 = arith.constant 0x7FFFE000000001000000 : f80
+
+ // F80 positive infinity.
+ // CHECK: arith.constant 0x7FFF8000000000000000 : f80
+ %4 = arith.constant 0x7FFF8000000000000000 : f80
+ // F80 negative infinity.
+ // CHECK: arith.constant 0xFFFF8000000000000000 : f80
+ %5 = arith.constant 0xFFFF8000000000000000 : f80
+
+ return
+}
+
// We want to print floats in exponential notation with 6 significant digits,
// but it may lead to precision loss when parsing back, in which case we print
// the decimal form instead.
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 40f035a3e3a4e5..12635e107bd42c 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -340,4 +340,15 @@ def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
}];
}
+// Test AsmParser::parseFloat(const fltSemnatics&, APFloat&) API through the
+// custom parser and printer.
+def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
+ let mnemonic = "custom_float";
+ let parameters = (ins "mlir::StringAttr":$type_str, APFloatParameter<"">:$value);
+
+ let assemblyFormat = [{
+ `<` custom<CustomFloatAttr>($type_str, $value) `>`
+ }];
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
index 2cc051e664beec..d7e40d35238d91 100644
--- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -240,6 +241,46 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
p.printKeywordOrString(value);
}
+//===----------------------------------------------------------------------===//
+// Custom Float Attribute
+//===----------------------------------------------------------------------===//
+
+static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr,
+ APFloat value) {
+ p << typeStrAttr << " : " << value;
+}
+
+static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
+ FailureOr<APFloat> &value) {
+
+ std::string str;
+ if (p.parseString(&str))
+ return failure();
+
+ typeStrAttr = StringAttr::get(p.getContext(), str);
+
+ if (p.parseColon())
+ return failure();
+
+ const llvm::fltSemantics *semantics;
+ if (str == "float")
+ semantics = &llvm::APFloat::IEEEsingle();
+ else if (str == "double")
+ semantics = &llvm::APFloat::IEEEdouble();
+ else if (str == "fp80")
+ semantics = &llvm::APFloat::x87DoubleExtended();
+ else
+ return p.emitError(p.getCurrentLocation(), "unknown float type, expected "
+ "'float', 'double' or 'fp80'");
+
+ APFloat parsedValue(0.0);
+ if (p.parseFloat(*semantics, parsedValue))
+ return failure();
+
+ value.emplace(parsedValue);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Tablegen Generated Definitions
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list