[Mlir-commits] [mlir] e5924d6 - [mlir][sparse] Implement parsing n out of m (#79935)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 8 11:38:46 PST 2024


Author: Yinying Li
Date: 2024-02-08T14:38:42-05:00
New Revision: e5924d64991abb4da111317ff5e8d9147265354a

URL: https://github.com/llvm/llvm-project/commit/e5924d64991abb4da111317ff5e8d9147265354a
DIFF: https://github.com/llvm/llvm-project/commit/e5924d64991abb4da111317ff5e8d9147265354a.diff

LOG: [mlir][sparse] Implement parsing n out of m (#79935)

1. Add parsing methods for block[n, m].
2. Encode n and m with the newly extended 64-bit LevelType enum.
3. Update 2:4 methods names/comments to n:m.

Added: 
    

Modified: 
    mlir/include/mlir-c/Dialect/SparseTensor.h
    mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
    mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
    mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
    mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
    mlir/lib/Bindings/Python/DialectSparseTensor.cpp
    mlir/lib/CAPI/Dialect/SparseTensor.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
    mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
    mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
    mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
    mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
    mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
    mlir/test/CAPI/sparse_tensor.c
    mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
    mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
    mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
    mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
    mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
    mlir/test/python/dialects/sparse_tensor/dialect.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 42d8400cb5e958..2c71b0008ad16a 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -28,20 +28,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 typedef uint64_t MlirSparseTensorLevelType;
 
 enum MlirBaseSparseTensorLevelType {
-  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4,                   // 0b00001_00
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8,              // 0b00010_00
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9,           // 0b00010_01
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10,          // 0b00010_10
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11,       // 0b00010_11
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16,              // 0b00100_00
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17,           // 0b00100_01
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18,           // 0b00100_10
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19,        // 0b00100_11
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32,       // 0b01000_00
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33,    // 0b01000_01
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34,    // 0b01000_10
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
-  MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64,        // 0b10000_00
+  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
+  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 86c52bfc651ef1..e940d203be9ed5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,9 +154,10 @@ enum class Action : uint32_t {
 
 /// This enum defines all the sparse representations supportable by
 /// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton, loose_compressed,
-/// two-out-of-four) as well as the "properties" (ordered, unique). The
-/// encoding is chosen for performance of the runtime library, and thus may
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
+/// the format is NOutOfM.
+/// The encoding is chosen for performance of the runtime library, and thus may
 /// change in future versions; consequently, client code should use the
 /// predicate functions defined below, rather than relying on knowledge
 /// about the particular binary encoding.
@@ -165,41 +166,75 @@ enum class Action : uint32_t {
 /// where we need to store an undefined or indeterminate `LevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
 enum class LevelType : uint64_t {
-  Undef = 0,                // 0b00000_00
-  Dense = 4,                // 0b00001_00
-  Compressed = 8,           // 0b00010_00
-  CompressedNu = 9,         // 0b00010_01
-  CompressedNo = 10,        // 0b00010_10
-  CompressedNuNo = 11,      // 0b00010_11
-  Singleton = 16,           // 0b00100_00
-  SingletonNu = 17,         // 0b00100_01
-  SingletonNo = 18,         // 0b00100_10
-  SingletonNuNo = 19,       // 0b00100_11
-  LooseCompressed = 32,     // 0b01000_00
-  LooseCompressedNu = 33,   // 0b01000_01
-  LooseCompressedNo = 34,   // 0b01000_10
-  LooseCompressedNuNo = 35, // 0b01000_11
-  TwoOutOfFour = 64,        // 0b10000_00
+  Undef = 0x000000000000,
+  Dense = 0x000000010000,
+  Compressed = 0x000000020000,
+  CompressedNu = 0x000000020001,
+  CompressedNo = 0x000000020002,
+  CompressedNuNo = 0x000000020003,
+  Singleton = 0x000000040000,
+  SingletonNu = 0x000000040001,
+  SingletonNo = 0x000000040002,
+  SingletonNuNo = 0x000000040003,
+  LooseCompressed = 0x000000080000,
+  LooseCompressedNu = 0x000000080001,
+  LooseCompressedNo = 0x000000080002,
+  LooseCompressedNuNo = 0x000000080003,
+  NOutOfM = 0x000000100000,
 };
 
 /// This enum defines all supported storage format without the level properties.
 enum class LevelFormat : uint64_t {
-  Dense = 4,            // 0b00001_00
-  Compressed = 8,       // 0b00010_00
-  Singleton = 16,       // 0b00100_00
-  LooseCompressed = 32, // 0b01000_00
-  TwoOutOfFour = 64,    // 0b10000_00
+  Dense = 0x00010000,
+  Compressed = 0x00020000,
+  Singleton = 0x00040000,
+  LooseCompressed = 0x00080000,
+  NOutOfM = 0x00100000,
 };
 
 /// This enum defines all the nondefault properties for storage formats.
 enum class LevelPropertyNondefault : uint64_t {
-  Nonunique = 1,  // 0b00000_01
-  Nonordered = 2, // 0b00000_10
+  Nonunique = 0x0001,
+  Nonordered = 0x0002,
 };
 
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+}
+
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
+
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+  return ((static_cast<uint64_t>(lt) & 0x100000) ==
+          static_cast<uint64_t>(LevelType::NOutOfM));
+}
+
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+  return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
+
 /// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lt) {
+constexpr const char *toMLIRString(LevelType lvlType) {
+  auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
   switch (lt) {
   case LevelType::Undef:
     return "undef";
@@ -229,21 +264,22 @@ constexpr const char *toMLIRString(LevelType lt) {
     return "loose_compressed(nonordered)";
   case LevelType::LooseCompressedNuNo:
     return "loose_compressed(nonunique, nonordered)";
-  case LevelType::TwoOutOfFour:
-    return "block2_4";
+  case LevelType::NOutOfM:
+    return "structured";
   }
   return "";
 }
 
 /// Check that the `LevelType` contains a valid (possibly undefined) value.
 constexpr bool isValidLT(LevelType lt) {
-  const uint64_t formatBits = static_cast<uint64_t>(lt) >> 2;
-  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 3;
-  // If undefined or dense, then must be unique and ordered.
+  const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+  // If undefined/dense/NOutOfM, then must be unique and ordered.
   // Otherwise, the format must be one of the known ones.
-  return (formatBits <= 1 || formatBits == 16)
+  return (formatBits <= 0x10000 || formatBits == 0x100000)
              ? (propertyBits == 0)
-             : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+             : (formatBits == 0x20000 || formatBits == 0x40000 ||
+                formatBits == 0x80000);
 }
 
 /// Check if the `LevelType` is the special undefined value.
@@ -251,34 +287,28 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
 
 /// Check if the `LevelType` is dense (regardless of properties).
 constexpr bool isDenseLT(LevelType lt) {
-  return (static_cast<uint64_t>(lt) & ~3) ==
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
          static_cast<uint64_t>(LevelType::Dense);
 }
 
 /// Check if the `LevelType` is compressed (regardless of properties).
 constexpr bool isCompressedLT(LevelType lt) {
-  return (static_cast<uint64_t>(lt) & ~3) ==
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
          static_cast<uint64_t>(LevelType::Compressed);
 }
 
 /// Check if the `LevelType` is singleton (regardless of properties).
 constexpr bool isSingletonLT(LevelType lt) {
-  return (static_cast<uint64_t>(lt) & ~3) ==
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
          static_cast<uint64_t>(LevelType::Singleton);
 }
 
 /// Check if the `LevelType` is loose compressed (regardless of properties).
 constexpr bool isLooseCompressedLT(LevelType lt) {
-  return (static_cast<uint64_t>(lt) & ~3) ==
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
          static_cast<uint64_t>(LevelType::LooseCompressed);
 }
 
-/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
-constexpr bool is2OutOf4LT(LevelType lt) {
-  return (static_cast<uint64_t>(lt) & ~3) ==
-         static_cast<uint64_t>(LevelType::TwoOutOfFour);
-}
-
 /// Check if the `LevelType` needs positions array.
 constexpr bool isWithPosLT(LevelType lt) {
   return isCompressedLT(lt) || isLooseCompressedLT(lt);
@@ -287,17 +317,19 @@ constexpr bool isWithPosLT(LevelType lt) {
 /// Check if the `LevelType` needs coordinates array.
 constexpr bool isWithCrdLT(LevelType lt) {
   return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-         is2OutOf4LT(lt);
+         isNOutOfMLT(lt);
 }
 
 /// Check if the `LevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedLT(LevelType lt) {
   return !(static_cast<uint64_t>(lt) & 2);
+  return !(static_cast<uint64_t>(lt) & 2);
 }
 
 /// Check if the `LevelType` is unique (regardless of storage format).
 constexpr bool isUniqueLT(LevelType lt) {
   return !(static_cast<uint64_t>(lt) & 1);
+  return !(static_cast<uint64_t>(lt) & 1);
 }
 
 /// Convert a LevelType to its corresponding LevelFormat.
@@ -305,21 +337,25 @@ constexpr bool isUniqueLT(LevelType lt) {
 constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
   if (lt == LevelType::Undef)
     return std::nullopt;
-  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & ~3);
+  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
 }
 
 /// Convert a LevelFormat to its corresponding LevelType with the given
 /// properties. Returns std::nullopt when the properties are not applicable
 /// for the input level format.
 constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
-                                                  bool unique) {
-  auto lt = static_cast<LevelType>(static_cast<uint64_t>(lf) |
-                                   (ordered ? 0 : 2) | (unique ? 0 : 1));
+                                                  bool unique, uint64_t n = 0,
+                                                  uint64_t m = 0) {
+  uint64_t newN = n << 32;
+  uint64_t newM = m << 40;
+  auto lt =
+      static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
+                             (unique ? 0 : 1) | newN | newM);
   return isValidLT(lt) ? std::optional(lt) : std::nullopt;
 }
 
 //
-// Ensure the above methods work as indended.
+// Ensure the above methods work as intended.
 //
 
 static_assert(
@@ -341,7 +377,7 @@ static_assert(
          LevelFormat::LooseCompressed &&
      *getLevelFormat(LevelType::LooseCompressedNuNo) ==
          LevelFormat::LooseCompressed &&
-     *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
+     *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
     "getLevelFormat conversion is broken");
 
 static_assert(
@@ -373,13 +409,28 @@ static_assert(
          LevelType::LooseCompressedNo &&
      *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
          LevelType::LooseCompressedNuNo &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
-         LevelType::TwoOutOfFour),
+     buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
+     buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
+     buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
+     *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
     "buildLevelType conversion is broken");
 
+static_assert(
+    (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
+     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
+     getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
+     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
+    "getN/M conversion is broken");
+
+static_assert(
+    (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
+                      2, 4) &&
+     isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
+                      8, 10) &&
+     !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
+                       2, 4)),
+    "isValidNOutOfMLT definition is broken");
+
 static_assert(
     (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
      isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
@@ -391,7 +442,7 @@ static_assert(
      isValidLT(LevelType::LooseCompressedNu) &&
      isValidLT(LevelType::LooseCompressedNo) &&
      isValidLT(LevelType::LooseCompressedNuNo) &&
-     isValidLT(LevelType::TwoOutOfFour)),
+     isValidLT(LevelType::NOutOfM)),
     "isValidLT definition is broken");
 
 static_assert((isDenseLT(LevelType::Dense) &&
@@ -407,7 +458,7 @@ static_assert((isDenseLT(LevelType::Dense) &&
                !isDenseLT(LevelType::LooseCompressedNu) &&
                !isDenseLT(LevelType::LooseCompressedNo) &&
                !isDenseLT(LevelType::LooseCompressedNuNo) &&
-               !isDenseLT(LevelType::TwoOutOfFour)),
+               !isDenseLT(LevelType::NOutOfM)),
               "isDenseLT definition is broken");
 
 static_assert((!isCompressedLT(LevelType::Dense) &&
@@ -423,7 +474,7 @@ static_assert((!isCompressedLT(LevelType::Dense) &&
                !isCompressedLT(LevelType::LooseCompressedNu) &&
                !isCompressedLT(LevelType::LooseCompressedNo) &&
                !isCompressedLT(LevelType::LooseCompressedNuNo) &&
-               !isCompressedLT(LevelType::TwoOutOfFour)),
+               !isCompressedLT(LevelType::NOutOfM)),
               "isCompressedLT definition is broken");
 
 static_assert((!isSingletonLT(LevelType::Dense) &&
@@ -439,7 +490,7 @@ static_assert((!isSingletonLT(LevelType::Dense) &&
                !isSingletonLT(LevelType::LooseCompressedNu) &&
                !isSingletonLT(LevelType::LooseCompressedNo) &&
                !isSingletonLT(LevelType::LooseCompressedNuNo) &&
-               !isSingletonLT(LevelType::TwoOutOfFour)),
+               !isSingletonLT(LevelType::NOutOfM)),
               "isSingletonLT definition is broken");
 
 static_assert((!isLooseCompressedLT(LevelType::Dense) &&
@@ -455,24 +506,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) &&
                isLooseCompressedLT(LevelType::LooseCompressedNu) &&
                isLooseCompressedLT(LevelType::LooseCompressedNo) &&
                isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
-               !isLooseCompressedLT(LevelType::TwoOutOfFour)),
+               !isLooseCompressedLT(LevelType::NOutOfM)),
               "isLooseCompressedLT definition is broken");
 
-static_assert((!is2OutOf4LT(LevelType::Dense) &&
-               !is2OutOf4LT(LevelType::Compressed) &&
-               !is2OutOf4LT(LevelType::CompressedNu) &&
-               !is2OutOf4LT(LevelType::CompressedNo) &&
-               !is2OutOf4LT(LevelType::CompressedNuNo) &&
-               !is2OutOf4LT(LevelType::Singleton) &&
-               !is2OutOf4LT(LevelType::SingletonNu) &&
-               !is2OutOf4LT(LevelType::SingletonNo) &&
-               !is2OutOf4LT(LevelType::SingletonNuNo) &&
-               !is2OutOf4LT(LevelType::LooseCompressed) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNu) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNo) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
-               is2OutOf4LT(LevelType::TwoOutOfFour)),
-              "is2OutOf4LT definition is broken");
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+               !isNOutOfMLT(LevelType::Compressed) &&
+               !isNOutOfMLT(LevelType::CompressedNu) &&
+               !isNOutOfMLT(LevelType::CompressedNo) &&
+               !isNOutOfMLT(LevelType::CompressedNuNo) &&
+               !isNOutOfMLT(LevelType::Singleton) &&
+               !isNOutOfMLT(LevelType::SingletonNu) &&
+               !isNOutOfMLT(LevelType::SingletonNo) &&
+               !isNOutOfMLT(LevelType::SingletonNuNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressed) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+               isNOutOfMLT(LevelType::NOutOfM)),
+              "isNOutOfMLT definition is broken");
 
 static_assert((isOrderedLT(LevelType::Dense) &&
                isOrderedLT(LevelType::Compressed) &&
@@ -487,7 +538,7 @@ static_assert((isOrderedLT(LevelType::Dense) &&
                isOrderedLT(LevelType::LooseCompressedNu) &&
                !isOrderedLT(LevelType::LooseCompressedNo) &&
                !isOrderedLT(LevelType::LooseCompressedNuNo) &&
-               isOrderedLT(LevelType::TwoOutOfFour)),
+               isOrderedLT(LevelType::NOutOfM)),
               "isOrderedLT definition is broken");
 
 static_assert((isUniqueLT(LevelType::Dense) &&
@@ -503,7 +554,7 @@ static_assert((isUniqueLT(LevelType::Dense) &&
                !isUniqueLT(LevelType::LooseCompressedNu) &&
                isUniqueLT(LevelType::LooseCompressedNo) &&
                !isUniqueLT(LevelType::LooseCompressedNuNo) &&
-               isUniqueLT(LevelType::TwoOutOfFour)),
+               isUniqueLT(LevelType::NOutOfM)),
               "isUniqueLT definition is broken");
 
 /// Bit manipulations for affine encoding.

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 12c1068ae1f546..5b3b971f9a7f23 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     - **compressed** : only nonzeros along this level are stored
     - **loose_compressed** : as compressed, but allows for free space between regions
     - **singleton** : a variant of the compressed format, where coordinates have no siblings
-    - **block2_4** : the compression uses a 2:4 encoding per 1x4 block
+    - **structured[n, m]** : the compression uses a n:m encoding
+      (viz. n out of m consecutive elements are nonzero)
 
     For a compressed level, each position interval is represented in a compact
     way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
@@ -374,7 +375,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedLT(getLvlType(l)); }
     bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonLT(getLvlType(l)); }
     bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedLT(getLvlType(l)); }
-    bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4LT(getLvlType(l)); }
+    bool isNOutOfMLvl(::mlir::sparse_tensor::Level l) const { return isNOutOfMLT(getLvlType(l)); }
     bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedLT(getLvlType(l)); }
     bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueLT(getLvlType(l)); }
 

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4c98129744bcd9..4e2b85d35c1ac1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -291,7 +291,7 @@ class SparseTensorType {
     return isLooseCompressedLT(getLvlType(l));
   }
   bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); }
-  bool is2OutOf4Lvl(Level l) const { return is2OutOf4LT(getLvlType(l)); }
+  bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); }
   bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); }
   bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); }
   bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); }

diff  --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4a34bb2e003e88..490ef3071af1b7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -510,7 +510,7 @@ class Merger {
     if (isLvlWithNonTrivialIdxExp(b)) {
       auto lt = getLoopDependentLevelType(b);
       return isCompressedLT(lt) || isSingletonLT(lt) ||
-             isLooseCompressedLT(lt) || is2OutOf4LT(lt);
+             isLooseCompressedLT(lt) || isNOutOfMLT(lt);
     }
     return false;
   }

diff  --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 01c5f2382ffe69..14182172f4f622 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -123,8 +123,8 @@ class SparseTensorStorageBase {
   /// Safely checks if the level uses singleton storage.
   bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
 
-  /// Safely checks if the level uses 2 out of 4 storage.
-  bool is2OutOf4Lvl(uint64_t l) const { return is2OutOf4LT(getLvlType(l)); }
+  /// Safely checks if the level uses n out of m storage.
+  bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
 
   /// Safely checks if the level is ordered.
   bool isOrderedLvl(uint64_t l) const { return isOrderedLT(getLvlType(l)); }
@@ -450,7 +450,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) {
     if (!isDenseLvl(lvl)) {
       assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) ||
-             isSingletonLvl(lvl) || is2OutOf4Lvl(lvl));
+             isSingletonLvl(lvl) || isNOutOfMLvl(lvl));
       coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
     } else { // Dense level.
       assert(crd >= full && "Coordinate was already filled");
@@ -473,7 +473,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
       return positions[l][parentSz];
     if (isLooseCompressedLvl(l))
       return positions[l][2 * parentSz - 1];
-    if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+    if (isSingletonLvl(l) || isNOutOfMLvl(l))
       return parentSz; // new size same as the parent
     assert(isDenseLvl(l));
     return parentSz * getLvlSize(l);
@@ -527,7 +527,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
       uint64_t pos = coordinates[l].size();
       positions[l].insert(positions[l].end(), 2 * count,
                           detail::checkOverflowCast<P>(pos));
-    } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+    } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
       return; // Nothing to finalize.
     } else {  // Dense dimension.
       assert(isDenseLvl(l));
@@ -624,7 +624,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
         lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
         toCOO(pos, l + 1, dimCoords);
       }
-    } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+    } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
       assert(parentPos < coordinates[l].size());
       lvlCursor[l] = static_cast<uint64_t>(coordinates[l][parentPos]);
       toCOO(parentPos, l + 1, dimCoords);
@@ -721,8 +721,8 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     } else if (isSingletonLvl(l)) {
       coordinates[l].reserve(sz);
       sz = 1;
-    } else if (is2OutOf4Lvl(l)) {
-      assert(l == lvlRank - 1 && "unexpected 2:4 usage");
+    } else if (isNOutOfMLvl(l)) {
+      assert(l == lvlRank - 1 && "unexpected n:m usage");
       sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
       coordinates[l].reserve(sz);
       values.reserve(sz);
@@ -791,8 +791,8 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
       }
     } else if (isSingletonLvl(l)) {
       assert(0 && "general singleton not supported yet");
-    } else if (is2OutOf4Lvl(l)) {
-      assert(0 && "2Out4 not supported yet");
+    } else if (isNOutOfMLvl(l)) {
+      assert(0 && "n ouf of m not supported yet");
     } else {
       assert(isDenseLvl(l));
     }

diff  --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 698367a1addaff..607534c6156439 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -25,7 +25,7 @@ using namespace mlir::python::adaptors;
 static void populateDialectSparseTensorSubmodule(const py::module &m) {
   py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
       .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
-      .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
+      .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
       .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
       .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
       .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)

diff  --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e4534ad132385f..a34b9a29b0e90a 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -20,25 +20,36 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
                                       mlir::sparse_tensor::SparseTensorDialect)
 
 // Ensure the C-API enums are int-castable to C++ equivalents.
-static_assert(static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
-                      static_cast<int>(LevelType::Dense) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
-                      static_cast<int>(LevelType::Compressed) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
-                      static_cast<int>(LevelType::CompressedNu) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
-                      static_cast<int>(LevelType::CompressedNo) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
-                      static_cast<int>(LevelType::CompressedNuNo) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
-                      static_cast<int>(LevelType::Singleton) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
-                      static_cast<int>(LevelType::SingletonNu) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
-                      static_cast<int>(LevelType::SingletonNo) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
-                      static_cast<int>(LevelType::SingletonNuNo),
-              "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
+static_assert(
+    static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
+            static_cast<int>(LevelType::Dense) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
+            static_cast<int>(LevelType::Compressed) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
+            static_cast<int>(LevelType::CompressedNu) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
+            static_cast<int>(LevelType::CompressedNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
+            static_cast<int>(LevelType::CompressedNuNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
+            static_cast<int>(LevelType::Singleton) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
+            static_cast<int>(LevelType::SingletonNu) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
+            static_cast<int>(LevelType::SingletonNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
+            static_cast<int>(LevelType::SingletonNuNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
+            static_cast<int>(LevelType::LooseCompressed) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) ==
+            static_cast<int>(LevelType::LooseCompressedNu) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) ==
+            static_cast<int>(LevelType::LooseCompressedNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) ==
+            static_cast<int>(LevelType::LooseCompressedNuNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
+            static_cast<int>(LevelType::NOutOfM),
+    "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
 
 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
   return isa<SparseTensorEncodingAttr>(unwrap(attr));

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index eb7ea63a4e88b8..752d6e6481dfee 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -29,12 +29,21 @@ using namespace mlir::sparse_tensor::ir_detail;
 // `LvlTypeParser` implementation.
 //===----------------------------------------------------------------------===//
 
-FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
+FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   StringRef base;
   const auto loc = parser.getCurrentLocation();
   ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
            "expected valid level format (e.g. dense, compressed or singleton)")
-  uint8_t properties = 0;
+  uint64_t properties = 0;
+  SmallVector<unsigned> structure;
+
+  if (base.compare("structured") == 0) {
+    ParseResult res = parser.parseCommaSeparatedList(
+        mlir::OpAsmParser::Delimiter::OptionalSquare,
+        [&]() -> ParseResult { return parseStructure(parser, &structure); },
+        " in block n out of m");
+    FAILURE_IF_FAILED(res)
+  }
 
   ParseResult res = parser.parseCommaSeparatedList(
       mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -44,15 +53,20 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
 
   // Set the base bit for properties.
   if (base.compare("dense") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::Dense);
+    properties |= static_cast<uint64_t>(LevelFormat::Dense);
   } else if (base.compare("compressed") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::Compressed);
-  } else if (base.compare("block2_4") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::TwoOutOfFour);
+    properties |= static_cast<uint64_t>(LevelFormat::Compressed);
+  } else if (base.compare("structured") == 0) {
+    if (structure.size() != 2) {
+      parser.emitError(loc, "expected exactly 2 structure sizes");
+      return failure();
+    }
+    properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
+    properties |= nToBits(structure[0]) | mToBits(structure[1]);
   } else if (base.compare("loose_compressed") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::LooseCompressed);
+    properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
   } else if (base.compare("singleton") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::Singleton);
+    properties |= static_cast<uint64_t>(LevelFormat::Singleton);
   } else {
     parser.emitError(loc, "unknown level format: ") << base;
     return failure();
@@ -64,15 +78,15 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
 }
 
 ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
-                                         uint8_t *properties) const {
+                                         uint64_t *properties) const {
   StringRef strVal;
   auto loc = parser.getCurrentLocation();
   ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
            "expected valid level property (e.g. nonordered, nonunique or high)")
   if (strVal.compare("nonunique") == 0) {
-    *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonunique);
+    *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
   } else if (strVal.compare("nonordered") == 0) {
-    *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonordered);
+    *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
   } else {
     parser.emitError(loc, "unknown level property: ") << strVal;
     return failure();
@@ -80,4 +94,22 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   return success();
 }
 
+ParseResult
+LvlTypeParser::parseStructure(AsmParser &parser,
+                              SmallVector<unsigned> *structure) const {
+  int intVal;
+  auto loc = parser.getCurrentLocation();
+  OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
+  if (intValParseResult.has_value()) {
+    if (failed(*intValParseResult)) {
+      parser.emitError(loc, "failed to parse block size");
+      return failure();
+    }
+    structure->push_back(intVal);
+    return success();
+  }
+  parser.emitError(loc, "expected valid integer for block size");
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 5e2f11b75d4da6..6a13112195d440 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -18,10 +18,12 @@ namespace ir_detail {
 class LvlTypeParser {
 public:
   LvlTypeParser() = default;
-  FailureOr<uint8_t> parseLvlType(AsmParser &parser) const;
+  FailureOr<uint64_t> parseLvlType(AsmParser &parser) const;
 
 private:
-  ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const;
+  ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
+  ParseResult parseStructure(AsmParser &parser,
+                             SmallVector<unsigned> *structure) const;
 };
 
 } // namespace ir_detail

diff  --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 27125bc7ed45e3..67b1d7974fa259 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -613,16 +613,28 @@ void SparseTensorEncodingAttr::printDimensions(
   }
 }
 
+std::string getNOutOfMString(LevelType lt) {
+  if (isNOutOfMLT(lt)) {
+    unsigned n = getN(lt);
+    unsigned m = getM(lt);
+    auto output = "[" + std::to_string(n) + ", " + std::to_string(m) + "]";
+    return output;
+  }
+  return "";
+}
+
 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
                                            ArrayRef<LevelType> lvlTypes) const {
   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
     map.getResult(i).print(printer.getStream());
-    printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
+    printer << " : " << toMLIRString(lvlTypes[i])
+            << getNOutOfMString(lvlTypes[i]) << ", ";
   }
   if (map.getNumResults() >= 1) {
     auto lastIndex = map.getNumResults() - 1;
     map.getResult(lastIndex).print(printer.getStream());
-    printer << " : " << toMLIRString(lvlTypes[lastIndex]);
+    printer << " : " << toMLIRString(lvlTypes[lastIndex])
+            << getNOutOfMString(lvlTypes[lastIndex]);
   }
 }
 

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index dd3af9d8354123..3f352c868467fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -451,7 +451,7 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
 /// Test for 2:4 matrix with suitable metadata.
 static bool isAdmissible24(SparseTensorType &aTp) {
   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
-         aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
+         aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
 }
 
 /// Test for conversion into 2:4 matrix.

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 491501a3381b9c..d4459c6ea1e521 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -130,7 +130,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
       createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
                      /*value=*/posZero, /*repeat=*/linear);
       return;
-    } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+    } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
       return; // nothing to do
     }
     // Keep compounding the size, but nothing needs to be initialized
@@ -409,7 +409,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
       }
     } else {
       assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
-             is2OutOf4LT(lt));
+             isNOutOfMLT(lt));
     }
   }
 }
@@ -488,7 +488,7 @@ class SparseInsertGenerator
         }
         parentPos =
             genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
-      } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+      } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
         // Create:
         //   coordinates[lvl].push_back(coords[lvl])
         //   positions[lvl] = positions[lvl-1]

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index ab38ab5cc3f78f..8f2ae60b311f7c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -891,7 +891,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
         assert(curr == env.merger().loop(b));
         Value clause;
         if (isCompressedLT(lt) || isSingletonLT(lt) ||
-            isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
+            isLooseCompressedLT(lt) || isNOutOfMLT(lt)) {
           assert(lvl.has_value());
           const Value crd = env.emitter().getCoord(tid, *lvl);
           const Value lvar = env.getLoopVar(curr);

diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 4ba9ecbe03c72d..c85f8204ba7527 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -139,18 +139,19 @@ class SingletonLevel : public SparseLevel {
   }
 };
 
-class TwoOutFourLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel {
 public:
-  TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
-                  Value crdBuffer)
+  NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+               Value crdBuffer)
       : SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
 
   ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
                         Value max) const override {
-    assert(max == nullptr && isUnique() && "2:4 level can not be non-unique.");
-    // Each 2:4 blk has exactly two specified elements.
-    Value posLo = MULI(p, C_IDX(2));
-    return {posLo, ADDI(posLo, C_IDX(2))};
+    assert(max == nullptr && isUnique() && "n:m level can not be non-unique.");
+    // Each n:m blk has exactly n specified elements.
+    auto n = getN(lt);
+    Value posLo = MULI(p, C_IDX(n));
+    return {posLo, ADDI(posLo, C_IDX(n))};
   }
 };
 
@@ -1291,9 +1292,9 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
     Value crd = genToCoordinates(b, l, t, lvl);
     return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
   }
-  case LevelFormat::TwoOutOfFour: {
+  case LevelFormat::NOutOfM: {
     Value crd = genToCoordinates(b, l, t, lvl);
-    return std::make_unique<TwoOutFourLevel>(tid, lvl, lt, sz, crd);
+    return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
   }
   }
   llvm_unreachable("unrecognizable level format");

diff  --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 6cdf5f8c0168be..96537cbb0c4836 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -489,7 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
       const auto lt = getLvlType(b);
       if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
-          !isLooseCompressedLT(lt) && !is2OutOf4LT(lt)) {
+          !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
         if (reset)
           simple.reset(b);
         reset = true;
@@ -670,7 +670,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
   for (TensorLoopId b : bits.set_bits()) {
     const auto lt = getLvlType(b);
     if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-        is2OutOf4LT(lt))
+        isNOutOfMLT(lt))
       return true;
   }
   return hasSparseIdxReduction(bits);

diff  --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 0c7b3a228a65cf..9e8b240899d808 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -45,7 +45,7 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
   for (uint64_t l = 0; l < lvlRank; l++) {
     assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
     assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
-           isSingletonLvl(l) || is2OutOf4Lvl(l));
+           isSingletonLvl(l) || isNOutOfMLvl(l));
   }
 }
 

diff  --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 2c6ad559f19d0c..a8b9f9048d5912 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -38,9 +38,9 @@ static int testRoundtripEncoding(MlirContext ctx) {
       mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr);
   // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
   mlirAffineMapDump(dimToLvl);
-  // CHECK: level_type: 4
-  // CHECK: level_type: 8
-  // CHECK: level_type: 8
+  // CHECK: level_type: 65536
+  // CHECK: level_type: 131072
+  // CHECK: level_type: 131072
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);

diff  --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index 6fe7ec906f30e9..8293169049ca61 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -4,7 +4,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : structured[2, 4]
   )
 }>
 

diff  --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 20702bb9850284..64520638b253df 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -207,12 +207,12 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : structured[2, 4]
   ),
   crdWidth = 8  // we would even like just 2-bits
 }>
 
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block2_4), crdWidth = 8 }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), crdWidth = 8 }>
 // CHECK-LABEL: func private @NV_24(
 // CHECK-SAME: tensor<?x?xf64, #[[$NV_24]]>
 func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
@@ -226,11 +226,11 @@ func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
   ( i            : dense,
     j            : dense,
     k floordiv 4 : dense,
-    k mod 4      : block2_4
+    k mod 4      : structured[2, 4]
   )
 }>
 
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : structured[2, 4]) }>
 // CHECK-LABEL: func private @NV_24(
 // CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
 func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
@@ -244,13 +244,31 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
   ( i            : dense,
     k floordiv 4 : dense,
     j            : dense,
-    k mod 4      : block2_4
+    k mod 4      : structured[2, 4]
   )
 }>
 
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : structured[2, 4]) }>
 // CHECK-LABEL: func private @NV_24(
 // CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
 func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
   return
 }
+
+// -----
+
+#NOutOfM = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    k floordiv 8 : dense,
+    j            : dense,
+    k mod 8      : structured[5, 8]
+  )
+}>
+
+// CHECK-DAG: #[[$NOutOfM:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 8 : dense, d1 : dense, d2 mod 8 : structured[5, 8]) }>
+// CHECK-LABEL: func private @NOutOfM(
+// CHECK-SAME: tensor<?x?x?xf64, #[[$NOutOfM]]>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+  return
+}

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 7c494b2bcfe1d1..d04fbe8ed5c220 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,7 +14,7 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 8 : i64
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 131072 : i64
 // CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
 // CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
 // CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 4bc080fc538fc6..e47ac46597b77a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -59,7 +59,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : structured[2, 4]
   ),
 }>
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
index df5b48a3b6ece8..ec5c7580657cd7 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
@@ -41,7 +41,7 @@
 #NV_24 = #sparse_tensor.encoding<{
   map = ( i, j ) -> ( i            : dense,
                       j floordiv 4 : dense,
-                      j mod 4      : block2_4),
+                      j mod 4      : structured[2, 4]),
   crdWidth = 8
 }>
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index 17b50b46d073ae..b0f63f12c2d579 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -20,7 +20,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : structured[2, 4]
   )
 }>
 

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index eb99a027a88600..311cb607b4293c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -20,7 +20,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : structured[2, 4]
   )
 }>
 

diff  --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 946a224dab064a..412c5797067b7a 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [8]
+        # CHECK: lvl_types: [131072]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0) -> (d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -70,7 +70,7 @@ def testEncodingAttr2D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [4, 8]
+        # CHECK: lvl_types: [65536, 131072]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")


        


More information about the Mlir-commits mailing list