[Mlir-commits] [mlir] ee3ee13 - [mlir][sparse] cleanup of enums header (#71090)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 2 13:00:32 PDT 2023
Author: Aart Bik
Date: 2023-11-02T13:00:27-07:00
New Revision: ee3ee1315a757345dfc50eed34b5074c6e87df2d
URL: https://github.com/llvm/llvm-project/commit/ee3ee1315a757345dfc50eed34b5074c6e87df2d
DIFF: https://github.com/llvm/llvm-project/commit/ee3ee1315a757345dfc50eed34b5074c6e87df2d.diff
LOG: [mlir][sparse] cleanup of enums header (#71090)
Some DLT related methods leaked into sparse_tensor.h, and this moves it
back to the right header. Also, the asserts were incomplete and some DLT
methods duplicated.
Added:
Modified:
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1e9aa2bdf45dbdb..9c277a0b23633d8 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,24 +153,18 @@ enum class Action : uint32_t {
};
/// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton) as well as
-/// the "properties" (ordered, unique). The encoding is chosen for
-/// performance of the runtime library, and thus may change in future
-/// versions; consequently, client code should use the predicate functions
-/// defined below, rather than relying on knowledge about the particular
-/// binary encoding.
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// both the "format" per se (dense, compressed, singleton, loose_compressed,
+/// two-out-of-four) as well as the "properties" (ordered, unique). The
+/// encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
///
/// The `Undef` "format" is a special value used internally for cases
/// where we need to store an undefined or indeterminate `DimLevelType`.
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
-///
-// TODO: We should generalize TwoOutOfFour to N out of M and use property to
-// encode the value of N and M.
-// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
-// higher 4 bits to store level properties. Consider LooseCompressed and
-// TwoOutOfFour as properties instead of formats.
enum class DimLevelType : uint8_t {
Undef = 0, // 0b00000_00
Dense = 4, // 0b00001_00
@@ -257,44 +251,47 @@ constexpr bool isUndefDLT(DimLevelType dlt) {
return dlt == DimLevelType::Undef;
}
-/// Check if the `DimLevelType` is dense.
+/// Check if the `DimLevelType` is dense (regardless of properties).
constexpr bool isDenseDLT(DimLevelType dlt) {
- return dlt == DimLevelType::Dense;
-}
-
-/// Check if the `DimLevelType` is 2:4
-constexpr bool isTwoOutOfFourDLT(DimLevelType dlt) {
- return dlt == DimLevelType::TwoOutOfFour;
+ return (static_cast<uint8_t>(dlt) & ~3) ==
+ static_cast<uint8_t>(DimLevelType::Dense);
}
-// We use the idiom `(dlt & ~3) == format` in order to only return true
-// for valid DLTs. Whereas the `dlt & format` idiom is a bit faster but
-// can return false-positives on invalid DLTs.
-
/// Check if the `DimLevelType` is compressed (regardless of properties).
constexpr bool isCompressedDLT(DimLevelType dlt) {
return (static_cast<uint8_t>(dlt) & ~3) ==
static_cast<uint8_t>(DimLevelType::Compressed);
}
-/// Check if the `DimLevelType` is loose compressed (regardless of properties).
-constexpr bool isLooseCompressedDLT(DimLevelType dlt) {
- return (static_cast<uint8_t>(dlt) & ~3) ==
- static_cast<uint8_t>(DimLevelType::LooseCompressed);
-}
-
/// Check if the `DimLevelType` is singleton (regardless of properties).
constexpr bool isSingletonDLT(DimLevelType dlt) {
return (static_cast<uint8_t>(dlt) & ~3) ==
static_cast<uint8_t>(DimLevelType::Singleton);
}
+/// Check if the `DimLevelType` is loose compressed (regardless of properties).
+constexpr bool isLooseCompressedDLT(DimLevelType dlt) {
+ return (static_cast<uint8_t>(dlt) & ~3) ==
+ static_cast<uint8_t>(DimLevelType::LooseCompressed);
+}
+
/// Check if the `DimLevelType` is 2OutOf4 (regardless of properties).
constexpr bool is2OutOf4DLT(DimLevelType dlt) {
return (static_cast<uint8_t>(dlt) & ~3) ==
static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
}
+/// Check if the `DimLevelType` needs positions array.
+constexpr bool isDLTWithPos(DimLevelType dlt) {
+ return isCompressedDLT(dlt) || isLooseCompressedDLT(dlt);
+}
+
+/// Check if the `DimLevelType` needs coordinates array.
+constexpr bool isDLTWithCrd(DimLevelType dlt) {
+ return isCompressedDLT(dlt) || isSingletonDLT(dlt) ||
+ isLooseCompressedDLT(dlt) || is2OutOf4DLT(dlt);
+}
+
/// Check if the `DimLevelType` is ordered (regardless of storage format).
constexpr bool isOrderedDLT(DimLevelType dlt) {
return !(static_cast<uint8_t>(dlt) & 2);
@@ -325,7 +322,10 @@ buildLevelType(LevelFormat lf, bool ordered, bool unique) {
return isValidDLT(dlt) ? std::optional(dlt) : std::nullopt;
}
-/// Ensure the above conversion works as intended.
+//
+// Ensure the above methods work as indended.
+//
+
static_assert(
(getLevelFormat(DimLevelType::Undef) == std::nullopt &&
*getLevelFormat(DimLevelType::Dense) == LevelFormat::Dense &&
@@ -336,7 +336,16 @@ static_assert(
*getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
*getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
*getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
- *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton),
+ *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton &&
+ *getLevelFormat(DimLevelType::LooseCompressed) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::LooseCompressedNu) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::LooseCompressedNo) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::LooseCompressedNuNo) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
"getLevelFormat conversion is broken");
static_assert(
@@ -344,11 +353,6 @@ static_assert(
buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
*buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
- DimLevelType::TwoOutOfFour &&
*buildLevelType(LevelFormat::Compressed, true, true) ==
DimLevelType::Compressed &&
*buildLevelType(LevelFormat::Compressed, true, false) ==
@@ -364,10 +368,22 @@ static_assert(
*buildLevelType(LevelFormat::Singleton, false, true) ==
DimLevelType::SingletonNo &&
*buildLevelType(LevelFormat::Singleton, false, false) ==
- DimLevelType::SingletonNuNo),
+ DimLevelType::SingletonNuNo &&
+ *buildLevelType(LevelFormat::LooseCompressed, true, true) ==
+ DimLevelType::LooseCompressed &&
+ *buildLevelType(LevelFormat::LooseCompressed, true, false) ==
+ DimLevelType::LooseCompressedNu &&
+ *buildLevelType(LevelFormat::LooseCompressed, false, true) ==
+ DimLevelType::LooseCompressedNo &&
+ *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
+ DimLevelType::LooseCompressedNuNo &&
+ buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
+ buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
+ buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
+ *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
+ DimLevelType::TwoOutOfFour),
"buildLevelType conversion is broken");
-// Ensure the above predicates work as intended.
static_assert((isValidDLT(DimLevelType::Undef) &&
isValidDLT(DimLevelType::Dense) &&
isValidDLT(DimLevelType::Compressed) &&
@@ -385,6 +401,22 @@ static_assert((isValidDLT(DimLevelType::Undef) &&
isValidDLT(DimLevelType::TwoOutOfFour)),
"isValidDLT definition is broken");
+static_assert((isDenseDLT(DimLevelType::Dense) &&
+ !isDenseDLT(DimLevelType::Compressed) &&
+ !isDenseDLT(DimLevelType::CompressedNu) &&
+ !isDenseDLT(DimLevelType::CompressedNo) &&
+ !isDenseDLT(DimLevelType::CompressedNuNo) &&
+ !isDenseDLT(DimLevelType::Singleton) &&
+ !isDenseDLT(DimLevelType::SingletonNu) &&
+ !isDenseDLT(DimLevelType::SingletonNo) &&
+ !isDenseDLT(DimLevelType::SingletonNuNo) &&
+ !isDenseDLT(DimLevelType::LooseCompressed) &&
+ !isDenseDLT(DimLevelType::LooseCompressedNu) &&
+ !isDenseDLT(DimLevelType::LooseCompressedNo) &&
+ !isDenseDLT(DimLevelType::LooseCompressedNuNo) &&
+ !isDenseDLT(DimLevelType::TwoOutOfFour)),
+ "isDenseDLT definition is broken");
+
static_assert((!isCompressedDLT(DimLevelType::Dense) &&
isCompressedDLT(DimLevelType::Compressed) &&
isCompressedDLT(DimLevelType::CompressedNu) &&
@@ -393,20 +425,14 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
!isCompressedDLT(DimLevelType::Singleton) &&
!isCompressedDLT(DimLevelType::SingletonNu) &&
!isCompressedDLT(DimLevelType::SingletonNo) &&
- !isCompressedDLT(DimLevelType::SingletonNuNo)),
+ !isCompressedDLT(DimLevelType::SingletonNuNo) &&
+ !isCompressedDLT(DimLevelType::LooseCompressed) &&
+ !isCompressedDLT(DimLevelType::LooseCompressedNu) &&
+ !isCompressedDLT(DimLevelType::LooseCompressedNo) &&
+ !isCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
+ !isCompressedDLT(DimLevelType::TwoOutOfFour)),
"isCompressedDLT definition is broken");
-static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
- isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
- isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
- isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
- isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
- !isLooseCompressedDLT(DimLevelType::Singleton) &&
- !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
- !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
- !isLooseCompressedDLT(DimLevelType::SingletonNuNo)),
- "isLooseCompressedDLT definition is broken");
-
static_assert((!isSingletonDLT(DimLevelType::Dense) &&
!isSingletonDLT(DimLevelType::Compressed) &&
!isSingletonDLT(DimLevelType::CompressedNu) &&
@@ -415,11 +441,47 @@ static_assert((!isSingletonDLT(DimLevelType::Dense) &&
isSingletonDLT(DimLevelType::Singleton) &&
isSingletonDLT(DimLevelType::SingletonNu) &&
isSingletonDLT(DimLevelType::SingletonNo) &&
- isSingletonDLT(DimLevelType::SingletonNuNo)),
+ isSingletonDLT(DimLevelType::SingletonNuNo) &&
+ !isSingletonDLT(DimLevelType::LooseCompressed) &&
+ !isSingletonDLT(DimLevelType::LooseCompressedNu) &&
+ !isSingletonDLT(DimLevelType::LooseCompressedNo) &&
+ !isSingletonDLT(DimLevelType::LooseCompressedNuNo) &&
+ !isSingletonDLT(DimLevelType::TwoOutOfFour)),
"isSingletonDLT definition is broken");
+static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
+ !isLooseCompressedDLT(DimLevelType::Compressed) &&
+ !isLooseCompressedDLT(DimLevelType::CompressedNu) &&
+ !isLooseCompressedDLT(DimLevelType::CompressedNo) &&
+ !isLooseCompressedDLT(DimLevelType::CompressedNuNo) &&
+ !isLooseCompressedDLT(DimLevelType::Singleton) &&
+ !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
+ !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
+ !isLooseCompressedDLT(DimLevelType::SingletonNuNo) &&
+ isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
+ isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
+ isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
+ isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
+ !isLooseCompressedDLT(DimLevelType::TwoOutOfFour)),
+ "isLooseCompressedDLT definition is broken");
+
+static_assert((!is2OutOf4DLT(DimLevelType::Dense) &&
+ !is2OutOf4DLT(DimLevelType::Compressed) &&
+ !is2OutOf4DLT(DimLevelType::CompressedNu) &&
+ !is2OutOf4DLT(DimLevelType::CompressedNo) &&
+ !is2OutOf4DLT(DimLevelType::CompressedNuNo) &&
+ !is2OutOf4DLT(DimLevelType::Singleton) &&
+ !is2OutOf4DLT(DimLevelType::SingletonNu) &&
+ !is2OutOf4DLT(DimLevelType::SingletonNo) &&
+ !is2OutOf4DLT(DimLevelType::SingletonNuNo) &&
+ !is2OutOf4DLT(DimLevelType::LooseCompressed) &&
+ !is2OutOf4DLT(DimLevelType::LooseCompressedNu) &&
+ !is2OutOf4DLT(DimLevelType::LooseCompressedNo) &&
+ !is2OutOf4DLT(DimLevelType::LooseCompressedNuNo) &&
+ is2OutOf4DLT(DimLevelType::TwoOutOfFour)),
+ "is2OutOf4DLT definition is broken");
+
static_assert((isOrderedDLT(DimLevelType::Dense) &&
- isOrderedDLT(DimLevelType::TwoOutOfFour) &&
isOrderedDLT(DimLevelType::Compressed) &&
isOrderedDLT(DimLevelType::CompressedNu) &&
!isOrderedDLT(DimLevelType::CompressedNo) &&
@@ -431,11 +493,11 @@ static_assert((isOrderedDLT(DimLevelType::Dense) &&
isOrderedDLT(DimLevelType::LooseCompressed) &&
isOrderedDLT(DimLevelType::LooseCompressedNu) &&
!isOrderedDLT(DimLevelType::LooseCompressedNo) &&
- !isOrderedDLT(DimLevelType::LooseCompressedNuNo)),
+ !isOrderedDLT(DimLevelType::LooseCompressedNuNo) &&
+ isOrderedDLT(DimLevelType::TwoOutOfFour)),
"isOrderedDLT definition is broken");
static_assert((isUniqueDLT(DimLevelType::Dense) &&
- isUniqueDLT(DimLevelType::TwoOutOfFour) &&
isUniqueDLT(DimLevelType::Compressed) &&
!isUniqueDLT(DimLevelType::CompressedNu) &&
isUniqueDLT(DimLevelType::CompressedNo) &&
@@ -447,7 +509,8 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
isUniqueDLT(DimLevelType::LooseCompressed) &&
!isUniqueDLT(DimLevelType::LooseCompressedNu) &&
isUniqueDLT(DimLevelType::LooseCompressedNo) &&
- !isUniqueDLT(DimLevelType::LooseCompressedNuNo)),
+ !isUniqueDLT(DimLevelType::LooseCompressedNuNo) &&
+ isUniqueDLT(DimLevelType::TwoOutOfFour)),
"isUniqueDLT definition is broken");
/// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 94e7d12b9ee915f..241d90a87165928 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,16 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
-/// Convenience method to query whether a given DLT needs both position and
-/// coordinates array or only coordinates array.
-constexpr inline bool isDLTWithPos(DimLevelType dlt) {
- return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
-}
-constexpr inline bool isDLTWithCrd(DimLevelType dlt) {
- return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
- isCompressedDLT(dlt);
-}
-
/// Returns true iff the given sparse tensor encoding attribute has a trailing
/// COO region starting at the given level.
bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 3c73b19319e588c..a254f52aa86e7db 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -371,10 +371,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const;
bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); }
- bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return isTwoOutOfFourDLT(getLvlType(l)); }
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); }
- bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); }
+ bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
+ bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4DLT(getLvlType(l)); }
bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); }
bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 99214fadf4ba3db..c727b8d05c26d7d 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -10,6 +10,7 @@
#include "Detail/DimLvlMapParser.h"
+#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
More information about the Mlir-commits
mailing list