[Mlir-commits] [mlir] 8840891 - [mlir][sparse] introduce LevelFormat which do not encoding level properties of a sparse tensor level.

Peiming Liu llvmlistbot at llvm.org
Tue Dec 20 16:47:12 PST 2022


Author: Peiming Liu
Date: 2022-12-21T00:47:06Z
New Revision: 8840891debf196d63b531b69fbb65138ecd50d02

URL: https://github.com/llvm/llvm-project/commit/8840891debf196d63b531b69fbb65138ecd50d02
DIFF: https://github.com/llvm/llvm-project/commit/8840891debf196d63b531b69fbb65138ecd50d02.diff

LOG: [mlir][sparse] introduce LevelFormat which do not encoding level properties of a sparse tensor level.

Reviewed By: aartbik, wrengr

Differential Revision: https://reviews.llvm.org/D140442

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index d46e325433a5d..3283a0f88b0c5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -33,6 +33,7 @@
 
 #include <cinttypes>
 #include <complex>
+#include <optional>
 
 namespace mlir {
 namespace sparse_tensor {
@@ -157,6 +158,14 @@ enum class DimLevelType : uint8_t {
   SingletonNuNo = 19,  // 0b100_11
 };
 
+/// This enum defines all the storage formats supported by the sparse compiler,
+/// without the level properties.
+enum class LevelFormat : uint8_t {
+  Dense = 4,      // 0b001_00
+  Compressed = 8, // 0b010_00
+  Singleton = 16, // 0b100_00
+};
+
 /// Returns string representation of the given dimension level type.
 inline std::string toMLIRString(DimLevelType dlt) {
   switch (dlt) {
@@ -231,6 +240,63 @@ constexpr bool isUniqueDLT(DimLevelType dlt) {
   return !(static_cast<uint8_t>(dlt) & 1);
 }
 
+/// Convert a DimLevelType to its corresponding LevelFormat.
+/// Returns std::nullopt when input dlt is Undef.
+constexpr std::optional<LevelFormat> getLevelFormat(DimLevelType dlt) {
+  if (dlt == DimLevelType::Undef)
+    return std::nullopt;
+  return static_cast<LevelFormat>(static_cast<uint8_t>(dlt) & ~3);
+}
+
+/// Convert a LevelFormat to its corresponding DimLevelType with the given
+/// properties. Returns std::nullopt when the properties are not applicable for
+/// the input level format.
+/// TODO: factor out a new LevelProperties type so we can add new properties
+/// without changing this function's signature
+constexpr std::optional<DimLevelType>
+getDimLevelType(LevelFormat lf, bool ordered, bool unique) {
+  auto dlt = static_cast<DimLevelType>(static_cast<uint8_t>(lf) |
+                                       (ordered ? 0 : 2) | (unique ? 0 : 1));
+  return isValidDLT(dlt) ? std::optional(dlt) : std::nullopt;
+}
+
+/// Ensure the above conversion works as intended.
+static_assert(
+    (getLevelFormat(DimLevelType::Undef) == std::nullopt &&
+     *getLevelFormat(DimLevelType::Dense) == LevelFormat::Dense &&
+     *getLevelFormat(DimLevelType::Compressed) == LevelFormat::Compressed &&
+     *getLevelFormat(DimLevelType::CompressedNu) == LevelFormat::Compressed &&
+     *getLevelFormat(DimLevelType::CompressedNo) == LevelFormat::Compressed &&
+     *getLevelFormat(DimLevelType::CompressedNuNo) == LevelFormat::Compressed &&
+     *getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
+     *getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
+     *getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
+     *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton),
+    "getLevelFormat conversion is broken");
+
+static_assert(
+    (getDimLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
+     getDimLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
+     getDimLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
+     *getDimLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
+     *getDimLevelType(LevelFormat::Compressed, true, true) ==
+         DimLevelType::Compressed &&
+     *getDimLevelType(LevelFormat::Compressed, true, false) ==
+         DimLevelType::CompressedNu &&
+     *getDimLevelType(LevelFormat::Compressed, false, true) ==
+         DimLevelType::CompressedNo &&
+     *getDimLevelType(LevelFormat::Compressed, false, false) ==
+         DimLevelType::CompressedNuNo &&
+     *getDimLevelType(LevelFormat::Singleton, true, true) ==
+         DimLevelType::Singleton &&
+     *getDimLevelType(LevelFormat::Singleton, true, false) ==
+         DimLevelType::SingletonNu &&
+     *getDimLevelType(LevelFormat::Singleton, false, true) ==
+         DimLevelType::SingletonNo &&
+     *getDimLevelType(LevelFormat::Singleton, false, false) ==
+         DimLevelType::SingletonNuNo),
+    "getDimLevelType conversion is broken");
+
 // Ensure the above predicates work as intended.
 static_assert((isValidDLT(DimLevelType::Undef) &&
                isValidDLT(DimLevelType::Dense) &&


        


More information about the Mlir-commits mailing list