[Mlir-commits] [mlir] 774674c - [mlir][sparse] Factored out a "FATAL" macro for unrecoverable assertion failure

wren romano llvmlistbot at llvm.org
Thu May 19 15:26:26 PDT 2022


Author: wren romano
Date: 2022-05-19T15:26:20-07:00
New Revision: 774674ce9abbae6538d404ae1187d7b8a0fd120d

URL: https://github.com/llvm/llvm-project/commit/774674ce9abbae6538d404ae1187d7b8a0fd120d
DIFF: https://github.com/llvm/llvm-project/commit/774674ce9abbae6538d404ae1187d7b8a0fd120d.diff

LOG: [mlir][sparse] Factored out a "FATAL" macro for unrecoverable assertion failure

Depends On D126019

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D126022

Added: 
    

Modified: 
    mlir/lib/ExecutionEngine/SparseTensorUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
index 53d26d5af689..d278e4e18e9d 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorUtils.cpp
@@ -84,6 +84,17 @@ static inline uint64_t checkedMul(uint64_t lhs, uint64_t rhs) {
   return lhs * rhs;
 }
 
+// This macro helps minimize repetition of this idiom, as well as ensuring
+// we have some additional output indicating where the error is coming from.
+// (Since `fprintf` doesn't provide a stacktrace, this helps make it easier
+// to track down whether an error is coming from our code vs somewhere else
+// in MLIR.)
+#define FATAL(...)                                                             \
+  {                                                                            \
+    fprintf(stderr, "SparseTensorUtils: " __VA_ARGS__);                        \
+    exit(1);                                                                   \
+  }
+
 // TODO: adjust this so it can be used by `openSparseTensorCOO` too.
 // That version doesn't have the permutation, and the `dimSizes` are
 // a pointer/C-array rather than `std::vector`.
@@ -262,6 +273,11 @@ struct SparseTensorCOO final {
 template <typename V>
 class SparseTensorEnumeratorBase;
 
+// Helper macro for generating error messages when some
+// `SparseTensorStorage<P,I,V>` is cast to `SparseTensorStorageBase`
+// and then the wrong "partial method specialization" is called.
+#define FATAL_PIV(NAME) FATAL("<P,I,V> type mismatch for: " #NAME);
+
 /// Abstract base class for `SparseTensorStorage<P,I,V>`.  This class
 /// takes responsibility for all the `<P,I,V>`-independent aspects
 /// of the tensor (e.g., shape, sparsity, permutation).  In addition,
@@ -325,37 +341,53 @@ class SparseTensorStorageBase {
 #define DECL_NEWENUMERATOR(VNAME, V)                                           \
   virtual void newEnumerator(SparseTensorEnumeratorBase<V> **, uint64_t,       \
                              const uint64_t *) const {                         \
-    fatal("newEnumerator" #VNAME);                                             \
+    FATAL_PIV("newEnumerator" #VNAME);                                         \
   }
   FOREVERY_V(DECL_NEWENUMERATOR)
 #undef DECL_NEWENUMERATOR
 
   /// Overhead storage.
-  virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
-  virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
-  virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
-  virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
-  virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
-  virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
-  virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
-  virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
+  virtual void getPointers(std::vector<uint64_t> **, uint64_t) {
+    FATAL_PIV("p64");
+  }
+  virtual void getPointers(std::vector<uint32_t> **, uint64_t) {
+    FATAL_PIV("p32");
+  }
+  virtual void getPointers(std::vector<uint16_t> **, uint64_t) {
+    FATAL_PIV("p16");
+  }
+  virtual void getPointers(std::vector<uint8_t> **, uint64_t) {
+    FATAL_PIV("p8");
+  }
+  virtual void getIndices(std::vector<uint64_t> **, uint64_t) {
+    FATAL_PIV("i64");
+  }
+  virtual void getIndices(std::vector<uint32_t> **, uint64_t) {
+    FATAL_PIV("i32");
+  }
+  virtual void getIndices(std::vector<uint16_t> **, uint64_t) {
+    FATAL_PIV("i16");
+  }
+  virtual void getIndices(std::vector<uint8_t> **, uint64_t) {
+    FATAL_PIV("i8");
+  }
 
   /// Primary storage.
 #define DECL_GETVALUES(VNAME, V)                                               \
-  virtual void getValues(std::vector<V> **) { fatal("getValues" #VNAME); }
+  virtual void getValues(std::vector<V> **) { FATAL_PIV("getValues" #VNAME); }
   FOREVERY_V(DECL_GETVALUES)
 #undef DECL_GETVALUES
 
   /// Element-wise insertion in lexicographic index order.
 #define DECL_LEXINSERT(VNAME, V)                                               \
-  virtual void lexInsert(const uint64_t *, V) { fatal("lexInsert" #VNAME); }
+  virtual void lexInsert(const uint64_t *, V) { FATAL_PIV("lexInsert" #VNAME); }
   FOREVERY_V(DECL_LEXINSERT)
 #undef DECL_LEXINSERT
 
   /// Expanded insertion.
 #define DECL_EXPINSERT(VNAME, V)                                               \
   virtual void expInsert(uint64_t *, V *, bool *, uint64_t *, uint64_t) {      \
-    fatal("expInsert" #VNAME);                                                 \
+    FATAL_PIV("expInsert" #VNAME);                                             \
   }
   FOREVERY_V(DECL_EXPINSERT)
 #undef DECL_EXPINSERT
@@ -374,16 +406,13 @@ class SparseTensorStorageBase {
   SparseTensorStorageBase &operator=(const SparseTensorStorageBase &) = delete;
 
 private:
-  static void fatal(const char *tp) {
-    fprintf(stderr, "unsupported %s\n", tp);
-    exit(1);
-  }
-
   const std::vector<uint64_t> dimSizes;
   std::vector<uint64_t> rev;
   const std::vector<DimLevelType> dimTypes;
 };
 
+#undef FATAL_PIV
+
 // Forward.
 template <typename P, typename I, typename V>
 class SparseTensorEnumerator;
@@ -1122,10 +1151,8 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
   char symmetry[64];
   // Read header line.
   if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
-             symmetry) != 5) {
-    fprintf(stderr, "Corrupt header in %s\n", filename);
-    exit(1);
-  }
+             symmetry) != 5)
+    FATAL("Corrupt header in %s\n", filename);
   // Set properties
   *isPattern = (strcmp(toLower(field), "pattern") == 0);
   *isSymmetric = (strcmp(toLower(symmetry), "symmetric") == 0);
@@ -1134,26 +1161,20 @@ static void readMMEHeader(FILE *file, char *filename, char *line,
       strcmp(toLower(object), "matrix") ||
       strcmp(toLower(format), "coordinate") ||
       (strcmp(toLower(field), "real") && !(*isPattern)) ||
-      (strcmp(toLower(symmetry), "general") && !(*isSymmetric))) {
-    fprintf(stderr, "Cannot find a general sparse matrix in %s\n", filename);
-    exit(1);
-  }
+      (strcmp(toLower(symmetry), "general") && !(*isSymmetric)))
+    FATAL("Cannot find a general sparse matrix in %s\n", filename);
   // Skip comments.
   while (true) {
-    if (!fgets(line, kColWidth, file)) {
-      fprintf(stderr, "Cannot find data in %s\n", filename);
-      exit(1);
-    }
+    if (!fgets(line, kColWidth, file))
+      FATAL("Cannot find data in %s\n", filename);
     if (line[0] != '%')
       break;
   }
   // Next line contains M N NNZ.
   idata[0] = 2; // rank
   if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
-             idata + 1) != 3) {
-    fprintf(stderr, "Cannot find size in %s\n", filename);
-    exit(1);
-  }
+             idata + 1) != 3)
+    FATAL("Cannot find size in %s\n", filename);
 }
 
 /// Read the "extended" FROSTT header. Although not part of the documented
@@ -1164,25 +1185,18 @@ static void readExtFROSTTHeader(FILE *file, char *filename, char *line,
                                 uint64_t *idata) {
   // Skip comments.
   while (true) {
-    if (!fgets(line, kColWidth, file)) {
-      fprintf(stderr, "Cannot find data in %s\n", filename);
-      exit(1);
-    }
+    if (!fgets(line, kColWidth, file))
+      FATAL("Cannot find data in %s\n", filename);
     if (line[0] != '#')
       break;
   }
   // Next line contains RANK and NNZ.
-  if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
-    fprintf(stderr, "Cannot find metadata in %s\n", filename);
-    exit(1);
-  }
+  if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2)
+    FATAL("Cannot find metadata in %s\n", filename);
   // Followed by a line with the dimension sizes (one per rank).
-  for (uint64_t r = 0; r < idata[0]; r++) {
-    if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
-      fprintf(stderr, "Cannot find dimension size %s\n", filename);
-      exit(1);
-    }
-  }
+  for (uint64_t r = 0; r < idata[0]; r++)
+    if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1)
+      FATAL("Cannot find dimension size %s\n", filename);
   fgets(line, kColWidth, file); // end of line
 }
 
@@ -1193,12 +1207,10 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
                                                const uint64_t *shape,
                                                const uint64_t *perm) {
   // Open the file.
+  assert(filename && "Received nullptr for filename");
   FILE *file = fopen(filename, "r");
-  if (!file) {
-    assert(filename && "Received nullptr for filename");
-    fprintf(stderr, "Cannot find file %s\n", filename);
-    exit(1);
-  }
+  if (!file)
+    FATAL("Cannot find file %s\n", filename);
   // Perform some file format dependent set up.
   char line[kColWidth];
   uint64_t idata[512];
@@ -1209,8 +1221,7 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
   } else if (strstr(filename, ".tns")) {
     readExtFROSTTHeader(file, filename, line, idata);
   } else {
-    fprintf(stderr, "Unknown format %s\n", filename);
-    exit(1);
+    FATAL("Unknown format %s\n", filename);
   }
   // Prepare sparse tensor object with per-dimension sizes
   // and the number of nonzeros as initial capacity.
@@ -1224,10 +1235,8 @@ static SparseTensorCOO<V> *openSparseTensorCOO(char *filename, uint64_t rank,
   // Read all nonzero elements.
   std::vector<uint64_t> indices(rank);
   for (uint64_t k = 0; k < nnz; k++) {
-    if (!fgets(line, kColWidth, file)) {
-      fprintf(stderr, "Cannot find next line of data in %s\n", filename);
-      exit(1);
-    }
+    if (!fgets(line, kColWidth, file))
+      FATAL("Cannot find next line of data in %s\n", filename);
     char *linePtr = line;
     for (uint64_t r = 0; r < rank; r++) {
       uint64_t idx = strtoul(linePtr, &linePtr, 10);
@@ -1290,22 +1299,15 @@ toMLIRSparseTensor(uint64_t rank, uint64_t nse, uint64_t *shape, V *values,
   // Verify that perm is a permutation of 0..(rank-1).
   std::vector<uint64_t> order(perm, perm + rank);
   std::sort(order.begin(), order.end());
-  for (uint64_t i = 0; i < rank; ++i) {
-    if (i != order[i]) {
-      fprintf(stderr, "Not a permutation of 0..%" PRIu64 "\n", rank);
-      exit(1);
-    }
-  }
+  for (uint64_t i = 0; i < rank; ++i)
+    if (i != order[i])
+      FATAL("Not a permutation of 0..%" PRIu64 "\n", rank);
 
   // Verify that the sparsity values are supported.
-  for (uint64_t i = 0; i < rank; ++i) {
+  for (uint64_t i = 0; i < rank; ++i)
     if (sparsity[i] != DimLevelType::kDense &&
-        sparsity[i] != DimLevelType::kCompressed) {
-      fprintf(stderr, "Unsupported sparsity value %d\n",
-              static_cast<int>(sparsity[i]));
-      exit(1);
-    }
-  }
+        sparsity[i] != DimLevelType::kCompressed)
+      FATAL("Unsupported sparsity value %d\n", static_cast<int>(sparsity[i]));
 #endif
 
   // Convert external format to internal COO.
@@ -1539,8 +1541,10 @@ _mlir_ciface_newSparseTensor(StridedMemRefType<DimLevelType, 1> *aref, // NOLINT
   CASE_SECSAME(OverheadType::kU64, PrimaryType::kC32, uint64_t, complex32);
 
   // Unsupported case (add above if needed).
-  fputs("unsupported combination of types\n", stderr);
-  exit(1);
+  // TODO: better pretty-printing of enum values!
+  FATAL("unsupported combination of types: <P=%d, I=%d, V=%d>\n",
+        static_cast<int>(ptrTp), static_cast<int>(indTp),
+        static_cast<int>(valTp));
 }
 #undef CASE
 #undef CASE_SECSAME
@@ -1704,10 +1708,8 @@ char *getTensorFilename(index_type id) {
   char var[80];
   sprintf(var, "TENSOR%" PRIu64, id);
   char *env = getenv(var);
-  if (!env) {
-    fprintf(stderr, "Environment variable %s is not set\n", var);
-    exit(1);
-  }
+  if (!env)
+    FATAL("Environment variable %s is not set\n", var);
   return env;
 }
 


        


More information about the Mlir-commits mailing list