[Mlir-commits] [mlir] 4a7364f - [mlir][Parser] Use APFloat instead of FloatAttr when parsing DenseElementsAttrs.

River Riddle llvmlistbot at llvm.org
Wed Feb 19 10:35:23 PST 2020


Author: River Riddle
Date: 2020-02-19T10:30:07-08:00
New Revision: 4a7364f1c2ef0c45d7e603799fe0b7662d4c4078

URL: https://github.com/llvm/llvm-project/commit/4a7364f1c2ef0c45d7e603799fe0b7662d4c4078
DIFF: https://github.com/llvm/llvm-project/commit/4a7364f1c2ef0c45d7e603799fe0b7662d4c4078.diff

LOG: [mlir][Parser] Use APFloat instead of FloatAttr when parsing DenseElementsAttrs.

Summary: DenseElementsAttr stores float values as raw bits internally, so creating attributes just to have them unwrapped is extremely inefficient.

Differential Revision: https://reviews.llvm.org/D74818

Added: 
    

Modified: 
    mlir/lib/Parser/Parser.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 8bd57a11888c..2a2219c4202f 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1734,22 +1734,19 @@ Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
 }
 
 /// Construct a float attribute bitwise equivalent to the integer literal.
-static FloatAttr buildHexadecimalFloatLiteral(Parser *p, FloatType type,
-                                              uint64_t value) {
+static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
+                                                      uint64_t value) {
   // FIXME: bfloat is currently stored as a double internally because it doesn't
   // have valid APFloat semantics.
-  if (type.isF64() || type.isBF16()) {
-    APFloat apFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
-    return p->builder.getFloatAttr(type, apFloat);
-  }
+  if (type.isF64() || type.isBF16())
+    return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
 
   APInt apInt(type.getWidth(), value);
   if (apInt != value) {
     p->emitError("hexadecimal float constant out of range for type");
-    return nullptr;
+    return llvm::None;
   }
-  APFloat apFloat(type.getFloatSemantics(), apInt);
-  return p->builder.getFloatAttr(type, apFloat);
+  return APFloat(type.getFloatSemantics(), apInt);
 }
 
 /// Parse a decimal or a hexadecimal literal, which can be either an integer
@@ -1787,7 +1784,9 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
     }
 
     // Construct a float attribute bitwise equivalent to the integer literal.
-    return buildHexadecimalFloatLiteral(this, floatType, *val);
+    Optional<APFloat> apVal =
+        buildHexadecimalFloatLiteral(this, floatType, *val);
+    return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
   }
 
   if (!type.isIntOrIndex())
@@ -1996,7 +1995,7 @@ DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc,
 DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
                                                     ShapedType type,
                                                     FloatType eltTy) {
-  std::vector<Attribute> floatValues;
+  std::vector<APFloat> floatValues;
   floatValues.reserve(storage.size());
   for (const auto &signAndToken : storage) {
     bool isNegative = signAndToken.first;
@@ -2014,10 +2013,10 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
         p.emitError("hexadecimal float constant out of range for attribute");
         return nullptr;
       }
-      FloatAttr attr = buildHexadecimalFloatLiteral(&p, eltTy, *val);
-      if (!attr)
+      Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
+      if (!apVal)
         return nullptr;
-      floatValues.push_back(attr);
+      floatValues.push_back(*apVal);
       continue;
     }
 
@@ -2033,7 +2032,14 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
       p.emitError("floating point value too large for attribute");
       return nullptr;
     }
-    floatValues.push_back(FloatAttr::get(eltTy, isNegative ? -*val : *val));
+    // Treat BF16 as double because it is not supported in LLVM's APFloat.
+    APFloat apVal(isNegative ? -*val : *val);
+    if (!eltTy.isBF16() && !eltTy.isF64()) {
+      bool unused;
+      apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
+                    &unused);
+    }
+    floatValues.push_back(apVal);
   }
 
   return DenseElementsAttr::get(type, floatValues);


        


More information about the Mlir-commits mailing list