[Mlir-commits] [mlir] Reapply "[mlir][sparse] remove LevelType enum, construct LevelType from LevelFormat and Properties" (#81923) (PR #81934)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 15 14:27:53 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

<details>
<summary>Changes</summary>



---

Patch is 36.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81934.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+202-325) 
- (modified) mlir/lib/CAPI/Dialect/SparseTensor.cpp (+5-5) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+12-4) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+4-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-1) 
- (modified) mlir/unittests/Dialect/SparseTensor/MergerTest.cpp (+17-17) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 74cc0dee554a17..c7db5beb2015a6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,45 +153,9 @@ enum class Action : uint32_t {
   kSortCOOInPlace = 8,
 };
 
-/// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// the "format" per se (dense, compressed, singleton, loose_compressed,
-/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
-/// the format is NOutOfM.
-/// The encoding is chosen for performance of the runtime library, and thus may
-/// change in future versions; consequently, client code should use the
-/// predicate functions defined below, rather than relying on knowledge
-/// about the particular binary encoding.
-///
-/// The `Undef` "format" is a special value used internally for cases
-/// where we need to store an undefined or indeterminate `LevelType`.
-/// It should not be used externally, since it does not indicate an
-/// actual/representable format.
-///
-/// Bit manipulations for LevelType:
-///
-/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
-///
-enum class LevelType : uint64_t {
-  Undef = 0x000000000000,
-  Dense = 0x000000010000,
-  Compressed = 0x000000020000,
-  CompressedNu = 0x000000020001,
-  CompressedNo = 0x000000020002,
-  CompressedNuNo = 0x000000020003,
-  Singleton = 0x000000040000,
-  SingletonNu = 0x000000040001,
-  SingletonNo = 0x000000040002,
-  SingletonNuNo = 0x000000040003,
-  LooseCompressed = 0x000000080000,
-  LooseCompressedNu = 0x000000080001,
-  LooseCompressedNo = 0x000000080002,
-  LooseCompressedNuNo = 0x000000080003,
-  NOutOfM = 0x000000100000,
-};
-
 /// This enum defines all supported storage format without the level properties.
 enum class LevelFormat : uint64_t {
+  Undef = 0x00000000,
   Dense = 0x00010000,
   Compressed = 0x00020000,
   Singleton = 0x00040000,
@@ -199,327 +163,240 @@ enum class LevelFormat : uint64_t {
   NOutOfM = 0x00100000,
 };
 
+template <LevelFormat... targets>
+constexpr bool isAnyOfFmt(LevelFormat fmt) {
+  return (... || (targets == fmt));
+}
+
+/// Returns string representation of the given level format.
+constexpr const char *toFormatString(LevelFormat lvlFmt) {
+  switch (lvlFmt) {
+  case LevelFormat::Undef:
+    return "undef";
+  case LevelFormat::Dense:
+    return "dense";
+  case LevelFormat::Compressed:
+    return "compressed";
+  case LevelFormat::Singleton:
+    return "singleton";
+  case LevelFormat::LooseCompressed:
+    return "loose_compressed";
+  case LevelFormat::NOutOfM:
+    return "structured";
+  }
+  return "";
+}
+
 /// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint64_t {
+enum class LevelPropNonDefault : uint64_t {
   Nonunique = 0x0001,
   Nonordered = 0x0002,
 };
 
-/// Get N of NOutOfM level type.
-constexpr uint64_t getN(LevelType lt) {
-  return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+/// Returns string representation of the given level properties.
+constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
+  switch (lvlProp) {
+  case LevelPropNonDefault::Nonunique:
+    return "nonunique";
+  case LevelPropNonDefault::Nonordered:
+    return "nonordered";
+  }
+  return "";
 }
 
-/// Get M of NOutOfM level type.
-constexpr uint64_t getM(LevelType lt) {
-  return (static_cast<uint64_t>(lt) >> 40) & 0xff;
-}
+/// This enum defines all the sparse representations supportable by
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
+/// the format is NOutOfM.
+/// The encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
+///
+/// The `Undef` "format" is a special value used internally for cases
+/// where we need to store an undefined or indeterminate `LevelType`.
+/// It should not be used externally, since it does not indicate an
+/// actual/representable format.
 
-/// Convert N of NOutOfM level type to the stored bits.
-constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+struct LevelType {
+public:
+  /// Check that the `LevelType` contains a valid (possibly undefined) value.
+  static constexpr bool isValidLvlBits(uint64_t lvlBits) {
+    auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
+    const uint64_t propertyBits = lvlBits & 0xffff;
+    // If undefined/dense/NOutOfM, then must be unique and ordered.
+    // Otherwise, the format must be one of the known ones.
+    return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
+                       LevelFormat::NOutOfM>(fmt))
+               ? (propertyBits == 0)
+               : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
+                             LevelFormat::LooseCompressed>(fmt));
+  }
 
-/// Convert M of NOutOfM level type to the stored bits.
-constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+  /// Convert a LevelFormat to its corresponding LevelType with the given
+  /// properties. Returns std::nullopt when the properties are not applicable
+  /// for the input level format.
+  static std::optional<LevelType>
+  buildLvlType(LevelFormat lf,
+               const std::vector<LevelPropNonDefault> &properties,
+               uint64_t n = 0, uint64_t m = 0) {
+    assert((n & 0xff) == n && (m & 0xff) == m);
+    uint64_t newN = n << 32;
+    uint64_t newM = m << 40;
+    uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
+    for (auto p : properties)
+      ltBits |= static_cast<uint64_t>(p);
+
+    return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits))
+                                  : std::nullopt;
+  }
+  static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
+                                               bool unique, uint64_t n = 0,
+                                               uint64_t m = 0) {
+    std::vector<LevelPropNonDefault> properties;
+    if (!ordered)
+      properties.push_back(LevelPropNonDefault::Nonordered);
+    if (!unique)
+      properties.push_back(LevelPropNonDefault::Nonunique);
+    return buildLvlType(lf, properties, n, m);
+  }
 
-/// Check if the `LevelType` is NOutOfM (regardless of
-/// properties and block sizes).
-constexpr bool isNOutOfMLT(LevelType lt) {
-  return ((static_cast<uint64_t>(lt) & 0x100000) ==
-          static_cast<uint64_t>(LevelType::NOutOfM));
-}
+  /// Explicit conversion from uint64_t.
+  constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
+    assert(isValidLvlBits(bits));
+  };
 
-/// Check if the `LevelType` is NOutOfM with the correct block sizes.
-constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
-  return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
-}
+  /// Constructs a LevelType with the given format using all default properties.
+  /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
+    assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
+  };
 
-/// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lvlType) {
-  auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
-  switch (lt) {
-  case LevelType::Undef:
-    return "undef";
-  case LevelType::Dense:
-    return "dense";
-  case LevelType::Compressed:
-    return "compressed";
-  case LevelType::CompressedNu:
-    return "compressed(nonunique)";
-  case LevelType::CompressedNo:
-    return "compressed(nonordered)";
-  case LevelType::CompressedNuNo:
-    return "compressed(nonunique, nonordered)";
-  case LevelType::Singleton:
-    return "singleton";
-  case LevelType::SingletonNu:
-    return "singleton(nonunique)";
-  case LevelType::SingletonNo:
-    return "singleton(nonordered)";
-  case LevelType::SingletonNuNo:
-    return "singleton(nonunique, nonordered)";
-  case LevelType::LooseCompressed:
-    return "loose_compressed";
-  case LevelType::LooseCompressedNu:
-    return "loose_compressed(nonunique)";
-  case LevelType::LooseCompressedNo:
-    return "loose_compressed(nonordered)";
-  case LevelType::LooseCompressedNuNo:
-    return "loose_compressed(nonunique, nonordered)";
-  case LevelType::NOutOfM:
-    return "structured";
+  /// Converts to uint64_t
+  explicit operator uint64_t() const { return lvlBits; }
+
+  bool operator==(const LevelType lhs) const {
+    return static_cast<uint64_t>(lhs) == lvlBits;
   }
-  return "";
-}
+  bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
 
-/// Check that the `LevelType` contains a valid (possibly undefined) value.
-constexpr bool isValidLT(LevelType lt) {
-  const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
-  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
-  // If undefined/dense/NOutOfM, then must be unique and ordered.
-  // Otherwise, the format must be one of the known ones.
-  return (formatBits <= 0x10000 || formatBits == 0x100000)
-             ? (propertyBits == 0)
-             : (formatBits == 0x20000 || formatBits == 0x40000 ||
-                formatBits == 0x80000);
-}
+  LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
 
-/// Check if the `LevelType` is the special undefined value.
-constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
+  /// Get N of NOutOfM level type.
+  constexpr uint64_t getN() const {
+    assert(isa<LevelFormat::NOutOfM>());
+    return (lvlBits >> 32) & 0xff;
+  }
 
-/// Check if the `LevelType` is dense (regardless of properties).
-constexpr bool isDenseLT(LevelType lt) {
-  return (static_cast<uint64_t>(lt) & ~0xffff) ==
-         static_cast<uint64_t>(LevelType::Dense);
-}
+  /// Get M of NOutOfM level type.
+  constexpr uint64_t getM() const {
+    assert(isa<LevelFormat::NOutOfM>());
+    return (lvlBits >> 40) & 0xff;
+  }
 
-/// Check if the `LevelType` is compressed (regardless of properties).
-constexpr bool isCompressedLT(LevelType lt) {
-  return (static_cast<uint64_t>(lt) & ~0xffff) ==
-         static_cast<uint64_t>(LevelType::Compressed);
-}
+  /// Get the `LevelFormat` of the `LevelType`.
+  LevelFormat getLvlFmt() const {
+    return static_cast<LevelFormat>(lvlBits & 0xffff0000);
+  }
 
-/// Check if the `LevelType` is singleton (regardless of properties).
-constexpr bool isSingletonLT(LevelType lt) {
-  return (static_cast<uint64_t>(lt) & ~0xffff) ==
-         static_cast<uint64_t>(LevelType::Singleton);
-}
+  /// Check if the `LevelType` is in the `LevelFormat`.
+  template <LevelFormat fmt>
+  constexpr bool isa() const {
+    return getLvlFmt() == fmt;
+  }
 
-/// Check if the `LevelType` is loose compressed (regardless of properties).
-constexpr bool isLooseCompressedLT(LevelType lt) {
-  return (static_cast<uint64_t>(lt) & ~0xffff) ==
-         static_cast<uint64_t>(LevelType::LooseCompressed);
-}
+  /// Check if the `LevelType` has the properties
+  template <LevelPropNonDefault p>
+  constexpr bool isa() const {
+    return lvlBits & static_cast<uint64_t>(p);
+  }
 
-/// Check if the `LevelType` needs positions array.
-constexpr bool isWithPosLT(LevelType lt) {
-  return isCompressedLT(lt) || isLooseCompressedLT(lt);
-}
+  /// Check if the `LevelType` needs positions array.
+  constexpr bool isWithPosLT() const {
+    return isa<LevelFormat::Compressed>() ||
+           isa<LevelFormat::LooseCompressed>();
+  }
 
-/// Check if the `LevelType` needs coordinates array.
-constexpr bool isWithCrdLT(LevelType lt) {
-  return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-         isNOutOfMLT(lt);
-}
+  /// Check if the `LevelType` needs coordinates array.
+  constexpr bool isWithCrdLT() const {
+    // All sparse levels has coordinate array.
+    return !isa<LevelFormat::Dense>();
+  }
 
-/// Check if the `LevelType` is ordered (regardless of storage format).
-constexpr bool isOrderedLT(LevelType lt) {
-  return !(static_cast<uint64_t>(lt) & 2);
-  return !(static_cast<uint64_t>(lt) & 2);
-}
+  std::string toMLIRString() const {
+    std::string lvlStr = toFormatString(getLvlFmt());
+    std::string propStr = "";
+    if (isa<LevelPropNonDefault::Nonunique>())
+      propStr += toPropString(LevelPropNonDefault::Nonunique);
+
+    if (isa<LevelPropNonDefault::Nonordered>()) {
+      if (!propStr.empty())
+        propStr += ", ";
+      propStr += toPropString(LevelPropNonDefault::Nonordered);
+    }
+    if (!propStr.empty())
+      lvlStr += ("(" + propStr + ")");
+    return lvlStr;
+  }
 
-/// Check if the `LevelType` is unique (regardless of storage format).
-constexpr bool isUniqueLT(LevelType lt) {
-  return !(static_cast<uint64_t>(lt) & 1);
-  return !(static_cast<uint64_t>(lt) & 1);
-}
+private:
+  /// Bit manipulations for LevelType:
+  ///
+  /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+  ///
+  uint64_t lvlBits;
+};
 
-/// Convert a LevelType to its corresponding LevelFormat.
-/// Returns std::nullopt when input lt is Undef.
-constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
-  if (lt == LevelType::Undef)
-    return std::nullopt;
-  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
-}
+// For backward-compatibility. TODO: remove below after fully migration.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
 
-/// Convert a LevelFormat to its corresponding LevelType with the given
-/// properties. Returns std::nullopt when the properties are not applicable
-/// for the input level format.
 inline std::optional<LevelType>
 buildLevelType(LevelFormat lf,
-               const std::vector<LevelPropertyNondefault> &properties,
+               const std::vector<LevelPropNonDefault> &properties,
                uint64_t n = 0, uint64_t m = 0) {
-  uint64_t newN = n << 32;
-  uint64_t newM = m << 40;
-  uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
-  for (auto p : properties) {
-    ltInt |= static_cast<uint64_t>(p);
-  }
-  auto lt = static_cast<LevelType>(ltInt);
-  return isValidLT(lt) ? std::optional(lt) : std::nullopt;
+  return LevelType::buildLvlType(lf, properties, n, m);
 }
-
 inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
                                                bool unique, uint64_t n = 0,
                                                uint64_t m = 0) {
-  std::vector<LevelPropertyNondefault> properties;
-  if (!ordered)
-    properties.push_back(LevelPropertyNondefault::Nonordered);
-  if (!unique)
-    properties.push_back(LevelPropertyNondefault::Nonunique);
-  return buildLevelType(lf, properties, n, m);
+  return LevelType::buildLvlType(lf, ordered, unique, n, m);
 }
-
-//
-// Ensure the above methods work as intended.
-//
-
-static_assert(
-    (getLevelFormat(LevelType::Undef) == std::nullopt &&
-     *getLevelFormat(LevelType::Dense) == LevelFormat::Dense &&
-     *getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed &&
-     *getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed &&
-     *getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed &&
-     *getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed &&
-     *getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton &&
-     *getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton &&
-     *getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton &&
-     *getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton &&
-     *getLevelFormat(LevelType::LooseCompressed) ==
-         LevelFormat::LooseCompressed &&
-     *getLevelFormat(LevelType::LooseCompressedNu) ==
-         LevelFormat::LooseCompressed &&
-     *getLevelFormat(LevelType::LooseCompressedNo) ==
-         LevelFormat::LooseCompressed &&
-     *getLevelFormat(LevelType::LooseCompressedNuNo) ==
-         LevelFormat::LooseCompressed &&
-     *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
-    "getLevelFormat conversion is broken");
-
-static_assert(
-    (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
-     isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
-     isValidLT(LevelType::CompressedNo) &&
-     isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) &&
-     isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) &&
-     isValidLT(LevelType::SingletonNuNo) &&
-     isValidLT(LevelType::LooseCompressed) &&
-     isValidLT(LevelType::LooseCompressedNu) &&
-     isValidLT(LevelType::LooseCompressedNo) &&
-     isValidLT(LevelType::LooseCompressedNuNo) &&
-     isValidLT(LevelType::NOutOfM)),
-    "isValidLT definition is broken");
-
-static_assert((isDenseLT(LevelType::Dense) &&
-               !isDenseLT(LevelType::Compressed) &&
-               !isDenseLT(LevelType::CompressedNu) &&
-               !isDenseLT(LevelType::CompressedNo) &&
-               !isDenseLT(LevelType::CompressedNuNo) &&
-               !isDenseLT(LevelType::Singleton) &&
-               !isDenseLT(LevelType::SingletonNu) &&
-               !isDenseLT(LevelType::SingletonNo) &&
-               !isDenseLT(LevelType::SingletonNuNo) &&
-               !isDenseLT(LevelType::LooseCompressed) &&
-               !isDenseLT(LevelType::LooseCompressedNu) &&
-               !isDenseLT(LevelType::LooseCompressedNo) &&
-               !isDenseLT(LevelType::LooseCompressedNuNo) &&
-               !isDenseLT(LevelType::NOutOfM)),
-              "isDenseLT definition is broken");
-
-static_assert((!isCompressedLT(LevelType::Dense) &&
-               isCompressedLT(LevelType::Compressed) &&
-               isCompressedLT(LevelType::CompressedNu) &&
-               isCompressedLT(LevelType::CompressedNo) &&
-               isCompressedLT(LevelType::CompressedNuNo) &&
-               !isCompressedLT(LevelType::Singleton) &&
-               !isCompressedLT(LevelType::SingletonNu) &&
-               !isCompressedLT(LevelType::SingletonNo) &&
-               !isCompressedLT(LevelType::SingletonNuNo) &&
-               !isCompressedLT(LevelType::LooseCompressed) &&
-               !isCompressedLT(LevelType::LooseCompressedNu) &&
-               !isCompressedLT(LevelType::LooseCompressedNo) &&
-               !isCompressedLT(LevelType::LooseCompressedNuNo) &&
-               !isCompressedLT(LevelType::NOutOfM)),
-              "isCompressedLT definition is broken");
-
-static_assert((!isSingletonLT(LevelType::Dense) &&
-               !isSingletonLT(LevelType::Compressed) &&
-               !isSingletonLT(LevelType::CompressedNu) &&
-               !isSingletonLT(LevelType::CompressedNo) &&
-               !isSingletonLT(LevelType::CompressedNuNo) &&
-               isSingletonLT(LevelType::Singleton) &&
-               isSingletonLT(LevelType::SingletonNu) &&
-               isSingletonLT(LevelType::SingletonNo) &&
-               isSingletonLT(LevelType::SingletonNuNo) &&
-               !isSingletonLT(LevelType::LooseCompressed) &&
-               !isSingletonLT(LevelType::LooseCompressedNu) &&
-               !isSingletonLT(LevelType::LooseCompressedNo) &&
-               !isSingletonLT(LevelType::LooseCompressedNuNo) &&
-               !isSingletonLT(LevelType::NOutOfM)),
-              "isSingletonLT definition is broken");
-
-static_assert((!isLooseCompressedLT(LevelType::Dense) &&
-               !isLooseCompressedLT(LevelType::Compressed) &&
-               !isLooseCompressedLT(LevelType::CompressedNu) &&
-               !isLooseCompressedLT(LevelType::CompressedNo) &&
-               !isLooseCompressedLT(LevelType::CompressedNuNo) &&
-               !isLooseCompressedLT(LevelType::Singleton) &&
-               !isLooseCompressedLT(LevelType::SingletonNu) &&
-               !isLooseCompressedLT(LevelType::SingletonNo) &&
-               !isLooseCompressedLT(LevelType::SingletonNuNo) &&
-               isLooseCompressedLT(LevelType::LooseCompressed) &&
-           ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/81934


More information about the Mlir-commits mailing list