[Mlir-commits] [mlir] 1d4d1c9 - [mlir][sparse] Factoring out predicates on DimLevelTypes

wren romano llvmlistbot at llvm.org
Fri Sep 30 11:15:46 PDT 2022


Author: wren romano
Date: 2022-09-30T11:15:34-07:00
New Revision: 1d4d1c99c550bc8d8df237c8edacda41a29acbe6

URL: https://github.com/llvm/llvm-project/commit/1d4d1c99c550bc8d8df237c8edacda41a29acbe6
DIFF: https://github.com/llvm/llvm-project/commit/1d4d1c99c550bc8d8df237c8edacda41a29acbe6.diff

LOG: [mlir][sparse] Factoring out predicates on DimLevelTypes

This way the predicates can be reused elsewhere, and can more easily be kept in sync with changes to the enum.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D134926

Added: 
    

Modified: 
    mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h
    mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h
index 6da5d472da0ac..bd665ab313a4d 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Enums.h
@@ -157,6 +157,63 @@ enum class MLIR_SPARSETENSOR_EXPORT DimLevelType : uint8_t {
   kSingletonNuNo = 8,
 };
 
+/// Check if the `DimLevelType` is dense.
+constexpr MLIR_SPARSETENSOR_EXPORT bool isDenseDLT(DimLevelType dlt) {
+  return dlt == DimLevelType::kDense;
+}
+
+/// Check if the `DimLevelType` is compressed (regardless of properties).
+constexpr MLIR_SPARSETENSOR_EXPORT bool isCompressedDLT(DimLevelType dlt) {
+  switch (dlt) {
+  case DimLevelType::kCompressed:
+  case DimLevelType::kCompressedNu:
+  case DimLevelType::kCompressedNo:
+  case DimLevelType::kCompressedNuNo:
+    return true;
+  default:
+    return false;
+  }
+}
+
+/// Check if the `DimLevelType` is singleton (regardless of properties).
+constexpr MLIR_SPARSETENSOR_EXPORT bool isSingletonDLT(DimLevelType dlt) {
+  switch (dlt) {
+  case DimLevelType::kSingleton:
+  case DimLevelType::kSingletonNu:
+  case DimLevelType::kSingletonNo:
+  case DimLevelType::kSingletonNuNo:
+    return true;
+  default:
+    return false;
+  }
+}
+
+/// Check if the `DimLevelType` is ordered (regardless of storage format).
+constexpr MLIR_SPARSETENSOR_EXPORT bool isOrderedDLT(DimLevelType dlt) {
+  switch (dlt) {
+  case DimLevelType::kCompressedNo:
+  case DimLevelType::kCompressedNuNo:
+  case DimLevelType::kSingletonNo:
+  case DimLevelType::kSingletonNuNo:
+    return false;
+  default:
+    return true;
+  }
+}
+
+/// Check if the `DimLevelType` is unique (regardless of storage format).
+constexpr MLIR_SPARSETENSOR_EXPORT bool isUniqueDLT(DimLevelType dlt) {
+  switch (dlt) {
+  case DimLevelType::kCompressedNu:
+  case DimLevelType::kCompressedNuNo:
+  case DimLevelType::kSingletonNu:
+  case DimLevelType::kSingletonNuNo:
+    return false;
+  default:
+    return true;
+  }
+}
+
 } // namespace sparse_tensor
 } // namespace mlir
 

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 54729231a0a90..cfe7805e1b0d8 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -102,67 +102,30 @@ class SparseTensorStorageBase {
   /// Get the dimension-types array, in storage-order.
   const std::vector<DimLevelType> &getDimTypes() const { return dimTypes; }
 
-  /// Safely check if the (storage-order) dimension uses dense storage.
-  bool isDenseDim(uint64_t d) const {
+  /// Safely lookup the level-type of the given (storage-order) dimension.
+  DimLevelType getDimType(uint64_t d) const {
     ASSERT_VALID_DIM(d);
-    return dimTypes[d] == DimLevelType::kDense;
+    return dimTypes[d];
   }
 
+  /// Safely check if the (storage-order) dimension uses dense storage.
+  bool isDenseDim(uint64_t d) const { return isDenseDLT(getDimType(d)); }
+
   /// Safely check if the (storage-order) dimension uses compressed storage.
   bool isCompressedDim(uint64_t d) const {
-    ASSERT_VALID_DIM(d);
-    switch (dimTypes[d]) {
-    case DimLevelType::kCompressed:
-    case DimLevelType::kCompressedNu:
-    case DimLevelType::kCompressedNo:
-    case DimLevelType::kCompressedNuNo:
-      return true;
-    default:
-      return false;
-    }
+    return isCompressedDLT(getDimType(d));
   }
 
   /// Safely check if the (storage-order) dimension uses singleton storage.
   bool isSingletonDim(uint64_t d) const {
-    ASSERT_VALID_DIM(d);
-    switch (dimTypes[d]) {
-    case DimLevelType::kSingleton:
-    case DimLevelType::kSingletonNu:
-    case DimLevelType::kSingletonNo:
-    case DimLevelType::kSingletonNuNo:
-      return true;
-    default:
-      return false;
-    }
+    return isSingletonDLT(getDimType(d));
   }
 
   /// Safely check if the (storage-order) dimension is ordered.
-  bool isOrderedDim(uint64_t d) const {
-    ASSERT_VALID_DIM(d);
-    switch (dimTypes[d]) {
-    case DimLevelType::kCompressedNo:
-    case DimLevelType::kCompressedNuNo:
-    case DimLevelType::kSingletonNo:
-    case DimLevelType::kSingletonNuNo:
-      return false;
-    default:
-      return true;
-    }
-  }
+  bool isOrderedDim(uint64_t d) const { return isOrderedDLT(getDimType(d)); }
 
   /// Safely check if the (storage-order) dimension is unique.
-  bool isUniqueDim(uint64_t d) const {
-    ASSERT_VALID_DIM(d);
-    switch (dimTypes[d]) {
-    case DimLevelType::kCompressedNu:
-    case DimLevelType::kCompressedNuNo:
-    case DimLevelType::kSingletonNu:
-    case DimLevelType::kSingletonNuNo:
-      return false;
-    default:
-      return true;
-    }
-  }
+  bool isUniqueDim(uint64_t d) const { return isUniqueDLT(getDimType(d)); }
 
   /// Allocate a new enumerator.
 #define DECL_NEWENUMERATOR(VNAME, V)                                           \


        


More information about the Mlir-commits mailing list