[Mlir-commits] [mlir] Revert "[mlir][sparse] remove LevelType enum, construct LevelType from LevelF…" (PR #81923)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 15 13:27:04 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Mehdi Amini (joker-eph)
<details>
<summary>Changes</summary>
Reverts llvm/llvm-project#<!-- -->81799 ; this broke the mlir gcc7 bot.
---
Patch is 36.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81923.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+325-202)
- (modified) mlir/lib/CAPI/Dialect/SparseTensor.cpp (+5-5)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+2-2)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+4-12)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+2-4)
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+1-2)
- (modified) mlir/unittests/Dialect/SparseTensor/MergerTest.cpp (+17-17)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index a20a7906189d01..74cc0dee554a17 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,9 +153,45 @@ enum class Action : uint32_t {
kSortCOOInPlace = 8,
};
+/// This enum defines all the sparse representations supportable by
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
+/// the format is NOutOfM.
+/// The encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
+///
+/// The `Undef` "format" is a special value used internally for cases
+/// where we need to store an undefined or indeterminate `LevelType`.
+/// It should not be used externally, since it does not indicate an
+/// actual/representable format.
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
+enum class LevelType : uint64_t {
+ Undef = 0x000000000000,
+ Dense = 0x000000010000,
+ Compressed = 0x000000020000,
+ CompressedNu = 0x000000020001,
+ CompressedNo = 0x000000020002,
+ CompressedNuNo = 0x000000020003,
+ Singleton = 0x000000040000,
+ SingletonNu = 0x000000040001,
+ SingletonNo = 0x000000040002,
+ SingletonNuNo = 0x000000040003,
+ LooseCompressed = 0x000000080000,
+ LooseCompressedNu = 0x000000080001,
+ LooseCompressedNo = 0x000000080002,
+ LooseCompressedNuNo = 0x000000080003,
+ NOutOfM = 0x000000100000,
+};
+
/// This enum defines all supported storage format without the level properties.
enum class LevelFormat : uint64_t {
- Undef = 0x00000000,
Dense = 0x00010000,
Compressed = 0x00020000,
Singleton = 0x00040000,
@@ -163,240 +199,327 @@ enum class LevelFormat : uint64_t {
NOutOfM = 0x00100000,
};
-template <LevelFormat... targets>
-constexpr bool isAnyOfFmt(LevelFormat fmt) {
- return (... || (targets == fmt));
-}
-
-/// Returns string representation of the given level format.
-constexpr const char *toFormatString(LevelFormat lvlFmt) {
- switch (lvlFmt) {
- case LevelFormat::Undef:
- return "undef";
- case LevelFormat::Dense:
- return "dense";
- case LevelFormat::Compressed:
- return "compressed";
- case LevelFormat::Singleton:
- return "singleton";
- case LevelFormat::LooseCompressed:
- return "loose_compressed";
- case LevelFormat::NOutOfM:
- return "structured";
- }
- return "";
-}
-
/// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropNonDefault : uint64_t {
+enum class LevelPropertyNondefault : uint64_t {
Nonunique = 0x0001,
Nonordered = 0x0002,
};
-/// Returns string representation of the given level properties.
-constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
- switch (lvlProp) {
- case LevelPropNonDefault::Nonunique:
- return "nonunique";
- case LevelPropNonDefault::Nonordered:
- return "nonordered";
- }
- return "";
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 32) & 0xff;
}
-/// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// the "format" per se (dense, compressed, singleton, loose_compressed,
-/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
-/// the format is NOutOfM.
-/// The encoding is chosen for performance of the runtime library, and thus may
-/// change in future versions; consequently, client code should use the
-/// predicate functions defined below, rather than relying on knowledge
-/// about the particular binary encoding.
-///
-/// The `Undef` "format" is a special value used internally for cases
-/// where we need to store an undefined or indeterminate `LevelType`.
-/// It should not be used externally, since it does not indicate an
-/// actual/representable format.
-
-struct LevelType {
-public:
- /// Check that the `LevelType` contains a valid (possibly undefined) value.
- static constexpr bool isValidLvlBits(uint64_t lvlBits) {
- auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
- const uint64_t propertyBits = lvlBits & 0xffff;
- // If undefined/dense/NOutOfM, then must be unique and ordered.
- // Otherwise, the format must be one of the known ones.
- return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
- LevelFormat::NOutOfM>(fmt))
- ? (propertyBits == 0)
- : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
- LevelFormat::LooseCompressed>(fmt));
- }
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
- /// Convert a LevelFormat to its corresponding LevelType with the given
- /// properties. Returns std::nullopt when the properties are not applicable
- /// for the input level format.
- static std::optional<LevelType>
- buildLvlType(LevelFormat lf,
- const std::vector<LevelPropNonDefault> &properties,
- uint64_t n = 0, uint64_t m = 0) {
- assert((n & 0xff) == n && (m & 0xff) == m);
- uint64_t newN = n << 32;
- uint64_t newM = m << 40;
- uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
- for (auto p : properties)
- ltBits |= static_cast<uint64_t>(p);
-
- return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits))
- : std::nullopt;
- }
- static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
- bool unique, uint64_t n = 0,
- uint64_t m = 0) {
- std::vector<LevelPropNonDefault> properties;
- if (!ordered)
- properties.push_back(LevelPropNonDefault::Nonordered);
- if (!unique)
- properties.push_back(LevelPropNonDefault::Nonunique);
- return buildLvlType(lf, properties, n, m);
- }
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
- /// Explicit conversion from uint64_t.
- constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
- assert(isValidLvlBits(bits));
- };
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
- /// Constructs a LevelType with the given format using all default properties.
- /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
- assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
- };
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+ return ((static_cast<uint64_t>(lt) & 0x100000) ==
+ static_cast<uint64_t>(LevelType::NOutOfM));
+}
- /// Converts to uint64_t
- explicit operator uint64_t() const { return lvlBits; }
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+ return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
- bool operator==(const LevelType lhs) const {
- return static_cast<uint64_t>(lhs) == lvlBits;
+/// Returns string representation of the given dimension level type.
+constexpr const char *toMLIRString(LevelType lvlType) {
+ auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
+ switch (lt) {
+ case LevelType::Undef:
+ return "undef";
+ case LevelType::Dense:
+ return "dense";
+ case LevelType::Compressed:
+ return "compressed";
+ case LevelType::CompressedNu:
+ return "compressed(nonunique)";
+ case LevelType::CompressedNo:
+ return "compressed(nonordered)";
+ case LevelType::CompressedNuNo:
+ return "compressed(nonunique, nonordered)";
+ case LevelType::Singleton:
+ return "singleton";
+ case LevelType::SingletonNu:
+ return "singleton(nonunique)";
+ case LevelType::SingletonNo:
+ return "singleton(nonordered)";
+ case LevelType::SingletonNuNo:
+ return "singleton(nonunique, nonordered)";
+ case LevelType::LooseCompressed:
+ return "loose_compressed";
+ case LevelType::LooseCompressedNu:
+ return "loose_compressed(nonunique)";
+ case LevelType::LooseCompressedNo:
+ return "loose_compressed(nonordered)";
+ case LevelType::LooseCompressedNuNo:
+ return "loose_compressed(nonunique, nonordered)";
+ case LevelType::NOutOfM:
+ return "structured";
}
- bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
+ return "";
+}
- LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
+/// Check that the `LevelType` contains a valid (possibly undefined) value.
+constexpr bool isValidLT(LevelType lt) {
+ const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+ const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+ // If undefined/dense/NOutOfM, then must be unique and ordered.
+ // Otherwise, the format must be one of the known ones.
+ return (formatBits <= 0x10000 || formatBits == 0x100000)
+ ? (propertyBits == 0)
+ : (formatBits == 0x20000 || formatBits == 0x40000 ||
+ formatBits == 0x80000);
+}
- /// Get N of NOutOfM level type.
- constexpr uint64_t getN() const {
- assert(isa<LevelFormat::NOutOfM>());
- return (lvlBits >> 32) & 0xff;
- }
+/// Check if the `LevelType` is the special undefined value.
+constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
- /// Get M of NOutOfM level type.
- constexpr uint64_t getM() const {
- assert(isa<LevelFormat::NOutOfM>());
- return (lvlBits >> 40) & 0xff;
- }
+/// Check if the `LevelType` is dense (regardless of properties).
+constexpr bool isDenseLT(LevelType lt) {
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Dense);
+}
- /// Get the `LevelFormat` of the `LevelType`.
- LevelFormat getLvlFmt() const {
- return static_cast<LevelFormat>(lvlBits & 0xffff0000);
- }
+/// Check if the `LevelType` is compressed (regardless of properties).
+constexpr bool isCompressedLT(LevelType lt) {
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Compressed);
+}
- /// Check if the `LevelType` is in the `LevelFormat`.
- template <LevelFormat fmt>
- bool isa() const {
- return getLvlFmt() == fmt;
- }
+/// Check if the `LevelType` is singleton (regardless of properties).
+constexpr bool isSingletonLT(LevelType lt) {
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Singleton);
+}
- /// Check if the `LevelType` has the properties
- template <LevelPropNonDefault p>
- bool isa() const {
- return lvlBits & static_cast<uint64_t>(p);
- }
+/// Check if the `LevelType` is loose compressed (regardless of properties).
+constexpr bool isLooseCompressedLT(LevelType lt) {
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::LooseCompressed);
+}
- /// Check if the `LevelType` needs positions array.
- bool isWithPosLT() const {
- return isa<LevelFormat::Compressed>() ||
- isa<LevelFormat::LooseCompressed>();
- }
+/// Check if the `LevelType` needs positions array.
+constexpr bool isWithPosLT(LevelType lt) {
+ return isCompressedLT(lt) || isLooseCompressedLT(lt);
+}
- /// Check if the `LevelType` needs coordinates array.
- constexpr bool isWithCrdLT() const {
- // All sparse levels has coordinate array.
- return !isa<LevelFormat::Dense>();
- }
+/// Check if the `LevelType` needs coordinates array.
+constexpr bool isWithCrdLT(LevelType lt) {
+ return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
+ isNOutOfMLT(lt);
+}
- std::string toMLIRString() const {
- std::string lvlStr = toFormatString(getLvlFmt());
- std::string propStr = "";
- if (isa<LevelPropNonDefault::Nonunique>())
- propStr += toPropString(LevelPropNonDefault::Nonunique);
-
- if (isa<LevelPropNonDefault::Nonordered>()) {
- if (!propStr.empty())
- propStr += ", ";
- propStr += toPropString(LevelPropNonDefault::Nonordered);
- }
- if (!propStr.empty())
- lvlStr += ("(" + propStr + ")");
- return lvlStr;
- }
+/// Check if the `LevelType` is ordered (regardless of storage format).
+constexpr bool isOrderedLT(LevelType lt) {
+ return !(static_cast<uint64_t>(lt) & 2);
+ return !(static_cast<uint64_t>(lt) & 2);
+}
-private:
- /// Bit manipulations for LevelType:
- ///
- /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
- ///
- uint64_t lvlBits;
-};
+/// Check if the `LevelType` is unique (regardless of storage format).
+constexpr bool isUniqueLT(LevelType lt) {
+ return !(static_cast<uint64_t>(lt) & 1);
+ return !(static_cast<uint64_t>(lt) & 1);
+}
-// For backward-compatibility. TODO: remove below after fully migration.
-constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
-constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+/// Convert a LevelType to its corresponding LevelFormat.
+/// Returns std::nullopt when input lt is Undef.
+constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
+ if (lt == LevelType::Undef)
+ return std::nullopt;
+ return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
+}
+/// Convert a LevelFormat to its corresponding LevelType with the given
+/// properties. Returns std::nullopt when the properties are not applicable
+/// for the input level format.
inline std::optional<LevelType>
buildLevelType(LevelFormat lf,
- const std::vector<LevelPropNonDefault> &properties,
+ const std::vector<LevelPropertyNondefault> &properties,
uint64_t n = 0, uint64_t m = 0) {
- return LevelType::buildLvlType(lf, properties, n, m);
+ uint64_t newN = n << 32;
+ uint64_t newM = m << 40;
+ uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
+ for (auto p : properties) {
+ ltInt |= static_cast<uint64_t>(p);
+ }
+ auto lt = static_cast<LevelType>(ltInt);
+ return isValidLT(lt) ? std::optional(lt) : std::nullopt;
}
+
inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
bool unique, uint64_t n = 0,
uint64_t m = 0) {
- return LevelType::buildLvlType(lf, ordered, unique, n, m);
+ std::vector<LevelPropertyNondefault> properties;
+ if (!ordered)
+ properties.push_back(LevelPropertyNondefault::Nonordered);
+ if (!unique)
+ properties.push_back(LevelPropertyNondefault::Nonunique);
+ return buildLevelType(lf, properties, n, m);
}
-inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
-inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
-inline bool isCompressedLT(LevelType lt) {
- return lt.isa<LevelFormat::Compressed>();
-}
-inline bool isLooseCompressedLT(LevelType lt) {
- return lt.isa<LevelFormat::LooseCompressed>();
-}
-inline bool isSingletonLT(LevelType lt) {
- return lt.isa<LevelFormat::Singleton>();
-}
-inline bool isNOutOfMLT(LevelType lt) { return lt.isa<LevelFormat::NOutOfM>(); }
-inline bool isOrderedLT(LevelType lt) {
- return !lt.isa<LevelPropNonDefault::Nonordered>();
-}
-inline bool isUniqueLT(LevelType lt) {
- return !lt.isa<LevelPropNonDefault::Nonunique>();
-}
-inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); }
-inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); }
-inline bool isValidLT(LevelType lt) {
- return LevelType::isValidLvlBits(static_cast<uint64_t>(lt));
-}
-inline std::optional<LevelFormat> getLevelFormat(LevelType lt) {
- LevelFormat fmt = lt.getLvlFmt();
- if (fmt == LevelFormat::Undef)
- return std::nullopt;
- return fmt;
-}
-inline uint64_t getN(LevelType lt) { return lt.getN(); }
-inline uint64_t getM(LevelType lt) { return lt.getM(); }
-inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
- return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m;
-}
-inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); }
+
+//
+// Ensure the above methods work as intended.
+//
+
+static_assert(
+ (getLevelFormat(LevelType::Undef) == std::nullopt &&
+ *getLevelFormat(LevelType::Dense) == LevelFormat::Dense &&
+ *getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed &&
+ *getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed &&
+ *getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed &&
+ *getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed &&
+ *getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton &&
+ *getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton &&
+ *getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton &&
+ *getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton &&
+ *getLevelFormat(LevelType::LooseCompressed) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(LevelType::LooseCompressedNu) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(LevelType::LooseCompressedNo) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(LevelType::LooseCompressedNuNo) ==
+ LevelFormat::LooseCompressed &&
+ *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
+ "getLevelFormat conversion is broken");
+
+static_assert(
+ (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
+ isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
+ isValidLT(LevelType::CompressedNo) &&
+ isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) &&
+ isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) &&
+ isValidLT(LevelType::SingletonNuNo) &&
+ isValidLT(LevelType::LooseCompressed) &&
+ isValidLT(LevelType::LooseCompressedNu) &&
+ isValidLT(LevelType::LooseCompressedNo) &&
+ isValidLT(LevelType::LooseCompressedNuNo) &&
+ isValidLT(LevelType::NOutOfM)),
+ "isValidLT definition is broken");
+
+static_assert((isDenseLT(LevelType::Dense) &&
+ !isDenseLT(LevelType::Compressed) &&
+ !isDenseLT(LevelType::CompressedNu) &&
+ !isDenseLT(LevelType::CompressedNo) &&
+ !isDenseLT(LevelType::CompressedNuNo) &&
+ !isDenseLT(LevelType::Singleton) &&
+ !isDenseLT(LevelType::SingletonNu) &&
+ !isDenseLT(LevelType::SingletonNo) &&
+ !isDenseLT(LevelType::SingletonNuNo) &&
+ !isDenseLT(LevelType::LooseCompressed) &&
+ !isDenseLT(LevelType::LooseCompressedNu) &&
+ !isDenseLT(LevelType::LooseCompressedNo) &&
+ !isDenseLT(LevelType::LooseCompressedNuNo) &&
+ !isDenseLT(LevelType::NOutOfM)),
+ "isDenseLT definition is broken");
+
+static_assert((!isCompressedLT(LevelType::Dense) &&
+ isCompressedLT(LevelType::Compressed) &&
+ isCompressedLT(LevelType::CompressedNu) &&
+ isCompressedLT(LevelType::CompressedNo) &&
+ isCompressedLT(LevelType::CompressedNuNo) &&
+ !isCompressedLT(LevelType::Singleton) &&
+ !isCompressedLT(LevelType::SingletonNu) &&
+ !isCompressedLT(LevelType::SingletonNo) &&
+ !isCompressedLT(LevelType::SingletonNuNo) &&
+ !isCompressedLT(LevelType::LooseCompressed) &&
+ !isCompressedLT(LevelType::LooseCompressedNu) &&
+ !isCompressedLT(LevelType::LooseCompressedNo) &&
+ !isCompressedLT(LevelType::LooseCompressedNuNo) &&
+ !isCompressedLT(LevelType::NOutOfM)),
+ "isCompressedLT definition is broken");
+
+static_assert((!isSingletonLT(LevelType::Dense) &&
+ !isSingletonLT(LevelType::Co...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/81923
More information about the Mlir-commits
mailing list