[Mlir-commits] [mlir] Revert "[mlir][sparse] remove LevelType enum, construct LevelType from LevelF…" (PR #81923)

Mehdi Amini llvmlistbot at llvm.org
Thu Feb 15 13:26:36 PST 2024


https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/81923

Reverts llvm/llvm-project#81799 ; this broke the mlir gcc7 bot.

>From c11e879dec122a027ca9ab897fa9c6517cc3f33d Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 15 Feb 2024 13:26:07 -0800
Subject: [PATCH] =?UTF-8?q?Revert=20"[mlir][sparse]=20remove=20LevelType?=
 =?UTF-8?q?=20enum,=20construct=20LevelType=20from=20LevelF=E2=80=A6"?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This reverts commit 235ec0f791749d94ac1ca1441b8b06d4ba09792c.
---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      | 527 +++++++++++-------
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        |  10 +-
 .../SparseTensor/IR/Detail/LvlTypeParser.cpp  |   4 +-
 .../SparseTensor/IR/SparseTensorDialect.cpp   |  16 +-
 .../Transforms/SparseTensorRewriting.cpp      |   2 +-
 .../Transforms/Utils/SparseTensorLevel.cpp    |   6 +-
 .../lib/Dialect/SparseTensor/Utils/Merger.cpp |   3 +-
 .../Dialect/SparseTensor/MergerTest.cpp       |  34 +-
 8 files changed, 357 insertions(+), 245 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index a20a7906189d01..74cc0dee554a17 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -153,9 +153,45 @@ enum class Action : uint32_t {
   kSortCOOInPlace = 8,
 };
 
+/// This enum defines all the sparse representations supportable by
+/// the SparseTensor dialect. We use a lightweight encoding to encode
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
+/// the format is NOutOfM.
+/// The encoding is chosen for performance of the runtime library, and thus may
+/// change in future versions; consequently, client code should use the
+/// predicate functions defined below, rather than relying on knowledge
+/// about the particular binary encoding.
+///
+/// The `Undef` "format" is a special value used internally for cases
+/// where we need to store an undefined or indeterminate `LevelType`.
+/// It should not be used externally, since it does not indicate an
+/// actual/representable format.
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
+enum class LevelType : uint64_t {
+  Undef = 0x000000000000,
+  Dense = 0x000000010000,
+  Compressed = 0x000000020000,
+  CompressedNu = 0x000000020001,
+  CompressedNo = 0x000000020002,
+  CompressedNuNo = 0x000000020003,
+  Singleton = 0x000000040000,
+  SingletonNu = 0x000000040001,
+  SingletonNo = 0x000000040002,
+  SingletonNuNo = 0x000000040003,
+  LooseCompressed = 0x000000080000,
+  LooseCompressedNu = 0x000000080001,
+  LooseCompressedNo = 0x000000080002,
+  LooseCompressedNuNo = 0x000000080003,
+  NOutOfM = 0x000000100000,
+};
+
 /// This enum defines all supported storage format without the level properties.
 enum class LevelFormat : uint64_t {
-  Undef = 0x00000000,
   Dense = 0x00010000,
   Compressed = 0x00020000,
   Singleton = 0x00040000,
@@ -163,240 +199,327 @@ enum class LevelFormat : uint64_t {
   NOutOfM = 0x00100000,
 };
 
-template <LevelFormat... targets>
-constexpr bool isAnyOfFmt(LevelFormat fmt) {
-  return (... || (targets == fmt));
-}
-
-/// Returns string representation of the given level format.
-constexpr const char *toFormatString(LevelFormat lvlFmt) {
-  switch (lvlFmt) {
-  case LevelFormat::Undef:
-    return "undef";
-  case LevelFormat::Dense:
-    return "dense";
-  case LevelFormat::Compressed:
-    return "compressed";
-  case LevelFormat::Singleton:
-    return "singleton";
-  case LevelFormat::LooseCompressed:
-    return "loose_compressed";
-  case LevelFormat::NOutOfM:
-    return "structured";
-  }
-  return "";
-}
-
 /// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropNonDefault : uint64_t {
+enum class LevelPropertyNondefault : uint64_t {
   Nonunique = 0x0001,
   Nonordered = 0x0002,
 };
 
-/// Returns string representation of the given level properties.
-constexpr const char *toPropString(LevelPropNonDefault lvlProp) {
-  switch (lvlProp) {
-  case LevelPropNonDefault::Nonunique:
-    return "nonunique";
-  case LevelPropNonDefault::Nonordered:
-    return "nonordered";
-  }
-  return "";
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 32) & 0xff;
 }
 
-/// This enum defines all the sparse representations supportable by
-/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// the "format" per se (dense, compressed, singleton, loose_compressed,
-/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
-/// the format is NOutOfM.
-/// The encoding is chosen for performance of the runtime library, and thus may
-/// change in future versions; consequently, client code should use the
-/// predicate functions defined below, rather than relying on knowledge
-/// about the particular binary encoding.
-///
-/// The `Undef` "format" is a special value used internally for cases
-/// where we need to store an undefined or indeterminate `LevelType`.
-/// It should not be used externally, since it does not indicate an
-/// actual/representable format.
-
-struct LevelType {
-public:
-  /// Check that the `LevelType` contains a valid (possibly undefined) value.
-  static constexpr bool isValidLvlBits(uint64_t lvlBits) {
-    auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000);
-    const uint64_t propertyBits = lvlBits & 0xffff;
-    // If undefined/dense/NOutOfM, then must be unique and ordered.
-    // Otherwise, the format must be one of the known ones.
-    return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense,
-                       LevelFormat::NOutOfM>(fmt))
-               ? (propertyBits == 0)
-               : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton,
-                             LevelFormat::LooseCompressed>(fmt));
-  }
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
 
-  /// Convert a LevelFormat to its corresponding LevelType with the given
-  /// properties. Returns std::nullopt when the properties are not applicable
-  /// for the input level format.
-  static std::optional<LevelType>
-  buildLvlType(LevelFormat lf,
-               const std::vector<LevelPropNonDefault> &properties,
-               uint64_t n = 0, uint64_t m = 0) {
-    assert((n & 0xff) == n && (m & 0xff) == m);
-    uint64_t newN = n << 32;
-    uint64_t newM = m << 40;
-    uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM;
-    for (auto p : properties)
-      ltBits |= static_cast<uint64_t>(p);
-
-    return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits))
-                                  : std::nullopt;
-  }
-  static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered,
-                                               bool unique, uint64_t n = 0,
-                                               uint64_t m = 0) {
-    std::vector<LevelPropNonDefault> properties;
-    if (!ordered)
-      properties.push_back(LevelPropNonDefault::Nonordered);
-    if (!unique)
-      properties.push_back(LevelPropNonDefault::Nonunique);
-    return buildLvlType(lf, properties, n, m);
-  }
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
 
-  /// Explicit conversion from uint64_t.
-  constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) {
-    assert(isValidLvlBits(bits));
-  };
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
 
-  /// Constructs a LevelType with the given format using all default properties.
-  /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) {
-    assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>());
-  };
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+  return ((static_cast<uint64_t>(lt) & 0x100000) ==
+          static_cast<uint64_t>(LevelType::NOutOfM));
+}
 
-  /// Converts to uint64_t
-  explicit operator uint64_t() const { return lvlBits; }
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+  return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
 
-  bool operator==(const LevelType lhs) const {
-    return static_cast<uint64_t>(lhs) == lvlBits;
+/// Returns string representation of the given dimension level type.
+constexpr const char *toMLIRString(LevelType lvlType) {
+  auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
+  switch (lt) {
+  case LevelType::Undef:
+    return "undef";
+  case LevelType::Dense:
+    return "dense";
+  case LevelType::Compressed:
+    return "compressed";
+  case LevelType::CompressedNu:
+    return "compressed(nonunique)";
+  case LevelType::CompressedNo:
+    return "compressed(nonordered)";
+  case LevelType::CompressedNuNo:
+    return "compressed(nonunique, nonordered)";
+  case LevelType::Singleton:
+    return "singleton";
+  case LevelType::SingletonNu:
+    return "singleton(nonunique)";
+  case LevelType::SingletonNo:
+    return "singleton(nonordered)";
+  case LevelType::SingletonNuNo:
+    return "singleton(nonunique, nonordered)";
+  case LevelType::LooseCompressed:
+    return "loose_compressed";
+  case LevelType::LooseCompressedNu:
+    return "loose_compressed(nonunique)";
+  case LevelType::LooseCompressedNo:
+    return "loose_compressed(nonordered)";
+  case LevelType::LooseCompressedNuNo:
+    return "loose_compressed(nonunique, nonordered)";
+  case LevelType::NOutOfM:
+    return "structured";
   }
-  bool operator!=(const LevelType lhs) const { return !(*this == lhs); }
+  return "";
+}
 
-  LevelType stripProperties() const { return LevelType(lvlBits & ~0xffff); }
+/// Check that the `LevelType` contains a valid (possibly undefined) value.
+constexpr bool isValidLT(LevelType lt) {
+  const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+  // If undefined/dense/NOutOfM, then must be unique and ordered.
+  // Otherwise, the format must be one of the known ones.
+  return (formatBits <= 0x10000 || formatBits == 0x100000)
+             ? (propertyBits == 0)
+             : (formatBits == 0x20000 || formatBits == 0x40000 ||
+                formatBits == 0x80000);
+}
 
-  /// Get N of NOutOfM level type.
-  constexpr uint64_t getN() const {
-    assert(isa<LevelFormat::NOutOfM>());
-    return (lvlBits >> 32) & 0xff;
-  }
+/// Check if the `LevelType` is the special undefined value.
+constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
 
-  /// Get M of NOutOfM level type.
-  constexpr uint64_t getM() const {
-    assert(isa<LevelFormat::NOutOfM>());
-    return (lvlBits >> 40) & 0xff;
-  }
+/// Check if the `LevelType` is dense (regardless of properties).
+constexpr bool isDenseLT(LevelType lt) {
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Dense);
+}
 
-  /// Get the `LevelFormat` of the `LevelType`.
-  LevelFormat getLvlFmt() const {
-    return static_cast<LevelFormat>(lvlBits & 0xffff0000);
-  }
+/// Check if the `LevelType` is compressed (regardless of properties).
+constexpr bool isCompressedLT(LevelType lt) {
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Compressed);
+}
 
-  /// Check if the `LevelType` is in the `LevelFormat`.
-  template <LevelFormat fmt>
-  bool isa() const {
-    return getLvlFmt() == fmt;
-  }
+/// Check if the `LevelType` is singleton (regardless of properties).
+constexpr bool isSingletonLT(LevelType lt) {
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Singleton);
+}
 
-  /// Check if the `LevelType` has the properties
-  template <LevelPropNonDefault p>
-  bool isa() const {
-    return lvlBits & static_cast<uint64_t>(p);
-  }
+/// Check if the `LevelType` is loose compressed (regardless of properties).
+constexpr bool isLooseCompressedLT(LevelType lt) {
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::LooseCompressed);
+}
 
-  /// Check if the `LevelType` needs positions array.
-  bool isWithPosLT() const {
-    return isa<LevelFormat::Compressed>() ||
-           isa<LevelFormat::LooseCompressed>();
-  }
+/// Check if the `LevelType` needs positions array.
+constexpr bool isWithPosLT(LevelType lt) {
+  return isCompressedLT(lt) || isLooseCompressedLT(lt);
+}
 
-  /// Check if the `LevelType` needs coordinates array.
-  constexpr bool isWithCrdLT() const {
-    // All sparse levels has coordinate array.
-    return !isa<LevelFormat::Dense>();
-  }
+/// Check if the `LevelType` needs coordinates array.
+constexpr bool isWithCrdLT(LevelType lt) {
+  return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
+         isNOutOfMLT(lt);
+}
 
-  std::string toMLIRString() const {
-    std::string lvlStr = toFormatString(getLvlFmt());
-    std::string propStr = "";
-    if (isa<LevelPropNonDefault::Nonunique>())
-      propStr += toPropString(LevelPropNonDefault::Nonunique);
-
-    if (isa<LevelPropNonDefault::Nonordered>()) {
-      if (!propStr.empty())
-        propStr += ", ";
-      propStr += toPropString(LevelPropNonDefault::Nonordered);
-    }
-    if (!propStr.empty())
-      lvlStr += ("(" + propStr + ")");
-    return lvlStr;
-  }
+/// Check if the `LevelType` is ordered (regardless of storage format).
+constexpr bool isOrderedLT(LevelType lt) {
+  return !(static_cast<uint64_t>(lt) & 2);
+  return !(static_cast<uint64_t>(lt) & 2);
+}
 
-private:
-  /// Bit manipulations for LevelType:
-  ///
-  /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
-  ///
-  uint64_t lvlBits;
-};
+/// Check if the `LevelType` is unique (regardless of storage format).
+constexpr bool isUniqueLT(LevelType lt) {
+  return !(static_cast<uint64_t>(lt) & 1);
+  return !(static_cast<uint64_t>(lt) & 1);
+}
 
-// For backward-compatibility. TODO: remove below after fully migration.
-constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
-constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+/// Convert a LevelType to its corresponding LevelFormat.
+/// Returns std::nullopt when input lt is Undef.
+constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
+  if (lt == LevelType::Undef)
+    return std::nullopt;
+  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
+}
 
+/// Convert a LevelFormat to its corresponding LevelType with the given
+/// properties. Returns std::nullopt when the properties are not applicable
+/// for the input level format.
 inline std::optional<LevelType>
 buildLevelType(LevelFormat lf,
-               const std::vector<LevelPropNonDefault> &properties,
+               const std::vector<LevelPropertyNondefault> &properties,
                uint64_t n = 0, uint64_t m = 0) {
-  return LevelType::buildLvlType(lf, properties, n, m);
+  uint64_t newN = n << 32;
+  uint64_t newM = m << 40;
+  uint64_t ltInt = static_cast<uint64_t>(lf) | newN | newM;
+  for (auto p : properties) {
+    ltInt |= static_cast<uint64_t>(p);
+  }
+  auto lt = static_cast<LevelType>(ltInt);
+  return isValidLT(lt) ? std::optional(lt) : std::nullopt;
 }
+
 inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
                                                bool unique, uint64_t n = 0,
                                                uint64_t m = 0) {
-  return LevelType::buildLvlType(lf, ordered, unique, n, m);
+  std::vector<LevelPropertyNondefault> properties;
+  if (!ordered)
+    properties.push_back(LevelPropertyNondefault::Nonordered);
+  if (!unique)
+    properties.push_back(LevelPropertyNondefault::Nonunique);
+  return buildLevelType(lf, properties, n, m);
 }
-inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); }
-inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); }
-inline bool isCompressedLT(LevelType lt) {
-  return lt.isa<LevelFormat::Compressed>();
-}
-inline bool isLooseCompressedLT(LevelType lt) {
-  return lt.isa<LevelFormat::LooseCompressed>();
-}
-inline bool isSingletonLT(LevelType lt) {
-  return lt.isa<LevelFormat::Singleton>();
-}
-inline bool isNOutOfMLT(LevelType lt) { return lt.isa<LevelFormat::NOutOfM>(); }
-inline bool isOrderedLT(LevelType lt) {
-  return !lt.isa<LevelPropNonDefault::Nonordered>();
-}
-inline bool isUniqueLT(LevelType lt) {
-  return !lt.isa<LevelPropNonDefault::Nonunique>();
-}
-inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); }
-inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); }
-inline bool isValidLT(LevelType lt) {
-  return LevelType::isValidLvlBits(static_cast<uint64_t>(lt));
-}
-inline std::optional<LevelFormat> getLevelFormat(LevelType lt) {
-  LevelFormat fmt = lt.getLvlFmt();
-  if (fmt == LevelFormat::Undef)
-    return std::nullopt;
-  return fmt;
-}
-inline uint64_t getN(LevelType lt) { return lt.getN(); }
-inline uint64_t getM(LevelType lt) { return lt.getM(); }
-inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
-  return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m;
-}
-inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); }
+
+//
+// Ensure the above methods work as intended.
+//
+
+static_assert(
+    (getLevelFormat(LevelType::Undef) == std::nullopt &&
+     *getLevelFormat(LevelType::Dense) == LevelFormat::Dense &&
+     *getLevelFormat(LevelType::Compressed) == LevelFormat::Compressed &&
+     *getLevelFormat(LevelType::CompressedNu) == LevelFormat::Compressed &&
+     *getLevelFormat(LevelType::CompressedNo) == LevelFormat::Compressed &&
+     *getLevelFormat(LevelType::CompressedNuNo) == LevelFormat::Compressed &&
+     *getLevelFormat(LevelType::Singleton) == LevelFormat::Singleton &&
+     *getLevelFormat(LevelType::SingletonNu) == LevelFormat::Singleton &&
+     *getLevelFormat(LevelType::SingletonNo) == LevelFormat::Singleton &&
+     *getLevelFormat(LevelType::SingletonNuNo) == LevelFormat::Singleton &&
+     *getLevelFormat(LevelType::LooseCompressed) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(LevelType::LooseCompressedNu) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(LevelType::LooseCompressedNo) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(LevelType::LooseCompressedNuNo) ==
+         LevelFormat::LooseCompressed &&
+     *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
+    "getLevelFormat conversion is broken");
+
+static_assert(
+    (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
+     isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
+     isValidLT(LevelType::CompressedNo) &&
+     isValidLT(LevelType::CompressedNuNo) && isValidLT(LevelType::Singleton) &&
+     isValidLT(LevelType::SingletonNu) && isValidLT(LevelType::SingletonNo) &&
+     isValidLT(LevelType::SingletonNuNo) &&
+     isValidLT(LevelType::LooseCompressed) &&
+     isValidLT(LevelType::LooseCompressedNu) &&
+     isValidLT(LevelType::LooseCompressedNo) &&
+     isValidLT(LevelType::LooseCompressedNuNo) &&
+     isValidLT(LevelType::NOutOfM)),
+    "isValidLT definition is broken");
+
+static_assert((isDenseLT(LevelType::Dense) &&
+               !isDenseLT(LevelType::Compressed) &&
+               !isDenseLT(LevelType::CompressedNu) &&
+               !isDenseLT(LevelType::CompressedNo) &&
+               !isDenseLT(LevelType::CompressedNuNo) &&
+               !isDenseLT(LevelType::Singleton) &&
+               !isDenseLT(LevelType::SingletonNu) &&
+               !isDenseLT(LevelType::SingletonNo) &&
+               !isDenseLT(LevelType::SingletonNuNo) &&
+               !isDenseLT(LevelType::LooseCompressed) &&
+               !isDenseLT(LevelType::LooseCompressedNu) &&
+               !isDenseLT(LevelType::LooseCompressedNo) &&
+               !isDenseLT(LevelType::LooseCompressedNuNo) &&
+               !isDenseLT(LevelType::NOutOfM)),
+              "isDenseLT definition is broken");
+
+static_assert((!isCompressedLT(LevelType::Dense) &&
+               isCompressedLT(LevelType::Compressed) &&
+               isCompressedLT(LevelType::CompressedNu) &&
+               isCompressedLT(LevelType::CompressedNo) &&
+               isCompressedLT(LevelType::CompressedNuNo) &&
+               !isCompressedLT(LevelType::Singleton) &&
+               !isCompressedLT(LevelType::SingletonNu) &&
+               !isCompressedLT(LevelType::SingletonNo) &&
+               !isCompressedLT(LevelType::SingletonNuNo) &&
+               !isCompressedLT(LevelType::LooseCompressed) &&
+               !isCompressedLT(LevelType::LooseCompressedNu) &&
+               !isCompressedLT(LevelType::LooseCompressedNo) &&
+               !isCompressedLT(LevelType::LooseCompressedNuNo) &&
+               !isCompressedLT(LevelType::NOutOfM)),
+              "isCompressedLT definition is broken");
+
+static_assert((!isSingletonLT(LevelType::Dense) &&
+               !isSingletonLT(LevelType::Compressed) &&
+               !isSingletonLT(LevelType::CompressedNu) &&
+               !isSingletonLT(LevelType::CompressedNo) &&
+               !isSingletonLT(LevelType::CompressedNuNo) &&
+               isSingletonLT(LevelType::Singleton) &&
+               isSingletonLT(LevelType::SingletonNu) &&
+               isSingletonLT(LevelType::SingletonNo) &&
+               isSingletonLT(LevelType::SingletonNuNo) &&
+               !isSingletonLT(LevelType::LooseCompressed) &&
+               !isSingletonLT(LevelType::LooseCompressedNu) &&
+               !isSingletonLT(LevelType::LooseCompressedNo) &&
+               !isSingletonLT(LevelType::LooseCompressedNuNo) &&
+               !isSingletonLT(LevelType::NOutOfM)),
+              "isSingletonLT definition is broken");
+
+static_assert((!isLooseCompressedLT(LevelType::Dense) &&
+               !isLooseCompressedLT(LevelType::Compressed) &&
+               !isLooseCompressedLT(LevelType::CompressedNu) &&
+               !isLooseCompressedLT(LevelType::CompressedNo) &&
+               !isLooseCompressedLT(LevelType::CompressedNuNo) &&
+               !isLooseCompressedLT(LevelType::Singleton) &&
+               !isLooseCompressedLT(LevelType::SingletonNu) &&
+               !isLooseCompressedLT(LevelType::SingletonNo) &&
+               !isLooseCompressedLT(LevelType::SingletonNuNo) &&
+               isLooseCompressedLT(LevelType::LooseCompressed) &&
+               isLooseCompressedLT(LevelType::LooseCompressedNu) &&
+               isLooseCompressedLT(LevelType::LooseCompressedNo) &&
+               isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
+               !isLooseCompressedLT(LevelType::NOutOfM)),
+              "isLooseCompressedLT definition is broken");
+
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+               !isNOutOfMLT(LevelType::Compressed) &&
+               !isNOutOfMLT(LevelType::CompressedNu) &&
+               !isNOutOfMLT(LevelType::CompressedNo) &&
+               !isNOutOfMLT(LevelType::CompressedNuNo) &&
+               !isNOutOfMLT(LevelType::Singleton) &&
+               !isNOutOfMLT(LevelType::SingletonNu) &&
+               !isNOutOfMLT(LevelType::SingletonNo) &&
+               !isNOutOfMLT(LevelType::SingletonNuNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressed) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+               isNOutOfMLT(LevelType::NOutOfM)),
+              "isNOutOfMLT definition is broken");
+
+static_assert((isOrderedLT(LevelType::Dense) &&
+               isOrderedLT(LevelType::Compressed) &&
+               isOrderedLT(LevelType::CompressedNu) &&
+               !isOrderedLT(LevelType::CompressedNo) &&
+               !isOrderedLT(LevelType::CompressedNuNo) &&
+               isOrderedLT(LevelType::Singleton) &&
+               isOrderedLT(LevelType::SingletonNu) &&
+               !isOrderedLT(LevelType::SingletonNo) &&
+               !isOrderedLT(LevelType::SingletonNuNo) &&
+               isOrderedLT(LevelType::LooseCompressed) &&
+               isOrderedLT(LevelType::LooseCompressedNu) &&
+               !isOrderedLT(LevelType::LooseCompressedNo) &&
+               !isOrderedLT(LevelType::LooseCompressedNuNo) &&
+               isOrderedLT(LevelType::NOutOfM)),
+              "isOrderedLT definition is broken");
+
+static_assert((isUniqueLT(LevelType::Dense) &&
+               isUniqueLT(LevelType::Compressed) &&
+               !isUniqueLT(LevelType::CompressedNu) &&
+               isUniqueLT(LevelType::CompressedNo) &&
+               !isUniqueLT(LevelType::CompressedNuNo) &&
+               isUniqueLT(LevelType::Singleton) &&
+               !isUniqueLT(LevelType::SingletonNu) &&
+               isUniqueLT(LevelType::SingletonNo) &&
+               !isUniqueLT(LevelType::SingletonNuNo) &&
+               isUniqueLT(LevelType::LooseCompressed) &&
+               !isUniqueLT(LevelType::LooseCompressedNu) &&
+               isUniqueLT(LevelType::LooseCompressedNo) &&
+               !isUniqueLT(LevelType::LooseCompressedNuNo) &&
+               isUniqueLT(LevelType::NOutOfM)),
+              "isUniqueLT definition is broken");
 
 /// Bit manipulations for affine encoding.
 ///
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index 3ae06f220c5281..55af8becbba20e 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -34,9 +34,9 @@ static_assert(
     "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch");
 
 static_assert(static_cast<int>(MLIR_SPARSE_PROPERTY_NON_ORDERED) ==
-                      static_cast<int>(LevelPropNonDefault::Nonordered) &&
+                      static_cast<int>(LevelPropertyNondefault::Nonordered) &&
                   static_cast<int>(MLIR_SPARSE_PROPERTY_NON_UNIQUE) ==
-                      static_cast<int>(LevelPropNonDefault::Nonunique),
+                      static_cast<int>(LevelPropertyNondefault::Nonunique),
               "MlirSparseTensorLevelProperty (C-API) and "
               "LevelPropertyNondefault (C++) mismatch");
 
@@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat
 mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) {
   LevelType lt =
       static_cast<LevelType>(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl));
-  return static_cast<MlirSparseTensorLevelFormat>(lt.getLvlFmt());
+  return static_cast<MlirSparseTensorLevelFormat>(*getLevelFormat(lt));
 }
 
 int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) {
@@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType(
     const enum MlirSparseTensorLevelPropertyNondefault *properties,
     unsigned size, unsigned n, unsigned m) {
 
-  std::vector<LevelPropNonDefault> props;
+  std::vector<LevelPropertyNondefault> props;
   for (unsigned i = 0; i < size; i++)
-    props.push_back(static_cast<LevelPropNonDefault>(properties[i]));
+    props.push_back(static_cast<LevelPropertyNondefault>(properties[i]));
 
   return static_cast<MlirSparseTensorLevelType>(
       *buildLevelType(static_cast<LevelFormat>(lvlFmt), props, n, m));
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 380cccc989ec6a..0fb0d2761054b5 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -88,9 +88,9 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
            "expected valid level property (e.g. nonordered, nonunique or high)")
   if (strVal.compare("nonunique") == 0) {
-    *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonunique);
+    *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
   } else if (strVal.compare("nonordered") == 0) {
-    *properties |= static_cast<uint64_t>(LevelPropNonDefault::Nonordered);
+    *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
   } else {
     parser.emitError(loc, "unknown level property: ") << strVal;
     return failure();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6d02645d860e96..aed43f26d54f11 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -35,14 +35,6 @@
 using namespace mlir;
 using namespace mlir::sparse_tensor;
 
-// Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as
-// well.
-namespace mlir::sparse_tensor {
-llvm::hash_code hash_value(LevelType lt) {
-  return llvm::hash_value(static_cast<uint64_t>(lt));
-}
-} // namespace mlir::sparse_tensor
-
 //===----------------------------------------------------------------------===//
 // Local Convenience Methods.
 //===----------------------------------------------------------------------===//
@@ -91,11 +83,11 @@ void StorageLayout::foreachField(
   }
   // The values array.
   if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
-                 LevelFormat::Undef)))
+                 LevelType::Undef)))
     return;
   // Put metadata at the end.
   if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
-                 LevelFormat::Undef)))
+                 LevelType::Undef)))
     return;
 }
 
@@ -349,7 +341,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const {
 
 LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
   if (!getImpl())
-    return LevelFormat::Dense;
+    return LevelType::Dense;
   assert(l < getLvlRank() && "Level is out of bounds");
   return getLvlTypes()[l];
 }
@@ -983,7 +975,7 @@ static SparseTensorEncodingAttr
 getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
   SmallVector<LevelType> lts;
   for (auto lt : enc.getLvlTypes())
-    lts.push_back(lt.stripProperties());
+    lts.push_back(*buildLevelType(*getLevelFormat(lt), true, true));
 
   return SparseTensorEncodingAttr::get(
       enc.getContext(), lts,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 7326a6a3811284..235c5453f9cc98 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -46,7 +46,7 @@ static bool isZeroValue(Value val) {
 static bool isSparseTensor(Value v) {
   auto enc = getSparseTensorEncoding(v.getType());
   return enc && !llvm::all_of(enc.getLvlTypes(),
-                              [](auto lt) { return lt == LevelFormat::Dense; });
+                              [](auto lt) { return lt == LevelType::Dense; });
 }
 static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 61a3703b73bf07..c85f8204ba7527 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -63,7 +63,7 @@ class SparseLevel : public SparseTensorLevel {
 class DenseLevel : public SparseTensorLevel {
 public:
   DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
-      : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize),
+      : SparseTensorLevel(tid, lvl, LevelType::Dense, lvlSize),
         encoded(encoded) {}
 
   Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
@@ -1275,7 +1275,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
   Value sz = stt.hasEncoding() ? b.create<LvlOp>(l, t, lvl).getResult()
                                : b.create<tensor::DimOp>(l, t, lvl).getResult();
 
-  switch (lt.getLvlFmt()) {
+  switch (*getLevelFormat(lt)) {
   case LevelFormat::Dense:
     return std::make_unique<DenseLevel>(tid, lvl, sz, stt.hasEncoding());
   case LevelFormat::Compressed: {
@@ -1296,8 +1296,6 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
     Value crd = genToCoordinates(b, l, t, lvl);
     return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
   }
-  case LevelFormat::Undef:
-    llvm_unreachable("undefined level format");
   }
   llvm_unreachable("unrecognizable level format");
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 731cd79a1e3b4b..96537cbb0c4836 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -226,8 +226,7 @@ Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
       syntheticTensor(numInputOutputTensors),
       numTensors(numInputOutputTensors + 1), numLoops(numLoops),
       hasSparseOut(false),
-      lvlTypes(numTensors,
-               std::vector<LevelType>(numLoops, LevelFormat::Undef)),
+      lvlTypes(numTensors, std::vector<LevelType>(numLoops, LevelType::Undef)),
       loopToLvl(numTensors,
                 std::vector<std::optional<Level>>(numLoops, std::nullopt)),
       lvlToLoop(numTensors,
diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
index 62a19c084cac0f..ce9c0e39b31b95 100644
--- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
+++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp
@@ -313,11 +313,11 @@ class MergerTest3T1L : public MergerTestBase {
   MergerTest3T1L() : MergerTestBase(3, 1) {
     EXPECT_TRUE(merger.getOutTensorID() == tid(2));
     // Tensor 0: sparse input vector.
-    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
+    merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
     // Tensor 1: sparse input vector.
-    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
+    merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
     // Tensor 2: dense output vector.
-    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
   }
 };
 
@@ -327,13 +327,13 @@ class MergerTest4T1L : public MergerTestBase {
   MergerTest4T1L() : MergerTestBase(4, 1) {
     EXPECT_TRUE(merger.getOutTensorID() == tid(3));
     // Tensor 0: sparse input vector.
-    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
+    merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
     // Tensor 1: sparse input vector.
-    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
+    merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Compressed);
     // Tensor 2: sparse input vector
-    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
+    merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
     // Tensor 3: dense output vector
-    merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
   }
 };
 
@@ -347,11 +347,11 @@ class MergerTest3T1LD : public MergerTestBase {
   MergerTest3T1LD() : MergerTestBase(3, 1) {
     EXPECT_TRUE(merger.getOutTensorID() == tid(2));
     // Tensor 0: sparse input vector.
-    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
+    merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Compressed);
     // Tensor 1: dense input vector.
-    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
     // Tensor 2: dense output vector.
-    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Dense);
   }
 };
 
@@ -365,13 +365,13 @@ class MergerTest4T1LU : public MergerTestBase {
   MergerTest4T1LU() : MergerTestBase(4, 1) {
     EXPECT_TRUE(merger.getOutTensorID() == tid(3));
     // Tensor 0: undef input vector.
-    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
+    merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
     // Tensor 1: dense input vector.
-    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Dense);
     // Tensor 2: undef input vector.
-    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
+    merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Undef);
     // Tensor 3: dense output vector.
-    merger.setLevelAndType(tid(3), lid(0), 0, LevelFormat::Dense);
+    merger.setLevelAndType(tid(3), lid(0), 0, LevelType::Dense);
   }
 };
 
@@ -387,11 +387,11 @@ class MergerTest3T1LSo : public MergerTestBase {
     EXPECT_TRUE(merger.getSynTensorID() == tid(3));
     merger.setHasSparseOut(true);
     // Tensor 0: undef input vector.
-    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
+    merger.setLevelAndType(tid(0), lid(0), 0, LevelType::Undef);
     // Tensor 1: undef input vector.
-    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
+    merger.setLevelAndType(tid(1), lid(0), 0, LevelType::Undef);
     // Tensor 2: sparse output vector.
-    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
+    merger.setLevelAndType(tid(2), lid(0), 0, LevelType::Compressed);
   }
 };
 



More information about the Mlir-commits mailing list