[Mlir-commits] [mlir] 51bf5d3 - [mlir][Parser] Update DenseElementsAttr to print in hex when the number of elements is over a certain threshold.
River Riddle
llvmlistbot at llvm.org
Thu Feb 20 14:45:45 PST 2020
Author: River Riddle
Date: 2020-02-20T14:40:58-08:00
New Revision: 51bf5d3cc19ac113de2ff185fb5bc2b99b8d89bc
URL: https://github.com/llvm/llvm-project/commit/51bf5d3cc19ac113de2ff185fb5bc2b99b8d89bc
DIFF: https://github.com/llvm/llvm-project/commit/51bf5d3cc19ac113de2ff185fb5bc2b99b8d89bc.diff
LOG: [mlir][Parser] Update DenseElementsAttr to print in hex when the number of elements is over a certain threshold.
Summary: DenseElementsAttr is used to store tensor data, which in some cases can become extremely large(100s of mb). In these cases it is much more efficient to format the data as a string of hex values instead.
Differential Revision: https://reviews.llvm.org/D74922
Added:
mlir/test/IR/dense-elements-hex.mlir
Modified:
mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Attributes.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/invalid.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 8a3b6cfa98b5..d5c063f326ce 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -723,6 +723,13 @@ class DenseElementsAttr
return get(type, ArrayRef<T>(list));
}
+ /// Construct a dense elements attribute from a raw buffer representing the
+ /// data for this attribute. Users should generally not use this methods as
+ /// the expected buffer format may not be a form the user expects.
+ static DenseElementsAttr getFromRawBuffer(ShapedType type,
+ ArrayRef<char> rawBuffer,
+ bool isSplatBuffer);
+
//===--------------------------------------------------------------------===//
// Iterators
//===--------------------------------------------------------------------===//
@@ -918,6 +925,11 @@ class DenseElementsAttr
FloatElementIterator float_value_begin() const;
FloatElementIterator float_value_end() const;
+ /// Return the raw storage data held by this attribute. Users should generally
+ /// not use this directly, as the internal storage format is not always in the
+ /// form the user might expect.
+ ArrayRef<char> getRawData() const;
+
//===--------------------------------------------------------------------===//
// Mutation Utilities
//===--------------------------------------------------------------------===//
@@ -941,9 +953,6 @@ class DenseElementsAttr
function_ref<APInt(const APFloat &)> mapping) const;
protected:
- /// Return the raw storage data held by this attribute.
- ArrayRef<char> getRawData() const;
-
/// Get iterators to the raw APInt values for each element in this attribute.
IntElementIterator raw_int_begin() const {
return IntElementIterator(*this, 0);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f3d884f75781..05e3645151ea 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -63,6 +63,13 @@ OpAsmPrinter::~OpAsmPrinter() {}
// OpPrintingFlags
//===----------------------------------------------------------------------===//
+static llvm::cl::opt<int> printElementsAttrWithHexIfLarger(
+ "mlir-print-elementsattrs-with-hex-if-larger",
+ llvm::cl::desc(
+ "Print DenseElementsAttrs with a hex string that have "
+ "more elements than the given upper limit (use -1 to disable)"),
+ llvm::cl::init(100));
+
static llvm::cl::opt<unsigned> elideElementsAttrIfLarger(
"mlir-elide-elementsattrs-if-larger",
llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
@@ -887,7 +894,10 @@ class ModulePrinter {
bool withKeyword = false);
void printTrailingLocation(Location loc);
void printLocationInternal(LocationAttr loc, bool pretty = false);
- void printDenseElementsAttr(DenseElementsAttr attr);
+
+ /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
+ /// used instead of individual elements when the elements attr is large.
+ void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
void printDialectAttribute(Attribute attr);
void printDialectType(Type type);
@@ -1321,7 +1331,7 @@ void ModulePrinter::printAttribute(Attribute attr,
break;
}
os << "dense<";
- printDenseElementsAttr(eltsAttr);
+ printDenseElementsAttr(eltsAttr, /*allowHex=*/true);
os << '>';
break;
}
@@ -1333,9 +1343,9 @@ void ModulePrinter::printAttribute(Attribute attr,
break;
}
os << "sparse<";
- printDenseElementsAttr(elementsAttr.getIndices());
+ printDenseElementsAttr(elementsAttr.getIndices(), /*allowHex=*/false);
os << ", ";
- printDenseElementsAttr(elementsAttr.getValues());
+ printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
os << '>';
break;
}
@@ -1375,7 +1385,8 @@ static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os,
printFloatValue(value, os);
}
-void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
+void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
+ bool allowHex) {
auto type = attr.getType();
auto shape = type.getShape();
auto rank = type.getRank();
@@ -1401,6 +1412,15 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
return;
}
+ // Check to see if we should format this attribute as a hex string.
+ if (allowHex && printElementsAttrWithHexIfLarger != -1 &&
+ numElements > printElementsAttrWithHexIfLarger) {
+ ArrayRef<char> rawData = attr.getRawData();
+ os << '"' << "0x" << llvm::toHex(StringRef(rawData.data(), rawData.size()))
+ << "\"";
+ return;
+ }
+
// We use a mixed-radix counter to iterate through the shape. When we bump a
// non-least-significant digit, we emit a close bracket. When we next emit an
// element we re-open all closed brackets.
diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp
index 4833dee5a951..5691184fefda 100644
--- a/mlir/lib/IR/Attributes.cpp
+++ b/mlir/lib/IR/Attributes.cpp
@@ -664,9 +664,18 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
return getRaw(type, intValues);
}
-// Constructs a dense elements attribute from an array of raw APInt values.
-// Each APInt value is expected to have the same bitwidth as the element type
-// of 'type'.
+/// Construct a dense elements attribute from a raw buffer representing the
+/// data for this attribute. Users should generally not use this methods as
+/// the expected buffer format may not be a form the user expects.
+DenseElementsAttr DenseElementsAttr::getFromRawBuffer(ShapedType type,
+ ArrayRef<char> rawBuffer,
+ bool isSplatBuffer) {
+ return getRaw(type, rawBuffer, isSplatBuffer);
+}
+
+/// Constructs a dense elements attribute from an array of raw APInt values.
+/// Each APInt value is expected to have the same bitwidth as the element type
+/// of 'type'.
DenseElementsAttr DenseElementsAttr::getRaw(ShapedType type,
ArrayRef<APInt> values) {
assert(hasSameElementsOrSplat(type, values));
@@ -727,11 +736,6 @@ bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize,
return ::isValidIntOrFloat(getType(), dataEltSize, isInt);
}
-/// Return the raw storage data held by this attribute.
-ArrayRef<char> DenseElementsAttr::getRawData() const {
- return static_cast<ImplType *>(impl)->data;
-}
-
/// Returns if this attribute corresponds to a splat, i.e. if all element
/// values are the same.
bool DenseElementsAttr::isSplat() const { return getImpl()->isSplat; }
@@ -795,6 +799,11 @@ auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
return getFloatValues().end();
}
+/// Return the raw storage data held by this attribute.
+ArrayRef<char> DenseElementsAttr::getRawData() const {
+ return static_cast<ImplType *>(impl)->data;
+}
+
/// Return a new DenseElementsAttr that has the same data as the current
/// attribute, but has been reshaped to 'newType'. The new type must have the
/// same total number of elements as well as element type.
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 2a2219c4202f..9962fbf8c055 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1808,6 +1808,24 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
return builder.getIntegerAttr(type, isNegative ? -apInt : apInt);
}
+/// Parse elements values stored within a hex etring. On success, the values are
+/// stored into 'result'.
+static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
+ std::string &result) {
+ std::string val = tok.getStringValue();
+ if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
+ return parser.emitError(tok.getLoc(),
+ "elements hex string should start with '0x'");
+
+ StringRef hexValues = StringRef(val).drop_front(2);
+ if (!llvm::all_of(hexValues, llvm::isHexDigit))
+ return parser.emitError(tok.getLoc(),
+ "elements hex string only contains hex digits");
+
+ result = llvm::fromHex(hexValues);
+ return success();
+}
+
/// Parse an opaque elements attribute.
Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
consumeToken(Token::kw_opaque);
@@ -1825,31 +1843,23 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
if (!dialect)
return (emitError("no registered dialect with namespace '" + name + "'"),
nullptr);
-
consumeToken(Token::string);
+
if (parseToken(Token::comma, "expected ','"))
return nullptr;
- if (getToken().getKind() != Token::string)
- return (emitError("opaque string should start with '0x'"), nullptr);
-
- auto val = getToken().getStringValue();
- if (val.size() < 2 || val[0] != '0' || val[1] != 'x')
- return (emitError("opaque string should start with '0x'"), nullptr);
-
- val = val.substr(2);
- if (!llvm::all_of(val, llvm::isHexDigit))
- return (emitError("opaque string only contains hex digits"), nullptr);
-
- consumeToken(Token::string);
- if (parseToken(Token::greater, "expected '>'"))
+ Token hexTok = getToken();
+ if (parseToken(Token::string, "elements hex string should start with '0x'") ||
+ parseToken(Token::greater, "expected '>'"))
return nullptr;
-
auto type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;
- return OpaqueElementsAttr::get(dialect, type, llvm::fromHex(val));
+ std::string data;
+ if (parseElementAttrHexValues(*this, hexTok, data))
+ return nullptr;
+ return OpaqueElementsAttr::get(dialect, type, data);
}
namespace {
@@ -1857,11 +1867,9 @@ class TensorLiteralParser {
public:
TensorLiteralParser(Parser &p) : p(p) {}
- ParseResult parse() {
- if (p.getToken().is(Token::l_square))
- return parseList(shape);
- return parseElement();
- }
+ /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
+ /// may also parse a tensor literal that is store as a hex string.
+ ParseResult parse(bool allowHex);
/// Build a dense attribute instance with the parsed elements and the given
/// shaped type.
@@ -1893,6 +1901,9 @@ class TensorLiteralParser {
DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type,
FloatType eltTy);
+ /// Build a Dense attribute with hex data for the given type.
+ DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
+
/// Parse a single element, returning failure if it isn't a valid element
/// literal. For example:
/// parseElement(1) -> Success, 1
@@ -1907,6 +1918,9 @@ class TensorLiteralParser {
/// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
ParseResult parseList(SmallVectorImpl<int64_t> &dims);
+ /// Parse a literal that was printed as a hex string.
+ ParseResult parseHexElements();
+
Parser &p;
/// The shape inferred from the parsed elements.
@@ -1917,13 +1931,35 @@ class TensorLiteralParser {
/// A flag that indicates the type of elements that have been parsed.
Optional<ElementKind> knownEltKind;
+
+ /// Storage used when parsing elements that were stored as hex values.
+ Optional<Token> hexStorage;
};
} // namespace
+/// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
+/// may also parse a tensor literal that is store as a hex string.
+ParseResult TensorLiteralParser::parse(bool allowHex) {
+ // If hex is allowed, check for a string literal.
+ if (allowHex && p.getToken().is(Token::string)) {
+ hexStorage = p.getToken();
+ p.consumeToken(Token::string);
+ return success();
+ }
+ // Otherwise, parse a list or an individual element.
+ if (p.getToken().is(Token::l_square))
+ return parseList(shape);
+ return parseElement();
+}
+
/// Build a dense attribute instance with the parsed elements and the given
/// shaped type.
DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
ShapedType type) {
+ // Check to see if we parsed the literal from a hex string.
+ if (hexStorage.hasValue())
+ return getHexAttr(loc, type);
+
// Check that the parsed storage size has the same number of elements to the
// type, or is a known splat.
if (!shape.empty() && getShape() != type.getShape()) {
@@ -2045,6 +2081,33 @@ DenseElementsAttr TensorLiteralParser::getFloatAttr(llvm::SMLoc loc,
return DenseElementsAttr::get(type, floatValues);
}
+/// Build a Dense attribute with hex data for the given type.
+DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
+ ShapedType type) {
+ Type elementType = type.getElementType();
+ if (!elementType.isIntOrFloat()) {
+ p.emitError(loc) << "expected floating-point or integer element type, got "
+ << elementType;
+ return nullptr;
+ }
+
+ std::string data;
+ if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
+ return nullptr;
+
+ // Check that the size of the hex data correpsonds to the size of the type, or
+ // a splat of the type.
+ if (static_cast<int64_t>(data.size() * CHAR_BIT) !=
+ (type.getNumElements() * elementType.getIntOrFloatBitWidth())) {
+ p.emitError(loc) << "elements hex data size is invalid for provided type: "
+ << type;
+ return nullptr;
+ }
+
+ return DenseElementsAttr::getFromRawBuffer(
+ type, ArrayRef<char>(data.data(), data.size()), /*isSplatBuffer=*/false);
+}
+
ParseResult TensorLiteralParser::parseElement() {
switch (p.getToken().getKind()) {
// Parse a boolean element.
@@ -2125,7 +2188,7 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
// Parse the literal data.
TensorLiteralParser literalParser(*this);
- if (literalParser.parse())
+ if (literalParser.parse(/*allowHex=*/true))
return nullptr;
if (parseToken(Token::greater, "expected '>'"))
@@ -2170,19 +2233,20 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
- /// Parse indices
+ /// Parse the indices. We don't allow hex values here as we may need to use
+ /// the inferred shape.
auto indicesLoc = getToken().getLoc();
TensorLiteralParser indiceParser(*this);
- if (indiceParser.parse())
+ if (indiceParser.parse(/*allowHex=*/false))
return nullptr;
if (parseToken(Token::comma, "expected ','"))
return nullptr;
- /// Parse values.
+ /// Parse the values.
auto valuesLoc = getToken().getLoc();
TensorLiteralParser valuesParser(*this);
- if (valuesParser.parse())
+ if (valuesParser.parse(/*allowHex=*/true))
return nullptr;
if (parseToken(Token::greater, "expected '>'"))
diff --git a/mlir/test/IR/dense-elements-hex.mlir b/mlir/test/IR/dense-elements-hex.mlir
new file mode 100644
index 000000000000..b3f9539e0111
--- /dev/null
+++ b/mlir/test/IR/dense-elements-hex.mlir
@@ -0,0 +1,28 @@
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file -mlir-print-elementsattrs-with-hex-if-larger=1 | FileCheck %s --check-prefix=HEX
+// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// HEX: dense<"0x00000000000024400000000000001440"> : tensor<2xf64>
+"foo.op"() {dense.attr = dense<[10.0, 5.0]> : tensor<2xf64>} : () -> ()
+
+// CHECK: dense<[1.000000e+01, 5.000000e+00]> : tensor<2xf64>
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
+
+// -----
+
+// expected-error at +1 {{elements hex string should start with '0x'}}
+"foo.op"() {dense.attr = dense<"00000000000024400000000000001440"> : tensor<2xf64>} : () -> ()
+
+// -----
+
+// expected-error at +1 {{elements hex string only contains hex digits}}
+"foo.op"() {dense.attr = dense<"0x0000000000002440000000000000144X"> : tensor<2xf64>} : () -> ()
+
+// -----
+
+// expected-error at +1 {{expected floating-point or integer element type, got '!unknown<"">'}}
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<2x!unknown<"">>} : () -> ()
+
+// -----
+
+// expected-error at +1 {{elements hex data size is invalid for provided type}}
+"foo.op"() {dense.attr = dense<"0x00000000000024400000000000001440"> : tensor<4xf64>} : () -> ()
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 058744a78fcf..25756735df06 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -703,14 +703,14 @@ func @elementsattr_malformed_opaque() -> () {
func @elementsattr_malformed_opaque1() -> () {
^bb0:
- "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{opaque string only contains hex digits}}
+ "foo"(){bar = opaque<"", "0xQZz123"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string only contains hex digits}}
}
// -----
func @elementsattr_malformed_opaque2() -> () {
^bb0:
- "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{opaque string should start with '0x'}}
+ "foo"(){bar = opaque<"", "00abc"> : tensor<1xi8>} : () -> () // expected-error {{elements hex string should start with '0x'}}
}
// -----
More information about the Mlir-commits
mailing list