[Mlir-commits] [mlir] fb674e3 - [mlir] Add support for sparse DenseStringElements.
River Riddle
llvmlistbot at llvm.org
Sat Apr 25 01:23:09 PDT 2020
Author: Rob Suderman
Date: 2020-04-25T01:21:40-07:00
New Revision: fb674e3329d8fa694d8e5ce179081890fb918556
URL: https://github.com/llvm/llvm-project/commit/fb674e3329d8fa694d8e5ce179081890fb918556
DIFF: https://github.com/llvm/llvm-project/commit/fb674e3329d8fa694d8e5ce179081890fb918556.diff
LOG: [mlir] Add support for sparse DenseStringElements.
Summary: Added support for sparse strings elements. This is a follow up from the original DenseStringElements.
Differential Revision: https://reviews.llvm.org/D78844
Added:
Modified:
mlir/include/mlir/IR/Attributes.h
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/parser.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 656c28ba4e8a..b4f1d025be85 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -1274,6 +1274,14 @@ class SparseElementsAttr
getZeroValue() const {
return getZeroAPFloat();
}
+
+ /// Get a zero for a StringRef.
+ template <typename T>
+ typename std::enable_if<std::is_same<StringRef, T>::value, T>::type
+ getZeroValue() const {
+ return StringRef();
+ }
+
/// Get a zero for an C++ integer or float type.
template <typename T>
typename std::enable_if<std::numeric_limits<T>::is_integer ||
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index bdaf15c6e5c5..f17d8fde6a82 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -976,6 +976,11 @@ class ModulePrinter {
/// Print a dense string elements attribute.
void printDenseStringElementsAttr(DenseStringElementsAttr attr);
+ /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
+ /// used instead of individual elements when the elements attr is large.
+ void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
+ bool allowHex);
+
void printDialectAttribute(Attribute attr);
void printDialectType(Type type);
@@ -1396,13 +1401,13 @@ void ModulePrinter::printAttribute(Attribute attr,
break;
}
case StandardAttributes::DenseIntOrFPElements: {
- auto eltsAttr = attr.cast<DenseElementsAttr>();
+ auto eltsAttr = attr.cast<DenseIntOrFPElementsAttr>();
if (printerFlags.shouldElideElementsAttr(eltsAttr)) {
printElidedElementsAttr(os);
break;
}
os << "dense<";
- printDenseElementsAttr(eltsAttr, /*allowHex=*/true);
+ printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true);
os << '>';
break;
}
@@ -1425,7 +1430,8 @@ void ModulePrinter::printAttribute(Attribute attr,
break;
}
os << "sparse<";
- printDenseElementsAttr(elementsAttr.getIndices(), /*allowHex=*/false);
+ printDenseIntOrFPElementsAttr(elementsAttr.getIndices(),
+ /*allowHex=*/false);
os << ", ";
printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
os << '>';
@@ -1477,6 +1483,17 @@ static void printDenseStringElement(DenseStringElementsAttr attr,
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
bool allowHex) {
+ if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
+ printDenseStringElementsAttr(stringAttr);
+ return;
+ }
+
+ printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
+ allowHex);
+}
+
+void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
+ bool allowHex) {
auto type = attr.getType();
auto shape = type.getShape();
auto rank = type.getRank();
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index 7a02d776c172..2170927df927 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -764,6 +764,11 @@ func @sparsetensorattr() -> () {
"foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> ()
// CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<f32>} : () -> ()
"foof321"(){bar = sparse<[], []> : tensor<f32>} : () -> ()
+
+// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
+ "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
+// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> ()
+ "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> ()
return
}
More information about the Mlir-commits
mailing list