[Mlir-commits] [mlir] [mlir][sparse] remove LevelType enum, construct LevelType from LevelF… (PR #81799)
Peiming Liu
llvmlistbot at llvm.org
Thu Feb 15 12:07:22 PST 2024
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/81799
>From 5cba884e9f66608bea6a19e38fd298e101fd0214 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Wed, 14 Feb 2024 22:24:24 +0000
Subject: [PATCH 1/2] [mlir][sparse] remove LevelType enum, construct LevelType
from LevelFormat and properties instead.
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 516 +++++++-----------
mlir/lib/CAPI/Dialect/SparseTensor.cpp | 10 +-
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 4 +-
.../SparseTensor/IR/SparseTensorDialect.cpp | 16 +-
.../Transforms/SparseTensorRewriting.cpp | 2 +-
.../Transforms/Utils/SparseTensorLevel.cpp | 6 +-
.../lib/Dialect/SparseTensor/Utils/Merger.cpp | 3 +-
.../Dialect/SparseTensor/MergerTest.cpp | 34 +-
8 files changed, 237 insertions(+), 354 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 74cc0dee554a17..079899a147476e 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,45 +153,9 @@ enum class Action : uint32_t {
kSortCOOInPlace = 8,
};
-/// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// the "format" per se (dense, compressed, singleton, loose_compressed,
-/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
-/// the format is NOutOfM.
-/// The encoding is chosen for performance of the runtime library, and thus may
-/// change in future versions; consequently, client code should use the
-/// predicate functions defined below, rather than relying on knowledge
-/// about the particular binary encoding.
-///
-/// The `Undef` "format" is a special value used internally for cases
-/// where we need to store an undefined or indeterminate `LevelType`.
-/// It should not be used externally, since it does not indicate an
-/// actual/representable format.
-///
-/// Bit manipulations for LevelType:
-///
-/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
-///
-enum class LevelType : uint64_t {
- Undef = 0x000000000000,
- Dense = 0x000000010000,
- Compressed = 0x000000020000,
- CompressedNu = 0x000000020001,
- CompressedNo = 0x000000020002,
- CompressedNuNo = 0x000000020003,
- Singleton = 0x000000040000,
- SingletonNu = 0x000000040001,
- SingletonNo = 0x000000040002,
- SingletonNuNo = 0x000000040003,
- LooseCompressed = 0x000000080000,
- LooseCompressedNu = 0x000000080001,
- LooseCompressedNo = 0x000000080002,
- LooseCompressedNuNo = 0x000000080003,
- NOutOfM = 0x000000100000,
-};
-
/// This enum defines all supported storage format without the level properties.
enum class LevelFormat : uint64_t {
+ Undef = 0x00000000,
Dense = 0x00010000,
Compressed = 0x00020000,
Singleton = 0x00040000,
@@ -199,328 +163,236 @@ enum class LevelFormat : uint64_t {
NOutOfM = 0x00100000,
};
+/// Returns string representation of the given level format.
+constexpr const char *toFormatString(LevelFormat lvlFmt) {
+ switch (lvlFmt) {
+ case LevelFormat::Undef:
+ return "undef";
+ case LevelFormat::Dense:
+ return "dense";
+ case LevelFormat::Compressed:
+ return "compressed";
+ case LevelFormat::Singleton:
+ return "singleton";
+ case LevelFormat::LooseCompressed:
+ return "loose_compressed";
+ case LevelFormat::NOutOfM:
+ return "structured";
+ }
+ return "";
+}
+
/// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint64_t {
+enum class LevelPropNonDefault : uint64_t {
Nonunique = 0x0001,
Nonordered = 0x0002,
};
-/// Get N of NOutOfM level type.
-constexpr uint64_t getN(LevelType lt) {
- return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+/// Returns string representation of the given level properties.
+constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
+ switch (lvlProp) {
+ case LevelPropNonDefault::Nonunique:
+ return "nonunique";
+ case LevelPropNonDefault::Nonordered:
+ return "nonordered";
+ }
+ return "";
}
-/// Get M of NOutOfM level type.
-constexpr uint64_t getM(LevelType lt) {
- return (static_cast<uint64_t>(lt) >> 40) & 0xff;
-}
+/// This enum defines all the sparse representations supportable by
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
+/// the format is NOutOfM.
+/// The encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
+///
+/// The `Undef` "format" is a special value used internally for cases
+/// where we need to store an undefined or indeterminate `LevelType`.
+/// It should not be used externally, since it does not indicate an
+/// actual/representable format.
-/// Convert N of NOutOfM level type to the stored bits.
-constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+struct LevelType {
+public:
+ /// Check that the `LevelType` contains a valid (possibly undefined) value.
+ static constexpr bool isValidLvlBits(uint64_t lvlBits) {
+ const uint64_t formatBits = lvlBits & 0xffff0000;
+ const uint64_t propertyBits = lvlBits & 0xffff;
+ // If undefined/dense/NOutOfM, then must be unique and ordered.
+ // Otherwise, the format must be one of the known ones.
+ return (formatBits <= 0x10000 || formatBits == 0x100000)
+ ? (propertyBits == 0)
+ : (formatBits == 0x20000 || formatBits == 0x40000 ||
+ formatBits == 0x80000);
+ }
-/// Convert M of NOutOfM level type to the stored bits.
-constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+ /// Convert a LevelFormat to its corresponding LevelType with the given
+ /// properties. Returns std::nullopt when the properties are not applicable
+ /// for the input level format.
+ static std::optional<LevelType>
+ buildLvlType(LevelFormat lf,
+ const std::vector<LevelPropNonDefault> &properties,
+ uint64_t n = 0, uint64_t m = 0) {
+ uint64_t newN = n << 32;
+ uint64_t newM = m << 40;
+ uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
+ for (auto p : properties)
+ ltBits |= static_cast<uint64_t>(p);
+
+ return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits))
+ : std::nullopt;
+ }
+ static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
+ bool unique, uint64_t n = 0,
+ uint64_t m = 0) {
+ std::vector<LevelPropNonDefault> properties;
+ if (!ordered)
+ properties.push_back(LevelPropNonDefault::Nonordered);
+ if (!unique)
+ properties.push_back(LevelPropNonDefault::Nonunique);
+ return buildLvlType(lf, properties, n, m);
+ }
-/// Check if the `LevelType` is NOutOfM (regardless of
-/// properties and block sizes).
-constexpr bool isNOutOfMLT(LevelType lt) {
- return ((static_cast<uint64_t>(lt) & 0x100000) ==
- static_cast<uint64_t>(LevelType::NOutOfM));
-}
+ /// Explicit conversion from uint64_t.
+ constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
+ assert(isValidLvlBits(bits));
+ };
-/// Check if the `LevelType` is NOutOfM with the correct block sizes.
-constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
- return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
-}
+ /// Constructs a LevelType with the given format using all default properties.
+ /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
+ assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
+ };
-/// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lvlType) {
- auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
- switch (lt) {
- case LevelType::Undef:
- return "undef";
- case LevelType::Dense:
- return "dense";
- case LevelType::Compressed:
- return "compressed";
- case LevelType::CompressedNu:
- return "compressed(nonunique)";
- case LevelType::CompressedNo:
- return "compressed(nonordered)";
- case LevelType::CompressedNuNo:
- return "compressed(nonunique, nonordered)";
- case LevelType::Singleton:
- return "singleton";
- case LevelType::SingletonNu:
- return "singleton(nonunique)";
- case LevelType::SingletonNo:
- return "singleton(nonordered)";
- case LevelType::SingletonNuNo:
- return "singleton(nonunique, nonordered)";
- case LevelType::LooseCompressed:
- return "loose_compressed";
- case LevelType::LooseCompressedNu:
- return "loose_compressed(nonunique)";
- case LevelType::LooseCompressedNo:
- return "loose_compressed(nonordered)";
- case LevelType::LooseCompressedNuNo:
- return "loose_compressed(nonunique, nonordered)";
- case LevelType::NOutOfM:
- return "structured";
- }
- return "";
-}
+ /// Converts to uint64_t
+ explicit operator uint64_t() const { return lvlBits; }
-/// Check that the `LevelType` contains a valid (possibly undefined) value.
-constexpr bool isValidLT(LevelType lt) {
- const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
- const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
- // If undefined/dense/NOutOfM, then must be unique and ordered.
- // Otherwise, the format must be one of the known ones.
- return (formatBits <= 0x10000 || formatBits == 0x100000)
- ? (propertyBits == 0)
- : (formatBits == 0x20000 || formatBits == 0x40000 ||
- formatBits == 0x80000);
-}
+ bool operator==(const LevelType lhs) const {
+ return static_cast<uint64_t>(lhs) == lvlBits;
+ }
+ bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
-/// Check if the `LevelType` is the special undefined value.
-constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
+ LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
-/// Check if the `LevelType` is dense (regardless of properties).
-constexpr bool isDenseLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::Dense);
-}
+ /// Get N/M of NOutOfM level type.
+ constexpr uint64_t getN() const {
+ assert(isa<LevelFormat::NOutOfM>());
+ return (lvlBits >> 32) & 0xff;
+ }
+ constexpr uint64_t getM() const {
+ assert(isa<LevelFormat::NOutOfM>());
+ return (lvlBits >> 40) & 0xff;
+ }
-/// Check if the `LevelType` is compressed (regardless of properties).
-constexpr bool isCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::Compressed);
-}
+ /// Get the `LevelFormat` of the `LevelType`.
+ LevelFormat getLvlFmt() const {
+ return static_cast<LevelFormat>(lvlBits & 0xffff0000);
+ }
-/// Check if the `LevelType` is singleton (regardless of properties).
-constexpr bool isSingletonLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::Singleton);
-}
+ /// Check if the `LevelType` is in the `LevelFormat`.
+ template <LevelFormat fmt>
+ bool isa() const {
+ return getLvlFmt() == fmt;
+ }
-/// Check if the `LevelType` is loose compressed (regardless of properties).
-constexpr bool isLooseCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~0xffff) ==
- static_cast<uint64_t>(LevelType::LooseCompressed);
-}
+ /// Check if the `LevelType` has the properties
+ template <LevelPropNonDefault p>
+ bool isa() const {
+ return lvlBits & static_cast<uint64_t>(p);
+ }
-/// Check if the `LevelType` needs positions array.
-constexpr bool isWithPosLT(LevelType lt) {
- return isCompressedLT(lt) || isLooseCompressedLT(lt);
-}
+ /// Check if the `LevelType` needs positions array.
+ bool isWithPosLT() const {
+ return isa<LevelFormat::Compressed>() ||
+ isa<LevelFormat::LooseCompressed>();
+ }
-/// Check if the `LevelType` needs coordinates array.
-constexpr bool isWithCrdLT(LevelType lt) {
- return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- isNOutOfMLT(lt);
-}
+ /// Check if the `LevelType` needs coordinates array.
+ constexpr bool isWithCrdLT() const {
+ // All sparse levels has coordinate array.
+ return !isa<LevelFormat::Dense>();
+ }
-/// Check if the `LevelType` is ordered (regardless of storage format).
-constexpr bool isOrderedLT(LevelType lt) {
- return !(static_cast<uint64_t>(lt) & 2);
- return !(static_cast<uint64_t>(lt) & 2);
-}
+ std::string toMLIRString() const {
+ std::string lvlStr = toFormatString(getLvlFmt());
+ std::string propStr = "";
+ if (isa<LevelPropNonDefault::Nonunique>())
+ propStr += toPropString(LevelPropNonDefault::Nonunique);
+
+ if (isa<LevelPropNonDefault::Nonordered>()) {
+ if (!propStr.empty())
+ propStr += ", ";
+ propStr += toPropString(LevelPropNonDefault::Nonordered);
+ }
+ if (!propStr.empty())
+ lvlStr += ("(" + propStr + ")");
+ return lvlStr;
+ }
-/// Check if the `LevelType` is unique (regardless of storage format).
-constexpr bool isUniqueLT(LevelType lt) {
- return !(static_cast<uint64_t>(lt) & 1);
- return !(static_cast<uint64_t>(lt) & 1);
-}
+private:
+ /// Bit manipulations for LevelType:
+ ///
+ /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+ ///
+ uint64_t lvlBits;
+};
-/// Convert a LevelType to its corresponding LevelFormat.
-/// Returns std::nullopt when input lt is Undef.
-constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
- if (lt == LevelType::Undef)
- return std::nullopt;
- return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
-}
+// For backward-compatibility. TODO: remove below after fully migration.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
-/// Convert a LevelFormat to its corresponding LevelType with the given
-/// properties. Returns std::nullopt when the properties are not applicable
-/// for the input level format.
inline std::optional<LevelType>
buildLevelType(LevelFormat lf,
- const std::vector<LevelPropertyNondefault> &properties,
+ const std::vector<LevelPropNonDefault> &properties,
uint64_t n = 0, uint64_t m = 0) {
- uint64_t newN = n << 32;
- uint64_t newM = m << 40;
- uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
- for (auto p : properties) {
- ltInt |= static_cast<uint64_t>(p);
- }
- auto lt = static_cast<LevelType>(ltInt);
- return isValidLT(lt) ? std::optional(lt) : std::nullopt;
+ return LevelType::buildLvlType(lf, properties, n, m);
}
-
inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
bool unique, uint64_t n = 0,
uint64_t m = 0) {
- std::vector<LevelPropertyNondefault> properties;
- if (!ordered)
- properties.push_back(LevelPropertyNondefault::Nonordered);
- if (!unique)
- properties.push_back(LevelPropertyNondefault::Nonunique);
- return buildLevelType(lf, properties, n, m);
+ return LevelType::buildLvlType(lf, ordered, unique, n, m);
+}
+inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
+inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
+inline bool isCompressedLT(LevelType lt) {
+ return lt.isa<LevelFormat::Compressed>();
+}
+inline bool isLooseCompressedLT(LevelType lt) {
+ return lt.isa<LevelFormat::LooseCompressed>();
+}
+inline bool isSingletonLT(LevelType lt) {
+ return lt.isa<LevelFormat::Singleton>();
+}
+inline bool isNOutOfMLT(LevelType lt) { return lt.isa<LevelFormat::NOutOfM>(); }
+inline bool isOrderedLT(LevelType lt) {
+ return !lt.isa<LevelPropNonDefault::Nonordered>();
}
+inline bool isUniqueLT(LevelType lt) {
+ return !lt.isa<LevelPropNonDefault::Nonunique>();
+}
+inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); }
+inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); }
+inline bool isValidLT(LevelType lt) {
+ return LevelType::isValidLvlBits(static_cast<uint64_t>(lt));
+}
+inline std::optional<LevelFormat> getLevelFormat(LevelType lt) {
+ LevelFormat fmt = lt.getLvlFmt();
+ if (fmt == LevelFormat::Undef)
+ return std::nullopt;
+ return fmt;
+}
+inline uint64_t getN(LevelType lt) { return lt.getN(); }
+inline uint64_t getM(LevelType lt) { return lt.getM(); }
+inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+ return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m;
+}
+inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); }
//
// Ensure the above methods work as intended.
//
-static_assert(
- (getLevelFormat(LevelType::Undef) == std::nullopt &&
- *getLevelFormat(LevelType::Dense) == LevelFormat::Dense &&
- *getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed &&
- *getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton &&
- *getLevelFormat(LevelType::LooseCompressed) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::LooseCompressedNu) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::LooseCompressedNo) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::LooseCompressedNuNo) ==
- LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
- "getLevelFormat conversion is broken");
-
-static_assert(
- (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
- isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
- isValidLT(LevelType::CompressedNo) &&
- isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) &&
- isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) &&
- isValidLT(LevelType::SingletonNuNo) &&
- isValidLT(LevelType::LooseCompressed) &&
- isValidLT(LevelType::LooseCompressedNu) &&
- isValidLT(LevelType::LooseCompressedNo) &&
- isValidLT(LevelType::LooseCompressedNuNo) &&
- isValidLT(LevelType::NOutOfM)),
- "isValidLT definition is broken");
-
-static_assert((isDenseLT(LevelType::Dense) &&
- !isDenseLT(LevelType::Compressed) &&
- !isDenseLT(LevelType::CompressedNu) &&
- !isDenseLT(LevelType::CompressedNo) &&
- !isDenseLT(LevelType::CompressedNuNo) &&
- !isDenseLT(LevelType::Singleton) &&
- !isDenseLT(LevelType::SingletonNu) &&
- !isDenseLT(LevelType::SingletonNo) &&
- !isDenseLT(LevelType::SingletonNuNo) &&
- !isDenseLT(LevelType::LooseCompressed) &&
- !isDenseLT(LevelType::LooseCompressedNu) &&
- !isDenseLT(LevelType::LooseCompressedNo) &&
- !isDenseLT(LevelType::LooseCompressedNuNo) &&
- !isDenseLT(LevelType::NOutOfM)),
- "isDenseLT definition is broken");
-
-static_assert((!isCompressedLT(LevelType::Dense) &&
- isCompressedLT(LevelType::Compressed) &&
- isCompressedLT(LevelType::CompressedNu) &&
- isCompressedLT(LevelType::CompressedNo) &&
- isCompressedLT(LevelType::CompressedNuNo) &&
- !isCompressedLT(LevelType::Singleton) &&
- !isCompressedLT(LevelType::SingletonNu) &&
- !isCompressedLT(LevelType::SingletonNo) &&
- !isCompressedLT(LevelType::SingletonNuNo) &&
- !isCompressedLT(LevelType::LooseCompressed) &&
- !isCompressedLT(LevelType::LooseCompressedNu) &&
- !isCompressedLT(LevelType::LooseCompressedNo) &&
- !isCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isCompressedLT(LevelType::NOutOfM)),
- "isCompressedLT definition is broken");
-
-static_assert((!isSingletonLT(LevelType::Dense) &&
- !isSingletonLT(LevelType::Compressed) &&
- !isSingletonLT(LevelType::CompressedNu) &&
- !isSingletonLT(LevelType::CompressedNo) &&
- !isSingletonLT(LevelType::CompressedNuNo) &&
- isSingletonLT(LevelType::Singleton) &&
- isSingletonLT(LevelType::SingletonNu) &&
- isSingletonLT(LevelType::SingletonNo) &&
- isSingletonLT(LevelType::SingletonNuNo) &&
- !isSingletonLT(LevelType::LooseCompressed) &&
- !isSingletonLT(LevelType::LooseCompressedNu) &&
- !isSingletonLT(LevelType::LooseCompressedNo) &&
- !isSingletonLT(LevelType::LooseCompressedNuNo) &&
- !isSingletonLT(LevelType::NOutOfM)),
- "isSingletonLT definition is broken");
-
-static_assert((!isLooseCompressedLT(LevelType::Dense) &&
- !isLooseCompressedLT(LevelType::Compressed) &&
- !isLooseCompressedLT(LevelType::CompressedNu) &&
- !isLooseCompressedLT(LevelType::CompressedNo) &&
- !isLooseCompressedLT(LevelType::CompressedNuNo) &&
- !isLooseCompressedLT(LevelType::Singleton) &&
- !isLooseCompressedLT(LevelType::SingletonNu) &&
- !isLooseCompressedLT(LevelType::SingletonNo) &&
- !isLooseCompressedLT(LevelType::SingletonNuNo) &&
- isLooseCompressedLT(LevelType::LooseCompressed) &&
- isLooseCompressedLT(LevelType::LooseCompressedNu) &&
- isLooseCompressedLT(LevelType::LooseCompressedNo) &&
- isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isLooseCompressedLT(LevelType::NOutOfM)),
- "isLooseCompressedLT definition is broken");
-
-static_assert((!isNOutOfMLT(LevelType::Dense) &&
- !isNOutOfMLT(LevelType::Compressed) &&
- !isNOutOfMLT(LevelType::CompressedNu) &&
- !isNOutOfMLT(LevelType::CompressedNo) &&
- !isNOutOfMLT(LevelType::CompressedNuNo) &&
- !isNOutOfMLT(LevelType::Singleton) &&
- !isNOutOfMLT(LevelType::SingletonNu) &&
- !isNOutOfMLT(LevelType::SingletonNo) &&
- !isNOutOfMLT(LevelType::SingletonNuNo) &&
- !isNOutOfMLT(LevelType::LooseCompressed) &&
- !isNOutOfMLT(LevelType::LooseCompressedNu) &&
- !isNOutOfMLT(LevelType::LooseCompressedNo) &&
- !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
- isNOutOfMLT(LevelType::NOutOfM)),
- "isNOutOfMLT definition is broken");
-
-static_assert((isOrderedLT(LevelType::Dense) &&
- isOrderedLT(LevelType::Compressed) &&
- isOrderedLT(LevelType::CompressedNu) &&
- !isOrderedLT(LevelType::CompressedNo) &&
- !isOrderedLT(LevelType::CompressedNuNo) &&
- isOrderedLT(LevelType::Singleton) &&
- isOrderedLT(LevelType::SingletonNu) &&
- !isOrderedLT(LevelType::SingletonNo) &&
- !isOrderedLT(LevelType::SingletonNuNo) &&
- isOrderedLT(LevelType::LooseCompressed) &&
- isOrderedLT(LevelType::LooseCompressedNu) &&
- !isOrderedLT(LevelType::LooseCompressedNo) &&
- !isOrderedLT(LevelType::LooseCompressedNuNo) &&
- isOrderedLT(LevelType::NOutOfM)),
- "isOrderedLT definition is broken");
-
-static_assert((isUniqueLT(LevelType::Dense) &&
- isUniqueLT(LevelType::Compressed) &&
- !isUniqueLT(LevelType::CompressedNu) &&
- isUniqueLT(LevelType::CompressedNo) &&
- !isUniqueLT(LevelType::CompressedNuNo) &&
- isUniqueLT(LevelType::Singleton) &&
- !isUniqueLT(LevelType::SingletonNu) &&
- isUniqueLT(LevelType::SingletonNo) &&
- !isUniqueLT(LevelType::SingletonNuNo) &&
- isUniqueLT(LevelType::LooseCompressed) &&
- !isUniqueLT(LevelType::LooseCompressedNu) &&
- isUniqueLT(LevelType::LooseCompressedNo) &&
- !isUniqueLT(LevelType::LooseCompressedNuNo) &&
- isUniqueLT(LevelType::NOutOfM)),
- "isUniqueLT definition is broken");
-
/// Bit manipulations for affine encoding.
///
/// Note that because the indices in the mappings refer to dimensions
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 55af8becbba20e..3ae06f220c5281 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -34,9 +34,9 @@ static_assert(
"MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");
static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
- static_cast<int>(LevelPropertyNondefault::Nonordered) &&
+ static_cast<int>(LevelPropNonDefault::Nonordered) &&
static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
- static_cast<int>(LevelPropertyNondefault::Nonunique),
+ static_cast<int>(LevelPropNonDefault::Nonunique),
"MlirSparseTensorLevelProperty (C-API) and "
"LevelPropertyNondefault (C++) mismatch");
@@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat
mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
LevelType lt =
static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
- return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt));
+ return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt());
}
int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
@@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
const enum MlirSparseTensorLevelPropertyNondefault *properties,
unsigned size, unsigned n, unsigned m) {
- std::vector<LevelPropertyNondefault> props;
+ std::vector<LevelPropNonDefault> props;
for (unsigned i = 0; i < size; i++)
- props.push_back(static_cast<LevelPropertyNondefault>(properties[i]));
+ props.push_back(static_cast<LevelPropNonDefault>(properties[i]));
return static_cast<MlirSparseTensorLevelType>(
*buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 0fb0d2761054b5..380cccc989ec6a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -88,9 +88,9 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
"expected valid level property (e.g. nonordered, nonunique or high)")
if (strVal.compare("nonunique") == 0) {
- *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
+ *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
} else if (strVal.compare("nonordered") == 0) {
- *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
+ *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
} else {
parser.emitError(loc, "unknown level property: ") << strVal;
return failure();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index aed43f26d54f11..6d02645d860e96 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -35,6 +35,14 @@
using namespace mlir;
using namespace mlir::sparse_tensor;
+// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
+// well.
+namespace mlir::sparse_tensor {
+llvm::hash_code hash_value(LevelType lt) {
+ return llvm::hash_value(static_cast<uint64_t>(lt));
+}
+} // namespace mlir::sparse_tensor
+
//===----------------------------------------------------------------------===//
// Local Convenience Methods.
//===----------------------------------------------------------------------===//
@@ -83,11 +91,11 @@ void StorageLayout::foreachField(
}
// The values array.
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
- LevelType::Undef)))
+ LevelFormat::Undef)))
return;
// Put metadata at the end.
if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
- LevelType::Undef)))
+ LevelFormat::Undef)))
return;
}
@@ -341,7 +349,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const {
LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
if (!getImpl())
- return LevelType::Dense;
+ return LevelFormat::Dense;
assert(l < getLvlRank() && "Level is out of bounds");
return getLvlTypes()[l];
}
@@ -975,7 +983,7 @@ static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
SmallVector<LevelType> lts;
for (auto lt : enc.getLvlTypes())
- lts.push_back(*buildLevelType(*getLevelFormat(lt), true, true));
+ lts.push_back(lt.stripProperties());
return SparseTensorEncodingAttr::get(
enc.getContext(), lts,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 235c5453f9cc98..7326a6a3811284 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -46,7 +46,7 @@ static bool isZeroValue(Value val) {
static bool isSparseTensor(Value v) {
auto enc = getSparseTensorEncoding(v.getType());
return enc && !llvm::all_of(enc.getLvlTypes(),
- [](auto lt) { return lt == LevelType::Dense; });
+ [](auto lt) { return lt == LevelFormat::Dense; });
}
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index c85f8204ba7527..61a3703b73bf07 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -63,7 +63,7 @@ class SparseLevel : public SparseTensorLevel {
class DenseLevel : public SparseTensorLevel {
public:
DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
- : SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize),
+ : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize),
encoded(encoded) {}
Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
@@ -1275,7 +1275,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
: b.create<tensor::DimOp>(l, t, lvl).getResult();
- switch (*getLevelFormat(lt)) {
+ switch (lt.getLvlFmt()) {
case LevelFormat::Dense:
return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
case LevelFormat::Compressed: {
@@ -1296,6 +1296,8 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value crd = genToCoordinates(b, l, t, lvl);
return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
}
+ case LevelFormat::Undef:
+ llvm_unreachable("undefined level format");
}
llvm_unreachable("unrecognizable level format");
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 96537cbb0c4836..731cd79a1e3b4b 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -226,7 +226,8 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
syntheticTensor(numInputOutputTensors),
numTensors(numInputOutputTensors + 1), numLoops(numLoops),
hasSparseOut(false),
- lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)),
+ lvlTypes(numTensors,
+ std::vector<LevelType>(numLoops, LevelFormat::Undef)),
loopToLvl(numTensors,
std::vector<std::optional<Level>>(numLoops, std::nullopt)),
lvlToLoop(numTensors,
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index ce9c0e39b31b95..62a19c084cac0f 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -313,11 +313,11 @@ class MergerTest3T1L : public MergerTestBase {
MergerTest3T1L() : MergerTestBase(3, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
// Tensor 0: sparse input vector.
- merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
+ merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: sparse input vector.
- merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
+ merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
// Tensor 2: dense output vector.
- merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
+ merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
}
};
@@ -327,13 +327,13 @@ class MergerTest4T1L : public MergerTestBase {
MergerTest4T1L() : MergerTestBase(4, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
// Tensor 0: sparse input vector.
- merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
+ merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: sparse input vector.
- merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
+ merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
// Tensor 2: sparse input vector
- merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
+ merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
// Tensor 3: dense output vector
- merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
+ merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
}
};
@@ -347,11 +347,11 @@ class MergerTest3T1LD : public MergerTestBase {
MergerTest3T1LD() : MergerTestBase(3, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
// Tensor 0: sparse input vector.
- merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
+ merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
// Tensor 1: dense input vector.
- merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
+ merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
// Tensor 2: dense output vector.
- merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
+ merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
}
};
@@ -365,13 +365,13 @@ class MergerTest4T1LU : public MergerTestBase {
MergerTest4T1LU() : MergerTestBase(4, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
// Tensor 0: undef input vector.
- merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
+ merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
// Tensor 1: dense input vector.
- merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
+ merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
// Tensor 2: undef input vector.
- merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Undef);
+ merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
// Tensor 3: dense output vector.
- merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
+ merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
}
};
@@ -387,11 +387,11 @@ class MergerTest3T1LSo : public MergerTestBase {
EXPECT_TRUE(merger.getSynTensorID() == tid(3));
merger.setHasSparseOut(true);
// Tensor 0: undef input vector.
- merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
+ merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
// Tensor 1: undef input vector.
- merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Undef);
+ merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
// Tensor 2: sparse output vector.
- merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
+ merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
}
};
>From 0f92b30f92a81331600d2d9cf99d83cde31f3f63 Mon Sep 17 00:00:00 2001
From: Peiming Liu <peiming at google.com>
Date: Thu, 15 Feb 2024 20:07:06 +0000
Subject: [PATCH 2/2] address comments
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 23 +++++++++++--------
1 file changed, 14 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 079899a147476e..a20a7906189d01 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -163,6 +163,11 @@ enum class LevelFormat : uint64_t {
NOutOfM = 0x00100000,
};
+template <LevelFormat... targets>
+constexpr bool isAnyOfFmt(LevelFormat fmt) {
+ return (... || (targets == fmt));
+}
+
/// Returns string representation of the given level format.
constexpr const char *toFormatString(LevelFormat lvlFmt) {
switch (lvlFmt) {
@@ -218,14 +223,15 @@ struct LevelType {
public:
/// Check that the `LevelType` contains a valid (possibly undefined) value.
static constexpr bool isValidLvlBits(uint64_t lvlBits) {
- const uint64_t formatBits = lvlBits & 0xffff0000;
+ auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
const uint64_t propertyBits = lvlBits & 0xffff;
// If undefined/dense/NOutOfM, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
- return (formatBits <= 0x10000 || formatBits == 0x100000)
+ return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
+ LevelFormat::NOutOfM>(fmt))
? (propertyBits == 0)
- : (formatBits == 0x20000 || formatBits == 0x40000 ||
- formatBits == 0x80000);
+ : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
+ LevelFormat::LooseCompressed>(fmt));
}
/// Convert a LevelFormat to its corresponding LevelType with the given
@@ -235,6 +241,7 @@ struct LevelType {
buildLvlType(LevelFormat lf,
const std::vector<LevelPropNonDefault> &properties,
uint64_t n = 0, uint64_t m = 0) {
+ assert((n & 0xff) == n && (m & 0xff) == m);
uint64_t newN = n << 32;
uint64_t newM = m << 40;
uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
@@ -275,11 +282,13 @@ struct LevelType {
LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
- /// Get N/M of NOutOfM level type.
+ /// Get N of NOutOfM level type.
constexpr uint64_t getN() const {
assert(isa<LevelFormat::NOutOfM>());
return (lvlBits >> 32) & 0xff;
}
+
+ /// Get M of NOutOfM level type.
constexpr uint64_t getM() const {
assert(isa<LevelFormat::NOutOfM>());
return (lvlBits >> 40) & 0xff;
@@ -389,10 +398,6 @@ inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
}
inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); }
-//
-// Ensure the above methods work as intended.
-//
-
/// Bit manipulations for affine encoding.
///
/// Note that because the indices in the mappings refer to dimensions
More information about the Mlir-commits
mailing list