[Mlir-commits] [mlir] e5924d6 - [mlir][sparse] Implement parsing n out of m (#79935)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 8 11:38:46 PST 2024
Author: Yinying Li
Date: 2024-02-08T14:38:42-05:00
New Revision: e5924d64991abb4da111317ff5e8d9147265354a
URL: https://github.com/llvm/llvm-project/commit/e5924d64991abb4da111317ff5e8d9147265354a
DIFF: https://github.com/llvm/llvm-project/commit/e5924d64991abb4da111317ff5e8d9147265354a.diff
LOG: [mlir][sparse] Implement parsing n out of m (#79935)
1. Add parsing methods for block[n, m].
2. Encode n and m with the newly extended 64-bit LevelType enum.
3. Update 2:4 methods names/comments to n:m.
Added:
Modified:
mlir/include/mlir-c/Dialect/SparseTensor.h
mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
mlir/lib/Bindings/Python/DialectSparseTensor.cpp
mlir/lib/CAPI/Dialect/SparseTensor.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
mlir/test/CAPI/sparse_tensor.c
mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
mlir/test/python/dialects/sparse_tensor/dialect.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h
index 42d8400cb5e958..2c71b0008ad16a 100644
--- a/mlir/include/mlir-c/Dialect/SparseTensor.h
+++ b/mlir/include/mlir-c/Dialect/SparseTensor.h
@@ -28,20 +28,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor);
typedef uint64_t MlirSparseTensorLevelType;
enum MlirBaseSparseTensorLevelType {
- MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10
- MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10
- MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10
- MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11
- MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00
+ MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002,
+ MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002,
+ MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002,
+ MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003,
+ MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000,
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 86c52bfc651ef1..e940d203be9ed5 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -154,9 +154,10 @@ enum class Action : uint32_t {
/// This enum defines all the sparse representations supportable by
/// the SparseTensor dialect. We use a lightweight encoding to encode
-/// both the "format" per se (dense, compressed, singleton, loose_compressed,
-/// two-out-of-four) as well as the "properties" (ordered, unique). The
-/// encoding is chosen for performance of the runtime library, and thus may
+/// the "format" per se (dense, compressed, singleton, loose_compressed,
+/// n-out-of-m), the "properties" (ordered, unique) as well as n and m when
+/// the format is NOutOfM.
+/// The encoding is chosen for performance of the runtime library, and thus may
/// change in future versions; consequently, client code should use the
/// predicate functions defined below, rather than relying on knowledge
/// about the particular binary encoding.
@@ -165,41 +166,75 @@ enum class Action : uint32_t {
/// where we need to store an undefined or indeterminate `LevelType`.
/// It should not be used externally, since it does not indicate an
/// actual/representable format.
+///
+/// Bit manipulations for LevelType:
+///
+/// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty |
+///
enum class LevelType : uint64_t {
- Undef = 0, // 0b00000_00
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- CompressedNu = 9, // 0b00010_01
- CompressedNo = 10, // 0b00010_10
- CompressedNuNo = 11, // 0b00010_11
- Singleton = 16, // 0b00100_00
- SingletonNu = 17, // 0b00100_01
- SingletonNo = 18, // 0b00100_10
- SingletonNuNo = 19, // 0b00100_11
- LooseCompressed = 32, // 0b01000_00
- LooseCompressedNu = 33, // 0b01000_01
- LooseCompressedNo = 34, // 0b01000_10
- LooseCompressedNuNo = 35, // 0b01000_11
- TwoOutOfFour = 64, // 0b10000_00
+ Undef = 0x000000000000,
+ Dense = 0x000000010000,
+ Compressed = 0x000000020000,
+ CompressedNu = 0x000000020001,
+ CompressedNo = 0x000000020002,
+ CompressedNuNo = 0x000000020003,
+ Singleton = 0x000000040000,
+ SingletonNu = 0x000000040001,
+ SingletonNo = 0x000000040002,
+ SingletonNuNo = 0x000000040003,
+ LooseCompressed = 0x000000080000,
+ LooseCompressedNu = 0x000000080001,
+ LooseCompressedNo = 0x000000080002,
+ LooseCompressedNuNo = 0x000000080003,
+ NOutOfM = 0x000000100000,
};
/// This enum defines all supported storage format without the level properties.
enum class LevelFormat : uint64_t {
- Dense = 4, // 0b00001_00
- Compressed = 8, // 0b00010_00
- Singleton = 16, // 0b00100_00
- LooseCompressed = 32, // 0b01000_00
- TwoOutOfFour = 64, // 0b10000_00
+ Dense = 0x00010000,
+ Compressed = 0x00020000,
+ Singleton = 0x00040000,
+ LooseCompressed = 0x00080000,
+ NOutOfM = 0x00100000,
};
/// This enum defines all the nondefault properties for storage formats.
enum class LevelPropertyNondefault : uint64_t {
- Nonunique = 1, // 0b00000_01
- Nonordered = 2, // 0b00000_10
+ Nonunique = 0x0001,
+ Nonordered = 0x0002,
};
+/// Get N of NOutOfM level type.
+constexpr uint64_t getN(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 32) & 0xff;
+}
+
+/// Get M of NOutOfM level type.
+constexpr uint64_t getM(LevelType lt) {
+ return (static_cast<uint64_t>(lt) >> 40) & 0xff;
+}
+
+/// Convert N of NOutOfM level type to the stored bits.
+constexpr uint64_t nToBits(uint64_t n) { return n << 32; }
+
+/// Convert M of NOutOfM level type to the stored bits.
+constexpr uint64_t mToBits(uint64_t m) { return m << 40; }
+
+/// Check if the `LevelType` is NOutOfM (regardless of
+/// properties and block sizes).
+constexpr bool isNOutOfMLT(LevelType lt) {
+ return ((static_cast<uint64_t>(lt) & 0x100000) ==
+ static_cast<uint64_t>(LevelType::NOutOfM));
+}
+
+/// Check if the `LevelType` is NOutOfM with the correct block sizes.
+constexpr bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) {
+ return isNOutOfMLT(lt) && getN(lt) == n && getM(lt) == m;
+}
+
/// Returns string representation of the given dimension level type.
-constexpr const char *toMLIRString(LevelType lt) {
+constexpr const char *toMLIRString(LevelType lvlType) {
+ auto lt = static_cast<LevelType>(static_cast<uint64_t>(lvlType) & 0xffffffff);
switch (lt) {
case LevelType::Undef:
return "undef";
@@ -229,21 +264,22 @@ constexpr const char *toMLIRString(LevelType lt) {
return "loose_compressed(nonordered)";
case LevelType::LooseCompressedNuNo:
return "loose_compressed(nonunique, nonordered)";
- case LevelType::TwoOutOfFour:
- return "block2_4";
+ case LevelType::NOutOfM:
+ return "structured";
}
return "";
}
/// Check that the `LevelType` contains a valid (possibly undefined) value.
constexpr bool isValidLT(LevelType lt) {
- const uint64_t formatBits = static_cast<uint64_t>(lt) >> 2;
- const uint64_t propertyBits = static_cast<uint64_t>(lt) & 3;
- // If undefined or dense, then must be unique and ordered.
+ const uint64_t formatBits = static_cast<uint64_t>(lt) & 0xffff0000;
+ const uint64_t propertyBits = static_cast<uint64_t>(lt) & 0xffff;
+ // If undefined/dense/NOutOfM, then must be unique and ordered.
// Otherwise, the format must be one of the known ones.
- return (formatBits <= 1 || formatBits == 16)
+ return (formatBits <= 0x10000 || formatBits == 0x100000)
? (propertyBits == 0)
- : (formatBits == 2 || formatBits == 4 || formatBits == 8);
+ : (formatBits == 0x20000 || formatBits == 0x40000 ||
+ formatBits == 0x80000);
}
/// Check if the `LevelType` is the special undefined value.
@@ -251,34 +287,28 @@ constexpr bool isUndefLT(LevelType lt) { return lt == LevelType::Undef; }
/// Check if the `LevelType` is dense (regardless of properties).
constexpr bool isDenseLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
static_cast<uint64_t>(LevelType::Dense);
}
/// Check if the `LevelType` is compressed (regardless of properties).
constexpr bool isCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
static_cast<uint64_t>(LevelType::Compressed);
}
/// Check if the `LevelType` is singleton (regardless of properties).
constexpr bool isSingletonLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
static_cast<uint64_t>(LevelType::Singleton);
}
/// Check if the `LevelType` is loose compressed (regardless of properties).
constexpr bool isLooseCompressedLT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
+ return (static_cast<uint64_t>(lt) & ~0xffff) ==
static_cast<uint64_t>(LevelType::LooseCompressed);
}
-/// Check if the `LevelType` is 2OutOf4 (regardless of properties).
-constexpr bool is2OutOf4LT(LevelType lt) {
- return (static_cast<uint64_t>(lt) & ~3) ==
- static_cast<uint64_t>(LevelType::TwoOutOfFour);
-}
-
/// Check if the `LevelType` needs positions array.
constexpr bool isWithPosLT(LevelType lt) {
return isCompressedLT(lt) || isLooseCompressedLT(lt);
@@ -287,17 +317,19 @@ constexpr bool isWithPosLT(LevelType lt) {
/// Check if the `LevelType` needs coordinates array.
constexpr bool isWithCrdLT(LevelType lt) {
return isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt);
+ isNOutOfMLT(lt);
}
/// Check if the `LevelType` is ordered (regardless of storage format).
constexpr bool isOrderedLT(LevelType lt) {
return !(static_cast<uint64_t>(lt) & 2);
+ return !(static_cast<uint64_t>(lt) & 2);
}
/// Check if the `LevelType` is unique (regardless of storage format).
constexpr bool isUniqueLT(LevelType lt) {
return !(static_cast<uint64_t>(lt) & 1);
+ return !(static_cast<uint64_t>(lt) & 1);
}
/// Convert a LevelType to its corresponding LevelFormat.
@@ -305,21 +337,25 @@ constexpr bool isUniqueLT(LevelType lt) {
constexpr std::optional<LevelFormat> getLevelFormat(LevelType lt) {
if (lt == LevelType::Undef)
return std::nullopt;
- return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & ~3);
+ return static_cast<LevelFormat>(static_cast<uint64_t>(lt) & 0xffff0000);
}
/// Convert a LevelFormat to its corresponding LevelType with the given
/// properties. Returns std::nullopt when the properties are not applicable
/// for the input level format.
constexpr std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered,
- bool unique) {
- auto lt = static_cast<LevelType>(static_cast<uint64_t>(lf) |
- (ordered ? 0 : 2) | (unique ? 0 : 1));
+ bool unique, uint64_t n = 0,
+ uint64_t m = 0) {
+ uint64_t newN = n << 32;
+ uint64_t newM = m << 40;
+ auto lt =
+ static_cast<LevelType>(static_cast<uint64_t>(lf) | (ordered ? 0 : 2) |
+ (unique ? 0 : 1) | newN | newM);
return isValidLT(lt) ? std::optional(lt) : std::nullopt;
}
//
-// Ensure the above methods work as indended.
+// Ensure the above methods work as intended.
//
static_assert(
@@ -341,7 +377,7 @@ static_assert(
LevelFormat::LooseCompressed &&
*getLevelFormat(LevelType::LooseCompressedNuNo) ==
LevelFormat::LooseCompressed &&
- *getLevelFormat(LevelType::TwoOutOfFour) == LevelFormat::TwoOutOfFour),
+ *getLevelFormat(LevelType::NOutOfM) == LevelFormat::NOutOfM),
"getLevelFormat conversion is broken");
static_assert(
@@ -373,13 +409,28 @@ static_assert(
LevelType::LooseCompressedNo &&
*buildLevelType(LevelFormat::LooseCompressed, false, false) ==
LevelType::LooseCompressedNuNo &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, true) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, true, false) == std::nullopt &&
- buildLevelType(LevelFormat::TwoOutOfFour, false, false) == std::nullopt &&
- *buildLevelType(LevelFormat::TwoOutOfFour, true, true) ==
- LevelType::TwoOutOfFour),
+ buildLevelType(LevelFormat::NOutOfM, false, true) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, true, false) == std::nullopt &&
+ buildLevelType(LevelFormat::NOutOfM, false, false) == std::nullopt &&
+ *buildLevelType(LevelFormat::NOutOfM, true, true) == LevelType::NOutOfM),
"buildLevelType conversion is broken");
+static_assert(
+ (getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 2 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4)) == 4 &&
+ getN(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 8 &&
+ getM(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10)) == 10),
+ "getN/M conversion is broken");
+
+static_assert(
+ (isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 2, 4),
+ 2, 4) &&
+ isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 8, 10),
+ 8, 10) &&
+ !isValidNOutOfMLT(*buildLevelType(LevelFormat::NOutOfM, true, true, 3, 4),
+ 2, 4)),
+ "isValidNOutOfMLT definition is broken");
+
static_assert(
(isValidLT(LevelType::Undef) && isValidLT(LevelType::Dense) &&
isValidLT(LevelType::Compressed) && isValidLT(LevelType::CompressedNu) &&
@@ -391,7 +442,7 @@ static_assert(
isValidLT(LevelType::LooseCompressedNu) &&
isValidLT(LevelType::LooseCompressedNo) &&
isValidLT(LevelType::LooseCompressedNuNo) &&
- isValidLT(LevelType::TwoOutOfFour)),
+ isValidLT(LevelType::NOutOfM)),
"isValidLT definition is broken");
static_assert((isDenseLT(LevelType::Dense) &&
@@ -407,7 +458,7 @@ static_assert((isDenseLT(LevelType::Dense) &&
!isDenseLT(LevelType::LooseCompressedNu) &&
!isDenseLT(LevelType::LooseCompressedNo) &&
!isDenseLT(LevelType::LooseCompressedNuNo) &&
- !isDenseLT(LevelType::TwoOutOfFour)),
+ !isDenseLT(LevelType::NOutOfM)),
"isDenseLT definition is broken");
static_assert((!isCompressedLT(LevelType::Dense) &&
@@ -423,7 +474,7 @@ static_assert((!isCompressedLT(LevelType::Dense) &&
!isCompressedLT(LevelType::LooseCompressedNu) &&
!isCompressedLT(LevelType::LooseCompressedNo) &&
!isCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isCompressedLT(LevelType::TwoOutOfFour)),
+ !isCompressedLT(LevelType::NOutOfM)),
"isCompressedLT definition is broken");
static_assert((!isSingletonLT(LevelType::Dense) &&
@@ -439,7 +490,7 @@ static_assert((!isSingletonLT(LevelType::Dense) &&
!isSingletonLT(LevelType::LooseCompressedNu) &&
!isSingletonLT(LevelType::LooseCompressedNo) &&
!isSingletonLT(LevelType::LooseCompressedNuNo) &&
- !isSingletonLT(LevelType::TwoOutOfFour)),
+ !isSingletonLT(LevelType::NOutOfM)),
"isSingletonLT definition is broken");
static_assert((!isLooseCompressedLT(LevelType::Dense) &&
@@ -455,24 +506,24 @@ static_assert((!isLooseCompressedLT(LevelType::Dense) &&
isLooseCompressedLT(LevelType::LooseCompressedNu) &&
isLooseCompressedLT(LevelType::LooseCompressedNo) &&
isLooseCompressedLT(LevelType::LooseCompressedNuNo) &&
- !isLooseCompressedLT(LevelType::TwoOutOfFour)),
+ !isLooseCompressedLT(LevelType::NOutOfM)),
"isLooseCompressedLT definition is broken");
-static_assert((!is2OutOf4LT(LevelType::Dense) &&
- !is2OutOf4LT(LevelType::Compressed) &&
- !is2OutOf4LT(LevelType::CompressedNu) &&
- !is2OutOf4LT(LevelType::CompressedNo) &&
- !is2OutOf4LT(LevelType::CompressedNuNo) &&
- !is2OutOf4LT(LevelType::Singleton) &&
- !is2OutOf4LT(LevelType::SingletonNu) &&
- !is2OutOf4LT(LevelType::SingletonNo) &&
- !is2OutOf4LT(LevelType::SingletonNuNo) &&
- !is2OutOf4LT(LevelType::LooseCompressed) &&
- !is2OutOf4LT(LevelType::LooseCompressedNu) &&
- !is2OutOf4LT(LevelType::LooseCompressedNo) &&
- !is2OutOf4LT(LevelType::LooseCompressedNuNo) &&
- is2OutOf4LT(LevelType::TwoOutOfFour)),
- "is2OutOf4LT definition is broken");
+static_assert((!isNOutOfMLT(LevelType::Dense) &&
+ !isNOutOfMLT(LevelType::Compressed) &&
+ !isNOutOfMLT(LevelType::CompressedNu) &&
+ !isNOutOfMLT(LevelType::CompressedNo) &&
+ !isNOutOfMLT(LevelType::CompressedNuNo) &&
+ !isNOutOfMLT(LevelType::Singleton) &&
+ !isNOutOfMLT(LevelType::SingletonNu) &&
+ !isNOutOfMLT(LevelType::SingletonNo) &&
+ !isNOutOfMLT(LevelType::SingletonNuNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressed) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNu) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNo) &&
+ !isNOutOfMLT(LevelType::LooseCompressedNuNo) &&
+ isNOutOfMLT(LevelType::NOutOfM)),
+ "isNOutOfMLT definition is broken");
static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::Compressed) &&
@@ -487,7 +538,7 @@ static_assert((isOrderedLT(LevelType::Dense) &&
isOrderedLT(LevelType::LooseCompressedNu) &&
!isOrderedLT(LevelType::LooseCompressedNo) &&
!isOrderedLT(LevelType::LooseCompressedNuNo) &&
- isOrderedLT(LevelType::TwoOutOfFour)),
+ isOrderedLT(LevelType::NOutOfM)),
"isOrderedLT definition is broken");
static_assert((isUniqueLT(LevelType::Dense) &&
@@ -503,7 +554,7 @@ static_assert((isUniqueLT(LevelType::Dense) &&
!isUniqueLT(LevelType::LooseCompressedNu) &&
isUniqueLT(LevelType::LooseCompressedNo) &&
!isUniqueLT(LevelType::LooseCompressedNuNo) &&
- isUniqueLT(LevelType::TwoOutOfFour)),
+ isUniqueLT(LevelType::NOutOfM)),
"isUniqueLT definition is broken");
/// Bit manipulations for affine encoding.
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 12c1068ae1f546..5b3b971f9a7f23 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -145,7 +145,8 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
- **compressed** : only nonzeros along this level are stored
- **loose_compressed** : as compressed, but allows for free space between regions
- **singleton** : a variant of the compressed format, where coordinates have no siblings
- - **block2_4** : the compression uses a 2:4 encoding per 1x4 block
+ - **structured[n, m]** : the compression uses a n:m encoding
+ (viz. n out of m consecutive elements are nonzero)
For a compressed level, each position interval is represented in a compact
way with a lowerbound `pos(i)` and an upperbound `pos(i+1) - 1`, which implies
@@ -374,7 +375,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
bool isCompressedLvl(::mlir::sparse_tensor::Level l) const { return isCompressedLT(getLvlType(l)); }
bool isSingletonLvl(::mlir::sparse_tensor::Level l) const { return isSingletonLT(getLvlType(l)); }
bool isLooseCompressedLvl(::mlir::sparse_tensor::Level l) const { return isLooseCompressedLT(getLvlType(l)); }
- bool isTwoOutOfFourLvl(::mlir::sparse_tensor::Level l) const { return is2OutOf4LT(getLvlType(l)); }
+ bool isNOutOfMLvl(::mlir::sparse_tensor::Level l) const { return isNOutOfMLT(getLvlType(l)); }
bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedLT(getLvlType(l)); }
bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueLT(getLvlType(l)); }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index 4c98129744bcd9..4e2b85d35c1ac1 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -291,7 +291,7 @@ class SparseTensorType {
return isLooseCompressedLT(getLvlType(l));
}
bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); }
- bool is2OutOf4Lvl(Level l) const { return is2OutOf4LT(getLvlType(l)); }
+ bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); }
bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); }
bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); }
bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
index 4a34bb2e003e88..490ef3071af1b7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
@@ -510,7 +510,7 @@ class Merger {
if (isLvlWithNonTrivialIdxExp(b)) {
auto lt = getLoopDependentLevelType(b);
return isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || is2OutOf4LT(lt);
+ isLooseCompressedLT(lt) || isNOutOfMLT(lt);
}
return false;
}
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
index 01c5f2382ffe69..14182172f4f622 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
@@ -123,8 +123,8 @@ class SparseTensorStorageBase {
/// Safely checks if the level uses singleton storage.
bool isSingletonLvl(uint64_t l) const { return isSingletonLT(getLvlType(l)); }
- /// Safely checks if the level uses 2 out of 4 storage.
- bool is2OutOf4Lvl(uint64_t l) const { return is2OutOf4LT(getLvlType(l)); }
+ /// Safely checks if the level uses n out of m storage.
+ bool isNOutOfMLvl(uint64_t l) const { return isNOutOfMLT(getLvlType(l)); }
/// Safely checks if the level is ordered.
bool isOrderedLvl(uint64_t l) const { return isOrderedLT(getLvlType(l)); }
@@ -450,7 +450,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
void appendCrd(uint64_t lvl, uint64_t full, uint64_t crd) {
if (!isDenseLvl(lvl)) {
assert(isCompressedLvl(lvl) || isLooseCompressedLvl(lvl) ||
- isSingletonLvl(lvl) || is2OutOf4Lvl(lvl));
+ isSingletonLvl(lvl) || isNOutOfMLvl(lvl));
coordinates[lvl].push_back(detail::checkOverflowCast<C>(crd));
} else { // Dense level.
assert(crd >= full && "Coordinate was already filled");
@@ -473,7 +473,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
return positions[l][parentSz];
if (isLooseCompressedLvl(l))
return positions[l][2 * parentSz - 1];
- if (isSingletonLvl(l) || is2OutOf4Lvl(l))
+ if (isSingletonLvl(l) || isNOutOfMLvl(l))
return parentSz; // new size same as the parent
assert(isDenseLvl(l));
return parentSz * getLvlSize(l);
@@ -527,7 +527,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
uint64_t pos = coordinates[l].size();
positions[l].insert(positions[l].end(), 2 * count,
detail::checkOverflowCast<P>(pos));
- } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+ } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
return; // Nothing to finalize.
} else { // Dense dimension.
assert(isDenseLvl(l));
@@ -624,7 +624,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
lvlCursor[l] = static_cast<uint64_t>(coordinatesL[pos]);
toCOO(pos, l + 1, dimCoords);
}
- } else if (isSingletonLvl(l) || is2OutOf4Lvl(l)) {
+ } else if (isSingletonLvl(l) || isNOutOfMLvl(l)) {
assert(parentPos < coordinates[l].size());
lvlCursor[l] = static_cast<uint64_t>(coordinates[l][parentPos]);
toCOO(parentPos, l + 1, dimCoords);
@@ -721,8 +721,8 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
} else if (isSingletonLvl(l)) {
coordinates[l].reserve(sz);
sz = 1;
- } else if (is2OutOf4Lvl(l)) {
- assert(l == lvlRank - 1 && "unexpected 2:4 usage");
+ } else if (isNOutOfMLvl(l)) {
+ assert(l == lvlRank - 1 && "unexpected n:m usage");
sz = detail::checkedMul(sz, lvlSizes[l]) / 2;
coordinates[l].reserve(sz);
values.reserve(sz);
@@ -791,8 +791,8 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
}
} else if (isSingletonLvl(l)) {
assert(0 && "general singleton not supported yet");
- } else if (is2OutOf4Lvl(l)) {
- assert(0 && "2Out4 not supported yet");
+ } else if (isNOutOfMLvl(l)) {
+ assert(0 && "n ouf of m not supported yet");
} else {
assert(isDenseLvl(l));
}
diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
index 698367a1addaff..607534c6156439 100644
--- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
+++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp
@@ -25,7 +25,7 @@ using namespace mlir::python::adaptors;
static void populateDialectSparseTensorSubmodule(const py::module &m) {
py::enum_<MlirBaseSparseTensorLevelType>(m, "LevelType", py::module_local())
.value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE)
- .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR)
+ .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M)
.value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED)
.value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU)
.value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO)
diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
index e4534ad132385f..a34b9a29b0e90a 100644
--- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp
+++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp
@@ -20,25 +20,36 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor,
mlir::sparse_tensor::SparseTensorDialect)
// Ensure the C-API enums are int-castable to C++ equivalents.
-static_assert(static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
- static_cast<int>(LevelType::Dense) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
- static_cast<int>(LevelType::Compressed) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
- static_cast<int>(LevelType::CompressedNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
- static_cast<int>(LevelType::CompressedNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
- static_cast<int>(LevelType::CompressedNuNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
- static_cast<int>(LevelType::Singleton) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
- static_cast<int>(LevelType::SingletonNu) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
- static_cast<int>(LevelType::SingletonNo) &&
- static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
- static_cast<int>(LevelType::SingletonNuNo),
- "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
+static_assert(
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_DENSE) ==
+ static_cast<int>(LevelType::Dense) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) ==
+ static_cast<int>(LevelType::Compressed) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) ==
+ static_cast<int>(LevelType::CompressedNu) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) ==
+ static_cast<int>(LevelType::CompressedNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) ==
+ static_cast<int>(LevelType::CompressedNuNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) ==
+ static_cast<int>(LevelType::Singleton) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) ==
+ static_cast<int>(LevelType::SingletonNu) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) ==
+ static_cast<int>(LevelType::SingletonNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) ==
+ static_cast<int>(LevelType::SingletonNuNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) ==
+ static_cast<int>(LevelType::LooseCompressed) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) ==
+ static_cast<int>(LevelType::LooseCompressedNu) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) ==
+ static_cast<int>(LevelType::LooseCompressedNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) ==
+ static_cast<int>(LevelType::LooseCompressedNuNo) &&
+ static_cast<int>(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) ==
+ static_cast<int>(LevelType::NOutOfM),
+ "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch");
bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) {
return isa<SparseTensorEncodingAttr>(unwrap(attr));
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
index eb7ea63a4e88b8..752d6e6481dfee 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.cpp
@@ -29,12 +29,21 @@ using namespace mlir::sparse_tensor::ir_detail;
// `LvlTypeParser` implementation.
//===----------------------------------------------------------------------===//
-FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
+FailureOr<uint64_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
StringRef base;
const auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&base)),
"expected valid level format (e.g. dense, compressed or singleton)")
- uint8_t properties = 0;
+ uint64_t properties = 0;
+ SmallVector<unsigned> structure;
+
+ if (base.compare("structured") == 0) {
+ ParseResult res = parser.parseCommaSeparatedList(
+ mlir::OpAsmParser::Delimiter::OptionalSquare,
+ [&]() -> ParseResult { return parseStructure(parser, &structure); },
+ " in block n out of m");
+ FAILURE_IF_FAILED(res)
+ }
ParseResult res = parser.parseCommaSeparatedList(
mlir::OpAsmParser::Delimiter::OptionalParen,
@@ -44,15 +53,20 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
// Set the base bit for properties.
if (base.compare("dense") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Dense);
+ properties |= static_cast<uint64_t>(LevelFormat::Dense);
} else if (base.compare("compressed") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Compressed);
- } else if (base.compare("block2_4") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::TwoOutOfFour);
+ properties |= static_cast<uint64_t>(LevelFormat::Compressed);
+ } else if (base.compare("structured") == 0) {
+ if (structure.size() != 2) {
+ parser.emitError(loc, "expected exactly 2 structure sizes");
+ return failure();
+ }
+ properties |= static_cast<uint64_t>(LevelFormat::NOutOfM);
+ properties |= nToBits(structure[0]) | mToBits(structure[1]);
} else if (base.compare("loose_compressed") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::LooseCompressed);
+ properties |= static_cast<uint64_t>(LevelFormat::LooseCompressed);
} else if (base.compare("singleton") == 0) {
- properties |= static_cast<uint8_t>(LevelFormat::Singleton);
+ properties |= static_cast<uint64_t>(LevelFormat::Singleton);
} else {
parser.emitError(loc, "unknown level format: ") << base;
return failure();
@@ -64,15 +78,15 @@ FailureOr<uint8_t> LvlTypeParser::parseLvlType(AsmParser &parser) const {
}
ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
- uint8_t *properties) const {
+ uint64_t *properties) const {
StringRef strVal;
auto loc = parser.getCurrentLocation();
ERROR_IF(failed(parser.parseOptionalKeyword(&strVal)),
"expected valid level property (e.g. nonordered, nonunique or high)")
if (strVal.compare("nonunique") == 0) {
- *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonunique);
+ *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonunique);
} else if (strVal.compare("nonordered") == 0) {
- *properties |= static_cast<uint8_t>(LevelPropertyNondefault::Nonordered);
+ *properties |= static_cast<uint64_t>(LevelPropertyNondefault::Nonordered);
} else {
parser.emitError(loc, "unknown level property: ") << strVal;
return failure();
@@ -80,4 +94,22 @@ ParseResult LvlTypeParser::parseProperty(AsmParser &parser,
return success();
}
+ParseResult
+LvlTypeParser::parseStructure(AsmParser &parser,
+ SmallVector<unsigned> *structure) const {
+ int intVal;
+ auto loc = parser.getCurrentLocation();
+ OptionalParseResult intValParseResult = parser.parseOptionalInteger(intVal);
+ if (intValParseResult.has_value()) {
+ if (failed(*intValParseResult)) {
+ parser.emitError(loc, "failed to parse block size");
+ return failure();
+ }
+ structure->push_back(intVal);
+ return success();
+ }
+ parser.emitError(loc, "expected valid integer for block size");
+ return failure();
+}
+
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
index 5e2f11b75d4da6..6a13112195d440 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
+++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/LvlTypeParser.h
@@ -18,10 +18,12 @@ namespace ir_detail {
class LvlTypeParser {
public:
LvlTypeParser() = default;
- FailureOr<uint8_t> parseLvlType(AsmParser &parser) const;
+ FailureOr<uint64_t> parseLvlType(AsmParser &parser) const;
private:
- ParseResult parseProperty(AsmParser &parser, uint8_t *properties) const;
+ ParseResult parseProperty(AsmParser &parser, uint64_t *properties) const;
+ ParseResult parseStructure(AsmParser &parser,
+ SmallVector<unsigned> *structure) const;
};
} // namespace ir_detail
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 27125bc7ed45e3..67b1d7974fa259 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -613,16 +613,28 @@ void SparseTensorEncodingAttr::printDimensions(
}
}
+std::string getNOutOfMString(LevelType lt) {
+ if (isNOutOfMLT(lt)) {
+ unsigned n = getN(lt);
+ unsigned m = getM(lt);
+ auto output = "[" + std::to_string(n) + ", " + std::to_string(m) + "]";
+ return output;
+ }
+ return "";
+}
+
void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer,
ArrayRef<LevelType> lvlTypes) const {
for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) {
map.getResult(i).print(printer.getStream());
- printer << " : " << toMLIRString(lvlTypes[i]) << ", ";
+ printer << " : " << toMLIRString(lvlTypes[i])
+ << getNOutOfMString(lvlTypes[i]) << ", ";
}
if (map.getNumResults() >= 1) {
auto lastIndex = map.getNumResults() - 1;
map.getResult(lastIndex).print(printer.getStream());
- printer << " : " << toMLIRString(lvlTypes[lastIndex]);
+ printer << " : " << toMLIRString(lvlTypes[lastIndex])
+ << getNOutOfMString(lvlTypes[lastIndex]);
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index dd3af9d8354123..3f352c868467fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -451,7 +451,7 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
/// Test for 2:4 matrix with suitable metadata.
static bool isAdmissible24(SparseTensorType &aTp) {
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
- aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
+ aTp.isDenseLvl(1) && aTp.isNOutOfMLvl(2) && isAdmissibleMetaData(aTp);
}
/// Test for conversion into 2:4 matrix.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 491501a3381b9c..d4459c6ea1e521 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -130,7 +130,7 @@ static void allocSchemeForRank(OpBuilder &builder, Location loc,
createPushback(builder, loc, desc, SparseTensorFieldKind::PosMemRef, lvl,
/*value=*/posZero, /*repeat=*/linear);
return;
- } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+ } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
return; // nothing to do
}
// Keep compounding the size, but nothing needs to be initialized
@@ -409,7 +409,7 @@ static void genEndInsert(OpBuilder &builder, Location loc,
}
} else {
assert(isDenseLT(lt) || isLooseCompressedLT(lt) || isSingletonLT(lt) ||
- is2OutOf4LT(lt));
+ isNOutOfMLT(lt));
}
}
}
@@ -488,7 +488,7 @@ class SparseInsertGenerator
}
parentPos =
genCompressed(builder, loc, desc, coords, value, parentPos, lvl);
- } else if (isSingletonLT(lt) || is2OutOf4LT(lt)) {
+ } else if (isSingletonLT(lt) || isNOutOfMLT(lt)) {
// Create:
// coordinates[lvl].push_back(coords[lvl])
// positions[lvl] = positions[lvl-1]
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index ab38ab5cc3f78f..8f2ae60b311f7c 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -891,7 +891,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
assert(curr == env.merger().loop(b));
Value clause;
if (isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
+ isLooseCompressedLT(lt) || isNOutOfMLT(lt)) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoord(tid, *lvl);
const Value lvar = env.getLoopVar(curr);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 4ba9ecbe03c72d..c85f8204ba7527 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -139,18 +139,19 @@ class SingletonLevel : public SparseLevel {
}
};
-class TwoOutFourLevel : public SparseLevel {
+class NOutOfMLevel : public SparseLevel {
public:
- TwoOutFourLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
- Value crdBuffer)
+ NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
+ Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
Value max) const override {
- assert(max == nullptr && isUnique() && "2:4 level can not be non-unique.");
- // Each 2:4 blk has exactly two specified elements.
- Value posLo = MULI(p, C_IDX(2));
- return {posLo, ADDI(posLo, C_IDX(2))};
+ assert(max == nullptr && isUnique() && "n:m level can not be non-unique.");
+ // Each n:m blk has exactly n specified elements.
+ auto n = getN(lt);
+ Value posLo = MULI(p, C_IDX(n));
+ return {posLo, ADDI(posLo, C_IDX(n))};
}
};
@@ -1291,9 +1292,9 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
Value crd = genToCoordinates(b, l, t, lvl);
return std::make_unique<SingletonLevel>(tid, lvl, lt, sz, crd);
}
- case LevelFormat::TwoOutOfFour: {
+ case LevelFormat::NOutOfM: {
Value crd = genToCoordinates(b, l, t, lvl);
- return std::make_unique<TwoOutFourLevel>(tid, lvl, lt, sz, crd);
+ return std::make_unique<NOutOfMLevel>(tid, lvl, lt, sz, crd);
}
}
llvm_unreachable("unrecognizable level format");
diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
index 6cdf5f8c0168be..96537cbb0c4836 100644
--- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
@@ -489,7 +489,7 @@ BitVector Merger::simplifyCond(LatSetId s0, LatPointId p0) {
if (simple[b] && !isSparseLvlWithNonTrivialIdxExp(b)) {
const auto lt = getLvlType(b);
if (!isCompressedLT(lt) && !isSingletonLT(lt) &&
- !isLooseCompressedLT(lt) && !is2OutOf4LT(lt)) {
+ !isLooseCompressedLT(lt) && !isNOutOfMLT(lt)) {
if (reset)
simple.reset(b);
reset = true;
@@ -670,7 +670,7 @@ bool Merger::hasAnySparse(const BitVector &bits) const {
for (TensorLoopId b : bits.set_bits()) {
const auto lt = getLvlType(b);
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt))
+ isNOutOfMLT(lt))
return true;
}
return hasSparseIdxReduction(bits);
diff --git a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
index 0c7b3a228a65cf..9e8b240899d808 100644
--- a/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
@@ -45,7 +45,7 @@ SparseTensorStorageBase::SparseTensorStorageBase( // NOLINT
for (uint64_t l = 0; l < lvlRank; l++) {
assert(lvlSizes[l] > 0 && "Level size zero has trivial storage");
assert(isDenseLvl(l) || isCompressedLvl(l) || isLooseCompressedLvl(l) ||
- isSingletonLvl(l) || is2OutOf4Lvl(l));
+ isSingletonLvl(l) || isNOutOfMLvl(l));
}
}
diff --git a/mlir/test/CAPI/sparse_tensor.c b/mlir/test/CAPI/sparse_tensor.c
index 2c6ad559f19d0c..a8b9f9048d5912 100644
--- a/mlir/test/CAPI/sparse_tensor.c
+++ b/mlir/test/CAPI/sparse_tensor.c
@@ -38,9 +38,9 @@ static int testRoundtripEncoding(MlirContext ctx) {
mlirSparseTensorEncodingAttrGetDimToLvl(originalAttr);
// CHECK: (d0, d1)[s0] -> (s0, d0, d1)
mlirAffineMapDump(dimToLvl);
- // CHECK: level_type: 4
- // CHECK: level_type: 8
- // CHECK: level_type: 8
+ // CHECK: level_type: 65536
+ // CHECK: level_type: 131072
+ // CHECK: level_type: 131072
MlirAffineMap lvlToDim =
mlirSparseTensorEncodingAttrGetLvlToDim(originalAttr);
int lvlRank = mlirSparseTensorEncodingGetLvlRank(originalAttr);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index 6fe7ec906f30e9..8293169049ca61 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -4,7 +4,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : structured[2, 4]
)
}>
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 20702bb9850284..64520638b253df 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -207,12 +207,12 @@ func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : structured[2, 4]
),
crdWidth = 8 // we would even like just 2-bits
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : block2_4), crdWidth = 8 }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), crdWidth = 8 }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
@@ -226,11 +226,11 @@ func.func private @NV_24(%arg0: tensor<?x?xf64, #NV_24>) {
( i : dense,
j : dense,
k floordiv 4 : dense,
- k mod 4 : block2_4
+ k mod 4 : structured[2, 4]
)
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 floordiv 4 : dense, d2 mod 4 : structured[2, 4]) }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
@@ -244,13 +244,31 @@ func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
( i : dense,
k floordiv 4 : dense,
j : dense,
- k mod 4 : block2_4
+ k mod 4 : structured[2, 4]
)
}>
-// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : block2_4) }>
+// CHECK-DAG: #[[$NV_24:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 4 : dense, d1 : dense, d2 mod 4 : structured[2, 4]) }>
// CHECK-LABEL: func private @NV_24(
// CHECK-SAME: tensor<?x?x?xf64, #[[$NV_24]]>
func.func private @NV_24(%arg0: tensor<?x?x?xf64, #NV_24>) {
return
}
+
+// -----
+
+#NOutOfM = #sparse_tensor.encoding<{
+ map = ( i, j, k ) ->
+ ( i : dense,
+ k floordiv 8 : dense,
+ j : dense,
+ k mod 8 : structured[5, 8]
+ )
+}>
+
+// CHECK-DAG: #[[$NOutOfM:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d2 floordiv 8 : dense, d1 : dense, d2 mod 8 : structured[5, 8]) }>
+// CHECK-LABEL: func private @NOutOfM(
+// CHECK-SAME: tensor<?x?x?xf64, #[[$NOutOfM]]>
+func.func private @NOutOfM(%arg0: tensor<?x?x?xf64, #NOutOfM>) {
+ return
+}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
index 7c494b2bcfe1d1..d04fbe8ed5c220 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_fill_zero.mlir
@@ -14,7 +14,7 @@
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 100 : index
// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 300 : index
-// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 8 : i64
+// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 131072 : i64
// CHECK: %[[VAL_12:.*]] = memref.alloca() : memref<2xi64>
// CHECK: %[[VAL_13:.*]] = memref.cast %[[VAL_12]] : memref<2xi64> to memref<?xi64>
// CHECK: memref.store %[[VAL_11]], %[[VAL_12]]{{\[}}%[[VAL_5]]] : memref<2xi64>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
index 4bc080fc538fc6..e47ac46597b77a 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -59,7 +59,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : structured[2, 4]
),
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
index df5b48a3b6ece8..ec5c7580657cd7 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_ds.mlir
@@ -41,7 +41,7 @@
#NV_24 = #sparse_tensor.encoding<{
map = ( i, j ) -> ( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4),
+ j mod 4 : structured[2, 4]),
crdWidth = 8
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index 17b50b46d073ae..b0f63f12c2d579 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -20,7 +20,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : structured[2, 4]
)
}>
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index eb99a027a88600..311cb607b4293c 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -20,7 +20,7 @@
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
- j mod 4 : block2_4
+ j mod 4 : structured[2, 4]
)
}>
diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py
index 946a224dab064a..412c5797067b7a 100644
--- a/mlir/test/python/dialects/sparse_tensor/dialect.py
+++ b/mlir/test/python/dialects/sparse_tensor/dialect.py
@@ -28,7 +28,7 @@ def testEncodingAttr1D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")
- # CHECK: lvl_types: [8]
+ # CHECK: lvl_types: [131072]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0) -> (d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
@@ -70,7 +70,7 @@ def testEncodingAttr2D():
# CHECK: equal: True
print(f"equal: {casted == parsed}")
- # CHECK: lvl_types: [4, 8]
+ # CHECK: lvl_types: [65536, 131072]
print(f"lvl_types: {casted.lvl_types}")
# CHECK: dim_to_lvl: (d0, d1) -> (d1, d0)
print(f"dim_to_lvl: {casted.dim_to_lvl}")
More information about the Mlir-commits
mailing list