[Mlir-commits] [mlir] [mlir][sparse] Expand LevelType to 64 bits and implement parsing n out of m (PR #79935)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 30 12:14:28 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Yinying Li (yinying-lisa-li)
<details>
<summary>Changes</summary>
---
Patch is 57.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79935.diff
28 Files Affected:
- (modified) mlir/include/mlir-c/Dialect/SparseTensor.h (+14-14)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h (+135-85)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td (+2-2)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+1-1)
- (modified) mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h (+1-1)
- (modified) mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h (+9-9)
- (modified) mlir/lib/Bindings/Python/DialectSparseTensor.cpp (+1-1)
- (modified) mlir/lib/CAPI/Dialect/SparseTensor.cpp (+30-19)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp (+43-11)
- (modified) mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h (+4-2)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+14-2)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+3-3)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp (+2-2)
- (modified) mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp (+1-1)
- (modified) mlir/test/CAPI/sparse_tensor.c (+3-3)
- (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (+1-1)
- (modified) mlir/test/Dialect/SparseTensor/conversion.mlir (+8-8)
- (modified) mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir (+24-6)
- (modified) mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir (+6-6)
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir (+1-1)
- (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir (+1-1)
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir (+1-1)
- (modified) mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir (+1-1)
- (modified) mlir/test/python/dialects/sparse_tensor/dialect.py (+2-2)
``````````diff
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 41d024db04964..5fc1f51452482 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -26,20 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
/// If updating, keep them in sync and update the static_assert in the impl
/// file.
enum MlirSparseTensorLevelType {
- MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
- MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00
+ MLIR_SPARSE_TENSOR_LEVEL_DENSE = 65536, // 0x00_00_0001_0000
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 131072, // 0x00_00_0002_0000
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 131073, // 0x00_00_0002_0001
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 131074, // 0x00_00_0002_0002
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 131075, // 0x00_00_0002_0003
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 262144, // 0x00_00_0004_0000
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 262145, // 0x00_00_0004_0001
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 262146, // 0x00_00_0004_0002
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 262147, // 0x00_00_0004_0003
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 524288, // 0x00_00_0008_0000
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 524289, // 0x00_00_0008_0001
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 524290, // 0x00_00_0008_0002
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 524291, // 0x00_00_0008_0003
+ MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 1048576, // 0x00_00_0010_0000
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index ac91bfa5ae622..99443957d01d5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,9 +154,10 @@ enum class Action : uint32_t {
/// This enum defines all the sparse representations supportable by
/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton, loose_compressed,
-/// two-out-of-four) as well as the "properties" (ordered, unique). The
-/// encoding is chosen for performance of the runtime library, and thus may
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m for
+/// NOutOfM level type.
+/// The encoding is chosen for performance of the runtime library, and thus may
/// change in future versions; consequently, client code should use the
/// predicate functions defined below, rather than relying on knowledge
/// about the particular binary encoding.
@@ -165,39 +166,72 @@ enum class Action : uint32_t {
/// where we need to store an undefined or indeterminate `LevelType`.
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
-enum class LevelType : uint8_t {
- Undef = 0, // 0b00000_00
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- CompressedNu = 9, // 0b00010_01
- CompressedNo = 10, // 0b00010_10
- CompressedNuNo = 11, // 0b00010_11
- Singleton = 16, // 0b00100_00
- SingletonNu = 17, // 0b00100_01
- SingletonNo = 18, // 0b00100_10
- SingletonNuNo = 19, // 0b00100_11
- LooseCompressed = 32, // 0b01000_00
- LooseCompressedNu = 33, // 0b01000_01
- LooseCompressedNo = 34, // 0b01000_10
- LooseCompressedNuNo = 35, // 0b01000_11
- TwoOutOfFour = 64, // 0b10000_00
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
+enum class LevelType : uint64_t {
+ Undef = 0, // 0x00_00_0000_0000
+ Dense = 65536, // 0x00_00_0001_0000
+ Compressed = 131072, // 0x00_00_0002_0000
+ CompressedNu = 131073, // 0x00_00_0002_0001
+ CompressedNo = 131074, // 0x00_00_0002_0002
+ CompressedNuNo = 131075, // 0x00_00_0002_0003
+ Singleton = 262144, // 0x00_00_0004_0000
+ SingletonNu = 262145, // 0x00_00_0004_0001
+ SingletonNo = 262146, // 0x00_00_0004_0002
+ SingletonNuNo = 262147, // 0x00_00_0004_0003
+ LooseCompressed = 524288, // 0x00_00_0008_0000
+ LooseCompressedNu = 524289, // 0x00_00_0008_0001
+ LooseCompressedNo = 524290, // 0x00_00_0008_0002
+ LooseCompressedNuNo = 524291, // 0x00_00_0008_0003
+ NOutOfM = 1048576, // 0x00_00_0010_0000
};
/// This enum defines all supported storage format without the level properties.
-enum class LevelFormat : uint8_t {
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- Singleton = 16, // 0b00100_00
- LooseCompressed = 32, // 0b01000_00
- TwoOutOfFour = 64, // 0b10000_00
+enum class LevelFormat : uint64_t {
+ Dense = 65536, // 0x0001_0000
+ Compressed = 131072, // 0x0002_0000
+ Singleton = 262144, // 0x0004_0000
+ LooseCompressed = 524288, // 0x0008_0000
+ NOutOfM = 1048576, // 0x0010_0000
};
/// This enum defines all the nondefault properties for storage formats.
-enum class LevelPropertyNondefault : uint8_t {
- Nonunique = 1, // 0b00000_01
- Nonordered = 2, // 0b00000_10
+enum class LevelPropertyNondefault : uint64_t {
+ Nonunique = 1, // 0x0001
+ Nonordered = 2, // 0x0002
};
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+}
+
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
+
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+ return ((static_cast<uint64_t>(lt) & 0x100000) ==
+ static_cast<uint64_t>(LevelType::NOutOfM));
+}
+
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+ return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
+
/// Returns string representation of the given dimension level type.
constexpr const char *toMLIRString(LevelType lt) {
switch (lt) {
@@ -229,21 +263,24 @@ constexpr const char *toMLIRString(LevelType lt) {
return "loose_compressed(nonordered)";
case LevelType::LooseCompressedNuNo:
return "loose_compressed(nonunique, nonordered)";
- case LevelType::TwoOutOfFour:
- return "block2_4";
+ default:
+ if (isNOutOfMLT(lt)) {
+ return "block";
+ }
}
return "";
}
/// Check that the `LevelType` contains a valid (possibly undefined) value.
constexpr bool isValidLT(LevelType lt) {
- const uint8_t formatBits = static_cast<uint8_t>(lt) >> 2;
- const uint8_t propertyBits = static_cast<uint8_t>(lt) & 3;
- // If undefined or dense, then must be unique and ordered.
+ const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+ const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+ // If undefined/dense/NOutOfM, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
- return (formatBits <= 1 || formatBits == 16)
+ return (formatBits <= 0x10000 || formatBits == 0x100000)
? (propertyBits == 0)
- : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+ : (formatBits == 0x20000 || formatBits == 0x40000 ||
+ formatBits == 0x80000);
}
/// Check if the `LevelType` is the special undefined value.
@@ -251,32 +288,26 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
/// Check if the `LevelType` is dense (regardless of properties).
constexpr bool isDenseLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::Dense);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Dense);
}
/// Check if the `LevelType` is compressed (regardless of properties).
constexpr bool isCompressedLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::Compressed);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Compressed);
}
/// Check if the `LevelType` is singleton (regardless of properties).
constexpr bool isSingletonLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::Singleton);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::Singleton);
}
/// Check if the `LevelType` is loose compressed (regardless of properties).
constexpr bool isLooseCompressedLT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::LooseCompressed);
-}
-
-/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
-constexpr bool is2OutOf4LT(LevelType lt) {
- return (static_cast<uint8_t>(lt) & ~3) ==
- static_cast<uint8_t>(LevelType::TwoOutOfFour);
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
+ static_cast<uint64_t>(LevelType::LooseCompressed);
}
/// Check if the `LevelType` needs positions array.
@@ -287,17 +318,17 @@ constexpr bool isWithPosLT(LevelType lt) {
/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT(LevelType lt) {
return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt);
+ isNOutOfMLT(lt);
}
/// Check if the `LevelType` is ordered (regardless of storage format).
constexpr bool isOrderedLT(LevelType lt) {
- return !(static_cast<uint8_t>(lt) & 2);
+ return !(static_cast<uint64_t>(lt) & 2);
}
/// Check if the `LevelType` is unique (regardless of storage format).
constexpr bool isUniqueLT(LevelType lt) {
- return !(static_cast<uint8_t>(lt) & 1);
+ return !(static_cast<uint64_t>(lt) & 1);
}
/// Convert a LevelType to its corresponding LevelFormat.
@@ -305,21 +336,25 @@ constexpr bool isUniqueLT(LevelType lt) {
constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
if (lt == LevelType::Undef)
return std::nullopt;
- return static_cast<LevelFormat>(static_cast<uint8_t>(lt) & ~3);
+ return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
}
/// Convert a LevelFormat to its corresponding LevelType with the given
/// properties. Returns std::nullopt when the properties are not applicable
/// for the input level format.
constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
- bool unique) {
- auto lt = static_cast<LevelType>(static_cast<uint8_t>(lf) |
- (ordered ? 0 : 2) | (unique ? 0 : 1));
+ bool unique, uint64_t n = 0,
+ uint64_t m = 0) {
+ uint64_t newN = n << 32;
+ uint64_t newM = m << 40;
+ auto lt =
+ static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
+ (unique ? 0 : 1) | newN | newM);
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
}
//
-// Ensure the above methods work as indended.
+// Ensure the above methods work as intended.
//
static_assert(
@@ -341,7 +376,7 @@ static_assert(
LevelFormat::LooseCompressed &&
*getLevelFormat(LevelType::LooseCompressedNuNo) ==
LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
+ *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
"getLevelFormat conversion is broken");
static_assert(
@@ -373,13 +408,28 @@ static_assert(
LevelType::LooseCompressedNo &&
*buildLevelType(LevelFormat::LooseCompressed, false, false) ==
LevelType::LooseCompressedNuNo &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
- LevelType::TwoOutOfFour),
+ buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
+ *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
"buildLevelType conversion is broken");
+static_assert(
+ (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
+ getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
+ "getN/M conversion is broken");
+
+static_assert(
+ (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
+ 2, 4) &&
+ isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
+ 8, 10) &&
+ !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
+ 2, 4)),
+ "isValidNOutOfMLT definition is broken");
+
static_assert(
(isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
@@ -391,7 +441,7 @@ static_assert(
isValidLT(LevelType::LooseCompressedNu) &&
isValidLT(LevelType::LooseCompressedNo) &&
isValidLT(LevelType::LooseCompressedNuNo) &&
- isValidLT(LevelType::TwoOutOfFour)),
+ isValidLT(LevelType::NOutOfM)),
"isValidLT definition is broken");
static_assert((isDenseLT(LevelType::Dense) &&
@@ -407,7 +457,7 @@ static_assert((isDenseLT(LevelType::Dense) &&
!isDenseLT(LevelType::LooseCompressedNu) &&
!isDenseLT(LevelType::LooseCompressedNo) &&
!isDenseLT(LevelType::LooseCompressedNuNo) &&
- !isDenseLT(LevelType::TwoOutOfFour)),
+ !isDenseLT(LevelType::NOutOfM)),
"isDenseLT definition is broken");
static_assert((!isCompressedLT(LevelType::Dense) &&
@@ -423,7 +473,7 @@ static_assert((!isCompressedLT(LevelType::Dense) &&
!isCompressedLT(LevelType::LooseCompressedNu) &&
!isCompressedLT(LevelType::LooseCompressedNo) &&
!isCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isCompressedLT(LevelType::TwoOutOfFour)),
+ !isCompressedLT(LevelType::NOutOfM)),
"isCompressedLT definition is broken");
static_assert((!isSingletonLT(LevelType::Dense) &&
@@ -439,7 +489,7 @@ static_assert((!isSingletonLT(LevelType::Dense) &&
!isSingletonLT(LevelType::LooseCompressedNu) &&
!isSingletonLT(LevelType::LooseCompressedNo) &&
!isSingletonLT(LevelType::LooseCompressedNuNo) &&
- !isSingletonLT(LevelType::TwoOutOfFour)),
+ !isSingletonLT(LevelType::NOutOfM)),
"isSingletonLT definition is broken");
static_assert((!isLooseCompressedLT(LevelType::Dense) &&
@@ -455,24 +505,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) &&
isLooseCompressedLT(LevelType::LooseCompressedNu) &&
isLooseCompressedLT(LevelType::LooseCompressedNo) &&
isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isLooseCompressedLT(LevelType::TwoOutOfFour)),
+ !isLooseCompressedLT(LevelType::NOutOfM)),
"isLooseCompressedLT definition is broken");
-static_assert((!is2OutOf4LT(LevelType::Dense) &&
- !is2OutOf4LT(LevelType::Compressed) &&
- !is2OutOf4LT(LevelType::CompressedNu) &&
- !is2OutOf4LT(LevelType::CompressedNo) &&
- !is2OutOf4LT(LevelType::CompressedNuNo) &&
- !is2OutOf4LT(LevelType::Singleton) &&
- !is2OutOf4LT(LevelType::SingletonNu) &&
- !is2OutOf4LT(LevelType::SingletonNo) &&
- !is2OutOf4LT(LevelType::SingletonNuNo) &&
- !is2OutOf4LT(LevelType::LooseCompressed) &&
- !is2OutOf4LT(LevelType::LooseCompressedNu) &&
- !is2OutOf4LT(LevelType::LooseCompressedNo) &&
- !is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
- is2OutOf4LT(LevelType::TwoOutOfFour)),
- "is2OutOf4LT definition is broken");
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+ !isNOutOfMLT(LevelType::Compressed) &&
+ !isNOutOfMLT(LevelType::CompressedNu) &&
+ !isNOutOfMLT(LevelType::CompressedNo) &&
+ !isNOutOfMLT(LevelType::CompressedNuNo) &&
+ !isNOutOfMLT(LevelType::Singleton) &&
+ !isNOutOfMLT(LevelType::SingletonNu) &&
+ !isNOutOfMLT(LevelType::SingletonNo) &&
+ !isNOutOfMLT(LevelType::SingletonNuNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressed) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+ isNOutOfMLT(LevelType::NOutOfM)),
+ "isNOutOfMLT definition is broken");
static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::Compressed) &&
@@ -487,7 +537,7 @@ static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::LooseCompressedNu) &&
!isOrderedLT(LevelType::LooseCompressedNo) &&
!isOrderedLT(LevelType::LooseCompressedNuNo) &&
- isOrderedLT(LevelType::TwoOutOfFour)),
+ isOrderedLT(LevelType::NOutOfM)),
"isOrderedLT definition is broken");
static_assert((isUniqueLT(LevelType::Dense) &&
@@ -503,7 +553,7 @@ static_assert((isUniqueLT(LevelType::Dense) &&
!isUniqueLT(LevelType::LooseCompressedNu) &&
isUniqueLT(LevelType::LooseCompressedNo) &&
!isUniqueLT(LevelType::LooseCompressedNuNo) &&
- isUniqueLT(LevelType::TwoOutOfFour)),
+ isUniqueLT(LevelType::NOutOfM)),
"isUniqueLT definition is broken");
/// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 12c1068ae1f54..08ba96d437045 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
- **compressed** : only nonzeros along...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/79935
More information about the Mlir-commits
mailing list