[Mlir-commits] [mlir] c8177f8 - [mlir][sparse] Factoring out SparseTensorFile::canReadAs predicate
wren romano
llvmlistbot at llvm.org
Thu Sep 29 14:46:56 PDT 2022
Author: wren romano
Date: 2022-09-29T14:46:45-07:00
New Revision: c8177f845b4132f2838d169cee04270051235140
URL: https://github.com/llvm/llvm-project/commit/c8177f845b4132f2838d169cee04270051235140
DIFF: https://github.com/llvm/llvm-project/commit/c8177f845b4132f2838d169cee04270051235140.diff
LOG: [mlir][sparse] Factoring out SparseTensorFile::canReadAs predicate
This is a followup to the refactoring of D133462, D133830, D133831, and D133833.
Depends On D133833
Reviewed By: aartbik
Differential Revision: https://reviews.llvm.org/D133835
Added:
Modified:
mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
mlir/lib/ExecutionEngine/SparseTensor/File.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 3d82b4908ad5..5dd6c17364a8 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -86,6 +86,10 @@ class SparseTensorFile final {
/// Checks if a header has been successfully read.
bool isValid() const { return valueKind_ != ValueKind::kInvalid; }
+ /// Checks if the file's ValueKind can be converted into the given
+ /// tensor PrimaryType. Is only valid after parsing the header.
+ bool canReadAs(PrimaryType valTy) const;
+
/// Gets the MME "pattern" property setting. Is only valid after
/// parsing the header.
bool isPattern() const {
@@ -208,16 +212,10 @@ openSparseTensorCOO(const char *filename, uint64_t rank, const uint64_t *shape,
stfile.openFile();
stfile.readHeader();
// Check tensor element type against the value type in the input file.
- SparseTensorFile::ValueKind valueKind = stfile.getValueKind();
- bool tensorIsInteger =
- (valTp >= PrimaryType::kI64 && valTp <= PrimaryType::kI8);
- bool tensorIsReal = (valTp >= PrimaryType::kF64 && valTp <= PrimaryType::kI8);
- if ((valueKind == SparseTensorFile::ValueKind::kReal && tensorIsInteger) ||
- (valueKind == SparseTensorFile::ValueKind::kComplex && tensorIsReal)) {
+ if (!stfile.canReadAs(valTp))
MLIR_SPARSETENSOR_FATAL(
"Tensor element type %d not compatible with values in file %s\n",
static_cast<int>(valTp), filename);
- }
stfile.assertMatchesShape(rank, shape);
// Prepare sparse tensor object with per-dimension sizes
// and the number of nonzeros as initial capacity.
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp
index b105f93cdd8b..5a01b89f2afa 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp
@@ -83,6 +83,33 @@ void SparseTensorFile::assertMatchesShape(uint64_t rank,
"Dimension size mismatch");
}
+bool SparseTensorFile::canReadAs(PrimaryType valTy) const {
+ switch (valueKind_) {
+ case ValueKind::kInvalid:
+ assert(false && "Must readHeader() before calling canReadAs()");
+ return false; // In case assertions are disabled.
+ case ValueKind::kPattern:
+ return true;
+ case ValueKind::kInteger:
+ // When the file is specified to store integer values, we still
+ // allow implicitly converting those to floating primary-types.
+ return isRealPrimaryType(valTy);
+ case ValueKind::kReal:
+ // When the file is specified to store real/floating values, then
+ // we disallow implicit conversion to integer primary-types.
+ return isFloatingPrimaryType(valTy);
+ case ValueKind::kComplex:
+ // When the file is specified to store complex values, then we
+ // require a complex primary-type.
+ return isComplexPrimaryType(valTy);
+ case ValueKind::kUndefined:
+ // The "extended" FROSTT format doesn't specify a ValueKind.
+ // So we allow implicitly converting the stored values to both
+ // integer and floating primary-types.
+ return isRealPrimaryType(valTy);
+ }
+}
+
/// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
static inline char *toLower(char *token) {
for (char *c = token; *c; ++c)
More information about the Mlir-commits
mailing list