[Mlir-commits] [mlir] 51bf5d3 - [mlir][Parser] Update DenseElementsAttr to print in hex when the number of elements is over a certain threshold.

River Riddle llvmlistbot at llvm.org
Thu Feb 20 14:45:45 PST 2020


Author: River Riddle
Date: 2020-02-20T14:40:58-08:00
New Revision: 51bf5d3cc19ac113de2ff185fb5bc2b99b8d89bc

URL: https://github.com/llvm/llvm-project/commit/51bf5d3cc19ac113de2ff185fb5bc2b99b8d89bc
DIFF: https://github.com/llvm/llvm-project/commit/51bf5d3cc19ac113de2ff185fb5bc2b99b8d89bc.diff

LOG: [mlir][Parser] Update DenseElementsAttr to print in hex when the number of elements is over a certain threshold.

Summary: DenseElementsAttr is used to store tensor data, which in some cases can become extremely large(100s of mb). In these cases it is much more efficient to format the data as a string of hex values instead.

Differential Revision: https://reviews.llvm.org/D74922

Added: 
    mlir/test/IR/dense-elements-hex.mlir

Modified: 
    mlir/include/mlir/IR/Attributes.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/IR/Attributes.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/test/IR/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 8a3b6cfa98b5..d5c063f326ce 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -723,6 +723,13 @@ class DenseElementsAttr
     return get(type, ArrayRef<T>(list));
   }
 
+  /// Construct a dense elements attribute from a raw buffer representing the
+  /// data for this attribute. Users should generally not use this methods as
+  /// the expected buffer format may not be a form the user expects.
+  static DenseElementsAttr getFromRawBuffer(ShapedType type,
+                                            ArrayRef<char> rawBuffer,
+                                            bool isSplatBuffer);
+
   //===--------------------------------------------------------------------===//
   // Iterators
   //===--------------------------------------------------------------------===//
@@ -918,6 +925,11 @@ class DenseElementsAttr
   FloatElementIterator float_value_begin() const;
   FloatElementIterator float_value_end() const;
 
+  /// Return the raw storage data held by this attribute. Users should generally
+  /// not use this directly, as the internal storage format is not always in the
+  /// form the user might expect.
+  ArrayRef<char> getRawData() const;
+
   //===--------------------------------------------------------------------===//
   // Mutation Utilities
   //===--------------------------------------------------------------------===//
@@ -941,9 +953,6 @@ class DenseElementsAttr
             function_ref<APInt(const APFloat &)> mapping) const;
 
 protected:
-  /// Return the raw storage data held by this attribute.
-  ArrayRef<char> getRawData() const;
-
   /// Get iterators to the raw APInt values for each element in this attribute.
   IntElementIterator raw_int_begin() const {
     return IntElementIterator(*this, 0);

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f3d884f75781..05e3645151ea 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -63,6 +63,13 @@ OpAsmPrinter::~OpAsmPrinter() {}
 // OpPrintingFlags
 //===----------------------------------------------------------------------===//
 
+static llvm::cl::opt<int> printElementsAttrWithHexIfLarger(
+    "mlir-print-elementsattrs-with-hex-if-larger",
+    llvm::cl::desc(
+        "Print DenseElementsAttrs with a hex string that have "
+        "more elements than the given upper limit (use -1 to disable)"),
+    llvm::cl::init(100));
+
 static llvm::cl::opt<unsigned> elideElementsAttrIfLarger(
     "mlir-elide-elementsattrs-if-larger",
     llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
@@ -887,7 +894,10 @@ class ModulePrinter {
                              bool withKeyword = false);
   void printTrailingLocation(Location loc);
   void printLocationInternal(LocationAttr loc, bool pretty = false);
-  void printDenseElementsAttr(DenseElementsAttr attr);
+
+  /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
+  /// used instead of individual elements when the elements attr is large.
+  void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
 
   void printDialectAttribute(Attribute attr);
   void printDialectType(Type type);
@@ -1321,7 +1331,7 @@ void ModulePrinter::printAttribute(Attribute attr,
       break;
     }
     os << "dense<";
-    printDenseElementsAttr(eltsAttr);
+    printDenseElementsAttr(eltsAttr, /*allowHex=*/true);
     os << '>';
     break;
   }
@@ -1333,9 +1343,9 @@ void ModulePrinter::printAttribute(Attribute attr,
       break;
     }
     os << "sparse<";
-    printDenseElementsAttr(elementsAttr.getIndices());
+    printDenseElementsAttr(elementsAttr.getIndices(), /*allowHex=*/false);
     os << ", ";
-    printDenseElementsAttr(elementsAttr.getValues());
+    printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
     os << '>';
     break;
   }
@@ -1375,7 +1385,8 @@ static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
   printFloatValue(value, os);
 }
 
-void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
+void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
+                                           bool allowHex) {
   auto type = attr.getType();
   auto shape = type.getShape();
   auto rank = type.getRank();
@@ -1401,6 +1412,15 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
     return;
   }
 
+  // Check to see if we should format this attribute as a hex string.
+  if (allowHex && printElementsAttrWithHexIfLarger != -1 &&
+      numElements > printElementsAttrWithHexIfLarger) {
+    ArrayRef<char> rawData = attr.getRawData();
+    os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
+       << "\"";
+    return;
+  }
+
   // We use a mixed-radix counter to iterate through the shape. When we bump a
   // non-least-significant digit, we emit a close bracket. When we next emit an
   // element we re-open all closed brackets.

diff  --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 4833dee5a951..5691184fefda 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -664,9 +664,18 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
   return getRaw(type, intValues);
 }
 
-// Constructs a dense elements attribute from an array of raw APInt values.
-// Each APInt value is expected to have the same bitwidth as the element type
-// of 'type'.
+/// Construct a dense elements attribute from a raw buffer representing the
+/// data for this attribute. Users should generally not use this methods as
+/// the expected buffer format may not be a form the user expects.
+DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
+                                                      ArrayRef<char> rawBuffer,
+                                                      bool isSplatBuffer) {
+  return getRaw(type, rawBuffer, isSplatBuffer);
+}
+
+/// Constructs a dense elements attribute from an array of raw APInt values.
+/// Each APInt value is expected to have the same bitwidth as the element type
+/// of 'type'.
 DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
                                             ArrayRef<APInt> values) {
   assert(hasSameElementsOrSplat(type, values));
@@ -727,11 +736,6 @@ bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
   return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
 }
 
-/// Return the raw storage data held by this attribute.
-ArrayRef<char> DenseElementsAttr::getRawData() const {
-  return static_cast<ImplType *>(impl)->data;
-}
-
 /// Returns if this attribute corresponds to a splat, i.e. if all element
 /// values are the same.
 bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
@@ -795,6 +799,11 @@ auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
   return getFloatValues().end();
 }
 
+/// Return the raw storage data held by this attribute.
+ArrayRef<char> DenseElementsAttr::getRawData() const {
+  return static_cast<ImplType *>(impl)->data;
+}
+
 /// Return a new DenseElementsAttr that has the same data as the current
 /// attribute, but has been reshaped to 'newType'. The new type must have the
 /// same total number of elements as well as element type.

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 2a2219c4202f..9962fbf8c055 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1808,6 +1808,24 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
   return builder.getIntegerAttr(type, isNegative ? -apInt : apInt);
 }
 
+/// Parse elements values stored within a hex etring. On success, the values are
+/// stored into 'result'.
+static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
+                                             std::string &result) {
+  std::string val = tok.getStringValue();
+  if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
+    return parser.emitError(tok.getLoc(),
+                            "elements hex string should start with '0x'");
+
+  StringRef hexValues = StringRef(val).drop_front(2);
+  if (!llvm::all_of(hexValues, llvm::isHexDigit))
+    return parser.emitError(tok.getLoc(),
+                            "elements hex string only contains hex digits");
+
+  result = llvm::fromHex(hexValues);
+  return success();
+}
+
 /// Parse an opaque elements attribute.
 Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
   consumeToken(Token::kw_opaque);
@@ -1825,31 +1843,23 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
   if (!dialect)
     return (emitError("no registered dialect with namespace '" + name + "'"),
             nullptr);
-
   consumeToken(Token::string);
+
   if (parseToken(Token::comma, "expected ','"))
     return nullptr;
 
-  if (getToken().getKind() != Token::string)
-    return (emitError("opaque string should start with '0x'"), nullptr);
-
-  auto val = getToken().getStringValue();
-  if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
-    return (emitError("opaque string should start with '0x'"), nullptr);
-
-  val = val.substr(2);
-  if (!llvm::all_of(val, llvm::isHexDigit))
-    return (emitError("opaque string only contains hex digits"), nullptr);
-
-  consumeToken(Token::string);
-  if (parseToken(Token::greater, "expected '>'"))
+  Token hexTok = getToken();
+  if (parseToken(Token::string, "elements hex string should start with '0x'") ||
+      parseToken(Token::greater, "expected '>'"))
     return nullptr;
-
   auto type = parseElementsLiteralType(attrType);
   if (!type)
     return nullptr;
 
-  return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val));
+  std::string data;
+  if (parseElementAttrHexValues(*this, hexTok, data))
+    return nullptr;
+  return OpaqueElementsAttr::get(dialect, type, data);
 }
 
 namespace {
@@ -1857,11 +1867,9 @@ class TensorLiteralParser {
 public:
   TensorLiteralParser(Parser &p) : p(p) {}
 
-  ParseResult parse() {
-    if (p.getToken().is(Token::l_square))
-      return parseList(shape);
-    return parseElement();
-  }
+  /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
+  /// may also parse a tensor literal that is store as a hex string.
+  ParseResult parse(bool allowHex);
 
   /// Build a dense attribute instance with the parsed elements and the given
   /// shaped type.
@@ -1893,6 +1901,9 @@ class TensorLiteralParser {
   DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
                                  FloatType eltTy);
 
+  /// Build a Dense attribute with hex data for the given type.
+  DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
+
   /// Parse a single element, returning failure if it isn't a valid element
   /// literal. For example:
   /// parseElement(1) -> Success, 1
@@ -1907,6 +1918,9 @@ class TensorLiteralParser {
   ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
   ParseResult parseList(SmallVectorImpl<int64_t> &dims);
 
+  /// Parse a literal that was printed as a hex string.
+  ParseResult parseHexElements();
+
   Parser &p;
 
   /// The shape inferred from the parsed elements.
@@ -1917,13 +1931,35 @@ class TensorLiteralParser {
 
   /// A flag that indicates the type of elements that have been parsed.
   Optional<ElementKind> knownEltKind;
+
+  /// Storage used when parsing elements that were stored as hex values.
+  Optional<Token> hexStorage;
 };
 } // namespace
 
+/// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
+/// may also parse a tensor literal that is store as a hex string.
+ParseResult TensorLiteralParser::parse(bool allowHex) {
+  // If hex is allowed, check for a string literal.
+  if (allowHex && p.getToken().is(Token::string)) {
+    hexStorage = p.getToken();
+    p.consumeToken(Token::string);
+    return success();
+  }
+  // Otherwise, parse a list or an individual element.
+  if (p.getToken().is(Token::l_square))
+    return parseList(shape);
+  return parseElement();
+}
+
 /// Build a dense attribute instance with the parsed elements and the given
 /// shaped type.
 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
                                                ShapedType type) {
+  // Check to see if we parsed the literal from a hex string.
+  if (hexStorage.hasValue())
+    return getHexAttr(loc, type);
+
   // Check that the parsed storage size has the same number of elements to the
   // type, or is a known splat.
   if (!shape.empty() && getShape() != type.getShape()) {
@@ -2045,6 +2081,33 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
   return DenseElementsAttr::get(type, floatValues);
 }
 
+/// Build a Dense attribute with hex data for the given type.
+DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
+                                                  ShapedType type) {
+  Type elementType = type.getElementType();
+  if (!elementType.isIntOrFloat()) {
+    p.emitError(loc) << "expected floating-point or integer element type, got "
+                     << elementType;
+    return nullptr;
+  }
+
+  std::string data;
+  if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
+    return nullptr;
+
+  // Check that the size of the hex data correpsonds to the size of the type, or
+  // a splat of the type.
+  if (static_cast<int64_t>(data.size() * CHAR_BIT) !=
+      (type.getNumElements() * elementType.getIntOrFloatBitWidth())) {
+    p.emitError(loc) << "elements hex data size is invalid for provided type: "
+                     << type;
+    return nullptr;
+  }
+
+  return DenseElementsAttr::getFromRawBuffer(
+      type, ArrayRef<char>(data.data(), data.size()), /*isSplatBuffer=*/false);
+}
+
 ParseResult TensorLiteralParser::parseElement() {
   switch (p.getToken().getKind()) {
   // Parse a boolean element.
@@ -2125,7 +2188,7 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
 
   // Parse the literal data.
   TensorLiteralParser literalParser(*this);
-  if (literalParser.parse())
+  if (literalParser.parse(/*allowHex=*/true))
     return nullptr;
 
   if (parseToken(Token::greater, "expected '>'"))
@@ -2170,19 +2233,20 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
     return nullptr;
 
-  /// Parse indices
+  /// Parse the indices. We don't allow hex values here as we may need to use
+  /// the inferred shape.
   auto indicesLoc = getToken().getLoc();
   TensorLiteralParser indiceParser(*this);
-  if (indiceParser.parse())
+  if (indiceParser.parse(/*allowHex=*/false))
     return nullptr;
 
   if (parseToken(Token::comma, "expected ','"))
     return nullptr;
 
-  /// Parse values.
+  /// Parse the values.
   auto valuesLoc = getToken().getLoc();
   TensorLiteralParser valuesParser(*this);
-  if (valuesParser.parse())
+  if (valuesParser.parse(/*allowHex=*/true))
     return nullptr;
 
   if (parseToken(Token::greater, "expected '>'"))

diff  --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir
new file mode 100644
index 000000000000..b3f9539e0111
--- /dev/null
+++ b/mlir/test/IR/dense-elements-hex.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file -mlir-print-elementsattrs-with-hex-if-larger=1 | FileCheck %s --check-prefix=HEX
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// HEX: dense<"0x00000000000024400000000000001440"> : tensor<2xf64>
+"foo.op"() {dense.attr = dense<[10.0, 5.0]> : tensor<2xf64>} : () -> ()
+
+// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64>
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
+
+// -----
+
+// expected-error at +1 {{elements hex string should start with '0x'}}
+"foo.op"() {dense.attr = dense<"00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
+
+// -----
+
+// expected-error at +1 {{elements hex string only contains hex digits}}
+"foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144X"> : tensor<2xf64>} : () -> ()
+
+// -----
+
+// expected-error at +1 {{expected floating-point or integer element type, got '!unknown<"">'}}
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2x!unknown<"">>} : () -> ()
+
+// -----
+
+// expected-error at +1 {{elements hex data size is invalid for provided type}}
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<4xf64>} : () -> ()

diff  --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 058744a78fcf..25756735df06 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -703,14 +703,14 @@ func @elementsattr_malformed_opaque() -> () {
 
 func @elementsattr_malformed_opaque1() -> () {
 ^bb0:
-  "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{opaque string only contains hex digits}}
+  "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string only contains hex digits}}
 }
 
 // -----
 
 func @elementsattr_malformed_opaque2() -> () {
 ^bb0:
-  "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{opaque string should start with '0x'}}
+  "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string should start with '0x'}}
 }
 
 // -----


        


More information about the Mlir-commits mailing list