[Mlir-commits] [mlir] 8840891 - [mlir][sparse] introduce LevelFormat which do not encoding level properties of a sparse tensor level.
Peiming Liu
llvmlistbot at llvm.org
Tue Dec 20 16:47:12 PST 2022
Author: Peiming Liu
Date: 2022-12-21T00:47:06Z
New Revision: 8840891debf196d63b531b69fbb65138ecd50d02
URL: https://github.com/llvm/llvm-project/commit/8840891debf196d63b531b69fbb65138ecd50d02
DIFF: https://github.com/llvm/llvm-project/commit/8840891debf196d63b531b69fbb65138ecd50d02.diff
LOG: [mlir][sparse] introduce LevelFormat which do not encoding level properties of a sparse tensor level.
Reviewed By: aartbik, wrengr
Differential Revision: https://reviews.llvm.org/D140442
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index d46e325433a5d..3283a0f88b0c5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -33,6 +33,7 @@
#include <cinttypes>
#include <complex>
+#include <optional>
namespace mlir {
namespace sparse_tensor {
@@ -157,6 +158,14 @@ enum class DimLevelType : uint8_t {
SingletonNuNo = 19, // 0b100_11
};
+/// This enum defines all the storage formats supported by the sparse compiler,
+/// without the level properties.
+enum class LevelFormat : uint8_t {
+ Dense = 4, // 0b001_00
+ Compressed = 8, // 0b010_00
+ Singleton = 16, // 0b100_00
+};
+
/// Returns string representation of the given dimension level type.
inline std::string toMLIRString(DimLevelType dlt) {
switch (dlt) {
@@ -231,6 +240,63 @@ constexpr bool isUniqueDLT(DimLevelType dlt) {
return !(static_cast<uint8_t>(dlt) & 1);
}
+/// Convert a DimLevelType to its corresponding LevelFormat.
+/// Returns std::nullopt when input dlt is Undef.
+constexpr std::optional<LevelFormat> getLevelFormat(DimLevelType dlt) {
+ if (dlt == DimLevelType::Undef)
+ return std::nullopt;
+ return static_cast<LevelFormat>(static_cast<uint8_t>(dlt) & ~3);
+}
+
+/// Convert a LevelFormat to its corresponding DimLevelType with the given
+/// properties. Returns std::nullopt when the properties are not applicable for
+/// the input level format.
+/// TODO: factor out a new LevelProperties type so we can add new properties
+/// without changing this function's signature
+constexpr std::optional<DimLevelType>
+getDimLevelType(LevelFormat lf, bool ordered, bool unique) {
+ auto dlt = static_cast<DimLevelType>(static_cast<uint8_t>(lf) |
+ (ordered ? 0 : 2) | (unique ? 0 : 1));
+ return isValidDLT(dlt) ? std::optional(dlt) : std::nullopt;
+}
+
+/// Ensure the above conversion works as intended.
+static_assert(
+ (getLevelFormat(DimLevelType::Undef) == std::nullopt &&
+ *getLevelFormat(DimLevelType::Dense) == LevelFormat::Dense &&
+ *getLevelFormat(DimLevelType::Compressed) == LevelFormat::Compressed &&
+ *getLevelFormat(DimLevelType::CompressedNu) == LevelFormat::Compressed &&
+ *getLevelFormat(DimLevelType::CompressedNo) == LevelFormat::Compressed &&
+ *getLevelFormat(DimLevelType::CompressedNuNo) == LevelFormat::Compressed &&
+ *getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
+ *getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
+ *getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
+ *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton),
+ "getLevelFormat conversion is broken");
+
+static_assert(
+ (getDimLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
+ getDimLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
+ getDimLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
+ *getDimLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
+ *getDimLevelType(LevelFormat::Compressed, true, true) ==
+ DimLevelType::Compressed &&
+ *getDimLevelType(LevelFormat::Compressed, true, false) ==
+ DimLevelType::CompressedNu &&
+ *getDimLevelType(LevelFormat::Compressed, false, true) ==
+ DimLevelType::CompressedNo &&
+ *getDimLevelType(LevelFormat::Compressed, false, false) ==
+ DimLevelType::CompressedNuNo &&
+ *getDimLevelType(LevelFormat::Singleton, true, true) ==
+ DimLevelType::Singleton &&
+ *getDimLevelType(LevelFormat::Singleton, true, false) ==
+ DimLevelType::SingletonNu &&
+ *getDimLevelType(LevelFormat::Singleton, false, true) ==
+ DimLevelType::SingletonNo &&
+ *getDimLevelType(LevelFormat::Singleton, false, false) ==
+ DimLevelType::SingletonNuNo),
+ "getDimLevelType conversion is broken");
+
// Ensure the above predicates work as intended.
static_assert((isValidDLT(DimLevelType::Undef) &&
isValidDLT(DimLevelType::Dense) &&
More information about the Mlir-commits
mailing list