[Mlir-commits] [mlir] [mlir][Parser] Add `nan` and `inf` keywords (PR #116176)

Matthias Springer llvmlistbot at llvm.org
Mon Nov 18 02:10:36 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/116176

>From 80171c91f69739b697b8022b5317856e50cfc5be Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Thu, 14 Nov 2024 08:47:06 +0100
Subject: [PATCH] [mlir][Parser] Add `nan` and `inf` keywords

---
 mlir/lib/AsmParser/AttributeParser.cpp        | 32 +++++++----
 mlir/lib/AsmParser/Parser.cpp                 | 22 ++++++++
 mlir/lib/AsmParser/TokenKinds.def             |  2 +
 .../SuperVectorize/vectorize_reduction.mlir   |  4 +-
 mlir/test/Dialect/Arith/canonicalize.mlir     | 10 ++--
 .../Linalg/fusion-elementwise-ops.mlir        |  2 +-
 mlir/test/Dialect/Tosa/canonicalize.mlir      |  6 +--
 .../Tosa/constant-reciprocal-fold.mlir        |  2 +-
 mlir/test/IR/attribute.mlir                   | 54 +++++++++++++++++++
 .../tile-and-fuse-using-interface.mlir        |  2 +-
 .../tile-and-fuse-using-scfforall.mlir        |  2 +-
 mlir/test/Transforms/constant-fold.mlir       |  2 +-
 .../math-polynomial-approx.mlir               | 36 ++++++-------
 .../test-expand-math-approx.mlir              | 10 ++--
 14 files changed, 137 insertions(+), 49 deletions(-)

diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index ff616dac9625b4..4191e49a90f04b 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -21,8 +21,10 @@
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/DialectResourceBlobManager.h"
 #include "mlir/IR/IntegerSet.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Endian.h"
+#include <cmath>
 #include <optional>
 
 using namespace mlir;
@@ -121,6 +123,8 @@ Attribute Parser::parseAttribute(Type type) {
 
   // Parse floating point and integer attributes.
   case Token::floatliteral:
+  case Token::kw_inf:
+  case Token::kw_nan:
     return parseFloatAttr(type, /*isNegative=*/false);
   case Token::integer:
     return parseDecOrHexAttr(type, /*isNegative=*/false);
@@ -128,7 +132,8 @@ Attribute Parser::parseAttribute(Type type) {
     consumeToken(Token::minus);
     if (getToken().is(Token::integer))
       return parseDecOrHexAttr(type, /*isNegative=*/true);
-    if (getToken().is(Token::floatliteral))
+    if (getToken().is(Token::floatliteral) || getToken().is(Token::kw_inf) ||
+        getToken().is(Token::kw_nan))
       return parseFloatAttr(type, /*isNegative=*/true);
 
     return (emitWrongTokenError(
@@ -342,10 +347,8 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
 
 /// Parse a float attribute.
 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
-  auto val = getToken().getFloatingPointValue();
-  if (!val)
-    return (emitError("floating point value too large for attribute"), nullptr);
-  consumeToken(Token::floatliteral);
+  const Token tok = getToken();
+  consumeToken();
   if (!type) {
     // Default to F64 when no type is specified.
     if (!consumeIf(Token::colon))
@@ -353,10 +356,16 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
     else if (!(type = parseType()))
       return nullptr;
   }
-  if (!isa<FloatType>(type))
-    return (emitError("floating point value not valid for specified type"),
+  auto floatType = dyn_cast<FloatType>(type);
+  if (!floatType)
+    return (emitError(tok.getLoc(),
+                      "floating point value not valid for specified type"),
             nullptr);
-  return FloatAttr::get(type, isNegative ? -*val : *val);
+  std::optional<APFloat> apResult;
+  if (failed(parseFloatFromLiteral(apResult, tok, isNegative,
+                                   floatType.getFloatSemantics())))
+    return Attribute();
+  return FloatAttr::get(floatType, *apResult);
 }
 
 /// Construct an APint from a parsed value, a known attribute type and
@@ -622,7 +631,7 @@ TensorLiteralParser::getIntAttrElements(SMLoc loc, Type eltTy,
     }
 
     // Check to see if floating point values were parsed.
-    if (token.is(Token::floatliteral)) {
+    if (token.isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan)) {
       return p.emitError(tokenLoc)
              << "expected integer elements, but parsed floating-point";
     }
@@ -729,6 +738,8 @@ ParseResult TensorLiteralParser::parseElement() {
   // Parse a boolean element.
   case Token::kw_true:
   case Token::kw_false:
+  case Token::kw_inf:
+  case Token::kw_nan:
   case Token::floatliteral:
   case Token::integer:
     storage.emplace_back(/*isNegative=*/false, p.getToken());
@@ -738,7 +749,8 @@ ParseResult TensorLiteralParser::parseElement() {
   // Parse a signed integer or a negative floating-point element.
   case Token::minus:
     p.consumeToken(Token::minus);
-    if (!p.getToken().isAny(Token::floatliteral, Token::integer))
+    if (!p.getToken().isAny(Token::floatliteral, Token::kw_inf, Token::kw_nan,
+                            Token::integer))
       return p.emitError("expected integer or floating point literal");
     storage.emplace_back(/*isNegative=*/true, p.getToken());
     p.consumeToken();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index e3db248164672c..3783dd95f4b61c 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -350,11 +350,33 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
 ParseResult Parser::parseFloatFromLiteral(std::optional<APFloat> &result,
                                           const Token &tok, bool isNegative,
                                           const llvm::fltSemantics &semantics) {
+  // Check for inf keyword.
+  if (tok.is(Token::kw_inf)) {
+    if (!APFloat::semanticsHasInf(semantics))
+      return emitError(tok.getLoc())
+             << "floating point type does not support infinity";
+    result = APFloat::getInf(semantics, isNegative);
+    return success();
+  }
+
+  // Check for NaN keyword.
+  if (tok.is(Token::kw_nan)) {
+    if (!APFloat::semanticsHasNaN(semantics))
+      return emitError(tok.getLoc())
+             << "floating point type does not support NaN";
+    result = APFloat::getNaN(semantics, isNegative);
+    return success();
+  }
+
   // Check for a floating point value.
   if (tok.is(Token::floatliteral)) {
     auto val = tok.getFloatingPointValue();
     if (!val)
       return emitError(tok.getLoc()) << "floating point value too large";
+    if (std::fpclassify(*val) == FP_ZERO &&
+        !APFloat::semanticsHasZero(semantics))
+      return emitError(tok.getLoc())
+             << "floating point type does not support zero";
 
     result.emplace(isNegative ? -*val : *val);
     bool unused;
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 49da8c3dea5fa5..9208c8adddcfce 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -111,11 +111,13 @@ TOK_KEYWORD(floordiv)
 TOK_KEYWORD(for)
 TOK_KEYWORD(func)
 TOK_KEYWORD(index)
+TOK_KEYWORD(inf)
 TOK_KEYWORD(loc)
 TOK_KEYWORD(max)
 TOK_KEYWORD(memref)
 TOK_KEYWORD(min)
 TOK_KEYWORD(mod)
+TOK_KEYWORD(nan)
 TOK_KEYWORD(none)
 TOK_KEYWORD(offset)
 TOK_KEYWORD(size)
diff --git a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
index 29c42fcd50bd74..0fef5183e3405f 100644
--- a/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
+++ b/mlir/test/Dialect/Affine/SuperVectorize/vectorize_reduction.mlir
@@ -30,7 +30,7 @@ func.func @vecdim_reduction(%in: memref<256x512xf32>, %out: memref<256xf32>) {
 // -----
 
 func.func @vecdim_reduction_minf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
- %cst = arith.constant 0x7F800000 : f32
+ %cst = arith.constant inf : f32
  affine.for %i = 0 to 256 {
    %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
      %ld = affine.load %in[%i, %j] : memref<256x512xf32>
@@ -57,7 +57,7 @@ func.func @vecdim_reduction_minf(%in: memref<256x512xf32>, %out: memref<256xf32>
 // -----
 
 func.func @vecdim_reduction_maxf(%in: memref<256x512xf32>, %out: memref<256xf32>) {
- %cst = arith.constant 0xFF800000 : f32
+ %cst = arith.constant -inf : f32
  affine.for %i = 0 to 256 {
    %final_red = affine.for %j = 0 to 512 iter_args(%red_iter = %cst) -> (f32) {
      %ld = affine.load %in[%i, %j] : memref<256x512xf32>
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index a386a178b78995..c86b2b5f63f016 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -1880,7 +1880,7 @@ func.func @test_minimumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.minimumf %arg0, %[[C0]]
   // CHECK-NEXT:  return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %inf = arith.constant 0x7F800000 : f32
+  %inf = arith.constant inf : f32
   %0 = arith.minimumf %c0, %arg0 : f32
   %1 = arith.minimumf %arg0, %arg0 : f32
   %2 = arith.minimumf %inf, %arg0 : f32
@@ -1895,7 +1895,7 @@ func.func @test_maximumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.maximumf %arg0, %[[C0]]
   // CHECK-NEXT:   return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %-inf = arith.constant 0xFF800000 : f32
+  %-inf = arith.constant -inf : f32
   %0 = arith.maximumf %c0, %arg0 : f32
   %1 = arith.maximumf %arg0, %arg0 : f32
   %2 = arith.maximumf %-inf, %arg0 : f32
@@ -1910,7 +1910,7 @@ func.func @test_minnumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.minnumf %arg0, %[[C0]]
   // CHECK-NEXT:  return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %inf = arith.constant 0x7F800000 : f32
+  %inf = arith.constant inf : f32
   %0 = arith.minnumf %c0, %arg0 : f32
   %1 = arith.minnumf %arg0, %arg0 : f32
   %2 = arith.minnumf %inf, %arg0 : f32
@@ -1925,7 +1925,7 @@ func.func @test_maxnumf(%arg0 : f32) -> (f32, f32, f32) {
   // CHECK-NEXT:  %[[X:.+]] = arith.maxnumf %arg0, %[[C0]]
   // CHECK-NEXT:   return %[[X]], %arg0, %arg0
   %c0 = arith.constant 0.0 : f32
-  %-inf = arith.constant 0xFF800000 : f32
+  %-inf = arith.constant -inf : f32
   %0 = arith.maxnumf %c0, %arg0 : f32
   %1 = arith.maxnumf %arg0, %arg0 : f32
   %2 = arith.maxnumf %-inf, %arg0 : f32
@@ -2024,7 +2024,7 @@ func.func @test_cmpf(%arg0 : f32) -> (i1, i1, i1, i1) {
 //   CHECK-DAG:   %[[T:.*]] = arith.constant true
 //   CHECK-DAG:   %[[F:.*]] = arith.constant false
 //       CHECK:   return %[[F]], %[[F]], %[[T]], %[[T]]
-  %nan = arith.constant 0x7fffffff : f32
+  %nan = arith.constant nan : f32
   %0 = arith.cmpf olt, %nan, %arg0 : f32
   %1 = arith.cmpf olt, %arg0, %nan : f32
   %2 = arith.cmpf ugt, %nan, %arg0 : f32
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 28e1291bce1fad..13846954492d07 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -791,7 +791,7 @@ func.func @fuse_scalar_constant(%arg0 : tensor<?x?xf32>) -> (tensor<?x?xf32>, te
 // CHECK: linalg.generic
 // CHECK: linalg.generic
 func.func @no_fusion_missing_reduction_shape(%arg0: tensor<f32>, %arg1: index) -> tensor<?xf32> {
-  %cst = arith.constant 0xFF800000 : f32
+  %cst = arith.constant -inf : f32
   %4 = tensor.empty(%arg1, %arg1) : tensor<?x?xf32>
   %5 = linalg.generic {
     indexing_maps = [#map0, #map1],
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..c96fccd0d7b3b6 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -80,8 +80,7 @@ func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
 func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
   // CHECK: return %arg0
   // CHECK-NOT: "tosa.clamp"
-  // 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity.
-  %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf16>) -> tensor<4xf16>
+  %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -inf : f32, max_fp = inf : f32} : (tensor<4xf16>) -> tensor<4xf16>
   return %0 : tensor<4xf16>
 }
 
@@ -91,8 +90,7 @@ func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> {
 func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: "tosa.clamp"
-  // 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity.
-  %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf32>) -> tensor<4xf32>
+  %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -inf : f32, max_fp = inf : f32} : (tensor<4xf32>) -> tensor<4xf32>
   return %0 : tensor<4xf32>
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir
index cc71c43d53ce29..2a2098d33e10fd 100644
--- a/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-reciprocal-fold.mlir
@@ -58,7 +58,7 @@ func.func @reciprocal_div_infinity() -> tensor<f32> {
   // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}<0.{{0*}}e+00>
   // CHECK-NOT: tosa.reciprocal
   // CHECK: return [[RES]]
-  %0 = "tosa.const"() {value = dense<0x7F800000> : tensor<f32>} : () -> tensor<f32>
+  %0 = "tosa.const"() {value = dense<inf> : tensor<f32>} : () -> tensor<f32>
   %1 = "tosa.reciprocal"(%0) : (tensor<f32>) -> tensor<f32>
   return %1 : tensor<f32>
 }
diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir
index a62de3f5004d73..1fbb4986ab2ea7 100644
--- a/mlir/test/IR/attribute.mlir
+++ b/mlir/test/IR/attribute.mlir
@@ -108,9 +108,63 @@ func.func @float_attrs_pass() {
     // CHECK: float_attr = 2.000000e+00 : f128
     float_attr = 2. : f128
   } : () -> ()
+  "test.float_attrs"() {
+    // Note: nan/inf are printed in binary format because there may be multiple
+    // nan/inf representations.
+    // CHECK: float_attr = 0x7FC00000 : f32
+    float_attr = nan : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0x7C : f8E4M3
+    float_attr = nan : f8E4M3
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xFFC00000 : f32
+    float_attr = -nan : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xFC : f8E4M3
+    float_attr = -nan : f8E4M3
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0x7F800000 : f32
+    float_attr = inf : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0x78 : f8E4M3
+    float_attr = inf : f8E4M3
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xFF800000 : f32
+    float_attr = -inf : f32
+  } : () -> ()
+  "test.float_attrs"() {
+    // CHECK: float_attr = 0xF8 : f8E4M3
+    float_attr = -inf : f8E4M3
+  } : () -> ()
   return
 }
 
+// -----
+
+func.func @float_nan_unsupported() {
+  "test.float_attrs"() {
+    // expected-error @below{{floating point type does not support NaN}}
+    float_attr = nan : f4E2M1FN
+  } : () -> ()
+}
+
+// -----
+
+func.func @float_inf_unsupported() {
+  "test.float_attrs"() {
+    // expected-error @below{{floating point type does not support infinity}}
+    float_attr = inf : f4E2M1FN
+  } : () -> ()
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // Test integer attributes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 5f7663af773a4a..dfd26ef9264cf0 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -470,7 +470,7 @@ module attributes {transform.with_named_sequence} {
 
 func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
   %cst = arith.constant 0.000000e+00 : f32
-  %cst_0 = arith.constant 0xFF800000 : f32
+  %cst_0 = arith.constant -inf : f32
   %0 = tensor.empty() : tensor<30xf32>
   %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
   %2 = linalg.generic {
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir
index 0bd2546e082b5a..5dd622e6fd51df 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir
@@ -101,7 +101,7 @@ module attributes {transform.with_named_sequence} {
 
 func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
   %cst = arith.constant 0.000000e+00 : f32
-  %cst_0 = arith.constant 0xFF800000 : f32
+  %cst_0 = arith.constant -inf : f32
   %0 = tensor.empty() : tensor<30xf32>
   %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32>
   %2 = linalg.generic {
diff --git a/mlir/test/Transforms/constant-fold.mlir b/mlir/test/Transforms/constant-fold.mlir
index 981757aed9b1d6..f6d1873dc237a0 100644
--- a/mlir/test/Transforms/constant-fold.mlir
+++ b/mlir/test/Transforms/constant-fold.mlir
@@ -752,7 +752,7 @@ func.func @cmpf_nan() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
 // CHECK-LABEL: func @cmpf_inf
 func.func @cmpf_inf() -> (i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1, i1) {
   %c42 = arith.constant 42. : f32
-  %cpinf = arith.constant 0x7F800000 : f32
+  %cpinf = arith.constant inf : f32
   // CHECK-DAG: [[F:%.+]] = arith.constant false
   // CHECK-DAG: [[T:%.+]] = arith.constant true
   // CHECK-NEXT: return [[F]],
diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
index b8861198d596b0..28b656b0da5f1a 100644
--- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir
@@ -41,7 +41,7 @@ func.func @tanh() {
   call @tanh_8xf32(%v2) : (vector<8xf32>) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @tanh_f32(%nan) : (f32) -> ()
 
  return
@@ -87,15 +87,15 @@ func.func @log() {
   call @log_f32(%zero) : (f32) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @log_f32(%nan) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @log_f32(%inf) : (f32) -> ()
 
   // CHECK: -inf, nan, inf, 0.693147
-  %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 2.0]> : vector<4xf32>
+  %special_vec = arith.constant dense<[0.0, -1.0, inf, 2.0]> : vector<4xf32>
   call @log_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   return
@@ -141,11 +141,11 @@ func.func @log2() {
   call @log2_f32(%neg_one) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @log2_f32(%inf) : (f32) -> ()
 
   // CHECK: -inf, nan, inf, 1.58496
-  %special_vec = arith.constant dense<[0.0, -1.0, 0x7f800000, 3.0]> : vector<4xf32>
+  %special_vec = arith.constant dense<[0.0, -1.0, inf, 3.0]> : vector<4xf32>
   call @log2_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   return
@@ -192,11 +192,11 @@ func.func @log1p() {
   call @log1p_f32(%neg_two) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @log1p_f32(%inf) : (f32) -> ()
 
   // CHECK: -inf, nan, inf, 9.99995e-06
-  %special_vec = arith.constant dense<[-1.0, -1.1, 0x7f800000, 0.00001]> : vector<4xf32>
+  %special_vec = arith.constant dense<[-1.0, -1.1, inf, 0.00001]> : vector<4xf32>
   call @log1p_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   return
@@ -247,7 +247,7 @@ func.func @erf() {
   call @erf_f32(%val7) : (f32) -> ()
 
   // CHECK: -1
-  %negativeInf = arith.constant 0xff800000 : f32
+  %negativeInf = arith.constant -inf : f32
   call @erf_f32(%negativeInf) : (f32) -> ()
 
   // CHECK: -1, -1, -0.913759, -0.731446
@@ -263,11 +263,11 @@ func.func @erf() {
   call @erf_4xf32(%vecVals3) : (vector<4xf32>) -> ()
 
   // CHECK: 1
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @erf_f32(%inf) : (f32) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @erf_f32(%nan) : (f32) -> ()
 
   return
@@ -306,15 +306,15 @@ func.func @exp() {
   call @exp_4xf32(%special_vec) : (vector<4xf32>) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @exp_f32(%inf) : (f32) -> ()
 
   // CHECK: 0
-  %negative_inf = arith.constant 0xff800000 : f32
+  %negative_inf = arith.constant -inf : f32
   call @exp_f32(%negative_inf) : (f32) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @exp_f32(%nan) : (f32) -> ()
 
   return
@@ -358,19 +358,19 @@ func.func @expm1() {
   call @expm1_8xf32(%v2) : (vector<8xf32>) -> ()
 
   // CHECK: -1
-  %neg_inf = arith.constant 0xff800000 : f32
+  %neg_inf = arith.constant -inf : f32
   call @expm1_f32(%neg_inf) : (f32) -> ()
 
   // CHECK: inf
-  %inf = arith.constant 0x7f800000 : f32
+  %inf = arith.constant inf : f32
   call @expm1_f32(%inf) : (f32) -> ()
 
   // CHECK: -1, inf, 1e-10
-  %special_vec = arith.constant dense<[0xff800000, 0x7f800000, 1.0e-10]> : vector<3xf32>
+  %special_vec = arith.constant dense<[-inf, inf, 1.0e-10]> : vector<3xf32>
   call @expm1_3xf32(%special_vec) : (vector<3xf32>) -> ()
 
   // CHECK: nan
-  %nan = arith.constant 0x7fc00000 : f32
+  %nan = arith.constant nan : f32
   call @expm1_f32(%nan) : (f32) -> ()
 
   return
diff --git a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
index 80d559cc6f730b..fd076d47cad955 100644
--- a/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
+++ b/mlir/test/mlir-cpu-runner/test-expand-math-approx.mlir
@@ -47,7 +47,7 @@ func.func @exp2f() {
   call @func_exp2f(%g) : (f64) -> ()
 
   // CHECK-NEXT: 0
-  %neg_inf = arith.constant 0xff80000000000000 : f64
+  %neg_inf = arith.constant -inf : f64
   call @func_exp2f(%neg_inf) : (f64) -> ()
 
   // CHECK-NEXT: inf
@@ -114,7 +114,7 @@ func.func @roundf() {
   // Special values: 0, -0, inf, -inf, nan, -nan
   %cNeg0 = arith.constant -0.0 : f32
   %c0 = arith.constant 0.0 : f32
-  %cInfInt = arith.constant 0x7f800000 : i32
+  %cInfInt = arith.constant 0xff800000 : i32
   %cInf = arith.bitcast %cInfInt : i32 to f32
   %cNegInfInt = arith.constant 0xff800000 : i32
   %cNegInf = arith.bitcast %cNegInfInt : i32 to f32
@@ -229,7 +229,7 @@ func.func @powf() {
 
   // CHECK-NEXT: nan
   %i = arith.constant 1.0 : f64
-  %h = arith.constant 0x7fffffffffffffff : f64
+  %h = arith.constant nan : f64
   call @func_powff64(%i, %h) : (f64, f64) -> ()
 
   // CHECK-NEXT: inf
@@ -370,7 +370,7 @@ func.func @roundeven32() {
   // Special values: 0, -0, inf, -inf, nan, -nan
   %cNeg0 = arith.constant -0.0 : f32
   %c0 = arith.constant 0.0 : f32
-  %cInfInt = arith.constant 0x7f800000 : i32
+  %cInfInt = arith.constant 0xff800000 : i32
   %cInf = arith.bitcast %cInfInt : i32 to f32
   %cNegInfInt = arith.constant 0xff800000 : i32
   %cNegInf = arith.bitcast %cNegInfInt : i32 to f32
@@ -711,7 +711,7 @@ func.func @tanh_8xf32(%a : vector<8xf32>) {
 
 func.func @tanh() {
   // CHECK: -1, -0.761594, -0.291313, 0, 0.291313, 0.761594, 1, 1
-  %v3 = arith.constant dense<[0xff800000, -1.0, -0.3, 0.0, 0.3, 1.0, 10.0, 0x7f800000]> : vector<8xf32>
+  %v3 = arith.constant dense<[-inf, -1.0, -0.3, 0.0, 0.3, 1.0, 10.0, inf]> : vector<8xf32>
   call @tanh_8xf32(%v3) : (vector<8xf32>) -> ()
 
  return



More information about the Mlir-commits mailing list