[Mlir-commits] [mlir] c8177f8 - [mlir][sparse] Factoring out SparseTensorFile::canReadAs predicate

wren romano llvmlistbot at llvm.org
Thu Sep 29 14:46:56 PDT 2022


Author: wren romano
Date: 2022-09-29T14:46:45-07:00
New Revision: c8177f845b4132f2838d169cee04270051235140

URL: https://github.com/llvm/llvm-project/commit/c8177f845b4132f2838d169cee04270051235140
DIFF: https://github.com/llvm/llvm-project/commit/c8177f845b4132f2838d169cee04270051235140.diff

LOG: [mlir][sparse] Factoring out SparseTensorFile::canReadAs predicate

This is a followup to the refactoring of D133462, D133830, D133831, and D133833.

Depends On D133833

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D133835

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
    mlir/lib/ExecutionEngine/SparseTensor/File.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
index 3d82b4908ad5..5dd6c17364a8 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/File.h
@@ -86,6 +86,10 @@ class SparseTensorFile final {
   /// Checks if a header has been successfully read.
   bool isValid() const { return valueKind_ != ValueKind::kInvalid; }
 
+  /// Checks if the file's ValueKind can be converted into the given
+  /// tensor PrimaryType.  Is only valid after parsing the header.
+  bool canReadAs(PrimaryType valTy) const;
+
   /// Gets the MME "pattern" property setting.  Is only valid after
   /// parsing the header.
   bool isPattern() const {
@@ -208,16 +212,10 @@ openSparseTensorCOO(const char *filename, uint64_t rank, const uint64_t *shape,
   stfile.openFile();
   stfile.readHeader();
   // Check tensor element type against the value type in the input file.
-  SparseTensorFile::ValueKind valueKind = stfile.getValueKind();
-  bool tensorIsInteger =
-      (valTp >= PrimaryType::kI64 && valTp <= PrimaryType::kI8);
-  bool tensorIsReal = (valTp >= PrimaryType::kF64 && valTp <= PrimaryType::kI8);
-  if ((valueKind == SparseTensorFile::ValueKind::kReal && tensorIsInteger) ||
-      (valueKind == SparseTensorFile::ValueKind::kComplex && tensorIsReal)) {
+  if (!stfile.canReadAs(valTp))
     MLIR_SPARSETENSOR_FATAL(
         "Tensor element type %d not compatible with values in file %s\n",
         static_cast<int>(valTp), filename);
-  }
   stfile.assertMatchesShape(rank, shape);
   // Prepare sparse tensor object with per-dimension sizes
   // and the number of nonzeros as initial capacity.

diff  --git a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp
index b105f93cdd8b..5a01b89f2afa 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/File.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/File.cpp
@@ -83,6 +83,33 @@ void SparseTensorFile::assertMatchesShape(uint64_t rank,
            "Dimension size mismatch");
 }
 
+bool SparseTensorFile::canReadAs(PrimaryType valTy) const {
+  switch (valueKind_) {
+  case ValueKind::kInvalid:
+    assert(false && "Must readHeader() before calling canReadAs()");
+    return false; // In case assertions are disabled.
+  case ValueKind::kPattern:
+    return true;
+  case ValueKind::kInteger:
+    // When the file is specified to store integer values, we still
+    // allow implicitly converting those to floating primary-types.
+    return isRealPrimaryType(valTy);
+  case ValueKind::kReal:
+    // When the file is specified to store real/floating values, then
+    // we disallow implicit conversion to integer primary-types.
+    return isFloatingPrimaryType(valTy);
+  case ValueKind::kComplex:
+    // When the file is specified to store complex values, then we
+    // require a complex primary-type.
+    return isComplexPrimaryType(valTy);
+  case ValueKind::kUndefined:
+    // The "extended" FROSTT format doesn't specify a ValueKind.
+    // So we allow implicitly converting the stored values to both
+    // integer and floating primary-types.
+    return isRealPrimaryType(valTy);
+  }
+}
+
 /// Helper to convert C-style strings (i.e., '\0' terminated) to lower case.
 static inline char *toLower(char *token) {
   for (char *c = token; *c; ++c)


        


More information about the Mlir-commits mailing list