[Mlir-commits] [mlir] [mlir][sparse] cleanup of enums header (PR #71090)
Aart Bik
llvmlistbot at llvm.org
Thu Nov 2 12:04:51 PDT 2023
https://github.com/aartbik updated https://github.com/llvm/llvm-project/pull/71090
>From f067914fdb8c746b4ec94aa573e020c0a285d324 Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 2 Nov 2023 11:29:20 -0700
Subject: [PATCH 1/3] [mlir][sparse] cleanup of enums header
Some DLT related methods leaked into sparse_tensor.h, and this
moves it back to the right header. Also, the asserts were incomplete
and some DLT methods duplicated.
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 129 +++++++++++-------
.../Dialect/SparseTensor/IR/SparseTensor.h | 10 --
.../SparseTensor/IR/SparseTensorAttrDefs.td | 2 +-
.../SparseTensor/IR/SparseTensorDialect.cpp | 1 +
4 files changed, 81 insertions(+), 61 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1e9aa2bdf45dbdb..a867b99c3bfa5ba 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,24 +153,18 @@ enum class Action : uint32_t {
};
/// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton) as well as
-/// the "properties" (ordered, unique). The encoding is chosen for
-/// performance of the runtime library, and thus may change in future
-/// versions; consequently, client code should use the predicate functions
-/// defined below, rather than relying on knowledge about the particular
-/// binary encoding.
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// both the "format" per se (dense, compressed, singleton, loose_compressed,
+/// two-out-of-four) as well as the "properties" (ordered, unique). The
+/// encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
///
/// The `Undef` "format" is a special value used internally for cases
/// where we need to store an undefined or indeterminate `DimLevelType`.
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
-///
-// TODO: We should generalize TwoOutOfFour to N out of M and use property to
-// encode the value of N and M.
-// TODO: Update DimLevelType to use lower 8 bits for storage formats and the
-// higher 4 bits to store level properties. Consider LooseCompressed and
-// TwoOutOfFour as properties instead of formats.
enum class DimLevelType : uint8_t {
Undef = 0, // 0b00000_00
Dense = 4, // 0b00001_00
@@ -257,20 +251,12 @@ constexpr bool isUndefDLT(DimLevelType dlt) {
return dlt == DimLevelType::Undef;
}
-/// Check if the `DimLevelType` is dense.
+/// Check if the `DimLevelType` is dense (regardless of properties).
constexpr bool isDenseDLT(DimLevelType dlt) {
- return dlt == DimLevelType::Dense;
-}
-
-/// Check if the `DimLevelType` is 2:4
-constexpr bool isTwoOutOfFourDLT(DimLevelType dlt) {
- return dlt == DimLevelType::TwoOutOfFour;
+ return (static_cast<uint8_t>(dlt) & ~3) ==
+ static_cast<uint8_t>(DimLevelType::Dense);
}
-// We use the idiom `(dlt & ~3) == format` in order to only return true
-// for valid DLTs. Whereas the `dlt & format` idiom is a bit faster but
-// can return false-positives on invalid DLTs.
-
/// Check if the `DimLevelType` is compressed (regardless of properties).
constexpr bool isCompressedDLT(DimLevelType dlt) {
return (static_cast<uint8_t>(dlt) & ~3) ==
@@ -295,6 +281,17 @@ constexpr bool is2OutOf4DLT(DimLevelType dlt) {
static_cast<uint8_t>(DimLevelType::TwoOutOfFour);
}
+/// Check if the `DimLevelType` needs positions array.
+constexpr bool isDLTWithPos(DimLevelType dlt) {
+ return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
+}
+
+/// Check if the `DimLevelType` needs coordinates array.
+constexpr bool isDLTWithCrd(DimLevelType dlt) {
+ return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
+ isCompressedDLT(dlt);
+}
+
/// Check if the `DimLevelType` is ordered (regardless of storage format).
constexpr bool isOrderedDLT(DimLevelType dlt) {
return !(static_cast<uint8_t>(dlt) & 2);
@@ -336,35 +333,52 @@ static_assert(
*getLevelFormat(DimLevelType::Singleton) == LevelFormat::Singleton &&
*getLevelFormat(DimLevelType::SingletonNu) == LevelFormat::Singleton &&
*getLevelFormat(DimLevelType::SingletonNo) == LevelFormat::Singleton &&
- *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton),
+ *getLevelFormat(DimLevelType::SingletonNuNo) == LevelFormat::Singleton &&
+ *getLevelFormat(DimLevelType::LooseCompressed) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::LooseCompressedNu) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::LooseCompressedNo) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::LooseCompressedNuNo) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(DimLevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
"getLevelFormat conversion is broken");
static_assert(
(buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
- DimLevelType::TwoOutOfFour &&
- *buildLevelType(LevelFormat::Compressed, true, true) ==
+ buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
+ buildLevelType(LevelFormat::Compressed, true, true) ==
DimLevelType::Compressed &&
- *buildLevelType(LevelFormat::Compressed, true, false) ==
+ buildLevelType(LevelFormat::Compressed, true, false) ==
DimLevelType::CompressedNu &&
- *buildLevelType(LevelFormat::Compressed, false, true) ==
+ buildLevelType(LevelFormat::Compressed, false, true) ==
DimLevelType::CompressedNo &&
- *buildLevelType(LevelFormat::Compressed, false, false) ==
+ buildLevelType(LevelFormat::Compressed, false, false) ==
DimLevelType::CompressedNuNo &&
- *buildLevelType(LevelFormat::Singleton, true, true) ==
+ buildLevelType(LevelFormat::Singleton, true, true) ==
DimLevelType::Singleton &&
- *buildLevelType(LevelFormat::Singleton, true, false) ==
+ buildLevelType(LevelFormat::Singleton, true, false) ==
DimLevelType::SingletonNu &&
- *buildLevelType(LevelFormat::Singleton, false, true) ==
+ buildLevelType(LevelFormat::Singleton, false, true) ==
DimLevelType::SingletonNo &&
- *buildLevelType(LevelFormat::Singleton, false, false) ==
- DimLevelType::SingletonNuNo),
+ buildLevelType(LevelFormat::Singleton, false, false) ==
+ DimLevelType::SingletonNuNo &&
+ buildLevelType(LevelFormat::LooseCompressed, true, true) ==
+ DimLevelType::LooseCompressed &&
+ buildLevelType(LevelFormat::LooseCompressed, true, false) ==
+ DimLevelType::LooseCompressedNu &&
+ buildLevelType(LevelFormat::LooseCompressed, false, true) ==
+ DimLevelType::LooseCompressedNo &&
+ buildLevelType(LevelFormat::LooseCompressed, false, false) ==
+ DimLevelType::LooseCompressedNuNo &&
+ buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
+ buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
+ buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
+ buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
+ DimLevelType::TwoOutOfFour),
"buildLevelType conversion is broken");
// Ensure the above predicates work as intended.
@@ -393,18 +407,28 @@ static_assert((!isCompressedDLT(DimLevelType::Dense) &&
!isCompressedDLT(DimLevelType::Singleton) &&
!isCompressedDLT(DimLevelType::SingletonNu) &&
!isCompressedDLT(DimLevelType::SingletonNo) &&
- !isCompressedDLT(DimLevelType::SingletonNuNo)),
+ !isCompressedDLT(DimLevelType::SingletonNuNo) &&
+ !isCompressedDLT(DimLevelType::LooseCompressed) &&
+ !isCompressedDLT(DimLevelType::LooseCompressedNu) &&
+ !isCompressedDLT(DimLevelType::LooseCompressedNo) &&
+ !isCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
+ !isCompressedDLT(DimLevelType::TwoOutOfFour)),
"isCompressedDLT definition is broken");
static_assert((!isLooseCompressedDLT(DimLevelType::Dense) &&
+ !isLooseCompressedDLT(DimLevelType::Compressed) &&
+ !isLooseCompressedDLT(DimLevelType::CompressedNu) &&
+ !isLooseCompressedDLT(DimLevelType::CompressedNo) &&
+ !isLooseCompressedDLT(DimLevelType::CompressedNuNo) &&
+ !isLooseCompressedDLT(DimLevelType::Singleton) &&
+ !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
+ !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
+ !isLooseCompressedDLT(DimLevelType::SingletonNuNo) &&
isLooseCompressedDLT(DimLevelType::LooseCompressed) &&
isLooseCompressedDLT(DimLevelType::LooseCompressedNu) &&
isLooseCompressedDLT(DimLevelType::LooseCompressedNo) &&
isLooseCompressedDLT(DimLevelType::LooseCompressedNuNo) &&
- !isLooseCompressedDLT(DimLevelType::Singleton) &&
- !isLooseCompressedDLT(DimLevelType::SingletonNu) &&
- !isLooseCompressedDLT(DimLevelType::SingletonNo) &&
- !isLooseCompressedDLT(DimLevelType::SingletonNuNo)),
+ !isLooseCompressedDLT(DimLevelType::TwoOutOfFour)),
"isLooseCompressedDLT definition is broken");
static_assert((!isSingletonDLT(DimLevelType::Dense) &&
@@ -415,11 +439,15 @@ static_assert((!isSingletonDLT(DimLevelType::Dense) &&
isSingletonDLT(DimLevelType::Singleton) &&
isSingletonDLT(DimLevelType::SingletonNu) &&
isSingletonDLT(DimLevelType::SingletonNo) &&
- isSingletonDLT(DimLevelType::SingletonNuNo)),
+ isSingletonDLT(DimLevelType::SingletonNuNo) &&
+ !isSingletonDLT(DimLevelType::LooseCompressed) &&
+ !isSingletonDLT(DimLevelType::LooseCompressedNu) &&
+ !isSingletonDLT(DimLevelType::LooseCompressedNo) &&
+ !isSingletonDLT(DimLevelType::LooseCompressedNuNo) &&
+ !isSingletonDLT(DimLevelType::TwoOutOfFour)),
"isSingletonDLT definition is broken");
static_assert((isOrderedDLT(DimLevelType::Dense) &&
- isOrderedDLT(DimLevelType::TwoOutOfFour) &&
isOrderedDLT(DimLevelType::Compressed) &&
isOrderedDLT(DimLevelType::CompressedNu) &&
!isOrderedDLT(DimLevelType::CompressedNo) &&
@@ -431,11 +459,11 @@ static_assert((isOrderedDLT(DimLevelType::Dense) &&
isOrderedDLT(DimLevelType::LooseCompressed) &&
isOrderedDLT(DimLevelType::LooseCompressedNu) &&
!isOrderedDLT(DimLevelType::LooseCompressedNo) &&
- !isOrderedDLT(DimLevelType::LooseCompressedNuNo)),
+ !isOrderedDLT(DimLevelType::LooseCompressedNuNo) &&
+ isOrderedDLT(DimLevelType::TwoOutOfFour)),
"isOrderedDLT definition is broken");
static_assert((isUniqueDLT(DimLevelType::Dense) &&
- isUniqueDLT(DimLevelType::TwoOutOfFour) &&
isUniqueDLT(DimLevelType::Compressed) &&
!isUniqueDLT(DimLevelType::CompressedNu) &&
isUniqueDLT(DimLevelType::CompressedNo) &&
@@ -447,7 +475,8 @@ static_assert((isUniqueDLT(DimLevelType::Dense) &&
isUniqueDLT(DimLevelType::LooseCompressed) &&
!isUniqueDLT(DimLevelType::LooseCompressedNu) &&
isUniqueDLT(DimLevelType::LooseCompressedNo) &&
- !isUniqueDLT(DimLevelType::LooseCompressedNuNo)),
+ !isUniqueDLT(DimLevelType::LooseCompressedNuNo) &&
+ isUniqueDLT(DimLevelType::TwoOutOfFour)),
"isUniqueDLT definition is broken");
/// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 94e7d12b9ee915f..241d90a87165928 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -89,16 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
/// Returns null-attribute for any type without an encoding.
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
-/// Convenience method to query whether a given DLT needs both position and
-/// coordinates array or only coordinates array.
-constexpr inline bool isDLTWithPos(DimLevelType dlt) {
- return isLooseCompressedDLT(dlt) || isCompressedDLT(dlt);
-}
-constexpr inline bool isDLTWithCrd(DimLevelType dlt) {
- return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
- isCompressedDLT(dlt);
-}
-
/// Returns true iff the given sparse tensor encoding attribute has a trailing
/// COO region starting at the given level.
bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 3c73b19319e588c..e7c6435e997ca00 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -371,7 +371,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
::mlir::sparse_tensor::DimLevelType getLvlType(::mlir::sparse_tensor::Level l) const;
bool isDenseLvl(::mlir::sparse_tensor::Level l) const { return isDenseDLT(getLvlType(l)); }
- bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return isTwoOutOfFourDLT(getLvlType(l)); }
+ bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4DLT(getLvlType(l)); }
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedDLT(getLvlType(l)); }
bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedDLT(getLvlType(l)); }
bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonDLT(getLvlType(l)); }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6080317d07a64e0..97ef753aacf35b1 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -10,6 +10,7 @@
#include "Detail/DimLvlMapParser.h"
+#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
>From 8c3006b96b81818c8ba49f8d68245972f68076be Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 2 Nov 2023 11:45:02 -0700
Subject: [PATCH 2/3] 2:4 has crd
---
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index a867b99c3bfa5ba..4f18bcd815230a3 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -289,7 +289,7 @@ constexpr bool isDLTWithPos(DimLevelType dlt) {
/// Check if the `DimLevelType` needs coordinates array.
constexpr bool isDLTWithCrd(DimLevelType dlt) {
return isSingletonDLT(dlt) || isLooseCompressedDLT(dlt) ||
- isCompressedDLT(dlt);
+ isCompressedDLT(dlt) || is2OutOf4DLT(dlt);
}
/// Check if the `DimLevelType` is ordered (regardless of storage format).
>From 50eb4ed88f60f5689f81ca2ddfc776c1ac3d926b Mon Sep 17 00:00:00 2001
From: Aart Bik <ajcbik at google.com>
Date: Thu, 2 Nov 2023 12:04:05 -0700
Subject: [PATCH 3/3] typo
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 28 +++++++++----------
1 file changed, 14 insertions(+), 14 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 4f18bcd815230a3..272f9b5b1756c23 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -349,35 +349,35 @@ static_assert(
(buildLevelType(LevelFormat::Dense, false, true) == std::nullopt &&
buildLevelType(LevelFormat::Dense, true, false) == std::nullopt &&
buildLevelType(LevelFormat::Dense, false, false) == std::nullopt &&
- buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
- buildLevelType(LevelFormat::Compressed, true, true) ==
+ *buildLevelType(LevelFormat::Dense, true, true) == DimLevelType::Dense &&
+ *buildLevelType(LevelFormat::Compressed, true, true) ==
DimLevelType::Compressed &&
- buildLevelType(LevelFormat::Compressed, true, false) ==
+ *buildLevelType(LevelFormat::Compressed, true, false) ==
DimLevelType::CompressedNu &&
- buildLevelType(LevelFormat::Compressed, false, true) ==
+ *buildLevelType(LevelFormat::Compressed, false, true) ==
DimLevelType::CompressedNo &&
- buildLevelType(LevelFormat::Compressed, false, false) ==
+ *buildLevelType(LevelFormat::Compressed, false, false) ==
DimLevelType::CompressedNuNo &&
- buildLevelType(LevelFormat::Singleton, true, true) ==
+ *buildLevelType(LevelFormat::Singleton, true, true) ==
DimLevelType::Singleton &&
- buildLevelType(LevelFormat::Singleton, true, false) ==
+ *buildLevelType(LevelFormat::Singleton, true, false) ==
DimLevelType::SingletonNu &&
- buildLevelType(LevelFormat::Singleton, false, true) ==
+ *buildLevelType(LevelFormat::Singleton, false, true) ==
DimLevelType::SingletonNo &&
- buildLevelType(LevelFormat::Singleton, false, false) ==
+ *buildLevelType(LevelFormat::Singleton, false, false) ==
DimLevelType::SingletonNuNo &&
- buildLevelType(LevelFormat::LooseCompressed, true, true) ==
+ *buildLevelType(LevelFormat::LooseCompressed, true, true) ==
DimLevelType::LooseCompressed &&
- buildLevelType(LevelFormat::LooseCompressed, true, false) ==
+ *buildLevelType(LevelFormat::LooseCompressed, true, false) ==
DimLevelType::LooseCompressedNu &&
- buildLevelType(LevelFormat::LooseCompressed, false, true) ==
+ *buildLevelType(LevelFormat::LooseCompressed, false, true) ==
DimLevelType::LooseCompressedNo &&
- buildLevelType(LevelFormat::LooseCompressed, false, false) ==
+ *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
DimLevelType::LooseCompressedNuNo &&
buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
+ *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
DimLevelType::TwoOutOfFour),
"buildLevelType conversion is broken");
More information about the Mlir-commits
mailing list