[Mlir-commits] [mlir] [mlir][sparse] Expand LevelType to 64 bits and implement parsing n out of m (PR #79935)

Yinying Li llvmlistbot at llvm.org
Thu Feb 1 01:05:41 PST 2024


https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/79935

>From 32280ec068f51ccc44977c9ce296b7b83c3da58f Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 01:01:52 +0000
Subject: [PATCH 01/10] [mlir][sparse] Expand LevelType to 64 bit and implement
 n out of m

---
 mlir/include/mlir-c/Dialect/SparseTensor.h    |  28 +--
 .../mlir/Dialect/SparseTensor/IR/Enums.h      | 225 +++++++++++-------
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |   4 +-
 .../SparseTensor/IR/SparseTensorType.h        |   2 +-
 .../mlir/Dialect/SparseTensor/Utils/Merger.h  |   2 +-
 .../ExecutionEngine/SparseTensor/Storage.h    |  14 +-
 .../Bindings/Python/DialectSparseTensor.cpp   |   2 +-
 mlir/lib/CAPI/Dialect/SparseTensor.cpp        |  49 ++--
 .../IR/Detail/DimLvlMapParser.cpp             |   2 +
 .../SparseTensor/IR/Detail/LvlTypeParser.cpp  |  55 ++++-
 .../SparseTensor/IR/Detail/LvlTypeParser.h    |   6 +-
 .../Transforms/SparseGPUCodegen.cpp           |   2 +-
 .../Transforms/SparseTensorCodegen.cpp        |   6 +-
 .../Transforms/Sparsification.cpp             |   2 +-
 .../Transforms/Utils/CodegenUtils.h           |   2 +-
 .../Transforms/Utils/SparseTensorLevel.cpp    |   2 +-
 .../lib/Dialect/SparseTensor/Utils/Merger.cpp |   4 +-
 .../ExecutionEngine/SparseTensor/Storage.cpp  |   2 +-
 mlir/test/CAPI/sparse_tensor.c                |   6 +-
 .../SparseTensor/GPU/gpu_matmul24_lib.mlir    |   2 +-
 .../test/Dialect/SparseTensor/conversion.mlir |  16 +-
 .../SparseTensor/roundtrip_encoding.mlir      |  12 +-
 .../SparseTensor/sparse_fill_zero.mlir        |  12 +-
 .../SparseTensor/CPU/sparse_block_matmul.mlir |   2 +-
 .../Dialect/SparseTensor/CPU/sparse_ds.mlir   |   2 +-
 .../CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir   |   2 +-
 .../CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir |   2 +-
 .../python/dialects/sparse_tensor/dialect.py  | 148 ++++++------
 28 files changed, 358 insertions(+), 255 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 41d024db04964..5fc1f51452482 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,20 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
 enum MlirSparseTensorLevelType {
-  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4,                   // 0b00001_00
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8,              // 0b00010_00
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9,           // 0b00010_01
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10,          // 0b00010_10
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11,       // 0b00010_11
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16,              // 0b00100_00
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17,           // 0b00100_01
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18,           // 0b00100_10
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19,        // 0b00100_11
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32,       // 0b01000_00
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33,    // 0b01000_01
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34,    // 0b01000_10
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
-  MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64,        // 0b10000_00
+  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536,                   // 0x00_00_0001_0000
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072,             // 0x00_00_0002_0000
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073,          // 0x00_00_0002_0001
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074,          // 0x00_00_0002_0002
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075,       // 0x00_00_0002_0003
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144,              // 0x00_00_0004_0000
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145,           // 0x00_00_0004_0001
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146,           // 0x00_00_0004_0002
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147,        // 0x00_00_0004_0003
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288,       // 0x00_00_0008_0000
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289,    // 0x00_00_0008_0001
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290,    // 0x00_00_0008_0002
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003
+  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576,            // 0x00_00_0010_0000
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ac91bfa5ae622..6ddc9326179fe 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,9 +154,10 @@ enum class Action : uint32_t {
 
 /// This enum defines all the sparse representations supportable by
 /// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton, loose_compressed,
-/// two-out-of-four) as well as the "properties" (ordered, unique). The
-/// encoding is chosen for performance of the runtime library, and thus may
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m for
+/// NOutOfM level type.
+/// The encoding is chosen for performance of the runtime library, and thus may
 /// change in future versions; consequently, client code should use the
 /// predicate functions defined below, rather than relying on knowledge
 /// about the particular binary encoding.
@@ -165,41 +166,74 @@ enum class Action : uint32_t {
 /// where we need to store an undefined or indeterminate `LevelType`.
 /// It should not be used externally, since it does not indicate an
 /// actual/representable format.
-enum class LevelType : uint8_t {
-  Undef = 0,                // 0b00000_00
-  Dense = 4,                // 0b00001_00
-  Compressed = 8,           // 0b00010_00
-  CompressedNu = 9,         // 0b00010_01
-  CompressedNo = 10,        // 0b00010_10
-  CompressedNuNo = 11,      // 0b00010_11
-  Singleton = 16,           // 0b00100_00
-  SingletonNu = 17,         // 0b00100_01
-  SingletonNo = 18,         // 0b00100_10
-  SingletonNuNo = 19,       // 0b00100_11
-  LooseCompressed = 32,     // 0b01000_00
-  LooseCompressedNu = 33,   // 0b01000_01
-  LooseCompressedNo = 34,   // 0b01000_10
-  LooseCompressedNuNo = 35, // 0b01000_11
-  TwoOutOfFour = 64,        // 0b10000_00
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
+enum class LevelType : uint64_t {
+  Undef = 0,                    // 0x00_00_0000_0000
+  Dense = 65536,                // 0x00_00_0001_0000
+  Compressed = 131072,          // 0x00_00_0002_0000
+  CompressedNu = 131073,        // 0x00_00_0002_0001
+  CompressedNo = 131074,        // 0x00_00_0002_0002
+  CompressedNuNo = 131075,      // 0x00_00_0002_0003
+  Singleton = 262144,           // 0x00_00_0004_0000
+  SingletonNu = 262145,         // 0x00_00_0004_0001
+  SingletonNo = 262146,         // 0x00_00_0004_0002
+  SingletonNuNo = 262147,       // 0x00_00_0004_0003
+  LooseCompressed = 524288,     // 0x00_00_0008_0000
+  LooseCompressedNu = 524289,   // 0x00_00_0008_0001
+  LooseCompressedNo = 524290,   // 0x00_00_0008_0002
+  LooseCompressedNuNo = 524291, // 0x00_00_0008_0003
+  NOutOfM = 1048576,            // 0x00_00_0010_0000
 };
 
 /// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
-  Dense = 4,            // 0b00001_00
-  Compressed = 8,       // 0b00010_00
-  Singleton = 16,       // 0b00100_00
-  LooseCompressed = 32, // 0b01000_00
-  TwoOutOfFour = 64,    // 0b10000_00
+enum class LevelFormat : uint64_t {
+  Dense = 65536,            // 0x0001_0000
+  Compressed = 131072,      // 0x0002_0000
+  Singleton = 262144,       // 0x0004_0000
+  LooseCompressed = 524288, // 0x0008_0000
+  NOutOfM = 1048576,        // 0x0010_0000
 };
 
 /// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint8_t {
-  Nonunique = 1,  // 0b00000_01
-  Nonordered = 2, // 0b00000_10
+enum class LevelPropertyNondefault : uint64_t {
+  Nonunique = 1,  // 0x0001
+  Nonordered = 2, // 0x0002
 };
 
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+}
+
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+  return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
+
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+  return ((static_cast<uint64_t>(lt) & 0x100000) ==
+          static_cast<uint64_t>(LevelType::NOutOfM));
+}
+
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+  return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
+
 /// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lt) {
+std::string toMLIRString(LevelType lt) {
   switch (lt) {
   case LevelType::Undef:
     return "undef";
@@ -229,21 +263,28 @@ constexpr const char *toMLIRString(LevelType lt) {
     return "loose_compressed(nonordered)";
   case LevelType::LooseCompressedNuNo:
     return "loose_compressed(nonunique, nonordered)";
-  case LevelType::TwoOutOfFour:
-    return "block2_4";
+  default:
+    // If NOutOfM bit is set, print the [n, m] sizes.
+    if (isNOutOfMLT(lt)) {
+      unsigned n = getN(lt);
+      unsigned m = getM(lt);
+      return std::string("block[") + std::to_string(n) + ", " +
+             std::to_string(m) + "]";
+    }
   }
   return "";
 }
 
 /// Check that the `LevelType` contains a valid (possibly undefined) value.
 constexpr bool isValidLT(LevelType lt) {
-  const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
-  const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
-  // If undefined or dense, then must be unique and ordered.
+  const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+  const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+  // If undefined/dense/NOutOfM, then must be unique and ordered.
   // Otherwise, the format must be one of the known ones.
-  return (formatBits <= 1 || formatBits == 16)
+  return (formatBits <= 0x10000 || formatBits == 0x100000)
              ? (propertyBits == 0)
-             : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+             : (formatBits == 0x20000 || formatBits == 0x40000 ||
+                formatBits == 0x80000);
 }
 
 /// Check if the `LevelType` is the special undefined value.
@@ -251,33 +292,28 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
 
 /// Check if the `LevelType` is dense (regardless of properties).
 constexpr bool isDenseLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Dense);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Dense);
 }
 
 /// Check if the `LevelType` is compressed (regardless of properties).
 constexpr bool isCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Compressed);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Compressed);
 }
 
 /// Check if the `LevelType` is singleton (regardless of properties).
 constexpr bool isSingletonLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::Singleton);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::Singleton);
 }
 
 /// Check if the `LevelType` is loose compressed (regardless of properties).
 constexpr bool isLooseCompressedLT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::LooseCompressed);
+  return (static_cast<uint64_t>(lt) & ~0xffff) ==
+         static_cast<uint64_t>(LevelType::LooseCompressed);
 }
 
-/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
-constexpr bool is2OutOf4LT(LevelType lt) {
-  return (static_cast<uint8_t>(lt) & ~3) ==
-         static_cast<uint8_t>(LevelType::TwoOutOfFour);
-}
 
 /// Check if the `LevelType` needs positions array.
 constexpr bool isWithPosLT(LevelType lt) {
@@ -287,17 +323,17 @@ constexpr bool isWithPosLT(LevelType lt) {
 /// Check if the `LevelType` needs coordinates array.
 constexpr bool isWithCrdLT(LevelType lt) {
   return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-         is2OutOf4LT(lt);
+         isNOutOfMLT(lt);
 }
 
 /// Check if the `LevelType` is ordered (regardless of storage format).
 constexpr bool isOrderedLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 2);
+  return !(static_cast<uint64_t>(lt) & 2);
 }
 
 /// Check if the `LevelType` is unique (regardless of storage format).
 constexpr bool isUniqueLT(LevelType lt) {
-  return !(static_cast<uint8_t>(lt) & 1);
+  return !(static_cast<uint64_t>(lt) & 1);
 }
 
 /// Convert a LevelType to its corresponding LevelFormat.
@@ -305,21 +341,25 @@ constexpr bool isUniqueLT(LevelType lt) {
 constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
   if (lt == LevelType::Undef)
     return std::nullopt;
-  return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
+  return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
 }
 
 /// Convert a LevelFormat to its corresponding LevelType with the given
 /// properties. Returns std::nullopt when the properties are not applicable
 /// for the input level format.
 constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
-                                                  bool unique) {
-  auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
-                                   (ordered ? 0 : 2) | (unique ? 0 : 1));
+                                                  bool unique, uint64_t n = 0,
+                                                  uint64_t m = 0) {
+  uint64_t newN = n << 32;
+  uint64_t newM = m << 40;
+  auto lt =
+      static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
+                             (unique ? 0 : 1) | newN | newM);
   return isValidLT(lt) ? std::optional(lt) : std::nullopt;
 }
 
 //
-// Ensure the above methods work as indended.
+// Ensure the above methods work as intended.
 //
 
 static_assert(
@@ -341,7 +381,7 @@ static_assert(
          LevelFormat::LooseCompressed &&
      *getLevelFormat(LevelType::LooseCompressedNuNo) ==
          LevelFormat::LooseCompressed &&
-     *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
+     *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
     "getLevelFormat conversion is broken");
 
 static_assert(
@@ -373,13 +413,28 @@ static_assert(
          LevelType::LooseCompressedNo &&
      *buildLevelType(LevelFormat::LooseCompressed, false, false) ==
          LevelType::LooseCompressedNuNo &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
-     buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
-     *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
-         LevelType::TwoOutOfFour),
+     buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
+     buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
+     buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
+     *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
     "buildLevelType conversion is broken");
 
+static_assert(
+    (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
+     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
+     getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
+     getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
+    "getN/M conversion is broken");
+
+static_assert(
+    (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
+                      2, 4) &&
+     isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
+                      8, 10) &&
+     !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
+                       2, 4)),
+    "isValidNOutOfMLT definition is broken");
+
 static_assert(
     (isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
      isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
@@ -391,7 +446,7 @@ static_assert(
      isValidLT(LevelType::LooseCompressedNu) &&
      isValidLT(LevelType::LooseCompressedNo) &&
      isValidLT(LevelType::LooseCompressedNuNo) &&
-     isValidLT(LevelType::TwoOutOfFour)),
+     isValidLT(LevelType::NOutOfM)),
     "isValidLT definition is broken");
 
 static_assert((isDenseLT(LevelType::Dense) &&
@@ -407,7 +462,7 @@ static_assert((isDenseLT(LevelType::Dense) &&
                !isDenseLT(LevelType::LooseCompressedNu) &&
                !isDenseLT(LevelType::LooseCompressedNo) &&
                !isDenseLT(LevelType::LooseCompressedNuNo) &&
-               !isDenseLT(LevelType::TwoOutOfFour)),
+               !isDenseLT(LevelType::NOutOfM)),
               "isDenseLT definition is broken");
 
 static_assert((!isCompressedLT(LevelType::Dense) &&
@@ -423,7 +478,7 @@ static_assert((!isCompressedLT(LevelType::Dense) &&
                !isCompressedLT(LevelType::LooseCompressedNu) &&
                !isCompressedLT(LevelType::LooseCompressedNo) &&
                !isCompressedLT(LevelType::LooseCompressedNuNo) &&
-               !isCompressedLT(LevelType::TwoOutOfFour)),
+               !isCompressedLT(LevelType::NOutOfM)),
               "isCompressedLT definition is broken");
 
 static_assert((!isSingletonLT(LevelType::Dense) &&
@@ -439,7 +494,7 @@ static_assert((!isSingletonLT(LevelType::Dense) &&
                !isSingletonLT(LevelType::LooseCompressedNu) &&
                !isSingletonLT(LevelType::LooseCompressedNo) &&
                !isSingletonLT(LevelType::LooseCompressedNuNo) &&
-               !isSingletonLT(LevelType::TwoOutOfFour)),
+               !isSingletonLT(LevelType::NOutOfM)),
               "isSingletonLT definition is broken");
 
 static_assert((!isLooseCompressedLT(LevelType::Dense) &&
@@ -455,24 +510,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) &&
                isLooseCompressedLT(LevelType::LooseCompressedNu) &&
                isLooseCompressedLT(LevelType::LooseCompressedNo) &&
                isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
-               !isLooseCompressedLT(LevelType::TwoOutOfFour)),
+               !isLooseCompressedLT(LevelType::NOutOfM)),
               "isLooseCompressedLT definition is broken");
 
-static_assert((!is2OutOf4LT(LevelType::Dense) &&
-               !is2OutOf4LT(LevelType::Compressed) &&
-               !is2OutOf4LT(LevelType::CompressedNu) &&
-               !is2OutOf4LT(LevelType::CompressedNo) &&
-               !is2OutOf4LT(LevelType::CompressedNuNo) &&
-               !is2OutOf4LT(LevelType::Singleton) &&
-               !is2OutOf4LT(LevelType::SingletonNu) &&
-               !is2OutOf4LT(LevelType::SingletonNo) &&
-               !is2OutOf4LT(LevelType::SingletonNuNo) &&
-               !is2OutOf4LT(LevelType::LooseCompressed) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNu) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNo) &&
-               !is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
-               is2OutOf4LT(LevelType::TwoOutOfFour)),
-              "is2OutOf4LT definition is broken");
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+               !isNOutOfMLT(LevelType::Compressed) &&
+               !isNOutOfMLT(LevelType::CompressedNu) &&
+               !isNOutOfMLT(LevelType::CompressedNo) &&
+               !isNOutOfMLT(LevelType::CompressedNuNo) &&
+               !isNOutOfMLT(LevelType::Singleton) &&
+               !isNOutOfMLT(LevelType::SingletonNu) &&
+               !isNOutOfMLT(LevelType::SingletonNo) &&
+               !isNOutOfMLT(LevelType::SingletonNuNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressed) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+               !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+               isNOutOfMLT(LevelType::NOutOfM)),
+              "isNOutOfMLT definition is broken");
 
 static_assert((isOrderedLT(LevelType::Dense) &&
                isOrderedLT(LevelType::Compressed) &&
@@ -487,7 +542,7 @@ static_assert((isOrderedLT(LevelType::Dense) &&
                isOrderedLT(LevelType::LooseCompressedNu) &&
                !isOrderedLT(LevelType::LooseCompressedNo) &&
                !isOrderedLT(LevelType::LooseCompressedNuNo) &&
-               isOrderedLT(LevelType::TwoOutOfFour)),
+               isOrderedLT(LevelType::NOutOfM)),
               "isOrderedLT definition is broken");
 
 static_assert((isUniqueLT(LevelType::Dense) &&
@@ -503,7 +558,7 @@ static_assert((isUniqueLT(LevelType::Dense) &&
                !isUniqueLT(LevelType::LooseCompressedNu) &&
                isUniqueLT(LevelType::LooseCompressedNo) &&
                !isUniqueLT(LevelType::LooseCompressedNuNo) &&
-               isUniqueLT(LevelType::TwoOutOfFour)),
+               isUniqueLT(LevelType::NOutOfM)),
               "isUniqueLT definition is broken");
 
 /// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 12c1068ae1f54..299ba0e603089 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     - **compressed** : only nonzeros along this level are stored
     - **loose_compressed** : as compressed, but allows for free space between regions
     - **singleton** : a variant of the compressed format, where coordinates have no siblings
-    - **block2_4** : the compression uses a 2:4 encoding per 1x4 block
+    - **block[2, 4]** : the compression uses a 2:4 encoding per 1x4 block
 
     For a compressed level, each position interval is represented in a compact
     way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
@@ -374,7 +374,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedLT(getLvlType(l)); }
     bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonLT(getLvlType(l)); }
     bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedLT(getLvlType(l)); }
-    bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4LT(getLvlType(l)); }
+    bool isNOutOfMLvl(::mlir::sparse_tensor::Level l) const { return isNOutOfMLT(getLvlType(l)); }
     bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedLT(getLvlType(l)); }
     bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueLT(getLvlType(l)); }
 
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4c98129744bcd..4e2b85d35c1ac 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -291,7 +291,7 @@ class SparseTensorType {
     return isLooseCompressedLT(getLvlType(l));
   }
   bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); }
-  bool is2OutOf4Lvl(Level l) const { return is2OutOf4LT(getLvlType(l)); }
+  bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); }
   bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); }
   bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); }
   bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4a34bb2e003e8..490ef3071af1b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -510,7 +510,7 @@ class Merger {
     if (isLvlWithNonTrivialIdxExp(b)) {
       auto lt = getLoopDependentLevelType(b);
       return isCompressedLT(lt) || isSingletonLT(lt) ||
-             isLooseCompressedLT(lt) || is2OutOf4LT(lt);
+             isLooseCompressedLT(lt) || isNOutOfMLT(lt);
     }
     return false;
   }
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 01c5f2382ffe6..1d8d9bcfb3b2c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -124,7 +124,7 @@ class SparseTensorStorageBase {
   bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
 
   /// Safely checks if the level uses 2 out of 4 storage.
-  bool is2OutOf4Lvl(uint64_t l) const { return is2OutOf4LT(getLvlType(l)); }
+  bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
 
   /// Safely checks if the level is ordered.
   bool isOrderedLvl(uint64_t l) const { return isOrderedLT(getLvlType(l)); }
@@ -450,7 +450,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
   void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) {
     if (!isDenseLvl(lvl)) {
       assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) ||
-             isSingletonLvl(lvl) || is2OutOf4Lvl(lvl));
+             isSingletonLvl(lvl) || isNOutOfMLvl(lvl));
       coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
     } else { // Dense level.
       assert(crd >= full && "Coordinate was already filled");
@@ -473,7 +473,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
       return positions[l][parentSz];
     if (isLooseCompressedLvl(l))
       return positions[l][2 * parentSz - 1];
-    if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+    if (isSingletonLvl(l) || isNOutOfMLvl(l))
       return parentSz; // new size same as the parent
     assert(isDenseLvl(l));
     return parentSz * getLvlSize(l);
@@ -527,7 +527,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
       uint64_t pos = coordinates[l].size();
       positions[l].insert(positions[l].end(), 2 * count,
                           detail::checkOverflowCast<P>(pos));
-    } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+    } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
       return; // Nothing to finalize.
     } else {  // Dense dimension.
       assert(isDenseLvl(l));
@@ -624,7 +624,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
         lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
         toCOO(pos, l + 1, dimCoords);
       }
-    } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+    } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
       assert(parentPos < coordinates[l].size());
       lvlCursor[l] = static_cast<uint64_t>(coordinates[l][parentPos]);
       toCOO(parentPos, l + 1, dimCoords);
@@ -721,7 +721,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     } else if (isSingletonLvl(l)) {
       coordinates[l].reserve(sz);
       sz = 1;
-    } else if (is2OutOf4Lvl(l)) {
+    } else if (isNOutOfMLvl(l)) {
       assert(l == lvlRank - 1 && "unexpected 2:4 usage");
       sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
       coordinates[l].reserve(sz);
@@ -791,7 +791,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
       }
     } else if (isSingletonLvl(l)) {
       assert(0 && "general singleton not supported yet");
-    } else if (is2OutOf4Lvl(l)) {
+    } else if (isNOutOfMLvl(l)) {
       assert(0 && "2Out4 not supported yet");
     } else {
       assert(isDenseLvl(l));
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 8706c523988b1..f68d77dc129ad 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -25,7 +25,7 @@ using namespace mlir::python::adaptors;
 static void populateDialectSparseTensorSubmodule(const py::module &m) {
   py::enum_<MlirSparseTensorLevelType>(m, "LevelType", py::module_local())
       .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
-      .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
+      .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
       .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
       .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
       .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e4534ad132385..a34b9a29b0e90 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -20,25 +20,36 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
                                       mlir::sparse_tensor::SparseTensorDialect)
 
 // Ensure the C-API enums are int-castable to C++ equivalents.
-static_assert(static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
-                      static_cast<int>(LevelType::Dense) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
-                      static_cast<int>(LevelType::Compressed) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
-                      static_cast<int>(LevelType::CompressedNu) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
-                      static_cast<int>(LevelType::CompressedNo) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
-                      static_cast<int>(LevelType::CompressedNuNo) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
-                      static_cast<int>(LevelType::Singleton) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
-                      static_cast<int>(LevelType::SingletonNu) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
-                      static_cast<int>(LevelType::SingletonNo) &&
-                  static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
-                      static_cast<int>(LevelType::SingletonNuNo),
-              "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
+static_assert(
+    static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
+            static_cast<int>(LevelType::Dense) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
+            static_cast<int>(LevelType::Compressed) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
+            static_cast<int>(LevelType::CompressedNu) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
+            static_cast<int>(LevelType::CompressedNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
+            static_cast<int>(LevelType::CompressedNuNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
+            static_cast<int>(LevelType::Singleton) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
+            static_cast<int>(LevelType::SingletonNu) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
+            static_cast<int>(LevelType::SingletonNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
+            static_cast<int>(LevelType::SingletonNuNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
+            static_cast<int>(LevelType::LooseCompressed) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) ==
+            static_cast<int>(LevelType::LooseCompressedNu) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) ==
+            static_cast<int>(LevelType::LooseCompressedNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) ==
+            static_cast<int>(LevelType::LooseCompressedNuNo) &&
+        static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
+            static_cast<int>(LevelType::NOutOfM),
+    "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
 
 bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
   return isa<SparseTensorEncodingAttr>(unwrap(attr));
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 56b435c57d30a..95874d4857fc8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -299,6 +299,8 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
   FAILURE_IF_FAILED(type)
 
   lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
+  llvm::errs() << "type = " << toMLIRString(static_cast<LevelType>(*type))
+               << "\n";
   return success();
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index eb7ea63a4e88b..14ebe14b49f64 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -29,12 +29,21 @@ using namespace mlir::sparse_tensor::ir_detail;
 // `LvlTypeParser` implementation.
 //===----------------------------------------------------------------------===//
 
-FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
+FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   StringRef base;
   const auto loc = parser.getCurrentLocation();
   ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
            "expected valid level format (e.g. dense, compressed or singleton)")
-  uint8_t properties = 0;
+  uint64_t properties = 0;
+  SmallVector<unsigned> blockSizes;
+
+  if (base.compare("block") == 0) {
+    ParseResult res = parser.parseCommaSeparatedList(
+        mlir::OpAsmParser::Delimiter::OptionalSquare,
+        [&]() -> ParseResult { return parseBlockSize(parser, &blockSizes); },
+        " in block n out of m");
+    FAILURE_IF_FAILED(res)
+  }
 
   ParseResult res = parser.parseCommaSeparatedList(
       mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -44,15 +53,21 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
 
   // Set the base bit for properties.
   if (base.compare("dense") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::Dense);
+    properties |= static_cast<uint64_t>(LevelFormat::Dense);
   } else if (base.compare("compressed") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::Compressed);
-  } else if (base.compare("block2_4") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::TwoOutOfFour);
+    properties |= static_cast<uint64_t>(LevelFormat::Compressed);
+  } else if (base.compare("block") == 0) {
+    if (blockSizes.size() != 2) {
+      parser.emitError(loc, "expected exactly 2 block sizes");
+      return failure();
+    }
+    properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
+    properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]);
+    llvm::errs() << "properties1: " << properties << "\n";
   } else if (base.compare("loose_compressed") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::LooseCompressed);
+    properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
   } else if (base.compare("singleton") == 0) {
-    properties |= static_cast<uint8_t>(LevelFormat::Singleton);
+    properties |= static_cast<uint64_t>(LevelFormat::Singleton);
   } else {
     parser.emitError(loc, "unknown level format: ") << base;
     return failure();
@@ -64,15 +79,15 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
 }
 
 ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
-                                         uint8_t *properties) const {
+                                         uint64_t *properties) const {
   StringRef strVal;
   auto loc = parser.getCurrentLocation();
   ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
            "expected valid level property (e.g. nonordered, nonunique or high)")
   if (strVal.compare("nonunique") == 0) {
-    *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonunique);
+    *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
   } else if (strVal.compare("nonordered") == 0) {
-    *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonordered);
+    *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
   } else {
     parser.emitError(loc, "unknown level property: ") << strVal;
     return failure();
@@ -80,4 +95,22 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
   return success();
 }
 
+ParseResult
+LvlTypeParser::parseBlockSize(AsmParser &parser,
+                              SmallVector<unsigned> *blockSizes) const {
+  int intVal;
+  auto loc = parser.getCurrentLocation();
+  OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
+  if (intValParseResult.has_value()) {
+    if (failed(*intValParseResult)) {
+      parser.emitError(loc, "failed to parse block size");
+      return failure();
+    }
+    blockSizes->push_back(intVal);
+    return success();
+  }
+  parser.emitError(loc, "expected valid integer for block size");
+  return failure();
+}
+
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 5e2f11b75d4da..78ae667f97923 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -18,10 +18,12 @@ namespace ir_detail {
 class LvlTypeParser {
 public:
   LvlTypeParser() = default;
-  FailureOr<uint8_t> parseLvlType(AsmParser &parser) const;
+  FailureOr<uint64_t> parseLvlType(AsmParser &parser) const;
 
 private:
-  ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const;
+  ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
+  ParseResult parseBlockSize(AsmParser &parser,
+                             SmallVector<unsigned> *blockSizes) const;
 };
 
 } // namespace ir_detail
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 87a37a7926e9e..23676eccdfb28 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -451,7 +451,7 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
 /// Test for 2:4 matrix with suitable metadata.
 static bool isAdmissible24(SparseTensorType &aTp) {
   return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
-         aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
+         aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
 }
 
 /// Test for conversion into 2:4 matrix.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 491501a3381b9..d4459c6ea1e52 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -130,7 +130,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
       createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
                      /*value=*/posZero, /*repeat=*/linear);
       return;
-    } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+    } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
       return; // nothing to do
     }
     // Keep compounding the size, but nothing needs to be initialized
@@ -409,7 +409,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
       }
     } else {
       assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
-             is2OutOf4LT(lt));
+             isNOutOfMLT(lt));
     }
   }
 }
@@ -488,7 +488,7 @@ class SparseInsertGenerator
         }
         parentPos =
             genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
-      } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+      } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
         // Create:
         //   coordinates[lvl].push_back(coords[lvl])
         //   positions[lvl] = positions[lvl-1]
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5266ca7213bfc..cc39f21001168 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -891,7 +891,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
         assert(curr == env.merger().loop(b));
         Value clause;
         if (isCompressedLT(lt) || isSingletonLT(lt) ||
-            isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
+            isLooseCompressedLT(lt) || isNOutOfMLT(lt)) {
           assert(lvl.has_value());
           const Value crd = env.emitter().getCoord(tid, *lvl);
           const Value lvar = env.getLoopVar(curr);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index 8d54b5959d871..cc119bc704559 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -423,7 +423,7 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
 /// Generates a constant of the internal dimension level type encoding.
 inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
                                        LevelType lt) {
-  return constantI8(builder, loc, static_cast<uint8_t>(lt));
+  return constantI64(builder, loc, static_cast<uint64_t>(lt));
 }
 
 inline bool isZeroRankedTensorOrScalar(Type type) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 98323c2195461..39e9a532eb5e8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1205,7 +1205,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
     Value crd = genToCoordinates(b, l, t, lvl);
     return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
   }
-  case LevelFormat::TwoOutOfFour: {
+  case LevelFormat::NOutOfM: {
     Value crd = genToCoordinates(b, l, t, lvl);
     return std::make_unique<TwoOutFourLevel>(tid, lvl, lt, sz, crd);
   }
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 6cdf5f8c0168b..96537cbb0c483 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -489,7 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
     if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
       const auto lt = getLvlType(b);
       if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
-          !isLooseCompressedLT(lt) && !is2OutOf4LT(lt)) {
+          !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
         if (reset)
           simple.reset(b);
         reset = true;
@@ -670,7 +670,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
   for (TensorLoopId b : bits.set_bits()) {
     const auto lt = getLvlType(b);
     if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
-        is2OutOf4LT(lt))
+        isNOutOfMLT(lt))
       return true;
   }
   return hasSparseIdxReduction(bits);
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 0c7b3a228a65c..9e8b240899d80 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -45,7 +45,7 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
   for (uint64_t l = 0; l < lvlRank; l++) {
     assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
     assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
-           isSingletonLvl(l) || is2OutOf4Lvl(l));
+           isSingletonLvl(l) || isNOutOfMLvl(l));
   }
 }
 
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index b0bc9bb6e881a..ea4c56b7ec0c5 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -37,9 +37,9 @@ static int testRoundtripEncoding(MlirContext ctx) {
       mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr);
   // CHECK: (d0, d1)[s0] -> (s0, d0, d1)
   mlirAffineMapDump(dimToLvl);
-  // CHECK: level_type: 4
-  // CHECK: level_type: 8
-  // CHECK: level_type: 8
+  // CHECK: level_type: 65536
+  // CHECK: level_type: 131072
+  // CHECK: level_type: 131072
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index 6fe7ec906f30e..e4884a9bf393f 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -4,7 +4,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : block[2, 4]
   )
 }>
 
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index e4e825bf85043..465f210862660 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -78,8 +78,8 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
 //   CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -96,8 +96,8 @@ func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
 //       CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -114,8 +114,8 @@ func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CSR> {
 //   CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
 //       CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
 //       CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi64>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
 //   CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
@@ -136,10 +136,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor<?x?x?xf32, #SparseTensor> {
 //   CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
 //   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 //   CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
+//   CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
 //   CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex>
 //   CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
-//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+//   CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
 //   CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
 //   CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex>
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 20702bb985028..966a7ff2d38e1 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -207,12 +207,12 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : block[2, 4]
   ),
   crdWidth = 8  // we would even like just 2-bits
 }>
 
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block2_4), crdWidth = 8 }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block[2, 4]), crdWidth = 8 }>
 // CHECK-LABEL: func private @NV_24(
 // CHECK-SAME: tensor<?x?xf64, #[[$NV_24]]>
 func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
@@ -226,11 +226,11 @@ func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
   ( i            : dense,
     j            : dense,
     k floordiv 4 : dense,
-    k mod 4      : block2_4
+    k mod 4      : block[2, 4]
   )
 }>
 
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block[2, 4]) }>
 // CHECK-LABEL: func private @NV_24(
 // CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
 func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
@@ -244,11 +244,11 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
   ( i            : dense,
     k floordiv 4 : dense,
     j            : dense,
-    k mod 4      : block2_4
+    k mod 4      : block[2, 4]
   )
 }>
 
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block[2, 4]) }>
 // CHECK-LABEL: func private @NV_24(
 // CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
 func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 40367f12f85a4..d04fbe8ed5c22 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,11 +14,11 @@
 // CHECK-DAG:       %[[VAL_8:.*]] = arith.constant true
 // CHECK-DAG:       %[[VAL_9:.*]] = arith.constant 100 : index
 // CHECK-DAG:       %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 8 : i8
-// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
-// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8>
+// CHECK-DAG:       %[[VAL_11:.*]] = arith.constant 131072 : i64
+// CHECK:           %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
+// CHECK:           %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
+// CHECK:           memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
 // CHECK:           %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
 // CHECK:           %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
 // CHECK:           memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
@@ -28,7 +28,7 @@
 // CHECK:           memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
 // CHECK:           memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
 // CHECK:           %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr
-// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
+// CHECK:           %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
 // CHECK:           %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
 // CHECK:           %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
 // CHECK:           %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 4bc080fc538fc..554d6207aef7e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -59,7 +59,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : block[2, 4]
   ),
 }>
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
index df5b48a3b6ece..9935d7c69e63a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
@@ -41,7 +41,7 @@
 #NV_24 = #sparse_tensor.encoding<{
   map = ( i, j ) -> ( i            : dense,
                       j floordiv 4 : dense,
-                      j mod 4      : block2_4),
+                      j mod 4      : block[2, 4]),
   crdWidth = 8
 }>
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index 17b50b46d073a..25454f5c06b45 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -20,7 +20,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : block[2, 4]
   )
 }>
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index eb99a027a8860..da735b4a3b58a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -20,7 +20,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block2_4
+    j mod 4      : block[2, 4]
   )
 }>
 
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 88a5595d75aea..e9296b961e7fe 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -13,85 +13,85 @@ def run(f):
 # CHECK-LABEL: TEST: testEncodingAttr1D
 @run
 def testEncodingAttr1D():
-    with Context() as ctx:
-        parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0) -> (d0 : compressed),"
-            "  posWidth = 16,"
-            "  crdWidth = 32"
-            "}>"
-        )
-        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
-        print(parsed)
-
-        casted = st.EncodingAttr(parsed)
-        # CHECK: equal: True
-        print(f"equal: {casted == parsed}")
-
-        # CHECK: lvl_types: [<LevelType.compressed: 8>]
-        print(f"lvl_types: {casted.lvl_types}")
-        # CHECK: dim_to_lvl: (d0) -> (d0)
-        print(f"dim_to_lvl: {casted.dim_to_lvl}")
-        # CHECK: lvl_to_dim: (d0) -> (d0)
-        print(f"lvl_to_dim: {casted.lvl_to_dim}")
-        # CHECK: pos_width: 16
-        print(f"pos_width: {casted.pos_width}")
-        # CHECK: crd_width: 32
-        print(f"crd_width: {casted.crd_width}")
-
-        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
-        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
-        print(created)
-        # CHECK: created_equal: False
-        print(f"created_equal: {created == casted}")
-
-        # Verify that the factory creates an instance of the proper type.
-        # CHECK: is_proper_instance: True
-        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
-        # CHECK: created_pos_width: 0
-        print(f"created_pos_width: {created.pos_width}")
+  with Context() as ctx:
+    parsed = Attribute.parse(
+        "#sparse_tensor.encoding<{"
+        "  map = (d0) -> (d0 : compressed),"
+        "  posWidth = 16,"
+        "  crdWidth = 32"
+        "}>"
+    )
+    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+    print(parsed)
+
+    casted = st.EncodingAttr(parsed)
+    # CHECK: equal: True
+    print(f"equal: {casted == parsed}")
+
+    # CHECK: lvl_types: [<LevelType.compressed: 131072>]
+    print(f"lvl_types: {casted.lvl_types}")
+    # CHECK: dim_to_lvl: (d0) -> (d0)
+    print(f"dim_to_lvl: {casted.dim_to_lvl}")
+    # CHECK: lvl_to_dim: (d0) -> (d0)
+    print(f"lvl_to_dim: {casted.lvl_to_dim}")
+    # CHECK: pos_width: 16
+    print(f"pos_width: {casted.pos_width}")
+    # CHECK: crd_width: 32
+    print(f"crd_width: {casted.crd_width}")
+
+    created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+    print(created)
+    # CHECK: created_equal: False
+    print(f"created_equal: {created == casted}")
+
+    # Verify that the factory creates an instance of the proper type.
+    # CHECK: is_proper_instance: True
+    print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+    # CHECK: created_pos_width: 0
+    print(f"created_pos_width: {created.pos_width}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttr2D
 @run
 def testEncodingAttr2D():
-    with Context() as ctx:
-        parsed = Attribute.parse(
-            "#sparse_tensor.encoding<{"
-            "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
-            "  posWidth = 8,"
-            "  crdWidth = 32"
-            "}>"
-        )
-        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
-        print(parsed)
-
-        casted = st.EncodingAttr(parsed)
-        # CHECK: equal: True
-        print(f"equal: {casted == parsed}")
-
-        # CHECK: lvl_types: [<LevelType.dense: 4>, <LevelType.compressed: 8>]
-        print(f"lvl_types: {casted.lvl_types}")
-        # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
-        print(f"dim_to_lvl: {casted.dim_to_lvl}")
-        # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
-        print(f"lvl_to_dim: {casted.lvl_to_dim}")
-        # CHECK: pos_width: 8
-        print(f"pos_width: {casted.pos_width}")
-        # CHECK: crd_width: 32
-        print(f"crd_width: {casted.crd_width}")
-
-        created = st.EncodingAttr.get(
-            casted.lvl_types,
-            casted.dim_to_lvl,
-            casted.lvl_to_dim,
-            8,
-            32,
-        )
-        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
-        print(created)
-        # CHECK: created_equal: True
-        print(f"created_equal: {created == casted}")
+  with Context() as ctx:
+    parsed = Attribute.parse(
+        "#sparse_tensor.encoding<{"
+        "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
+        "  posWidth = 8,"
+        "  crdWidth = 32"
+        "}>"
+    )
+    # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+    print(parsed)
+
+    casted = st.EncodingAttr(parsed)
+    # CHECK: equal: True
+    print(f"equal: {casted == parsed}")
+
+    # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>]
+    print(f"lvl_types: {casted.lvl_types}")
+    # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
+    print(f"dim_to_lvl: {casted.dim_to_lvl}")
+    # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
+    print(f"lvl_to_dim: {casted.lvl_to_dim}")
+    # CHECK: pos_width: 8
+    print(f"pos_width: {casted.pos_width}")
+    # CHECK: crd_width: 32
+    print(f"crd_width: {casted.crd_width}")
+
+    created = st.EncodingAttr.get(
+        casted.lvl_types,
+        casted.dim_to_lvl,
+        casted.lvl_to_dim,
+        8,
+        32,
+    )
+    # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+    print(created)
+    # CHECK: created_equal: True
+    print(f"created_equal: {created == casted}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttrOnTensorType

>From 8756b765741d661a4519d2d7cdb00b5cd440b540 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 02:58:18 +0000
Subject: [PATCH 02/10] format

---
 .../include/mlir/Dialect/SparseTensor/IR/Enums.h |  8 ++------
 .../SparseTensor/IR/Detail/DimLvlMapParser.cpp   |  2 --
 .../SparseTensor/IR/Detail/LvlTypeParser.cpp     |  1 -
 .../SparseTensor/IR/SparseTensorDialect.cpp      | 16 ++++++++++++++--
 4 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 6ddc9326179fe..15802a5ad3563 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -233,7 +233,7 @@ constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
 }
 
 /// Returns string representation of the given dimension level type.
-std::string toMLIRString(LevelType lt) {
+constexpr const char *toMLIRString(LevelType lt) {
   switch (lt) {
   case LevelType::Undef:
     return "undef";
@@ -264,12 +264,8 @@ std::string toMLIRString(LevelType lt) {
   case LevelType::LooseCompressedNuNo:
     return "loose_compressed(nonunique, nonordered)";
   default:
-    // If NOutOfM bit is set, print the [n, m] sizes.
     if (isNOutOfMLT(lt)) {
-      unsigned n = getN(lt);
-      unsigned m = getM(lt);
-      return std::string("block[") + std::to_string(n) + ", " +
-             std::to_string(m) + "]";
+      return "block";
     }
   }
   return "";
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 95874d4857fc8..56b435c57d30a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -299,8 +299,6 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
   FAILURE_IF_FAILED(type)
 
   lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
-  llvm::errs() << "type = " << toMLIRString(static_cast<LevelType>(*type))
-               << "\n";
   return success();
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 14ebe14b49f64..993ad9be8a012 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -63,7 +63,6 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
     }
     properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
     properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]);
-    llvm::errs() << "properties1: " << properties << "\n";
   } else if (base.compare("loose_compressed") == 0) {
     properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
   } else if (base.compare("singleton") == 0) {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6033ebf6897ce..d56d90a2d6130 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -613,16 +613,28 @@ void SparseTensorEncodingAttr::printDimensions(
   }
 }
 
+std::string getNOutOfMString(LevelType lt) {
+  if (isNOutOfMLT(lt)) {
+    unsigned n = getN(lt);
+    unsigned m = getM(lt);
+    auto output = "[" + std::to_string(n) + ", " + std::to_string(m) + "]";
+    return output;
+  }
+  return "";
+}
+
 void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
                                            ArrayRef<LevelType> lvlTypes) const {
   for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
     map.getResult(i).print(printer.getStream());
-    printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
+    printer << " : " << toMLIRString(lvlTypes[i])
+            << getNOutOfMString(lvlTypes[i]) << ", ";
   }
   if (map.getNumResults() >= 1) {
     auto lastIndex = map.getNumResults() - 1;
     map.getResult(lastIndex).print(printer.getStream());
-    printer << " : " << toMLIRString(lvlTypes[lastIndex]);
+    printer << " : " << toMLIRString(lvlTypes[lastIndex])
+            << getNOutOfMString(lvlTypes[lastIndex]);
   }
 }
 

>From 4e59b51b6d89ca3947934d52334ebe92a7f8e028 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 03:13:30 +0000
Subject: [PATCH 03/10] python format

---
 .../mlir/Dialect/SparseTensor/IR/Enums.h      |   1 -
 .../python/dialects/sparse_tensor/dialect.py  | 148 +++++++++---------
 2 files changed, 74 insertions(+), 75 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 15802a5ad3563..99443957d01d5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -310,7 +310,6 @@ constexpr bool isLooseCompressedLT(LevelType lt) {
          static_cast<uint64_t>(LevelType::LooseCompressed);
 }
 
-
 /// Check if the `LevelType` needs positions array.
 constexpr bool isWithPosLT(LevelType lt) {
   return isCompressedLT(lt) || isLooseCompressedLT(lt);
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index e9296b961e7fe..75c47a57f78af 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -13,85 +13,85 @@ def run(f):
 # CHECK-LABEL: TEST: testEncodingAttr1D
 @run
 def testEncodingAttr1D():
-  with Context() as ctx:
-    parsed = Attribute.parse(
-        "#sparse_tensor.encoding<{"
-        "  map = (d0) -> (d0 : compressed),"
-        "  posWidth = 16,"
-        "  crdWidth = 32"
-        "}>"
-    )
-    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
-    print(parsed)
-
-    casted = st.EncodingAttr(parsed)
-    # CHECK: equal: True
-    print(f"equal: {casted == parsed}")
-
-    # CHECK: lvl_types: [<LevelType.compressed: 131072>]
-    print(f"lvl_types: {casted.lvl_types}")
-    # CHECK: dim_to_lvl: (d0) -> (d0)
-    print(f"dim_to_lvl: {casted.dim_to_lvl}")
-    # CHECK: lvl_to_dim: (d0) -> (d0)
-    print(f"lvl_to_dim: {casted.lvl_to_dim}")
-    # CHECK: pos_width: 16
-    print(f"pos_width: {casted.pos_width}")
-    # CHECK: crd_width: 32
-    print(f"crd_width: {casted.crd_width}")
-
-    created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
-    # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
-    print(created)
-    # CHECK: created_equal: False
-    print(f"created_equal: {created == casted}")
-
-    # Verify that the factory creates an instance of the proper type.
-    # CHECK: is_proper_instance: True
-    print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
-    # CHECK: created_pos_width: 0
-    print(f"created_pos_width: {created.pos_width}")
+    with Context() as ctx:
+        parsed = Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            "  map = (d0) -> (d0 : compressed),"
+            "  posWidth = 16,"
+            "  crdWidth = 32"
+            "}>"
+        )
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+        print(parsed)
+
+        casted = st.EncodingAttr(parsed)
+        # CHECK: equal: True
+        print(f"equal: {casted == parsed}")
+
+        # CHECK: lvl_types: [<LevelType.compressed: 131072>]
+        print(f"lvl_types: {casted.lvl_types}")
+        # CHECK: dim_to_lvl: (d0) -> (d0)
+        print(f"dim_to_lvl: {casted.dim_to_lvl}")
+        # CHECK: lvl_to_dim: (d0) -> (d0)
+        print(f"lvl_to_dim: {casted.lvl_to_dim}")
+        # CHECK: pos_width: 16
+        print(f"pos_width: {casted.pos_width}")
+        # CHECK: crd_width: 32
+        print(f"crd_width: {casted.crd_width}")
+
+        created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+        # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+        print(created)
+        # CHECK: created_equal: False
+        print(f"created_equal: {created == casted}")
+
+        # Verify that the factory creates an instance of the proper type.
+        # CHECK: is_proper_instance: True
+        print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+        # CHECK: created_pos_width: 0
+        print(f"created_pos_width: {created.pos_width}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttr2D
 @run
 def testEncodingAttr2D():
-  with Context() as ctx:
-    parsed = Attribute.parse(
-        "#sparse_tensor.encoding<{"
-        "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
-        "  posWidth = 8,"
-        "  crdWidth = 32"
-        "}>"
-    )
-    # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
-    print(parsed)
-
-    casted = st.EncodingAttr(parsed)
-    # CHECK: equal: True
-    print(f"equal: {casted == parsed}")
-
-    # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>]
-    print(f"lvl_types: {casted.lvl_types}")
-    # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
-    print(f"dim_to_lvl: {casted.dim_to_lvl}")
-    # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
-    print(f"lvl_to_dim: {casted.lvl_to_dim}")
-    # CHECK: pos_width: 8
-    print(f"pos_width: {casted.pos_width}")
-    # CHECK: crd_width: 32
-    print(f"crd_width: {casted.crd_width}")
-
-    created = st.EncodingAttr.get(
-        casted.lvl_types,
-        casted.dim_to_lvl,
-        casted.lvl_to_dim,
-        8,
-        32,
-    )
-    # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
-    print(created)
-    # CHECK: created_equal: True
-    print(f"created_equal: {created == casted}")
+    with Context() as ctx:
+        parsed = Attribute.parse(
+            "#sparse_tensor.encoding<{"
+            "  map = (d0, d1) -> (d1 : dense, d0 : compressed),"
+            "  posWidth = 8,"
+            "  crdWidth = 32"
+            "}>"
+        )
+        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+        print(parsed)
+
+        casted = st.EncodingAttr(parsed)
+        # CHECK: equal: True
+        print(f"equal: {casted == parsed}")
+
+        # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>]
+        print(f"lvl_types: {casted.lvl_types}")
+        # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
+        print(f"dim_to_lvl: {casted.dim_to_lvl}")
+        # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
+        print(f"lvl_to_dim: {casted.lvl_to_dim}")
+        # CHECK: pos_width: 8
+        print(f"pos_width: {casted.pos_width}")
+        # CHECK: crd_width: 32
+        print(f"crd_width: {casted.crd_width}")
+
+        created = st.EncodingAttr.get(
+            casted.lvl_types,
+            casted.dim_to_lvl,
+            casted.lvl_to_dim,
+            8,
+            32,
+        )
+        # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+        print(created)
+        # CHECK: created_equal: True
+        print(f"created_equal: {created == casted}")
 
 
 # CHECK-LABEL: TEST: testEncodingAttrOnTensorType

>From 4ba9d91ae04004519d451e953f208b9a0343a850 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 19:37:09 +0000
Subject: [PATCH 04/10] more edits for n:m

---
 .../SparseTensor/IR/SparseTensorAttrDefs.td    |  2 +-
 .../ExecutionEngine/SparseTensor/Storage.h     |  4 ++--
 .../SparseTensor/IR/Detail/LvlTypeParser.cpp   |  6 +++---
 .../SparseTensor/IR/Detail/LvlTypeParser.h     |  4 ++--
 .../SparseTensor/roundtrip_encoding.mlir       | 18 ++++++++++++++++++
 5 files changed, 26 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 299ba0e603089..08ba96d437045 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     - **compressed** : only nonzeros along this level are stored
     - **loose_compressed** : as compressed, but allows for free space between regions
     - **singleton** : a variant of the compressed format, where coordinates have no siblings
-    - **block[2, 4]** : the compression uses a 2:4 encoding per 1x4 block
+    - **block[n, m]** : the compression uses a n:m encoding per 1xm block
 
     For a compressed level, each position interval is represented in a compact
     way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 1d8d9bcfb3b2c..14809381517e7 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -123,7 +123,7 @@ class SparseTensorStorageBase {
   /// Safely checks if the level uses singleton storage.
   bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
 
-  /// Safely checks if the level uses 2 out of 4 storage.
+  /// Safely checks if the level uses n out of m storage.
   bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
 
   /// Safely checks if the level is ordered.
@@ -792,7 +792,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
     } else if (isSingletonLvl(l)) {
       assert(0 && "general singleton not supported yet");
     } else if (isNOutOfMLvl(l)) {
-      assert(0 && "2Out4 not supported yet");
+      assert(0 && "n ouf of m not supported yet");
     } else {
       assert(isDenseLvl(l));
     }
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 993ad9be8a012..022cc37615f13 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -40,7 +40,7 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   if (base.compare("block") == 0) {
     ParseResult res = parser.parseCommaSeparatedList(
         mlir::OpAsmParser::Delimiter::OptionalSquare,
-        [&]() -> ParseResult { return parseBlockSize(parser, &blockSizes); },
+        [&]() -> ParseResult { return parseBlockSizes(parser, &blockSizes); },
         " in block n out of m");
     FAILURE_IF_FAILED(res)
   }
@@ -95,8 +95,8 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
 }
 
 ParseResult
-LvlTypeParser::parseBlockSize(AsmParser &parser,
-                              SmallVector<unsigned> *blockSizes) const {
+LvlTypeParser::parseBlockSizes(AsmParser &parser,
+                               SmallVector<unsigned> *blockSizes) const {
   int intVal;
   auto loc = parser.getCurrentLocation();
   OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 78ae667f97923..250c98fb8702b 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -22,8 +22,8 @@ class LvlTypeParser {
 
 private:
   ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
-  ParseResult parseBlockSize(AsmParser &parser,
-                             SmallVector<unsigned> *blockSizes) const;
+  ParseResult parseBlockSizes(AsmParser &parser,
+                              SmallVector<unsigned> *blockSizes) const;
 };
 
 } // namespace ir_detail
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 966a7ff2d38e1..367b9c736095a 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -254,3 +254,21 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
 func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
   return
 }
+
+// -----
+
+#NOutOfM = #sparse_tensor.encoding<{
+  map = ( i, j, k ) ->
+  ( i            : dense,
+    k floordiv 8 : dense,
+    j            : dense,
+    k mod 8      : block[5, 8]
+  )
+}>
+
+// CHECK-DAG: #[[$NOutOfM:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 8 : dense, d1 : dense, d2 mod 8 : block[5, 8]) }>
+// CHECK-LABEL: func private @NOutOfM(
+// CHECK-SAME: tensor<?x?x?xf64, #[[$NOutOfM]]>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+  return
+}

>From fd2f19eae9df3dad88008fccc77519a04cd9f4ec Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 23:08:20 +0000
Subject: [PATCH 05/10] address comments

---
 mlir/include/mlir-c/Dialect/SparseTensor.h    | 28 +++++------
 .../mlir/Dialect/SparseTensor/IR/Enums.h      | 46 +++++++++----------
 .../SparseTensor/IR/SparseTensorAttrDefs.td   |  3 +-
 .../SparseTensor/IR/Detail/LvlTypeParser.cpp  | 20 ++++----
 .../SparseTensor/IR/Detail/LvlTypeParser.h    |  4 +-
 mlir/test/CAPI/sparse_tensor.c                |  2 +-
 .../SparseTensor/GPU/gpu_matmul24_lib.mlir    |  2 +-
 .../SparseTensor/roundtrip_encoding.mlir      | 16 +++----
 .../SparseTensor/CPU/sparse_block_matmul.mlir |  2 +-
 .../Dialect/SparseTensor/CPU/sparse_ds.mlir   |  2 +-
 .../CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir   |  2 +-
 .../CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir |  2 +-
 12 files changed, 65 insertions(+), 64 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 5fc1f51452482..3a501056fbae3 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,20 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
 enum MlirSparseTensorLevelType {
-  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536,                   // 0x00_00_0001_0000
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072,             // 0x00_00_0002_0000
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073,          // 0x00_00_0002_0001
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074,          // 0x00_00_0002_0002
-  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075,       // 0x00_00_0002_0003
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144,              // 0x00_00_0004_0000
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145,           // 0x00_00_0004_0001
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146,           // 0x00_00_0004_0002
-  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147,        // 0x00_00_0004_0003
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288,       // 0x00_00_0008_0000
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289,    // 0x00_00_0008_0001
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290,    // 0x00_00_0008_0002
-  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003
-  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576,            // 0x00_00_0010_0000
+  MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002,
+  MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002,
+  MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
+  MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
+  MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 99443957d01d5..bd50804ee1d3d 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -172,36 +172,36 @@ enum class Action : uint32_t {
 /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
 ///
 enum class LevelType : uint64_t {
-  Undef = 0,                    // 0x00_00_0000_0000
-  Dense = 65536,                // 0x00_00_0001_0000
-  Compressed = 131072,          // 0x00_00_0002_0000
-  CompressedNu = 131073,        // 0x00_00_0002_0001
-  CompressedNo = 131074,        // 0x00_00_0002_0002
-  CompressedNuNo = 131075,      // 0x00_00_0002_0003
-  Singleton = 262144,           // 0x00_00_0004_0000
-  SingletonNu = 262145,         // 0x00_00_0004_0001
-  SingletonNo = 262146,         // 0x00_00_0004_0002
-  SingletonNuNo = 262147,       // 0x00_00_0004_0003
-  LooseCompressed = 524288,     // 0x00_00_0008_0000
-  LooseCompressedNu = 524289,   // 0x00_00_0008_0001
-  LooseCompressedNo = 524290,   // 0x00_00_0008_0002
-  LooseCompressedNuNo = 524291, // 0x00_00_0008_0003
-  NOutOfM = 1048576,            // 0x00_00_0010_0000
+  Undef = 0x000000000000,
+  Dense = 0x000000010000,
+  Compressed = 0x000000020000,
+  CompressedNu = 0x000000020001,
+  CompressedNo = 0x000000020002,
+  CompressedNuNo = 0x000000020003,
+  Singleton = 0x000000040000,
+  SingletonNu = 0x000000040001,
+  SingletonNo = 0x000000040002,
+  SingletonNuNo = 0x000000040003,
+  LooseCompressed = 0x000000080000,
+  LooseCompressedNu = 0x000000080001,
+  LooseCompressedNo = 0x000000080002,
+  LooseCompressedNuNo = 0x000000080003,
+  NOutOfM = 0x000000100000,
 };
 
 /// This enum defines all supported storage format without the level properties.
 enum class LevelFormat : uint64_t {
-  Dense = 65536,            // 0x0001_0000
-  Compressed = 131072,      // 0x0002_0000
-  Singleton = 262144,       // 0x0004_0000
-  LooseCompressed = 524288, // 0x0008_0000
-  NOutOfM = 1048576,        // 0x0010_0000
+  Dense = 0x00010000,
+  Compressed = 0x00020000,
+  Singleton = 0x00040000,
+  LooseCompressed = 0x00080000,
+  NOutOfM = 0x00100000,
 };
 
 /// This enum defines all the nondefault properties for storage formats.
 enum class LevelPropertyNondefault : uint64_t {
-  Nonunique = 1,  // 0x0001
-  Nonordered = 2, // 0x0002
+  Nonunique = 0x0001,
+  Nonordered = 0x0002,
 };
 
 /// Get N of NOutOfM level type.
@@ -265,7 +265,7 @@ constexpr const char *toMLIRString(LevelType lt) {
     return "loose_compressed(nonunique, nonordered)";
   default:
     if (isNOutOfMLT(lt)) {
-      return "block";
+      return "structured";
     }
   }
   return "";
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 08ba96d437045..5b3b971f9a7f2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
     - **compressed** : only nonzeros along this level are stored
     - **loose_compressed** : as compressed, but allows for free space between regions
     - **singleton** : a variant of the compressed format, where coordinates have no siblings
-    - **block[n, m]** : the compression uses a n:m encoding per 1xm block
+    - **structured[n, m]** : the compression uses a n:m encoding
+      (viz. n out of m consecutive elements are nonzero)
 
     For a compressed level, each position interval is represented in a compact
     way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 022cc37615f13..752d6e6481dfe 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -35,12 +35,12 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
   ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
            "expected valid level format (e.g. dense, compressed or singleton)")
   uint64_t properties = 0;
-  SmallVector<unsigned> blockSizes;
+  SmallVector<unsigned> structure;
 
-  if (base.compare("block") == 0) {
+  if (base.compare("structured") == 0) {
     ParseResult res = parser.parseCommaSeparatedList(
         mlir::OpAsmParser::Delimiter::OptionalSquare,
-        [&]() -> ParseResult { return parseBlockSizes(parser, &blockSizes); },
+        [&]() -> ParseResult { return parseStructure(parser, &structure); },
         " in block n out of m");
     FAILURE_IF_FAILED(res)
   }
@@ -56,13 +56,13 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
     properties |= static_cast<uint64_t>(LevelFormat::Dense);
   } else if (base.compare("compressed") == 0) {
     properties |= static_cast<uint64_t>(LevelFormat::Compressed);
-  } else if (base.compare("block") == 0) {
-    if (blockSizes.size() != 2) {
-      parser.emitError(loc, "expected exactly 2 block sizes");
+  } else if (base.compare("structured") == 0) {
+    if (structure.size() != 2) {
+      parser.emitError(loc, "expected exactly 2 structure sizes");
       return failure();
     }
     properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
-    properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]);
+    properties |= nToBits(structure[0]) | mToBits(structure[1]);
   } else if (base.compare("loose_compressed") == 0) {
     properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
   } else if (base.compare("singleton") == 0) {
@@ -95,8 +95,8 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
 }
 
 ParseResult
-LvlTypeParser::parseBlockSizes(AsmParser &parser,
-                               SmallVector<unsigned> *blockSizes) const {
+LvlTypeParser::parseStructure(AsmParser &parser,
+                              SmallVector<unsigned> *structure) const {
   int intVal;
   auto loc = parser.getCurrentLocation();
   OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
@@ -105,7 +105,7 @@ LvlTypeParser::parseBlockSizes(AsmParser &parser,
       parser.emitError(loc, "failed to parse block size");
       return failure();
     }
-    blockSizes->push_back(intVal);
+    structure->push_back(intVal);
     return success();
   }
   parser.emitError(loc, "expected valid integer for block size");
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 250c98fb8702b..6a13112195d44 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -22,8 +22,8 @@ class LvlTypeParser {
 
 private:
   ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
-  ParseResult parseBlockSizes(AsmParser &parser,
-                              SmallVector<unsigned> *blockSizes) const;
+  ParseResult parseStructure(AsmParser &parser,
+                             SmallVector<unsigned> *structure) const;
 };
 
 } // namespace ir_detail
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index ea4c56b7ec0c5..b7cff2cf9998e 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -47,7 +47,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
       malloc(sizeof(enum MlirSparseTensorLevelType) * lvlRank);
   for (int l = 0; l < lvlRank; ++l) {
     lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
-    fprintf(stderr, "level_type: %d\n", lvlTypes[l]);
+    fprintf(stderr, "level_type: %u\n", lvlTypes[l]);
   }
   // CHECK: posWidth: 32
   int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index e4884a9bf393f..8293169049ca6 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -4,7 +4,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block[2, 4]
+    j mod 4      : structured[2, 4]
   )
 }>
 
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 367b9c736095a..64520638b253d 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -207,12 +207,12 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block[2, 4]
+    j mod 4      : structured[2, 4]
   ),
   crdWidth = 8  // we would even like just 2-bits
 }>
 
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block[2, 4]), crdWidth = 8 }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), crdWidth = 8 }>
 // CHECK-LABEL: func private @NV_24(
 // CHECK-SAME: tensor<?x?xf64, #[[$NV_24]]>
 func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
@@ -226,11 +226,11 @@ func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
   ( i            : dense,
     j            : dense,
     k floordiv 4 : dense,
-    k mod 4      : block[2, 4]
+    k mod 4      : structured[2, 4]
   )
 }>
 
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block[2, 4]) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : structured[2, 4]) }>
 // CHECK-LABEL: func private @NV_24(
 // CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
 func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
@@ -244,11 +244,11 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
   ( i            : dense,
     k floordiv 4 : dense,
     j            : dense,
-    k mod 4      : block[2, 4]
+    k mod 4      : structured[2, 4]
   )
 }>
 
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block[2, 4]) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : structured[2, 4]) }>
 // CHECK-LABEL: func private @NV_24(
 // CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
 func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
@@ -262,11 +262,11 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
   ( i            : dense,
     k floordiv 8 : dense,
     j            : dense,
-    k mod 8      : block[5, 8]
+    k mod 8      : structured[5, 8]
   )
 }>
 
-// CHECK-DAG: #[[$NOutOfM:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 8 : dense, d1 : dense, d2 mod 8 : block[5, 8]) }>
+// CHECK-DAG: #[[$NOutOfM:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 8 : dense, d1 : dense, d2 mod 8 : structured[5, 8]) }>
 // CHECK-LABEL: func private @NOutOfM(
 // CHECK-SAME: tensor<?x?x?xf64, #[[$NOutOfM]]>
 func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 554d6207aef7e..e47ac46597b77 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -59,7 +59,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block[2, 4]
+    j mod 4      : structured[2, 4]
   ),
 }>
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
index 9935d7c69e63a..ec5c7580657cd 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
@@ -41,7 +41,7 @@
 #NV_24 = #sparse_tensor.encoding<{
   map = ( i, j ) -> ( i            : dense,
                       j floordiv 4 : dense,
-                      j mod 4      : block[2, 4]),
+                      j mod 4      : structured[2, 4]),
   crdWidth = 8
 }>
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index 25454f5c06b45..b0f63f12c2d57 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -20,7 +20,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block[2, 4]
+    j mod 4      : structured[2, 4]
   )
 }>
 
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index da735b4a3b58a..311cb607b4293 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -20,7 +20,7 @@
   map = ( i, j ) ->
   ( i            : dense,
     j floordiv 4 : dense,
-    j mod 4      : block[2, 4]
+    j mod 4      : structured[2, 4]
   )
 }>
 

>From 2655039a0f14ddae45265c11222b0b1917e55fec Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 05:49:08 +0000
Subject: [PATCH 06/10] ensure 64 bits

---
 mlir/include/mlir-c/Dialect/SparseTensor.h | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 3a501056fbae3..8759f89952cda 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -12,6 +12,7 @@
 
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
+#include <assert.h>
 
 #ifdef __cplusplus
 extern "C" {
@@ -25,7 +26,11 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API.
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
-enum MlirSparseTensorLevelType {
+
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wfixed-enum-extension"
+
+enum MlirSparseTensorLevelType : uint64_t {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
@@ -42,6 +47,11 @@ enum MlirSparseTensorLevelType {
   MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
 };
 
+static_assert((sizeof(enum MlirSparseTensorLevelType) == 8),
+              "MlirSparseTensorLevelType must be 8 bytes");
+
+#pragma GCC diagnostic pop
+
 //===----------------------------------------------------------------------===//
 // SparseTensorEncodingAttr
 //===----------------------------------------------------------------------===//

>From 403acc555bd47b90b37afa28e3af4fac674e85c7 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 06:24:56 +0000
Subject: [PATCH 07/10] windows

---
 mlir/include/mlir-c/Dialect/SparseTensor.h | 9 +++++++--
 mlir/test/CAPI/sparse_tensor.c             | 2 +-
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 8759f89952cda..eba3cf51093bc 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -27,10 +27,12 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
 
+#if !defined(_MSC_VER)
 #pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wfixed-enum-extension"
+#pragma GCC diagnostic ignored "-Wpedantic"
+#endif
 
-enum MlirSparseTensorLevelType : uint64_t {
+enum MlirSparseTensorLevelType {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
@@ -45,12 +47,15 @@ enum MlirSparseTensorLevelType : uint64_t {
   MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
   MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
   MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
+  MLIR_SPARSE_TENSOR_LEVEL_2_OUT_OF_4 = 0x020400100000,
 };
 
 static_assert((sizeof(enum MlirSparseTensorLevelType) == 8),
               "MlirSparseTensorLevelType must be 8 bytes");
 
+#if !defined(_MSC_VER)
 #pragma GCC diagnostic pop
+#endif
 
 //===----------------------------------------------------------------------===//
 // SparseTensorEncodingAttr
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index b7cff2cf9998e..612f26e00a64d 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -47,7 +47,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
       malloc(sizeof(enum MlirSparseTensorLevelType) * lvlRank);
   for (int l = 0; l < lvlRank; ++l) {
     lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
-    fprintf(stderr, "level_type: %u\n", lvlTypes[l]);
+    fprintf(stderr, "level_type: %lu\n", lvlTypes[l]);
   }
   // CHECK: posWidth: 32
   int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(originalAttr);

>From d726251a0089d84bccf7b957b75159a542ed7665 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 07:12:14 +0000
Subject: [PATCH 08/10] windows set enum width

---
 mlir/include/mlir-c/Dialect/SparseTensor.h | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index eba3cf51093bc..9fa93e678b733 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -29,10 +29,10 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 
 #if !defined(_MSC_VER)
 #pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wpedantic"
+#pragma GCC diagnostic ignored "-Wfixed-enum-extension"
 #endif
 
-enum MlirSparseTensorLevelType {
+enum MlirSparseTensorLevelType : uint64_t {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
@@ -47,7 +47,6 @@ enum MlirSparseTensorLevelType {
   MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
   MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
   MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
-  MLIR_SPARSE_TENSOR_LEVEL_2_OUT_OF_4 = 0x020400100000,
 };
 
 static_assert((sizeof(enum MlirSparseTensorLevelType) == 8),

>From 1bcdd7267c24164f581e82e8b916995ae555a2bb Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 07:58:51 +0000
Subject: [PATCH 09/10] use typedef

---
 mlir/include/mlir-c/Dialect/SparseTensor.h | 22 +++++-----------------
 mlir/lib/CAPI/Dialect/SparseTensor.cpp     | 16 ++++++++--------
 mlir/test/CAPI/sparse_tensor.c             |  3 +--
 3 files changed, 14 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 9fa93e678b733..73af8d68a1177 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -12,7 +12,6 @@
 
 #include "mlir-c/AffineMap.h"
 #include "mlir-c/IR.h"
-#include <assert.h>
 
 #ifdef __cplusplus
 extern "C" {
@@ -27,12 +26,9 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
 /// If updating, keep them in sync and update the static_assert in the impl
 /// file.
 
-#if !defined(_MSC_VER)
-#pragma GCC diagnostic push
-#pragma GCC diagnostic ignored "-Wfixed-enum-extension"
-#endif
+typedef uint64_t level_type;
 
-enum MlirSparseTensorLevelType : uint64_t {
+enum MlirSparseTensorLevelType {
   MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
   MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
@@ -49,13 +45,6 @@ enum MlirSparseTensorLevelType : uint64_t {
   MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
 };
 
-static_assert((sizeof(enum MlirSparseTensorLevelType) == 8),
-              "MlirSparseTensorLevelType must be 8 bytes");
-
-#if !defined(_MSC_VER)
-#pragma GCC diagnostic pop
-#endif
-
 //===----------------------------------------------------------------------===//
 // SparseTensorEncodingAttr
 //===----------------------------------------------------------------------===//
@@ -66,16 +55,15 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr);
 
 /// Creates a `sparse_tensor.encoding` attribute with the given parameters.
 MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet(
-    MlirContext ctx, intptr_t lvlRank,
-    enum MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl,
-    MlirAffineMap lvlTodim, int posWidth, int crdWidth);
+    MlirContext ctx, intptr_t lvlRank, level_type const *lvlTypes,
+    MlirAffineMap dimToLvl, MlirAffineMap lvlTodim, int posWidth, int crdWidth);
 
 /// Returns the level-rank of the `sparse_tensor.encoding` attribute.
 MLIR_CAPI_EXPORTED intptr_t
 mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr);
 
 /// Returns a specified level-type of the `sparse_tensor.encoding` attribute.
-MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelType
+MLIR_CAPI_EXPORTED level_type
 mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl);
 
 /// Returns the dimension-to-level mapping of the `sparse_tensor.encoding`
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index a34b9a29b0e90..24e41c037a8dd 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -55,11 +55,11 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
   return isa<SparseTensorEncodingAttr>(unwrap(attr));
 }
 
-MlirAttribute
-mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
-                                MlirSparseTensorLevelType const *lvlTypes,
-                                MlirAffineMap dimToLvl, MlirAffineMap lvlToDim,
-                                int posWidth, int crdWidth) {
+MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank,
+                                              level_type const *lvlTypes,
+                                              MlirAffineMap dimToLvl,
+                                              MlirAffineMap lvlToDim,
+                                              int posWidth, int crdWidth) {
   SmallVector<LevelType> cppLvlTypes;
   cppLvlTypes.reserve(lvlRank);
   for (intptr_t l = 0; l < lvlRank; ++l)
@@ -81,9 +81,9 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) {
   return cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlRank();
 }
 
-MlirSparseTensorLevelType
-mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) {
-  return static_cast<MlirSparseTensorLevelType>(
+level_type mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr,
+                                                  intptr_t lvl) {
+  return static_cast<level_type>(
       cast<SparseTensorEncodingAttr>(unwrap(attr)).getLvlType(lvl));
 }
 
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 612f26e00a64d..48c5e1e3b9aa8 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -43,8 +43,7 @@ static int testRoundtripEncoding(MlirContext ctx) {
   MlirAffineMap lvlToDim =
       mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
   int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
-  enum MlirSparseTensorLevelType *lvlTypes =
-      malloc(sizeof(enum MlirSparseTensorLevelType) * lvlRank);
+  level_type *lvlTypes = malloc(sizeof(level_type) * lvlRank);
   for (int l = 0; l < lvlRank; ++l) {
     lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(originalAttr, l);
     fprintf(stderr, "level_type: %lu\n", lvlTypes[l]);

>From dfc43edefdc754002f7401d98508c4a938cfad3f Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Thu, 1 Feb 2024 09:05:22 +0000
Subject: [PATCH 10/10] level_type for pybind

---
 mlir/lib/Bindings/Python/DialectSparseTensor.cpp   | 4 ++--
 mlir/test/python/dialects/sparse_tensor/dialect.py | 4 ++--
 2 files changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index f68d77dc129ad..b84c2867a9abc 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -46,7 +46,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
                           mlirAttributeIsASparseTensorEncodingAttr)
       .def_classmethod(
           "get",
-          [](py::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes,
+          [](py::object cls, std::vector<level_type> lvlTypes,
              std::optional<MlirAffineMap> dimToLvl,
              std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth,
              MlirContext context) {
@@ -64,7 +64,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) {
           "lvl_types",
           [](MlirAttribute self) {
             const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self);
-            std::vector<MlirSparseTensorLevelType> ret;
+            std::vector<level_type> ret;
             ret.reserve(lvlRank);
             for (int l = 0; l < lvlRank; ++l)
               ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l));
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 75c47a57f78af..412c5797067b7 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [<LevelType.compressed: 131072>]
+        # CHECK: lvl_types: [131072]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0) -> (d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -70,7 +70,7 @@ def testEncodingAttr2D():
         # CHECK: equal: True
         print(f"equal: {casted == parsed}")
 
-        # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>]
+        # CHECK: lvl_types: [65536, 131072]
         print(f"lvl_types: {casted.lvl_types}")
         # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
         print(f"dim_to_lvl: {casted.dim_to_lvl}")



More information about the Mlir-commits mailing list