[libcxx-commits] [libc] [compiler-rt] [mlir] [clang] [libcxx] [flang] [libcxxabi] [lld] [lldb] [llvm] [mlir][sparse] Expand LevelType to 64 bits and implement n out of m (PR #79935)
Yinying Li via libcxx-commits
libcxx-commits at lists.llvm.org
Tue Jan 30 10:03:29 PST 2024
https://github.com/yinying-lisa-li updated https://github.com/llvm/llvm-project/pull/79935
>From fa5210448dea1f88d8e0a242543ad1be655087e0 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 01:01:52 +0000
Subject: [PATCH 1/3] [mlir][sparse] Expand LevelType to 64 bit and implement n
out of m
---
mlir/include/mlir-c/Dialect/SparseTensor.h | 28 +--
.../mlir/Dialect/SparseTensor/IR/Enums.h | 225 +++++++++++-------
.../SparseTensor/IR/SparseTensorAttrDefs.td | 4 +-
.../SparseTensor/IR/SparseTensorType.h | 2 +-
.../mlir/Dialect/SparseTensor/Utils/Merger.h | 2 +-
.../ExecutionEngine/SparseTensor/Storage.h | 14 +-
.../Bindings/Python/DialectSparseTensor.cpp | 2 +-
mlir/lib/CAPI/Dialect/SparseTensor.cpp | 49 ++--
.../IR/Detail/DimLvlMapParser.cpp | 2 +
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 55 ++++-
.../SparseTensor/IR/Detail/LvlTypeParser.h | 6 +-
.../Transforms/SparseGPUCodegen.cpp | 2 +-
.../Transforms/SparseTensorCodegen.cpp | 6 +-
.../Transforms/Sparsification.cpp | 2 +-
.../Transforms/Utils/CodegenUtils.h | 2 +-
.../Transforms/Utils/SparseTensorLevel.cpp | 2 +-
.../lib/Dialect/SparseTensor/Utils/Merger.cpp | 4 +-
.../ExecutionEngine/SparseTensor/Storage.cpp | 2 +-
mlir/test/CAPI/sparse_tensor.c | 6 +-
.../SparseTensor/GPU/gpu_matmul24_lib.mlir | 2 +-
.../test/Dialect/SparseTensor/conversion.mlir | 16 +-
.../SparseTensor/roundtrip_encoding.mlir | 12 +-
.../SparseTensor/sparse_fill_zero.mlir | 12 +-
.../SparseTensor/CPU/sparse_block_matmul.mlir | 2 +-
.../Dialect/SparseTensor/CPU/sparse_ds.mlir | 2 +-
.../CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir | 2 +-
.../CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir | 2 +-
.../python/dialects/sparse_tensor/dialect.py | 148 ++++++------
28 files changed, 358 insertions(+), 255 deletions(-)
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 41d024db04964..5fc1f51452482 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,20 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
/// If updating, keep them in sync and update the static_assert in the impl
/// file.
enum MlirSparseTensorLevelType {
- MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
- MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00
+ MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536, // 0x00_00_0001_0000
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072, // 0x00_00_0002_0000
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073, // 0x00_00_0002_0001
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074, // 0x00_00_0002_0002
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075, // 0x00_00_0002_0003
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144, // 0x00_00_0004_0000
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145, // 0x00_00_0004_0001
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146, // 0x00_00_0004_0002
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147, // 0x00_00_0004_0003
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288, // 0x00_00_0008_0000
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289, // 0x00_00_0008_0001
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290, // 0x00_00_0008_0002
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003
+ MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576, // 0x00_00_0010_0000
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ac91bfa5ae622..6ddc9326179fe 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,9 +154,10 @@ enum class Action : uint32_t {
/// This enum defines all the sparse representations supportable by
/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton, loose_compressed,
-/// two-out-of-four) as well as the "properties" (ordered, unique). The
-/// encoding is chosen for performance of the runtime library, and thus may
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m for
+/// NOutOfM level type.
+/// The encoding is chosen for performance of the runtime library, and thus may
/// change in future versions; consequently, client code should use the
/// predicate functions defined below, rather than relying on knowledge
/// about the particular binary encoding.
@@ -165,41 +166,74 @@ enum class Action : uint32_t {
/// where we need to store an undefined or indeterminate `LevelType`.
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
-enum class LevelType : uint8_t {
- Undef = 0, // 0b00000_00
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- CompressedNu = 9, // 0b00010_01
- CompressedNo = 10, // 0b00010_10
- CompressedNuNo = 11, // 0b00010_11
- Singleton = 16, // 0b00100_00
- SingletonNu = 17, // 0b00100_01
- SingletonNo = 18, // 0b00100_10
- SingletonNuNo = 19, // 0b00100_11
- LooseCompressed = 32, // 0b01000_00
- LooseCompressedNu = 33, // 0b01000_01
- LooseCompressedNo = 34, // 0b01000_10
- LooseCompressedNuNo = 35, // 0b01000_11
- TwoOutOfFour = 64, // 0b10000_00
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
+enum class LevelType : uint64_t {
+ Undef = 0, // 0x00_00_0000_0000
+ Dense = 65536, // 0x00_00_0001_0000
+ Compressed = 131072, // 0x00_00_0002_0000
+ CompressedNu = 131073, // 0x00_00_0002_0001
+ CompressedNo = 131074, // 0x00_00_0002_0002
+ CompressedNuNo = 131075, // 0x00_00_0002_0003
+ Singleton = 262144, // 0x00_00_0004_0000
+ SingletonNu = 262145, // 0x00_00_0004_0001
+ SingletonNo = 262146, // 0x00_00_0004_0002
+ SingletonNuNo = 262147, // 0x00_00_0004_0003
+ LooseCompressed = 524288, // 0x00_00_0008_0000
+ LooseCompressedNu = 524289, // 0x00_00_0008_0001
+ LooseCompressedNo = 524290, // 0x00_00_0008_0002
+ LooseCompressedNuNo = 524291, // 0x00_00_0008_0003
+ NOutOfM = 1048576, // 0x00_00_0010_0000
};
/// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- Singleton = 16, // 0b00100_00
- LooseCompressed = 32, // 0b01000_00
- TwoOutOfFour = 64, // 0b10000_00
+enum class LevelFormat : uint64_t {
+ Dense = 65536, // 0x0001_0000
+ Compressed = 131072, // 0x0002_0000
+ Singleton = 262144, // 0x0004_0000
+ LooseCompressed = 524288, // 0x0008_0000
+ NOutOfM = 1048576, // 0x0010_0000
};
/// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint8_t {
- Nonunique = 1, // 0b00000_01
- Nonordered = 2, // 0b00000_10
+enum class LevelPropertyNondefault : uint64_t {
+ Nonunique = 1, // 0x0001
+ Nonordered = 2, // 0x0002
};
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+}
+
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
+
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+ return ((static_cast<uint64_t>(lt) & 0x100000) ==
+ static_cast<uint64_t>(LevelType::NOutOfM));
+}
+
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+ return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
+
/// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lt) {
+std::string toMLIRString(LevelType lt) {
switch (lt) {
case LevelType::Undef:
return "undef";
@@ -229,21 +263,28 @@ constexpr const char *toMLIRString(LevelType lt) {
return "loose_compressed(nonordered)";
case LevelType::LooseCompressedNuNo:
return "loose_compressed(nonunique, nonordered)";
- case LevelType::TwoOutOfFour:
- return "block2_4";
+ default:
+ // If NOutOfM bit is set, print the [n, m] sizes.
+ if (isNOutOfMLT(lt)) {
+ unsigned n = getN(lt);
+ unsigned m = getM(lt);
+ return std::string("block[") + std::to_string(n) + ", " +
+ std::to_string(m) + "]";
+ }
}
return "";
}
/// Check that the `LevelType` contains a valid (possibly undefined) value.
constexpr bool isValidLT(LevelType lt) {
- const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
- const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
- // If undefined or dense, then must be unique and ordered.
+ const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+ const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+ // If undefined/dense/NOutOfM, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
- return (formatBits <= 1 || formatBits == 16)
+ return (formatBits <= 0x10000 || formatBits == 0x100000)
? (propertyBits == 0)
- : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+ : (formatBits == 0x20000 || formatBits == 0x40000 ||
+ formatBits == 0x80000);
}
/// Check if the `LevelType` is the special undefined value.
@@ -251,33 +292,28 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
/// Check if the `LevelType` is dense (regardless of properties).
constexpr bool isDenseLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::Dense);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Dense);
}
/// Check if the `LevelType` is compressed (regardless of properties).
constexpr bool isCompressedLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::Compressed);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Compressed);
}
/// Check if the `LevelType` is singleton (regardless of properties).
constexpr bool isSingletonLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::Singleton);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Singleton);
}
/// Check if the `LevelType` is loose compressed (regardless of properties).
constexpr bool isLooseCompressedLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::LooseCompressed);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::LooseCompressed);
}
-/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
-constexpr bool is2OutOf4LT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::TwoOutOfFour);
-}
/// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT(LevelType lt) {
@@ -287,17 +323,17 @@ constexpr bool isWithPosLT(LevelType lt) {
/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT(LevelType lt) {
return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt);
+ isNOutOfMLT(lt);
}
/// Check if the `LevelType` is ordered (regardless of storage format).
constexpr bool isOrderedLT(LevelType lt) {
- return !(static_cast<uint8_t>(lt) & 2);
+ return !(static_cast<uint64_t>(lt) & 2);
}
/// Check if the `LevelType` is unique (regardless of storage format).
constexpr bool isUniqueLT(LevelType lt) {
- return !(static_cast<uint8_t>(lt) & 1);
+ return !(static_cast<uint64_t>(lt) & 1);
}
/// Convert a LevelType to its corresponding LevelFormat.
@@ -305,21 +341,25 @@ constexpr bool isUniqueLT(LevelType lt) {
constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
if (lt == LevelType::Undef)
return std::nullopt;
- return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
+ return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
}
/// Convert a LevelFormat to its corresponding LevelType with the given
/// properties. Returns std::nullopt when the properties are not applicable
/// for the input level format.
constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
- bool unique) {
- auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
- (ordered ? 0 : 2) | (unique ? 0 : 1));
+ bool unique, uint64_t n = 0,
+ uint64_t m = 0) {
+ uint64_t newN = n << 32;
+ uint64_t newM = m << 40;
+ auto lt =
+ static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
+ (unique ? 0 : 1) | newN | newM);
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
}
//
-// Ensure the above methods work as indended.
+// Ensure the above methods work as intended.
//
static_assert(
@@ -341,7 +381,7 @@ static_assert(
LevelFormat::LooseCompressed &&
*getLevelFormat(LevelType::LooseCompressedNuNo) ==
LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
+ *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
"getLevelFormat conversion is broken");
static_assert(
@@ -373,13 +413,28 @@ static_assert(
LevelType::LooseCompressedNo &&
*buildLevelType(LevelFormat::LooseCompressed, false, false) ==
LevelType::LooseCompressedNuNo &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
- LevelType::TwoOutOfFour),
+ buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
+ *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
"buildLevelType conversion is broken");
+static_assert(
+ (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
+ getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
+ "getN/M conversion is broken");
+
+static_assert(
+ (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
+ 2, 4) &&
+ isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
+ 8, 10) &&
+ !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
+ 2, 4)),
+ "isValidNOutOfMLT definition is broken");
+
static_assert(
(isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
@@ -391,7 +446,7 @@ static_assert(
isValidLT(LevelType::LooseCompressedNu) &&
isValidLT(LevelType::LooseCompressedNo) &&
isValidLT(LevelType::LooseCompressedNuNo) &&
- isValidLT(LevelType::TwoOutOfFour)),
+ isValidLT(LevelType::NOutOfM)),
"isValidLT definition is broken");
static_assert((isDenseLT(LevelType::Dense) &&
@@ -407,7 +462,7 @@ static_assert((isDenseLT(LevelType::Dense) &&
!isDenseLT(LevelType::LooseCompressedNu) &&
!isDenseLT(LevelType::LooseCompressedNo) &&
!isDenseLT(LevelType::LooseCompressedNuNo) &&
- !isDenseLT(LevelType::TwoOutOfFour)),
+ !isDenseLT(LevelType::NOutOfM)),
"isDenseLT definition is broken");
static_assert((!isCompressedLT(LevelType::Dense) &&
@@ -423,7 +478,7 @@ static_assert((!isCompressedLT(LevelType::Dense) &&
!isCompressedLT(LevelType::LooseCompressedNu) &&
!isCompressedLT(LevelType::LooseCompressedNo) &&
!isCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isCompressedLT(LevelType::TwoOutOfFour)),
+ !isCompressedLT(LevelType::NOutOfM)),
"isCompressedLT definition is broken");
static_assert((!isSingletonLT(LevelType::Dense) &&
@@ -439,7 +494,7 @@ static_assert((!isSingletonLT(LevelType::Dense) &&
!isSingletonLT(LevelType::LooseCompressedNu) &&
!isSingletonLT(LevelType::LooseCompressedNo) &&
!isSingletonLT(LevelType::LooseCompressedNuNo) &&
- !isSingletonLT(LevelType::TwoOutOfFour)),
+ !isSingletonLT(LevelType::NOutOfM)),
"isSingletonLT definition is broken");
static_assert((!isLooseCompressedLT(LevelType::Dense) &&
@@ -455,24 +510,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) &&
isLooseCompressedLT(LevelType::LooseCompressedNu) &&
isLooseCompressedLT(LevelType::LooseCompressedNo) &&
isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isLooseCompressedLT(LevelType::TwoOutOfFour)),
+ !isLooseCompressedLT(LevelType::NOutOfM)),
"isLooseCompressedLT definition is broken");
-static_assert((!is2OutOf4LT(LevelType::Dense) &&
- !is2OutOf4LT(LevelType::Compressed) &&
- !is2OutOf4LT(LevelType::CompressedNu) &&
- !is2OutOf4LT(LevelType::CompressedNo) &&
- !is2OutOf4LT(LevelType::CompressedNuNo) &&
- !is2OutOf4LT(LevelType::Singleton) &&
- !is2OutOf4LT(LevelType::SingletonNu) &&
- !is2OutOf4LT(LevelType::SingletonNo) &&
- !is2OutOf4LT(LevelType::SingletonNuNo) &&
- !is2OutOf4LT(LevelType::LooseCompressed) &&
- !is2OutOf4LT(LevelType::LooseCompressedNu) &&
- !is2OutOf4LT(LevelType::LooseCompressedNo) &&
- !is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
- is2OutOf4LT(LevelType::TwoOutOfFour)),
- "is2OutOf4LT definition is broken");
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+ !isNOutOfMLT(LevelType::Compressed) &&
+ !isNOutOfMLT(LevelType::CompressedNu) &&
+ !isNOutOfMLT(LevelType::CompressedNo) &&
+ !isNOutOfMLT(LevelType::CompressedNuNo) &&
+ !isNOutOfMLT(LevelType::Singleton) &&
+ !isNOutOfMLT(LevelType::SingletonNu) &&
+ !isNOutOfMLT(LevelType::SingletonNo) &&
+ !isNOutOfMLT(LevelType::SingletonNuNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressed) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+ isNOutOfMLT(LevelType::NOutOfM)),
+ "isNOutOfMLT definition is broken");
static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::Compressed) &&
@@ -487,7 +542,7 @@ static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::LooseCompressedNu) &&
!isOrderedLT(LevelType::LooseCompressedNo) &&
!isOrderedLT(LevelType::LooseCompressedNuNo) &&
- isOrderedLT(LevelType::TwoOutOfFour)),
+ isOrderedLT(LevelType::NOutOfM)),
"isOrderedLT definition is broken");
static_assert((isUniqueLT(LevelType::Dense) &&
@@ -503,7 +558,7 @@ static_assert((isUniqueLT(LevelType::Dense) &&
!isUniqueLT(LevelType::LooseCompressedNu) &&
isUniqueLT(LevelType::LooseCompressedNo) &&
!isUniqueLT(LevelType::LooseCompressedNuNo) &&
- isUniqueLT(LevelType::TwoOutOfFour)),
+ isUniqueLT(LevelType::NOutOfM)),
"isUniqueLT definition is broken");
/// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 12c1068ae1f54..299ba0e603089 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
- **compressed** : only nonzeros along this level are stored
- **loose_compressed** : as compressed, but allows for free space between regions
- **singleton** : a variant of the compressed format, where coordinates have no siblings
- - **block2_4** : the compression uses a 2:4 encoding per 1x4 block
+ - **block[2, 4]** : the compression uses a 2:4 encoding per 1x4 block
For a compressed level, each position interval is represented in a compact
way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
@@ -374,7 +374,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedLT(getLvlType(l)); }
bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonLT(getLvlType(l)); }
bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedLT(getLvlType(l)); }
- bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4LT(getLvlType(l)); }
+ bool isNOutOfMLvl(::mlir::sparse_tensor::Level l) const { return isNOutOfMLT(getLvlType(l)); }
bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedLT(getLvlType(l)); }
bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueLT(getLvlType(l)); }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4c98129744bcd..4e2b85d35c1ac 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -291,7 +291,7 @@ class SparseTensorType {
return isLooseCompressedLT(getLvlType(l));
}
bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); }
- bool is2OutOf4Lvl(Level l) const { return is2OutOf4LT(getLvlType(l)); }
+ bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); }
bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); }
bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); }
bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4a34bb2e003e8..490ef3071af1b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -510,7 +510,7 @@ class Merger {
if (isLvlWithNonTrivialIdxExp(b)) {
auto lt = getLoopDependentLevelType(b);
return isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || is2OutOf4LT(lt);
+ isLooseCompressedLT(lt) || isNOutOfMLT(lt);
}
return false;
}
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 01c5f2382ffe6..1d8d9bcfb3b2c 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -124,7 +124,7 @@ class SparseTensorStorageBase {
bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
/// Safely checks if the level uses 2 out of 4 storage.
- bool is2OutOf4Lvl(uint64_t l) const { return is2OutOf4LT(getLvlType(l)); }
+ bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
/// Safely checks if the level is ordered.
bool isOrderedLvl(uint64_t l) const { return isOrderedLT(getLvlType(l)); }
@@ -450,7 +450,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) {
if (!isDenseLvl(lvl)) {
assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) ||
- isSingletonLvl(lvl) || is2OutOf4Lvl(lvl));
+ isSingletonLvl(lvl) || isNOutOfMLvl(lvl));
coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
} else { // Dense level.
assert(crd >= full && "Coordinate was already filled");
@@ -473,7 +473,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
return positions[l][parentSz];
if (isLooseCompressedLvl(l))
return positions[l][2 * parentSz - 1];
- if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+ if (isSingletonLvl(l) || isNOutOfMLvl(l))
return parentSz; // new size same as the parent
assert(isDenseLvl(l));
return parentSz * getLvlSize(l);
@@ -527,7 +527,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t pos = coordinates[l].size();
positions[l].insert(positions[l].end(), 2 * count,
detail::checkOverflowCast<P>(pos));
- } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+ } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
return; // Nothing to finalize.
} else { // Dense dimension.
assert(isDenseLvl(l));
@@ -624,7 +624,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
toCOO(pos, l + 1, dimCoords);
}
- } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+ } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
assert(parentPos < coordinates[l].size());
lvlCursor[l] = static_cast<uint64_t>(coordinates[l][parentPos]);
toCOO(parentPos, l + 1, dimCoords);
@@ -721,7 +721,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
} else if (isSingletonLvl(l)) {
coordinates[l].reserve(sz);
sz = 1;
- } else if (is2OutOf4Lvl(l)) {
+ } else if (isNOutOfMLvl(l)) {
assert(l == lvlRank - 1 && "unexpected 2:4 usage");
sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
coordinates[l].reserve(sz);
@@ -791,7 +791,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
}
} else if (isSingletonLvl(l)) {
assert(0 && "general singleton not supported yet");
- } else if (is2OutOf4Lvl(l)) {
+ } else if (isNOutOfMLvl(l)) {
assert(0 && "2Out4 not supported yet");
} else {
assert(isDenseLvl(l));
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 8706c523988b1..f68d77dc129ad 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -25,7 +25,7 @@ using namespace mlir::python::adaptors;
static void populateDialectSparseTensorSubmodule(const py::module &m) {
py::enum_<MlirSparseTensorLevelType>(m, "LevelType", py::module_local())
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
- .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
+ .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
.value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
.value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e4534ad132385..a34b9a29b0e90 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -20,25 +20,36 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
mlir::sparse_tensor::SparseTensorDialect)
// Ensure the C-API enums are int-castable to C++ equivalents.
-static_assert(static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
- static_cast<int>(LevelType::Dense) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
- static_cast<int>(LevelType::Compressed) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
- static_cast<int>(LevelType::CompressedNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
- static_cast<int>(LevelType::CompressedNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
- static_cast<int>(LevelType::CompressedNuNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
- static_cast<int>(LevelType::Singleton) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
- static_cast<int>(LevelType::SingletonNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
- static_cast<int>(LevelType::SingletonNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
- static_cast<int>(LevelType::SingletonNuNo),
- "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
+static_assert(
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
+ static_cast<int>(LevelType::Dense) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
+ static_cast<int>(LevelType::Compressed) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
+ static_cast<int>(LevelType::CompressedNu) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
+ static_cast<int>(LevelType::CompressedNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
+ static_cast<int>(LevelType::CompressedNuNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
+ static_cast<int>(LevelType::Singleton) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
+ static_cast<int>(LevelType::SingletonNu) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
+ static_cast<int>(LevelType::SingletonNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
+ static_cast<int>(LevelType::SingletonNuNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
+ static_cast<int>(LevelType::LooseCompressed) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) ==
+ static_cast<int>(LevelType::LooseCompressedNu) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) ==
+ static_cast<int>(LevelType::LooseCompressedNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) ==
+ static_cast<int>(LevelType::LooseCompressedNuNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
+ static_cast<int>(LevelType::NOutOfM),
+ "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
return isa<SparseTensorEncodingAttr>(unwrap(attr));
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 56b435c57d30a..95874d4857fc8 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -299,6 +299,8 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
FAILURE_IF_FAILED(type)
lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
+ llvm::errs() << "type = " << toMLIRString(static_cast<LevelType>(*type))
+ << "\n";
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index eb7ea63a4e88b..14ebe14b49f64 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -29,12 +29,21 @@ using namespace mlir::sparse_tensor::ir_detail;
// `LvlTypeParser` implementation.
//===----------------------------------------------------------------------===//
-FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
+FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
StringRef base;
const auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
"expected valid level format (e.g. dense, compressed or singleton)")
- uint8_t properties = 0;
+ uint64_t properties = 0;
+ SmallVector<unsigned> blockSizes;
+
+ if (base.compare("block") == 0) {
+ ParseResult res = parser.parseCommaSeparatedList(
+ mlir::OpAsmParser::Delimiter::OptionalSquare,
+ [&]() -> ParseResult { return parseBlockSize(parser, &blockSizes); },
+ " in block n out of m");
+ FAILURE_IF_FAILED(res)
+ }
ParseResult res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -44,15 +53,21 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
// Set the base bit for properties.
if (base.compare("dense") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Dense);
+ properties |= static_cast<uint64_t>(LevelFormat::Dense);
} else if (base.compare("compressed") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Compressed);
- } else if (base.compare("block2_4") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::TwoOutOfFour);
+ properties |= static_cast<uint64_t>(LevelFormat::Compressed);
+ } else if (base.compare("block") == 0) {
+ if (blockSizes.size() != 2) {
+ parser.emitError(loc, "expected exactly 2 block sizes");
+ return failure();
+ }
+ properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
+ properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]);
+ llvm::errs() << "properties1: " << properties << "\n";
} else if (base.compare("loose_compressed") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::LooseCompressed);
+ properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
} else if (base.compare("singleton") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Singleton);
+ properties |= static_cast<uint64_t>(LevelFormat::Singleton);
} else {
parser.emitError(loc, "unknown level format: ") << base;
return failure();
@@ -64,15 +79,15 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
}
ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
- uint8_t *properties) const {
+ uint64_t *properties) const {
StringRef strVal;
auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
"expected valid level property (e.g. nonordered, nonunique or high)")
if (strVal.compare("nonunique") == 0) {
- *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonunique);
+ *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
} else if (strVal.compare("nonordered") == 0) {
- *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonordered);
+ *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
} else {
parser.emitError(loc, "unknown level property: ") << strVal;
return failure();
@@ -80,4 +95,22 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
return success();
}
+ParseResult
+LvlTypeParser::parseBlockSize(AsmParser &parser,
+ SmallVector<unsigned> *blockSizes) const {
+ int intVal;
+ auto loc = parser.getCurrentLocation();
+ OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
+ if (intValParseResult.has_value()) {
+ if (failed(*intValParseResult)) {
+ parser.emitError(loc, "failed to parse block size");
+ return failure();
+ }
+ blockSizes->push_back(intVal);
+ return success();
+ }
+ parser.emitError(loc, "expected valid integer for block size");
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 5e2f11b75d4da..78ae667f97923 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -18,10 +18,12 @@ namespace ir_detail {
class LvlTypeParser {
public:
LvlTypeParser() = default;
- FailureOr<uint8_t> parseLvlType(AsmParser &parser) const;
+ FailureOr<uint64_t> parseLvlType(AsmParser &parser) const;
private:
- ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const;
+ ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
+ ParseResult parseBlockSize(AsmParser &parser,
+ SmallVector<unsigned> *blockSizes) const;
};
} // namespace ir_detail
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 87a37a7926e9e..23676eccdfb28 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -451,7 +451,7 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
/// Test for 2:4 matrix with suitable metadata.
static bool isAdmissible24(SparseTensorType &aTp) {
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
- aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
+ aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
}
/// Test for conversion into 2:4 matrix.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 491501a3381b9..d4459c6ea1e52 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -130,7 +130,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
/*value=*/posZero, /*repeat=*/linear);
return;
- } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+ } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
return; // nothing to do
}
// Keep compounding the size, but nothing needs to be initialized
@@ -409,7 +409,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
}
} else {
assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
- is2OutOf4LT(lt));
+ isNOutOfMLT(lt));
}
}
}
@@ -488,7 +488,7 @@ class SparseInsertGenerator
}
parentPos =
genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
- } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+ } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
// Create:
// coordinates[lvl].push_back(coords[lvl])
// positions[lvl] = positions[lvl-1]
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 5266ca7213bfc..cc39f21001168 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -891,7 +891,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
assert(curr == env.merger().loop(b));
Value clause;
if (isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
+ isLooseCompressedLT(lt) || isNOutOfMLT(lt)) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoord(tid, *lvl);
const Value lvar = env.getLoopVar(curr);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index 8d54b5959d871..cc119bc704559 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -423,7 +423,7 @@ inline Value constantPrimaryTypeEncoding(OpBuilder &builder, Location loc,
/// Generates a constant of the internal dimension level type encoding.
inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
LevelType lt) {
- return constantI8(builder, loc, static_cast<uint8_t>(lt));
+ return constantI64(builder, loc, static_cast<uint64_t>(lt));
}
inline bool isZeroRankedTensorOrScalar(Type type) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index e43896942d7fe..14051fe631f09 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -1221,7 +1221,7 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value crd = genToCoordinates(b, l, t, lvl);
return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
}
- case LevelFormat::TwoOutOfFour: {
+ case LevelFormat::NOutOfM: {
Value crd = genToCoordinates(b, l, t, lvl);
return std::make_unique<TwoOutFourLevel>(tid, lvl, lt, sz, crd);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 6cdf5f8c0168b..96537cbb0c483 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -489,7 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto lt = getLvlType(b);
if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
- !isLooseCompressedLT(lt) && !is2OutOf4LT(lt)) {
+ !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
if (reset)
simple.reset(b);
reset = true;
@@ -670,7 +670,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto lt = getLvlType(b);
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt))
+ isNOutOfMLT(lt))
return true;
}
return hasSparseIdxReduction(bits);
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 0c7b3a228a65c..9e8b240899d80 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -45,7 +45,7 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
for (uint64_t l = 0; l < lvlRank; l++) {
assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
- isSingletonLvl(l) || is2OutOf4Lvl(l));
+ isSingletonLvl(l) || isNOutOfMLvl(l));
}
}
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index b0bc9bb6e881a..ea4c56b7ec0c5 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -37,9 +37,9 @@ static int testRoundtripEncoding(MlirContext ctx) {
mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr);
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(dimToLvl);
- // CHECK: level_type: 4
- // CHECK: level_type: 8
- // CHECK: level_type: 8
+ // CHECK: level_type: 65536
+ // CHECK: level_type: 131072
+ // CHECK: level_type: 131072
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index 6fe7ec906f30e..e4884a9bf393f 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -4,7 +4,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
)
}>
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index e4e825bf85043..465f210862660 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -78,8 +78,8 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
-// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi64>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi64> to memref<?xi64>
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -96,8 +96,8 @@ func.func @sparse_new1d(%arg0: !llvm.ptr) -> tensor<128xf64, #SparseVector> {
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
@@ -114,8 +114,8 @@ func.func @sparse_new2d(%arg0: !llvm.ptr) -> tensor<?x?xf32, #CSR> {
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi64>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi64> to memref<?xi64>
// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
@@ -136,10 +136,10 @@ func.func @sparse_new3d(%arg0: !llvm.ptr) -> tensor<?x?x?xf32, #SparseTensor> {
// CHECK-DAG: %[[Empty:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi64>
// CHECK-DAG: %[[Sizes0:.*]] = memref.alloca() : memref<2xindex>
// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi64> to memref<?xi64>
// CHECK-DAG: %[[Sizes:.*]] = memref.cast %[[Sizes0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: memref.store %[[I]], %[[Sizes0]][%[[C0]]] : memref<2xindex>
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 20702bb985028..966a7ff2d38e1 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -207,12 +207,12 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
),
crdWidth = 8 // we would even like just 2-bits
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block2_4), crdWidth = 8 }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block[2, 4]), crdWidth = 8 }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
@@ -226,11 +226,11 @@ func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
( i : dense,
j : dense,
k floordiv 4 : dense,
- k mod 4 : block2_4
+ k mod 4 : block[2, 4]
)
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block[2, 4]) }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
@@ -244,11 +244,11 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
( i : dense,
k floordiv 4 : dense,
j : dense,
- k mod 4 : block2_4
+ k mod 4 : block[2, 4]
)
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block[2, 4]) }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 40367f12f85a4..d04fbe8ed5c22 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,11 +14,11 @@
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i8
-// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi8>
-// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi8> to memref<?xi8>
-// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi8>
-// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi8>
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 131072 : i64
+// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
+// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
+// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
+// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_6]]] : memref<2xi64>
// CHECK: %[[VAL_14:.*]] = memref.alloca() : memref<2xindex>
// CHECK: %[[VAL_15:.*]] = memref.cast %[[VAL_14]] : memref<2xindex> to memref<?xindex>
// CHECK: memref.store %[[VAL_9]], %[[VAL_14]]{{\[}}%[[VAL_5]]] : memref<2xindex>
@@ -28,7 +28,7 @@
// CHECK: memref.store %[[VAL_5]], %[[VAL_16]]{{\[}}%[[VAL_5]]] : memref<2xindex>
// CHECK: memref.store %[[VAL_6]], %[[VAL_16]]{{\[}}%[[VAL_6]]] : memref<2xindex>
// CHECK: %[[VAL_18:.*]] = llvm.mlir.zero : !llvm.ptr
-// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi8>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
+// CHECK: %[[VAL_19:.*]] = call @newSparseTensor(%[[VAL_15]], %[[VAL_15]], %[[VAL_13]], %[[VAL_17]], %[[VAL_17]], %[[VAL_4]], %[[VAL_4]], %[[VAL_3]], %[[VAL_4]], %[[VAL_18]]) : (memref<?xindex>, memref<?xindex>, memref<?xi64>, memref<?xindex>, memref<?xindex>, i32, i32, i32, i32, !llvm.ptr) -> !llvm.ptr
// CHECK: %[[VAL_20:.*]] = memref.alloc() : memref<300xf64>
// CHECK: %[[VAL_21:.*]] = memref.cast %[[VAL_20]] : memref<300xf64> to memref<?xf64>
// CHECK: %[[VAL_22:.*]] = memref.alloc() : memref<300xi1>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 4bc080fc538fc..554d6207aef7e 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -59,7 +59,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
),
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
index df5b48a3b6ece..9935d7c69e63a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
@@ -41,7 +41,7 @@
#NV_24 = #sparse_tensor.encoding<{
map = ( i, j ) -> ( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4),
+ j mod 4 : block[2, 4]),
crdWidth = 8
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index 17b50b46d073a..25454f5c06b45 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -20,7 +20,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
)
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index eb99a027a8860..da735b4a3b58a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -20,7 +20,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : block[2, 4]
)
}>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 88a5595d75aea..e9296b961e7fe 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -13,85 +13,85 @@ def run(f):
# CHECK-LABEL: TEST: testEncodingAttr1D
@run
def testEncodingAttr1D():
- with Context() as ctx:
- parsed = Attribute.parse(
- "#sparse_tensor.encoding<{"
- " map = (d0) -> (d0 : compressed),"
- " posWidth = 16,"
- " crdWidth = 32"
- "}>"
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
-
- # CHECK: lvl_types: [<LevelType.compressed: 8>]
- print(f"lvl_types: {casted.lvl_types}")
- # CHECK: dim_to_lvl: (d0) -> (d0)
- print(f"dim_to_lvl: {casted.dim_to_lvl}")
- # CHECK: lvl_to_dim: (d0) -> (d0)
- print(f"lvl_to_dim: {casted.lvl_to_dim}")
- # CHECK: pos_width: 16
- print(f"pos_width: {casted.pos_width}")
- # CHECK: crd_width: 32
- print(f"crd_width: {casted.crd_width}")
-
- created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
- print(created)
- # CHECK: created_equal: False
- print(f"created_equal: {created == casted}")
-
- # Verify that the factory creates an instance of the proper type.
- # CHECK: is_proper_instance: True
- print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
- # CHECK: created_pos_width: 0
- print(f"created_pos_width: {created.pos_width}")
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ " map = (d0) -> (d0 : compressed),"
+ " posWidth = 16,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
+
+ # CHECK: lvl_types: [<LevelType.compressed: 131072>]
+ print(f"lvl_types: {casted.lvl_types}")
+ # CHECK: dim_to_lvl: (d0) -> (d0)
+ print(f"dim_to_lvl: {casted.dim_to_lvl}")
+ # CHECK: lvl_to_dim: (d0) -> (d0)
+ print(f"lvl_to_dim: {casted.lvl_to_dim}")
+ # CHECK: pos_width: 16
+ print(f"pos_width: {casted.pos_width}")
+ # CHECK: crd_width: 32
+ print(f"crd_width: {casted.crd_width}")
+
+ created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+ print(created)
+ # CHECK: created_equal: False
+ print(f"created_equal: {created == casted}")
+
+ # Verify that the factory creates an instance of the proper type.
+ # CHECK: is_proper_instance: True
+ print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+ # CHECK: created_pos_width: 0
+ print(f"created_pos_width: {created.pos_width}")
# CHECK-LABEL: TEST: testEncodingAttr2D
@run
def testEncodingAttr2D():
- with Context() as ctx:
- parsed = Attribute.parse(
- "#sparse_tensor.encoding<{"
- " map = (d0, d1) -> (d1 : dense, d0 : compressed),"
- " posWidth = 8,"
- " crdWidth = 32"
- "}>"
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
-
- # CHECK: lvl_types: [<LevelType.dense: 4>, <LevelType.compressed: 8>]
- print(f"lvl_types: {casted.lvl_types}")
- # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
- print(f"dim_to_lvl: {casted.dim_to_lvl}")
- # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
- print(f"lvl_to_dim: {casted.lvl_to_dim}")
- # CHECK: pos_width: 8
- print(f"pos_width: {casted.pos_width}")
- # CHECK: crd_width: 32
- print(f"crd_width: {casted.crd_width}")
-
- created = st.EncodingAttr.get(
- casted.lvl_types,
- casted.dim_to_lvl,
- casted.lvl_to_dim,
- 8,
- 32,
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
- print(created)
- # CHECK: created_equal: True
- print(f"created_equal: {created == casted}")
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ " map = (d0, d1) -> (d1 : dense, d0 : compressed),"
+ " posWidth = 8,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
+
+ # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>]
+ print(f"lvl_types: {casted.lvl_types}")
+ # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
+ print(f"dim_to_lvl: {casted.dim_to_lvl}")
+ # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
+ print(f"lvl_to_dim: {casted.lvl_to_dim}")
+ # CHECK: pos_width: 8
+ print(f"pos_width: {casted.pos_width}")
+ # CHECK: crd_width: 32
+ print(f"crd_width: {casted.crd_width}")
+
+ created = st.EncodingAttr.get(
+ casted.lvl_types,
+ casted.dim_to_lvl,
+ casted.lvl_to_dim,
+ 8,
+ 32,
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+ print(created)
+ # CHECK: created_equal: True
+ print(f"created_equal: {created == casted}")
# CHECK-LABEL: TEST: testEncodingAttrOnTensorType
>From a91ee4a2701822a50dc048563740a60cc00caf05 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 02:58:18 +0000
Subject: [PATCH 2/3] format
---
.../include/mlir/Dialect/SparseTensor/IR/Enums.h | 8 ++------
.../SparseTensor/IR/Detail/DimLvlMapParser.cpp | 2 --
.../SparseTensor/IR/Detail/LvlTypeParser.cpp | 1 -
.../SparseTensor/IR/SparseTensorDialect.cpp | 16 ++++++++++++++--
4 files changed, 16 insertions(+), 11 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 6ddc9326179fe..15802a5ad3563 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -233,7 +233,7 @@ constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
}
/// Returns string representation of the given dimension level type.
-std::string toMLIRString(LevelType lt) {
+constexpr const char *toMLIRString(LevelType lt) {
switch (lt) {
case LevelType::Undef:
return "undef";
@@ -264,12 +264,8 @@ std::string toMLIRString(LevelType lt) {
case LevelType::LooseCompressedNuNo:
return "loose_compressed(nonunique, nonordered)";
default:
- // If NOutOfM bit is set, print the [n, m] sizes.
if (isNOutOfMLT(lt)) {
- unsigned n = getN(lt);
- unsigned m = getM(lt);
- return std::string("block[") + std::to_string(n) + ", " +
- std::to_string(m) + "]";
+ return "block";
}
}
return "";
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
index 95874d4857fc8..56b435c57d30a 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp
@@ -299,8 +299,6 @@ ParseResult DimLvlMapParser::parseLvlSpec(bool requireLvlVarBinding) {
FAILURE_IF_FAILED(type)
lvlSpecs.emplace_back(var, expr, static_cast<LevelType>(*type));
- llvm::errs() << "type = " << toMLIRString(static_cast<LevelType>(*type))
- << "\n";
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index 14ebe14b49f64..993ad9be8a012 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -63,7 +63,6 @@ FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
}
properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
properties |= nToBits(blockSizes[0]) | mToBits(blockSizes[1]);
- llvm::errs() << "properties1: " << properties << "\n";
} else if (base.compare("loose_compressed") == 0) {
properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
} else if (base.compare("singleton") == 0) {
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 6033ebf6897ce..d56d90a2d6130 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -613,16 +613,28 @@ void SparseTensorEncodingAttr::printDimensions(
}
}
+std::string getNOutOfMString(LevelType lt) {
+ if (isNOutOfMLT(lt)) {
+ unsigned n = getN(lt);
+ unsigned m = getM(lt);
+ auto output = "[" + std::to_string(n) + ", " + std::to_string(m) + "]";
+ return output;
+ }
+ return "";
+}
+
void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
ArrayRef<LevelType> lvlTypes) const {
for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
map.getResult(i).print(printer.getStream());
- printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
+ printer << " : " << toMLIRString(lvlTypes[i])
+ << getNOutOfMString(lvlTypes[i]) << ", ";
}
if (map.getNumResults() >= 1) {
auto lastIndex = map.getNumResults() - 1;
map.getResult(lastIndex).print(printer.getStream());
- printer << " : " << toMLIRString(lvlTypes[lastIndex]);
+ printer << " : " << toMLIRString(lvlTypes[lastIndex])
+ << getNOutOfMString(lvlTypes[lastIndex]);
}
}
>From 8f669440ab9e817cf628f81f274e45859cec7f68 Mon Sep 17 00:00:00 2001
From: Yinying Li <yinyingli at google.com>
Date: Tue, 30 Jan 2024 03:13:30 +0000
Subject: [PATCH 3/3] python format
---
.../mlir/Dialect/SparseTensor/IR/Enums.h | 1 -
.../python/dialects/sparse_tensor/dialect.py | 148 +++++++++---------
2 files changed, 74 insertions(+), 75 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 15802a5ad3563..99443957d01d5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -310,7 +310,6 @@ constexpr bool isLooseCompressedLT(LevelType lt) {
static_cast<uint64_t>(LevelType::LooseCompressed);
}
-
/// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT(LevelType lt) {
return isCompressedLT(lt) || isLooseCompressedLT(lt);
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index e9296b961e7fe..75c47a57f78af 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -13,85 +13,85 @@ def run(f):
# CHECK-LABEL: TEST: testEncodingAttr1D
@run
def testEncodingAttr1D():
- with Context() as ctx:
- parsed = Attribute.parse(
- "#sparse_tensor.encoding<{"
- " map = (d0) -> (d0 : compressed),"
- " posWidth = 16,"
- " crdWidth = 32"
- "}>"
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
-
- # CHECK: lvl_types: [<LevelType.compressed: 131072>]
- print(f"lvl_types: {casted.lvl_types}")
- # CHECK: dim_to_lvl: (d0) -> (d0)
- print(f"dim_to_lvl: {casted.dim_to_lvl}")
- # CHECK: lvl_to_dim: (d0) -> (d0)
- print(f"lvl_to_dim: {casted.lvl_to_dim}")
- # CHECK: pos_width: 16
- print(f"pos_width: {casted.pos_width}")
- # CHECK: crd_width: 32
- print(f"crd_width: {casted.crd_width}")
-
- created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
- # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
- print(created)
- # CHECK: created_equal: False
- print(f"created_equal: {created == casted}")
-
- # Verify that the factory creates an instance of the proper type.
- # CHECK: is_proper_instance: True
- print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
- # CHECK: created_pos_width: 0
- print(f"created_pos_width: {created.pos_width}")
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ " map = (d0) -> (d0 : compressed),"
+ " posWidth = 16,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
+
+ # CHECK: lvl_types: [<LevelType.compressed: 131072>]
+ print(f"lvl_types: {casted.lvl_types}")
+ # CHECK: dim_to_lvl: (d0) -> (d0)
+ print(f"dim_to_lvl: {casted.dim_to_lvl}")
+ # CHECK: lvl_to_dim: (d0) -> (d0)
+ print(f"lvl_to_dim: {casted.lvl_to_dim}")
+ # CHECK: pos_width: 16
+ print(f"pos_width: {casted.pos_width}")
+ # CHECK: crd_width: 32
+ print(f"crd_width: {casted.crd_width}")
+
+ created = st.EncodingAttr.get(casted.lvl_types, None, None, 0, 0)
+ # CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
+ print(created)
+ # CHECK: created_equal: False
+ print(f"created_equal: {created == casted}")
+
+ # Verify that the factory creates an instance of the proper type.
+ # CHECK: is_proper_instance: True
+ print(f"is_proper_instance: {isinstance(created, st.EncodingAttr)}")
+ # CHECK: created_pos_width: 0
+ print(f"created_pos_width: {created.pos_width}")
# CHECK-LABEL: TEST: testEncodingAttr2D
@run
def testEncodingAttr2D():
- with Context() as ctx:
- parsed = Attribute.parse(
- "#sparse_tensor.encoding<{"
- " map = (d0, d1) -> (d1 : dense, d0 : compressed),"
- " posWidth = 8,"
- " crdWidth = 32"
- "}>"
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
- print(parsed)
-
- casted = st.EncodingAttr(parsed)
- # CHECK: equal: True
- print(f"equal: {casted == parsed}")
-
- # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>]
- print(f"lvl_types: {casted.lvl_types}")
- # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
- print(f"dim_to_lvl: {casted.dim_to_lvl}")
- # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
- print(f"lvl_to_dim: {casted.lvl_to_dim}")
- # CHECK: pos_width: 8
- print(f"pos_width: {casted.pos_width}")
- # CHECK: crd_width: 32
- print(f"crd_width: {casted.crd_width}")
-
- created = st.EncodingAttr.get(
- casted.lvl_types,
- casted.dim_to_lvl,
- casted.lvl_to_dim,
- 8,
- 32,
- )
- # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
- print(created)
- # CHECK: created_equal: True
- print(f"created_equal: {created == casted}")
+ with Context() as ctx:
+ parsed = Attribute.parse(
+ "#sparse_tensor.encoding<{"
+ " map = (d0, d1) -> (d1 : dense, d0 : compressed),"
+ " posWidth = 8,"
+ " crdWidth = 32"
+ "}>"
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+ print(parsed)
+
+ casted = st.EncodingAttr(parsed)
+ # CHECK: equal: True
+ print(f"equal: {casted == parsed}")
+
+ # CHECK: lvl_types: [<LevelType.dense: 65536>, <LevelType.compressed: 131072>]
+ print(f"lvl_types: {casted.lvl_types}")
+ # CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
+ print(f"dim_to_lvl: {casted.dim_to_lvl}")
+ # CHECK: lvl_to_dim: (d0, d1) -> (d1, d0)
+ print(f"lvl_to_dim: {casted.lvl_to_dim}")
+ # CHECK: pos_width: 8
+ print(f"pos_width: {casted.pos_width}")
+ # CHECK: crd_width: 32
+ print(f"crd_width: {casted.crd_width}")
+
+ created = st.EncodingAttr.get(
+ casted.lvl_types,
+ casted.dim_to_lvl,
+ casted.lvl_to_dim,
+ 8,
+ 32,
+ )
+ # CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
+ print(created)
+ # CHECK: created_equal: True
+ print(f"created_equal: {created == casted}")
# CHECK-LABEL: TEST: testEncodingAttrOnTensorType
More information about the libcxx-commits
mailing list