[Mlir-commits] [mlir] 7d59f49 - [mlir] Fix representation of BF16 constants

Diego Caballero llvmlistbot at llvm.org
Fri Jun 5 17:46:25 PDT 2020


Author: Diego Caballero
Date: 2020-06-05T17:43:06-07:00
New Revision: 7d59f49bdaddf053d74de9ef57c7ec64bdf4fa25

URL: https://github.com/llvm/llvm-project/commit/7d59f49bdaddf053d74de9ef57c7ec64bdf4fa25
DIFF: https://github.com/llvm/llvm-project/commit/7d59f49bdaddf053d74de9ef57c7ec64bdf4fa25.diff

LOG: [mlir] Fix representation of BF16 constants

This patch is a follow-up on https://reviews.llvm.org/D81127

BF16 constants were represented as 64-bit floating point values due to the lack
of support for BF16 in APFloat. APFloat was recently extended to support
BF16 so this patch is fixing the BF16 constant representation to be 16-bit.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D81218

Added: 
    

Modified: 
    mlir/lib/IR/AttributeDetail.h
    mlir/lib/IR/StandardTypes.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/IR/dense-elements-hex.mlir
    mlir/test/IR/parser.mlir
    mlir/test/Target/llvmir.mlir
    mlir/unittests/IR/AttributeTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index f0b7794c2fe2..e6c9ae5ed59c 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -132,9 +132,7 @@ struct FloatAttributeStorage final
 
   /// Construct a key with a type and double.
   static KeyTy getKey(Type type, double value) {
-    // Treat BF16 as double because it is not supported in LLVM's APFloat.
-    // TODO(b/121118307): add BF16 support to APFloat?
-    if (type.isBF16() || type.isF64())
+    if (type.isF64())
       return KeyTy(type, APFloat(value));
 
     // This handles, e.g., F16 because there is no APFloat constructor for it.
@@ -355,10 +353,6 @@ inline size_t getDenseElementBitWidth(Type eltType) {
   // Align the width for complex to 8 to make storage and interpretation easier.
   if (ComplexType comp = eltType.dyn_cast<ComplexType>())
     return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2;
-  // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored
-  // with double semantics.
-  if (eltType.isBF16())
-    return 64;
   if (eltType.isIndex())
     return IndexType::kInternalStorageBitWidth;
   return eltType.getIntOrFloatBitWidth();

diff  --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp
index 808b4fc910d2..7a823ee49cfd 100644
--- a/mlir/lib/IR/StandardTypes.cpp
+++ b/mlir/lib/IR/StandardTypes.cpp
@@ -157,12 +157,7 @@ unsigned FloatType::getWidth() {
 /// Returns the floating semantics for the given type.
 const llvm::fltSemantics &FloatType::getFloatSemantics() {
   if (isBF16())
-    // Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
-    // not defined in LLVM.
-    // TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
-    // else one could add it.
-    //  static const fltSemantics semBF16 = {127, -126, 8, 16};
-    return APFloat::IEEEdouble();
+    return APFloat::BFloat();
   if (isF16())
     return APFloat::IEEEhalf();
   if (isF32())

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index d5108a4ed29e..88e576c36df7 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1774,9 +1774,7 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
 /// Construct a float attribute bitwise equivalent to the integer literal.
 static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
                                                       uint64_t value) {
-  // FIXME: bfloat is currently stored as a double internally because it doesn't
-  // have valid APFloat semantics.
-  if (type.isF64() || type.isBF16())
+  if (type.isF64())
     return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
 
   APInt apInt(type.getWidth(), value);
@@ -2153,9 +2151,8 @@ TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
     if (!val.hasValue())
       return p.emitError("floating point value too large for attribute");
 
-    // Treat BF16 as double because it is not supported in LLVM's APFloat.
     APFloat apVal(isNegative ? -*val : *val);
-    if (!eltTy.isBF16() && !eltTy.isF64()) {
+    if (!eltTy.isF64()) {
       bool unused;
       apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
                     &unused);

diff  --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir
index e0e12418e1d5..4b53467d8285 100644
--- a/mlir/test/IR/dense-elements-hex.mlir
+++ b/mlir/test/IR/dense-elements-hex.mlir
@@ -11,7 +11,7 @@
 "foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144000000000000024400000000000001440"> : tensor<2xcomplex<f64>>} : () -> ()
 
 // CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xbf16>
-"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xbf16>} : () -> ()
+"foo.op"() {dense.attr = dense<"0x2041A040"> : tensor<2xbf16>} : () -> ()
 
 // -----
 

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index ba44f992093f..733f04d3a690 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -1073,28 +1073,26 @@ func @f64_special_values() {
   return
 }
 
-// FIXME: bfloat16 currently uses f64 as a storage format. This test should be
-// changed when that gets fixed.
 // CHECK-LABEL: @bfloat16_special_values
 func @bfloat16_special_values() {
   // bfloat16 signaling NaNs.
-  // CHECK: constant 0x7FF0000000000001 : bf16
-  %0 = constant 0x7FF0000000000001 : bf16
-  // CHECK: constant 0x7FF8000000000000 : bf16
-  %1 = constant 0x7FF8000000000000 : bf16
+  // CHECK: constant 0x7F81 : bf16
+  %0 = constant 0x7F81 : bf16
+  // CHECK: constant 0xFF81 : bf16
+  %1 = constant 0xFF81 : bf16
 
   // bfloat16 quiet NaNs.
-  // CHECK: constant 0x7FF0000001000000 : bf16
-  %2 = constant 0x7FF0000001000000 : bf16
-  // CHECK: constant 0xFFF0000001000000 : bf16
-  %3 = constant 0xFFF0000001000000 : bf16
+  // CHECK: constant 0x7FC0 : bf16
+  %2 = constant 0x7FC0 : bf16
+  // CHECK: constant 0xFFC0 : bf16
+  %3 = constant 0xFFC0 : bf16
 
   // bfloat16 positive infinity.
-  // CHECK: constant 0x7FF0000000000000 : bf16
-  %4 = constant 0x7FF0000000000000 : bf16
+  // CHECK: constant 0x7F80 : bf16
+  %4 = constant 0x7F80 : bf16
   // bfloat16 negative infinity.
-  // CHECK: constant 0xFFF0000000000000 : bf16
-  %5 = constant 0xFFF0000000000000 : bf16
+  // CHECK: constant 0xFF80 : bf16
+  %5 = constant 0xFF80 : bf16
 
   return
 }
@@ -1215,12 +1213,12 @@ func @pretty_names() {
   %x = test.string_attr_pretty_name
   // CHECK: %x = test.string_attr_pretty_name
   // CHECK-NOT: attributes
-  
+
   // This specifies an explicit name, which should override the result.
   %YY = test.string_attr_pretty_name attributes { names = ["y"] }
   // CHECK: %y = test.string_attr_pretty_name
   // CHECK-NOT: attributes
-  
+
   // Conflicts with the 'y' name, so need an explicit attribute.
   %0 = "test.string_attr_pretty_name"() { names = ["y"]} : () -> i32
   // CHECK: %y_0 = test.string_attr_pretty_name attributes {names = ["y"]}

diff  --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir
index 88f0e2f47ca6..a052203c8ba8 100644
--- a/mlir/test/Target/llvmir.mlir
+++ b/mlir/test/Target/llvmir.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
 
 // CHECK: @i32_global = internal global i32 42
 llvm.mlir.global internal @i32_global(42: i32) : !llvm.i32
@@ -1214,3 +1214,14 @@ llvm.func @passthrough() attributes {passthrough = ["noinline", ["alignstack", "
 // CHECK-DAG: alignstack=4
 // CHECK-DAG: null_pointer_is_valid
 // CHECK-DAG: "foo"="bar"
+
+// -----
+
+// CHECK-LABEL: @constant_bf16
+llvm.func @constant_bf16() -> !llvm<"bfloat"> {
+  %0 = llvm.mlir.constant(1.000000e+01 : bf16) : !llvm<"bfloat">
+  llvm.return %0 : !llvm<"bfloat">
+}
+
+// CHECK: ret bfloat 0xR4120
+

diff  --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 8fda2a2e73b6..df449a0da75c 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -134,7 +134,7 @@ TEST(DenseSplatTest, F64Splat) {
 
 TEST(DenseSplatTest, FloatAttrSplat) {
   MLIRContext context;
-  FloatType floatTy = FloatType::getBF16(&context);
+  FloatType floatTy = FloatType::getF32(&context);
   Attribute value = FloatAttr::get(floatTy, 10.0);
 
   testSplat(floatTy, value);
@@ -143,8 +143,7 @@ TEST(DenseSplatTest, FloatAttrSplat) {
 TEST(DenseSplatTest, BF16Splat) {
   MLIRContext context;
   FloatType floatTy = FloatType::getBF16(&context);
-  // Note: We currently use double to represent bfloat16.
-  double value = 10.0;
+  Attribute value = FloatAttr::get(floatTy, 10.0);
 
   testSplat(floatTy, value);
 }


        


More information about the Mlir-commits mailing list