[Mlir-commits] [mlir] [mlir][sparse] Expand LevelType to 64 bits and implement parsing n out of m (PR #79935)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 30 12:14:28 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Yinying Li (yinying-lisa-li)

<details>
<summary>Changes</summary>



---

Patch is 57.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79935.diff


28 Files Affected:

- (modified) mlir/include/mlir-c/Dialect/SparseTensor.h (+14-14) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+135-85) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2-2) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+1-1) 
- (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+1-1) 
- (modified) mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h (+9-9) 
- (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+1-1) 
- (modified) mlir/lib/CAPI/Dialect/SparseTensor.cpp (+30-19) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+43-11) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h (+4-2) 
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+14-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+3-3) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+1-1) 
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-2) 
- (modified) mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp (+1-1) 
- (modified) mlir/test/CAPI/sparse_tensor.c (+3-3) 
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (+1-1) 
- (modified) mlir/test/Dialect/SparseTensor/conversion.mlir (+8-8) 
- (modified) mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir (+24-6) 
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+6-6) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir (+1-1) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir (+1-1) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir (+1-1) 
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir (+1-1) 
- (modified) mlir/test/python/dialects/sparse_tensor/dialect.py (+2-2) 


``````````diff
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 41d024db04964..5fc1f51452482 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,20 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
 enum MlirSparseTensorLevelType {
-  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4,                   // 0b00001_00
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8,              // 0b00010_00
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9,           // 0b00010_01
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10,          // 0b00010_10
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11,       // 0b00010_11
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16,              // 0b00100_00
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17,           // 0b00100_01
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18,           // 0b00100_10
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19,        // 0b00100_11
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32,       // 0b01000_00
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33,    // 0b01000_01
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34,    // 0b01000_10
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
-  MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64,        // 0b10000_00
+  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536,                   // 0x00_00_0001_0000
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072,             // 0x00_00_0002_0000
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073,          // 0x00_00_0002_0001
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074,          // 0x00_00_0002_0002
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075,       // 0x00_00_0002_0003
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144,              // 0x00_00_0004_0000
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145,           // 0x00_00_0004_0001
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146,           // 0x00_00_0004_0002
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147,        // 0x00_00_0004_0003
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288,       // 0x00_00_0008_0000
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289,    // 0x00_00_0008_0001
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290,    // 0x00_00_0008_0002
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003
+  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576,            // 0x00_00_0010_0000
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ac91bfa5ae622..99443957d01d5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,9 +154,10 @@ enum class Action : uint32_t {
 
 /// This enum defines all the sparse representations supportable by
 /// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton, loose_compressed,
-/// two-out-of-four) as well as the "properties" (ordered, unique). The
-/// encoding is chosen for performance of the runtime library, and thus may
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m for
+/// NOutOfM level type.
+/// The encoding is chosen for performance of the runtime library, and thus may
 /// change in future versions; consequently, client code should use the
 /// predicate functions defined below, rather than relying on knowledge
 /// about the particular binary encoding.
@@ -165,39 +166,72 @@ enum class Action : uint32_t {
 /// where we need to store an undefined or indeterminate `LevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
-enum class LevelType : uint8_t {
-  Undef = 0,                // 0b00000_00
-  Dense = 4,                // 0b00001_00
-  Compressed = 8,           // 0b00010_00
-  CompressedNu = 9,         // 0b00010_01
-  CompressedNo = 10,        // 0b00010_10
-  CompressedNuNo = 11,      // 0b00010_11
-  Singleton = 16,           // 0b00100_00
-  SingletonNu = 17,         // 0b00100_01
-  SingletonNo = 18,         // 0b00100_10
-  SingletonNuNo = 19,       // 0b00100_11
-  LooseCompressed = 32,     // 0b01000_00
-  LooseCompressedNu = 33,   // 0b01000_01
-  LooseCompressedNo = 34,   // 0b01000_10
-  LooseCompressedNuNo = 35, // 0b01000_11
-  TwoOutOfFour = 64,        // 0b10000_00
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
+enum class LevelType : uint64_t {
+  Undef = 0,                    // 0x00_00_0000_0000
+  Dense = 65536,                // 0x00_00_0001_0000
+  Compressed = 131072,          // 0x00_00_0002_0000
+  CompressedNu = 131073,        // 0x00_00_0002_0001
+  CompressedNo = 131074,        // 0x00_00_0002_0002
+  CompressedNuNo = 131075,      // 0x00_00_0002_0003
+  Singleton = 262144,           // 0x00_00_0004_0000
+  SingletonNu = 262145,         // 0x00_00_0004_0001
+  SingletonNo = 262146,         // 0x00_00_0004_0002
+  SingletonNuNo = 262147,       // 0x00_00_0004_0003
+  LooseCompressed = 524288,     // 0x00_00_0008_0000
+  LooseCompressedNu = 524289,   // 0x00_00_0008_0001
+  LooseCompressedNo = 524290,   // 0x00_00_0008_0002
+  LooseCompressedNuNo = 524291, // 0x00_00_0008_0003
+  NOutOfM = 1048576,            // 0x00_00_0010_0000
 };
 
 /// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
-  Dense = 4,            // 0b00001_00
-  Compressed = 8,       // 0b00010_00
-  Singleton = 16,       // 0b00100_00
-  LooseCompressed = 32, // 0b01000_00
-  TwoOutOfFour = 64,    // 0b10000_00
+enum class LevelFormat : uint64_t {
+  Dense = 65536,            // 0x0001_0000
+  Compressed = 131072,      // 0x0002_0000
+  Singleton = 262144,       // 0x0004_0000
+  LooseCompressed = 524288, // 0x0008_0000
+  NOutOfM = 1048576,        // 0x0010_0000
 };
 
 /// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint8_t {
-  Nonunique = 1,  // 0b00000_01
-  Nonordered = 2, // 0b00000_10
+enum class LevelPropertyNondefault : uint64_t {
+  Nonunique = 1,  // 0x0001
+  Nonordered = 2, // 0x0002
 };
 
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+}
+
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
+
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+  return ((static_cast<uint64_t>(lt) & 0x100000) ==
+          static_cast<uint64_t>(LevelType::NOutOfM));
+}
+
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+  return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
+
 /// Returns string representation of the given dimension level type.
 constexpr const char *toMLIRString(LevelType lt) {
   switch (lt) {
@@ -229,21 +263,24 @@ constexpr const char *toMLIRString(LevelType lt) {
     return "loose_compressed(nonordered)";
   case LevelType::LooseCompressedNuNo:
     return "loose_compressed(nonunique, nonordered)";
-  case LevelType::TwoOutOfFour:
-    return "block2_4";
+  default:
+    if (isNOutOfMLT(lt)) {
+      return "block";
+    }
   }
   return "";
 }
 
 /// Check that the `LevelType` contains a valid (possibly undefined) value.
 constexpr bool isValidLT(LevelType lt) {
-  const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
-  const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
-  // If undefined or dense, then must be unique and ordered.
+  const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+  // If undefined/dense/NOutOfM, then must be unique and ordered.
   // Otherwise, the format must be one of the known ones.
-  return (formatBits <= 1 || formatBits == 16)
+  return (formatBits <= 0x10000 || formatBits == 0x100000)
              ? (propertyBits == 0)
-             : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+             : (formatBits == 0x20000 || formatBits == 0x40000 ||
+                formatBits == 0x80000);
 }
 
 /// Check if the `LevelType` is the special undefined value.
@@ -251,32 +288,26 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
 
 /// Check if the `LevelType` is dense (regardless of properties).
 constexpr bool isDenseLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Dense);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Dense);
 }
 
 /// Check if the `LevelType` is compressed (regardless of properties).
 constexpr bool isCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Compressed);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Compressed);
 }
 
 /// Check if the `LevelType` is singleton (regardless of properties).
 constexpr bool isSingletonLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Singleton);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Singleton);
 }
 
 /// Check if the `LevelType` is loose compressed (regardless of properties).
 constexpr bool isLooseCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::LooseCompressed);
-}
-
-/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
-constexpr bool is2OutOf4LT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::TwoOutOfFour);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::LooseCompressed);
 }
 
 /// Check if the `LevelType` needs positions array.
@@ -287,17 +318,17 @@ constexpr bool isWithPosLT(LevelType lt) {
 /// Check if the `LevelType` needs coordinates array.
 constexpr bool isWithCrdLT(LevelType lt) {
   return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-         is2OutOf4LT(lt);
+         isNOutOfMLT(lt);
 }
 
 /// Check if the `LevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 2);
+  return !(static_cast<uint64_t>(lt) & 2);
 }
 
 /// Check if the `LevelType` is unique (regardless of storage format).
 constexpr bool isUniqueLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 1);
+  return !(static_cast<uint64_t>(lt) & 1);
 }
 
 /// Convert a LevelType to its corresponding LevelFormat.
@@ -305,21 +336,25 @@ constexpr bool isUniqueLT(LevelType lt) {
 constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
   if (lt == LevelType::Undef)
     return std::nullopt;
-  return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
+  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
 }
 
 /// Convert a LevelFormat to its corresponding LevelType with the given
 /// properties. Returns std::nullopt when the properties are not applicable
 /// for the input level format.
 constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
-                                                  bool unique) {
-  auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
-                                   (ordered ? 0 : 2) | (unique ? 0 : 1));
+                                                  bool unique, uint64_t n = 0,
+                                                  uint64_t m = 0) {
+  uint64_t newN = n << 32;
+  uint64_t newM = m << 40;
+  auto lt =
+      static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
+                             (unique ? 0 : 1) | newN | newM);
   return isValidLT(lt) ? std::optional(lt) : std::nullopt;
 }
 
 //
-// Ensure the above methods work as indended.
+// Ensure the above methods work as intended.
 //
 
 static_assert(
@@ -341,7 +376,7 @@ static_assert(
          LevelFormat::LooseCompressed &&
      *getLevelFormat(LevelType::LooseCompressedNuNo) ==
          LevelFormat::LooseCompressed &&
-     *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
+     *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
     "getLevelFormat conversion is broken");
 
 static_assert(
@@ -373,13 +408,28 @@ static_assert(
          LevelType::LooseCompressedNo &&
      *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
          LevelType::LooseCompressedNuNo &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
-         LevelType::TwoOutOfFour),
+     buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
+     buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
+     buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
+     *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
     "buildLevelType conversion is broken");
 
+static_assert(
+    (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
+     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
+     getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
+     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
+    "getN/M conversion is broken");
+
+static_assert(
+    (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
+                      2, 4) &&
+     isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
+                      8, 10) &&
+     !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
+                       2, 4)),
+    "isValidNOutOfMLT definition is broken");
+
 static_assert(
     (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
      isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
@@ -391,7 +441,7 @@ static_assert(
      isValidLT(LevelType::LooseCompressedNu) &&
      isValidLT(LevelType::LooseCompressedNo) &&
      isValidLT(LevelType::LooseCompressedNuNo) &&
-     isValidLT(LevelType::TwoOutOfFour)),
+     isValidLT(LevelType::NOutOfM)),
     "isValidLT definition is broken");
 
 static_assert((isDenseLT(LevelType::Dense) &&
@@ -407,7 +457,7 @@ static_assert((isDenseLT(LevelType::Dense) &&
                !isDenseLT(LevelType::LooseCompressedNu) &&
                !isDenseLT(LevelType::LooseCompressedNo) &&
                !isDenseLT(LevelType::LooseCompressedNuNo) &&
-               !isDenseLT(LevelType::TwoOutOfFour)),
+               !isDenseLT(LevelType::NOutOfM)),
               "isDenseLT definition is broken");
 
 static_assert((!isCompressedLT(LevelType::Dense) &&
@@ -423,7 +473,7 @@ static_assert((!isCompressedLT(LevelType::Dense) &&
                !isCompressedLT(LevelType::LooseCompressedNu) &&
                !isCompressedLT(LevelType::LooseCompressedNo) &&
                !isCompressedLT(LevelType::LooseCompressedNuNo) &&
-               !isCompressedLT(LevelType::TwoOutOfFour)),
+               !isCompressedLT(LevelType::NOutOfM)),
               "isCompressedLT definition is broken");
 
 static_assert((!isSingletonLT(LevelType::Dense) &&
@@ -439,7 +489,7 @@ static_assert((!isSingletonLT(LevelType::Dense) &&
                !isSingletonLT(LevelType::LooseCompressedNu) &&
                !isSingletonLT(LevelType::LooseCompressedNo) &&
                !isSingletonLT(LevelType::LooseCompressedNuNo) &&
-               !isSingletonLT(LevelType::TwoOutOfFour)),
+               !isSingletonLT(LevelType::NOutOfM)),
               "isSingletonLT definition is broken");
 
 static_assert((!isLooseCompressedLT(LevelType::Dense) &&
@@ -455,24 +505,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) &&
                isLooseCompressedLT(LevelType::LooseCompressedNu) &&
                isLooseCompressedLT(LevelType::LooseCompressedNo) &&
                isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
-               !isLooseCompressedLT(LevelType::TwoOutOfFour)),
+               !isLooseCompressedLT(LevelType::NOutOfM)),
               "isLooseCompressedLT definition is broken");
 
-static_assert((!is2OutOf4LT(LevelType::Dense) &&
-               !is2OutOf4LT(LevelType::Compressed) &&
-               !is2OutOf4LT(LevelType::CompressedNu) &&
-               !is2OutOf4LT(LevelType::CompressedNo) &&
-               !is2OutOf4LT(LevelType::CompressedNuNo) &&
-               !is2OutOf4LT(LevelType::Singleton) &&
-               !is2OutOf4LT(LevelType::SingletonNu) &&
-               !is2OutOf4LT(LevelType::SingletonNo) &&
-               !is2OutOf4LT(LevelType::SingletonNuNo) &&
-               !is2OutOf4LT(LevelType::LooseCompressed) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNu) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNo) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
-               is2OutOf4LT(LevelType::TwoOutOfFour)),
-              "is2OutOf4LT definition is broken");
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+               !isNOutOfMLT(LevelType::Compressed) &&
+               !isNOutOfMLT(LevelType::CompressedNu) &&
+               !isNOutOfMLT(LevelType::CompressedNo) &&
+               !isNOutOfMLT(LevelType::CompressedNuNo) &&
+               !isNOutOfMLT(LevelType::Singleton) &&
+               !isNOutOfMLT(LevelType::SingletonNu) &&
+               !isNOutOfMLT(LevelType::SingletonNo) &&
+               !isNOutOfMLT(LevelType::SingletonNuNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressed) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+               isNOutOfMLT(LevelType::NOutOfM)),
+              "isNOutOfMLT definition is broken");
 
 static_assert((isOrderedLT(LevelType::Dense) &&
                isOrderedLT(LevelType::Compressed) &&
@@ -487,7 +537,7 @@ static_assert((isOrderedLT(LevelType::Dense) &&
                isOrderedLT(LevelType::LooseCompressedNu) &&
                !isOrderedLT(LevelType::LooseCompressedNo) &&
                !isOrderedLT(LevelType::LooseCompressedNuNo) &&
-               isOrderedLT(LevelType::TwoOutOfFour)),
+               isOrderedLT(LevelType::NOutOfM)),
               "isOrderedLT definition is broken");
 
 static_assert((isUniqueLT(LevelType::Dense) &&
@@ -503,7 +553,7 @@ static_assert((isUniqueLT(LevelType::Dense) &&
                !isUniqueLT(LevelType::LooseCompressedNu) &&
                isUniqueLT(LevelType::LooseCompressedNo) &&
                !isUniqueLT(LevelType::LooseCompressedNuNo) &&
-               isUniqueLT(LevelType::TwoOutOfFour)),
+               isUniqueLT(LevelType::NOutOfM)),
               "isUniqueLT definition is broken");
 
 /// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 12c1068ae1f54..08ba96d437045 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     - **compressed** : only nonzeros along...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/79935


More information about the Mlir-commits mailing list